sonolus.py 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sonolus.py might be problematic. Click here for more details.

Files changed (90) hide show
  1. sonolus/backend/blocks.py +756 -756
  2. sonolus/backend/excepthook.py +37 -37
  3. sonolus/backend/finalize.py +77 -69
  4. sonolus/backend/interpret.py +7 -7
  5. sonolus/backend/ir.py +29 -3
  6. sonolus/backend/mode.py +24 -24
  7. sonolus/backend/node.py +40 -40
  8. sonolus/backend/ops.py +197 -197
  9. sonolus/backend/optimize/__init__.py +0 -0
  10. sonolus/backend/optimize/allocate.py +126 -0
  11. sonolus/backend/optimize/constant_evaluation.py +374 -0
  12. sonolus/backend/optimize/copy_coalesce.py +85 -0
  13. sonolus/backend/optimize/dead_code.py +185 -0
  14. sonolus/backend/optimize/dominance.py +96 -0
  15. sonolus/backend/{flow.py → optimize/flow.py} +122 -92
  16. sonolus/backend/optimize/inlining.py +137 -0
  17. sonolus/backend/optimize/liveness.py +177 -0
  18. sonolus/backend/optimize/optimize.py +44 -0
  19. sonolus/backend/optimize/passes.py +52 -0
  20. sonolus/backend/optimize/simplify.py +191 -0
  21. sonolus/backend/optimize/ssa.py +200 -0
  22. sonolus/backend/place.py +17 -25
  23. sonolus/backend/utils.py +58 -48
  24. sonolus/backend/visitor.py +1151 -882
  25. sonolus/build/cli.py +7 -1
  26. sonolus/build/compile.py +88 -90
  27. sonolus/build/engine.py +10 -5
  28. sonolus/build/level.py +24 -23
  29. sonolus/build/node.py +43 -43
  30. sonolus/script/archetype.py +438 -139
  31. sonolus/script/array.py +27 -10
  32. sonolus/script/array_like.py +297 -0
  33. sonolus/script/bucket.py +253 -191
  34. sonolus/script/containers.py +257 -51
  35. sonolus/script/debug.py +26 -10
  36. sonolus/script/easing.py +365 -0
  37. sonolus/script/effect.py +191 -131
  38. sonolus/script/engine.py +71 -4
  39. sonolus/script/globals.py +303 -269
  40. sonolus/script/instruction.py +205 -151
  41. sonolus/script/internal/__init__.py +5 -5
  42. sonolus/script/internal/builtin_impls.py +255 -144
  43. sonolus/script/{callbacks.py → internal/callbacks.py} +127 -127
  44. sonolus/script/internal/constant.py +139 -0
  45. sonolus/script/internal/context.py +26 -9
  46. sonolus/script/internal/descriptor.py +17 -17
  47. sonolus/script/internal/dict_impl.py +65 -0
  48. sonolus/script/internal/generic.py +6 -9
  49. sonolus/script/internal/impl.py +38 -13
  50. sonolus/script/internal/introspection.py +17 -14
  51. sonolus/script/internal/math_impls.py +121 -0
  52. sonolus/script/internal/native.py +40 -38
  53. sonolus/script/internal/random.py +67 -0
  54. sonolus/script/internal/range.py +81 -0
  55. sonolus/script/internal/transient.py +51 -0
  56. sonolus/script/internal/tuple_impl.py +113 -0
  57. sonolus/script/internal/value.py +3 -3
  58. sonolus/script/interval.py +338 -112
  59. sonolus/script/iterator.py +167 -214
  60. sonolus/script/level.py +24 -0
  61. sonolus/script/num.py +80 -48
  62. sonolus/script/options.py +257 -191
  63. sonolus/script/particle.py +190 -157
  64. sonolus/script/pointer.py +30 -30
  65. sonolus/script/print.py +102 -81
  66. sonolus/script/project.py +8 -0
  67. sonolus/script/quad.py +263 -0
  68. sonolus/script/record.py +47 -16
  69. sonolus/script/runtime.py +52 -1
  70. sonolus/script/sprite.py +418 -333
  71. sonolus/script/text.py +409 -407
  72. sonolus/script/timing.py +114 -42
  73. sonolus/script/transform.py +332 -48
  74. sonolus/script/ui.py +216 -160
  75. sonolus/script/values.py +6 -13
  76. sonolus/script/vec.py +196 -78
  77. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/METADATA +1 -1
  78. sonolus_py-0.1.5.dist-info/RECORD +89 -0
  79. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/WHEEL +1 -1
  80. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/licenses/LICENSE +21 -21
  81. sonolus/backend/allocate.py +0 -51
  82. sonolus/backend/optimize.py +0 -9
  83. sonolus/backend/passes.py +0 -6
  84. sonolus/backend/simplify.py +0 -30
  85. sonolus/script/comptime.py +0 -160
  86. sonolus/script/graphics.py +0 -150
  87. sonolus/script/math.py +0 -92
  88. sonolus/script/range.py +0 -58
  89. sonolus_py-0.1.3.dist-info/RECORD +0 -75
  90. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections import deque
