Skip to content

compile

Compilation of ZX graphs into JAX-compatible data structures.

CompiledScalarGraphs

Bases: Module


              flowchart TD
              tsim.compile.compile.CompiledScalarGraphs[CompiledScalarGraphs]

              

              click tsim.compile.compile.CompiledScalarGraphs href "" "tsim.compile.compile.CompiledScalarGraphs"
            

JAX-compatible compiled representation of a list of scalar ZX graphs.

The scalar for each graph is a product of four term families, multiplied by a per-graph ScalarPrefactor (global phase, floatfactor, 2^power2, optional approximate complex floatfactor). All arrays are static-shaped so the whole struct can be traced under jax.jit.

compile_scalar_graphs

compile_scalar_graphs(
    g_list: list[BaseGraph], params: list[str]
) -> CompiledScalarGraphs

Compile ZX-graph list into JAX-compatible structure for fast evaluation.

Parameters:

Name Type Description Default
g_list list[BaseGraph]

List of ZX-graphs to compile (must be scalar graphs with no vertices)

required
params list[str]

List of parameter names used by this circuit. Each parameter will correspond to columns in the jax.Arrays of the compiled circuit.

required

Returns:

Type Description
CompiledScalarGraphs

CompiledScalarGraphs with all data in static-shaped JAX arrays

Source code in src/tsim/compile/compile.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def compile_scalar_graphs(
    g_list: list[BaseGraph], params: list[str]
) -> CompiledScalarGraphs:
    """Compile ZX-graph list into JAX-compatible structure for fast evaluation.

    Args:
        g_list: List of ZX-graphs to compile (must be scalar graphs with no vertices)
        params: List of parameter names used by this circuit. Each parameter will correspond to columns in
            the jax.Arrays of the compiled circuit.

    Returns:
        CompiledScalarGraphs with all data in static-shaped JAX arrays

    """
    for i, g in enumerate(g_list):
        n_vertices = len(list(g.vertices()))
        if n_vertices != 0:
            raise ValueError(
                f"Only scalar graphs can be compiled but graph {i} has {n_vertices} vertices"
            )
        if g.scalar.phasevars_pi and not g.scalar.is_zero:
            raise NotImplementedError(
                f"compile_scalar_graphs does not support Scalar.phasevars_pi "
                f"(graph {i} has phasevars_pi={sorted(g.scalar.phasevars_pi)!r})"
            )

    g_list = [g for g in g_list if not g.scalar.is_zero]

    n_params = len(params)
    num_graphs = len(g_list)
    char_to_idx = {char: i for i, char in enumerate(params)}

    return CompiledScalarGraphs(
        num_graphs=num_graphs,
        n_params=n_params,
        node_phases=_compile_node_phases(g_list, char_to_idx, n_params),
        halfpi_phases=_compile_halfpi_phases(g_list, char_to_idx, n_params),
        pi_products=_compile_pi_products(g_list, char_to_idx, n_params),
        phase_pairs=_compile_phase_pairs(g_list, char_to_idx, n_params),
        prefactor=_compile_prefactor(g_list),
    )