machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,204 @@
1
+ """Tests for associativity and commutativity optimizations in algebraic simplification."""
2
+
3
+ from machine_dialect.mir.basic_block import BasicBlock
4
+ from machine_dialect.mir.mir_function import MIRFunction
5
+ from machine_dialect.mir.mir_instructions import BinaryOp, LoadConst
6
+ from machine_dialect.mir.mir_module import MIRModule
7
+ from machine_dialect.mir.mir_transformer import MIRTransformer
8
+ from machine_dialect.mir.mir_types import MIRType
9
+ from machine_dialect.mir.mir_values import Constant, Temp
10
+ from machine_dialect.mir.optimizations.algebraic_simplification import AlgebraicSimplification
11
+
12
+
13
+ class TestAlgebraicAssociativity:
14
+ """Test associativity and commutativity optimizations."""
15
+
16
+ def setup_method(self) -> None:
17
+ """Set up test fixtures."""
18
+ self.module = MIRModule("test")
19
+ self.func = MIRFunction("test_func", [], MIRType.INT)
20
+ self.block = BasicBlock("entry")
21
+ self.func.cfg.add_block(self.block)
22
+ self.func.cfg.entry_block = self.block
23
+ self.module.add_function(self.func)
24
+ self.transformer = MIRTransformer(self.func)
25
+ self.opt = AlgebraicSimplification()
26
+
27
+ def test_addition_associativity_left(self) -> None:
28
+ """Test (a + 2) + 3 → a + 5."""
29
+ t0 = Temp(MIRType.INT)
30
+ t1 = Temp(MIRType.INT)
31
+ t2 = Temp(MIRType.INT)
32
+
33
+ # t0 = x, t1 = t0 + 2, t2 = t1 + 3
34
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
35
+ self.block.add_instruction(BinaryOp(t1, "+", t0, Constant(2, MIRType.INT), (1, 1)))
36
+ self.block.add_instruction(BinaryOp(t2, "+", t1, Constant(3, MIRType.INT), (1, 1)))
37
+
38
+ changed = self.opt.run_on_function(self.func)
39
+
40
+ assert changed
41
+ assert self.opt.stats.get("associativity_applied") == 1
42
+ instructions = list(self.block.instructions)
43
+ # The last instruction should be BinaryOp(t2, "+", t0, Constant(5, (1, 1)))
44
+ assert isinstance(instructions[2], BinaryOp)
45
+ binary_inst = instructions[2]
46
+ assert isinstance(binary_inst, BinaryOp)
47
+ assert binary_inst.op == "+"
48
+ assert binary_inst.left == t0
49
+ assert isinstance(binary_inst.right, Constant)
50
+ assert isinstance(binary_inst.right, Constant)
51
+ assert binary_inst.right.value == 5
52
+
53
+ def test_multiplication_associativity_left(self) -> None:
54
+ """Test (a * 2) * 3 → a * 6."""
55
+ t0 = Temp(MIRType.INT)
56
+ t1 = Temp(MIRType.INT)
57
+ t2 = Temp(MIRType.INT)
58
+
59
+ # t0 = x, t1 = t0 * 2, t2 = t1 * 3
60
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
61
+ self.block.add_instruction(BinaryOp(t1, "*", t0, Constant(2, MIRType.INT), (1, 1)))
62
+ self.block.add_instruction(BinaryOp(t2, "*", t1, Constant(3, MIRType.INT), (1, 1)))
63
+
64
+ changed = self.opt.run_on_function(self.func)
65
+
66
+ assert changed
67
+ assert self.opt.stats.get("associativity_applied") == 1
68
+ instructions = list(self.block.instructions)
69
+ # The last instruction should be BinaryOp(t2, "*", t0, Constant(6, (1, 1)))
70
+ assert isinstance(instructions[2], BinaryOp)
71
+ binary_inst = instructions[2]
72
+ assert isinstance(binary_inst, BinaryOp)
73
+ assert binary_inst.op == "*"
74
+ assert binary_inst.left == t0
75
+ assert isinstance(binary_inst.right, Constant)
76
+ assert isinstance(binary_inst.right, Constant)
77
+ assert binary_inst.right.value == 6
78
+
79
+ def test_addition_commutativity_right(self) -> None:
80
+ """Test 3 + (a + 2) → 5 + a."""
81
+ t0 = Temp(MIRType.INT)
82
+ t1 = Temp(MIRType.INT)
83
+ t2 = Temp(MIRType.INT)
84
+
85
+ # t0 = x, t1 = t0 + 2, t2 = 3 + t1
86
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
87
+ self.block.add_instruction(BinaryOp(t1, "+", t0, Constant(2, MIRType.INT), (1, 1)))
88
+ self.block.add_instruction(BinaryOp(t2, "+", Constant(3, MIRType.INT), t1, (1, 1)))
89
+
90
+ changed = self.opt.run_on_function(self.func)
91
+
92
+ assert changed
93
+ assert self.opt.stats.get("associativity_applied") == 1
94
+ instructions = list(self.block.instructions)
95
+ # The last instruction should be BinaryOp(t2, "+", Constant(5), t0)
96
+ assert isinstance(instructions[2], BinaryOp)
97
+ binary_inst = instructions[2]
98
+ assert isinstance(binary_inst, BinaryOp)
99
+ assert binary_inst.op == "+"
100
+ assert isinstance(binary_inst.left, Constant)
101
+ assert isinstance(binary_inst.left, Constant)
102
+ assert binary_inst.left.value == 5
103
+ assert binary_inst.right == t0
104
+
105
+ def test_multiplication_commutativity_right(self) -> None:
106
+ """Test 3 * (a * 2) → 6 * a."""
107
+ t0 = Temp(MIRType.INT)
108
+ t1 = Temp(MIRType.INT)
109
+ t2 = Temp(MIRType.INT)
110
+
111
+ # t0 = x, t1 = t0 * 2, t2 = 3 * t1
112
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
113
+ self.block.add_instruction(BinaryOp(t1, "*", t0, Constant(2, MIRType.INT), (1, 1)))
114
+ self.block.add_instruction(BinaryOp(t2, "*", Constant(3, MIRType.INT), t1, (1, 1)))
115
+
116
+ changed = self.opt.run_on_function(self.func)
117
+
118
+ assert changed
119
+ assert self.opt.stats.get("associativity_applied") == 1
120
+ instructions = list(self.block.instructions)
121
+ # The last instruction should be BinaryOp(t2, "*", Constant(6), t0)
122
+ assert isinstance(instructions[2], BinaryOp)
123
+ binary_inst = instructions[2]
124
+ assert isinstance(binary_inst, BinaryOp)
125
+ assert binary_inst.op == "*"
126
+ assert isinstance(binary_inst.left, Constant)
127
+ assert isinstance(binary_inst.left, Constant)
128
+ assert binary_inst.left.value == 6
129
+ assert binary_inst.right == t0
130
+
131
+ def test_nested_addition_associativity(self) -> None:
132
+ """Test ((a + 1) + 2) + 3 → a + 6 in a single pass."""
133
+ t0 = Temp(MIRType.INT)
134
+ t1 = Temp(MIRType.INT)
135
+ t2 = Temp(MIRType.INT)
136
+ t3 = Temp(MIRType.INT)
137
+
138
+ # t0 = x, t1 = t0 + 1, t2 = t1 + 2, t3 = t2 + 3
139
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
140
+ self.block.add_instruction(BinaryOp(t1, "+", t0, Constant(1, MIRType.INT), (1, 1)))
141
+ self.block.add_instruction(BinaryOp(t2, "+", t1, Constant(2, MIRType.INT), (1, 1)))
142
+ self.block.add_instruction(BinaryOp(t3, "+", t2, Constant(3, MIRType.INT), (1, 1)))
143
+
144
+ # Run optimization - should fold nested additions in single pass
145
+ changed = self.opt.run_on_function(self.func)
146
+ assert changed
147
+
148
+ # Should have applied associativity at least twice
149
+ assert self.opt.stats.get("associativity_applied", 0) >= 2
150
+
151
+ instructions = list(self.block.instructions)
152
+
153
+ # Verify t2 = t0 + 3 (folded 1 + 2)
154
+ t2_inst = instructions[2]
155
+ assert isinstance(t2_inst, BinaryOp)
156
+ assert isinstance(t2_inst, BinaryOp)
157
+ assert t2_inst.op == "+"
158
+ assert t2_inst.left == t0
159
+ assert isinstance(t2_inst.right, Constant)
160
+ assert isinstance(t2_inst.right, Constant)
161
+ assert t2_inst.right.value == 3
162
+
163
+ # Verify t3 = t0 + 6 (folded 3 + 3)
164
+ t3_inst = instructions[3]
165
+ assert isinstance(t3_inst, BinaryOp)
166
+ assert isinstance(t3_inst, BinaryOp)
167
+ assert t3_inst.op == "+"
168
+ assert t3_inst.left == t0
169
+ assert isinstance(t3_inst.right, Constant)
170
+ assert isinstance(t3_inst.right, Constant)
171
+ assert t3_inst.right.value == 6
172
+
173
+ # Verify second pass finds nothing to optimize (fixed point reached)
174
+ self.opt.stats.clear()
175
+ changed = self.opt.run_on_function(self.func)
176
+ assert not changed, "Second pass should find nothing to optimize"
177
+
178
+ def test_no_associativity_without_constants(self) -> None:
179
+ """Test that (a + b) + c doesn't change without constants."""
180
+ t0 = Temp(MIRType.INT)
181
+ t1 = Temp(MIRType.INT)
182
+ t2 = Temp(MIRType.INT)
183
+ t3 = Temp(MIRType.INT)
184
+ t4 = Temp(MIRType.INT)
185
+
186
+ # All variables, no constants
187
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
188
+ self.block.add_instruction(LoadConst(t1, Constant(20, MIRType.INT), (1, 1)))
189
+ self.block.add_instruction(LoadConst(t2, Constant(30, MIRType.INT), (1, 1)))
190
+ self.block.add_instruction(BinaryOp(t3, "+", t0, t1, (1, 1)))
191
+ self.block.add_instruction(BinaryOp(t4, "+", t3, t2, (1, 1)))
192
+
193
+ changed = self.opt.run_on_function(self.func)
194
+
195
+ # Should not apply associativity since there are no constant pairs to fold
196
+ assert not changed
197
+ assert "associativity_applied" not in self.opt.stats
198
+ instructions = list(self.block.instructions)
199
+ # Last instruction should remain unchanged
200
+ assert isinstance(instructions[4], BinaryOp)
201
+ binary_inst = instructions[4]
202
+ assert isinstance(binary_inst, BinaryOp)
203
+ assert binary_inst.left == t3
204
+ assert binary_inst.right == t2
@@ -0,0 +1,221 @@
1
+ """Tests for complex pattern matching in algebraic simplification."""
2
+
3
+ from machine_dialect.mir.basic_block import BasicBlock
4
+ from machine_dialect.mir.mir_function import MIRFunction
5
+ from machine_dialect.mir.mir_instructions import BinaryOp, Copy, LoadConst, UnaryOp
6
+ from machine_dialect.mir.mir_module import MIRModule
7
+ from machine_dialect.mir.mir_transformer import MIRTransformer
8
+ from machine_dialect.mir.mir_types import MIRType
9
+ from machine_dialect.mir.mir_values import Constant, Temp
10
+ from machine_dialect.mir.optimizations.algebraic_simplification import AlgebraicSimplification
11
+
12
+
13
+ class TestAlgebraicComplexPatterns:
14
+ """Test complex pattern matching in algebraic simplification."""
15
+
16
+ def setup_method(self) -> None:
17
+ """Set up test fixtures."""
18
+ self.module = MIRModule("test")
19
+ self.func = MIRFunction("test_func", [], MIRType.INT)
20
+ self.block = BasicBlock("entry")
21
+ self.func.cfg.add_block(self.block)
22
+ self.func.cfg.entry_block = self.block
23
+ self.module.add_function(self.func)
24
+ self.transformer = MIRTransformer(self.func)
25
+ self.opt = AlgebraicSimplification()
26
+
27
+ def test_add_then_subtract_pattern(self) -> None:
28
+ """Test (a + b) - b → a."""
29
+ t0 = Temp(MIRType.INT)
30
+ t1 = Temp(MIRType.INT)
31
+ t2 = Temp(MIRType.INT)
32
+ t3 = Temp(MIRType.INT)
33
+
34
+ # a = 10, b = 5, t2 = a + b, t3 = t2 - b
35
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
36
+ self.block.add_instruction(LoadConst(t1, Constant(5, MIRType.INT), (1, 1)))
37
+ self.block.add_instruction(BinaryOp(t2, "+", t0, t1, (1, 1)))
38
+ self.block.add_instruction(BinaryOp(t3, "-", t2, t1, (1, 1)))
39
+
40
+ changed = self.opt.run_on_function(self.func)
41
+
42
+ assert changed
43
+ assert self.opt.stats.get("complex_pattern_matched") == 1
44
+ instructions = list(self.block.instructions)
45
+ # The last instruction should be Copy(t3, t0, (1, 1))
46
+ assert isinstance(instructions[3], Copy)
47
+ copy_inst = instructions[3]
48
+ assert isinstance(copy_inst, Copy)
49
+ assert copy_inst.source == t0
50
+ assert copy_inst.dest == t3
51
+
52
+ def test_subtract_then_add_pattern(self) -> None:
53
+ """Test (a - b) + b → a."""
54
+ t0 = Temp(MIRType.INT)
55
+ t1 = Temp(MIRType.INT)
56
+ t2 = Temp(MIRType.INT)
57
+ t3 = Temp(MIRType.INT)
58
+
59
+ # a = 10, b = 5, t2 = a - b, t3 = t2 + b
60
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
61
+ self.block.add_instruction(LoadConst(t1, Constant(5, MIRType.INT), (1, 1)))
62
+ self.block.add_instruction(BinaryOp(t2, "-", t0, t1, (1, 1)))
63
+ self.block.add_instruction(BinaryOp(t3, "+", t2, t1, (1, 1)))
64
+
65
+ changed = self.opt.run_on_function(self.func)
66
+
67
+ assert changed
68
+ assert self.opt.stats.get("complex_pattern_matched") == 1
69
+ instructions = list(self.block.instructions)
70
+ # The last instruction should be Copy(t3, t0, (1, 1))
71
+ assert isinstance(instructions[3], Copy)
72
+ copy_inst = instructions[3]
73
+ assert isinstance(copy_inst, Copy)
74
+ assert copy_inst.source == t0
75
+ assert copy_inst.dest == t3
76
+
77
+ def test_multiply_then_divide_pattern(self) -> None:
78
+ """Test (a * b) / b → a."""
79
+ t0 = Temp(MIRType.INT)
80
+ t1 = Temp(MIRType.INT)
81
+ t2 = Temp(MIRType.INT)
82
+ t3 = Temp(MIRType.INT)
83
+
84
+ # a = 10, b = 5, t2 = a * b, t3 = t2 / b
85
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
86
+ self.block.add_instruction(LoadConst(t1, Constant(5, MIRType.INT), (1, 1)))
87
+ self.block.add_instruction(BinaryOp(t2, "*", t0, t1, (1, 1)))
88
+ self.block.add_instruction(BinaryOp(t3, "/", t2, t1, (1, 1)))
89
+
90
+ changed = self.opt.run_on_function(self.func)
91
+
92
+ assert changed
93
+ assert self.opt.stats.get("complex_pattern_matched") == 1
94
+ instructions = list(self.block.instructions)
95
+ # The last instruction should be Copy(t3, t0, (1, 1))
96
+ assert isinstance(instructions[3], Copy)
97
+ copy_inst = instructions[3]
98
+ assert isinstance(copy_inst, Copy)
99
+ assert copy_inst.source == t0
100
+ assert copy_inst.dest == t3
101
+
102
+ def test_divide_then_multiply_pattern(self) -> None:
103
+ """Test (a / b) * b → a."""
104
+ t0 = Temp(MIRType.INT)
105
+ t1 = Temp(MIRType.INT)
106
+ t2 = Temp(MIRType.INT)
107
+ t3 = Temp(MIRType.INT)
108
+
109
+ # a = 10, b = 5, t2 = a / b, t3 = t2 * b
110
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
111
+ self.block.add_instruction(LoadConst(t1, Constant(5, MIRType.INT), (1, 1)))
112
+ self.block.add_instruction(BinaryOp(t2, "/", t0, t1, (1, 1)))
113
+ self.block.add_instruction(BinaryOp(t3, "*", t2, t1, (1, 1)))
114
+
115
+ changed = self.opt.run_on_function(self.func)
116
+
117
+ assert changed
118
+ assert self.opt.stats.get("complex_pattern_matched") == 1
119
+ instructions = list(self.block.instructions)
120
+ # The last instruction should be Copy(t3, t0, (1, 1))
121
+ assert isinstance(instructions[3], Copy)
122
+ copy_inst = instructions[3]
123
+ assert isinstance(copy_inst, Copy)
124
+ assert copy_inst.source == t0
125
+ assert copy_inst.dest == t3
126
+
127
+ def test_zero_minus_x_pattern(self) -> None:
128
+ """Test 0 - x → -x."""
129
+ t0 = Temp(MIRType.INT)
130
+ t1 = Temp(MIRType.INT)
131
+
132
+ self.block.add_instruction(LoadConst(t0, Constant(42, MIRType.INT), (1, 1)))
133
+ self.block.add_instruction(BinaryOp(t1, "-", Constant(0, MIRType.INT), t0, (1, 1)))
134
+
135
+ changed = self.opt.run_on_function(self.func)
136
+
137
+ assert changed
138
+ assert self.opt.stats.get("complex_pattern_matched") == 1
139
+ instructions = list(self.block.instructions)
140
+ assert isinstance(instructions[1], UnaryOp)
141
+ unary_inst = instructions[1]
142
+ assert isinstance(unary_inst, UnaryOp)
143
+ assert unary_inst.op == "-"
144
+ assert unary_inst.operand == t0
145
+ assert unary_inst.dest == t1
146
+
147
+ def test_chained_subtraction_constants(self) -> None:
148
+ """Test (a - b) - c → a - (b + c) when b and c are constants."""
149
+ t0 = Temp(MIRType.INT)
150
+ t1 = Temp(MIRType.INT)
151
+ t2 = Temp(MIRType.INT)
152
+
153
+ # a = t0, t1 = a - 3, t2 = t1 - 2 → t2 = a - 5
154
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
155
+ self.block.add_instruction(BinaryOp(t1, "-", t0, Constant(3, MIRType.INT), (1, 1)))
156
+ self.block.add_instruction(BinaryOp(t2, "-", t1, Constant(2, MIRType.INT), (1, 1)))
157
+
158
+ changed = self.opt.run_on_function(self.func)
159
+
160
+ assert changed
161
+ assert self.opt.stats.get("complex_pattern_matched") == 1
162
+ instructions = list(self.block.instructions)
163
+ # The last instruction should be BinaryOp(t2, "-", t0, Constant(5, (1, 1)))
164
+ assert isinstance(instructions[2], BinaryOp)
165
+ binary_inst = instructions[2]
166
+ assert isinstance(binary_inst, BinaryOp)
167
+ assert binary_inst.op == "-"
168
+ assert binary_inst.left == t0
169
+ assert isinstance(binary_inst.right, Constant)
170
+ assert isinstance(binary_inst.right, Constant)
171
+ assert binary_inst.right.value == 5
172
+
173
+ def test_commutative_add_subtract_pattern(self) -> None:
174
+ """Test (b + a) - b → a (commutative version)."""
175
+ t0 = Temp(MIRType.INT)
176
+ t1 = Temp(MIRType.INT)
177
+ t2 = Temp(MIRType.INT)
178
+ t3 = Temp(MIRType.INT)
179
+
180
+ # a = 10, b = 5, t2 = b + a, t3 = t2 - b
181
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
182
+ self.block.add_instruction(LoadConst(t1, Constant(5, MIRType.INT), (1, 1)))
183
+ self.block.add_instruction(BinaryOp(t2, "+", t1, t0, (1, 1)))
184
+ self.block.add_instruction(BinaryOp(t3, "-", t2, t1, (1, 1)))
185
+
186
+ changed = self.opt.run_on_function(self.func)
187
+
188
+ assert changed
189
+ assert self.opt.stats.get("complex_pattern_matched") == 1
190
+ instructions = list(self.block.instructions)
191
+ # The last instruction should be Copy(t3, t0, (1, 1))
192
+ assert isinstance(instructions[3], Copy)
193
+ copy_inst = instructions[3]
194
+ assert isinstance(copy_inst, Copy)
195
+ assert copy_inst.source == t0
196
+ assert copy_inst.dest == t3
197
+
198
+ def test_commutative_multiply_divide_pattern(self) -> None:
199
+ """Test (b * a) / b → a (commutative version)."""
200
+ t0 = Temp(MIRType.INT)
201
+ t1 = Temp(MIRType.INT)
202
+ t2 = Temp(MIRType.INT)
203
+ t3 = Temp(MIRType.INT)
204
+
205
+ # a = 10, b = 5, t2 = b * a, t3 = t2 / b
206
+ self.block.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
207
+ self.block.add_instruction(LoadConst(t1, Constant(5, MIRType.INT), (1, 1)))
208
+ self.block.add_instruction(BinaryOp(t2, "*", t1, t0, (1, 1)))
209
+ self.block.add_instruction(BinaryOp(t3, "/", t2, t1, (1, 1)))
210
+
211
+ changed = self.opt.run_on_function(self.func)
212
+
213
+ assert changed
214
+ assert self.opt.stats.get("complex_pattern_matched") == 1
215
+ instructions = list(self.block.instructions)
216
+ # The last instruction should be Copy(t3, t0, (1, 1))
217
+ assert isinstance(instructions[3], Copy)
218
+ copy_inst = instructions[3]
219
+ assert isinstance(copy_inst, Copy)
220
+ assert copy_inst.source == t0
221
+ assert copy_inst.dest == t3
@@ -0,0 +1,126 @@
1
+ """Tests for division optimizations in algebraic simplification."""
2
+
3
+ from machine_dialect.mir.basic_block import BasicBlock
4
+ from machine_dialect.mir.mir_function import MIRFunction
5
+ from machine_dialect.mir.mir_instructions import BinaryOp, Copy, LoadConst, UnaryOp
6
+ from machine_dialect.mir.mir_module import MIRModule
7
+ from machine_dialect.mir.mir_transformer import MIRTransformer
8
+ from machine_dialect.mir.mir_types import MIRType
9
+ from machine_dialect.mir.mir_values import Constant, Temp
10
+ from machine_dialect.mir.optimizations.algebraic_simplification import AlgebraicSimplification
11
+
12
+
13
+ class TestAlgebraicSimplificationDivision:
14
+ """Test division operation simplifications."""
15
+
16
+ def setup_method(self) -> None:
17
+ """Set up test fixtures."""
18
+ self.module = MIRModule("test")
19
+ self.func = MIRFunction("test_func", [], MIRType.INT)
20
+ self.block = BasicBlock("entry")
21
+ self.func.cfg.add_block(self.block)
22
+ self.func.cfg.entry_block = self.block
23
+ self.module.add_function(self.func)
24
+ self.transformer = MIRTransformer(self.func)
25
+ self.opt = AlgebraicSimplification()
26
+
27
+ def test_divide_by_one(self) -> None:
28
+ """Test x / 1 → x."""
29
+ t0 = Temp(MIRType.INT)
30
+ t1 = Temp(MIRType.INT)
31
+ self.block.add_instruction(LoadConst(t0, Constant(42, MIRType.INT), (1, 1)))
32
+ self.block.add_instruction(BinaryOp(t1, "/", t0, Constant(1, MIRType.INT), (1, 1)))
33
+
34
+ changed = self.opt.run_on_function(self.func)
35
+
36
+ assert changed
37
+ assert self.opt.stats.get("division_simplified") == 1
38
+ instructions = list(self.block.instructions)
39
+ assert isinstance(instructions[1], Copy)
40
+ copy_inst = instructions[1]
41
+ assert isinstance(copy_inst, Copy)
42
+ assert copy_inst.source == t0
43
+
44
+ def test_divide_self(self) -> None:
45
+ """Test x / x → 1."""
46
+ t0 = Temp(MIRType.INT)
47
+ t1 = Temp(MIRType.INT)
48
+ self.block.add_instruction(LoadConst(t0, Constant(42, MIRType.INT), (1, 1)))
49
+ self.block.add_instruction(BinaryOp(t1, "/", t0, t0, (1, 1)))
50
+
51
+ changed = self.opt.run_on_function(self.func)
52
+
53
+ assert changed
54
+ assert self.opt.stats.get("division_simplified") == 1
55
+ instructions = list(self.block.instructions)
56
+ assert isinstance(instructions[1], LoadConst)
57
+ load_inst = instructions[1]
58
+ assert isinstance(load_inst, LoadConst)
59
+ assert load_inst.constant.value == 1
60
+
61
+ def test_zero_divided_by_x(self) -> None:
62
+ """Test 0 / x → 0."""
63
+ t0 = Temp(MIRType.INT)
64
+ self.block.add_instruction(BinaryOp(t0, "/", Constant(0, MIRType.INT), Constant(42, MIRType.INT), (1, 1)))
65
+
66
+ changed = self.opt.run_on_function(self.func)
67
+
68
+ assert changed
69
+ assert self.opt.stats.get("division_simplified") == 1
70
+ instructions = list(self.block.instructions)
71
+ assert isinstance(instructions[0], LoadConst)
72
+ load_inst = instructions[0]
73
+ assert isinstance(load_inst, LoadConst)
74
+ assert load_inst.constant.value == 0
75
+
76
+ def test_divide_by_negative_one(self) -> None:
77
+ """Test x / -1 → -x."""
78
+ t0 = Temp(MIRType.INT)
79
+ t1 = Temp(MIRType.INT)
80
+ self.block.add_instruction(LoadConst(t0, Constant(42, MIRType.INT), (1, 1)))
81
+ self.block.add_instruction(BinaryOp(t1, "/", t0, Constant(-1, MIRType.INT), (1, 1)))
82
+
83
+ changed = self.opt.run_on_function(self.func)
84
+
85
+ assert changed
86
+ assert self.opt.stats.get("division_simplified") == 1
87
+ instructions = list(self.block.instructions)
88
+ assert isinstance(instructions[1], UnaryOp)
89
+ unary_inst = instructions[1]
90
+ assert isinstance(unary_inst, UnaryOp)
91
+ assert unary_inst.op == "-"
92
+ assert unary_inst.operand == t0
93
+
94
+ def test_integer_divide_by_one(self) -> None:
95
+ """Test x // 1 → x."""
96
+ t0 = Temp(MIRType.INT)
97
+ t1 = Temp(MIRType.INT)
98
+ self.block.add_instruction(LoadConst(t0, Constant(42, MIRType.INT), (1, 1)))
99
+ self.block.add_instruction(BinaryOp(t1, "//", t0, Constant(1, MIRType.INT), (1, 1)))
100
+
101
+ changed = self.opt.run_on_function(self.func)
102
+
103
+ assert changed
104
+ assert self.opt.stats.get("division_simplified") == 1
105
+ instructions = list(self.block.instructions)
106
+ assert isinstance(instructions[1], Copy)
107
+ copy_inst = instructions[1]
108
+ assert isinstance(copy_inst, Copy)
109
+ assert copy_inst.source == t0
110
+
111
+ def test_integer_divide_self(self) -> None:
112
+ """Test x // x → 1."""
113
+ t0 = Temp(MIRType.INT)
114
+ t1 = Temp(MIRType.INT)
115
+ self.block.add_instruction(LoadConst(t0, Constant(42, MIRType.INT), (1, 1)))
116
+ self.block.add_instruction(BinaryOp(t1, "//", t0, t0, (1, 1)))
117
+
118
+ changed = self.opt.run_on_function(self.func)
119
+
120
+ assert changed
121
+ assert self.opt.stats.get("division_simplified") == 1
122
+ instructions = list(self.block.instructions)
123
+ assert isinstance(instructions[1], LoadConst)
124
+ load_inst = instructions[1]
125
+ assert isinstance(load_inst, LoadConst)
126
+ assert load_inst.constant.value == 1