All blogs

From English to Español: Model Training

Welcome back! In the first part of our series, we meticulously prepared our dataset, trained a shared BPE tokenizer, and serialized everything into TFRecord files. Now comes the exciting part: building our very own Transformer model from the ground up using TensorFlow. We'll assemble all the essential components, positional embeddings, self-attention, cross-attention, and finally train our model to translate languages.

Think of a Transformer as a team of expert linguists in a room. To translate a sentence, they don't just look at words in isolation. They need to understand the context, the word order, and the intricate relationships between them. Our model will do the same, and we'll build each part of this "expert team" ourselves.

Recap: Our Configuration

We already defined our Config class in Part 1, but here's a quick refresher on the hyperparameters that drive the model architecture itself:

class Config:
    MIX_VOCAB_SIZE = 16000  # Shared English+Spanish vocabulary
    EMBED_DIM = 512         # d_model, the "richness" of meaning per token
    FF_DIM = 1024           # Workspace size for the feed-forward layer
    NUM_HEADS = 8           # Parallel attention "experts"
    NUM_ENCODER_LAYERS = 6
    NUM_DECODER_LAYERS = 6
    DROPOUT_RATE = 0.1
    MAX_LENGTH = 96
    BATCH_SIZE = 128
    EPOCHS = 50
    WARMUP_STEPS = 4000

Think of EMBED_DIM as the number of adjectives you can use to describe a token's meaning. NUM_HEADS is like having multiple specialists helping the model capture different kinds of context. These values mirror the base Transformer from the paper.

Stop! Words Have an Order: The Positional Embedding Layer

What it does: Adds a position-aware signal to every token embedding so the model can tell "dog bites man" apart from "man bites dog".

Why it does it: Self-attention is permutation-invariant, shuffle the input tokens and the attention output shuffles the same way. Without an explicit position signal, the network literally can't perceive order. We follow the paper's sinusoidal formulation: even dimensions use sin, odd dimensions use cos, with frequencies that decay geometrically across the embedding axis. Because each position becomes a unique combination of waves, the network can also reason about relative offsets (e.g. "the token 3 steps back") via simple linear functions of these vectors. The formula which we are going to apply is:

PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

class PositionalEmbedding(Layer):
    def __init__(self, sequence_length, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.sequence_length = sequence_length
        self.embed_dim = embed_dim
        self.pos_encoding = self.positional_encoding(sequence_length, embed_dim)

    def get_angles(self, pos, i, d_model):
        angles = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
        return pos * angles

    def positional_encoding(self, position, d_model):
        angle_rads = self.get_angles(
            np.arange(position)[:, np.newaxis],
            np.arange(d_model)[np.newaxis, :],
            d_model
        )
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
        pos_encoding = angle_rads[np.newaxis, ...]
        return tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, inputs):
        if getattr(inputs, 'sparse', False) or isinstance(inputs, tf.SparseTensor):
            inputs = tf.sparse.to_dense(inputs)

        length = tf.shape(inputs)[1]
        return inputs + self.pos_encoding[:, :length, :]

We pre-compute the full (MAX_LENGTH, embed_dim) matrix once at construction time and just slice it to the batch's actual length in call. Because the encoding is fixed (not learned), it adds zero trainable parameters. The sparse-tensor check at the top is a small defensive guard. Keras occasionally hands embeddings through as sparse tensors, and broadcasting sparse + dense would fail.

Padding Masks: Don't Attend to <pad>

What it does: Builds a per-batch (batch, 1, seq) mask that is 1 at real-token positions and 0 at <pad> positions, so the attention softmax can zero those positions out.

Why it does it: Because our batches are length-bucketed and padded with <pad> (id 0), every batch has positions that carry no meaning. If we let attention see them, the softmax leaks probability mass onto padding, gradients flow into junk positions and the model wastes capacity learning to ignore them. Worse, padding can dominate when most of a batch is short.

We wrap this in a tiny Layer so it shows up cleanly in the model graph and serializes with model.save:

