bloqade-circuit 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (74) hide show
  1. bloqade/analysis/address/impls.py +5 -9
  2. bloqade/analysis/address/lattice.py +1 -1
  3. bloqade/analysis/fidelity/__init__.py +1 -0
  4. bloqade/analysis/fidelity/analysis.py +69 -0
  5. bloqade/device.py +130 -0
  6. bloqade/noise/__init__.py +2 -1
  7. bloqade/noise/fidelity.py +51 -0
  8. bloqade/noise/native/model.py +1 -2
  9. bloqade/noise/native/rewrite.py +5 -5
  10. bloqade/noise/native/stmts.py +40 -11
  11. bloqade/pyqrack/__init__.py +8 -2
  12. bloqade/pyqrack/base.py +24 -3
  13. bloqade/pyqrack/device.py +166 -0
  14. bloqade/pyqrack/noise/native.py +1 -2
  15. bloqade/pyqrack/qasm2/core.py +31 -15
  16. bloqade/pyqrack/qasm2/glob.py +28 -0
  17. bloqade/pyqrack/qasm2/uop.py +9 -1
  18. bloqade/pyqrack/reg.py +17 -49
  19. bloqade/pyqrack/squin/__init__.py +0 -0
  20. bloqade/pyqrack/squin/op.py +154 -0
  21. bloqade/pyqrack/squin/qubit.py +85 -0
  22. bloqade/pyqrack/squin/runtime.py +515 -0
  23. bloqade/pyqrack/squin/wire.py +69 -0
  24. bloqade/pyqrack/target.py +9 -2
  25. bloqade/pyqrack/task.py +30 -0
  26. bloqade/qasm2/_wrappers.py +11 -1
  27. bloqade/qasm2/dialects/core/stmts.py +15 -4
  28. bloqade/qasm2/dialects/expr/_emit.py +9 -8
  29. bloqade/qasm2/emit/base.py +4 -2
  30. bloqade/qasm2/emit/gate.py +0 -14
  31. bloqade/qasm2/emit/main.py +19 -15
  32. bloqade/qasm2/emit/target.py +2 -6
  33. bloqade/qasm2/glob.py +1 -1
  34. bloqade/qasm2/parse/lowering.py +124 -1
  35. bloqade/qasm2/passes/glob.py +3 -3
  36. bloqade/qasm2/passes/lift_qubits.py +26 -0
  37. bloqade/qasm2/passes/noise.py +6 -14
  38. bloqade/qasm2/passes/parallel.py +3 -3
  39. bloqade/qasm2/passes/py2qasm.py +1 -2
  40. bloqade/qasm2/passes/qasm2py.py +1 -2
  41. bloqade/qasm2/rewrite/desugar.py +6 -6
  42. bloqade/qasm2/rewrite/glob.py +9 -9
  43. bloqade/qasm2/rewrite/heuristic_noise.py +30 -38
  44. bloqade/qasm2/rewrite/insert_qubits.py +34 -0
  45. bloqade/qasm2/rewrite/native_gates.py +54 -55
  46. bloqade/qasm2/rewrite/parallel_to_uop.py +9 -9
  47. bloqade/qasm2/rewrite/uop_to_parallel.py +20 -22
  48. bloqade/qasm2/types.py +3 -6
  49. bloqade/qbraid/schema.py +10 -12
  50. bloqade/squin/__init__.py +1 -1
  51. bloqade/squin/analysis/nsites/analysis.py +4 -6
  52. bloqade/squin/analysis/nsites/impls.py +2 -6
  53. bloqade/squin/analysis/schedule.py +1 -1
  54. bloqade/squin/groups.py +15 -7
  55. bloqade/squin/noise/__init__.py +27 -0
  56. bloqade/squin/noise/_dialect.py +3 -0
  57. bloqade/squin/noise/stmts.py +59 -0
  58. bloqade/squin/op/__init__.py +35 -5
  59. bloqade/squin/op/number.py +5 -0
  60. bloqade/squin/op/rewrite.py +46 -0
  61. bloqade/squin/op/stmts.py +23 -2
  62. bloqade/squin/op/types.py +14 -0
  63. bloqade/squin/qubit.py +79 -11
  64. bloqade/squin/rewrite/__init__.py +0 -0
  65. bloqade/squin/rewrite/measure_desugar.py +33 -0
  66. bloqade/squin/wire.py +31 -2
  67. bloqade/stim/emit/stim.py +1 -1
  68. bloqade/task.py +94 -0
  69. bloqade/visual/animation/base.py +25 -15
  70. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/METADATA +8 -2
  71. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/RECORD +73 -52
  72. bloqade/squin/op/complex.py +0 -6
  73. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/WHEEL +0 -0
  74. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,6 @@
