CAUSALLY -GUIDED REGULARIZATION OF GRAPH AT- TENTION IMPROVES GENERALIZABILITY Alexander Wu

2025-04-30 0 0 6.68MB 20 页 10玖币
侵权投诉
CAUSALLY-GUIDED REGULARIZATION OF GRAPH AT-
TENTION IMPROVES GENERALIZABILITY
Alexander Wu
MIT
alexwu@mit.edu
Thomas Markovich
Twitter Cortex
tmarkovich@twitter.com
Bonnie Berger§
MIT
bab@mit.edu
Nils Hammerla
Twitter Cortex
nhammerla@twitter.com
Rohit Singh
MIT
rsingh@mit.edu
ABSTRACT
Graph attention networks estimate the relational importance of node neighbors
to aggregate relevant information over local neighborhoods for a prediction task.
However, the inferred attentions are vulnerable to spurious correlations and con-
nectivity in the training data, hampering the generalizability of models. We intro-
duce CAR, a general-purpose regularization framework for graph attention net-
works. Embodying a causal inference approach based on invariance prediction,
CAR aligns the attention mechanism with the causal effects of active interventions
on graph connectivity in a scalable manner. CAR is compatible with a variety of
graph attention architectures, and we show that it systematically improves gener-
alizability on various node classification tasks. Our ablation studies indicate that
CAR hones in on the aspects of graph structure most pertinent to the prediction
(e.g., homophily), and does so more effectively than alternative approaches. Fi-
nally, we also show that CAR enhances interpretability of attention coefficients by
accentuating node-neighbor relations that point to causal hypotheses.
1 INTRODUCTION
Graphs encode rich relational information that can be leveraged in learning tasks across a wide
variety of domains. Graph neural networks (GNNs) can learn powerful node, edge or graph-level
representations by aggregating a node’s representations with that of its neighbors. The specifics
of a GNN’s neighborhood aggregation scheme are critical to its effectiveness on a prediction task.
For instance, graph convolutional networks (GCNs) aggregate information via a simple averaging
or max-pooling of neighbor features. GCNs are prone to suffer in many real-world scenarios where
uninformative or noisy connections exist between nodes (Kipf & Welling, 2017; Hamilton et al.,
2017). Graph-based attention mechanisms combat these issues by quantifying the relevance of
node-neighbor relations and softly selecting neighbors in the aggregation step accordingly (Velick-
ovic et al., 2018; Brody et al., 2022; Shi et al., 2021). This process of attending to select neighbors
has contributed to significant performance gains for GNNs across a variety of tasks (Zhou et al.,
2018; Veliˇ
ckovi´
c, 2022). Similar to the use of attention in natural language processing and com-
puter vision, attention in graph settings also enables the interpretability of model predictions via the
examination of attention coefficients (Serrano & Smith, 2019).
However, graph attention mechanisms can be prone to spurious edges and correlations that mislead
them in how they attend to node neighbors, which manifests as a failure to generalize to unseen data
(Knyazev et al., 2019). One approach to improve GNNs’ generalizability is to regularize attention
coefficients in order to make them more robust to spurious correlations/connections in the training
Computer Sci. and Artificial Intelligence Lab, Cambridge, MA 02139
Twitter address
Co-corresponding authors
§Also with the Dept. of Mathematics, MIT
1
arXiv:2210.10946v3 [cs.LG] 1 Mar 2023
data. Previous work has focused on L0regularization of attention coefficients to enforce sparsity (Ye
& Ji, 2021) or has co-optimized a link prediction task using attention (Kim & Oh, 2021). Since these
regularization strategies are formulated independently of the primary prediction task, they align the
attention mechanism with some intrinsic property of the input graph without regard for the training
objective.
We take a different approach and consider the question: “What is the importance of a specific edge
to the prediction task?” Our answer comes from the perspective of regularization: we introduce
CAR, a causal attention regularization framework that is broadly suitable for graph attention net-
work architectures (Figure 1). Intuitively, an edge in the input graph is important to a prediction
task if removing it leads to substantial degradation in the prediction performance of the GNN. The
key conceptual advance of this work is to scalably leverage active interventions on node neighbor-
hoods (i.e., deletion of specific edges) to align graph attention training with the causal impact of
these interventions on task performance. Theoretically, our approach is motivated by the invariant
prediction framework for causal inference (Peters et al., 2016; Wu et al., 2022). While some efforts
have previously been made to infuse notions of causality into GNNs, these causal approaches have
been largely limited to using causal effects from pre-trained models as features for a separate model
(Feng et al., 2021; Knyazev et al., 2019) or decoupling causal from non-causal effects (Sui et al.,
2021).
We apply CAR on three graph attention architectures across eight node classification tasks, finding
that it consistently improves test loss and accuracy. CAR is able to fine-tune graph attention by
improving its alignment with task-specific homophily. Correspondingly, we found that as graph het-
erophily increases, the margin of CAR’s outperformance widens. In contrast, a non-causal approach
that directly regularizes with respect to label similarity generalizes less well. On the ogbn-arxiv
network, we investigate the citations up/down-weighted by CAR and found them to broadly group
into three intuitive themes. Our causal approach can thus enhance the interpretability of attention
coefficients, and we provide a qualitative analysis of this improved interpretability. We also present
preliminary results demonstrating the applicability of CAR to graph pruning tasks. Due to the
size of industrially relevant graphs, it is common to use GCNs or sampling-based approaches on
them. There, using attention coefficients learned by CAR on sampled subnetworks may guide graph
rewiring of the full network to improve the results obtained with convolutional techniques.
2 METHODS
2.1 GRAPH ATTENTION NETWORKS
Attention mechanisms have been effectively used in many domains by enabling models to dynam-
ically attend to the specific parts of an input that are relevant to a prediction task (Chaudhari et al.,
2021). In graph settings, attention mechanisms compute the relevance of edges in the graph for a
prediction task. A neighbor aggregation operator then uses this information to weight the contribu-
tion of each edge (Lee et al., 2019a; Li et al., 2016; Lee et al., 2019b).
The approach for computing attention is similar in many graph attention mechanisms. A graph
attention layer takes as input a set of node features = {h1, ..., hN},hiRF, where Nis the number
of nodes. The graph attention layer uses these node features to compute attention coefficients for
each edge: αij =a(hi,hj),
where a:RF0×RF0(0,1) is the attention mechanism function, and the attention coeffi-
cient αij for an edge indicates the importance of node is input features to node j. For a node j,
these attention coefficients are then used to compute a linear combination of its neighbors’ features:
h0
j=X
iN(j)
αij W hi,s.t.X
iN(j)
αij = 1. For multi-headed attention, each of the Kheads first
independently calculates its own attention coefficients α(k)
i,j with its head-specific attention mecha-
nism a(k)(·,·), after which the head-specific outputs are averaged.
In this paper, we focus on three widely used graph attention architectures: the original graph atten-
tion network (GAT) (Velickovic et al., 2018), a modified version of this original network (GATv2)
(Brody et al., 2022), and the Graph Transformer network (Shi et al., 2021). The three architectures
and their equations for computing attention are presented in Appendix A.1.
2
Figure 1: Schematic of CAR: Graph attention networks learn the relative importance of each node-neighbor
for a given prediction task. However, their inferred attention coefficients can be miscalibrated due to noise,
spurious correlations, or confounders (e.g., node size here). Our causal approach directly intervenes on a
sampled subset of edges and supervises an auxiliary task that aligns an edge’s causal importance to the task
with its attention coefficient.
2.2 CAUSAL ATTENTION REGULARIZATION:AN INVARIANCE PREDICTION FORMULATION
CAR is motivated by the invariant prediction (IP) formulation of causal inference (Peters et al.,
2016; Wu et al., 2022). The central insight of this formulation is that, given sub-models that each
contain a different set of predictor variables, the underlying causal model of a system is comprised
of the set of all sub-models for which the predicted class distributions are equivalent, up to a noise
term. This approach is capable of providing statistically rigorous estimates for both the causal effect
strength of predictor variables as well as confidence intervals. With CAR, our core insight is that the
graph structure itself, in addition to the set of node features, comprise the set of predictor variables.
This is equivalent to the intuition that relevant edges for a particular task should not only be assigned
high attention coefficients but also be important to the predictive accuracy of the model (Figure 1).
The removal of these relevant edges from the graph should cause the predictions that rely on them
to substantially worsen.
We leverage the residual formulation of IP to formalize this intuition. This formulation assumes that
we can generate sub-models for different sets of predictor variables, each corresponding to a separate
experiment e∈ E. For each sub-model Se, we compute the predictions Ye=g(Ge
S,Xe
S, e)where
Gis the graph structure, Xis the set of features associated with G,eis the noise distribution, and S
is the set of predictor variables corresponding to Se. We next compute the residuals R=YYe. IP
requires that we perform a hypothesis test on the means of the residuals, with the generic approach
being to perform an F-test for each sub-model against the null-hypothesis. The relevant assumptions
(eF, and eSefor all e∈ E) are satisfied if and only if the conditionals Ye|Seand Yf|Sf
are identical for all experiments e, f ∈ E.
We use an edge intervention-based strategy that corresponds precisely to this IP-based formulation.
However, we differ from the standard IP formulation in how we estimate the final causal model.
While IP provides a method to explicitly construct an estimator of the true causal model (by taking
the intersection of all models for which the null hypothesis was rejected), we rely on intervention-
guided regularization of graph attention coefficients as a way to aggregate sub-models while balanc-
ing model complexity and runtime considerations. In our setting, each sub-model corresponds to a
set of edge interventions and, thus, slightly different graph structures. The same GNN architecture
is trained on each of these sub-models. Given a set of experiments E={e}with sub-models Se,
outputs Yeand errors e, we regularize the attention coefficients to align with sub-model errors, thus
learning a GNN architecture primarily from causal sub-models. Incorporating this regularization as
an auxiliary task, we seek to minimize the following loss:
L=Lp+λLc(1)
The full loss function Lconsists of the loss associated with the prediction Lp, the loss associated
with causal attention task Lc, and a causal regularization strength hyperparameter λthat medi-
ates the contribution of the regularization loss to the objective. For the prediction loss, we have
3
Lp=1
NPN
n=1 p(ˆy(n), y(n)), where Nis the size of the training set, p(·,·)corresponds to the loss
function for the given prediction task, ˆy(n)is the prediction for entity n, and y(n)is the ground truth
value for entity n. We seek to align the attention coefficient for an edge with the causal effect of
removing that edge through the use of the following loss function:
Lc=1
R
R
X
r=1 1
S(r)
X
(n,i,j)S(r)
cα(n)
ij , c(n)
ij (2)
Here, nrepresents a single entity for which we aim to make a prediction. For a node prediction task,
the entity ncorresponds to a node, and in a graph prediction task, ncorresponds to a graph. In this
paper, we assume that all edges are directed and, if necessary, decompose an undirected edge into
two directed edges. In each mini-batch, we generate Rseparate sub-models, each of which consists
of a set of edge interventions S(r),r= 1, . . . , R. Each edge intervention in S(r)is represented
by a set of tuples (n, i, j)which denote a selected edge (i, j)for an entity n. Note that in a node
classification task, nis the same as j(i.e., the node with the incoming edge). More details for the
edge intervention procedure and causal effect calculations can be found in the sections below. The
causal effect c(n)
ij scores the impact of deleting edge (i, j)through a likelihood ratio test. This causal
effect is compared to the edge’s attention coefficient α(n)
ij via the loss function c(·,·). A detailed
algorithm for CAR is provided in algorithm 1.
Algorithm 1 CAR Framework
Input: Training set Dtrain, validation set Dval, model M, regularization strength λ
repeat
for each mini-batch {Bk={n(k)
j}bk
j=1}do
Prediction loss: Lp1
|Bk|Pn∈Bkp(ˆy(n), y(n))
procedure EDGE INTERVENTION
Causal attention loss: Lc0
for round r1to Rdo
Set of edge interventions S(r)← {}
for each entity {n(k)
j}bk
j=1 do
Sample edge (i, j)En(k)
j
 En(k)
