Telegram Web
😁3🥴1🏆1
This media is not supported in your browser
VIEW IN TELEGRAM
22🕊9👍1
xLSTM: Extended Long Short-Term Memory
Maximilian Beck, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter, Sepp Hochreiter
Статья: https://arxiv.org/abs/2405.04517
Код: https://github.com/NX-AI/xlstm

Новый улучшенный LSTM от автора старого LSTM! Вэлкам extended LSTM, xLSTM!

Цель авторов — понять, докуда можно прокачать языковое моделирование на LSTM, если отскейлить их до миллиардов параметров, перенять все передовые наработки из LLM и устранить известные узкие места.

LSTM оказались суперуспешны и выдержали проверку временем. “The most cited NN of the 20th century” (https://people.idsia.ch/~juergen/most-cited-neural-nets.html) как никак.

Напомню, что LSTM в оригинале (1997, https://direct.mit.edu/neco/article-abstract/9/8/1735/6109/Long-Short-Term-Memory) создавался для борьбы с затухающими (или взрывающимися) градиентами. В отличие от обычной RNN (где было только скрытое состояние h) в LSTM добавлена скалярная ячейка памяти (и здесь теперь одновременно есть c и h) и то что называлось constant error carousel (CEC) для обновления её состояния, изначально это было рекуррентное соединение с весом 1.0 на своё же прошлое состояние. В изначальной модели также были два управляющих процессом гейта, input и output. Input gate защищал память от влияния нерелевантных данных, output gate защищал другие ячейки от нерелевантного состояния текущей. Чуть позже (через три года, https://direct.mit.edu/neco/article-abstract/12/10/2451/6415/Learning-to-Forget-Continual-Prediction-with-LSTM) добавился forget gate для сбрасывания состояния памяти, когда надо.

Но у LSTM есть три основных ограничения:

1. Неспособность пересмотреть решение о сохранении данных внутри ячейки памяти. Это демонстрируют на простой задаче поиска ближайшего соседа (Nearest Neighbor Search), где сначала даётся референсный вектор, а далее сканируется последовательность других векторов, и модель должна найти наиболее похожий вектор и вернуть связанное с ним значение, когда последовательность закончится. Когда в последовательности попадается ещё более подходящий вектор, то модель не справляется.

2. Ограниченная память — всё надо впихнуть внутрь скаляра, который хранится в ячейке памяти LSTM. Это демонстрируют на задаче предсказания редкого токена (Rare Token Prediction), где перплексия на токенах из редкого бакета на Wikitext-103 особенно плоха.

3. Плохая параллелизация в силу последовательной обработки скрытых состояний ячейки между соседними временными отсчётами. Состояние зависит от предыдущего через hidden-hidden связи, это так называемый memory mixing.

Авторы предлагают Extended Long Short-Term Memory (xLSTM) с двумя основными модификациями базового уравнения, описывающего работу стандартного LSTM. Одно изменение -- экспоненциальные гейты, другое — новые структуры памяти.

Отсюда рождаются два новых члена семьи LSTM: sLSTM со скалярной памятью, скалярным обновлением и memory mixing, и mLSTM с матричной памятью, covariance update rule через outer product и без memory mixing. Оба варианта с экспоненциальным гейтингом.

В sLSTM добавляются экспоненциальные функции активации на input и forget gate. Также появляется отдельное состояние нормализатора, и чтобы от экспоненты всё не разнесло ещё и состояние стабилизатора. Память — по-прежнему скаляр.

Как и LSTM, sLSTM может иметь множество ячеек памяти (состояние памяти — вектор), где возможен memory mixing через рекуррентные соединения скрытого состояния (h) и гейтов со входами ячейки памяти. Также sLSTM может иметь множество голов с memory mixing внутри головы, но не между головами. Я так понял, что головы задаются структурой блочно-диагональной матрицы, через которую прогоняются все входы, и диагональные блоки задают отдельные головы.
🔥219👍7😐1
mLSTM (не путать с multiplicative LSTM из https://arxiv.org/abs/1609.07959) выглядит похитрее. Здесь память не скаляр, а матрица C размерности d×d. Авторы описывают работу с памятью в трансформерных терминах с query, key, value. Вытаскивание из памяти реализовано матричным умножением C на q. Для сохранения k,v (каждый из которых вектор размерности d) используется covariance update rule: C_t = C_{t−1} + v_t * k^⊤_t. Используется LayerNorm для k и v, чтобы среднее было нулевым. Как и в sLSTM здесь есть экспоненциальные гейты, отдельное состояние нормализатора и такая же стабилизация.

У mLSTM тоже может быть множество ячеек памяти. Memory mixing здесь нет и множественные головы равнозначны множественным ячейкам. И так как нет рекуррентных hidden-hidden соединений, то возможна формулировка вычислений в параллельной форме.

Если полученные новые варианты вставить в residual blocks (с pre-LayerNorm), то получаются xLSTM блоки, которые можно стыковать. Есть два варианта xLSTM блока: с post up-projection и pre up-projection. Первый (post, как в трансформерах) делает нелинейную суммаризацию прошлого в своём оригинальном пространстве, затем линейно преобразует его в пространство более высокой размерности, применяет там нелинейную функцию активации и переводит обратно в пространство поменьше. Второй (pre, как в SSM) сначала переводит в пространство размерности побольше, там делает суммаризацию, и переводит обратно. См. картинки. Для первого обычно используется sLSTM, для второго mLSTM (в высокоразмерном пространстве матричная память лучше работает).

В отличие от стандартных трансформеров относительно длины последовательности сложность по вычислениям у xLSTM линейная, и по памяти константа. Рекомендуют для edge вычислений, потому что память compressive. mLSTM хорошо параллелится, sLSTM не параллелится.

Главный вопрос — что это всё даёт. В экспериментальной части сосредоточились на языковом моделировании.

В экспериментах используется нотация xLSTM[a:b], где отношение a/b отвечает за количества mLSTM/sLSTM блоков. Так, если всего в xLSTM 48 блоков, то для xLSTM[7:1] это значит, что есть 42 mLSTM и 6 sLSTM.

Сначала проверили на задачах с формальными языками, где должна быть видна способность решать state tracking problems. Бейзлайны достойные: Llama, Mamba, Retention, Hyena, RWKV-4/5/6. Результаты подтверждают, что трансформеры и SSM фундаментально менее мощные, чем RNN (полезное видео для интересующихся https://www.youtube.com/watch?v=4-VXe1yPDjk и статья https://arxiv.org/abs/2404.08819; также см https://www.tgoop.com/gonzo_ML/1049 про иерархии Хомского). Также видно, что sLSTM обходит mLSTM.

Затем потестили на Multi-Query Associative Recall task, когда демонстрируются до 256 key,value пар в последовательности и они должны быть запомнены для последующего retrieval’а. Трансформер тут золотой стандарт, из всех нетрансформерных (Mamba, RWKV-5, RWKV-6, xLSTM[1:1], xLSTM[1:0]) лучше оказалась xLSTM[1:1].

На задачах из Long Range Arena (Retrieval, ListOps, Image, Pathfinder) чуть лучше Мамбы и ещё лучше RWKV.

Наконец языковое моделирование. Обучали на 15B токенов из SlimPajama, сравнивали со многими свежими моделями (GPT-3, Llama, H3, Mamba, Hyena, RWKV, RetNet, HGRN, HGRN2, GLA), сходу нет только какого-нибудь Griffin (но его официальной открытой имплементации так понимаю нет). Модели сопоставимы с GPT-3 с 350M параметров. По итоговой перплексии в топе xLSTM[1:0] и xLSTM[7:1].

Кривая скейлинга (до 2.7B) выглядит хорошо, лучше Llama, Mamba, RWKV-4.

Сделали абляции постепенно трансформируя ванильный LSTM в xLSTM. Экспоненциальные гейты и матричная память сильно добавляют качества.

Далее увеличили объём обучающих данных до 300B токенов (такие же числа использовались в Mamba и Griffin). В сравнении участвовали xLSTM, RWKV-4, Llama, Mamba как лучшие представители в своих классах. Обучали модели размеров 125M, 350M, 760M, 1.3B.

На обученном контексте в 2048 проверили экстраполяцию на большую длину, до размера 16384 у xLSTM всё хорошо.
👍12🔥84❤‍🔥1
Потом проверили перплексию и качество на различных downstream задачах. xLSTM почти везде лидирует. На языковых задачах из PALOMA (https://arxiv.org/abs/2312.10523) тоже лучше Мамбы, Ламы и RWKV-4.

Скейлится на 300B токенов тоже хорошо. Интересно, конечно, было бы что-то гигантское обучить, скажем 175B. Но я понимаю, что не у всех бюджеты как у OpenAI или Гугла.

Из главных ограничений пока скорость. sLSTM не параллелится, но быстрая имплементация всего в 1.5 раза медленнее параллельного mLSTM. Этот в свою очередь не оптимизирован и примерно в 4 раза медленнее FlashAttention или реализации через scan в Mamba. Но наверняка можно всё ускорить. Есть ещё некоторые тонкости, перечислены в разделе 5 статьи. По оптимизации и поиску хороших гиперпараметров явно ещё поле непаханое.

Официального кода вроде нет, но вот начались народные попытки воспроизвести, например, mLSTM (https://github.com/andrewgcodes/xlstm).

В общем имеем RNN, которая выглядит не хуже трансформеров и SSM. Ждём продолжений! И кода.

Long live RNN!
👍22🔥16
2025/07/14 19:52:00
Back to Top
HTML Embed Code: