cirq-core 1.1.0.dev20221220224914__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cirq/__init__.py +8 -0
- cirq/_compat.py +29 -4
- cirq/_compat_test.py +24 -26
- cirq/_version.py +32 -1
- cirq/_version_test.py +1 -1
- cirq/circuits/_block_diagram_drawer_test.py +4 -3
- cirq/circuits/circuit.py +109 -63
- cirq/circuits/circuit_operation.py +2 -3
- cirq/circuits/circuit_operation_test.py +4 -4
- cirq/circuits/circuit_test.py +11 -0
- cirq/circuits/frozen_circuit.py +13 -1
- cirq/circuits/frozen_circuit_test.py +5 -1
- cirq/circuits/moment.py +39 -14
- cirq/circuits/moment_test.py +7 -0
- cirq/circuits/text_diagram_drawer.py +1 -1
- cirq/circuits/text_diagram_drawer_test.py +3 -7
- cirq/contrib/acquaintance/bipartite.py +1 -1
- cirq/contrib/acquaintance/devices.py +2 -2
- cirq/contrib/acquaintance/executor.py +5 -2
- cirq/contrib/acquaintance/gates.py +3 -2
- cirq/contrib/acquaintance/permutation.py +13 -2
- cirq/contrib/acquaintance/testing.py +3 -5
- cirq/contrib/paulistring/recombine.py +3 -6
- cirq/contrib/qasm_import/_parser.py +17 -21
- cirq/contrib/qasm_import/_parser_test.py +30 -45
- cirq/contrib/qcircuit/qcircuit_test.py +3 -7
- cirq/contrib/quantum_volume/quantum_volume.py +3 -3
- cirq/contrib/quimb/mps_simulator.py +1 -1
- cirq/contrib/quimb/state_vector.py +2 -0
- cirq/contrib/quirk/quirk_gate.py +1 -0
- cirq/contrib/svg/svg.py +4 -7
- cirq/contrib/svg/svg_test.py +29 -1
- cirq/devices/grid_qubit.py +26 -28
- cirq/devices/grid_qubit_test.py +21 -5
- cirq/devices/line_qubit.py +10 -12
- cirq/devices/line_qubit_test.py +9 -2
- cirq/devices/named_topologies.py +1 -1
- cirq/devices/noise_model.py +4 -1
- cirq/devices/superconducting_qubits_noise_properties.py +1 -3
- cirq/experiments/n_qubit_tomography.py +1 -1
- cirq/experiments/qubit_characterizations.py +2 -2
- cirq/experiments/single_qubit_readout_calibration.py +1 -1
- cirq/experiments/t2_decay_experiment.py +1 -1
- cirq/experiments/xeb_simulation_test.py +2 -2
- cirq/interop/quirk/cells/testing.py +1 -1
- cirq/json_resolver_cache.py +1 -0
- cirq/linalg/__init__.py +2 -0
- cirq/linalg/decompositions_test.py +4 -4
- cirq/linalg/diagonalize_test.py +5 -6
- cirq/linalg/transformations.py +72 -9
- cirq/linalg/transformations_test.py +23 -7
- cirq/ops/__init__.py +4 -0
- cirq/ops/arithmetic_operation.py +4 -6
- cirq/ops/classically_controlled_operation.py +10 -3
- cirq/ops/clifford_gate.py +1 -7
- cirq/ops/common_channels.py +21 -15
- cirq/ops/common_gate_families.py +2 -3
- cirq/ops/common_gates.py +48 -11
- cirq/ops/common_gates_test.py +4 -0
- cirq/ops/controlled_gate.py +44 -18
- cirq/ops/controlled_operation.py +13 -5
- cirq/ops/dense_pauli_string.py +14 -19
- cirq/ops/diagonal_gate.py +3 -4
- cirq/ops/eigen_gate.py +8 -10
- cirq/ops/eigen_gate_test.py +6 -0
- cirq/ops/gate_operation.py +11 -6
- cirq/ops/gate_operation_test.py +11 -2
- cirq/ops/gateset.py +2 -1
- cirq/ops/gateset_test.py +38 -5
- cirq/ops/global_phase_op.py +28 -2
- cirq/ops/global_phase_op_test.py +21 -0
- cirq/ops/identity.py +1 -1
- cirq/ops/kraus_channel_test.py +2 -2
- cirq/ops/linear_combinations.py +7 -6
- cirq/ops/linear_combinations_test.py +26 -10
- cirq/ops/matrix_gates.py +8 -4
- cirq/ops/matrix_gates_test.py +25 -3
- cirq/ops/measure_util.py +13 -5
- cirq/ops/measure_util_test.py +8 -2
- cirq/ops/measurement_gate.py +1 -1
- cirq/ops/measurement_gate_test.py +9 -4
- cirq/ops/mixed_unitary_channel_test.py +4 -4
- cirq/ops/named_qubit.py +2 -4
- cirq/ops/parity_gates.py +5 -1
- cirq/ops/parity_gates_test.py +6 -0
- cirq/ops/pauli_gates.py +9 -9
- cirq/ops/pauli_string.py +4 -2
- cirq/ops/pauli_string_raw_types.py +4 -11
- cirq/ops/pauli_string_test.py +13 -13
- cirq/ops/pauli_sum_exponential.py +6 -1
- cirq/ops/qubit_manager.py +97 -0
- cirq/ops/qubit_manager_test.py +66 -0
- cirq/ops/raw_types.py +75 -33
- cirq/ops/raw_types_test.py +34 -0
- cirq/ops/three_qubit_gates.py +16 -10
- cirq/ops/three_qubit_gates_test.py +4 -2
- cirq/ops/two_qubit_diagonal_gate.py +3 -3
- cirq/ops/wait_gate.py +1 -1
- cirq/protocols/__init__.py +1 -0
- cirq/protocols/act_on_protocol.py +3 -3
- cirq/protocols/act_on_protocol_test.py +5 -5
- cirq/protocols/apply_channel_protocol.py +9 -8
- cirq/protocols/apply_mixture_protocol.py +8 -8
- cirq/protocols/apply_mixture_protocol_test.py +1 -1
- cirq/protocols/apply_unitary_protocol.py +66 -19
- cirq/protocols/apply_unitary_protocol_test.py +50 -0
- cirq/protocols/circuit_diagram_info_protocol.py +7 -9
- cirq/protocols/decompose_protocol.py +167 -125
- cirq/protocols/decompose_protocol_test.py +132 -2
- cirq/protocols/has_stabilizer_effect_protocol.py +2 -1
- cirq/protocols/inverse_protocol.py +2 -2
- cirq/protocols/json_serialization_test.py +3 -3
- cirq/protocols/json_test_data/Linspace.json +20 -7
- cirq/protocols/json_test_data/Linspace.repr +4 -1
- cirq/protocols/json_test_data/Points.json +19 -8
- cirq/protocols/json_test_data/Points.repr +4 -1
- cirq/protocols/json_test_data/Result.repr_inward +1 -1
- cirq/protocols/json_test_data/ResultDict.repr +1 -1
- cirq/protocols/json_test_data/ResultDict.repr_inward +1 -1
- cirq/protocols/json_test_data/TrialResult.repr_inward +1 -1
- cirq/protocols/json_test_data/XPowGate.json +13 -5
- cirq/protocols/json_test_data/XPowGate.repr +1 -1
- cirq/protocols/json_test_data/ZPowGate.json +13 -5
- cirq/protocols/json_test_data/ZPowGate.repr +1 -1
- cirq/protocols/json_test_data/ZipLongest.json +19 -0
- cirq/protocols/json_test_data/ZipLongest.repr +1 -0
- cirq/protocols/json_test_data/spec.py +1 -0
- cirq/protocols/kraus_protocol.py +3 -4
- cirq/protocols/measurement_key_protocol.py +3 -1
- cirq/protocols/mixture_protocol.py +3 -2
- cirq/protocols/phase_protocol.py +3 -3
- cirq/protocols/pow_protocol.py +1 -2
- cirq/protocols/qasm.py +4 -4
- cirq/protocols/qid_shape_protocol.py +8 -8
- cirq/protocols/resolve_parameters.py +8 -3
- cirq/protocols/resolve_parameters_test.py +3 -3
- cirq/protocols/unitary_protocol.py +19 -11
- cirq/protocols/unitary_protocol_test.py +37 -0
- cirq/qis/channels.py +1 -1
- cirq/qis/clifford_tableau.py +4 -5
- cirq/qis/quantum_state_representation.py +7 -9
- cirq/qis/states.py +21 -13
- cirq/qis/states_test.py +7 -0
- cirq/sim/clifford/clifford_simulator.py +3 -3
- cirq/sim/density_matrix_simulation_state.py +2 -1
- cirq/sim/density_matrix_simulator.py +1 -1
- cirq/sim/density_matrix_simulator_test.py +9 -5
- cirq/sim/density_matrix_utils.py +7 -32
- cirq/sim/mux.py +2 -2
- cirq/sim/simulation_state.py +58 -18
- cirq/sim/simulation_state_base.py +5 -2
- cirq/sim/simulation_state_test.py +121 -9
- cirq/sim/simulation_utils.py +59 -0
- cirq/sim/simulation_utils_test.py +32 -0
- cirq/sim/simulator.py +2 -1
- cirq/sim/simulator_base_test.py +3 -3
- cirq/sim/sparse_simulator.py +1 -1
- cirq/sim/sparse_simulator_test.py +5 -5
- cirq/sim/state_vector.py +7 -36
- cirq/sim/state_vector_simulation_state.py +18 -1
- cirq/sim/state_vector_simulator.py +3 -2
- cirq/sim/state_vector_simulator_test.py +24 -2
- cirq/sim/state_vector_test.py +46 -15
- cirq/study/__init__.py +1 -0
- cirq/study/flatten_expressions.py +2 -2
- cirq/study/resolver.py +2 -0
- cirq/study/resolver_test.py +1 -1
- cirq/study/result.py +1 -1
- cirq/study/sweeps.py +103 -9
- cirq/study/sweeps_test.py +64 -0
- cirq/testing/__init__.py +4 -0
- cirq/testing/circuit_compare.py +15 -18
- cirq/testing/consistent_act_on.py +4 -4
- cirq/testing/consistent_controlled_gate_op_test.py +1 -1
- cirq/testing/consistent_decomposition.py +11 -2
- cirq/testing/consistent_decomposition_test.py +8 -1
- cirq/testing/consistent_protocols.py +2 -0
- cirq/testing/consistent_protocols_test.py +8 -4
- cirq/testing/consistent_qasm.py +8 -15
- cirq/testing/consistent_specified_has_unitary.py +1 -1
- cirq/testing/consistent_unitary.py +85 -0
- cirq/testing/consistent_unitary_test.py +96 -0
- cirq/testing/equivalent_repr_eval.py +10 -10
- cirq/testing/json.py +3 -3
- cirq/testing/logs.py +1 -1
- cirq/testing/order_tester.py +4 -5
- cirq/testing/random_circuit.py +3 -5
- cirq/testing/sample_gates.py +79 -0
- cirq/testing/sample_gates_test.py +59 -0
- cirq/transformers/__init__.py +2 -0
- cirq/transformers/analytical_decompositions/__init__.py +8 -0
- cirq/transformers/analytical_decompositions/pauli_string_decomposition.py +130 -0
- cirq/transformers/analytical_decompositions/pauli_string_decomposition_test.py +58 -0
- cirq/transformers/analytical_decompositions/quantum_shannon_decomposition.py +230 -0
- cirq/transformers/analytical_decompositions/quantum_shannon_decomposition_test.py +112 -0
- cirq/transformers/analytical_decompositions/three_qubit_decomposition_test.py +1 -3
- cirq/transformers/analytical_decompositions/two_qubit_to_fsim.py +1 -1
- cirq/transformers/expand_composite.py +1 -1
- cirq/transformers/heuristic_decompositions/gate_tabulation_math_utils.py +4 -4
- cirq/transformers/measurement_transformers.py +4 -4
- cirq/transformers/merge_single_qubit_gates.py +17 -4
- cirq/transformers/routing/route_circuit_cqc.py +2 -2
- cirq/transformers/stratify.py +125 -62
- cirq/transformers/stratify_test.py +20 -16
- cirq/transformers/transformer_api.py +1 -1
- cirq/transformers/transformer_primitives.py +3 -2
- cirq/transformers/transformer_primitives_test.py +11 -0
- cirq/value/abc_alt.py +3 -2
- cirq/value/abc_alt_test.py +1 -0
- cirq/value/classical_data.py +10 -10
- cirq/value/digits.py +2 -2
- cirq/value/linear_dict.py +18 -19
- cirq/value/product_state.py +7 -6
- cirq/value/value_equality_attr.py +2 -2
- cirq/vis/heatmap.py +1 -1
- cirq/vis/heatmap_test.py +2 -2
- cirq/work/collector.py +2 -2
- cirq/work/observable_measurement_data.py +5 -5
- cirq/work/observable_readout_calibration.py +3 -1
- cirq/work/observable_settings.py +1 -1
- cirq/work/pauli_sum_collector.py +9 -8
- cirq/work/sampler.py +2 -0
- cirq/work/zeros_sampler.py +2 -2
- {cirq_core-1.1.0.dev20221220224914.dist-info → cirq_core-1.2.0.dist-info}/METADATA +7 -15
- {cirq_core-1.1.0.dev20221220224914.dist-info → cirq_core-1.2.0.dist-info}/RECORD +228 -214
- {cirq_core-1.1.0.dev20221220224914.dist-info → cirq_core-1.2.0.dist-info}/WHEEL +1 -1
- {cirq_core-1.1.0.dev20221220224914.dist-info → cirq_core-1.2.0.dist-info}/LICENSE +0 -0
- {cirq_core-1.1.0.dev20221220224914.dist-info → cirq_core-1.2.0.dist-info}/top_level.txt +0 -0
cirq/sim/simulator_base_test.py
CHANGED
|
@@ -34,7 +34,7 @@ class CountingState(cirq.qis.QuantumStateRepresentation):
|
|
|
34
34
|
self.measurement_count += 1
|
|
35
35
|
return [self.gate_count]
|
|
36
36
|
|
|
37
|
-
def kron(self
|
|
37
|
+
def kron(self, other: 'CountingState') -> 'CountingState':
|
|
38
38
|
return CountingState(
|
|
39
39
|
self.data,
|
|
40
40
|
self.gate_count + other.gate_count,
|
|
@@ -43,13 +43,13 @@ class CountingState(cirq.qis.QuantumStateRepresentation):
|
|
|
43
43
|
)
|
|
44
44
|
|
|
45
45
|
def factor(
|
|
46
|
-
self
|
|
46
|
+
self, axes: Sequence[int], *, validate=True, atol=1e-07
|
|
47
47
|
) -> Tuple['CountingState', 'CountingState']:
|
|
48
48
|
return CountingState(
|
|
49
49
|
self.data, self.gate_count, self.measurement_count, self.copy_count
|
|
50
50
|
), CountingState(self.data)
|
|
51
51
|
|
|
52
|
-
def reindex(self
|
|
52
|
+
def reindex(self, axes: Sequence[int]) -> 'CountingState':
|
|
53
53
|
return CountingState(self.data, self.gate_count, self.measurement_count, self.copy_count)
|
|
54
54
|
|
|
55
55
|
def copy(self, deep_copy_buffers: bool = True) -> 'CountingState':
|
cirq/sim/sparse_simulator.py
CHANGED
|
@@ -282,5 +282,5 @@ class SparseSimulatorStep(
|
|
|
282
282
|
# Dtype doesn't have a good repr, so we work around by invoking __name__.
|
|
283
283
|
return (
|
|
284
284
|
f'cirq.SparseSimulatorStep(sim_state={self._sim_state!r},'
|
|
285
|
-
f' dtype=np.{self._dtype
|
|
285
|
+
f' dtype=np.{np.dtype(self._dtype)!r})'
|
|
286
286
|
)
|
|
@@ -255,7 +255,7 @@ def test_run_mixture(dtype: Type[np.complexfloating], split: bool):
|
|
|
255
255
|
simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
|
|
256
256
|
circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0))
|
|
257
257
|
result = simulator.run(circuit, repetitions=100)
|
|
258
|
-
assert 20 < sum(result.measurements['q(0)'])[0] < 80
|
|
258
|
+
assert 20 < sum(result.measurements['q(0)'])[0] < 80
|
|
259
259
|
|
|
260
260
|
|
|
261
261
|
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
|
|
@@ -265,8 +265,8 @@ def test_run_mixture_with_gates(dtype: Type[np.complexfloating], split: bool):
|
|
|
265
265
|
simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split, seed=23)
|
|
266
266
|
circuit = cirq.Circuit(cirq.H(q0), cirq.phase_flip(0.5)(q0), cirq.H(q0), cirq.measure(q0))
|
|
267
267
|
result = simulator.run(circuit, repetitions=100)
|
|
268
|
-
assert sum(result.measurements['q(0)'])[0] < 80
|
|
269
|
-
assert sum(result.measurements['q(0)'])[0] > 20
|
|
268
|
+
assert sum(result.measurements['q(0)'])[0] < 80
|
|
269
|
+
assert sum(result.measurements['q(0)'])[0] > 20
|
|
270
270
|
|
|
271
271
|
|
|
272
272
|
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
|
|
@@ -775,9 +775,9 @@ def test_sparse_simulator_repr():
|
|
|
775
775
|
# No equality so cannot use cirq.testing.assert_equivalent_repr
|
|
776
776
|
assert (
|
|
777
777
|
repr(step) == "cirq.SparseSimulatorStep(sim_state=cirq.StateVectorSimulationState("
|
|
778
|
-
"initial_state=np.array([[0j, (1+0j)], [0j, 0j]], dtype=np.complex64), "
|
|
778
|
+
"initial_state=np.array([[0j, (1+0j)], [0j, 0j]], dtype=np.dtype('complex64')), "
|
|
779
779
|
"qubits=(cirq.LineQubit(0), cirq.LineQubit(1)), "
|
|
780
|
-
"classical_data=cirq.ClassicalDataDictionaryStore()), dtype=np.complex64)"
|
|
780
|
+
"classical_data=cirq.ClassicalDataDictionaryStore()), dtype=np.dtype('complex64'))"
|
|
781
781
|
)
|
|
782
782
|
|
|
783
783
|
|
cirq/sim/state_vector.py
CHANGED
|
@@ -19,7 +19,7 @@ from typing import List, Mapping, Optional, Tuple, TYPE_CHECKING, Sequence
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
|
|
21
21
|
from cirq import linalg, qis, value
|
|
22
|
-
from cirq.sim import simulator
|
|
22
|
+
from cirq.sim import simulator, simulation_utils
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
25
|
import cirq
|
|
@@ -35,7 +35,7 @@ class StateVectorMixin:
|
|
|
35
35
|
"""Inits StateVectorMixin.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
qubit_map: A map from the Qubits in the Circuit to the
|
|
38
|
+
qubit_map: A map from the Qubits in the Circuit to the index
|
|
39
39
|
of this qubit for a canonical ordering. This canonical ordering
|
|
40
40
|
is used to define the state (see the state_vector() method).
|
|
41
41
|
*args: Passed on to the class that this is mixed in with.
|
|
@@ -102,7 +102,7 @@ class StateVectorMixin:
|
|
|
102
102
|
and non-zero floats of the specified accuracy."""
|
|
103
103
|
return qis.dirac_notation(self.state_vector(), decimals, qid_shape=self._qid_shape)
|
|
104
104
|
|
|
105
|
-
def density_matrix_of(self, qubits: List['cirq.Qid'] = None) -> np.ndarray:
|
|
105
|
+
def density_matrix_of(self, qubits: Optional[List['cirq.Qid']] = None) -> np.ndarray:
|
|
106
106
|
r"""Returns the density matrix of the state.
|
|
107
107
|
|
|
108
108
|
Calculate the density matrix for the system on the qubits provided.
|
|
@@ -215,7 +215,8 @@ def sample_state_vector(
|
|
|
215
215
|
prng = value.parse_random_state(seed)
|
|
216
216
|
|
|
217
217
|
# Calculate the measurement probabilities.
|
|
218
|
-
probs =
|
|
218
|
+
probs = (state_vector * state_vector.conj()).real
|
|
219
|
+
probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape)
|
|
219
220
|
|
|
220
221
|
# We now have the probability vector, correctly ordered, so sample over
|
|
221
222
|
# it. Note that we us ints here, since numpy's choice does not allow for
|
|
@@ -288,7 +289,8 @@ def measure_state_vector(
|
|
|
288
289
|
initial_shape = state_vector.shape
|
|
289
290
|
|
|
290
291
|
# Calculate the measurement probabilities and then make the measurement.
|
|
291
|
-
probs =
|
|
292
|
+
probs = (state_vector * state_vector.conj()).real
|
|
293
|
+
probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape)
|
|
292
294
|
result = prng.choice(len(probs), p=probs)
|
|
293
295
|
###measurement_bits = [(1 & (result >> i)) for i in range(len(indices))]
|
|
294
296
|
# Convert to individual qudit measurements.
|
|
@@ -321,34 +323,3 @@ def measure_state_vector(
|
|
|
321
323
|
assert out is not None
|
|
322
324
|
# We mutate and return out, so mypy cannot identify that the out cannot be None.
|
|
323
325
|
return measurement_bits, out
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray:
|
|
327
|
-
"""Returns the probabilities for a measurement on the given indices."""
|
|
328
|
-
tensor = np.reshape(state, qid_shape)
|
|
329
|
-
# Calculate the probabilities for measuring the particular results.
|
|
330
|
-
if len(indices) == len(qid_shape):
|
|
331
|
-
# We're measuring every qudit, so no need for fancy indexing
|
|
332
|
-
probs = np.abs(tensor) ** 2
|
|
333
|
-
probs = np.transpose(probs, indices)
|
|
334
|
-
probs = probs.reshape(-1)
|
|
335
|
-
else:
|
|
336
|
-
# Fancy indexing required
|
|
337
|
-
meas_shape = tuple(qid_shape[i] for i in indices)
|
|
338
|
-
probs = (
|
|
339
|
-
np.abs(
|
|
340
|
-
[
|
|
341
|
-
tensor[
|
|
342
|
-
linalg.slice_for_qubits_equal_to(
|
|
343
|
-
indices, big_endian_qureg_value=b, qid_shape=qid_shape
|
|
344
|
-
)
|
|
345
|
-
]
|
|
346
|
-
for b in range(np.prod(meas_shape, dtype=np.int64))
|
|
347
|
-
]
|
|
348
|
-
)
|
|
349
|
-
** 2
|
|
350
|
-
)
|
|
351
|
-
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))
|
|
352
|
-
|
|
353
|
-
# To deal with rounding issues, ensure that the probabilities sum to 1.
|
|
354
|
-
return probs / np.sum(probs)
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
14
15
|
"""Objects and methods for acting efficiently on a state vector."""
|
|
15
16
|
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union
|
|
16
17
|
|
|
@@ -355,6 +356,22 @@ class StateVectorSimulationState(SimulationState[_BufferedStateVector]):
|
|
|
355
356
|
)
|
|
356
357
|
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
|
|
357
358
|
|
|
359
|
+
def add_qubits(self, qubits: Sequence['cirq.Qid']):
|
|
360
|
+
ret = super().add_qubits(qubits)
|
|
361
|
+
return (
|
|
362
|
+
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
|
|
363
|
+
if ret is NotImplemented
|
|
364
|
+
else ret
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
def remove_qubits(self, qubits: Sequence['cirq.Qid']):
|
|
368
|
+
ret = super().remove_qubits(qubits)
|
|
369
|
+
if ret is not NotImplemented:
|
|
370
|
+
return ret
|
|
371
|
+
extracted, remainder = self.factor(qubits, inplace=True)
|
|
372
|
+
remainder._state._state_vector *= extracted._state._state_vector.reshape((-1,))[0]
|
|
373
|
+
return remainder
|
|
374
|
+
|
|
358
375
|
def _act_on_fallback_(
|
|
359
376
|
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
|
|
360
377
|
) -> bool:
|
|
@@ -377,7 +394,7 @@ class StateVectorSimulationState(SimulationState[_BufferedStateVector]):
|
|
|
377
394
|
raise TypeError(
|
|
378
395
|
"Can't simulate operations that don't implement "
|
|
379
396
|
"SupportsUnitary, SupportsConsistentApplyUnitary, "
|
|
380
|
-
"SupportsMixture or is a measurement: {!r}"
|
|
397
|
+
f"SupportsMixture or is a measurement: {action!r}"
|
|
381
398
|
)
|
|
382
399
|
|
|
383
400
|
def __repr__(self) -> str:
|
|
@@ -20,6 +20,7 @@ import numpy as np
|
|
|
20
20
|
|
|
21
21
|
from cirq import _compat, ops, value, qis
|
|
22
22
|
from cirq.sim import simulator, state_vector, simulator_base
|
|
23
|
+
from cirq.protocols import qid_shape
|
|
23
24
|
|
|
24
25
|
if TYPE_CHECKING:
|
|
25
26
|
import cirq
|
|
@@ -31,7 +32,7 @@ TStateVectorStepResult = TypeVar('TStateVectorStepResult', bound='StateVectorSte
|
|
|
31
32
|
class SimulatesIntermediateStateVector(
|
|
32
33
|
Generic[TStateVectorStepResult],
|
|
33
34
|
simulator_base.SimulatorBase[
|
|
34
|
-
TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState'
|
|
35
|
+
TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState'
|
|
35
36
|
],
|
|
36
37
|
simulator.SimulatesAmplitudes,
|
|
37
38
|
metaclass=abc.ABCMeta,
|
|
@@ -172,7 +173,7 @@ class StateVectorTrialResult(
|
|
|
172
173
|
size = np.prod(shape, dtype=np.int64)
|
|
173
174
|
final = final.reshape(size)
|
|
174
175
|
if len([1 for e in final if abs(e) > 0.001]) < 16:
|
|
175
|
-
state_vector = qis.dirac_notation(final, 3)
|
|
176
|
+
state_vector = qis.dirac_notation(final, 3, qid_shape(substate.qubits))
|
|
176
177
|
else:
|
|
177
178
|
state_vector = str(final)
|
|
178
179
|
label = f'qubits: {substate.qubits}' if substate.qubits else 'phase:'
|
|
@@ -35,9 +35,9 @@ def test_state_vector_trial_result_repr():
|
|
|
35
35
|
expected_repr = (
|
|
36
36
|
"cirq.StateVectorTrialResult("
|
|
37
37
|
"params=cirq.ParamResolver({'s': 1}), "
|
|
38
|
-
"measurements={'m': np.array([[1]], dtype=np.int32)}, "
|
|
38
|
+
"measurements={'m': np.array([[1]], dtype=np.dtype('int32'))}, "
|
|
39
39
|
"final_simulator_state=cirq.StateVectorSimulationState("
|
|
40
|
-
"initial_state=np.array([0j, (1+0j)], dtype=np.complex64), "
|
|
40
|
+
"initial_state=np.array([0j, (1+0j)], dtype=np.dtype('complex64')), "
|
|
41
41
|
"qubits=(cirq.NamedQubit('a'),), "
|
|
42
42
|
"classical_data=cirq.ClassicalDataDictionaryStore()))"
|
|
43
43
|
)
|
|
@@ -159,6 +159,28 @@ def test_str_big():
|
|
|
159
159
|
assert 'output vector: [0.03125+0.j 0.03125+0.j 0.03125+0.j ..' in str(result)
|
|
160
160
|
|
|
161
161
|
|
|
162
|
+
def test_str_qudit():
|
|
163
|
+
qutrit = cirq.LineQid(0, dimension=3)
|
|
164
|
+
final_simulator_state = cirq.StateVectorSimulationState(
|
|
165
|
+
prng=np.random.RandomState(0),
|
|
166
|
+
qubits=[qutrit],
|
|
167
|
+
initial_state=np.array([0, 0, 1]),
|
|
168
|
+
dtype=np.complex64,
|
|
169
|
+
)
|
|
170
|
+
result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state)
|
|
171
|
+
assert "|2⟩" in str(result)
|
|
172
|
+
|
|
173
|
+
ququart = cirq.LineQid(0, dimension=4)
|
|
174
|
+
final_simulator_state = cirq.StateVectorSimulationState(
|
|
175
|
+
prng=np.random.RandomState(0),
|
|
176
|
+
qubits=[ququart],
|
|
177
|
+
initial_state=np.array([0, 1, 0, 0]),
|
|
178
|
+
dtype=np.complex64,
|
|
179
|
+
)
|
|
180
|
+
result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state)
|
|
181
|
+
assert "|1⟩" in str(result)
|
|
182
|
+
|
|
183
|
+
|
|
162
184
|
def test_pretty_print():
|
|
163
185
|
final_simulator_state = cirq.StateVectorSimulationState(
|
|
164
186
|
available_buffer=np.array([1]),
|
cirq/sim/state_vector_test.py
CHANGED
|
@@ -21,6 +21,7 @@ import numpy as np
|
|
|
21
21
|
|
|
22
22
|
import cirq
|
|
23
23
|
import cirq.testing
|
|
24
|
+
from cirq import linalg
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
def test_state_mixin():
|
|
@@ -172,7 +173,9 @@ def test_sample_no_indices_repetitions():
|
|
|
172
173
|
)
|
|
173
174
|
|
|
174
175
|
|
|
175
|
-
|
|
176
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
177
|
+
def test_measure_state_computational_basis(use_np_transpose: bool):
|
|
178
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
176
179
|
results = []
|
|
177
180
|
for x in range(8):
|
|
178
181
|
initial_state = cirq.to_valid_state_vector(x, 3)
|
|
@@ -183,7 +186,9 @@ def test_measure_state_computational_basis():
|
|
|
183
186
|
assert results == expected
|
|
184
187
|
|
|
185
188
|
|
|
186
|
-
|
|
189
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
190
|
+
def test_measure_state_reshape(use_np_transpose: bool):
|
|
191
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
187
192
|
results = []
|
|
188
193
|
for x in range(8):
|
|
189
194
|
initial_state = np.reshape(cirq.to_valid_state_vector(x, 3), [2] * 3)
|
|
@@ -194,7 +199,9 @@ def test_measure_state_reshape():
|
|
|
194
199
|
assert results == expected
|
|
195
200
|
|
|
196
201
|
|
|
197
|
-
|
|
202
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
203
|
+
def test_measure_state_partial_indices(use_np_transpose: bool):
|
|
204
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
198
205
|
for index in range(3):
|
|
199
206
|
for x in range(8):
|
|
200
207
|
initial_state = cirq.to_valid_state_vector(x, 3)
|
|
@@ -203,7 +210,9 @@ def test_measure_state_partial_indices():
|
|
|
203
210
|
assert bits == [bool(1 & (x >> (2 - index)))]
|
|
204
211
|
|
|
205
212
|
|
|
206
|
-
|
|
213
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
214
|
+
def test_measure_state_partial_indices_order(use_np_transpose: bool):
|
|
215
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
207
216
|
for x in range(8):
|
|
208
217
|
initial_state = cirq.to_valid_state_vector(x, 3)
|
|
209
218
|
bits, state = cirq.measure_state_vector(initial_state, [2, 1])
|
|
@@ -211,7 +220,9 @@ def test_measure_state_partial_indices_order():
|
|
|
211
220
|
assert bits == [bool(1 & (x >> 0)), bool(1 & (x >> 1))]
|
|
212
221
|
|
|
213
222
|
|
|
214
|
-
|
|
223
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
224
|
+
def test_measure_state_partial_indices_all_orders(use_np_transpose: bool):
|
|
225
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
215
226
|
for perm in itertools.permutations([0, 1, 2]):
|
|
216
227
|
for x in range(8):
|
|
217
228
|
initial_state = cirq.to_valid_state_vector(x, 3)
|
|
@@ -220,7 +231,9 @@ def test_measure_state_partial_indices_all_orders():
|
|
|
220
231
|
assert bits == [bool(1 & (x >> (2 - p))) for p in perm]
|
|
221
232
|
|
|
222
233
|
|
|
223
|
-
|
|
234
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
235
|
+
def test_measure_state_collapse(use_np_transpose: bool):
|
|
236
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
224
237
|
initial_state = np.zeros(8, dtype=np.complex64)
|
|
225
238
|
initial_state[0] = 1 / np.sqrt(2)
|
|
226
239
|
initial_state[2] = 1 / np.sqrt(2)
|
|
@@ -243,7 +256,9 @@ def test_measure_state_collapse():
|
|
|
243
256
|
assert bits == [False]
|
|
244
257
|
|
|
245
258
|
|
|
246
|
-
|
|
259
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
260
|
+
def test_measure_state_seed(use_np_transpose: bool):
|
|
261
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
247
262
|
n = 10
|
|
248
263
|
initial_state = np.ones(2**n) / 2 ** (n / 2)
|
|
249
264
|
|
|
@@ -262,7 +277,9 @@ def test_measure_state_seed():
|
|
|
262
277
|
np.testing.assert_allclose(state1, state2)
|
|
263
278
|
|
|
264
279
|
|
|
265
|
-
|
|
280
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
281
|
+
def test_measure_state_out_is_state(use_np_transpose: bool):
|
|
282
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
266
283
|
initial_state = np.zeros(8, dtype=np.complex64)
|
|
267
284
|
initial_state[0] = 1 / np.sqrt(2)
|
|
268
285
|
initial_state[2] = 1 / np.sqrt(2)
|
|
@@ -273,7 +290,9 @@ def test_measure_state_out_is_state():
|
|
|
273
290
|
assert state is initial_state
|
|
274
291
|
|
|
275
292
|
|
|
276
|
-
|
|
293
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
294
|
+
def test_measure_state_out_is_not_state(use_np_transpose: bool):
|
|
295
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
277
296
|
initial_state = np.zeros(8, dtype=np.complex64)
|
|
278
297
|
initial_state[0] = 1 / np.sqrt(2)
|
|
279
298
|
initial_state[2] = 1 / np.sqrt(2)
|
|
@@ -283,14 +302,18 @@ def test_measure_state_out_is_not_state():
|
|
|
283
302
|
assert out is state
|
|
284
303
|
|
|
285
304
|
|
|
286
|
-
|
|
305
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
306
|
+
def test_measure_state_not_power_of_two(use_np_transpose: bool):
|
|
307
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
287
308
|
with pytest.raises(ValueError, match='3'):
|
|
288
309
|
_, _ = cirq.measure_state_vector(np.array([1, 0, 0]), [1])
|
|
289
310
|
with pytest.raises(ValueError, match='5'):
|
|
290
311
|
cirq.measure_state_vector(np.array([0, 1, 0, 0, 0]), [1])
|
|
291
312
|
|
|
292
313
|
|
|
293
|
-
|
|
314
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
315
|
+
def test_measure_state_index_out_of_range(use_np_transpose: bool):
|
|
316
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
294
317
|
state = cirq.to_valid_state_vector(0, 3)
|
|
295
318
|
with pytest.raises(IndexError, match='-2'):
|
|
296
319
|
cirq.measure_state_vector(state, [-2])
|
|
@@ -298,14 +321,18 @@ def test_measure_state_index_out_of_range():
|
|
|
298
321
|
cirq.measure_state_vector(state, [3])
|
|
299
322
|
|
|
300
323
|
|
|
301
|
-
|
|
324
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
325
|
+
def test_measure_state_no_indices(use_np_transpose: bool):
|
|
326
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
302
327
|
initial_state = cirq.to_valid_state_vector(0, 3)
|
|
303
328
|
bits, state = cirq.measure_state_vector(initial_state, [])
|
|
304
329
|
assert [] == bits
|
|
305
330
|
np.testing.assert_almost_equal(state, initial_state)
|
|
306
331
|
|
|
307
332
|
|
|
308
|
-
|
|
333
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
334
|
+
def test_measure_state_no_indices_out_is_state(use_np_transpose: bool):
|
|
335
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
309
336
|
initial_state = cirq.to_valid_state_vector(0, 3)
|
|
310
337
|
bits, state = cirq.measure_state_vector(initial_state, [], out=initial_state)
|
|
311
338
|
assert [] == bits
|
|
@@ -313,7 +340,9 @@ def test_measure_state_no_indices_out_is_state():
|
|
|
313
340
|
assert state is initial_state
|
|
314
341
|
|
|
315
342
|
|
|
316
|
-
|
|
343
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
344
|
+
def test_measure_state_no_indices_out_is_not_state(use_np_transpose: bool):
|
|
345
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
317
346
|
initial_state = cirq.to_valid_state_vector(0, 3)
|
|
318
347
|
out = np.zeros_like(initial_state)
|
|
319
348
|
bits, state = cirq.measure_state_vector(initial_state, [], out=out)
|
|
@@ -323,7 +352,9 @@ def test_measure_state_no_indices_out_is_not_state():
|
|
|
323
352
|
assert out is not initial_state
|
|
324
353
|
|
|
325
354
|
|
|
326
|
-
|
|
355
|
+
@pytest.mark.parametrize('use_np_transpose', [False, True])
|
|
356
|
+
def test_measure_state_empty_state(use_np_transpose: bool):
|
|
357
|
+
linalg.can_numpy_support_shape = lambda s: use_np_transpose
|
|
327
358
|
initial_state = np.array([1.0])
|
|
328
359
|
bits, state = cirq.measure_state_vector(initial_state, [])
|
|
329
360
|
assert [] == bits
|
cirq/study/__init__.py
CHANGED
|
@@ -40,7 +40,7 @@ def flatten(val: Any) -> Tuple[Any, 'ExpressionMap']:
|
|
|
40
40
|
the name to avoid collision: `sympy.Symbol('<x + 1>_1')`.
|
|
41
41
|
|
|
42
42
|
This function also creates a dictionary mapping from expressions and symbols
|
|
43
|
-
in `val` to the new symbols in the flattened copy of `val`. E.g
|
|
43
|
+
in `val` to the new symbols in the flattened copy of `val`. E.g.
|
|
44
44
|
`cirq.ExpressionMap({sympy.Symbol('x')+1: sympy.Symbol('<x + 1>')})`. This
|
|
45
45
|
`ExpressionMap` can be used to transform a sweep over the symbols in `val`
|
|
46
46
|
to a sweep over the flattened symbols e.g. a sweep over `sympy.Symbol('x')`
|
|
@@ -200,7 +200,7 @@ class _ParamFlattener(resolver.ParamResolver):
|
|
|
200
200
|
self,
|
|
201
201
|
param_dict: Optional[resolver.ParamResolverOrSimilarType] = None,
|
|
202
202
|
*, # Force keyword args
|
|
203
|
-
get_param_name: Callable[[sympy.Expr], str
|
|
203
|
+
get_param_name: Optional[Callable[[sympy.Expr], str]] = None,
|
|
204
204
|
):
|
|
205
205
|
"""Initializes a new _ParamFlattener.
|
|
206
206
|
|
cirq/study/resolver.py
CHANGED
|
@@ -150,6 +150,8 @@ class ParamResolver:
|
|
|
150
150
|
# The following resolves common sympy expressions
|
|
151
151
|
# If sympy did its job and wasn't slower than molasses,
|
|
152
152
|
# we wouldn't need the following block.
|
|
153
|
+
if isinstance(value, sympy.Float):
|
|
154
|
+
return float(value)
|
|
153
155
|
if isinstance(value, sympy.Add):
|
|
154
156
|
summation = self.value_of(value.args[0], recursive)
|
|
155
157
|
for addend in value.args[1:]:
|
cirq/study/resolver_test.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
"""Tests for parameter resolvers."""
|
|
16
|
+
|
|
16
17
|
import fractions
|
|
17
18
|
|
|
18
19
|
import numpy as np
|
|
@@ -241,7 +242,6 @@ def test_custom_resolved_value():
|
|
|
241
242
|
assert r.value_of(b) == 'Baz'
|
|
242
243
|
|
|
243
244
|
|
|
244
|
-
@pytest.mark.xfail(reason='this test requires sympy 1.12', strict=True)
|
|
245
245
|
def test_custom_value_not_implemented():
|
|
246
246
|
class Bar:
|
|
247
247
|
def _resolved_value_(self):
|
cirq/study/result.py
CHANGED
|
@@ -140,7 +140,7 @@ class Result(abc.ABC):
|
|
|
140
140
|
basis = 2 ** np.arange(n, dtype=dtype)[::-1]
|
|
141
141
|
converted_dict[key] = np.sum(basis * bitstrings, axis=1)
|
|
142
142
|
|
|
143
|
-
# Use objects to
|
|
143
|
+
# Use objects to accommodate more than 64 qubits if needed.
|
|
144
144
|
dtype = object if any(bs.shape[1] > 63 for _, bs in measurements.items()) else np.int64
|
|
145
145
|
return pd.DataFrame(converted_dict, dtype=dtype)
|
|
146
146
|
|
cirq/study/sweeps.py
CHANGED
|
@@ -18,6 +18,7 @@ from typing import (
|
|
|
18
18
|
Iterable,
|
|
19
19
|
Iterator,
|
|
20
20
|
List,
|
|
21
|
+
Optional,
|
|
21
22
|
overload,
|
|
22
23
|
Sequence,
|
|
23
24
|
TYPE_CHECKING,
|
|
@@ -127,7 +128,7 @@ class Sweep(metaclass=abc.ABCMeta):
|
|
|
127
128
|
def __getitem__(self, val: slice) -> 'Sweep':
|
|
128
129
|
pass
|
|
129
130
|
|
|
130
|
-
def __getitem__(self, val):
|
|
131
|
+
def __getitem__(self, val: Union[int, slice]) -> Union[resolver.ParamResolver, 'Sweep']:
|
|
131
132
|
n = len(self)
|
|
132
133
|
if isinstance(val, int):
|
|
133
134
|
if val < -n or val >= n:
|
|
@@ -292,7 +293,7 @@ class Zip(Sweep):
|
|
|
292
293
|
self.sweeps = sweeps
|
|
293
294
|
|
|
294
295
|
def __eq__(self, other):
|
|
295
|
-
if
|
|
296
|
+
if type(other) is not Zip:
|
|
296
297
|
return NotImplemented
|
|
297
298
|
return self.sweeps == other.sweeps
|
|
298
299
|
|
|
@@ -327,7 +328,62 @@ class Zip(Sweep):
|
|
|
327
328
|
|
|
328
329
|
@classmethod
|
|
329
330
|
def _from_json_dict_(cls, sweeps, **kwargs):
|
|
330
|
-
return
|
|
331
|
+
return cls(*sweeps)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class ZipLongest(Zip):
|
|
335
|
+
"""Iterate over constituent sweeps in parallel
|
|
336
|
+
|
|
337
|
+
Analogous to itertools.zip_longest.
|
|
338
|
+
Note that we iterate until all sweeps terminate,
|
|
339
|
+
so if the sweeps are different lengths, the
|
|
340
|
+
shorter sweeps will be filled by repeating their last value
|
|
341
|
+
until all sweeps have equal length.
|
|
342
|
+
|
|
343
|
+
Note that this is different from itertools.zip_longest,
|
|
344
|
+
which uses a fixed fill value.
|
|
345
|
+
|
|
346
|
+
Raises:
|
|
347
|
+
ValueError if an input sweep if completely empty.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
def __init__(self, *sweeps: Sweep) -> None:
|
|
351
|
+
super().__init__(*sweeps)
|
|
352
|
+
if any(len(sweep) == 0 for sweep in self.sweeps):
|
|
353
|
+
raise ValueError('All sweeps must be non-empty for ZipLongest')
|
|
354
|
+
|
|
355
|
+
def __eq__(self, other):
|
|
356
|
+
if not isinstance(other, ZipLongest):
|
|
357
|
+
return NotImplemented
|
|
358
|
+
return self.sweeps == other.sweeps
|
|
359
|
+
|
|
360
|
+
def __len__(self) -> int:
|
|
361
|
+
if not self.sweeps:
|
|
362
|
+
return 0
|
|
363
|
+
return max(len(sweep) for sweep in self.sweeps)
|
|
364
|
+
|
|
365
|
+
def __hash__(self) -> int:
|
|
366
|
+
return hash((self.__class__.__name__, tuple(self.sweeps)))
|
|
367
|
+
|
|
368
|
+
def __repr__(self) -> str:
|
|
369
|
+
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
|
|
370
|
+
return f'cirq_google.ZipLongest({sweeps_repr})'
|
|
371
|
+
|
|
372
|
+
def __str__(self) -> str:
|
|
373
|
+
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
|
|
374
|
+
return f'ZipLongest({sweeps_repr})'
|
|
375
|
+
|
|
376
|
+
def param_tuples(self) -> Iterator[Params]:
|
|
377
|
+
def _iter_and_repeat_last(one_iter: Iterator[Params]):
|
|
378
|
+
last = None
|
|
379
|
+
for last in one_iter:
|
|
380
|
+
yield last
|
|
381
|
+
while True:
|
|
382
|
+
yield last
|
|
383
|
+
|
|
384
|
+
iters = [_iter_and_repeat_last(sweep.param_tuples()) for sweep in self.sweeps]
|
|
385
|
+
for values in itertools.islice(zip(*iters), len(self)):
|
|
386
|
+
yield tuple(item for value in values for item in value)
|
|
331
387
|
|
|
332
388
|
|
|
333
389
|
class SingleSweep(Sweep):
|
|
@@ -366,9 +422,23 @@ class SingleSweep(Sweep):
|
|
|
366
422
|
class Points(SingleSweep):
|
|
367
423
|
"""A simple sweep with explicitly supplied values."""
|
|
368
424
|
|
|
369
|
-
def __init__(
|
|
370
|
-
|
|
425
|
+
def __init__(
|
|
426
|
+
self, key: 'cirq.TParamKey', points: Sequence[float], metadata: Optional[Any] = None
|
|
427
|
+
) -> None:
|
|
428
|
+
"""Creates a sweep on a variable with supplied values.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
key: sympy.Symbol or equivalent to sweep across.
|
|
432
|
+
points: sequence of floating point values that represent
|
|
433
|
+
the values to sweep across. The length of the sweep
|
|
434
|
+
will be equivalent to the length of this sequence.
|
|
435
|
+
metadata: Optional metadata to attach to the sweep to
|
|
436
|
+
annotate the sweep or its variable.
|
|
437
|
+
|
|
438
|
+
"""
|
|
439
|
+
super().__init__(key)
|
|
371
440
|
self.points = points
|
|
441
|
+
self.metadata = metadata
|
|
372
442
|
|
|
373
443
|
def _tuple(self) -> Tuple[Union[str, sympy.Expr], Sequence[float]]:
|
|
374
444
|
return self.key, tuple(self.points)
|
|
@@ -380,25 +450,44 @@ class Points(SingleSweep):
|
|
|
380
450
|
return iter(self.points)
|
|
381
451
|
|
|
382
452
|
def __repr__(self) -> str:
|
|
383
|
-
|
|
453
|
+
metadata_repr = f', metadata={self.metadata!r}' if self.metadata is not None else ""
|
|
454
|
+
return f'cirq.Points({self.key!r}, {self.points!r}{metadata_repr})'
|
|
384
455
|
|
|
385
456
|
def _json_dict_(self) -> Dict[str, Any]:
|
|
457
|
+
if self.metadata is not None:
|
|
458
|
+
return protocols.obj_to_dict_helper(self, ["key", "points", "metadata"])
|
|
386
459
|
return protocols.obj_to_dict_helper(self, ["key", "points"])
|
|
387
460
|
|
|
388
461
|
|
|
389
462
|
class Linspace(SingleSweep):
|
|
390
463
|
"""A simple sweep over linearly-spaced values."""
|
|
391
464
|
|
|
392
|
-
def __init__(
|
|
465
|
+
def __init__(
|
|
466
|
+
self,
|
|
467
|
+
key: 'cirq.TParamKey',
|
|
468
|
+
start: float,
|
|
469
|
+
stop: float,
|
|
470
|
+
length: int,
|
|
471
|
+
metadata: Optional[Any] = None,
|
|
472
|
+
) -> None:
|
|
393
473
|
"""Creates a linear-spaced sweep for a given key.
|
|
394
474
|
|
|
395
475
|
For the given args, assigns to the list of values
|
|
396
476
|
start, start + (stop - start) / (length - 1), ..., stop
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
key: sympy.Symbol or equivalent to sweep across.
|
|
480
|
+
start: minimum value of linear sweep.
|
|
481
|
+
stop: maximum value of linear sweep.
|
|
482
|
+
length: number of points in the sweep.
|
|
483
|
+
metadata: Optional metadata to attach to the sweep to
|
|
484
|
+
annotate the sweep or its variable.
|
|
397
485
|
"""
|
|
398
|
-
super(
|
|
486
|
+
super().__init__(key)
|
|
399
487
|
self.start = start
|
|
400
488
|
self.stop = stop
|
|
401
489
|
self.length = length
|
|
490
|
+
self.metadata = metadata
|
|
402
491
|
|
|
403
492
|
def _tuple(self) -> Tuple[Union[str, sympy.Expr], float, float, int]:
|
|
404
493
|
return (self.key, self.start, self.stop, self.length)
|
|
@@ -415,12 +504,17 @@ class Linspace(SingleSweep):
|
|
|
415
504
|
yield self.start * (1 - p) + self.stop * p
|
|
416
505
|
|
|
417
506
|
def __repr__(self) -> str:
|
|
507
|
+
metadata_repr = f', metadata={self.metadata!r}' if self.metadata is not None else ""
|
|
418
508
|
return (
|
|
419
509
|
f'cirq.Linspace({self.key!r}, start={self.start!r}, '
|
|
420
|
-
f'stop={self.stop!r}, length={self.length!r})'
|
|
510
|
+
f'stop={self.stop!r}, length={self.length!r}{metadata_repr})'
|
|
421
511
|
)
|
|
422
512
|
|
|
423
513
|
def _json_dict_(self) -> Dict[str, Any]:
|
|
514
|
+
if self.metadata is not None:
|
|
515
|
+
return protocols.obj_to_dict_helper(
|
|
516
|
+
self, ["key", "start", "stop", "length", "metadata"]
|
|
517
|
+
)
|
|
424
518
|
return protocols.obj_to_dict_helper(self, ["key", "start", "stop", "length"])
|
|
425
519
|
|
|
426
520
|
|