Skip to content

graph

ZX graph construction, manipulation, and preparation for sampling.

ConnectedComponent dataclass

ConnectedComponent(
    graph: BaseGraph, output_indices: list[int]
)

A connected subgraph with its associated output indices.

build_sampling_graph

build_sampling_graph(
    built: GraphRepresentation,
    sample_detectors: bool = False,
) -> BaseGraph

Build a ZX graph for sampling from a GraphRepresentation.

This is the internal implementation of get_sampling_graph.

Source code in src/tsim/core/graph.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def build_sampling_graph(
    built: GraphRepresentation, sample_detectors: bool = False
) -> BaseGraph:
    """Build a ZX graph for sampling from a GraphRepresentation.

    This is the internal implementation of get_sampling_graph.
    """
    g = built.graph.copy()

    # Initialize un-initialized first vertices to the 0 state
    for v in built.first_vertex.values():
        if g.type(v) == VertexType.BOUNDARY:
            g.set_type(v, VertexType.X)

    # Clean up last row
    if built.last_vertex:
        max_row = max(g.row(v) for v in built.last_vertex.values())
        for q in built.last_vertex:
            g.set_row(built.last_vertex[q], max_row)

    num_measurements = len(built.rec)
    outputs = [v for v in g.vertices() if g.type(v) == VertexType.BOUNDARY]
    g.set_outputs(tuple(outputs))

    g_adj = g.adjoint()
    g.compose(g_adj)

    g = g.copy()

    label_to_vertex: dict[str, list[int]] = defaultdict(list)
    annotation_to_vertex: dict[str, list[int]] = defaultdict(list)
    for v in g.vertices():
        phase_vars = g._phaseVars[v]
        if len(phase_vars) != 1:
            continue
        phase = next(iter(phase_vars))
        if "det" in phase or "obs" in phase or "rec" in phase or "m" in phase:
            label_to_vertex[phase].append(v)
        if "det" in phase or "obs" in phase:
            annotation_to_vertex[phase].append(v)

    outputs = [0] * num_measurements if not sample_detectors else []

    # Connect all rec[i] vertices to each other and add red vertex with rec[i] label
    for i in range(num_measurements):
        label = f"rec[{i}]"
        vertices = label_to_vertex[label]

        assert len(vertices) == 2
        v0, v1 = vertices
        if not g.connected(v0, v1):
            g.add_edge((v0, v1))
        g.set_phase(v0, 0)
        g.set_phase(v1, 0)

        # Add outputs
        if not sample_detectors:
            v3 = g.add_vertex(VertexType.BOUNDARY, qubit=-1, row=i + 1, phase=0)
            outputs[i] = v3
            g.add_edge((v0, v3))

    # Connect all m[i] vertices to each other
    for i in range(len(built.silent_rec)):
        label = f"m[{i}]"
        vertices = label_to_vertex[label]

        assert len(vertices) == 2
        v0, v1 = vertices
        if not g.connected(v0, v1):
            g.add_edge((v0, v1))
        g.set_phase(v0, 0)
        g.set_phase(v1, 0)

    if not sample_detectors:
        # Sample measurements: remove detectors and observables
        for vertices in annotation_to_vertex.values():
            assert len(vertices) == 2
            for v in vertices:
                g.remove_vertex(v)
    else:
        # Sample detectors and observables:
        # Keep detector and observables but remove the adjoint (duplicated)
        # annotation nodes
        for vertices in annotation_to_vertex.values():
            assert len(vertices) == 2
            g.remove_vertex(vertices.pop())

        labels = [f"det[{i}]" for i in range(len(built.detectors))] + [
            f"obs[{i}]" for i in sorted(built.observables_dict)
        ]
        for label in labels:
            vs = annotation_to_vertex[label]
            assert len(vs) == 1
            v = vs[0]
            row = g.row(v)
            vb = g.add_vertex(
                VertexType.BOUNDARY, qubit=-2 if "det" in label else -2.5, row=row
            )
            g.add_edge((v, vb))
            g.set_phase(v, 0)
            outputs.append(vb)

    g.set_outputs(tuple(outputs))

    return g

