bloqade-circuit 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (74) hide show
  1. bloqade/analysis/address/impls.py +5 -9
  2. bloqade/analysis/address/lattice.py +1 -1
  3. bloqade/analysis/fidelity/__init__.py +1 -0
  4. bloqade/analysis/fidelity/analysis.py +69 -0
  5. bloqade/device.py +130 -0
  6. bloqade/noise/__init__.py +2 -1
  7. bloqade/noise/fidelity.py +51 -0
  8. bloqade/noise/native/model.py +1 -2
  9. bloqade/noise/native/rewrite.py +5 -5
  10. bloqade/noise/native/stmts.py +40 -11
  11. bloqade/pyqrack/__init__.py +8 -2
  12. bloqade/pyqrack/base.py +24 -3
  13. bloqade/pyqrack/device.py +166 -0
  14. bloqade/pyqrack/noise/native.py +1 -2
  15. bloqade/pyqrack/qasm2/core.py +31 -15
  16. bloqade/pyqrack/qasm2/glob.py +28 -0
  17. bloqade/pyqrack/qasm2/uop.py +9 -1
  18. bloqade/pyqrack/reg.py +17 -49
  19. bloqade/pyqrack/squin/__init__.py +0 -0
  20. bloqade/pyqrack/squin/op.py +154 -0
  21. bloqade/pyqrack/squin/qubit.py +85 -0
  22. bloqade/pyqrack/squin/runtime.py +515 -0
  23. bloqade/pyqrack/squin/wire.py +69 -0
  24. bloqade/pyqrack/target.py +9 -2
  25. bloqade/pyqrack/task.py +30 -0
  26. bloqade/qasm2/_wrappers.py +11 -1
  27. bloqade/qasm2/dialects/core/stmts.py +15 -4
  28. bloqade/qasm2/dialects/expr/_emit.py +9 -8
  29. bloqade/qasm2/emit/base.py +4 -2
  30. bloqade/qasm2/emit/gate.py +0 -14
  31. bloqade/qasm2/emit/main.py +19 -15
  32. bloqade/qasm2/emit/target.py +2 -6
  33. bloqade/qasm2/glob.py +1 -1
  34. bloqade/qasm2/parse/lowering.py +124 -1
  35. bloqade/qasm2/passes/glob.py +3 -3
  36. bloqade/qasm2/passes/lift_qubits.py +26 -0
  37. bloqade/qasm2/passes/noise.py +6 -14
  38. bloqade/qasm2/passes/parallel.py +3 -3
  39. bloqade/qasm2/passes/py2qasm.py +1 -2
  40. bloqade/qasm2/passes/qasm2py.py +1 -2
  41. bloqade/qasm2/rewrite/desugar.py +6 -6
  42. bloqade/qasm2/rewrite/glob.py +9 -9
  43. bloqade/qasm2/rewrite/heuristic_noise.py +30 -38
  44. bloqade/qasm2/rewrite/insert_qubits.py +34 -0
  45. bloqade/qasm2/rewrite/native_gates.py +54 -55
  46. bloqade/qasm2/rewrite/parallel_to_uop.py +9 -9
  47. bloqade/qasm2/rewrite/uop_to_parallel.py +20 -22
  48. bloqade/qasm2/types.py +3 -6
  49. bloqade/qbraid/schema.py +10 -12
  50. bloqade/squin/__init__.py +1 -1
  51. bloqade/squin/analysis/nsites/analysis.py +4 -6
  52. bloqade/squin/analysis/nsites/impls.py +2 -6
  53. bloqade/squin/analysis/schedule.py +1 -1
  54. bloqade/squin/groups.py +15 -7
  55. bloqade/squin/noise/__init__.py +27 -0
  56. bloqade/squin/noise/_dialect.py +3 -0
  57. bloqade/squin/noise/stmts.py +59 -0
  58. bloqade/squin/op/__init__.py +35 -5
  59. bloqade/squin/op/number.py +5 -0
  60. bloqade/squin/op/rewrite.py +46 -0
  61. bloqade/squin/op/stmts.py +23 -2
  62. bloqade/squin/op/types.py +14 -0
  63. bloqade/squin/qubit.py +79 -11
  64. bloqade/squin/rewrite/__init__.py +0 -0
  65. bloqade/squin/rewrite/measure_desugar.py +33 -0
  66. bloqade/squin/wire.py +31 -2
  67. bloqade/stim/emit/stim.py +1 -1
  68. bloqade/task.py +94 -0
  69. bloqade/visual/animation/base.py +25 -15
  70. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/METADATA +8 -2
  71. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/RECORD +73 -52
  72. bloqade/squin/op/complex.py +0 -6
  73. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/WHEEL +0 -0
  74. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -3,8 +3,8 @@ from typing import Dict, List, Tuple, Iterable
3
3
  from dataclasses import field, dataclass
4
4
 
5
5
  from kirin import ir
6
- from kirin.rewrite import abc as rewrite_abc
7
6
  from kirin.dialects import py, ilist
7
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
8
8
  from kirin.analysis.const import lattice
9
9
 
10
10
  from bloqade.analysis import address
@@ -14,7 +14,7 @@ from bloqade.squin.analysis.schedule import StmtDag
14
14
 
15
15
  class MergePolicyABC(abc.ABC):
16
16
  @abc.abstractmethod
17
- def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
17
+ def __call__(self, node: ir.Statement) -> RewriteResult:
18
18
  pass
19
19
 
20
20
  @classmethod
@@ -141,10 +141,10 @@ class SimpleMergePolicy(MergePolicyABC):
141
141
  group_numbers=group_numbers,
142
142
  )
143
143
 
144
- def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
144
+ def __call__(self, node: ir.Statement) -> RewriteResult:
145
145
 
146
146
  if node not in self.group_numbers:
147
- return rewrite_abc.RewriteResult()
147
+ return RewriteResult()
148
148
 
149
149
  group_number = self.group_numbers[node]
150
150
  group = self.merge_groups[group_number]
@@ -154,12 +154,10 @@ class SimpleMergePolicy(MergePolicyABC):
154
154
  self.group_has_merged[group_number] = result.has_done_something
155
155
  return result
156
156
 
157
- if self.group_has_merged[group_number]:
157
+ if self.group_has_merged.setdefault(group_number, False):
158
158
  node.delete()
159
159
 
160
- return rewrite_abc.RewriteResult(
161
- has_done_something=self.group_has_merged[group_number]
162
- )
160
+ return RewriteResult(has_done_something=self.group_has_merged[group_number])
163
161
 
164
162
  def move_and_collect_qubit_list(
165
163
  self, qargs: List[ir.SSAValue], node: ir.Statement
@@ -219,14 +217,14 @@ class SimpleMergePolicy(MergePolicyABC):
219
217
  ctrls.append(stmt.ctrls)
220
218
  qargs.append(stmt.qargs)
221
219
  else:
222
- return rewrite_abc.RewriteResult(has_done_something=False)
220
+ return RewriteResult(has_done_something=False)
223
221
 
224
222
  ctrls_values = self.move_and_collect_qubit_list(ctrls, node)
225
223
  qargs_values = self.move_and_collect_qubit_list(qargs, node)
226
224
 
227
225
  if ctrls_values is None or qargs_values is None:
228
226
  # give up if we cannot determine the address or cannot move the qubits
229
- return rewrite_abc.RewriteResult(has_done_something=False)
227
+ return RewriteResult(has_done_something=False)
230
228
 
231
229
  new_ctrls = ilist.New(values=ctrls_values)
232
230
  new_qargs = ilist.New(values=qargs_values)
@@ -238,7 +236,7 @@ class SimpleMergePolicy(MergePolicyABC):
238
236
 
239
237
  node.delete()
240
238
 
241
- return rewrite_abc.RewriteResult(has_done_something=True)
239
+ return RewriteResult(has_done_something=True)
242
240
 
243
241
  def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]):
244
242
  return self.rewrite_group_u(node, group)
@@ -252,13 +250,13 @@ class SimpleMergePolicy(MergePolicyABC):
252
250
  elif isinstance(stmt, parallel.UGate):
253
251
  qargs.append(stmt.qargs)
254
252
  else:
255
- return rewrite_abc.RewriteResult(has_done_something=False)
253
+ return RewriteResult(has_done_something=False)
256
254
 
257
255
  assert isinstance(node, (uop.UGate, parallel.UGate))
258
256
  qargs_values = self.move_and_collect_qubit_list(qargs, node)
259
257
 
260
258
  if qargs_values is None:
261
- return rewrite_abc.RewriteResult(has_done_something=False)
259
+ return RewriteResult(has_done_something=False)
262
260
 
263
261
  new_qargs = ilist.New(values=qargs_values)
264
262
  new_gate = parallel.UGate(
@@ -271,7 +269,7 @@ class SimpleMergePolicy(MergePolicyABC):
271
269
  new_gate.insert_before(node)
272
270
  node.delete()
273
271
 
274
- return rewrite_abc.RewriteResult(has_done_something=True)
272
+ return RewriteResult(has_done_something=True)
275
273
 
276
274
  def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
277
275
  qargs = []
@@ -282,14 +280,14 @@ class SimpleMergePolicy(MergePolicyABC):
282
280
  elif isinstance(stmt, parallel.RZ):
283
281
  qargs.append(stmt.qargs)
284
282
  else:
285
- return rewrite_abc.RewriteResult(has_done_something=False)
283
+ return RewriteResult(has_done_something=False)
286
284
 
287
285
  assert isinstance(node, (uop.RZ, parallel.RZ))
288
286
 
289
287
  qargs_values = self.move_and_collect_qubit_list(qargs, node)
290
288
 
291
289
  if qargs_values is None:
292
- return rewrite_abc.RewriteResult(has_done_something=False)
290
+ return RewriteResult(has_done_something=False)
293
291
 
294
292
  new_qargs = ilist.New(values=qargs_values)
295
293
  new_gate = parallel.RZ(
@@ -300,7 +298,7 @@ class SimpleMergePolicy(MergePolicyABC):
300
298
  new_gate.insert_before(node)
301
299
  node.delete()
302
300
 
303
- return rewrite_abc.RewriteResult(has_done_something=True)
301
+ return RewriteResult(has_done_something=True)
304
302
 
305
303
  def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
306
304
  qargs = []
@@ -310,13 +308,13 @@ class SimpleMergePolicy(MergePolicyABC):
310
308
  qargs_values = self.move_and_collect_qubit_list(qargs, node)
311
309
 
312
310
  if qargs_values is None:
313
- return rewrite_abc.RewriteResult(has_done_something=False)
311
+ return RewriteResult(has_done_something=False)
314
312
 
315
313
  new_node = uop.Barrier(qargs=qargs_values)
316
314
  new_node.insert_before(node)
317
315
  node.delete()
318
316
 
319
- return rewrite_abc.RewriteResult(has_done_something=True)
317
+ return RewriteResult(has_done_something=True)
320
318
 
321
319
 
322
320
  class GreedyMixin(MergePolicyABC):
@@ -385,11 +383,11 @@ class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy):
385
383
 
386
384
 
387
385
  @dataclass
388
- class UOpToParallelRule(rewrite_abc.RewriteRule):
386
+ class UOpToParallelRule(RewriteRule):
389
387
  merge_rewriters: Dict[ir.Block | None, MergePolicyABC]
390
388
 
391
- def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
389
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
392
390
  merge_rewriter = self.merge_rewriters.get(
393
- node.parent_block, lambda _: rewrite_abc.RewriteResult()
391
+ node.parent_block, lambda _: RewriteResult()
394
392
  )
395
393
  return merge_rewriter(node)
bloqade/qasm2/types.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from kirin import types
2
+ from kirin.dialects import ilist
2
3
 
3
4
  from bloqade.types import Qubit as Qubit, QubitType as QubitType
4
5
 
@@ -15,11 +16,7 @@ class Bit:
15
16
  pass
16
17
 
17
18
 
18
- class QReg:
19
- """Runtime representation of a quantum register."""
20
-
21
- def __getitem__(self, index) -> Qubit:
22
- raise NotImplementedError("cannot call __getitem__ outside of a kernel")
19
+ QReg = ilist.IList[Qubit, types.Any]
23
20
 
24
21
 
25
22
  class CReg:
@@ -32,7 +29,7 @@ class CReg:
32
29
  BitType = types.PyClass(Bit)
33
30
  """Kirin type for a classical bit."""
34
31
 
35
- QRegType = types.PyClass(QReg)
32
+ QRegType = ilist.IListType[QubitType, types.Any]
36
33
  """Kirin type for a quantum register."""
37
34
 
38
35
  CRegType = types.PyClass(CReg)
bloqade/qbraid/schema.py CHANGED
@@ -9,7 +9,7 @@ class Operation(BaseModel, frozen=True, extra="forbid"):
9
9
  op_type: str = Field(init=False)
10
10
 
11
11
 
12
- class CZ(Operation):
12
+ class CZ(Operation, frozen=True):
13
13
  """A CZ gate operation.
14
14
 
15
15
  Fields:
@@ -22,7 +22,7 @@ class CZ(Operation):
22
22
  participants: Tuple[Union[Tuple[int], Tuple[int, int]], ...]
23
23
 
24
24
 
25
- class GlobalRz(Operation):
25
+ class GlobalRz(Operation, frozen=True):
26
26
  """GlobalRz operation.
27
27
 
28
28
  Fields:
@@ -34,7 +34,7 @@ class GlobalRz(Operation):
34
34
  phi: float
35
35
 
36
36
 
37
- class GlobalW(Operation):
37
+ class GlobalW(Operation, frozen=True):
38
38
  """GlobalW operation.
39
39
 
40
40
  Fields:
@@ -48,7 +48,7 @@ class GlobalW(Operation):
48
48
  phi: float
49
49
 
50
50
 
51
- class LocalRz(Operation):
51
+ class LocalRz(Operation, frozen=True):
52
52
  """LocalRz operation.
53
53
 
54
54
  Fields:
@@ -63,7 +63,7 @@ class LocalRz(Operation):
63
63
  phi: float
64
64
 
65
65
 
66
- class LocalW(Operation):
66
+ class LocalW(Operation, frozen=True):
67
67
  """LocalW operation.
68
68
 
69
69
  Fields:
@@ -80,7 +80,7 @@ class LocalW(Operation):
80
80
  phi: float
81
81
 
82
82
 
83
- class Measurement(Operation):
83
+ class Measurement(Operation, frozen=True):
84
84
  """Measurement operation.
85
85
 
86
86
  Fields:
@@ -95,9 +95,7 @@ class Measurement(Operation):
95
95
  participants: Tuple[int, ...]
96
96
 
97
97
 
98
- OperationType = TypeVar(
99
- "OperationType", bound=Union[CZ, GlobalRz, GlobalW, LocalRz, LocalW, Measurement]
100
- )
98
+ OperationType = CZ | GlobalRz | GlobalW | LocalRz | LocalW | Measurement
101
99
 
102
100
 
103
101
  class ErrorModel(BaseModel, frozen=True, extra="forbid"):
@@ -106,7 +104,7 @@ class ErrorModel(BaseModel, frozen=True, extra="forbid"):
106
104
  error_model_type: str = Field(init=False)
107
105
 
108
106
 
109
- class PauliErrorModel(ErrorModel):
107
+ class PauliErrorModel(ErrorModel, frozen=True):
110
108
  """Pauli error model.
111
109
 
112
110
  Fields:
@@ -131,7 +129,7 @@ class ErrorOperation(BaseModel, Generic[ErrorModelType], frozen=True, extra="for
131
129
  survival_prob: Tuple[float, ...]
132
130
 
133
131
 
134
- class CZError(ErrorOperation[ErrorModelType]):
132
+ class CZError(ErrorOperation[ErrorModelType], frozen=True):
135
133
  """CZError operation.
136
134
 
137
135
  Fields:
@@ -149,7 +147,7 @@ class CZError(ErrorOperation[ErrorModelType]):
149
147
  single_error: ErrorModelType
150
148
 
151
149
 
152
- class SingleQubitError(ErrorOperation[ErrorModelType]):
150
+ class SingleQubitError(ErrorOperation[ErrorModelType], frozen=True):
153
151
  """SingleQubitError operation.
154
152
 
155
153
  Fields:
bloqade/squin/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from . import op as op, wire as wire, qubit as qubit
1
+ from . import op as op, wire as wire, noise as noise, qubit as qubit
2
2
  from .groups import wired as wired, kernel as kernel
@@ -15,9 +15,7 @@ class NSitesAnalysis(Forward[Sites]):
15
15
  keys = ["op.nsites"]
16
16
  lattice = Sites
17
17
 
18
- # Take a page from const prop in Kirin,
19
- # I can get the data I want from the SizedTrait
20
- # and go from there
18
+ # Take a page from how constprop works in Kirin
21
19
 
22
20
  ## This gets called before the registry look up
23
21
  def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
@@ -25,11 +23,11 @@ class NSitesAnalysis(Forward[Sites]):
25
23
  if method is not None:
26
24
  return method(self, frame, stmt)
27
25
  elif stmt.has_trait(HasSites):
28
- has_sites_trait = stmt.get_trait(HasSites)
26
+ has_sites_trait = stmt.get_present_trait(HasSites)
29
27
  sites = has_sites_trait.get_sites(stmt)
30
28
  return (NumberSites(sites=sites),)
31
29
  elif stmt.has_trait(FixedSites):
32
- sites_trait = stmt.get_trait(FixedSites)
30
+ sites_trait = stmt.get_present_trait(FixedSites)
33
31
  return (NumberSites(sites=sites_trait.data),)
34
32
  else:
35
33
  return (NoSites(),)
@@ -37,7 +35,7 @@ class NSitesAnalysis(Forward[Sites]):
37
35
  # For when no implementation is found for the statement
38
36
  def eval_stmt_fallback(
39
37
  self, frame: ForwardFrame[Sites], stmt: ir.Statement
40
- ) -> tuple[Sites, ...]: # some form of Shape will go back into the frame
38
+ ) -> tuple[Sites, ...]: # some form of Sites will go back into the frame
41
39
  return tuple(
42
40
  (
43
41
  self.lattice.top()
@@ -1,6 +1,4 @@
1
- from typing import cast
2
-
3
- from kirin import ir, interp
1
+ from kirin import interp
4
2
 
5
3
  from bloqade.squin import op
6
4
 
@@ -52,9 +50,7 @@ class SquinOp(interp.MethodTable):
52
50
 
53
51
  if isinstance(op_sites, NumberSites):
54
52
  n_sites = op_sites.sites
55
- n_controls_attr = stmt.get_attr_or_prop("n_controls")
56
- n_controls = cast(ir.PyAttr[int], n_controls_attr).data
57
- return (NumberSites(sites=n_sites + n_controls),)
53
+ return (NumberSites(sites=n_sites + stmt.n_controls),)
58
54
  else:
59
55
  return (NoSites(),)
60
56
 
@@ -226,7 +226,7 @@ class DagScheduleAnalysis(Forward[GateSchedule]):
226
226
  if args is None:
227
227
  args = tuple(self.lattice.top() for _ in mt.args)
228
228
 
229
- self.run(mt, args, kwargs).expect()
229
+ self.run(mt, args, kwargs)
230
230
  return self.stmt_dags
231
231
 
232
232
 
bloqade/squin/groups.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from kirin import ir, passes
2
2
  from kirin.prelude import structural_no_opt
3
3
  from kirin.dialects import ilist
4
-
5
- from bloqade.qasm2.rewrite.desugar import IndexingDesugarPass
4
+ from kirin.rewrite.walk import Walk
6
5
 
7
6
  from . import op, wire, qubit
7
+ from .op.rewrite import PyMultToSquinMult
8
+ from .rewrite.measure_desugar import MeasureDesugarRule
8
9
 
9
10
 
10
11
  @ir.dialect_group(structural_no_opt.union([op, qubit]))
@@ -12,27 +13,34 @@ def kernel(self):
12
13
  fold_pass = passes.Fold(self)
13
14
  typeinfer_pass = passes.TypeInfer(self)
14
15
  ilist_desugar_pass = ilist.IListDesugar(self)
15
- indexing_desugar_pass = IndexingDesugarPass(self)
16
+ measure_desugar_pass = Walk(MeasureDesugarRule())
17
+ py_mult_to_mult_pass = PyMultToSquinMult(self)
16
18
 
17
- def run_pass(method, *, fold=True, typeinfer=True):
19
+ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
18
20
  method.verify()
19
21
  if fold:
20
22
  fold_pass.fixpoint(method)
21
23
 
24
+ py_mult_to_mult_pass(method)
25
+
22
26
  if typeinfer:
23
27
  typeinfer_pass(method)
28
+ measure_desugar_pass.rewrite(method.code)
29
+
24
30
  ilist_desugar_pass(method)
25
- indexing_desugar_pass(method)
31
+
26
32
  if typeinfer:
27
33
  typeinfer_pass(method) # fix types after desugaring
28
- method.code.typecheck()
34
+ method.verify_type()
29
35
 
30
36
  return run_pass
31
37
 
32
38
 
33
39
  @ir.dialect_group(structural_no_opt.union([op, wire]))
34
40
  def wired(self):
41
+ py_mult_to_mult_pass = PyMultToSquinMult(self)
42
+
35
43
  def run_pass(method):
36
- pass
44
+ py_mult_to_mult_pass(method)
37
45
 
38
46
  return run_pass
@@ -0,0 +1,27 @@
1
+ # Put all the proper wrappers here
2
+
3
+ from kirin.lowering import wraps as _wraps
4
+
5
+ from bloqade.squin.op.types import Op
6
+
7
+ from . import stmts as stmts
8
+
9
+
10
+ @_wraps(stmts.PauliError)
11
+ def pauli_error(basis: Op, p: float) -> Op: ...
12
+
13
+
14
+ @_wraps(stmts.PPError)
15
+ def pp_error(op: Op, p: float) -> Op: ...
16
+
17
+
18
+ @_wraps(stmts.Depolarize)
19
+ def depolarize(n_qubits: int, p: float) -> Op: ...
20
+
21
+
22
+ @_wraps(stmts.PauliChannel)
23
+ def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
24
+
25
+
26
+ @_wraps(stmts.QubitLoss)
27
+ def qubit_loss(p: float) -> Op: ...
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect(name="squin.noise")
@@ -0,0 +1,59 @@
1
+ from kirin import ir, types
2
+ from kirin.decl import info, statement
3
+
4
+ from bloqade.squin.op.types import OpType
5
+
6
+ from ._dialect import dialect
7
+
8
+
9
+ @statement
10
+ class NoiseChannel(ir.Statement):
11
+ pass
12
+
13
+
14
+ @statement(dialect=dialect)
15
+ class PauliError(NoiseChannel):
16
+ basis: ir.SSAValue = info.argument(OpType)
17
+ p: ir.SSAValue = info.argument(types.Float)
18
+ result: ir.ResultValue = info.result(OpType)
19
+
20
+
21
+ @statement(dialect=dialect)
22
+ class PPError(NoiseChannel):
23
+ """
24
+ Pauli Product Error
25
+ """
26
+
27
+ op: ir.SSAValue = info.argument(OpType)
28
+ p: ir.SSAValue = info.argument(types.Float)
29
+ result: ir.ResultValue = info.result(OpType)
30
+
31
+
32
+ @statement(dialect=dialect)
33
+ class Depolarize(NoiseChannel):
34
+ """
35
+ Apply n-qubit depolaize error to qubits
36
+ NOTE For Stim, this can only accept 1 or 2 qubits
37
+ """
38
+
39
+ n_qubits: int = info.attribute(types.Int)
40
+ p: ir.SSAValue = info.argument(types.Float)
41
+ result: ir.ResultValue = info.result(OpType)
42
+
43
+
44
+ @statement(dialect=dialect)
45
+ class PauliChannel(NoiseChannel):
46
+ # NOTE:
47
+ # 1-qubit 3 params px, py, pz
48
+ # 2-qubit 15 params pix, piy, piz, pxi, pxx, pxy, pxz, pyi, pyx ..., pzz
49
+ # TODO add validation for params (maybe during lowering via custom lower?)
50
+ n_qubits: int = info.attribute()
51
+ params: ir.SSAValue = info.argument(types.Tuple[types.Vararg(types.Float)])
52
+ result: ir.ResultValue = info.result(OpType)
53
+
54
+
55
+ @statement(dialect=dialect)
56
+ class QubitLoss(NoiseChannel):
57
+ # NOTE: qubit loss error (not supported by Stim)
58
+ p: ir.SSAValue = info.argument(types.Float)
59
+ result: ir.ResultValue = info.result(OpType)
@@ -2,25 +2,47 @@ from kirin import ir as _ir
2
2
  from kirin.prelude import structural_no_opt as _structural_no_opt
3
3
  from kirin.lowering import wraps as _wraps
4
4
 
5
- from . import stmts as stmts, types as types
5
+ from . import stmts as stmts, types as types, rewrite as rewrite
6
6
  from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
7
7
  from ._dialect import dialect as dialect
8
8
 
9
9
 
10
10
  @_wraps(stmts.Kron)
11
- def kron(lhs: types.Op, rhs: types.Op, *, is_unitary: bool = False) -> types.Op: ...
11
+ def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12
+
13
+
14
+ @_wraps(stmts.Mult)
15
+ def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
16
+
17
+
18
+ @_wraps(stmts.Scale)
19
+ def scale(op: types.Op, factor: complex) -> types.Op: ...
12
20
 
13
21
 
14
22
  @_wraps(stmts.Adjoint)
15
- def adjoint(op: types.Op, *, is_unitary: bool = False) -> types.Op: ...
23
+ def adjoint(op: types.Op) -> types.Op: ...
16
24
 
17
25
 
18
26
  @_wraps(stmts.Control)
19
- def control(op: types.Op, *, n_controls: int, is_unitary: bool = False) -> types.Op: ...
27
+ def control(op: types.Op, *, n_controls: int) -> types.Op:
28
+ """
29
+ Create a controlled operator.
30
+
31
+ Note, that when considering atom loss, the operator will not be applied if
32
+ any of the controls has been lost.
33
+
34
+ Args:
35
+ operator: The operator to apply under the control.
36
+ n_controls: The number qubits to be used as control.
37
+
38
+ Returns:
39
+ Operator
40
+ """
41
+ ...
20
42
 
21
43
 
22
44
  @_wraps(stmts.Identity)
23
- def identity(*, size: int) -> types.Op: ...
45
+ def identity(*, sites: int) -> types.Op: ...
24
46
 
25
47
 
26
48
  @_wraps(stmts.Rot)
@@ -75,6 +97,14 @@ def spin_n() -> types.Op: ...
75
97
  def spin_p() -> types.Op: ...
76
98
 
77
99
 
100
+ @_wraps(stmts.U3)
101
+ def u(theta: float, phi: float, lam: float) -> types.Op: ...
102
+
103
+
104
+ @_wraps(stmts.PauliString)
105
+ def pauli_string(*, string: str) -> types.Op: ...
106
+
107
+
78
108
  # stdlibs
79
109
  @_ir.dialect_group(_structural_no_opt.add(dialect))
80
110
  def op(self):
@@ -0,0 +1,5 @@
1
+ import numbers
2
+
3
+ from kirin.ir.attrs.types import PyClass
4
+
5
+ NumberType = PyClass(numbers.Number)
@@ -0,0 +1,46 @@
1
+ """Rewrite py.binop.mult to Mult stmt"""
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import Walk
6
+ from kirin.dialects import py
7
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
8
+
9
+ from .stmts import Mult, Scale
10
+ from .types import OpType
11
+
12
+
13
+ class _PyMultToSquinMult(RewriteRule):
14
+
15
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
16
+ if not isinstance(node, py.Mult):
17
+ return RewriteResult()
18
+
19
+ lhs_is_op = node.lhs.type.is_subseteq(OpType)
20
+ rhs_is_op = node.rhs.type.is_subseteq(OpType)
21
+
22
+ if not lhs_is_op and not rhs_is_op:
23
+ return RewriteResult()
24
+
25
+ if lhs_is_op and rhs_is_op:
26
+ mult = Mult(node.lhs, node.rhs)
27
+ node.replace_by(mult)
28
+ return RewriteResult(has_done_something=True)
29
+
30
+ if lhs_is_op:
31
+ scale = Scale(node.lhs, node.rhs)
32
+ node.replace_by(scale)
33
+ return RewriteResult(has_done_something=True)
34
+
35
+ if rhs_is_op:
36
+ scale = Scale(node.rhs, node.lhs)
37
+ node.replace_by(scale)
38
+ return RewriteResult(has_done_something=True)
39
+
40
+ return RewriteResult()
41
+
42
+
43
+ class PyMultToSquinMult(Pass):
44
+
45
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
46
+ return Walk(_PyMultToSquinMult()).rewrite(mt.code)
bloqade/squin/op/stmts.py CHANGED
@@ -2,8 +2,8 @@ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
4
  from .types import OpType
5
+ from .number import NumberType
5
6
  from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
6
- from .complex import Complex
7
7
  from ._dialect import dialect
8
8
 
9
9
 
@@ -54,7 +54,7 @@ class Scale(CompositeOp):
54
54
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
55
55
  is_unitary: bool = info.attribute(default=False)
56
56
  op: ir.SSAValue = info.argument(OpType)
57
- factor: ir.SSAValue = info.argument(Complex)
57
+ factor: ir.SSAValue = info.argument(NumberType)
58
58
  result: ir.ResultValue = info.result(OpType)
59
59
 
60
60
 
@@ -103,6 +103,15 @@ class ConstantUnitary(ConstantOp):
103
103
  )
104
104
 
105
105
 
106
+ @statement(dialect=dialect)
107
+ class U3(PrimitiveOp):
108
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
109
+ theta: ir.SSAValue = info.argument(types.Float)
110
+ phi: ir.SSAValue = info.argument(types.Float)
111
+ lam: ir.SSAValue = info.argument(types.Float)
112
+ result: ir.ResultValue = info.result(OpType)
113
+
114
+
106
115
  @statement(dialect=dialect)
107
116
  class PhaseOp(PrimitiveOp):
108
117
  """
@@ -138,6 +147,18 @@ class PauliOp(ConstantUnitary):
138
147
  pass
139
148
 
140
149
 
150
+ @statement(dialect=dialect)
151
+ class PauliString(ConstantUnitary):
152
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
153
+ string: str = info.attribute()
154
+
155
+ def verify(self) -> None:
156
+ if not set("XYZ").issuperset(self.string):
157
+ raise ValueError(
158
+ f"Invalid Pauli string: {self.string}. Must be a combination of 'X', 'Y', and 'Z'."
159
+ )
160
+
161
+
141
162
  @statement(dialect=dialect)
142
163
  class X(PauliOp):
143
164
  pass
bloqade/squin/op/types.py CHANGED
@@ -1,3 +1,5 @@
1
+ from typing import overload
2
+
1
3
  from kirin import types
2
4
 
3
5
 
@@ -6,5 +8,17 @@ class Op:
6
8
  def __matmul__(self, other: "Op") -> "Op":
7
9
  raise NotImplementedError("@ can only be used within a squin kernel program")
8
10
 
11
+ @overload
12
+ def __mul__(self, other: "Op") -> "Op": ...
13
+
14
+ @overload
15
+ def __mul__(self, other: complex) -> "Op": ...
16
+
17
+ def __mul__(self, other) -> "Op":
18
+ raise NotImplementedError("@ can only be used within a squin kernel program")
19
+
20
+ def __rmul__(self, other: complex) -> "Op":
21
+ raise NotImplementedError("@ can only be used within a squin kernel program")
22
+
9
23
 
10
24
  OpType = types.PyClass(Op)