Исследователи из Google AI предложили метод для обучения модели на множестве функций потерь одновременно. Loss-conditional обучение помогает при отборе оптимального распределения коэффициентов в функции потерь.
В части задач машинного обучения оценку модели нельзя выразить в единственном числе. Например, модель для сжатия изображения должна одновременно минимизировать размер сжатого изображения и максимизировать его качество. Часто невозможно одновременно оптимизировать все интересующие переменные, потому как они противоречат друг другу или из-за ограничений в обучении модели.
Ограничения взвешенной суммы в функции потерь
Стандартным подходом для обучения модели, которая оптимизирует несколько характеристик, является минимизация функции потерь, в которой все параметры суммируются с определенными весами. В случае с сжатием изображений функция потерь включала бы в себя два параметра, которые отражали бы качество изображения и уровень сжатия. Веса параметров в функции потерь влияют на результат обучения модели.
Если необходимо сравнить разное распределение весов в функции потерь, принято обучать несколько моделей с разными функциями потерь. Такой подход требует траты ресурсов на обучение и инференс нескольких моделей. Чтобы решить эту проблему, исследователи предлагают обучать одну модель. Модель учитывает функции потерь с разным распределением весов.
Loss-Conditional обучение
Идея метода заключается в том, что бы обучить одну модель, которая покрывает все возможные варианты распределения коэффициентов для параметров функции потерь. Такой формат обучения позволяет сократить требуемые ресурсы на обучение и тестирование моделей.
Обучение, которое обусловлено функцией потерь, состоит из двух шагов:
- Модель обучается на распределении функций потерь, а не на единственной функции потерь;
- Выходы модели соотносятся с вектором коэффициентов для параметров функции потерь
Так, во время инференса модели можно менять вектор с распределением весов параметров функции потерь. Это позволяет смещаться в пространств моделей с разными весами для параметров функции потерь.