5
+
6
+ from sonolus.backend.optimize.flow import BasicBlock
7
+
8
+
9
+ class CompilerPass(ABC):
10
+ def requires(self) -> set[CompilerPass]:
11
+ return set()
12
+
13
+ def preserves(self) -> set[CompilerPass] | None:
14
+ return None
15
+
16
+ def destroys(self) -> set[CompilerPass] | None:
17
+ return None
18
+
19
+ def applies(self) -> set[CompilerPass]:
20
+ return {self}
21
+
22
+ def exists_after(self, passes: set[CompilerPass]) -> set[CompilerPass]:
23
+ preserved = self.preserves()
24
+ destroyed = self.destroys()
25
+ if destroyed is None and preserved is None:
26
+ return self.applies()
27
+ if preserved is not None:
28
+ passes = {p for p in passes if p in preserved}
29
+ if destroyed is not None:
30
+ passes = {p for p in passes if p not in destroyed}
31
+ return passes | self.applies()
32
+
33
+ @abstractmethod
34
+ def run(self, entry: BasicBlock) -> BasicBlock:
35
+ pass
36
+
37
+
38
+ def run_passes(entry: BasicBlock, passes: list[CompilerPass]) -> BasicBlock:
39
+ active_passes = set()
40
+ queue = deque(passes)
41
+ while queue:
42
+ if len(queue) > 99:
43
+ raise RuntimeError("Likely unsatisfiable pass requirements")
44
+ current_pass = queue.popleft()
45
+ missing_requirements = current_pass.requires() - active_passes
46
+ if missing_requirements:
47
+ queue.appendleft(current_pass)
48
+ queue.extendleft(missing_requirements)
49
+ continue
50
+ entry = current_pass.run(entry)
51
+ active_passes = current_pass.exists_after(active_passes)
52
+ return entry
@@ -0,0 +1,191 @@
1
+ from sonolus.backend.ir import IRConst, IRGet, IRPureInstr, IRSet
2
+ from sonolus.backend.ops import Op
3
+ from sonolus.backend.optimize.flow import BasicBlock, traverse_cfg_preorder
4
+ from sonolus.backend.optimize.passes import CompilerPass
5
+
6
+
7
+ class CoalesceFlow(CompilerPass):
8
+ def run(self, entry: BasicBlock) -> BasicBlock:
9
+ queue = [entry]
10
+ processed = set()
11
+ while queue:
12
+ block = queue.pop()
13
+ if block in processed:
14
+ continue
15
+ processed.add(block)
16
+ for edge in block.outgoing:
17
+ while True:
18
+ dst = edge.dst
19
+ if dst.phis or dst.statements or len(dst.outgoing) != 1 or dst is block:
20
+ break
21
+ next_dst = next(iter(dst.outgoing)).dst
22
+ if next_dst.phis:
23
+ break
24
+ dst.incoming.remove(edge)
25
+ if not dst.incoming:
26
+ for dst_edge in dst.outgoing:
27
+ dst_edge.dst.incoming.remove(dst_edge)
28
+ processed.add(dst)
29
+ edge.dst = next_dst
30
+ next_dst.incoming.add(edge)
31
+ if dst is edge.dst:
32
+ break
33
+ default_edge = next((edge for edge in block.outgoing if edge.cond is None), None)
34
+ if default_edge is not None:
35
+ for edge in [*block.outgoing]:
36
+ if edge is default_edge:
37
+ continue
38
+ if edge.dst is default_edge.dst:
39
+ block.outgoing.remove(edge)
40
+ edge.dst.incoming.remove(edge)
41
+ if len(block.outgoing) != 1:
42
+ queue.extend(edge.dst for edge in block.outgoing)
43
+ continue
44
+ next_block = next(iter(block.outgoing)).dst
45
+ if len(next_block.incoming) != 1:
46
+ queue.append(next_block)
47
+ if not block.statements and not block.phis and not next_block.phis:
48
+ for edge in block.incoming:
49
+ edge.dst = next_block
50
+ next_block.incoming.add(edge)
51
+ for edge in block.outgoing: # There should be exactly one
52
+ next_block.incoming.remove(edge)
53
+ if block is entry:
54
+ entry = next_block
55
+ continue
56
+ for p, args in next_block.phis.items():
57
+ if block not in args:
58
+ continue
59
+ block.statements.append(IRSet(p, IRGet(args[block])))
60
+ block.statements.extend(next_block.statements)
61
+ block.test = next_block.test
62
+ block.outgoing = next_block.outgoing
63
+ for edge in block.outgoing:
64
+ edge.src = block
65
+ dst = edge.dst
66
+ for args in dst.phis.values():
67
+ if next_block in args:
68
+ args[block] = args.pop(next_block)
69
+ processed.add(next_block)
70
+ queue.extend(edge.dst for edge in block.outgoing)
71
+ processed.remove(block)
72
+ queue.append(block)
73
+ return entry
74
+
75
+
76
+ class RewriteToSwitch(CompilerPass):
77
+ """Rewrite if-else chains to switch statements.
78
+
79
+ Note that this needs inlining (and dead code elimination) to be run first to really do anything useful.
80
+ """
81
+
82
+ def run(self, entry: BasicBlock) -> BasicBlock:
83
+ self.ifs_to_switch(entry)
84
+ self.combine_blocks(entry)
85
+ self.remove_unreachable(entry)
86
+ return entry
87
+
88
+ def ifs_to_switch(self, entry: BasicBlock):
89
+ for block in traverse_cfg_preorder(entry):
90
+ if len(block.outgoing) != 2 or {edge.cond for edge in block.outgoing} != {None, 0}:
91
+ continue
92
+ test = block.test
93
+ if not isinstance(test, IRPureInstr) or test.op != Op.Equal:
94
+ continue
95
+ assert len(test.args) == 2
96
+ if isinstance(test.args[0], IRConst):
97
+ const, other = test.args
98
+ elif isinstance(test.args[1], IRConst):
99
+ other, const = test.args
100
+ else:
101
+ continue
102
+ block.test = other
103
+ for edge in block.outgoing:
104
+ if edge.cond is None:
105
+ edge.cond = const.value
106
+ else:
107
+ edge.cond = None
108
+
109
+ def combine_blocks(self, entry: BasicBlock):
110
+ queue = [entry]
111
+ processed = set()
112
+ while queue:
113
+ block = queue.pop()
114
+ if block in processed:
115
+ continue
116
+ processed.add(block)
117
+ queue.extend(edge.dst for edge in block.outgoing)
118
+
119
+ default_edge = next((edge for edge in block.outgoing if edge.cond is None), None)
120
+ if default_edge is None:
121
+ continue
122
+
123
+ next_block = default_edge.dst
124
+ if (
125
+ len(next_block.incoming) > 1
126
+ or next_block.statements
127
+ or next_block.phis
128
+ or block.test != next_block.test
129
+ or block is next_block
130
+ or next_block is entry
131
+ ):
132
+ continue
133
+
134
+ outgoing_by_cond = {edge.cond: edge for edge in block.outgoing}
135
+ assert len(outgoing_by_cond) == len(block.outgoing)
136
+ outgoing_by_cond.pop(None)
137
+ for edge in next_block.outgoing:
138
+ if edge.cond in outgoing_by_cond:
139
+ # This edge is unreachable since an equivalent edge would have been taken
140
+ edge.dst.incoming.remove(edge)
141
+ continue
142
+ outgoing_by_cond[edge.cond] = edge
143
+ edge.src = block
144
+ for args in edge.dst.phis.values():
145
+ if next_block in args:
146
+ args[block] = args.pop(next_block)
147
+ block.outgoing = set(outgoing_by_cond.values())
148
+ processed.add(next_block)
149
+ queue.append(block)
150
+ processed.remove(block)
151
+
152
+ def remove_unreachable(self, entry: BasicBlock):
153
+ reachable = {*traverse_cfg_preorder(entry)}
154
+ for block in traverse_cfg_preorder(entry):
155
+ block.incoming = {edge for edge in block.incoming if edge.src in reachable}
156
+ block.outgoing = {edge for edge in block.outgoing if edge.dst in reachable}
157
+
158
+
159
+ class NormalizeSwitch(CompilerPass):
160
+ """Normalize branches like cond -> case a, case a + b, case a + 2b to ((cond - a) / b) -> case 0, case 1, case 2."""
161
+
162
+ def run(self, entry: BasicBlock) -> BasicBlock:
163
+ for block in traverse_cfg_preorder(entry):
164
+ cases = {edge.cond for edge in block.outgoing}
165
+ if len(cases) <= 2:
166
+ continue
167
+ assert None in cases, "Non-terminal blocks should always have a default edge"
168
+ cases.remove(None)
169
+ offset, stride = self.get_offset_stride(cases)
170
+ if offset is None or (offset == 0 and stride == 1):
171
+ continue
172
+ for edge in block.outgoing:
173
+ if edge.cond is None:
174
+ continue
175
+ edge.cond = (edge.cond - offset) // stride
176
+ if offset != 0:
177
+ block.test = IRPureInstr(Op.Subtract, [block.test, IRConst(offset)])
178
+ if stride != 1:
179
+ block.test = IRPureInstr(Op.Divide, [block.test, IRConst(stride)])
180
+ return entry
181
+
182
+ def get_offset_stride(self, cases: set[int]) -> tuple[int | None, int | None]:
183
+ cases = sorted(cases)
184
+ offset = cases[0]
185
+ stride = cases[1] - offset
186
+ if int(offset) != offset or int(stride) != stride:
187
+ return None, None
188
+ for i, case in enumerate(cases[2:], 2):
189
+ if case != offset + i * stride:
190
+ return None, None
191
+ return offset, stride
@@ -0,0 +1,200 @@
1
+ from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet, IRStmt
2
+ from sonolus.backend.optimize.dominance import DominanceFrontiers, get_df, get_dom_children
3
+ from sonolus.backend.optimize.flow import BasicBlock, FlowEdge, traverse_cfg_preorder
4
+ from sonolus.backend.optimize.passes import CompilerPass
5
+ from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
6
+
7
+
8
+ class ToSSA(CompilerPass):
9
+ def requires(self) -> set[CompilerPass]:
10
+ return {DominanceFrontiers()}
11
+
12
+ def run(self, entry: BasicBlock) -> BasicBlock:
13
+ defs = self.defs_to_blocks(entry)
14
+ self.insert_phis(defs)
15
+ self.rename(entry, defs, {var: [] for var in defs}, {})
16
+ self.remove_placeholder_phis(entry)
17
+ return entry
18
+
19
+ def rename(
20
+ self,
21
+ block: BasicBlock,
22
+ defs: dict[TempBlock, set[BasicBlock]],
23
+ ssa_places: dict[TempBlock, list[SSAPlace]],
24
+ used: dict[str, int],
25
+ ):
26
+ to_pop = []
27
+ for var, args in [*block.phis.items()]:
28
+ if isinstance(var, SSAPlace):
29
+ continue
30
+ ssa_places[var].append(self.get_new_ssa_place(var.name, used))
31
+ to_pop.append(var)
32
+ block.phis[ssa_places[var][-1]] = args
33
+ block.statements = [self.rename_stmt(stmt, ssa_places, used, to_pop) for stmt in block.statements]
34
+ for edge in block.outgoing:
35
+ dst = edge.dst
36
+ for var, args in dst.phis.items():
37
+ if isinstance(var, SSAPlace):
38
+ continue
39
+ if ssa_places[var]:
40
+ args[block] = ssa_places[var][-1]
41
+ block.test = self.rename_stmt(block.test, ssa_places, used, to_pop)
42
+ for dom_child in get_dom_children(block):
43
+ self.rename(dom_child, defs, ssa_places, used)
44
+ for var in to_pop:
45
+ ssa_places[var].pop()
46
+
47
+ def remove_placeholder_phis(self, entry: BasicBlock):
48
+ for block in traverse_cfg_preorder(entry):
49
+ block.phis = {var: args for var, args in block.phis.items() if isinstance(var, SSAPlace)}
50
+
51
+ def rename_stmt(
52
+ self, stmt: IRStmt, ssa_places: dict[TempBlock, list[SSAPlace]], used: dict[str, int], to_pop: list[SSAPlace]
53
+ ):
54
+ match stmt:
55
+ case IRConst():
56
+ return stmt
57
+ case IRPureInstr(op=op, args=args):
58
+ return IRPureInstr(op=op, args=[self.rename_stmt(arg, ssa_places, used, to_pop) for arg in args])
59
+ case IRInstr(op=op, args=args):
60
+ return IRInstr(op=op, args=[self.rename_stmt(arg, ssa_places, used, to_pop) for arg in args])
61
+ case IRGet(place=place):
62
+ return IRGet(place=self.rename_stmt(place, ssa_places, used, to_pop))
63
+ case IRSet(place=place, value=value):
64
+ value = self.rename_stmt(value, ssa_places, used, to_pop)
65
+ if isinstance(place, BlockPlace) and isinstance(place.block, TempBlock) and place.block.size == 1:
66
+ ssa_places[place.block].append(self.get_new_ssa_place(place.block.name, used))
67
+ to_pop.append(place.block)
68
+ place = self.rename_stmt(place, ssa_places, used, to_pop)
69
+ return IRSet(place=place, value=value)
70
+ case SSAPlace():
71
+ return stmt
72
+ case TempBlock() if stmt.size == 1:
73
+ if stmt not in ssa_places or not ssa_places[stmt]:
74
+ # This is an access to a definitely undefined variable
75
+ # But it might not be reachable in reality, so we should allow it
76
+ # Maybe there should be an error if this still happens after optimization,
77
+ # but recovering the location of the error in the original code is hard.
78
+ # This can happen in places like matching a VarArray[Num, 1] which was just created.
79
+ # IR generation won't immediately fold a check that size > 0 to false, so here we
80
+ # might see an access to uninitialized memory even though it's not reachable in reality.
81
+ return SSAPlace("err", 0)
82
+ return ssa_places[stmt][-1]
83
+ case TempBlock():
84
+ return stmt
85
+ case int():
86
+ return stmt
87
+ case BlockPlace(block=block, index=index, offset=offset):
88
+ if isinstance(block, TempBlock) and block.size == 1:
89
+ return self.rename_stmt(block, ssa_places, used, to_pop)
90
+ return BlockPlace(
91
+ block=self.rename_stmt(block, ssa_places, used, to_pop),
92
+ index=self.rename_stmt(index, ssa_places, used, to_pop),
93
+ offset=self.rename_stmt(offset, ssa_places, used, to_pop),
94
+ )
95
+ case _:
96
+ raise TypeError(f"Unexpected statement: {stmt}")
97
+
98
+ def insert_phis(self, defs: dict[TempBlock, set[BasicBlock]]):
99
+ for var, blocks in defs.items():
100
+ df = self.get_iterated_df(blocks)
101
+ for block in df:
102
+ block.phis[var] = {}
103
+
104
+ def defs_to_blocks(self, entry: BasicBlock) -> dict[TempBlock, set[BasicBlock]]:
105
+ result = {}
106
+ for block in traverse_cfg_preorder(entry):
107
+ for stmt in block.statements:
108
+ def_block = self.get_stmt_def(stmt)
109
+ if def_block is not None:
110
+ result.setdefault(def_block, set()).add(block)
111
+ return result
112
+
113
+ def get_stmt_def(self, stmt: IRStmt) -> TempBlock:
114
+ if (
115
+ isinstance(stmt, IRSet)
116
+ and isinstance(stmt.place, BlockPlace)
117
+ and isinstance(stmt.place.block, TempBlock)
118
+ and stmt.place.block.size == 1
119
+ ):
120
+ return stmt.place.block
121
+ return None
122
+
123
+ def get_iterated_df(self, blocks: set[BasicBlock]) -> set[BasicBlock]:
124
+ df = set()
125
+ worklist = set(blocks)
126
+ while worklist:
127
+ block = worklist.pop()
128
+ new_df = get_df(block) - df
129
+ if new_df:
130
+ df.update(new_df)
131
+ worklist.update(new_df)
132
+ return df
133
+
134
+ def get_new_ssa_place(self, name: str, used: dict[str, int]) -> SSAPlace:
135
+ if name not in used:
136
+ used[name] = 0
137
+ used[name] += 1
138
+ return SSAPlace(name, used[name])
139
+
140
+
141
+ class FromSSA(CompilerPass):
142
+ def run(self, entry: BasicBlock) -> BasicBlock:
143
+ for block in [*traverse_cfg_preorder(entry)]:
144
+ self.process_block(block)
145
+ return entry
146
+
147
+ def process_block(self, block: BasicBlock):
148
+ incoming = [*block.incoming]
149
+ block.incoming.clear()
150
+ for edge in incoming:
151
+ between_block = BasicBlock()
152
+ edge.dst = between_block
153
+ between_block.incoming.add(edge)
154
+ next_edge = FlowEdge(between_block, block, None)
155
+ block.incoming.add(next_edge)
156
+ between_block.outgoing.add(next_edge)
157
+ for args in block.phis.values():
158
+ if edge.src in args:
159
+ args[between_block] = args.pop(edge.src)
160
+ for var, args in block.phis.items():
161
+ for src, arg in args.items():
162
+ src.statements.append(
163
+ IRSet(place=self.place_from_ssa_place(var), value=IRGet(place=self.place_from_ssa_place(arg)))
164
+ )
165
+ block.phis = {}
166
+ block.statements = [self.process_stmt(stmt) for stmt in block.statements]
167
+ block.test = self.process_stmt(block.test)
168
+
169
+ def process_stmt(self, stmt: IRStmt):
170
+ match stmt:
171
+ case IRConst():
172
+ return stmt
173
+ case IRPureInstr(op=op, args=args):
174
+ return IRPureInstr(op=op, args=[self.process_stmt(arg) for arg in args])
175
+ case IRInstr(op=op, args=args):
176
+ return IRInstr(op=op, args=[self.process_stmt(arg) for arg in args])
177
+ case IRGet(place=place):
178
+ return IRGet(place=self.process_stmt(place))
179
+ case IRSet(place=place, value=value):
180
+ return IRSet(place=self.process_stmt(place), value=self.process_stmt(value))
181
+ case SSAPlace():
182
+ return self.place_from_ssa_place(stmt)
183
+ case TempBlock():
184
+ return stmt
185
+ case int():
186
+ return stmt
187
+ case BlockPlace(block=block, index=index, offset=offset):
188
+ return BlockPlace(
189
+ block=self.process_stmt(block),
190
+ index=self.process_stmt(index),
191
+ offset=self.process_stmt(offset),
192
+ )
193
+ case _:
194
+ raise TypeError(f"Unexpected statement: {stmt}")
195
+
196
+ def temp_block_from_ssa_place(self, ssa_place: SSAPlace) -> TempBlock:
197
+ return TempBlock(f"{ssa_place.name}.{ssa_place.num}")
198
+
199
+ def place_from_ssa_place(self, ssa_place: SSAPlace) -> BlockPlace:
200
+ return BlockPlace(block=self.temp_block_from_ssa_place(ssa_place), index=0, offset=0)
sonolus/backend/place.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Iterator
2
- from typing import Self
2
+ from typing import NamedTuple, Self
3
3
 
