Surrogate Gap Minimization Improves Sharpness-Aware Training

by   Juntang Zhuang, et al.

The recently proposed Sharpness-Aware Minimization (SAM) improves generalization by minimizing a perturbed loss defined as the maximum loss within a neighborhood in the parameter space. However, we show that both sharp and flat minima can have a low perturbed loss, implying that SAM does not always prefer flat minima. Instead, we define a surrogate gap, a measure equivalent to the dominant eigenvalue of Hessian at a local minimum when the radius of the neighborhood (to derive the perturbed loss) is small. The surrogate gap is easy to compute and feasible for direct minimization during training. Based on the above observations, we propose Surrogate Gap Guided Sharpness-Aware Minimization (GSAM), a novel improvement over SAM with negligible computation overhead. Conceptually, GSAM consists of two steps: 1) a gradient descent like SAM to minimize the perturbed loss, and 2) an ascent step in the orthogonal direction (after gradient decomposition) to minimize the surrogate gap and yet not affect the perturbed loss. GSAM seeks a region with both small loss (by step 1) and low sharpness (by step 2), giving rise to a model with high generalization capabilities. Theoretically, we show the convergence of GSAM and provably better generalization than SAM. Empirically, GSAM consistently improves generalization (e.g., +3.2% over SAM and +5.4% over AdamW on ImageNet top-1 accuracy for ViT-B/32). Code is released at <>.


Gradient Norm Aware Minimization Seeks First-Order Flatness and Improves Generalization

Recently, flat minima are proven to be effective for improving generaliz...

Stability Analysis of Sharpness-Aware Minimization

Sharpness-aware minimization (SAM) is a recently proposed training metho...

SmoothOut: Smoothing Out Sharp Minima for Generalization in Large-Batch Deep Learning

In distributed deep learning, a large batch size in Stochastic Gradient ...

Loss Spike in Training Neural Networks

In this work, we study the mechanism underlying loss spikes observed dur...

ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks

Recently, learning algorithms motivated from sharpness of loss surface a...

Noise Stability Optimization for Flat Minima with Optimal Convergence Rates

We consider finding flat, local minimizers by adding average weight pert...

Frustratingly Easy Model Generalization by Dummy Risk Minimization

Empirical risk minimization (ERM) is a fundamental machine learning para...

Please sign up or login with your details

Forgot password? Click here to reset