machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,551 @@
|
|
1
|
+
"""Function inlining optimization pass.
|
2
|
+
|
3
|
+
This module implements function inlining to eliminate call overhead and
|
4
|
+
enable further optimizations by exposing more code to analysis.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from collections import defaultdict
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
12
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
13
|
+
from machine_dialect.mir.mir_instructions import (
|
14
|
+
Call,
|
15
|
+
ConditionalJump,
|
16
|
+
Copy,
|
17
|
+
Jump,
|
18
|
+
MIRInstruction,
|
19
|
+
Phi,
|
20
|
+
Return,
|
21
|
+
)
|
22
|
+
from machine_dialect.mir.mir_module import MIRModule
|
23
|
+
from machine_dialect.mir.mir_transformer import MIRTransformer
|
24
|
+
from machine_dialect.mir.mir_types import MIRType
|
25
|
+
from machine_dialect.mir.mir_values import FunctionRef, MIRValue, Temp, Variable
|
26
|
+
from machine_dialect.mir.optimization_pass import (
|
27
|
+
ModulePass,
|
28
|
+
PassInfo,
|
29
|
+
PassType,
|
30
|
+
PreservationLevel,
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class InliningCost:
|
36
|
+
"""Cost model for inlining decisions.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
instruction_count: Number of instructions in the function.
|
40
|
+
call_site_benefit: Benefit from inlining at this call site.
|
41
|
+
size_threshold: Maximum size for inlining.
|
42
|
+
depth: Current inlining depth (to prevent infinite recursion).
|
43
|
+
"""
|
44
|
+
|
45
|
+
instruction_count: int
|
46
|
+
call_site_benefit: float
|
47
|
+
size_threshold: int
|
48
|
+
depth: int
|
49
|
+
|
50
|
+
def should_inline(self) -> bool:
|
51
|
+
"""Determine if function should be inlined.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
True if inlining is beneficial.
|
55
|
+
"""
|
56
|
+
# Don't inline if too deep (prevent infinite recursion)
|
57
|
+
if self.depth > 3:
|
58
|
+
return False
|
59
|
+
|
60
|
+
# Don't inline very large functions
|
61
|
+
if self.instruction_count > self.size_threshold:
|
62
|
+
return False
|
63
|
+
|
64
|
+
# Inline small functions (always beneficial)
|
65
|
+
if self.instruction_count <= 5:
|
66
|
+
return True
|
67
|
+
|
68
|
+
# Use cost-benefit analysis for medium functions
|
69
|
+
# Higher benefit for functions that enable optimizations
|
70
|
+
cost = self.instruction_count * 1.0
|
71
|
+
benefit = self.call_site_benefit
|
72
|
+
|
73
|
+
# Inline if benefit outweighs cost
|
74
|
+
return benefit > cost
|
75
|
+
|
76
|
+
|
77
|
+
class FunctionInlining(ModulePass):
|
78
|
+
"""Function inlining optimization pass."""
|
79
|
+
|
80
|
+
def __init__(self, size_threshold: int = 50) -> None:
|
81
|
+
"""Initialize inlining pass.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
size_threshold: Maximum function size to consider for inlining.
|
85
|
+
"""
|
86
|
+
super().__init__()
|
87
|
+
self.size_threshold = size_threshold
|
88
|
+
self.stats = {"inlined": 0, "call_sites_processed": 0}
|
89
|
+
self.inlining_depth: dict[str, int] = defaultdict(int)
|
90
|
+
|
91
|
+
def initialize(self) -> None:
|
92
|
+
"""Initialize the pass before running."""
|
93
|
+
super().initialize()
|
94
|
+
# Re-initialize stats after base class clears them
|
95
|
+
self.stats = {"inlined": 0, "call_sites_processed": 0}
|
96
|
+
self.inlining_depth = defaultdict(int)
|
97
|
+
|
98
|
+
def get_info(self) -> PassInfo:
|
99
|
+
"""Get pass information.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
Pass information.
|
103
|
+
"""
|
104
|
+
return PassInfo(
|
105
|
+
name="inline",
|
106
|
+
description="Inline function calls",
|
107
|
+
pass_type=PassType.OPTIMIZATION,
|
108
|
+
requires=[],
|
109
|
+
preserves=PreservationLevel.NONE,
|
110
|
+
)
|
111
|
+
|
112
|
+
def run_on_module(self, module: MIRModule) -> bool:
|
113
|
+
"""Run inlining on a module.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
module: The module to optimize.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
True if the module was modified.
|
120
|
+
"""
|
121
|
+
modified = False
|
122
|
+
|
123
|
+
# Process each function
|
124
|
+
for _, function in module.functions.items():
|
125
|
+
if self._inline_calls_in_function(function, module):
|
126
|
+
modified = True
|
127
|
+
|
128
|
+
return modified
|
129
|
+
|
130
|
+
def _inline_calls_in_function(self, function: MIRFunction, module: MIRModule) -> bool:
|
131
|
+
"""Inline calls within a function.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
function: The function to process.
|
135
|
+
module: The containing module.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
True if modifications were made.
|
139
|
+
"""
|
140
|
+
modified = False
|
141
|
+
transformer = MIRTransformer(function)
|
142
|
+
|
143
|
+
# Keep inlining until no more opportunities
|
144
|
+
# This handles the case where inlining creates new opportunities
|
145
|
+
changed = True
|
146
|
+
while changed:
|
147
|
+
changed = False
|
148
|
+
|
149
|
+
# Find all call sites fresh each iteration
|
150
|
+
call_sites = self._find_call_sites(function)
|
151
|
+
|
152
|
+
for call_inst, block in call_sites:
|
153
|
+
self.stats["call_sites_processed"] += 1
|
154
|
+
|
155
|
+
# Get the called function
|
156
|
+
if not isinstance(call_inst.func, FunctionRef):
|
157
|
+
continue
|
158
|
+
|
159
|
+
callee_name = call_inst.func.name
|
160
|
+
if callee_name not in module.functions:
|
161
|
+
continue
|
162
|
+
|
163
|
+
callee = module.functions[callee_name]
|
164
|
+
|
165
|
+
# Check if we should inline
|
166
|
+
cost = self._calculate_inlining_cost(callee, call_inst, self.inlining_depth[callee_name])
|
167
|
+
if not cost.should_inline():
|
168
|
+
continue
|
169
|
+
|
170
|
+
# Don't inline recursive functions directly
|
171
|
+
if callee_name == function.name:
|
172
|
+
continue
|
173
|
+
|
174
|
+
# Verify the call is still in the block (it might have been removed by previous inlining)
|
175
|
+
if call_inst not in block.instructions:
|
176
|
+
continue
|
177
|
+
|
178
|
+
# Perform inlining
|
179
|
+
self.inlining_depth[callee_name] += 1
|
180
|
+
if self._inline_call(call_inst, block, callee, function, transformer):
|
181
|
+
modified = True
|
182
|
+
changed = True
|
183
|
+
self.stats["inlined"] += 1
|
184
|
+
# Break inner loop to re-find call sites
|
185
|
+
break
|
186
|
+
self.inlining_depth[callee_name] -= 1
|
187
|
+
|
188
|
+
return modified
|
189
|
+
|
190
|
+
def _find_call_sites(self, function: MIRFunction) -> list[tuple[Call, BasicBlock]]:
|
191
|
+
"""Find all call instructions in a function.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
function: The function to search.
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
List of (call instruction, containing block) pairs.
|
198
|
+
"""
|
199
|
+
call_sites = []
|
200
|
+
for block in function.cfg.blocks.values():
|
201
|
+
for inst in block.instructions:
|
202
|
+
if isinstance(inst, Call):
|
203
|
+
call_sites.append((inst, block))
|
204
|
+
return call_sites
|
205
|
+
|
206
|
+
def _calculate_inlining_cost(self, callee: MIRFunction, call_inst: Call, depth: int) -> InliningCost:
|
207
|
+
"""Calculate the cost of inlining a function.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
callee: The function to inline.
|
211
|
+
call_inst: The call instruction.
|
212
|
+
depth: Current inlining depth.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
Inlining cost information.
|
216
|
+
"""
|
217
|
+
# Count instructions in callee
|
218
|
+
instruction_count = sum(len(block.instructions) for block in callee.cfg.blocks.values())
|
219
|
+
|
220
|
+
# Calculate call site benefit
|
221
|
+
# Higher benefit if:
|
222
|
+
# - Arguments are constants (enables constant propagation)
|
223
|
+
# - Function is called in a loop (amortizes inlining cost)
|
224
|
+
# - Function has single return (simpler CFG merge)
|
225
|
+
benefit = 10.0 # Base benefit from removing call overhead
|
226
|
+
|
227
|
+
# Bonus for constant arguments
|
228
|
+
from machine_dialect.mir.mir_values import Constant
|
229
|
+
|
230
|
+
const_args = sum(1 for arg in call_inst.args if isinstance(arg, Constant))
|
231
|
+
benefit += const_args * 5.0
|
232
|
+
|
233
|
+
# Bonus for simple functions (single block)
|
234
|
+
if len(callee.cfg.blocks) == 1:
|
235
|
+
benefit += 10.0
|
236
|
+
|
237
|
+
# Penalty for multiple returns (complex CFG)
|
238
|
+
return_count = sum(
|
239
|
+
1 for block in callee.cfg.blocks.values() for inst in block.instructions if isinstance(inst, Return)
|
240
|
+
)
|
241
|
+
if return_count > 1:
|
242
|
+
benefit -= (return_count - 1) * 5.0
|
243
|
+
|
244
|
+
return InliningCost(
|
245
|
+
instruction_count=instruction_count,
|
246
|
+
call_site_benefit=benefit,
|
247
|
+
size_threshold=self.size_threshold,
|
248
|
+
depth=depth,
|
249
|
+
)
|
250
|
+
|
251
|
+
def _inline_call(
|
252
|
+
self,
|
253
|
+
call_inst: Call,
|
254
|
+
call_block: BasicBlock,
|
255
|
+
callee: MIRFunction,
|
256
|
+
caller: MIRFunction,
|
257
|
+
transformer: MIRTransformer,
|
258
|
+
) -> bool:
|
259
|
+
"""Inline a function call.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
call_inst: The call instruction to inline.
|
263
|
+
call_block: The block containing the call.
|
264
|
+
callee: The function to inline.
|
265
|
+
caller: The calling function.
|
266
|
+
transformer: MIR transformer.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
True if inlining succeeded.
|
270
|
+
"""
|
271
|
+
# Create value mapping for parameters -> arguments
|
272
|
+
value_map: dict[MIRValue, MIRValue] = {}
|
273
|
+
# Get parameter names from the callee function
|
274
|
+
# The params might be strings or Variables
|
275
|
+
param_names: list[str] = []
|
276
|
+
if hasattr(callee, "params") and callee.params:
|
277
|
+
if isinstance(callee.params[0], str):
|
278
|
+
param_names = callee.params # type: ignore
|
279
|
+
else:
|
280
|
+
# Extract names from Variable objects
|
281
|
+
for p in callee.params:
|
282
|
+
if isinstance(p, Variable):
|
283
|
+
param_names.append(p.name)
|
284
|
+
else:
|
285
|
+
param_names.append(str(p))
|
286
|
+
|
287
|
+
for param_name, arg in zip(param_names, call_inst.args, strict=True):
|
288
|
+
param_var = Variable(param_name, MIRType.INT) # Assume INT for now
|
289
|
+
value_map[param_var] = arg
|
290
|
+
|
291
|
+
# Clone the callee's CFG
|
292
|
+
_cloned_blocks, entry_block, return_blocks = self._clone_function_body(callee, caller, value_map, transformer)
|
293
|
+
|
294
|
+
# Split the call block at the call instruction
|
295
|
+
call_idx = call_block.instructions.index(call_inst)
|
296
|
+
pre_call = call_block.instructions[:call_idx]
|
297
|
+
post_call = call_block.instructions[call_idx + 1 :]
|
298
|
+
|
299
|
+
# Create continuation block for code after the call
|
300
|
+
cont_block = BasicBlock(f"{call_block.label}_cont")
|
301
|
+
caller.cfg.add_block(cont_block)
|
302
|
+
cont_block.instructions = post_call
|
303
|
+
cont_block.successors = call_block.successors.copy()
|
304
|
+
|
305
|
+
# Update predecessors of original successors
|
306
|
+
for succ in call_block.successors:
|
307
|
+
succ.predecessors.remove(call_block)
|
308
|
+
succ.predecessors.append(cont_block)
|
309
|
+
|
310
|
+
# Modify call block to jump to inlined entry
|
311
|
+
call_block.instructions = [*pre_call, Jump(entry_block.label, call_inst.source_location)]
|
312
|
+
call_block.successors = [entry_block]
|
313
|
+
entry_block.predecessors.append(call_block)
|
314
|
+
|
315
|
+
# Handle returns from inlined function
|
316
|
+
if call_inst.dest:
|
317
|
+
# If the call has a destination, we need to merge return values
|
318
|
+
return_value_var = call_inst.dest
|
319
|
+
for ret_block in return_blocks:
|
320
|
+
# Replace return with assignment and jump to continuation
|
321
|
+
ret_inst = ret_block.instructions[-1]
|
322
|
+
assert isinstance(ret_inst, Return)
|
323
|
+
if ret_inst.value:
|
324
|
+
ret_block.instructions[-1] = Copy(return_value_var, ret_inst.value, ret_inst.source_location)
|
325
|
+
ret_block.instructions.append(Jump(cont_block.label, ret_inst.source_location))
|
326
|
+
else:
|
327
|
+
ret_block.instructions[-1] = Jump(cont_block.label, ret_inst.source_location)
|
328
|
+
ret_block.successors = [cont_block]
|
329
|
+
cont_block.predecessors.append(ret_block)
|
330
|
+
else:
|
331
|
+
# No return value - just jump to continuation
|
332
|
+
for ret_block in return_blocks:
|
333
|
+
ret_inst = ret_block.instructions[-1]
|
334
|
+
source_loc = ret_inst.source_location if hasattr(ret_inst, "source_location") else (0, 0)
|
335
|
+
ret_block.instructions[-1] = Jump(cont_block.label, source_loc)
|
336
|
+
ret_block.successors = [cont_block]
|
337
|
+
cont_block.predecessors.append(ret_block)
|
338
|
+
|
339
|
+
transformer.modified = True
|
340
|
+
return True
|
341
|
+
|
342
|
+
def _clone_function_body(
|
343
|
+
self,
|
344
|
+
callee: MIRFunction,
|
345
|
+
caller: MIRFunction,
|
346
|
+
value_map: dict[MIRValue, MIRValue],
|
347
|
+
transformer: MIRTransformer,
|
348
|
+
) -> tuple[dict[str, BasicBlock], BasicBlock, list[BasicBlock]]:
|
349
|
+
"""Clone a function's body for inlining.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
callee: The function to clone.
|
353
|
+
caller: The calling function.
|
354
|
+
value_map: Mapping from callee values to caller values.
|
355
|
+
transformer: MIR transformer.
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
Tuple of (cloned blocks dict, entry block, return blocks list).
|
359
|
+
"""
|
360
|
+
# Create a mapping for blocks
|
361
|
+
block_map: dict[BasicBlock, BasicBlock] = {}
|
362
|
+
cloned_blocks: dict[str, BasicBlock] = {}
|
363
|
+
|
364
|
+
# First pass: create all blocks
|
365
|
+
for old_block in callee.cfg.blocks.values():
|
366
|
+
new_label = f"inlined_{callee.name}_{old_block.label}"
|
367
|
+
new_block = BasicBlock(new_label)
|
368
|
+
caller.cfg.add_block(new_block)
|
369
|
+
block_map[old_block] = new_block
|
370
|
+
cloned_blocks[new_label] = new_block
|
371
|
+
|
372
|
+
# Map entry block - if not set, assume first block or "entry" label
|
373
|
+
if callee.cfg.entry_block:
|
374
|
+
entry_block = block_map[callee.cfg.entry_block]
|
375
|
+
else:
|
376
|
+
# Try to find entry block by label
|
377
|
+
entry_block = None
|
378
|
+
for old_block in callee.cfg.blocks.values():
|
379
|
+
if old_block.label == "entry":
|
380
|
+
entry_block = block_map[old_block]
|
381
|
+
break
|
382
|
+
if not entry_block and block_map:
|
383
|
+
# Use first block as entry
|
384
|
+
entry_block = next(iter(block_map.values()))
|
385
|
+
if not entry_block:
|
386
|
+
# Create a dummy entry block if empty
|
387
|
+
entry_block = BasicBlock("inline_entry")
|
388
|
+
|
389
|
+
assert entry_block is not None, "Entry block must be set"
|
390
|
+
|
391
|
+
# Generate unique temps for the inlined function
|
392
|
+
temp_counter = caller._next_temp_id
|
393
|
+
|
394
|
+
def map_value(value: MIRValue) -> MIRValue:
|
395
|
+
"""Map a value from callee to caller."""
|
396
|
+
if value in value_map:
|
397
|
+
return value_map[value]
|
398
|
+
if isinstance(value, Temp):
|
399
|
+
# Create new temp with unique ID
|
400
|
+
nonlocal temp_counter
|
401
|
+
new_temp = Temp(value.type, temp_counter)
|
402
|
+
temp_counter += 1
|
403
|
+
caller._next_temp_id = temp_counter
|
404
|
+
value_map[value] = new_temp
|
405
|
+
return new_temp
|
406
|
+
# Constants and other values remain unchanged
|
407
|
+
return value
|
408
|
+
|
409
|
+
# Second pass: clone instructions and update CFG
|
410
|
+
return_blocks = []
|
411
|
+
for old_block, new_block in block_map.items():
|
412
|
+
# Clone instructions
|
413
|
+
for inst in old_block.instructions:
|
414
|
+
cloned_inst = self._clone_instruction(inst, map_value, block_map)
|
415
|
+
new_block.instructions.append(cloned_inst)
|
416
|
+
|
417
|
+
# Track return blocks
|
418
|
+
if isinstance(cloned_inst, Return):
|
419
|
+
return_blocks.append(new_block)
|
420
|
+
|
421
|
+
# Update successors/predecessors
|
422
|
+
for succ in old_block.successors:
|
423
|
+
new_succ = block_map[succ]
|
424
|
+
new_block.successors.append(new_succ)
|
425
|
+
new_succ.predecessors.append(new_block)
|
426
|
+
|
427
|
+
return cloned_blocks, entry_block, return_blocks
|
428
|
+
|
429
|
+
def _clone_instruction(
|
430
|
+
self,
|
431
|
+
inst: MIRInstruction,
|
432
|
+
map_value: Any,
|
433
|
+
block_map: dict[BasicBlock, BasicBlock],
|
434
|
+
) -> MIRInstruction:
|
435
|
+
"""Clone an instruction with value remapping.
|
436
|
+
|
437
|
+
Args:
|
438
|
+
inst: The instruction to clone.
|
439
|
+
map_value: Function to map values.
|
440
|
+
block_map: Mapping from old blocks to new blocks.
|
441
|
+
|
442
|
+
Returns:
|
443
|
+
Cloned instruction.
|
444
|
+
"""
|
445
|
+
# Import here to avoid circular dependency
|
446
|
+
from machine_dialect.mir.mir_instructions import BinaryOp, LoadConst, Print, StoreVar, UnaryOp
|
447
|
+
|
448
|
+
# Handle each instruction type
|
449
|
+
if isinstance(inst, BinaryOp):
|
450
|
+
return BinaryOp(
|
451
|
+
map_value(inst.dest),
|
452
|
+
inst.op,
|
453
|
+
map_value(inst.left),
|
454
|
+
map_value(inst.right),
|
455
|
+
inst.source_location,
|
456
|
+
)
|
457
|
+
elif isinstance(inst, UnaryOp):
|
458
|
+
return UnaryOp(
|
459
|
+
map_value(inst.dest),
|
460
|
+
inst.op,
|
461
|
+
map_value(inst.operand),
|
462
|
+
inst.source_location,
|
463
|
+
)
|
464
|
+
elif isinstance(inst, Copy):
|
465
|
+
return Copy(
|
466
|
+
map_value(inst.dest),
|
467
|
+
map_value(inst.source),
|
468
|
+
inst.source_location,
|
469
|
+
)
|
470
|
+
elif isinstance(inst, LoadConst):
|
471
|
+
return LoadConst(
|
472
|
+
map_value(inst.dest),
|
473
|
+
inst.constant.value if hasattr(inst.constant, "value") else inst.constant, # Use the constant value
|
474
|
+
inst.source_location,
|
475
|
+
)
|
476
|
+
elif isinstance(inst, StoreVar):
|
477
|
+
return StoreVar(
|
478
|
+
inst.var, # Variable names stay the same
|
479
|
+
map_value(inst.source),
|
480
|
+
inst.source_location,
|
481
|
+
)
|
482
|
+
elif isinstance(inst, Call):
|
483
|
+
return Call(
|
484
|
+
map_value(inst.dest) if inst.dest else None,
|
485
|
+
inst.func,
|
486
|
+
[map_value(arg) for arg in inst.args],
|
487
|
+
inst.source_location,
|
488
|
+
)
|
489
|
+
elif isinstance(inst, Return):
|
490
|
+
return Return(
|
491
|
+
inst.source_location,
|
492
|
+
map_value(inst.value) if inst.value else None,
|
493
|
+
)
|
494
|
+
elif isinstance(inst, ConditionalJump):
|
495
|
+
# Find the blocks that correspond to the labels
|
496
|
+
true_block = None
|
497
|
+
false_block = None
|
498
|
+
for old_b, new_b in block_map.items():
|
499
|
+
if old_b.label == inst.true_label:
|
500
|
+
true_block = new_b
|
501
|
+
if inst.false_label and old_b.label == inst.false_label:
|
502
|
+
false_block = new_b
|
503
|
+
return ConditionalJump(
|
504
|
+
map_value(inst.condition),
|
505
|
+
true_block.label if true_block else inst.true_label,
|
506
|
+
inst.source_location,
|
507
|
+
false_block.label if false_block else inst.false_label,
|
508
|
+
)
|
509
|
+
elif isinstance(inst, Jump):
|
510
|
+
# Find the block that corresponds to the label
|
511
|
+
target_block = None
|
512
|
+
for old_b, new_b in block_map.items():
|
513
|
+
if old_b.label == inst.label:
|
514
|
+
target_block = new_b
|
515
|
+
break
|
516
|
+
return Jump(
|
517
|
+
target_block.label if target_block else inst.label,
|
518
|
+
inst.source_location,
|
519
|
+
)
|
520
|
+
elif isinstance(inst, Phi):
|
521
|
+
new_incoming = []
|
522
|
+
for value, label in inst.incoming:
|
523
|
+
# Find the new label for this block
|
524
|
+
new_label = label
|
525
|
+
for old_b, new_b in block_map.items():
|
526
|
+
if old_b.label == label:
|
527
|
+
new_label = new_b.label
|
528
|
+
break
|
529
|
+
new_incoming.append((map_value(value), new_label))
|
530
|
+
return Phi(map_value(inst.dest), new_incoming, inst.source_location)
|
531
|
+
elif isinstance(inst, Print):
|
532
|
+
return Print(map_value(inst.value), inst.source_location)
|
533
|
+
else:
|
534
|
+
# For any other instruction types, return as-is
|
535
|
+
# This is conservative - may need to extend for new instruction types
|
536
|
+
return inst
|
537
|
+
|
538
|
+
def finalize(self) -> None:
|
539
|
+
"""Finalize the pass after running.
|
540
|
+
|
541
|
+
Cleans up any temporary state.
|
542
|
+
"""
|
543
|
+
self.inlining_depth.clear()
|
544
|
+
|
545
|
+
def get_statistics(self) -> dict[str, int]:
|
546
|
+
"""Get optimization statistics.
|
547
|
+
|
548
|
+
Returns:
|
549
|
+
Dictionary of statistics.
|
550
|
+
"""
|
551
|
+
return self.stats
|