machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,255 @@
1
+ """Tests for cross-block constant propagation."""
2
+
3
+ from machine_dialect.mir.basic_block import BasicBlock
4
+ from machine_dialect.mir.mir_function import MIRFunction
5
+ from machine_dialect.mir.mir_instructions import (
6
+ BinaryOp,
7
+ ConditionalJump,
8
+ Copy,
9
+ Jump,
10
+ LoadConst,
11
+ Phi,
12
+ Return,
13
+ )
14
+ from machine_dialect.mir.mir_types import MIRType
15
+ from machine_dialect.mir.mir_values import Constant, Temp, Variable
16
+ from machine_dialect.mir.optimizations.constant_propagation import ConstantPropagation
17
+
18
+
19
+ class TestCrossBlockConstantPropagation:
20
+ """Test cross-block constant propagation."""
21
+
22
+ def test_constant_through_multiple_blocks(self) -> None:
23
+ """Test propagation of constants through multiple blocks."""
24
+ # Create function
25
+ func = MIRFunction("test", [])
26
+
27
+ # Block 1: x = 10
28
+ block1 = BasicBlock("entry")
29
+ x = Variable("x", MIRType.INT)
30
+ func.add_local(x)
31
+ block1.add_instruction(LoadConst(x, Constant(10, MIRType.INT), (1, 1)))
32
+ block1.add_instruction(Jump("block2", (1, 1)))
33
+
34
+ # Block 2: y = x + 5
35
+ block2 = BasicBlock("block2")
36
+ y = Variable("y", MIRType.INT)
37
+ func.add_local(y)
38
+ t1 = Temp(MIRType.INT, 0)
39
+ block2.add_instruction(BinaryOp(t1, "+", x, Constant(5, MIRType.INT), (1, 1)))
40
+ block2.add_instruction(Copy(y, t1, (1, 1)))
41
+ block2.add_instruction(Jump("block3", (1, 1)))
42
+
43
+ # Block 3: z = y * 2
44
+ block3 = BasicBlock("block3")
45
+ z = Variable("z", MIRType.INT)
46
+ func.add_local(z)
47
+ t2 = Temp(MIRType.INT, 1)
48
+ block3.add_instruction(BinaryOp(t2, "*", y, Constant(2, MIRType.INT), (1, 1)))
49
+ block3.add_instruction(Copy(z, t2, (1, 1)))
50
+ block3.add_instruction(Return((1, 1), z))
51
+
52
+ # Set up CFG
53
+ func.cfg.add_block(block1)
54
+ func.cfg.add_block(block2)
55
+ func.cfg.add_block(block3)
56
+ func.cfg.set_entry_block(block1)
57
+
58
+ block1.add_successor(block2)
59
+ block2.add_predecessor(block1)
60
+ block2.add_successor(block3)
61
+ block3.add_predecessor(block2)
62
+
63
+ # Run optimization
64
+ optimizer = ConstantPropagation()
65
+ modified = optimizer.run_on_function(func)
66
+
67
+ assert modified
68
+ # After optimization, operations should be folded
69
+ # x = 10 -> y = 15 -> z = 30
70
+ # Check that final return is a constant
71
+ final_inst = block3.instructions[-1]
72
+ assert isinstance(final_inst, Return)
73
+ # The value might be replaced with a constant
74
+
75
+ def test_phi_node_constant_propagation(self) -> None:
76
+ """Test constant propagation through phi nodes."""
77
+ func = MIRFunction("test", [])
78
+
79
+ # Entry block
80
+ entry = BasicBlock("entry")
81
+ cond = Variable("cond", MIRType.BOOL)
82
+ func.add_local(cond)
83
+ entry.add_instruction(LoadConst(cond, Constant(True, MIRType.BOOL), (1, 1)))
84
+ entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
85
+
86
+ # Then block: x = 10
87
+ then_block = BasicBlock("then")
88
+ x_then = Temp(MIRType.INT, 0)
89
+ then_block.add_instruction(LoadConst(x_then, Constant(10, MIRType.INT), (1, 1)))
90
+ then_block.add_instruction(Jump("merge", (1, 1)))
91
+
92
+ # Else block: x = 10 (same value)
93
+ else_block = BasicBlock("else")
94
+ x_else = Temp(MIRType.INT, 1)
95
+ else_block.add_instruction(LoadConst(x_else, Constant(10, MIRType.INT), (1, 1)))
96
+ else_block.add_instruction(Jump("merge", (1, 1)))
97
+
98
+ # Merge block with phi node
99
+ merge = BasicBlock("merge")
100
+ x = Variable("x", MIRType.INT)
101
+ func.add_local(x)
102
+ phi = Phi(x, [(x_then, "then"), (x_else, "else")], (1, 1))
103
+ merge.phi_nodes.append(phi)
104
+
105
+ # Use x in computation
106
+ result = Temp(MIRType.INT, 2)
107
+ merge.add_instruction(BinaryOp(result, "+", x, Constant(5, MIRType.INT), (1, 1)))
108
+ merge.add_instruction(Return((1, 1), result))
109
+
110
+ # Set up CFG
111
+ func.cfg.add_block(entry)
112
+ func.cfg.add_block(then_block)
113
+ func.cfg.add_block(else_block)
114
+ func.cfg.add_block(merge)
115
+ func.cfg.set_entry_block(entry)
116
+
117
+ entry.add_successor(then_block)
118
+ entry.add_successor(else_block)
119
+ then_block.add_predecessor(entry)
120
+ else_block.add_predecessor(entry)
121
+ then_block.add_successor(merge)
122
+ else_block.add_successor(merge)
123
+ merge.add_predecessor(then_block)
124
+ merge.add_predecessor(else_block)
125
+
126
+ # Run optimization
127
+ optimizer = ConstantPropagation()
128
+ modified = optimizer.run_on_function(func)
129
+
130
+ assert modified
131
+ # Since both branches assign the same constant (10) to x,
132
+ # the phi should resolve to 10 and x + 5 should fold to 15
133
+
134
+ def test_loop_constant_propagation(self) -> None:
135
+ """Test constant propagation in loops."""
136
+ func = MIRFunction("test", [])
137
+
138
+ # Entry block: i = 0, sum = 0
139
+ entry = BasicBlock("entry")
140
+ i = Variable("i", MIRType.INT)
141
+ sum_var = Variable("sum", MIRType.INT)
142
+ func.add_local(i)
143
+ func.add_local(sum_var)
144
+
145
+ entry.add_instruction(LoadConst(i, Constant(0, MIRType.INT), (1, 1)))
146
+ entry.add_instruction(LoadConst(sum_var, Constant(0, MIRType.INT), (1, 1)))
147
+ entry.add_instruction(Jump("loop", (1, 1)))
148
+
149
+ # Loop block
150
+ loop = BasicBlock("loop")
151
+ # Phi nodes for loop variables
152
+ i_phi = Phi(i, [(i, "entry")], (1, 1)) # Will have back-edge added
153
+ sum_phi = Phi(sum_var, [(sum_var, "entry")], (1, 1))
154
+ loop.phi_nodes.append(i_phi)
155
+ loop.phi_nodes.append(sum_phi)
156
+
157
+ # Check condition: i < 10
158
+ t_cond = Temp(MIRType.BOOL, 0)
159
+ loop.add_instruction(BinaryOp(t_cond, "<", i, Constant(10, MIRType.INT), (1, 1)))
160
+ loop.add_instruction(ConditionalJump(t_cond, "body", (1, 1), "exit"))
161
+
162
+ # Loop body
163
+ body = BasicBlock("body")
164
+ # sum = sum + i
165
+ t_sum = Temp(MIRType.INT, 1)
166
+ body.add_instruction(BinaryOp(t_sum, "+", sum_var, i, (1, 1)))
167
+ body.add_instruction(Copy(sum_var, t_sum, (1, 1)))
168
+
169
+ # i = i + 1
170
+ t_i = Temp(MIRType.INT, 2)
171
+ body.add_instruction(BinaryOp(t_i, "+", i, Constant(1, MIRType.INT), (1, 1)))
172
+ body.add_instruction(Copy(i, t_i, (1, 1)))
173
+ body.add_instruction(Jump("loop", (1, 1)))
174
+
175
+ # Exit block
176
+ exit_block = BasicBlock("exit")
177
+ exit_block.add_instruction(Return((1, 1), sum_var))
178
+
179
+ # Set up CFG
180
+ func.cfg.add_block(entry)
181
+ func.cfg.add_block(loop)
182
+ func.cfg.add_block(body)
183
+ func.cfg.add_block(exit_block)
184
+ func.cfg.set_entry_block(entry)
185
+
186
+ entry.add_successor(loop)
187
+ loop.add_predecessor(entry)
188
+ loop.add_successor(body)
189
+ loop.add_successor(exit_block)
190
+ body.add_predecessor(loop)
191
+ body.add_successor(loop) # Back-edge
192
+ loop.add_predecessor(body) # Back-edge
193
+ exit_block.add_predecessor(loop)
194
+
195
+ # Add back-edge to phi nodes
196
+ i_phi.incoming.append((i, "body"))
197
+ sum_phi.incoming.append((sum_var, "body"))
198
+
199
+ # Run optimization
200
+ optimizer = ConstantPropagation()
201
+ optimizer.run_on_function(func)
202
+
203
+ # In loops, constant propagation is limited but should still
204
+ # propagate initial values and fold operations where possible
205
+ assert optimizer.stats.get("constants_propagated", 0) >= 0
206
+
207
+ def test_conditional_constant_propagation(self) -> None:
208
+ """Test constant propagation with conditional branches."""
209
+ func = MIRFunction("test", [])
210
+
211
+ # Entry: x = 5, y = 10
212
+ entry = BasicBlock("entry")
213
+ x = Variable("x", MIRType.INT)
214
+ y = Variable("y", MIRType.INT)
215
+ func.add_local(x)
216
+ func.add_local(y)
217
+
218
+ entry.add_instruction(LoadConst(x, Constant(5, MIRType.INT), (1, 1)))
219
+ entry.add_instruction(LoadConst(y, Constant(10, MIRType.INT), (1, 1)))
220
+
221
+ # Compute condition: x < y (should be constant True)
222
+ cond = Temp(MIRType.BOOL, 0)
223
+ entry.add_instruction(BinaryOp(cond, "<", x, y, (1, 1)))
224
+ entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
225
+
226
+ # Then block (should be taken)
227
+ then_block = BasicBlock("then")
228
+ result_then = Temp(MIRType.INT, 1)
229
+ then_block.add_instruction(BinaryOp(result_then, "+", x, y, (1, 1)))
230
+ then_block.add_instruction(Return((1, 1), result_then))
231
+
232
+ # Else block (dead code)
233
+ else_block = BasicBlock("else")
234
+ result_else = Temp(MIRType.INT, 2)
235
+ else_block.add_instruction(BinaryOp(result_else, "-", y, x, (1, 1)))
236
+ else_block.add_instruction(Return((1, 1), result_else))
237
+
238
+ # Set up CFG
239
+ func.cfg.add_block(entry)
240
+ func.cfg.add_block(then_block)
241
+ func.cfg.add_block(else_block)
242
+ func.cfg.set_entry_block(entry)
243
+
244
+ entry.add_successor(then_block)
245
+ entry.add_successor(else_block)
246
+ then_block.add_predecessor(entry)
247
+ else_block.add_predecessor(entry)
248
+
249
+ # Run optimization
250
+ optimizer = ConstantPropagation()
251
+ modified = optimizer.run_on_function(func)
252
+
253
+ assert modified
254
+ # The condition x < y should be folded to True
255
+ # and potentially the branch should be simplified
@@ -0,0 +1,166 @@
1
+ """Test custom optimization passes functionality."""
2
+
3
+ from machine_dialect.mir.basic_block import BasicBlock
4
+ from machine_dialect.mir.mir_function import MIRFunction
5
+ from machine_dialect.mir.mir_instructions import LoadConst, Return
6
+ from machine_dialect.mir.mir_module import MIRModule
7
+ from machine_dialect.mir.mir_types import MIRType
8
+ from machine_dialect.mir.mir_values import Temp
9
+ from machine_dialect.mir.optimization_config import OptimizationConfig
10
+ from machine_dialect.mir.optimize_mir import optimize_mir
11
+
12
+
13
+ def create_simple_module() -> MIRModule:
14
+ """Create a simple test module.
15
+
16
+ Returns:
17
+ A MIR module with a simple main function.
18
+ """
19
+ module = MIRModule("test")
20
+ func = MIRFunction("main", [], MIRType.INT)
21
+ entry = BasicBlock("entry")
22
+ entry.add_instruction(LoadConst(Temp(MIRType.INT), 42, (1, 1)))
23
+ entry.add_instruction(Return((1, 1)))
24
+ func.cfg.add_block(entry)
25
+ func.cfg.entry = entry # type: ignore[attr-defined]
26
+ module.add_function(func)
27
+ return module
28
+
29
+
30
+ def test_custom_passes_override_default() -> None:
31
+ """Test that custom passes override the default optimization pipeline."""
32
+ module = create_simple_module()
33
+
34
+ # Run with custom passes at optimization level 2
35
+ # Level 2 would normally run many passes, but custom passes should override
36
+ custom_passes = ["constant-propagation", "dce"]
37
+ _optimized, stats = optimize_mir(module, optimization_level=2, custom_passes=custom_passes)
38
+
39
+ # Verify that only our custom passes ran (plus their dependencies)
40
+ # constant-propagation depends on use-def-chains analysis
41
+ assert "constant-propagation" in stats
42
+ assert "dce" in stats
43
+
44
+ # These passes would normally run at level 2 but shouldn't with custom passes
45
+ assert "cse" not in stats # Common subexpression elimination
46
+ assert "strength-reduction" not in stats
47
+
48
+
49
+ def test_no_custom_passes_uses_default() -> None:
50
+ """Test that without custom passes, the default pipeline is used."""
51
+ module = create_simple_module()
52
+
53
+ # Run at level 2 without custom passes
54
+ _optimized, stats = optimize_mir(module, optimization_level=2)
55
+
56
+ # At level 2, we should see standard optimizations
57
+ assert "constant-propagation" in stats or "constant-folding" in stats
58
+ assert "dce" in stats
59
+ # CSE is enabled at level 2
60
+ assert "cse" in stats
61
+
62
+
63
+ def test_custom_passes_empty_list() -> None:
64
+ """Test that an empty custom passes list runs no optimization passes."""
65
+ module = create_simple_module()
66
+
67
+ # Run with empty custom passes list
68
+ _optimized, stats = optimize_mir(module, optimization_level=2, custom_passes=[])
69
+
70
+ # No optimization passes should have run
71
+ assert len(stats) == 0
72
+
73
+
74
+ def test_custom_passes_at_level_0() -> None:
75
+ """Test that custom passes work even at optimization level 0."""
76
+ module = create_simple_module()
77
+
78
+ # Level 0 normally runs no optimizations, but custom passes should still run
79
+ custom_passes = ["dce"]
80
+ _optimized, stats = optimize_mir(module, optimization_level=0, custom_passes=custom_passes)
81
+
82
+ # DCE should have run despite level 0
83
+ assert "dce" in stats
84
+
85
+
86
+ def test_custom_passes_with_dependencies() -> None:
87
+ """Test that custom passes include their required analysis passes."""
88
+ module = create_simple_module()
89
+
90
+ # constant-propagation requires use-def-chains analysis
91
+ custom_passes = ["constant-propagation"]
92
+ _optimized, stats = optimize_mir(module, optimization_level=1, custom_passes=custom_passes)
93
+
94
+ # Both the optimization and its required analysis should be in stats
95
+ assert "constant-propagation" in stats
96
+ # Note: dependencies are handled internally by the pass manager
97
+ # The stats might not always include analysis passes
98
+
99
+
100
+ def test_custom_passes_preserve_module() -> None:
101
+ """Test that optimization with custom passes preserves module structure."""
102
+ module = create_simple_module()
103
+ original_func_count = len(module.functions)
104
+ original_func_name = next(iter(module.functions.values())).name if module.functions else None
105
+
106
+ # Run with custom passes
107
+ custom_passes = ["dce"]
108
+ optimized, _stats = optimize_mir(module, optimization_level=1, custom_passes=custom_passes)
109
+
110
+ # Module structure should be preserved
111
+ assert len(optimized.functions) == original_func_count
112
+ if original_func_name:
113
+ assert next(iter(optimized.functions.values())).name == original_func_name
114
+ assert optimized.name == module.name
115
+
116
+
117
+ def test_invalid_custom_pass_name() -> None:
118
+ """Test that invalid pass names are handled gracefully."""
119
+ module = create_simple_module()
120
+
121
+ # Try to run with an invalid pass name
122
+ custom_passes = ["invalid-pass-name", "dce"]
123
+
124
+ # This should either skip the invalid pass or raise an error
125
+ # The actual behavior depends on the pass manager implementation
126
+ # For now, we just verify it doesn't crash
127
+ try:
128
+ _optimized, stats = optimize_mir(module, optimization_level=1, custom_passes=custom_passes)
129
+ # If it succeeds, the valid pass should still run
130
+ assert "dce" in stats or len(stats) >= 0
131
+ except (KeyError, ValueError):
132
+ # It's also acceptable to raise an error for invalid passes
133
+ pass
134
+
135
+
136
+ def test_custom_passes_order_preserved() -> None:
137
+ """Test that custom passes run in the specified order."""
138
+ module = create_simple_module()
139
+
140
+ # Specify passes in a specific order
141
+ custom_passes = ["dce", "constant-propagation", "dce"]
142
+ _optimized, stats = optimize_mir(module, optimization_level=1, custom_passes=custom_passes)
143
+
144
+ # Both passes should have run
145
+ # Note: stats might aggregate multiple runs of the same pass
146
+ assert "dce" in stats
147
+ assert "constant-propagation" in stats
148
+
149
+
150
+ def test_custom_passes_with_config() -> None:
151
+ """Test that custom passes work with a custom OptimizationConfig."""
152
+ module = create_simple_module()
153
+
154
+ # Create a custom config
155
+ config = OptimizationConfig.from_level(2)
156
+ config.debug_passes = True
157
+ config.pass_statistics = True
158
+
159
+ # Run with custom passes and custom config
160
+ custom_passes = ["constant-propagation"]
161
+ _optimized, stats = optimize_mir(module, optimization_level=2, config=config, custom_passes=custom_passes)
162
+
163
+ # Custom passes should override the config's default pipeline
164
+ assert "constant-propagation" in stats
165
+ # CSE would normally be in level 2 but shouldn't run with custom passes
166
+ assert "cse" not in stats