4
4
  from sonolus.backend.blocks import Block
5
5
 
@@ -8,13 +8,9 @@ type BlockValue = Block | int | TempBlock | Place
8
8
  type IndexValue = int | Place
9
9
 
10
10
 
11
- class TempBlock:
11
+ class TempBlock(NamedTuple):
12
12
  name: str
13
- size: int
14
-
15
- def __init__(self, name: str, size: int = 1):
16
- self.name = name
17
- self.size = size
13
+ size: int = 1
18
14
 
19
15
  def __repr__(self):
20
16
  return f"TempBlock(name={self.name!r}, size={self.size!r})"
@@ -33,19 +29,14 @@ class TempBlock:
33
29
  return isinstance(other, TempBlock) and self.name == other.name and self.size == other.size
34
30
 
35
31
  def __hash__(self):
36
- return hash((self.name, self.size))
32
+ return hash(self.name) # Typically will be unique by name alone
37
33
 
38
34
 
39
- class BlockPlace:
35
+ class BlockPlace(NamedTuple):
40
36
  block: BlockValue
41
- index: IndexValue
37
+ index: IndexValue = 0
42
38
  offset: int = 0
43
39
 
44
- def __init__(self, block: BlockValue, index: IndexValue = 0, offset: int = 0):
45
- self.block = block
46
- self.index = index
47
- self.offset = offset
48
-
49
40
  def __repr__(self):
