cirq-core 1.1.0.dev20221219200817__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. cirq/__init__.py +8 -0
  2. cirq/_compat.py +29 -4
  3. cirq/_compat_test.py +24 -26
  4. cirq/_version.py +32 -1
  5. cirq/_version_test.py +1 -1
  6. cirq/circuits/_block_diagram_drawer_test.py +4 -3
  7. cirq/circuits/circuit.py +109 -63
  8. cirq/circuits/circuit_operation.py +2 -3
  9. cirq/circuits/circuit_operation_test.py +4 -4
  10. cirq/circuits/circuit_test.py +11 -0
  11. cirq/circuits/frozen_circuit.py +13 -1
  12. cirq/circuits/frozen_circuit_test.py +5 -1
  13. cirq/circuits/moment.py +39 -14
  14. cirq/circuits/moment_test.py +7 -0
  15. cirq/circuits/text_diagram_drawer.py +1 -1
  16. cirq/circuits/text_diagram_drawer_test.py +3 -7
  17. cirq/conftest.py +8 -0
  18. cirq/contrib/acquaintance/bipartite.py +1 -1
  19. cirq/contrib/acquaintance/devices.py +2 -2
  20. cirq/contrib/acquaintance/executor.py +5 -2
  21. cirq/contrib/acquaintance/gates.py +3 -2
  22. cirq/contrib/acquaintance/permutation.py +13 -2
  23. cirq/contrib/acquaintance/testing.py +3 -5
  24. cirq/contrib/paulistring/recombine.py +3 -6
  25. cirq/contrib/qasm_import/_parser.py +17 -21
  26. cirq/contrib/qasm_import/_parser_test.py +30 -45
  27. cirq/contrib/qcircuit/qcircuit_test.py +3 -7
  28. cirq/contrib/quantum_volume/quantum_volume.py +3 -3
  29. cirq/contrib/quimb/mps_simulator.py +1 -1
  30. cirq/contrib/quimb/state_vector.py +2 -0
  31. cirq/contrib/quirk/quirk_gate.py +1 -0
  32. cirq/contrib/svg/svg.py +4 -7
  33. cirq/contrib/svg/svg_test.py +29 -1
  34. cirq/devices/grid_qubit.py +26 -28
  35. cirq/devices/grid_qubit_test.py +21 -5
  36. cirq/devices/line_qubit.py +10 -12
  37. cirq/devices/line_qubit_test.py +9 -2
  38. cirq/devices/named_topologies.py +1 -1
  39. cirq/devices/noise_model.py +4 -1
  40. cirq/devices/superconducting_qubits_noise_properties.py +1 -3
  41. cirq/experiments/n_qubit_tomography.py +1 -1
  42. cirq/experiments/qubit_characterizations.py +2 -2
  43. cirq/experiments/single_qubit_readout_calibration.py +1 -1
  44. cirq/experiments/t2_decay_experiment.py +1 -1
  45. cirq/experiments/xeb_simulation_test.py +2 -2
  46. cirq/interop/quirk/cells/testing.py +1 -1
  47. cirq/json_resolver_cache.py +1 -0
  48. cirq/linalg/__init__.py +2 -0
  49. cirq/linalg/decompositions_test.py +4 -4
  50. cirq/linalg/diagonalize_test.py +5 -6
  51. cirq/linalg/transformations.py +72 -9
  52. cirq/linalg/transformations_test.py +23 -7
  53. cirq/ops/__init__.py +4 -0
  54. cirq/ops/arithmetic_operation.py +4 -6
  55. cirq/ops/classically_controlled_operation.py +10 -3
  56. cirq/ops/clifford_gate.py +1 -7
  57. cirq/ops/common_channels.py +21 -15
  58. cirq/ops/common_gate_families.py +2 -3
  59. cirq/ops/common_gates.py +48 -11
  60. cirq/ops/common_gates_test.py +4 -0
  61. cirq/ops/controlled_gate.py +44 -18
  62. cirq/ops/controlled_operation.py +13 -5
  63. cirq/ops/dense_pauli_string.py +14 -19
  64. cirq/ops/diagonal_gate.py +3 -4
  65. cirq/ops/eigen_gate.py +8 -10
  66. cirq/ops/eigen_gate_test.py +6 -0
  67. cirq/ops/gate_operation.py +11 -6
  68. cirq/ops/gate_operation_test.py +11 -2
  69. cirq/ops/gateset.py +2 -1
  70. cirq/ops/gateset_test.py +38 -5
  71. cirq/ops/global_phase_op.py +28 -2
  72. cirq/ops/global_phase_op_test.py +21 -0
  73. cirq/ops/identity.py +1 -1
  74. cirq/ops/kraus_channel_test.py +2 -2
  75. cirq/ops/linear_combinations.py +7 -6
  76. cirq/ops/linear_combinations_test.py +26 -10
  77. cirq/ops/matrix_gates.py +8 -4
  78. cirq/ops/matrix_gates_test.py +25 -3
  79. cirq/ops/measure_util.py +13 -5
  80. cirq/ops/measure_util_test.py +8 -2
  81. cirq/ops/measurement_gate.py +1 -1
  82. cirq/ops/measurement_gate_test.py +9 -4
  83. cirq/ops/mixed_unitary_channel_test.py +4 -4
  84. cirq/ops/named_qubit.py +2 -4
  85. cirq/ops/parity_gates.py +5 -1
  86. cirq/ops/parity_gates_test.py +6 -0
  87. cirq/ops/pauli_gates.py +9 -9
  88. cirq/ops/pauli_string.py +4 -2
  89. cirq/ops/pauli_string_raw_types.py +4 -11
  90. cirq/ops/pauli_string_test.py +13 -13
  91. cirq/ops/pauli_sum_exponential.py +6 -1
  92. cirq/ops/qubit_manager.py +97 -0
  93. cirq/ops/qubit_manager_test.py +66 -0
  94. cirq/ops/raw_types.py +75 -33
  95. cirq/ops/raw_types_test.py +34 -0
  96. cirq/ops/three_qubit_gates.py +16 -10
  97. cirq/ops/three_qubit_gates_test.py +4 -2
  98. cirq/ops/two_qubit_diagonal_gate.py +3 -3
  99. cirq/ops/wait_gate.py +1 -1
  100. cirq/protocols/__init__.py +1 -0
  101. cirq/protocols/act_on_protocol.py +3 -3
  102. cirq/protocols/act_on_protocol_test.py +5 -5
  103. cirq/protocols/apply_channel_protocol.py +9 -8
  104. cirq/protocols/apply_mixture_protocol.py +8 -8
  105. cirq/protocols/apply_mixture_protocol_test.py +1 -1
  106. cirq/protocols/apply_unitary_protocol.py +66 -19
  107. cirq/protocols/apply_unitary_protocol_test.py +50 -0
  108. cirq/protocols/circuit_diagram_info_protocol.py +7 -9
  109. cirq/protocols/decompose_protocol.py +167 -125
  110. cirq/protocols/decompose_protocol_test.py +132 -2
  111. cirq/protocols/has_stabilizer_effect_protocol.py +2 -1
  112. cirq/protocols/inverse_protocol.py +2 -2
  113. cirq/protocols/json_serialization_test.py +3 -3
  114. cirq/protocols/json_test_data/Linspace.json +20 -7
  115. cirq/protocols/json_test_data/Linspace.repr +4 -1
  116. cirq/protocols/json_test_data/Points.json +19 -8
  117. cirq/protocols/json_test_data/Points.repr +4 -1
  118. cirq/protocols/json_test_data/Result.repr_inward +1 -1
  119. cirq/protocols/json_test_data/ResultDict.repr +1 -1
  120. cirq/protocols/json_test_data/ResultDict.repr_inward +1 -1
  121. cirq/protocols/json_test_data/TrialResult.repr_inward +1 -1
  122. cirq/protocols/json_test_data/XPowGate.json +13 -5
  123. cirq/protocols/json_test_data/XPowGate.repr +1 -1
  124. cirq/protocols/json_test_data/ZPowGate.json +13 -5
  125. cirq/protocols/json_test_data/ZPowGate.repr +1 -1
  126. cirq/protocols/json_test_data/ZipLongest.json +19 -0
  127. cirq/protocols/json_test_data/ZipLongest.repr +1 -0
  128. cirq/protocols/json_test_data/spec.py +1 -0
  129. cirq/protocols/kraus_protocol.py +3 -4
  130. cirq/protocols/measurement_key_protocol.py +3 -1
  131. cirq/protocols/mixture_protocol.py +3 -2
  132. cirq/protocols/phase_protocol.py +3 -3
  133. cirq/protocols/pow_protocol.py +1 -2
  134. cirq/protocols/qasm.py +4 -4
  135. cirq/protocols/qid_shape_protocol.py +8 -8
  136. cirq/protocols/resolve_parameters.py +8 -3
  137. cirq/protocols/resolve_parameters_test.py +3 -3
  138. cirq/protocols/unitary_protocol.py +19 -11
  139. cirq/protocols/unitary_protocol_test.py +37 -0
  140. cirq/qis/channels.py +1 -1
  141. cirq/qis/clifford_tableau.py +4 -5
  142. cirq/qis/quantum_state_representation.py +7 -9
  143. cirq/qis/states.py +21 -13
  144. cirq/qis/states_test.py +7 -0
  145. cirq/sim/clifford/clifford_simulator.py +3 -3
  146. cirq/sim/density_matrix_simulation_state.py +2 -1
  147. cirq/sim/density_matrix_simulator.py +1 -1
  148. cirq/sim/density_matrix_simulator_test.py +9 -5
  149. cirq/sim/density_matrix_utils.py +7 -32
  150. cirq/sim/mux.py +2 -2
  151. cirq/sim/simulation_state.py +58 -18
  152. cirq/sim/simulation_state_base.py +5 -2
  153. cirq/sim/simulation_state_test.py +121 -9
  154. cirq/sim/simulation_utils.py +59 -0
  155. cirq/sim/simulation_utils_test.py +32 -0
  156. cirq/sim/simulator.py +2 -1
  157. cirq/sim/simulator_base_test.py +3 -3
  158. cirq/sim/sparse_simulator.py +1 -1
  159. cirq/sim/sparse_simulator_test.py +5 -5
  160. cirq/sim/state_vector.py +7 -36
  161. cirq/sim/state_vector_simulation_state.py +18 -1
  162. cirq/sim/state_vector_simulator.py +3 -2
  163. cirq/sim/state_vector_simulator_test.py +24 -2
  164. cirq/sim/state_vector_test.py +46 -15
  165. cirq/study/__init__.py +1 -0
  166. cirq/study/flatten_expressions.py +2 -2
  167. cirq/study/resolver.py +2 -0
  168. cirq/study/resolver_test.py +1 -1
  169. cirq/study/result.py +1 -1
  170. cirq/study/sweeps.py +103 -9
  171. cirq/study/sweeps_test.py +64 -0
  172. cirq/testing/__init__.py +4 -0
  173. cirq/testing/circuit_compare.py +15 -18
  174. cirq/testing/consistent_act_on.py +4 -4
  175. cirq/testing/consistent_controlled_gate_op_test.py +1 -1
  176. cirq/testing/consistent_decomposition.py +11 -2
  177. cirq/testing/consistent_decomposition_test.py +8 -1
  178. cirq/testing/consistent_protocols.py +2 -0
  179. cirq/testing/consistent_protocols_test.py +8 -4
  180. cirq/testing/consistent_qasm.py +8 -15
  181. cirq/testing/consistent_specified_has_unitary.py +1 -1
  182. cirq/testing/consistent_unitary.py +85 -0
  183. cirq/testing/consistent_unitary_test.py +96 -0
  184. cirq/testing/equivalent_repr_eval.py +10 -10
  185. cirq/testing/json.py +3 -3
  186. cirq/testing/logs.py +1 -1
  187. cirq/testing/order_tester.py +4 -5
  188. cirq/testing/random_circuit.py +3 -5
  189. cirq/testing/sample_gates.py +79 -0
  190. cirq/testing/sample_gates_test.py +59 -0
  191. cirq/transformers/__init__.py +2 -0
  192. cirq/transformers/analytical_decompositions/__init__.py +8 -0
  193. cirq/transformers/analytical_decompositions/pauli_string_decomposition.py +130 -0
  194. cirq/transformers/analytical_decompositions/pauli_string_decomposition_test.py +58 -0
  195. cirq/transformers/analytical_decompositions/quantum_shannon_decomposition.py +230 -0
  196. cirq/transformers/analytical_decompositions/quantum_shannon_decomposition_test.py +112 -0
  197. cirq/transformers/analytical_decompositions/three_qubit_decomposition_test.py +1 -3
  198. cirq/transformers/analytical_decompositions/two_qubit_to_fsim.py +1 -1
  199. cirq/transformers/expand_composite.py +1 -1
  200. cirq/transformers/heuristic_decompositions/gate_tabulation_math_utils.py +4 -4
  201. cirq/transformers/measurement_transformers.py +4 -4
  202. cirq/transformers/merge_single_qubit_gates.py +17 -4
  203. cirq/transformers/routing/route_circuit_cqc.py +2 -2
  204. cirq/transformers/stratify.py +125 -62
  205. cirq/transformers/stratify_test.py +20 -16
  206. cirq/transformers/transformer_api.py +1 -1
  207. cirq/transformers/transformer_primitives.py +3 -2
  208. cirq/transformers/transformer_primitives_test.py +11 -0
  209. cirq/value/abc_alt.py +3 -2
  210. cirq/value/abc_alt_test.py +1 -0
  211. cirq/value/classical_data.py +10 -10
  212. cirq/value/digits.py +2 -2
  213. cirq/value/linear_dict.py +18 -19
  214. cirq/value/product_state.py +7 -6
  215. cirq/value/value_equality_attr.py +2 -2
  216. cirq/vis/heatmap.py +1 -1
  217. cirq/vis/heatmap_test.py +2 -2
  218. cirq/work/collector.py +2 -2
  219. cirq/work/observable_measurement_data.py +5 -5
  220. cirq/work/observable_readout_calibration.py +3 -1
  221. cirq/work/observable_settings.py +1 -1
  222. cirq/work/pauli_sum_collector.py +9 -8
  223. cirq/work/sampler.py +2 -0
  224. cirq/work/zeros_sampler.py +2 -2
  225. {cirq_core-1.1.0.dev20221219200817.dist-info → cirq_core-1.2.0.dist-info}/METADATA +7 -15
  226. {cirq_core-1.1.0.dev20221219200817.dist-info → cirq_core-1.2.0.dist-info}/RECORD +229 -215
  227. {cirq_core-1.1.0.dev20221219200817.dist-info → cirq_core-1.2.0.dist-info}/WHEEL +1 -1
  228. {cirq_core-1.1.0.dev20221219200817.dist-info → cirq_core-1.2.0.dist-info}/LICENSE +0 -0
  229. {cirq_core-1.1.0.dev20221219200817.dist-info → cirq_core-1.2.0.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,7 @@ class CountingState(cirq.qis.QuantumStateRepresentation):
34
34
  self.measurement_count += 1
35
35
  return [self.gate_count]
36
36
 
37
- def kron(self: 'CountingState', other: 'CountingState') -> 'CountingState':
37
+ def kron(self, other: 'CountingState') -> 'CountingState':
38
38
  return CountingState(
39
39
  self.data,
40
40
  self.gate_count + other.gate_count,
@@ -43,13 +43,13 @@ class CountingState(cirq.qis.QuantumStateRepresentation):
43
43
  )
44
44
 
45
45
  def factor(
46
- self: 'CountingState', axes: Sequence[int], *, validate=True, atol=1e-07
46
+ self, axes: Sequence[int], *, validate=True, atol=1e-07
47
47
  ) -> Tuple['CountingState', 'CountingState']:
48
48
  return CountingState(
49
49
  self.data, self.gate_count, self.measurement_count, self.copy_count
50
50
  ), CountingState(self.data)
51
51
 
52
- def reindex(self: 'CountingState', axes: Sequence[int]) -> 'CountingState':
52
+ def reindex(self, axes: Sequence[int]) -> 'CountingState':
53
53
  return CountingState(self.data, self.gate_count, self.measurement_count, self.copy_count)
54
54
 
55
55
  def copy(self, deep_copy_buffers: bool = True) -> 'CountingState':
@@ -282,5 +282,5 @@ class SparseSimulatorStep(
282
282
  # Dtype doesn't have a good repr, so we work around by invoking __name__.
283
283
  return (
284
284
  f'cirq.SparseSimulatorStep(sim_state={self._sim_state!r},'
285
- f' dtype=np.{self._dtype.__name__})'
285
+ f' dtype=np.{np.dtype(self._dtype)!r})'
286
286
  )
@@ -255,7 +255,7 @@ def test_run_mixture(dtype: Type[np.complexfloating], split: bool):
255
255
  simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
256
256
  circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0))