classify_direct

classify_direct(
    component: ConnectedComponent,
) -> tuple[int, bool] | None

Check if a component is directly determined by a single f-variable.

A component qualifies when its graph consists of exactly two vertices — one boundary output and one Z-spider — connected by a Hadamard edge, where the Z-spider carries a single f parameter and a constant phase of either 0 (no flip) or π (flip).

Parameters:

Name Type Description Default
component ConnectedComponent

A connected component to classify.

required

Returns:

Type Description
tuple[int, bool] | None

(f_index, flip) if the fast path applies, otherwise None.

Source code in src/tsim/core/graph.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def classify_direct(
    component: ConnectedComponent,
) -> tuple[int, bool] | None:
    """Check if a component is directly determined by a single f-variable.

    A component qualifies when its graph consists of exactly two vertices — one
    boundary output and one Z-spider — connected by a Hadamard edge, where the
    Z-spider carries a single ``f`` parameter and a constant phase of either 0
    (no flip) or π (flip).

    Args:
        component: A connected component to classify.

    Returns:
        ``(f_index, flip)`` if the fast path applies, otherwise ``None``.

    """
    graph = component.graph
    outputs = list(graph.outputs())
    if len(outputs) != 1:
        return None

    vertices = list(graph.vertices())
    if len(vertices) != 2:
        return None

    v_out = outputs[0]
    neighbors = list(graph.neighbors(v_out))
    if len(neighbors) != 1:
        return None

    v_det = neighbors[0]
    if graph.type(v_det) != VertexType.Z:
        return None
    if graph.edge_type(graph.edge(v_out, v_det)) != EdgeType.HADAMARD:
        return None

    params = graph.get_params(v_det)
    if len(params) != 1:
        return None
    f_param = next(iter(params))
    if not f_param.startswith("f"):
        return None

    all_graph_params = get_params(graph)
    if all_graph_params != {f_param}:
        return None

    phase = graph.phase(v_det)
    if phase == 0:
        flip = False
    elif phase == Fraction(1, 1):
        flip = True
    else:
        return None

    return int(f_param[1:]), flip

connected_components

connected_components(
    g: BaseGraph,
) -> list[ConnectedComponent]

Return each connected component of g as its own ZX subgraph.

Each component is packaged inside a :class:ConnectedComponent that contains the subgraph and a list of output indices matching the original output indices.

Source code in src/tsim/core/graph.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def connected_components(g: BaseGraph) -> list[ConnectedComponent]:
    """Return each connected component of ``g`` as its own ZX subgraph.

    Each component is packaged inside a :class:`ConnectedComponent` that contains
    the subgraph and a list of output indices matching the original output indices.
    """
    components: list[ConnectedComponent] = []
    visited: set[Any] = set()
    outputs = tuple(g.outputs())
    output_indices = {vertex: idx for idx, vertex in enumerate(outputs)}

    for vertex in list(g.vertices()):
        if vertex in visited:
            continue

        component_vertices = _collect_vertices(g, vertex, visited)
        subgraph = _induced_subgraph(g, component_vertices)

        component_output_indices = [
            output_indices[v] for v in component_vertices if v in output_indices
        ]
        component_output_indices.sort()

        components.append(
            ConnectedComponent(
                graph=subgraph,
                output_indices=component_output_indices,
            )
        )

    return components

evaluate_graph

evaluate_graph(
    g: GraphS, vals: dict[str, Fraction] | None = None
) -> np.ndarray

Evaluate a ZX graph to a tensor with given parameter values.

