Exact Gradient Computation for Spiking Neural Networks Through Forward Propagation Jane H. Lee1 Saeid Haghighatshoar2and Amin Karbasi3

2025-05-06 0 0 2.14MB 26 页 10玖币
侵权投诉
Exact Gradient Computation for Spiking Neural Networks
Through Forward Propagation
Jane H. Lee1, Saeid Haghighatshoar2and Amin Karbasi3
1Department of Computer Science, Yale University
2SynSense, Zurich, Switzerland
3Department of Electrical Engineering, Yale University
March 13, 2023
Abstract
Spiking neural networks (SNN) have recently emerged as alternatives to traditional neural networks,
owing to energy efficiency benefits and capacity to better capture biological neuronal mechanisms. However,
the classic backpropagation algorithm for training traditional networks has been notoriously difficult
to apply to SNN due to the hard-thresholding and discontinuities at spike times. Therefore, a large
majority of prior work believes exact gradients for SNN w.r.t. their weights do not exist and has focused
on approximation methods to produce surrogate gradients. In this paper, (1) by applying the implicit
function theorem to SNN at the discrete spike times, we prove that, albeit being non-differentiable in time,
SNNs have well-defined gradients w.r.t. their weights, and (2) we propose a novel training algorithm, called
forward propagation (FP), that computes exact gradients for SNN. FP exploits the causality structure
between the spikes and allows us to parallelize computation forward in time. It can be used with other
algorithms that simulate the forward pass, and it also provides insights on why other related algorithms
such as Hebbian learning and also recently-proposed surrogate gradient methods may perform well.
Keywords: spiking neural networks ·exact gradients ·neuromorphic computation
1 Introduction
While artificial neural networks have achieved state-of-the-art performance on various tasks, such as in natural
language processing or computer vision, these networks are usually large, complex, and their computation
consumes a lot of energy. Spiking neural networks (SNNs), inspired by biological neuronal mechanisms and
sometimes referred to as the third generation of neural networks [
34
], have garnered considerable attention
recently [
42
,
39
,
7
,
9
,
12
] as low-power alternatives. For instance, SNNs have been shown to yield 1-2 orders
of magnitude energy saving over ANNs on emerging neuromorphic hardware [
1
,
10
]. SNNs have other
unique properties, owing to their ability to model biological mechanisms such as dendritic computations with
temporally evolving potentials [
20
] or short-term plasticity, which allow them to even outperform ANNs in
accuracy in some tasks [
35
]. The power of neuromorphic computing can even be seen in ANNs, e.g., [
24
]
use rank-coding in ANN inspired by the temporal encoding of information in SNNs. However, due to the
discontinuous resetting of the membrane potential in spiking neurons, e.g., in Integrate-and-Fire (IF) or
Leaky-Integrate-and-Fire (LIF) type neurons [
6
,
27
], it is notoriously difficult to calculate gradients and train
SNNs by conventional methods. For instance, [
24
] use the fact that “spike coding poses difficulties...and
training that require ad hoc mitigation” and “SNNs are particularly difficult to analyse mathematically” to
Equal contribution
1
arXiv:2210.15415v2 [cs.NE] 10 Mar 2023
motivate rank-coding for ANN. As such, many existing works on training SNN do so without exact gradients,
which range from heuristic rules like Hebbian learning [
26
,
44
] and STDP [
31
,
33
], SNN-ANN conversion
[43, 13, 22], and surrogate gradient approximations [37].
In this work, by applying the implicit function theorem (IFT) at the firing times of the neurons in SNN,
we first show that under fairly general conditions, gradients of loss w.r.t. network weights are well-defined.
We do this by proving that the conditions for IFT are always satisfied at firing times. We then provide what
we call a forward-propagation (FP) algorithm which uses the causality structure in network firing times and
our IFT-based gradient calculations in order to calculate exact gradients of the loss w.r.t. network weights.
We call it forward propagation because intermediate calculations needed to calculate the final gradient are
actually done forward in time (or forward in layers for feed-forward networks). We highlight the following
features of our method:
Our method can be applied in networks with arbitrary recurrent connections (up to self loops) and
is agnostic to how the forward pass is implemented. We provide an implementation for computing
the firing times in the forward pass, but as long as we can obtain accurate firing times and causality
information (for instance, using existing libraries), we can calculate gradients.
Our method can be seen as an extension of Hebbian learning as it illustrates that the gradient w.r.t. a
weight
Wji
connecting neuron
j
to neuron
i
is almost an average of the feeding kernel
yji
between these
neurons at the firing times. In the context of Hebbian learning (especially from a biological perspective),
this is interpreted as the well-known fact that stronger feeding/activation amplifies the association
between the neurons. [8, 19]
In our method, the smoothing kernels
yji
arise naturally as a result of application of IFT at the firing
times, resembling the smoothing kernels applied in surrogate gradient methods. As a result (1) our
method sheds some light on why the surrogate gradient methods may work quite well and (2) in our
method, the smoothing kernels
yji
vary according to the firing times between two neurons; thus, they
can be seen as an adaptive version of the fixed smoothing kernels used in surrogate gradient methods.
Most of the methods in the literature apply a time-quantized version of the neuron dynamics and convert
the continuous-time system into a discrete-time system. While we derive results in the continuous time
regime, our IFT formulation is also applicable in these discrete-time scenarios. To do so, one needs to
treat the weight parameters and all the time-quantized versions of the variables (such as synaptic and
membrane potential, etc.) as separate variables. The number of these state variables however grows
proportionally to the simulation time and the precision of the time quantization, which is why the
continuous-time regime is preferred.
1.1 Related Work
A review of learning in deep spiking networks can be found at [
48
,
40
,
42
,
49
], with [
42
] discussing also
developments in neuromorphic computing in both software (algorithms) and hardware. [
37
] focuses on
surrogate gradient methods, which use smooth activation functions in place of the hard-thresholding for
compatibility with usual backpropagation and have been used to train SNNs in a variety of settings [
16
,
3
,
23, 51, 47, 45].
A number of works explore backpropagation in SNNs [
5
,
25
,
52
]. The SpikeProp [
5
] framework assumes a
linear relationship between the post-synaptic input and the resultant spiking time, which our framework does
not rely on. The method in [
25
] and its RSNN version [
52
] are limited to a rate-coded loss that depends on
spike counts. The continuous “spike time” representation of spikes in our framework is related to temporal
coding [
36
], but the authors of [
36
] in the context of differentiation of losses largely ignore the discontinuities
that occur at spikes times, stating “the derivative...is discontinuous at such points [but] many feedforward
ANNs use activation functions with a discontinuous first derivative”. In contrast with [
36
], we prove that
exact gradients can be calculated despite this discontinuity.
2
As mentioned in [
50
], applying methods from optimal control theory to compute exact gradients in hard-
threshold spiking neural networks has been recognized [
46
,
30
,
29
]. However, unlike in our setting these works
consider a neuron with a two-sided threshold and provide specialized algorithms for specific loss functions.
Most related to our work is the recent EventProp [
50
] which derives an algorithm for a continuous-time
spiking neural network by applying the adjoint method (which can be seen as generalized backpropagation)
together with proper partial derivative jumps. EventProp calculates the gradients by accumulating adjoint
variables while computing adjoint state trajectories via simulating another continuous-time dynamical system
with transition jumps in a backward pass, but our algorithm computes gradients with just firing time and
causality information. In particular, the only time we need to simulate continuous-time dynamics is in the
forward pass.
2 Spiking Neural Networks
In this section, we first describe the precise models we use throughout the paper for the pre-synaptic and
pos-synaptic behaviors of spiking neurons. We then explain the dynamics of a SNN and the effects of spike
generations.
2.1 Pre-Synaptic Model
For the ease of presentation, a generic structure of a SNN is illustrated in Fig. 1 on the left. There are many
different models to simulate the nonlinear dynamics of a spiking neuron (e.g., see [
19
]). In this paper, we
adopt the Leaky-Integrate-and-Fire (LIF) model which consists of three main steps.
2.1.1 Synaptic Dynamics
A generic neuron
i
is stimulated through a collection of input neurons, its neighborhood
Ni
. Each neuron
j∈ Ni
has a synaptic connection to
i
whose dynamics is modelled by a 1
st
-order low-pass
RC
circuit that
smooths out the Dirac Delta currents it receives from neuron
j
. Since this system is linear and time-invariant
(LTI), it can be described by its impulse response
hs
j(t) = eαjtu(t),
where
αj
=
1
τs
j
and
τs
j
=
Rs
jCs
j
denotes the synaptic time constant of neuron
j
, and
u
(
t
)denotes the Heaviside
step function. Therefore, the output synaptic current Ij(t)can be written as
Ij(t) = hs
j(t)?X
f∈Fj
δ(tf) = X
f∈Fj
hs
j(tf),(1)
where
Fj
is the set of output firing times from neuron
j
. Note that in Eq.
(1)
we used the fact that convolution
with a Direct Delta function hs
j(t)? δ(tf) = hs
j(tf), is equivalent to shifts in time.
2.1.2 Neuron Dynamics
The synaptic current of all stimulating neurons is weighted by
Wji
,
j∈ Ni
, and builds the weighted current
that feeds the neuron. The dynamic of the neuron can be described by yet another 1
st
-order low-pass
RC
circuit with a time constant
τn
i
=
Rn
iCn
i
and with an impulse response
hn
i
(
t
) =
eβitu
(
t
)where
βi
=
1
τn
i
. The
output of this system is the membrane potential Vi(t).
2.1.3 Hard-thresholding and spike generation
The membrane potential
Vi
(
t
)is compared with the firing threshold
θi
of neuron
i
and a spike (a delta current)
is produced by neuron when
Vi
(
t
)goes above
θi
. Also, after spike generation, the membrane potential is
reset/dropped immediately by θi(reset to zero).
3
2.2 Post-Synaptic Kernel Model
We call the model illustrated in the left of Fig. 1 the pre-synaptic model, as the spiking dynamics of the
stimulating neurons
Ni
of a generic neuron
i
appear before the synapse. In this paper, we will work with a
modified but equivalent model in which we combine the synaptic and neuron dynamics, and consider the
effect of spiking dynamics of
Ni
directly on the membrane potential after it is being smoothed out by the
synapse and neuron low-pass filters. We call this model the post-synaptic or kernel model of the SNN.
To derive this model, we simply use the fact that the only source of non-linearity in SNN is hard-
thresholding during the spike generation. And, in particular, SNN dynamics from the stimulating neuron
j∈ Ni
until the membrane potential
Vi
(
t
)is completely linear and can be described by the joint impulse
response
hji(t) = hs
j(t)? hn
i(t)
=Z
−∞
hs
j(τ)hn
i(tτ)
=Zt
0
eαjτeβi(tτ)
=eαjteβit
βiαj
u(t).(2)
Therefore the whole effect of spikes
Fj
of neuron
j∈ Ni
on the membrane potential can be written in terms
of kernel
yji(t) = X
f∈Fj
hji(tf).
We call this model post-synaptic since the effect of dynamic of neuron
j∈ Ni
on
Vi
(
t
)is considered after
being processed by the synapse and even the neuron
i
. Using the linearity and applying super-position for
linear systems, we can see that the effect of all spikes coming for all stimulating neurons
Ni
, can be written as
V
i(t) = X
j∈Ni
Wjiyji(t),(3)
where
Wji
is the weight from neuron
j
to
i
. We used
V
i
(
t
)to denote the contribution to the membrane
potential
Vi
(
t
)after neglecting the potential reset due to hard-thresholding and spike generation. Fig. 1 (right)
illustrates the post-synaptic model for the SNN.
Remark 1.
Our main motivation for using this equivalent model comes from the fact that even though the
spikes are not differentiable functions, the effect of each stimulating neuron
j∈ Ni
on neuron
i
is written as a
well-defined and (almost everywhere) differentiable kernel yji(.).
Remark 2
(Connection with the surrogate gradients)
.
Intuitively speaking, and as we will show rigorously
in the following sections, the kernel model derived here immediately shows that SNNs have an intrinsic
smoothing mechanism for their abrupt spiking inputs, through the low-pass impulse response
hji
(
t
)between
their neurons. As a result, one does not need to introduce any additional artificial smoothing to derive
surrogate gradients by modifying the neuron model in the backward gradient computation path. We will use
this inherent smoothing to prove that SNNs indeed have well-defined gradients. Interestingly, our derivation of
the exact gradient based on this inherent smoothing property intuitively explains that even though surrogate
gradients are not exact, they may be close to and yield a similar training performance as the exact gradients.
2.3 SNN Full Dynamics
In the post-synaptic kernel model, we already specified the effect of spikes from stimulating neurons as in
(3)
. To have a full picture of the SNN dynamics, we need to specify also the effect of spike generation. The
following theorem completes this.
4
Figure 1:
(Left)
A generic structure of a spiking neural network: (i) spikes (train of Dirac Delta currents)
Fj
coming from a generic input neuron
j
pass through the synaptic RC circuit with a time constant
τs
j
=
Rs
jCs
j
and build the synaptic current
Ij
(
t
), (ii) synaptic current
Ij
(
t
)are weighted by
Wji
and build the input
current
PjWjiIj
(
t
), (iii) this current is filtered through neuron
i
as an RC circuit with a time constant
τn
i
=
Rn
iCn
i
and produces the membrane potential
Vi
(
t
), (iv) membrane potential
Vi
(
t
)is compared with
the threshold
θi
and a current spike is produced when it passes above
θi
, then (v) membrane potential is
reset/dropped by
θi
immediately after the spike generation.
(Right)
Post-synaptic kernel model of the SNNs.
In this model neuron
j∈ Ni
stimulates neuron
i
through the smooth kernel
yji
(
t
) =
Pg∈Fihji
(
tg
)rather
than the abrupt spiking signal Pg∈Fjδ(tg)as adopted in pre-synaptic model.
Theorem 1.
Let
i
be a generic neuron in SNN and let
Ni
be the set of its stimulating neurons. Let
hn
i
(
t
)and
hs
j
(
t
)be the impulse response of the neuron
i
and synapse
j∈ Ni
, respectively, and let
hji
(
t
) =
hn
i
(
t
)
? hs
j
(
t
).
Then the membrane potential of the neuron ifor all times tis given by
Vi(t) = V
i(t)X
f∈Fi
θihn
i(tf),(4)
where
yji
(
t
) =
Pg∈Fjhji
(
tg
)denotes the smoothed kernel between the neuron
i
and
j∈ Ni
, and
θi
denotes
the spike generation threshold of the neuron i.
Proof.
In the following, we provide a a simple and intuitive proof. An alternative and more rigorous proof by
induction on the number of firing times of neuron iis provided in the Appendix 7.1.
Proof (i)
: We use the following simple result/computation-trick from circuit theory that in an RC circuit,
abrupt dropping of the potential of the capacitor by
θi
at a specific firing time
f∈ Fi
can be mimicked by
adding a voltage source
θiu
(
tf
)series with the capacitor. If we do this for all the firing times of the
neuron, we obtain a linear RC circuit with two inputs: (i) weighted synaptic current coming from the neurons
Ni, (ii) voltage sources {−θiu(tf) : f∈ Fi}. This is illustrated in Fig. 2.
The key observation is that although this new circuit is obtained after running the dynamics of the neuron
and observing its firing times
Fi
, as far as the membrane potential
Vi
(
t
)is concerned, the two circuits are
equivalent. Interestingly, after this modification, the new circuit is a completely linear circuit and we can
apply the super-position principle for linear circuits to write the response of the neuron as the summation of:
(i) the response
V(1)
i
(
t
)due to the weighted synaptic current
Is
i
(
t
)in the input (as in the previous circuit),
and (ii) the response
V(2)
i
(
t
)due to Heaviside voltage sources
{−θiu
(
tf
) :
f∈ Fi}
. From
(3)
,
V(1)
i
(
t
)is
simply given by
V(1)
i(t) = X
j∈Ni
Wjiyji(t).
The response of an RC circuit to a Heaviside voltage function
θiu
(
tf
)is given by
θihn
i
(
tf
)where
hn
i
(
t
)is the impulse response of the neuron
i
as before. We also used the time invariance property (for shift
5
摘要:

ExactGradientComputationforSpikingNeuralNetworksThroughForwardPropagationJaneH.Lee*1,SaeidHaghighatshoar*2andAminKarbasi31DepartmentofComputerScience,YaleUniversity2SynSense,Zurich,Switzerland3DepartmentofElectricalEngineering,YaleUniversityMarch13,2023AbstractSpikingneuralnetworks(SNN)haverecentlye...

展开>> 收起<<
Exact Gradient Computation for Spiking Neural Networks Through Forward Propagation Jane H. Lee1 Saeid Haghighatshoar2and Amin Karbasi3.pdf

共26页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:26 页 大小:2.14MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 26
客服
关注