indicated its value in each feature dimension (equivalent to concatenated onehot feature vectors). As
such, the model receives disentangled feature information in the input, though in principle it can learn
to disentangle feature information given onehot encodings for each unique item.
Label-based order encoding
. Using position-based order encodings, models trained with sequences
up to length
L
encounter an out-of-distribution problem when tested on longer sequences, as position
encodings beyond
L
are unfamiliar to the model. We introduce label-based encoding, which instead
pairs items in each sequence with ascending random integer labels to communicate order information
(Fig 1B). This allows models to encode longer sequences of tokens with familiar labels seen during
training. In our model, these labels were embedded with learnable weights, and we contrast the
random label encoding method with sinusoidal and learnable encodings based on item positions. A
concurrent work also explored the random position method and tested with other types of encodings
(Anonymous, 2022). In all reported results, we pre-generated item labels sampled from a range up to
the maximum generalization length (50) for all sequences in the dataset, and these labels were shared
across training steps and model seeds. In practice, the labels for each sequence can be sampled online
and from a larger range to accommodate generalization to even longer sequences.
Model
. The main model architecture is shown in Fig 1B. Each input sequence consisted of a task
token and the paired item and label tokens, with the EOS token serving as the first query for tokens
in the output sequence. The input tokens were first embedded to the model’s latent representational
space through a set of embedding layers depending on the token type (task, item, or label). The item
and label embeddings were then added to form a composite item embedding. These embedded tokens
were fed into a causal transformer, which contained one or two layers of alternating future-masked
attention sublayers and MLP sublayers. Residual connections and layer normalization were applied
after each sublayer as in Vaswani et al. (2017). We tested architectural variations in the number of
attention heads in different layers of the model while controlling for the total number of learnable
parameters (see detailed hyperparameters in Appendix B). The state of the query token at the output
of the causal transformer was passed through two linear heads to predict the next output token (the
task token, or an item and its associated label).
Training and evaluation
. The models were trained using full teacher forcing (where we always
feed the model the correct tokens) on all sequences of lengths 5 to 25 in the dataset (
∼
46k) and
evaluated for length generalization on sequences of lengths 26 to 50 (
∼
54k). We trained models in
both single-task and multi-task settings. In both cases, the output sequence consisted of the correctly
ordered items and their labels given the task being trained, followed by an EOS token. In single-task
learning, we did not include the task token in training or evaluation. In multi-task learning, the
task token was used and the models were trained to first output the task token before predicting the
output sequence. The training sequences used in multi-task learning remained the same ones between
lengths 5–25, but each sequence corresponded to a different output sequence under different tasks.
The models were trained using softmax cross-entropy loss on the prediction of feature classes, labels,
and task/EOS categories for tokens in the output sequence. Item predictions were treated as average
feature prediction accuracy, i.e., if the model predicted 2/3 features correct, its token-level item
accuracy is 2/3. Training stopped at 32k gradient updates for single-task models and 38k gradient
updates for multi-task models. Below, we report both token-level and sequence-level accuracy, under
both teacher forcing and top1 rollout (i.e., greedy decoding). Results were aggregated over four
random seeds for each task type
×
architecture pair. Unless otherwise specified, results were taken
from the checkpoint with the highest generalization accuracy within each seed. Error shades and
error bars indicate standard error of the mean across models.
3 Results
3.1 Single-task learning
Two-layer models with label encoding learn the SORT task and generalize to longer sequences
.
We first trained the model with the SORT[SHAPE,COLOR,TEXTURE] task. Using our label encoding
method, models with two single-headed layers (indicated as [1,1]) were able to achieve near-ceiling
accuracy on training sequences and generalize to longer sequences (Fig 2; also see quantitative
results in Appendix C). The predictions of the EOS token were also highly accurate in these models
(see Fig S1A in Appendix A.1). Item prediction was more accurate than label prediction in this
task, reflecting that the models represented item feature information more accurately in order to sort
3