cirq-core 1.6.0.dev20250501173104__py3-none-any.whl → 1.6.0.dev20250501231232__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cirq-core might be problematic. Click here for more details.

Files changed (59) hide show
  1. cirq/_version.py +1 -1
  2. cirq/_version_test.py +1 -1
  3. cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py +211 -107
  4. cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py +347 -3
  5. cirq/transformers/analytical_decompositions/two_qubit_to_cz.py +18 -18
  6. cirq/transformers/analytical_decompositions/two_qubit_to_fsim.py +18 -19
  7. cirq/transformers/analytical_decompositions/two_qubit_to_ms.py +8 -10
  8. cirq/transformers/analytical_decompositions/two_qubit_to_sqrt_iswap.py +26 -28
  9. cirq/transformers/drop_empty_moments.py +4 -2
  10. cirq/transformers/drop_negligible_operations.py +6 -4
  11. cirq/transformers/dynamical_decoupling.py +6 -4
  12. cirq/transformers/dynamical_decoupling_test.py +8 -6
  13. cirq/transformers/eject_phased_paulis.py +14 -12
  14. cirq/transformers/eject_z.py +8 -6
  15. cirq/transformers/expand_composite.py +5 -3
  16. cirq/transformers/gauge_compiling/sqrt_cz_gauge.py +3 -1
  17. cirq/transformers/heuristic_decompositions/two_qubit_gate_tabulation.py +4 -1
  18. cirq/transformers/insertion_sort.py +6 -4
  19. cirq/transformers/measurement_transformers.py +21 -21
  20. cirq/transformers/merge_k_qubit_gates.py +11 -9
  21. cirq/transformers/merge_k_qubit_gates_test.py +5 -3
  22. cirq/transformers/merge_single_qubit_gates.py +15 -13
  23. cirq/transformers/optimize_for_target_gateset.py +14 -12
  24. cirq/transformers/optimize_for_target_gateset_test.py +7 -3
  25. cirq/transformers/qubit_management_transformers.py +10 -8
  26. cirq/transformers/randomized_measurements.py +9 -7
  27. cirq/transformers/routing/initial_mapper.py +5 -3
  28. cirq/transformers/routing/line_initial_mapper.py +15 -13
  29. cirq/transformers/routing/mapping_manager.py +9 -9
  30. cirq/transformers/routing/route_circuit_cqc.py +17 -15
  31. cirq/transformers/routing/visualize_routed_circuit.py +7 -6
  32. cirq/transformers/stratify.py +13 -11
  33. cirq/transformers/synchronize_terminal_measurements.py +9 -9
  34. cirq/transformers/target_gatesets/compilation_target_gateset.py +19 -17
  35. cirq/transformers/target_gatesets/compilation_target_gateset_test.py +11 -7
  36. cirq/transformers/target_gatesets/cz_gateset.py +4 -2
  37. cirq/transformers/target_gatesets/sqrt_iswap_gateset.py +5 -3
  38. cirq/transformers/transformer_api.py +17 -15
  39. cirq/transformers/transformer_primitives.py +22 -20
  40. cirq/transformers/transformer_primitives_test.py +3 -1
  41. cirq/value/classical_data.py +26 -26
  42. cirq/value/condition.py +23 -21
  43. cirq/value/duration.py +11 -8
  44. cirq/value/linear_dict.py +22 -20
  45. cirq/value/periodic_value.py +4 -4
  46. cirq/value/probability.py +3 -1
  47. cirq/value/product_state.py +14 -12
  48. cirq/work/collector.py +7 -5
  49. cirq/work/observable_measurement.py +24 -22
  50. cirq/work/observable_measurement_data.py +9 -7
  51. cirq/work/observable_readout_calibration.py +4 -1
  52. cirq/work/observable_readout_calibration_test.py +4 -1
  53. cirq/work/observable_settings.py +4 -2
  54. cirq/work/pauli_sum_collector.py +8 -6
  55. {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/METADATA +1 -1
  56. {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/RECORD +59 -59
  57. {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/WHEEL +0 -0
  58. {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/licenses/LICENSE +0 -0
  59. {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@
14
14
 
15
15
  """Defines the API for circuit transformers in Cirq."""
16
16
 
17
+ from __future__ import annotations
18
+
17
19
  import dataclasses
18
20
  import enum
19
21
  import functools
@@ -83,8 +85,8 @@ class _LoggerNode:
83
85
 
84
86
  transformer_id: int
85
87
  transformer_name: str
86
- initial_circuit: 'cirq.AbstractCircuit'
87
- final_circuit: 'cirq.AbstractCircuit'
88
+ initial_circuit: cirq.AbstractCircuit
89
+ final_circuit: cirq.AbstractCircuit
88
90
  logs: List[Tuple[LogLevel, Tuple[str, ...]]] = dataclasses.field(default_factory=list)
89
91
  nested_loggers: List[int] = dataclasses.field(default_factory=list)
90
92
 
@@ -116,7 +118,7 @@ class TransformerLogger:
116
118
  self._logs: List[_LoggerNode] = []
117
119
  self._stack: List[int] = []
118
120
 
119
- def register_initial(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
121
+ def register_initial(self, circuit: cirq.AbstractCircuit, transformer_name: str) -> None:
120
122
  """Register the beginning of a new transformer stage.
121
123
 
122
124
  Args:
@@ -143,7 +145,7 @@ class TransformerLogger:
143
145
  raise ValueError('No active transformer found.')
144
146
  self._logs[self._stack[-1]].logs.append((level, args))
145
147
 
146
- def register_final(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
148
+ def register_final(self, circuit: cirq.AbstractCircuit, transformer_name: str) -> None:
147
149
  """Register the end of the currently active transformer stage.
148
150
 
149
151
  Args:
@@ -195,13 +197,13 @@ class TransformerLogger:
195
197
  class NoOpTransformerLogger(TransformerLogger):
196
198
  """All calls to this logger are a no-op"""
197
199
 
198
- def register_initial(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
200
+ def register_initial(self, circuit: cirq.AbstractCircuit, transformer_name: str) -> None:
199
201
  pass
200
202
 
201
203
  def log(self, *args: str, level: LogLevel = LogLevel.INFO) -> None:
202
204
  pass
203
205
 
204
- def register_final(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
206
+ def register_final(self, circuit: cirq.AbstractCircuit, transformer_name: str) -> None:
205
207
  pass
206
208
 
207
209
  def show(self, level: LogLevel = LogLevel.INFO) -> None:
@@ -262,8 +264,8 @@ class TRANSFORMER(Protocol):
262
264
  """
263
265
 
264
266
  def __call__(
265
- self, circuit: 'cirq.AbstractCircuit', *, context: Optional[TransformerContext] = None
266
- ) -> 'cirq.AbstractCircuit': ...
267
+ self, circuit: cirq.AbstractCircuit, *, context: Optional[TransformerContext] = None
268
+ ) -> cirq.AbstractCircuit: ...
267
269
 
268
270
 
269
271
  _TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
@@ -357,8 +359,8 @@ def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> A
357
359
 
358
360
  @functools.wraps(method)
359
361
  def method_with_logging(
360
- self, circuit: 'cirq.AbstractCircuit', **kwargs
361
- ) -> 'cirq.AbstractCircuit':
362
+ self, circuit: cirq.AbstractCircuit, **kwargs
363
+ ) -> cirq.AbstractCircuit:
362
364
  return _transform_and_log(
363
365
  add_deep_support,
364
366
  lambda circuit, **kwargs: method(self, circuit, **kwargs),
@@ -376,7 +378,7 @@ def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> A
376
378
  default_context = _get_default_context(func)
377
379
 
378
380
  @functools.wraps(func)
379
- def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit':
381
+ def func_with_logging(circuit: cirq.AbstractCircuit, **kwargs) -> cirq.AbstractCircuit:
380
382
  return _transform_and_log(
381
383
  add_deep_support,
382
384
  func,
@@ -401,10 +403,10 @@ def _get_default_context(func: TRANSFORMER) -> TransformerContext:
401
403
  def _run_transformer_on_circuit(
402
404
  add_deep_support: bool,
403
405
  func: TRANSFORMER,
404
- circuit: 'cirq.AbstractCircuit',
406
+ circuit: cirq.AbstractCircuit,
405
407
  extracted_context: Optional[TransformerContext],
406
408
  **kwargs,
407
- ) -> 'cirq.AbstractCircuit':
409
+ ) -> cirq.AbstractCircuit:
408
410
  mutable_circuit = None
409
411
  if extracted_context and extracted_context.deep and add_deep_support:
410
412
  batch_replace = []
@@ -429,10 +431,10 @@ def _transform_and_log(
429
431
  add_deep_support: bool,
430
432
  func: TRANSFORMER,
431
433
  transformer_name: str,
432
- circuit: 'cirq.AbstractCircuit',
434
+ circuit: cirq.AbstractCircuit,
433
435
  extracted_context: Optional[TransformerContext],
434
436
  **kwargs,
435
- ) -> 'cirq.AbstractCircuit':
437
+ ) -> cirq.AbstractCircuit:
436
438
  """Helper to log initial and final circuits before and after calling the transformer."""
437
439
  if extracted_context:
438
440
  extracted_context.logger.register_initial(circuit, transformer_name)
@@ -14,6 +14,8 @@
14
14
 
15
15
  """Defines primitives for common transformer patterns."""
16
16
 
17
+ from __future__ import annotations
18
+
17
19
  import bisect
18
20
  import dataclasses
19
21
  from collections import defaultdict
@@ -157,7 +159,7 @@ def _map_operations_impl(
157
159
  """
158
160
  tags_to_ignore_set = set(tags_to_ignore)
159
161
 
160
- def apply_map_func(op: 'cirq.Operation', idx: int) -> List['cirq.Operation']:
162
+ def apply_map_func(op: cirq.Operation, idx: int) -> List[cirq.Operation]:
161
163
  if tags_to_ignore_set.intersection(op.tags):
162
164
  return [op]
163
165
  if deep and isinstance(op.untagged, circuits.CircuitOperation):
@@ -173,7 +175,7 @@ def _map_operations_impl(
173
175
  ).with_tags(*op.tags)
174
176
  mapped_ops = [*ops.flatten_to_ops(map_func(op, idx))]
175
177
  op_qubits = set(op.qubits)
176
- mapped_ops_qubits: Set['cirq.Qid'] = set()
178
+ mapped_ops_qubits: Set[cirq.Qid] = set()
177
179
  has_overlapping_ops = False
178
180
  for mapped_op in mapped_ops:
179
181
  if raise_if_add_qubits and not op_qubits.issuperset(mapped_op.qubits):
@@ -194,9 +196,9 @@ def _map_operations_impl(
194
196
  ]
195
197
  return mapped_ops
196
198
 
197
- new_moments: List[List['cirq.Operation']] = []
199
+ new_moments: List[List[cirq.Operation]] = []
198
200
  for idx, moment in enumerate(circuit):
199
- curr_moments: List[List['cirq.Operation']] = [[]] if wrap_in_circuit_op else []
201
+ curr_moments: List[List[cirq.Operation]] = [[]] if wrap_in_circuit_op else []
200
202
  placement_cache = circuits.circuit._PlacementCache()
201
203
  for op in moment:
202
204
  mapped_ops = apply_map_func(op, idx)
@@ -305,21 +307,21 @@ class _MergedCircuit:
305
307
  of a set to store operations to preserve insertion order.
306
308
  """
307
309
 
308
- qubit_indexes: Dict['cirq.Qid', List[int]] = dataclasses.field(
310
+ qubit_indexes: Dict[cirq.Qid, List[int]] = dataclasses.field(
309
311
  default_factory=lambda: defaultdict(lambda: [-1])
310
312
  )
311
- mkey_indexes: Dict['cirq.MeasurementKey', List[int]] = dataclasses.field(
313
+ mkey_indexes: Dict[cirq.MeasurementKey, List[int]] = dataclasses.field(
312
314
  default_factory=lambda: defaultdict(lambda: [-1])
313
315
  )
314
- ckey_indexes: Dict['cirq.MeasurementKey', List[int]] = dataclasses.field(
316
+ ckey_indexes: Dict[cirq.MeasurementKey, List[int]] = dataclasses.field(
315
317
  default_factory=lambda: defaultdict(lambda: [-1])
316
318
  )
317
- ops_by_index: List[Dict['cirq.Operation', int]] = dataclasses.field(default_factory=list)
319
+ ops_by_index: List[Dict[cirq.Operation, int]] = dataclasses.field(default_factory=list)
318
320
 
319
321
  def append_empty_moment(self) -> None:
320
322
  self.ops_by_index.append({})
321
323
 
322
- def add_op_to_moment(self, moment_index: int, op: 'cirq.Operation') -> None:
324
+ def add_op_to_moment(self, moment_index: int, op: cirq.Operation) -> None:
323
325
  self.ops_by_index[moment_index][op] = 0
324
326
  for q in op.qubits:
325
327
  if moment_index > self.qubit_indexes[q][-1]:
@@ -331,7 +333,7 @@ class _MergedCircuit:
331
333
  for ckey in protocols.control_keys(op):
332
334
  bisect.insort(self.ckey_indexes[ckey], moment_index)
333
335
 
334
- def remove_op_from_moment(self, moment_index: int, op: 'cirq.Operation') -> None:
336
+ def remove_op_from_moment(self, moment_index: int, op: cirq.Operation) -> None:
335
337
  self.ops_by_index[moment_index].pop(op)
336
338
  for q in op.qubits:
337
339
  if self.qubit_indexes[q][-1] == moment_index:
@@ -344,8 +346,8 @@ class _MergedCircuit:
344
346
  self.ckey_indexes[ckey].remove(moment_index)
345
347
 
346
348
  def get_mergeable_ops(
347
- self, op: 'cirq.Operation', op_qs: Set['cirq.Qid']
348
- ) -> Tuple[int, List['cirq.Operation']]:
349
+ self, op: cirq.Operation, op_qs: Set[cirq.Qid]
350
+ ) -> Tuple[int, List[cirq.Operation]]:
349
351
  # Find the index of previous moment which can be merged with `op`.
350
352
  idx = max([self.qubit_indexes[q][-1] for q in op_qs], default=-1)
351
353
  idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in protocols.control_keys(op)])
@@ -360,7 +362,7 @@ class _MergedCircuit:
360
362
  left_op for left_op in self.ops_by_index[idx] if not op_qs.isdisjoint(left_op.qubits)
361
363
  ]
362
364
 
363
- def get_cirq_circuit(self) -> 'cirq.Circuit':
365
+ def get_cirq_circuit(self) -> cirq.Circuit:
364
366
  return circuits.Circuit(circuits.Moment(m.keys()) for m in self.ops_by_index)
365
367
 
366
368
 
@@ -493,7 +495,7 @@ def merge_operations(
493
495
 
494
496
  def merge_operations_to_circuit_op(
495
497
  circuit: CIRCUIT_TYPE,
496
- can_merge: Callable[[Sequence['cirq.Operation'], Sequence['cirq.Operation']], bool],
498
+ can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool],
497
499
  *,
498
500
  tags_to_ignore: Sequence[Hashable] = (),
499
501
  merged_circuit_op_tag: str = "Merged connected component",
@@ -524,8 +526,8 @@ def merge_operations_to_circuit_op(
524
526
  Copy of input circuit with valid connected components wrapped in tagged circuit operations.
525
527
  """
526
528
 
527
- def merge_func(op1: 'cirq.Operation', op2: 'cirq.Operation') -> Optional['cirq.Operation']:
528
- def get_ops(op: 'cirq.Operation'):
529
+ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> Optional[cirq.Operation]:
530
+ def get_ops(op: cirq.Operation):
529
531
  op_untagged = op.untagged
530
532
  return (
531
533
  [*op_untagged.circuit.all_operations()]
@@ -573,7 +575,7 @@ def merge_k_qubit_unitaries_to_circuit_op(
573
575
  Copy of input circuit with valid connected components wrapped in tagged circuit operations.
574
576
  """
575
577
 
576
- def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']) -> bool:
578
+ def can_merge(ops1: Sequence[cirq.Operation], ops2: Sequence[cirq.Operation]) -> bool:
577
579
  return all(
578
580
  protocols.num_qubits(op) <= k and protocols.has_unitary(op)
579
581
  for op_list in [ops1, ops2]
@@ -659,7 +661,7 @@ def unroll_circuit_op(
659
661
  """
660
662
 
661
663
  def map_func(m: circuits.Moment, _: int):
662
- to_zip: List['cirq.AbstractCircuit'] = []
664
+ to_zip: List[cirq.AbstractCircuit] = []
663
665
  for op in m:
664
666
  op_untagged = op.untagged
665
667
  if isinstance(op_untagged, circuits.CircuitOperation):
@@ -751,7 +753,7 @@ def unroll_circuit_op_greedy_frontier(
751
753
  Copy of input circuit with (Tagged) CircuitOperation's expanded inline at qubit frontier.
752
754
  """
753
755
  unrolled_circuit = circuit.unfreeze(copy=True)
754
- frontier: Dict['cirq.Qid', int] = defaultdict(lambda: 0)
756
+ frontier: Dict[cirq.Qid, int] = defaultdict(lambda: 0)
755
757
  idx = 0
756
758
  while idx < len(unrolled_circuit):
757
759
  for op in unrolled_circuit[idx].operations:
@@ -799,7 +801,7 @@ def toggle_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool =
799
801
  """
800
802
  tags_to_xor = set(tags)
801
803
 
802
- def map_func(op: 'cirq.Operation', _) -> 'cirq.Operation':
804
+ def map_func(op: cirq.Operation, _) -> cirq.Operation:
803
805
  return (
804
806
  op
805
807
  if deep and isinstance(op, circuits.CircuitOperation)
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  from typing import Iterator, List, Optional
16
18
 
17
19
  import pytest
@@ -721,7 +723,7 @@ def test_merge_operations_to_circuit_op_merges_connected_component():
721
723
  ''',
722
724
  )
723
725
 
724
- def can_merge(ops1: List['cirq.Operation'], ops2: List['cirq.Operation']) -> bool:
726
+ def can_merge(ops1: List[cirq.Operation], ops2: List[cirq.Operation]) -> bool:
725
727
  """Artificial example where a CZ will absorb any merge-able operation."""
726
728
  return any(o.gate == cirq.CZ for op_list in [ops1, ops2] for o in op_list)
727
729
 
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  import abc
16
18
  import enum
17
19
  from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
@@ -45,21 +47,21 @@ class MeasurementType(enum.IntEnum):
45
47
 
46
48
  class ClassicalDataStoreReader(abc.ABC):
47
49
  @abc.abstractmethod
48
- def keys(self) -> Tuple['cirq.MeasurementKey', ...]:
50
+ def keys(self) -> Tuple[cirq.MeasurementKey, ...]:
49
51
  """Gets the measurement keys in the order they were stored."""
50
52
 
51
53
  @property
52
54
  @abc.abstractmethod
53
- def records(self) -> Mapping['cirq.MeasurementKey', List[Tuple[int, ...]]]:
55
+ def records(self) -> Mapping[cirq.MeasurementKey, List[Tuple[int, ...]]]:
54
56
  """Gets the a mapping from measurement key to measurement records."""
55
57
 
56
58
  @property
57
59
  @abc.abstractmethod
58
- def channel_records(self) -> Mapping['cirq.MeasurementKey', List[int]]:
60
+ def channel_records(self) -> Mapping[cirq.MeasurementKey, List[int]]:
59
61
  """Gets the a mapping from measurement key to channel measurement records."""
60
62
 
61
63
  @abc.abstractmethod
62
- def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int:
64
+ def get_int(self, key: cirq.MeasurementKey, index=-1) -> int:
63
65
  """Gets the integer corresponding to the measurement.
64
66
 
65
67
  The integer is determined by summing the qubit-dimensional basis value
@@ -81,7 +83,7 @@ class ClassicalDataStoreReader(abc.ABC):
81
83
  """
82
84
 
83
85
  @abc.abstractmethod
84
- def get_digits(self, key: 'cirq.MeasurementKey', index=-1) -> Tuple[int, ...]:
86
+ def get_digits(self, key: cirq.MeasurementKey, index=-1) -> Tuple[int, ...]:
85
87
  """Gets the values of the qubits that were measured into this key.
86
88
 
87
89
  For example, if the measurement of qubits [q0, q1] produces [0, 1],
@@ -107,7 +109,7 @@ class ClassicalDataStoreReader(abc.ABC):
107
109
  class ClassicalDataStore(ClassicalDataStoreReader, abc.ABC):
108
110
  @abc.abstractmethod
109
111
  def record_measurement(
110
- self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid']
112
+ self, key: cirq.MeasurementKey, measurement: Sequence[int], qubits: Sequence[cirq.Qid]
111
113
  ):
112
114
  """Records a measurement.
113
115
 
@@ -122,7 +124,7 @@ class ClassicalDataStore(ClassicalDataStoreReader, abc.ABC):
122
124
  """
123
125
 
124
126
  @abc.abstractmethod
125
- def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int):
127
+ def record_channel_measurement(self, key: cirq.MeasurementKey, measurement: int):
126
128
  """Records a channel measurement.
127
129
 
128
130
  Args:
@@ -141,12 +143,10 @@ class ClassicalDataDictionaryStore(ClassicalDataStore):
141
143
  def __init__(
142
144
  self,
143
145
  *,
144
- _records: Optional[Dict['cirq.MeasurementKey', List[Tuple[int, ...]]]] = None,
145
- _measured_qubits: Optional[
146
- Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]]
147
- ] = None,
148
- _channel_records: Optional[Dict['cirq.MeasurementKey', List[int]]] = None,
149
- _measurement_types: Optional[Dict['cirq.MeasurementKey', 'cirq.MeasurementType']] = None,
146
+ _records: Optional[Dict[cirq.MeasurementKey, List[Tuple[int, ...]]]] = None,
147
+ _measured_qubits: Optional[Dict[cirq.MeasurementKey, List[Tuple[cirq.Qid, ...]]]] = None,
148
+ _channel_records: Optional[Dict[cirq.MeasurementKey, List[int]]] = None,
149
+ _measurement_types: Optional[Dict[cirq.MeasurementKey, cirq.MeasurementType]] = None,
150
150
  ):
151
151
  """Initializes a `ClassicalDataDictionaryStore` object."""
152
152
  if not _measurement_types:
@@ -165,40 +165,40 @@ class ClassicalDataDictionaryStore(ClassicalDataStore):
165
165
  _measured_qubits = {}
166
166
  if _channel_records is None:
167
167
  _channel_records = {}
168
- self._records: Dict['cirq.MeasurementKey', List[Tuple[int, ...]]] = _records
169
- self._measured_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]] = (
168
+ self._records: Dict[cirq.MeasurementKey, List[Tuple[int, ...]]] = _records
169
+ self._measured_qubits: Dict[cirq.MeasurementKey, List[Tuple[cirq.Qid, ...]]] = (
170
170
  _measured_qubits
171
171
  )
172
- self._channel_records: Dict['cirq.MeasurementKey', List[int]] = _channel_records
173
- self._measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = (
172
+ self._channel_records: Dict[cirq.MeasurementKey, List[int]] = _channel_records
173
+ self._measurement_types: Dict[cirq.MeasurementKey, cirq.MeasurementType] = (
174
174
  _measurement_types
175
175
  )
176
176
 
177
177
  @property
178
- def records(self) -> Mapping['cirq.MeasurementKey', List[Tuple[int, ...]]]:
178
+ def records(self) -> Mapping[cirq.MeasurementKey, List[Tuple[int, ...]]]:
179
179
  """Gets the a mapping from measurement key to measurement records."""
180
180
  return self._records
181
181
 
182
182
  @property
183
- def channel_records(self) -> Mapping['cirq.MeasurementKey', List[int]]:
183
+ def channel_records(self) -> Mapping[cirq.MeasurementKey, List[int]]:
184
184
  """Gets the a mapping from measurement key to channel measurement records."""
185
185
  return self._channel_records
186
186
 
187
187
  @property
188
- def measured_qubits(self) -> Mapping['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]]:
188
+ def measured_qubits(self) -> Mapping[cirq.MeasurementKey, List[Tuple[cirq.Qid, ...]]]:
189
189
  """Gets the a mapping from measurement key to the qubits measured."""
190
190
  return self._measured_qubits
191
191
 
192
192
  @property
193
- def measurement_types(self) -> Mapping['cirq.MeasurementKey', 'cirq.MeasurementType']:
193
+ def measurement_types(self) -> Mapping[cirq.MeasurementKey, cirq.MeasurementType]:
194
194
  """Gets the a mapping from measurement key to the measurement type."""
195
195
  return self._measurement_types
196
196
 
197
- def keys(self) -> Tuple['cirq.MeasurementKey', ...]:
197
+ def keys(self) -> Tuple[cirq.MeasurementKey, ...]:
198
198
  return tuple(self._measurement_types.keys())
199
199
 
200
200
  def record_measurement(
201
- self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid']
201
+ self, key: cirq.MeasurementKey, measurement: Sequence[int], qubits: Sequence[cirq.Qid]
202
202
  ):
203
203
  if len(measurement) != len(qubits):
204
204
  raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.')
@@ -217,7 +217,7 @@ class ClassicalDataDictionaryStore(ClassicalDataStore):
217
217
  measured_qubits.append(tuple(qubits))
218
218
  self._records[key].append(tuple(measurement))
219
219
 
220
- def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int):
220
+ def record_channel_measurement(self, key: cirq.MeasurementKey, measurement: int):
221
221
  if key not in self._measurement_types:
222
222
  self._measurement_types[key] = MeasurementType.CHANNEL
223
223
  self._channel_records[key] = []
@@ -225,14 +225,14 @@ class ClassicalDataDictionaryStore(ClassicalDataStore):
225
225
  raise ValueError(f"Measurement already logged to key {key}")
226
226
  self._channel_records[key].append(measurement)
227
227
 
228
- def get_digits(self, key: 'cirq.MeasurementKey', index=-1) -> Tuple[int, ...]:
228
+ def get_digits(self, key: cirq.MeasurementKey, index=-1) -> Tuple[int, ...]:
229
229
  return (
230
230
  self._records[key][index]
231
231
  if self._measurement_types[key] == MeasurementType.MEASUREMENT
232
232
  else (self._channel_records[key][index],)
233
233
  )
234
234
 
235
- def get_int(self, key: 'cirq.MeasurementKey', index=-1) -> int:
235
+ def get_int(self, key: cirq.MeasurementKey, index=-1) -> int:
236
236
  if key not in self._measurement_types:
237
237
  raise KeyError(f'The measurement key {key} is not in {self._measurement_types}')
238
238
  measurement_type = self._measurement_types[key]
cirq/value/condition.py CHANGED
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  import abc
16
18
  import dataclasses
17
19
  from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple, TYPE_CHECKING
@@ -32,15 +34,15 @@ class Condition(abc.ABC):
32
34
 
33
35
  @property
34
36
  @abc.abstractmethod
35
- def keys(self) -> Tuple['cirq.MeasurementKey', ...]:
37
+ def keys(self) -> Tuple[cirq.MeasurementKey, ...]:
36
38
  """Gets the control keys."""
37
39
 
38
40
  @abc.abstractmethod
39
- def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'):
41
+ def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey):
40
42
  """Replaces the control keys."""
41
43
 
42
44
  @abc.abstractmethod
43
- def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
45
+ def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool:
44
46
  """Resolves the condition based on the measurements."""
45
47
 
46
48
  @property
@@ -48,24 +50,24 @@ class Condition(abc.ABC):
48
50
  def qasm(self):
49
51
  """Returns the qasm of this condition."""
50
52
 
51
- def _qasm_(self, args: 'cirq.QasmArgs', **kwargs) -> Optional[str]:
53
+ def _qasm_(self, args: cirq.QasmArgs, **kwargs) -> Optional[str]:
52
54
  return self.qasm
53
55
 
54
- def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]) -> 'cirq.Condition':
56
+ def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]) -> cirq.Condition:
55
57
  condition = self
56
58
  for k in self.keys:
57
59
  condition = condition.replace_key(k, mkp.with_measurement_key_mapping(k, key_map))
58
60
  return condition
59
61
 
60
- def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'cirq.Condition':
62
+ def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> cirq.Condition:
61
63
  condition = self
62
64
  for k in self.keys:
63
65
  condition = condition.replace_key(k, mkp.with_key_path_prefix(k, path))
64
66
  return condition
65
67
 
66
68
  def _with_rescoped_keys_(
67
- self, path: Tuple[str, ...], bindable_keys: FrozenSet['cirq.MeasurementKey']
68
- ) -> 'cirq.Condition':
69
+ self, path: Tuple[str, ...], bindable_keys: FrozenSet[cirq.MeasurementKey]
70
+ ) -> cirq.Condition:
69
71
  condition = self
70
72
  for key in self.keys:
71
73
  for i in range(len(path) + 1):
@@ -85,14 +87,14 @@ class KeyCondition(Condition):
85
87
  time of resolution.
86
88
  """
87
89
 
88
- key: 'cirq.MeasurementKey'
90
+ key: cirq.MeasurementKey
89
91
  index: int = -1
90
92
 
91
93
  @property
92
94
  def keys(self):
93
95
  return (self.key,)
94
96
 
95
- def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'):
97
+ def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey):
96
98
  return KeyCondition(replacement) if self.key == current else self
97
99
 
98
100
  def __str__(self):
@@ -103,7 +105,7 @@ class KeyCondition(Condition):
103
105
  return f'cirq.KeyCondition({self.key!r}, {self.index})'
104
106
  return f'cirq.KeyCondition({self.key!r})'
105
107
 
106
- def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
108
+ def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool:
107
109
  if self.key not in classical_data.keys():
108
110
  raise ValueError(f'Measurement key {self.key} missing when testing classical control')
109
111
  return classical_data.get_int(self.key, self.index) != 0
@@ -119,7 +121,7 @@ class KeyCondition(Condition):
119
121
  def qasm(self):
120
122
  raise ValueError('QASM is defined only for SympyConditions of type key == constant.')
121
123
 
122
- def _qasm_(self, args: 'cirq.QasmArgs', **kwargs) -> Optional[str]:
124
+ def _qasm_(self, args: cirq.QasmArgs, **kwargs) -> Optional[str]:
123
125
  args.validate_version('2.0', '3.0')
124
126
  key_str = str(self.key)
125
127
  if key_str not in args.meas_key_id_map:
@@ -162,7 +164,7 @@ class BitMaskKeyCondition(Condition):
162
164
  - bitmask: Optional bitmask to apply before doing the comparison.
163
165
  """
164
166
 
165
- key: 'cirq.MeasurementKey' = attrs.field(
167
+ key: cirq.MeasurementKey = attrs.field(
166
168
  converter=lambda x: (
167
169
  x
168
170
  if isinstance(x, measurement_key.MeasurementKey)
@@ -180,8 +182,8 @@ class BitMaskKeyCondition(Condition):
180
182
 
181
183
  @staticmethod
182
184
  def create_equal_mask(
183
- key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1
184
- ) -> 'BitMaskKeyCondition':
185
+ key: cirq.MeasurementKey, bitmask: int, *, index: int = -1
186
+ ) -> BitMaskKeyCondition:
185
187
  """Creates a condition that evaluates (meas & bitmask) == bitmask."""
186
188
  return BitMaskKeyCondition(
187
189
  key, index, target_value=bitmask, equal_target=True, bitmask=bitmask
@@ -189,14 +191,14 @@ class BitMaskKeyCondition(Condition):
189
191
 
190
192
  @staticmethod
191
193
  def create_not_equal_mask(
192
- key: 'cirq.MeasurementKey', bitmask: int, *, index: int = -1
193
- ) -> 'BitMaskKeyCondition':
194
+ key: cirq.MeasurementKey, bitmask: int, *, index: int = -1
195
+ ) -> BitMaskKeyCondition:
194
196
  """Creates a condition that evaluates (meas & bitmask) != bitmask."""
195
197
  return BitMaskKeyCondition(
196
198
  key, index, target_value=bitmask, equal_target=False, bitmask=bitmask
197
199
  )
198
200
 
199
- def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'):
201
+ def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey):
200
202
  return BitMaskKeyCondition(replacement) if self.key == current else self
201
203
 
202
204
  def __str__(self):
@@ -218,7 +220,7 @@ class BitMaskKeyCondition(Condition):
218
220
  parameters = ', '.join(f'{f.name}={repr(values[f.name])}' for f in attrs.fields(type(self)))
219
221
  return f'cirq.BitMaskKeyCondition({parameters})'
220
222
 
221
- def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
223
+ def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool:
222
224
  if self.key not in classical_data.keys():
223
225
  raise ValueError(f'Measurement key {self.key} missing when testing classical control')
224
226
  value = classical_data.get_int(self.key, self.index)
@@ -269,7 +271,7 @@ class SympyCondition(Condition):
269
271
  # keep the former here.
270
272
  )
271
273
 
272
- def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'):
274
+ def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey):
273
275
  return SympyCondition(self.expr.subs({str(current): sympy.Symbol(str(replacement))}))
274
276
 
275
277
  def __str__(self):
@@ -278,7 +280,7 @@ class SympyCondition(Condition):
278
280
  def __repr__(self):
279
281
  return f'cirq.SympyCondition({proper_repr(self.expr)})'
280
282
 
281
- def resolve(self, classical_data: 'cirq.ClassicalDataStoreReader') -> bool:
283
+ def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool:
282
284
  missing = [str(k) for k in self.keys if k not in classical_data.keys()]
283
285
  if missing:
284
286
  raise ValueError(f'Measurement keys {missing} missing when testing classical control')
cirq/value/duration.py CHANGED
@@ -11,8 +11,11 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  """A typed time delta that supports picosecond accuracy."""
15
16
 
17
+ from __future__ import annotations
18
+
16
19
  import datetime
17
20
  from typing import AbstractSet, Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
18
21
 
@@ -98,7 +101,7 @@ class Duration:
98
101
  def _parameter_names_(self) -> AbstractSet[str]:
99
102
  return protocols.parameter_names(self._time_vals)
100
103
 
101
- def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Duration':
104
+ def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> Duration:
102
105
  return _duration_from_time_vals(
103
106
  protocols.resolve_parameters(self._time_vals, resolver, recursive)
104
107
  )
@@ -121,16 +124,16 @@ class Duration:
121
124
  """Returns the number of milliseconds that the duration spans."""
122
125
  return self.total_picos() / 1000_000_000
123
126
 
124
- def __add__(self, other) -> 'Duration':
127
+ def __add__(self, other) -> Duration:
125
128
  other = _attempt_duration_like_to_duration(other)
126
129
  if other is None:
127
130
  return NotImplemented
128
131
  return _duration_from_time_vals(_add_time_vals(self._time_vals, other._time_vals))
129
132
 
130
- def __radd__(self, other) -> 'Duration':
133
+ def __radd__(self, other) -> Duration:
131
134
  return self.__add__(other)
132
135
 
133
- def __sub__(self, other) -> 'Duration':
136
+ def __sub__(self, other) -> Duration:
134
137
  other = _attempt_duration_like_to_duration(other)
135
138
  if other is None:
136
139
  return NotImplemented
@@ -138,7 +141,7 @@ class Duration:
138
141
  _add_time_vals(self._time_vals, [-x for x in other._time_vals])
139
142
  )
140
143
 
141
- def __rsub__(self, other) -> 'Duration':
144
+ def __rsub__(self, other) -> Duration:
142
145
  other = _attempt_duration_like_to_duration(other)
143
146
  if other is None:
144
147
  return NotImplemented
@@ -146,17 +149,17 @@ class Duration:
146
149
  _add_time_vals(other._time_vals, [-x for x in self._time_vals])
147
150
  )
148
151
 
149
- def __mul__(self, other) -> 'Duration':
152
+ def __mul__(self, other) -> Duration:
150
153
  if not isinstance(other, (int, float, sympy.Expr)):
151
154
  return NotImplemented
152
155
  if other == 0:
153
156
  return _duration_from_time_vals([0] * 4)
154
157
  return _duration_from_time_vals([x * other for x in self._time_vals])
155
158
 
156
- def __rmul__(self, other) -> 'Duration':
159
+ def __rmul__(self, other) -> Duration:
157
160
  return self.__mul__(other)
158
161
 
159
- def __truediv__(self, other) -> Union['Duration', float]:
162
+ def __truediv__(self, other) -> Union[Duration, float]:
160
163
  if isinstance(other, (int, float, sympy.Expr)):
161
164
  new_time_vals = [x / other for x in self._time_vals]
162
165
  return _duration_from_time_vals(new_time_vals)