
As shown in the first row of Fig.1, where Mask R-CNN [
17
] with ResNet50 [
22
] as the backbone is
trained using different batch sizes, we compare the gradient variances of different network modules,
including backbone, FPN, RPN, detection head, and mask head. We see that when the batch size
is small (
32
in the first figure), the gradient variances of different network modules are similar
throughout the training process. When the batch size increases from 256 to 1024 (
2nd ∼4th
figures), the gradient variances misalign in different modules whose variance gap enlarges during
training. Training fails when batch size equals 1024. More importantly, the gradient variances have
significantly smaller values in the RPN, FPN, detection head, and mask head compared to that in the
backbone, and their gradient variances change sharply in the late stage of training (two figures in
the middle). We find that such misalignment undesirably burdens the large-batch training, leading to
severe performance drop and even training failure. More observations on various visual tasks and
networks can be found in Appendix A.2.
The above empirical analysis naturally inspires us to design a simple yet effective method AGVM for
training dense visual predictors with multiple modules using very large batch size. AGVM directly
modulates the misaligned variance of gradient, making it consistent between different network
modules throughout training. As shown in the second row of Fig.1, AGVM significantly outperforms
the recent approaches of large-batch training in four different visual prediction tasks with various
batch sizes from 32 to 2048. For example, AGVM enables us to train an object detector with a huge
batch size 1536 (where prior arts may fail), reducing training time by more than 35
×
compared to
the regular training setup.
This work makes three main
contributions
.
Firstly
, we carefully design AGVM, which to our
knowledge, is the first large-batch optimization method for various dense prediction networks and
tasks. We evaluate AGVM in different architectures (e.g., CNNs and Transformers), solvers (e.g., SGD
and AdamW), and tasks (e.g., object detection, instance segmentation, semantic segmentation, and
panoptic segmentation).
Secondly
, we provide a convergence analysis of AGVM, which converges
to a stable point in a general non-convex optimization setting. We also conduct an empirical analysis
that reveals an important insight: the inconsistency of effective batch size between different modules
would aggravate the gradient variance misalignment when batch size is large, leading to performance
drop and even training failure. We believe this insight may facilitate future research for large-scale
training of complicated vision systems.
Thirdly
, extensive experiments are conducted to evaluate
AGVM, which achieves many new state-of-the-art performances on large-batch training. For example,
AGVM demonstrates more stable generalization performance than prior arts under extremely large
batch size (i.e., 10k). In particular, it enables training of the widely-used Faster R-CNN+ResNet50
within 4 minutes without performance drop. More importantly, AGVM can train a detector with one
billion parameters within just 3.5 hours, which reduces the training time by 20.9
×
, while achieving a
top-ranking mAP 62.2 on the COCO dataset.
2 Preliminary and Notation
Let
S={(xi, yi)}n
i=1
denote a dataset with
n
training samples, where
xi
and
yi
represent a data
point and its label respectively. We can estimate the value of a loss function
L:Rd→R
using
a mini-batch of samples that are randomly sampled, and obtain
l(wt) = 1
bPj∈StL(wt,(xj, yj))
,
where
St
denotes the mini-batch at the
t
-th iteration with batch size
|St|=b
and
wt
represents the
parameters of a deep neural network. We can apply stochastic gradient descent (SGD), one of the
most representative algorithms, to update the parameters
wt
. The SGD update equation with learning
rate ηtis:
wt+1 =wt−ηt∇l(wt),(1)
where ∇l(wt)represents the gradient of the loss function with respect to wt.
Layerwise Scaling Ratio.
In large-batch training, You et al.
[8]
observe that the ratio between the
norm of the layer weights and the norm of the gradients is unstable (i.e., oscillate a lot), leading to
training failure. You et al.
[8]
present the LARS algorithm, which adopts a layerwise scaling ratio,
kw(i)
tk/k∇l(w(i)
t)+λw(i)
tk
, to modify the magnitude of the gradient of the
i
-th layer
∇l(w(i)
t)
, where
w(i)
t
and
λ
indicate the parameters of the
i
-th layer and the weight decay coefficient, respectively.
Furthermore, LAMB [
6
] improves LARS by combining the AdamW optimizer with the layerwise
scaling ratio. It can be formulated as
rt=mt/√vt+
, where
mt=β1mt−1+ (1 −β1)∇l(wt)
3