sonolus.py 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sonolus.py might be problematic. Click here for more details.

Files changed (68) hide show
  1. sonolus/backend/allocate.py +125 -51
  2. sonolus/backend/blocks.py +756 -756
  3. sonolus/backend/coalesce.py +85 -0
  4. sonolus/backend/constant_evaluation.py +374 -0
  5. sonolus/backend/dead_code.py +80 -0
  6. sonolus/backend/dominance.py +111 -0
  7. sonolus/backend/excepthook.py +37 -37
  8. sonolus/backend/finalize.py +69 -69
  9. sonolus/backend/flow.py +121 -92
  10. sonolus/backend/inlining.py +150 -0
  11. sonolus/backend/ir.py +5 -3
  12. sonolus/backend/liveness.py +173 -0
  13. sonolus/backend/mode.py +24 -24
  14. sonolus/backend/node.py +40 -40
  15. sonolus/backend/ops.py +197 -197
  16. sonolus/backend/optimize.py +37 -9
  17. sonolus/backend/passes.py +52 -6
  18. sonolus/backend/simplify.py +47 -30
  19. sonolus/backend/ssa.py +187 -0
  20. sonolus/backend/utils.py +48 -48
  21. sonolus/backend/visitor.py +892 -882
  22. sonolus/build/cli.py +7 -1
  23. sonolus/build/compile.py +88 -90
  24. sonolus/build/level.py +24 -23
  25. sonolus/build/node.py +43 -43
  26. sonolus/script/archetype.py +23 -6
  27. sonolus/script/array.py +2 -2
  28. sonolus/script/bucket.py +191 -191
  29. sonolus/script/callbacks.py +127 -127
  30. sonolus/script/comptime.py +1 -1
  31. sonolus/script/containers.py +23 -0
  32. sonolus/script/debug.py +19 -3
  33. sonolus/script/easing.py +323 -0
  34. sonolus/script/effect.py +131 -131
  35. sonolus/script/globals.py +269 -269
  36. sonolus/script/graphics.py +200 -150
  37. sonolus/script/instruction.py +151 -151
  38. sonolus/script/internal/__init__.py +5 -5
  39. sonolus/script/internal/builtin_impls.py +144 -144
  40. sonolus/script/internal/context.py +12 -4
  41. sonolus/script/internal/descriptor.py +17 -17
  42. sonolus/script/internal/introspection.py +14 -14
  43. sonolus/script/internal/native.py +40 -38
  44. sonolus/script/internal/value.py +3 -3
  45. sonolus/script/interval.py +120 -112
  46. sonolus/script/iterator.py +214 -214
  47. sonolus/script/math.py +30 -1
  48. sonolus/script/num.py +1 -1
  49. sonolus/script/options.py +191 -191
  50. sonolus/script/particle.py +157 -157
  51. sonolus/script/pointer.py +30 -30
  52. sonolus/script/print.py +81 -81
  53. sonolus/script/random.py +14 -0
  54. sonolus/script/range.py +58 -58
  55. sonolus/script/record.py +3 -3
  56. sonolus/script/runtime.py +2 -0
  57. sonolus/script/sprite.py +333 -333
  58. sonolus/script/text.py +407 -407
  59. sonolus/script/timing.py +42 -42
  60. sonolus/script/transform.py +77 -23
  61. sonolus/script/ui.py +160 -160
  62. sonolus/script/vec.py +81 -78
  63. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.4.dist-info}/METADATA +1 -1
  64. sonolus_py-0.1.4.dist-info/RECORD +84 -0
  65. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.4.dist-info}/WHEEL +1 -1
  66. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.4.dist-info}/licenses/LICENSE +21 -21
  67. sonolus_py-0.1.3.dist-info/RECORD +0 -75
  68. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,85 @@
