MGT: графовая нейросеть для распознавания набросков

MGT — это архитектура нейросети, которая адаптирована для распознавания набросков. Модель обрабатывает наброски в виде графов. MGT выучивает геометрические и временные признаки рисунков. Предложенный подход протестировали на датасете Google QuickDraw. Результаты MGT сравнимы с подходом, основанном на CNN: 72.80% точность против 74.22%. При этом модель обошла подход, который основан на RNN. Исследователи заявляют, что ранее скетчи не представлялись в виде графов и не обрабатывались с помощью графовых архитектур.

Обучение представлениям рисунков является сложной задачей из-за высокой абстрактности набросков и разреженности линий. Существующие подходы фокусируются на статичности скетчей с помощью сверточных сетей (CNNs). Помимо этого, скетчи могут рассматриваться как последовательности линий. Тогда для их представления используют рекуррентные сети (RNNs). Multi-Graph Transformer (MGT) работает со скетчами как с графами. Графы захватывают локальную и глобальную структуру линий наброска.  

Как это работает 

Архитектура MGT вдохновлена трансформером. Каждый слой в модели состоит из двух компонентов:

  1. Multi-Graph Multi-Head Attention (MGMHA): подслой с механизмом внимания;
  2. Полносвязный подслой с кодированием позиции (positional encoding)
Визуализация составных частей MGT

Тестирование работы модели

Чтобы проверить работу MGT, исследователи взяли задачу распознавания скетчей на датасете Google QuickDraw. Google QuickDraw состоит из 414 тысяч изображений набросков. В качестве базовых моделей исследователи отобрали вариации state-of-the-art CNN архитектур и двунаправленные RNN. Ниже видно, что CNN-подходы выдавали более точные предсказания, чем MGT. Несмотря на это, MGT обошла RNN-подходы.

Таблица с результатами сравнений
Подписаться
Уведомить о
guest

0 Comments
Межтекстовые Отзывы
Посмотреть все комментарии

gogpt