machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,447 @@
|
|
1
|
+
"""Type specialization optimization pass.
|
2
|
+
|
3
|
+
This module implements type specialization to generate optimized versions
|
4
|
+
of functions for specific type combinations based on profiling data.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from collections import defaultdict
|
8
|
+
from dataclasses import dataclass
|
9
|
+
|
10
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
11
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
12
|
+
from machine_dialect.mir.mir_instructions import (
|
13
|
+
BinaryOp,
|
14
|
+
Call,
|
15
|
+
)
|
16
|
+
from machine_dialect.mir.mir_module import MIRModule
|
17
|
+
from machine_dialect.mir.mir_types import MIRType, MIRUnionType
|
18
|
+
from machine_dialect.mir.mir_values import Constant, MIRValue, Variable
|
19
|
+
from machine_dialect.mir.optimization_pass import (
|
20
|
+
ModulePass,
|
21
|
+
PassInfo,
|
22
|
+
PassType,
|
23
|
+
PreservationLevel,
|
24
|
+
)
|
25
|
+
from machine_dialect.mir.profiling.profile_data import ProfileData
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class TypeSignature:
|
30
|
+
"""Type signature for function specialization.
|
31
|
+
|
32
|
+
Attributes:
|
33
|
+
param_types: Types of function parameters.
|
34
|
+
return_type: Return type of the function.
|
35
|
+
"""
|
36
|
+
|
37
|
+
param_types: tuple[MIRType | MIRUnionType, ...]
|
38
|
+
return_type: MIRType | MIRUnionType
|
39
|
+
|
40
|
+
def __hash__(self) -> int:
|
41
|
+
"""Hash for use in dictionaries."""
|
42
|
+
return hash((self.param_types, self.return_type))
|
43
|
+
|
44
|
+
def __str__(self) -> str:
|
45
|
+
"""String representation."""
|
46
|
+
params = ", ".join(str(t) for t in self.param_types)
|
47
|
+
return f"({params}) -> {self.return_type}"
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class SpecializationCandidate:
|
52
|
+
"""Candidate function for type specialization.
|
53
|
+
|
54
|
+
Attributes:
|
55
|
+
function_name: Name of the function.
|
56
|
+
signature: Type signature to specialize for.
|
57
|
+
call_count: Number of calls with this signature.
|
58
|
+
benefit: Estimated benefit of specialization.
|
59
|
+
"""
|
60
|
+
|
61
|
+
function_name: str
|
62
|
+
signature: TypeSignature
|
63
|
+
call_count: int
|
64
|
+
benefit: float
|
65
|
+
|
66
|
+
def specialized_name(self) -> str:
|
67
|
+
"""Generate specialized function name."""
|
68
|
+
type_names = []
|
69
|
+
for t in self.signature.param_types:
|
70
|
+
if isinstance(t, MIRUnionType):
|
71
|
+
# Format union types as "union_type1_type2"
|
72
|
+
union_name = "union_" + "_".join(ut.name.lower() for ut in t.types)
|
73
|
+
type_names.append(union_name)
|
74
|
+
else:
|
75
|
+
type_names.append(t.name.lower())
|
76
|
+
type_suffix = "_".join(type_names)
|
77
|
+
return f"{self.function_name}__{type_suffix}"
|
78
|
+
|
79
|
+
|
80
|
+
class TypeSpecialization(ModulePass):
|
81
|
+
"""Type specialization optimization pass.
|
82
|
+
|
83
|
+
This pass creates specialized versions of functions for frequently-used
|
84
|
+
type combinations, enabling better optimization and reducing type checks.
|
85
|
+
"""
|
86
|
+
|
87
|
+
def __init__(self, profile_data: ProfileData | None = None, threshold: int = 100) -> None:
|
88
|
+
"""Initialize type specialization pass.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
profile_data: Optional profiling data for hot type combinations.
|
92
|
+
threshold: Minimum call count to consider specialization.
|
93
|
+
"""
|
94
|
+
super().__init__()
|
95
|
+
self.profile_data = profile_data
|
96
|
+
self.threshold = threshold
|
97
|
+
self.stats = {
|
98
|
+
"functions_analyzed": 0,
|
99
|
+
"functions_specialized": 0,
|
100
|
+
"specializations_created": 0,
|
101
|
+
"type_checks_eliminated": 0,
|
102
|
+
}
|
103
|
+
|
104
|
+
# Track type signatures seen for each function
|
105
|
+
self.type_signatures: dict[str, dict[TypeSignature, int]] = defaultdict(lambda: defaultdict(int))
|
106
|
+
|
107
|
+
# Map of original to specialized functions
|
108
|
+
self.specializations: dict[str, dict[TypeSignature, str]] = defaultdict(dict)
|
109
|
+
|
110
|
+
def get_info(self) -> PassInfo:
|
111
|
+
"""Get pass information.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Pass information.
|
115
|
+
"""
|
116
|
+
return PassInfo(
|
117
|
+
name="type-specialization",
|
118
|
+
description="Create type-specialized function versions",
|
119
|
+
pass_type=PassType.OPTIMIZATION,
|
120
|
+
requires=[],
|
121
|
+
preserves=PreservationLevel.NONE,
|
122
|
+
)
|
123
|
+
|
124
|
+
def finalize(self) -> None:
|
125
|
+
"""Finalize the pass after running."""
|
126
|
+
pass
|
127
|
+
|
128
|
+
def run_on_module(self, module: MIRModule) -> bool:
|
129
|
+
"""Run type specialization on a module.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
module: The module to optimize.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
True if the module was modified.
|
136
|
+
"""
|
137
|
+
modified = False
|
138
|
+
|
139
|
+
# Phase 1: Collect type signatures from call sites
|
140
|
+
self._collect_type_signatures(module)
|
141
|
+
|
142
|
+
# Phase 2: Identify specialization candidates
|
143
|
+
candidates = self._identify_candidates(module)
|
144
|
+
|
145
|
+
# Phase 3: Create specialized functions
|
146
|
+
for candidate in candidates:
|
147
|
+
if self._create_specialization(module, candidate):
|
148
|
+
modified = True
|
149
|
+
self.stats["specializations_created"] += 1
|
150
|
+
|
151
|
+
# Phase 4: Update call sites to use specialized versions
|
152
|
+
if modified:
|
153
|
+
self._update_call_sites(module)
|
154
|
+
|
155
|
+
return modified
|
156
|
+
|
157
|
+
def _collect_type_signatures(self, module: MIRModule) -> None:
|
158
|
+
"""Collect type signatures from all call sites.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
module: The module to analyze.
|
162
|
+
"""
|
163
|
+
for function in module.functions.values():
|
164
|
+
self.stats["functions_analyzed"] += 1
|
165
|
+
|
166
|
+
for block in function.cfg.blocks.values():
|
167
|
+
for inst in block.instructions:
|
168
|
+
if isinstance(inst, Call):
|
169
|
+
# Infer types of arguments
|
170
|
+
arg_types = self._infer_arg_types(inst.args)
|
171
|
+
if arg_types and hasattr(inst.func, "name"):
|
172
|
+
# Record this type signature
|
173
|
+
func_name = inst.func.name
|
174
|
+
return_type = self._infer_return_type(inst)
|
175
|
+
signature = TypeSignature(arg_types, return_type)
|
176
|
+
|
177
|
+
# Use profile data if available
|
178
|
+
if self.profile_data and func_name in self.profile_data.functions:
|
179
|
+
profile = self.profile_data.functions[func_name]
|
180
|
+
self.type_signatures[func_name][signature] += profile.call_count
|
181
|
+
else:
|
182
|
+
self.type_signatures[func_name][signature] += 1
|
183
|
+
|
184
|
+
def _infer_arg_types(self, args: list[MIRValue]) -> tuple[MIRType | MIRUnionType, ...] | None:
|
185
|
+
"""Infer types of function arguments.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
args: List of argument values.
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
Tuple of types or None if unable to infer.
|
192
|
+
"""
|
193
|
+
types = []
|
194
|
+
for arg in args:
|
195
|
+
if isinstance(arg, Constant):
|
196
|
+
types.append(arg.type)
|
197
|
+
elif isinstance(arg, Variable):
|
198
|
+
if arg.type != MIRType.UNKNOWN:
|
199
|
+
types.append(arg.type)
|
200
|
+
else:
|
201
|
+
return None # Can't infer all types
|
202
|
+
else:
|
203
|
+
return None # Unknown value type
|
204
|
+
|
205
|
+
return tuple(types)
|
206
|
+
|
207
|
+
def _infer_return_type(self, call: Call) -> MIRType | MIRUnionType:
|
208
|
+
"""Infer return type of a call.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
call: The call instruction.
|
212
|
+
|
213
|
+
Returns:
|
214
|
+
The inferred return type.
|
215
|
+
"""
|
216
|
+
if call.dest:
|
217
|
+
if hasattr(call.dest, "type"):
|
218
|
+
return call.dest.type
|
219
|
+
return MIRType.UNKNOWN
|
220
|
+
|
221
|
+
def _identify_candidates(self, module: MIRModule) -> list[SpecializationCandidate]:
|
222
|
+
"""Identify functions worth specializing.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
module: The module to analyze.
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
List of specialization candidates.
|
229
|
+
"""
|
230
|
+
candidates = []
|
231
|
+
|
232
|
+
for func_name, signatures in self.type_signatures.items():
|
233
|
+
# Skip if function doesn't exist in module
|
234
|
+
if func_name not in module.functions:
|
235
|
+
continue
|
236
|
+
|
237
|
+
function = module.functions[func_name]
|
238
|
+
|
239
|
+
# Skip if function is too large (avoid code bloat)
|
240
|
+
if self._count_instructions(function) > 100:
|
241
|
+
continue
|
242
|
+
|
243
|
+
# Find hot type signatures
|
244
|
+
for signature, count in signatures.items():
|
245
|
+
if count >= self.threshold:
|
246
|
+
# Calculate benefit based on:
|
247
|
+
# 1. Call frequency
|
248
|
+
# 2. Potential for optimization (numeric types benefit more)
|
249
|
+
# 3. Type check elimination
|
250
|
+
benefit = self._calculate_benefit(signature, count, function)
|
251
|
+
|
252
|
+
if benefit > 0:
|
253
|
+
candidates.append(
|
254
|
+
SpecializationCandidate(
|
255
|
+
function_name=func_name, signature=signature, call_count=count, benefit=benefit
|
256
|
+
)
|
257
|
+
)
|
258
|
+
|
259
|
+
# Sort by benefit (highest first)
|
260
|
+
candidates.sort(key=lambda c: c.benefit, reverse=True)
|
261
|
+
|
262
|
+
# Limit number of specializations to avoid code bloat
|
263
|
+
return candidates[:10]
|
264
|
+
|
265
|
+
def _count_instructions(self, function: MIRFunction) -> int:
|
266
|
+
"""Count instructions in a function.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
function: The function to count.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
Total instruction count.
|
273
|
+
"""
|
274
|
+
count = 0
|
275
|
+
for block in function.cfg.blocks.values():
|
276
|
+
count += len(block.instructions)
|
277
|
+
return count
|
278
|
+
|
279
|
+
def _calculate_benefit(self, signature: TypeSignature, call_count: int, function: MIRFunction) -> float:
|
280
|
+
"""Calculate specialization benefit.
|
281
|
+
|
282
|
+
Args:
|
283
|
+
signature: Type signature to specialize for.
|
284
|
+
call_count: Number of calls with this signature.
|
285
|
+
function: The function to specialize.
|
286
|
+
|
287
|
+
Returns:
|
288
|
+
Estimated benefit score.
|
289
|
+
"""
|
290
|
+
benefit = 0.0
|
291
|
+
|
292
|
+
# Benefit from call frequency
|
293
|
+
benefit += call_count * 0.1
|
294
|
+
|
295
|
+
# Benefit from numeric types (can use specialized instructions)
|
296
|
+
for param_type in signature.param_types:
|
297
|
+
if param_type in (MIRType.INT, MIRType.FLOAT):
|
298
|
+
benefit += 20.0
|
299
|
+
elif param_type == MIRType.BOOL:
|
300
|
+
benefit += 10.0
|
301
|
+
|
302
|
+
# Benefit from eliminating type checks
|
303
|
+
type_check_count = self._count_type_checks(function)
|
304
|
+
benefit += type_check_count * 5.0
|
305
|
+
|
306
|
+
# Penalty for code size
|
307
|
+
inst_count = self._count_instructions(function)
|
308
|
+
benefit -= inst_count * 0.5
|
309
|
+
|
310
|
+
return max(0.0, benefit)
|
311
|
+
|
312
|
+
def _count_type_checks(self, function: MIRFunction) -> int:
|
313
|
+
"""Count potential type checks in a function.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
function: The function to analyze.
|
317
|
+
|
318
|
+
Returns:
|
319
|
+
Number of potential type checks.
|
320
|
+
"""
|
321
|
+
# Simple heuristic: count operations that might need type checking
|
322
|
+
count = 0
|
323
|
+
for block in function.cfg.blocks.values():
|
324
|
+
for inst in block.instructions:
|
325
|
+
if isinstance(inst, BinaryOp):
|
326
|
+
# Binary ops often need type checking
|
327
|
+
count += 1
|
328
|
+
return count
|
329
|
+
|
330
|
+
def _create_specialization(self, module: MIRModule, candidate: SpecializationCandidate) -> bool:
|
331
|
+
"""Create a specialized version of a function.
|
332
|
+
|
333
|
+
Args:
|
334
|
+
module: The module containing the function.
|
335
|
+
candidate: The specialization candidate.
|
336
|
+
|
337
|
+
Returns:
|
338
|
+
True if specialization was created.
|
339
|
+
"""
|
340
|
+
original_func = module.functions.get(candidate.function_name)
|
341
|
+
if not original_func:
|
342
|
+
return False
|
343
|
+
|
344
|
+
# Clone the function
|
345
|
+
specialized_name = candidate.specialized_name()
|
346
|
+
specialized_func = self._clone_function(original_func, specialized_name)
|
347
|
+
|
348
|
+
# Apply type information to parameters
|
349
|
+
for i, param in enumerate(specialized_func.params):
|
350
|
+
if i < len(candidate.signature.param_types):
|
351
|
+
param.type = candidate.signature.param_types[i]
|
352
|
+
|
353
|
+
# Optimize the specialized function
|
354
|
+
self._optimize_specialized_function(specialized_func, candidate.signature)
|
355
|
+
|
356
|
+
# Add to module
|
357
|
+
module.add_function(specialized_func)
|
358
|
+
|
359
|
+
# Track the specialization
|
360
|
+
self.specializations[candidate.function_name][candidate.signature] = specialized_name
|
361
|
+
self.stats["functions_specialized"] += 1
|
362
|
+
|
363
|
+
return True
|
364
|
+
|
365
|
+
def _clone_function(self, original: MIRFunction, new_name: str) -> MIRFunction:
|
366
|
+
"""Clone a function with a new name.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
original: The function to clone.
|
370
|
+
new_name: Name for the cloned function.
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
The cloned function.
|
374
|
+
"""
|
375
|
+
# Create new function with same parameters
|
376
|
+
cloned = MIRFunction(new_name, [Variable(p.name, p.type) for p in original.params])
|
377
|
+
|
378
|
+
# Clone all blocks
|
379
|
+
block_mapping: dict[str, str] = {}
|
380
|
+
for block_name, block in original.cfg.blocks.items():
|
381
|
+
new_block = BasicBlock(block_name)
|
382
|
+
|
383
|
+
# Clone instructions
|
384
|
+
for inst in block.instructions:
|
385
|
+
# Deep copy the instruction
|
386
|
+
# (In a real implementation, we'd need proper deep copying)
|
387
|
+
new_block.add_instruction(inst)
|
388
|
+
|
389
|
+
cloned.cfg.add_block(new_block)
|
390
|
+
block_mapping[block_name] = block_name
|
391
|
+
|
392
|
+
# Set entry block
|
393
|
+
if original.cfg.entry_block:
|
394
|
+
cloned.cfg.entry_block = original.cfg.entry_block
|
395
|
+
|
396
|
+
return cloned
|
397
|
+
|
398
|
+
def _optimize_specialized_function(self, function: MIRFunction, signature: TypeSignature) -> None:
|
399
|
+
"""Apply type-specific optimizations to specialized function.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
function: The specialized function.
|
403
|
+
signature: The type signature it's specialized for.
|
404
|
+
"""
|
405
|
+
# With known types, we can:
|
406
|
+
# 1. Eliminate type checks
|
407
|
+
# 2. Use specialized instructions
|
408
|
+
# 3. Constant fold more aggressively
|
409
|
+
|
410
|
+
for block in function.cfg.blocks.values():
|
411
|
+
new_instructions = []
|
412
|
+
|
413
|
+
for inst in block.instructions:
|
414
|
+
# Example: optimize integer operations
|
415
|
+
if isinstance(inst, BinaryOp):
|
416
|
+
# If we know types are integers, can use specialized ops
|
417
|
+
if all(t == MIRType.INT for t in signature.param_types):
|
418
|
+
# Could replace with specialized integer instruction
|
419
|
+
self.stats["type_checks_eliminated"] += 1
|
420
|
+
|
421
|
+
new_instructions.append(inst)
|
422
|
+
|
423
|
+
block.instructions = new_instructions
|
424
|
+
|
425
|
+
def _update_call_sites(self, module: MIRModule) -> None:
|
426
|
+
"""Update call sites to use specialized versions.
|
427
|
+
|
428
|
+
Args:
|
429
|
+
module: The module to update.
|
430
|
+
"""
|
431
|
+
for function in module.functions.values():
|
432
|
+
for block in function.cfg.blocks.values():
|
433
|
+
for inst in block.instructions:
|
434
|
+
if isinstance(inst, Call) and hasattr(inst.func, "name"):
|
435
|
+
func_name = inst.func.name
|
436
|
+
|
437
|
+
# Check if we have a specialization for this call
|
438
|
+
if func_name in self.specializations:
|
439
|
+
arg_types = self._infer_arg_types(inst.args)
|
440
|
+
if arg_types:
|
441
|
+
return_type = self._infer_return_type(inst)
|
442
|
+
signature = TypeSignature(arg_types, return_type)
|
443
|
+
|
444
|
+
if signature in self.specializations[func_name]:
|
445
|
+
# Update to use specialized version
|
446
|
+
specialized_name = self.specializations[func_name][signature]
|
447
|
+
inst.func.name = specialized_name
|