nic hoffs

Teaching a GPT to Do Modular Addition

from tinygrad import Tensor, dtypes, TinyJit
from tinygrad.nn.optim import AdamW
from tinygrad.nn.state import get_parameters
from tqdm import trange
from math import prod
import matplotlib.pyplot as plt
import numpy as np
mod = 113
train_test_ratio = .3
ds_len = mod * mod
# [[0,1,2,..,mod,0,1,2,...mod] mod times]
a = (
    Tensor.arange(mod, dtype=dtypes.int)
    .repeat((mod, 1))
    .flatten(0, -1)
    .unsqueeze(0)
)

# [[0,0,0,...,1,1,1,...,112,112,112] mod times]
b = (
    Tensor.arange(mod, dtype=dtypes.int)
    .unsqueeze(-1)
    .repeat((1, mod))
    .flatten(0, -1)
    .unsqueeze(0)
)

# [[113, 113, 113,...,113, 113] mod times]
equals = Tensor.full((ds_len), mod).unsqueeze(0)

# [[0+0, 1+0, 2+0, ..., 112+0], [0+1, 1+1, 2+1, ..., 112+112]]
sum = a + b
products = sum.div(mod).floor() * mod

# [[0, 1, 2, ..., 112], [1, 2, 3, ..., 113], ...]
targets = sum - products

ds = a.cat(b, equals, dim=0).T

indices = Tensor.randint(
    ds_len,
    low=0,
    high=ds_len,
)

ds_shuffled = ds[indices].cast(dtypes.float)
targets_shuffled = (
    targets[:, indices].cast(dtypes.float).reshape(prod(targets.shape), 1)
)

train_cutoff = int(train_test_ratio * ds_len)


x_train = ds_shuffled[:train_cutoff]
y_train = targets_shuffled[:train_cutoff]
x_test = ds_shuffled[train_cutoff:]
y_test = targets_shuffled[train_cutoff:]
class TransformerBlock:
    def __init__(self, embed_dim, head_dim, num_heads):
        self.embed_dim = embed_dim
        self.head_dim = head_dim
        self.num_heads = num_heads

        self.q = Tensor.normal(embed_dim, embed_dim)
        self.k = Tensor.normal(embed_dim, embed_dim)
        self.v = Tensor.normal(embed_dim, embed_dim)

        self.head_out = Tensor.normal(num_heads * head_dim, embed_dim)

        self.ff1 = Tensor.normal(embed_dim, 4 * embed_dim)
        self.ff2 = Tensor.normal(4 * embed_dim, embed_dim)

    def attn(self, x):
        bsz = x.shape[0]
        q, k, v = [
            x.linear(proj)
            .reshape(bsz, -1, self.num_heads, self.head_dim)
            .transpose(1, 2)
            for proj in (self.q, self.k, self.v)
        ]
        return (
            q.scaled_dot_product_attention(k, v)
            .transpose(1, 2)
            .reshape(bsz, -1, self.num_heads * self.head_dim)
            .linear(self.head_out)
        )

    def mlp(self, x):
        return x.linear(self.ff1).relu().linear(self.ff2)

    def __call__(self, x):
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x


class GPT:
    def __init__(self, num_layers=1, embed_dim=128, vocab_size=113, context_length=3, num_heads=4):
        self.num_layers = num_layers
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.num_heads = num_heads

        self.tok_embed = Tensor.normal(vocab_size, embed_dim)
        self.pos_embed = Tensor.normal(context_length, embed_dim)

        self.blocks = [
            TransformerBlock(embed_dim, embed_dim // num_heads, num_heads)
            for _ in range(num_layers)
        ]

        self.out = Tensor.normal(embed_dim, vocab_size - 1)

    def __call__(self, x):
        # input shape (B,T,C)
        bsz = x.shape[0]
        pos = (
            Tensor.arange(self.context_length)
            .one_hot(self.context_length)
            .cast(dtypes.float)[: x.shape[1]]
            .expand((bsz, None, None))
        )
        x = x.one_hot(self.vocab_size).linear(self.tok_embed) + pos.linear(
            self.pos_embed
        )
        x = x.sequential(self.blocks)
        x = x.reshape(-1, x.shape[-1]).linear(self.out)
        return x.reshape((bsz, -1, x.shape[-1]))
def loss_fn(logits: Tensor, labels):
    log_probs = logits.log_softmax(axis=-1).cast(dtypes.float64)
    correct = log_probs.gather(dim=-1, index=labels,)[:, 0]
    return -correct.mean()

def train(
    model,
    X_train,
    Y_train,
    X_test,
    Y_test,
    optim,
    steps=10000,  # Adjust this as per the actual training epochs needed
    lossfn=lambda out, y: out.sparse_categorical_crossentropy(y),
    allow_jit=True,
):
    def train_step(x, y):
        out = model(x)[:, -1]
        loss = lossfn(out, y)
        loss.backward()
        optim.step()
        optim.zero_grad()
        return loss.realize()

    def test_step(x, y):
        out = model(x)[:, -1]
        optim.zero_grad()
        loss = lossfn(out, y)
        return loss.realize()

    if allow_jit:
        train_step = TinyJit(train_step)

    train_losses = []
    test_losses = []
    with Tensor.train():
        for i in (t := trange(steps)):
            train_loss = train_step(X_train, Y_train)
            test_loss = test_step(X_test, Y_test)
            if test_loss.numpy() < 0.005:
              break
            train_losses.append(train_loss.numpy())
            test_losses.append(test_loss.numpy())

            t.set_description(
                f"train loss: {train_loss.numpy():.2f}, test loss: {test_loss.numpy():.2f}"
            )
    return train_losses, test_losses
model = GPT()

optimizer = AdamW(get_parameters(model), lr=1e-3, b1=0.9, b2=0.98, weight_decay=1.0)

train_losses, test_losses = train(
        model,
        x_train,
        y_train,
        x_test,
        y_test,
        optimizer,
        steps=50000,
        lossfn=loss_fn,
)
train loss: 0.00, test loss: 0.01:  40%|███▉      | 19885/50000 [15:16<23:08, 21.69it/s]
plt.plot(np.log(train_losses), label="train")
plt.plot(np.log(test_losses), label="test")
plt.legend()
plt.show()

This is the famous loss curve (forgive me for the scale of y-axis). I believe the strange periodic behavior is due to numerical instability, but the idea is there -- we see that the train loss plummets quickly, while the test loss remains constant and only plummets after tens of thousands of epochs of training. markov_chain

import pickle

with open('train.pkl', 'wb') as file:
    pickle.dump(train_losses, file)

with open('test.pkl', 'wb') as file:
    pickle.dump(test_losses, file)