50
41
  return f"BlockPlace(block={self.block!r}, index={self.index!r}, offset={self.offset!r})"
51
42
 
@@ -59,24 +50,25 @@ class BlockPlace:
59
50
  else:
60
51
  return f"{self.block}[{self.index} + {self.offset}]"
61
52
 
53
+ def add_offset(self, offset: int) -> Self:
54
+ return BlockPlace(self.block, self.index, self.offset + offset)
55
+
62
56
  def __eq__(self, other):
63
- return isinstance(other, BlockPlace) and self.block == other.block and self.index == other.index
57
+ return (
58
+ isinstance(other, BlockPlace)
59
+ and self.block == other.block
60
+ and self.index == other.index
61
+ and self.offset == other.offset
62
+ )
64
63
 
65
64
  def __hash__(self):
66
- return hash((self.block, self.index))
65
+ return hash((self.block, self.index, self.offset))
67
66
 
68
- def add_offset(self, offset: int) -> Self:
69
- return BlockPlace(self.block, self.index, self.offset + offset)
70
67
 
71
-
72
- class SSAPlace:
68
+ class SSAPlace(NamedTuple):
73
69
  name: str
74
70
  num: int
75
71
 
76
- def __init__(self, name: str, num: int):
77
- self.name = name
78
- self.num = num
79
-
80
72
  def __repr__(self):