257
257
  result = simulator.run(circuit, repetitions=100)
258
- assert 20 < sum(result.measurements['q(0)'])[0] < 80 # type: ignore
258
+ assert 20 < sum(result.measurements['q(0)'])[0] < 80
259
259
 
260
260
 
261
261
  @pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
@@ -265,8 +265,8 @@ def test_run_mixture_with_gates(dtype: Type[np.complexfloating], split: bool):
265
265
  simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split, seed=23)
266
266
  circuit = cirq.Circuit(cirq.H(q0), cirq.phase_flip(0.5)(q0), cirq.H(q0), cirq.measure(q0))
267
267
  result = simulator.run(circuit, repetitions=100)
268
- assert sum(result.measurements['q(0)'])[0] < 80 # type: ignore
269
- assert sum(result.measurements['q(0)'])[0] > 20 # type: ignore
268
+ assert sum(result.measurements['q(0)'])[0] < 80
269
+ assert sum(result.measurements['q(0)'])[0] > 20
270
270
 
271
271
 
272
272
  @pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
@@ -775,9 +775,9 @@ def test_sparse_simulator_repr():
775
775
  # No equality so cannot use cirq.testing.assert_equivalent_repr
776
776
  assert (
777
777
  repr(step) == "cirq.SparseSimulatorStep(sim_state=cirq.StateVectorSimulationState("
778
- "initial_state=np.array([[0j, (1+0j)], [0j, 0j]], dtype=np.complex64), "
778
+ "initial_state=np.array([[0j, (1+0j)], [0j, 0j]], dtype=np.dtype('complex64')), "
779
779
  "qubits=(cirq.LineQubit(0), cirq.LineQubit(1)), "
780
- "classical_data=cirq.ClassicalDataDictionaryStore()), dtype=np.complex64)"
780
+ "classical_data=cirq.ClassicalDataDictionaryStore()), dtype=np.dtype('complex64'))"
781
781
  )
