NBDT: интерпретируемая нейросеть на основе решающих деревьев

Исследователи из UC Berkley предлагают интерпретируемую нейросеть, которая основана на архитектуре решающих деревьев и выдает сравнимые с state-of-the-art предсказания. Код проекта и предобученные модели доступны в открытом репозитории на GitHub.

Глубокое обучение применяется в сферах, которые требуют обоснованности предсказаний. Такими сферами являются финансовое моделирование или диагностирование болезней. Существующие исследования фокусируются на интерпретируемости обученных state-of-the-art нейросетей постфактум. До популяризации нейросетей решающие деревья являлись золотым стандартом как баланс между интерпретируемостью и точностью модели.

Предыдущие попытки скомбинировать решающие деревья и нейросети результировали в модели, которые:

  1. Менее точны, чем современные нейросети, даже на сравнительно маленьких датасетах (MNIST);
  2. Требует значительных изменений в архитектурах

Neural-Backed Decision Trees (NBDTs) выдают state-of-the-art результаты. При этом они не требуют изменений в архитектуре нейросети. NBDTs по точности отличаются от базовой нейросети не более чем на 1% на дататасетах CIFAR10, CIFAR100 и TinyImageNet. На ImageNet предложенная архитектура по точности отстает от EfficientNet на 2%.

Как это работает

Процесс обучения и инференса Neural-Backed Decision Tree состоит из четырех этапов:

  1. Сначала выстраивается иерархия для решающего дерева, которую называют Induced Hierarchy;
  2. Эта иерархия использует специальную функцию потерь Tree Supervision Loss;
  3. На инференсе данные сначала проходят через базовый блок нейросети (backbone), который идет до финального полносвязного слоя;
  4. Финальный полносвязный слой прогоняется как последовательность решающих правил, которые называют Embedded Decision Rules
Визуализация работы модели
Подписаться
Уведомить о
guest

0 Comments
Межтекстовые Отзывы
Посмотреть все комментарии

gogpt