fbpx
  • Google Brain оптимизировали время обучения BERT с 3-х дней до 76 минут

    google brain bert ai

    Команда Google Brain представила оптимайзер LAMB, способный сократить время предобучения языковой модели BERT с 3-х дней до 76 минут (вычисления проводились на TPU).

    Transfer learning и проблема BERT

    BERT (Bidirectional Encoder Representations from Transformers) — алгоритм для решения задачи transfer learning в обработке естественного языка. Задача transfer learning сводится к тому, что для предсказания следующего слова необходимо хранить информацию о контексте предыдущих слов (память). BERT был представлен в 2018 году исследователями из Google.

    Результаты BERT в сравнении с существующими архитектурами
    Результаты BERT в сравнении с существующими архитектурами

    BERT изначально показывал лучшие результаты в экспериментах в сравнении с конкурирующими архитектурами, однако, полная версия модели обучается за 1 миллион итераций. Это делает модель высоко ресурсозатратной и вводит ограничение на ее самостоятельное предобучение. Чтобы частично решить проблему масштабирования модели, исследователи предлагают заменить AdamW оптимайзер, используемый для обучения BERT, на LAMB (Layer-wise Adaptive Moments optimizer for Batch training).

    Общая задача нейронной сети сводится к поиску глобального минимума на плоскости. Оптимайзер нужен нейронной сети, чтобы решать, в какую сторону делать шаг на каждой из итераций обучения нейросети.

    LAMB оптимайзер

    Значительное преимущество LAMB в том, что единственный гиперпараметр, который он принимает на вход — это learning rate. Использование LAMB помогает уменьшить размер батчей до 65536, не теряя в точности модели. Псевдокод, описывающий алгоритм работы LAMB представлен на картинке ниже.

    lamp optimizer

    C помощью LAMB исследователи смогли сократить количество требуемых для обучения итераций с 1 миллиона до 8599 (с размером батчей равным 65536/32768).

    Результаты работы BERT с LAMB оптимайзером

    В экспериментах, проведенных для тестирования работы оптимизированной модели, использовался TPUv3 Pod с 1024 чипами. Это необходимо учитывать, чтобы адекватно оценивать масштабы ресурсозатратности обучения BERT.

    В качестве бейзлайна исследователи использовали F1 меру полной модели (BERT-Large). В экспериментах были использованы те же сеты данных, что и для опубликованной изначально модели: Wikipedia, содержащий 2.5 миллиарда слов, и BooksCorpus, содержащий 800 миллионов слов.

    Предобучение BERT состоит из двух этапов:

    1. На первых 90% эпохах размер последовательностей слов составляет 128;
    2. На последних 10% эпох размер последовательностей слов увеличивается до 512.

    Обучение масштабированной версии BERT закончилось в 8599 итераций с размером батча равным 64 тыс. (для этапа 1) и 32 тыс. (для этапа 2).

    В результате обучение завершилось за 76 минут. Так, исследователям получилось оптимизировать эффективность модели на 101.8%.