cirq-core 1.2.0.dev20230717232332__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. cirq/__init__.py +5 -0
  2. cirq/_compat.py +26 -11
  3. cirq/_compat_test.py +37 -3
  4. cirq/_version.py +31 -1
  5. cirq/_version_test.py +1 -1
  6. cirq/circuits/circuit.py +106 -32
  7. cirq/circuits/circuit_operation.py +2 -2
  8. cirq/circuits/circuit_operation_test.py +1 -1
  9. cirq/circuits/circuit_test.py +109 -3
  10. cirq/circuits/frozen_circuit.py +80 -5
  11. cirq/circuits/frozen_circuit_test.py +47 -2
  12. cirq/circuits/qasm_output_test.py +9 -9
  13. cirq/conftest.py +1 -2
  14. cirq/contrib/acquaintance/devices.py +1 -1
  15. cirq/contrib/hacks/disable_validation_test.py +1 -1
  16. cirq/contrib/noise_models/noise_models.py +1 -2
  17. cirq/contrib/paulistring/clifford_optimize.py +1 -1
  18. cirq/contrib/paulistring/clifford_target_gateset_test.py +4 -4
  19. cirq/contrib/qcircuit/qcircuit_pdf.py +1 -1
  20. cirq/contrib/quimb/density_matrix.py +2 -3
  21. cirq/contrib/quimb/grid_circuits.py +3 -3
  22. cirq/contrib/quimb/state_vector.py +3 -5
  23. cirq/contrib/routing/utils.py +1 -2
  24. cirq/contrib/svg/svg.py +4 -6
  25. cirq/devices/grid_qubit.py +49 -38
  26. cirq/devices/grid_qubit_test.py +1 -3
  27. cirq/devices/insertion_noise_model.py +21 -1
  28. cirq/devices/insertion_noise_model_test.py +6 -0
  29. cirq/devices/line_qubit.py +67 -40
  30. cirq/devices/named_topologies.py +8 -14
  31. cirq/devices/noise_properties.py +1 -1
  32. cirq/devices/noise_utils.py +7 -5
  33. cirq/devices/noise_utils_test.py +7 -0
  34. cirq/experiments/fidelity_estimation_test.py +1 -1
  35. cirq/experiments/qubit_characterizations.py +6 -5
  36. cirq/experiments/random_quantum_circuit_generation.py +1 -1
  37. cirq/experiments/random_quantum_circuit_generation_test.py +28 -1
  38. cirq/experiments/readout_confusion_matrix.py +6 -6
  39. cirq/experiments/xeb_fitting.py +3 -5
  40. cirq/experiments/xeb_fitting_test.py +2 -2
  41. cirq/experiments/xeb_sampling.py +1 -1
  42. cirq/interop/quirk/url_to_circuit.py +40 -38
  43. cirq/json_resolver_cache.py +2 -0
  44. cirq/linalg/decompositions.py +6 -5
  45. cirq/ops/__init__.py +2 -0
  46. cirq/ops/classically_controlled_operation.py +1 -1
  47. cirq/ops/clifford_gate.py +9 -9
  48. cirq/ops/clifford_gate_test.py +3 -4
  49. cirq/ops/common_channels.py +2 -5
  50. cirq/ops/common_channels_test.py +3 -5
  51. cirq/ops/common_gates_test.py +7 -7
  52. cirq/ops/controlled_operation_test.py +2 -2
  53. cirq/ops/dense_pauli_string.py +3 -0
  54. cirq/ops/eigen_gate_test.py +1 -3
  55. cirq/ops/fourier_transform.py +1 -2
  56. cirq/ops/fsim_gate.py +1 -1
  57. cirq/ops/gate_features_test.py +2 -2
  58. cirq/ops/gate_operation_test.py +1 -2
  59. cirq/ops/greedy_qubit_manager.py +86 -0
  60. cirq/ops/greedy_qubit_manager_test.py +98 -0
  61. cirq/ops/linear_combinations.py +1 -1
  62. cirq/ops/named_qubit.py +55 -18
  63. cirq/ops/parity_gates.py +65 -18
  64. cirq/ops/parity_gates_test.py +41 -2
  65. cirq/ops/pauli_gates.py +2 -2
  66. cirq/ops/pauli_string.py +3 -4
  67. cirq/ops/pauli_string_raw_types_test.py +3 -3
  68. cirq/ops/pauli_string_test.py +3 -4
  69. cirq/ops/random_gate_channel_test.py +3 -3
  70. cirq/ops/raw_types.py +1 -1
  71. cirq/ops/raw_types_test.py +5 -5
  72. cirq/ops/three_qubit_gates.py +12 -8
  73. cirq/protocols/act_on_protocol_test.py +9 -9
  74. cirq/protocols/apply_channel_protocol.py +9 -6
  75. cirq/protocols/apply_unitary_protocol_test.py +1 -1
  76. cirq/protocols/equal_up_to_global_phase_protocol_test.py +2 -2
  77. cirq/protocols/has_stabilizer_effect_protocol.py +52 -6
  78. cirq/protocols/has_stabilizer_effect_protocol_test.py +21 -8
  79. cirq/protocols/has_unitary_protocol_test.py +1 -3
  80. cirq/protocols/json_serialization.py +6 -6
  81. cirq/protocols/json_serialization_test.py +7 -14
  82. cirq/protocols/json_test_data/InsertionNoiseModel.json +91 -0
  83. cirq/protocols/json_test_data/InsertionNoiseModel.repr +4 -0
  84. cirq/protocols/json_test_data/OpIdentifier.json +45 -10
  85. cirq/protocols/json_test_data/OpIdentifier.repr +7 -1
  86. cirq/protocols/json_test_data/spec.py +4 -0
  87. cirq/protocols/measurement_key_protocol_test.py +1 -1
  88. cirq/protocols/unitary_protocol_test.py +13 -16
  89. cirq/qis/clifford_tableau.py +7 -8
  90. cirq/qis/measures.py +1 -1
  91. cirq/qis/states.py +2 -3
  92. cirq/sim/__init__.py +2 -0
  93. cirq/sim/classical_simulator.py +107 -0
  94. cirq/sim/classical_simulator_test.py +207 -0
  95. cirq/sim/clifford/clifford_simulator_test.py +7 -7
  96. cirq/sim/clifford/stabilizer_simulation_state.py +2 -2
  97. cirq/sim/clifford/stabilizer_state_ch_form.py +7 -7
  98. cirq/sim/density_matrix_simulation_state.py +19 -4
  99. cirq/sim/density_matrix_simulator_test.py +5 -13
  100. cirq/sim/simulation_state_test.py +13 -14
  101. cirq/sim/simulator_test.py +6 -9
  102. cirq/sim/state_vector_simulation_state.py +1 -1
  103. cirq/study/resolver.py +41 -41
  104. cirq/study/resolver_test.py +13 -12
  105. cirq/testing/__init__.py +4 -1
  106. cirq/testing/circuit_compare.py +1 -1
  107. cirq/testing/circuit_compare_test.py +11 -11
  108. cirq/testing/consistent_controlled_gate_op.py +15 -1
  109. cirq/testing/consistent_controlled_gate_op_test.py +12 -3
  110. cirq/testing/consistent_decomposition.py +0 -1
  111. cirq/testing/consistent_protocols.py +6 -1
  112. cirq/testing/consistent_protocols_test.py +5 -10
  113. cirq/testing/consistent_qasm.py +2 -4
  114. cirq/testing/consistent_qasm_test.py +2 -3
  115. cirq/testing/consistent_specified_has_unitary_test.py +1 -3
  116. cirq/testing/equals_tester.py +1 -1
  117. cirq/testing/equals_tester_test.py +5 -5
  118. cirq/testing/equivalent_repr_eval_test.py +1 -3
  119. cirq/testing/gate_features_test.py +6 -6
  120. cirq/testing/order_tester_test.py +1 -3
  121. cirq/testing/random_circuit_test.py +1 -3
  122. cirq/transformers/__init__.py +3 -0
  123. cirq/transformers/analytical_decompositions/__init__.py +1 -0
  124. cirq/transformers/analytical_decompositions/three_qubit_decomposition.py +1 -2
  125. cirq/transformers/analytical_decompositions/three_qubit_decomposition_test.py +2 -5
  126. cirq/transformers/analytical_decompositions/two_qubit_state_preparation.py +38 -0
  127. cirq/transformers/analytical_decompositions/two_qubit_state_preparation_test.py +18 -0
  128. cirq/transformers/expand_composite_test.py +4 -4
  129. cirq/transformers/heuristic_decompositions/gate_tabulation_math_utils.py +1 -1
  130. cirq/transformers/heuristic_decompositions/two_qubit_gate_tabulation.py +1 -2
  131. cirq/transformers/merge_k_qubit_gates_test.py +2 -2
  132. cirq/transformers/qubit_management_transformers.py +177 -0
  133. cirq/transformers/qubit_management_transformers_test.py +250 -0
  134. cirq/transformers/routing/route_circuit_cqc.py +23 -4
  135. cirq/transformers/routing/route_circuit_cqc_test.py +42 -0
  136. cirq/transformers/stratify.py +10 -11
  137. cirq/transformers/target_gatesets/compilation_target_gateset_test.py +10 -10
  138. cirq/transformers/target_gatesets/cz_gateset_test.py +8 -10
  139. cirq/transformers/transformer_primitives.py +138 -28
  140. cirq/value/abc_alt_test.py +4 -4
  141. cirq/value/duration.py +68 -37
  142. cirq/value/duration_test.py +2 -0
  143. cirq/value/measurement_key_test.py +1 -1
  144. cirq/value/product_state.py +4 -8
  145. cirq/value/value_equality_attr.py +12 -5
  146. cirq/vis/heatmap.py +7 -4
  147. cirq/vis/heatmap_test.py +14 -4
  148. cirq/vis/histogram.py +4 -4
  149. cirq/vis/state_histogram.py +10 -6
  150. cirq/vis/state_histogram_test.py +2 -0
  151. cirq/work/observable_measurement_data_test.py +1 -1
  152. cirq/work/observable_measurement_test.py +2 -2
  153. cirq/work/zeros_sampler.py +1 -1
  154. {cirq_core-1.2.0.dev20230717232332.dist-info → cirq_core-1.3.0.dist-info}/METADATA +11 -19
  155. {cirq_core-1.2.0.dev20230717232332.dist-info → cirq_core-1.3.0.dist-info}/RECORD +158 -150
  156. {cirq_core-1.2.0.dev20230717232332.dist-info → cirq_core-1.3.0.dist-info}/WHEEL +1 -1
  157. {cirq_core-1.2.0.dev20230717232332.dist-info → cirq_core-1.3.0.dist-info}/LICENSE +0 -0
  158. {cirq_core-1.2.0.dev20230717232332.dist-info → cirq_core-1.3.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from cirq.sim import simulation_state
22
22
  from cirq.testing import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla
23
23
 
24
24
 
25
- class DummyQuantumState(cirq.QuantumStateRepresentation):
25
+ class ExampleQuantumState(cirq.QuantumStateRepresentation):
26
26
  def copy(self, deep_copy_buffers=True):
27
27
  pass
28
28
 
@@ -33,9 +33,9 @@ class DummyQuantumState(cirq.QuantumStateRepresentation):
33
33
  return self
34
34
 
35
35
 
36
- class DummySimulationState(cirq.SimulationState):
36
+ class ExampleSimulationState(cirq.SimulationState):
37
37
  def __init__(self, qubits=cirq.LineQubit.range(2)):
38
- super().__init__(state=DummyQuantumState(), qubits=qubits)
38
+ super().__init__(state=ExampleQuantumState(), qubits=qubits)
39
39
 
40
40
  def _act_on_fallback_(
41
41
  self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
@@ -73,13 +73,13 @@ class Composite(cirq.Gate):
73
73
 
74
74
 
75
75
  def test_measurements():
76
- args = DummySimulationState()
76
+ args = ExampleSimulationState()
77
77
  args.measure([cirq.LineQubit(0)], "test", [False], {})
78
78
  assert args.log_of_measurement_results["test"] == [5]
79
79
 
80
80
 
81
81
  def test_decompose():
82
- args = DummySimulationState()
82
+ args = ExampleSimulationState()
83
83
  assert simulation_state.strat_act_on_from_apply_decompose(
84
84
  Composite(), args, [cirq.LineQubit(0)]
85
85
  )
@@ -91,14 +91,14 @@ def test_decompose_for_gate_allocating_qubits_raises():
91
91
  anc = cirq.NamedQubit("anc")
92
92
  yield cirq.CNOT(*qubits, anc)
93
93
 
94
- args = DummySimulationState()
94
+ args = ExampleSimulationState()
95
95
 
96
96
  with pytest.raises(TypeError, match="add_qubits but not remove_qubits"):
97
97
  simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
98
98
 
99
99
 
100
100
  def test_mapping():
101
- args = DummySimulationState()
101
+ args = ExampleSimulationState()
102
102
  assert list(iter(args)) == cirq.LineQubit.range(2)
103
103
  r1 = args[cirq.LineQubit(0)]
104
104
  assert args is r1
@@ -109,7 +109,7 @@ def test_mapping():
109
109
  def test_swap_bad_dimensions():
110
110
  q0 = cirq.LineQubit(0)
111
111
  q1 = cirq.LineQid(1, 3)
112
- args = DummySimulationState()
112
+ args = ExampleSimulationState()
113
113
  with pytest.raises(ValueError, match='Cannot swap different dimensions'):
114
114
  args.swap(q0, q1)
115
115
 
@@ -117,14 +117,14 @@ def test_swap_bad_dimensions():
117
117
  def test_rename_bad_dimensions():
118
118
  q0 = cirq.LineQubit(0)
119
119
  q1 = cirq.LineQid(1, 3)
120
- args = DummySimulationState()
120
+ args = ExampleSimulationState()
121
121
  with pytest.raises(ValueError, match='Cannot rename to different dimensions'):
122
122
  args.rename(q0, q1)
123
123
 
124
124
 
125
125
  def test_transpose_qubits():
126
126
  q0, q1, q2 = cirq.LineQubit.range(3)
127
- args = DummySimulationState()
127
+ args = ExampleSimulationState()
128
128
  assert args.transpose_to_qubit_order((q1, q0)).qubits == (q1, q0)
129
129
  with pytest.raises(ValueError, match='Qubits do not match'):
130
130
  args.transpose_to_qubit_order((q0, q2))
@@ -133,7 +133,7 @@ def test_transpose_qubits():
133
133
 
134
134
 
135
135
  def test_field_getters():
136
- args = DummySimulationState()
136
+ args = ExampleSimulationState()
137
137
  assert args.prng is np.random
138
138
  assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}
139
139
 
@@ -164,9 +164,8 @@ def test_delegating_gate_channel(exp):
164
164
  control_circuit = cirq.Circuit(cirq.H(q))
165
165
  control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))
166
166
 
167
- with pytest.raises(TypeError, match="DensityMatrixSimulator doesn't support"):
168
- # TODO: This test should pass once we extend support to DensityMatrixSimulator.
169
- assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
167
+ assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)
168
+ assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
170
169
 
171
170
 
172
171
  @pytest.mark.parametrize('num_ancilla', [1, 2, 3])
@@ -134,8 +134,7 @@ def test_intermediate_simulator():
134
134
 
135
135
  simulator.simulate_moment_steps.side_effect = steps
136
136
  circuit = mock.Mock(cirq.Circuit)
137
- param_resolver = mock.Mock(cirq.ParamResolver)
138
- param_resolver.param_dict = {}
137
+ param_resolver = cirq.ParamResolver({})
139
138
  qubit_order = mock.Mock(cirq.QubitOrder)
140
139
  result = simulator.simulate(
141
140
  program=circuit, param_resolver=param_resolver, qubit_order=qubit_order, initial_state=2
@@ -163,9 +162,7 @@ def test_intermediate_sweeps():
163
162
 
164
163
  simulator.simulate_moment_steps.side_effect = steps
165
164
  circuit = mock.Mock(cirq.Circuit)
166
- param_resolvers = [mock.Mock(cirq.ParamResolver), mock.Mock(cirq.ParamResolver)]
167
- for resolver in param_resolvers:
168
- resolver.param_dict = {}
165
+ param_resolvers = [cirq.ParamResolver({}), cirq.ParamResolver({})]
169
166
  qubit_order = mock.Mock(cirq.QubitOrder)
170
167
  results = simulator.simulate_sweep(
171
168
  program=circuit, params=param_resolvers, qubit_order=qubit_order, initial_state=2
@@ -435,7 +432,7 @@ def test_monte_carlo_on_unknown_channel():
435
432
 
436
433
 
437
434
  def test_iter_definitions():
438
- dummy_trial_result = SimulationTrialResult(params={}, measurements={}, final_simulator_state=[])
435
+ mock_trial_result = SimulationTrialResult(params={}, measurements={}, final_simulator_state=[])
439
436
 
440
437
  class FakeNonIterSimulatorImpl(
441
438
  SimulatesAmplitudes, SimulatesExpectationValues, SimulatesFinalState
@@ -472,7 +469,7 @@ def test_iter_definitions():
472
469
  qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT,
473
470
  initial_state: Any = None,
474
471
  ) -> List[SimulationTrialResult]:
475
- return [dummy_trial_result]
472
+ return [mock_trial_result]
476
473
 
477
474
  non_iter_sim = FakeNonIterSimulatorImpl()
478
475
  q0 = cirq.LineQubit(0)
@@ -488,9 +485,9 @@ def test_iter_definitions():
488
485
  ev_iter = non_iter_sim.simulate_expectation_values_sweep_iter(circuit, obs, params)
489
486
  assert next(ev_iter) == [1.0]
490
487
 
491
- assert non_iter_sim.simulate_sweep(circuit, params) == [dummy_trial_result]
488
+ assert non_iter_sim.simulate_sweep(circuit, params) == [mock_trial_result]
492
489
  state_iter = non_iter_sim.simulate_sweep_iter(circuit, params)
493
- assert next(state_iter) == dummy_trial_result
490
+ assert next(state_iter) == mock_trial_result
494
491
 
495
492
 
496
493
  def test_missing_iter_definitions():
@@ -387,7 +387,7 @@ class StateVectorSimulationState(SimulationState[_BufferedStateVector]):
387
387
  for strat in strats:
388
388
  result = strat(action, self, qubits)
389
389
  if result is False:
390
- break # coverage: ignore
390
+ break # pragma: no cover
391
391
  if result is True:
392
392
  return True
393
393
  assert result is NotImplemented, str(result)
cirq/study/resolver.py CHANGED
@@ -36,8 +36,11 @@ document(
36
36
  ParamResolverOrSimilarType, """Something that can be used to turn parameters into values."""
37
37
  )
38
38
 
39
+ # Used to mark values that are not found in a dict.
40
+ _NOT_FOUND = object()
41
+
39
42
  # Used to mark values that are being resolved recursively to detect loops.
40
- _RecursionFlag = object()
43
+ _RECURSION_FLAG = object()
41
44
 
42
45
 
43
46
  def _is_param_resolver_or_similar_type(obj: Any):
@@ -72,7 +75,7 @@ class ParamResolver:
72
75
 
73
76
  self._param_hash: Optional[int] = None
74
77
  self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
75
- for key in self.param_dict:
78
+ for key in self._param_dict:
76
79
  if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
77
80
  raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
78
81
  self._deep_eval_map: ParamDictType = {}
@@ -120,32 +123,30 @@ class ParamResolver:
120
123
  if v is not NotImplemented:
121
124
  return v
122
125
 
123
- # Handles 2 cases:
124
- # Input is a string and maps to a number in the dictionary
125
- # Input is a symbol and maps to a number in the dictionary
126
- # In both cases, return it directly.
127
- if value in self.param_dict:
128
- # Note: if the value is in the dictionary, it will be a key type
129
- # Add a cast to make mypy happy.
130
- param_value = self.param_dict[cast('cirq.TParamKey', value)]
126
+ # Handle string or symbol
127
+ if isinstance(value, (str, sympy.Symbol)):
128
+ string = value if isinstance(value, str) else value.name
129
+ symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value)
130
+ param_value = self._param_dict.get(string, _NOT_FOUND)
131
+ if param_value is _NOT_FOUND:
132
+ param_value = self._param_dict.get(symbol, _NOT_FOUND)
133
+ if param_value is _NOT_FOUND:
134
+ # Symbol or string cannot be resolved if not in param dict; return as symbol.
135
+ return symbol
131
136
  v = _resolve_value(param_value)
132
137
  if v is not NotImplemented:
133
138
  return v
139
+ if isinstance(param_value, str):
140
+ param_value = sympy.Symbol(param_value)
141
+ elif not isinstance(param_value, sympy.Basic):
142
+ return value # type: ignore[return-value]
143
+ if recursive:
144
+ param_value = self._value_of_recursive(value)
145
+ return param_value # type: ignore[return-value]
134
146
 
135
- # Input is a string and is not in the dictionary.
136
- # Treat it as a symbol instead.
137
- if isinstance(value, str):
138
- # If the string is in the param_dict as a value, return it.
139
- # Otherwise, try using the symbol instead.
140
- return self.value_of(sympy.Symbol(value), recursive)
141
-
142
- # Input is a symbol (sympy.Symbol('a')) and its string maps to a number
143
- # in the dictionary ({'a': 1.0}). Return it.
144
- if isinstance(value, sympy.Symbol) and value.name in self.param_dict:
145
- param_value = self.param_dict[value.name]
146
- v = _resolve_value(param_value)
147
- if v is not NotImplemented:
148
- return v
147
+ if not isinstance(value, sympy.Basic):
148
+ # No known way to resolve this variable, return unchanged.
149
+ return value
149
150
 
150
151
  # The following resolves common sympy expressions
151
152
  # If sympy did its job and wasn't slower than molasses,
@@ -171,10 +172,6 @@ class ParamResolver:
171
172
  return np.float_power(cast(complex, base), cast(complex, exponent))
172
173
  return np.power(cast(complex, base), cast(complex, exponent))
173
174
 
174
- if not isinstance(value, sympy.Basic):
175
- # No known way to resolve this variable, return unchanged.
176
- return value
177
-
178
175
  # Input is either a sympy formula or the dictionary maps to a
179
176
  # formula. Use sympy to resolve the value.
180
177
  # Note that sympy.subs() is slow, so we want to avoid this and
@@ -186,7 +183,7 @@ class ParamResolver:
186
183
  # Note that a sympy.SympifyError here likely means
187
184
  # that one of the expressions was not parsable by sympy
188
185
  # (such as a function returning NotImplemented)
189
- v = value.subs(self.param_dict, simultaneous=True)
186
+ v = value.subs(self._param_dict, simultaneous=True)
190
187
 
191
188
  if v.free_symbols:
192
189
  return v
@@ -197,23 +194,26 @@ class ParamResolver:
197
194
  else:
198
195
  return float(v)
199
196
 
197
+ return self._value_of_recursive(value)
198
+
199
+ def _value_of_recursive(self, value: 'cirq.TParamKey') -> 'cirq.TParamValComplex':
200
200
  # Recursive parameter resolution. We can safely assume that value is a
201
201
  # single symbol, since combinations are handled earlier in the method.
202
202
  if value in self._deep_eval_map:
203
203
  v = self._deep_eval_map[value]
204
- if v is not _RecursionFlag:
205
- return v
206
- raise RecursionError('Evaluation of {value} indirectly contains itself.')
204
+ if v is _RECURSION_FLAG:
205
+ raise RecursionError('Evaluation of {value} indirectly contains itself.')
206
+ return v
207
207
 
208
208
  # There isn't a full evaluation for 'value' yet. Until it's ready,
209
209
  # map value to None to identify loops in component evaluation.
210
- self._deep_eval_map[value] = _RecursionFlag # type: ignore
210
+ self._deep_eval_map[value] = _RECURSION_FLAG # type: ignore
211
211
 
212
212
  v = self.value_of(value, recursive=False)
213
213
  if v == value:
214
214
  self._deep_eval_map[value] = v
215
215
  else:
216
- self._deep_eval_map[value] = self.value_of(v, recursive)
216
+ self._deep_eval_map[value] = self.value_of(v, recursive=True)
217
217
  return self._deep_eval_map[value]
218
218
 
219
219
  def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'ParamResolver':
@@ -224,17 +224,17 @@ class ParamResolver:
224
224
  new_dict.update(
225
225
  {k: resolver.value_of(v, recursive) for k, v in new_dict.items()} # type: ignore[misc]
226
226
  )
227
- if recursive and self.param_dict:
227
+ if recursive and self._param_dict:
228
228
  new_resolver = ParamResolver(cast(ParamDictType, new_dict))
229
229
  # Resolve down to single-step mappings.
230
230
  return ParamResolver()._resolve_parameters_(new_resolver, recursive=True)
231
231
  return ParamResolver(cast(ParamDictType, new_dict))
232
232
 
233
233
  def __iter__(self) -> Iterator[Union[str, sympy.Expr]]:
234
- return iter(self.param_dict)
234
+ return iter(self._param_dict)
235
235
 
236
236
  def __bool__(self) -> bool:
237
- return bool(self.param_dict)
237
+ return bool(self._param_dict)
238
238
 
239
239
  def __getitem__(
240
240
  self, key: Union['cirq.TParamKey', 'cirq.TParamValComplex']
@@ -243,13 +243,13 @@ class ParamResolver:
243
243
 
244
244
  def __hash__(self) -> int:
245
245
  if self._param_hash is None:
246
- self._param_hash = hash(frozenset(self.param_dict.items()))
246
+ self._param_hash = hash(frozenset(self._param_dict.items()))
247
247
  return self._param_hash
248
248
 
249
249
  def __eq__(self, other):
250
250
  if not isinstance(other, ParamResolver):
251
251
  return NotImplemented
252
- return self.param_dict == other.param_dict
252
+ return self._param_dict == other._param_dict
253
253
 
254
254
  def __ne__(self, other):
255
255
  return not self == other
@@ -257,7 +257,7 @@ class ParamResolver:
257
257
  def __repr__(self) -> str:
258
258
  param_dict_repr = (
259
259
  '{'
260
- + ', '.join([f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self.param_dict.items()])
260
+ + ', '.join(f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self._param_dict.items())
261
261
  + '}'
262
262
  )
263
263
  return f'cirq.ParamResolver({param_dict_repr})'
@@ -265,7 +265,7 @@ class ParamResolver:
265
265
  def _json_dict_(self) -> Dict[str, Any]:
266
266
  return {
267
267
  # JSON requires mappings to have keys of basic types.
268
- 'param_dict': list(self.param_dict.items())
268
+ 'param_dict': list(self._param_dict.items())
269
269
  }
270
270
 
271
271
  @classmethod
@@ -53,10 +53,10 @@ def test_value_of_transformed_types(val, resolved):
53
53
 
54
54
  @pytest.mark.parametrize('val,resolved', [(sympy.I, 1j)])
55
55
  def test_value_of_substituted_types(val, resolved):
56
- _assert_consistent_resolution(val, resolved, True)
56
+ _assert_consistent_resolution(val, resolved)
57
57
 
58
58
 
59
- def _assert_consistent_resolution(v, resolved, subs_called=False):
59
+ def _assert_consistent_resolution(v, resolved):
60
60
  """Asserts that parameter resolution works consistently.
61
61
 
62
62
  The ParamResolver.value_of method can resolve any Sympy expression -
@@ -70,7 +70,7 @@ def _assert_consistent_resolution(v, resolved, subs_called=False):
70
70
  Args:
71
71
  v: the value to resolve
72
72
  resolved: the expected resolution result
73
- subs_called: if True, it is expected that the slow subs method is called
73
+
74
74
  Raises:
75
75
  AssertionError in case resolution assertion fail.
76
76
  """
@@ -93,9 +93,7 @@ def _assert_consistent_resolution(v, resolved, subs_called=False):
93
93
  # symbol based resolution
94
94
  s = SubsAwareSymbol('a')
95
95
  assert r.value_of(s) == resolved, f"expected {resolved}, got {r.value_of(s)}"
96
- assert (
97
- subs_called == s.called
98
- ), f"For pass-through type {type(v)} sympy.subs shouldn't have been called."
96
+ assert not s.called, f"For pass-through type {type(v)} sympy.subs shouldn't have been called."
99
97
  assert isinstance(
100
98
  r.value_of(s), type(resolved)
101
99
  ), f"expected {type(resolved)} got {type(r.value_of(s))}"
@@ -243,15 +241,18 @@ def test_custom_resolved_value():
243
241
 
244
242
 
245
243
  def test_custom_value_not_implemented():
246
- class Bar:
244
+ class BarImplicit:
245
+ pass
246
+
247
+ class BarExplicit:
247
248
  def _resolved_value_(self):
248
249
  return NotImplemented
249
250
 
250
- b = sympy.Symbol('b')
251
- bar = Bar()
252
- r = cirq.ParamResolver({b: bar})
253
- with pytest.raises(sympy.SympifyError):
254
- _ = r.value_of(b)
251
+ for cls in [BarImplicit, BarExplicit]:
252
+ b = sympy.Symbol('b')
253
+ bar = cls()
254
+ r = cirq.ParamResolver({b: bar})
255
+ assert r.value_of(b) == b
255
256
 
256
257
 
257
258
  def test_compose():
cirq/testing/__init__.py CHANGED
@@ -30,7 +30,10 @@ from cirq.testing.consistent_act_on import assert_all_implemented_act_on_effects
30
30
 
31
31
  from cirq.testing.consistent_channels import assert_consistent_channel, assert_consistent_mixture
32
32
 
33
- from cirq.testing.consistent_controlled_gate_op import assert_controlled_and_controlled_by_identical
33
+ from cirq.testing.consistent_controlled_gate_op import (
34
+ assert_controlled_and_controlled_by_identical,
35
+ assert_controlled_unitary_consistent,
36
+ )
34
37
 
35
38
  from cirq.testing.consistent_decomposition import (
36
39
  assert_decompose_ends_at_default_gateset,
@@ -218,7 +218,7 @@ def _first_differing_moment_index(
218
218
  for i, (m1, m2) in enumerate(itertools.zip_longest(circuit1, circuit2)):
219
219
  if m1 != m2:
220
220
  return i
221
- return None # coverage: ignore
221
+ return None # pragma: no cover
222
222
 
223
223
 
224
224
  def assert_circuits_have_same_unitary_given_final_permutation(
@@ -484,7 +484,7 @@ def test_assert_has_consistent_qid_shape():
484
484
 
485
485
  class ConsistentOp(cirq.Operation):
486
486
  def with_qubits(self, *qubits):
487
- raise NotImplementedError # coverage: ignore
487
+ raise NotImplementedError # pragma: no cover
488
488
 
489
489
  @property
490
490
  def qubits(self):
@@ -496,47 +496,47 @@ def test_assert_has_consistent_qid_shape():
496
496
  def _qid_shape_(self):
497
497
  return (1, 2, 3, 4)
498
498
 
499
- # The 'coverage: ignore' comments in the InconsistentOp classes is needed
499
+ # The 'pragma: no cover' comments in the InconsistentOp classes is needed
500
500
  # because test_assert_has_consistent_qid_shape may only need to check two of
501
501
  # the three methods before finding an inconsistency and throwing an error.
502
502
  class InconsistentOp1(cirq.Operation):
503
503
  def with_qubits(self, *qubits):
504
- raise NotImplementedError # coverage: ignore
504
+ raise NotImplementedError # pragma: no cover
505
505
 
506
506
  @property
507
507
  def qubits(self):
508
508
  return cirq.LineQubit.range(2)
509
509
 
510
510
  def _num_qubits_(self):
511
- return 4 # coverage: ignore
511
+ return 4 # pragma: no cover
512
512
 
513
513
  def _qid_shape_(self):
514
- return (1, 2, 3, 4) # coverage: ignore
514
+ return (1, 2, 3, 4) # pragma: no cover
515
515
 
516
516
  class InconsistentOp2(cirq.Operation):
517
517
  def with_qubits(self, *qubits):
518
- raise NotImplementedError # coverage: ignore
518
+ raise NotImplementedError # pragma: no cover
519
519
 
520
520
  @property
521
521
  def qubits(self):
522
- return cirq.LineQubit.range(4) # coverage: ignore
522
+ return cirq.LineQubit.range(4) # pragma: no cover
523
523
 
524
524
  def _num_qubits_(self):
525
525
  return 2
526
526
 
527
527
  def _qid_shape_(self):
528
- return (1, 2, 3, 4) # coverage: ignore
528
+ return (1, 2, 3, 4) # pragma: no cover
529
529
 
530
530
  class InconsistentOp3(cirq.Operation):
531
531
  def with_qubits(self, *qubits):
532
- raise NotImplementedError # coverage: ignore
532
+ raise NotImplementedError # pragma: no cover
533
533
 
534
534
  @property
535
535
  def qubits(self):
536
- return cirq.LineQubit.range(4) # coverage: ignore
536
+ return cirq.LineQubit.range(4) # pragma: no cover
537
537
 
538
538
  def _num_qubits_(self):
539
- return 4 # coverage: ignore
539
+ return 4 # pragma: no cover
540
540
 
541
541
  def _qid_shape_(self):
542
542
  return 1, 2
@@ -14,7 +14,8 @@
14
14
 
15
15
  from typing import Sequence, Optional, Union, Collection
16
16
 
17
- from cirq import devices, ops
17
+ from cirq import devices, ops, protocols
18
+ import numpy as np
18
19
 
19
20
 
20
21
  def assert_controlled_and_controlled_by_identical(
@@ -34,6 +35,19 @@ def assert_controlled_and_controlled_by_identical(
34
35
  _assert_gate_consistent(gate, num_control, control_value)
35
36
 
36
37
 
38
+ def assert_controlled_unitary_consistent(gate: ops.Gate):
39
+ """Checks that unitary of ControlledGate(gate) is consistent with gate.controlled()."""
40
+
41
+ u_orig = protocols.unitary(ops.ControlledGate(gate))
42
+ u_controlled = protocols.unitary(gate.controlled())
43
+ np.testing.assert_allclose(
44
+ u_orig,
45
+ u_controlled,
46
+ atol=1e-6,
47
+ err_msg=f"Unitary for gate.controlled() is inconsistent for {gate=}",
48
+ )
49
+
50
+
37
51
  def _assert_gate_consistent(
38
52
  gate: ops.Gate,
39
53
  num_controls: int,
@@ -23,8 +23,7 @@ from cirq.ops import control_values as cv
23
23
 
24
24
 
25
25
  class GoodGate(cirq.EigenGate, cirq.testing.SingleQubitGate):
26
- def _eigen_components(self) -> List[Tuple[float, np.ndarray]]:
27
- # coverage: ignore
26
+ def _eigen_components(self) -> List[Tuple[float, np.ndarray]]: # pragma: no cover
28
27
  return [(0, np.diag([1, 0])), (1, np.diag([0, 1]))]
29
28
 
30
29
 
@@ -41,7 +40,6 @@ class BadGateOperation(cirq.GateOperation):
41
40
 
42
41
  class BadGate(cirq.EigenGate, cirq.testing.SingleQubitGate):
43
42
  def _eigen_components(self) -> List[Tuple[float, np.ndarray]]:
44
- # coverage: ignore
45
43
  return [(0, np.diag([1, 0])), (1, np.diag([0, 1]))]
46
44
 
47
45
  def on(self, *qubits: 'cirq.Qid') -> 'cirq.Operation':
@@ -76,3 +74,14 @@ def test_assert_controlled_and_controlled_by_identical():
76
74
  cirq.testing.assert_controlled_and_controlled_by_identical(
77
75
  GoodGate(), num_controls=[1, 2], control_values=[(1,), (1, 1, 1)]
78
76
  )
77
+
78
+
79
+ def test_assert_controlled_unitary_consistent():
80
+ cirq.testing.assert_controlled_and_controlled_by_identical(
81
+ GoodGate(exponent=0.5, global_shift=1 / 3)
82
+ )
83
+
84
+ with pytest.raises(AssertionError):
85
+ cirq.testing.assert_controlled_and_controlled_by_identical(
86
+ BadGate(exponent=0.5, global_shift=1 / 3)
87
+ )
@@ -54,7 +54,6 @@ def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase:
54
54
  if ignoring_global_phase:
55
55
  lin_alg_utils.assert_allclose_up_to_global_phase(actual, expected, atol=1e-8)
56
56
  else:
57
- # coverage: ignore
58
57
  np.testing.assert_allclose(actual, expected, atol=1e-8)
59
58
 
60
59
 
@@ -36,7 +36,10 @@ from cirq.testing.consistent_pauli_expansion import (
36
36
  from cirq.testing.consistent_resolve_parameters import assert_consistent_resolve_parameters
37
37
  from cirq.testing.consistent_specified_has_unitary import assert_specifies_has_unitary_if_unitary
38
38
  from cirq.testing.equivalent_repr_eval import assert_equivalent_repr
39
- from cirq.testing.consistent_controlled_gate_op import assert_controlled_and_controlled_by_identical
39
+ from cirq.testing.consistent_controlled_gate_op import (
40
+ assert_controlled_and_controlled_by_identical,
41
+ assert_controlled_unitary_consistent,
42
+ )
40
43
  from cirq.testing.consistent_unitary import assert_unitary_is_consistent
41
44
 
42
45
 
@@ -167,6 +170,8 @@ def _assert_meets_standards_helper(
167
170
  assert_eigen_shifts_is_consistent_with_eigen_components(val)
168
171
  if isinstance(val, ops.Gate) and protocols.has_mixture(val):
169
172
  assert_controlled_and_controlled_by_identical(val)
173
+ if protocols.has_unitary(val):
174
+ assert_controlled_unitary_consistent(val)
170
175
 
171
176
 
172
177
  def assert_commutes_magic_method_consistent_with_unitaries(
@@ -87,8 +87,7 @@ class GoodGate(cirq.testing.SingleQubitGate):
87
87
  def __pow__(self, exponent: Union[float, sympy.Expr]) -> 'GoodGate':
88
88
  new_exponent = cirq.mul(self.exponent, exponent, NotImplemented)
89
89
  if new_exponent is NotImplemented:
90
- # coverage: ignore
91
- return NotImplemented
90
+ return NotImplemented # pragma: no cover
92
91
  return GoodGate(phase_exponent=self.phase_exponent, exponent=new_exponent)
93
92
 
94
93
  def __repr__(self):
@@ -114,8 +113,7 @@ class GoodGate(cirq.testing.SingleQubitGate):
114
113
 
115
114
  def __eq__(self, other):
116
115
  if not isinstance(other, type(self)):
117
- # coverage: ignore
118
- return NotImplemented
116
+ return NotImplemented # pragma: no cover
119
117
  return self._identity_tuple() == other._identity_tuple()
120
118
 
121
119
 
@@ -132,8 +130,7 @@ class BadGateParameterNames(GoodGate):
132
130
  class BadGateApplyUnitaryToTensor(GoodGate):
133
131
  def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> Union[np.ndarray, NotImplementedType]:
134
132
  if self.exponent != 1 or cirq.is_parameterized(self):
135
- # coverage: ignore
136
- return NotImplemented
133
+ return NotImplemented # pragma: no cover
137
134
 
138
135
  zero = cirq.slice_for_qubits_equal_to(args.axes, 0)
139
136
  one = cirq.slice_for_qubits_equal_to(args.axes, 1)
@@ -154,8 +151,7 @@ class BadGateDecompose(GoodGate):
154
151
  z = cirq.Z(q) ** self.phase_exponent
155
152
  x = cirq.X(q) ** (2 * self.exponent)
156
153
  if cirq.is_parameterized(z):
157
- # coverage: ignore
158
- return NotImplemented
154
+ return NotImplemented # pragma: no cover
159
155
  return z**-1, x, z
160
156
 
161
157
 
@@ -176,8 +172,7 @@ class BadGateRepr(GoodGate):
176
172
  def __repr__(self):
177
173
  args = [f'phase_exponent={2 * self.phase_exponent!r}']
178
174
  if self.exponent != 1:
179
- # coverage: ignore
180
- args.append(f'exponent={proper_repr(self.exponent)}')
175
+ args.append(f'exponent={proper_repr(self.exponent)}') # pragma: no cover
181
176
  return f"BadGateRepr({', '.join(args)})"
182
177
 
183
178
 
@@ -27,8 +27,7 @@ def assert_qasm_is_consistent_with_unitary(val: Any):
27
27
  # Only test if qiskit is installed.
28
28
  try:
29
29
  import qiskit
30
- except ImportError:
31
- # coverage: ignore
30
+ except ImportError: # pragma: no cover
32
31
  warnings.warn(
33
32
  "Skipped assert_qasm_is_consistent_with_unitary because "
34
33
  "qiskit isn't installed to verify against."
@@ -101,8 +100,7 @@ qreg q[{num_qubits}];
101
100
  )
102
101
 
103
102
 
104
- def assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, unitary):
105
- # coverage: ignore
103
+ def assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, unitary): # pragma: no cover
106
104
  try:
107
105
  # We don't want to require qiskit as a dependency but
108
106
  # if Qiskit is installed, test QASM output against it.