A walk-through of AxiomMath/axplorer — an open-source PatternBoost implementation that uses a small transformer plus hand-written local search to find extremal combinatorial structures.
Keep a population of the best-known solutions to a discrete extremal problem. Train a small decoder-only transformer to imitate them via next-token prediction. Sample new candidates from the model, repair them with a hand-written local search, score, dedupe, keep the top — repeat. The transformer learns what good solutions look like; local search keeps things valid; the rolling top-k is the curriculum. No RL, no policy gradient — just supervised learning on a moving target.
The technique is from the 2024 PatternBoost paper by Charton, Ellenberg, Wagner & Williamson; axplorer is a clean ~2,400-line re-implementation suitable for adapting to new problems.
Per epoch, in order:
max_steps gradient updates on the current population (next-token cross-entropy).num_samples_from_model new sequences with temperature sampling. Temperature auto-bumps by inc_temp if the model starts producing duplicates.--always_search valid ones get a greedy improvement pass too.pop_size by score (combined with the previous population, optionally deduped via canonical features).The training loop itself is the boring part — it really is just a textbook decoder-only training step:
src/trainer.py · L31–L66view on github →31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46def train(model, args, loader, optim, test_dataset, current_best_loss=None): best_loss = current_best_loss or float("inf") curr_loss = 0 for step in range(args.max_steps): if step % 100 == 0: t0 = time.time() batch = loader.next() batch = [t.to(args.device) for t in batch] X, Y = batch[0], batch[1] _, loss, _ = model(X, Y) model.zero_grad(set_to_none=True) loss.backward() optim.step() curr_loss += loss.item()
The interesting work is everywhere else: how an object is encoded, scored, repaired.
Adding a problem means writing src/envs/<problem>.py with two classes. The cleanest example to read is cycle.py — Turán-style maximum-edges-in-a-C₄-free-graph, 182 lines.
DataPoint subclass — one candidate solutionsrc/envs/environment.py · L11–L26 (base class)view on github →11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26class DataPoint(ABC): def __init__(self): super().__init__() self.score = -1 self.features = "" @abstractmethod def calc_score(self): pass @abstractmethod def calc_features(self): pass def local_search(self, improve_with_local_search): return
For the C₄-free problem, the score is just edge count if the graph is valid, else -1:
src/envs/cycle.py · L26–L37view on github →26 27 28 29 30 31 32 33 34 35 36 37def calc_score(self): if len(self.cycles) > 0: self.score = -1 else: self.score = self.data.sum().item() // 2 def calc_features(self): w = [] for i in range(self.N): for j in range(i + 1, self.N): w.append(self.data[i, j]) self.features = ",".join(map(str, w))
Why this matters: calc_score is the only part the user must get mathematically correct. Everything else is forgiving. program.md explicitly halts the new-environment workflow at a “stop here, have the user audit calc_score” gate — a wrong scorer silently corrupts every downstream epoch.
The hand-written escape hatch — local_search — is what turns the model's noisy samples back into valid objects:
src/envs/cycle.py · L115–L128view on github →115 116 117 118 119 120 121 122 123 124 125 126 127 128def local_search(self, improve_with_local_search): # here I start from a dirty graph, so we need to compute 4-cycles first self._cycles_computation() # first step of local search: remove edges greedily until there is no 4-cycle self._remove_edges_greedily() # second step: add edges greedily while avoiding 4-cycles if improve_with_local_search: self._add_edges_greedily() self.cycles = [] if self.MAKE_OBJECT_CANONICAL: self.data = sort_graph_based_on_degree(self.data) self.calc_features() self.calc_score()
Environment subclass — wires the DataPoint to a tokenizersrc/envs/cycle.py · L139–L169view on github →139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169class SquareEnvironment(BaseEnvironment): # this problem lives in N^2, therefore k=2 # (i, j) or (j, i) represents the same edge, therefore are_coordinates_symmetric=True k = 2 are_coordinates_symmetric = True data_class = SquareDataPoint def __init__(self, params): super().__init__(params) self.data_class.MAKE_OBJECT_CANONICAL = params.make_object_canonical encoding_augmentation = random_symmetry_adj_matrix if params.augment_data_representation else None if params.encoding_tokens == "single_integer": self.tokenizer = SparseTokenizerSingleInteger( self.data_class, params.N, self.k, self.are_coordinates_symmetric, self.SPECIAL_SYMBOLS, encoding_augmentation=encoding_augmentation ) elif params.encoding_tokens == "sequence_k_tokens": self.tokenizer = SparseTokenizerSequenceKTokens(...) elif params.encoding_tokens == "adjacency": self.tokenizer = DenseTokenizer(...) else: raise ValueError(f"Invalid encoding: {params.encoding_tokens}")
And one line in the registry:
src/envs/__init__.pyview on github →1 2 3 4 5from src.envs.cycle import SquareEnvironment from src.envs.isosceles import IsoscelesEnvironment from src.envs.sphere import SphereEnvironment ENVS = {"square": SquareEnvironment, "isosceles": IsoscelesEnvironment, "sphere": SphereEnvironment}
The encoding choice is set per-experiment via --encoding_tokens; all three live in tokenizers.py (260 lines). The pivot point is how k-tuples (edges, points, triples…) get turned into transformer tokens:
| encoding | token-per-tuple | vocab | seq length | good for |
|---|---|---|---|---|
single_integer | 1 | ~Nk | short | sparse objects |
sequence_k_tokens | k | ~N | longer | large N, sparse |
adjacency | — | 2pow2base | O(N²) | dense graphs |
The single-integer encoding is the most natural one to read — it just enumerates index tuples and assigns each a unique token id:
src/envs/tokenizers.py · L8–L22view on github →8 9 10 11 12 13 14 15 16 17 18 19 20 21 22def generate_index_tuples(N, k, are_coordinates_symmetric): if k == 1: yield from range(N) else: if are_coordinates_symmetric: yield from combinations(range(N), k) else: yield from product(range(N), repeat=k) def count_index_tuples(N, k, are_coordinates_symmetric): if are_coordinates_symmetric: return math.comb(N, k) else: return N**k
The transformer in model.py (194 lines) is a small GPT-style decoder. Causal self-attention, GELU MLP, learned embeddings, optionally no positional embeddings (set --no_positional for permutation-invariant problems — graphs, point sets where the order of edges/points is meaningless):
src/models/model.py · L48–L72view on github →48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) self.act = nn.GELU() def forward(self, x): return self.c_proj(self.act(self.c_fc(x))) class Block(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x, past_kv=None): attn_out, present_kv = self.attn(self.ln_1(x), past_kv=past_kv) x = x + attn_out x = x + self.mlp(self.ln_2(x)) return x, present_kv
Default knobs are tiny (n_layer=4, n_embd=256-ish). The whole point is that the heavy lifting is done by the population + local search, not by parameter count.
env_name | object | constraint | maximize |
|---|---|---|---|
square | graph on N vertices | no 4-cycle (C₄) | edges |
isosceles | points in [N]² | no isosceles triangle | points |
sphere | points in [N]³ | no 5 cospherical points | points |
All three are extremal combinatorics: known to be hard for combinatorial search alone, the kind of problem where PatternBoost beat prior bounds in the original paper.
| file | lines | role |
|---|---|---|
train.py | 208 | CLI entry, arg parsing, per-epoch orchestration |
src/trainer.py | 74 | Inner SGD loop + checkpointing |
src/datasets.py | 213 | Tokenize population → batches |
src/evaluator.py | 123 | Sample from model, decode to objects |
src/envs/environment.py | 148 | Base classes + parallel scoring |
src/envs/tokenizers.py | 260 | Three encoding strategies |
cycle.py / isosceles.py / sphere.py | 182 / 319 / 316 | Concrete problems |
src/models/model.py | 194 | Decoder-only transformer |
new_envs.ipynb | — | Tutorial for adding your own problem |
~2,400 lines of Python, readable in an afternoon.
If you've seen AlphaTensor, FunSearch, or any evolutionary search with a learned proposal distribution, axplorer is the same template at small scale:
Everything else — the model, the tokenizer, the data loader — is interchangeable plumbing. That's why the repo is small, and why "implement your problem" is a one-file change.