Just Train Twice: Improving Group Robustness without Training Group Information

07/19/2021
by   Evan Zheran Liu, et al.
12

Standard training via empirical risk minimization (ERM) can produce models that achieve high accuracy on average but low accuracy on certain groups, especially in the presence of spurious correlations between the input and label. Prior approaches that achieve high worst-group accuracy, like group distributionally robust optimization (group DRO) require expensive group annotations for each training point, whereas approaches that do not use such group annotations typically achieve unsatisfactory worst-group accuracy. In this paper, we propose a simple two-stage approach, JTT, that first trains a standard ERM model for several epochs, and then trains a second model that upweights the training examples that the first model misclassified. Intuitively, this upweights examples from groups on which standard ERM models perform poorly, leading to improved worst-group performance. Averaged over four image classification and natural language processing tasks with spurious correlations, JTT closes 75 standard ERM and group DRO, while only requiring group annotations on a small validation set in order to tune hyperparameters.

READ FULL TEXT

page 3

page 6

page 16

research
08/26/2022

Take One Gram of Neural Features, Get Enhanced Group Robustness

Predictive performance of machine learning models trained with empirical...
research
04/20/2022

Improved Worst-Group Robustness via Classifier Retraining on Independent Splits

High-capacity deep neural networks (DNNs) trained with Empirical Risk Mi...
research
03/10/2023

Distributionally Robust Optimization with Probabilistic Group

Modern machine learning models may be susceptible to learning spurious c...
research
02/11/2023

Pushing the Accuracy-Group Robustness Frontier with Introspective Self-play

Standard empirical risk minimization (ERM) training can produce deep neu...
research
02/06/2023

Bitrate-Constrained DRO: Beyond Worst Case Robustness To Unknown Group Shifts

Training machine learning models robust to distribution shifts is critic...
research
10/27/2021

Simple data balancing achieves competitive worst-group-accuracy

We study the problem of learning classifiers that perform well across (k...
research
10/13/2022

Outlier-Robust Group Inference via Gradient Space Clustering

Traditional machine learning models focus on achieving good performance ...

Please sign up or login with your details

Forgot password? Click here to reset