Source code in src/tsim/core/graph.py
447
448
449
450
451
452
453
454
455
456
457
458
459
def evaluate_graph(g: GraphS, vals: dict[str, Fraction] | None = None) -> np.ndarray:
    """Evaluate a ZX graph to a tensor with given parameter values."""
    if vals is None:
        vals = defaultdict(lambda: Fraction(0, 1))
    g = g.copy()  # type: ignore
    for v in g.vertices():
        param_phase = g.phase(v)
        for p in g.get_params(v):
            param_phase += vals[p]
        g.set_phase(v, param_phase, clearParams=True)
    scalar_val = g.scalar.evaluate_scalar(vals)
    g.scalar = Scalar()
    return g.to_tensor() * scalar_val

get_params

get_params(g: BaseGraph) -> set[str]

Get all parameter variables that appear in the graph and its scalar.

Collects variables from: - Vertex phases (g._phaseVars) - Scalar phase variables (phasevars_pi, phasevars_pi_pair, phasevars_halfpi) - Scalar phase pairs (phasepairs with paramsA, paramsB) - Scalar phase nodes (phasenodevars)

Parameters:

Name Type Description Default
g BaseGraph

A ZX graph with parametrized phases

required

Returns:

Type Description
set[str]

Set of all variable names (e.g., {'f0', 'f2', 'm1'}) that appear in the graph

Source code in src/tsim/core/graph.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def get_params(g: BaseGraph) -> set[str]:
    """Get all parameter variables that appear in the graph and its scalar.

    Collects variables from:
    - Vertex phases (g._phaseVars)
    - Scalar phase variables (phasevars_pi, phasevars_pi_pair, phasevars_halfpi)
    - Scalar phase pairs (phasepairs with paramsA, paramsB)
    - Scalar phase nodes (phasenodevars)

    Args:
        g: A ZX graph with parametrized phases

    Returns:
        Set of all variable names (e.g., {'f0', 'f2', 'm1'}) that appear in the graph

    """
    active: set[str] = set()

    for v in g.vertices():
        active |= g._phaseVars[v]

    scalar = g.scalar

    active |= scalar.phasevars_pi

    for pair in scalar.phasevars_pi_pair:
        for var_set in pair:
            active |= var_set

    for coeff in scalar.phasevars_halfpi:  # coeff is 1 or 3
        for var_set in scalar.phasevars_halfpi[coeff]:
            active |= var_set

    for spider_pair in scalar.phasepairs:
        active |= spider_pair.paramsA
        active |= spider_pair.paramsB

    for var_set in scalar.phasenodevars:
        active |= var_set

    return active

prepare_graph

prepare_graph(
    circuit: Circuit, *, sample_detectors: bool
) -> SamplingGraph

Prepare a circuit for compilation.

This function performs the graph preparation phase: 1. Parse the stim circuit into a ZX graph 2. Build the sampling graph (compose with adjoint, add outputs) 3. Reduce the graph via zx.full_reduce 4. Transform error basis via Gaussian elimination (e → f) 5. Clear the scalar (safe before stabilizer rank decomposition)

Parameters:

Name Type Description Default
circuit Circuit

The quantum circuit to prepare.

required
sample_detectors bool

If True, prepare for detector sampling. If False, prepare for measurement sampling.

required

Returns:

Type Description
SamplingGraph

A SamplingGraph containing the reduced graph and error transformation.

