bloqade-circuit 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +5 -9
- bloqade/analysis/address/lattice.py +1 -1
- bloqade/analysis/fidelity/__init__.py +1 -0
- bloqade/analysis/fidelity/analysis.py +69 -0
- bloqade/device.py +130 -0
- bloqade/noise/__init__.py +2 -1
- bloqade/noise/fidelity.py +51 -0
- bloqade/noise/native/model.py +1 -2
- bloqade/noise/native/rewrite.py +5 -5
- bloqade/noise/native/stmts.py +40 -11
- bloqade/pyqrack/__init__.py +8 -2
- bloqade/pyqrack/base.py +24 -3
- bloqade/pyqrack/device.py +166 -0
- bloqade/pyqrack/noise/native.py +1 -2
- bloqade/pyqrack/qasm2/core.py +31 -15
- bloqade/pyqrack/qasm2/glob.py +28 -0
- bloqade/pyqrack/qasm2/uop.py +9 -1
- bloqade/pyqrack/reg.py +17 -49
- bloqade/pyqrack/squin/__init__.py +0 -0
- bloqade/pyqrack/squin/op.py +154 -0
- bloqade/pyqrack/squin/qubit.py +85 -0
- bloqade/pyqrack/squin/runtime.py +515 -0
- bloqade/pyqrack/squin/wire.py +69 -0
- bloqade/pyqrack/target.py +9 -2
- bloqade/pyqrack/task.py +30 -0
- bloqade/qasm2/_wrappers.py +11 -1
- bloqade/qasm2/dialects/core/stmts.py +15 -4
- bloqade/qasm2/dialects/expr/_emit.py +9 -8
- bloqade/qasm2/emit/base.py +4 -2
- bloqade/qasm2/emit/gate.py +0 -14
- bloqade/qasm2/emit/main.py +19 -15
- bloqade/qasm2/emit/target.py +2 -6
- bloqade/qasm2/glob.py +1 -1
- bloqade/qasm2/parse/lowering.py +124 -1
- bloqade/qasm2/passes/glob.py +3 -3
- bloqade/qasm2/passes/lift_qubits.py +26 -0
- bloqade/qasm2/passes/noise.py +6 -14
- bloqade/qasm2/passes/parallel.py +3 -3
- bloqade/qasm2/passes/py2qasm.py +1 -2
- bloqade/qasm2/passes/qasm2py.py +1 -2
- bloqade/qasm2/rewrite/desugar.py +6 -6
- bloqade/qasm2/rewrite/glob.py +9 -9
- bloqade/qasm2/rewrite/heuristic_noise.py +30 -38
- bloqade/qasm2/rewrite/insert_qubits.py +34 -0
- bloqade/qasm2/rewrite/native_gates.py +54 -55
- bloqade/qasm2/rewrite/parallel_to_uop.py +9 -9
- bloqade/qasm2/rewrite/uop_to_parallel.py +20 -22
- bloqade/qasm2/types.py +3 -6
- bloqade/qbraid/schema.py +10 -12
- bloqade/squin/__init__.py +1 -1
- bloqade/squin/analysis/nsites/analysis.py +4 -6
- bloqade/squin/analysis/nsites/impls.py +2 -6
- bloqade/squin/analysis/schedule.py +1 -1
- bloqade/squin/groups.py +15 -7
- bloqade/squin/noise/__init__.py +27 -0
- bloqade/squin/noise/_dialect.py +3 -0
- bloqade/squin/noise/stmts.py +59 -0
- bloqade/squin/op/__init__.py +35 -5
- bloqade/squin/op/number.py +5 -0
- bloqade/squin/op/rewrite.py +46 -0
- bloqade/squin/op/stmts.py +23 -2
- bloqade/squin/op/types.py +14 -0
- bloqade/squin/qubit.py +79 -11
- bloqade/squin/rewrite/__init__.py +0 -0
- bloqade/squin/rewrite/measure_desugar.py +33 -0
- bloqade/squin/wire.py +31 -2
- bloqade/stim/emit/stim.py +1 -1
- bloqade/task.py +94 -0
- bloqade/visual/animation/base.py +25 -15
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/METADATA +8 -2
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/RECORD +73 -52
- bloqade/squin/op/complex.py +0 -6
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,8 +3,8 @@ from typing import Dict, List, Tuple, Iterable
|
|
|
3
3
|
from dataclasses import field, dataclass
|
|
4
4
|
|
|
5
5
|
from kirin import ir
|
|
6
|
-
from kirin.rewrite import abc as rewrite_abc
|
|
7
6
|
from kirin.dialects import py, ilist
|
|
7
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
8
8
|
from kirin.analysis.const import lattice
|
|
9
9
|
|
|
10
10
|
from bloqade.analysis import address
|
|
@@ -14,7 +14,7 @@ from bloqade.squin.analysis.schedule import StmtDag
|
|
|
14
14
|
|
|
15
15
|
class MergePolicyABC(abc.ABC):
|
|
16
16
|
@abc.abstractmethod
|
|
17
|
-
def __call__(self, node: ir.Statement) ->
|
|
17
|
+
def __call__(self, node: ir.Statement) -> RewriteResult:
|
|
18
18
|
pass
|
|
19
19
|
|
|
20
20
|
@classmethod
|
|
@@ -141,10 +141,10 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
141
141
|
group_numbers=group_numbers,
|
|
142
142
|
)
|
|
143
143
|
|
|
144
|
-
def __call__(self, node: ir.Statement) ->
|
|
144
|
+
def __call__(self, node: ir.Statement) -> RewriteResult:
|
|
145
145
|
|
|
146
146
|
if node not in self.group_numbers:
|
|
147
|
-
return
|
|
147
|
+
return RewriteResult()
|
|
148
148
|
|
|
149
149
|
group_number = self.group_numbers[node]
|
|
150
150
|
group = self.merge_groups[group_number]
|
|
@@ -154,12 +154,10 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
154
154
|
self.group_has_merged[group_number] = result.has_done_something
|
|
155
155
|
return result
|
|
156
156
|
|
|
157
|
-
if self.group_has_merged
|
|
157
|
+
if self.group_has_merged.setdefault(group_number, False):
|
|
158
158
|
node.delete()
|
|
159
159
|
|
|
160
|
-
return
|
|
161
|
-
has_done_something=self.group_has_merged[group_number]
|
|
162
|
-
)
|
|
160
|
+
return RewriteResult(has_done_something=self.group_has_merged[group_number])
|
|
163
161
|
|
|
164
162
|
def move_and_collect_qubit_list(
|
|
165
163
|
self, qargs: List[ir.SSAValue], node: ir.Statement
|
|
@@ -219,14 +217,14 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
219
217
|
ctrls.append(stmt.ctrls)
|
|
220
218
|
qargs.append(stmt.qargs)
|
|
221
219
|
else:
|
|
222
|
-
return
|
|
220
|
+
return RewriteResult(has_done_something=False)
|
|
223
221
|
|
|
224
222
|
ctrls_values = self.move_and_collect_qubit_list(ctrls, node)
|
|
225
223
|
qargs_values = self.move_and_collect_qubit_list(qargs, node)
|
|
226
224
|
|
|
227
225
|
if ctrls_values is None or qargs_values is None:
|
|
228
226
|
# give up if we cannot determine the address or cannot move the qubits
|
|
229
|
-
return
|
|
227
|
+
return RewriteResult(has_done_something=False)
|
|
230
228
|
|
|
231
229
|
new_ctrls = ilist.New(values=ctrls_values)
|
|
232
230
|
new_qargs = ilist.New(values=qargs_values)
|
|
@@ -238,7 +236,7 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
238
236
|
|
|
239
237
|
node.delete()
|
|
240
238
|
|
|
241
|
-
return
|
|
239
|
+
return RewriteResult(has_done_something=True)
|
|
242
240
|
|
|
243
241
|
def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]):
|
|
244
242
|
return self.rewrite_group_u(node, group)
|
|
@@ -252,13 +250,13 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
252
250
|
elif isinstance(stmt, parallel.UGate):
|
|
253
251
|
qargs.append(stmt.qargs)
|
|
254
252
|
else:
|
|
255
|
-
return
|
|
253
|
+
return RewriteResult(has_done_something=False)
|
|
256
254
|
|
|
257
255
|
assert isinstance(node, (uop.UGate, parallel.UGate))
|
|
258
256
|
qargs_values = self.move_and_collect_qubit_list(qargs, node)
|
|
259
257
|
|
|
260
258
|
if qargs_values is None:
|
|
261
|
-
return
|
|
259
|
+
return RewriteResult(has_done_something=False)
|
|
262
260
|
|
|
263
261
|
new_qargs = ilist.New(values=qargs_values)
|
|
264
262
|
new_gate = parallel.UGate(
|
|
@@ -271,7 +269,7 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
271
269
|
new_gate.insert_before(node)
|
|
272
270
|
node.delete()
|
|
273
271
|
|
|
274
|
-
return
|
|
272
|
+
return RewriteResult(has_done_something=True)
|
|
275
273
|
|
|
276
274
|
def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
|
|
277
275
|
qargs = []
|
|
@@ -282,14 +280,14 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
282
280
|
elif isinstance(stmt, parallel.RZ):
|
|
283
281
|
qargs.append(stmt.qargs)
|
|
284
282
|
else:
|
|
285
|
-
return
|
|
283
|
+
return RewriteResult(has_done_something=False)
|
|
286
284
|
|
|
287
285
|
assert isinstance(node, (uop.RZ, parallel.RZ))
|
|
288
286
|
|
|
289
287
|
qargs_values = self.move_and_collect_qubit_list(qargs, node)
|
|
290
288
|
|
|
291
289
|
if qargs_values is None:
|
|
292
|
-
return
|
|
290
|
+
return RewriteResult(has_done_something=False)
|
|
293
291
|
|
|
294
292
|
new_qargs = ilist.New(values=qargs_values)
|
|
295
293
|
new_gate = parallel.RZ(
|
|
@@ -300,7 +298,7 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
300
298
|
new_gate.insert_before(node)
|
|
301
299
|
node.delete()
|
|
302
300
|
|
|
303
|
-
return
|
|
301
|
+
return RewriteResult(has_done_something=True)
|
|
304
302
|
|
|
305
303
|
def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
|
|
306
304
|
qargs = []
|
|
@@ -310,13 +308,13 @@ class SimpleMergePolicy(MergePolicyABC):
|
|
|
310
308
|
qargs_values = self.move_and_collect_qubit_list(qargs, node)
|
|
311
309
|
|
|
312
310
|
if qargs_values is None:
|
|
313
|
-
return
|
|
311
|
+
return RewriteResult(has_done_something=False)
|
|
314
312
|
|
|
315
313
|
new_node = uop.Barrier(qargs=qargs_values)
|
|
316
314
|
new_node.insert_before(node)
|
|
317
315
|
node.delete()
|
|
318
316
|
|
|
319
|
-
return
|
|
317
|
+
return RewriteResult(has_done_something=True)
|
|
320
318
|
|
|
321
319
|
|
|
322
320
|
class GreedyMixin(MergePolicyABC):
|
|
@@ -385,11 +383,11 @@ class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy):
|
|
|
385
383
|
|
|
386
384
|
|
|
387
385
|
@dataclass
|
|
388
|
-
class UOpToParallelRule(
|
|
386
|
+
class UOpToParallelRule(RewriteRule):
|
|
389
387
|
merge_rewriters: Dict[ir.Block | None, MergePolicyABC]
|
|
390
388
|
|
|
391
|
-
def rewrite_Statement(self, node: ir.Statement) ->
|
|
389
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
392
390
|
merge_rewriter = self.merge_rewriters.get(
|
|
393
|
-
node.parent_block, lambda _:
|
|
391
|
+
node.parent_block, lambda _: RewriteResult()
|
|
394
392
|
)
|
|
395
393
|
return merge_rewriter(node)
|
bloqade/qasm2/types.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from kirin import types
|
|
2
|
+
from kirin.dialects import ilist
|
|
2
3
|
|
|
3
4
|
from bloqade.types import Qubit as Qubit, QubitType as QubitType
|
|
4
5
|
|
|
@@ -15,11 +16,7 @@ class Bit:
|
|
|
15
16
|
pass
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
"""Runtime representation of a quantum register."""
|
|
20
|
-
|
|
21
|
-
def __getitem__(self, index) -> Qubit:
|
|
22
|
-
raise NotImplementedError("cannot call __getitem__ outside of a kernel")
|
|
19
|
+
QReg = ilist.IList[Qubit, types.Any]
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
class CReg:
|
|
@@ -32,7 +29,7 @@ class CReg:
|
|
|
32
29
|
BitType = types.PyClass(Bit)
|
|
33
30
|
"""Kirin type for a classical bit."""
|
|
34
31
|
|
|
35
|
-
QRegType = types.
|
|
32
|
+
QRegType = ilist.IListType[QubitType, types.Any]
|
|
36
33
|
"""Kirin type for a quantum register."""
|
|
37
34
|
|
|
38
35
|
CRegType = types.PyClass(CReg)
|
bloqade/qbraid/schema.py
CHANGED
|
@@ -9,7 +9,7 @@ class Operation(BaseModel, frozen=True, extra="forbid"):
|
|
|
9
9
|
op_type: str = Field(init=False)
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class CZ(Operation):
|
|
12
|
+
class CZ(Operation, frozen=True):
|
|
13
13
|
"""A CZ gate operation.
|
|
14
14
|
|
|
15
15
|
Fields:
|
|
@@ -22,7 +22,7 @@ class CZ(Operation):
|
|
|
22
22
|
participants: Tuple[Union[Tuple[int], Tuple[int, int]], ...]
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class GlobalRz(Operation):
|
|
25
|
+
class GlobalRz(Operation, frozen=True):
|
|
26
26
|
"""GlobalRz operation.
|
|
27
27
|
|
|
28
28
|
Fields:
|
|
@@ -34,7 +34,7 @@ class GlobalRz(Operation):
|
|
|
34
34
|
phi: float
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class GlobalW(Operation):
|
|
37
|
+
class GlobalW(Operation, frozen=True):
|
|
38
38
|
"""GlobalW operation.
|
|
39
39
|
|
|
40
40
|
Fields:
|
|
@@ -48,7 +48,7 @@ class GlobalW(Operation):
|
|
|
48
48
|
phi: float
|
|
49
49
|
|
|
50
50
|
|
|
51
|
-
class LocalRz(Operation):
|
|
51
|
+
class LocalRz(Operation, frozen=True):
|
|
52
52
|
"""LocalRz operation.
|
|
53
53
|
|
|
54
54
|
Fields:
|
|
@@ -63,7 +63,7 @@ class LocalRz(Operation):
|
|
|
63
63
|
phi: float
|
|
64
64
|
|
|
65
65
|
|
|
66
|
-
class LocalW(Operation):
|
|
66
|
+
class LocalW(Operation, frozen=True):
|
|
67
67
|
"""LocalW operation.
|
|
68
68
|
|
|
69
69
|
Fields:
|
|
@@ -80,7 +80,7 @@ class LocalW(Operation):
|
|
|
80
80
|
phi: float
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
class Measurement(Operation):
|
|
83
|
+
class Measurement(Operation, frozen=True):
|
|
84
84
|
"""Measurement operation.
|
|
85
85
|
|
|
86
86
|
Fields:
|
|
@@ -95,9 +95,7 @@ class Measurement(Operation):
|
|
|
95
95
|
participants: Tuple[int, ...]
|
|
96
96
|
|
|
97
97
|
|
|
98
|
-
OperationType =
|
|
99
|
-
"OperationType", bound=Union[CZ, GlobalRz, GlobalW, LocalRz, LocalW, Measurement]
|
|
100
|
-
)
|
|
98
|
+
OperationType = CZ | GlobalRz | GlobalW | LocalRz | LocalW | Measurement
|
|
101
99
|
|
|
102
100
|
|
|
103
101
|
class ErrorModel(BaseModel, frozen=True, extra="forbid"):
|
|
@@ -106,7 +104,7 @@ class ErrorModel(BaseModel, frozen=True, extra="forbid"):
|
|
|
106
104
|
error_model_type: str = Field(init=False)
|
|
107
105
|
|
|
108
106
|
|
|
109
|
-
class PauliErrorModel(ErrorModel):
|
|
107
|
+
class PauliErrorModel(ErrorModel, frozen=True):
|
|
110
108
|
"""Pauli error model.
|
|
111
109
|
|
|
112
110
|
Fields:
|
|
@@ -131,7 +129,7 @@ class ErrorOperation(BaseModel, Generic[ErrorModelType], frozen=True, extra="for
|
|
|
131
129
|
survival_prob: Tuple[float, ...]
|
|
132
130
|
|
|
133
131
|
|
|
134
|
-
class CZError(ErrorOperation[ErrorModelType]):
|
|
132
|
+
class CZError(ErrorOperation[ErrorModelType], frozen=True):
|
|
135
133
|
"""CZError operation.
|
|
136
134
|
|
|
137
135
|
Fields:
|
|
@@ -149,7 +147,7 @@ class CZError(ErrorOperation[ErrorModelType]):
|
|
|
149
147
|
single_error: ErrorModelType
|
|
150
148
|
|
|
151
149
|
|
|
152
|
-
class SingleQubitError(ErrorOperation[ErrorModelType]):
|
|
150
|
+
class SingleQubitError(ErrorOperation[ErrorModelType], frozen=True):
|
|
153
151
|
"""SingleQubitError operation.
|
|
154
152
|
|
|
155
153
|
Fields:
|
bloqade/squin/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
from . import op as op, wire as wire, qubit as qubit
|
|
1
|
+
from . import op as op, wire as wire, noise as noise, qubit as qubit
|
|
2
2
|
from .groups import wired as wired, kernel as kernel
|
|
@@ -15,9 +15,7 @@ class NSitesAnalysis(Forward[Sites]):
|
|
|
15
15
|
keys = ["op.nsites"]
|
|
16
16
|
lattice = Sites
|
|
17
17
|
|
|
18
|
-
# Take a page from
|
|
19
|
-
# I can get the data I want from the SizedTrait
|
|
20
|
-
# and go from there
|
|
18
|
+
# Take a page from how constprop works in Kirin
|
|
21
19
|
|
|
22
20
|
## This gets called before the registry look up
|
|
23
21
|
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
|
|
@@ -25,11 +23,11 @@ class NSitesAnalysis(Forward[Sites]):
|
|
|
25
23
|
if method is not None:
|
|
26
24
|
return method(self, frame, stmt)
|
|
27
25
|
elif stmt.has_trait(HasSites):
|
|
28
|
-
has_sites_trait = stmt.
|
|
26
|
+
has_sites_trait = stmt.get_present_trait(HasSites)
|
|
29
27
|
sites = has_sites_trait.get_sites(stmt)
|
|
30
28
|
return (NumberSites(sites=sites),)
|
|
31
29
|
elif stmt.has_trait(FixedSites):
|
|
32
|
-
sites_trait = stmt.
|
|
30
|
+
sites_trait = stmt.get_present_trait(FixedSites)
|
|
33
31
|
return (NumberSites(sites=sites_trait.data),)
|
|
34
32
|
else:
|
|
35
33
|
return (NoSites(),)
|
|
@@ -37,7 +35,7 @@ class NSitesAnalysis(Forward[Sites]):
|
|
|
37
35
|
# For when no implementation is found for the statement
|
|
38
36
|
def eval_stmt_fallback(
|
|
39
37
|
self, frame: ForwardFrame[Sites], stmt: ir.Statement
|
|
40
|
-
) -> tuple[Sites, ...]: # some form of
|
|
38
|
+
) -> tuple[Sites, ...]: # some form of Sites will go back into the frame
|
|
41
39
|
return tuple(
|
|
42
40
|
(
|
|
43
41
|
self.lattice.top()
|
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from kirin import ir, interp
|
|
1
|
+
from kirin import interp
|
|
4
2
|
|
|
5
3
|
from bloqade.squin import op
|
|
6
4
|
|
|
@@ -52,9 +50,7 @@ class SquinOp(interp.MethodTable):
|
|
|
52
50
|
|
|
53
51
|
if isinstance(op_sites, NumberSites):
|
|
54
52
|
n_sites = op_sites.sites
|
|
55
|
-
|
|
56
|
-
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
|
|
57
|
-
return (NumberSites(sites=n_sites + n_controls),)
|
|
53
|
+
return (NumberSites(sites=n_sites + stmt.n_controls),)
|
|
58
54
|
else:
|
|
59
55
|
return (NoSites(),)
|
|
60
56
|
|
bloqade/squin/groups.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from kirin import ir, passes
|
|
2
2
|
from kirin.prelude import structural_no_opt
|
|
3
3
|
from kirin.dialects import ilist
|
|
4
|
-
|
|
5
|
-
from bloqade.qasm2.rewrite.desugar import IndexingDesugarPass
|
|
4
|
+
from kirin.rewrite.walk import Walk
|
|
6
5
|
|
|
7
6
|
from . import op, wire, qubit
|
|
7
|
+
from .op.rewrite import PyMultToSquinMult
|
|
8
|
+
from .rewrite.measure_desugar import MeasureDesugarRule
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@ir.dialect_group(structural_no_opt.union([op, qubit]))
|
|
@@ -12,27 +13,34 @@ def kernel(self):
|
|
|
12
13
|
fold_pass = passes.Fold(self)
|
|
13
14
|
typeinfer_pass = passes.TypeInfer(self)
|
|
14
15
|
ilist_desugar_pass = ilist.IListDesugar(self)
|
|
15
|
-
|
|
16
|
+
measure_desugar_pass = Walk(MeasureDesugarRule())
|
|
17
|
+
py_mult_to_mult_pass = PyMultToSquinMult(self)
|
|
16
18
|
|
|
17
|
-
def run_pass(method, *, fold=True, typeinfer=True):
|
|
19
|
+
def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
|
|
18
20
|
method.verify()
|
|
19
21
|
if fold:
|
|
20
22
|
fold_pass.fixpoint(method)
|
|
21
23
|
|
|
24
|
+
py_mult_to_mult_pass(method)
|
|
25
|
+
|
|
22
26
|
if typeinfer:
|
|
23
27
|
typeinfer_pass(method)
|
|
28
|
+
measure_desugar_pass.rewrite(method.code)
|
|
29
|
+
|
|
24
30
|
ilist_desugar_pass(method)
|
|
25
|
-
|
|
31
|
+
|
|
26
32
|
if typeinfer:
|
|
27
33
|
typeinfer_pass(method) # fix types after desugaring
|
|
28
|
-
method.
|
|
34
|
+
method.verify_type()
|
|
29
35
|
|
|
30
36
|
return run_pass
|
|
31
37
|
|
|
32
38
|
|
|
33
39
|
@ir.dialect_group(structural_no_opt.union([op, wire]))
|
|
34
40
|
def wired(self):
|
|
41
|
+
py_mult_to_mult_pass = PyMultToSquinMult(self)
|
|
42
|
+
|
|
35
43
|
def run_pass(method):
|
|
36
|
-
|
|
44
|
+
py_mult_to_mult_pass(method)
|
|
37
45
|
|
|
38
46
|
return run_pass
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Put all the proper wrappers here
|
|
2
|
+
|
|
3
|
+
from kirin.lowering import wraps as _wraps
|
|
4
|
+
|
|
5
|
+
from bloqade.squin.op.types import Op
|
|
6
|
+
|
|
7
|
+
from . import stmts as stmts
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@_wraps(stmts.PauliError)
|
|
11
|
+
def pauli_error(basis: Op, p: float) -> Op: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@_wraps(stmts.PPError)
|
|
15
|
+
def pp_error(op: Op, p: float) -> Op: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@_wraps(stmts.Depolarize)
|
|
19
|
+
def depolarize(n_qubits: int, p: float) -> Op: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@_wraps(stmts.PauliChannel)
|
|
23
|
+
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@_wraps(stmts.QubitLoss)
|
|
27
|
+
def qubit_loss(p: float) -> Op: ...
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from kirin import ir, types
|
|
2
|
+
from kirin.decl import info, statement
|
|
3
|
+
|
|
4
|
+
from bloqade.squin.op.types import OpType
|
|
5
|
+
|
|
6
|
+
from ._dialect import dialect
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@statement
|
|
10
|
+
class NoiseChannel(ir.Statement):
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@statement(dialect=dialect)
|
|
15
|
+
class PauliError(NoiseChannel):
|
|
16
|
+
basis: ir.SSAValue = info.argument(OpType)
|
|
17
|
+
p: ir.SSAValue = info.argument(types.Float)
|
|
18
|
+
result: ir.ResultValue = info.result(OpType)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@statement(dialect=dialect)
|
|
22
|
+
class PPError(NoiseChannel):
|
|
23
|
+
"""
|
|
24
|
+
Pauli Product Error
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
op: ir.SSAValue = info.argument(OpType)
|
|
28
|
+
p: ir.SSAValue = info.argument(types.Float)
|
|
29
|
+
result: ir.ResultValue = info.result(OpType)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@statement(dialect=dialect)
|
|
33
|
+
class Depolarize(NoiseChannel):
|
|
34
|
+
"""
|
|
35
|
+
Apply n-qubit depolaize error to qubits
|
|
36
|
+
NOTE For Stim, this can only accept 1 or 2 qubits
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
n_qubits: int = info.attribute(types.Int)
|
|
40
|
+
p: ir.SSAValue = info.argument(types.Float)
|
|
41
|
+
result: ir.ResultValue = info.result(OpType)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@statement(dialect=dialect)
|
|
45
|
+
class PauliChannel(NoiseChannel):
|
|
46
|
+
# NOTE:
|
|
47
|
+
# 1-qubit 3 params px, py, pz
|
|
48
|
+
# 2-qubit 15 params pix, piy, piz, pxi, pxx, pxy, pxz, pyi, pyx ..., pzz
|
|
49
|
+
# TODO add validation for params (maybe during lowering via custom lower?)
|
|
50
|
+
n_qubits: int = info.attribute()
|
|
51
|
+
params: ir.SSAValue = info.argument(types.Tuple[types.Vararg(types.Float)])
|
|
52
|
+
result: ir.ResultValue = info.result(OpType)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@statement(dialect=dialect)
|
|
56
|
+
class QubitLoss(NoiseChannel):
|
|
57
|
+
# NOTE: qubit loss error (not supported by Stim)
|
|
58
|
+
p: ir.SSAValue = info.argument(types.Float)
|
|
59
|
+
result: ir.ResultValue = info.result(OpType)
|
bloqade/squin/op/__init__.py
CHANGED
|
@@ -2,25 +2,47 @@ from kirin import ir as _ir
|
|
|
2
2
|
from kirin.prelude import structural_no_opt as _structural_no_opt
|
|
3
3
|
from kirin.lowering import wraps as _wraps
|
|
4
4
|
|
|
5
|
-
from . import stmts as stmts, types as types
|
|
5
|
+
from . import stmts as stmts, types as types, rewrite as rewrite
|
|
6
6
|
from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
|
|
7
7
|
from ._dialect import dialect as dialect
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
@_wraps(stmts.Kron)
|
|
11
|
-
def kron(lhs: types.Op, rhs: types.Op
|
|
11
|
+
def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@_wraps(stmts.Mult)
|
|
15
|
+
def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@_wraps(stmts.Scale)
|
|
19
|
+
def scale(op: types.Op, factor: complex) -> types.Op: ...
|
|
12
20
|
|
|
13
21
|
|
|
14
22
|
@_wraps(stmts.Adjoint)
|
|
15
|
-
def adjoint(op: types.Op
|
|
23
|
+
def adjoint(op: types.Op) -> types.Op: ...
|
|
16
24
|
|
|
17
25
|
|
|
18
26
|
@_wraps(stmts.Control)
|
|
19
|
-
def control(op: types.Op, *, n_controls: int
|
|
27
|
+
def control(op: types.Op, *, n_controls: int) -> types.Op:
|
|
28
|
+
"""
|
|
29
|
+
Create a controlled operator.
|
|
30
|
+
|
|
31
|
+
Note, that when considering atom loss, the operator will not be applied if
|
|
32
|
+
any of the controls has been lost.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
operator: The operator to apply under the control.
|
|
36
|
+
n_controls: The number qubits to be used as control.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Operator
|
|
40
|
+
"""
|
|
41
|
+
...
|
|
20
42
|
|
|
21
43
|
|
|
22
44
|
@_wraps(stmts.Identity)
|
|
23
|
-
def identity(*,
|
|
45
|
+
def identity(*, sites: int) -> types.Op: ...
|
|
24
46
|
|
|
25
47
|
|
|
26
48
|
@_wraps(stmts.Rot)
|
|
@@ -75,6 +97,14 @@ def spin_n() -> types.Op: ...
|
|
|
75
97
|
def spin_p() -> types.Op: ...
|
|
76
98
|
|
|
77
99
|
|
|
100
|
+
@_wraps(stmts.U3)
|
|
101
|
+
def u(theta: float, phi: float, lam: float) -> types.Op: ...
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@_wraps(stmts.PauliString)
|
|
105
|
+
def pauli_string(*, string: str) -> types.Op: ...
|
|
106
|
+
|
|
107
|
+
|
|
78
108
|
# stdlibs
|
|
79
109
|
@_ir.dialect_group(_structural_no_opt.add(dialect))
|
|
80
110
|
def op(self):
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Rewrite py.binop.mult to Mult stmt"""
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.passes import Pass
|
|
5
|
+
from kirin.rewrite import Walk
|
|
6
|
+
from kirin.dialects import py
|
|
7
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
8
|
+
|
|
9
|
+
from .stmts import Mult, Scale
|
|
10
|
+
from .types import OpType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _PyMultToSquinMult(RewriteRule):
|
|
14
|
+
|
|
15
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
16
|
+
if not isinstance(node, py.Mult):
|
|
17
|
+
return RewriteResult()
|
|
18
|
+
|
|
19
|
+
lhs_is_op = node.lhs.type.is_subseteq(OpType)
|
|
20
|
+
rhs_is_op = node.rhs.type.is_subseteq(OpType)
|
|
21
|
+
|
|
22
|
+
if not lhs_is_op and not rhs_is_op:
|
|
23
|
+
return RewriteResult()
|
|
24
|
+
|
|
25
|
+
if lhs_is_op and rhs_is_op:
|
|
26
|
+
mult = Mult(node.lhs, node.rhs)
|
|
27
|
+
node.replace_by(mult)
|
|
28
|
+
return RewriteResult(has_done_something=True)
|
|
29
|
+
|
|
30
|
+
if lhs_is_op:
|
|
31
|
+
scale = Scale(node.lhs, node.rhs)
|
|
32
|
+
node.replace_by(scale)
|
|
33
|
+
return RewriteResult(has_done_something=True)
|
|
34
|
+
|
|
35
|
+
if rhs_is_op:
|
|
36
|
+
scale = Scale(node.rhs, node.lhs)
|
|
37
|
+
node.replace_by(scale)
|
|
38
|
+
return RewriteResult(has_done_something=True)
|
|
39
|
+
|
|
40
|
+
return RewriteResult()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PyMultToSquinMult(Pass):
|
|
44
|
+
|
|
45
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
|
46
|
+
return Walk(_PyMultToSquinMult()).rewrite(mt.code)
|
bloqade/squin/op/stmts.py
CHANGED
|
@@ -2,8 +2,8 @@ from kirin import ir, types, lowering
|
|
|
2
2
|
from kirin.decl import info, statement
|
|
3
3
|
|
|
4
4
|
from .types import OpType
|
|
5
|
+
from .number import NumberType
|
|
5
6
|
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
|
|
6
|
-
from .complex import Complex
|
|
7
7
|
from ._dialect import dialect
|
|
8
8
|
|
|
9
9
|
|
|
@@ -54,7 +54,7 @@ class Scale(CompositeOp):
|
|
|
54
54
|
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
|
|
55
55
|
is_unitary: bool = info.attribute(default=False)
|
|
56
56
|
op: ir.SSAValue = info.argument(OpType)
|
|
57
|
-
factor: ir.SSAValue = info.argument(
|
|
57
|
+
factor: ir.SSAValue = info.argument(NumberType)
|
|
58
58
|
result: ir.ResultValue = info.result(OpType)
|
|
59
59
|
|
|
60
60
|
|
|
@@ -103,6 +103,15 @@ class ConstantUnitary(ConstantOp):
|
|
|
103
103
|
)
|
|
104
104
|
|
|
105
105
|
|
|
106
|
+
@statement(dialect=dialect)
|
|
107
|
+
class U3(PrimitiveOp):
|
|
108
|
+
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
|
|
109
|
+
theta: ir.SSAValue = info.argument(types.Float)
|
|
110
|
+
phi: ir.SSAValue = info.argument(types.Float)
|
|
111
|
+
lam: ir.SSAValue = info.argument(types.Float)
|
|
112
|
+
result: ir.ResultValue = info.result(OpType)
|
|
113
|
+
|
|
114
|
+
|
|
106
115
|
@statement(dialect=dialect)
|
|
107
116
|
class PhaseOp(PrimitiveOp):
|
|
108
117
|
"""
|
|
@@ -138,6 +147,18 @@ class PauliOp(ConstantUnitary):
|
|
|
138
147
|
pass
|
|
139
148
|
|
|
140
149
|
|
|
150
|
+
@statement(dialect=dialect)
|
|
151
|
+
class PauliString(ConstantUnitary):
|
|
152
|
+
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
|
|
153
|
+
string: str = info.attribute()
|
|
154
|
+
|
|
155
|
+
def verify(self) -> None:
|
|
156
|
+
if not set("XYZ").issuperset(self.string):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Invalid Pauli string: {self.string}. Must be a combination of 'X', 'Y', and 'Z'."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
141
162
|
@statement(dialect=dialect)
|
|
142
163
|
class X(PauliOp):
|
|
143
164
|
pass
|
bloqade/squin/op/types.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
1
3
|
from kirin import types
|
|
2
4
|
|
|
3
5
|
|
|
@@ -6,5 +8,17 @@ class Op:
|
|
|
6
8
|
def __matmul__(self, other: "Op") -> "Op":
|
|
7
9
|
raise NotImplementedError("@ can only be used within a squin kernel program")
|
|
8
10
|
|
|
11
|
+
@overload
|
|
12
|
+
def __mul__(self, other: "Op") -> "Op": ...
|
|
13
|
+
|
|
14
|
+
@overload
|
|
15
|
+
def __mul__(self, other: complex) -> "Op": ...
|
|
16
|
+
|
|
17
|
+
def __mul__(self, other) -> "Op":
|
|
18
|
+
raise NotImplementedError("@ can only be used within a squin kernel program")
|
|
19
|
+
|
|
20
|
+
def __rmul__(self, other: complex) -> "Op":
|
|
21
|
+
raise NotImplementedError("@ can only be used within a squin kernel program")
|
|
22
|
+
|
|
9
23
|
|
|
10
24
|
OpType = types.PyClass(Op)
|