tgoop.com/gonzo_ML/71
Last Update:
ACT: Adaptive Computation Time
Есть такая хорошая и малоизвестная тема под названием Adaptive Computation Time (ACT). Недавно стала чуть более известна благодаря Universal Transformer, но в целом всё равно для многих является экзотикой. Хочется рассказать.
Идея в том, что можно динамически решать, как долго проводить вычисления: сколько гонять их через один слой в случае RNN или трансформера, или когда остановиться при проходе через residual block (причём своё решение может быть принято для каждой отдельной части изображения). Для сложных элементов последовательности, возможно, есть смысл пообрабатывать их подольше, а простым может хватить и одного прохода/слоя.
Первоначальная работа:
Alex Graves,
“Adaptive Computation Time for Recurrent Neural Networks”,
https://arxiv.org/abs/1603.08983
Это тот самый Alex Graves, который придумал многомерные RNN, CTC, Grid-LSTM, работал над NTM и DNC, а также над много чем ещё интересным. Начинал в лаборатории Шмидхубера, сейчас работает в DeepMind. Рекомендую следить за его работами, если ещё не.
Итак,
#1. RNN with ACT
В обычных рекурретных сетях текущее состояние s_t зависит от предыдущего состояния s_{t-1} и от входа x_t:
s_t = S(s_{t-1}, x_t)
В случае ACT этот шаг модифицируется так, что рекуррентный слой может крутить внутри себя данные несколько раз, функция S теперь может применяться n раз, на n-ом шаге беря на вход результат n-1 шага:
s^n_t = S(s_{t-1}, x^1_t) для n=1
s^n_t = S(s^{n-1}_t, x^n_t) для других n.
Входной x также дополняется бинарным флажком про увеличение индекса входа, чтобы сеть могла различать повтор в последовательности от шагов последовательной обработки такой вот модифицированной функцией.
Вопрос в том, какой выбрать n, сколько шагов повторять эту процедуру.
Для определение необходимости останова заводится отдельный сигмоидальный halting unit, выход которого, h, определяет, что пора прекращать обработку. Обработка прекращается, когда сумма выходов halting unit за прошедшие шаги становится близка к единице (1 - epsilon, где epsilon -- это гиперпараметр, в работе выбран 0.01). Последнее значение h заменяется на остаток, чтобы сумма всех h была равна единице.
После этого состояние s_t определяется как взвешенная сумма состояний за разные временные шаги с весами, полученными описанным выше способом. Идейно похоже на механизм soft attention.
То есть, если halting unit выдавал на последовательных шагах значения: 0.1, 0.3, 0.4, 0.4, то получим массив весов: [0.1, 0.3, 0.4, 0.2] (последнее значение урезалось). И конечным состоянием слоя будет s_t = 0.1*s^1_t + 0.3*s^2_t + 0.4*s^3_t + 0.2*s^4_t.
Чтобы у сети не было стимула крутить данные внутри как можно дольше в функцию потерь добавляется так называемый ponder cost (с ещё одним гиперпараметром, time penalty -- весом, который надо подбирать), отражающий, как долго обрабатывался каждый отдельный элемент входной последовательности.
Этот ponder cost не является непрерывным в точках, где изменяется количество обработок элемента последовательности, но на это забивают и в остальном градиенты halting activations вычисляются нормально и всё обучается backprop’ом.
Проверяют на задачках, где это может сыграть: проверка чётности входного вектора (который подаётся целиком за один раз, а не как последовательность), логические функции над входом-вектором, сложение элементов последовательности (где каждый элемент представлен последовательностью цифр) и сортировка небольшого набора чисел.
Сети с ACT справляются ощутимо лучше, но надо подбирать параметр time penalty (здесь подбирали grid search’ем).
Ещё один интересный эксперимент -- предсказание следующего символа на датасете из Википедии. Получаются красивые графики, над каким символом сеть сколько (ponder time) думала.
(продолжение скоро)
Если обнаружите неточности или есть комментарии, пишите @che_shr_cat
BY gonzo-обзоры ML статей
Share with your friend now:
tgoop.com/gonzo_ML/71