Source code in src/tsim/core/graph.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
def prepare_graph(circuit: Circuit, *, sample_detectors: bool) -> SamplingGraph:
    """Prepare a circuit for compilation.

    This function performs the graph preparation phase:
    1. Parse the stim circuit into a ZX graph
    2. Build the sampling graph (compose with adjoint, add outputs)
    3. Reduce the graph via zx.full_reduce
    4. Transform error basis via Gaussian elimination (e → f)
    5. Clear the scalar (safe before stabilizer rank decomposition)

    Args:
        circuit: The quantum circuit to prepare.
        sample_detectors: If True, prepare for detector sampling.
            If False, prepare for measurement sampling.

    Returns:
        A SamplingGraph containing the reduced graph and error transformation.

    """
    built = parse_stim_circuit(circuit._stim_circ)

    # Build sampling graph (doubles the diagram)
    graph = build_sampling_graph(built, sample_detectors=sample_detectors)

    zx.full_reduce(graph, paramSafe=True)
    squash_graph(graph)

    # Transform error basis: e-params → f-params via Gaussian elimination
    graph, error_transform = transform_error_basis(graph, num_e=built.num_error_bits)

    # Since we compute normalization separately, discard all scalar terms.
    # This avoids computing scalars that would cancel out anyway during normalization.
    graph.scalar = Scalar()

    return SamplingGraph(
        graph=graph,
        error_transform=error_transform,
        channel_probs=built.channel_probs,
        num_outputs=len(graph.outputs()),
        num_detectors=len(built.detectors),
    )

scale_horizontally

scale_horizontally(g: BaseGraph, scale: float) -> None

Scale horizontal positions of graph vertices by a factor of scale.

Parameters:

Name Type Description Default
g BaseGraph

A ZX graph

required
scale float

The factor to scale the graph by

required
Source code in src/tsim/core/graph.py
505
506
507
508
509
510
511
512
513
514
def scale_horizontally(g: BaseGraph, scale: float) -> None:
    """Scale horizontal positions of graph vertices by a factor of ``scale``.

    Args:
        g: A ZX graph
        scale: The factor to scale the graph by

    """
    for v in g.vertices():
        g.set_row(v, g.row(v) * scale)

squash_graph

squash_graph(g: BaseGraph) -> None

Compact the graph by placing vertices underneath their output connections.

Starting from output vertices, each vertex is placed directly underneath (same row, qubit - 1) its already-placed neighbor. Positions are assigned via BFS traversal from outputs, ensuring no (qubit, row) collisions.

Source code in src/tsim/core/graph.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
def squash_graph(g: BaseGraph) -> None:
    """Compact the graph by placing vertices underneath their output connections.

    Starting from output vertices, each vertex is placed directly underneath
    (same row, qubit - 1) its already-placed neighbor. Positions are assigned
    via BFS traversal from outputs, ensuring no (qubit, row) collisions.
    """
    outputs = list(g.outputs())
    if not outputs:
        return

    # Normalize output positions: consecutive rows at qubit = num_outputs
    num_outputs = len(outputs)
    for row, v in enumerate(outputs):
        g.set_row(v, row)

    # Track occupied positions and placed vertices
    occupied: set[tuple[int, int]] = {(num_outputs, row) for row in range(num_outputs)}
    placed: set[Any] = set(outputs)

    # BFS from outputs
    queue: deque[Any] = deque(outputs)

    while queue:
        current = queue.popleft()
        current_qubit = int(g.qubit(current))
        current_row = int(g.row(current))

        for neighbor in g.neighbors(current):
            if neighbor in placed:
                continue

            # Try to place directly underneath: same row, qubit - 1
            target_qubit = current_qubit - 1
            target_row = current_row

            # If spot is taken, search for nearest free spot at same qubit level
            if (target_qubit, target_row) in occupied:
                # Search outward from target_row
                for offset in range(1, 1000):
                    if (target_qubit, target_row + offset) not in occupied:
                        target_row = target_row + offset
                        break
                    if (
                        target_qubit,
                        target_row - offset,
                    ) not in occupied and target_row - offset >= 0:
                        target_row = target_row - offset
                        break

            g.set_qubit(neighbor, target_qubit)
            g.set_row(neighbor, target_row)
            occupied.add((target_qubit, target_row))
            placed.add(neighbor)
            queue.append(neighbor)

    for v in g.outputs():
        neighbors = list(g.neighbors(v))
        if neighbors and len(list(g.neighbors(neighbors[0]))) == 1:
            g.set_qubit(neighbors[0], g.qubit(v) + 1)
            g.set_row(neighbors[0], g.row(v))

transform_error_basis

