machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,435 @@
|
|
1
|
+
"""Tests for function inlining optimization."""
|
2
|
+
|
3
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
4
|
+
from machine_dialect.mir.mir_instructions import (
|
5
|
+
BinaryOp,
|
6
|
+
Call,
|
7
|
+
ConditionalJump,
|
8
|
+
Copy,
|
9
|
+
Jump,
|
10
|
+
MIRInstruction,
|
11
|
+
Return,
|
12
|
+
UnaryOp,
|
13
|
+
)
|
14
|
+
from machine_dialect.mir.mir_module import MIRModule
|
15
|
+
from machine_dialect.mir.mir_types import MIRType
|
16
|
+
from machine_dialect.mir.mir_values import Constant, MIRValue, Temp, Variable
|
17
|
+
from machine_dialect.mir.optimizations.inlining import FunctionInlining, InliningCost
|
18
|
+
|
19
|
+
|
20
|
+
def create_simple_module() -> MIRModule:
|
21
|
+
"""Create a module with simple functions for inlining tests.
|
22
|
+
|
23
|
+
Contains:
|
24
|
+
- add(a, b): Simple function that returns a + b
|
25
|
+
- multiply(x, y): Simple function that returns x * y
|
26
|
+
- compute(n): Calls add and multiply
|
27
|
+
"""
|
28
|
+
module = MIRModule("test_module")
|
29
|
+
|
30
|
+
# Create add function: return a + b
|
31
|
+
add_func = MIRFunction("add", [Variable("a", MIRType.INT), Variable("b", MIRType.INT)])
|
32
|
+
add_entry = add_func.cfg.get_or_create_block("entry")
|
33
|
+
a_var = Variable("a", MIRType.INT)
|
34
|
+
b_var = Variable("b", MIRType.INT)
|
35
|
+
result = Temp(MIRType.INT, 0)
|
36
|
+
add_entry.instructions = [
|
37
|
+
BinaryOp(result, "+", a_var, b_var, (1, 1)),
|
38
|
+
Return((1, 1), result),
|
39
|
+
]
|
40
|
+
module.functions["add"] = add_func
|
41
|
+
|
42
|
+
# Create multiply function: return x * y
|
43
|
+
mul_func = MIRFunction("multiply", [Variable("x", MIRType.INT), Variable("y", MIRType.INT)])
|
44
|
+
mul_entry = mul_func.cfg.get_or_create_block("entry")
|
45
|
+
x_var = Variable("x", MIRType.INT)
|
46
|
+
y_var = Variable("y", MIRType.INT)
|
47
|
+
result2 = Temp(MIRType.INT, 1)
|
48
|
+
mul_entry.instructions = [
|
49
|
+
BinaryOp(result2, "*", x_var, y_var, (1, 1)),
|
50
|
+
Return((1, 1), result2),
|
51
|
+
]
|
52
|
+
module.functions["multiply"] = mul_func
|
53
|
+
|
54
|
+
# Create compute function that calls add and multiply
|
55
|
+
compute_func = MIRFunction("compute", [Variable("n", MIRType.INT)])
|
56
|
+
compute_entry = compute_func.cfg.get_or_create_block("entry")
|
57
|
+
n_var = Variable("n", MIRType.INT)
|
58
|
+
sum_result = Temp(MIRType.INT, 10)
|
59
|
+
prod_result = Temp(MIRType.INT, 11)
|
60
|
+
final_result = Temp(MIRType.INT, 12)
|
61
|
+
compute_entry.instructions = [
|
62
|
+
Call(sum_result, "add", [n_var, Constant(10)], (1, 1)),
|
63
|
+
Call(prod_result, "multiply", [sum_result, Constant(2)], (1, 1)),
|
64
|
+
BinaryOp(final_result, "+", prod_result, Constant(5), (1, 1)),
|
65
|
+
Return((1, 1), final_result),
|
66
|
+
]
|
67
|
+
module.functions["compute"] = compute_func
|
68
|
+
|
69
|
+
return module
|
70
|
+
|
71
|
+
|
72
|
+
def create_conditional_module() -> MIRModule:
|
73
|
+
"""Create a module with conditional functions.
|
74
|
+
|
75
|
+
Contains:
|
76
|
+
- abs(x): Returns absolute value with conditional
|
77
|
+
- max(a, b): Returns maximum with conditional
|
78
|
+
- process(x, y): Calls abs and max
|
79
|
+
"""
|
80
|
+
module = MIRModule("conditional_module")
|
81
|
+
|
82
|
+
# Create abs function
|
83
|
+
abs_func = MIRFunction("abs", [Variable("x", MIRType.INT)])
|
84
|
+
entry = abs_func.cfg.get_or_create_block("entry")
|
85
|
+
positive = abs_func.cfg.get_or_create_block("positive")
|
86
|
+
negative = abs_func.cfg.get_or_create_block("negative")
|
87
|
+
exit_block = abs_func.cfg.get_or_create_block("exit")
|
88
|
+
|
89
|
+
x_var = Variable("x", MIRType.INT)
|
90
|
+
is_negative = Temp(MIRType.BOOL, 20)
|
91
|
+
neg_x = Temp(MIRType.INT, 21)
|
92
|
+
result_var = Variable("result", MIRType.INT)
|
93
|
+
|
94
|
+
# Entry: check if x < 0
|
95
|
+
entry.instructions = [
|
96
|
+
BinaryOp(is_negative, "<", x_var, Constant(0), (1, 1)),
|
97
|
+
ConditionalJump(is_negative, "negative", (1, 1), "positive"),
|
98
|
+
]
|
99
|
+
abs_func.cfg.connect(entry, negative)
|
100
|
+
abs_func.cfg.connect(entry, positive)
|
101
|
+
|
102
|
+
# Negative branch: result = -x
|
103
|
+
negative.instructions = [
|
104
|
+
UnaryOp(neg_x, "-", x_var, (1, 1)),
|
105
|
+
Copy(result_var, neg_x, (1, 1)),
|
106
|
+
Jump("exit", (1, 1)),
|
107
|
+
]
|
108
|
+
abs_func.cfg.connect(negative, exit_block)
|
109
|
+
|
110
|
+
# Positive branch: result = x
|
111
|
+
positive.instructions = [
|
112
|
+
Copy(result_var, x_var, (1, 1)),
|
113
|
+
Jump("exit", (1, 1)),
|
114
|
+
]
|
115
|
+
abs_func.cfg.connect(positive, exit_block)
|
116
|
+
|
117
|
+
# Exit: return result
|
118
|
+
exit_block.instructions = [Return((1, 1), result_var)]
|
119
|
+
|
120
|
+
module.functions["abs"] = abs_func
|
121
|
+
|
122
|
+
# Create process function that calls abs
|
123
|
+
process_func = MIRFunction("process", [Variable("x", MIRType.INT)])
|
124
|
+
process_entry = process_func.cfg.get_or_create_block("entry")
|
125
|
+
x_param = Variable("x", MIRType.INT)
|
126
|
+
abs_result = Temp(MIRType.INT, 30)
|
127
|
+
doubled = Temp(MIRType.INT, 31)
|
128
|
+
process_entry.instructions = [
|
129
|
+
Call(abs_result, "abs", [x_param], (1, 1)),
|
130
|
+
BinaryOp(doubled, "*", abs_result, Constant(2), (1, 1)),
|
131
|
+
Return((1, 1), doubled),
|
132
|
+
]
|
133
|
+
module.functions["process"] = process_func
|
134
|
+
|
135
|
+
return module
|
136
|
+
|
137
|
+
|
138
|
+
def create_large_function_module() -> MIRModule:
|
139
|
+
"""Create a module with a large function that shouldn't be inlined."""
|
140
|
+
module = MIRModule("large_module")
|
141
|
+
|
142
|
+
# Create a large function with many instructions
|
143
|
+
large_func = MIRFunction("large_func", [Variable("x", MIRType.INT)])
|
144
|
+
entry = large_func.cfg.get_or_create_block("entry")
|
145
|
+
x_var = Variable("x", MIRType.INT)
|
146
|
+
|
147
|
+
instructions: list[MIRInstruction] = []
|
148
|
+
current: MIRValue = x_var
|
149
|
+
for i in range(100): # Create 100 instructions
|
150
|
+
temp = Temp(MIRType.INT, 100 + i)
|
151
|
+
instructions.append(BinaryOp(temp, "+", current, Constant(i), (1, 1)))
|
152
|
+
current = temp
|
153
|
+
instructions.append(Return((1, 1), current))
|
154
|
+
entry.instructions = instructions
|
155
|
+
|
156
|
+
module.functions["large_func"] = large_func
|
157
|
+
|
158
|
+
# Create caller
|
159
|
+
caller_func = MIRFunction("caller", [Variable("n", MIRType.INT)])
|
160
|
+
caller_entry = caller_func.cfg.get_or_create_block("entry")
|
161
|
+
n_var = Variable("n", MIRType.INT)
|
162
|
+
result = Temp(MIRType.INT, 500)
|
163
|
+
caller_entry.instructions = [
|
164
|
+
Call(result, "large_func", [n_var], (1, 1)),
|
165
|
+
Return((1, 1), result),
|
166
|
+
]
|
167
|
+
module.functions["caller"] = caller_func
|
168
|
+
|
169
|
+
return module
|
170
|
+
|
171
|
+
|
172
|
+
def create_recursive_module() -> MIRModule:
|
173
|
+
"""Create a module with recursive function."""
|
174
|
+
module = MIRModule("recursive_module")
|
175
|
+
|
176
|
+
# Create factorial function (recursive)
|
177
|
+
fact_func = MIRFunction("factorial", [Variable("n", MIRType.INT)])
|
178
|
+
entry = fact_func.cfg.get_or_create_block("entry")
|
179
|
+
base_case = fact_func.cfg.get_or_create_block("base_case")
|
180
|
+
recursive_case = fact_func.cfg.get_or_create_block("recursive_case")
|
181
|
+
|
182
|
+
n_var = Variable("n", MIRType.INT)
|
183
|
+
is_base = Temp(MIRType.BOOL, 40)
|
184
|
+
n_minus_one = Temp(MIRType.INT, 41)
|
185
|
+
recursive_result = Temp(MIRType.INT, 42)
|
186
|
+
final_result = Temp(MIRType.INT, 43)
|
187
|
+
|
188
|
+
# Entry: check if n <= 1
|
189
|
+
entry.instructions = [
|
190
|
+
BinaryOp(is_base, "<=", n_var, Constant(1), (1, 1)),
|
191
|
+
ConditionalJump(is_base, "base_case", (1, 1), "recursive_case"),
|
192
|
+
]
|
193
|
+
fact_func.cfg.connect(entry, base_case)
|
194
|
+
fact_func.cfg.connect(entry, recursive_case)
|
195
|
+
|
196
|
+
# Base case: return 1
|
197
|
+
base_case.instructions = [Return((1, 1), Constant(1))]
|
198
|
+
|
199
|
+
# Recursive case: return n * factorial(n-1)
|
200
|
+
recursive_case.instructions = [
|
201
|
+
BinaryOp(n_minus_one, "-", n_var, Constant(1), (1, 1)),
|
202
|
+
Call(recursive_result, "factorial", [n_minus_one], (1, 1)),
|
203
|
+
BinaryOp(final_result, "*", n_var, recursive_result, (1, 1)),
|
204
|
+
Return((1, 1), final_result),
|
205
|
+
]
|
206
|
+
|
207
|
+
module.functions["factorial"] = fact_func
|
208
|
+
|
209
|
+
return module
|
210
|
+
|
211
|
+
|
212
|
+
class TestInliningCost:
|
213
|
+
"""Test the inlining cost model."""
|
214
|
+
|
215
|
+
def test_small_function_always_inlined(self) -> None:
|
216
|
+
"""Test that small functions are always inlined."""
|
217
|
+
cost = InliningCost(
|
218
|
+
instruction_count=3,
|
219
|
+
call_site_benefit=5.0,
|
220
|
+
size_threshold=50,
|
221
|
+
depth=0,
|
222
|
+
)
|
223
|
+
assert cost.should_inline()
|
224
|
+
|
225
|
+
def test_large_function_not_inlined(self) -> None:
|
226
|
+
"""Test that large functions are not inlined."""
|
227
|
+
cost = InliningCost(
|
228
|
+
instruction_count=100,
|
229
|
+
call_site_benefit=10.0,
|
230
|
+
size_threshold=50,
|
231
|
+
depth=0,
|
232
|
+
)
|
233
|
+
assert not cost.should_inline()
|
234
|
+
|
235
|
+
def test_deep_recursion_prevented(self) -> None:
|
236
|
+
"""Test that deep inlining is prevented."""
|
237
|
+
cost = InliningCost(
|
238
|
+
instruction_count=5,
|
239
|
+
call_site_benefit=20.0,
|
240
|
+
size_threshold=50,
|
241
|
+
depth=5, # Too deep
|
242
|
+
)
|
243
|
+
assert not cost.should_inline()
|
244
|
+
|
245
|
+
def test_cost_benefit_analysis(self) -> None:
|
246
|
+
"""Test cost-benefit analysis for medium functions."""
|
247
|
+
# High benefit should inline
|
248
|
+
cost_high_benefit = InliningCost(
|
249
|
+
instruction_count=20,
|
250
|
+
call_site_benefit=25.0,
|
251
|
+
size_threshold=50,
|
252
|
+
depth=1,
|
253
|
+
)
|
254
|
+
assert cost_high_benefit.should_inline()
|
255
|
+
|
256
|
+
# Low benefit should not inline
|
257
|
+
cost_low_benefit = InliningCost(
|
258
|
+
instruction_count=20,
|
259
|
+
call_site_benefit=15.0,
|
260
|
+
size_threshold=50,
|
261
|
+
depth=1,
|
262
|
+
)
|
263
|
+
assert not cost_low_benefit.should_inline()
|
264
|
+
|
265
|
+
|
266
|
+
class TestFunctionInlining:
|
267
|
+
"""Test suite for function inlining."""
|
268
|
+
|
269
|
+
def test_simple_inlining(self) -> None:
|
270
|
+
"""Test inlining of simple functions."""
|
271
|
+
module = create_simple_module()
|
272
|
+
inliner = FunctionInlining(size_threshold=50)
|
273
|
+
|
274
|
+
# Run inlining
|
275
|
+
modified = inliner.run_on_module(module)
|
276
|
+
assert modified, "Module should be modified"
|
277
|
+
|
278
|
+
# Check statistics
|
279
|
+
stats = inliner.get_statistics()
|
280
|
+
assert stats["inlined"] >= 2, "Should inline add and multiply calls"
|
281
|
+
assert stats["call_sites_processed"] >= 2
|
282
|
+
|
283
|
+
# Check that compute function has inlined code
|
284
|
+
compute_func = module.functions["compute"]
|
285
|
+
has_add_op = False
|
286
|
+
has_mul_op = False
|
287
|
+
for block in compute_func.cfg.blocks.values():
|
288
|
+
for inst in block.instructions:
|
289
|
+
if isinstance(inst, BinaryOp):
|
290
|
+
if inst.op == "+":
|
291
|
+
has_add_op = True
|
292
|
+
elif inst.op == "*":
|
293
|
+
has_mul_op = True
|
294
|
+
|
295
|
+
assert has_add_op, "Add operation should be inlined"
|
296
|
+
assert has_mul_op, "Multiply operation should be inlined"
|
297
|
+
|
298
|
+
def test_conditional_inlining(self) -> None:
|
299
|
+
"""Test inlining of functions with conditionals."""
|
300
|
+
module = create_conditional_module()
|
301
|
+
inliner = FunctionInlining(size_threshold=50)
|
302
|
+
|
303
|
+
# Run inlining
|
304
|
+
modified = inliner.run_on_module(module)
|
305
|
+
assert modified, "Module should be modified"
|
306
|
+
|
307
|
+
# Check that process function has inlined abs
|
308
|
+
process_func = module.functions["process"]
|
309
|
+
|
310
|
+
# Should have conditional jump from inlined abs
|
311
|
+
has_conditional = False
|
312
|
+
for block in process_func.cfg.blocks.values():
|
313
|
+
for inst in block.instructions:
|
314
|
+
if isinstance(inst, ConditionalJump):
|
315
|
+
has_conditional = True
|
316
|
+
break
|
317
|
+
|
318
|
+
assert has_conditional, "Conditional from abs should be inlined"
|
319
|
+
|
320
|
+
# Check statistics
|
321
|
+
stats = inliner.get_statistics()
|
322
|
+
assert stats["inlined"] >= 1, "Should inline abs call"
|
323
|
+
|
324
|
+
def test_large_function_not_inlined(self) -> None:
|
325
|
+
"""Test that large functions are not inlined."""
|
326
|
+
module = create_large_function_module()
|
327
|
+
inliner = FunctionInlining(size_threshold=50)
|
328
|
+
|
329
|
+
# Run inlining
|
330
|
+
modified = inliner.run_on_module(module)
|
331
|
+
assert not modified, "Large function should not be inlined"
|
332
|
+
|
333
|
+
# Check that call remains
|
334
|
+
caller_func = module.functions["caller"]
|
335
|
+
has_call = False
|
336
|
+
for block in caller_func.cfg.blocks.values():
|
337
|
+
for inst in block.instructions:
|
338
|
+
if isinstance(inst, Call):
|
339
|
+
has_call = True
|
340
|
+
break
|
341
|
+
|
342
|
+
assert has_call, "Call to large function should remain"
|
343
|
+
|
344
|
+
# Check statistics
|
345
|
+
stats = inliner.get_statistics()
|
346
|
+
assert stats["inlined"] == 0, "No functions should be inlined"
|
347
|
+
|
348
|
+
def test_recursive_not_directly_inlined(self) -> None:
|
349
|
+
"""Test that recursive functions are not directly inlined."""
|
350
|
+
module = create_recursive_module()
|
351
|
+
inliner = FunctionInlining(size_threshold=50)
|
352
|
+
|
353
|
+
# Run inlining
|
354
|
+
inliner.run_on_module(module)
|
355
|
+
|
356
|
+
# The recursive call should not be inlined into itself
|
357
|
+
fact_func = module.functions["factorial"]
|
358
|
+
has_recursive_call = False
|
359
|
+
for block in fact_func.cfg.blocks.values():
|
360
|
+
for inst in block.instructions:
|
361
|
+
if isinstance(inst, Call) and inst.func.name == "factorial":
|
362
|
+
has_recursive_call = True
|
363
|
+
break
|
364
|
+
|
365
|
+
assert has_recursive_call, "Recursive call should not be inlined"
|
366
|
+
|
367
|
+
def test_constant_propagation_benefit(self) -> None:
|
368
|
+
"""Test that constant arguments increase inlining benefit."""
|
369
|
+
module = MIRModule("const_module")
|
370
|
+
|
371
|
+
# Create simple function
|
372
|
+
simple_func = MIRFunction("simple", [Variable("x", MIRType.INT)])
|
373
|
+
entry = simple_func.cfg.get_or_create_block("entry")
|
374
|
+
x_var = Variable("x", MIRType.INT)
|
375
|
+
result = Temp(MIRType.INT, 60)
|
376
|
+
entry.instructions = [
|
377
|
+
BinaryOp(result, "*", x_var, Constant(2), (1, 1)),
|
378
|
+
Return((1, 1), result),
|
379
|
+
]
|
380
|
+
module.functions["simple"] = simple_func
|
381
|
+
|
382
|
+
# Create caller with constant argument
|
383
|
+
caller_func = MIRFunction("caller", [])
|
384
|
+
caller_entry = caller_func.cfg.get_or_create_block("entry")
|
385
|
+
call_result = Temp(MIRType.INT, 61)
|
386
|
+
caller_entry.instructions = [
|
387
|
+
Call(call_result, "simple", [Constant(5)], (1, 1)), # Constant argument
|
388
|
+
Return((1, 1), call_result),
|
389
|
+
]
|
390
|
+
module.functions["caller"] = caller_func
|
391
|
+
|
392
|
+
# Run inlining
|
393
|
+
inliner = FunctionInlining(size_threshold=10)
|
394
|
+
modified = inliner.run_on_module(module)
|
395
|
+
|
396
|
+
assert modified, "Function with constant argument should be inlined"
|
397
|
+
|
398
|
+
# Check that the call was inlined
|
399
|
+
caller_func = module.functions["caller"]
|
400
|
+
has_call = False
|
401
|
+
has_binary_op = False
|
402
|
+
for block in caller_func.cfg.blocks.values():
|
403
|
+
for inst in block.instructions:
|
404
|
+
if isinstance(inst, Call):
|
405
|
+
has_call = True
|
406
|
+
elif isinstance(inst, BinaryOp):
|
407
|
+
has_binary_op = True
|
408
|
+
|
409
|
+
assert not has_call, "Call should be inlined"
|
410
|
+
assert has_binary_op, "Binary operation should be present"
|
411
|
+
|
412
|
+
def test_no_functions_to_inline(self) -> None:
|
413
|
+
"""Test module with no inlinable functions."""
|
414
|
+
module = MIRModule("empty_module")
|
415
|
+
|
416
|
+
# Single function with no calls
|
417
|
+
func = MIRFunction("no_calls", [Variable("x", MIRType.INT)])
|
418
|
+
entry = func.cfg.get_or_create_block("entry")
|
419
|
+
x_var = Variable("x", MIRType.INT)
|
420
|
+
result = Temp(MIRType.INT, 70)
|
421
|
+
entry.instructions = [
|
422
|
+
BinaryOp(result, "*", x_var, Constant(2), (1, 1)),
|
423
|
+
Return((1, 1), result),
|
424
|
+
]
|
425
|
+
module.functions["no_calls"] = func
|
426
|
+
|
427
|
+
# Run inlining
|
428
|
+
inliner = FunctionInlining()
|
429
|
+
modified = inliner.run_on_module(module)
|
430
|
+
|
431
|
+
assert not modified, "Module should not be modified"
|
432
|
+
|
433
|
+
stats = inliner.get_statistics()
|
434
|
+
assert stats["inlined"] == 0
|
435
|
+
assert stats["call_sites_processed"] == 0
|