1
+ from sonolus.backend.flow import BasicBlock, traverse_cfg_preorder
2
+ from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet
3
+ from sonolus.backend.liveness import LivenessAnalysis, get_live
4
+ from sonolus.backend.passes import CompilerPass
5
+ from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
6
+
7
+
8
+ class CopyCoalesce(CompilerPass):
9
+ def requires(self) -> set[CompilerPass]:
10
+ return {LivenessAnalysis()}
11
+
12
+ def run(self, entry: BasicBlock) -> BasicBlock:
13
+ mapping = self.get_mapping(entry)
14
+ for block in traverse_cfg_preorder(entry):
15
+ block.statements = [self.apply_to_stmt(stmt, mapping) for stmt in block.statements]
16
+ block.test = self.apply_to_stmt(block.test, mapping)
17
+ return entry
18
+
19
+ def apply_to_stmt(self, stmt, mapping: dict[TempBlock, TempBlock]):
20
+ match stmt:
21
+ case IRConst():
22
+ return stmt
23
+ case IRSet(place=place, value=value):
24
+ return IRSet(self.apply_to_stmt(place, mapping), self.apply_to_stmt(value, mapping))
25
+ case IRGet(place=place):
26
+ return IRGet(self.apply_to_stmt(place, mapping))
27
+ case IRPureInstr(op=op, args=args):
28
+ return IRPureInstr(op, [self.apply_to_stmt(arg, mapping) for arg in args])
29
+ case IRInstr(op=op, args=args):
30
+ return IRInstr(op, [self.apply_to_stmt(arg, mapping) for arg in args])
31
+ case BlockPlace(block=block, index=index, offset=offset):
32
+ return BlockPlace(self.apply_to_stmt(block, mapping), self.apply_to_stmt(index, mapping), offset)
33
+ case TempBlock():
34
+ return mapping.get(stmt, stmt)
35
+ case SSAPlace() | int() | float():
36
+ return stmt
37
+
38
+ def get_mapping(self, entry: BasicBlock) -> dict[TempBlock, TempBlock]:
39
+ interference = self.get_interference(entry)
40
+ copies = self.get_copies(entry)
41
+
42
+ mapping = {}
43
+
44
+ for target, sources in copies.items():
45
+ for source in sources:
46
+ if source in mapping:
47
+ continue
48
+ if source in interference.get(target, set()):
49
+ continue
50
+ mapping[source] = mapping.get(target, target)
51
+ combined_interference = interference.get(target, set()) | interference.get(source, set())
52
+ interference[source] = combined_interference
53
+ interference[target] = combined_interference
54
+
55
+ return mapping
56
+
57
+ def get_interference(self, entry: BasicBlock) -> dict[TempBlock, set[TempBlock]]:
58
+ result = {}
59
+ for block in traverse_cfg_preorder(entry):
60
+ for stmt in [*block.statements, block.test]:
61
+ live = {p for p in get_live(stmt) if isinstance(p, TempBlock) and p.size == 1}
62
+ for place in live:
63
+ result.setdefault(place, set()).update(live - {place})
64
+ return result
65
+
66
+ def get_copies(self, entry: BasicBlock) -> dict[TempBlock, set[TempBlock]]:
67
+ result = {}
68
+ for block in traverse_cfg_preorder(entry):
69
+ for stmt in block.statements:
70
+ if (
71
+ not isinstance(stmt, IRSet)
72
+ or not isinstance(stmt.place, BlockPlace)
73
+ or not isinstance(stmt.place.block, TempBlock)
74
+ or stmt.place.block.size != 1
75
+ or not isinstance(stmt.value, IRGet)
76
+ or not isinstance(stmt.value.place, BlockPlace)
77
+ or not isinstance(stmt.value.place.block, TempBlock)
78
+ or stmt.value.place.block.size != 1
79
+ ):
80
+ continue
81
+ target = stmt.place.block
82
+ source = stmt.value.place.block
83
+ result.setdefault(target, set()).add(source)
84
+ result.setdefault(source, set()).add(target)
85
+ return result
@@ -0,0 +1,374 @@
1
+ # ruff: noqa: PLR1702
2
+ import functools
3
+ import math
4
+ import operator
5
+ from typing import ClassVar
6
+
7
+ import sonolus.script.math as smath
8
+ from sonolus.backend.flow import BasicBlock, FlowEdge, traverse_cfg_preorder
9
+ from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet, IRStmt
10
+ from sonolus.backend.ops import Op
11
+ from sonolus.backend.passes import CompilerPass
12
+ from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
13
+
14
+
15
+ class Undefined:
16
+ pass
17
+
18
+
19
+ class NotAConstant:
20
+ pass
21
+
22
+
23
+ UNDEF = Undefined()
24
+ NAC = NotAConstant()
25
+
26
+
27
+ type Value = float | Undefined | NotAConstant
28
+
29
+
30
+ class SparseConditionalConstantPropagation(CompilerPass):
31
+ SUPPORTED_OPS: ClassVar[set[Op]] = {
32
+ Op.Equal,
33
+ Op.NotEqual,
34
+ Op.Greater,
35
+ Op.GreaterOr,
36
+ Op.Less,
37
+ Op.LessOr,
38
+ Op.Not,
39
+ Op.And,
40
+ Op.Or,
41
+ Op.Negate,
42
+ Op.Add,
43
+ Op.Subtract,
44
+ Op.Multiply,
45
+ Op.Divide,
46
+ Op.Power,
47
+ Op.Log,
48
+ Op.Ceil,
49
+ Op.Floor,
50
+ Op.Round,
51
+ Op.Frac,
52
+ Op.Mod,
53
+ Op.Rem,
54
+ Op.Sin,
55
+ Op.Cos,
56
+ Op.Tan,
57
+ Op.Sinh,
58
+ Op.Cosh,
59
+ Op.Tanh,
60
+ Op.Arcsin,
61
+ Op.Arccos,
62
+ Op.Arctan,
63
+ Op.Arctan2,
64
+ }
65
+
66
+ def run(self, entry: BasicBlock) -> BasicBlock:
67
+ ssa_edges: dict[SSAPlace, set[SSAPlace | BasicBlock]] = {}
68
+ executable_edges: set[FlowEdge] = set()
69
+
70
+ # BasicBlock key means the block's test
71
+ values: dict[SSAPlace | BasicBlock, Value] = {}
72
+ defs: dict[SSAPlace | BasicBlock, IRStmt | dict[FlowEdge, SSAPlace]] = {}
73
+ places_to_blocks: dict[SSAPlace, BasicBlock] = {}
74
+ reachable_blocks: set[BasicBlock] = set()
75
+
76
+ for block in traverse_cfg_preorder(entry):
77
+ incoming_by_src = {edge.src: edge for edge in block.incoming}
78
+ for p, args in block.phis.items():
79
+ if not isinstance(p, SSAPlace):
80
+ continue
81
+ defs[p] = {incoming_by_src[b]: v for b, v in args.items()}
82
+ values[p] = UNDEF
83
+ for arg in args.values():
84
+ ssa_edges.setdefault(arg, set()).add(p)
85
+ for stmt in block.statements:
86
+ if isinstance(stmt, IRSet) and isinstance(stmt.place, SSAPlace):
87
+ defs[stmt.place] = stmt.value
88
+ places_to_blocks[stmt.place] = block
89
+ values[stmt.place] = UNDEF
90
+ for dep in self.get_dependencies(stmt.value, set()):
91
+ ssa_edges.setdefault(dep, set()).add(stmt.place)
92
+ defs[block] = block.test
93
+ values[block] = UNDEF
94
+ for dep in self.get_dependencies(block.test, set()):
95
+ ssa_edges.setdefault(dep, set()).add(block)
96
+
97
+ def visit_phi(p):
98
+ arg_values = [values[v] if b in executable_edges else UNDEF for b, v in defs[p].items()]
99
+ distinct_defined_arg_values = {arg for arg in arg_values if arg is not UNDEF}
100
+ value = values[p]
101
+ if len(distinct_defined_arg_values) == 1:
102
+ new_value = distinct_defined_arg_values.pop()
103
+ elif len(distinct_defined_arg_values) > 1:
104
+ new_value = NAC
105
+ else:
106
+ new_value = UNDEF
107
+ if new_value != value:
108
+ values[p] = new_value
109
+ ssa_worklist.update(ssa_edges.get(p, set()))
110
+
111
+ flow_worklist: set[FlowEdge] = {FlowEdge(entry, entry, None)}
112
+ ssa_worklist: set[SSAPlace | BasicBlock] = set()
113
+ while flow_worklist or ssa_worklist:
114
+ while flow_worklist:
115
+ edge = flow_worklist.pop()
116
+ if edge in executable_edges:
117
+ continue
118
+ executable_edges.add(edge)
119
+ block: BasicBlock = edge.dst
120
+ for p in block.phis:
121
+ visit_phi(p)
122
+ is_first_visit = sum(edge in executable_edges for edge in block.incoming) <= 1
123
+ if is_first_visit:
124
+ for stmt in block.statements:
125
+ if not (isinstance(stmt, IRSet) and isinstance(stmt.place, SSAPlace)):
126
+ continue
127
+ value = values[stmt.place]
128
+ new_value = self.evaluate_stmt(stmt.value, values)
129
+ if new_value != value:
130
+ values[stmt.place] = new_value
131
+ ssa_worklist.update(ssa_edges.get(stmt.place, set()))
132
+ test_value = values[block]
133
+ new_test_value = self.evaluate_stmt(block.test, values)
134
+ if new_test_value != test_value:
135
+ assert new_test_value is not UNDEF
136
+ values[block] = new_test_value
137
+ if new_test_value is NAC:
138
+ flow_worklist.update(block.outgoing)
139
+ reachable_blocks.update(e.dst for e in block.outgoing)
140
+ else:
141
+ taken_edge = next(
142
+ (edge for edge in block.outgoing if edge.cond == new_test_value), None
143
+ ) or next((edge for edge in block.outgoing if edge.cond is None), None)
144
+ if taken_edge:
145
+ flow_worklist.add(taken_edge)
146
+ reachable_blocks.add(taken_edge.dst)
147
+ elif len(block.outgoing) == 1 and next(iter(block.outgoing)).cond is None:
148
+ flow_worklist.update(block.outgoing)
149
+ reachable_blocks.update(e.dst for e in block.outgoing)
150
+ while ssa_worklist:
151
+ p = ssa_worklist.pop()
152
+ defn = defs[p]
153
+ if isinstance(defn, dict):
154
+ # This is a phi
155
+ visit_phi(p)
156
+ elif isinstance(p, BasicBlock):
157
+ # This is the block's test
158
+ test_value = values[p]
159
+ new_test_value = self.evaluate_stmt(defn, values)
160
+ if new_test_value != test_value:
161
+ assert new_test_value is not UNDEF
162
+ values[p] = new_test_value
163
+ if new_test_value is NAC:
164
+ flow_worklist.update(p.outgoing)
165
+ reachable_blocks.update(e.dst for e in p.outgoing)
166
+ else:
167
+ taken_edge = next(edge for edge in p.outgoing if edge.cond == new_test_value) or next(
168
+ (edge for edge in p.outgoing if edge.cond is None), None
169
+ )
170
+ if taken_edge:
171
+ flow_worklist.add(taken_edge)
172
+ reachable_blocks.add(taken_edge.dst)
173
+ else:
174
+ # This is a regular SSA assignment
175
+ if places_to_blocks[p] not in reachable_blocks:
176
+ continue
177
+ value = values[p]
178
+ new_value = self.evaluate_stmt(defn, values)
179
+ if new_value != value:
180
+ values[p] = new_value
181
+ ssa_worklist.update(ssa_edges.get(p, set()))
182
+
183
+ for block in traverse_cfg_preorder(entry):
184
+ block.statements = [self.substitute_constants(stmt, values) for stmt in block.statements]
185
+ block.test = self.substitute_constants(block.test, values)
186
+
187
+ return entry
188
+
189
+ def get_dependencies(self, stmt, dependencies: set[SSAPlace]):
190
+ match stmt:
191
+ case IRConst():
192
+ pass
193
+ case IRPureInstr(op=_, args=args) | IRInstr(op=_, args=args):
194
+ for arg in args:
195
+ self.get_dependencies(arg, dependencies)
196
+ case IRGet(place=SSAPlace() as place):
197
+ dependencies.add(place)
198
+ case IRGet(place=BlockPlace() as place):
199
+ self.get_dependencies(place.block, dependencies)
200
+ case BlockPlace(block=block, index=index, offset=_):
201
+ self.get_dependencies(block, dependencies)
202
+ self.get_dependencies(index, dependencies)
203
+ case SSAPlace():
204
+ dependencies.add(stmt)
205
+ case int() | float() | TempBlock():
206
+ pass
207
+ case _:
208
+ raise TypeError(f"Unexpected statement: {stmt}")
209
+ return dependencies
210
+
211
+ def substitute_constants(self, stmt, values: dict[SSAPlace, Value]):
212
+ match stmt:
213
+ case IRConst():
214
+ return stmt
215
+ case IRPureInstr(op=op, args=args):
216
+ return IRPureInstr(op=op, args=[self.substitute_constants(arg, values) for arg in args])
217
+ case IRInstr(op=op, args=args):
218
+ return IRInstr(op=op, args=[self.substitute_constants(arg, values) for arg in args])
219
+ case IRGet(place=SSAPlace() as place):
220
+ value = values[place]
221
+ if isinstance(value, int | float):
222
+ return IRConst(value)
223
+ return stmt
224
+ case IRGet(place=place):
225
+ return IRGet(place=self.substitute_constants(place, values))
226
+ case IRSet(place=SSAPlace() as place, value=value):
227
+ return IRSet(place=place, value=self.substitute_constants(value, values))
228
+ case IRSet(place=place, value=value):
229
+ return IRSet(
230
+ place=self.substitute_constants(place, values), value=self.substitute_constants(value, values)
231
+ )
232
+ case BlockPlace(block=block, index=index, offset=offset):
233
+ return BlockPlace(
234
+ block=self.substitute_constants(block, values),
235
+ index=self.substitute_constants(index, values),
236
+ offset=offset,
237
+ )
238
+ case SSAPlace():
239
+ value = values[stmt]
240
+ if isinstance(value, int | float):
241
+ return IRConst(value)
242
+ return stmt
243
+ case int() | float() | TempBlock():
244
+ return stmt
245
+ case _:
246
+ raise TypeError(f"Unexpected statement: {stmt}")
247
+
248
+ def evaluate_stmt(self, stmt, values: dict[SSAPlace, Value]) -> Value:
249
+ match stmt:
250
+ case IRConst(value=value):
251
+ return value
252
+ case IRPureInstr(op=op, args=args) | IRInstr(op=op, args=args):
253
+ if op not in self.SUPPORTED_OPS:
254
+ return NAC
255
+ args = [self.evaluate_stmt(arg, values) for arg in args]
256
+ match op:
257
+ case Op.And:
258
+ if any(arg == 0 for arg in args):
259
+ return 0
260
+ case Op.Or:
261
+ if any(arg == 1 for arg in args):
262
+ return 1
263
+ case Op.Multiply:
264
+ if any(arg == 0 for arg in args):
265
+ return 0
266
+ if any(arg is NAC for arg in args):
267
+ return NAC
268
+ if any(arg is UNDEF for arg in args):
269
+ return UNDEF
270
+ match op:
271
+ case Op.Equal:
272
+ assert len(args) == 2
273
+ return args[0] == args[1]
274
+ case Op.NotEqual:
275
+ assert len(args) == 2
276
+ return args[0] != args[1]
277
+ case Op.Greater:
278
+ assert len(args) == 2
279
+ return args[0] > args[1]
280
+ case Op.GreaterOr:
281
+ assert len(args) == 2
282
+ return args[0] >= args[1]
283
+ case Op.Less:
284
+ assert len(args) == 2
285
+ return args[0] < args[1]
286
+ case Op.LessOr:
287
+ assert len(args) == 2
288
+ return args[0] <= args[1]
289
+ case Op.Not:
290
+ assert len(args) == 1
291
+ return int(not args[0])
292
+ case Op.And:
293
+ return all(args)
294
+ case Op.Or:
295
+ return any(args)
296
+ case Op.Negate:
297
+ assert len(args) == 1
298
+ return -args[0]
299
+ case Op.Add:
300
+ return sum(args)
301
+ case Op.Subtract:
302
+ if len(args) == 0:
303
+ return 0
304
+ return args[0] - sum(args[1:])
305
+ case Op.Multiply:
306
+ if len(args) == 0:
307
+ return 1
308
+ return functools.reduce(operator.mul, args, 1)
309
+ case Op.Divide:
310
+ if len(args) == 0:
311
+ return 1
312
+ return args[0] / functools.reduce(operator.mul, args[1:], 1)
313
+ case Op.Power:
314
+ if len(args) == 0:
315
+ return 1
316
+ return functools.reduce(operator.pow, args)
317
+ case Op.Log:
318
+ assert len(args) == 2
319
+ return math.log(args[0], args[1])
320
+ case Op.Ceil:
321
+ assert len(args) == 1
322
+ return math.ceil(args[0])
323
+ case Op.Floor:
324
+ assert len(args) == 1
325
+ return math.floor(args[0])
326
+ case Op.Round:
327
+ assert len(args) == 1
328
+ # This is round half to even in both Python and Sonolus
329
+ return round(args[0])
330
+ case Op.Frac:
331
+ assert len(args) == 1
332
+ return smath.frac(args[0])
333
+ case Op.Mod:
334
+ assert len(args) == 2
335
+ return args[0] % args[1]
336
+ case Op.Rem:
337
+ assert len(args) == 2
338
+ return smath.remainder(args[0], args[1])
339
+ case Op.Sin:
340
+ assert len(args) == 1
341
+ return math.sin(args[0])
342
+ case Op.Cos:
343
+ assert len(args) == 1
344
+ return math.cos(args[0])
345
+ case Op.Tan:
346
+ assert len(args) == 1
347
+ return math.tan(args[0])
348
+ case Op.Sinh:
349
+ assert len(args) == 1
350
+ return math.sinh(args[0])
351
+ case Op.Cosh:
352
+ assert len(args) == 1
353
+ return math.cosh(args[0])
354
+ case Op.Tanh:
355
+ assert len(args) == 1
356
+ return math.tanh(args[0])
357
+ case Op.Arcsin:
358
+ assert len(args) == 1
359
+ return math.asin(args[0])
360
+ case Op.Arccos:
361
+ assert len(args) == 1
362
+ return math.acos(args[0])
363
+ case Op.Arctan:
364
+ assert len(args) == 1
365
+ return math.atan(args[0])
366
+ case Op.Arctan2:
367
+ assert len(args) == 2
368
+ return math.atan2(args[0], args[1])
369
+ case IRGet(place=SSAPlace() as place):
370
+ return values[place]
371
+ case IRGet():
372
+ return NAC
373
+ case IRSet() | _:
374
+ raise TypeError(f"Unexpected statement: {stmt}")
@@ -0,0 +1,80 @@
1
+ from sonolus.backend.flow import BasicBlock, traverse_cfg_preorder
2
+ from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRSet
3
+ from sonolus.backend.liveness import LivenessAnalysis, get_live, get_live_phi_targets
4
+ from sonolus.backend.passes import CompilerPass
5
+ from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
6
+
7
+
8
+ class UnreachableCodeElimination(CompilerPass):
9
+ def run(self, entry: BasicBlock) -> BasicBlock:
10
+ original_blocks = [*traverse_cfg_preorder(entry)]
11
+ worklist = {entry}
12
+ visited = set()
13
+ while worklist:
14
+ block = worklist.pop()
15
+ if block in visited:
16
+ continue
17
+ visited.add(block)
18
+ match block.test:
19
+ case IRConst(value=value):
20
+ block.test = IRConst(0)
21
+ taken_edge = next(
22
+ (edge for edge in block.outgoing if edge.cond == value),
23
+ None,
24
+ ) or next((edge for edge in block.outgoing if edge.cond is None), None)
25
+ assert not block.outgoing or taken_edge
26
+ for edge in [*block.outgoing]:
27
+ if edge is not taken_edge:
28
+ edge.dst.incoming.remove(edge)
29
+ block.outgoing.remove(edge)
30
+ if taken_edge:
31
+ taken_edge.cond = None
32
+ block.outgoing.add(taken_edge)
33
+ worklist.add(taken_edge.dst)
34
+ case _:
35
+ worklist.update(edge.dst for edge in block.outgoing)
36
+ for block in original_blocks:
37
+ if block not in visited:
38
+ for edge in block.outgoing:
39
+ edge.dst.incoming.remove(edge)
40
+ else:
41
+ for args in block.phis.values():
42
+ for src_block in [*args]:
43
+ if src_block not in visited:
44
+ args.pop(src_block)
45
+ return entry
46
+
47
+
48
+ class DeadCodeElimination(CompilerPass):
49
+ def requires(self) -> set[CompilerPass]:
50
+ return {LivenessAnalysis()}
51
+
52
+ def preserves(self) -> set[CompilerPass] | None:
53
+ return {LivenessAnalysis()}
54
+
55
+ def run(self, entry: BasicBlock) -> BasicBlock:
56
+ for block in traverse_cfg_preorder(entry):
57
+ live_stmts = []
58
+ for statement in block.statements:
59
+ live = get_live(statement)
60
+ match statement:
61
+ case IRSet(place=place, value=value):
62
+ is_live = not (
63
+ (isinstance(place, SSAPlace) and place not in live)
64
+ or (
65
+ isinstance(place, BlockPlace)
66
+ and isinstance(place.block, TempBlock)
67
+ and place.block not in live
68
+ )
69
+ or (isinstance(value, IRGet) and place == value.place)
70
+ )
71
+ if is_live:
72
+ live_stmts.append(statement)
73
+ elif isinstance(value, IRInstr) and value.op.side_effects:
74
+ live_stmts.append(value)
75
+ value.live = live
76
+ case other:
77
+ live_stmts.append(other)
78
+ block.statements = live_stmts
79
+ block.phis = {place: phi for place, phi in block.phis.items() if place in get_live_phi_targets(block)}
80
+ return entry
@@ -0,0 +1,111 @@
1
+ from sonolus.backend.flow import (
2
+ BasicBlock,
3
+ traverse_cfg_reverse_postorder,
4
+ )
5
+ from sonolus.backend.passes import CompilerPass
6
+
7
+ # traverse_cfg_preorder(entry: BasicBlock) -> Iterator[BasicBlock]
8
+ # traverse_cfg_postorder(entry: BasicBlock) -> Iterator[BasicBlock]
9
+ # traverse_cfg_reverse_postorder(entry: BasicBlock) -> Iterator[BasicBlock]
10
+
11
+ # class BasicBlock:
12
+ # phis: dict[SSAPlace, dict[Self, SSAPlace]]
13
+ # statements: list[IRStmt]
14
+ # test: IRExpr
15
+ # incoming: set[FlowEdge]
16
+ # outgoing: set[FlowEdge]
17
+
18
+ # class FlowEdge:
19
+ # src: "BasicBlock"
20
+ # dst: "BasicBlock"
21
+
22
+
23
+ class DominanceFrontiers(CompilerPass):
24
+ def destroys(self) -> set[CompilerPass] | None:
25
+ return set()
26
+
27
+ def run(self, entry: BasicBlock) -> BasicBlock:
28
+ blocks = list(traverse_cfg_reverse_postorder(entry))
29
+
30
+ self.number_blocks(blocks)
31
+ self.initialize_idoms(blocks, entry)
32
+ self.compute_idoms(blocks)
33
+ self.build_dominator_tree(blocks)
34
+ self.compute_dominance_frontiers(blocks)
35
+
36
+ return entry
37
+
38
+ def number_blocks(self, blocks: list[BasicBlock]):
39
+ """Assign a unique number to each block based on reverse post-order."""
40
+ for idx, block in enumerate(blocks):
41
+ block.num = idx
42
+
43
+ def initialize_idoms(self, blocks: list[BasicBlock], entry_block: BasicBlock):
44
+ """Initialize immediate dominators for each block."""
45
+ for block in blocks:
46
+ block.idom = None
47
+ entry_block.idom = entry_block
48
+
49
+ def compute_idoms(self, blocks: list[BasicBlock]):
50
+ """Iteratively compute the immediate dominators of each block."""
51
+ changed = True
52
+ while changed:
53
+ changed = False
54
+ for b in blocks[1:]: # Skip the entry block
55
+ new_idom = None
56
+ for edge in b.incoming:
57
+ p = edge.src
58
+ if p.idom is not None:
59
+ if new_idom is None:
60
+ new_idom = p
61
+ else:
62
+ new_idom = self.intersect(p, new_idom)
63
+ if b.idom != new_idom:
64
+ b.idom = new_idom
65
+ changed = True
66
+
67
+ def build_dominator_tree(self, blocks: list[BasicBlock]):
68
+ """Construct the dominator tree using the immediate dominators."""
69
+ for block in blocks:
70
+ block.dom_children = []
71
+
72
+ for block in blocks:
73
+ if block.idom != block:
74
+ block.idom.dom_children.append(block)
75
+
76
+ def compute_dominance_frontiers(self, blocks: list[BasicBlock]):
77
+ """Compute the dominance frontiers for all blocks."""
78
+ for block in blocks:
79
+ block.df = set()
80
+
81
+ for b in blocks:
82
+ if len(b.incoming) >= 2:
83
+ for edge in b.incoming:
84
+ p = edge.src
85
+ runner = p
86
+ while runner != b.idom:
87
+ runner.df.add(b)
88
+ runner = runner.idom
89
+
90
+ def intersect(self, b1: BasicBlock, b2: BasicBlock) -> BasicBlock:
91
+ """Helper function to find the closest common dominator of two blocks."""
92
+ while b1 != b2:
93
+ while b1.num > b2.num:
94
+ b1 = b1.idom
95
+ while b2.num > b1.num:
96
+ b2 = b2.idom
97
+ return b1
98
+
99
+ def __eq__(self, other):
100
+ return isinstance(other, DominanceFrontiers)
101
+
102
+ def __hash__(self):
103
+ return hash(DominanceFrontiers)
104
+
105
+
106
+ def get_df(block: BasicBlock) -> set[BasicBlock]:
107
+ return block.df
108
+
109
+
110
+ def get_dom_children(block: BasicBlock) -> list[BasicBlock]:
111
+ return block.dom_children