Scaling description of generalization with number of parameters in deep learning
We provide a description for the evolution of the generalization performance of fixed-depth fully-connected deep neural networks, as a function of their number of parameters N. In the setup where the number of data points is larger than the input dimension, as N gets large, we observe that increasing N at fixed depth reduces the fluctuations of the output function f_N induced by initial conditions, with ||f_N-f̅_N||∼ N^-1/4 where f̅_N denotes an average over initial conditions. We explain this asymptotic behavior in terms of the fluctuations of the so-called Neural Tangent Kernel that controls the dynamics of the output function. For the task of classification, we predict these fluctuations to increase the true test error ϵ as ϵ_N-ϵ_∞∼ N^-1/2 + O( N^-3/4). This prediction is consistent with our empirical results on the MNIST dataset and it explains in a concrete case the puzzling observation that the predictive power of deep networks improves as the number of fitting parameters grows. This asymptotic description breaks down at a so-called jamming transition which takes place at a critical N=N^*, below which the training error is non-zero. In the absence of regularization, we observe an apparent divergence ||f_N||∼ (N-N^*)^-α and provide a simple argument suggesting α=1, consistent with empirical observations. This result leads to a plausible explanation for the cusp in test error known to occur at N^*. Overall, our analysis suggests that once models are averaged, the optimal model complexity is reached just beyond the point where the data can be perfectly fitted, a result of practical importance that needs to be tested in a wide range of architectures and data set.
READ FULL TEXT