Deep neural networks are ambiguous for many reasons. They can be as simple as, “How can Stochastic Gradient Descent (SGD) find good solutions to a complicated non-convex optimisation problem?”
However, the answer is not always straightforward. In order to define other such puzzling fundamental questions, the researchers at Facebook AI recently released a new study. This study demonstrates a way to address fundamental questions with a framework while backing up their claims with theoretical proofs.
This study tries to address questions like:
- Why do neural networks generalise?
- How can networks trained with SGD fit both random noise and structured data, but prioritise structured models, even in the presence of massive noise?
- Why are flat minima related to good generalisation?
- Why does over-parameterisation lead to better generalisation?
- Why do lottery tickets exist?
An Overview Of Teacher-Student Model
To begin with, neural networks have gained popularity because of their ability to generalise. To generalise means that a trained network can classify data from the same class as the learning data that it has never seen before.
Making the networks to learn new strategies to generalise better is usually the aim behind any algorithmic enhancement. One such pedagogical strategy is the introduction of the teacher-student network.
First, a ‘student’ neural network is given randomly selected input examples of concepts and is trained from those examples using traditional supervised learning methods to guess the correct concept labels.
In the second step, the ‘teacher’ network is made to test different examples on the student and see which concept labels the student assigns them, eventually converging on the smallest set of examples it needs to give to let the student guess the intended concept. These examples end up looking interpretable because they are still grounded to the concepts (via the student trained in step one).
For example, a dog classifier might say “German shepherd 0.9, Pug 0.1” instead of “German shepherd 1.0, Pug 0.0”. These “soft labels” are more informative than the original ones — telling the student that yes, a particular dog does slightly resembles a Pug.
So, a classifier that is trained on lots of labelled data, and when the resulting model is too large; teacher and student network can be fed with some data and train the student on the output of the teacher (rather than your original labels).
Puzzling Aspects Of Deep Networks
In their theory, the researchers claim that salient weights are those lucky regions that happen to overlap with some teacher nodes after initialisation and converge to them in optimisation. Using a teacher-student network, experiments are done on various publicly available dataset to check the consequences of over parameterisation, use of BatchNorm and other such fundamental questions which underline the importance of training.
For instance, Lottery Tickets is an interesting phenomenon in which, resetting the “salient weights” (trained weights with large magnitude) back to the values before optimisation but after initialisation, prune other weights (often > 90% of total weights) and retrain the model. The test performance is worse if these salient weights are reinitialised.
Therefore, if their weights are reset and others are pruned away, they can still converge to the same set of teacher nodes, and potentially achieve better performance due to less interference with other irrelevant nodes.
In case of fitting both structured and random data under gradient descent dynamics, some student nodes, which happen to overlap substantially with teacher nodes, will move into the teacher node and cover them.
In over-parameterisation, lots of student nodes are initialised randomly at each layer. And, any teacher node is more likely to have a significant overlap with some student nodes, which leads to fast convergence- an explanation for why over-parameterisation leads to generalisation
Finding the minima is an intrinsic objective of almost all the deep learning algorithms. Deep networks often converge to “flat minima” containing a lot of small eigenvalues. Flat minima seem to be associated with good generalisation, while sharp minima often lead to poor generalisation.
This new theoretical framework that uses teacher-student setting can also be extended to understand the training dynamics of multi-layered ReLU network.
On the other hand, in a curious case of implicit regularisation, the snapping behaviour enforces winner-take-all after optimisation, a teacher node is fully covered (explained) by a few student nodes, rather than splitting amongst student nodes due to over-parameterisation.
This explains why the same network, once trained with structured data, can generalise to the test set.
- The final winning student nodes also have a good rank at the early stages of training, in particular after the first epoch.
- BatchNorm helps a lot, in particular for the CNN case with GAUS dataset. Using BatchNorm accelerates the growth of accuracy.
- For CIFAR-10, the final evaluation accuracy learned by the student is often ∼1% higher than the teacher.
- Using a teacher-student setting, a novel relationship between the gradient received by hidden student nodes and the activations of teacher nodes for deep ReLU network has been discovered.
Read more about this work here.