cirq-core 1.2.0.dev20230717225858__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. cirq/__init__.py +5 -0
  2. cirq/_compat.py +26 -11
  3. cirq/_compat_test.py +37 -3
  4. cirq/_version.py +31 -1
  5. cirq/_version_test.py +1 -1
  6. cirq/circuits/circuit.py +106 -32
  7. cirq/circuits/circuit_operation.py +2 -2
  8. cirq/circuits/circuit_operation_test.py +1 -1
  9. cirq/circuits/circuit_test.py +109 -3
  10. cirq/circuits/frozen_circuit.py +80 -5
  11. cirq/circuits/frozen_circuit_test.py +47 -2
  12. cirq/circuits/qasm_output_test.py +9 -9
  13. cirq/conftest.py +1 -2
  14. cirq/contrib/acquaintance/devices.py +1 -1
  15. cirq/contrib/hacks/disable_validation_test.py +1 -1
  16. cirq/contrib/noise_models/noise_models.py +1 -2
  17. cirq/contrib/paulistring/clifford_optimize.py +1 -1
  18. cirq/contrib/paulistring/clifford_target_gateset_test.py +4 -4
  19. cirq/contrib/qcircuit/qcircuit_pdf.py +1 -1
  20. cirq/contrib/quimb/density_matrix.py +2 -3
  21. cirq/contrib/quimb/grid_circuits.py +3 -3
  22. cirq/contrib/quimb/state_vector.py +3 -5
  23. cirq/contrib/routing/utils.py +1 -2
  24. cirq/contrib/svg/svg.py +4 -6
  25. cirq/devices/grid_qubit.py +49 -38
  26. cirq/devices/grid_qubit_test.py +1 -3
  27. cirq/devices/insertion_noise_model.py +21 -1
  28. cirq/devices/insertion_noise_model_test.py +6 -0
  29. cirq/devices/line_qubit.py +67 -40
  30. cirq/devices/named_topologies.py +8 -14
  31. cirq/devices/noise_properties.py +1 -1
  32. cirq/devices/noise_utils.py +7 -5
  33. cirq/devices/noise_utils_test.py +7 -0
  34. cirq/experiments/fidelity_estimation_test.py +1 -1
  35. cirq/experiments/qubit_characterizations.py +6 -5
  36. cirq/experiments/random_quantum_circuit_generation.py +1 -1
  37. cirq/experiments/random_quantum_circuit_generation_test.py +28 -1
  38. cirq/experiments/readout_confusion_matrix.py +6 -6
  39. cirq/experiments/xeb_fitting.py +3 -5
  40. cirq/experiments/xeb_fitting_test.py +2 -2
  41. cirq/experiments/xeb_sampling.py +1 -1
  42. cirq/interop/quirk/url_to_circuit.py +40 -38
  43. cirq/json_resolver_cache.py +2 -0
  44. cirq/linalg/decompositions.py +6 -5
  45. cirq/ops/__init__.py +2 -0
  46. cirq/ops/classically_controlled_operation.py +1 -1
  47. cirq/ops/clifford_gate.py +9 -9
  48. cirq/ops/clifford_gate_test.py +3 -4
  49. cirq/ops/common_channels.py +2 -5
  50. cirq/ops/common_channels_test.py +3 -5
  51. cirq/ops/common_gates_test.py +7 -7
  52. cirq/ops/controlled_operation_test.py +2 -2
  53. cirq/ops/dense_pauli_string.py +3 -0
  54. cirq/ops/eigen_gate_test.py +1 -3
  55. cirq/ops/fourier_transform.py +1 -2
  56. cirq/ops/fsim_gate.py +1 -1
  57. cirq/ops/gate_features_test.py +2 -2
  58. cirq/ops/gate_operation_test.py +1 -2
  59. cirq/ops/greedy_qubit_manager.py +86 -0
  60. cirq/ops/greedy_qubit_manager_test.py +98 -0
  61. cirq/ops/linear_combinations.py +1 -1
  62. cirq/ops/named_qubit.py +55 -18
  63. cirq/ops/parity_gates.py +65 -18
  64. cirq/ops/parity_gates_test.py +41 -2
  65. cirq/ops/pauli_gates.py +2 -2
  66. cirq/ops/pauli_string.py +3 -4
  67. cirq/ops/pauli_string_raw_types_test.py +3 -3
  68. cirq/ops/pauli_string_test.py +3 -4
  69. cirq/ops/random_gate_channel_test.py +3 -3
  70. cirq/ops/raw_types.py +1 -1
  71. cirq/ops/raw_types_test.py +5 -5
  72. cirq/ops/three_qubit_gates.py +12 -8
  73. cirq/protocols/act_on_protocol_test.py +9 -9
  74. cirq/protocols/apply_channel_protocol.py +9 -6
  75. cirq/protocols/apply_unitary_protocol_test.py +1 -1
  76. cirq/protocols/equal_up_to_global_phase_protocol_test.py +2 -2
  77. cirq/protocols/has_stabilizer_effect_protocol.py +52 -6
  78. cirq/protocols/has_stabilizer_effect_protocol_test.py +21 -8
  79. cirq/protocols/has_unitary_protocol_test.py +1 -3
  80. cirq/protocols/json_serialization.py +6 -6
  81. cirq/protocols/json_serialization_test.py +7 -14
  82. cirq/protocols/json_test_data/InsertionNoiseModel.json +91 -0
  83. cirq/protocols/json_test_data/InsertionNoiseModel.repr +4 -0
  84. cirq/protocols/json_test_data/OpIdentifier.json +45 -10
  85. cirq/protocols/json_test_data/OpIdentifier.repr +7 -1
  86. cirq/protocols/json_test_data/spec.py +4 -0
  87. cirq/protocols/measurement_key_protocol_test.py +1 -1
  88. cirq/protocols/unitary_protocol_test.py +13 -16
  89. cirq/qis/clifford_tableau.py +7 -8
  90. cirq/qis/measures.py +1 -1
  91. cirq/qis/states.py +2 -3
  92. cirq/sim/__init__.py +2 -0
  93. cirq/sim/classical_simulator.py +107 -0
  94. cirq/sim/classical_simulator_test.py +207 -0
  95. cirq/sim/clifford/clifford_simulator_test.py +7 -7
  96. cirq/sim/clifford/stabilizer_simulation_state.py +2 -2
  97. cirq/sim/clifford/stabilizer_state_ch_form.py +7 -7
  98. cirq/sim/density_matrix_simulation_state.py +19 -4
  99. cirq/sim/density_matrix_simulator_test.py +5 -13
  100. cirq/sim/simulation_state_test.py +13 -14
  101. cirq/sim/simulator_test.py +6 -9
  102. cirq/sim/state_vector_simulation_state.py +1 -1
  103. cirq/study/resolver.py +41 -41
  104. cirq/study/resolver_test.py +13 -12
  105. cirq/testing/__init__.py +4 -1
  106. cirq/testing/circuit_compare.py +1 -1
  107. cirq/testing/circuit_compare_test.py +11 -11
  108. cirq/testing/consistent_controlled_gate_op.py +15 -1
  109. cirq/testing/consistent_controlled_gate_op_test.py +12 -3
  110. cirq/testing/consistent_decomposition.py +0 -1
  111. cirq/testing/consistent_protocols.py +6 -1
  112. cirq/testing/consistent_protocols_test.py +5 -10
  113. cirq/testing/consistent_qasm.py +2 -4
  114. cirq/testing/consistent_qasm_test.py +2 -3
  115. cirq/testing/consistent_specified_has_unitary_test.py +1 -3
  116. cirq/testing/equals_tester.py +1 -1
  117. cirq/testing/equals_tester_test.py +5 -5
  118. cirq/testing/equivalent_repr_eval_test.py +1 -3
  119. cirq/testing/gate_features_test.py +6 -6
  120. cirq/testing/order_tester_test.py +1 -3
  121. cirq/testing/random_circuit_test.py +1 -3
  122. cirq/transformers/__init__.py +3 -0
  123. cirq/transformers/analytical_decompositions/__init__.py +1 -0
  124. cirq/transformers/analytical_decompositions/three_qubit_decomposition.py +1 -2
  125. cirq/transformers/analytical_decompositions/three_qubit_decomposition_test.py +2 -5
  126. cirq/transformers/analytical_decompositions/two_qubit_state_preparation.py +38 -0
  127. cirq/transformers/analytical_decompositions/two_qubit_state_preparation_test.py +18 -0
  128. cirq/transformers/expand_composite_test.py +4 -4
  129. cirq/transformers/heuristic_decompositions/gate_tabulation_math_utils.py +1 -1
  130. cirq/transformers/heuristic_decompositions/two_qubit_gate_tabulation.py +1 -2
  131. cirq/transformers/merge_k_qubit_gates_test.py +2 -2
  132. cirq/transformers/qubit_management_transformers.py +177 -0
  133. cirq/transformers/qubit_management_transformers_test.py +250 -0
  134. cirq/transformers/routing/route_circuit_cqc.py +23 -4
  135. cirq/transformers/routing/route_circuit_cqc_test.py +42 -0
  136. cirq/transformers/stratify.py +10 -11
  137. cirq/transformers/target_gatesets/compilation_target_gateset_test.py +10 -10
  138. cirq/transformers/target_gatesets/cz_gateset_test.py +8 -10
  139. cirq/transformers/transformer_primitives.py +138 -28
  140. cirq/value/abc_alt_test.py +4 -4
  141. cirq/value/duration.py +68 -37
  142. cirq/value/duration_test.py +2 -0
  143. cirq/value/measurement_key_test.py +1 -1
  144. cirq/value/product_state.py +4 -8
  145. cirq/value/value_equality_attr.py +12 -5
  146. cirq/vis/heatmap.py +7 -4
  147. cirq/vis/heatmap_test.py +14 -4
  148. cirq/vis/histogram.py +4 -4
  149. cirq/vis/state_histogram.py +10 -6
  150. cirq/vis/state_histogram_test.py +2 -0
  151. cirq/work/observable_measurement_data_test.py +1 -1
  152. cirq/work/observable_measurement_test.py +2 -2
  153. cirq/work/zeros_sampler.py +1 -1
  154. {cirq_core-1.2.0.dev20230717225858.dist-info → cirq_core-1.3.0.dist-info}/METADATA +11 -19
  155. {cirq_core-1.2.0.dev20230717225858.dist-info → cirq_core-1.3.0.dist-info}/RECORD +158 -150
  156. {cirq_core-1.2.0.dev20230717225858.dist-info → cirq_core-1.3.0.dist-info}/WHEEL +1 -1
  157. {cirq_core-1.2.0.dev20230717225858.dist-info → cirq_core-1.3.0.dist-info}/LICENSE +0 -0
  158. {cirq_core-1.2.0.dev20230717225858.dist-info → cirq_core-1.3.0.dist-info}/top_level.txt +0 -0
@@ -104,6 +104,134 @@ def map_moments(
104
104
  )
105
105
 
106
106
 
107
+ def _map_operations_impl(
108
+ circuit: CIRCUIT_TYPE,
109
+ map_func: Callable[[ops.Operation, int], ops.OP_TREE],
110
+ *,
111
+ deep: bool = False,
112
+ raise_if_add_qubits=True,
113
+ tags_to_ignore: Sequence[Hashable] = (),
114
+ wrap_in_circuit_op: bool = True,
115
+ ) -> CIRCUIT_TYPE:
116
+ """Applies local transformations, by calling `map_func(op, moment_index)` for each operation.
117
+
118
+ This method provides a fast, iterative implementation for the two `map_operations_*` variants
119
+ exposed as public transformer primitives. The high level idea for the iterative implementation
120
+ is to
121
+ 1) For each operation `op`, find the corresponding mapped operation(s) `mapped_ops`. The
122
+ set of mapped operations can be either wrapped in a circuit operation or not, depending
123
+ on the value of flag `wrap_in_circuit_op` and whether the mapped operations will end up
124
+ occupying more than one moment or not.
125
+ 2) Use the `get_earliest_accommodating_moment_index` infrastructure built for `cirq.Circuit`
126
+ construction to determine the index at which the mapped operations should be inserted.
127
+ This step takes care of the nuances that arise due to (a) preserving moment structure
128
+ and (b) mapped operations spanning across multiple moments (these both are trivial when
129
+ `op` is mapped to a single `mapped_op` that acts on the same set of qubits).
130
+
131
+ By default, the function assumes `issubset(qubit_set(map_func(op, moment_index)), op.qubits)` is
132
+ True.
133
+
134
+ Args:
135
+ circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
136
+ map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. If the
137
+ resulting optree spans more than 1 moment, it's either wrapped in a tagged circuit
138
+ operation and inserted in-place in the same moment (if `wrap_in_circuit_op` is True)
139
+ OR the mapped operations are inserted directly in the circuit, preserving moment
140
+ strucutre. The effect is equivalent to (but much faster) a two-step approach of first
141
+ wrapping the operations in a circuit operation and then calling `cirq.unroll_circuit_op`
142
+ to unroll the corresponding circuit ops.
143
+ deep: If true, `map_func` will be recursively applied to circuits wrapped inside
144
+ any circuit operations contained within `circuit`.
145
+ raise_if_add_qubits: Set to True by default. If True, raises ValueError if
146
+ `map_func(op, idx)` adds operations on qubits outside of `op.qubits`.
147
+ tags_to_ignore: Sequence of tags which should be ignored while applying `map_func` on
148
+ tagged operations -- i.e. `map_func(op, idx)` will be called only for operations that
149
+ satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
150
+ wrap_in_circuit_op: If True, the mapped operations will be wrapped in a tagged circuit
151
+ operation and inserted in-place if they occupy more than one moment.
152
+
153
+ Raises:
154
+ ValueError if `issubset(qubit_set(map_func(op, idx)), op.qubits) is False` and
155
+ `raise_if_add_qubits is True`.
156
+
157
+ Returns:
158
+ Copy of input circuit with mapped operations.
159
+ """
160
+ tags_to_ignore_set = set(tags_to_ignore)
161
+
162
+ def apply_map_func(op: 'cirq.Operation', idx: int) -> List['cirq.Operation']:
163
+ if tags_to_ignore_set.intersection(op.tags):
164
+ return [op]
165
+ if deep and isinstance(op.untagged, circuits.CircuitOperation):
166
+ op = op.untagged.replace(
167
+ circuit=_map_operations_impl(
168
+ op.untagged.circuit,
169
+ map_func,
170
+ deep=deep,
171
+ raise_if_add_qubits=raise_if_add_qubits,
172
+ tags_to_ignore=tags_to_ignore,
173
+ wrap_in_circuit_op=wrap_in_circuit_op,
174
+ )
175
+ ).with_tags(*op.tags)
176
+ mapped_ops = [*ops.flatten_to_ops(map_func(op, idx))]
177
+ op_qubits = set(op.qubits)
178
+ mapped_ops_qubits: Set['cirq.Qid'] = set()
179
+ has_overlapping_ops = False
180
+ for mapped_op in mapped_ops:
181
+ if raise_if_add_qubits and not op_qubits.issuperset(mapped_op.qubits):
182
+ raise ValueError(
183
+ f"Mapped operations {mapped_ops} should act on a subset "
184
+ f"of qubits of the original operation {op}"
185
+ )
186
+ if mapped_ops_qubits.intersection(mapped_op.qubits):
187
+ has_overlapping_ops = True
188
+ mapped_ops_qubits = mapped_ops_qubits.union(mapped_op.qubits)
189
+ if wrap_in_circuit_op and has_overlapping_ops:
190
+ # Mapped operations should be wrapped in a `CircuitOperation` only iff they occupy more
191
+ # than one moment, i.e. there are at least two operations that share a qubit.
192
+ mapped_ops = [
193
+ circuits.CircuitOperation(circuits.FrozenCircuit(mapped_ops)).with_tags(
194
+ MAPPED_CIRCUIT_OP_TAG
195
+ )
196
+ ]
197
+ return mapped_ops
198
+
199
+ new_moments: List[List['cirq.Operation']] = []
200
+
201
+ # Keep track of the latest time index for each qubit, measurement key, and control key.
202
+ qubit_time_index: Dict['cirq.Qid', int] = {}
203
+ measurement_time_index: Dict['cirq.MeasurementKey', int] = {}
204
+ control_time_index: Dict['cirq.MeasurementKey', int] = {}
205
+
206
+ # New mapped operations in the current moment should be inserted after `last_moment_time_index`.
207
+ last_moment_time_index = -1
208
+
209
+ for idx, moment in enumerate(circuit):
210
+ if wrap_in_circuit_op:
211
+ new_moments.append([])
212
+ for op in moment:
213
+ mapped_ops = apply_map_func(op, idx)
214
+
215
+ for mapped_op in mapped_ops:
216
+ # Identify the earliest moment that can accommodate this op.
217
+ placement_index = circuits.circuit.get_earliest_accommodating_moment_index(
218
+ mapped_op, qubit_time_index, measurement_time_index, control_time_index
219
+ )
220
+ placement_index = max(placement_index, last_moment_time_index + 1)
221
+ new_moments.extend([[] for _ in range(placement_index - len(new_moments) + 1)])
222
+ new_moments[placement_index].append(mapped_op)
223
+ for qubit in mapped_op.qubits:
224
+ qubit_time_index[qubit] = placement_index
225
+ for key in protocols.measurement_key_objs(mapped_op):
226
+ measurement_time_index[key] = placement_index
227
+ for key in protocols.control_keys(mapped_op):
228
+ control_time_index[key] = placement_index
229
+
230
+ last_moment_time_index = len(new_moments) - 1
231
+
232
+ return _create_target_circuit_type([circuits.Moment(moment) for moment in new_moments], circuit)
233
+
234
+
107
235
  def map_operations(
108
236
  circuit: CIRCUIT_TYPE,
109
237
  map_func: Callable[[ops.Operation, int], ops.OP_TREE],
@@ -139,29 +267,13 @@ def map_operations(
139
267
  Returns:
140
268
  Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation).
141
269
  """
142
-
143
- def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE:
144
- if not set(op.tags).isdisjoint(tags_to_ignore):
145
- return op
146
- c = circuits.FrozenCircuit(map_func(op, idx))
147
- if raise_if_add_qubits and not c.all_qubits().issubset(op.qubits):
148
- raise ValueError(
149
- f"Mapped operations {c.all_operations()} should act on a subset "
150
- f"of qubits of the original operation {op}"
151
- )
152
- if len(c) <= 1:
153
- # Either empty circuit or all operations act in the same moment;
154
- # So, we don't need to wrap them in a circuit_op.
155
- return c[0].operations if c else []
156
- circuit_op = circuits.CircuitOperation(c).with_tags(MAPPED_CIRCUIT_OP_TAG)
157
- return circuit_op
158
-
159
- return map_moments(
270
+ return _map_operations_impl(
160
271
  circuit,
161
- lambda m, i: circuits.Circuit(apply_map(op, i) for op in m.operations).moments
162
- or [circuits.Moment()],
272
+ map_func,
163
273
  deep=deep,
274
+ raise_if_add_qubits=raise_if_add_qubits,
164
275
  tags_to_ignore=tags_to_ignore,
276
+ wrap_in_circuit_op=True,
165
277
  )
166
278
 
167
279
 
@@ -191,15 +303,13 @@ def map_operations_and_unroll(
191
303
  Returns:
192
304
  Copy of input circuit with mapped operations, unrolled in a moment preserving way.
193
305
  """
194
- return unroll_circuit_op(
195
- map_operations(
196
- circuit,
197
- map_func,
198
- deep=deep,
199
- raise_if_add_qubits=raise_if_add_qubits,
200
- tags_to_ignore=tags_to_ignore,
201
- ),
306
+ return _map_operations_impl(
307
+ circuit,
308
+ map_func,
202
309
  deep=deep,
310
+ raise_if_add_qubits=raise_if_add_qubits,
311
+ tags_to_ignore=tags_to_ignore,
312
+ wrap_in_circuit_op=False,
203
313
  )
204
314
 
205
315
 
@@ -140,7 +140,7 @@ def test_classcell_in_namespace():
140
140
  class _(metaclass=ABCMetaImplementAnyOneOf):
141
141
  def other_method(self):
142
142
  # Triggers __classcell__ to be added to the class namespace
143
- super() # coverage: ignore
143
+ super() # pragma: no cover
144
144
 
145
145
 
146
146
  def test_two_alternatives():
@@ -170,17 +170,17 @@ def test_two_alternatives():
170
170
  return 'alt1'
171
171
 
172
172
  def alt2(self) -> NoReturn:
173
- raise RuntimeError # coverage: ignore
173
+ raise RuntimeError # pragma: no cover
174
174
 
175
175
  class TwoAlternativesOverride(TwoAlternatives):
176
176
  def my_method(self, arg, kw=99) -> str:
177
177
  return 'override'
178
178
 
179
179
  def alt1(self) -> NoReturn:
180
- raise RuntimeError # coverage: ignore
180
+ raise RuntimeError # pragma: no cover
181
181
 
182
182
  def alt2(self) -> NoReturn:
183
- raise RuntimeError # coverage: ignore
183
+ raise RuntimeError # pragma: no cover
184
184
 
185
185
  class TwoAlternativesForceSecond(TwoAlternatives):
186
186
  def _do_alt1_with_my_method(self):
cirq/value/duration.py CHANGED
@@ -13,14 +13,14 @@
13
13
  # limitations under the License.
14
14
  """A typed time delta that supports picosecond accuracy."""
15
15
 
16
- from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
16
+ from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union, List
17
17
  import datetime
18
18
 
19
19
  import sympy
20
20
  import numpy as np
21
21
 
22
22
  from cirq import protocols
23
- from cirq._compat import proper_repr
23
+ from cirq._compat import proper_repr, cached_method
24
24
  from cirq._doc import document
25
25
 
26
26
  if TYPE_CHECKING:
@@ -79,48 +79,53 @@ class Duration:
79
79
  >>> print(cirq.Duration(micros=1.5 * sympy.Symbol('t')))
80
80
  (1500.0*t) ns
81
81
  """
82
+ self._time_vals: List[_NUMERIC_INPUT_TYPE] = [0, 0, 0, 0]
83
+ self._multipliers = [1, 1000, 1000_000, 1000_000_000]
82
84
  if value is not None and value != 0:
83
85
  if isinstance(value, datetime.timedelta):
84
86
  # timedelta has microsecond resolution.
85
- micros += int(value / datetime.timedelta(microseconds=1))
87
+ self._time_vals[2] = int(value / datetime.timedelta(microseconds=1))
86
88
  elif isinstance(value, Duration):
87
- picos += value._picos
89
+ self._time_vals = value._time_vals
88
90
  else:
89
91
  raise TypeError(f'Not a `cirq.DURATION_LIKE`: {repr(value)}.')
90
-
91
- val = picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
92
- self._picos: _NUMERIC_OUTPUT_TYPE = float(val) if isinstance(val, np.number) else val
92
+ input_vals = [picos, nanos, micros, millis]
93
+ self._time_vals = _add_time_vals(self._time_vals, input_vals)
93
94
 
94
95
  def _is_parameterized_(self) -> bool:
95
- return protocols.is_parameterized(self._picos)
96
+ return protocols.is_parameterized(self._time_vals)
96
97
 
97
98
  def _parameter_names_(self) -> AbstractSet[str]:
98
- return protocols.parameter_names(self._picos)
99
+ return protocols.parameter_names(self._time_vals)
99
100
 
100
101
  def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Duration':
101
- return Duration(picos=protocols.resolve_parameters(self._picos, resolver, recursive))
102
+ return _duration_from_time_vals(
103
+ protocols.resolve_parameters(self._time_vals, resolver, recursive)
104
+ )
102
105
 
106
+ @cached_method
103
107
  def total_picos(self) -> _NUMERIC_OUTPUT_TYPE:
104
108
  """Returns the number of picoseconds that the duration spans."""
105
- return self._picos
109
+ val = sum(a * b for a, b in zip(self._time_vals, self._multipliers))
110
+ return float(val) if isinstance(val, np.number) else val
106
111
 
107
112
  def total_nanos(self) -> _NUMERIC_OUTPUT_TYPE:
108
113
  """Returns the number of nanoseconds that the duration spans."""
109
- return self._picos / 1000
114
+ return self.total_picos() / 1000
110
115
 
111
116
  def total_micros(self) -> _NUMERIC_OUTPUT_TYPE:
112
117
  """Returns the number of microseconds that the duration spans."""
113
- return self._picos / 1000_000
118
+ return self.total_picos() / 1000_000
114
119
 
115
120
  def total_millis(self) -> _NUMERIC_OUTPUT_TYPE:
116
121
  """Returns the number of milliseconds that the duration spans."""
117
- return self._picos / 1000_000_000
122
+ return self.total_picos() / 1000_000_000
118
123
 
119
124
  def __add__(self, other) -> 'Duration':
120
125
  other = _attempt_duration_like_to_duration(other)
121
126
  if other is None:
122
127
  return NotImplemented
123
- return Duration(picos=self._picos + other._picos)
128
+ return _duration_from_time_vals(_add_time_vals(self._time_vals, other._time_vals))
124
129
 
125
130
  def __radd__(self, other) -> 'Duration':
126
131
  return self.__add__(other)
@@ -129,29 +134,36 @@ class Duration:
129
134
  other = _attempt_duration_like_to_duration(other)
130
135
  if other is None:
131
136
  return NotImplemented
132
- return Duration(picos=self._picos - other._picos)
137
+ return _duration_from_time_vals(
138
+ _add_time_vals(self._time_vals, [-x for x in other._time_vals])
139
+ )
133
140
 
134
141
  def __rsub__(self, other) -> 'Duration':
135
142
  other = _attempt_duration_like_to_duration(other)
136
143
  if other is None:
137
144
  return NotImplemented
138
- return Duration(picos=other._picos - self._picos)
145
+ return _duration_from_time_vals(
146
+ _add_time_vals(other._time_vals, [-x for x in self._time_vals])
147
+ )
139
148
 
140
149
  def __mul__(self, other) -> 'Duration':
141
150
  if not isinstance(other, (int, float, sympy.Expr)):
142
151
  return NotImplemented
143
- return Duration(picos=self._picos * other)
152
+ if other == 0:
153
+ return _duration_from_time_vals([0] * 4)
154
+ return _duration_from_time_vals([x * other for x in self._time_vals])
144
155
 
145
156
  def __rmul__(self, other) -> 'Duration':
146
157
  return self.__mul__(other)
147
158
 
148
159
  def __truediv__(self, other) -> Union['Duration', float]:
149
160
  if isinstance(other, (int, float, sympy.Expr)):
150
- return Duration(picos=self._picos / other)
161
+ new_time_vals = [x / other for x in self._time_vals]
162
+ return _duration_from_time_vals(new_time_vals)
151
163
 
152
164
  other_duration = _attempt_duration_like_to_duration(other)
153
165
  if other_duration is not None:
154
- return self._picos / other_duration._picos
166
+ return self.total_picos() / other_duration.total_picos()
155
167
 
156
168
  return NotImplemented
157
169
 
@@ -159,56 +171,57 @@ class Duration:
159
171
  other = _attempt_duration_like_to_duration(other)
160
172
  if other is None:
161
173
  return NotImplemented
162
- return self._picos == other._picos
174
+ return self.total_picos() == other.total_picos()
163
175
 
164
176
  def __ne__(self, other):
165
177
  other = _attempt_duration_like_to_duration(other)
166
178
  if other is None:
167
179
  return NotImplemented
168
- return self._picos != other._picos
180
+ return self.total_picos() != other.total_picos()
169
181
 
170
182
  def __gt__(self, other):
171
183
  other = _attempt_duration_like_to_duration(other)
172
184
  if other is None:
173
185
  return NotImplemented
174
- return self._picos > other._picos
186
+ return self.total_picos() > other.total_picos()
175
187
 
176
188
  def __lt__(self, other):
177
189
  other = _attempt_duration_like_to_duration(other)
178
190
  if other is None:
179
191
  return NotImplemented
180
- return self._picos < other._picos
192
+ return self.total_picos() < other.total_picos()
181
193
 
182
194
  def __ge__(self, other):
183
195
  other = _attempt_duration_like_to_duration(other)
184
196
  if other is None:
185
197
  return NotImplemented
186
- return self._picos >= other._picos
198
+ return self.total_picos() >= other.total_picos()
187
199
 
188
200
  def __le__(self, other):
189
201
  other = _attempt_duration_like_to_duration(other)
190
202
  if other is None:
191
203
  return NotImplemented
192
- return self._picos <= other._picos
204
+ return self.total_picos() <= other.total_picos()
193
205
 
194
206
  def __bool__(self):
195
- return bool(self._picos)
207
+ return bool(self.total_picos())
196
208
 
197
209
  def __hash__(self):
198
- if isinstance(self._picos, (int, float)) and self._picos % 1000000 == 0:
199
- return hash(datetime.timedelta(microseconds=self._picos / 1000000))
200
- return hash((Duration, self._picos))
210
+ if isinstance(self.total_picos(), (int, float)) and self.total_picos() % 1000000 == 0:
211
+ return hash(datetime.timedelta(microseconds=self.total_picos() / 1000000))
212
+ return hash((Duration, self.total_picos()))
201
213
 
202
214
  def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]:
215
+ picos = self.total_picos()
203
216
  if (
204
- isinstance(self._picos, sympy.Mul)
205
- and len(self._picos.args) == 2
206
- and isinstance(self._picos.args[0], (sympy.Integer, sympy.Float))
217
+ isinstance(picos, sympy.Mul)
218
+ and len(picos.args) == 2
219
+ and isinstance(picos.args[0], (sympy.Integer, sympy.Float))
207
220
  ):
208
- scale = self._picos.args[0]
209
- rest = self._picos.args[1]
221
+ scale = picos.args[0]
222
+ rest = picos.args[1]
210
223
  else:
211
- scale = self._picos
224
+ scale = picos
212
225
  rest = 1
213
226
 
214
227
  if scale % 1000_000_000 == 0:
@@ -234,7 +247,7 @@ class Duration:
234
247
  return amount * rest, unit, suffix
235
248
 
236
249
  def __str__(self) -> str:
237
- if self._picos == 0:
250
+ if self.total_picos() == 0:
238
251
  return 'Duration(0)'
239
252
  amount, _, suffix = self._decompose_into_amount_unit_suffix()
240
253
  if not isinstance(amount, (int, float, sympy.Symbol)):
@@ -257,3 +270,21 @@ def _attempt_duration_like_to_duration(value: Any) -> Optional[Duration]:
257
270
  if isinstance(value, (int, float)) and value == 0:
258
271
  return Duration()
259
272
  return None
273
+
274
+
275
+ def _add_time_vals(
276
+ val1: List[_NUMERIC_INPUT_TYPE], val2: List[_NUMERIC_INPUT_TYPE]
277
+ ) -> List[_NUMERIC_INPUT_TYPE]:
278
+ ret: List[_NUMERIC_INPUT_TYPE] = []
279
+ for i in range(4):
280
+ if val1[i] and val2[i]:
281
+ ret.append(val1[i] + val2[i])
282
+ else:
283
+ ret.append(val1[i] or val2[i])
284
+ return ret
285
+
286
+
287
+ def _duration_from_time_vals(time_vals: List[_NUMERIC_INPUT_TYPE]):
288
+ ret = Duration()
289
+ ret._time_vals = time_vals
290
+ return ret
@@ -168,9 +168,11 @@ def test_sub():
168
168
  def test_mul():
169
169
  assert Duration(picos=2) * 3 == Duration(picos=6)
170
170
  assert 4 * Duration(picos=3) == Duration(picos=12)
171
+ assert 0 * Duration(picos=10) == Duration()
171
172
 
172
173
  t = sympy.Symbol('t')
173
174
  assert t * Duration(picos=3) == Duration(picos=3 * t)
175
+ assert 0 * Duration(picos=t) == Duration(picos=0)
174
176
 
175
177
  with pytest.raises(TypeError):
176
178
  _ = Duration() * Duration()
@@ -43,7 +43,7 @@ def test_eq_and_hash():
43
43
  self.some_str = some_str
44
44
 
45
45
  def __str__(self):
46
- return self.some_str # coverage: ignore
46
+ return self.some_str # pragma: no cover
47
47
 
48
48
  mkey = cirq.MeasurementKey('key')
49
49
  assert mkey == 'key'
@@ -61,8 +61,7 @@ class ProductState:
61
61
 
62
62
  def __init__(self, states=None):
63
63
  if states is None:
64
- # coverage: ignore
65
- states = dict()
64
+ states = dict() # pragma: no cover
66
65
 
67
66
  object.__setattr__(self, 'states', states)
68
67
 
@@ -200,8 +199,7 @@ class _XEigenState(_PauliEigenState):
200
199
  return np.array([1, 1]) / np.sqrt(2)
201
200
  elif self.eigenvalue == -1:
202
201
  return np.array([1, -1]) / np.sqrt(2)
203
- # coverage: ignore
204
- raise ValueError(f"Bad eigenvalue: {self.eigenvalue}")
202
+ raise ValueError(f"Bad eigenvalue: {self.eigenvalue}") # pragma: no cover
205
203
 
206
204
  def stabilized_by(self) -> Tuple[int, 'cirq.Pauli']:
207
205
  # Prevent circular import from `value.value_equality`
@@ -218,8 +216,7 @@ class _YEigenState(_PauliEigenState):
218
216
  return np.array([1, 1j]) / np.sqrt(2)
219
217
  elif self.eigenvalue == -1:
220
218
  return np.array([1, -1j]) / np.sqrt(2)
221
- # coverage: ignore
222
- raise ValueError(f"Bad eigenvalue: {self.eigenvalue}")
219
+ raise ValueError(f"Bad eigenvalue: {self.eigenvalue}") # pragma: no cover
223
220
 
224
221
  def stabilized_by(self) -> Tuple[int, 'cirq.Pauli']:
225
222
  from cirq import ops
@@ -235,8 +232,7 @@ class _ZEigenState(_PauliEigenState):
235
232
  return np.array([1, 0])
236
233
  elif self.eigenvalue == -1:
237
234
  return np.array([0, 1])
238
- # coverage: ignore
239
- raise ValueError(f"Bad eigenvalue: {self.eigenvalue}")
235
+ raise ValueError(f"Bad eigenvalue: {self.eigenvalue}") # pragma: no cover
240
236
 
241
237
  def stabilized_by(self) -> Tuple[int, 'cirq.Pauli']:
242
238
  from cirq import ops
@@ -17,7 +17,7 @@ from typing import Any, Callable, Optional, overload, Union
17
17
 
18
18
  from typing_extensions import Protocol
19
19
 
20
- from cirq import protocols
20
+ from cirq import protocols, _compat
21
21
 
22
22
 
23
23
  class _SupportsValueEquality(Protocol):
@@ -50,8 +50,7 @@ class _SupportsValueEquality(Protocol):
50
50
  Returns:
51
51
  Any type supported by `cirq.approx_eq()`.
52
52
  """
53
- # coverage: ignore
54
- return self._value_equality_values_()
53
+ return self._value_equality_values_() # pragma: no cover
55
54
 
56
55
  def _value_equality_values_cls_(self) -> Any:
57
56
  """Automatically implemented by the `cirq.value_equality` decorator.
@@ -222,13 +221,21 @@ def value_equality(
222
221
  )
223
222
  else:
224
223
  setattr(cls, '_value_equality_values_cls_', lambda self: cls)
225
- setattr(cls, '__hash__', None if unhashable else _value_equality_hash)
224
+ cached_values_getter = values_getter if unhashable else _compat.cached_method(values_getter)
225
+ setattr(cls, '_value_equality_values_', cached_values_getter)
226
+ setattr(cls, '__hash__', None if unhashable else _compat.cached_method(_value_equality_hash))
226
227
  setattr(cls, '__eq__', _value_equality_eq)
227
228
  setattr(cls, '__ne__', _value_equality_ne)
228
229
 
229
230
  if approximate:
230
231
  if not hasattr(cls, '_value_equality_approximate_values_'):
231
- setattr(cls, '_value_equality_approximate_values_', values_getter)
232
+ setattr(cls, '_value_equality_approximate_values_', cached_values_getter)
233
+ else:
234
+ approx_values_getter = getattr(cls, '_value_equality_approximate_values_')
235
+ cached_approx_values_getter = (
236
+ approx_values_getter if unhashable else _compat.cached_method(approx_values_getter)
237
+ )
238
+ setattr(cls, '_value_equality_approximate_values_', cached_approx_values_getter)
232
239
  setattr(cls, '_approx_eq_', _value_equality_approx_eq)
233
240
 
234
241
  return cls
cirq/vis/heatmap.py CHANGED
@@ -15,6 +15,7 @@ import copy
15
15
  from dataclasses import astuple, dataclass
16
16
  from typing import (
17
17
  Any,
18
+ cast,
18
19
  Dict,
19
20
  List,
20
21
  Mapping,
@@ -217,7 +218,7 @@ class Heatmap:
217
218
  )
218
219
  position = self._config['colorbar_position']
219
220
  orien = 'vertical' if position in ('left', 'right') else 'horizontal'
220
- colorbar = ax.figure.colorbar(
221
+ colorbar = cast(plt.Figure, ax.figure).colorbar(
221
222
  mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {})
222
223
  )
223
224
  colorbar_ax.tick_params(axis='y', direction='out')
@@ -230,15 +231,15 @@ class Heatmap:
230
231
  ax: plt.Axes,
231
232
  ) -> None:
232
233
  """Writes annotations to the center of cells. Internal."""
233
- for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()):
234
+ for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolor()):
234
235
  # Calculate the center of the cell, assuming that it is a square
235
236
  # centered at (x=col, y=row).
236
237
  if not annotation:
237
238
  continue
238
239
  x, y = center
239
- face_luminance = vis_utils.relative_luminance(facecolor)
240
+ face_luminance = vis_utils.relative_luminance(facecolor) # type: ignore
240
241
  text_color = 'black' if face_luminance > 0.4 else 'white'
241
- text_kwargs = dict(color=text_color, ha="center", va="center")
242
+ text_kwargs: Dict[str, Any] = dict(color=text_color, ha="center", va="center")
242
243
  text_kwargs.update(self._config.get('annotation_text_kwargs', {}))
243
244
  ax.text(x, y, annotation, **text_kwargs)
244
245
 
@@ -295,6 +296,7 @@ class Heatmap:
295
296
  show_plot = not ax
296
297
  if not ax:
297
298
  fig, ax = plt.subplots(figsize=(8, 8))
299
+ ax = cast(plt.Axes, ax)
298
300
  original_config = copy.deepcopy(self._config)
299
301
  self.update_config(**kwargs)
300
302
  collection = self._plot_on_axis(ax)
@@ -381,6 +383,7 @@ class TwoQubitInteractionHeatmap(Heatmap):
381
383
  show_plot = not ax
382
384
  if not ax:
383
385
  fig, ax = plt.subplots(figsize=(8, 8))
386
+ ax = cast(plt.Axes, ax)
384
387
  original_config = copy.deepcopy(self._config)
385
388
  self.update_config(**kwargs)
386
389
  qubits = set([q for qubits in self._value_map.keys() for q in qubits])
cirq/vis/heatmap_test.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Tests for Heatmap."""
15
15
 
16
16
  import pathlib
17
+ import shutil
17
18
  import string
18
19
  from tempfile import mkdtemp
19
20
 
@@ -33,6 +34,14 @@ def ax():
33
34
  return figure.add_subplot(111)
34
35
 
35
36
 
37
+ def test_default_ax():
38
+ row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
39
+ test_value_map = {
40
+ grid_qubit.GridQubit(row, col): np.random.random() for (row, col) in row_col_list
41
+ }
42
+ _, _ = heatmap.Heatmap(test_value_map).plot()
43
+
44
+
36
45
  @pytest.mark.parametrize('tuple_keys', [True, False])
37
46
  def test_cells_positions(ax, tuple_keys):
38
47
  row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
@@ -60,6 +69,8 @@ def test_two_qubit_heatmap(ax):
60
69
  title = "Two Qubit Interaction Heatmap"
61
70
  heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot(ax)
62
71
  assert ax.get_title() == title
72
+ # Test default axis
73
+ heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot()
63
74
 
64
75
 
65
76
  def test_invalid_args():
@@ -104,10 +115,8 @@ def test_cell_colors(ax, colormap_name):
104
115
  col = int(round(np.mean([v[0] for v in vertices])))
105
116
  value = test_row_col_map[(row, col)]
106
117
  color_scale = (value - vmin) / (vmax - vmin)
107
- if color_scale < 0.0:
108
- color_scale = 0.0
109
- if color_scale > 1.0:
110
- color_scale = 1.0
118
+ color_scale = max(color_scale, 0.0)
119
+ color_scale = min(color_scale, 1.0)
111
120
  expected_color = np.array(colormap(color_scale))
112
121
  assert np.all(np.isclose(facecolor, expected_color))
113
122
 
@@ -309,6 +318,7 @@ def test_colorbar(ax, position, size, pad):
309
318
 
310
319
  plt.close(fig1)
311
320
  plt.close(fig2)
321
+ shutil.rmtree(tmp_dir)
312
322
 
313
323
 
314
324
  @pytest.mark.usefixtures('closefigures')