sonolus.py 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sonolus.py might be problematic. Click here for more details.

Files changed (90) hide show
  1. sonolus/backend/blocks.py +756 -756
  2. sonolus/backend/excepthook.py +37 -37
  3. sonolus/backend/finalize.py +77 -69
  4. sonolus/backend/interpret.py +7 -7
  5. sonolus/backend/ir.py +29 -3
  6. sonolus/backend/mode.py +24 -24
  7. sonolus/backend/node.py +40 -40
  8. sonolus/backend/ops.py +197 -197
  9. sonolus/backend/optimize/__init__.py +0 -0
  10. sonolus/backend/optimize/allocate.py +126 -0
  11. sonolus/backend/optimize/constant_evaluation.py +374 -0
  12. sonolus/backend/optimize/copy_coalesce.py +85 -0
  13. sonolus/backend/optimize/dead_code.py +185 -0
  14. sonolus/backend/optimize/dominance.py +96 -0
  15. sonolus/backend/{flow.py → optimize/flow.py} +122 -92
  16. sonolus/backend/optimize/inlining.py +137 -0
  17. sonolus/backend/optimize/liveness.py +177 -0
  18. sonolus/backend/optimize/optimize.py +44 -0
  19. sonolus/backend/optimize/passes.py +52 -0
  20. sonolus/backend/optimize/simplify.py +191 -0
  21. sonolus/backend/optimize/ssa.py +200 -0
  22. sonolus/backend/place.py +17 -25
  23. sonolus/backend/utils.py +58 -48
  24. sonolus/backend/visitor.py +1151 -882
  25. sonolus/build/cli.py +7 -1
  26. sonolus/build/compile.py +88 -90
  27. sonolus/build/engine.py +10 -5
  28. sonolus/build/level.py +24 -23
  29. sonolus/build/node.py +43 -43
  30. sonolus/script/archetype.py +438 -139
  31. sonolus/script/array.py +27 -10
  32. sonolus/script/array_like.py +297 -0
  33. sonolus/script/bucket.py +253 -191
  34. sonolus/script/containers.py +257 -51
  35. sonolus/script/debug.py +26 -10
  36. sonolus/script/easing.py +365 -0
  37. sonolus/script/effect.py +191 -131
  38. sonolus/script/engine.py +71 -4
  39. sonolus/script/globals.py +303 -269
  40. sonolus/script/instruction.py +205 -151
  41. sonolus/script/internal/__init__.py +5 -5
  42. sonolus/script/internal/builtin_impls.py +255 -144
  43. sonolus/script/{callbacks.py → internal/callbacks.py} +127 -127
  44. sonolus/script/internal/constant.py +139 -0
  45. sonolus/script/internal/context.py +26 -9
  46. sonolus/script/internal/descriptor.py +17 -17
  47. sonolus/script/internal/dict_impl.py +65 -0
  48. sonolus/script/internal/generic.py +6 -9
  49. sonolus/script/internal/impl.py +38 -13
  50. sonolus/script/internal/introspection.py +17 -14
  51. sonolus/script/internal/math_impls.py +121 -0
  52. sonolus/script/internal/native.py +40 -38
  53. sonolus/script/internal/random.py +67 -0
  54. sonolus/script/internal/range.py +81 -0
  55. sonolus/script/internal/transient.py +51 -0
  56. sonolus/script/internal/tuple_impl.py +113 -0
  57. sonolus/script/internal/value.py +3 -3
  58. sonolus/script/interval.py +338 -112
  59. sonolus/script/iterator.py +167 -214
  60. sonolus/script/level.py +24 -0
  61. sonolus/script/num.py +80 -48
  62. sonolus/script/options.py +257 -191
  63. sonolus/script/particle.py +190 -157
  64. sonolus/script/pointer.py +30 -30
  65. sonolus/script/print.py +102 -81
  66. sonolus/script/project.py +8 -0
  67. sonolus/script/quad.py +263 -0
  68. sonolus/script/record.py +47 -16
  69. sonolus/script/runtime.py +52 -1
  70. sonolus/script/sprite.py +418 -333
  71. sonolus/script/text.py +409 -407
  72. sonolus/script/timing.py +114 -42
  73. sonolus/script/transform.py +332 -48
  74. sonolus/script/ui.py +216 -160
  75. sonolus/script/values.py +6 -13
  76. sonolus/script/vec.py +196 -78
  77. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/METADATA +1 -1
  78. sonolus_py-0.1.5.dist-info/RECORD +89 -0
  79. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/WHEEL +1 -1
  80. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/licenses/LICENSE +21 -21
  81. sonolus/backend/allocate.py +0 -51
  82. sonolus/backend/optimize.py +0 -9
  83. sonolus/backend/passes.py +0 -6
  84. sonolus/backend/simplify.py +0 -30
  85. sonolus/script/comptime.py +0 -160
  86. sonolus/script/graphics.py +0 -150
  87. sonolus/script/math.py +0 -92
  88. sonolus/script/range.py +0 -58
  89. sonolus_py-0.1.3.dist-info/RECORD +0 -75
  90. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,374 @@
