cirq-core 1.6.0.dev20250501173104__py3-none-any.whl → 1.6.0.dev20250501231232__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cirq-core might be problematic. Click here for more details.
- cirq/_version.py +1 -1
- cirq/_version_test.py +1 -1
- cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py +211 -107
- cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py +347 -3
- cirq/transformers/analytical_decompositions/two_qubit_to_cz.py +18 -18
- cirq/transformers/analytical_decompositions/two_qubit_to_fsim.py +18 -19
- cirq/transformers/analytical_decompositions/two_qubit_to_ms.py +8 -10
- cirq/transformers/analytical_decompositions/two_qubit_to_sqrt_iswap.py +26 -28
- cirq/transformers/drop_empty_moments.py +4 -2
- cirq/transformers/drop_negligible_operations.py +6 -4
- cirq/transformers/dynamical_decoupling.py +6 -4
- cirq/transformers/dynamical_decoupling_test.py +8 -6
- cirq/transformers/eject_phased_paulis.py +14 -12
- cirq/transformers/eject_z.py +8 -6
- cirq/transformers/expand_composite.py +5 -3
- cirq/transformers/gauge_compiling/sqrt_cz_gauge.py +3 -1
- cirq/transformers/heuristic_decompositions/two_qubit_gate_tabulation.py +4 -1
- cirq/transformers/insertion_sort.py +6 -4
- cirq/transformers/measurement_transformers.py +21 -21
- cirq/transformers/merge_k_qubit_gates.py +11 -9
- cirq/transformers/merge_k_qubit_gates_test.py +5 -3
- cirq/transformers/merge_single_qubit_gates.py +15 -13
- cirq/transformers/optimize_for_target_gateset.py +14 -12
- cirq/transformers/optimize_for_target_gateset_test.py +7 -3
- cirq/transformers/qubit_management_transformers.py +10 -8
- cirq/transformers/randomized_measurements.py +9 -7
- cirq/transformers/routing/initial_mapper.py +5 -3
- cirq/transformers/routing/line_initial_mapper.py +15 -13
- cirq/transformers/routing/mapping_manager.py +9 -9
- cirq/transformers/routing/route_circuit_cqc.py +17 -15
- cirq/transformers/routing/visualize_routed_circuit.py +7 -6
- cirq/transformers/stratify.py +13 -11
- cirq/transformers/synchronize_terminal_measurements.py +9 -9
- cirq/transformers/target_gatesets/compilation_target_gateset.py +19 -17
- cirq/transformers/target_gatesets/compilation_target_gateset_test.py +11 -7
- cirq/transformers/target_gatesets/cz_gateset.py +4 -2
- cirq/transformers/target_gatesets/sqrt_iswap_gateset.py +5 -3
- cirq/transformers/transformer_api.py +17 -15
- cirq/transformers/transformer_primitives.py +22 -20
- cirq/transformers/transformer_primitives_test.py +3 -1
- cirq/value/classical_data.py +26 -26
- cirq/value/condition.py +23 -21
- cirq/value/duration.py +11 -8
- cirq/value/linear_dict.py +22 -20
- cirq/value/periodic_value.py +4 -4
- cirq/value/probability.py +3 -1
- cirq/value/product_state.py +14 -12
- cirq/work/collector.py +7 -5
- cirq/work/observable_measurement.py +24 -22
- cirq/work/observable_measurement_data.py +9 -7
- cirq/work/observable_readout_calibration.py +4 -1
- cirq/work/observable_readout_calibration_test.py +4 -1
- cirq/work/observable_settings.py +4 -2
- cirq/work/pauli_sum_collector.py +8 -6
- {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/METADATA +1 -1
- {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/RECORD +59 -59
- {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/WHEEL +0 -0
- {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/licenses/LICENSE +0 -0
- {cirq_core-1.6.0.dev20250501173104.dist-info → cirq_core-1.6.0.dev20250501231232.dist-info}/top_level.txt +0 -0
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
|
|
15
15
|
"""Defines the API for circuit transformers in Cirq."""
|
|
16
16
|
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
17
19
|
import dataclasses
|
|
18
20
|
import enum
|
|
19
21
|
import functools
|
|
@@ -83,8 +85,8 @@ class _LoggerNode:
|
|
|
83
85
|
|
|
84
86
|
transformer_id: int
|
|
85
87
|
transformer_name: str
|
|
86
|
-
initial_circuit:
|
|
87
|
-
final_circuit:
|
|
88
|
+
initial_circuit: cirq.AbstractCircuit
|
|
89
|
+
final_circuit: cirq.AbstractCircuit
|
|
88
90
|
logs: List[Tuple[LogLevel, Tuple[str, ...]]] = dataclasses.field(default_factory=list)
|
|
89
91
|
nested_loggers: List[int] = dataclasses.field(default_factory=list)
|
|
90
92
|
|
|
@@ -116,7 +118,7 @@ class TransformerLogger:
|
|
|
116
118
|
self._logs: List[_LoggerNode] = []
|
|
117
119
|
self._stack: List[int] = []
|
|
118
120
|
|
|
119
|
-
def register_initial(self, circuit:
|
|
121
|
+
def register_initial(self, circuit: cirq.AbstractCircuit, transformer_name: str) -> None:
|
|
120
122
|
"""Register the beginning of a new transformer stage.
|
|
121
123
|
|
|
122
124
|
Args:
|
|
@@ -143,7 +145,7 @@ class TransformerLogger:
|
|
|
143
145
|
raise ValueError('No active transformer found.')
|
|
144
146
|
self._logs[self._stack[-1]].logs.append((level, args))
|
|
145
147
|
|
|
146
|
-
def register_final(self, circuit:
|
|
148
|
+
def register_final(self, circuit: cirq.AbstractCircuit, transformer_name: str) -> None:
|
|
147
149
|
"""Register the end of the currently active transformer stage.
|
|
148
150
|
|
|
149
151
|
Args:
|
|
@@ -195,13 +197,13 @@ class TransformerLogger:
|
|
|
195
197
|
class NoOpTransformerLogger(TransformerLogger):
|
|
196
198
|
"""All calls to this logger are a no-op"""
|
|
197
199
|
|
|
198
|
-
def register_initial(self, circuit:
|
|
200
|
+
def register_initial(self, circuit: cirq.AbstractCircuit, transformer_name: str) -> None:
|
|
199
201
|
pass
|
|
200
202
|
|
|
201
203
|
def log(self, *args: str, level: LogLevel = LogLevel.INFO) -> None:
|
|
202
204
|
pass
|
|
203
205
|
|
|
204
|
-
def register_final(self, circuit:
|
|
206
|
+
def register_final(self, circuit: cirq.AbstractCircuit, transformer_name: str) -> None:
|
|
205
207
|
pass
|
|
206
208
|
|
|
207
209
|
def show(self, level: LogLevel = LogLevel.INFO) -> None:
|
|
@@ -262,8 +264,8 @@ class TRANSFORMER(Protocol):
|
|
|
262
264
|
"""
|
|
263
265
|
|
|
264
266
|
def __call__(
|
|
265
|
-
self, circuit:
|
|
266
|
-
) ->
|
|
267
|
+
self, circuit: cirq.AbstractCircuit, *, context: Optional[TransformerContext] = None
|
|
268
|
+
) -> cirq.AbstractCircuit: ...
|
|
267
269
|
|
|
268
270
|
|
|
269
271
|
_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
|
|
@@ -357,8 +359,8 @@ def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> A
|
|
|
357
359
|
|
|
358
360
|
@functools.wraps(method)
|
|
359
361
|
def method_with_logging(
|
|
360
|
-
self, circuit:
|
|
361
|
-
) ->
|
|
362
|
+
self, circuit: cirq.AbstractCircuit, **kwargs
|
|
363
|
+
) -> cirq.AbstractCircuit:
|
|
362
364
|
return _transform_and_log(
|
|
363
365
|
add_deep_support,
|
|
364
366
|
lambda circuit, **kwargs: method(self, circuit, **kwargs),
|
|
@@ -376,7 +378,7 @@ def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> A
|
|
|
376
378
|
default_context = _get_default_context(func)
|
|
377
379
|
|
|
378
380
|
@functools.wraps(func)
|
|
379
|
-
def func_with_logging(circuit:
|
|
381
|
+
def func_with_logging(circuit: cirq.AbstractCircuit, **kwargs) -> cirq.AbstractCircuit:
|
|
380
382
|
return _transform_and_log(
|
|
381
383
|
add_deep_support,
|
|
382
384
|
func,
|
|
@@ -401,10 +403,10 @@ def _get_default_context(func: TRANSFORMER) -> TransformerContext:
|
|
|
401
403
|
def _run_transformer_on_circuit(
|
|
402
404
|
add_deep_support: bool,
|
|
403
405
|
func: TRANSFORMER,
|
|
404
|
-
circuit:
|
|
406
|
+
circuit: cirq.AbstractCircuit,
|
|
405
407
|
extracted_context: Optional[TransformerContext],
|
|
406
408
|
**kwargs,
|
|
407
|
-
) ->
|
|
409
|
+
) -> cirq.AbstractCircuit:
|
|
408
410
|
mutable_circuit = None
|
|
409
411
|
if extracted_context and extracted_context.deep and add_deep_support:
|
|
410
412
|
batch_replace = []
|
|
@@ -429,10 +431,10 @@ def _transform_and_log(
|
|
|
429
431
|
add_deep_support: bool,
|
|
430
432
|
func: TRANSFORMER,
|
|
431
433
|
transformer_name: str,
|
|
432
|
-
circuit:
|
|
434
|
+
circuit: cirq.AbstractCircuit,
|
|
433
435
|
extracted_context: Optional[TransformerContext],
|
|
434
436
|
**kwargs,
|
|
435
|
-
) ->
|
|
437
|
+
) -> cirq.AbstractCircuit:
|
|
436
438
|
"""Helper to log initial and final circuits before and after calling the transformer."""
|
|
437
439
|
if extracted_context:
|
|
438
440
|
extracted_context.logger.register_initial(circuit, transformer_name)
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
|
|
15
15
|
"""Defines primitives for common transformer patterns."""
|
|
16
16
|
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
17
19
|
import bisect
|
|
18
20
|
import dataclasses
|
|
19
21
|
from collections import defaultdict
|
|
@@ -157,7 +159,7 @@ def _map_operations_impl(
|
|
|
157
159
|
"""
|
|
158
160
|
tags_to_ignore_set = set(tags_to_ignore)
|
|
159
161
|
|
|
160
|
-
def apply_map_func(op:
|
|
162
|
+
def apply_map_func(op: cirq.Operation, idx: int) -> List[cirq.Operation]:
|
|
161
163
|
if tags_to_ignore_set.intersection(op.tags):
|
|
162
164
|
return [op]
|
|
163
165
|
if deep and isinstance(op.untagged, circuits.CircuitOperation):
|
|
@@ -173,7 +175,7 @@ def _map_operations_impl(
|
|
|
173
175
|
).with_tags(*op.tags)
|
|
174
176
|
mapped_ops = [*ops.flatten_to_ops(map_func(op, idx))]
|
|
175
177
|
op_qubits = set(op.qubits)
|
|
176
|
-
mapped_ops_qubits: Set[
|
|
178
|
+
mapped_ops_qubits: Set[cirq.Qid] = set()
|
|
177
179
|
has_overlapping_ops = False
|
|
178
180
|
for mapped_op in mapped_ops:
|
|
179
181
|
if raise_if_add_qubits and not op_qubits.issuperset(mapped_op.qubits):
|
|
@@ -194,9 +196,9 @@ def _map_operations_impl(
|
|
|
194
196
|
]
|
|
195
197
|
return mapped_ops
|
|
196
198
|
|
|
197
|
-
new_moments: List[List[
|
|
199
|
+
new_moments: List[List[cirq.Operation]] = []
|
|
198
200
|
for idx, moment in enumerate(circuit):
|
|
199
|
-
curr_moments: List[List[
|
|
201
|
+
curr_moments: List[List[cirq.Operation]] = [[]] if wrap_in_circuit_op else []
|
|
200
202
|
placement_cache = circuits.circuit._PlacementCache()
|
|
201
203
|
for op in moment:
|
|
202
204
|
mapped_ops = apply_map_func(op, idx)
|
|
@@ -305,21 +307,21 @@ class _MergedCircuit:
|
|
|
305
307
|
of a set to store operations to preserve insertion order.
|
|
306
308
|
"""
|
|
307
309
|
|
|
308
|
-
qubit_indexes: Dict[
|
|
310
|
+
qubit_indexes: Dict[cirq.Qid, List[int]] = dataclasses.field(
|
|
309
311
|
default_factory=lambda: defaultdict(lambda: [-1])
|
|
310
312
|
)
|
|
311
|
-
mkey_indexes: Dict[
|
|
313
|
+
mkey_indexes: Dict[cirq.MeasurementKey, List[int]] = dataclasses.field(
|
|
312
314
|
default_factory=lambda: defaultdict(lambda: [-1])
|
|
313
315
|
)
|
|
314
|
-
ckey_indexes: Dict[
|
|
316
|
+
ckey_indexes: Dict[cirq.MeasurementKey, List[int]] = dataclasses.field(
|
|
315
317
|
default_factory=lambda: defaultdict(lambda: [-1])
|
|
316
318
|
)
|
|
317
|
-
ops_by_index: List[Dict[
|
|
319
|
+
ops_by_index: List[Dict[cirq.Operation, int]] = dataclasses.field(default_factory=list)
|
|
318
320
|
|
|
319
321
|
def append_empty_moment(self) -> None:
|
|
320
322
|
self.ops_by_index.append({})
|
|
321
323
|
|
|
322
|
-
def add_op_to_moment(self, moment_index: int, op:
|
|
324
|
+
def add_op_to_moment(self, moment_index: int, op: cirq.Operation) -> None:
|
|
323
325
|
self.ops_by_index[moment_index][op] = 0
|
|
324
326
|
for q in op.qubits:
|
|
325
327
|
if moment_index > self.qubit_indexes[q][-1]:
|
|
@@ -331,7 +333,7 @@ class _MergedCircuit:
|
|
|
331
333
|
for ckey in protocols.control_keys(op):
|
|
332
334
|
bisect.insort(self.ckey_indexes[ckey], moment_index)
|
|
333
335
|
|
|
334
|
-
def remove_op_from_moment(self, moment_index: int, op:
|
|
336
|
+
def remove_op_from_moment(self, moment_index: int, op: cirq.Operation) -> None:
|
|
335
337
|
self.ops_by_index[moment_index].pop(op)
|
|
336
338
|
for q in op.qubits:
|
|
337
339
|
if self.qubit_indexes[q][-1] == moment_index:
|
|
@@ -344,8 +346,8 @@ class _MergedCircuit:
|
|
|
344
346
|
self.ckey_indexes[ckey].remove(moment_index)
|
|
345
347
|
|
|
346
348
|
def get_mergeable_ops(
|
|
347
|
-
self, op:
|
|
348
|
-
) -> Tuple[int, List[
|
|
349
|
+
self, op: cirq.Operation, op_qs: Set[cirq.Qid]
|
|
350
|
+
) -> Tuple[int, List[cirq.Operation]]:
|
|
349
351
|
# Find the index of previous moment which can be merged with `op`.
|
|
350
352
|
idx = max([self.qubit_indexes[q][-1] for q in op_qs], default=-1)
|
|
351
353
|
idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in protocols.control_keys(op)])
|
|
@@ -360,7 +362,7 @@ class _MergedCircuit:
|
|
|
360
362
|
left_op for left_op in self.ops_by_index[idx] if not op_qs.isdisjoint(left_op.qubits)
|
|
361
363
|
]
|
|
362
364
|
|
|
363
|
-
def get_cirq_circuit(self) ->
|
|
365
|
+
def get_cirq_circuit(self) -> cirq.Circuit:
|
|
364
366
|
return circuits.Circuit(circuits.Moment(m.keys()) for m in self.ops_by_index)
|
|
365
367
|
|
|
366
368
|
|
|
@@ -493,7 +495,7 @@ def merge_operations(
|
|
|
493
495
|
|
|
494
496
|
def merge_operations_to_circuit_op(
|
|
495
497
|
circuit: CIRCUIT_TYPE,
|
|
496
|
-
can_merge: Callable[[Sequence[
|
|
498
|
+
can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool],
|
|
497
499
|
*,
|
|
498
500
|
tags_to_ignore: Sequence[Hashable] = (),
|
|
499
501
|
merged_circuit_op_tag: str = "Merged connected component",
|
|
@@ -524,8 +526,8 @@ def merge_operations_to_circuit_op(
|
|
|
524
526
|
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
|
|
525
527
|
"""
|
|
526
528
|
|
|
527
|
-
def merge_func(op1:
|
|
528
|
-
def get_ops(op:
|
|
529
|
+
def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> Optional[cirq.Operation]:
|
|
530
|
+
def get_ops(op: cirq.Operation):
|
|
529
531
|
op_untagged = op.untagged
|
|
530
532
|
return (
|
|
531
533
|
[*op_untagged.circuit.all_operations()]
|
|
@@ -573,7 +575,7 @@ def merge_k_qubit_unitaries_to_circuit_op(
|
|
|
573
575
|
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
|
|
574
576
|
"""
|
|
575
577
|
|
|
576
|
-
def can_merge(ops1: Sequence[
|
|
578
|
+
def can_merge(ops1: Sequence[cirq.Operation], ops2: Sequence[cirq.Operation]) -> bool:
|
|
577
579
|
return all(
|
|
578
580
|
protocols.num_qubits(op) <= k and protocols.has_unitary(op)
|
|
579
581
|
for op_list in [ops1, ops2]
|
|
@@ -659,7 +661,7 @@ def unroll_circuit_op(
|
|
|
659
661
|
"""
|
|
660
662
|
|
|
661
663
|
def map_func(m: circuits.Moment, _: int):
|
|
662
|
-
to_zip: List[
|
|
664
|
+
to_zip: List[cirq.AbstractCircuit] = []
|
|
663
665
|
for op in m:
|
|
664
666
|
op_untagged = op.untagged
|
|
665
667
|
if isinstance(op_untagged, circuits.CircuitOperation):
|
|
@@ -751,7 +753,7 @@ def unroll_circuit_op_greedy_frontier(
|
|
|
751
753
|
Copy of input circuit with (Tagged) CircuitOperation's expanded inline at qubit frontier.
|
|
752
754
|
"""
|
|
753
755
|
unrolled_circuit = circuit.unfreeze(copy=True)
|
|
754
|
-
frontier: Dict[
|
|
756
|
+
frontier: Dict[cirq.Qid, int] = defaultdict(lambda: 0)
|
|
755
757
|
idx = 0
|
|
756
758
|
while idx < len(unrolled_circuit):
|
|
757
759
|
for op in unrolled_circuit[idx].operations:
|
|
@@ -799,7 +801,7 @@ def toggle_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool =
|
|
|
799
801
|
"""
|
|
800
802
|
tags_to_xor = set(tags)
|
|
801
803
|
|
|
802
|
-
def map_func(op:
|
|
804
|
+
def map_func(op: cirq.Operation, _) -> cirq.Operation:
|
|
803
805
|
return (
|
|
804
806
|
op
|
|
805
807
|
if deep and isinstance(op, circuits.CircuitOperation)
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
15
17
|
from typing import Iterator, List, Optional
|
|
16
18
|
|
|
17
19
|
import pytest
|
|
@@ -721,7 +723,7 @@ def test_merge_operations_to_circuit_op_merges_connected_component():
|
|
|
721
723
|
''',
|
|
722
724
|
)
|
|
723
725
|
|
|
724
|
-
def can_merge(ops1: List[
|
|
726
|
+
def can_merge(ops1: List[cirq.Operation], ops2: List[cirq.Operation]) -> bool:
|
|
725
727
|
"""Artificial example where a CZ will absorb any merge-able operation."""
|
|
726
728
|
return any(o.gate == cirq.CZ for op_list in [ops1, ops2] for o in op_list)
|
|
727
729
|
|
cirq/value/classical_data.py
CHANGED
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
15
17
|
import abc
|
|
16
18
|
import enum
|
|
17
19
|
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
|
|
@@ -45,21 +47,21 @@ class MeasurementType(enum.IntEnum):
|
|
|
45
47
|
|
|
46
48
|
class ClassicalDataStoreReader(abc.ABC):
|
|
47
49
|
@abc.abstractmethod
|
|
48
|
-
def keys(self) -> Tuple[
|
|
50
|
+
def keys(self) -> Tuple[cirq.MeasurementKey, ...]:
|
|
49
51
|
"""Gets the measurement keys in the order they were stored."""
|
|
50
52
|
|
|
51
53
|
@property
|
|
52
54
|
@abc.abstractmethod
|
|
53
|
-
def records(self) -> Mapping[
|
|
55
|
+
def records(self) -> Mapping[cirq.MeasurementKey, List[Tuple[int, ...]]]:
|
|
54
56
|
"""Gets the a mapping from measurement key to measurement records."""
|
|
55
57
|
|
|
56
58
|
@property
|
|
57
59
|
@abc.abstractmethod
|
|
58
|
-
def channel_records(self) -> Mapping[
|
|
60
|
+
def channel_records(self) -> Mapping[cirq.MeasurementKey, List[int]]:
|
|
59
61
|
"""Gets the a mapping from measurement key to channel measurement records."""
|
|
60
62
|
|
|
61
63
|
@abc.abstractmethod
|
|
62
|
-
def get_int(self, key:
|
|
64
|
+
def get_int(self, key: cirq.MeasurementKey, index=-1) -> int:
|
|
63
65
|
"""Gets the integer corresponding to the measurement.
|
|
64
66
|
|
|
65
67
|
The integer is determined by summing the qubit-dimensional basis value
|
|
@@ -81,7 +83,7 @@ class ClassicalDataStoreReader(abc.ABC):
|
|
|
81
83
|
"""
|
|
82
84
|
|
|
83
85
|
@abc.abstractmethod
|
|
84
|
-
def get_digits(self, key:
|
|
86
|
+
def get_digits(self, key: cirq.MeasurementKey, index=-1) -> Tuple[int, ...]:
|
|
85
87
|
"""Gets the values of the qubits that were measured into this key.
|
|
86
88
|
|
|
87
89
|
For example, if the measurement of qubits [q0, q1] produces [0, 1],
|
|
@@ -107,7 +109,7 @@ class ClassicalDataStoreReader(abc.ABC):
|
|
|
107
109
|
class ClassicalDataStore(ClassicalDataStoreReader, abc.ABC):
|
|
108
110
|
@abc.abstractmethod
|
|
109
111
|
def record_measurement(
|
|
110
|
-
self, key:
|
|
112
|
+
self, key: cirq.MeasurementKey, measurement: Sequence[int], qubits: Sequence[cirq.Qid]
|
|
111
113
|
):
|
|
112
114
|
"""Records a measurement.
|
|
113
115
|
|
|
@@ -122,7 +124,7 @@ class ClassicalDataStore(ClassicalDataStoreReader, abc.ABC):
|
|
|
122
124
|
"""
|
|
123
125
|
|
|
124
126
|
@abc.abstractmethod
|
|
125
|
-
def record_channel_measurement(self, key:
|
|
127
|
+
def record_channel_measurement(self, key: cirq.MeasurementKey, measurement: int):
|
|
126
128
|
"""Records a channel measurement.
|
|
127
129
|
|
|
128
130
|
Args:
|
|
@@ -141,12 +143,10 @@ class ClassicalDataDictionaryStore(ClassicalDataStore):
|
|
|
141
143
|
def __init__(
|
|
142
144
|
self,
|
|
143
145
|
*,
|
|
144
|
-
_records: Optional[Dict[
|
|
145
|
-
_measured_qubits: Optional[
|
|
146
|
-
|
|
147
|
-
] = None,
|
|
148
|
-
_channel_records: Optional[Dict['cirq.MeasurementKey', List[int]]] = None,
|
|
149
|
-
_measurement_types: Optional[Dict['cirq.MeasurementKey', 'cirq.MeasurementType']] = None,
|
|
146
|
+
_records: Optional[Dict[cirq.MeasurementKey, List[Tuple[int, ...]]]] = None,
|
|
147
|
+
_measured_qubits: Optional[Dict[cirq.MeasurementKey, List[Tuple[cirq.Qid, ...]]]] = None,
|
|
148
|
+
_channel_records: Optional[Dict[cirq.MeasurementKey, List[int]]] = None,
|
|
149
|
+
_measurement_types: Optional[Dict[cirq.MeasurementKey, cirq.MeasurementType]] = None,
|
|
150
150
|
):
|
|
151
151
|
"""Initializes a `ClassicalDataDictionaryStore` object."""
|
|
152
152
|
if not _measurement_types:
|
|
@@ -165,40 +165,40 @@ class ClassicalDataDictionaryStore(ClassicalDataStore):
|
|
|
165
165
|
_measured_qubits = {}
|
|
166
166
|
if _channel_records is None:
|
|
167
167
|
_channel_records = {}
|
|
168
|
-
self._records: Dict[
|
|
169
|
-
self._measured_qubits: Dict[
|
|
168
|
+
self._records: Dict[cirq.MeasurementKey, List[Tuple[int, ...]]] = _records
|
|
169
|
+
self._measured_qubits: Dict[cirq.MeasurementKey, List[Tuple[cirq.Qid, ...]]] = (
|
|
170
170
|
_measured_qubits
|
|
171
171
|
)
|
|
172
|
-
self._channel_records: Dict[
|
|
173
|
-
self._measurement_types: Dict[
|
|
172
|
+
self._channel_records: Dict[cirq.MeasurementKey, List[int]] = _channel_records
|
|
173
|
+
self._measurement_types: Dict[cirq.MeasurementKey, cirq.MeasurementType] = (
|
|
174
174
|
_measurement_types
|
|
175
175
|
)
|
|
176
176
|
|
|
177
177
|
@property
|
|
178
|
-
def records(self) -> Mapping[
|
|
178
|
+
def records(self) -> Mapping[cirq.MeasurementKey, List[Tuple[int, ...]]]:
|
|
179
179
|
"""Gets the a mapping from measurement key to measurement records."""
|
|
180
180
|
return self._records
|
|
181
181
|
|
|
182
182
|
@property
|
|
183
|
-
def channel_records(self) -> Mapping[
|
|
183
|
+
def channel_records(self) -> Mapping[cirq.MeasurementKey, List[int]]:
|
|
184
184
|
"""Gets the a mapping from measurement key to channel measurement records."""
|
|
185
185
|
return self._channel_records
|
|
186
186
|
|
|
187
187
|
@property
|
|
188
|
-
def measured_qubits(self) -> Mapping[
|
|
188
|
+
def measured_qubits(self) -> Mapping[cirq.MeasurementKey, List[Tuple[cirq.Qid, ...]]]:
|
|
189
189
|
"""Gets the a mapping from measurement key to the qubits measured."""
|
|
190
190
|
return self._measured_qubits
|
|
191
191
|
|
|
192
192
|
@property
|
|
193
|
-
def measurement_types(self) -> Mapping[
|
|
193
|
+
def measurement_types(self) -> Mapping[cirq.MeasurementKey, cirq.MeasurementType]:
|
|
194
194
|
"""Gets the a mapping from measurement key to the measurement type."""
|
|
195
195
|
return self._measurement_types
|
|
196
196
|
|
|
197
|
-
def keys(self) -> Tuple[
|
|
197
|
+
def keys(self) -> Tuple[cirq.MeasurementKey, ...]:
|
|
198
198
|
return tuple(self._measurement_types.keys())
|
|
199
199
|
|
|
200
200
|
def record_measurement(
|
|
201
|
-
self, key:
|
|
201
|
+
self, key: cirq.MeasurementKey, measurement: Sequence[int], qubits: Sequence[cirq.Qid]
|
|
202
202
|
):
|
|
203
203
|
if len(measurement) != len(qubits):
|
|
204
204
|
raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.')
|
|
@@ -217,7 +217,7 @@ class ClassicalDataDictionaryStore(ClassicalDataStore):
|
|
|
217
217
|
measured_qubits.append(tuple(qubits))
|
|
218
218
|
self._records[key].append(tuple(measurement))
|
|
219
219
|
|
|
220
|
-
def record_channel_measurement(self, key:
|
|
220
|
+
def record_channel_measurement(self, key: cirq.MeasurementKey, measurement: int):
|
|
221
221
|
if key not in self._measurement_types:
|
|
222
222
|
self._measurement_types[key] = MeasurementType.CHANNEL
|
|
223
223
|
self._channel_records[key] = []
|
|
@@ -225,14 +225,14 @@ class ClassicalDataDictionaryStore(ClassicalDataStore):
|
|
|
225
225
|
raise ValueError(f"Measurement already logged to key {key}")
|
|
226
226
|
self._channel_records[key].append(measurement)
|
|
227
227
|
|
|
228
|
-
def get_digits(self, key:
|
|
228
|
+
def get_digits(self, key: cirq.MeasurementKey, index=-1) -> Tuple[int, ...]:
|
|
229
229
|
return (
|
|
230
230
|
self._records[key][index]
|
|
231
231
|
if self._measurement_types[key] == MeasurementType.MEASUREMENT
|
|
232
232
|
else (self._channel_records[key][index],)
|
|
233
233
|
)
|
|
234
234
|
|
|
235
|
-
def get_int(self, key:
|
|
235
|
+
def get_int(self, key: cirq.MeasurementKey, index=-1) -> int:
|
|
236
236
|
if key not in self._measurement_types:
|
|
237
237
|
raise KeyError(f'The measurement key {key} is not in {self._measurement_types}')
|
|
238
238
|
measurement_type = self._measurement_types[key]
|
cirq/value/condition.py
CHANGED
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
15
17
|
import abc
|
|
16
18
|
import dataclasses
|
|
17
19
|
from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple, TYPE_CHECKING
|
|
@@ -32,15 +34,15 @@ class Condition(abc.ABC):
|
|
|
32
34
|
|
|
33
35
|
@property
|
|
34
36
|
@abc.abstractmethod
|
|
35
|
-
def keys(self) -> Tuple[
|
|
37
|
+
def keys(self) -> Tuple[cirq.MeasurementKey, ...]:
|
|
36
38
|
"""Gets the control keys."""
|
|
37
39
|
|
|
38
40
|
@abc.abstractmethod
|
|
39
|
-
def replace_key(self, current:
|
|
41
|
+
def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey):
|
|
40
42
|
"""Replaces the control keys."""
|
|
41
43
|
|
|
42
44
|
@abc.abstractmethod
|
|
43
|
-
def resolve(self, classical_data:
|
|
45
|
+
def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool:
|
|
44
46
|
"""Resolves the condition based on the measurements."""
|
|
45
47
|
|
|
46
48
|
@property
|
|
@@ -48,24 +50,24 @@ class Condition(abc.ABC):
|
|
|
48
50
|
def qasm(self):
|
|
49
51
|
"""Returns the qasm of this condition."""
|
|
50
52
|
|
|
51
|
-
def _qasm_(self, args:
|
|
53
|
+
def _qasm_(self, args: cirq.QasmArgs, **kwargs) -> Optional[str]:
|
|
52
54
|
return self.qasm
|
|
53
55
|
|
|
54
|
-
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]) ->
|
|
56
|
+
def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]) -> cirq.Condition:
|
|
55
57
|
condition = self
|
|
56
58
|
for k in self.keys:
|
|
57
59
|
condition = condition.replace_key(k, mkp.with_measurement_key_mapping(k, key_map))
|
|
58
60
|
return condition
|
|
59
61
|
|
|
60
|
-
def _with_key_path_prefix_(self, path: Tuple[str, ...]) ->
|
|
62
|
+
def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> cirq.Condition:
|
|
61
63
|
condition = self
|
|
62
64
|
for k in self.keys:
|
|
63
65
|
condition = condition.replace_key(k, mkp.with_key_path_prefix(k, path))
|
|
64
66
|
return condition
|
|
65
67
|
|
|
66
68
|
def _with_rescoped_keys_(
|
|
67
|
-
self, path: Tuple[str, ...], bindable_keys: FrozenSet[
|
|
68
|
-
) ->
|
|
69
|
+
self, path: Tuple[str, ...], bindable_keys: FrozenSet[cirq.MeasurementKey]
|
|
70
|
+
) -> cirq.Condition:
|
|
69
71
|
condition = self
|
|
70
72
|
for key in self.keys:
|
|
71
73
|
for i in range(len(path) + 1):
|
|
@@ -85,14 +87,14 @@ class KeyCondition(Condition):
|
|
|
85
87
|
time of resolution.
|
|
86
88
|
"""
|
|
87
89
|
|
|
88
|
-
key:
|
|
90
|
+
key: cirq.MeasurementKey
|
|
89
91
|
index: int = -1
|
|
90
92
|
|
|
91
93
|
@property
|
|
92
94
|
def keys(self):
|
|
93
95
|
return (self.key,)
|
|
94
96
|
|
|
95
|
-
def replace_key(self, current:
|
|
97
|
+
def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey):
|
|
96
98
|
return KeyCondition(replacement) if self.key == current else self
|
|
97
99
|
|
|
98
100
|
def __str__(self):
|
|
@@ -103,7 +105,7 @@ class KeyCondition(Condition):
|
|
|
103
105
|
return f'cirq.KeyCondition({self.key!r}, {self.index})'
|
|
104
106
|
return f'cirq.KeyCondition({self.key!r})'
|
|
105
107
|
|
|
106
|
-
def resolve(self, classical_data:
|
|
108
|
+
def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool:
|
|
107
109
|
if self.key not in classical_data.keys():
|
|
108
110
|
raise ValueError(f'Measurement key {self.key} missing when testing classical control')
|
|
109
111
|
return classical_data.get_int(self.key, self.index) != 0
|
|
@@ -119,7 +121,7 @@ class KeyCondition(Condition):
|
|
|
119
121
|
def qasm(self):
|
|
120
122
|
raise ValueError('QASM is defined only for SympyConditions of type key == constant.')
|
|
121
123
|
|
|
122
|
-
def _qasm_(self, args:
|
|
124
|
+
def _qasm_(self, args: cirq.QasmArgs, **kwargs) -> Optional[str]:
|
|
123
125
|
args.validate_version('2.0', '3.0')
|
|
124
126
|
key_str = str(self.key)
|
|
125
127
|
if key_str not in args.meas_key_id_map:
|
|
@@ -162,7 +164,7 @@ class BitMaskKeyCondition(Condition):
|
|
|
162
164
|
- bitmask: Optional bitmask to apply before doing the comparison.
|
|
163
165
|
"""
|
|
164
166
|
|
|
165
|
-
key:
|
|
167
|
+
key: cirq.MeasurementKey = attrs.field(
|
|
166
168
|
converter=lambda x: (
|
|
167
169
|
x
|
|
168
170
|
if isinstance(x, measurement_key.MeasurementKey)
|
|
@@ -180,8 +182,8 @@ class BitMaskKeyCondition(Condition):
|
|
|
180
182
|
|
|
181
183
|
@staticmethod
|
|
182
184
|
def create_equal_mask(
|
|
183
|
-
key:
|
|
184
|
-
) ->
|
|
185
|
+
key: cirq.MeasurementKey, bitmask: int, *, index: int = -1
|
|
186
|
+
) -> BitMaskKeyCondition:
|
|
185
187
|
"""Creates a condition that evaluates (meas & bitmask) == bitmask."""
|
|
186
188
|
return BitMaskKeyCondition(
|
|
187
189
|
key, index, target_value=bitmask, equal_target=True, bitmask=bitmask
|
|
@@ -189,14 +191,14 @@ class BitMaskKeyCondition(Condition):
|
|
|
189
191
|
|
|
190
192
|
@staticmethod
|
|
191
193
|
def create_not_equal_mask(
|
|
192
|
-
key:
|
|
193
|
-
) ->
|
|
194
|
+
key: cirq.MeasurementKey, bitmask: int, *, index: int = -1
|
|
195
|
+
) -> BitMaskKeyCondition:
|
|
194
196
|
"""Creates a condition that evaluates (meas & bitmask) != bitmask."""
|
|
195
197
|
return BitMaskKeyCondition(
|
|
196
198
|
key, index, target_value=bitmask, equal_target=False, bitmask=bitmask
|
|
197
199
|
)
|
|
198
200
|
|
|
199
|
-
def replace_key(self, current:
|
|
201
|
+
def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey):
|
|
200
202
|
return BitMaskKeyCondition(replacement) if self.key == current else self
|
|
201
203
|
|
|
202
204
|
def __str__(self):
|
|
@@ -218,7 +220,7 @@ class BitMaskKeyCondition(Condition):
|
|
|
218
220
|
parameters = ', '.join(f'{f.name}={repr(values[f.name])}' for f in attrs.fields(type(self)))
|
|
219
221
|
return f'cirq.BitMaskKeyCondition({parameters})'
|
|
220
222
|
|
|
221
|
-
def resolve(self, classical_data:
|
|
223
|
+
def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool:
|
|
222
224
|
if self.key not in classical_data.keys():
|
|
223
225
|
raise ValueError(f'Measurement key {self.key} missing when testing classical control')
|
|
224
226
|
value = classical_data.get_int(self.key, self.index)
|
|
@@ -269,7 +271,7 @@ class SympyCondition(Condition):
|
|
|
269
271
|
# keep the former here.
|
|
270
272
|
)
|
|
271
273
|
|
|
272
|
-
def replace_key(self, current:
|
|
274
|
+
def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey):
|
|
273
275
|
return SympyCondition(self.expr.subs({str(current): sympy.Symbol(str(replacement))}))
|
|
274
276
|
|
|
275
277
|
def __str__(self):
|
|
@@ -278,7 +280,7 @@ class SympyCondition(Condition):
|
|
|
278
280
|
def __repr__(self):
|
|
279
281
|
return f'cirq.SympyCondition({proper_repr(self.expr)})'
|
|
280
282
|
|
|
281
|
-
def resolve(self, classical_data:
|
|
283
|
+
def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool:
|
|
282
284
|
missing = [str(k) for k in self.keys if k not in classical_data.keys()]
|
|
283
285
|
if missing:
|
|
284
286
|
raise ValueError(f'Measurement keys {missing} missing when testing classical control')
|
cirq/value/duration.py
CHANGED
|
@@ -11,8 +11,11 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
14
15
|
"""A typed time delta that supports picosecond accuracy."""
|
|
15
16
|
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
16
19
|
import datetime
|
|
17
20
|
from typing import AbstractSet, Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
|
18
21
|
|
|
@@ -98,7 +101,7 @@ class Duration:
|
|
|
98
101
|
def _parameter_names_(self) -> AbstractSet[str]:
|
|
99
102
|
return protocols.parameter_names(self._time_vals)
|
|
100
103
|
|
|
101
|
-
def _resolve_parameters_(self, resolver:
|
|
104
|
+
def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> Duration:
|
|
102
105
|
return _duration_from_time_vals(
|
|
103
106
|
protocols.resolve_parameters(self._time_vals, resolver, recursive)
|
|
104
107
|
)
|
|
@@ -121,16 +124,16 @@ class Duration:
|
|
|
121
124
|
"""Returns the number of milliseconds that the duration spans."""
|
|
122
125
|
return self.total_picos() / 1000_000_000
|
|
123
126
|
|
|
124
|
-
def __add__(self, other) ->
|
|
127
|
+
def __add__(self, other) -> Duration:
|
|
125
128
|
other = _attempt_duration_like_to_duration(other)
|
|
126
129
|
if other is None:
|
|
127
130
|
return NotImplemented
|
|
128
131
|
return _duration_from_time_vals(_add_time_vals(self._time_vals, other._time_vals))
|
|
129
132
|
|
|
130
|
-
def __radd__(self, other) ->
|
|
133
|
+
def __radd__(self, other) -> Duration:
|
|
131
134
|
return self.__add__(other)
|
|
132
135
|
|
|
133
|
-
def __sub__(self, other) ->
|
|
136
|
+
def __sub__(self, other) -> Duration:
|
|
134
137
|
other = _attempt_duration_like_to_duration(other)
|
|
135
138
|
if other is None:
|
|
136
139
|
return NotImplemented
|
|
@@ -138,7 +141,7 @@ class Duration:
|
|
|
138
141
|
_add_time_vals(self._time_vals, [-x for x in other._time_vals])
|
|
139
142
|
)
|
|
140
143
|
|
|
141
|
-
def __rsub__(self, other) ->
|
|
144
|
+
def __rsub__(self, other) -> Duration:
|
|
142
145
|
other = _attempt_duration_like_to_duration(other)
|
|
143
146
|
if other is None:
|
|
144
147
|
return NotImplemented
|
|
@@ -146,17 +149,17 @@ class Duration:
|
|
|
146
149
|
_add_time_vals(other._time_vals, [-x for x in self._time_vals])
|
|
147
150
|
)
|
|
148
151
|
|
|
149
|
-
def __mul__(self, other) ->
|
|
152
|
+
def __mul__(self, other) -> Duration:
|
|
150
153
|
if not isinstance(other, (int, float, sympy.Expr)):
|
|
151
154
|
return NotImplemented
|
|
152
155
|
if other == 0:
|
|
153
156
|
return _duration_from_time_vals([0] * 4)
|
|
154
157
|
return _duration_from_time_vals([x * other for x in self._time_vals])
|
|
155
158
|
|
|
156
|
-
def __rmul__(self, other) ->
|
|
159
|
+
def __rmul__(self, other) -> Duration:
|
|
157
160
|
return self.__mul__(other)
|
|
158
161
|
|
|
159
|
-
def __truediv__(self, other) -> Union[
|
|
162
|
+
def __truediv__(self, other) -> Union[Duration, float]:
|
|
160
163
|
if isinstance(other, (int, float, sympy.Expr)):
|
|
161
164
|
new_time_vals = [x / other for x in self._time_vals]
|
|
162
165
|
return _duration_from_time_vals(new_time_vals)
|