Join-Chain Network A Logical Reasoning View of the Multi-head Attention in Transformer Jianyi Zhang

2025-05-05 0 0 318.65KB 11 页 10玖币
侵权投诉
Join-Chain Network: A Logical Reasoning View of
the Multi-head Attention in Transformer
Jianyi Zhang
Duke University
jianyi.zhang@duke.edu
Yiran Chen
Duke University
yiran.chen@duke.edu
Jianshu Chen
Tencent AI Lab
jianshuchen@tencent.com
Abstract—Developing neural architectures that are capable
of logical reasoning has become increasingly important for a
wide range of applications (e.g., natural language processing).
Towards this grand objective, we propose a symbolic reasoning
architecture that chains many join operators together to model
output logical expressions. In particular, we demonstrate that
such an ensemble of join chains can express a broad subset of
“tree-structured” first-order logical expressions, named F OE T ,
which is particularly useful for modeling natural languages. To
endow it with differentiable learning capability, we closely exam-
ine various neural operators for approximating the symbolic join-
chains. Interestingly, we find that the widely used multi-head self-
attention module in transformer can be understood as a special
neural operator that implements the union bound of the join
operator in probabilistic predicate space. Our analysis not only
provides a new perspective on the mechanism of the pretrained
models such as BERT for natural language understanding, but
also suggests several important future improvement directions.
Index Terms—Logical reasoning, multi-head attention, NLP
I. INTRODUCTION
Developing logical system which can naturally process
symbolic rules is one of the important tasks for AI since
it is a foundational model which has wide applications in
language understanding and reasoning. Traditional models
such as Inductive logic programming (ILP) [2], [3] can learn
some logical rules from a collection of positive and negative
examples. However, the exponentially large searching space
of the logical rules limits the scalability of tradition ILP.
Considering that deep neural networks have achieved great
success in many applications such as image classification,
machine translation, speech recognition due to its powerful ex-
pressiveness, the question that comes naturally to us is whether
we can leverage the great expressiveness power of DNNs to
design the next generation logical system. Several previous
attempts [4], [5] have been made in this direction. However,
most of them are heuristic and lack clear interpretability.
Developing interpretable neural architectures that are capable
of logic reasoning has become increasingly important.
Another trend in the most recent development of AI models
is the wide use of multi-head attention mechanism [6]. Nowa-
days the multi-head attention has become a critical part for
many foundational language and vision models, such as Bert
[7] and ViT [8]. Multiple paralleled attention heads strengthen
ICDM 2022 Workshop on Foundation Models for Vision and Language
(FOMO-VL 2022)
the expressive power of a model since they can capture
diverse information from different representation subspaces
at different positions, which derives multiple latent features
depicting the input data from different perspectives.
In our work, to develop a more interpretable neural architec-
ture for logical reasoning, we identify a key operation for the
calculation of a logic predicate. We name it as join operation.
Based on this important operation, we can convert the calcula-
tion of a logic predicate into a process of conducting the join
operations recursively. We also notice that this process requires
skip-connection operations pass the necessary intermediate
outcomes. Based on the above observations, we design a new
framework which contains several kinds of neural operators
with different functions to fulfill our goal of implementing this
recursive process with neural networks. We adopt the same
skip-connection operation as ResNet. Interesting, for the join
operator which conducts the key part of our calculation, we
have found a strong connection between its operation and the
mechanism of the multi-head attention. Hence, we think the
widely adopted multi-head attention module can be understood
as a special neural operator that implements the union bound
of the join operator in probabilistic predicate space. This
finding not only provides us with a good module for our
logical reasoning network, but also inspires us to understand
the popular multi-head attention module in a different way,
which explains its great success in language understanding.
Our findings suggest several potential directions for the im-
provement of the transformer [6], which sheds light on the
design of the large pretrained language models [7], [8] in the
future.
II. JOIN-CHAIN NETWORK
We adopt the following example to introduce our method.
P(x) = y1y2y3y44
t=1Pt(yt)P0(x)W(0,1)(x, y1)
(1)
W(0,2)(x, y2)W(1,3)(y1, y3)W(1,4)(y1, y4).
P(x)is a first-order logic predicate. To derive the value of
P(x)for a given x, we can divide the calculation into the
following three steps.
Step 1: We first calculate the P3,1(y1) =
y3W(1,3)(y1, y3)P3(y3)and P4,1(y1) =
y3W(1,4)(y1, y4)P4(y4).
arXiv:2210.02729v3 [cs.CL] 23 Oct 2022
Fig. 1. Skeleton of the join-chain network
Step 2: We denote the Pnew,1(y1),P1(y1)
P3,1(y1)P4,1(y1). In this step, we calculate the
P1,0(x),y1W(0,1)(x, y1)Pnew,1(y1)and P2,0(x),
y2W(0,2)(x, y2)P2(y2).
Step 3: As for the final step, we need to calculate the
value of P1,0(x)P2,0(x). It is obvious that P(x) =
P1,0(x)P2,0(x)
Based on the above steps, we find a key operation for the
calculation of P(x)and name it as the join operation.
join operation: P(x) = yW (x, y)P(y)
We can transform the calculation of P(x)in Eq 1 as a
recursion process of the join operations. Hence, it is necessary
to include a module in our network to calculate the join
operation. Since we need to conduct the calculation of join
operation in a recursive way with multiple steps, our network
should have multiple layers. Besides, there are multiple join
operations to conduct in each step, we need to include several
join operation modules in each layer. It is worth noting that
some unary predicates such as P2(y1)and P2(y2)are not used
in Step 1 but thereafter used in Step 2. Hence we believe a skip
connection is also necessary for our network. Based on these
observations, to conduct the calculation following our steps,
we design a new network which is named as the join-chain
network. We visualize its skeleton in the Figure 1.
As shown in the Figure 1, the join operators will conduct
the join operations on the inputs to each layer. The skip con-
nection in each layer preserves the inputs for future use. The
aggregation after the join operators will aggregate the inputs
from the skip-connection and the results of join operators. This
reflects the process that we need to conduct operations after
the join operations in each step. At the end of our framework,
we adopt a feed-forward neural layer (FFN), which is designed
for our last step mentioned above.
For the example in Eq 1 mentioned above, if we consider
the set N={x, y1, y2, y3, y4}as the node set and the index
pairs set {(0,1),(0,2),(1,3),(1,4)}as the edge set E. We can
draw the graph G= (N,E)in figure 2 and it is easy to find
that the graph Gis a tree. As proved in the following section,
our framework can calculate the predicates which could be
visualized as a graph of tree-structured. Based on the previous
x
y1
y3y4
y2
Fig. 2. The tree structure of example in Eq 1
research, we can derive the dependency tree for every sentence.
If we translate the dependency tree into a logical forms, most
sentence can be represented by a logic predicate [9] which has
a tree structure similar to the example in Figure 2. Hence our
framework has very wide applications in NLP tasks.
III. LOGICAL EXPRESSIVENESS
Every first-order formula is logically equivalent to some
formula in prenex normal form. In our study, we are interested
in the case where there is only one variable xwhich is not
restricted by any quantifier. For convenience, we denote the
free variable xas y0and other restricted variables are denoted
as y1, y2,· · · , yT. We assume its prenex normal form has the
following formulation Eq 2.
P(x) = y1y2...yT
M
m=1  ¯
M
¯m=1
ˆ
Pm
¯m(yi(m, ¯m))(2)
M0
m0=1
ˆ
Wm
m0(yn(m,m0), yj(m,m0))Qm
Each ˆ
Pm
¯mis a unary predicate and each ˆ
Wm
m0
is a binary predicate. Qmis some propositional
constant. i(m, ¯m), n(m, m0)and j(m, m0)are three
index mapping functions, which map (m, ¯m)or
(m, m0)to some values in {0,1,2,·, T }. This means
yi(m, ¯m), yn(m,m0), yj(m,m0)∈ {y0, y1, ..., yT}. Moreover, we
assume n(m, m0)< j(m, m0).
Theorem 3.1: For every predicate P(x)in 2, there exists a
P(y0)which has the following formulation and is logically
equivalent to the P(x).
P(y0) = y1y2...yT
M
m=1 
t∈N mPm
t(yt)(3)
(tp,tc)∈EmWm
(tp,tc)(ytp, ytc)Qm,
where each Nm⊂ {0,1,2,· · · , T }is set of indices and 0
Nm.Emis set of index pairs. Each Nmand Emcan form a
graph Gm= (Nm,Em).
Then we provide the following definitions and assumption to
show the logical expressiveness of our join chain network.
Definition 3.1 (FOET (I1,I2)): FOET (I1,I2)is the set
of all the predicates P(y0)which can be transformed into the
formulations in 3 and satisfy the following requirements.
Each Wm
(tp,tc)(ytp, ytc)∈ I2and Pm
t(yt) = Pm
j1(yt)
Pm
j2(yt)· · ·∧Pm
jt(yt), where {Pm
j1, P m
j2,· · · , P m
jt}⊂I1.
Each graph Gm= (Nm,Em)is a tree.
Definition 3.2: Each graph Gm= (Nm,Em), m
{1,2,· · · , M}is a tree. We denote the height for
each tree as L1, L2,· · · , LMrespectively and Lmax =
max{L1, L2,· · · , LM}. We denote the number of leaf nodes
of each tree as H1, H2,· · · , HMrespectively and Hsum =
H1+H2+· · · +HM. We define the height of the predicate
as L=Lmax and the width of the predicate as H=Hsum.
Definition 3.3 (FOET {¯
L, ¯
H}(I1,I2)): FOET {¯
L, ¯
H}(I1,I2)
is the set of the predicates P(x)∈ FOET with height L¯
L
and width H¯
H.
Assumption 3.1: For any W∈ I2, there exists some
function f, such that W=f(Is), where Is⊂ I1.
Now we can provide the theorem to show the logical expres-
siveness.
Theorem 3.2: Under the assumption 3.1, a join-chain
network with the ¯
H-head and ¯
L-layer self-attention block
can express all the predicates in FOET {¯
L, ¯
H}(I1,I2)if the
input to the join-chain network is I1.
We look back at the example in Eq 1 to understand the
definitions and theorem. There is one tree i.e. M=1. The node
set of the graph is N={0,1,2,4}. The edge set of the graph
Eis {(0,1),(0,2),(1,3),(1,4)}. The graph (N,E)is a tree
visualized in the Figure 2. The height of the P(y0)is 2. The
width is 3 since there are 3 leaf nodes i.e. y3,y4and y2.
According to the theorem 3.2, P(x)can be expressed by the
join-chain network with 3 heads and 2 layers.
IV. RETHINK MULTI-HEAD ATTENTION AS JOIN OPERATION
To design the join operator with differentiable learning capa-
bility, we study various neural operators for approximating the
symbolic join operations. Interestingly, we find that the widely
adopted multi-head attention mechanism can be understood as
a join operator.
First, we denote the domain of all the predicates, in-
cluding all the binary predicates W(x, y)and unary pred-
icates P(y), as {x1, x1, x3, ..., xS}, which means x, y
{x1, x2, x3, ..., xS}. Then in the multi-head attention mech-
anism, the core part is the product between the self-attention
matrix Aand the value tensor V,
Z=AV (4)
If we want to calculate the the join operation between
W(x, y)and P(y), we know
yW (x, y)P(y) = S
s=1W(x, xs)P(xs)(5)
Then, if the multiplication can be understood as the conjunc-
tion operation and the addition as the disjunction , we can
consider the value tensor V as [P(x1), P (x2), ..., P (xS)] and
self-attention matrix Aas {W(xs, xs0)}. Based on this, the
s-th element zsin the tensor Zis the value of yW (xs, y)
P(y). Hence the tensor Zwill the join operation between
W(x, y)and P(y). Generally speaking, the self-attention
matrix Alearns all the values of the binary predicate W(x, y),
and the value tensor Vlearns all the values of the unary
predicate P(y). For each head of attention mechanism in each
layer, the self-attention matrix Alearns a binary predicate
W(x, y). Hence, the amount of the leaf nodes in the Figure 2
reflects the importance of multiple heads for self-attention.
V. DISCUSSION
Our work provides a novel understanding of the multi-
head attention from the logical reasoning view. Based on the
previous work on semantic parsing such as dependency tree
and lambda dependency-based compositional semantics [9],
[10], most sentences can be represented in logical forms which
have similar tree structure as the example in Eq 1. Hence,
our work provides a novel explanation why the multi-head
attention achieves great success in recent development of NLP
from a new perspective. Furthermore, the logic reasoning view
also provides us with some suggestions on how to improve the
design of transformers. Since logical expressions of most sen-
tences in NLP have tree structures similar to the example in Eq
1 shown in Figure 2, the number of join operations decreases
as we proceed the calculation. This means the amount of the
heads could decrease as the layer become less close to the
inputs. This provides us with a new insight on how to compress
the multi-head attention blocks in transformers. Besides, it is
worth noting that the skip connection is also heavily utilized
in the transformers. Our work provide a new interpretation for
the use of skip-connection in transformers which is different
from its original motivation in residual learning. Moreover,
the assumption 3.1 also provides us with a potential way to
augment the transformer with some external knowledge. We
could incorporate some additional commonsense knowledge
into the self-attention block to boost the logical reasoning as
well as the inference of transformer.
Our work has lots of interesting future directions. One is
to design some more efficient neural operators for the join
摘要:

Join-ChainNetwork:ALogicalReasoningViewoftheMulti-headAttentioninTransformerJianyiZhangDukeUniversityjianyi.zhang@duke.eduYiranChenDukeUniversityyiran.chen@duke.eduJianshuChenTencentAILabjianshuchen@tencent.comAbstract—Developingneuralarchitecturesthatarecapableoflogicalreasoninghasbecomeincreasingl...

展开>> 收起<<
Join-Chain Network A Logical Reasoning View of the Multi-head Attention in Transformer Jianyi Zhang.pdf

共11页,预览3页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:11 页 大小:318.65KB 格式:PDF 时间:2025-05-05

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 11
客服
关注