machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,422 @@
1
+ """Strength reduction optimization pass.
2
+
3
+ This module implements strength reduction at the MIR level, replacing
4
+ expensive operations with cheaper equivalents.
5
+ """
6
+
7
+ from machine_dialect.mir.basic_block import BasicBlock
8
+ from machine_dialect.mir.mir_function import MIRFunction
9
+ from machine_dialect.mir.mir_instructions import (
10
+ BinaryOp,
11
+ Copy,
12
+ LoadConst,
13
+ MIRInstruction,
14
+ UnaryOp,
15
+ )
16
+ from machine_dialect.mir.mir_transformer import MIRTransformer
17
+ from machine_dialect.mir.mir_types import MIRType
18
+ from machine_dialect.mir.mir_values import Constant, MIRValue
19
+ from machine_dialect.mir.optimization_pass import (
20
+ OptimizationPass,
21
+ PassInfo,
22
+ PassType,
23
+ PreservationLevel,
24
+ )
25
+
26
+
27
+ class StrengthReduction(OptimizationPass):
28
+ """Strength reduction optimization pass."""
29
+
30
+ def get_info(self) -> PassInfo:
31
+ """Get pass information.
32
+
33
+ Returns:
34
+ Pass information.
35
+ """
36
+ return PassInfo(
37
+ name="strength-reduction",
38
+ description="Replace expensive operations with cheaper equivalents",
39
+ pass_type=PassType.OPTIMIZATION,
40
+ requires=[],
41
+ preserves=PreservationLevel.CFG,
42
+ )
43
+
44
+ def run_on_function(self, function: MIRFunction) -> bool:
45
+ """Run strength reduction on a function.
46
+
47
+ Args:
48
+ function: The function to optimize.
49
+
50
+ Returns:
51
+ True if the function was modified.
52
+ """
53
+ transformer = MIRTransformer(function)
54
+
55
+ for block in function.cfg.blocks.values():
56
+ self._reduce_block(block, transformer)
57
+
58
+ return transformer.modified
59
+
60
+ def _reduce_block(self, block: BasicBlock, transformer: MIRTransformer) -> None:
61
+ """Apply strength reduction to a block.
62
+
63
+ Args:
64
+ block: The block to optimize.
65
+ transformer: MIR transformer.
66
+ """
67
+ for inst in list(block.instructions):
68
+ if isinstance(inst, BinaryOp):
69
+ self._reduce_binary_op(inst, block, transformer)
70
+ elif isinstance(inst, UnaryOp):
71
+ self._reduce_unary_op(inst, block, transformer)
72
+
73
+ def _reduce_binary_op(
74
+ self,
75
+ inst: BinaryOp,
76
+ block: BasicBlock,
77
+ transformer: MIRTransformer,
78
+ ) -> None:
79
+ """Apply strength reduction to a binary operation.
80
+
81
+ Args:
82
+ inst: Binary operation instruction.
83
+ block: Containing block.
84
+ transformer: MIR transformer.
85
+ """
86
+ # Check for multiplication by power of 2
87
+ if inst.op == "*":
88
+ # Check for x * 1 or 1 * x first (special cases)
89
+ if self._is_one(inst.right) or self._is_one(inst.left):
90
+ # Don't convert to shift, let algebraic simplifications handle it
91
+ pass
92
+ elif self._is_power_of_two_constant(inst.right):
93
+ shift = self._get_power_of_two(inst.right)
94
+ if shift is not None and shift > 0: # Only optimize for shift > 0
95
+ # Replace multiplication with left shift
96
+ shift_const = Constant(shift, MIRType.INT)
97
+ new_inst = BinaryOp(inst.dest, "<<", inst.left, shift_const, inst.source_location)
98
+ transformer.replace_instruction(block, inst, new_inst)
99
+ self.stats["multiply_to_shift"] = self.stats.get("multiply_to_shift", 0) + 1
100
+ return
101
+ elif self._is_power_of_two_constant(inst.left):
102
+ shift = self._get_power_of_two(inst.left)
103
+ if shift is not None and shift > 0: # Only optimize for shift > 0
104
+ # Replace multiplication with left shift (commutative)
105
+ shift_const = Constant(shift, MIRType.INT)
106
+ new_inst = BinaryOp(inst.dest, "<<", inst.right, shift_const, inst.source_location)
107
+ transformer.replace_instruction(block, inst, new_inst)
108
+ self.stats["multiply_to_shift"] = self.stats.get("multiply_to_shift", 0) + 1
109
+ return
110
+
111
+ # Check for division by power of 2
112
+ elif inst.op in ["/", "//"]:
113
+ # Check for x / 1 first (special case)
114
+ if self._is_one(inst.right):
115
+ # Don't convert to shift, let algebraic simplifications handle it
116
+ pass
117
+ elif self._is_power_of_two_constant(inst.right):
118
+ shift = self._get_power_of_two(inst.right)
119
+ if shift is not None and shift > 0: # Only optimize for shift > 0
120
+ # Replace division with right shift (for integers)
121
+ shift_const = Constant(shift, MIRType.INT)
122
+ new_inst = BinaryOp(inst.dest, ">>", inst.left, shift_const, inst.source_location)
123
+ transformer.replace_instruction(block, inst, new_inst)
124
+ self.stats["divide_to_shift"] = self.stats.get("divide_to_shift", 0) + 1
125
+ return
126
+
127
+ # Check for modulo by power of 2
128
+ elif inst.op == "%":
129
+ if self._is_power_of_two_constant(inst.right):
130
+ power = self._get_constant_value(inst.right)
131
+ if power is not None and power > 0:
132
+ # Replace modulo with bitwise AND (n % power = n & (power - 1))
133
+ mask = Constant(power - 1, MIRType.INT)
134
+ new_inst = BinaryOp(inst.dest, "&", inst.left, mask, inst.source_location)
135
+ transformer.replace_instruction(block, inst, new_inst)
136
+ self.stats["modulo_to_and"] = self.stats.get("modulo_to_and", 0) + 1
137
+ return
138
+
139
+ # Algebraic simplifications
140
+ self._apply_algebraic_simplifications(inst, block, transformer)
141
+
142
+ def _reduce_unary_op(
143
+ self,
144
+ inst: UnaryOp,
145
+ block: BasicBlock,
146
+ transformer: MIRTransformer,
147
+ ) -> None:
148
+ """Apply strength reduction to a unary operation.
149
+
150
+ Args:
151
+ inst: Unary operation instruction.
152
+ block: Containing block.
153
+ transformer: MIR transformer.
154
+ """
155
+ # Double negation elimination
156
+ if inst.op == "-":
157
+ # Check if operand is result of another negation
158
+ # This would require tracking def-use chains
159
+ pass
160
+
161
+ # Boolean not simplification
162
+ elif inst.op == "not":
163
+ # Check for not(not(x)) pattern
164
+ pass
165
+
166
+ def _apply_algebraic_simplifications(
167
+ self,
168
+ inst: BinaryOp,
169
+ block: BasicBlock,
170
+ transformer: MIRTransformer,
171
+ ) -> None:
172
+ """Apply algebraic simplifications to binary operations.
173
+
174
+ Args:
175
+ inst: Binary operation instruction.
176
+ block: Containing block.
177
+ transformer: MIR transformer.
178
+ """
179
+ new_inst: MIRInstruction
180
+ # Identity operations
181
+ if inst.op == "+":
182
+ # x + 0 = x
183
+ if self._is_zero(inst.right):
184
+ new_inst = Copy(inst.dest, inst.left, inst.source_location)
185
+ transformer.replace_instruction(block, inst, new_inst)
186
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
187
+ return
188
+ elif self._is_zero(inst.left):
189
+ new_inst = Copy(inst.dest, inst.right, inst.source_location)
190
+ transformer.replace_instruction(block, inst, new_inst)
191
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
192
+ return
193
+
194
+ elif inst.op == "-":
195
+ # x - 0 = x
196
+ if self._is_zero(inst.right):
197
+ new_inst = Copy(inst.dest, inst.left, inst.source_location)
198
+ transformer.replace_instruction(block, inst, new_inst)
199
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
200
+ return
201
+ # x - x = 0
202
+ elif self._values_equal(inst.left, inst.right):
203
+ zero = Constant(0, MIRType.INT)
204
+ new_inst = LoadConst(inst.dest, zero, inst.source_location)
205
+ transformer.replace_instruction(block, inst, new_inst)
206
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
207
+ return
208
+
209
+ elif inst.op == "*":
210
+ # x * 0 = 0
211
+ if self._is_zero(inst.right) or self._is_zero(inst.left):
212
+ zero = Constant(0, MIRType.INT)
213
+ new_inst = LoadConst(inst.dest, zero, inst.source_location)
214
+ transformer.replace_instruction(block, inst, new_inst)
215
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
216
+ return
217
+ # x * 1 = x
218
+ elif self._is_one(inst.right):
219
+ new_inst = Copy(inst.dest, inst.left, inst.source_location)
220
+ transformer.replace_instruction(block, inst, new_inst)
221
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
222
+ return
223
+ elif self._is_one(inst.left):
224
+ new_inst = Copy(inst.dest, inst.right, inst.source_location)
225
+ transformer.replace_instruction(block, inst, new_inst)
226
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
227
+ return
228
+ # x * -1 = -x
229
+ elif self._is_negative_one(inst.right):
230
+ new_inst = UnaryOp(inst.dest, "-", inst.left, inst.source_location)
231
+ transformer.replace_instruction(block, inst, new_inst)
232
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
233
+ return
234
+ elif self._is_negative_one(inst.left):
235
+ new_inst = UnaryOp(inst.dest, "-", inst.right, inst.source_location)
236
+ transformer.replace_instruction(block, inst, new_inst)
237
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
238
+ return
239
+
240
+ elif inst.op in ["/", "//"]:
241
+ # x / 1 = x
242
+ if self._is_one(inst.right):
243
+ new_inst = Copy(inst.dest, inst.left, inst.source_location)
244
+ transformer.replace_instruction(block, inst, new_inst)
245
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
246
+ return
247
+ # x / x = 1 (if x != 0)
248
+ elif self._values_equal(inst.left, inst.right):
249
+ one = Constant(1, MIRType.INT)
250
+ new_inst = LoadConst(inst.dest, one, inst.source_location)
251
+ transformer.replace_instruction(block, inst, new_inst)
252
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
253
+ return
254
+
255
+ # Boolean operations
256
+ elif inst.op == "and":
257
+ # x and False = False (check this first as it's stronger)
258
+ if self._is_false(inst.right) or self._is_false(inst.left):
259
+ false = Constant(False, MIRType.BOOL)
260
+ new_inst = LoadConst(inst.dest, false, inst.source_location)
261
+ transformer.replace_instruction(block, inst, new_inst)
262
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
263
+ return
264
+ # x and True = x
265
+ elif self._is_true(inst.right):
266
+ new_inst = Copy(inst.dest, inst.left, inst.source_location)
267
+ transformer.replace_instruction(block, inst, new_inst)
268
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
269
+ return
270
+ elif self._is_true(inst.left):
271
+ new_inst = Copy(inst.dest, inst.right, inst.source_location)
272
+ transformer.replace_instruction(block, inst, new_inst)
273
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
274
+ return
275
+
276
+ elif inst.op == "or":
277
+ # x or True = True (check this first as it's stronger)
278
+ if self._is_true(inst.right) or self._is_true(inst.left):
279
+ true = Constant(True, MIRType.BOOL)
280
+ new_inst = LoadConst(inst.dest, true, inst.source_location)
281
+ transformer.replace_instruction(block, inst, new_inst)
282
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
283
+ return
284
+ # x or False = x
285
+ elif self._is_false(inst.right):
286
+ new_inst = Copy(inst.dest, inst.left, inst.source_location)
287
+ transformer.replace_instruction(block, inst, new_inst)
288
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
289
+ return
290
+ elif self._is_false(inst.left):
291
+ new_inst = Copy(inst.dest, inst.right, inst.source_location)
292
+ transformer.replace_instruction(block, inst, new_inst)
293
+ self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
294
+ return
295
+
296
+ def _is_power_of_two_constant(self, value: MIRValue) -> bool:
297
+ """Check if a value is a constant power of 2.
298
+
299
+ Args:
300
+ value: Value to check.
301
+
302
+ Returns:
303
+ True if the value is a power of 2 constant.
304
+ """
305
+ if isinstance(value, Constant):
306
+ val = value.value
307
+ if isinstance(val, int) and val > 0:
308
+ # Check if only one bit is set
309
+ return (val & (val - 1)) == 0
310
+ return False
311
+
312
+ def _get_power_of_two(self, value: MIRValue) -> int | None:
313
+ """Get the power of 2 exponent.
314
+
315
+ Args:
316
+ value: Power of 2 constant.
317
+
318
+ Returns:
319
+ The exponent or None.
320
+ """
321
+ if isinstance(value, Constant):
322
+ val = value.value
323
+ if isinstance(val, int) and val > 0 and (val & (val - 1)) == 0:
324
+ # Count trailing zeros to get exponent
325
+ exp = 0
326
+ while val > 1:
327
+ val >>= 1
328
+ exp += 1
329
+ return exp
330
+ return None
331
+
332
+ def _get_constant_value(self, value: MIRValue) -> int | float | bool | None:
333
+ """Get constant value if it's a constant.
334
+
335
+ Args:
336
+ value: MIR value.
337
+
338
+ Returns:
339
+ The constant value or None.
340
+ """
341
+ if isinstance(value, Constant):
342
+ val = value.value
343
+ if isinstance(val, int | float | bool):
344
+ return val
345
+ return None
346
+
347
+ def _is_zero(self, value: MIRValue) -> bool:
348
+ """Check if value is zero.
349
+
350
+ Args:
351
+ value: Value to check.
352
+
353
+ Returns:
354
+ True if value is zero.
355
+ """
356
+ val = self._get_constant_value(value)
357
+ return val == 0
358
+
359
+ def _is_one(self, value: MIRValue) -> bool:
360
+ """Check if value is one.
361
+
362
+ Args:
363
+ value: Value to check.
364
+
365
+ Returns:
366
+ True if value is one.
367
+ """
368
+ val = self._get_constant_value(value)
369
+ return val == 1
370
+
371
+ def _is_negative_one(self, value: MIRValue) -> bool:
372
+ """Check if value is negative one.
373
+
374
+ Args:
375
+ value: Value to check.
376
+
377
+ Returns:
378
+ True if value is -1.
379
+ """
380
+ val = self._get_constant_value(value)
381
+ return val == -1
382
+
383
+ def _is_true(self, value: MIRValue) -> bool:
384
+ """Check if value is boolean true.
385
+
386
+ Args:
387
+ value: Value to check.
388
+
389
+ Returns:
390
+ True if value is boolean true.
391
+ """
392
+ val = self._get_constant_value(value)
393
+ return val is True
394
+
395
+ def _is_false(self, value: MIRValue) -> bool:
396
+ """Check if value is boolean false.
397
+
398
+ Args:
399
+ value: Value to check.
400
+
401
+ Returns:
402
+ True if value is boolean false.
403
+ """
404
+ val = self._get_constant_value(value)
405
+ return val is False
406
+
407
+ def _values_equal(self, v1: MIRValue, v2: MIRValue) -> bool:
408
+ """Check if two values are equal.
409
+
410
+ Args:
411
+ v1: First value.
412
+ v2: Second value.
413
+
414
+ Returns:
415
+ True if values are equal.
416
+ """
417
+ # Simple equality check - could be enhanced
418
+ return v1 == v2
419
+
420
+ def finalize(self) -> None:
421
+ """Finalize the pass."""
422
+ pass
@@ -0,0 +1,207 @@
1
+ """Tail call optimization pass.
2
+
3
+ This module implements tail call optimization to transform recursive calls
4
+ in tail position into jumps, eliminating stack growth for tail-recursive functions.
5
+ """
6
+
7
+ from machine_dialect.mir.basic_block import BasicBlock
8
+ from machine_dialect.mir.mir_function import MIRFunction
9
+ from machine_dialect.mir.mir_instructions import Call, Copy, Return
10
+ from machine_dialect.mir.mir_module import MIRModule
11
+ from machine_dialect.mir.optimization_pass import (
12
+ ModulePass,
13
+ PassInfo,
14
+ PassType,
15
+ PreservationLevel,
16
+ )
17
+
18
+
19
+ class TailCallOptimization(ModulePass):
20
+ """Tail call optimization pass.
21
+
22
+ This pass identifies function calls in tail position and marks them
23
+ for optimization. A call is in tail position if:
24
+ 1. It's immediately followed by a return of its result
25
+ 2. Or it's the last instruction before a return (for void calls)
26
+
27
+ The actual transformation to jumps happens during bytecode generation.
28
+ """
29
+
30
+ def __init__(self) -> None:
31
+ """Initialize the tail call optimization pass."""
32
+ super().__init__()
33
+ self.stats = {
34
+ "tail_calls_found": 0,
35
+ "functions_optimized": 0,
36
+ "recursive_tail_calls": 0,
37
+ }
38
+
39
+ def get_info(self) -> PassInfo:
40
+ """Get pass information.
41
+
42
+ Returns:
43
+ Pass information.
44
+ """
45
+ return PassInfo(
46
+ name="tail-call",
47
+ description="Optimize tail calls into jumps",
48
+ pass_type=PassType.OPTIMIZATION,
49
+ requires=[],
50
+ preserves=PreservationLevel.CFG,
51
+ )
52
+
53
+ def finalize(self) -> None:
54
+ """Finalize the pass after running.
55
+
56
+ Override from base class - no special finalization needed.
57
+ """
58
+ pass
59
+
60
+ def run_on_module(self, module: MIRModule) -> bool:
61
+ """Run tail call optimization on a module.
62
+
63
+ Args:
64
+ module: The module to optimize.
65
+
66
+ Returns:
67
+ True if the module was modified.
68
+ """
69
+ modified = False
70
+
71
+ # Process each function
72
+ for func_name, function in module.functions.items():
73
+ if self._optimize_tail_calls_in_function(function, func_name):
74
+ modified = True
75
+ self.stats["functions_optimized"] += 1
76
+
77
+ return modified
78
+
79
+ def _optimize_tail_calls_in_function(self, function: MIRFunction, func_name: str) -> bool:
80
+ """Optimize tail calls in a single function.
81
+
82
+ Args:
83
+ function: The function to optimize.
84
+ func_name: Name of the function (for recursive call detection).
85
+
86
+ Returns:
87
+ True if the function was modified.
88
+ """
89
+ modified = False
90
+
91
+ # Process each basic block
92
+ for block in function.cfg.blocks.values():
93
+ if self._optimize_tail_calls_in_block(block, func_name):
94
+ modified = True
95
+
96
+ return modified
97
+
98
+ def _optimize_tail_calls_in_block(self, block: BasicBlock, func_name: str) -> bool:
99
+ """Optimize tail calls in a basic block.
100
+
101
+ Args:
102
+ block: The basic block to process.
103
+ func_name: Name of the containing function.
104
+
105
+ Returns:
106
+ True if the block was modified.
107
+ """
108
+ modified = False
109
+ instructions = block.instructions
110
+
111
+ # Look for tail call patterns
112
+ i = 0
113
+ while i < len(instructions):
114
+ inst = instructions[i]
115
+
116
+ # Pattern 1: Call followed by Return of its result
117
+ if isinstance(inst, Call) and not inst.is_tail_call:
118
+ if i + 1 < len(instructions):
119
+ next_inst = instructions[i + 1]
120
+
121
+ # Direct return of call result
122
+ if isinstance(next_inst, Return) and next_inst.value == inst.dest:
123
+ inst.is_tail_call = True
124
+ self.stats["tail_calls_found"] += 1
125
+ modified = True
126
+
127
+ # Check if it's a recursive call
128
+ if hasattr(inst.func, "name") and inst.func.name == func_name:
129
+ self.stats["recursive_tail_calls"] += 1
130
+
131
+ # Call result copied to variable, then returned
132
+ elif i + 2 < len(instructions) and isinstance(next_inst, Copy):
133
+ third_inst = instructions[i + 2]
134
+ if (
135
+ isinstance(third_inst, Return)
136
+ and next_inst.source == inst.dest
137
+ and third_inst.value == next_inst.dest
138
+ ):
139
+ inst.is_tail_call = True
140
+ self.stats["tail_calls_found"] += 1
141
+ modified = True
142
+
143
+ # Check if it's a recursive call
144
+ if hasattr(inst.func, "name") and inst.func.name == func_name:
145
+ self.stats["recursive_tail_calls"] += 1
146
+
147
+ # Pattern 2: Void call followed by return
148
+ elif isinstance(inst, Call) and inst.dest is None and not inst.is_tail_call:
149
+ if i + 1 < len(instructions):
150
+ next_inst = instructions[i + 1]
151
+ if isinstance(next_inst, Return) and next_inst.value is None:
152
+ inst.is_tail_call = True
153
+ self.stats["tail_calls_found"] += 1
154
+ modified = True
155
+
156
+ # Check if it's a recursive call
157
+ if hasattr(inst.func, "name") and inst.func.name == func_name:
158
+ self.stats["recursive_tail_calls"] += 1
159
+
160
+ i += 1
161
+
162
+ return modified
163
+
164
+ def _is_tail_position(self, block: BasicBlock, instruction_index: int) -> bool:
165
+ """Check if an instruction is in tail position.
166
+
167
+ An instruction is in tail position if all paths from it lead
168
+ directly to a return without any other side effects.
169
+
170
+ Args:
171
+ block: The basic block containing the instruction.
172
+ instruction_index: Index of the instruction in the block.
173
+
174
+ Returns:
175
+ True if the instruction is in tail position.
176
+ """
177
+ # Simple check: instruction is followed only by a return
178
+ instructions = block.instructions
179
+
180
+ # Check remaining instructions after this one
181
+ for i in range(instruction_index + 1, len(instructions)):
182
+ inst = instructions[i]
183
+
184
+ # Return is ok
185
+ if isinstance(inst, Return):
186
+ return True
187
+
188
+ # Copy is ok if it's just moving the result
189
+ if isinstance(inst, Copy):
190
+ continue
191
+
192
+ # Any other instruction means not in tail position
193
+ return False
194
+
195
+ # If we reach end of block without return, check if block
196
+ # has a single successor that starts with return
197
+ # (This would require more complex CFG analysis)
198
+
199
+ return False
200
+
201
+ def get_statistics(self) -> dict[str, int]:
202
+ """Get optimization statistics.
203
+
204
+ Returns:
205
+ Dictionary of statistics.
206
+ """
207
+ return self.stats