Skip to content

exact_scalar

Exact scalar arithmetic for ZX-calculus phase computations.

Implements exact arithmetic for complex numbers of the form

(a + be^(ipi/4) + ci + de^(-i*pi/4)) * 2^power

This representation enables exact computation of phases in ZX-calculus graphs without floating-point errors.

ExactScalarArray

ExactScalarArray(coeffs: Array, power: Array | None = None)

Bases: Module


              flowchart TD
              tsim.core.exact_scalar.ExactScalarArray[ExactScalarArray]

              

              click tsim.core.exact_scalar.ExactScalarArray href "" "tsim.core.exact_scalar.ExactScalarArray"
            

Exact scalar array for ZX-calculus phase arithmetic using dyadic representation.

Represents values of the form (c_0 + c_1·ω + c_2·ω² + c_3·ω³) × 2^power where ω = e^(iπ/4). This enables exact computation without floating-point errors.

Attributes:

Name Type Description
coeffs Array

Array of shape (..., 4) containing dyadic coefficients.

power Array

Array of powers of 2 for scaling.

The value represented is (c_0 + c_1omega + c_2omega^2 + c_3omega^3) * 2^power where omega = e^{ipi/4}.

Source code in src/tsim/core/exact_scalar.py
62
63
64
65
66
67
68
69
70
71
72
def __init__(self, coeffs: Array, power: Array | None = None):
    """Initialize from coefficients and optional power.

    The value represented is (c_0 + c_1*omega + c_2*omega^2 + c_3*omega^3) * 2^power
    where omega = e^{i*pi/4}.
    """
    self.coeffs = coeffs
    if power is None:
        self.power = jnp.zeros(coeffs.shape[:-1], dtype=jnp.int32)
    else:
        self.power = power

__mul__

__mul__(other: ExactScalarArray) -> ExactScalarArray

Element-wise multiplication.

Source code in src/tsim/core/exact_scalar.py
74
75
76
77
78
def __mul__(self, other: "ExactScalarArray") -> "ExactScalarArray":
    """Element-wise multiplication."""
    new_coeffs = _scalar_mul(self.coeffs, other.coeffs)
    new_power = self.power + other.power
    return ExactScalarArray(new_coeffs, new_power)

prod

prod(axis: int = -1) -> ExactScalarArray

Compute product along the specified axis using associative scan.

Returns identity (1+0i with power 0) for empty reductions.

Parameters:

Name Type Description Default
axis int

The axis along which to compute the product.

-1

Returns:

Type Description
ExactScalarArray

ExactScalarArray with the product computed along the axis.

Source code in src/tsim/core/exact_scalar.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def prod(self, axis: int = -1) -> "ExactScalarArray":
    """Compute product along the specified axis using associative scan.

    Returns identity (1+0i with power 0) for empty reductions.

    Args:
        axis: The axis along which to compute the product.

    Returns:
        ExactScalarArray with the product computed along the axis.

    """
    if axis < 0:
        axis = self.coeffs.ndim + axis

    if self.coeffs.shape[axis] == 0:
        # Product of empty sequence is identity: [1, 0, 0, 0] * 2^0
        coeffs_shape = self.coeffs.shape[:axis] + self.coeffs.shape[axis + 1 :]
        result_coeffs = jnp.zeros(coeffs_shape, dtype=self.coeffs.dtype)
        result_coeffs = result_coeffs.at[..., 0].set(1)

        power_shape = self.power.shape[:axis] + self.power.shape[axis + 1 :]
        result_power = jnp.zeros(power_shape, dtype=self.power.dtype)

        return ExactScalarArray(result_coeffs, result_power)

    scanned = lax.associative_scan(_scalar_mul, self.coeffs, axis=axis)
    result_coeffs = jnp.take(scanned, indices=-1, axis=axis)
    result_power = jnp.sum(self.power, axis=axis)

    return ExactScalarArray(result_coeffs, result_power)

reduce

reduce() -> ExactScalarArray

Reduce power by dividing coefficients by 2 while they are all even.

Source code in src/tsim/core/exact_scalar.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def reduce(self) -> "ExactScalarArray":
    """Reduce power by dividing coefficients by 2 while they are all even."""

    def cond_fun(carry):
        coeffs, _ = carry
        # Reducible if all 4 components are even AND not all zero (0 is infinitely divisible)
        # We check 'not zero' to avoid infinite loops on strict 0.
        reducible = jnp.all(coeffs % 2 == 0, axis=-1) & jnp.any(
            coeffs != 0, axis=-1
        )
        return jnp.any(reducible)

    def body_fun(carry):
        coeffs, power = carry
        reducible = jnp.all(coeffs % 2 == 0, axis=-1) & jnp.any(
            coeffs != 0, axis=-1
        )
        coeffs = jnp.where(reducible[..., None], coeffs // 2, coeffs)
        power = jnp.where(reducible, power + 1, power)
        return coeffs, power

    new_coeffs, new_power = jax.lax.while_loop(
        cond_fun, body_fun, (self.coeffs, self.power)
    )
    return ExactScalarArray(new_coeffs, new_power)

sum

sum() -> ExactScalarArray

Sum elements along the last axis (axis=-2).

Aligns powers to the minimum power before summing.

Source code in src/tsim/core/exact_scalar.py
106
107
108
109
110
111
112
113
114
115
116
117
118
def sum(self) -> "ExactScalarArray":
    """Sum elements along the last axis (axis=-2).

    Aligns powers to the minimum power before summing.
    """
    # TODO: check for overflow and potentially refactor sum routine to scan
    # the array and reduce scalars every couple steps

    min_power = jnp.min(self.power, keepdims=True, axis=-1)
    pow = (self.power - min_power)[..., None]
    aligned_coeffs = self.coeffs * 2**pow
    summed_coeffs = jnp.sum(aligned_coeffs, axis=-2)
    return ExactScalarArray(summed_coeffs, min_power.squeeze(-1))

to_complex

to_complex() -> jax.Array

Convert to complex number.

Source code in src/tsim/core/exact_scalar.py
152
153
154
155
156
def to_complex(self) -> jax.Array:
    """Convert to complex number."""
    c_val = _scalar_to_complex(self.coeffs)
    scale = jnp.pow(2.0, self.power)
    return c_val * scale