1
1
  from typing import Literal
2
2
 
3
3
  from kirin import interp
4
- from kirin.emit.exceptions import EmitError
5
4
 
6
5
  from bloqade.qasm2.parse import ast
7
6
  from bloqade.qasm2.types import QubitType
@@ -19,16 +18,18 @@ class EmitExpr(interp.MethodTable):
19
18
  self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
20
19
  ):
21
20
 
21
+ args: list[ast.Node] = []
22
22
  cparams, qparams = [], []
23
- for arg in stmt.body.blocks[0].args[1:]:
24
- name = frame.get(arg)
25
- if not isinstance(name, ast.Name):
26
- raise EmitError("expected ast.Name")
23
+ for arg in stmt.body.blocks[0].args:
24
+ assert arg.name is not None
25
+
26
+ args.append(ast.Name(id=arg.name))
27
27
  if arg.type.is_subseteq(QubitType):
28
- qparams.append(name.id)
28
+ qparams.append(arg.name)
29
29
  else:
30
- cparams.append(name.id)
31
- emit.run_ssacfg_region(frame, stmt.body)
30
+ cparams.append(arg.name)
31
+
32
+ emit.run_ssacfg_region(frame, stmt.body, tuple(args))
32
33
  emit.output = ast.Gate(
33
34
  name=stmt.sym_name,
34
35
  cparams=cparams,
@@ -36,8 +36,10 @@ class EmitQASM2Base(
36
36
  )
37
37
  return self
38
38
 
39
- def new_frame(self, code: ir.Statement) -> EmitQASM2Frame:
40
- return EmitQASM2Frame.from_func_like(code)
39
+ def initialize_frame(
40
+ self, code: ir.Statement, *, has_parent_access: bool = False
41
+ ) -> EmitQASM2Frame[StmtType]:
42
+ return EmitQASM2Frame(code, has_parent_access=has_parent_access)
41
43
 
42
44
  def run_method(
43
45
  self, method: ir.Method, args: tuple[ast.Node | None, ...]
@@ -86,17 +86,3 @@ class Func(interp.MethodTable):
86
86
  @interp.impl(func.ConstantNone)
87
87
  def ignore(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
88
88
  return ()
89
-
90
- @interp.impl(func.Function)
91
- def emit_func(
92
- self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: func.Function
93
- ):
94
- emit.run_ssacfg_region(frame, stmt.body)
95
- cparams, qparams = [], []
96
- for arg in stmt.args:
97
- if arg.type.is_subseteq(QubitType):
98
- qparams.append(frame.get(arg))
99
- else:
100
- cparams.append(frame.get(arg))
101
- emit.output = ast.Gate(stmt.sym_name, cparams, qparams, frame.body)
102
- return ()
@@ -24,7 +24,7 @@ class Func(interp.MethodTable):
24
24
  ):
25
25
  from bloqade.qasm2.dialects import glob, noise, parallel
26
26
 
27
- emit.run_ssacfg_region(frame, stmt.body)
27
+ emit.run_ssacfg_region(frame, stmt.body, ())
28
28
  if emit.dialects.data.intersection(
29
29
  (parallel.dialect, glob.dialect, noise.dialect)
30
30
  ):
@@ -51,12 +51,14 @@ class Cf(interp.MethodTable):
51
51
  self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: cf.ConditionalBranch
52
52
  ):
53
53
  cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
54
- body_frame = emit.new_frame(stmt)
55
- body_frame.entries.update(frame.entries)
56
- body_frame.set_values(
57
- stmt.then_successor.args, frame.get_values(stmt.then_arguments)
58
- )
59
- emit.emit_block(body_frame, stmt.then_successor)
54
+
55
+ with emit.new_frame(stmt) as body_frame:
56
+ body_frame.entries.update(frame.entries)
57
+ body_frame.set_values(
58
+ stmt.then_successor.args, frame.get_values(stmt.then_arguments)
59
+ )
60
+ emit.emit_block(body_frame, stmt.then_successor)
61
+
60
62
  frame.body.append(
61
63
  ast.IfStmt(
62
64
  cond,
@@ -91,15 +93,17 @@ class Scf(interp.MethodTable):
91
93
  )
92
94
 
93
95
  cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
94
- then_frame = emit.new_frame(stmt)
95
- then_frame.entries.update(frame.entries)
96
- emit.emit_block(then_frame, stmt.then_body.blocks[0])
97
- frame.body.append(
98
- ast.IfStmt(
99
- cond,
100
- body=then_frame.body, # type: ignore
96
+
97
+ with emit.new_frame(stmt) as then_frame:
98
+ then_frame.entries.update(frame.entries)
99
+ emit.emit_block(then_frame, stmt.then_body.blocks[0])
100
+ frame.body.append(
101
+ ast.IfStmt(
102
+ cond,
103
+ body=then_frame.body, # type: ignore
104
+ )
101
105
  )
102
- )
106
+
103
107
  term = stmt.then_body.blocks[0].last_stmt
104
108
  if isinstance(term, scf.Yield):
105
109
  return then_frame.get_values(term.values)
@@ -101,9 +101,7 @@ class QASM2:
101
101
 
102
102
  Py2QASM(entry.dialects)(entry)
103
103
  target_main = EmitQASM2Main(self.main_target)
104
- target_main.run(
105
- entry, tuple(ast.Name(name) for name in entry.arg_names[1:])
106
- ).expect()
104
+ target_main.run(entry, ())
107
105
 
108
106
  main_program = target_main.output
109
107
  assert main_program is not None, f"failed to emit {entry.sym_name}"
@@ -133,9 +131,7 @@ class QASM2:
133
131
 
134
132
  Py2QASM(fn.dialects)(fn)
135
133
 
136
- target_gate.run(
137
- fn, tuple(ast.Name(name) for name in fn.arg_names[1:])
138
- ).expect()
134
+ target_gate.run(fn, tuple(ast.Name(name) for name in fn.arg_names[1:]))
139
135
  assert target_gate.output is not None, f"failed to emit {fn.sym_name}"
140
136
  extra.append(target_gate.output)
141
137
 
bloqade/qasm2/glob.py CHANGED
@@ -11,7 +11,7 @@ from .dialects import glob
11
11
 
12
12
  @wraps(glob.UGate)
13
13
  def u(
14
- theta: float, phi: float, lam: float, registers: ilist.IList[QReg, Any] | list
14
+ registers: ilist.IList[QReg, Any] | list, theta: float, phi: float, lam: float
15
15
  ) -> None:
16
16
  """Apply a U gate to all qubits in the input registers.
17
17
 
@@ -1,3 +1,5 @@
1
+ import os
2
+ import pathlib
1
3
  from typing import Any
2
4
  from dataclasses import field, dataclass
3
5
 
@@ -17,6 +19,119 @@ class QASM2(lowering.LoweringABC[ast.Node]):
17
19
  hint_show_lineno: bool = field(default=True, kw_only=True)
18
20
  stacktrace: bool = field(default=True, kw_only=True)
19
21
 
22
+ def loads(
23
+ self,
24
+ source: str,
25
+ kernel_name: str,
26
+ *,
27
+ returns: str | None = None,
28
+ globals: dict[str, Any] | None = None,
29
+ file: str | None = None,
30
+ lineno_offset: int = 0,
31
+ col_offset: int = 0,
32
+ compactify: bool = True,
33
+ ) -> ir.Method:
34
+ from ..parse import loads
35
+
36
+ # TODO: add source info
37
+ stmt = loads(source)
38
+
39
+ state = lowering.State(
40
+ self,
41
+ file=file,
42
+ lineno_offset=lineno_offset,
43
+ col_offset=col_offset,
44
+ )
45
+ with state.frame(
46
+ [stmt],
47
+ globals=globals,
48
+ finalize_next=False,
49
+ ) as frame:
50
+ try:
51
+ self.visit(state, stmt)
52
+ # append return statement with the return values
53
+ if returns is not None:
54
+ return_value = frame.get(returns)
55
+ if return_value is None:
56
+ raise lowering.BuildError(f"Cannot find return value {returns}")
57
+ else:
58
+ return_value = func.ConstantNone()
59
+
60
+ return_node = frame.push(func.Return(value_or_stmt=return_value))
61
+
62
+ except lowering.BuildError as e:
63
+ hint = state.error_hint(
64
+ e,
65
+ max_lines=self.max_lines,
66
+ indent=self.hint_indent,
67
+ show_lineno=self.hint_show_lineno,
68
+ )
69
+ if self.stacktrace:
70
+ raise Exception(
71
+ f"{e.args[0]}\n\n{hint}",
72
+ *e.args[1:],
73
+ ) from e
74
+ else:
75
+ e.args = (hint,)
76
+ raise e
77
+
78
+ region = frame.curr_region
79
+
80
+ if compactify:
81
+ from kirin.rewrite import Walk, CFGCompactify
82
+
83
+ Walk(CFGCompactify()).rewrite(region)
84
+
85
+ code = func.Function(
86
+ sym_name=kernel_name,
87
+ signature=func.Signature((), return_node.value.type),
88
+ body=region,
89
+ )
90
+
91
+ return ir.Method(
92
+ mod=None,
93
+ py_func=None,
94
+ sym_name=kernel_name,
95
+ arg_names=[],
96
+ dialects=self.dialects,
97
+ code=code,
98
+ )
99
+
100
+ def loadfile(
101
+ self,
102
+ file: str | pathlib.Path,
103
+ *,
104
+ kernel_name: str | None = None,
105
+ returns: str | None = None,
106
+ globals: dict[str, Any] | None = None,
107
+ lineno_offset: int = 0,
108
+ col_offset: int = 0,
109
+ compactify: bool = True,
110
+ ) -> ir.Method:
111
+ if isinstance(file, str):
112
+ file = pathlib.Path(*os.path.split(file))
113
+
114
+ if not file.is_file() or not file.name.endswith(".qasm"):
115
+ raise ValueError("File must be a .qasm file")
116
+
117
+ kernel_name = (
118
+ file.name.replace(".qasm", "") if kernel_name is None else kernel_name
119
+ )
120
+
121
+ with file.open("r") as f:
122
+ source = f.read()
123
+
124
+ return self.loads(
125
+ source,
126
+ kernel_name,
127
+ returns=returns,
128
+ globals=globals,
129
+ file=str(file),
130
+ lineno_offset=lineno_offset,
131
+ col_offset=col_offset,
132
+ compactify=compactify,
133
+ )
134
+
20
135
  def run(
21
136
  self,
22
137
  stmt: ast.Node,
@@ -85,6 +200,10 @@ class QASM2(lowering.LoweringABC[ast.Node]):
85
200
  stmt = expr.ConstInt(value=value)
86
201
  elif isinstance(value, float):
87
202
  stmt = expr.ConstFloat(value=value)
203
+ else:
204
+ raise lowering.BuildError(
205
+ f"Expected value of type float or int, got {type(value)}."
206
+ )
88
207
  state.current_frame.push(stmt)
89
208
  return stmt.result
90
209
 
@@ -99,6 +218,8 @@ class QASM2(lowering.LoweringABC[ast.Node]):
99
218
  dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"]
100
219
  elif isinstance(node.header, ast.Kirin):
101
220
  dialects = node.header.dialects
221
+ else:
222
+ raise lowering.BuildError(f"Unexpected node header {node.header}")
102
223
 
103
224
  for dialect in dialects:
104
225
  if dialect not in allowed:
@@ -278,7 +399,7 @@ class QASM2(lowering.LoweringABC[ast.Node]):
278
399
  def visit_UnaryOp(self, state: lowering.State[ast.Node], node: ast.UnaryOp):
279
400
  if node.op == "-":
280
401
  stmt = expr.Neg(value=state.lower(node.operand).expect_one())
281
- return stmt.result
402
+ return state.current_frame.push(stmt).result
282
403
  else:
283
404
  return state.lower(node.operand).expect_one()
284
405
 
@@ -295,6 +416,8 @@ class QASM2(lowering.LoweringABC[ast.Node]):
295
416
  stmt = core.QRegGet(reg, addr.result)
296
417
  elif reg.type.is_subseteq(CRegType):
297
418
  stmt = core.CRegGet(reg, addr.result)
419
+ else:
420
+ raise lowering.BuildError(f"Unexpected register type {reg.type}")
298
421
  return state.current_frame.push(stmt).result
299
422
 
300
423
  def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call):
@@ -4,7 +4,7 @@ which converts global gates to single qubit gates.
4
4
  """
5
5
 
6
6
  from kirin import ir
7
- from kirin.rewrite import cse, dce, walk, result
7
+ from kirin.rewrite import abc, cse, dce, walk
8
8
  from kirin.passes.abc import Pass
9
9
  from kirin.passes.fold import Fold
10
10
  from kirin.rewrite.fixpoint import Fixpoint
@@ -54,7 +54,7 @@ class GlobalToUOP(Pass):
54
54
  frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
55
55
  return GlobalToUOpRule(frame.entries)
56
56
 
57
- def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
57
+ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
58
58
  rewriter = walk.Walk(self.generate_rule(mt))
59
59
  result = rewriter.rewrite(mt.code)
60
60
 
@@ -106,7 +106,7 @@ class GlobalToParallel(Pass):
106
106
  frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
107
107
  return GlobalToParallelRule(frame.entries)
108
108
 
109
- def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
109
+ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
110
110
  rewriter = walk.Walk(self.generate_rule(mt))
111
111
  result = rewriter.rewrite(mt.code)
112
112
 
@@ -0,0 +1,26 @@
1
+ from kirin import ir
2
+ from kirin.passes import Pass
3
+ from kirin.rewrite import (
4
+ Walk,
5
+ Chain,
6
+ Fixpoint,
7
+ ConstantFold,
8
+ CommonSubexpressionElimination,
9
+ )
10
+ from kirin.passes.hint_const import HintConst
11
+
12
+ from bloqade.qasm2.rewrite.insert_qubits import InsertGetQubit
13
+
14
+
15
+ class LiftQubits(Pass):
16
+ """This pass lifts the creation of qubits to the block where the register is defined."""
17
+
18
+ def unsafe_run(self, mt: ir.Method):
19
+ result = Walk(InsertGetQubit()).rewrite(mt.code)
20
+ result = HintConst(self.dialects).unsafe_run(mt).join(result)
21
+ result = (
22
+ Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
23
+ .rewrite(mt.code)
24
+ .join(result)
25
+ )
26
+ return result
@@ -4,16 +4,13 @@ from kirin import ir
4
4
  from kirin.passes import Pass
5
5
  from kirin.rewrite import (
6
6
  Walk,
7
- Chain,
8
7
  Fixpoint,
9
- ConstantFold,
10
8
  DeadCodeElimination,
11
- CommonSubexpressionElimination,
12
9
  )
13
- from kirin.rewrite.result import RewriteResult
14
10
 
15
11
  from bloqade.noise import native
16
12
  from bloqade.analysis import address
13
+ from bloqade.qasm2.passes.lift_qubits import LiftQubits
17
14
  from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
18
15
 
19
16
 
@@ -38,24 +35,19 @@ class NoisePass(Pass):
38
35
  self.address_analysis = address.AddressAnalysis(self.dialects)
39
36
 
40
37
  def unsafe_run(self, mt: ir.Method):
41
- result = RewriteResult()
42
-
43
- frame, res = self.address_analysis.run_analysis(mt, no_raise=False)
38
+ result = LiftQubits(self.dialects).unsafe_run(mt)
39
+ frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
44
40
  result = (
45
41
  Walk(
46
42
  NoiseRewriteRule(
47
43
  address_analysis=frame.entries,
48
44
  noise_model=self.noise_model,
49
45
  gate_noise_params=self.gate_noise_params,
50
- )
46
+ ),
47
+ reverse=True,
51
48
  )
52
49
  .rewrite(mt.code)
53
50
  .join(result)
54
51
  )
55
- rule = Chain(
56
- ConstantFold(),
57
- DeadCodeElimination(),
58
- CommonSubexpressionElimination(),
59
- )
60
- result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
52
+ result = Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result)
61
53
  return result
@@ -16,7 +16,7 @@ from kirin.rewrite import (
16
16
  ConstantFold,
17
17
  DeadCodeElimination,
18
18
  CommonSubexpressionElimination,
19
- result,
19
+ abc,
20
20
  )
21
21
  from kirin.analysis import const
22
22
 
@@ -84,7 +84,7 @@ class ParallelToUOp(Pass):
84
84
 
85
85
  return ParallelToUOpRule(id_map=id_map, address_analysis=frame.entries)
86
86
 
87
- def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
87
+ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
88
88
  result = Walk(self.generate_rule(mt)).rewrite(mt.code)
89
89
  rule = Chain(
90
90
  ConstantFold(),
@@ -140,7 +140,7 @@ class UOpToParallel(Pass):
140
140
  def __post_init__(self):
141
141
  self.constprop = const.Propagate(self.dialects)
142
142
 
143
- def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
143
+ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
144
144
  result = Walk(RaiseRegisterRule()).rewrite(mt.code)
145
145
 
146
146
  # do not run the parallelization because registers are not at the top
@@ -4,8 +4,7 @@ from kirin import ir
4
4
  from kirin.passes import Pass
5
5
  from kirin.rewrite import Walk, Fixpoint
6
6
  from kirin.dialects import py, math
7
- from kirin.rewrite.abc import RewriteRule
8
- from kirin.rewrite.result import RewriteResult
7
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
9
8
 
10
9
  from bloqade.qasm2.dialects import core, expr
11
10
 
@@ -6,8 +6,7 @@ from kirin import ir
6
6
  from kirin.passes import Pass
7
7
  from kirin.rewrite import Walk, Fixpoint
8
8
  from kirin.dialects import py, math
9
- from kirin.rewrite.abc import RewriteRule
10
- from kirin.rewrite.result import RewriteResult
9
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
11
10
 
12
11
  from bloqade.qasm2.dialects import core, expr
13
12
 
@@ -2,27 +2,27 @@ from dataclasses import dataclass
2
2
 
3
3
  from kirin import ir
4
4
  from kirin.passes import Pass
5
- from kirin.rewrite import abc, walk, result
5
+ from kirin.rewrite import abc, walk
6
6
  from kirin.dialects import py
7
7
 
8
8
  from bloqade.qasm2.dialects import core
9
9
 
10
10
 
11
11
  class IndexingDesugarRule(abc.RewriteRule):
12
- def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
12
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
13
13
  if isinstance(node, py.indexing.GetItem):
14
14
  if node.obj.type.is_subseteq(core.QRegType):
15
15
  node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
16
- return result.RewriteResult(has_done_something=True)
16
+ return abc.RewriteResult(has_done_something=True)
17
17
  elif node.obj.type.is_subseteq(core.CRegType):
18
18
  node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
19
- return result.RewriteResult(has_done_something=True)
19
+ return abc.RewriteResult(has_done_something=True)
20
20
 
21
- return result.RewriteResult()
21
+ return abc.RewriteResult()
22
22
 
23
23
 
24
24
  @dataclass
25
25
  class IndexingDesugarPass(Pass):
26
- def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
26
+ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
27
27
 
28
28
  return walk.Walk(IndexingDesugarRule()).rewrite(mt.code)
@@ -2,7 +2,7 @@ from typing import Dict, List
2
2
  from dataclasses import dataclass
3
3
 
4
4
  from kirin import ir
5
- from kirin.rewrite import abc, result
5
+ from kirin.rewrite import abc
6
6
  from kirin.dialects import py, ilist
7
7
 
8
8
  from bloqade import qasm2
@@ -47,18 +47,18 @@ class GlobalRewriteBase:
47
47
  @dataclass
48
48
  class GlobalToParallelRule(abc.RewriteRule, GlobalRewriteBase):
49
49
 
50
- def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
50
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
51
51
  if type(node) in glob.dialect.stmts:
52
52
  return getattr(self, f"rewrite_{node.name}")(node)
53
53
 
54
- return result.RewriteResult()
54
+ return abc.RewriteResult()
55
55
 
56
56
  def rewrite_ugate(self, node: glob.UGate):
57
57
 
58
58
  new_stmts, qubit_ssa = self.get_qubit_ssa(node)
59
59
 
60
60
  if qubit_ssa is None:
61
- return result.RewriteResult()
61
+ return abc.RewriteResult()
62
62
 
63
63
  new_stmts.append(qargs := ilist.New(values=qubit_ssa))
64
64
  new_stmts.append(
@@ -72,24 +72,24 @@ class GlobalToParallelRule(abc.RewriteRule, GlobalRewriteBase):
72
72
 
73
73
  node.delete()
74
74
 
75
- return result.RewriteResult(has_done_something=True)
75
+ return abc.RewriteResult(has_done_something=True)
76
76
 
77
77
 
78
78
  @dataclass
79
79
  class GlobalToUOpRule(abc.RewriteRule, GlobalRewriteBase):
80
80
 
81
- def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
81
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
82
82
  if type(node) in glob.dialect.stmts:
83
83
  return getattr(self, f"rewrite_{node.name}")(node)
84
84
 
85
- return result.RewriteResult()
85
+ return abc.RewriteResult()
86
86
 
87
87
  def rewrite_ugate(self, node: glob.UGate):
88
88
 
89
89
  new_stmts, qubit_ssa = self.get_qubit_ssa(node)
90
90
 
91
91
  if qubit_ssa is None:
92
- return result.RewriteResult()
92
+ return abc.RewriteResult()
93
93
 
94
94
  for qarg in qubit_ssa:
95
95
  new_stmts.append(
@@ -100,4 +100,4 @@ class GlobalToUOpRule(abc.RewriteRule, GlobalRewriteBase):
100
100
  stmt.insert_before(node)
101
101
 
102
102
  node.delete()
103
- return result.RewriteResult(has_done_something=True)
103
+ return abc.RewriteResult(has_done_something=True)