Go Back

Quickbooks: Teach a GPT to do modular addition... watch what happens

October 31, 2024 by Nicholas Hoffs

```bash !python3 -m pip install git+https://github.com/tinygrad/tinygrad.git ``` ```text Collecting git+https://github.com/tinygrad/tinygrad.git Cloning https://github.com/tinygrad/tinygrad.git to /tmp/pip-req-build-ug2m_50j Running command git clone --filter=blob:none --quiet https://github.com/tinygrad/tinygrad.git /tmp/pip-req-build-ug2m_50j Resolved https://github.com/tinygrad/tinygrad.git to commit f8a623b3863d9212b2ad76b79e3689c0182ab70b Preparing metadata (setup.py) ... [?25l[?25hdone Building wheels for collected packages: tinygrad Building wheel for tinygrad (setup.py) ... [?25l[?25hdone Created wheel for tinygrad: filename=tinygrad-0.9.2-py3-none-any.whl size=1042905 sha256=98446f548ad53faf744fa6f09c677b2893ce73f81928ba8fe4ed29e203b2e690 Stored in directory: /tmp/pip-ephem-wheel-cache-aiwqkikf/wheels/86/f2/16/d5a5b26c57c97399f2a5776383dc8d69a9340af61421b55699 Successfully built tinygrad Installing collected packages: tinygrad Successfully installed tinygrad-0.9.2 ``` ```python from tinygrad import Tensor, dtypes, TinyJit from tinygrad.nn.optim import AdamW from tinygrad.nn.state import get_parameters from tqdm import trange from math import prod import matplotlib.pyplot as plt import numpy as np ``` ```python mod = 113 train_test_ratio = .3 ds_len = mod * mod ``` ```python # [[0,1,2,..,mod,0,1,2,...mod] mod times] a = ( Tensor.arange(mod, dtype=dtypes.int) .repeat((mod, 1)) .flatten(0, -1) .unsqueeze(0) ) # [[0,0,0,...,1,1,1,...,112,112,112] mod times] b = ( Tensor.arange(mod, dtype=dtypes.int) .unsqueeze(-1) .repeat((1, mod)) .flatten(0, -1) .unsqueeze(0) ) # [[113, 113, 113,...,113, 113] mod times] equals = Tensor.full((ds_len), mod).unsqueeze(0) # [[0+0, 1+0, 2+0, ..., 112+0], [0+1, 1+1, 2+1, ..., 112+112]] sum = a + b products = sum.div(mod).floor() * mod # [[0, 1, 2, ..., 112], [1, 2, 3, ..., 113], ...] targets = sum - products ds = a.cat(b, equals, dim=0).T indices = Tensor.randint( ds_len, low=0, high=ds_len, ) ds_shuffled = ds[indices].cast(dtypes.float) targets_shuffled = ( targets[:, indices].cast(dtypes.float).reshape(prod(targets.shape), 1) ) train_cutoff = int(train_test_ratio * ds_len) x_train = ds_shuffled[:train_cutoff] y_train = targets_shuffled[:train_cutoff] x_test = ds_shuffled[train_cutoff:] y_test = targets_shuffled[train_cutoff:] ``` ```python class TransformerBlock: def __init__(self, embed_dim, head_dim, num_heads): self.embed_dim = embed_dim self.head_dim = head_dim self.num_heads = num_heads self.q = Tensor.normal(embed_dim, embed_dim) self.k = Tensor.normal(embed_dim, embed_dim) self.v = Tensor.normal(embed_dim, embed_dim) self.head_out = Tensor.normal(num_heads * head_dim, embed_dim) self.ff1 = Tensor.normal(embed_dim, 4 * embed_dim) self.ff2 = Tensor.normal(4 * embed_dim, embed_dim) def attn(self, x): bsz = x.shape[0] q, k, v = [ x.linear(proj) .reshape(bsz, -1, self.num_heads, self.head_dim) .transpose(1, 2) for proj in (self.q, self.k, self.v) ] return ( q.scaled_dot_product_attention(k, v) .transpose(1, 2) .reshape(bsz, -1, self.num_heads * self.head_dim) .linear(self.head_out) ) def mlp(self, x): return x.linear(self.ff1).relu().linear(self.ff2) def __call__(self, x): x = x + self.attn(x) x = x + self.mlp(x) return x class GPT: def __init__(self, num_layers=1, embed_dim=128, vocab_size=113, context_length=3, num_heads=4): self.num_layers = num_layers self.embed_dim = embed_dim self.vocab_size = vocab_size self.context_length = context_length self.num_heads = num_heads self.tok_embed = Tensor.normal(vocab_size, embed_dim) self.pos_embed = Tensor.normal(context_length, embed_dim) self.blocks = [ TransformerBlock(embed_dim, embed_dim // num_heads, num_heads) for _ in range(num_layers) ] self.out = Tensor.normal(embed_dim, vocab_size - 1) def __call__(self, x): # input shape (B,T,C) bsz = x.shape[0] pos = ( Tensor.arange(self.context_length) .one_hot(self.context_length) .cast(dtypes.float)[: x.shape[1]] .expand((bsz, None, None)) ) x = x.one_hot(self.vocab_size).linear(self.tok_embed) + pos.linear( self.pos_embed ) x = x.sequential(self.blocks) x = x.reshape(-1, x.shape[-1]).linear(self.out) return x.reshape((bsz, -1, x.shape[-1])) ``` ```python def loss_fn(logits: Tensor, labels): log_probs = logits.log_softmax(axis=-1).cast(dtypes.float64) correct = log_probs.gather(dim=-1, index=labels,)[:, 0] return -correct.mean() def train( model, X_train, Y_train, X_test, Y_test, optim, steps=10000, # Adjust this as per the actual training epochs needed lossfn=lambda out, y: out.sparse_categorical_crossentropy(y), allow_jit=True, ): def train_step(x, y): out = model(x)[:, -1] loss = lossfn(out, y) loss.backward() optim.step() optim.zero_grad() return loss.realize() def test_step(x, y): out = model(x)[:, -1] optim.zero_grad() loss = lossfn(out, y) return loss.realize() if allow_jit: train_step = TinyJit(train_step) train_losses = [] test_losses = [] with Tensor.train(): for i in (t := trange(steps)): train_loss = train_step(X_train, Y_train) test_loss = test_step(X_test, Y_test) if test_loss.numpy() < 0.005: break train_losses.append(train_loss.numpy()) test_losses.append(test_loss.numpy()) t.set_description( f"train loss: {train_loss.numpy():.2f}, test loss: {test_loss.numpy():.2f}" ) return train_losses, test_losses ``` ```python model = GPT() optimizer = AdamW(get_parameters(model), lr=1e-3, b1=0.9, b2=0.98, weight_decay=1.0) train_losses, test_losses = train( model, x_train, y_train, x_test, y_test, optimizer, steps=50000, lossfn=loss_fn, ) ``` train loss: 0.00, test loss: 0.01: 40%|███▉ | 19885/50000 [15:16<23:08, 21.69it/s] ```python plt.plot(np.log(train_losses), label="train") plt.plot(np.log(test_losses), label="test") plt.legend() plt.show() ``` This is the famous loss curve (forgive me for the scale of y-axis). I believe the strange periodic behavior is due to numerical instability, but the idea is there -- we see that the train loss plummets quickly, while the test loss remains constant and only plummets after tens of thousands of epochs of training. ![png](grokking/grokking_loss.png) ```python import pickle with open('train.pkl', 'wb') as file: pickle.dump(train_losses, file) with open('test.pkl', 'wb') as file: pickle.dump(test_losses, file) ```