Skip to content

exact_scalar

Exact scalar arithmetic for ZX-calculus phase computations.

Implements exact arithmetic for complex numbers of the form

(a + be^(ipi/4) + ci + de^(-i*pi/4)) * 2^power

This representation enables exact computation of phases in ZX-calculus graphs without floating-point errors.

ExactScalarArray

ExactScalarArray(coeffs: Array, power: Array | None = None)

Bases: Module


              flowchart TD
              tsim.core.exact_scalar.ExactScalarArray[ExactScalarArray]

              

              click tsim.core.exact_scalar.ExactScalarArray href "" "tsim.core.exact_scalar.ExactScalarArray"
            

Exact scalar array for ZX-calculus phase arithmetic using dyadic representation.

Represents values of the form (c_0 + c_1·ω + c_2·ω² + c_3·ω³) × 2^power where ω = e^(iπ/4). This enables exact computation without floating-point errors.

Attributes:

Name Type Description
coeffs Array

Array of shape (..., 4) containing dyadic coefficients.

power Array

Array of powers of 2 for scaling.

The value represented is (c_0 + c_1omega + c_2omega^2 + c_3omega^3) * 2^power where omega = e^{ipi/4}.

Source code in src/tsim/core/exact_scalar.py
107
108
109
110
111
112
113
114
115
116
117
def __init__(self, coeffs: Array, power: Array | None = None):
    """Initialize from coefficients and optional power.

    The value represented is (c_0 + c_1*omega + c_2*omega^2 + c_3*omega^3) * 2^power
    where omega = e^{i*pi/4}.
    """
    self.coeffs = coeffs
    if power is None:
        self.power = jnp.zeros(coeffs.shape[:-1], dtype=jnp.int32)
    else:
        self.power = power

__mul__

__mul__(other: ExactScalarArray) -> ExactScalarArray

Element-wise multiplication.

Source code in src/tsim/core/exact_scalar.py
119
120
121
122
123
def __mul__(self, other: "ExactScalarArray") -> "ExactScalarArray":
    """Element-wise multiplication."""
    new_coeffs = _scalar_mul(self.coeffs, other.coeffs)
    new_power = self.power + other.power
    return ExactScalarArray(new_coeffs, new_power)

prod

prod(axis: int = -1) -> ExactScalarArray

Compute product along the specified axis using associative scan.

Returns identity (1+0i with power 0) for empty reductions.

Parameters:

Name Type Description Default
axis int

The axis along which to compute the product.

-1

Returns:

Type Description
ExactScalarArray

ExactScalarArray with the product computed along the axis.

Source code in src/tsim/core/exact_scalar.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def prod(self, axis: int = -1) -> "ExactScalarArray":
    """Compute product along the specified axis using associative scan.

    Returns identity (1+0i with power 0) for empty reductions.

    Args:
        axis: The axis along which to compute the product.

    Returns:
        ExactScalarArray with the product computed along the axis.

    """
    if axis < 0:
        axis += self.power.ndim

    if self.coeffs.shape[axis] == 0:
        # Product of empty sequence is identity: [1, 0, 0, 0] * 2^0
        coeffs_shape = self.coeffs.shape[:axis] + self.coeffs.shape[axis + 1 :]
        result_coeffs = jnp.zeros(coeffs_shape, dtype=self.coeffs.dtype)
        result_coeffs = result_coeffs.at[..., 0].set(1)
        return ExactScalarArray(result_coeffs)

    scanned_power, scanned_coeffs = lax.associative_scan(
        _scalar_mul_with_power, (self.power, self.coeffs), axis=axis
    )
    result_power = jnp.take(scanned_power, indices=-1, axis=axis)
    result_coeffs = jnp.take(scanned_coeffs, indices=-1, axis=axis)

    return ExactScalarArray(result_coeffs, result_power)

sum

sum(axis: int = -1) -> ExactScalarArray

Sum elements along the specified axis using normalized pairwise adds.

Parameters:

Name Type Description Default
axis int

The axis along which to sum.

-1

Returns:

Type Description
ExactScalarArray

ExactScalarArray with the sum computed along the axis.

Source code in src/tsim/core/exact_scalar.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def sum(self, axis: int = -1) -> "ExactScalarArray":
    """Sum elements along the specified axis using normalized pairwise adds.

    Args:
        axis: The axis along which to sum.

    Returns:
        ExactScalarArray with the sum computed along the axis.

    """
    if axis < 0:
        axis += self.power.ndim

    scanned_power, scanned_coeffs = lax.associative_scan(
        _scalar_add_with_power, (self.power, self.coeffs), axis=axis
    )
    result_power = jnp.take(scanned_power, indices=-1, axis=axis)
    result_coeffs = jnp.take(scanned_coeffs, indices=-1, axis=axis)
    return ExactScalarArray(result_coeffs, result_power)

to_complex

to_complex() -> jax.Array

Convert to complex number.

Source code in src/tsim/core/exact_scalar.py
175
176
177
178
179
def to_complex(self) -> jax.Array:
    """Convert to complex number."""
    c_val = _scalar_to_complex(self.coeffs)
    scale = jnp.pow(2.0, self.power)
    return c_val * scale