Online Training Through Time for Spiking Neural Networks Mingqing Xiao1 Qingyan Meng23 Zongpeng Zhang4 Di He1 Zhouchen Lin156

2025-05-02 0 0 1005.59KB 21 页 10玖币
侵权投诉
Online Training Through Time for Spiking Neural
Networks
Mingqing Xiao1, Qingyan Meng2,3, Zongpeng Zhang4, Di He1, Zhouchen Lin1,5,6
1Key Lab. of Machine Perception (MoE), School of Intelligence Science and Technology,
Peking University
2The Chinese University of Hong Kong, Shenzhen
3Shenzhen Research Institute of Big Data
4Center for Data Science, Academy for Advanced Interdisciplinary Studies, Peking University
5Institute for Artificial Intelligence, Peking University
6Peng Cheng Laboratory, China
{mingqing_xiao, dihe, zlin}@pku.edu.cn, qingyanmeng@link.cuhk.edu.cn,
zongpeng.zhang98@gmail.com
Abstract
Spiking neural networks (SNNs) are promising brain-inspired energy-efficient
models. Recent progress in training methods has enabled successful deep SNNs
on large-scale tasks with low latency. Particularly, backpropagation through time
(BPTT) with surrogate gradients (SG) is popularly used to enable models to achieve
high performance in a very small number of time steps. However, it is at the cost of
large memory consumption for training, lack of theoretical clarity for optimization,
and inconsistency with the online property of biological learning rules and rules on
neuromorphic hardware. Other works connect the spike representations of SNNs
with equivalent artificial neural network formulation and train SNNs by gradients
from equivalent mappings to ensure descent directions. But they fail to achieve low
latency and are also not online. In this work, we propose online training through
time (OTTT) for SNNs, which is derived from BPTT to enable forward-in-time
learning by tracking presynaptic activities and leveraging instantaneous loss and
gradients. Meanwhile, we theoretically analyze and prove that the gradients of
OTTT can provide a similar descent direction for optimization as gradients from
equivalent mapping between spike representations under both feedforward and
recurrent conditions. OTTT only requires constant training memory costs agnostic
to time steps, avoiding the significant memory costs of BPTT for GPU training.
Furthermore, the update rule of OTTT is in the form of three-factor Hebbian
learning, which could pave a path for online on-chip learning. With OTTT, it is the
first time that the two mainstream supervised SNN training methods, BPTT with
SG and spike representation-based training, are connected, and meanwhile it is in a
biologically plausible form. Experiments on CIFAR-10, CIFAR-100, ImageNet,
and CIFAR10-DVS demonstrate the superior performance of our method on large-
scale static and neuromorphic datasets in a small number of time steps. Our code
is available at https://github.com/pkuxmq/OTTT-SNN.
1 Introduction
Spiking neural networks (SNNs) are regarded as the third generation of neural network models [
1
]
and have gained increasing attention in recent years [
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
]. SNNs
are composed of brain-inspired spiking neurons that imitate biological neurons to transmit spikes
Corresponding author.
36th Conference on Neural Information Processing Systems (NeurIPS 2022).
arXiv:2210.04195v2 [cs.NE] 31 Dec 2022
between each other. This allows event-based computation and enables efficient computation on
neuromorphic hardware with low energy consumption [14, 15, 16].
However, the supervised training of SNNs is challenging due to the non-differentiable neuron model
with discrete spike-generation procedures. Several kinds of methods are proposed to tackle the
problem, and recent progress has empirically obtained successful results. Backpropagation through
time (BPTT) with surrogate gradients (SG) is one of the mainstream methods which enables the
training of deep SNNs with high performance on large-scale datasets (e.g., ImageNet) with extremely
low latency (e.g., 4-6 time steps) [
6
,
10
,
11
,
13
]. These approaches unfold the iterative expression
of spiking neurons, backpropagate the errors through time [
17
], and use surrogate derivatives to
approximate the gradient of the spiking function [
3
,
4
,
18
,
19
,
20
,
21
,
22
,
23
]. As a result, during
training, they suffer from significant memory costs proportional to the number of time steps, and the
optimization with approximated surrogate gradients is not well guaranteed theoretically. Another
branch of works builds the closed-form formulation for the spike representation of neurons, e.g. the
(weighted) firing rate or spiking time, which is similar to conventional artificial neural networks
(ANNs). Then SNNs can be either optimized by calculating the gradients from the equivalent
mappings between spike representations [
2
,
24
,
25
,
26
,
9
,
27
], or converted from a trained equivalent
ANN counterpart [
28
,
29
,
30
,
31
,
32
,
7
,
33
,
8
,
34
]. The optimization of these methods is clearer
than surrogate gradients. However, they require a larger number of time steps compared to BPTT
with SG. Therefore, they suffer from high latency, and more energy consumption is required if the
representation is rate-based. Another critical point for both methods is that they are indeed inconsistent
with biological online learning, which is also the learning rule on neuromorphic hardware [15].
In this work, we develop a novel approach for training SNNs to achieve high performance with
low latency, and maintain the online learning property to pave a path for learning on neuromorphic
chips. We call our method online training through time (OTTT). We first derive OTTT from the
commonly used BPTT with SG method by analyzing the temporal dependency and proposing to
track the presynaptic activities in order to decouple this dependency. With the instantaneous loss
calculation, OTTT can perform forward-in-time learning, i.e. calculations are done online in time
without computing backward through the time. Then we theoretically analyze the gradients of OTTT
and gradients of spike representation-based methods. We show that they have similar expressions and
prove that they can provide the similar descent direction for the optimization problem formulated
by spike representation. For the feedforward network condition, gradients are easily calculated and
analyzed. For the recurrent network condition, we follow the framework in [
12
] that weighted firing
rates will converge to an equilibrium state and gradients can be calculated by implicit differentiation.
With this formulation, the gradients correspond to an approximation of gradients calculated by
implicit differentiation, which can be proved to be a descent direction for the optimization problem
as well [
35
,
36
]. In this way, a connection between OTTT and spike representation-based methods is
bridged. Finally, we show that OTTT is in the form of three-factor Hebbian learning rule [
37
], which
could pave a path for online learning on neuromorphic chips. Our contributions include:
1.
We propose online training through time (OTTT) for SNNs, which enables forward-in-time
learning and only requires constant training memory agnostic to time steps, avoiding the
large training memory costs of backpropagation through time (BPTT).
2.
We theoretically analyze and connect the gradients of OTTT and gradients based on spike
representations, and prove the descent guarantee for optimization under both feedforward
and recurrent conditions.
3.
We show that OTTT is in the form of three-factor Hebbian learning rule, which could pave a
path for on-chip online learning. With OTTT, it is the first time that a connection between
BPTT with SG, spike representation-based methods, and biological three-factor Hebbian
learning is bridged.
4.
We conduct extensive experiments on CIFAR-10, CIFAR-100, ImageNet, and CIFAR10-
DVS, which demonstrate the superior results of our methods on large-scale static and
neuromorphic datasets in a small number of time steps.
2 Related Work
SNN Training Methods.
As for supervised training of SNNs, there are two main research direc-
tions. One direction is to build a connection between spike representations (e.g. firing rates) of SNNs
with equivalent ANN-like closed-form mappings. With the connection, SNNs can be converted from
2
ANNs [
28
,
29
,
30
,
31
,
32
,
7
,
33
,
8
,
34
,
38
], or SNNs can be optimized by gradients calculated from
equivalent mappings [
2
,
24
,
25
,
26
,
9
,
27
]. Variants following this direction also include [
12
] which
connects feedback SNNs with equilibrium states following fixed-point equations instead of closed-
form mappings. These methods have a clearer descent direction for the optimization problem, but
require a relatively large number of time steps, suffering from high latency and usually more energy
consumption with rate based representation. Another direction is to directly calculate gradients based
on the SNN computation. They follow the BPTT framework, and deal with the non-differentiable
problem of spiking functions by applying surrogate gradients [
3
,
4
,
18
,
19
,
20
,
21
,
6
,
23
,
10
,
11
,
13
],
or computing gradients with respect to spiking times based on the linear assumption [
39
,
40
], or
combining both [
22
]. BPTT with SG can achieve extremely low latency. However, it requires
large training memory to maintain the computational graph unfolded along time, and it remains
unknown why surrogate gradients work well. [
10
] empirically observed that surrogate gradients
have a high similarity with numerical gradients, but it remains unclear theoretically. And gradients
based on spiking times suffer from the “dead neuron” problem [
3
], so they should be combined with
SG in practice [
40
,
22
]. Meanwhile, methods in both directions are inconsistent with biological
online learning, i.e. forward-in-time learning, to pave a path for learning on neuromorphic hardware.
Differently, our proposed method avoids the above problems and maintain the online property.
Online Training of Neural Networks.
In the research of recurrent neural networks (RNNs), there
are several alternatives for BPTT to enable online learning. Particularly, real time recurrent learning
(RTRL) [
41
] proposes to propagate partial derivatives of hidden states over parameters through
time to enable forward-in-time calculation of gradients. Several recent works improve the memory
costs of RTRL with approximation for more practical usage [42, 43, 44]. Another work proposes to
online update parameters based on decoupled gradients with regularization at each time step [
45
].
However, these are all for RNNs and not tailored to SNNs. Several online training methods are
proposed for SNNs [
46
,
47
,
48
], which are derived in the spirit of RTRL and simplified for SNNs.
[
49
] leverages local losses and ignores temporal dependencies for online local training of SNNs,
and [
50
] directly apply the method in [
45
] to train SNNs. However, these methods also leverage
surrogate gradients without providing theoretical explanation for optimization. Meanwhile, [
46
,
49
]
use feedback alignment [
51
], [
47
] is limited to single-layer recurrent SNNs, and [
48
] requires much
larger memory costs for eligibility traces, so they cannot scale to large-scale tasks. [
50
] requires a
specially designed neuron model and more computation for parameter regularization, and also does
not consider large tasks. Differently, our work explain the descent direction under both feedforward
and recurrent conditions with convergent inputs, and is efficient and scalable to large-scale tasks
including ImageNet classification.
3 Preliminaries
3.1 Spiking Neural Networks
Spiking neurons are brain-inspired models that transmit information by spike trains. Each neuron
maintains a membrane potential
u
and integrates input spike trains, which will generate a spike once
uexceeds a threshold. We consider the commonly used leaky integrate and fire (LIF) model, which
describes the dynamics of the membrane potential as:
τm
du
dt =(uurest) + R·I(t), u < Vth,(1)
where
I
is the input current,
Vth
is the threshold, and
R
and
τm
are resistance and time constant,
respectively. A spike is generated when
u
reaches
Vth
at time
tf
, and
u
is reset to the resting potential
u=urest
, which is usually set to be zero. The output spike train is defined using the Dirac delta
function: s(t) = Ptfδ(ttf).
A spiking neural network is composed of connected spiking neurons with connection coefficients.
We consider a simple current model
Ii(t) = Pjwij sj(t) + bi
, where the subscript
i
represents the
i
-th neuron,
wij
is the weight from neuron
j
to neuron
i
, and
bi
is a bias. The discrete computational
form is:
ui[t+ 1] = λ(ui[t]Vthsi[t]) + X
j
wij sj[t] + bi,
si[t+ 1] = H(ui[t+ 1] Vth),
(2)
where
H(x)
is the Heaviside step function,
si[t]
is the spike train of neuron
i
at discrete time step
t
,
and
λ < 1
is a leaky term (typically taken as
11
τm
). The constant
R
,
τm
, and time step size are
absorbed into the weights and bias. The reset operation is implemented by subtraction.
3
3.2 Previous SNN Training Methods
Spike Representation.
The (weighted) firing rate or first spiking time of spiking neurons can be
proved to follow ANN-like closed-form transformations [
7
,
26
,
9
,
12
,
27
]. We focus on the weighted
firing rate [
12
,
27
] which has connection with OTTT in this work. Define weighted firing rates and
weighted average inputs
a[t] = Pt
τ=1 λtτs[τ]
Pt
τ=1 λtτ
,
x[t] = Pt
τ=0 λtτx[τ]
Pt
τ=0 λtτ
in the discrete condition. Given
convergent weighted average inputs
x[t]x
, it can be proved that
a[t]a=σ1
Vth x
with
bounded random error, where
σ
is a clamp function (
σ(x) = min(max(0, x),1)
) in the discrete
condition while a ReLU function in the continuous condition. For feedforward networks, the closed-
form mapping between successive layers is established based on weighted firing rate after time
T
:
al+1[T]σ1
Vth Wlal[T] + bl+1
, and gradients are calculated with such spike representation:
L
Wl=L
aN[T]Ql+1
i=N1
ai+1[T]
ai[T]
al+1[T]
Wl
. For the recurrent condition,
a[t]
will converge to an
equilibrium state following an implicit fixed-point equation, e.g.
a=σ1
Vth (Wa+Fx+b)
for a single-layer network with input connections
F
and contractive recurrent connections
W
,
and gradients can be calculated based on implicit differentiation [
12
]. Let
a=fθ(a)
denote the
fixed-point equation (
θ
are parameters). We have
L
θ=L
a[T]IJfθ|a[T]1fθ(a[T])
θ
, where
Jfθ|a[T]=fθ(a[T])
a[T]is the Jacobian of fθat a[T].
BPTT with SG.
BPTT unfolds the iterative update equation in Eq.(2) and backpropagates along
the computational graph as shown in Fig. 1(a), 1(c). The gradients with
T
time steps are calculated
by 2:
L
Wl=
T
X
t=1
L
sl+1[t]
sl+1[t]
ul+1[t] ul+1[t]
Wl+X
τ<t
τ
Y
i=t1ul+1[i+ 1]
ul+1[i]+ul+1[i+ 1]
sl+1[i]
sl+1[i]
ul+1[i]ul+1[τ]
Wl!,
(3)
where
Wl
is the weight from layer
l
to
l+ 1
and
L
is the loss. The non-differentiable terms
sl[t]
ul[t]
will be replaced by surrogate derivatives, e.g. derivatives of rectangular or sigmoid functions [
4
]:
s
u =1
a1sign |uVth|<a1
2or s
u =1
a2
e(Vthu)/a2
(1+e(Vthu)/a2)2, where a1and a2are hyperparameters.
4 Online Training Through Time for SNNs
This section contains four sub-sections. In Section 4.1, we introduce our proposed OTTT by
decoupling the temporal dependency of BPTT. Then in Section 4.2, we further connect the gradients
of OTTT and spike representation-based methods, and prove that OTTT can provide a descent
direction for optimization, which is not guaranteed by BPTT with SG. In Section 4.3, we discuss the
relationship between OTTT and the three-factor Hebbian learning rule. Implementation details are
presented in Section 4.4.
4.1 Derivation of Online Training Through Time
Decouple temporal dependency.
As shown in Fig. 1(c), BPTT has to maintain the computational
graph of previous time to backpropagate through time. We will decouple such temporal dependency
to enable online gradient calculation, as illustrated in Fig. 1(d).
We first focus on the feedforward network condition. In this setting, all temporal dependencies lie in
the dynamics of each spiking neuron, i.e.
ul+1[i+1]
ul+1[i]
and
ul+1[i+1]
sl+1[i]
sl+1[i]
ul+1[i]
in Eq.(3). We consider
the case that we do not apply surrogate derivatives to
sl+1[i]
ul+1[i]
in such temporal dependency. Since the
derivative of the Heaviside step function is 0 almost everywhere, we have
ul+1[i+1]
sl+1[i]
sl+1[i]
ul+1[i]03
.
2
Note that we follow the numerator layout convention for derivatives, i.e.
θL=L
θ>
is the gradient
with the same dimension of θ.
3
Note that this is consistent with some released implementations of BPTT with SG methods which detach the
neuron reset operation from the computational graph and do not backpropagate gradients in this path [23, 11].
4
𝑠𝑡−1
𝑙+1 𝑠𝑡
𝑙+1 𝑠𝑡+1
𝑙+1
𝑢𝑡−1
𝑙+1 𝑢𝑡
𝑙+1 𝑢𝑡+1
𝑙+1
𝑠𝑡−1
𝑙𝑠𝑡
𝑙𝑠𝑡+1
𝑙
𝑊𝑙𝑊𝑙𝑊𝑙
Time
(a) BPTT Forward
𝑠𝑡−1
𝑙+1 𝑠𝑡
𝑙+1 𝑠𝑡+1
𝑙+1
𝑢𝑡−1
𝑙+1 𝑢𝑡
𝑙+1 𝑢𝑡+1
𝑙+1
𝑠𝑡−1
𝑙𝑠𝑡
𝑙𝑠𝑡+1
𝑙
𝑊𝑙𝑊𝑙𝑊𝑙
Time
𝑎𝑡−1
𝑙𝑎𝑡
𝑙𝑎𝑡+1
𝑙
(b) OTTT Forward
𝜕𝐿
𝜕𝑠𝑡−1
𝑙+1
Time
𝜕𝐿
𝜕𝑠𝑡
𝑙+1
𝜕𝐿
𝜕𝑠𝑡+1
𝑙+1
𝜕𝐿
𝜕𝑠𝑡−1
𝑙
𝜕𝐿
𝜕𝑠𝑡
𝑙
𝜕𝐿
𝜕𝑠𝑡+1
𝑙
𝜕𝐿
𝜕𝑊𝑙
𝜕𝐿
𝜕𝑊𝑙
𝜕𝐿
𝜕𝑊𝑙
𝜕𝐿
𝜕𝑢𝑡−1
𝑙+1
𝜕𝐿
𝜕𝑢𝑡
𝑙+1
𝜕𝐿
𝜕𝑢𝑡+1
𝑙+1
𝑠𝑡−1
𝑙𝑠𝑡
𝑙𝑠𝑡+1
𝑙
(c) BPTT Backward
𝜕𝐿
𝜕𝑠𝑡−1
𝑙+1
Time
𝜕𝐿
𝜕𝑠𝑡
𝑙+1
𝜕𝐿
𝜕𝑠𝑡+1
𝑙+1
𝜕𝐿
𝜕𝑠𝑡−1
𝑙
𝜕𝐿
𝜕𝑠𝑡
𝑙
𝜕𝐿
𝜕𝑠𝑡+1
𝑙
𝜕𝐿
𝜕𝑊𝑙
𝜕𝐿
𝜕𝑊𝑙
𝜕𝐿
𝜕𝑊𝑙
𝜕𝐿
𝜕𝑢𝑡−1
𝑙+1
𝜕𝐿
𝜕𝑢𝑡
𝑙+1
𝜕𝐿
𝜕𝑢𝑡+1
𝑙+1
𝑎𝑡−1
𝑙𝑎𝑡
𝑙𝑎𝑡+1
𝑙
(d) OTTT Backward
Figure 1: Illustration of the forward and backward procedures of BPTT and OTTT.
Then the dependency only includes ul+1[i+1]
ul+1[i], which equals λI. Therefore, we have2:
L
Wl=
T
X
t=1
L
sl+1[t]
sl+1[t]
ul+1[t]
X
τt
λtτul+1[τ]
Wl
,WlL=
T
X
t=1
gul+1 [t]
X
τt
λtτsl[τ]
>
,
(4)
where
gul+1 [t] = L
sl+1[t]
sl+1[t]
ul+1[t]>
is the gradient for
ul+1[t]
. Based on Eq.(4), we can track
presynaptic activities
ˆ
al[t] = Pτtλtτsl[τ]
for each neuron during the forward procedure by
ˆ
al[t+ 1] = λˆ
al[t] + sl[t+ 1]
, so that when given
gul+1 [t]
, gradients at each time step can be
calculated independently by
WlL[t] = gul+1 [t]ˆ
al[t]>
without backpropagation through
ul+1[i+1]
ul+1[i]
.
As for the recurrent network condition, there are additional temporal dependencies due to the feedback
connections between neurons. If there is feedback connection from layer
l2
to
l1
(
l2l1
), there
would be terms such as ul1[i+1]
sl2[i]
sl2[i]
ul2[i]
ul2[i]
ul1[i]in the calculation of gradients (note that Eq. (3) omit
feedback connections for simplicity). We also consider not applying surrogate derivatives to
sl2[i]
ul2[i]
in
the temporal dependency so that gradients are not calculated in this path. Similar to the feedforward
condition, we can derive that the gradients of the general weight
Wlilj
from any layer
li
to any layer
lj
can be calculated by
WliljL[t] = gulj[t]ˆ
ali[t]>
at each time step. A theoretical explanation
for optimization will be presented in Section 4.2.
Instantaneous Loss and Gradient.
Calculating online gradients, e.g. the above
gul+1 [t]
for
WlL[t]
, requires instantaneous computation of the loss at each time step. Previous typical loss for
SNNs is based on the firing rate, e.g.
Lfr =L1
TPT
t=1 sN[t],y
, where
y
is the label,
sN[t]
is the
spike at the last layer, and
L
can take cross-entropy loss. This loss depends on all time steps and does
not support online gradients. We leverage the instantaneous loss and calculate the above
gul+1 [t]
as:
L[t] = 1
TLsN[t],y,gul+1 [t] = L[t]
sN[t]
l+1
Y
i=N1
si+1[t]
si[t]
sl+1[t]
ul+1[t]!>
.(5)
The total loss
L:=PT
t=1 L[t]
is the upper bound of
Lfr
when
L
is a convex function such as
cross-entropy. Then gradients can be calculated independently at each time step, as shown in Fig. 1(d).
We apply surrogate derivatives for sl[t]
ul[t]in this calculation, which will be explained in Section 4.2.
Since gradients are calculated instantaneously at each time step, OTTT does not require maintaining
the unfolded computational graph and only requires constant training memory costs agnostic to time
steps. Note that instantaneous gradients of OTTT will be different from BPTT with the instantaneous
loss for multi-layer or recurrent networks, as we do not consider future influence in the instantaneous
calculation: BPTT considers terms such as
L[t0]
uN[t0]
uN[t0]
uN[t]Ql
i=N1
ui+1[t]
ui[t]
(
t0> t
) for
ul[t]
while
we do not. The equivalence of OTTT and BPTT only holds for the last layer, and we do not seek
the exact equivalence to BPTT with SG which is theoretically unclear, but will build the connection
with spike representations and prove the descent guarantee. Also, note that the tracked presynaptic
activities are similar to the biologically plausible “eligibility traces” in the literature [
46
,
47
,
52
], and
we will provide a more solid theoretical grounding for optimization in Section 4.2.
5
摘要:

OnlineTrainingThroughTimeforSpikingNeuralNetworksMingqingXiao1,QingyanMeng2;3,ZongpengZhang4,DiHe1,ZhouchenLin1;5;61KeyLab.ofMachinePerception(MoE),SchoolofIntelligenceScienceandTechnology,PekingUniversity2TheChineseUniversityofHongKong,Shenzhen3ShenzhenResearchInstituteofBigData4CenterforDataScien...

展开>> 收起<<
Online Training Through Time for Spiking Neural Networks Mingqing Xiao1 Qingyan Meng23 Zongpeng Zhang4 Di He1 Zhouchen Lin156.pdf

共21页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:21 页 大小:1005.59KB 格式:PDF 时间:2025-05-02

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 21
客服
关注