Created
November 7, 2025 17:50
-
-
Save Epivalent/bd3e8d05be39267d8e71cf8f4a88f614 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # THRML unified sampler (clean rewrite) | |
| # - Perfect matching on planar cubic bridgeless graphs via 2-face-coloring + NAE vertex constraints. | |
| # - Face K-coloring (default K=4) on the dual graph via Potts (penalize equal colors on adjacent faces). | |
| # Supports two input modes: | |
| # (1) Lambda term -> primal rotation system -> dual faces, Tutte embedding for plotting | |
| # (2) Adjacency matrix (0/1 square text file) of the dual graph for direct face K-coloring | |
| # Includes: | |
| # * Block Gibbs with optional block partition via greedy vertex-coloring (--block-coloring) | |
| # * Optional anneal ladder (--anneal, --beta-start/end, --phases, --sweeps-per-phase) | |
| # * Adaptive enforcement: --enforce-proper (coloring) / --enforce-perfect (matching) | |
| # * Plotting with adaptive radius and optional edge downsampling for big graphs | |
| import math | |
| import argparse | |
| import time | |
| from datetime import datetime | |
| from collections import defaultdict, deque | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # ---- lambda-term utilities (provided by user's lambda_map7.py) ---- | |
| from lambda_map7 import ( | |
| tokenize, Parser, | |
| build_rotation_with_labels, | |
| enumerate_faces, face_containing_edge, | |
| tutte_embedding_with_largest_face, tutte_embedding_with_face, tutte_embedding_with_face_aligned, | |
| ) | |
| # ---- THRML imports ---- | |
| from thrml.block_management import Block | |
| from thrml.block_sampling import BlockGibbsSpec, sample_states, SamplingSchedule | |
| from thrml.models.discrete_ebm import CategoricalEBMFactor, CategoricalGibbsConditional | |
| from thrml.factor import FactorSamplingProgram | |
| from thrml.pgm import CategoricalNode | |
| # -------------------- Logging helpers -------------------- | |
| _t0 = time.time() | |
| def _now(): return datetime.now().strftime("%H:%M:%S") | |
| def _elapsed(): return f"{(time.time()-_t0):7.2f}s" | |
| def log(msg, enable=True): | |
| if enable: | |
| print(f"[{_now()} | {_elapsed()}] {msg}", flush=True) | |
| # -------------------- Geometry / constraints (lambda-term path) -------------------- | |
| def he2f_map(rotation): | |
| """Map half-edge -> face id by walking 'next-left' around faces.""" | |
| def build_halfedge_next_left(rotation): | |
| nxt = {} | |
| for u, nbrs in rotation.items(): | |
| for v in nbrs: | |
| order = rotation[v] | |
| idx = order.index(u) | |
| w = order[(idx - 1) % len(order)] | |
| nxt[(u, v)] = (v, w) | |
| return nxt | |
| nxt = build_halfedge_next_left(rotation) | |
| visited=set(); faces_list=[]; he2f={} | |
| for u,nbrs in rotation.items(): | |
| for v in nbrs: | |
| he=(u,v) | |
| if he in visited: continue | |
| cyc=[]; cur=he | |
| while cur not in visited: | |
| visited.add(cur); cyc.append(cur); cur=nxt[cur] | |
| fid=len(faces_list); faces_list.append(cyc) | |
| for h in cyc: he2f[h]=fid | |
| return he2f, faces_list | |
| def vertex_face_triples(rotation): | |
| """For each primal vertex (cubic), return the 3 incident face ids (in rotation order).""" | |
| he2f, faces_list = he2f_map(rotation) | |
| V_triples = {} | |
| F = len(faces_list) | |
| for v, nbrs in rotation.items(): | |
| if len(nbrs) != 3: | |
| raise ValueError("Graph must be cubic for this model") | |
| f = tuple(he2f[(v, u)] for u in nbrs) | |
| V_triples[v] = tuple(f) | |
| return V_triples, faces_list | |
| def dual_edges_from_rotation(rotation): | |
| """List of dual edges (unordered face pairs) from the primal rotation system.""" | |
| edges, faces, edge_faces = enumerate_faces(rotation) | |
| seen=set(); pairs=[] | |
| for (u,v),(lf,rf) in edge_faces.items(): | |
| a,b = (lf,rf) if lf<rf else (rf,lf) | |
| if (a,b) not in seen: | |
| seen.add((a,b)); pairs.append((a,b)) | |
| return pairs | |
| def tutte_embedding_for_term(rotation, root_var_edge): | |
| fid_guess, _ = face_containing_edge(rotation, root_var_edge) | |
| if fid_guess is None: | |
| pos, outer_face = tutte_embedding_with_largest_face(rotation) | |
| else: | |
| pos, outer_face = tutte_embedding_with_face(rotation, fid_guess) | |
| # Align a boundary edge for stable orientation | |
| outer_vertices = [] | |
| for (u,v) in enumerate_faces(rotation)[1][outer_face]: | |
| if u not in outer_vertices: | |
| outer_vertices.append(u) | |
| Lb = len(outer_vertices) | |
| top_v = max(outer_vertices, key=lambda vv: pos[vv][1]) | |
| tidx = outer_vertices.index(top_v) | |
| left_v = outer_vertices[(tidx - 1) % Lb] | |
| pos, outer_face = tutte_embedding_with_face_aligned(rotation, outer_face, (left_v, top_v)) | |
| return pos, outer_face | |
| def dual_bfs_bitstring(labels, rotation, outer_face): | |
| """Produce a bitstring-like summary by BFS in the dual starting at the outer face (root '0').""" | |
| edges, faces, edge_faces = enumerate_faces(rotation) | |
| dual_adj = defaultdict(list) | |
| for (u,v), (lf,rf) in edge_faces.items(): | |
| dual_adj[lf].append(rf); dual_adj[rf].append(lf) | |
| order = [] | |
| seen = {outer_face} | |
| q = deque([outer_face]) | |
| while q: | |
| f = q.popleft() | |
| for g in dual_adj[f]: | |
| if g not in seen: | |
| seen.add(g); order.append(g); q.append(g) | |
| # hex digit per face label (works for K<=16) with leading 0 for root | |
| return "0" + "".join(hex(int(labels[f]))[2:] for f in order) | |
| # -------------------- Adjacency input (dual graph) -------------------- | |
| def parse_adj_matrix(path): | |
| """Read square 0/1 adjacency, with or without spaces between digits.""" | |
| import numpy as onp | |
| rows = [] | |
| with open(path, "r", encoding="utf-8") as fh: | |
| for line in fh: | |
| line = line.strip() | |
| if not line: continue | |
| if " " in line or " " in line: | |
| toks = [t for t in line.replace(" "," ").split(" ") if t!=""] | |
| else: | |
| toks = list(line) | |
| row = [int(x) for x in toks] | |
| rows.append(row) | |
| if not rows: raise ValueError("Empty adjacency file") | |
| n = len(rows) | |
| for r in rows: | |
| if len(r) != n: raise ValueError("Adjacency must be square") | |
| for v in r: | |
| if v not in (0,1): raise ValueError("Adjacency entries must be 0/1") | |
| A = onp.array(rows, dtype=onp.uint8) | |
| onp.fill_diagonal(A, 0) | |
| A = ((A + A.T) > 0).astype(onp.uint8) # symmetrize and binarize | |
| return A | |
| def adj_edges(A): | |
| """Return list of (i,j) with i<j where A[i,j]==1.""" | |
| n = A.shape[0]; pairs=[] | |
| for i in range(n): | |
| for j in range(i+1, n): | |
| if A[i,j]: pairs.append((i,j)) | |
| return pairs | |
| def greedy_vertex_coloring(A): | |
| """Greedy vertex coloring to produce large independent sets as Gibbs blocks.""" | |
| import numpy as onp | |
| n = A.shape[0] | |
| deg = A.sum(axis=1).astype(int) | |
| order = list(onp.argsort(-deg)) # high-degree first | |
| color_of = [-1]*n | |
| classes = [] | |
| for v in order: | |
| forbidden = {color_of[u] for u in range(n) if A[v,u] and color_of[u] != -1} | |
| c = 0 | |
| while c in forbidden: c += 1 | |
| color_of[v] = c | |
| if c >= len(classes): classes.append([v]) | |
| else: classes[c].append(v) | |
| return classes # list of lists of vertices | |
| def generic_layout(A, seed=0): | |
| """Spring layout (via networkx if available) or fallback to circle.""" | |
| try: | |
| import networkx as nx | |
| G = nx.from_numpy_array(A) | |
| pos = nx.spring_layout(G, seed=seed, dim=2) | |
| return {int(k):(float(v[0]), float(v[1])) for k,v in pos.items()} | |
| except Exception: | |
| import numpy as onp | |
| n = A.shape[0] | |
| ang = onp.linspace(0, 2*onp.pi, n, endpoint=False) | |
| R = 1.0 | |
| return {i:(float(R*onp.cos(ang[i])), float(R*onp.sin(ang[i]))) for i in range(n)} | |
| # -------------------- Plotting -------------------- | |
| def adaptive_radius(span, n_vertices, base_frac=0.012, min_frac=0.004): | |
| """Scale node radius as ~ 1/sqrt(n) with clamps for readability.""" | |
| import numpy as onp | |
| n = max(int(n_vertices), 1) | |
| scale = onp.sqrt(30.0 / n) | |
| frac = base_frac * float(scale) | |
| frac = float(min(base_frac, max(min_frac, frac))) | |
| return frac * float(span) | |
| def draw_term_coloring(rotation, faces_list, pos, labels, title, out_path, | |
| palette=None, mono=False, mono_color="#9370DB", alpha=0.30): | |
| import numpy as onp | |
| fig = plt.figure(figsize=(8.6, 8.6)); ax = plt.gca() | |
| ax.set_aspect("equal"); ax.set_xticks([]); ax.set_yticks([]) | |
| if mono: | |
| for fid, cyc in enumerate(faces_list): | |
| if int(labels[fid]) == 1: | |
| pts = onp.array([pos[u] for (u,_) in cyc]) | |
| ax.fill(pts[:,0], pts[:,1], alpha=alpha, color=mono_color, zorder=0.1) | |
| else: | |
| if palette is None: | |
| palette = ["#4c78a8","#f58518","#e45756","#72b7b2"] | |
| for fid, cyc in enumerate(faces_list): | |
| c = palette[int(labels[fid]) % len(palette)] | |
| pts = onp.array([pos[u] for (u,_) in cyc]) | |
| ax.fill(pts[:,0], pts[:,1], alpha=alpha, color=c, zorder=0.1) | |
| edges, _, _ = enumerate_faces(rotation) | |
| for (u,v) in edges: | |
| x1,y1 = pos[u]; x2,y2 = pos[v] | |
| ax.plot([x1,x2],[y1,y2], lw=1.0, alpha=0.45, color="#444", zorder=0.8) | |
| xs = [p[0] for p in pos.values()]; ys=[p[1] for p in pos.values()] | |
| span = max(max(xs)-min(xs), max(ys)-min(ys)) | |
| R = adaptive_radius(span, len(rotation)); r = 0.6 * R | |
| for v in rotation: | |
| x,y = pos[v] | |
| circ_out = plt.Circle((x,y), R, fc="black", ec="black", lw=1.0, zorder=2.0) | |
| circ_in = plt.Circle((x,y), r, fc=(1,1,1,0), ec="none", zorder=2.1) | |
| ax.add_patch(circ_out); ax.add_patch(circ_in) | |
| ax.set_title(title, fontsize=10); plt.tight_layout(); plt.savefig(out_path, dpi=180); plt.close(fig) | |
| def draw_term_matching(rotation, faces_list, pos, spins, matching, title, out_path, | |
| face_alpha=0.30, face_color="#9370DB"): | |
| import numpy as onp | |
| fig = plt.figure(figsize=(8.6, 8.6)); ax = plt.gca() | |
| ax.set_aspect("equal"); ax.set_xticks([]); ax.set_yticks([]) | |
| for fid, cyc in enumerate(faces_list): | |
| if int(spins[fid]) == 1: | |
| pts = onp.array([pos[u] for (u,_) in cyc]) | |
| ax.fill(pts[:,0], pts[:,1], alpha=face_alpha, color=face_color, zorder=0.1) | |
| edges, _, _ = enumerate_faces(rotation) | |
| for (u,v) in edges: | |
| x1,y1 = pos[u]; x2,y2 = pos[v] | |
| ax.plot([x1,x2],[y1,y2], lw=1.0, alpha=0.45, color="#444", zorder=0.8) | |
| # draw matching as thick bands | |
| def data_width(ax, lw_pts): | |
| fig = ax.figure; dpi = fig.dpi | |
| px = lw_pts * dpi / 72.0 | |
| xlim = ax.get_xlim(); ylim = ax.get_ylim() | |
| x_per_px = (xlim[1]-xlim[0]) / ax.bbox.width | |
| y_per_px = (ylim[1]-ylim[0]) / ax.bbox.height | |
| data_per_px = math.sqrt((x_per_px**2 + y_per_px**2)/2.0) | |
| return px * data_per_px | |
| xs = [p[0] for p in pos.values()]; ys=[p[1] for p in pos.values()] | |
| span = max(max(xs)-min(xs), max(ys)-min(ys)) | |
| _ = adaptive_radius(span, len(rotation)) | |
| w = 4.0 * data_width(ax, 1.0) | |
| for (u,v) in matching: | |
| p = np.array(pos[u]); q = np.array(pos[v]) | |
| d = q - p; L = np.hypot(d[0], d[1]) + 1e-12 | |
| n = np.array([-d[1], d[0]]) / L | |
| A = p + n*(w/2); B = p - n*(w/2); C = q - n*(w/2); D = q + n*(w/2) | |
| quad = np.vstack([A,B,C,D]) | |
| ax.fill(quad[:,0], quad[:,1], alpha=0.35, color="#ff4d4d", zorder=1.2) | |
| ax.set_title(title, fontsize=10); plt.tight_layout(); plt.savefig(out_path, dpi=180); plt.close(fig) | |
| def draw_adj_coloring(A, pos, labels, title, out_path, palette=None, alpha=0.9, node_size=None, max_edges=None): | |
| import numpy as onp | |
| if palette is None: | |
| palette = ["#4c78a8","#f58518","#e45756","#72b7b2"] | |
| fig = plt.figure(figsize=(8.6, 8.6)); ax = plt.gca() | |
| ax.set_aspect("equal"); ax.set_xticks([]); ax.set_yticks([]) | |
| n = A.shape[0] | |
| edges_drawn = 0 | |
| for i in range(n): | |
| for j in range(i+1, n): | |
| if A[i,j]: | |
| if (max_edges is not None) and (edges_drawn >= max_edges): | |
| break | |
| x1,y1 = pos[i]; x2,y2 = pos[j] | |
| ax.plot([x1,x2],[y1,y2], lw=1.0, alpha=0.25, color="#444", zorder=0.5) | |
| edges_drawn += 1 | |
| if (max_edges is not None) and (edges_drawn >= max_edges): | |
| break | |
| xs = [p[0] for p in pos.values()]; ys = [p[1] for p in pos.values()] | |
| span = max(max(xs)-min(xs), max(ys)-min(ys)) or 1.0 | |
| if node_size is None: | |
| R = adaptive_radius(span, n) | |
| else: | |
| R = float(node_size) * span | |
| r = 0.6 * R | |
| for i in range(n): | |
| x,y = pos[i]; col = palette[int(labels[i]) % len(palette)] | |
| circ = plt.Circle((x,y), R, fc=col, ec="#222", lw=0.8, alpha=alpha, zorder=1.0) | |
| ax.add_patch(circ) | |
| inner = plt.Circle((x,y), r, fc=(1,1,1,0), ec="none", zorder=1.1) | |
| ax.add_patch(inner) | |
| ax.set_title(title, fontsize=10); plt.tight_layout(); plt.savefig(out_path, dpi=180); plt.close(fig) | |
| # -------------------- THRML model construction -------------------- | |
| def build_prog_matching(V_triples, nodes, penalty=1.0, outer_node=None, clamp_penalty=25.0): | |
| """2-color faces; NAE at each vertex; softly clamp outer face to color 0 (symmetry).""" | |
| a_nodes, b_nodes, c_nodes = [], [], [] | |
| for _, (f1, f2, f3) in V_triples.items(): | |
| a_nodes.append(nodes[f1]); b_nodes.append(nodes[f2]); c_nodes.append(nodes[f3]) | |
| K = 2 | |
| W = jnp.zeros((len(a_nodes), K, K, K), dtype=jnp.float32) | |
| # Penalize all-0 and all-1 (NAE) | |
| W = W.at[:, 0,0,0].set(-penalty) | |
| W = W.at[:, 1,1,1].set(-penalty) | |
| factors = [CategoricalEBMFactor([Block(a_nodes), Block(b_nodes), Block(c_nodes)], W)] | |
| if outer_node is not None: | |
| Wu = jnp.zeros((1, K), dtype=jnp.float32) | |
| Wu = Wu.at[:, 1].set(-clamp_penalty) # prefer 0 | |
| factors.append(CategoricalEBMFactor([Block([outer_node])], Wu)) | |
| free_blocks = [Block([n]) for n in nodes] | |
| spec = BlockGibbsSpec(free_blocks, []) | |
| sampler = CategoricalGibbsConditional(K) | |
| prog = FactorSamplingProgram(spec, [sampler for _ in spec.free_blocks], factors, []) | |
| return prog, free_blocks, K | |
| def build_prog_faceK(nodes, dual_pairs, K=4, penalty_equal=1.0, outer_node=None, | |
| clamp_color=0, clamp_penalty=25.0, neighbor_node=None, neighbor_color=1, neighbor_penalty=10.0, | |
| color_blocks=None): | |
| """Potts: penalize equal colors on adjacent faces; add soft clamps for symmetry breaking.""" | |
| a_nodes = [nodes[a] for (a,b) in dual_pairs] | |
| b_nodes = [nodes[b] for (a,b) in dual_pairs] | |
| W = jnp.zeros((len(dual_pairs), K, K), dtype=jnp.float32) | |
| for k in range(K): | |
| W = W.at[:, k, k].set(-penalty_equal) # penalize equality | |
| factors = [CategoricalEBMFactor([Block(a_nodes), Block(b_nodes)], W)] | |
| if outer_node is not None: | |
| Wu = jnp.zeros((1, K), dtype=jnp.float32) | |
| Wu = Wu.at[:, :].set(-clamp_penalty) | |
| Wu = Wu.at[:, clamp_color].set(0.0) | |
| factors.append(CategoricalEBMFactor([Block([outer_node])], Wu)) | |
| if neighbor_node is not None: | |
| Wu2 = jnp.zeros((1, K), dtype=jnp.float32) | |
| Wu2 = Wu2.at[:, :].set(-neighbor_penalty) | |
| Wu2 = Wu2.at[:, neighbor_color].set(0.0) | |
| factors.append(CategoricalEBMFactor([Block([neighbor_node])], Wu2)) | |
| if color_blocks: | |
| free_blocks = [Block([nodes[i] for i in cls]) for cls in color_blocks] | |
| else: | |
| free_blocks = [Block([n]) for n in nodes] | |
| spec = BlockGibbsSpec(free_blocks, []) | |
| sampler = CategoricalGibbsConditional(K) | |
| prog = FactorSamplingProgram(spec, [sampler for _ in spec.free_blocks], factors, []) | |
| return prog, free_blocks, K | |
| # -------------------- Sampling helpers -------------------- | |
| def normalize_results(out): | |
| """THRML returns nested (mem, results); unwrap the states tensor.""" | |
| res = out | |
| if isinstance(out, (list, tuple)): | |
| res = out[-1] | |
| if isinstance(res, (list, tuple)): | |
| res = res[0] | |
| return res | |
| def run_thrml_sampling(key, prog, free_blocks, all_nodes, n_chains, n_warmup, n_samples, | |
| steps_per_sample=2, jit_chains=False, trace=False, init_states=None): | |
| """Run sampling; optionally vectorize chains with jit+vmap; accept per-block initial states.""" | |
| k_init, k_samp = jax.random.split(key, 2) | |
| init_per_block = [] | |
| for bi, block in enumerate(free_blocks): | |
| k_init, sub = jax.random.split(k_init, 2) | |
| if init_states and bi < len(init_states) and init_states[bi] is not None: | |
| arr = jnp.asarray(init_states[bi]) | |
| if arr.ndim == 1: arr = arr[None, :] | |
| if arr.shape[0] != n_chains: | |
| arr = jnp.broadcast_to(arr, (n_chains, arr.shape[1])) | |
| else: | |
| # default K=4 initializer; THRML ignores invalid categories via sampler range | |
| arr = jax.random.randint(sub, (n_chains, len(block.nodes)), 0, 4, dtype=jnp.uint8) | |
| init_per_block.append(arr.astype(jnp.uint8)) | |
| schedule = SamplingSchedule(n_warmup=n_warmup, n_samples=n_samples, steps_per_sample=steps_per_sample) | |
| all_block = Block(all_nodes) | |
| if jit_chains: | |
| log("Compiling jit/vmap chain function...", trace) | |
| def one_chain_sample(subkey, *init_blocks): | |
| init_state_chain = [b for b in init_blocks] | |
| out = sample_states(subkey, prog, schedule, init_state_chain, [], [all_block]) | |
| return normalize_results(out) | |
| keys = jax.random.split(k_samp, n_chains) | |
| try: | |
| vmapped = jax.jit(jax.vmap(one_chain_sample, in_axes=(0,)+ (0,)*len(init_per_block))) | |
| samples = vmapped(keys, *[arr for arr in init_per_block]) | |
| return np.array(samples) | |
| except Exception as e: | |
| log(f"jit/vmap failed ({e}); falling back to sequential loop.", trace) | |
| keys = jax.random.split(k_samp, n_chains) | |
| results = [] | |
| for ci in range(n_chains): | |
| init_state_chain = [arr[ci] for arr in init_per_block] | |
| out = sample_states(keys[ci], prog, schedule, init_state_chain, [], [all_block]) | |
| results.append(np.array(normalize_results(out))) | |
| return np.stack(results, axis=0) # (chains, n_samples, total_nodes) | |
| def pick_best_sample(samples_array, score_fn): | |
| """Return (best_vec, best_score) by scanning chains × samples with a given score function.""" | |
| best = None; best_score = None | |
| C, S, D = samples_array.shape | |
| for c in range(C): | |
| for s in range(S): | |
| vec = samples_array[c, s] | |
| sc = score_fn(vec) | |
| if (best is None) or (sc < best_score): | |
| best = vec; best_score = sc | |
| return best, float(best_score) | |
| def anneal_ladder(key, build_prog_fn, nodes, phases, betas, sweeps_per_phase, | |
| steps_per_sample=1, jit_chains=False, trace=False, precompile=False): | |
| """Run a short single-chain schedule at increasing β to produce a good init state.""" | |
| prog, free_blocks, K = build_prog_fn(betas[0]) | |
| if precompile: | |
| schedule_pc = SamplingSchedule(n_warmup=1, n_samples=1, steps_per_sample=steps_per_sample) | |
| all_block_pc = Block(nodes) | |
| _ = sample_states(key, prog, schedule_pc, [jnp.zeros((len(b.nodes),), dtype=jnp.uint8) for b in free_blocks], [], [all_block_pc]) | |
| state_for_next = None | |
| k = key | |
| for i in range(phases): | |
| beta_i = betas[i] | |
| log(f"Anneal phase {i+1}/{phases} at β={beta_i}", trace) | |
| if i > 0: | |
| prog, free_blocks, K = build_prog_fn(beta_i) | |
| init_states = None | |
| if state_for_next is not None: | |
| init_states = [] | |
| offset = 0 | |
| for b in free_blocks: | |
| size = len(b.nodes) | |
| vec = state_for_next[offset:offset+size] | |
| init_states.append(vec) | |
| offset += size | |
| k, sub = jax.random.split(k, 2) | |
| samples = run_thrml_sampling(sub, prog, free_blocks, nodes, n_chains=1, | |
| n_warmup=sweeps_per_phase, n_samples=1, | |
| steps_per_sample=steps_per_sample, jit_chains=jit_chains, trace=trace, | |
| init_states=init_states) | |
| state_for_next = samples[0, -1] | |
| init_states = [] | |
| offset = 0 | |
| for b in free_blocks: | |
| size = len(b.nodes) | |
| vec = state_for_next[offset:offset+size] | |
| init_states.append(vec[None, :]) | |
| offset += size | |
| return init_states, free_blocks, K | |
| # -------------------- CLI & main -------------------- | |
| DEF_TERM = r"λx. λy. λz. λw. λu. λv. λt. λp. λq. x (λr. λs. y (λm. λn. z (λo. w (λϕ. u (λψ. λω. v (λa. t (λb. p (λc. q (r ((s m) (n (o (ϕ ψ))))) (ω (a (b c)))))))))))" | |
| def main(): | |
| ap = argparse.ArgumentParser(description="THRML sampler: perfect matching (2-color+NAE) or face 4-coloring (Potts)." ) | |
| ap.add_argument("--task", choices=["matching","face4"], default="matching") | |
| ap.add_argument("--term", type=str, default=DEF_TERM, help="Planar lambda term (quoted)." ) | |
| ap.add_argument("--adj-file", type=str, default="", help="Square 0/1 adjacency for the dual graph (face coloring)." ) | |
| ap.add_argument("--K", type=int, default=4, help="Number of face colors for --task face4." ) | |
| # penalties (pre-β); we scale by β internally | |
| ap.add_argument("--penalty", type=float, default=1.0, help="Penalty per violated constraint (pre-β)." ) | |
| ap.add_argument("--clamp-penalty", type=float, default=25.0, help="Penalty for symmetry clamps (pre-β)." ) | |
| ap.add_argument("--neighbor-penalty", type=float, default=10.0, help="Penalty for neighbor clamp (pre-β)." ) | |
| # anneal | |
| ap.add_argument("--anneal", action="store_true", help="Use β ladder preconditioning." ) | |
| ap.add_argument("--beta-start", type=float, default=2.0) | |
| ap.add_argument("--beta-end", type=float, default=8.0) | |
| ap.add_argument("--phases", type=int, default=4) | |
| ap.add_argument("--sweeps-per-phase", type=int, default=40) | |
| # final sampling | |
| ap.add_argument("--beta", type=float, default=6.0, help="β if no anneal; ignored if --anneal (we use beta-end)." ) | |
| ap.add_argument("--chains", type=int, default=1) | |
| ap.add_argument("--warmup", type=int, default=150) | |
| ap.add_argument("--samples", type=int, default=40) | |
| ap.add_argument("--steps-per-sample", type=int, default=2) | |
| ap.add_argument("--jit-chains", action="store_true") | |
| ap.add_argument("--seed", type=int, default=0) | |
| # adaptive | |
| ap.add_argument("--enforce-proper", action="store_true", help="For face4: adapt β/penalty/sweeps to reach target conflicts." ) | |
| ap.add_argument("--target-conflicts", type=int, default=0) | |
| ap.add_argument("--enforce-perfect", action="store_true", help="For matching: adapt to drive NAE violations to 0." ) | |
| ap.add_argument("--max-rounds", type=int, default=6) | |
| ap.add_argument("--beta-mult", type=float, default=1.4) | |
| ap.add_argument("--penalty-mult", type=float, default=1.25) | |
| ap.add_argument("--warmup-mult", type=float, default=1.5) | |
| ap.add_argument("--steps-inc", type=int, default=1) | |
| ap.add_argument("--restarts", type=int, default=0) | |
| # plotting | |
| ap.add_argument("--png", type=str, default="thrml_out.png") | |
| ap.add_argument("--energy", type=str, default="thrml_energy.png") | |
| ap.add_argument("--trace", action="store_true") | |
| ap.add_argument("--precompile", type=int, default=0) | |
| ap.add_argument("--mono", action="store_true", help="Term face-coloring: shade faces of color==1 only." ) | |
| ap.add_argument("--skip-plot", action="store_true") | |
| ap.add_argument("--plot-max-edges", type=int, default=None, help="Adjacency plotting: cap edges for speed." ) | |
| ap.add_argument("--block-coloring", action="store_true", help="Adjacency: use greedy independent-set blocks for faster sampling." ) | |
| args = ap.parse_args() | |
| log("Start", args.trace) | |
| key = jax.random.key(args.seed) | |
| # -------------- Adjacency path (face coloring) -------------- | |
| if args.adj_file: | |
| if args.task != "face4": | |
| raise SystemExit("--adj-file is only supported with --task face4.") | |
| log(f"Reading adjacency from {args.adj_file} ...", args.trace) | |
| A = parse_adj_matrix(args.adj_file) | |
| n = A.shape[0] | |
| log(f"Adjacency loaded: n={n}, edges={int(A.sum()//2)}", args.trace) | |
| K = int(args.K) | |
| nodes = [CategoricalNode() for _ in range(n)] | |
| # symmetry clamps: fix node 0 to color 0; fix one neighbor to color 1 | |
| neighbor_idx = None | |
| for j in range(1, n): | |
| if A[0, j]: | |
| neighbor_idx = j; break | |
| pairs = adj_edges(A) | |
| classes = greedy_vertex_coloring(A) if args.block_coloring else None | |
| def make_prog(beta_scale: float, penalty_scale: float): | |
| return build_prog_faceK( | |
| nodes, pairs, K=K, | |
| penalty_equal=penalty_scale * beta_scale, | |
| outer_node=nodes[0], clamp_color=0, clamp_penalty=args.clamp_penalty * beta_scale, | |
| neighbor_node=(nodes[neighbor_idx] if neighbor_idx is not None else None), | |
| neighbor_color=1, neighbor_penalty=args.neighbor_penalty * beta_scale, | |
| color_blocks=classes | |
| ) | |
| def conflicts_of(vec): | |
| # number of monochromatic edges | |
| return sum(1 for (a,b) in pairs if int(vec[a]) == int(vec[b])) | |
| # Prep β/penalty based on anneal or fixed | |
| beta_cur = (args.beta_end if args.anneal else args.beta) | |
| penalty_cur = args.penalty | |
| warmup_cur = args.warmup | |
| steps_cur = args.steps_per_sample | |
| init_states=None; free_blocks=None; _K=None | |
| if args.anneal: | |
| betas = np.linspace(args.beta_start, args.beta_end, num=args.phases).tolist() | |
| log(f"Anneal ladder: betas={betas}, sweeps_per_phase={args.sweeps_per_phase}", args.trace) | |
| start_ann = time.time() | |
| def make_for_anneal(b): return make_prog(b, penalty_cur) | |
| init_states, free_blocks, _K = anneal_ladder(key, make_for_anneal, nodes, args.phases, betas, | |
| sweeps_per_phase=args.sweeps_per_phase, | |
| steps_per_sample=1, jit_chains=False, trace=args.trace, precompile=False) | |
| log(f"Anneal finished in {time.time()-start_ann:.2f}s", args.trace) | |
| else: | |
| prog, free_blocks, _K = make_prog(beta_cur, penalty_cur) | |
| # Adaptive loop (if requested) | |
| best_vec=None; best_conf=None | |
| rounds = 1 if not args.enforce_proper else args.max_rounds | |
| for ri in range(rounds): | |
| if args.enforce_proper and ri > 0: | |
| beta_cur *= args.beta_mult | |
| penalty_cur *= args.penalty_mult | |
| warmup_cur = int(max(warmup_cur * args.warmup_mult, warmup_cur + 1)) | |
| steps_cur += args.steps_inc | |
| log(f"[adaptive] round {ri+1}: β={beta_cur:.3g}, penalty={penalty_cur:.3g}, warmup={warmup_cur}, steps={steps_cur}", args.trace) | |
| prog, free_blocks, _K = make_prog(beta_cur, penalty_cur) | |
| n_tries = max(1, args.restarts+1) | |
| cand_vec=None; cand_conf=None | |
| for ti in range(n_tries): | |
| log(f"Sampling: round={ri+1}/{rounds}, try={ti+1}/{n_tries} | chains={args.chains}, warmup={warmup_cur}, samples={args.samples}, steps={steps_cur} ...", args.trace) | |
| start_samp = time.time() | |
| samples = run_thrml_sampling(key, prog, free_blocks, nodes, n_chains=args.chains, | |
| n_warmup=warmup_cur, n_samples=args.samples, | |
| steps_per_sample=steps_cur, | |
| jit_chains=args.jit_chains, trace=args.trace, | |
| init_states=init_states) | |
| log(f"Sampling finished in {time.time()-start_samp:.2f}s", args.trace) | |
| samples_array = np.array(samples) | |
| vec, conf = pick_best_sample(samples_array, conflicts_of) | |
| if (cand_vec is None) or (conf < cand_conf): | |
| cand_vec, cand_conf = vec, conf | |
| init_states = None # subsequent try restarts randomly | |
| if (best_vec is None) or (cand_conf < best_conf): | |
| best_vec, best_conf = cand_vec, cand_conf | |
| log(f"[adaptive] best conflicts this round: {cand_conf}", args.trace) | |
| if best_conf <= args.target_conflicts: | |
| log(f"[adaptive] target reached: conflicts={best_conf}", args.trace) | |
| break | |
| final_labels = best_vec | |
| conflicts = int(best_conf) | |
| proper = (conflicts <= args.target_conflicts) | |
| # Plot + trace (plot can be skipped or edge-limited for speed) | |
| pos = generic_layout(A, seed=args.seed) | |
| title = (f"THRML face-{K} coloring (adj): β≈{beta_cur:.2f}, penalty≈{penalty_cur:.2f} | " | |
| f"conflicts={conflicts} | proper={proper}") | |
| if not args.skip_plot: | |
| draw_adj_coloring(A, pos, final_labels, title, args.png, max_edges=args.plot_max_edges) | |
| log(f"Graph saved: {args.png}", args.trace) | |
| try: | |
| mean_conf = [] | |
| for sidx in range(samples_array.shape[1]): | |
| cur = samples_array[:, sidx, :] | |
| cs = [] | |
| for ch in range(cur.shape[0]): | |
| cs.append(sum(1 for (a,b) in pairs if int(cur[ch,a])==int(cur[ch,b]))) | |
| mean_conf.append(np.mean(cs)) | |
| fig = plt.figure(figsize=(6,3.2)); ax = plt.gca() | |
| ax.plot(mean_conf, lw=1.2); ax.set_xlabel("sample index"); ax.set_ylabel("mean conflicts") | |
| ax.set_title("Sampling trace (adj face coloring)") | |
| plt.tight_layout(); plt.savefig(args.energy, dpi=160); plt.close(fig) | |
| log(f"Energy trace saved: {args.energy}", args.trace) | |
| except Exception: | |
| pass | |
| print("---- Results (adjacency) ----") | |
| print(f"Nodes: {n}, Edges: {int(A.sum()//2)}") | |
| print(f"Conflicts (final): {int(conflicts)}, Proper coloring: {proper}") | |
| print(f"Final β≈{beta_cur:.3g}, penalty≈{penalty_cur:.3g}, warmup={warmup_cur}, steps={steps_cur}") | |
| if not args.skip_plot: | |
| print(f"[OK] Saved graph to '{args.png}' and energy trace to '{args.energy}'.") | |
| log("Done.", args.trace) | |
| return | |
| # -------------- Lambda-term path -------------- | |
| log("Parsing lambda term...", args.trace) | |
| toks = tokenize(args.term); ast = Parser(toks).parse() | |
| log("Parsed term.", args.trace) | |
| log("Building rotation system...", args.trace) | |
| rotation, kinds, tree_edges, root_tree, var_edges, root_var_edge, edge_labels, outer_param, outer_body_str, root_body_id, root_var_id = build_rotation_with_labels(ast, "vbp", "paf") | |
| log(f"Rotation built: vertices={len(rotation)}", args.trace) | |
| log("Enumerating faces & constraints...", args.trace) | |
| V_triples, faces_list = vertex_face_triples(rotation) | |
| n_faces = len(faces_list) | |
| log(f"Faces={n_faces}, vertex-triples={len(V_triples)}", args.trace) | |
| log("Computing Tutte embedding...", args.trace) | |
| pos, outer_face = tutte_embedding_for_term(rotation, root_var_edge) | |
| log(f"Embedding done. outer_face={outer_face}", args.trace) | |
| log("Creating THRML nodes...", args.trace) | |
| nodes = [CategoricalNode() for _ in range(n_faces)] | |
| log("Nodes created.", args.trace) | |
| if args.task == "matching": | |
| # Program factory (β scales penalties) | |
| def make_prog(beta_scale: float, penalty_scale: float): | |
| return build_prog_matching( | |
| V_triples, nodes, | |
| penalty=penalty_scale * beta_scale, | |
| outer_node=nodes[outer_face], | |
| clamp_penalty=args.clamp_penalty * beta_scale | |
| ) | |
| def score_of(vec): | |
| # NAE violations (we want 0) | |
| violated = 0 | |
| for (a,b,c) in V_triples.values(): | |
| s0,s1,s2 = int(vec[a]), int(vec[b]), int(vec[c]) | |
| if s0==s1==s2: violated += 1 | |
| return violated | |
| target = 0 | |
| enforce = args.enforce_perfect | |
| else: # face4 on lambda-term dual | |
| K = int(args.K) if args.K else 4 | |
| pairs = dual_edges_from_rotation(rotation) | |
| # Break symmetry by clamping outer_face to 0 and one neighbor to 1 | |
| neighbor = None | |
| for (a,b) in pairs: | |
| if a == outer_face: neighbor = nodes[b]; break | |
| if b == outer_face: neighbor = nodes[a]; break | |
| def make_prog(beta_scale: float, penalty_scale: float): | |
| return build_prog_faceK( | |
| nodes, pairs, K=K, | |
| penalty_equal=penalty_scale * beta_scale, | |
| outer_node=nodes[outer_face], clamp_color=0, clamp_penalty=args.clamp_penalty * beta_scale, | |
| neighbor_node=neighbor, neighbor_color=1, neighbor_penalty=args.neighbor_penalty * beta_scale, | |
| color_blocks=None | |
| ) | |
| def score_of(vec): | |
| return sum(1 for (a,b) in pairs if int(vec[a])==int(vec[b])) | |
| target = args.target_conflicts | |
| enforce = args.enforce_proper | |
| # Initialize via anneal if requested | |
| beta_cur = (args.beta_end if args.anneal else args.beta) | |
| penalty_cur = args.penalty | |
| warmup_cur = args.warmup | |
| steps_cur = args.steps_per_sample | |
| init_states=None; free_blocks=None; _K=None | |
| if args.anneal: | |
| betas = np.linspace(args.beta_start, args.beta_end, num=args.phases).tolist() | |
| log(f"Anneal ladder: betas={betas}, sweeps_per_phase={args.sweeps_per_phase}", args.trace) | |
| start_ann = time.time() | |
| def make_for_anneal(b): return make_prog(b, penalty_cur) | |
| init_states, free_blocks, _K = anneal_ladder(key, make_for_anneal, nodes, args.phases, betas, | |
| sweeps_per_sample=1, jit_chains=False, trace=args.trace, precompile=False) | |
| log(f"Anneal finished in {time.time()-start_ann:.2f}s", args.trace) | |
| else: | |
| prog, free_blocks, _K = make_prog(beta_cur, penalty_cur) | |
| # Adaptive loop | |
| best_vec=None; best_score=None | |
| rounds = 1 if not enforce else args.max_rounds | |
| for ri in range(rounds): | |
| if enforce and ri > 0: | |
| beta_cur *= args.beta_mult | |
| penalty_cur *= args.penalty_mult | |
| warmup_cur = int(max(warmup_cur * args.warmup_mult, warmup_cur + 1)) | |
| steps_cur += args.steps_inc | |
| log(f"[adaptive] round {ri+1}: β={beta_cur:.3g}, penalty={penalty_cur:.3g}, warmup={warmup_cur}, steps={steps_cur}", args.trace) | |
| prog, free_blocks, _K = make_prog(beta_cur, penalty_cur) | |
| n_tries = max(1, args.restarts+1) | |
| cand_vec=None; cand_score=None | |
| for ti in range(n_tries): | |
| log(f"Sampling: round={ri+1}/{rounds}, try={ti+1}/{n_tries} | chains={args.chains}, warmup={warmup_cur}, samples={args.samples}, steps={steps_cur} ...", args.trace) | |
| start_samp = time.time() | |
| samples = run_thrml_sampling(key, prog, free_blocks, nodes, n_chains=args.chains, | |
| n_warmup=warmup_cur, n_samples=args.samples, | |
| steps_per_sample=steps_cur, | |
| jit_chains=args.jit_chains, trace=args.trace, | |
| init_states=init_states) | |
| log(f"Sampling finished in {time.time()-start_samp:.2f}s", args.trace) | |
| samples_array = np.array(samples) | |
| vec, sc = pick_best_sample(samples_array, score_of) | |
| if (cand_vec is None) or (sc < cand_score): | |
| cand_vec, cand_score = vec, sc | |
| init_states = None # randomize next try | |
| if (best_vec is None) or (cand_score < best_score): | |
| best_vec, best_score = cand_vec, cand_score | |
| log(f"[adaptive] best score this round: {cand_score}", args.trace) | |
| if best_score <= target: | |
| log(f"[adaptive] target reached: score={best_score}", args.trace) | |
| break | |
| final_labels = best_vec | |
| if args.task == "matching": | |
| violated = int(best_score) | |
| edges_all, faces_all, edge_faces = enumerate_faces(rotation) | |
| matching = set() | |
| for (u,vx) in edges_all: | |
| lf, rf = edge_faces[(u,vx)] | |
| if int(final_labels[lf]) == int(final_labels[rf]): | |
| matching.add(tuple(sorted((u,vx)))) | |
| deg = {vv:0 for vv in rotation} | |
| for (u,vx) in matching: | |
| deg[u]+=1; deg[vx]+=1 | |
| is_perfect = all(d == 1 for d in deg.values()) | |
| bitstring = dual_bfs_bitstring(final_labels, rotation, outer_face) | |
| title = (f"THRML matching (term): β≈{beta_cur:.2f}, penalty≈{penalty_cur:.2f} | " | |
| f"violations={violated} | perfect={is_perfect} | bitstring: {bitstring}") | |
| if not args.skip_plot: | |
| draw_term_matching(rotation, faces_list, pos, final_labels, matching, title, args.png) | |
| log(f"Graph saved: {args.png}", args.trace) | |
| # violations trace of last batch | |
| try: | |
| mean_viol = [] | |
| for sidx in range(samples_array.shape[1]): | |
| cur = samples_array[:, sidx, :] | |
| vs = [] | |
| for ch in range(cur.shape[0]): | |
| cnt = 0 | |
| for (a,b,c) in V_triples.values(): | |
| if int(cur[ch,a])==int(cur[ch,b])==int(cur[ch,c]): cnt += 1 | |
| vs.append(cnt) | |
| mean_viol.append(np.mean(vs)) | |
| fig = plt.figure(figsize=(6,3.2)); ax = plt.gca() | |
| ax.plot(mean_viol, lw=1.2); ax.set_xlabel("sample index"); ax.set_ylabel("mean NAE violations") | |
| ax.set_title("Sampling trace (matching)") | |
| plt.tight_layout(); plt.savefig(args.energy, dpi=160); plt.close(fig) | |
| log(f"Energy trace saved: {args.energy}", args.trace) | |
| except Exception: | |
| pass | |
| print("---- Results (term/matching) ----") | |
| print(f"Faces: {len(faces_list)}, Vertices: {len(rotation)}") | |
| print(f"NAE violations (final): {int(violated)}, Perfect matching: {is_perfect}") | |
| print(f"Matching edges: {len(matching)}") | |
| print(f"Root face: {outer_face}") | |
| if not args.skip_plot: | |
| print(f"[OK] Saved graph to '{args.png}' and energy trace to '{args.energy}'.") | |
| else: | |
| K = int(args.K) if args.K else 4 | |
| pairs = dual_edges_from_rotation(rotation) | |
| conflicts = int(best_score) | |
| proper = (conflicts <= target) | |
| title = (f"THRML face-{K} coloring (term): β≈{beta_cur:.2f}, penalty≈{penalty_cur:.2f} | " | |
| f"conflicts={conflicts} | proper={proper}") | |
| if not args.skip_plot: | |
| draw_term_coloring(rotation, faces_list, pos, final_labels, title, args.png, mono=args.mono) | |
| log(f"Graph saved: {args.png}", args.trace) | |
| try: | |
| mean_conf = [] | |
| for sidx in range(samples_array.shape[1]): | |
| cur = samples_array[:, sidx, :] | |
| cs = [] | |
| for ch in range(cur.shape[0]): | |
| cs.append(sum(1 for (a,b) in pairs if int(cur[ch,a])==int(cur[ch,b]))) | |
| mean_conf.append(np.mean(cs)) | |
| fig = plt.figure(figsize=(6,3.2)); ax = plt.gca() | |
| ax.plot(mean_conf, lw=1.2); ax.set_xlabel("sample index"); ax.set_ylabel("mean conflicts") | |
| ax.set_title("Sampling trace (term face coloring)") | |
| plt.tight_layout(); plt.savefig(args.energy, dpi=160); plt.close(fig) | |
| log(f"Energy trace saved: {args.energy}", args.trace) | |
| except Exception: | |
| pass | |
| print("---- Results (term/faceK) ----") | |
| print(f"Faces: {len(faces_list)}, Vertices: {len(rotation)}") | |
| print(f"Conflicts (final): {int(conflicts)}, Proper coloring: {proper}") | |
| print(f"Root face: {outer_face}") | |
| if not args.skip_plot: | |
| print(f"[OK] Saved graph to '{args.png}' and energy trace to '{args.energy}'.") | |
| log("Done.", args.trace) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment