diff --git a/source/qdk_package/qdk/_interpreter.py b/source/qdk_package/qdk/_interpreter.py index debdd1dc75..8324259256 100644 --- a/source/qdk_package/qdk/_interpreter.py +++ b/source/qdk_package/qdk/_interpreter.py @@ -159,6 +159,15 @@ def _get_default_context() -> Context: return _default_context +def _get_context_or_default(obj: Any) -> Context: + """Returns context associated with given object, if available. + Otherwise falls back to the default context. + """ + if hasattr(obj, "_qdk_context"): + return getattr(obj, "_qdk_context") + return _get_default_context() + + # --------------------------------------------------------------------------- # Functions accessing global context, for compatibility. # --------------------------------------------------------------------------- diff --git a/source/qdk_package/qdk/test_utils.py b/source/qdk_package/qdk/test_utils.py new file mode 100644 index 0000000000..02d83f86c7 --- /dev/null +++ b/source/qdk_package/qdk/test_utils.py @@ -0,0 +1,70 @@ +"""Helper functions for testing Q# code.""" + +from typing import Any + +from qdk._interpreter import _get_context_or_default + +from ._context import Context + + +def dump_operation_on_state( + op: Any, + num_qubits: int, + initial_state: list[float] | None = None, + context: Context | None = None, +) -> list[complex]: + """Returns statevector after applying operation to the given state. + + Uses big-endian convention for basis-state numbering. + + Args: + op: Q# callable from ``Context.code`` or a string that evaluates to + a Q# callable. The callable must have signature + ``(Qubit[] => Unit)``. + num_qubits: Number of qubits the operation acts on. + initial_state: Initial state given by list of `2**num_qubits` real amplitudes. + If the list is shorter, it will be padded with zeros. + If not provided, the initial state is zero state (|00..0>). + context: `qdk.Context` from which the operation was created (optional). If + not provided, will attempt to infer it from `op` and then fall back to + default context. + + Returns: + The state vector as a list of `2**num_qubits` complex numbers. + """ + context = context or _get_context_or_default(op) + if initial_state is None: + initial_state = [1.0] # |00..0> state. + if type(op) is str: + op = context.eval(op) + + if not hasattr(context.code, "_DumpOperationOnState"): + context.eval(""" + operation _DumpOperationOnState( + op : (Qubit[] => Unit), + num_qubits : Int, + initial_state : Double[] + ) : Unit { + use qubits = Qubit[num_qubits]; + if (Length(initial_state) > 1) { + Std.StatePreparation.PreparePureStateD(initial_state, qubits); + } + op(qubits); + Std.Diagnostics.DumpRegister(qubits); + ResetAll(qubits); + } + """) + + result = context.run( + context.code._DumpOperationOnState, + 1, # shots + op, + num_qubits, + initial_state, + save_events=True, + )[0] + state = result["events"][-1].state_dump().get_dict() + statevector = [0.0] * (2**num_qubits) + for index, amplitude in state.items(): + statevector[index] = amplitude + return statevector diff --git a/source/qdk_package/tests/test_test_utils.py b/source/qdk_package/tests/test_test_utils.py new file mode 100644 index 0000000000..957692e660 --- /dev/null +++ b/source/qdk_package/tests/test_test_utils.py @@ -0,0 +1,116 @@ +import math + +from qdk import code, qsharp, Context +from qdk.test_utils import dump_operation_on_state + + +def _assert_states_close(state1: list[complex], state2: list[complex]): + assert len(state1) == len(state2) + for i in range(len(state1)): + assert abs(state1[i] - state2[i]) < 1e-9 + + +def test_dump_operation_on_state(): + qsharp.eval(""" + operation MyOp1(q: Qubit[]) : Unit { + H(q[0]); + CNOT(q[0], q[1]); + Z(q[1]); + } + """) + + vector = dump_operation_on_state(code.MyOp1, num_qubits=2) + s = 0.5**0.5 + _assert_states_close(vector, [s, 0, 0, -s]) + + vector = dump_operation_on_state("MyOp1", num_qubits=2) + _assert_states_close(vector, [s, 0, 0, -s]) + + +def test_dump_operation_on_state_with_two_registers(): + qsharp.eval(""" + operation MyOp2(q1: Qubit[], q2: Qubit[]) : Unit { + H(q1[0]); + CNOT(q1[0], q2[0]); + } + + operation MyOp2_TestHelper(q: Qubit[]) : Unit { + let n = Length(q); + MyOp2(q[0..n/2-1], q[n/2..n-1]); + } + """) + + vector = dump_operation_on_state(code.MyOp2_TestHelper, num_qubits=4) + s = 0.5**0.5 + _assert_states_close(vector, [s, 0, 0, 0, 0, 0, 0, 0, 0, 0, s, 0, 0, 0, 0, 0]) + + +def test_dump_operation_on_state_with_partial_trace(): + qsharp.eval(""" + operation MyOp3(q1: Qubit[], q2: Qubit[]) : Unit { + H(q1[0]); + H(q2[0]); + } + + operation MyOp3_TestHelper(q: Qubit[]) : Unit { + use q2 = Qubit[2]; + MyOp3(q, q2); + ResetAll(q2); + } + """) + + vector = dump_operation_on_state(code.MyOp3_TestHelper, num_qubits=2) + s = 0.5**0.5 + _assert_states_close(vector, [s, 0, s, 0]) + + +def test_dump_operation_on_state_with_initial_state(): + qsharp.eval(""" + operation MyOp4(q: Qubit[]) : Unit is Adj { + CNOT(q[0], q[1]); + H(q[0]); + } + """) + + s = 0.5**0.5 + vector = dump_operation_on_state( + code.MyOp4, num_qubits=2, initial_state=[s, 0, 0, s] + ) + _assert_states_close(vector, [1, 0, 0, 0]) + + +def test_dump_operation_on_state_with_parameters(): + ctx = Context() + ctx.eval(""" + operation MyOp5(qs: Qubit[], angle: Double) : Unit is Adj { + for q in qs { + Rx(angle, q); + } + } + + operation MyOp5_TestHelper(angle: Double) : (Qubit[] => Unit) { + MyOp5(_, angle) + } + """) + + vector = dump_operation_on_state( + ctx.code.MyOp5_TestHelper(0.3), num_qubits=2, context=ctx + ) + c = math.cos(0.3 / 2) + s = math.sin(0.3 / 2) + _assert_states_close(vector, [c * c, -1j * c * s, -1j * c * s, -(s * s)]) + + +def test_dump_operation_on_state_with_parameterized_callable(): + qsharp.eval(""" + operation MyOp5(qs: Qubit[], angle: Double) : Unit is Adj { + for q in qs { + Rx(angle, q); + } + } + """) + + vector = dump_operation_on_state("MyOp5(_, 0.3)", num_qubits=2) + c = math.cos(0.3 / 2) + s = math.sin(0.3 / 2) + _assert_states_close(vector, [c * c, -1j * c * s, -1j * c * s, -(s * s)])