Proximal Mean Field Learning in Shallow Neural Networks
Recent mean field interpretations of learning dynamics in over-parameterized neural networks offer theoretical insights on the empirical success of first order optimization algorithms in finding global minima of the nonconvex risk landscape. In this paper, we explore applying mean field learning dynamics as a computational algorithm, rather than as an analytical tool. Specifically, we design a Sinkhorn regularized proximal algorithm to approximate the distributional flow from the learning dynamics in the mean field regime over weighted point clouds. In this setting, a contractive fixed point recursion computes the time-varying weights, numerically realizing the interacting Wasserstein gradient flow of the parameter distribution supported over the neuronal ensemble. An appealing aspect of the proposed algorithm is that the measure-valued recursions allow meshless computation. We demonstrate the proposed computational framework of interacting weighted particle evolution on binary and multi-class classification. Our algorithm performs gradient descent of the free energy associated with the risk functional.
READ FULL TEXT