Bayesian Federated Learning via Predictive Distribution Distillation
For most existing federated learning algorithms, each round consists of minimizing a loss function at each client to learn an optimal model at the client, followed by aggregating these client models at the server. Point estimation of the model parameters at the clients does not take into account the uncertainty in the models estimated at each client. In many situations, however, especially in limited data settings, it is beneficial to take into account the uncertainty in the client models for more accurate and robust predictions. Uncertainty also provides useful information for other important tasks, such as active learning and out-of-distribution (OOD) detection. We present a framework for Bayesian federated learning where each client infers the posterior predictive distribution using its training data and present various ways to aggregate these client-specific predictive distributions at the server. Since communicating and aggregating predictive distributions can be challenging and expensive, our approach is based on distilling each client's predictive distribution into a single deep neural network. This enables us to leverage advances in standard federated learning to Bayesian federated learning as well. Unlike some recent works that have tried to estimate model uncertainty of each client, our work also does not make any restrictive assumptions, such as the form of the client's posterior distribution. We evaluate our approach on classification in federated setting, as well as active learning and OOD detection in federated settings, on which our approach outperforms various existing federated learning baselines.
READ FULL TEXT