What is Knowledge Distillation?
A fascinating yet effective strategy for machine learning models to learn- better.
Knowledge distillation. In the popular machine learning world, this topic isn’t extremely well-known, for example not as well known as CNNs, or image classification, etc. But it is a critical method that can help, for example, CNNs, classify images with higher accuracy. And that’s critical because accuracy is one of the prime measures of a machine learning model’s success rate — such as when solving big tasks and problems such as detecting skin cancer (another common topic).
Basically, knowledge distillation is where a teacher model, or a baseline (usually a bigger model), is first trained on a dataset, and reaches an optimal accuracy. We then save the tuned weights of this model, and then distill or pass it on to a student model (usually a smaller model). This student model then tries to train on the same or a similar dataset to reach an even higher accuracy. Knowledge is being transferred from the teacher to the student by minimizing a loss function to increase accuracy and the student tries to match the logits of the teacher and the class labels of the dataset.
Knowledge distillation could also occur for knowledge transferring from an ensemble of models to a smaller model. For example, if multiple models are first pre-trained on a dataset, then combined to produce an average of their accuracies (hopefully higher than their individual accuracies), and then this combined model’s weights are saved, these weights can then be passed down to a smaller model. This smaller model can then basically get the same or better accuracy on a dataset than a bulky one can.
This is important because if one were to deploy a model into an app for consumer use, for example, that model needs to be small and lightweight enough. Therefore, bulky, large, complex models would not work, even though, by themselves, they generally result in higher accuracies than smaller models. So with knowledge distillation, smaller models can basically achieve high accuracy just like complex models can, and these small models can easily be deployed.
However, there are some details we need to consider when implementing knowledge distillation. There are cases where this method will not help increase accuracy much for a smaller or student model. For example, if teacher models are too big/complex, the student cannot easily match capabilities and would not be able to learn the weights from the teacher. Other times, if a dataset is too complex/challenging to learn, such as ImageNet, knowledge distillation does not help increase accuracy either, as a student model may not have the capabilities to learn.
Though not a super-popular topic in machine learning (not something you would widely see), knowledge distillation can prove to be a very effective method to increase accuracy for machine learning models. It does have some caveats to look out for, but they are a fascinating topic to look into to to not only further help increase accuracy but also provide an effective way for smaller and more accurate models to be deployed into a usable solution for customers to use.
Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2(7).
Cho, J. H., & Hariharan, B. (2019). On the efficacy of knowledge distillation. In Proceedings of the IEEE/CVF international conference on computer vision (pp. 4794–4802).