transform_error_basis(
    g: BaseGraph, num_e: int | None = None
) -> tuple[BaseGraph, np.ndarray]

Transform phase variables from the original 'e' basis to a reduced 'f' basis.

This function finds a linearly independent basis for the phase variables across all vertices and transforms them accordingly. The original variables (e0, e1, ...) are mapped to a smaller set (f0, f1, ...) where each f_i corresponds to a linear combination of original e variables.

Parameters:

Name Type Description Default
g BaseGraph

A ZX graph with phase variables attached to vertices.

required
num_e int | None

Total number of e-variables. If provided, the returned matrix will have exactly this many columns (padded with zeros if needed). If None, the matrix will have only the columns that appear in the graph.

None

Returns:

Type Description
tuple[BaseGraph, ndarray]

A tuple containing: - The modified graph (same object, mutated in place) - A binary matrix of shape (num_f, num_e) where entry [i, j] = 1 means f_i depends on e_j. For example, if row 0 is [0, 1, 0, 1], then f0 = e1 XOR e3.

Source code in src/tsim/core/graph.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
def transform_error_basis(
    g: BaseGraph, num_e: int | None = None
) -> tuple[BaseGraph, np.ndarray]:
    """Transform phase variables from the original 'e' basis to a reduced 'f' basis.

    This function finds a linearly independent basis for the phase variables
    across all vertices and transforms them accordingly. The original variables
    (e0, e1, ...) are mapped to a smaller set (f0, f1, ...) where each f_i
    corresponds to a linear combination of original e variables.

    Args:
        g: A ZX graph with phase variables attached to vertices.
        num_e: Total number of e-variables. If provided, the returned matrix
            will have exactly this many columns (padded with zeros if needed).
            If None, the matrix will have only the columns that appear in the graph.

    Returns:
        A tuple containing:
            - The modified graph (same object, mutated in place)
            - A binary matrix of shape (num_f, num_e) where entry [i, j] = 1
              means f_i depends on e_j. For example, if row 0 is [0, 1, 0, 1],
              then f0 = e1 XOR e3.

    """
    # Prioritize output-connected detector vertices so that f0, f1, ...
    # are assigned in output order.  This maximises the chance that the
    # direct-component fast path produces an identity permutation, avoiding
    # a column reindex at sample time.
    output_detectors = []
    for v_out in g.outputs():
        neighbors = list(g.neighbors(v_out))
        if len(neighbors) == 1 and g._phaseVars.get(neighbors[0]):
            output_detectors.append(neighbors[0])

    output_det_set = set(output_detectors)
    other_param_vertices = [
        v for v in g.vertices() if v not in output_det_set and g._phaseVars.get(v)
    ]
    parametrized_vertices = output_detectors + other_param_vertices

    if not parametrized_vertices:
        g.scalar = Scalar()
        num_cols = num_e if num_e is not None else 0
        return g, np.zeros((0, num_cols), dtype=np.uint8)

    # Parse variable indices and find the dimension
    for var in (var for v in parametrized_vertices for var in g._phaseVars[v]):
        assert (
            var.startswith("e") and var[1:].isdigit()
        ), f"unexpected phase var {var!r}"
    error_indices = [
        [int(var[1:]) for var in g._phaseVars[v]] for v in parametrized_vertices
    ]
    num_errors = max(max(indices) for indices in error_indices) + 1
    if num_e is not None:
        num_errors = max(num_errors, num_e)

    # Build binary matrix representation
    error_matrix = np.zeros((len(error_indices), num_errors), dtype=np.uint8)
    for row_idx, indices in enumerate(error_indices):
        error_matrix[row_idx, indices] = 1

    basis, transform = find_basis(error_matrix)
    # Now: error_matrix = transform @ basis

    for v, transform_row in zip(parametrized_vertices, transform, strict=True):
        new_vars = {f"f{j}" for j in np.nonzero(transform_row)[0]}
        g._phaseVars[v] = new_vars

    return g, basis