AlphaNet: Improved Training of Supernets with Alpha-Divergence

Abstract

Weight-sharing neural architecture search (NAS) is an effective technique for automating efficient neural architecture design. Weight-sharing NAS builds a supernet that assembles all the architectures as its sub-networks and jointly trains the supernet with the sub-networks. The success of weight-sharing NAS heavily relies on distilling the knowledge of the supernet to the sub-networks. However, we find that the widely used distillation divergence, i.e., KL divergence, may lead to student sub-networks that over-estimate or under-estimate the uncertainty of the teacher supernet, leading to inferior performance of the sub-networks. In this work, we propose to improve the supernet training with a more generalized alpha-divergence. By adaptively selecting the alpha-divergence, we simultaneously prevent the over-estimation or under-estimation of the uncertainty of the teacher model. We apply the proposed alpha-divergence based supernets training to both slimmable neural networks and weight-sharing NAS, and demonstrate significant improvements. Specifically, our discovered model family, AlphaNet, outperforms prior-art models on a wide range of FLOPs regimes, including BigNAS, Once-for-All networks, and AttentiveNAS. We achieve ImageNet top-1 accuracy of 80.0% with only 444M FLOPs.

Publication
In Conference on Machine Learning
Meng Li
Meng Li
Staff Research Scientist

I am currently a staff research scientist and tech lead in the Meta On-Device AI team with a focus on researching and productizing efficient AI algorithms and hardwares for next generation AR/VR devices. I received my Ph.D. degree in the Department of Electrical and Computer Engineering, University of Texas at Austin under the supervision of Prof. David Z. Pan and my bachelor degree in Peking University under the supervision of Prof. Ru Huang and Prof. Runsheng Wang. My research interests include efficient and secure AI algorithms and systems.

var dimensionValue = 'SOME_DIMENSION_VALUE'; ga('set', 'dimension1', dimensionValue);