Дообучение модели из Hugging Face на TPU: PyTorch, TorchAX и LoRA без полного перехода на JAX

https://dev.to/gde/fine-tune-any-huggingface-model-like-gemma-on-tpus-with-torchax-5g21 На Dev.to вышел разбор того, как совместить привычный стек PyTorch и каталог Hugging Face с Google TPU через TorchAX и LoRA, не переписывая модель целиком под JAX. Ниже — сжатое резюме поста и тех решений, которые автор выделяет как опорные для практики дообучения открытых весов.
От каталога весов до сохранённого адаптера: зачем это про инструменты вокруг моделей
Речь идёт о связке для fine-tuning: загрузка модели с Hugging Face, адаптер LoRA, обучение через примитивы JAX и обёртку torchax.train.make_train_step — путь от весов в каталоге до сохранённого адаптера с возможностью перезагрузки. В анонсе обещаны блоки про evaluation, сценарии save/reload и ссылка на Colab — это задаёт рамку как про ML-инжиниринг вокруг открытых моделей, а не про синтаксис языка отдельно от ИИ-стека.
Окружение: бесплатный Colab, TPU и установка зависимостей
На Dev.to в тексте описан бесплатный Google Colab с доступом к TPU v2-8 и порядка 15 GB высокоскоростной памяти; такого объёма достаточно для дообучения 1B модели в режиме LoRA (эту оценку даёт сам пост). Альтернатива полному JAX-рерайту — TorchAX: PyTorch-модель и цикл обучения через JAX-примитивы без полной переписки модели под JAX.
Как предпосылки указаны Python 3.10+, базовые знания PyTorch и Hugging Face Transformers, аккаунт Colab (для LoRA достаточно free tier). Установка сведена к примерам pip: отдельный набор для TPU (jax[tpu], torchax, transformers, flax, peft, datasets, optax и др.) и пометка для GPU с jax[cuda12]. Отмечено, что в Colab после установки пакетов имеет смысл перезапустить runtime, чтобы не держать в памяти предзагруженный JAX.
Какую модель и данные берёт пример
Базовая модель в приводимом коде — google/gemma-3-1b-it, датасет — databricks/databricks-dolly-15k (в описании — 15k пар instruction/response, семь категорий). Для обучения берётся подвыборка: shuffle(seed=42).select(range(2200)), затем train_test_split(test_size=200, seed=42); в показанном фрагменте токенизация с max_length=512, batch_size=2.
Конфигурация PEFT LoRA: r=8, lora_alpha=16, lora_dropout=0.0, целевые модули внимания q_proj, k_proj, v_proj, o_proj. Вывод print_trainable_parameters(): обучаемых 5 767 168 параметров из 2 619 206 656 (0,22% trainable).
Evaluation и известный подводный камень с маской внимания
До обучения считается средний loss на eval_dataloader с ограничением max_batches=50, выводится perplexity как экспонента от среднего loss (в коде есть верхняя отсечка для perplexity). Дополнительно — качественное сравнение ответов модели до и после дообучения. После обучения — табличное сравнение loss и perplexity «до/после» и снова качественная генерация.
Для Gemma со sliding window attention в связке torchax/JAX из батча убирается attention_mask: с паддинг-маской в этом сценарии возникают NaN; маскирование паддинга переносится на labels (значение -100).
Обучение: optax, первый JIT-шаг и ориентиры по времени
Оптимизатор в примере — optax с расписанием warmup_cosine_decay_schedule (peak_value 1e-4, warmup_steps 50, decay_steps 500), цепочка с clip_by_global_norm(1.0) и AdamW (weight_decay=0.01). Перед загрузкой модели с HF рекомендуется tx.disable_temporarily(); для стабильности bfloat16 — tx.enable_accuracy_mode() до tx.enable_globally().
Ориентиры по времени: первый шаг около 30–60 с (JIT), далее около 1–3 с на шаг; суммарно примерно 20–40 мин для 2000 сэмплов с LoRA на free Colab TPU.
Сохранение, перезагрузка и запас по памяти для полного дообучения
Сохранение описано через конвертацию состояния в CPU-тензоры и model.save_pretrained(..., safe_serialization=False) — safe_serialization=False обходит конфликт safetensors/torchax при последующей загрузке. Перезагрузка LoRA: снова tx.disable_temporarily(), базовая модель с HF, затем peft.PeftModel.from_pretrained(..., torch_device="cpu") с тем же мотивом, после чего reloaded_model.to(device).
В таблице материала для full fine-tuning приведены ориентиры памяти примерно 18–20 GB против примерно 5–7 GB для LoRA; для укладки в free Colab при full режиме предлагаются AdaFactor, уменьшенная длина последовательности (256), batch 1 с накоплением градиента, gradient checkpointing.
Отдельной юридической секции про условия использования весов Gemma в основном тексте материала на Dev.to нет: для лицензий и ограничений он не заменяет официальную документацию Google и карточку модели на Hugging Face.
В разделе Resources перечислены ноутбуки на GitHub agemagician/torchax-huggingface, в том числе полный туториал обучения, quickstart и ноутбук по инференсу — с прямыми ссылками на Colab в оригинале.
Источники
- Fine-Tune Any HuggingFace Model like Gemma on TPUs with TorchAX — публикация на Dev.to. Опубликовано: 2026-04-27 (UTC). Дата доступа: 2026-04-28. Dev.to
- Colab (из раздела Resources того же поста): https://colab.research.google.com/github/agemagician/torchax-huggingface/blob/main/notebooks/torchax_training_tutorial.ipynb — дата доступа: 2026-04-28 (UTC).
- Colab (quickstart): https://colab.research.google.com/github/agemagician/torchax-huggingface/blob/main/notebooks/torchax_training_quickstart.ipynb — дата доступа: 2026-04-28 (UTC).
- Colab (инференс, часть серии): https://colab.research.google.com/github/agemagician/torchax-huggingface/blob/main/notebooks/torchax_huggingface_tutorial.ipynb — дата доступа: 2026-04-28 (UTC).