machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,433 @@
1
+ """Tests for SSA construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from machine_dialect.mir.basic_block import CFG, BasicBlock
6
+ from machine_dialect.mir.mir_function import MIRFunction
7
+ from machine_dialect.mir.mir_instructions import (
8
+ BinaryOp,
9
+ ConditionalJump,
10
+ Copy,
11
+ Jump,
12
+ LoadConst,
13
+ Return,
14
+ StoreVar,
15
+ )
16
+ from machine_dialect.mir.mir_types import MIRType
17
+ from machine_dialect.mir.mir_values import Constant, Variable
18
+ from machine_dialect.mir.ssa_construction import DominanceInfo, construct_ssa
19
+
20
+
21
+ class TestDominanceInfo:
22
+ """Test dominance analysis."""
23
+
24
+ def test_simple_dominance(self) -> None:
25
+ """Test dominance in a simple CFG."""
26
+ # Create CFG: entry -> block1 -> exit
27
+ cfg = CFG()
28
+ entry = BasicBlock("entry")
29
+ block1 = BasicBlock("block1")
30
+ exit_block = BasicBlock("exit")
31
+
32
+ cfg.add_block(entry)
33
+ cfg.add_block(block1)
34
+ cfg.add_block(exit_block)
35
+ cfg.set_entry_block(entry)
36
+
37
+ cfg.connect(entry, block1)
38
+ cfg.connect(block1, exit_block)
39
+
40
+ # Compute dominance
41
+ dom_info = DominanceInfo(cfg)
42
+
43
+ # Entry dominates all blocks
44
+ assert dom_info.dominates(entry, entry)
45
+ assert dom_info.dominates(entry, block1)
46
+ assert dom_info.dominates(entry, exit_block)
47
+
48
+ # Block1 dominates exit but not entry
49
+ assert not dom_info.dominates(block1, entry)
50
+ assert dom_info.dominates(block1, exit_block)
51
+
52
+ # Exit dominates only itself
53
+ assert not dom_info.dominates(exit_block, entry)
54
+ assert not dom_info.dominates(exit_block, block1)
55
+
56
+ def test_dominance_with_branch(self) -> None:
57
+ """Test dominance in CFG with branches."""
58
+ # Create diamond CFG:
59
+ # entry
60
+ # / \
61
+ # then else
62
+ # \ /
63
+ # merge
64
+ cfg = CFG()
65
+ entry = BasicBlock("entry")
66
+ then_block = BasicBlock("then")
67
+ else_block = BasicBlock("else")
68
+ merge_block = BasicBlock("merge")
69
+
70
+ cfg.add_block(entry)
71
+ cfg.add_block(then_block)
72
+ cfg.add_block(else_block)
73
+ cfg.add_block(merge_block)
74
+ cfg.set_entry_block(entry)
75
+
76
+ cfg.connect(entry, then_block)
77
+ cfg.connect(entry, else_block)
78
+ cfg.connect(then_block, merge_block)
79
+ cfg.connect(else_block, merge_block)
80
+
81
+ # Compute dominance
82
+ dom_info = DominanceInfo(cfg)
83
+
84
+ # Entry dominates all
85
+ assert dom_info.dominates(entry, then_block)
86
+ assert dom_info.dominates(entry, else_block)
87
+ assert dom_info.dominates(entry, merge_block)
88
+
89
+ # Neither branch dominates the other or merge
90
+ assert not dom_info.dominates(then_block, else_block)
91
+ assert not dom_info.dominates(else_block, then_block)
92
+ assert not dom_info.dominates(then_block, merge_block)
93
+ assert not dom_info.dominates(else_block, merge_block)
94
+
95
+ def test_dominance_frontier(self) -> None:
96
+ """Test dominance frontier calculation."""
97
+ # Create diamond CFG
98
+ cfg = CFG()
99
+ entry = BasicBlock("entry")
100
+ then_block = BasicBlock("then")
101
+ else_block = BasicBlock("else")
102
+ merge_block = BasicBlock("merge")
103
+
104
+ cfg.add_block(entry)
105
+ cfg.add_block(then_block)
106
+ cfg.add_block(else_block)
107
+ cfg.add_block(merge_block)
108
+ cfg.set_entry_block(entry)
109
+
110
+ cfg.connect(entry, then_block)
111
+ cfg.connect(entry, else_block)
112
+ cfg.connect(then_block, merge_block)
113
+ cfg.connect(else_block, merge_block)
114
+
115
+ # Compute dominance
116
+ dom_info = DominanceInfo(cfg)
117
+
118
+ # Merge is in dominance frontier of then and else
119
+ assert merge_block in dom_info.dominance_frontier[then_block]
120
+ assert merge_block in dom_info.dominance_frontier[else_block]
121
+
122
+ # Entry and merge have empty frontiers
123
+ assert len(dom_info.dominance_frontier[entry]) == 0
124
+ assert len(dom_info.dominance_frontier[merge_block]) == 0
125
+
126
+
127
+ class TestSSAConstruction:
128
+ """Test SSA construction."""
129
+
130
+ def test_simple_ssa_construction(self) -> None:
131
+ """Test SSA construction for simple function."""
132
+ # Create function with single variable assignment
133
+ func = MIRFunction("test", [], MIRType.EMPTY)
134
+
135
+ # Create blocks
136
+ entry = BasicBlock("entry")
137
+ func.cfg.add_block(entry)
138
+ func.cfg.set_entry_block(entry)
139
+
140
+ # Create variable
141
+ x = Variable("x", MIRType.INT)
142
+ func.add_local(x)
143
+
144
+ # Add instructions: x = 1; return x
145
+ const1 = Constant(1, MIRType.INT)
146
+ entry.add_instruction(StoreVar(x, const1, (1, 1)))
147
+ temp = func.new_temp(MIRType.INT)
148
+ entry.add_instruction(Copy(temp, x, (1, 1)))
149
+ entry.add_instruction(Return((1, 1), temp))
150
+
151
+ # Convert to SSA
152
+ construct_ssa(func)
153
+
154
+ # Should have versioned variables
155
+ # Check that we have SSA form (no verification of exact names)
156
+ assert len(entry.instructions) > 0
157
+
158
+ def test_phi_insertion(self) -> None:
159
+ """Test phi node insertion at join points."""
160
+ # Create function with diamond CFG
161
+ func = MIRFunction("test", [], MIRType.INT)
162
+
163
+ # Create blocks
164
+ entry = BasicBlock("entry")
165
+ then_block = BasicBlock("then")
166
+ else_block = BasicBlock("else")
167
+ merge_block = BasicBlock("merge")
168
+
169
+ func.cfg.add_block(entry)
170
+ func.cfg.add_block(then_block)
171
+ func.cfg.add_block(else_block)
172
+ func.cfg.add_block(merge_block)
173
+ func.cfg.set_entry_block(entry)
174
+
175
+ func.cfg.connect(entry, then_block)
176
+ func.cfg.connect(entry, else_block)
177
+ func.cfg.connect(then_block, merge_block)
178
+ func.cfg.connect(else_block, merge_block)
179
+
180
+ # Create variable
181
+ x = Variable("x", MIRType.INT)
182
+ func.add_local(x)
183
+
184
+ # Add conditional jump in entry
185
+ cond = func.new_temp(MIRType.BOOL)
186
+ entry.add_instruction(LoadConst(cond, True, (1, 1)))
187
+ entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
188
+
189
+ # Assign different values in branches
190
+ const1 = Constant(1, MIRType.INT)
191
+ const2 = Constant(2, MIRType.INT)
192
+ then_block.add_instruction(StoreVar(x, const1, (1, 1)))
193
+ then_block.add_instruction(Jump("merge", (1, 1)))
194
+
195
+ else_block.add_instruction(StoreVar(x, const2, (1, 1)))
196
+ else_block.add_instruction(Jump("merge", (1, 1)))
197
+
198
+ # Use x in merge
199
+ result = func.new_temp(MIRType.INT)
200
+ merge_block.add_instruction(Copy(result, x, (1, 1)))
201
+ merge_block.add_instruction(Return((1, 1), result))
202
+
203
+ # Convert to SSA
204
+ construct_ssa(func)
205
+
206
+ # Merge block should have phi node (in phi_nodes list, not instructions)
207
+ assert len(merge_block.phi_nodes) > 0
208
+
209
+ # Phi should have incoming values from both predecessors
210
+ if merge_block.phi_nodes:
211
+ phi = merge_block.phi_nodes[0]
212
+ assert len(phi.incoming) == 2
213
+ incoming_labels = {label for _, label in phi.incoming}
214
+ assert "then" in incoming_labels
215
+ assert "else" in incoming_labels
216
+
217
+ def test_ssa_with_loops(self) -> None:
218
+ """Test SSA construction with loop (self-referencing phi)."""
219
+ # Create function with loop
220
+ func = MIRFunction("test", [], MIRType.EMPTY)
221
+
222
+ # Create blocks for loop
223
+ entry = BasicBlock("entry")
224
+ loop_header = BasicBlock("loop_header")
225
+ loop_body = BasicBlock("loop_body")
226
+ exit_block = BasicBlock("exit")
227
+
228
+ func.cfg.add_block(entry)
229
+ func.cfg.add_block(loop_header)
230
+ func.cfg.add_block(loop_body)
231
+ func.cfg.add_block(exit_block)
232
+ func.cfg.set_entry_block(entry)
233
+
234
+ # Connect for loop structure
235
+ func.cfg.connect(entry, loop_header)
236
+ func.cfg.connect(loop_header, loop_body)
237
+ func.cfg.connect(loop_header, exit_block)
238
+ func.cfg.connect(loop_body, loop_header) # Back edge
239
+
240
+ # Create loop counter variable
241
+ i = Variable("i", MIRType.INT)
242
+ func.add_local(i)
243
+
244
+ # Initialize counter in entry
245
+ const0 = Constant(0, MIRType.INT)
246
+ entry.add_instruction(StoreVar(i, const0, (1, 1)))
247
+ entry.add_instruction(Jump("loop_header", (1, 1)))
248
+
249
+ # Loop header checks condition
250
+ cond = func.new_temp(MIRType.BOOL)
251
+ ten = Constant(10, MIRType.INT)
252
+ loop_header.add_instruction(BinaryOp(cond, "<", i, ten, (1, 1)))
253
+ loop_header.add_instruction(ConditionalJump(cond, "loop_body", (1, 1), "exit"))
254
+
255
+ # Loop body increments counter
256
+ one = Constant(1, MIRType.INT)
257
+ new_i = func.new_temp(MIRType.INT)
258
+ loop_body.add_instruction(BinaryOp(new_i, "+", i, one, (1, 1)))
259
+ loop_body.add_instruction(StoreVar(i, new_i, (1, 1)))
260
+ loop_body.add_instruction(Jump("loop_header", (1, 1)))
261
+
262
+ # Exit
263
+ exit_block.add_instruction(Return((1, 1)))
264
+
265
+ # Convert to SSA
266
+ construct_ssa(func)
267
+
268
+ # Loop header should have phi node for loop variable (in phi_nodes list)
269
+ assert len(loop_header.phi_nodes) > 0
270
+
271
+ # Phi should have incoming from entry and loop_body
272
+ if loop_header.phi_nodes:
273
+ phi = loop_header.phi_nodes[0]
274
+ incoming_labels = {label for _, label in phi.incoming}
275
+ assert "entry" in incoming_labels
276
+ assert "loop_body" in incoming_labels
277
+
278
+ def test_multiple_variables_ssa(self) -> None:
279
+ """Test SSA construction with multiple variables."""
280
+ # Create function with multiple variables
281
+ func = MIRFunction("test", [], MIRType.EMPTY)
282
+
283
+ # Create simple CFG
284
+ entry = BasicBlock("entry")
285
+ block1 = BasicBlock("block1")
286
+
287
+ func.cfg.add_block(entry)
288
+ func.cfg.add_block(block1)
289
+ func.cfg.set_entry_block(entry)
290
+ func.cfg.connect(entry, block1)
291
+
292
+ # Create multiple variables
293
+ x = Variable("x", MIRType.INT)
294
+ y = Variable("y", MIRType.INT)
295
+ z = Variable("z", MIRType.INT)
296
+
297
+ func.add_local(x)
298
+ func.add_local(y)
299
+ func.add_local(z)
300
+
301
+ # Assign values in entry
302
+ const1 = Constant(1, MIRType.INT)
303
+ const2 = Constant(2, MIRType.INT)
304
+
305
+ entry.add_instruction(StoreVar(x, const1, (1, 1)))
306
+ entry.add_instruction(StoreVar(y, const2, (1, 1)))
307
+
308
+ # Compute z = x + y
309
+ temp = func.new_temp(MIRType.INT)
310
+ entry.add_instruction(BinaryOp(temp, "+", x, y, (1, 1)))
311
+ entry.add_instruction(StoreVar(z, temp, (1, 1)))
312
+ entry.add_instruction(Jump("block1", (1, 1)))
313
+
314
+ # Use all variables in block1
315
+ result = func.new_temp(MIRType.INT)
316
+ block1.add_instruction(BinaryOp(result, "+", z, x, (1, 1)))
317
+ block1.add_instruction(Return((1, 1), result))
318
+
319
+ # Convert to SSA
320
+ construct_ssa(func)
321
+
322
+ # Should complete without errors
323
+ assert True
324
+
325
+ def test_ssa_preserves_semantics(self) -> None:
326
+ """Test that SSA construction preserves program semantics."""
327
+ # Create function that computes: if (cond) x = 1 else x = 2; return x
328
+ func = MIRFunction("test", [], MIRType.INT)
329
+
330
+ # Create diamond CFG
331
+ entry = BasicBlock("entry")
332
+ then_block = BasicBlock("then")
333
+ else_block = BasicBlock("else")
334
+ merge_block = BasicBlock("merge")
335
+
336
+ func.cfg.add_block(entry)
337
+ func.cfg.add_block(then_block)
338
+ func.cfg.add_block(else_block)
339
+ func.cfg.add_block(merge_block)
340
+ func.cfg.set_entry_block(entry)
341
+
342
+ func.cfg.connect(entry, then_block)
343
+ func.cfg.connect(entry, else_block)
344
+ func.cfg.connect(then_block, merge_block)
345
+ func.cfg.connect(else_block, merge_block)
346
+
347
+ # Create variable
348
+ x = Variable("x", MIRType.INT)
349
+ func.add_local(x)
350
+
351
+ # Count instructions before SSA
352
+ total_before = sum(len(b.instructions) for b in func.cfg.blocks.values())
353
+
354
+ # Add instructions
355
+ cond = func.new_temp(MIRType.BOOL)
356
+ entry.add_instruction(LoadConst(cond, True, (1, 1)))
357
+ entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
358
+
359
+ const1 = Constant(1, MIRType.INT)
360
+ const2 = Constant(2, MIRType.INT)
361
+
362
+ then_block.add_instruction(StoreVar(x, const1, (1, 1)))
363
+ then_block.add_instruction(Jump("merge", (1, 1)))
364
+
365
+ else_block.add_instruction(StoreVar(x, const2, (1, 1)))
366
+ else_block.add_instruction(Jump("merge", (1, 1)))
367
+
368
+ result = func.new_temp(MIRType.INT)
369
+ merge_block.add_instruction(Copy(result, x, (1, 1)))
370
+ merge_block.add_instruction(Return((1, 1), result))
371
+
372
+ # Convert to SSA
373
+ construct_ssa(func)
374
+
375
+ # Should have added at least one phi node
376
+ total_after = sum(len(b.instructions) for b in func.cfg.blocks.values())
377
+ assert total_after > total_before
378
+
379
+ # Should still have return instruction
380
+ returns = []
381
+ for block in func.cfg.blocks.values():
382
+ returns.extend([inst for inst in block.instructions if isinstance(inst, Return)])
383
+ assert len(returns) == 1
384
+
385
+ def test_loadconst_preservation(self) -> None:
386
+ """Test that SSA construction preserves LoadConst instructions."""
387
+ # Create function with LoadConst instructions for constants
388
+ func = MIRFunction("test_const", [], MIRType.INT)
389
+
390
+ # Create simple CFG: entry -> exit
391
+ entry = BasicBlock("entry")
392
+ func.cfg.add_block(entry)
393
+ func.cfg.set_entry_block(entry)
394
+
395
+ # Create temporaries for constants
396
+ t0 = func.new_temp(MIRType.INT)
397
+ t1 = func.new_temp(MIRType.INT)
398
+ t2 = func.new_temp(MIRType.BOOL)
399
+
400
+ # Add LoadConst instructions
401
+ const5 = Constant(5, MIRType.INT)
402
+ const1 = Constant(1, MIRType.INT)
403
+
404
+ entry.add_instruction(LoadConst(t0, const5, (1, 1)))
405
+ entry.add_instruction(LoadConst(t1, const1, (1, 1)))
406
+
407
+ # Add binary operation using the loaded constants
408
+ entry.add_instruction(BinaryOp(t2, "<=", t0, t1, (1, 1)))
409
+
410
+ # Return the result
411
+ entry.add_instruction(Return((1, 1), t2))
412
+
413
+ # Count LoadConst instructions before SSA
414
+ loadconst_before = []
415
+ for block in func.cfg.blocks.values():
416
+ loadconst_before.extend([inst for inst in block.instructions if isinstance(inst, LoadConst)])
417
+
418
+ # Apply SSA construction
419
+ construct_ssa(func)
420
+
421
+ # Count LoadConst instructions after SSA
422
+ loadconst_after = []
423
+ for block in func.cfg.blocks.values():
424
+ loadconst_after.extend([inst for inst in block.instructions if isinstance(inst, LoadConst)])
425
+
426
+ # Verify LoadConst instructions are preserved
427
+ assert len(loadconst_before) == 2, "Should have 2 LoadConst instructions before SSA"
428
+ assert len(loadconst_after) == 2, "Should have 2 LoadConst instructions after SSA"
429
+
430
+ # Verify the constants are still correct
431
+ const_values = [inst.constant.value for inst in loadconst_after if hasattr(inst.constant, "value")]
432
+ assert 5 in const_values, "Constant 5 should be preserved"
433
+ assert 1 in const_values, "Constant 1 should be preserved"
@@ -0,0 +1,236 @@
1
+ """Tests for tail call optimization."""
2
+
3
+ from machine_dialect.mir.basic_block import BasicBlock
4
+ from machine_dialect.mir.mir_function import MIRFunction
5
+ from machine_dialect.mir.mir_instructions import Call, Copy, Return
6
+ from machine_dialect.mir.mir_module import MIRModule
7
+ from machine_dialect.mir.mir_types import MIRType
8
+ from machine_dialect.mir.mir_values import Constant, Temp, Variable
9
+ from machine_dialect.mir.optimizations.tail_call import TailCallOptimization
10
+
11
+
12
+ def test_simple_tail_call() -> None:
13
+ """Test detection of simple tail call pattern."""
14
+ # Create a function with a tail call
15
+ module = MIRModule("test")
16
+ func = MIRFunction("factorial", [Variable("n", MIRType.INT)])
17
+ module.add_function(func)
18
+
19
+ # Create basic block with tail call pattern
20
+ block = BasicBlock("entry")
21
+
22
+ # result = call factorial(n-1)
23
+ result = Temp(MIRType.INT, 0)
24
+ call_inst = Call(result, "factorial", [Variable("n", MIRType.INT)], (1, 1))
25
+ block.add_instruction(call_inst)
26
+
27
+ # return result
28
+ block.add_instruction(Return((1, 1), result))
29
+
30
+ func.cfg.add_block(block)
31
+ func.cfg.entry_block = block
32
+
33
+ # Run optimization
34
+ optimizer = TailCallOptimization()
35
+ modified = optimizer.run_on_module(module)
36
+
37
+ # Check that the call was marked as tail call
38
+ assert modified
39
+ assert call_inst.is_tail_call
40
+ assert optimizer.stats["tail_calls_found"] == 1
41
+ assert optimizer.stats["recursive_tail_calls"] == 1
42
+
43
+
44
+ def test_tail_call_with_copy() -> None:
45
+ """Test detection of tail call with intermediate copy."""
46
+ module = MIRModule("test")
47
+ func = MIRFunction("process", [Variable("x", MIRType.INT)])
48
+ module.add_function(func)
49
+
50
+ block = BasicBlock("entry")
51
+
52
+ # temp = call helper(x)
53
+ temp = Temp(MIRType.INT, 0)
54
+ call_inst = Call(temp, "helper", [Variable("x", MIRType.INT)], (1, 1))
55
+ block.add_instruction(call_inst)
56
+
57
+ # result = temp
58
+ result = Variable("result", MIRType.INT)
59
+ block.add_instruction(Copy(result, temp, (1, 1)))
60
+
61
+ # return result
62
+ block.add_instruction(Return((1, 1), result))
63
+
64
+ func.cfg.add_block(block)
65
+ func.cfg.entry_block = block
66
+
67
+ # Run optimization
68
+ optimizer = TailCallOptimization()
69
+ modified = optimizer.run_on_module(module)
70
+
71
+ # Check that the call was marked as tail call
72
+ assert modified
73
+ assert call_inst.is_tail_call
74
+ assert optimizer.stats["tail_calls_found"] == 1
75
+
76
+
77
+ def test_void_tail_call() -> None:
78
+ """Test detection of void tail call (no return value)."""
79
+ module = MIRModule("test")
80
+ func = MIRFunction("cleanup", [])
81
+ module.add_function(func)
82
+
83
+ block = BasicBlock("entry")
84
+
85
+ # call finalize()
86
+ call_inst = Call(None, "finalize", [], (1, 1))
87
+ block.add_instruction(call_inst)
88
+
89
+ # return
90
+ block.add_instruction(Return((1, 1), None))
91
+
92
+ func.cfg.add_block(block)
93
+ func.cfg.entry_block = block
94
+
95
+ # Run optimization
96
+ optimizer = TailCallOptimization()
97
+ modified = optimizer.run_on_module(module)
98
+
99
+ # Check that the call was marked as tail call
100
+ assert modified
101
+ assert call_inst.is_tail_call
102
+ assert optimizer.stats["tail_calls_found"] == 1
103
+
104
+
105
+ def test_non_tail_call() -> None:
106
+ """Test that non-tail calls are not marked."""
107
+ module = MIRModule("test")
108
+ func = MIRFunction("compute", [Variable("x", MIRType.INT)])
109
+ module.add_function(func)
110
+
111
+ block = BasicBlock("entry")
112
+
113
+ # temp = call helper(x)
114
+ temp = Temp(MIRType.INT, 0)
115
+ call_inst = Call(temp, "helper", [Variable("x", MIRType.INT)], (1, 1))
116
+ block.add_instruction(call_inst)
117
+
118
+ # result = temp + 1 (additional computation after call)
119
+ # We would add a BinaryOp here in real code
120
+
121
+ # return something else
122
+ block.add_instruction(Return((1, 1), Constant(42, MIRType.INT)))
123
+
124
+ func.cfg.add_block(block)
125
+ func.cfg.entry_block = block
126
+
127
+ # Run optimization
128
+ optimizer = TailCallOptimization()
129
+ modified = optimizer.run_on_module(module)
130
+
131
+ # Check that the call was NOT marked as tail call
132
+ assert not modified
133
+ assert not call_inst.is_tail_call
134
+ assert optimizer.stats["tail_calls_found"] == 0
135
+
136
+
137
+ def test_multiple_tail_calls() -> None:
138
+ """Test function with multiple tail calls in different blocks."""
139
+ module = MIRModule("test")
140
+ func = MIRFunction("fibonacci", [Variable("n", MIRType.INT)])
141
+ module.add_function(func)
142
+
143
+ # Block 1: tail call to fib(n-1)
144
+ block1 = BasicBlock("block1")
145
+ temp1 = Temp(MIRType.INT, 0)
146
+ call1 = Call(temp1, "fibonacci", [Variable("n", MIRType.INT)], (1, 1))
147
+ block1.add_instruction(call1)
148
+ block1.add_instruction(Return((1, 1), temp1))
149
+
150
+ # Block 2: tail call to fib(n-2)
151
+ block2 = BasicBlock("block2")
152
+ temp2 = Temp(MIRType.INT, 1)
153
+ call2 = Call(temp2, "fibonacci", [Variable("n", MIRType.INT)], (1, 1))
154
+ block2.add_instruction(call2)
155
+ block2.add_instruction(Return((1, 1), temp2))
156
+
157
+ func.cfg.add_block(block1)
158
+ func.cfg.add_block(block2)
159
+ func.cfg.entry_block = block1
160
+
161
+ # Run optimization
162
+ optimizer = TailCallOptimization()
163
+ modified = optimizer.run_on_module(module)
164
+
165
+ # Check that both calls were marked as tail calls
166
+ assert modified
167
+ assert call1.is_tail_call
168
+ assert call2.is_tail_call
169
+ assert optimizer.stats["tail_calls_found"] == 2
170
+ assert optimizer.stats["recursive_tail_calls"] == 2
171
+
172
+
173
+ def test_mutual_recursion() -> None:
174
+ """Test tail calls in mutually recursive functions."""
175
+ module = MIRModule("test")
176
+
177
+ # Function even calls odd
178
+ even_func = MIRFunction("even", [Variable("n", MIRType.INT)])
179
+ even_block = BasicBlock("entry")
180
+ even_result = Temp(MIRType.BOOL, 0)
181
+ even_call = Call(even_result, "odd", [Variable("n", MIRType.INT)], (1, 1))
182
+ even_block.add_instruction(even_call)
183
+ even_block.add_instruction(Return((1, 1), even_result))
184
+ even_func.cfg.add_block(even_block)
185
+ even_func.cfg.entry_block = even_block
186
+ module.add_function(even_func)
187
+
188
+ # Function odd calls even
189
+ odd_func = MIRFunction("odd", [Variable("n", MIRType.INT)])
190
+ odd_block = BasicBlock("entry")
191
+ odd_result = Temp(MIRType.BOOL, 1)
192
+ odd_call = Call(odd_result, "even", [Variable("n", MIRType.INT)], (1, 1))
193
+ odd_block.add_instruction(odd_call)
194
+ odd_block.add_instruction(Return((1, 1), odd_result))
195
+ odd_func.cfg.add_block(odd_block)
196
+ odd_func.cfg.entry_block = odd_block
197
+ module.add_function(odd_func)
198
+
199
+ # Run optimization
200
+ optimizer = TailCallOptimization()
201
+ modified = optimizer.run_on_module(module)
202
+
203
+ # Check that both mutual recursive calls were marked as tail calls
204
+ assert modified
205
+ assert even_call.is_tail_call
206
+ assert odd_call.is_tail_call
207
+ assert optimizer.stats["tail_calls_found"] == 2
208
+ # These are not self-recursive, so recursive count should be 0
209
+ assert optimizer.stats["recursive_tail_calls"] == 0
210
+
211
+
212
+ def test_already_optimized() -> None:
213
+ """Test that already marked tail calls are not counted again."""
214
+ module = MIRModule("test")
215
+ func = MIRFunction("test", [])
216
+ module.add_function(func)
217
+
218
+ block = BasicBlock("entry")
219
+
220
+ # Create a call already marked as tail call
221
+ result = Temp(MIRType.INT, 0)
222
+ call_inst = Call(result, "helper", [], (1, 1), is_tail_call=True)
223
+ block.add_instruction(call_inst)
224
+ block.add_instruction(Return((1, 1), result))
225
+
226
+ func.cfg.add_block(block)
227
+ func.cfg.entry_block = block
228
+
229
+ # Run optimization
230
+ optimizer = TailCallOptimization()
231
+ modified = optimizer.run_on_module(module)
232
+
233
+ # Check that no modifications were made
234
+ assert not modified
235
+ assert call_inst.is_tail_call # Still marked
236
+ assert optimizer.stats["tail_calls_found"] == 0 # Not counted as new