81
73
  return f"SSAPlace(name={self.name!r}, num={self.num!r})"
82
74
 
sonolus/backend/utils.py CHANGED
@@ -1,48 +1,58 @@
1
- # ruff: noqa: N802
2
- import ast
3
- import inspect
4
- from collections.abc import Callable
5
- from pathlib import Path
6
-
7
-
8
- def get_function(fn: Callable) -> tuple[str, ast.FunctionDef]:
9
- # This preserves both line number and column number in the returned node
10
- source_file = inspect.getsourcefile(fn)
11
- _, start_line = inspect.getsourcelines(fn)
12
- base_tree = ast.parse(Path(source_file).read_text(encoding="utf-8"))
13
- return source_file, find_function(base_tree, start_line)
14
-
15
-
16
- class FindFunction(ast.NodeVisitor):
17
- def __init__(self, line):
18
- self.line = line
19
- self.node: ast.FunctionDef | None = None
20
-
21
- def visit_FunctionDef(self, node: ast.FunctionDef):
22
- if node.lineno == self.line or (
23
- node.decorator_list and (node.decorator_list[-1].end_lineno <= self.line <= node.lineno)
24
- ):
25
- self.node = node
26
- else:
27
- self.generic_visit(node)
28
-
29
-
30
- def find_function(tree: ast.Module, line: int):
31
- visitor = FindFunction(line)
32
- visitor.visit(tree)
33
- return visitor.node
34
-
35
-
36
- class ScanWrites(ast.NodeVisitor):
37
- def __init__(self):
38
- self.writes = []
39
-
40
- def visit_Name(self, node):
41
- if isinstance(node.ctx, ast.Store | ast.Delete):
42
- self.writes.append(node.id)
43
-
44
-
45
- def scan_writes(node: ast.AST) -> set[str]:
46
- visitor = ScanWrites()
47
- visitor.visit(node)
48
- return set(visitor.writes)
1
+ # ruff: noqa: N802
2
+ import ast
3
+ import inspect
4
+ from collections.abc import Callable
5
+ from functools import cache
6
+ from pathlib import Path
7
+
8
+
9
+ @cache
10
+ def get_function(fn: Callable) -> tuple[str, ast.FunctionDef]:
11
+ # This preserves both line number and column number in the returned node
12
+ source_file = inspect.getsourcefile(fn)
13
+ _, start_line = inspect.getsourcelines(fn)
14
+ base_tree = ast.parse(Path(source_file).read_text(encoding="utf-8"))
15
+ return source_file, find_function(base_tree, start_line)
16
+
17
+
18
+ class FindFunction(ast.NodeVisitor):
19
+ def __init__(self, line):
20
+ self.line = line
21
+ self.node: ast.FunctionDef | None = None
22
+
23
+ def visit_FunctionDef(self, node: ast.FunctionDef):
24
+ if node.lineno == self.line or (
25
+ node.decorator_list and (node.decorator_list[-1].end_lineno <= self.line <= node.lineno)
26
+ ):
27
+ self.node = node
28
+ else:
29
+ self.generic_visit(node)
30
+
31
+ def visit_Lambda(self, node: ast.Lambda):
32
+ if node.lineno == self.line:
33
+ if self.node is not None:
34
+ raise ValueError("Multiple functions defined on the same line are not supported")
35
+ self.node = node
36
+ else:
37
+ self.generic_visit(node)
38
+
39
+
40
+ def find_function(tree: ast.Module, line: int):
41
+ visitor = FindFunction(line)
42
+ visitor.visit(tree)
43
+ return visitor.node
44
+
45
+
46
+ class ScanWrites(ast.NodeVisitor):
47
+ def __init__(self):
48
+ self.writes = []
49
+
50
+ def visit_Name(self, node):
51
+ if isinstance(node.ctx, ast.Store | ast.Delete):
52
+ self.writes.append(node.id)
53
+
54
+
55
+ def scan_writes(node: ast.AST) -> set[str]:
56
+ visitor = ScanWrites()
57
+ visitor.visit(node)
58
+ return set(visitor.writes)