sonolus.py 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sonolus.py might be problematic. Click here for more details.

Files changed (66) hide show
  1. sonolus/backend/excepthook.py +30 -0
  2. sonolus/backend/finalize.py +15 -1
  3. sonolus/backend/ops.py +4 -0
  4. sonolus/backend/optimize/allocate.py +5 -5
  5. sonolus/backend/optimize/constant_evaluation.py +124 -19
  6. sonolus/backend/optimize/copy_coalesce.py +15 -12
  7. sonolus/backend/optimize/dead_code.py +7 -6
  8. sonolus/backend/optimize/dominance.py +2 -2
  9. sonolus/backend/optimize/flow.py +54 -8
  10. sonolus/backend/optimize/inlining.py +137 -30
  11. sonolus/backend/optimize/liveness.py +2 -2
  12. sonolus/backend/optimize/optimize.py +15 -1
  13. sonolus/backend/optimize/passes.py +11 -3
  14. sonolus/backend/optimize/simplify.py +137 -8
  15. sonolus/backend/optimize/ssa.py +47 -13
  16. sonolus/backend/place.py +5 -4
  17. sonolus/backend/utils.py +24 -0
  18. sonolus/backend/visitor.py +260 -17
  19. sonolus/build/cli.py +47 -19
  20. sonolus/build/compile.py +12 -5
  21. sonolus/build/engine.py +70 -1
  22. sonolus/build/level.py +3 -3
  23. sonolus/build/project.py +2 -2
  24. sonolus/script/archetype.py +27 -24
  25. sonolus/script/array.py +25 -19
  26. sonolus/script/array_like.py +46 -49
  27. sonolus/script/bucket.py +1 -1
  28. sonolus/script/containers.py +22 -26
  29. sonolus/script/debug.py +24 -47
  30. sonolus/script/effect.py +1 -1
  31. sonolus/script/engine.py +2 -2
  32. sonolus/script/globals.py +3 -3
  33. sonolus/script/instruction.py +3 -3
  34. sonolus/script/internal/builtin_impls.py +155 -28
  35. sonolus/script/internal/constant.py +13 -3
  36. sonolus/script/internal/context.py +46 -15
  37. sonolus/script/internal/impl.py +9 -3
  38. sonolus/script/internal/introspection.py +8 -1
  39. sonolus/script/internal/math_impls.py +17 -0
  40. sonolus/script/internal/native.py +5 -5
  41. sonolus/script/internal/range.py +14 -17
  42. sonolus/script/internal/simulation_context.py +1 -1
  43. sonolus/script/internal/transient.py +2 -2
  44. sonolus/script/internal/value.py +42 -4
  45. sonolus/script/interval.py +15 -15
  46. sonolus/script/iterator.py +38 -107
  47. sonolus/script/maybe.py +139 -0
  48. sonolus/script/num.py +30 -15
  49. sonolus/script/options.py +1 -1
  50. sonolus/script/particle.py +1 -1
  51. sonolus/script/pointer.py +1 -1
  52. sonolus/script/project.py +24 -5
  53. sonolus/script/quad.py +15 -15
  54. sonolus/script/record.py +21 -12
  55. sonolus/script/runtime.py +22 -18
  56. sonolus/script/sprite.py +1 -1
  57. sonolus/script/stream.py +69 -85
  58. sonolus/script/transform.py +35 -34
  59. sonolus/script/values.py +10 -10
  60. sonolus/script/vec.py +23 -20
  61. {sonolus_py-0.3.3.dist-info → sonolus_py-0.4.0.dist-info}/METADATA +1 -1
  62. sonolus_py-0.4.0.dist-info/RECORD +93 -0
  63. sonolus_py-0.3.3.dist-info/RECORD +0 -92
  64. {sonolus_py-0.3.3.dist-info → sonolus_py-0.4.0.dist-info}/WHEEL +0 -0
  65. {sonolus_py-0.3.3.dist-info → sonolus_py-0.4.0.dist-info}/entry_points.txt +0 -0
  66. {sonolus_py-0.3.3.dist-info → sonolus_py-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,29 @@
1
+ from sonolus.backend.blocks import BlockData
1
2
  from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet, IRStmt
2
3
  from sonolus.backend.optimize.flow import BasicBlock, traverse_cfg_preorder
3
- from sonolus.backend.optimize.passes import CompilerPass
4
+ from sonolus.backend.optimize.passes import CompilerPass, OptimizerConfig
4
5
  from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
5
6
 
7
+ RUNTIME_CONSTANT_BLOCKS = {
8
+ "RuntimeEnvironment",
9
+ "RuntimeUI",
10
+ "RuntimeUIConfiguration",
11
+ "LevelData",
12
+ "LevelOption",
13
+ "LevelBucket",
14
+ "LevelScore",
15
+ "LevelLife",
16
+ "EngineRom",
17
+ "ArchetypeLife",
18
+ "RuntimeCanvas",
19
+ "PreviewData",
20
+ "PreviewOption",
21
+ "TutorialData",
22
+ }
23
+
6
24
 
7
25
  class InlineVars(CompilerPass):
8
- def run(self, entry: BasicBlock) -> BasicBlock:
26
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
9
27
  use_counts: dict[SSAPlace, int] = {}
10
28
  definitions: dict[SSAPlace, IRStmt] = {}
11
29
 
@@ -15,49 +33,97 @@ class InlineVars(CompilerPass):
15
33
  if isinstance(stmt, IRSet) and isinstance(stmt.place, SSAPlace):
16
34
  definitions[stmt.place] = stmt.value
17
35
  self.count_uses(block.test, use_counts)
36
+ for tgt, args in block.phis.items():
37
+ for arg in args.values():
38
+ self.count_uses(arg, use_counts)
39
+ if len(args) == 1:
40
+ arg = next(iter(args.values()))
41
+ definitions[tgt] = IRGet(place=arg)
18
42
 
43
+ for defn in definitions.values():
44
+ if isinstance(defn, IRGet) and isinstance(defn.place, SSAPlace):
45
+ use_counts[defn.place] -= 1
46
+
47
+ canonical_definitions: dict[SSAPlace, IRStmt] = {}
19
48
  for p, defn in definitions.items():
49
+ canonical_definitions[p] = defn
50
+ # Update the definition if it's a Get from another SSAPlace until we reach a definition that is not a Get
51
+ while defn and isinstance(defn, IRGet) and isinstance(defn.place, SSAPlace):
52
+ canonical_definitions[p] = defn
53
+ defn = definitions.get(defn.place, None) # Can be None if it's a phi
54
+ canonical_defn = canonical_definitions[p]
55
+ if (
56
+ use_counts.get(p, 0) > 0
57
+ and isinstance(canonical_defn, IRGet)
58
+ and isinstance(canonical_defn.place, SSAPlace)
59
+ ):
60
+ use_counts[canonical_defn.place] = use_counts.get(canonical_defn.place, 0) + 1
61
+
62
+ for p, defn in canonical_definitions.items():
63
+ if isinstance(defn, IRGet) and isinstance(defn.place, SSAPlace):
64
+ inner_p = defn.place
65
+ inner_defn = canonical_definitions.get(inner_p)
66
+ if (
67
+ inner_defn
68
+ and self.is_inlinable(inner_defn, config.callback)
69
+ and (use_counts.get(inner_p, 0) <= 1 or self.is_free_to_inline(inner_defn, config.callback))
70
+ ):
71
+ canonical_definitions[p] = inner_defn
72
+
73
+ inlined_definitions = {**canonical_definitions}
74
+ for p, defn in canonical_definitions.items():
20
75
  while True:
21
- if isinstance(defn, IRGet) and isinstance(defn.place, SSAPlace) and defn.place in definitions:
22
- inside_defn = definitions[defn.place]
23
- if not self.is_inlinable(inside_defn):
24
- break
25
- defn = inside_defn
26
- continue
27
76
  inlinable_uses = self.get_inlinable_uses(defn, set())
28
77
  subs = {}
29
78
  for inside_p in inlinable_uses:
30
- if inside_p not in definitions:
79
+ if inside_p not in canonical_definitions:
31
80
  continue
32
- inside_defn = definitions[inside_p]
33
- if not self.is_inlinable(inside_defn):
81
+ inside_defn = canonical_definitions[inside_p]
82
+ if not self.is_inlinable(inside_defn, config.callback):
34
83
  continue
35
- if (isinstance(inside_defn, IRGet) and isinstance(inside_defn.place, SSAPlace)) or use_counts[
36
- inside_p
37
- ] == 1:
84
+ if (
85
+ (isinstance(inside_defn, IRGet) and isinstance(inside_defn.place, SSAPlace))
86
+ or use_counts[inside_p] == 1
87
+ or self.is_free_to_inline(inside_defn, config.callback)
88
+ ):
38
89
  subs[inside_p] = inside_defn
39
90
  if not subs:
40
91
  break
41
92
  defn = self.substitute(defn, subs)
42
- definitions[p] = defn
93
+ inlined_definitions[p] = defn
43
94
 
44
- valid = {p for p in definitions if self.is_inlinable(definitions[p]) and use_counts.get(p, 0) <= 1}
95
+ valid = {
96
+ p
97
+ for p in inlined_definitions
98
+ if self.is_inlinable(inlined_definitions[p], config.callback)
99
+ and (use_counts.get(p, 0) <= 1 or self.is_free_to_inline(inlined_definitions[p], config.callback))
100
+ }
45
101
 
46
102
  for block in traverse_cfg_preorder(entry):
47
103
  new_statements = []
48
104
  for stmt in [*block.statements, block.test]:
49
- inlinable_uses = self.get_inlinable_uses(stmt, set())
50
- subs = {}
51
- for p in inlinable_uses:
52
- if p not in valid:
53
- continue
54
- definition = definitions[p]
55
- subs[p] = definition
56
-
57
- if subs:
58
- new_statements.append(self.substitute(stmt, subs))
59
- else:
105
+ if (
106
+ isinstance(stmt, IRSet)
107
+ and isinstance(stmt.place, SSAPlace)
108
+ and isinstance(stmt.value, IRGet)
109
+ and isinstance(stmt.value.place, SSAPlace)
110
+ ):
111
+ # Don't bother inlining a direct alias since it can get optimized away later and
112
+ # reordering can reduce optimality since we don't have many other code motion optimizations.
60
113
  new_statements.append(stmt)
114
+ continue
115
+ while True:
116
+ inlinable_uses = self.get_inlinable_uses(stmt, set())
117
+ subs = {}
118
+ for p in inlinable_uses:
119
+ if p in valid:
120
+ subs[p] = inlined_definitions[p]
121
+
122
+ if subs:
123
+ stmt = self.substitute(stmt, subs)
124
+ else:
125
+ new_statements.append(stmt)
126
+ break
61
127
 
62
128
  block.statements = new_statements[:-1]
63
129
  block.test = new_statements[-1]
@@ -123,14 +189,55 @@ class InlineVars(CompilerPass):
123
189
  raise TypeError(f"Unexpected statement: {stmt}")
124
190
  return uses
125
191
 
126
- def is_inlinable(self, stmt):
192
+ def is_inlinable(self, stmt, callback: str):
193
+ match stmt:
194
+ case IRConst():
195
+ return True
196
+ case IRInstr(op=op, args=args) | IRPureInstr(op=op, args=args):
197
+ return not op.side_effects and op.pure and all(self.is_inlinable(arg, callback) for arg in args)
198
+ case IRGet():
199
+ return isinstance(stmt.place, SSAPlace) or (
200
+ isinstance(stmt.place, BlockPlace)
201
+ and isinstance(stmt.place.block, BlockData)
202
+ and callback not in stmt.place.block.writable
203
+ and isinstance(stmt.place.index, int | SSAPlace)
204
+ )
205
+ case IRSet():
206
+ return False
207
+ case _:
208
+ raise TypeError(f"Unexpected statement: {stmt}")
209
+
210
+ def is_free_to_inline(self, stmt: IRStmt, callback: str) -> bool:
211
+ match stmt:
212
+ case IRConst():
213
+ return True
214
+ case IRInstr() | IRPureInstr():
215
+ return self.is_runtime_constant(stmt, callback)
216
+ case IRGet():
217
+ return isinstance(stmt.place, SSAPlace) or (
218
+ isinstance(stmt.place, BlockPlace)
219
+ and isinstance(stmt.place.block, float | int)
220
+ and isinstance(stmt.place.index, float | int)
221
+ )
222
+ case IRSet():
223
+ return False
224
+ case _:
225
+ raise TypeError(f"Unexpected statement: {stmt}")
226
+
227
+ def is_runtime_constant(self, stmt: IRStmt, callback: str) -> bool:
127
228
  match stmt:
128
229
  case IRConst():
129
230
  return True
130
231
  case IRInstr(op=op, args=args) | IRPureInstr(op=op, args=args):
131
- return not op.side_effects and op.pure and all(self.is_inlinable(arg) for arg in args)
232
+ return not op.side_effects and op.pure and all(self.is_runtime_constant(arg, callback) for arg in args)
132
233
  case IRGet():
133
- return isinstance(stmt.place, SSAPlace)
234
+ return (
235
+ isinstance(stmt.place, BlockPlace)
236
+ and isinstance(stmt.place.block, BlockData)
237
+ and callback not in stmt.place.block.writable
238
+ and stmt.place.block.name in RUNTIME_CONSTANT_BLOCKS
239
+ and isinstance(stmt.place.index, int | SSAPlace)
240
+ )
134
241
  case IRSet():
135
242
  return False
136
243
  case _:
@@ -2,7 +2,7 @@ from collections import deque
2
2
 
3
3
  from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet, IRStmt
4
4
  from sonolus.backend.optimize.flow import BasicBlock, traverse_cfg_preorder
5
- from sonolus.backend.optimize.passes import CompilerPass
5
+ from sonolus.backend.optimize.passes import CompilerPass, OptimizerConfig
6
6
  from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
7
7
 
8
8
  type HasLiveness = SSAPlace | TempBlock
@@ -12,7 +12,7 @@ class LivenessAnalysis(CompilerPass):
12
12
  def destroys(self) -> set[CompilerPass]:
13
13
  return set()
14
14
 
15
- def run(self, entry: BasicBlock) -> BasicBlock:
15
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
16
16
  self.preprocess(entry)
17
17
  self.process(entry)
18
18
  return entry
@@ -7,7 +7,13 @@ from sonolus.backend.optimize.dead_code import (
7
7
  UnreachableCodeElimination,
8
8
  )
9
9
  from sonolus.backend.optimize.inlining import InlineVars
10
- from sonolus.backend.optimize.simplify import CoalesceFlow, NormalizeSwitch, RewriteToSwitch
10
+ from sonolus.backend.optimize.simplify import (
11
+ CoalesceFlow,
12
+ CoalesceSmallConditionalBlocks,
13
+ NormalizeSwitch,
14
+ RemoveRedundantArguments,
15
+ RewriteToSwitch,
16
+ )
11
17
  from sonolus.backend.optimize.ssa import FromSSA, ToSSA
12
18
 
13
19
  MINIMAL_PASSES = (
@@ -27,6 +33,7 @@ STANDARD_PASSES = (
27
33
  CoalesceFlow(),
28
34
  UnreachableCodeElimination(),
29
35
  DeadCodeElimination(),
36
+ CoalesceSmallConditionalBlocks(),
30
37
  ToSSA(),
31
38
  SparseConditionalConstantPropagation(),
32
39
  UnreachableCodeElimination(),
@@ -34,7 +41,14 @@ STANDARD_PASSES = (
34
41
  CoalesceFlow(),
35
42
  InlineVars(),
36
43
  DeadCodeElimination(),
44
+ InlineVars(),
45
+ CoalesceFlow(),
46
+ SparseConditionalConstantPropagation(),
47
+ RemoveRedundantArguments(),
48
+ DeadCodeElimination(),
49
+ CoalesceFlow(),
37
50
  RewriteToSwitch(),
51
+ InlineVars(),
38
52
  FromSSA(),
39
53
  CoalesceFlow(),
40
54
  CopyCoalesce(),
@@ -3,10 +3,18 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from collections import deque
5
5
  from collections.abc import Sequence
6
+ from dataclasses import dataclass
6
7
 
8
+ from sonolus.backend.mode import Mode
7
9
  from sonolus.backend.optimize.flow import BasicBlock
8
10
 
9
11
 
12
+ @dataclass
13
+ class OptimizerConfig:
14
+ mode: Mode | None = None
15
+ callback: str | None = None
16
+
17
+
10
18
  class CompilerPass(ABC):
11
19
  def requires(self) -> set[CompilerPass]:
12
20
  return set()
@@ -32,11 +40,11 @@ class CompilerPass(ABC):
32
40
  return passes | self.applies()
33
41
 
34
42
  @abstractmethod
35
- def run(self, entry: BasicBlock) -> BasicBlock:
43
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
36
44
  pass
37
45
 
38
46
 
39
- def run_passes(entry: BasicBlock, passes: Sequence[CompilerPass]) -> BasicBlock:
47
+ def run_passes(entry: BasicBlock, passes: Sequence[CompilerPass], config: OptimizerConfig) -> BasicBlock:
40
48
  active_passes = set()
41
49
  queue = deque(passes)
42
50
  while queue:
@@ -48,6 +56,6 @@ def run_passes(entry: BasicBlock, passes: Sequence[CompilerPass]) -> BasicBlock:
48
56
  queue.appendleft(current_pass)
49
57
  queue.extendleft(missing_requirements)
50
58
  continue
51
- entry = current_pass.run(entry)
59
+ entry = current_pass.run(entry, config)
52
60
  active_passes = current_pass.exists_after(active_passes)
53
61
  return entry
@@ -1,11 +1,12 @@
1
- from sonolus.backend.ir import IRConst, IRGet, IRPureInstr, IRSet
1
+ from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet
2
2
  from sonolus.backend.ops import Op
3
- from sonolus.backend.optimize.flow import BasicBlock, traverse_cfg_preorder
4
- from sonolus.backend.optimize.passes import CompilerPass
3
+ from sonolus.backend.optimize.flow import BasicBlock, FlowEdge, traverse_cfg_preorder
4
+ from sonolus.backend.optimize.passes import CompilerPass, OptimizerConfig
5
+ from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
5
6
 
6
7
 
7
8
  class CoalesceFlow(CompilerPass):
8
- def run(self, entry: BasicBlock) -> BasicBlock:
9
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
9
10
  queue = [entry]
10
11
  processed = set()
11
12
  while queue:
@@ -16,7 +17,7 @@ class CoalesceFlow(CompilerPass):
16
17
  for edge in block.outgoing:
17
18
  while True:
18
19
  dst = edge.dst
19
- if dst.phis or dst.statements or len(dst.outgoing) != 1 or dst is block:
20
+ if dst.phis or dst.statements or len(dst.outgoing) != 1 or dst is block or dst is entry:
20
21
  break
21
22
  next_dst = next(iter(dst.outgoing)).dst
22
23
  if next_dst.phis:
@@ -42,6 +43,8 @@ class CoalesceFlow(CompilerPass):
42
43
  queue.extend(edge.dst for edge in block.outgoing)
43
44
  continue
44
45
  next_block = next(iter(block.outgoing)).dst
46
+ if next_block is block or next_block is entry:
47
+ continue
45
48
  if len(next_block.incoming) != 1:
46
49
  queue.append(next_block)
47
50
  if not block.statements and not block.phis and not next_block.phis:
@@ -55,7 +58,8 @@ class CoalesceFlow(CompilerPass):
55
58
  continue
56
59
  for p, args in next_block.phis.items():
57
60
  if block not in args:
58
- continue
61
+ # This is the only predecessor to the block, so it must be a phi argument
62
+ raise ValueError("Missing phi argument")
59
63
  block.statements.append(IRSet(p, IRGet(args[block])))
60
64
  block.statements.extend(next_block.statements)
61
65
  block.test = next_block.test
@@ -73,13 +77,48 @@ class CoalesceFlow(CompilerPass):
73
77
  return entry
74
78
 
75
79
 
80
+ class CoalesceSmallConditionalBlocks(CompilerPass):
81
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
82
+ queue = [entry]
83
+ processed = set()
84
+ while queue:
85
+ block = queue.pop()
86
+ if block.phis:
87
+ raise RuntimeError("SSA form is not supported in this pass")
88
+ if block in processed:
89
+ continue
90
+ processed.add(block)
91
+ while len(block.outgoing) == 1:
92
+ next_edge = next(iter(block.outgoing))
93
+ next_block = next_edge.dst
94
+ if len(next_block.statements) <= 1:
95
+ next_block.incoming.remove(next_edge)
96
+ block.test = next_block.test
97
+ block.outgoing = {FlowEdge(src=block, dst=edge.dst, cond=edge.cond) for edge in next_block.outgoing}
98
+ block.statements.extend(next_block.statements)
99
+ for edge in block.outgoing:
100
+ edge.dst.incoming.add(edge)
101
+ else:
102
+ break
103
+ queue.extend(
104
+ edge.dst
105
+ for edge in sorted(block.outgoing, key=lambda e: (e.cond is not None, e.cond))
106
+ if edge.dst not in processed
107
+ )
108
+
109
+ reachable_blocks = set(traverse_cfg_preorder(entry))
110
+ for block in traverse_cfg_preorder(entry):
111
+ block.incoming = {edge for edge in block.incoming if edge.src in reachable_blocks}
112
+ return entry
113
+
114
+
76
115
  class RewriteToSwitch(CompilerPass):
77
116
  """Rewrite if-else chains to switch statements.
78
117
 
79
118
  Note that this needs inlining (and dead code elimination) to be run first to really do anything useful.
80
119
  """
81
120
 
82
- def run(self, entry: BasicBlock) -> BasicBlock:
121
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
83
122
  self.ifs_to_switch(entry)
84
123
  self.combine_blocks(entry)
85
124
  self.remove_unreachable(entry)
@@ -159,7 +198,7 @@ class RewriteToSwitch(CompilerPass):
159
198
  class NormalizeSwitch(CompilerPass):
160
199
  """Normalize branches like cond -> case a, case a + b, case a + 2b to ((cond - a) / b) -> case 0, case 1, case 2."""
161
200
 
162
- def run(self, entry: BasicBlock) -> BasicBlock:
201
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
163
202
  for block in traverse_cfg_preorder(entry):
164
203
  cases = {edge.cond for edge in block.outgoing}
165
204
  if len(cases) <= 2:
@@ -189,3 +228,93 @@ class NormalizeSwitch(CompilerPass):
189
228
  if case != offset + i * stride:
190
229
  return None, None
191
230
  return offset, stride
231
+
232
+
233
+ class RemoveRedundantArguments(CompilerPass):
234
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
235
+ for block in traverse_cfg_preorder(entry):
236
+ block.statements = [self.update_statement(stmt) for stmt in block.statements]
237
+ block.test = self.update_statement(block.test)
238
+ return entry
239
+
240
+ def update_statement(self, stmt):
241
+ match stmt:
242
+ case IRPureInstr() | IRPureInstr():
243
+ op = stmt.op
244
+ args = stmt.args
245
+ match op:
246
+ case Op.Add:
247
+ args = [arg for arg in args if not (isinstance(arg, IRConst) and arg.value == 0)]
248
+ if len(args) == 1:
249
+ return args[0]
250
+ case Op.Subtract:
251
+ args = [
252
+ args[0],
253
+ *(arg for arg in args[1:] if not (isinstance(arg, IRConst) and arg.value == 0)),
254
+ ]
255
+ if len(args) == 1:
256
+ return args[0]
257
+ case Op.Multiply:
258
+ args = [arg for arg in args if not (isinstance(arg, IRConst) and arg.value == 1)]
259
+ if len(args) == 1:
260
+ return args[0]
261
+ case Op.Divide:
262
+ args = [
263
+ args[0],
264
+ *(arg for arg in args[1:] if not (isinstance(arg, IRConst) and arg.value == 1)),
265
+ ]
266
+ if len(args) == 1:
267
+ return args[0]
268
+ return type(stmt)(op=op, args=[self.update_statement(arg) for arg in args])
269
+ case IRSet(place=place, value=value):
270
+ return IRSet(place=place, value=self.update_statement(value))
271
+ case _:
272
+ return stmt
273
+
274
+
275
+ class RenumberVars(CompilerPass):
276
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
277
+ numbers = {}
278
+ for block in traverse_cfg_preorder(entry):
279
+ for stmt in block.statements:
280
+ if (
281
+ isinstance(stmt, IRSet)
282
+ and isinstance(stmt.place, BlockPlace)
283
+ and isinstance(stmt.place.block, TempBlock)
284
+ and stmt.place.block.size == 1
285
+ and stmt.place.block not in numbers
286
+ ):
287
+ numbers[stmt.place.block] = len(numbers) + 1
288
+ for block in traverse_cfg_preorder(entry):
289
+ block.statements = [self.update_statement(stmt, numbers) for stmt in block.statements]
290
+ block.test = self.update_statement(block.test, numbers)
291
+ return entry
292
+
293
+ def update_statement(self, stmt, numbers: dict[TempBlock, int]):
294
+ match stmt:
295
+ case IRConst():
296
+ return stmt
297
+ case IRPureInstr(op=op, args=args):
298
+ return IRPureInstr(op=op, args=[self.update_statement(arg, numbers) for arg in args])
299
+ case IRInstr(op=op, args=args):
300
+ return IRInstr(op=op, args=[self.update_statement(arg, numbers) for arg in args])
301
+ case IRGet(place=place):
302
+ return IRGet(place=self.update_statement(place, numbers))
303
+ case IRSet(place=SSAPlace() as place, value=value):
304
+ return IRSet(place=place, value=self.update_statement(value, numbers))
305
+ case IRSet(place=place, value=value):
306
+ return IRSet(place=self.update_statement(place, numbers), value=self.update_statement(value, numbers))
307
+ case BlockPlace(block=block, index=index, offset=offset):
308
+ return BlockPlace(
309
+ block=self.update_statement(block, numbers),
310
+ index=self.update_statement(index, numbers),
311
+ offset=offset,
312
+ )
313
+ case SSAPlace():
314
+ return stmt
315
+ case TempBlock() as b if b in numbers:
316
+ return TempBlock(f"v{numbers[b]}", size=1)
317
+ case int() | float() | TempBlock():
318
+ return stmt
319
+ case _:
320
+ raise TypeError(f"Unexpected statement: {stmt}")
@@ -1,7 +1,7 @@
1
1
  from sonolus.backend.ir import IRConst, IRGet, IRInstr, IRPureInstr, IRSet, IRStmt
2
2
  from sonolus.backend.optimize.dominance import DominanceFrontiers, get_df, get_dom_children
3
3
  from sonolus.backend.optimize.flow import BasicBlock, FlowEdge, traverse_cfg_preorder
4
- from sonolus.backend.optimize.passes import CompilerPass
4
+ from sonolus.backend.optimize.passes import CompilerPass, OptimizerConfig
5
5
  from sonolus.backend.place import BlockPlace, SSAPlace, TempBlock
6
6
 
7
7
 
@@ -9,7 +9,7 @@ class ToSSA(CompilerPass):
9
9
  def requires(self) -> set[CompilerPass]:
10
10
  return {DominanceFrontiers()}
11
11
 
12
- def run(self, entry: BasicBlock) -> BasicBlock:
12
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
13
13
  defs = self.defs_to_blocks(entry)
14
14
  self.insert_phis(defs)
15
15
  self.rename(entry, defs, {var: [] for var in defs}, {})
@@ -139,15 +139,15 @@ class ToSSA(CompilerPass):
139
139
 
140
140
 
141
141
  class FromSSA(CompilerPass):
142
- def run(self, entry: BasicBlock) -> BasicBlock:
142
+ def run(self, entry: BasicBlock, config: OptimizerConfig) -> BasicBlock:
143
143
  for block in [*traverse_cfg_preorder(entry)]:
144
144
  self.process_block(block)
145
145
  return entry
146
146
 
147
147
  def process_block(self, block: BasicBlock):
148
- incoming = [*block.incoming]
148
+ orig_incoming = [*block.incoming]
149
149
  block.incoming.clear()
150
- for edge in incoming:
150
+ for edge in orig_incoming:
151
151
  between_block = BasicBlock()
152
152
  edge.dst = between_block
153
153
  between_block.incoming.add(edge)
@@ -156,12 +156,46 @@ class FromSSA(CompilerPass):
156
156
  between_block.outgoing.add(next_edge)
157
157
  for args in block.phis.values():
158
158
  if edge.src in args:
159
- args[between_block] = args.pop(edge.src)
159
+ args[between_block] = args[edge.src]
160
+ for edge in orig_incoming:
161
+ # Multiple edges with different conditions can connect two of the same blocks,
162
+ # so we need to remove the old phi arguments in a pass at the end.
163
+ for args in block.phis.values():
164
+ if edge.src in args:
165
+ del args[edge.src]
166
+ incoming_blocks = {edge.src for edge in block.incoming}
167
+ args_by_src = {}
168
+ for args in block.phis.values():
169
+ for src, arg in args.items():
170
+ if src not in args_by_src:
171
+ args_by_src[src] = set()
172
+ args_by_src[src].add(arg)
160
173
  for var, args in block.phis.items():
161
174
  for src, arg in args.items():
162
- src.statements.append(
163
- IRSet(place=self.place_from_ssa_place(var), value=IRGet(place=self.place_from_ssa_place(arg)))
164
- )
175
+ if src not in incoming_blocks:
176
+ # Edges may have been rewritten so a phi refers to a block that is no longer directly connected.
177
+ continue
178
+ if var in args_by_src[src]:
179
+ # Make an extra copy first of values that may be overwritten by another assignment.
180
+ src.statements.append(
181
+ IRSet(
182
+ place=self.place_from_ssa_place(var, "*"), value=IRGet(place=self.place_from_ssa_place(arg))
183
+ )
184
+ )
185
+ for var, args in block.phis.items():
186
+ for src, arg in args.items():
187
+ if src not in incoming_blocks:
188
+ continue
189
+ if var in args_by_src[src]:
190
+ src.statements.append(
191
+ IRSet(
192
+ place=self.place_from_ssa_place(var), value=IRGet(place=self.place_from_ssa_place(var, "*"))
193
+ )
194
+ )
195
+ else:
196
+ src.statements.append(
197
+ IRSet(place=self.place_from_ssa_place(var), value=IRGet(place=self.place_from_ssa_place(arg)))
198
+ )
165
199
  block.phis = {}
166
200
  block.statements = [self.process_stmt(stmt) for stmt in block.statements]
167
201
  block.test = self.process_stmt(block.test)
@@ -193,8 +227,8 @@ class FromSSA(CompilerPass):
193
227
  case _:
194
228
  raise TypeError(f"Unexpected statement: {stmt}")
195
229
 
196
- def temp_block_from_ssa_place(self, ssa_place: SSAPlace) -> TempBlock:
197
- return TempBlock(f"{ssa_place.name}.{ssa_place.num}")
230
+ def temp_block_from_ssa_place(self, ssa_place: SSAPlace, suffix: str = "") -> TempBlock:
231
+ return TempBlock(f"{ssa_place.name}.{ssa_place.num}{suffix}")
198
232
 
199
- def place_from_ssa_place(self, ssa_place: SSAPlace) -> BlockPlace:
200
- return BlockPlace(block=self.temp_block_from_ssa_place(ssa_place), index=0, offset=0)
233
+ def place_from_ssa_place(self, ssa_place: SSAPlace, suffix: str = "") -> BlockPlace:
234
+ return BlockPlace(block=self.temp_block_from_ssa_place(ssa_place, suffix), index=0, offset=0)
sonolus/backend/place.py CHANGED
@@ -1,5 +1,6 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections.abc import Iterator
2
- from typing import Self
3
4
 
4
5
  from sonolus.backend.blocks import Block
5
6
 
@@ -22,10 +23,10 @@ class TempBlock:
22
23
  def __str__(self):
23
24
  return f"{self.name}"
24
25
 
25
- def __getitem__(self, item) -> "BlockPlace":
26
+ def __getitem__(self, item) -> BlockPlace:
26
27
  return BlockPlace(self, item)
27
28
 
28
- def __iter__(self) -> "Iterator[BlockPlace]":
29
+ def __iter__(self) -> Iterator[BlockPlace]:
29
30
  for i in range(self.size):
30
31
  yield self[i]
31
32
 
@@ -78,7 +79,7 @@ class BlockPlace:
78
79
  else:
79
80
  return f"{self.block}[{self.index} + {self.offset}]"
80
81
 
81
- def add_offset(self, offset: int) -> Self:
82
+ def add_offset(self, offset: int) -> BlockPlace:
82
83
  return BlockPlace(self.block, self.index, self.offset + offset)
83
84
 
84
85
  def __eq__(self, other):
sonolus/backend/utils.py CHANGED
@@ -61,3 +61,27 @@ def scan_writes(node: ast.AST) -> set[str]:
61
61
  visitor = ScanWrites()
62
62
  visitor.visit(node)
63
63
  return set(visitor.writes)
64
+
65
+
66
+ class HasDirectYield(ast.NodeVisitor):
67
+ def __init__(self):
68
+ self.started = False
69
+ self.has_yield = False
70
+
71
+ def visit_Yield(self, node: ast.Yield):
72
+ self.has_yield = True
73
+
74
+ def visit_YieldFrom(self, node: ast.YieldFrom):
75
+ self.has_yield = True
76
+
77
+ def visit_FunctionDef(self, node: ast.FunctionDef):
78
+ if self.started:
79
+ return
80
+ self.started = True
81
+ self.generic_visit(node)
82
+
83
+
84
+ def has_yield(node: ast.AST) -> bool:
85
+ visitor = HasDirectYield()
86
+ visitor.visit(node)
87
+ return visitor.has_yield