782
782
 
783
783
 
cirq/sim/state_vector.py CHANGED
@@ -19,7 +19,7 @@ from typing import List, Mapping, Optional, Tuple, TYPE_CHECKING, Sequence
19
19
  import numpy as np
20
20
 
21
21
  from cirq import linalg, qis, value
22
- from cirq.sim import simulator
22
+ from cirq.sim import simulator, simulation_utils
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  import cirq
@@ -35,7 +35,7 @@ class StateVectorMixin:
35
35
  """Inits StateVectorMixin.
36
36
 
37
37
  Args:
38
- qubit_map: A map from the Qubits in the Circuit to the the index
38
+ qubit_map: A map from the Qubits in the Circuit to the index
39
39
  of this qubit for a canonical ordering. This canonical ordering
40
40
  is used to define the state (see the state_vector() method).
41
41
  *args: Passed on to the class that this is mixed in with.
@@ -102,7 +102,7 @@ class StateVectorMixin:
102
102
  and non-zero floats of the specified accuracy."""
103
103
  return qis.dirac_notation(self.state_vector(), decimals, qid_shape=self._qid_shape)
104
104
 
105
- def density_matrix_of(self, qubits: List['cirq.Qid'] = None) -> np.ndarray:
105
+ def density_matrix_of(self, qubits: Optional[List['cirq.Qid']] = None) -> np.ndarray:
106
106
  r"""Returns the density matrix of the state.
107
107
 
108
108
  Calculate the density matrix for the system on the qubits provided.
@@ -215,7 +215,8 @@ def sample_state_vector(
215
215
  prng = value.parse_random_state(seed)
216
216
 
217
217
  # Calculate the measurement probabilities.
218
- probs = _probs(state_vector, indices, shape)
218
+ probs = (state_vector * state_vector.conj()).real
219
+ probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape)
219
220
 
220
221
  # We now have the probability vector, correctly ordered, so sample over
221
222
  # it. Note that we us ints here, since numpy's choice does not allow for
@@ -288,7 +289,8 @@ def measure_state_vector(
288
289
  initial_shape = state_vector.shape
289
290
 
290
291
  # Calculate the measurement probabilities and then make the measurement.
291
- probs = _probs(state_vector, indices, shape)
292
+ probs = (state_vector * state_vector.conj()).real
293
+ probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape)
292
294
  result = prng.choice(len(probs), p=probs)
293
295
  ###measurement_bits = [(1 & (result >> i)) for i in range(len(indices))]
294
296
  # Convert to individual qudit measurements.
@@ -321,34 +323,3 @@ def measure_state_vector(
321
323
  assert out is not None
322
324
  # We mutate and return out, so mypy cannot identify that the out cannot be None.
323
325
  return measurement_bits, out
324
-
325
-
326
- def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray:
327
- """Returns the probabilities for a measurement on the given indices."""
328
- tensor = np.reshape(state, qid_shape)
329
- # Calculate the probabilities for measuring the particular results.
330
- if len(indices) == len(qid_shape):
331
- # We're measuring every qudit, so no need for fancy indexing
332
- probs = np.abs(tensor) ** 2
333
- probs = np.transpose(probs, indices)
334
- probs = probs.reshape(-1)
335
- else:
336
- # Fancy indexing required
337
- meas_shape = tuple(qid_shape[i] for i in indices)
338
- probs = (
339
- np.abs(
340
- [
341
- tensor[
342
- linalg.slice_for_qubits_equal_to(
343
- indices, big_endian_qureg_value=b, qid_shape=qid_shape
344
- )
345
- ]
346
- for b in range(np.prod(meas_shape, dtype=np.int64))
347
- ]
348
- )
349
- ** 2
350
- )
351
- probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))
352
-
353
- # To deal with rounding issues, ensure that the probabilities sum to 1.
354
- return probs / np.sum(probs)
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  """Objects and methods for acting efficiently on a state vector."""
15
16
  from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union
16
17
 
@@ -355,6 +356,22 @@ class StateVectorSimulationState(SimulationState[_BufferedStateVector]):
355
356
  )
356
357
  super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)
357
358
 
359
+ def add_qubits(self, qubits: Sequence['cirq.Qid']):
360
+ ret = super().add_qubits(qubits)
361
+ return (
362
+ self.kronecker_product(type(self)(qubits=qubits), inplace=True)
363
+ if ret is NotImplemented
364
+ else ret
365
+ )
366
+
367
+ def remove_qubits(self, qubits: Sequence['cirq.Qid']):
368
+ ret = super().remove_qubits(qubits)
369
+ if ret is not NotImplemented:
370
+ return ret
371
+ extracted, remainder = self.factor(qubits, inplace=True)
372
+ remainder._state._state_vector *= extracted._state._state_vector.reshape((-1,))[0]
373
+ return remainder
374
+
358
375
  def _act_on_fallback_(
359
376
  self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
360
377
  ) -> bool:
@@ -377,7 +394,7 @@ class StateVectorSimulationState(SimulationState[_BufferedStateVector]):
377
394
  raise TypeError(
378
395
  "Can't simulate operations that don't implement "
379
396
  "SupportsUnitary, SupportsConsistentApplyUnitary, "
380
- "SupportsMixture or is a measurement: {!r}".format(action)
397
+ f"SupportsMixture or is a measurement: {action!r}"
381
398
  )
382
399
 
383
400
  def __repr__(self) -> str:
@@ -20,6 +20,7 @@ import numpy as np
20
20
 
21
21
  from cirq import _compat, ops, value, qis
22
22
  from cirq.sim import simulator, state_vector, simulator_base
23
+ from cirq.protocols import qid_shape
23
24
 
24
25
  if TYPE_CHECKING:
25
26
  import cirq
@@ -31,7 +32,7 @@ TStateVectorStepResult = TypeVar('TStateVectorStepResult', bound='StateVectorSte
31
32
  class SimulatesIntermediateStateVector(
32
33
  Generic[TStateVectorStepResult],
33
34
  simulator_base.SimulatorBase[
34
- TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState',
35
+ TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState'
35
36
  ],
36
37
  simulator.SimulatesAmplitudes,
37
38
  metaclass=abc.ABCMeta,
@@ -172,7 +173,7 @@ class StateVectorTrialResult(
172
173
  size = np.prod(shape, dtype=np.int64)
173
174
  final = final.reshape(size)
174
175
  if len([1 for e in final if abs(e) > 0.001]) < 16:
175
- state_vector = qis.dirac_notation(final, 3)
176
+ state_vector = qis.dirac_notation(final, 3, qid_shape(substate.qubits))
176
177
  else:
177
178
  state_vector = str(final)
178
179
  label = f'qubits: {substate.qubits}' if substate.qubits else 'phase:'
@@ -35,9 +35,9 @@ def test_state_vector_trial_result_repr():
35
35
  expected_repr = (
36
36
  "cirq.StateVectorTrialResult("
37
37
  "params=cirq.ParamResolver({'s': 1}), "
38
- "measurements={'m': np.array([[1]], dtype=np.int32)}, "
38
+ "measurements={'m': np.array([[1]], dtype=np.dtype('int32'))}, "
39
39
  "final_simulator_state=cirq.StateVectorSimulationState("
40
- "initial_state=np.array([0j, (1+0j)], dtype=np.complex64), "
40
+ "initial_state=np.array([0j, (1+0j)], dtype=np.dtype('complex64')), "
41
41
  "qubits=(cirq.NamedQubit('a'),), "
42
42
  "classical_data=cirq.ClassicalDataDictionaryStore()))"
43
43
  )
@@ -159,6 +159,28 @@ def test_str_big():
159
159
  assert 'output vector: [0.03125+0.j 0.03125+0.j 0.03125+0.j ..' in str(result)
160
160
 
161
161
 
162
+ def test_str_qudit():
163
+ qutrit = cirq.LineQid(0, dimension=3)
164
+ final_simulator_state = cirq.StateVectorSimulationState(
165
+ prng=np.random.RandomState(0),
166
+ qubits=[qutrit],
167
+ initial_state=np.array([0, 0, 1]),
168
+ dtype=np.complex64,
169
+ )
170
+ result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state)
171
+ assert "|2⟩" in str(result)
172
+
173
+ ququart = cirq.LineQid(0, dimension=4)
174
+ final_simulator_state = cirq.StateVectorSimulationState(
175
+ prng=np.random.RandomState(0),
176
+ qubits=[ququart],
177
+ initial_state=np.array([0, 1, 0, 0]),
178
+ dtype=np.complex64,
179
+ )
180
+ result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state)
181
+ assert "|1⟩" in str(result)
182
+
183
+
162
184
  def test_pretty_print():
163
185
  final_simulator_state = cirq.StateVectorSimulationState(
164
186
  available_buffer=np.array([1]),
@@ -21,6 +21,7 @@ import numpy as np
21
21
 
22
22
  import cirq
23
23
  import cirq.testing
24
+ from cirq import linalg
24
25
 
25
26
 
26
27
  def test_state_mixin():
@@ -172,7 +173,9 @@ def test_sample_no_indices_repetitions():
172
173
  )
173
174
 
174
175
 
175
- def test_measure_state_computational_basis():
176
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
177
+ def test_measure_state_computational_basis(use_np_transpose: bool):
178
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
176
179
  results = []
177
180
  for x in range(8):
178
181
  initial_state = cirq.to_valid_state_vector(x, 3)
@@ -183,7 +186,9 @@ def test_measure_state_computational_basis():
183
186
  assert results == expected
184
187
 
185
188
 
186
- def test_measure_state_reshape():
189
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
190
+ def test_measure_state_reshape(use_np_transpose: bool):
191
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
187
192
  results = []
188
193
  for x in range(8):
189
194
  initial_state = np.reshape(cirq.to_valid_state_vector(x, 3), [2] * 3)
@@ -194,7 +199,9 @@ def test_measure_state_reshape():
194
199
  assert results == expected
195
200
 
196
201
 
197
- def test_measure_state_partial_indices():
202
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
203
+ def test_measure_state_partial_indices(use_np_transpose: bool):
204
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
198
205
  for index in range(3):
199
206
  for x in range(8):
200
207
  initial_state = cirq.to_valid_state_vector(x, 3)
@@ -203,7 +210,9 @@ def test_measure_state_partial_indices():
203
210
  assert bits == [bool(1 & (x >> (2 - index)))]
204
211
 
205
212
 
206
- def test_measure_state_partial_indices_order():
213
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
214
+ def test_measure_state_partial_indices_order(use_np_transpose: bool):
215
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
207
216
  for x in range(8):
208
217
  initial_state = cirq.to_valid_state_vector(x, 3)
209
218
  bits, state = cirq.measure_state_vector(initial_state, [2, 1])
@@ -211,7 +220,9 @@ def test_measure_state_partial_indices_order():
211
220
  assert bits == [bool(1 & (x >> 0)), bool(1 & (x >> 1))]
212
221
 
213
222
 
214
- def test_measure_state_partial_indices_all_orders():
223
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
224
+ def test_measure_state_partial_indices_all_orders(use_np_transpose: bool):
225
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
215
226
  for perm in itertools.permutations([0, 1, 2]):
216
227
  for x in range(8):
217
228
  initial_state = cirq.to_valid_state_vector(x, 3)
@@ -220,7 +231,9 @@ def test_measure_state_partial_indices_all_orders():
220
231
  assert bits == [bool(1 & (x >> (2 - p))) for p in perm]
221
232
 
222
233
 
223
- def test_measure_state_collapse():
234
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
235
+ def test_measure_state_collapse(use_np_transpose: bool):
236
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
224
237
  initial_state = np.zeros(8, dtype=np.complex64)
225
238
  initial_state[0] = 1 / np.sqrt(2)
226
239
  initial_state[2] = 1 / np.sqrt(2)
@@ -243,7 +256,9 @@ def test_measure_state_collapse():
243
256
  assert bits == [False]
244
257
 
245
258
 
246
- def test_measure_state_seed():
259
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
260
+ def test_measure_state_seed(use_np_transpose: bool):
261
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
247
262
  n = 10
248
263
  initial_state = np.ones(2**n) / 2 ** (n / 2)
249
264
 
@@ -262,7 +277,9 @@ def test_measure_state_seed():
262
277
  np.testing.assert_allclose(state1, state2)
263
278
 
264
279
 
265
- def test_measure_state_out_is_state():
280
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
281
+ def test_measure_state_out_is_state(use_np_transpose: bool):
282
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
266
283
  initial_state = np.zeros(8, dtype=np.complex64)
267
284
  initial_state[0] = 1 / np.sqrt(2)
268
285
  initial_state[2] = 1 / np.sqrt(2)
@@ -273,7 +290,9 @@ def test_measure_state_out_is_state():
273
290
  assert state is initial_state
274
291
 
275
292
 
276
- def test_measure_state_out_is_not_state():
293
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
294
+ def test_measure_state_out_is_not_state(use_np_transpose: bool):
295
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
277
296
  initial_state = np.zeros(8, dtype=np.complex64)
278
297
  initial_state[0] = 1 / np.sqrt(2)
279
298
  initial_state[2] = 1 / np.sqrt(2)
@@ -283,14 +302,18 @@ def test_measure_state_out_is_not_state():
283
302
  assert out is state
284
303
 
285
304
 
286
- def test_measure_state_not_power_of_two():
305
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
306
+ def test_measure_state_not_power_of_two(use_np_transpose: bool):
307
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
287
308
  with pytest.raises(ValueError, match='3'):
288
309
  _, _ = cirq.measure_state_vector(np.array([1, 0, 0]), [1])
289
310
  with pytest.raises(ValueError, match='5'):
290
311
  cirq.measure_state_vector(np.array([0, 1, 0, 0, 0]), [1])
291
312
 
292
313
 
293
- def test_measure_state_index_out_of_range():
314
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
315
+ def test_measure_state_index_out_of_range(use_np_transpose: bool):
316
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
294
317
  state = cirq.to_valid_state_vector(0, 3)
295
318
  with pytest.raises(IndexError, match='-2'):
296
319
  cirq.measure_state_vector(state, [-2])
@@ -298,14 +321,18 @@ def test_measure_state_index_out_of_range():
298
321
  cirq.measure_state_vector(state, [3])
299
322
 
300
323
 
301
- def test_measure_state_no_indices():
324
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
325
+ def test_measure_state_no_indices(use_np_transpose: bool):
326
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
302
327
  initial_state = cirq.to_valid_state_vector(0, 3)
303
328
  bits, state = cirq.measure_state_vector(initial_state, [])
304
329
  assert [] == bits
305
330
  np.testing.assert_almost_equal(state, initial_state)
306
331
 
307
332
 
308
- def test_measure_state_no_indices_out_is_state():
333
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
334
+ def test_measure_state_no_indices_out_is_state(use_np_transpose: bool):
335
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
309
336
  initial_state = cirq.to_valid_state_vector(0, 3)
310
337
  bits, state = cirq.measure_state_vector(initial_state, [], out=initial_state)
311
338
  assert [] == bits
@@ -313,7 +340,9 @@ def test_measure_state_no_indices_out_is_state():
313
340
  assert state is initial_state
314
341
 
315
342
 
316
- def test_measure_state_no_indices_out_is_not_state():
343
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
344
+ def test_measure_state_no_indices_out_is_not_state(use_np_transpose: bool):
345
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
317
346
  initial_state = cirq.to_valid_state_vector(0, 3)
318
347
  out = np.zeros_like(initial_state)
319
348
  bits, state = cirq.measure_state_vector(initial_state, [], out=out)
@@ -323,7 +352,9 @@ def test_measure_state_no_indices_out_is_not_state():
323
352
  assert out is not initial_state
324
353
 
325
354
 
326
- def test_measure_state_empty_state():
355
+ @pytest.mark.parametrize('use_np_transpose', [False, True])
356
+ def test_measure_state_empty_state(use_np_transpose: bool):
357
+ linalg.can_numpy_support_shape = lambda s: use_np_transpose
327
358
  initial_state = np.array([1.0])
328
359
  bits, state = cirq.measure_state_vector(initial_state, [])
329
360
  assert [] == bits
cirq/study/__init__.py CHANGED
@@ -38,6 +38,7 @@ from cirq.study.sweeps import (
38
38
  Sweep,
39
39
  UnitSweep,
40
40
  Zip,
41
+ ZipLongest,
41
42
  dict_to_product_sweep,
42
43
  dict_to_zip_sweep,
43
44
  )
@@ -40,7 +40,7 @@ def flatten(val: Any) -> Tuple[Any, 'ExpressionMap']:
40
40
  the name to avoid collision: `sympy.Symbol('<x + 1>_1')`.
41
41
 
42
42
  This function also creates a dictionary mapping from expressions and symbols
43
- in `val` to the new symbols in the flattened copy of `val`. E.g
43
+ in `val` to the new symbols in the flattened copy of `val`. E.g.
44
44
  `cirq.ExpressionMap({sympy.Symbol('x')+1: sympy.Symbol('<x + 1>')})`. This
45
45
  `ExpressionMap` can be used to transform a sweep over the symbols in `val`
46
46
  to a sweep over the flattened symbols e.g. a sweep over `sympy.Symbol('x')`
@@ -200,7 +200,7 @@ class _ParamFlattener(resolver.ParamResolver):
200
200
  self,
201
201
  param_dict: Optional[resolver.ParamResolverOrSimilarType] = None,
202
202
  *, # Force keyword args
203
- get_param_name: Callable[[sympy.Expr], str,] = None,
203
+ get_param_name: Optional[Callable[[sympy.Expr], str]] = None,
204
204
  ):
205
205
  """Initializes a new _ParamFlattener.
