Preprint. Under review. NOPAIRS LEFTBEHIND IMPROVING METRIC LEARN - ING WITH REGULARIZED TRIPLET OBJECTIVE

2025-05-06 0 0 5.34MB 28 页 10玖币
侵权投诉
Preprint. Under review.
NOPAIRS LEFT BEHIND: IMPROVING METRIC LEARN-
ING WITH REGULARIZED TRIPLET OBJECTIVE
A. Ali Heydari
Department of Applied Mathematics
Health Sciences Research Institute
University of California, Merced
aheydari@ucmerced.edu
Naghmeh Rezaei
, Daniel J. McDuff, Javier L. Prieto
Google
{naghmehr, dmcduff, xaviprieto}@google.com
ABSTRACT
We propose a novel formulation of the triplet objective function that improves
metric learning without additional sample mining or overhead costs. Our approach
aims to explicitly regularize the distance between the positive and negative samples
in a triplet with respect to the anchor-negative distance. As an initial validation,
we show that our method (called No Pairs Left Behind [NPLB]) improves upon
the traditional and current state-of-the-art triplet objective formulations on stan-
dard benchmark datasets. To show the effectiveness and potentials of NPLB on
real-world complex data, we evaluate our approach on a large-scale healthcare
dataset (UK Biobank), demonstrating that the embeddings learned by our model
significantly outperform all other current representations on tested downstream
tasks. Additionally, we provide a new model-agnostic single-time health risk defini-
tion that, when used in tandem with the learned representations, achieves the most
accurate prediction of subjects’ future health complications. Our results indicate
that NPLB is a simple, yet effective framework for improving existing deep metric
learning models, showcasing the potential implications of metric learning in more
complex applications, especially in the biological and healthcare domains.
1 INTRODUCTION
Metric learning is the task of encoding similarity-based embeddings where similar samples are
mapped closer in space and dissimilar ones afar (Xing et al.,2002;Wang et al.,2019;Roth et al.,
2020). Deep metric learning (DML) has shown success in many domains, including computer vision
(Hermans et al.,2017;Vinyals et al.,2016;Wang et al.,2018b) and natural language processing
(Reimers & Gurevych,2019;Mueller & Thyagarajan,2016;Benajiba et al.,2019). Many DML
models utilize paired samples to learn useful embeddings based on distance comparisons. The most
common architectures among these techniques are the Siamese (Bromley et al.,1993) and triplet
networks (Hoffer & Ailon,2015). The main components of these models are the: (1) Strategies
for constructing training tuples and (2) objectives that the model must minimize. Though many
studies have focused on improving sampling strategies (Wu et al.,2017;Ge,2018;Shrivastava et al.,
2016;Kalantidis et al.,2020;Zhu et al.,2021), modifying the objective function has attracted less
attention. Given that learning representations with triplets very often yield better results than pairs
using the same network (Hoffer & Ailon,2015;Balntas et al.,2016), our work focuses on improving
triplet-based DML through a simple yet effective modification of the traditional objective.
Modifying DML loss functions often requires mining additional samples or identifying new quantities
(e.g. identifying class centers iteratively throughout training (He et al.,2018)) or computing quantities
with costly overheads (Balntas et al.,2016), which may limit their applications. In this work, we
aim to provide an easy and intuitive modification of the traditional triplet loss that is motivated by
metric learning on more complex datasets, and the notion of density and uniformity of each class.
Our proposed variation of the triplet loss leverages all pairwise distances between existing pairs in
traditional triplets (positive, negative, and anchor) to encourage denser clusters and better separability
Work done while at Google as a Research Intern
Corresponding Author
1
arXiv:2210.09506v1 [cs.LG] 18 Oct 2022
Preprint. Under review.
between classes. This allows for improving already existing triplet-based DML architectures using
implementations in standard deep learning (DL) libraries (e.g. TensorFlow), enabling a wider usage
of the methods and improvements presented in this work.
Many ML algorithms are developed for and tested on datasets such as MNIST (LeCun,1998) or
ImageNet (Deng et al.,2009), which often lack the intricacies and nuances of data in other fields,
such as health-related domains (Lee & Yoon,2017). Unfortunately, this can have direct consequences
when we try to understand how ML can help improve care for patients (e.g. diagnosis or prognosis).
In this work, we demonstrate that DML algorithms can be effective in learning embeddings from
complex healthcare datasets. We provide a novel DML objective function and show that our model’s
learned embeddings improve downstream tasks, such as classifying subjects and predicting future
health risk using a single-time point. More specifically, we build upon the DML-learned embeddings
to formulate a new mathematical definition for patient health-risks using a single time point which,
to the best of our knowledge, does not currently exist. To show the effectiveness of our model and
health risk definition, we evaluate our methodology on a large-scale complex public dataset, the
UK Biobank (UKB) (Bycroft et al.,2018), demonstrating the implications of our work for both
healthcare and the ML community. In summary, our most important contributions can be described as
follows. 1) We
present a novel triplet objective function
that improves model learning without any
additional sample mining or overhead computational costs. 2) We
demonstrate the effectiveness
of our approach on a large-scale complex public dataset
(UK Biobank) and on conventional
benchmarking datasets (MNIST and Fashion MNIST (Xiao et al.,2017)). This demonstrates the
potential of DML in other domains which traditionally may have been less considered. 3) We
provide
a novel definition of patient health risk from a single time point
, demonstrating the real-world
impact of our approach by predicting current healthy subjects’ future risks using only a single lab
visit, a challenging but crucial task in healthcare.
2 BACKGROUND AND RELATED WORK
Contrastive learning aims to minimize the distance between two samples if they belong to the
same class (are similar). As a result, contrastive models require two samples to be inputted before
calculating the loss and updating their parameters. This can be thought of as passing two samples to
two parallel models with tied weights, hence being called Siamese or Twin networks (Bromley et al.,
1993). Triplet networks (Hoffer & Ailon,2015) build upon this idea to rank positive and negative
samples based on an anchor value, thus requiring the model to produce mappings for all three before
the optimization step (hence being called triplets).
Modification of Triplet Loss:
Due to their success and importance, triplet networks have attracted
increasing attention in recent years. Though the majority of proposed improvements focus on the
sampling and selection of the triplets, some studies (Balntas et al.,2016;Zhao et al.,2019;Kim &
Park,2021;Nguyen et al.,2022) have proposed modifications of the traditional triplet loss proposed
in Hoffer & Ailon (2015). Similar to our work, Multi-level Distance Regularization (MDR) (Kim
& Park,2021) seeks to regularize the DML loss function. MDR regularizes the pairwise distances
between embedding vectors into multiple levels based on their similarity. The goal of MDR is to
disturb the optimization of the pairwise distances among examples and to discourage positive pairs
from getting too close and the negative pairs from being too distant. A drawback of MDR is the
choice of hyperparameter that balances the regularization term (although the authors suggest a fixed
value (
λ= 0.1
) which improved all tested datasets). Our approach does not require additional
hyperparameters since the regularization is done based on other pairwise distances. Most related
to our work, Balntas et al. (2016) modified the traditional objective by explicitly accounting for
the distance between the positive and negative pairs (which the traditional triplet function does not
consider), and applied their model to learn local feature descriptors using shallow convolutional
neural networks. They introduce the idea of "in-triplet hard negative", referring to the swap of the
anchor and positive sample if the positive sample is closer to the negative sample than the anchor, thus
improving on the performance of traditional triplet networks (we refer to this approach as Distance
Swap). Though this method uses the distance between the positive and negative samples to choose
the anchor, it does not explicitly enforce the model to regularize the distance between the two, which
was the main issue with the original formulation. Our work addresses this pitfall by using the notion
of local density and uniformity (defined later in §3) to explicitly enforce the regularization of the
distance between the positive and negative pairs using the distance between the anchors and the
2
Preprint. Under review.
Figure 1:
Visual comparisons between a traditional triplet loss (left), a Distance Swap triplet loss (middle)
and our proposed No Pairs Left Behind objective (right) on a toy example.
In this figure
φ
refers to a
learned operator,
δ+=d(φ(pi), φ(ai))
,
δ=d(φ(ni), φ(ai))
and
denotes the margin. For this toy example,
the network
φ(·)
trained on the traditional objective (left) is only concerned with satisfying
δ> δ++
,
potentially mapping a positive close to negative samples
nj, nq
, which is not desirable. A similar case could
happen for the Distance Swap variant as well (middle). Our proposed objective seeks to avoid this by explicitly
forming dependence between the distance of the positive and negative pair and
δ
. This regularization results in
denser mappings when samples are similar and vice versa when the samples are dissimilar, as shown in §4. We
describe our formulation and modification in §3.
negatives. As a result, our approach ensures better inter-class separability while encouraging denser
intra-class embeddings.
Deep Learned Embeddings for Healthcare:
Recent years have seen an increase in the number
of DL models for Electronic Health Records (EHR) with several methods aiming to produce rich
embeddings to better represent patients (Rajkomar et al.,2018;Choi et al.,2016b;Tran et al.,2015;
Nguyen et al.,2017;Choi et al.,2016a;Pham et al.,2017). Though most studies in this area consider
temporal components, DeepPatient (Miotto et al.,2016) does not explicitly account for time, making
it an appropriate model for comparison with our representation learning approach given our goal of
predicting patients’ health risks using a single snapshot. DeepPatient is an unsupervised DL model
that seeks to learn general deep representations by employing three stacks of denoising autoencoders
that learn hierarchical regularities and dependencies through reconstructing a masked input of EHR
features. We hypothesize that learning patient reconstructions alone (even with masking features)
does not help to discriminate against patients based on their similarities. We aim to address this by
employing a deep metric learning approach that learns similarity-based embeddings.
Predicting Patient’s Future Health Risks:
Assessing patients’ health risk using EHR remains a
crucial, yet challenging task of epidemiology and public health (Li et al.,2015). An example of such
challenges are the clinically-silent conditions, where patients fall within "normal" or "borderline"
ranges for specific known blood work markers, while being at the risk of developing chronic conditions
and co-morbidities that will reduce quality of life and cause mortality later on (Li et al.,2015).
Therefore, early and accurate assessment of health risk can tremendously improve the patient care,
specially in those who may appear "healthy" and do not show severe symptoms. Current approaches
for assessing future health complications tie the definition of health risks to multiple time points
(Hirooka et al.,2021;Chowdhury & Tomal,2022;Razavian et al.,2016;Kamal et al.,2020;Cohen
et al.,2021;Che et al.,2017). Despite the obvious appeal of such approaches, the use of many
visits for modeling and defining risk simply ignores a large portion of patients who do not return
for subsequent check ups, especially those with lower incomes and those without adequate access
to healthcare (Kullgren et al.,2010;Taani et al.,2020;Nishi et al.,2019). Given the importance of
addressing these issues, we propose a mathematical definition (that is built upon DML) based on a
single time point, which can be used to predict patient health risk from a single lab visit.
3 METHODS
Main Idea of No Pairs Left Behind (NPLB)
: The main idea behind our approach is to ensure that,
during optimization, the distance between positive
pi
and negative samples
ni
is considered, and
regularized with respect to the anchors
ai
(i.e. explicitly introducing a notion of distance between
d(pi, ni)
which depends on
d(ai, ni)
). We visualize this idea in Fig. 1. The mathematical intuition
behind our approach can be described by considering in-class local density and uniformity, as
introduced in Rojas-Thomas & Santos (2021) for unsupervised clustering evaluation metric.
3
Preprint. Under review.
Given a metric learning model
φ
, let local density of a class
ck
as
LD(ck) = min{d(φ(pi), φ(pj))}
,
for
i6=j
and
pick
, and let average density
AD(ck)
be the average local density of all point in the
class. An ideal operator
φ
would produce embeddings that are compact while well separated from
other classes, or that the in-class embeddings are uniform. This notion of uniformity, is proportional
to the difference between the local and average density of each class, i.e.
Unif(ck) = (|LD(ck)AD(ck)|
AD(ck)+ξif |ck|>1
0Otherwise .
for
0< ξ 1
. However, computing density and uniformity of classes is only possible post-hoc once
all labels are present and not feasible during training if the triplets are mined in a self-supervised
manner. To reduce the complexity and allow for general use, we utilize proxies for the mentioned
quantities to regularize the triplet objective using the notion of uniformity. We take the distance
between positive and negative pairs as inversely proportional to the local density of a class. Similarly,
the distance between anchors and negative pairs is closely related to the average density, given
that a triplet model maps positive pairs inside an
-ball of the anchor (
being the margin).
In this sense, the uniformity of a class is inversely proportional to
|d(φ(pi), φ(ni))d(φ(ai), φ(ni))|
.
NPLB Objective
: Let
φ(·)
denote an operator and
T
be the set of triplets of the form
(pi, ai, ni)
(positive, anchor and negative tensors) sampled from a mini-batch
B
with size
N
. For the ease of
notation, we will write
φ(qi)
as
φq
. Given a margin
(a hyperparameter), the traditional objective
function for a triplet network is shown in Eq. (1):
LT riplet =1
N
N
X
(pi,ai,ni)T
[d(φa, φp)d(φa, φn)) + ]+(1)
with
[·]+= max,0}
and
d(·)
being the Euclidean distance. Minimizing Eq.
(1)
only ensures that
the negative pairs fall outside of an
-ball around the
ai
, while bringing the positive sample
pi
inside
of this ball (illustrated in Fig. 1), satisfying
d(φa, φn)> d(φa, φp) +
. However, this objective
does not explicitly account for the distance between positive and negative samples, which can impede
performance especially when there exists high in-class variability. Motivated by our main idea of
having denser and more uniform in-class embeddings, we add a simple regularization term to address
the issues described above, as shown in Eq. (2)
LNP LB =1
N
N
X
(pi,ai,ni)T
[d(φa, φp)d(φa, φn) + ]++ [d(φp, φn)d(φa, φn)]p,(2)
where
pN
and
NP LB
refers to "No Pairs Left Behind." The regularization term in Eq.
(2)
enforces positive and negative samples to be roughly the same distance away as all other negative
pairings, while still minimizing their distance to the anchor values. However, if not careful, this
approach could result in the model learning to map
ni
such that
d(φa, φp)>max{, d(φp, φn)}
,
which would ignore the triplet term, resulting in a minimization problem with no lower bound
1
. To
avert such issues, we restrict p= 2 (or generally, p0 (mod 2)) as in Eq. (3).
LNP LB =1
N
N
X
(pi,ai,ni)T
[d(φa, φp)d(φa, φn) + ]++ [d(φp, φn)d(φa, φn)]2,(3)
Note that this formulation does not require mining of any additional samples nor complex computa-
tions since it just uses the existing samples in order to regularize the embedded space. Moreover
LNP LB = 0 = [d(φp, φn)d(φa, φn)]2= [d(φa, φp)d(φa, φn) + ]+
which, considering only the real domain, is possible if and only if
d(φp, φn) = d(φa, φn)
, and
d(φa, φn)d(φa, φp) + , explicitly enforcing separation between negative and positive pairs.
1The mentioned pitfall can be realized by taking p= 1, i.e.
L(pi, ai, ni) = 1
N
N
X
(pi,ai,ni)T
[d(φa, φp)d(φa, φn)) + ]++ [d(φp, φn)d(φa, φn)].
In this case, the model can learn to map niand aisuch that d(φa, φn)> C where
C= max{d(φp, φn), d(φa, φp) + m}, resulting in L<0.
4
Preprint. Under review.
Figure 2:
Visual comparisons between a traditional triplet loss, Distance Swap and the proposed NPLB
objective on train (top row) and test (bottom row) sets of (A) MNIST and (B) Fashion MNIST.
(A) To
evaluate the feasibility of the proposed Triplet loss on general datasets, we trained the same network (described
in Appendix Appendix J) under identical conditions on the MNIST dataset, with the only difference being
the loss function used. (B) UMAP-reduced embeddings of Fashion MNIST trained with the three triplet loss
versions. Our results indicate both quantitative and qualitative improvements in the embeddings, as shown above
and in Table 1. (See Fig. S5 and S6 for higher resolution versions).
4 VALIDATION OF NPLB ON STANDARD DATASETS
Prior to testing our methodology on healthcare data, we validate our derivations and intuition on
common benchmark datasets, namely MNIST and Fashion MNIST. To assess the improvement
gains from the proposed objective, we refrained from using more advanced triplet construction
techniques and followed the most common approach of constructing triplets using the labels offline.
We utilized the same architecture and training settings for all experiments, with the only difference per
dataset being the objective functions (see Appendix J for details on each architecture). After training,
we evaluated our approach quantitatively through assessing classification accuracy of embeddings
produced by optimizing the traditional triplet, Swap Distance and our proposed NPLB objective. The
results on MNIST and Fashion MNIST are presented in Table 1, showing that our approach improves
classification. We also assessed the embeddings qualitatively: Given the simplicity of MNIST, we
designed our model to produce two-dimensional embeddings which we directly visualized. For
Fashion MNIST, we generated embeddings in
R64
and used Uniform Manifold Approximation and
Projection (UMAP) (McInnes et al.,2018) to reduce the dimensions for visualizations, as shown
in Fig. 2. Our results demonstrate that networks trained on our proposed NPLB objective produce
embeddings that are denser and well separated in space, as desired.
Table 1:
Comparison of state-of-the-art (SOTA) triplet losses with our proposed objective function.
We
present (weighted) F1 scores for classifying MNIST and Fashion MNIST embeddings (visualized in Fig. 2)
using XGBoost on five random train-test splits (we randomly split the data into train and test (80-20) five time,
and calculated the mean and standard deviation of the accuracies). We note that the improved performance of
the NPLB-trained model was consistent across different classifiers (as shown for a different dataset in §5).
Trad. Triplet Loss MDR Distance Swap NPLB Objective (Ours)
MNIST 0.9859 ±0.0009 0.9886 ±0.0009 0.9891 ±0.0003 0.9954 ±0.0003
Fashion MNIST 0.9394 ±0.001 0.9557 ±0.001 0.9536 ±0.001 0.9664 ±0.001
5 IMPROVING PATIENT REPRESENTATION LEARNING
In this section, we aim to demonstrate the potential and implications of our approach on a more
complex dataset in three steps: First, we show that deep metric learning improves upon current
state-of-the-art patient embedding models (§5.1). Next, we provide a comparison between NPLB,
5
摘要:

Preprint.Underreview.NOPAIRSLEFTBEHIND:IMPROVINGMETRICLEARN-INGWITHREGULARIZEDTRIPLETOBJECTIVEA.AliHeydariDepartmentofAppliedMathematicsHealthSciencesResearchInstituteUniversityofCalifornia,Mercedaheydari@ucmerced.eduNaghmehRezaeiy,DanielJ.McDuff,JavierL.PrietoGoogle{naghmehr,dmcduff,xaviprieto}@go...

展开>> 收起<<
Preprint. Under review. NOPAIRS LEFTBEHIND IMPROVING METRIC LEARN - ING WITH REGULARIZED TRIPLET OBJECTIVE.pdf

共28页,预览5页

还剩页未读, 继续阅读

声明:本站为文档C2C交易模式,即用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。玖贝云文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知玖贝云文库,我们立即给予删除!

相关推荐

分类:图书资源 价格:10玖币 属性:28 页 大小:5.34MB 格式:PDF 时间:2025-05-06

开通VIP享超值会员特权

  • 多端同步记录
  • 高速下载文档
  • 免费文档工具
  • 分享文档赚钱
  • 每日登录抽奖
  • 优质衍生服务
/ 28
客服
关注