tgoop.com/quant_prune_distill/419
Last Update:
Метод
Архитектура следующая - есть трансформерная тушка и несколько голов, каждая из которых предсказывает k-ый следующий токен (для головы с индексом k). Если я правильно понял, эти головы на самом деле преобразуют эмбеддинг перед подачей в unembedding матрицу (из размерности модели в размер словаря), а сама unembedding матрица общая для всех токенов.
Обучают на стандартный кроссэнтропийный лосс.
Дабы расход памяти не взрывался от тяжелых матриц логитов, авторы предлагают делать backward по каждой голове в отдельности (в LigerKernel на этапе обучения логиты считают чанками и делают backprop на них, к слову).
Эксперименты
Обучают семейство моделей размером от 300M to 13B параметров на датасете из ~100B токенов какого-то кода. Валидируют на MBPP, HumanEval, APPS
- сравнительно простых задачах про код. Пробуют обучать на сырых байтах и словаре из 32к токенов.
На маленьких моделях предсказание нескольких токенов вперед работает плохо, но начиная с какого-то размера (~3B) становится лучше по бенчам.
4 головы отпимальны по качеству для словаря в 32к токенов (8 для байтов).
Далее метод проверяют в сценарии дообучения и сравнивают 3 варианта:
⚡️Дообучение на предсказание 1 токена вперед, для модели обученной предсказывать 1 токен вперед
⚡️Дообучение на предсказание 1 токена вперед, для модели обученной предсказывать 4 токена вперед
⚡️Дообучение на предсказание 4 токена вперед, для модели обученной предсказывать 4 токена вперед
Оказывается, что второй вариант работает лучше всего почему-то.
Multi-token prediction работает не очень на multiple-choice задачах. Вероятно потому, что там требуется выдать всего один или немного токенов.
Потом тестируются на синтетике - Induction Heads, арифметике многочленов и наблюдают некоторый прирост качества, который объясняют тем, что в таких задачах полезно смотреть слегка наперед.
Очевидный практический плюс от многотокенного предсказания - ускорение 🚤 инференса в 3 раза на BPE токенах и около 6 на байтах.
Вывод
Mutli-token prediction выглядит как естественная и рабочая история. Тем более что в нашумевшем DeepSeek-V3 (где использовалась модифицированная версия метода с трансфорнеыми блоками на каждый новый токен) данная стратегия тоже отлично завелась. Вероятно, она будет стандартной в будущих моделях. Ждем 🦙-4, Qwen-3?
BY КПД
Share with your friend now:
tgoop.com/quant_prune_distill/419