machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,435 @@
1
+ """Tests for function inlining optimization."""
2
+
3
+ from machine_dialect.mir.mir_function import MIRFunction
4
+ from machine_dialect.mir.mir_instructions import (
5
+ BinaryOp,
6
+ Call,
7
+ ConditionalJump,
8
+ Copy,
9
+ Jump,
10
+ MIRInstruction,
11
+ Return,
12
+ UnaryOp,
13
+ )
14
+ from machine_dialect.mir.mir_module import MIRModule
15
+ from machine_dialect.mir.mir_types import MIRType
16
+ from machine_dialect.mir.mir_values import Constant, MIRValue, Temp, Variable
17
+ from machine_dialect.mir.optimizations.inlining import FunctionInlining, InliningCost
18
+
19
+
20
+ def create_simple_module() -> MIRModule:
21
+ """Create a module with simple functions for inlining tests.
22
+
23
+ Contains:
24
+ - add(a, b): Simple function that returns a + b
25
+ - multiply(x, y): Simple function that returns x * y
26
+ - compute(n): Calls add and multiply
27
+ """
28
+ module = MIRModule("test_module")
29
+
30
+ # Create add function: return a + b
31
+ add_func = MIRFunction("add", [Variable("a", MIRType.INT), Variable("b", MIRType.INT)])
32
+ add_entry = add_func.cfg.get_or_create_block("entry")
33
+ a_var = Variable("a", MIRType.INT)
34
+ b_var = Variable("b", MIRType.INT)
35
+ result = Temp(MIRType.INT, 0)
36
+ add_entry.instructions = [
37
+ BinaryOp(result, "+", a_var, b_var, (1, 1)),
38
+ Return((1, 1), result),
39
+ ]
40
+ module.functions["add"] = add_func
41
+
42
+ # Create multiply function: return x * y
43
+ mul_func = MIRFunction("multiply", [Variable("x", MIRType.INT), Variable("y", MIRType.INT)])
44
+ mul_entry = mul_func.cfg.get_or_create_block("entry")
45
+ x_var = Variable("x", MIRType.INT)
46
+ y_var = Variable("y", MIRType.INT)
47
+ result2 = Temp(MIRType.INT, 1)
48
+ mul_entry.instructions = [
49
+ BinaryOp(result2, "*", x_var, y_var, (1, 1)),
50
+ Return((1, 1), result2),
51
+ ]
52
+ module.functions["multiply"] = mul_func
53
+
54
+ # Create compute function that calls add and multiply
55
+ compute_func = MIRFunction("compute", [Variable("n", MIRType.INT)])
56
+ compute_entry = compute_func.cfg.get_or_create_block("entry")
57
+ n_var = Variable("n", MIRType.INT)
58
+ sum_result = Temp(MIRType.INT, 10)
59
+ prod_result = Temp(MIRType.INT, 11)
60
+ final_result = Temp(MIRType.INT, 12)
61
+ compute_entry.instructions = [
62
+ Call(sum_result, "add", [n_var, Constant(10)], (1, 1)),
63
+ Call(prod_result, "multiply", [sum_result, Constant(2)], (1, 1)),
64
+ BinaryOp(final_result, "+", prod_result, Constant(5), (1, 1)),
65
+ Return((1, 1), final_result),
66
+ ]
67
+ module.functions["compute"] = compute_func
68
+
69
+ return module
70
+
71
+
72
+ def create_conditional_module() -> MIRModule:
73
+ """Create a module with conditional functions.
74
+
75
+ Contains:
76
+ - abs(x): Returns absolute value with conditional
77
+ - max(a, b): Returns maximum with conditional
78
+ - process(x, y): Calls abs and max
79
+ """
80
+ module = MIRModule("conditional_module")
81
+
82
+ # Create abs function
83
+ abs_func = MIRFunction("abs", [Variable("x", MIRType.INT)])
84
+ entry = abs_func.cfg.get_or_create_block("entry")
85
+ positive = abs_func.cfg.get_or_create_block("positive")
86
+ negative = abs_func.cfg.get_or_create_block("negative")
87
+ exit_block = abs_func.cfg.get_or_create_block("exit")
88
+
89
+ x_var = Variable("x", MIRType.INT)
90
+ is_negative = Temp(MIRType.BOOL, 20)
91
+ neg_x = Temp(MIRType.INT, 21)
92
+ result_var = Variable("result", MIRType.INT)
93
+
94
+ # Entry: check if x < 0
95
+ entry.instructions = [
96
+ BinaryOp(is_negative, "<", x_var, Constant(0), (1, 1)),
97
+ ConditionalJump(is_negative, "negative", (1, 1), "positive"),
98
+ ]
99
+ abs_func.cfg.connect(entry, negative)
100
+ abs_func.cfg.connect(entry, positive)
101
+
102
+ # Negative branch: result = -x
103
+ negative.instructions = [
104
+ UnaryOp(neg_x, "-", x_var, (1, 1)),
105
+ Copy(result_var, neg_x, (1, 1)),
106
+ Jump("exit", (1, 1)),
107
+ ]
108
+ abs_func.cfg.connect(negative, exit_block)
109
+
110
+ # Positive branch: result = x
111
+ positive.instructions = [
112
+ Copy(result_var, x_var, (1, 1)),
113
+ Jump("exit", (1, 1)),
114
+ ]
115
+ abs_func.cfg.connect(positive, exit_block)
116
+
117
+ # Exit: return result
118
+ exit_block.instructions = [Return((1, 1), result_var)]
119
+
120
+ module.functions["abs"] = abs_func
121
+
122
+ # Create process function that calls abs
123
+ process_func = MIRFunction("process", [Variable("x", MIRType.INT)])
124
+ process_entry = process_func.cfg.get_or_create_block("entry")
125
+ x_param = Variable("x", MIRType.INT)
126
+ abs_result = Temp(MIRType.INT, 30)
127
+ doubled = Temp(MIRType.INT, 31)
128
+ process_entry.instructions = [
129
+ Call(abs_result, "abs", [x_param], (1, 1)),
130
+ BinaryOp(doubled, "*", abs_result, Constant(2), (1, 1)),
131
+ Return((1, 1), doubled),
132
+ ]
133
+ module.functions["process"] = process_func
134
+
135
+ return module
136
+
137
+
138
+ def create_large_function_module() -> MIRModule:
139
+ """Create a module with a large function that shouldn't be inlined."""
140
+ module = MIRModule("large_module")
141
+
142
+ # Create a large function with many instructions
143
+ large_func = MIRFunction("large_func", [Variable("x", MIRType.INT)])
144
+ entry = large_func.cfg.get_or_create_block("entry")
145
+ x_var = Variable("x", MIRType.INT)
146
+
147
+ instructions: list[MIRInstruction] = []
148
+ current: MIRValue = x_var
149
+ for i in range(100): # Create 100 instructions
150
+ temp = Temp(MIRType.INT, 100 + i)
151
+ instructions.append(BinaryOp(temp, "+", current, Constant(i), (1, 1)))
152
+ current = temp
153
+ instructions.append(Return((1, 1), current))
154
+ entry.instructions = instructions
155
+
156
+ module.functions["large_func"] = large_func
157
+
158
+ # Create caller
159
+ caller_func = MIRFunction("caller", [Variable("n", MIRType.INT)])
160
+ caller_entry = caller_func.cfg.get_or_create_block("entry")
161
+ n_var = Variable("n", MIRType.INT)
162
+ result = Temp(MIRType.INT, 500)
163
+ caller_entry.instructions = [
164
+ Call(result, "large_func", [n_var], (1, 1)),
165
+ Return((1, 1), result),
166
+ ]
167
+ module.functions["caller"] = caller_func
168
+
169
+ return module
170
+
171
+
172
+ def create_recursive_module() -> MIRModule:
173
+ """Create a module with recursive function."""
174
+ module = MIRModule("recursive_module")
175
+
176
+ # Create factorial function (recursive)
177
+ fact_func = MIRFunction("factorial", [Variable("n", MIRType.INT)])
178
+ entry = fact_func.cfg.get_or_create_block("entry")
179
+ base_case = fact_func.cfg.get_or_create_block("base_case")
180
+ recursive_case = fact_func.cfg.get_or_create_block("recursive_case")
181
+
182
+ n_var = Variable("n", MIRType.INT)
183
+ is_base = Temp(MIRType.BOOL, 40)
184
+ n_minus_one = Temp(MIRType.INT, 41)
185
+ recursive_result = Temp(MIRType.INT, 42)
186
+ final_result = Temp(MIRType.INT, 43)
187
+
188
+ # Entry: check if n <= 1
189
+ entry.instructions = [
190
+ BinaryOp(is_base, "<=", n_var, Constant(1), (1, 1)),
191
+ ConditionalJump(is_base, "base_case", (1, 1), "recursive_case"),
192
+ ]
193
+ fact_func.cfg.connect(entry, base_case)
194
+ fact_func.cfg.connect(entry, recursive_case)
195
+
196
+ # Base case: return 1
197
+ base_case.instructions = [Return((1, 1), Constant(1))]
198
+
199
+ # Recursive case: return n * factorial(n-1)
200
+ recursive_case.instructions = [
201
+ BinaryOp(n_minus_one, "-", n_var, Constant(1), (1, 1)),
202
+ Call(recursive_result, "factorial", [n_minus_one], (1, 1)),
203
+ BinaryOp(final_result, "*", n_var, recursive_result, (1, 1)),
204
+ Return((1, 1), final_result),
205
+ ]
206
+
207
+ module.functions["factorial"] = fact_func
208
+
209
+ return module
210
+
211
+
212
+ class TestInliningCost:
213
+ """Test the inlining cost model."""
214
+
215
+ def test_small_function_always_inlined(self) -> None:
216
+ """Test that small functions are always inlined."""
217
+ cost = InliningCost(
218
+ instruction_count=3,
219
+ call_site_benefit=5.0,
220
+ size_threshold=50,
221
+ depth=0,
222
+ )
223
+ assert cost.should_inline()
224
+
225
+ def test_large_function_not_inlined(self) -> None:
226
+ """Test that large functions are not inlined."""
227
+ cost = InliningCost(
228
+ instruction_count=100,
229
+ call_site_benefit=10.0,
230
+ size_threshold=50,
231
+ depth=0,
232
+ )
233
+ assert not cost.should_inline()
234
+
235
+ def test_deep_recursion_prevented(self) -> None:
236
+ """Test that deep inlining is prevented."""
237
+ cost = InliningCost(
238
+ instruction_count=5,
239
+ call_site_benefit=20.0,
240
+ size_threshold=50,
241
+ depth=5, # Too deep
242
+ )
243
+ assert not cost.should_inline()
244
+
245
+ def test_cost_benefit_analysis(self) -> None:
246
+ """Test cost-benefit analysis for medium functions."""
247
+ # High benefit should inline
248
+ cost_high_benefit = InliningCost(
249
+ instruction_count=20,
250
+ call_site_benefit=25.0,
251
+ size_threshold=50,
252
+ depth=1,
253
+ )
254
+ assert cost_high_benefit.should_inline()
255
+
256
+ # Low benefit should not inline
257
+ cost_low_benefit = InliningCost(
258
+ instruction_count=20,
259
+ call_site_benefit=15.0,
260
+ size_threshold=50,
261
+ depth=1,
262
+ )
263
+ assert not cost_low_benefit.should_inline()
264
+
265
+
266
+ class TestFunctionInlining:
267
+ """Test suite for function inlining."""
268
+
269
+ def test_simple_inlining(self) -> None:
270
+ """Test inlining of simple functions."""
271
+ module = create_simple_module()
272
+ inliner = FunctionInlining(size_threshold=50)
273
+
274
+ # Run inlining
275
+ modified = inliner.run_on_module(module)
276
+ assert modified, "Module should be modified"
277
+
278
+ # Check statistics
279
+ stats = inliner.get_statistics()
280
+ assert stats["inlined"] >= 2, "Should inline add and multiply calls"
281
+ assert stats["call_sites_processed"] >= 2
282
+
283
+ # Check that compute function has inlined code
284
+ compute_func = module.functions["compute"]
285
+ has_add_op = False
286
+ has_mul_op = False
287
+ for block in compute_func.cfg.blocks.values():
288
+ for inst in block.instructions:
289
+ if isinstance(inst, BinaryOp):
290
+ if inst.op == "+":
291
+ has_add_op = True
292
+ elif inst.op == "*":
293
+ has_mul_op = True
294
+
295
+ assert has_add_op, "Add operation should be inlined"
296
+ assert has_mul_op, "Multiply operation should be inlined"
297
+
298
+ def test_conditional_inlining(self) -> None:
299
+ """Test inlining of functions with conditionals."""
300
+ module = create_conditional_module()
301
+ inliner = FunctionInlining(size_threshold=50)
302
+
303
+ # Run inlining
304
+ modified = inliner.run_on_module(module)
305
+ assert modified, "Module should be modified"
306
+
307
+ # Check that process function has inlined abs
308
+ process_func = module.functions["process"]
309
+
310
+ # Should have conditional jump from inlined abs
311
+ has_conditional = False
312
+ for block in process_func.cfg.blocks.values():
313
+ for inst in block.instructions:
314
+ if isinstance(inst, ConditionalJump):
315
+ has_conditional = True
316
+ break
317
+
318
+ assert has_conditional, "Conditional from abs should be inlined"
319
+
320
+ # Check statistics
321
+ stats = inliner.get_statistics()
322
+ assert stats["inlined"] >= 1, "Should inline abs call"
323
+
324
+ def test_large_function_not_inlined(self) -> None:
325
+ """Test that large functions are not inlined."""
326
+ module = create_large_function_module()
327
+ inliner = FunctionInlining(size_threshold=50)
328
+
329
+ # Run inlining
330
+ modified = inliner.run_on_module(module)
331
+ assert not modified, "Large function should not be inlined"
332
+
333
+ # Check that call remains
334
+ caller_func = module.functions["caller"]
335
+ has_call = False
336
+ for block in caller_func.cfg.blocks.values():
337
+ for inst in block.instructions:
338
+ if isinstance(inst, Call):
339
+ has_call = True
340
+ break
341
+
342
+ assert has_call, "Call to large function should remain"
343
+
344
+ # Check statistics
345
+ stats = inliner.get_statistics()
346
+ assert stats["inlined"] == 0, "No functions should be inlined"
347
+
348
+ def test_recursive_not_directly_inlined(self) -> None:
349
+ """Test that recursive functions are not directly inlined."""
350
+ module = create_recursive_module()
351
+ inliner = FunctionInlining(size_threshold=50)
352
+
353
+ # Run inlining
354
+ inliner.run_on_module(module)
355
+
356
+ # The recursive call should not be inlined into itself
357
+ fact_func = module.functions["factorial"]
358
+ has_recursive_call = False
359
+ for block in fact_func.cfg.blocks.values():
360
+ for inst in block.instructions:
361
+ if isinstance(inst, Call) and inst.func.name == "factorial":
362
+ has_recursive_call = True
363
+ break
364
+
365
+ assert has_recursive_call, "Recursive call should not be inlined"
366
+
367
+ def test_constant_propagation_benefit(self) -> None:
368
+ """Test that constant arguments increase inlining benefit."""
369
+ module = MIRModule("const_module")
370
+
371
+ # Create simple function
372
+ simple_func = MIRFunction("simple", [Variable("x", MIRType.INT)])
373
+ entry = simple_func.cfg.get_or_create_block("entry")
374
+ x_var = Variable("x", MIRType.INT)
375
+ result = Temp(MIRType.INT, 60)
376
+ entry.instructions = [
377
+ BinaryOp(result, "*", x_var, Constant(2), (1, 1)),
378
+ Return((1, 1), result),
379
+ ]
380
+ module.functions["simple"] = simple_func
381
+
382
+ # Create caller with constant argument
383
+ caller_func = MIRFunction("caller", [])
384
+ caller_entry = caller_func.cfg.get_or_create_block("entry")
385
+ call_result = Temp(MIRType.INT, 61)
386
+ caller_entry.instructions = [
387
+ Call(call_result, "simple", [Constant(5)], (1, 1)), # Constant argument
388
+ Return((1, 1), call_result),
389
+ ]
390
+ module.functions["caller"] = caller_func
391
+
392
+ # Run inlining
393
+ inliner = FunctionInlining(size_threshold=10)
394
+ modified = inliner.run_on_module(module)
395
+
396
+ assert modified, "Function with constant argument should be inlined"
397
+
398
+ # Check that the call was inlined
399
+ caller_func = module.functions["caller"]
400
+ has_call = False
401
+ has_binary_op = False
402
+ for block in caller_func.cfg.blocks.values():
403
+ for inst in block.instructions:
404
+ if isinstance(inst, Call):
405
+ has_call = True
406
+ elif isinstance(inst, BinaryOp):
407
+ has_binary_op = True
408
+
409
+ assert not has_call, "Call should be inlined"
410
+ assert has_binary_op, "Binary operation should be present"
411
+
412
+ def test_no_functions_to_inline(self) -> None:
413
+ """Test module with no inlinable functions."""
414
+ module = MIRModule("empty_module")
415
+
416
+ # Single function with no calls
417
+ func = MIRFunction("no_calls", [Variable("x", MIRType.INT)])
418
+ entry = func.cfg.get_or_create_block("entry")
419
+ x_var = Variable("x", MIRType.INT)
420
+ result = Temp(MIRType.INT, 70)
421
+ entry.instructions = [
422
+ BinaryOp(result, "*", x_var, Constant(2), (1, 1)),
423
+ Return((1, 1), result),
424
+ ]
425
+ module.functions["no_calls"] = func
426
+
427
+ # Run inlining
428
+ inliner = FunctionInlining()
429
+ modified = inliner.run_on_module(module)
430
+
431
+ assert not modified, "Module should not be modified"
432
+
433
+ stats = inliner.get_statistics()
434
+ assert stats["inlined"] == 0
435
+ assert stats["call_sites_processed"] == 0