machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,551 @@
1
+ """Function inlining optimization pass.
2
+
3
+ This module implements function inlining to eliminate call overhead and
4
+ enable further optimizations by exposing more code to analysis.
5
+ """
6
+
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from machine_dialect.mir.basic_block import BasicBlock
12
+ from machine_dialect.mir.mir_function import MIRFunction
13
+ from machine_dialect.mir.mir_instructions import (
14
+ Call,
15
+ ConditionalJump,
16
+ Copy,
17
+ Jump,
18
+ MIRInstruction,
19
+ Phi,
20
+ Return,
21
+ )
22
+ from machine_dialect.mir.mir_module import MIRModule
23
+ from machine_dialect.mir.mir_transformer import MIRTransformer
24
+ from machine_dialect.mir.mir_types import MIRType
25
+ from machine_dialect.mir.mir_values import FunctionRef, MIRValue, Temp, Variable
26
+ from machine_dialect.mir.optimization_pass import (
27
+ ModulePass,
28
+ PassInfo,
29
+ PassType,
30
+ PreservationLevel,
31
+ )
32
+
33
+
34
+ @dataclass
35
+ class InliningCost:
36
+ """Cost model for inlining decisions.
37
+
38
+ Attributes:
39
+ instruction_count: Number of instructions in the function.
40
+ call_site_benefit: Benefit from inlining at this call site.
41
+ size_threshold: Maximum size for inlining.
42
+ depth: Current inlining depth (to prevent infinite recursion).
43
+ """
44
+
45
+ instruction_count: int
46
+ call_site_benefit: float
47
+ size_threshold: int
48
+ depth: int
49
+
50
+ def should_inline(self) -> bool:
51
+ """Determine if function should be inlined.
52
+
53
+ Returns:
54
+ True if inlining is beneficial.
55
+ """
56
+ # Don't inline if too deep (prevent infinite recursion)
57
+ if self.depth > 3:
58
+ return False
59
+
60
+ # Don't inline very large functions
61
+ if self.instruction_count > self.size_threshold:
62
+ return False
63
+
64
+ # Inline small functions (always beneficial)
65
+ if self.instruction_count <= 5:
66
+ return True
67
+
68
+ # Use cost-benefit analysis for medium functions
69
+ # Higher benefit for functions that enable optimizations
70
+ cost = self.instruction_count * 1.0
71
+ benefit = self.call_site_benefit
72
+
73
+ # Inline if benefit outweighs cost
74
+ return benefit > cost
75
+
76
+
77
+ class FunctionInlining(ModulePass):
78
+ """Function inlining optimization pass."""
79
+
80
+ def __init__(self, size_threshold: int = 50) -> None:
81
+ """Initialize inlining pass.
82
+
83
+ Args:
84
+ size_threshold: Maximum function size to consider for inlining.
85
+ """
86
+ super().__init__()
87
+ self.size_threshold = size_threshold
88
+ self.stats = {"inlined": 0, "call_sites_processed": 0}
89
+ self.inlining_depth: dict[str, int] = defaultdict(int)
90
+
91
+ def initialize(self) -> None:
92
+ """Initialize the pass before running."""
93
+ super().initialize()
94
+ # Re-initialize stats after base class clears them
95
+ self.stats = {"inlined": 0, "call_sites_processed": 0}
96
+ self.inlining_depth = defaultdict(int)
97
+
98
+ def get_info(self) -> PassInfo:
99
+ """Get pass information.
100
+
101
+ Returns:
102
+ Pass information.
103
+ """
104
+ return PassInfo(
105
+ name="inline",
106
+ description="Inline function calls",
107
+ pass_type=PassType.OPTIMIZATION,
108
+ requires=[],
109
+ preserves=PreservationLevel.NONE,
110
+ )
111
+
112
+ def run_on_module(self, module: MIRModule) -> bool:
113
+ """Run inlining on a module.
114
+
115
+ Args:
116
+ module: The module to optimize.
117
+
118
+ Returns:
119
+ True if the module was modified.
120
+ """
121
+ modified = False
122
+
123
+ # Process each function
124
+ for _, function in module.functions.items():
125
+ if self._inline_calls_in_function(function, module):
126
+ modified = True
127
+
128
+ return modified
129
+
130
+ def _inline_calls_in_function(self, function: MIRFunction, module: MIRModule) -> bool:
131
+ """Inline calls within a function.
132
+
133
+ Args:
134
+ function: The function to process.
135
+ module: The containing module.
136
+
137
+ Returns:
138
+ True if modifications were made.
139
+ """
140
+ modified = False
141
+ transformer = MIRTransformer(function)
142
+
143
+ # Keep inlining until no more opportunities
144
+ # This handles the case where inlining creates new opportunities
145
+ changed = True
146
+ while changed:
147
+ changed = False
148
+
149
+ # Find all call sites fresh each iteration
150
+ call_sites = self._find_call_sites(function)
151
+
152
+ for call_inst, block in call_sites:
153
+ self.stats["call_sites_processed"] += 1
154
+
155
+ # Get the called function
156
+ if not isinstance(call_inst.func, FunctionRef):
157
+ continue
158
+
159
+ callee_name = call_inst.func.name
160
+ if callee_name not in module.functions:
161
+ continue
162
+
163
+ callee = module.functions[callee_name]
164
+
165
+ # Check if we should inline
166
+ cost = self._calculate_inlining_cost(callee, call_inst, self.inlining_depth[callee_name])
167
+ if not cost.should_inline():
168
+ continue
169
+
170
+ # Don't inline recursive functions directly
171
+ if callee_name == function.name:
172
+ continue
173
+
174
+ # Verify the call is still in the block (it might have been removed by previous inlining)
175
+ if call_inst not in block.instructions:
176
+ continue
177
+
178
+ # Perform inlining
179
+ self.inlining_depth[callee_name] += 1
180
+ if self._inline_call(call_inst, block, callee, function, transformer):
181
+ modified = True
182
+ changed = True
183
+ self.stats["inlined"] += 1
184
+ # Break inner loop to re-find call sites
185
+ break
186
+ self.inlining_depth[callee_name] -= 1
187
+
188
+ return modified
189
+
190
+ def _find_call_sites(self, function: MIRFunction) -> list[tuple[Call, BasicBlock]]:
191
+ """Find all call instructions in a function.
192
+
193
+ Args:
194
+ function: The function to search.
195
+
196
+ Returns:
197
+ List of (call instruction, containing block) pairs.
198
+ """
199
+ call_sites = []
200
+ for block in function.cfg.blocks.values():
201
+ for inst in block.instructions:
202
+ if isinstance(inst, Call):
203
+ call_sites.append((inst, block))
204
+ return call_sites
205
+
206
+ def _calculate_inlining_cost(self, callee: MIRFunction, call_inst: Call, depth: int) -> InliningCost:
207
+ """Calculate the cost of inlining a function.
208
+
209
+ Args:
210
+ callee: The function to inline.
211
+ call_inst: The call instruction.
212
+ depth: Current inlining depth.
213
+
214
+ Returns:
215
+ Inlining cost information.
216
+ """
217
+ # Count instructions in callee
218
+ instruction_count = sum(len(block.instructions) for block in callee.cfg.blocks.values())
219
+
220
+ # Calculate call site benefit
221
+ # Higher benefit if:
222
+ # - Arguments are constants (enables constant propagation)
223
+ # - Function is called in a loop (amortizes inlining cost)
224
+ # - Function has single return (simpler CFG merge)
225
+ benefit = 10.0 # Base benefit from removing call overhead
226
+
227
+ # Bonus for constant arguments
228
+ from machine_dialect.mir.mir_values import Constant
229
+
230
+ const_args = sum(1 for arg in call_inst.args if isinstance(arg, Constant))
231
+ benefit += const_args * 5.0
232
+
233
+ # Bonus for simple functions (single block)
234
+ if len(callee.cfg.blocks) == 1:
235
+ benefit += 10.0
236
+
237
+ # Penalty for multiple returns (complex CFG)
238
+ return_count = sum(
239
+ 1 for block in callee.cfg.blocks.values() for inst in block.instructions if isinstance(inst, Return)
240
+ )
241
+ if return_count > 1:
242
+ benefit -= (return_count - 1) * 5.0
243
+
244
+ return InliningCost(
245
+ instruction_count=instruction_count,
246
+ call_site_benefit=benefit,
247
+ size_threshold=self.size_threshold,
248
+ depth=depth,
249
+ )
250
+
251
+ def _inline_call(
252
+ self,
253
+ call_inst: Call,
254
+ call_block: BasicBlock,
255
+ callee: MIRFunction,
256
+ caller: MIRFunction,
257
+ transformer: MIRTransformer,
258
+ ) -> bool:
259
+ """Inline a function call.
260
+
261
+ Args:
262
+ call_inst: The call instruction to inline.
263
+ call_block: The block containing the call.
264
+ callee: The function to inline.
265
+ caller: The calling function.
266
+ transformer: MIR transformer.
267
+
268
+ Returns:
269
+ True if inlining succeeded.
270
+ """
271
+ # Create value mapping for parameters -> arguments
272
+ value_map: dict[MIRValue, MIRValue] = {}
273
+ # Get parameter names from the callee function
274
+ # The params might be strings or Variables
275
+ param_names: list[str] = []
276
+ if hasattr(callee, "params") and callee.params:
277
+ if isinstance(callee.params[0], str):
278
+ param_names = callee.params # type: ignore
279
+ else:
280
+ # Extract names from Variable objects
281
+ for p in callee.params:
282
+ if isinstance(p, Variable):
283
+ param_names.append(p.name)
284
+ else:
285
+ param_names.append(str(p))
286
+
287
+ for param_name, arg in zip(param_names, call_inst.args, strict=True):
288
+ param_var = Variable(param_name, MIRType.INT) # Assume INT for now
289
+ value_map[param_var] = arg
290
+
291
+ # Clone the callee's CFG
292
+ _cloned_blocks, entry_block, return_blocks = self._clone_function_body(callee, caller, value_map, transformer)
293
+
294
+ # Split the call block at the call instruction
295
+ call_idx = call_block.instructions.index(call_inst)
296
+ pre_call = call_block.instructions[:call_idx]
297
+ post_call = call_block.instructions[call_idx + 1 :]
298
+
299
+ # Create continuation block for code after the call
300
+ cont_block = BasicBlock(f"{call_block.label}_cont")
301
+ caller.cfg.add_block(cont_block)
302
+ cont_block.instructions = post_call
303
+ cont_block.successors = call_block.successors.copy()
304
+
305
+ # Update predecessors of original successors
306
+ for succ in call_block.successors:
307
+ succ.predecessors.remove(call_block)
308
+ succ.predecessors.append(cont_block)
309
+
310
+ # Modify call block to jump to inlined entry
311
+ call_block.instructions = [*pre_call, Jump(entry_block.label, call_inst.source_location)]
312
+ call_block.successors = [entry_block]
313
+ entry_block.predecessors.append(call_block)
314
+
315
+ # Handle returns from inlined function
316
+ if call_inst.dest:
317
+ # If the call has a destination, we need to merge return values
318
+ return_value_var = call_inst.dest
319
+ for ret_block in return_blocks:
320
+ # Replace return with assignment and jump to continuation
321
+ ret_inst = ret_block.instructions[-1]
322
+ assert isinstance(ret_inst, Return)
323
+ if ret_inst.value:
324
+ ret_block.instructions[-1] = Copy(return_value_var, ret_inst.value, ret_inst.source_location)
325
+ ret_block.instructions.append(Jump(cont_block.label, ret_inst.source_location))
326
+ else:
327
+ ret_block.instructions[-1] = Jump(cont_block.label, ret_inst.source_location)
328
+ ret_block.successors = [cont_block]
329
+ cont_block.predecessors.append(ret_block)
330
+ else:
331
+ # No return value - just jump to continuation
332
+ for ret_block in return_blocks:
333
+ ret_inst = ret_block.instructions[-1]
334
+ source_loc = ret_inst.source_location if hasattr(ret_inst, "source_location") else (0, 0)
335
+ ret_block.instructions[-1] = Jump(cont_block.label, source_loc)
336
+ ret_block.successors = [cont_block]
337
+ cont_block.predecessors.append(ret_block)
338
+
339
+ transformer.modified = True
340
+ return True
341
+
342
+ def _clone_function_body(
343
+ self,
344
+ callee: MIRFunction,
345
+ caller: MIRFunction,
346
+ value_map: dict[MIRValue, MIRValue],
347
+ transformer: MIRTransformer,
348
+ ) -> tuple[dict[str, BasicBlock], BasicBlock, list[BasicBlock]]:
349
+ """Clone a function's body for inlining.
350
+
351
+ Args:
352
+ callee: The function to clone.
353
+ caller: The calling function.
354
+ value_map: Mapping from callee values to caller values.
355
+ transformer: MIR transformer.
356
+
357
+ Returns:
358
+ Tuple of (cloned blocks dict, entry block, return blocks list).
359
+ """
360
+ # Create a mapping for blocks
361
+ block_map: dict[BasicBlock, BasicBlock] = {}
362
+ cloned_blocks: dict[str, BasicBlock] = {}
363
+
364
+ # First pass: create all blocks
365
+ for old_block in callee.cfg.blocks.values():
366
+ new_label = f"inlined_{callee.name}_{old_block.label}"
367
+ new_block = BasicBlock(new_label)
368
+ caller.cfg.add_block(new_block)
369
+ block_map[old_block] = new_block
370
+ cloned_blocks[new_label] = new_block
371
+
372
+ # Map entry block - if not set, assume first block or "entry" label
373
+ if callee.cfg.entry_block:
374
+ entry_block = block_map[callee.cfg.entry_block]
375
+ else:
376
+ # Try to find entry block by label
377
+ entry_block = None
378
+ for old_block in callee.cfg.blocks.values():
379
+ if old_block.label == "entry":
380
+ entry_block = block_map[old_block]
381
+ break
382
+ if not entry_block and block_map:
383
+ # Use first block as entry
384
+ entry_block = next(iter(block_map.values()))
385
+ if not entry_block:
386
+ # Create a dummy entry block if empty
387
+ entry_block = BasicBlock("inline_entry")
388
+
389
+ assert entry_block is not None, "Entry block must be set"
390
+
391
+ # Generate unique temps for the inlined function
392
+ temp_counter = caller._next_temp_id
393
+
394
+ def map_value(value: MIRValue) -> MIRValue:
395
+ """Map a value from callee to caller."""
396
+ if value in value_map:
397
+ return value_map[value]
398
+ if isinstance(value, Temp):
399
+ # Create new temp with unique ID
400
+ nonlocal temp_counter
401
+ new_temp = Temp(value.type, temp_counter)
402
+ temp_counter += 1
403
+ caller._next_temp_id = temp_counter
404
+ value_map[value] = new_temp
405
+ return new_temp
406
+ # Constants and other values remain unchanged
407
+ return value
408
+
409
+ # Second pass: clone instructions and update CFG
410
+ return_blocks = []
411
+ for old_block, new_block in block_map.items():
412
+ # Clone instructions
413
+ for inst in old_block.instructions:
414
+ cloned_inst = self._clone_instruction(inst, map_value, block_map)
415
+ new_block.instructions.append(cloned_inst)
416
+
417
+ # Track return blocks
418
+ if isinstance(cloned_inst, Return):
419
+ return_blocks.append(new_block)
420
+
421
+ # Update successors/predecessors
422
+ for succ in old_block.successors:
423
+ new_succ = block_map[succ]
424
+ new_block.successors.append(new_succ)
425
+ new_succ.predecessors.append(new_block)
426
+
427
+ return cloned_blocks, entry_block, return_blocks
428
+
429
+ def _clone_instruction(
430
+ self,
431
+ inst: MIRInstruction,
432
+ map_value: Any,
433
+ block_map: dict[BasicBlock, BasicBlock],
434
+ ) -> MIRInstruction:
435
+ """Clone an instruction with value remapping.
436
+
437
+ Args:
438
+ inst: The instruction to clone.
439
+ map_value: Function to map values.
440
+ block_map: Mapping from old blocks to new blocks.
441
+
442
+ Returns:
443
+ Cloned instruction.
444
+ """
445
+ # Import here to avoid circular dependency
446
+ from machine_dialect.mir.mir_instructions import BinaryOp, LoadConst, Print, StoreVar, UnaryOp
447
+
448
+ # Handle each instruction type
449
+ if isinstance(inst, BinaryOp):
450
+ return BinaryOp(
451
+ map_value(inst.dest),
452
+ inst.op,
453
+ map_value(inst.left),
454
+ map_value(inst.right),
455
+ inst.source_location,
456
+ )
457
+ elif isinstance(inst, UnaryOp):
458
+ return UnaryOp(
459
+ map_value(inst.dest),
460
+ inst.op,
461
+ map_value(inst.operand),
462
+ inst.source_location,
463
+ )
464
+ elif isinstance(inst, Copy):
465
+ return Copy(
466
+ map_value(inst.dest),
467
+ map_value(inst.source),
468
+ inst.source_location,
469
+ )
470
+ elif isinstance(inst, LoadConst):
471
+ return LoadConst(
472
+ map_value(inst.dest),
473
+ inst.constant.value if hasattr(inst.constant, "value") else inst.constant, # Use the constant value
474
+ inst.source_location,
475
+ )
476
+ elif isinstance(inst, StoreVar):
477
+ return StoreVar(
478
+ inst.var, # Variable names stay the same
479
+ map_value(inst.source),
480
+ inst.source_location,
481
+ )
482
+ elif isinstance(inst, Call):
483
+ return Call(
484
+ map_value(inst.dest) if inst.dest else None,
485
+ inst.func,
486
+ [map_value(arg) for arg in inst.args],
487
+ inst.source_location,
488
+ )
489
+ elif isinstance(inst, Return):
490
+ return Return(
491
+ inst.source_location,
492
+ map_value(inst.value) if inst.value else None,
493
+ )
494
+ elif isinstance(inst, ConditionalJump):
495
+ # Find the blocks that correspond to the labels
496
+ true_block = None
497
+ false_block = None
498
+ for old_b, new_b in block_map.items():
499
+ if old_b.label == inst.true_label:
500
+ true_block = new_b
501
+ if inst.false_label and old_b.label == inst.false_label:
502
+ false_block = new_b
503
+ return ConditionalJump(
504
+ map_value(inst.condition),
505
+ true_block.label if true_block else inst.true_label,
506
+ inst.source_location,
507
+ false_block.label if false_block else inst.false_label,
508
+ )
509
+ elif isinstance(inst, Jump):
510
+ # Find the block that corresponds to the label
511
+ target_block = None
512
+ for old_b, new_b in block_map.items():
513
+ if old_b.label == inst.label:
514
+ target_block = new_b
515
+ break
516
+ return Jump(
517
+ target_block.label if target_block else inst.label,
518
+ inst.source_location,
519
+ )
520
+ elif isinstance(inst, Phi):
521
+ new_incoming = []
522
+ for value, label in inst.incoming:
523
+ # Find the new label for this block
524
+ new_label = label
525
+ for old_b, new_b in block_map.items():
526
+ if old_b.label == label:
527
+ new_label = new_b.label
528
+ break
529
+ new_incoming.append((map_value(value), new_label))
530
+ return Phi(map_value(inst.dest), new_incoming, inst.source_location)
531
+ elif isinstance(inst, Print):
532
+ return Print(map_value(inst.value), inst.source_location)
533
+ else:
534
+ # For any other instruction types, return as-is
535
+ # This is conservative - may need to extend for new instruction types
536
+ return inst
537
+
538
+ def finalize(self) -> None:
539
+ """Finalize the pass after running.
540
+
541
+ Cleans up any temporary state.
542
+ """
543
+ self.inlining_depth.clear()
544
+
545
+ def get_statistics(self) -> dict[str, int]:
546
+ """Get optimization statistics.
547
+
548
+ Returns:
549
+ Dictionary of statistics.
550
+ """
551
+ return self.stats