class PaddingMask(Layer):
    """Returns (batch, 1, seq) mask: 1 for real tokens, 0 for <pad> (id 0)."""
    def call(self, inputs):
        return tf.cast(tf.not_equal(inputs, 0), tf.int32)[:, tf.newaxis, :]

The mask gets reused in three places: encoder self-attention (enc_padding_mask), decoder cross-attention (also enc_padding_mask, since the decoder attends to the encoder's outputs), and decoder self-attention (dec_padding_mask, combined with the causal mask).

The Understanding Unit: The Transformer Encoder

What it does: Takes the embedded source sentence and progressively rewrites each token's vector so that it encodes both what the token means and how it relates to every other token in the sentence.

Why it does it: A standalone word embedding for "bank" is the same in "river bank" and "savings bank". The encoder's job is to contextualize that embedding using the rest of the sentence so the decoder has a disambiguated, sentence aware representation to translate from.

Each TransformerEncoder layer has two sub-layers, and each sub-layer is wrapped with residual + LayerNorm + Dropout (the classic LayerNorm(x + Sublayer(x)) pattern):

  1. Multi-Head Self-Attention: every token looks at every other token and produces a weighted blend. The "multi-head" part runs NUM_HEADS = 8 of these attentions in parallel, each in a key_dim = embed_dim / num_heads = 64-dimensional subspace. Different heads end up specializing some attend to syntactic neighbors, others to long-range dependencies, others to specific token types.
  2. Position-wise Feed-Forward Network: a 2-layer MLP (512 → 1024 → 512, ReLU in the middle) applied independently to each position. This is where each token does its private "thinking" on whatever attention surfaced for it.

The reason key_dim = embed_dim // num_heads and not embed_dim itself: the paper splits d_model evenly across heads (d_k = d_v = d_model / h = 64), so 8 heads at 64 dims is roughly the same compute as one big 512-dim head but with the bonus of multiple specialized attention patterns.

The two "stabilizer" components matter just as much:

  • Residual connections: (inputs + attn_output) without these, gradients vanish through a 6-layer-deep stack. They also let each layer learn an additive refinement rather than rebuilding the representation from scratch.
  • LayerNorm: keeps each token's activations on a stable scale across layers, which is what allows the network to be this deep without going haywire.
  • Dropout: applied to the attention and FFN outputs before the residual add, with rate 0.1, to regularize.
class TransformerEncoder(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, drop_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads)
        self.ffn = Sequential([
            Dense(ff_dim, activation="relu"),
            Dense(embed_dim),
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(drop_rate)
        self.dropout2 = Dropout(drop_rate)

    def call(self, inputs, padding_mask=None, training=False):
        attn_output = self.att(inputs, inputs, attention_mask=padding_mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

The Writing Unit: The Transformer Decoder

What it does: Generates the Spanish translation one token at a time, conditioning on (a) the encoder's contextualized representation of the English sentence and (b) the Spanish tokens it has already produced.

Why it does it: Translation isn't a 1:1 word mapping. The decoder needs to pick its next word given both the full source meaning and its own running output (so it stays grammatically consistent and doesn't repeat itself). Stacking 6 decoder layers lets it refine that choice through multiple rounds of "look at the source" + "look at what I've written so far".

Each TransformerDecoder layer has three sub-layers (vs the encoder's two), each wrapped with residual + LayerNorm + Dropout:

  1. Masked (Causal) Multi-Head Self-Attention: like the encoder's self-attention, but the attention scores are masked so position t can only see positions 0…t. The causal mask is the whole reason we can train the decoder in parallel across an entire target sentence: every position learns "predict the next token" from a non-cheating context.
  2. Encoder-Decoder Cross-Attention: queries come from the decoder, keys/values come from the encoder's output. This is where source information actually flows into the translation: for each Spanish position being generated, the decoder asks "which English tokens matter most right now?"
  3. Position-wise Feed-Forward Network: same 512 → 1024 → 512 MLP as in the encoder.

Two masks are in play here. The causal mask is a lower-triangular matrix of 1s where 1 means "this position is allowed to be attended to":

1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1

We build it with tf.linalg.band_part (a one-liner for lower-triangular matrices). The decoder padding mask is 1 for real target tokens, 0 for <pad>. We combine them with tf.minimum so a position is attendable only if it's both non-future and non-padding. The encoder padding mask is passed separately into cross-attention so the decoder doesn't waste attention on padded source positions either.

class TransformerDecoder(Layer):
    def __init__(self, embed_dim, ff_dim, num_heads, drop_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.att1 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads)
        self.att2 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads)
        self.ffn = Sequential([
            Dense(ff_dim, activation="relu"),
            Dense(embed_dim),
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.layernorm3 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(drop_rate)
        self.dropout2 = Dropout(drop_rate)
        self.dropout3 = Dropout(drop_rate)

    def get_causal_attention_mask(self, inputs):
        batch_size = tf.shape(inputs)[0]
        sequence_length = tf.shape(inputs)[1]

        mask = tf.linalg.band_part(
            tf.ones((sequence_length, sequence_length), dtype=tf.int32),
            -1,
            0
        )

        mask = mask[tf.newaxis, :, :]
        return tf.tile(mask, [batch_size, 1, 1])

    def call(self, inputs, encoder_outputs,
             encoder_padding_mask=None, decoder_padding_mask=None, training=False):
        causal_mask = self.get_causal_attention_mask(inputs)
        if decoder_padding_mask is not None:
            combined_mask = tf.minimum(causal_mask, decoder_padding_mask)
        else:
            combined_mask = causal_mask

        att1_output = self.att1(inputs, inputs, attention_mask=combined_mask)
        att1_output = self.dropout1(att1_output, training=training)
        out1 = self.layernorm1(inputs + att1_output)

        att2_output = self.att2(out1, encoder_outputs, attention_mask=encoder_padding_mask)
        att2_output = self.dropout2(att2_output, training=training)
        out2 = self.layernorm2(out1 + att2_output)

        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        return self.layernorm3(out2 + ffn_output)

A Smart Study Plan: Custom Learning Rate Scheduler

What it does: Computes a learning rate that grows linearly for the first warmup_steps, then decays as 1/sqrt(step), all scaled by 1/sqrt(d_model).

Why it does it: Transformer training is notoriously sensitive at the start where the weights are random, attention distributions are noise, and a "normal" learning rate can blow up LayerNorm and never recover. The warm-up gives the optimizer time to stabilize its second-moment estimates, and the 1/sqrt(step) decay lets it eventually settle into a good minimum without overshooting.

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

Assembling the Full Model

Now we wire everything together, following three additional details from the paper:

  1. Embedding scaling: token embeddings are multiplied by sqrt(d_model) before the positional encoding is added. Fresh embedding vectors are initialized near zero with variance ~1/d_model, while the sinusoidal positional vectors have unit-scale entries. Without scaling, the positional signal would completely dominate the semantic signal at layer 0.
  2. Weight tying: instead of three independent matrices (encoder embedding, decoder embedding, output projection), we use a single shared embedding for all three. With a shared BPE vocabulary this is a natural fit as the same tokens mean the same thing on both the input and output side. This cuts roughly 16M parameters (vocab_size × d_model × 2) out of our model while typically improving generalization, because both the input lookup and the output softmax now train the same semantic space.
  3. Logits, not probabilities: the final layer outputs raw logits (no softmax). We will apply log-softmax inside the loss function, this is faster and numerically more stable than applying softmax here and then taking the log.

The TiedOutputProjection layer below is just a matmul against the (transposed) shared embedding matrix. The extra get_config / from_config boilerplate is so Keras can serialize and reload the model without duplicating the embedding weights:

class TiedOutputProjection(Layer):
    """Output projection that reuses the embedding matrix (weight tying)."""
    def __init__(self, embedding_layer, **kwargs):
        super().__init__(**kwargs)
        self.embedding_layer = embedding_layer

    def call(self, inputs):
        return tf.matmul(inputs, self.embedding_layer.embeddings, transpose_b=True)

    def get_config(self):
        config = super().get_config()
        config.update({
            "embedding_layer": tf.keras.saving.serialize_keras_object(self.embedding_layer)
        })
        return config

    @classmethod
    def from_config(cls, config):
        embedding_layer_config = config.pop("embedding_layer")
        embedding_layer = tf.keras.saving.deserialize_keras_object(embedding_layer_config)
        return cls(embedding_layer=embedding_layer, **config)


def build_transformer(config):
    encoder_inputs = Input(shape=(None,), dtype="int64", name="encoder_inputs")
    decoder_inputs = Input(shape=(None,), dtype="int64", name="decoder_inputs")

    enc_padding_mask = PaddingMask(name="enc_padding_mask")(encoder_inputs)
    dec_padding_mask = PaddingMask(name="dec_padding_mask")(decoder_inputs)

    shared_embedding = Embedding(
        config.MIX_VOCAB_SIZE,
        config.EMBED_DIM,
        name="shared_embedding",
    )
    scale = config.EMBED_DIM ** 0.5  # plain Python float

    # Encoder tower
    x = shared_embedding(encoder_inputs)
    x = x * scale
    x = PositionalEmbedding(config.MAX_LENGTH, config.EMBED_DIM)(x)
    x = Dropout(config.DROPOUT_RATE)(x)
    for _ in range(config.NUM_ENCODER_LAYERS):
        x = TransformerEncoder(
            config.EMBED_DIM, config.NUM_HEADS, config.FF_DIM, config.DROPOUT_RATE
        )(x, padding_mask=enc_padding_mask)
    encoder_outputs = x

    # Decoder tower
    y = shared_embedding(decoder_inputs)
    y = y * scale
    y = PositionalEmbedding(config.MAX_LENGTH, config.EMBED_DIM)(y)
    y = Dropout(config.DROPOUT_RATE)(y)
    for _ in range(config.NUM_DECODER_LAYERS):
        y = TransformerDecoder(
            config.EMBED_DIM, config.FF_DIM, config.NUM_HEADS, config.DROPOUT_RATE
        )(
            y, encoder_outputs,
            encoder_padding_mask=enc_padding_mask,
            decoder_padding_mask=dec_padding_mask,
        )

    decoder_outputs = TiedOutputProjection(
        shared_embedding, name="output_projection"
    )(y)
    return Model([encoder_inputs, decoder_inputs], decoder_outputs, name="transformer")

Masked Loss with Label Smoothing

Remember that our batches are padded with <pad> (id 0) so every sentence in a batch has the same length. If we naively compute loss over those padding positions, we'd be telling the model to "predict pad", thus wasting capacity and producing misleading metrics.

The fix is a mask: build a boolean mask that's True for real tokens and False for pads, then sum the loss only over real positions and divide by the count of real positions.

While we're touching the loss, the paper also recommends label smoothing with ε = 0.1. Instead of training the model to assign 100% probability to the correct token (and 0% to everything else), we redistribute a tiny ε slice of probability mass uniformly across the vocabulary. This makes the model less overconfident, improves calibration, and counterintuitively usually improves BLEU.

Since the model outputs logits, we compute the cross-entropy manually using log_softmax, which is more numerically stable than the standard from_logits=False path.

LABEL_SMOOTHING = 0.1


def masked_loss(real, pred):
    # real: (batch, seq)   pred: (batch, seq, vocab) -- LOGITS
    real = tf.cast(real, tf.int32)

    vocab_size = tf.shape(pred)[-1]
    real_one_hot = tf.one_hot(real, depth=vocab_size, dtype=pred.dtype)

    smooth_targets = real_one_hot * (1.0 - LABEL_SMOOTHING) + \
                     LABEL_SMOOTHING / tf.cast(vocab_size, pred.dtype)

    log_probs = tf.nn.log_softmax(pred, axis=-1)
    loss_per_token = -tf.reduce_sum(smooth_targets * log_probs, axis=-1)

    mask = tf.cast(tf.not_equal(real, 0), loss_per_token.dtype)
    loss_per_token *= mask
    return tf.reduce_sum(loss_per_token) / tf.reduce_sum(mask)


def masked_accuracy(real, pred):
    pred_ids = tf.argmax(pred, axis=2)
    real = tf.cast(real, pred_ids.dtype)
    is_correct = tf.equal(real, pred_ids)

    mask = tf.math.logical_not(tf.math.equal(real, 0))
    is_correct = tf.math.logical_and(mask, is_correct)

    is_correct = tf.cast(is_correct, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(is_correct) / tf.reduce_sum(mask)

Compiling the Model

We build, compile, and summary() the transformer inside strategy.scope() so the model's variables are mirrored across all available GPUs. The optimizer is Adam with the paper's tuned beta_2=0.98, epsilon=1e-9, and these settings work better than the Keras defaults specifically for Transformer-style warm-up schedules.

with strategy.scope():
    transformer = build_transformer(config)
    learning_rate = CustomSchedule(config.EMBED_DIM)
    optimizer = Adam(
        learning_rate,
        beta_1=0.9,
        beta_2=0.98,
        epsilon=1e-9
    )

    transformer.compile(
        optimizer=optimizer,
        loss=masked_loss,
        metrics=[masked_accuracy]
    )

    transformer.summary()

The BLEU Score

When we train a model to translate from English to Spanish, it might seem natural to check how "accurate" it is, how often its predictions exactly match the reference translations. But for translation, raw accuracy can be very misleading.

Here's why: in language, there's rarely just one correct way to say something. For example, the English sentence "I'm going home." could be translated as either:

  • "Voy a casa."
  • or "Me voy a casa."

Both are perfectly valid. But if our model produces the second one while the reference is the first, token-level accuracy would count it as wrong even though the meaning is identical.

That's why, in translation tasks, we use a more language-aware metric: the BLEU Score. BLEU stands for Bilingual Evaluation Understudy. Instead of checking for exact word matches, it measures how similar the model's output is to the reference by comparing n-grams (sequences of 1, 2, 3, 4 consecutive tokens). It also includes a brevity penalty so the model can't cheat by producing very short outputs.

A BLEU score lies between 0 and 1 (often reported on a 0–100 scale). Higher is better; a score of 1 (or 100) would be a perfect match against the reference.

The Training Process: Let the Learning Begin!

Training is where our model learns from the dataset. We'll use two Keras callbacks:

  • ModelCheckpoint saves the best weights based on validation loss, so we never lose our best model to later overfitting.
  • EarlyStopping halts training if validation loss stops improving for 7 epochs (You can use higher value for patience to prevent the model from getting stuck in local minima), and restore_best_weights=True rolls the model back to that best checkpoint automatically.
model_name = "transformer_model_binary.weights.h5"

if config.IS_TRAINING:
    checkpoints = ModelCheckpoint(
        model_name,
        monitor="val_loss",
        mode="min",
        save_best_only=True,
        save_weights_only=True,
    )

    early_stop = EarlyStopping(
        patience=7,
        monitor="val_loss",
        mode="min",
        restore_best_weights=True,
    )

    history = transformer.fit(
        train_ds,
        epochs=config.EPOCHS,
        validation_data=valid_ds,
        callbacks=[checkpoints, early_stop],
    )

else:
    try:
        transformer.load_weights(model_name)
    except (FileNotFoundError, tf.errors.NotFoundError):
        print(f"Error: weights file {model_name} not found. Set config.IS_TRAINING = True first.")

Grading Our Model: Plotting the Results

After training, plot the loss and accuracy curves:

  • Loss Plot: we want both training and validation loss to decrease and stabilize. If the validation loss starts climbing while training loss keeps dropping, that's overfitting.
  • Accuracy Plot: we want both curves to rise and converge. A widening gap is another overfitting signal.
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss Over Epochs')
plt.xlabel('Epoch'); plt.ylabel('Loss')
plt.legend(); plt.grid(True); plt.show()

plt.figure(figsize=(10, 6))
plt.plot(history.history['masked_accuracy'], label='Training Accuracy')
plt.plot(history.history['val_masked_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy Over Epochs')
plt.xlabel('Epoch'); plt.ylabel('Accuracy')
plt.legend(); plt.grid(True); plt.show()

And there you have it! You've successfully built and trained a complete Transformer model from scratch. In the next part, we'll put our model to the test and see how well it can translate new sentences. Stay tuned!