j
set of edges related to entity n(k)
j
if (i, j)independent of S(r)then See “Edge intervention procedure”
S(r)S(r)(n(k)
j, i, j)Add edge to set of edge interventions
Compute causal effect c(n)
ij σρ(n)
ij d(n)1Equation 5
end if
end for
Lc← Lc+1
R
1
|S(r)|P(n,i,j)S(r)cα(n)
ij , c(n)
ij Equation 3
end for
end procedure
Total loss: L=Lp+λLc
Update model parameters to minimize L
end for
until Convergence criterion We use convergence of the validation prediction loss.
Edge intervention procedure We sample a set of edges in each round rsuch that the prediction
for each entity will strictly be affected by at most one edge intervention in that round to ensure effect
independence. For example, in a node classification task for a model with one GNN layer, a round
of edge interventions entails removing only a single incoming edge for each node being classified.
In the graph property prediction case, only one edge will be removed from each graph per round.
Because a model with one GNN layer only aggregates information over a 1-hop neighborhood, the
removal of each edge will only affect the model’s prediction for that edge’s target node. To select a
set of edges in the L-layer GNN case, edges are sampled from the 1-hop neighborhood of each node
being classified, and sampled edges that lie in more than one target nodes’ L-hop neighborhood are
removed from consideration as intervention candidates.
4
摘要:

CAUSALLY-GUIDEDREGULARIZATIONOFGRAPHAT-TENTIONIMPROVESGENERALIZABILITYAlexanderWuMITalexwu@mit.eduThomasMarkovichyzTwitterCortextmarkovich@twitter.comBonnieBergerxMITbab@mit.eduNilsHammerlayTwitterCortexnhammerla@twitter.comRohitSinghzMITrsingh@mit.eduABSTRACTGraphattentionnetworksestimatetherela...

展开>> 收起<<
CAUSALLY -GUIDED REGULARIZATION OF GRAPH AT- TENTION IMPROVES GENERALIZABILITY Alexander Wu.pdf

共20页,预览4页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!
分类:图书资源 价格:10玖币 属性:20 页 大小:6.68MB 格式:PDF 时间:2025-04-30

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 20
客服
关注