1
+ # ruff: noqa: PLR1702
2
+ import functools
3
+ import math
4
+ import operator
5
+ from typing import ClassVar
6
+
7
+ import sonolus.script.internal.math_impls as smath
8
+ from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet, IRStmt
9
+ from sonolus.backend.ops import Op
10
+ from sonolus.backend.optimize.flow import BasicBlock, FlowEdge, traverse_cfg_preorder
11
+ from sonolus.backend.optimize.passes import CompilerPass
12
+ from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
13
+
14
+
15
+ class Undefined:
16
+ pass
17
+
18
+
19
+ class NotAConstant:
20
+ pass
21
+
22
+
23
+ UNDEF = Undefined()
24
+ NAC = NotAConstant()
25
+
26
+
27
+ type Value = float | Undefined | NotAConstant
28
+
29
+
30
+ class SparseConditionalConstantPropagation(CompilerPass):
31
+ SUPPORTED_OPS: ClassVar[set[Op]] = {
32
+ Op.Equal,
33
+ Op.NotEqual,
34
+ Op.Greater,
35
+ Op.GreaterOr,
36
+ Op.Less,
37
+ Op.LessOr,
38
+ Op.Not,
39
+ Op.And,
40
+ Op.Or,
41
+ Op.Negate,
42
+ Op.Add,
43
+ Op.Subtract,
44
+ Op.Multiply,
45
+ Op.Divide,
46
+ Op.Power,
47
+ Op.Log,
48
+ Op.Ceil,
49
+ Op.Floor,
50
+ Op.Round,
51
+ Op.Frac,
52
+ Op.Mod,
53
+ Op.Rem,
54
+ Op.Sin,
55
+ Op.Cos,
56
+ Op.Tan,
57
+ Op.Sinh,
58
+ Op.Cosh,
59
+ Op.Tanh,
60
+ Op.Arcsin,
61
+ Op.Arccos,
62
+ Op.Arctan,
63
+ Op.Arctan2,
64
+ }
65
+
66
+ def run(self, entry: BasicBlock) -> BasicBlock:
67
+ ssa_edges: dict[SSAPlace, set[SSAPlace | BasicBlock]] = {}
68
+ executable_edges: set[FlowEdge] = set()
69
+
70
+ # BasicBlock key means the block's test
71
+ values: dict[SSAPlace | BasicBlock, Value] = {SSAPlace("err", 0): UNDEF}
72
+ defs: dict[SSAPlace | BasicBlock, IRStmt | dict[FlowEdge, SSAPlace]] = {}
73
+ places_to_blocks: dict[SSAPlace, BasicBlock] = {}
74
+ reachable_blocks: set[BasicBlock] = set()
75
+
76
+ for block in traverse_cfg_preorder(entry):
77
+ incoming_by_src = {edge.src: edge for edge in block.incoming}
78
+ for p, args in block.phis.items():
79
+ if not isinstance(p, SSAPlace):
80
+ continue
81
+ defs[p] = {incoming_by_src[b]: v for b, v in args.items()}
82
+ values[p] = UNDEF
83
+ for arg in args.values():
84
+ ssa_edges.setdefault(arg, set()).add(p)
85
+ for stmt in block.statements:
86
+ if isinstance(stmt, IRSet) and isinstance(stmt.place, SSAPlace):
87
+ defs[stmt.place] = stmt.value
88
+ places_to_blocks[stmt.place] = block
89
+ values[stmt.place] = UNDEF
90
+ for dep in self.get_dependencies(stmt.value, set()):
91
+ ssa_edges.setdefault(dep, set()).add(stmt.place)
92
+ defs[block] = block.test
93
+ values[block] = UNDEF
94
+ for dep in self.get_dependencies(block.test, set()):
95
+ ssa_edges.setdefault(dep, set()).add(block)
96
+
97
+ def visit_phi(p):
98
+ arg_values = [values[v] if b in executable_edges else UNDEF for b, v in defs[p].items()]
99
+ distinct_defined_arg_values = {arg for arg in arg_values if arg is not UNDEF}
100
+ value = values[p]
101
+ if len(distinct_defined_arg_values) == 1:
102
+ new_value = distinct_defined_arg_values.pop()
103
+ elif len(distinct_defined_arg_values) > 1:
104
+ new_value = NAC
105
+ else:
106
+ new_value = UNDEF
107
+ if new_value != value:
108
+ values[p] = new_value
109
+ ssa_worklist.update(ssa_edges.get(p, set()))
110
+
111
+ flow_worklist: set[FlowEdge] = {FlowEdge(entry, entry, None)}
112
+ ssa_worklist: set[SSAPlace | BasicBlock] = set()
113
+ while flow_worklist or ssa_worklist:
114
+ while flow_worklist:
115
+ edge = flow_worklist.pop()
116
+ if edge in executable_edges:
117
+ continue
118
+ executable_edges.add(edge)
119
+ block: BasicBlock = edge.dst
120
+ for p in block.phis:
121
+ visit_phi(p)
122
+ is_first_visit = sum(edge in executable_edges for edge in block.incoming) <= 1
123
+ if is_first_visit:
124
+ for stmt in block.statements:
125
+ if not (isinstance(stmt, IRSet) and isinstance(stmt.place, SSAPlace)):
126
+ continue
127
+ value = values[stmt.place]
128
+ new_value = self.evaluate_stmt(stmt.value, values)
129
+ if new_value != value:
130
+ values[stmt.place] = new_value
131
+ ssa_worklist.update(ssa_edges.get(stmt.place, set()))
132
+ test_value = values[block]
133
+ new_test_value = self.evaluate_stmt(block.test, values)
134
+ if new_test_value != test_value:
135
+ assert new_test_value is not UNDEF
136
+ values[block] = new_test_value
137
+ if new_test_value is NAC:
138
+ flow_worklist.update(block.outgoing)
139
+ reachable_blocks.update(e.dst for e in block.outgoing)
140
+ else:
141
+ taken_edge = next(
142
+ (edge for edge in block.outgoing if edge.cond == new_test_value), None
143
+ ) or next((edge for edge in block.outgoing if edge.cond is None), None)
144
+ if taken_edge:
145
+ flow_worklist.add(taken_edge)
146
+ reachable_blocks.add(taken_edge.dst)
147
+ elif len(block.outgoing) == 1 and next(iter(block.outgoing)).cond is None:
148
+ flow_worklist.update(block.outgoing)
149
+ reachable_blocks.update(e.dst for e in block.outgoing)
150
+ while ssa_worklist:
151
+ p = ssa_worklist.pop()
152
+ defn = defs[p]
153
+ if isinstance(defn, dict):
154
+ # This is a phi
155
+ visit_phi(p)
156
+ elif isinstance(p, BasicBlock):
157
+ # This is the block's test
158
+ test_value = values[p]
159
+ new_test_value = self.evaluate_stmt(defn, values)
160
+ if new_test_value != test_value:
161
+ assert new_test_value is not UNDEF
162
+ values[p] = new_test_value
163
+ if new_test_value is NAC:
164
+ flow_worklist.update(p.outgoing)
165
+ reachable_blocks.update(e.dst for e in p.outgoing)
166
+ else:
167
+ taken_edge = next(
168
+ (edge for edge in p.outgoing if edge.cond == new_test_value), None
169
+ ) or next((edge for edge in p.outgoing if edge.cond is None), None)
170
+ if taken_edge:
171
+ flow_worklist.add(taken_edge)
172
+ reachable_blocks.add(taken_edge.dst)
173
+ else:
174
+ # This is a regular SSA assignment
175
+ if places_to_blocks[p] not in reachable_blocks:
176
+ continue
177
+ value = values[p]
178
+ new_value = self.evaluate_stmt(defn, values)
179
+ if new_value != value:
180
+ values[p] = new_value
181
+ ssa_worklist.update(ssa_edges.get(p, set()))
182
+
183
+ for block in traverse_cfg_preorder(entry):
184
+ block.statements = [self.substitute_constants(stmt, values) for stmt in block.statements]
185
+ block.test = self.substitute_constants(block.test, values)
186
+
187
+ return entry
188
+
189
+ def get_dependencies(self, stmt, dependencies: set[SSAPlace]):
190
+ match stmt:
191
+ case IRConst():
192
+ pass
193
+ case IRPureInstr(op=_, args=args) | IRInstr(op=_, args=args):
194
+ for arg in args:
195
+ self.get_dependencies(arg, dependencies)
196
+ case IRGet(place=SSAPlace() as place):
197
+ dependencies.add(place)
198
+ case IRGet(place=BlockPlace() as place):
199
+ self.get_dependencies(place.block, dependencies)
200
+ case BlockPlace(block=block, index=index, offset=_):
201
+ self.get_dependencies(block, dependencies)
202
+ self.get_dependencies(index, dependencies)
203
+ case SSAPlace():
204
+ dependencies.add(stmt)
205
+ case int() | float() | TempBlock():
206
+ pass
207
+ case _:
208
+ raise TypeError(f"Unexpected statement: {stmt}")
209
+ return dependencies
210
+
211
+ def substitute_constants(self, stmt, values: dict[SSAPlace, Value]):
212
+ match stmt:
213
+ case IRConst():
214
+ return stmt
215
+ case IRPureInstr(op=op, args=args):
216
+ return IRPureInstr(op=op, args=[self.substitute_constants(arg, values) for arg in args])
217
+ case IRInstr(op=op, args=args):
218
+ return IRInstr(op=op, args=[self.substitute_constants(arg, values) for arg in args])
219
+ case IRGet(place=SSAPlace() as place):
220
+ value = values[place]
221
+ if isinstance(value, int | float):
222
+ return IRConst(value)
223
+ return stmt
224
+ case IRGet(place=place):
225
+ return IRGet(place=self.substitute_constants(place, values))
226
+ case IRSet(place=SSAPlace() as place, value=value):
227
+ return IRSet(place=place, value=self.substitute_constants(value, values))
228
+ case IRSet(place=place, value=value):
229
+ return IRSet(
230
+ place=self.substitute_constants(place, values), value=self.substitute_constants(value, values)
231
+ )
232
+ case BlockPlace(block=block, index=index, offset=offset):
233
+ return BlockPlace(
234
+ block=self.substitute_constants(block, values),
235
+ index=self.substitute_constants(index, values),
236
+ offset=offset,
237
+ )
238
+ case SSAPlace():
239
+ value = values[stmt]
240
+ if isinstance(value, int | float):
241
+ return IRConst(value)
242
+ return stmt
243
+ case int() | float() | TempBlock():
244
+ return stmt
245
+ case _:
246
+ raise TypeError(f"Unexpected statement: {stmt}")
247
+
248
+ def evaluate_stmt(self, stmt, values: dict[SSAPlace, Value]) -> Value:
249
+ match stmt:
250
+ case IRConst(value=value):
251
+ return value
252
+ case IRPureInstr(op=op, args=args) | IRInstr(op=op, args=args):
253
+ if op not in self.SUPPORTED_OPS:
254
+ return NAC
255
+ args = [self.evaluate_stmt(arg, values) for arg in args]
256
+ match op:
257
+ case Op.And:
258
+ if any(arg == 0 for arg in args):
259
+ return 0
260
+ case Op.Or:
261
+ if any(arg == 1 for arg in args):
262
+ return 1
263
+ case Op.Multiply:
264
+ if any(arg == 0 for arg in args):
265
+ return 0
266
+ if any(arg is NAC for arg in args):
267
+ return NAC
268
+ if any(arg is UNDEF for arg in args):
269
+ return UNDEF
270
+ match op:
271
+ case Op.Equal:
272
+ assert len(args) == 2
273
+ return args[0] == args[1]
274
+ case Op.NotEqual:
275
+ assert len(args) == 2
276
+ return args[0] != args[1]
277
+ case Op.Greater:
278
+ assert len(args) == 2
279
+ return args[0] > args[1]
280
+ case Op.GreaterOr:
281
+ assert len(args) == 2
282
+ return args[0] >= args[1]
283
+ case Op.Less:
284
+ assert len(args) == 2
285
+ return args[0] < args[1]
286
+ case Op.LessOr:
287
+ assert len(args) == 2
288
+ return args[0] <= args[1]
289
+ case Op.Not:
290
+ assert len(args) == 1
291
+ return int(not args[0])
292
+ case Op.And:
293
+ return all(args)
294
+ case Op.Or:
295
+ return any(args)
296
+ case Op.Negate:
297
+ assert len(args) == 1
298
+ return -args[0]
299
+ case Op.Add:
300
+ return sum(args)
301
+ case Op.Subtract:
302
+ if len(args) == 0:
303
+ return 0
304
+ return args[0] - sum(args[1:])
305
+ case Op.Multiply:
306
+ if len(args) == 0:
307
+ return 1
308
+ return functools.reduce(operator.mul, args, 1)
309
+ case Op.Divide:
310
+ if len(args) == 0:
311
+ return 1
312
+ return args[0] / functools.reduce(operator.mul, args[1:], 1)
313
+ case Op.Power:
314
+ if len(args) == 0:
315
+ return 1
316
+ return functools.reduce(operator.pow, args)
317
+ case Op.Log:
318
+ assert len(args) == 2
319
+ return math.log(args[0], args[1])
320
+ case Op.Ceil:
321
+ assert len(args) == 1
322
+ return math.ceil(args[0])
323
+ case Op.Floor:
324
+ assert len(args) == 1
325
+ return math.floor(args[0])
326
+ case Op.Round:
327
+ assert len(args) == 1
328
+ # This is round half to even in both Python and Sonolus
329
+ return round(args[0])
330
+ case Op.Frac:
331
+ assert len(args) == 1
332
+ return smath.frac(args[0])
333
+ case Op.Mod:
334
+ assert len(args) == 2
335
+ return args[0] % args[1]
336
+ case Op.Rem:
337
+ assert len(args) == 2
338
+ return smath.remainder(args[0], args[1])
339
+ case Op.Sin:
340
+ assert len(args) == 1
341
+ return math.sin(args[0])
342
+ case Op.Cos:
343
+ assert len(args) == 1
344
+ return math.cos(args[0])
345
+ case Op.Tan:
346
+ assert len(args) == 1
347
+ return math.tan(args[0])
348
+ case Op.Sinh:
349
+ assert len(args) == 1
350
+ return math.sinh(args[0])
351
+ case Op.Cosh:
352
+ assert len(args) == 1
353
+ return math.cosh(args[0])
354
+ case Op.Tanh:
355
+ assert len(args) == 1
356
+ return math.tanh(args[0])
357
+ case Op.Arcsin:
358
+ assert len(args) == 1
359
+ return math.asin(args[0])
360
+ case Op.Arccos:
361
+ assert len(args) == 1
362
+ return math.acos(args[0])
363
+ case Op.Arctan:
364
+ assert len(args) == 1
365
+ return math.atan(args[0])
366
+ case Op.Arctan2:
367
+ assert len(args) == 2
368
+ return math.atan2(args[0], args[1])
369
+ case IRGet(place=SSAPlace() as place):
370
+ return values[place]
371
+ case IRGet():
372
+ return NAC
373
+ case IRSet() | _:
374
+ raise TypeError(f"Unexpected statement: {stmt}")
@@ -0,0 +1,85 @@
1
+ from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet
2
+ from sonolus.backend.optimize.flow import BasicBlock, traverse_cfg_preorder
3
+ from sonolus.backend.optimize.liveness import LivenessAnalysis, get_live
4
+ from sonolus.backend.optimize.passes import CompilerPass
5
+ from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
6
+
7
+
8
+ class CopyCoalesce(CompilerPass):
9
+ def requires(self) -> set[CompilerPass]:
10
+ return {LivenessAnalysis()}
11
+
12
+ def run(self, entry: BasicBlock) -> BasicBlock:
13
+ mapping = self.get_mapping(entry)
14
+ for block in traverse_cfg_preorder(entry):
15
+ block.statements = [self.apply_to_stmt(stmt, mapping) for stmt in block.statements]
16
+ block.test = self.apply_to_stmt(block.test, mapping)
17
+ return entry
18
+
19
+ def apply_to_stmt(self, stmt, mapping: dict[TempBlock, TempBlock]):
20
+ match stmt:
21
+ case IRConst():
22
+ return stmt
23
+ case IRSet(place=place, value=value):
24
+ return IRSet(self.apply_to_stmt(place, mapping), self.apply_to_stmt(value, mapping))
25
+ case IRGet(place=place):
26
+ return IRGet(self.apply_to_stmt(place, mapping))
27
+ case IRPureInstr(op=op, args=args):
28
+ return IRPureInstr(op, [self.apply_to_stmt(arg, mapping) for arg in args])
29
+ case IRInstr(op=op, args=args):
30
+ return IRInstr(op, [self.apply_to_stmt(arg, mapping) for arg in args])
31
+ case BlockPlace(block=block, index=index, offset=offset):
32
+ return BlockPlace(self.apply_to_stmt(block, mapping), self.apply_to_stmt(index, mapping), offset)
33
+ case TempBlock():
34
+ return mapping.get(stmt, stmt)
35
+ case SSAPlace() | int() | float():
36
+ return stmt
37
+
38
+ def get_mapping(self, entry: BasicBlock) -> dict[TempBlock, TempBlock]:
39
+ interference = self.get_interference(entry)
40
+ copies = self.get_copies(entry)
41
+
42
+ mapping = {}
43
+
44
+ for target, sources in copies.items():
45
+ for source in sources:
46
+ if source in mapping:
47
+ continue
48
+ if source in interference.get(target, set()):
49
+ continue
50
+ mapping[source] = mapping.get(target, target)
51
+ combined_interference = interference.get(target, set()) | interference.get(source, set())
52
+ interference[source] = combined_interference
53
+ interference[target] = combined_interference
54
+
55
+ return mapping
56
+
57
+ def get_interference(self, entry: BasicBlock) -> dict[TempBlock, set[TempBlock]]:
58
+ result = {}
59
+ for block in traverse_cfg_preorder(entry):
60
+ for stmt in [*block.statements, block.test]:
61
+ live = {p for p in get_live(stmt) if isinstance(p, TempBlock) and p.size == 1}
62
+ for place in live:
63
+ result.setdefault(place, set()).update(live - {place})
64
+ return result
65
+
66
+ def get_copies(self, entry: BasicBlock) -> dict[TempBlock, set[TempBlock]]:
67
+ result = {}
68
+ for block in traverse_cfg_preorder(entry):
69
+ for stmt in block.statements:
70
+ if (
71
+ not isinstance(stmt, IRSet)
72
+ or not isinstance(stmt.place, BlockPlace)
73
+ or not isinstance(stmt.place.block, TempBlock)
74
+ or stmt.place.block.size != 1
75
+ or not isinstance(stmt.value, IRGet)
76
+ or not isinstance(stmt.value.place, BlockPlace)
77
+ or not isinstance(stmt.value.place.block, TempBlock)
78
+ or stmt.value.place.block.size != 1
79
+ ):
80
+ continue
81
+ target = stmt.place.block
82
+ source = stmt.value.place.block
83
+ result.setdefault(target, set()).add(source)
84
+ result.setdefault(source, set()).add(target)
85
+ return result
@@ -0,0 +1,185 @@
1
+ from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet, IRStmt
2
+ from sonolus.backend.optimize.flow import BasicBlock, traverse_cfg_preorder
3
+ from sonolus.backend.optimize.liveness import HasLiveness, LivenessAnalysis, get_live, get_live_phi_targets
4
+ from sonolus.backend.optimize.passes import CompilerPass
5
+ from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
6
+
7
+
8
+ class UnreachableCodeElimination(CompilerPass):
9
+ def run(self, entry: BasicBlock) -> BasicBlock:
10
+ original_blocks = [*traverse_cfg_preorder(entry)]
11
+ worklist = {entry}
12
+ visited = set()
13
+ while worklist:
14
+ block = worklist.pop()
15
+ if block in visited:
16
+ continue
17
+ visited.add(block)
18
+ match block.test:
19
+ case IRConst(value=value):
20
+ block.test = IRConst(0)
21
+ taken_edge = next(
22
+ (edge for edge in block.outgoing if edge.cond == value),
23
+ None,
24
+ ) or next((edge for edge in block.outgoing if edge.cond is None), None)
25
+ assert not block.outgoing or taken_edge
26
+ for edge in [*block.outgoing]:
27
+ if edge is not taken_edge:
28
+ edge.dst.incoming.remove(edge)
29
+ block.outgoing.remove(edge)
30
+ if taken_edge:
31
+ taken_edge.cond = None
32
+ block.outgoing.add(taken_edge)
33
+ worklist.add(taken_edge.dst)
34
+ case _:
35
+ worklist.update(edge.dst for edge in block.outgoing)
36
+ for block in original_blocks:
37
+ if block not in visited:
38
+ for edge in block.outgoing:
39
+ edge.dst.incoming.remove(edge)
40
+ else:
41
+ for args in block.phis.values():
42
+ for src_block in [*args]:
43
+ if src_block not in visited:
44
+ args.pop(src_block)
45
+ return entry
46
+
47
+
48
+ class DeadCodeElimination(CompilerPass):
49
+ def run(self, entry: BasicBlock) -> BasicBlock:
50
+ uses = set()
51
+ defs = {}
52
+ for block in traverse_cfg_preorder(entry):
53
+ for statement in block.statements:
54
+ self.handle_statement(statement, uses, defs)
55
+ for target, args in block.phis.items():
56
+ if target not in defs:
57
+ defs[target] = []
58
+ defs[target].append(tuple(args.values()))
59
+ self.update_uses(block.test, uses)
60
+
61
+ queue = [*uses]
62
+ while queue:
63
+ val = queue.pop()
64
+ if val not in defs:
65
+ continue
66
+ for stmt in defs[val]:
67
+ if isinstance(stmt, tuple):
68
+ stmt_uses = stmt
69
+ else:
70
+ stmt_uses = self.update_uses(stmt, set())
71
+ for use in stmt_uses:
72
+ if use not in uses:
73
+ uses.add(use)
74
+ queue.append(use)
75
+
76
+ for block in traverse_cfg_preorder(entry):
77
+ live_stmts = []
78
+ for statement in block.statements:
79
+ match statement:
80
+ case IRSet(place=place, value=value):
81
+ is_live = not (
82
+ (isinstance(place, SSAPlace) and place not in uses)
83
+ or (
84
+ isinstance(place, BlockPlace)
85
+ and isinstance(place.block, TempBlock)
86
+ and place.block not in uses
87
+ )
88
+ or (isinstance(value, IRGet) and place == value.place)
89
+ )
90
+ if is_live:
91
+ live_stmts.append(statement)
92
+ elif isinstance(value, IRInstr) and value.op.side_effects:
93
+ live_stmts.append(value)
94
+ case other:
95
+ live_stmts.append(other)
96
+ block.statements = live_stmts
97
+ block.phis = {place: phi for place, phi in block.phis.items() if place in uses}
98
+ return entry
99
+
100
+ def handle_statement(
101
+ self,
102
+ stmt: IRStmt | BlockPlace | SSAPlace | TempBlock | int,
103
+ uses: set[HasLiveness],
104
+ defs: dict[HasLiveness, list[IRStmt | tuple[HasLiveness]]],
105
+ ):
106
+ if isinstance(stmt, IRSet):
107
+ place = stmt.place
108
+ value = stmt.value
109
+ if isinstance(place, SSAPlace):
110
+ if place not in defs:
111
+ defs[place] = []
112
+ defs[place].append(stmt)
113
+ if isinstance(value, IRInstr) and value.op.side_effects:
114
+ self.update_uses(value, uses)
115
+ elif isinstance(place, BlockPlace) and isinstance(place.block, TempBlock):
116
+ if place.block not in defs:
117
+ defs[place.block] = []
118
+ defs[place.block].append(stmt)
119
+ if isinstance(value, IRInstr) and value.op.side_effects:
120
+ self.update_uses(value, uses)
121
+ else:
122
+ self.update_uses(place, uses)
123
+ self.update_uses(value, uses)
124
+ else:
125
+ self.update_uses(stmt, uses)
126
+
127
+ def update_uses(
128
+ self, stmt: IRStmt | BlockPlace | SSAPlace | TempBlock | int, uses: set[HasLiveness]
129
+ ) -> set[HasLiveness]:
130
+ match stmt:
131
+ case IRPureInstr(op=_, args=args) | IRInstr(op=_, args=args):
132
+ for arg in args:
133
+ self.update_uses(arg, uses)
134
+ case IRGet(place=place):
135
+ self.update_uses(place, uses)
136
+ case IRSet(place=place, value=value):
137
+ if isinstance(place, BlockPlace):
138
+ if not isinstance(place.block, TempBlock):
139
+ self.update_uses(place.block, uses)
140
+ self.update_uses(place.index, uses)
141
+ self.update_uses(value, uses)
142
+ case IRConst() | int():
143
+ pass
144
+ case BlockPlace(block=block, index=index, offset=_):
145
+ self.update_uses(block, uses)
146
+ self.update_uses(index, uses)
147
+ case TempBlock() | SSAPlace():
148
+ uses.add(stmt)
149
+ case _:
150
+ raise TypeError(f"Unexpected statement type: {type(stmt)}")
151
+ return uses
152
+
153
+
154
+ class AdvancedDeadCodeElimination(CompilerPass):
155
+ """Slower than regular DeadCodeElimination but can handle cases like definitions after the last use and so on."""
156
+
157
+ def requires(self) -> set[CompilerPass]:
158
+ return {LivenessAnalysis()}
159
+
160
+ def run(self, entry: BasicBlock) -> BasicBlock:
161
+ for block in traverse_cfg_preorder(entry):
162
+ live_stmts = []
163
+ for statement in block.statements:
164
+ live = get_live(statement)
165
+ match statement:
166
+ case IRSet(place=place, value=value):
167
+ is_live = not (
168
+ (isinstance(place, SSAPlace) and place not in live)
169
+ or (
170
+ isinstance(place, BlockPlace)
171
+ and isinstance(place.block, TempBlock)
172
+ and place.block not in live
173
+ )
174
+ or (isinstance(value, IRGet) and place == value.place)
175
+ )
176
+ if is_live:
177
+ live_stmts.append(statement)
178
+ elif isinstance(value, IRInstr) and value.op.side_effects:
179
+ live_stmts.append(value)
180
+ value.live = live
181
+ case other:
182
+ live_stmts.append(other)
183
+ block.statements = live_stmts
184
+ block.phis = {place: phi for place, phi in block.phis.items() if place in get_live_phi_targets(block)}
185
+ return entry