
Shared
encoder
Source private
encoder
Private target
encoder
xR
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
xs
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
xT
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
s
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
pR
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
pT
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
zs
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
zs
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
zpT
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
zpR
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
Handling heterogeneous
feature spaces
hpR
w,l
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
hpT
w,l
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
hs
w,l
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
ˆµR
w(xR)
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
ˆµT
w(xT)
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
Sharing information between PO functions
across source and target domains
pT(xT)
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
pR(xR)
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
Figure 1: Building block
for handling the heteroge-
neous feature space of the
source and target domains.
As illustrated in Figure 1, a source example
xR
is encoded to
[zs||zpR]
and a target example
xT
to
[zs||zpT]
, where
||
denotes concatenation and
where both representations have size
Ds+Dp
. Note that an alternative
approach would have been to use the domain-specific encoders
φp
only
for the private features
xpR
and
xpT
. However, inputting the shared
features through both types of encoders allows us to learn relationships
between them that are shared across the different domain, as well as
interactions which are domain-specific.
To discourage redundancy and ensure that
zp
and
zs
encode different
information from the input features, we propose using a regularization
loss that enforces their orthogonality [39]:
Lorthz=kζs>ζpRk2
F+kζs>ζpTk2
F(3)
where
ζpR, ζpT
and
ζs
are matrices whose rows are the private
zpR
,
zpT
and shared
zs
representations for the source and target examples
respectively, and k·k2
Fis the squared Frobenius norm.
4.2 Sharing information between potential outcomes response functions across domains
As treatment responses can vary between different patient populations, it is important to build
a transfer approach that enables learning target-specific outcome functions, while also sharing
information from the source domain. We propose a building block for sharing information between
PO functions across domains that is inspired by the FlexTENet architecture [
14
] and by works in
multitask learning [
41
] and that involves having private layers (subspaces) for each domain as well as
shared layers.
Shared
encoder
Source private
encoder
Private target
encoder
xR
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
xs
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
xT
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
s
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
pR
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
pT
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
zs
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
zs
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
zpT
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
zpR
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
Handling heterogeneous
feature spaces
hpR
w,l
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
hpT
w,l
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
hs
w,l
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
ˆµR
w(xR)
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
ˆµT
w(xT)
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
Sharing information between PO functions
across source and target domains
pT(xT)
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
pR(xR)
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
<latexit sha1_base64="(null)">(null)</latexit>
Figure 2: Building block for sharing information
between PO across domains.
As shown in Figure 2, for each treatment
w∈
{0,1}
, we consider a model architecture for esti-
mating its PO functions in the source and target
domains
µR
w
and
µT
w
that consists of
L
layers,
each having a shared and two private subspaces
(one for each domain). For simplicity, we con-
sider the same number of hidden dimensions
for each shared and private subspace. Let
˜
hpR
w,l
,
˜
hpT
w,l,˜
hs
w,l
be the inputs and
hpR
w,l
,
hpT
w,l, hs
w,l
the
outputs of the
lth
layer. For
l > 1
, similarly to
[
14
], the inputs to the
(l+1)th
layer are obtained
as follows:
˜
hpR
w,l+1 = [hs
w,l||hpR
w,l]
,
˜
hpT
w,l+1 =
[hs
w,l||hpT
w,l]
,
˜
hs
w,l+1 = [hs
w,l]
. For
l= 1
, we
set
˜
hpR
w,1= ΦR(xR)
,
˜
hpT
w,1= ΦT(xT)
, and
˜
hs
w,1=˜
hpR
w,1
when using an example from the
source domain or
˜
hs
w,1=˜
hpT
w,1
when using an
example from the target domain, where
ΦR(·)
and
ΦT(·)
are input representations. When sharing the
encoders from Section 4.1 for both treatments, we set
ΦR(xR) = [zs||zpR]
and
ΦT(xT)=[zs||zpT]
.
However, as we will see in Section 5.1, this input representation is CATE learner specific and can be
extended (see Section 5.2) by adding more representation layers to share information between PO
functions within each domain. For the last layer
L
, we build
hs
w,L, hpR
w,L, hpT
w,L
to each have the same
dimension as the potential outcome y.
Overall, let
gR
w
,
gT
w
be the hypothesis functions estimating the potential outcomes in the source
and target domains respectively, such that
gR
w(ΦR(xR)) = ψ(hpR
w,L +hs
w,L)
and
gT
w(ΦT(xT)) =
ψ(hpT
w,L +hs
w,L)
, where
ψ
is the linear function for continuous outcomes and sigmoid function for
binary ones. This allows us to define the following loss function for estimating the PO:
Ly=
NR
X
i=1
l(yi, gR
wi(ΦR
wi(xR
i))) +
NT
X
i=1
l(yi, gT
wi(ΦT
wi(xT
i))) (4)
5