206
206
 
cirq/study/resolver.py CHANGED
@@ -150,6 +150,8 @@ class ParamResolver:
150
150
  # The following resolves common sympy expressions
151
151
  # If sympy did its job and wasn't slower than molasses,
152
152
  # we wouldn't need the following block.
153
+ if isinstance(value, sympy.Float):
154
+ return float(value)
153
155
  if isinstance(value, sympy.Add):
154
156
  summation = self.value_of(value.args[0], recursive)
155
157
  for addend in value.args[1:]:
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  """Tests for parameter resolvers."""
16
+
16
17
  import fractions
17
18
 
18
19
  import numpy as np
@@ -241,7 +242,6 @@ def test_custom_resolved_value():
241
242
  assert r.value_of(b) == 'Baz'
242
243
 
243
244
 
244
- @pytest.mark.xfail(reason='this test requires sympy 1.12', strict=True)
245
245
  def test_custom_value_not_implemented():
246
246
  class Bar:
247
247
  def _resolved_value_(self):
cirq/study/result.py CHANGED
@@ -140,7 +140,7 @@ class Result(abc.ABC):
140
140
  basis = 2 ** np.arange(n, dtype=dtype)[::-1]
141
141
  converted_dict[key] = np.sum(basis * bitstrings, axis=1)
