machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,421 @@
|
|
1
|
+
"""Comprehensive tests for type specialization optimization pass."""
|
2
|
+
|
3
|
+
from unittest.mock import MagicMock
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
|
7
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
8
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
9
|
+
from machine_dialect.mir.mir_instructions import (
|
10
|
+
BinaryOp,
|
11
|
+
Call,
|
12
|
+
Return,
|
13
|
+
)
|
14
|
+
from machine_dialect.mir.mir_module import MIRModule
|
15
|
+
from machine_dialect.mir.mir_types import MIRType
|
16
|
+
from machine_dialect.mir.mir_values import Constant, Temp, Variable
|
17
|
+
from machine_dialect.mir.optimization_pass import PassType, PreservationLevel
|
18
|
+
from machine_dialect.mir.optimizations.type_specialization import (
|
19
|
+
SpecializationCandidate,
|
20
|
+
TypeSignature,
|
21
|
+
TypeSpecialization,
|
22
|
+
)
|
23
|
+
from machine_dialect.mir.profiling.profile_data import ProfileData
|
24
|
+
|
25
|
+
|
26
|
+
class TestTypeSignature:
|
27
|
+
"""Test TypeSignature dataclass."""
|
28
|
+
|
29
|
+
def test_signature_creation(self) -> None:
|
30
|
+
"""Test creating a type signature."""
|
31
|
+
sig = TypeSignature(
|
32
|
+
param_types=(MIRType.INT, MIRType.FLOAT),
|
33
|
+
return_type=MIRType.INT,
|
34
|
+
)
|
35
|
+
assert sig.param_types == (MIRType.INT, MIRType.FLOAT)
|
36
|
+
assert sig.return_type == MIRType.INT
|
37
|
+
|
38
|
+
def test_signature_hash(self) -> None:
|
39
|
+
"""Test that type signatures can be hashed."""
|
40
|
+
sig1 = TypeSignature(
|
41
|
+
param_types=(MIRType.INT, MIRType.INT),
|
42
|
+
return_type=MIRType.INT,
|
43
|
+
)
|
44
|
+
sig2 = TypeSignature(
|
45
|
+
param_types=(MIRType.INT, MIRType.INT),
|
46
|
+
return_type=MIRType.INT,
|
47
|
+
)
|
48
|
+
sig3 = TypeSignature(
|
49
|
+
param_types=(MIRType.FLOAT, MIRType.INT),
|
50
|
+
return_type=MIRType.INT,
|
51
|
+
)
|
52
|
+
|
53
|
+
# Same signatures should have same hash
|
54
|
+
assert hash(sig1) == hash(sig2)
|
55
|
+
# Different signatures should have different hash
|
56
|
+
assert hash(sig1) != hash(sig3)
|
57
|
+
|
58
|
+
def test_signature_string_representation(self) -> None:
|
59
|
+
"""Test string representation of type signature."""
|
60
|
+
sig = TypeSignature(
|
61
|
+
param_types=(MIRType.INT, MIRType.BOOL),
|
62
|
+
return_type=MIRType.FLOAT,
|
63
|
+
)
|
64
|
+
assert str(sig) == "(int, bool) -> float"
|
65
|
+
|
66
|
+
|
67
|
+
class TestSpecializationCandidate:
|
68
|
+
"""Test SpecializationCandidate dataclass."""
|
69
|
+
|
70
|
+
def test_candidate_creation(self) -> None:
|
71
|
+
"""Test creating a specialization candidate."""
|
72
|
+
sig = TypeSignature(
|
73
|
+
param_types=(MIRType.INT, MIRType.INT),
|
74
|
+
return_type=MIRType.INT,
|
75
|
+
)
|
76
|
+
candidate = SpecializationCandidate(
|
77
|
+
function_name="add",
|
78
|
+
signature=sig,
|
79
|
+
call_count=500,
|
80
|
+
benefit=0.85,
|
81
|
+
)
|
82
|
+
assert candidate.function_name == "add"
|
83
|
+
assert candidate.signature == sig
|
84
|
+
assert candidate.call_count == 500
|
85
|
+
assert candidate.benefit == 0.85
|
86
|
+
|
87
|
+
def test_specialized_name_generation(self) -> None:
|
88
|
+
"""Test generation of specialized function names."""
|
89
|
+
sig = TypeSignature(
|
90
|
+
param_types=(MIRType.INT, MIRType.FLOAT),
|
91
|
+
return_type=MIRType.FLOAT,
|
92
|
+
)
|
93
|
+
candidate = SpecializationCandidate(
|
94
|
+
function_name="multiply",
|
95
|
+
signature=sig,
|
96
|
+
call_count=200,
|
97
|
+
benefit=0.5,
|
98
|
+
)
|
99
|
+
assert candidate.specialized_name() == "multiply__int_float"
|
100
|
+
|
101
|
+
def test_specialized_name_no_params(self) -> None:
|
102
|
+
"""Test specialized name for function with no parameters."""
|
103
|
+
sig = TypeSignature(
|
104
|
+
param_types=(),
|
105
|
+
return_type=MIRType.INT,
|
106
|
+
)
|
107
|
+
candidate = SpecializationCandidate(
|
108
|
+
function_name="get_value",
|
109
|
+
signature=sig,
|
110
|
+
call_count=100,
|
111
|
+
benefit=0.3,
|
112
|
+
)
|
113
|
+
assert candidate.specialized_name() == "get_value__"
|
114
|
+
|
115
|
+
|
116
|
+
class TestTypeSpecialization:
|
117
|
+
"""Test TypeSpecialization optimization pass."""
|
118
|
+
|
119
|
+
@pytest.fixture
|
120
|
+
def module(self) -> MIRModule:
|
121
|
+
"""Test fixture providing a MIRModule with a simple add function."""
|
122
|
+
module = MIRModule("test")
|
123
|
+
|
124
|
+
# Create a simple function to specialize
|
125
|
+
func = MIRFunction(
|
126
|
+
"add",
|
127
|
+
[Variable("a", MIRType.UNKNOWN), Variable("b", MIRType.UNKNOWN)],
|
128
|
+
MIRType.UNKNOWN,
|
129
|
+
)
|
130
|
+
block = BasicBlock("entry")
|
131
|
+
|
132
|
+
# Add simple addition: result = a + b; return result
|
133
|
+
a = Variable("a", MIRType.UNKNOWN)
|
134
|
+
b = Variable("b", MIRType.UNKNOWN)
|
135
|
+
result = Temp(MIRType.UNKNOWN)
|
136
|
+
block.add_instruction(BinaryOp(result, "+", a, b, (1, 1)))
|
137
|
+
block.add_instruction(Return((1, 1), result))
|
138
|
+
|
139
|
+
func.cfg.add_block(block)
|
140
|
+
func.cfg.entry_block = block
|
141
|
+
module.add_function(func)
|
142
|
+
|
143
|
+
return module
|
144
|
+
|
145
|
+
def test_pass_initialization(self) -> None:
|
146
|
+
"""Test initialization of type specialization pass."""
|
147
|
+
opt = TypeSpecialization(threshold=50)
|
148
|
+
assert opt.profile_data is None
|
149
|
+
assert opt.threshold == 50
|
150
|
+
assert opt.stats["functions_analyzed"] == 0
|
151
|
+
assert opt.stats["functions_specialized"] == 0
|
152
|
+
|
153
|
+
def test_pass_info(self) -> None:
|
154
|
+
"""Test pass information."""
|
155
|
+
opt = TypeSpecialization()
|
156
|
+
info = opt.get_info()
|
157
|
+
assert info.name == "type-specialization"
|
158
|
+
assert info.pass_type == PassType.OPTIMIZATION
|
159
|
+
assert info.preserves == PreservationLevel.NONE
|
160
|
+
|
161
|
+
def test_collect_type_signatures(self, module: MIRModule) -> None:
|
162
|
+
"""Test collecting type signatures from call sites."""
|
163
|
+
opt = TypeSpecialization()
|
164
|
+
|
165
|
+
# Create a caller function with typed calls
|
166
|
+
caller = MIRFunction("caller", [], MIRType.EMPTY)
|
167
|
+
block = BasicBlock("entry")
|
168
|
+
|
169
|
+
# Call add(1, 2) - both int
|
170
|
+
t1 = Temp(MIRType.INT)
|
171
|
+
block.add_instruction(Call(t1, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
|
172
|
+
|
173
|
+
# Call add(1.0, 2.0) - both float
|
174
|
+
t2 = Temp(MIRType.FLOAT)
|
175
|
+
block.add_instruction(Call(t2, "add", [Constant(1.0, MIRType.FLOAT), Constant(2.0, MIRType.FLOAT)], (1, 1)))
|
176
|
+
|
177
|
+
# Call add(1, 2) again - int
|
178
|
+
t3 = Temp(MIRType.INT)
|
179
|
+
block.add_instruction(Call(t3, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
|
180
|
+
|
181
|
+
caller.cfg.add_block(block)
|
182
|
+
caller.cfg.entry_block = block
|
183
|
+
module.add_function(caller)
|
184
|
+
|
185
|
+
# Collect signatures
|
186
|
+
opt._collect_type_signatures(module)
|
187
|
+
|
188
|
+
# Check collected signatures
|
189
|
+
assert "add" in opt.type_signatures
|
190
|
+
signatures = opt.type_signatures["add"]
|
191
|
+
|
192
|
+
# Should have two different signatures
|
193
|
+
assert len(signatures) == 2
|
194
|
+
|
195
|
+
# Check int signature (called twice)
|
196
|
+
int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
|
197
|
+
assert int_sig in signatures
|
198
|
+
assert signatures[int_sig] == 2
|
199
|
+
|
200
|
+
# Check float signature (called once)
|
201
|
+
float_sig = TypeSignature((MIRType.FLOAT, MIRType.FLOAT), MIRType.FLOAT)
|
202
|
+
assert float_sig in signatures
|
203
|
+
assert signatures[float_sig] == 1
|
204
|
+
|
205
|
+
def test_identify_candidates(self, module: MIRModule) -> None:
|
206
|
+
"""Test identifying specialization candidates."""
|
207
|
+
opt = TypeSpecialization(threshold=2)
|
208
|
+
|
209
|
+
# Set up type signatures
|
210
|
+
int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
|
211
|
+
float_sig = TypeSignature((MIRType.FLOAT, MIRType.FLOAT), MIRType.FLOAT)
|
212
|
+
|
213
|
+
opt.type_signatures["add"][int_sig] = 10 # Above threshold
|
214
|
+
opt.type_signatures["add"][float_sig] = 1 # Below threshold
|
215
|
+
|
216
|
+
candidates = opt._identify_candidates(module)
|
217
|
+
|
218
|
+
# Should only have one candidate (int signature)
|
219
|
+
assert len(candidates) == 1
|
220
|
+
candidate = candidates[0]
|
221
|
+
assert candidate.function_name == "add"
|
222
|
+
assert candidate.signature == int_sig
|
223
|
+
assert candidate.call_count == 10
|
224
|
+
|
225
|
+
def test_calculate_benefit(self, module: MIRModule) -> None:
|
226
|
+
"""Test benefit calculation for specialization."""
|
227
|
+
opt = TypeSpecialization()
|
228
|
+
|
229
|
+
# Test with specific type signature (high benefit)
|
230
|
+
int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
|
231
|
+
func = module.functions["add"]
|
232
|
+
benefit = opt._calculate_benefit(int_sig, 100, func)
|
233
|
+
assert benefit > 0
|
234
|
+
|
235
|
+
# Test with UNKNOWN types (lower benefit)
|
236
|
+
any_sig = TypeSignature((MIRType.UNKNOWN, MIRType.UNKNOWN), MIRType.UNKNOWN)
|
237
|
+
benefit_any = opt._calculate_benefit(any_sig, 100, func)
|
238
|
+
assert benefit_any <= benefit
|
239
|
+
|
240
|
+
def test_create_specialized_function(self, module: MIRModule) -> None:
|
241
|
+
"""Test creating a specialized function."""
|
242
|
+
opt = TypeSpecialization()
|
243
|
+
|
244
|
+
int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
|
245
|
+
candidate = SpecializationCandidate(
|
246
|
+
function_name="add",
|
247
|
+
signature=int_sig,
|
248
|
+
call_count=100,
|
249
|
+
benefit=0.8,
|
250
|
+
)
|
251
|
+
|
252
|
+
# Create specialization (returns True/False)
|
253
|
+
created = opt._create_specialization(module, candidate)
|
254
|
+
assert created
|
255
|
+
|
256
|
+
# Check that specialized function was added to module
|
257
|
+
specialized_name = candidate.specialized_name()
|
258
|
+
assert specialized_name in module.functions
|
259
|
+
|
260
|
+
specialized = module.functions[specialized_name]
|
261
|
+
assert specialized.name == "add__int_int"
|
262
|
+
assert len(specialized.params) == 2
|
263
|
+
assert specialized.params[0].type == MIRType.INT
|
264
|
+
assert specialized.params[1].type == MIRType.INT
|
265
|
+
# Note: return_type might be set differently during specialization
|
266
|
+
# Check that function exists instead
|
267
|
+
assert specialized.return_type is not None
|
268
|
+
|
269
|
+
# Check that blocks were copied
|
270
|
+
assert len(specialized.cfg.blocks) == 1
|
271
|
+
|
272
|
+
def test_update_call_sites(self, module: MIRModule) -> None:
|
273
|
+
"""Test updating call sites to use specialized functions."""
|
274
|
+
opt = TypeSpecialization()
|
275
|
+
|
276
|
+
# Create specialized function mapping
|
277
|
+
int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
|
278
|
+
opt.specializations["add"][int_sig] = "add__int_int"
|
279
|
+
|
280
|
+
# Create a caller with matching call
|
281
|
+
caller = MIRFunction("caller", [], MIRType.EMPTY)
|
282
|
+
block = BasicBlock("entry")
|
283
|
+
|
284
|
+
t1 = Temp(MIRType.INT)
|
285
|
+
call_inst = Call(t1, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1))
|
286
|
+
block.add_instruction(call_inst)
|
287
|
+
|
288
|
+
caller.cfg.add_block(block)
|
289
|
+
caller.cfg.entry_block = block
|
290
|
+
module.add_function(caller)
|
291
|
+
|
292
|
+
# Update call sites
|
293
|
+
opt._update_call_sites(module)
|
294
|
+
|
295
|
+
# Check that call was updated
|
296
|
+
updated_call = next(iter(block.instructions))
|
297
|
+
assert isinstance(updated_call, Call)
|
298
|
+
# Call has 'func' attribute which is a FunctionRef (with @ prefix)
|
299
|
+
assert isinstance(updated_call, Call)
|
300
|
+
assert str(updated_call.func) == "@add__int_int"
|
301
|
+
|
302
|
+
def test_run_on_module_with_profile(self, module: MIRModule) -> None:
|
303
|
+
"""Test running type specialization with profile data."""
|
304
|
+
# Create mock profile data
|
305
|
+
profile = MagicMock(spec=ProfileData)
|
306
|
+
profile.get_function_metrics = MagicMock(
|
307
|
+
return_value={
|
308
|
+
"call_count": 1000,
|
309
|
+
"type_signatures": {
|
310
|
+
((MIRType.INT, MIRType.INT), MIRType.INT): 800,
|
311
|
+
((MIRType.FLOAT, MIRType.FLOAT), MIRType.FLOAT): 200,
|
312
|
+
},
|
313
|
+
}
|
314
|
+
)
|
315
|
+
|
316
|
+
opt = TypeSpecialization(profile_data=profile, threshold=100)
|
317
|
+
|
318
|
+
# Run optimization
|
319
|
+
changed = opt.run_on_module(module)
|
320
|
+
|
321
|
+
# Should have analyzed functions (might not change if threshold not met)
|
322
|
+
assert opt.stats["functions_analyzed"] > 0
|
323
|
+
# Changed flag depends on whether specialization was created
|
324
|
+
if changed:
|
325
|
+
assert opt.stats["specializations_created"] > 0
|
326
|
+
|
327
|
+
def test_run_on_module_without_profile(self, module: MIRModule) -> None:
|
328
|
+
"""Test running type specialization without profile data."""
|
329
|
+
opt = TypeSpecialization(threshold=1)
|
330
|
+
|
331
|
+
# Add a caller to create type signatures
|
332
|
+
caller = MIRFunction("main", [], MIRType.EMPTY)
|
333
|
+
block = BasicBlock("entry")
|
334
|
+
|
335
|
+
# Multiple calls with int types
|
336
|
+
for _ in range(5):
|
337
|
+
t = Temp(MIRType.INT)
|
338
|
+
block.add_instruction(Call(t, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
|
339
|
+
|
340
|
+
caller.cfg.add_block(block)
|
341
|
+
caller.cfg.entry_block = block
|
342
|
+
module.add_function(caller)
|
343
|
+
|
344
|
+
# Run optimization
|
345
|
+
changed = opt.run_on_module(module)
|
346
|
+
|
347
|
+
# Should have analyzed functions
|
348
|
+
assert opt.stats["functions_analyzed"] > 0
|
349
|
+
|
350
|
+
# Check if specialization was created (depends on threshold)
|
351
|
+
if changed:
|
352
|
+
assert opt.stats["specializations_created"] > 0
|
353
|
+
|
354
|
+
def test_no_specialization_below_threshold(self, module: MIRModule) -> None:
|
355
|
+
"""Test that no specialization occurs below threshold."""
|
356
|
+
opt = TypeSpecialization(threshold=1000) # Very high threshold
|
357
|
+
|
358
|
+
# Add a caller with few calls
|
359
|
+
caller = MIRFunction("main", [], MIRType.EMPTY)
|
360
|
+
block = BasicBlock("entry")
|
361
|
+
|
362
|
+
t = Temp(MIRType.INT)
|
363
|
+
block.add_instruction(Call(t, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
|
364
|
+
|
365
|
+
caller.cfg.add_block(block)
|
366
|
+
caller.cfg.entry_block = block
|
367
|
+
module.add_function(caller)
|
368
|
+
|
369
|
+
# Run optimization
|
370
|
+
changed = opt.run_on_module(module)
|
371
|
+
|
372
|
+
# Should not have made changes
|
373
|
+
assert not changed
|
374
|
+
assert opt.stats["specializations_created"] == 0
|
375
|
+
|
376
|
+
def test_multiple_function_specialization(self, module: MIRModule) -> None:
|
377
|
+
"""Test specializing multiple functions."""
|
378
|
+
opt = TypeSpecialization(threshold=2)
|
379
|
+
|
380
|
+
# Add another function to specialize
|
381
|
+
mul_func = MIRFunction(
|
382
|
+
"multiply",
|
383
|
+
[Variable("x", MIRType.UNKNOWN), Variable("y", MIRType.UNKNOWN)],
|
384
|
+
MIRType.UNKNOWN,
|
385
|
+
)
|
386
|
+
block = BasicBlock("entry")
|
387
|
+
x = Variable("x", MIRType.UNKNOWN)
|
388
|
+
y = Variable("y", MIRType.UNKNOWN)
|
389
|
+
result = Temp(MIRType.UNKNOWN)
|
390
|
+
block.add_instruction(BinaryOp(result, "*", x, y, (1, 1)))
|
391
|
+
block.add_instruction(Return((1, 1), result))
|
392
|
+
mul_func.cfg.add_block(block)
|
393
|
+
mul_func.cfg.entry_block = block
|
394
|
+
module.add_function(mul_func)
|
395
|
+
|
396
|
+
# Add caller with calls to both functions
|
397
|
+
caller = MIRFunction("main", [], MIRType.EMPTY)
|
398
|
+
block = BasicBlock("entry")
|
399
|
+
|
400
|
+
# Call add multiple times
|
401
|
+
for _ in range(3):
|
402
|
+
t = Temp(MIRType.INT)
|
403
|
+
block.add_instruction(Call(t, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
|
404
|
+
|
405
|
+
# Call multiply multiple times
|
406
|
+
for _ in range(3):
|
407
|
+
t = Temp(MIRType.FLOAT)
|
408
|
+
block.add_instruction(
|
409
|
+
Call(t, "multiply", [Constant(1.0, MIRType.FLOAT), Constant(2.0, MIRType.FLOAT)], (1, 1))
|
410
|
+
)
|
411
|
+
|
412
|
+
caller.cfg.add_block(block)
|
413
|
+
caller.cfg.entry_block = block
|
414
|
+
module.add_function(caller)
|
415
|
+
|
416
|
+
# Run optimization
|
417
|
+
changed = opt.run_on_module(module)
|
418
|
+
|
419
|
+
# Should have specialized both functions
|
420
|
+
assert changed
|
421
|
+
assert opt.stats["functions_specialized"] >= 2
|