142
142
 
143
- # Use objects to accomodate more than 64 qubits if needed.
143
+ # Use objects to accommodate more than 64 qubits if needed.
144
144
  dtype = object if any(bs.shape[1] > 63 for _, bs in measurements.items()) else np.int64
145
145
  return pd.DataFrame(converted_dict, dtype=dtype)
146
146
 
cirq/study/sweeps.py CHANGED
@@ -18,6 +18,7 @@ from typing import (
18
18
  Iterable,
19
19
  Iterator,
20
20
  List,
21
+ Optional,
21
22
  overload,
22
23
  Sequence,
23
24
  TYPE_CHECKING,
@@ -127,7 +128,7 @@ class Sweep(metaclass=abc.ABCMeta):
127
128
  def __getitem__(self, val: slice) -> 'Sweep':
128
129
  pass
129
130
 
130
- def __getitem__(self, val):
131
+ def __getitem__(self, val: Union[int, slice]) -> Union[resolver.ParamResolver, 'Sweep']:
131
132
  n = len(self)
132
133
  if isinstance(val, int):
133
134
  if val < -n or val >= n:
@@ -292,7 +293,7 @@ class Zip(Sweep):
292
293
  self.sweeps = sweeps
293
294
 
294
295
  def __eq__(self, other):
295
- if not isinstance(other, Zip):
296
+ if type(other) is not Zip:
296
297
  return NotImplemented
297
298
  return self.sweeps == other.sweeps
298
299
 
@@ -327,7 +328,62 @@ class Zip(Sweep):
327
328
 
328
329
  @classmethod
329
330
  def _from_json_dict_(cls, sweeps, **kwargs):
330
- return Zip(*sweeps)
331
+ return cls(*sweeps)
332
+
333
+
334
+ class ZipLongest(Zip):
335
+ """Iterate over constituent sweeps in parallel
336
+
337
+ Analogous to itertools.zip_longest.
338
+ Note that we iterate until all sweeps terminate,
339
+ so if the sweeps are different lengths, the
340
+ shorter sweeps will be filled by repeating their last value
341
+ until all sweeps have equal length.
342
+
343
+ Note that this is different from itertools.zip_longest,
344
+ which uses a fixed fill value.
345
+
346
+ Raises:
347
+ ValueError if an input sweep if completely empty.
348
+ """
349
+
350
+ def __init__(self, *sweeps: Sweep) -> None:
351
+ super().__init__(*sweeps)
352
+ if any(len(sweep) == 0 for sweep in self.sweeps):
353
+ raise ValueError('All sweeps must be non-empty for ZipLongest')
354
+
355
+ def __eq__(self, other):
356
+ if not isinstance(other, ZipLongest):
357
+ return NotImplemented
358
+ return self.sweeps == other.sweeps
359
+
360
+ def __len__(self) -> int:
361
+ if not self.sweeps:
362
+ return 0
363
+ return max(len(sweep) for sweep in self.sweeps)
364
+
365
+ def __hash__(self) -> int:
366
+ return hash((self.__class__.__name__, tuple(self.sweeps)))
367
+
368
+ def __repr__(self) -> str:
369
+ sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
370
+ return f'cirq_google.ZipLongest({sweeps_repr})'
371
+
372
+ def __str__(self) -> str:
373
+ sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
374
+ return f'ZipLongest({sweeps_repr})'
375
+
376
+ def param_tuples(self) -> Iterator[Params]:
377
+ def _iter_and_repeat_last(one_iter: Iterator[Params]):
378
+ last = None
379
+ for last in one_iter:
380
+ yield last
381
+ while True:
382
+ yield last
383
+
384
+ iters = [_iter_and_repeat_last(sweep.param_tuples()) for sweep in self.sweeps]
385
+ for values in itertools.islice(zip(*iters), len(self)):
386
+ yield tuple(item for value in values for item in value)
331
387
 
332
388
 
333
389
  class SingleSweep(Sweep):
@@ -366,9 +422,23 @@ class SingleSweep(Sweep):
366
422
  class Points(SingleSweep):
367
423
  """A simple sweep with explicitly supplied values."""
368
424
 
369
- def __init__(self, key: 'cirq.TParamKey', points: Sequence[float]) -> None:
370
- super(Points, self).__init__(key)
425
+ def __init__(
426
+ self, key: 'cirq.TParamKey', points: Sequence[float], metadata: Optional[Any] = None
427
+ ) -> None:
428
+ """Creates a sweep on a variable with supplied values.
429
+
430
+ Args:
431
+ key: sympy.Symbol or equivalent to sweep across.
432
+ points: sequence of floating point values that represent
433
+ the values to sweep across. The length of the sweep
434
+ will be equivalent to the length of this sequence.
435
+ metadata: Optional metadata to attach to the sweep to
436
+ annotate the sweep or its variable.
437
+
438
+ """
439
+ super().__init__(key)
371
440
  self.points = points
441
+ self.metadata = metadata
372
442
 
373
443
  def _tuple(self) -> Tuple[Union[str, sympy.Expr], Sequence[float]]:
374
444
  return self.key, tuple(self.points)
@@ -380,25 +450,44 @@ class Points(SingleSweep):
380
450
  return iter(self.points)
381
451
 
382
452
  def __repr__(self) -> str:
383
- return f'cirq.Points({self.key!r}, {self.points!r})'
453
+ metadata_repr = f', metadata={self.metadata!r}' if self.metadata is not None else ""
454
+ return f'cirq.Points({self.key!r}, {self.points!r}{metadata_repr})'
384
455
 
385
456
  def _json_dict_(self) -> Dict[str, Any]:
457
+ if self.metadata is not None:
458
+ return protocols.obj_to_dict_helper(self, ["key", "points", "metadata"])
386
459
  return protocols.obj_to_dict_helper(self, ["key", "points"])
387
460
 
388
461
 
389
462
  class Linspace(SingleSweep):
390
463
  """A simple sweep over linearly-spaced values."""
391
464
 
392
- def __init__(self, key: 'cirq.TParamKey', start: float, stop: float, length: int) -> None:
465
+ def __init__(
466
+ self,
467
+ key: 'cirq.TParamKey',
468
+ start: float,
469
+ stop: float,
470
+ length: int,
471
+ metadata: Optional[Any] = None,
472
+ ) -> None:
393
473
  """Creates a linear-spaced sweep for a given key.
394
474
 
395
475
  For the given args, assigns to the list of values
396
476
  start, start + (stop - start) / (length - 1), ..., stop
477
+
478
+ Args:
479
+ key: sympy.Symbol or equivalent to sweep across.
480
+ start: minimum value of linear sweep.
481
+ stop: maximum value of linear sweep.
482
+ length: number of points in the sweep.
483
+ metadata: Optional metadata to attach to the sweep to
484
+ annotate the sweep or its variable.
397
485
  """
398
- super(Linspace, self).__init__(key)
486
+ super().__init__(key)
399
487
  self.start = start
400
488
  self.stop = stop
401
489
  self.length = length
490
+ self.metadata = metadata
402
491
 
403
492
  def _tuple(self) -> Tuple[Union[str, sympy.Expr], float, float, int]:
404
493
  return (self.key, self.start, self.stop, self.length)
@@ -415,12 +504,17 @@ class Linspace(SingleSweep):
415
504
  yield self.start * (1 - p) + self.stop * p
416
505
 
417
506
  def __repr__(self) -> str:
507
+ metadata_repr = f', metadata={self.metadata!r}' if self.metadata is not None else ""
418
508
  return (
419
509
  f'cirq.Linspace({self.key!r}, start={self.start!r}, '
420
- f'stop={self.stop!r}, length={self.length!r})'
510
+ f'stop={self.stop!r}, length={self.length!r}{metadata_repr})'
421
511
  )
422
512
 
423
513
  def _json_dict_(self) -> Dict[str, Any]:
514
+ if self.metadata is not None:
515
+ return protocols.obj_to_dict_helper(
516
+ self, ["key", "start", "stop", "length", "metadata"]
517
+ )
424
518
  return protocols.obj_to_dict_helper(self, ["key", "start", "stop", "length"])
425
519
 
426
520