machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,433 @@
|
|
1
|
+
"""Tests for SSA construction."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from machine_dialect.mir.basic_block import CFG, BasicBlock
|
6
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
7
|
+
from machine_dialect.mir.mir_instructions import (
|
8
|
+
BinaryOp,
|
9
|
+
ConditionalJump,
|
10
|
+
Copy,
|
11
|
+
Jump,
|
12
|
+
LoadConst,
|
13
|
+
Return,
|
14
|
+
StoreVar,
|
15
|
+
)
|
16
|
+
from machine_dialect.mir.mir_types import MIRType
|
17
|
+
from machine_dialect.mir.mir_values import Constant, Variable
|
18
|
+
from machine_dialect.mir.ssa_construction import DominanceInfo, construct_ssa
|
19
|
+
|
20
|
+
|
21
|
+
class TestDominanceInfo:
|
22
|
+
"""Test dominance analysis."""
|
23
|
+
|
24
|
+
def test_simple_dominance(self) -> None:
|
25
|
+
"""Test dominance in a simple CFG."""
|
26
|
+
# Create CFG: entry -> block1 -> exit
|
27
|
+
cfg = CFG()
|
28
|
+
entry = BasicBlock("entry")
|
29
|
+
block1 = BasicBlock("block1")
|
30
|
+
exit_block = BasicBlock("exit")
|
31
|
+
|
32
|
+
cfg.add_block(entry)
|
33
|
+
cfg.add_block(block1)
|
34
|
+
cfg.add_block(exit_block)
|
35
|
+
cfg.set_entry_block(entry)
|
36
|
+
|
37
|
+
cfg.connect(entry, block1)
|
38
|
+
cfg.connect(block1, exit_block)
|
39
|
+
|
40
|
+
# Compute dominance
|
41
|
+
dom_info = DominanceInfo(cfg)
|
42
|
+
|
43
|
+
# Entry dominates all blocks
|
44
|
+
assert dom_info.dominates(entry, entry)
|
45
|
+
assert dom_info.dominates(entry, block1)
|
46
|
+
assert dom_info.dominates(entry, exit_block)
|
47
|
+
|
48
|
+
# Block1 dominates exit but not entry
|
49
|
+
assert not dom_info.dominates(block1, entry)
|
50
|
+
assert dom_info.dominates(block1, exit_block)
|
51
|
+
|
52
|
+
# Exit dominates only itself
|
53
|
+
assert not dom_info.dominates(exit_block, entry)
|
54
|
+
assert not dom_info.dominates(exit_block, block1)
|
55
|
+
|
56
|
+
def test_dominance_with_branch(self) -> None:
|
57
|
+
"""Test dominance in CFG with branches."""
|
58
|
+
# Create diamond CFG:
|
59
|
+
# entry
|
60
|
+
# / \
|
61
|
+
# then else
|
62
|
+
# \ /
|
63
|
+
# merge
|
64
|
+
cfg = CFG()
|
65
|
+
entry = BasicBlock("entry")
|
66
|
+
then_block = BasicBlock("then")
|
67
|
+
else_block = BasicBlock("else")
|
68
|
+
merge_block = BasicBlock("merge")
|
69
|
+
|
70
|
+
cfg.add_block(entry)
|
71
|
+
cfg.add_block(then_block)
|
72
|
+
cfg.add_block(else_block)
|
73
|
+
cfg.add_block(merge_block)
|
74
|
+
cfg.set_entry_block(entry)
|
75
|
+
|
76
|
+
cfg.connect(entry, then_block)
|
77
|
+
cfg.connect(entry, else_block)
|
78
|
+
cfg.connect(then_block, merge_block)
|
79
|
+
cfg.connect(else_block, merge_block)
|
80
|
+
|
81
|
+
# Compute dominance
|
82
|
+
dom_info = DominanceInfo(cfg)
|
83
|
+
|
84
|
+
# Entry dominates all
|
85
|
+
assert dom_info.dominates(entry, then_block)
|
86
|
+
assert dom_info.dominates(entry, else_block)
|
87
|
+
assert dom_info.dominates(entry, merge_block)
|
88
|
+
|
89
|
+
# Neither branch dominates the other or merge
|
90
|
+
assert not dom_info.dominates(then_block, else_block)
|
91
|
+
assert not dom_info.dominates(else_block, then_block)
|
92
|
+
assert not dom_info.dominates(then_block, merge_block)
|
93
|
+
assert not dom_info.dominates(else_block, merge_block)
|
94
|
+
|
95
|
+
def test_dominance_frontier(self) -> None:
|
96
|
+
"""Test dominance frontier calculation."""
|
97
|
+
# Create diamond CFG
|
98
|
+
cfg = CFG()
|
99
|
+
entry = BasicBlock("entry")
|
100
|
+
then_block = BasicBlock("then")
|
101
|
+
else_block = BasicBlock("else")
|
102
|
+
merge_block = BasicBlock("merge")
|
103
|
+
|
104
|
+
cfg.add_block(entry)
|
105
|
+
cfg.add_block(then_block)
|
106
|
+
cfg.add_block(else_block)
|
107
|
+
cfg.add_block(merge_block)
|
108
|
+
cfg.set_entry_block(entry)
|
109
|
+
|
110
|
+
cfg.connect(entry, then_block)
|
111
|
+
cfg.connect(entry, else_block)
|
112
|
+
cfg.connect(then_block, merge_block)
|
113
|
+
cfg.connect(else_block, merge_block)
|
114
|
+
|
115
|
+
# Compute dominance
|
116
|
+
dom_info = DominanceInfo(cfg)
|
117
|
+
|
118
|
+
# Merge is in dominance frontier of then and else
|
119
|
+
assert merge_block in dom_info.dominance_frontier[then_block]
|
120
|
+
assert merge_block in dom_info.dominance_frontier[else_block]
|
121
|
+
|
122
|
+
# Entry and merge have empty frontiers
|
123
|
+
assert len(dom_info.dominance_frontier[entry]) == 0
|
124
|
+
assert len(dom_info.dominance_frontier[merge_block]) == 0
|
125
|
+
|
126
|
+
|
127
|
+
class TestSSAConstruction:
|
128
|
+
"""Test SSA construction."""
|
129
|
+
|
130
|
+
def test_simple_ssa_construction(self) -> None:
|
131
|
+
"""Test SSA construction for simple function."""
|
132
|
+
# Create function with single variable assignment
|
133
|
+
func = MIRFunction("test", [], MIRType.EMPTY)
|
134
|
+
|
135
|
+
# Create blocks
|
136
|
+
entry = BasicBlock("entry")
|
137
|
+
func.cfg.add_block(entry)
|
138
|
+
func.cfg.set_entry_block(entry)
|
139
|
+
|
140
|
+
# Create variable
|
141
|
+
x = Variable("x", MIRType.INT)
|
142
|
+
func.add_local(x)
|
143
|
+
|
144
|
+
# Add instructions: x = 1; return x
|
145
|
+
const1 = Constant(1, MIRType.INT)
|
146
|
+
entry.add_instruction(StoreVar(x, const1, (1, 1)))
|
147
|
+
temp = func.new_temp(MIRType.INT)
|
148
|
+
entry.add_instruction(Copy(temp, x, (1, 1)))
|
149
|
+
entry.add_instruction(Return((1, 1), temp))
|
150
|
+
|
151
|
+
# Convert to SSA
|
152
|
+
construct_ssa(func)
|
153
|
+
|
154
|
+
# Should have versioned variables
|
155
|
+
# Check that we have SSA form (no verification of exact names)
|
156
|
+
assert len(entry.instructions) > 0
|
157
|
+
|
158
|
+
def test_phi_insertion(self) -> None:
|
159
|
+
"""Test phi node insertion at join points."""
|
160
|
+
# Create function with diamond CFG
|
161
|
+
func = MIRFunction("test", [], MIRType.INT)
|
162
|
+
|
163
|
+
# Create blocks
|
164
|
+
entry = BasicBlock("entry")
|
165
|
+
then_block = BasicBlock("then")
|
166
|
+
else_block = BasicBlock("else")
|
167
|
+
merge_block = BasicBlock("merge")
|
168
|
+
|
169
|
+
func.cfg.add_block(entry)
|
170
|
+
func.cfg.add_block(then_block)
|
171
|
+
func.cfg.add_block(else_block)
|
172
|
+
func.cfg.add_block(merge_block)
|
173
|
+
func.cfg.set_entry_block(entry)
|
174
|
+
|
175
|
+
func.cfg.connect(entry, then_block)
|
176
|
+
func.cfg.connect(entry, else_block)
|
177
|
+
func.cfg.connect(then_block, merge_block)
|
178
|
+
func.cfg.connect(else_block, merge_block)
|
179
|
+
|
180
|
+
# Create variable
|
181
|
+
x = Variable("x", MIRType.INT)
|
182
|
+
func.add_local(x)
|
183
|
+
|
184
|
+
# Add conditional jump in entry
|
185
|
+
cond = func.new_temp(MIRType.BOOL)
|
186
|
+
entry.add_instruction(LoadConst(cond, True, (1, 1)))
|
187
|
+
entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
|
188
|
+
|
189
|
+
# Assign different values in branches
|
190
|
+
const1 = Constant(1, MIRType.INT)
|
191
|
+
const2 = Constant(2, MIRType.INT)
|
192
|
+
then_block.add_instruction(StoreVar(x, const1, (1, 1)))
|
193
|
+
then_block.add_instruction(Jump("merge", (1, 1)))
|
194
|
+
|
195
|
+
else_block.add_instruction(StoreVar(x, const2, (1, 1)))
|
196
|
+
else_block.add_instruction(Jump("merge", (1, 1)))
|
197
|
+
|
198
|
+
# Use x in merge
|
199
|
+
result = func.new_temp(MIRType.INT)
|
200
|
+
merge_block.add_instruction(Copy(result, x, (1, 1)))
|
201
|
+
merge_block.add_instruction(Return((1, 1), result))
|
202
|
+
|
203
|
+
# Convert to SSA
|
204
|
+
construct_ssa(func)
|
205
|
+
|
206
|
+
# Merge block should have phi node (in phi_nodes list, not instructions)
|
207
|
+
assert len(merge_block.phi_nodes) > 0
|
208
|
+
|
209
|
+
# Phi should have incoming values from both predecessors
|
210
|
+
if merge_block.phi_nodes:
|
211
|
+
phi = merge_block.phi_nodes[0]
|
212
|
+
assert len(phi.incoming) == 2
|
213
|
+
incoming_labels = {label for _, label in phi.incoming}
|
214
|
+
assert "then" in incoming_labels
|
215
|
+
assert "else" in incoming_labels
|
216
|
+
|
217
|
+
def test_ssa_with_loops(self) -> None:
|
218
|
+
"""Test SSA construction with loop (self-referencing phi)."""
|
219
|
+
# Create function with loop
|
220
|
+
func = MIRFunction("test", [], MIRType.EMPTY)
|
221
|
+
|
222
|
+
# Create blocks for loop
|
223
|
+
entry = BasicBlock("entry")
|
224
|
+
loop_header = BasicBlock("loop_header")
|
225
|
+
loop_body = BasicBlock("loop_body")
|
226
|
+
exit_block = BasicBlock("exit")
|
227
|
+
|
228
|
+
func.cfg.add_block(entry)
|
229
|
+
func.cfg.add_block(loop_header)
|
230
|
+
func.cfg.add_block(loop_body)
|
231
|
+
func.cfg.add_block(exit_block)
|
232
|
+
func.cfg.set_entry_block(entry)
|
233
|
+
|
234
|
+
# Connect for loop structure
|
235
|
+
func.cfg.connect(entry, loop_header)
|
236
|
+
func.cfg.connect(loop_header, loop_body)
|
237
|
+
func.cfg.connect(loop_header, exit_block)
|
238
|
+
func.cfg.connect(loop_body, loop_header) # Back edge
|
239
|
+
|
240
|
+
# Create loop counter variable
|
241
|
+
i = Variable("i", MIRType.INT)
|
242
|
+
func.add_local(i)
|
243
|
+
|
244
|
+
# Initialize counter in entry
|
245
|
+
const0 = Constant(0, MIRType.INT)
|
246
|
+
entry.add_instruction(StoreVar(i, const0, (1, 1)))
|
247
|
+
entry.add_instruction(Jump("loop_header", (1, 1)))
|
248
|
+
|
249
|
+
# Loop header checks condition
|
250
|
+
cond = func.new_temp(MIRType.BOOL)
|
251
|
+
ten = Constant(10, MIRType.INT)
|
252
|
+
loop_header.add_instruction(BinaryOp(cond, "<", i, ten, (1, 1)))
|
253
|
+
loop_header.add_instruction(ConditionalJump(cond, "loop_body", (1, 1), "exit"))
|
254
|
+
|
255
|
+
# Loop body increments counter
|
256
|
+
one = Constant(1, MIRType.INT)
|
257
|
+
new_i = func.new_temp(MIRType.INT)
|
258
|
+
loop_body.add_instruction(BinaryOp(new_i, "+", i, one, (1, 1)))
|
259
|
+
loop_body.add_instruction(StoreVar(i, new_i, (1, 1)))
|
260
|
+
loop_body.add_instruction(Jump("loop_header", (1, 1)))
|
261
|
+
|
262
|
+
# Exit
|
263
|
+
exit_block.add_instruction(Return((1, 1)))
|
264
|
+
|
265
|
+
# Convert to SSA
|
266
|
+
construct_ssa(func)
|
267
|
+
|
268
|
+
# Loop header should have phi node for loop variable (in phi_nodes list)
|
269
|
+
assert len(loop_header.phi_nodes) > 0
|
270
|
+
|
271
|
+
# Phi should have incoming from entry and loop_body
|
272
|
+
if loop_header.phi_nodes:
|
273
|
+
phi = loop_header.phi_nodes[0]
|
274
|
+
incoming_labels = {label for _, label in phi.incoming}
|
275
|
+
assert "entry" in incoming_labels
|
276
|
+
assert "loop_body" in incoming_labels
|
277
|
+
|
278
|
+
def test_multiple_variables_ssa(self) -> None:
|
279
|
+
"""Test SSA construction with multiple variables."""
|
280
|
+
# Create function with multiple variables
|
281
|
+
func = MIRFunction("test", [], MIRType.EMPTY)
|
282
|
+
|
283
|
+
# Create simple CFG
|
284
|
+
entry = BasicBlock("entry")
|
285
|
+
block1 = BasicBlock("block1")
|
286
|
+
|
287
|
+
func.cfg.add_block(entry)
|
288
|
+
func.cfg.add_block(block1)
|
289
|
+
func.cfg.set_entry_block(entry)
|
290
|
+
func.cfg.connect(entry, block1)
|
291
|
+
|
292
|
+
# Create multiple variables
|
293
|
+
x = Variable("x", MIRType.INT)
|
294
|
+
y = Variable("y", MIRType.INT)
|
295
|
+
z = Variable("z", MIRType.INT)
|
296
|
+
|
297
|
+
func.add_local(x)
|
298
|
+
func.add_local(y)
|
299
|
+
func.add_local(z)
|
300
|
+
|
301
|
+
# Assign values in entry
|
302
|
+
const1 = Constant(1, MIRType.INT)
|
303
|
+
const2 = Constant(2, MIRType.INT)
|
304
|
+
|
305
|
+
entry.add_instruction(StoreVar(x, const1, (1, 1)))
|
306
|
+
entry.add_instruction(StoreVar(y, const2, (1, 1)))
|
307
|
+
|
308
|
+
# Compute z = x + y
|
309
|
+
temp = func.new_temp(MIRType.INT)
|
310
|
+
entry.add_instruction(BinaryOp(temp, "+", x, y, (1, 1)))
|
311
|
+
entry.add_instruction(StoreVar(z, temp, (1, 1)))
|
312
|
+
entry.add_instruction(Jump("block1", (1, 1)))
|
313
|
+
|
314
|
+
# Use all variables in block1
|
315
|
+
result = func.new_temp(MIRType.INT)
|
316
|
+
block1.add_instruction(BinaryOp(result, "+", z, x, (1, 1)))
|
317
|
+
block1.add_instruction(Return((1, 1), result))
|
318
|
+
|
319
|
+
# Convert to SSA
|
320
|
+
construct_ssa(func)
|
321
|
+
|
322
|
+
# Should complete without errors
|
323
|
+
assert True
|
324
|
+
|
325
|
+
def test_ssa_preserves_semantics(self) -> None:
|
326
|
+
"""Test that SSA construction preserves program semantics."""
|
327
|
+
# Create function that computes: if (cond) x = 1 else x = 2; return x
|
328
|
+
func = MIRFunction("test", [], MIRType.INT)
|
329
|
+
|
330
|
+
# Create diamond CFG
|
331
|
+
entry = BasicBlock("entry")
|
332
|
+
then_block = BasicBlock("then")
|
333
|
+
else_block = BasicBlock("else")
|
334
|
+
merge_block = BasicBlock("merge")
|
335
|
+
|
336
|
+
func.cfg.add_block(entry)
|
337
|
+
func.cfg.add_block(then_block)
|
338
|
+
func.cfg.add_block(else_block)
|
339
|
+
func.cfg.add_block(merge_block)
|
340
|
+
func.cfg.set_entry_block(entry)
|
341
|
+
|
342
|
+
func.cfg.connect(entry, then_block)
|
343
|
+
func.cfg.connect(entry, else_block)
|
344
|
+
func.cfg.connect(then_block, merge_block)
|
345
|
+
func.cfg.connect(else_block, merge_block)
|
346
|
+
|
347
|
+
# Create variable
|
348
|
+
x = Variable("x", MIRType.INT)
|
349
|
+
func.add_local(x)
|
350
|
+
|
351
|
+
# Count instructions before SSA
|
352
|
+
total_before = sum(len(b.instructions) for b in func.cfg.blocks.values())
|
353
|
+
|
354
|
+
# Add instructions
|
355
|
+
cond = func.new_temp(MIRType.BOOL)
|
356
|
+
entry.add_instruction(LoadConst(cond, True, (1, 1)))
|
357
|
+
entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
|
358
|
+
|
359
|
+
const1 = Constant(1, MIRType.INT)
|
360
|
+
const2 = Constant(2, MIRType.INT)
|
361
|
+
|
362
|
+
then_block.add_instruction(StoreVar(x, const1, (1, 1)))
|
363
|
+
then_block.add_instruction(Jump("merge", (1, 1)))
|
364
|
+
|
365
|
+
else_block.add_instruction(StoreVar(x, const2, (1, 1)))
|
366
|
+
else_block.add_instruction(Jump("merge", (1, 1)))
|
367
|
+
|
368
|
+
result = func.new_temp(MIRType.INT)
|
369
|
+
merge_block.add_instruction(Copy(result, x, (1, 1)))
|
370
|
+
merge_block.add_instruction(Return((1, 1), result))
|
371
|
+
|
372
|
+
# Convert to SSA
|
373
|
+
construct_ssa(func)
|
374
|
+
|
375
|
+
# Should have added at least one phi node
|
376
|
+
total_after = sum(len(b.instructions) for b in func.cfg.blocks.values())
|
377
|
+
assert total_after > total_before
|
378
|
+
|
379
|
+
# Should still have return instruction
|
380
|
+
returns = []
|
381
|
+
for block in func.cfg.blocks.values():
|
382
|
+
returns.extend([inst for inst in block.instructions if isinstance(inst, Return)])
|
383
|
+
assert len(returns) == 1
|
384
|
+
|
385
|
+
def test_loadconst_preservation(self) -> None:
|
386
|
+
"""Test that SSA construction preserves LoadConst instructions."""
|
387
|
+
# Create function with LoadConst instructions for constants
|
388
|
+
func = MIRFunction("test_const", [], MIRType.INT)
|
389
|
+
|
390
|
+
# Create simple CFG: entry -> exit
|
391
|
+
entry = BasicBlock("entry")
|
392
|
+
func.cfg.add_block(entry)
|
393
|
+
func.cfg.set_entry_block(entry)
|
394
|
+
|
395
|
+
# Create temporaries for constants
|
396
|
+
t0 = func.new_temp(MIRType.INT)
|
397
|
+
t1 = func.new_temp(MIRType.INT)
|
398
|
+
t2 = func.new_temp(MIRType.BOOL)
|
399
|
+
|
400
|
+
# Add LoadConst instructions
|
401
|
+
const5 = Constant(5, MIRType.INT)
|
402
|
+
const1 = Constant(1, MIRType.INT)
|
403
|
+
|
404
|
+
entry.add_instruction(LoadConst(t0, const5, (1, 1)))
|
405
|
+
entry.add_instruction(LoadConst(t1, const1, (1, 1)))
|
406
|
+
|
407
|
+
# Add binary operation using the loaded constants
|
408
|
+
entry.add_instruction(BinaryOp(t2, "<=", t0, t1, (1, 1)))
|
409
|
+
|
410
|
+
# Return the result
|
411
|
+
entry.add_instruction(Return((1, 1), t2))
|
412
|
+
|
413
|
+
# Count LoadConst instructions before SSA
|
414
|
+
loadconst_before = []
|
415
|
+
for block in func.cfg.blocks.values():
|
416
|
+
loadconst_before.extend([inst for inst in block.instructions if isinstance(inst, LoadConst)])
|
417
|
+
|
418
|
+
# Apply SSA construction
|
419
|
+
construct_ssa(func)
|
420
|
+
|
421
|
+
# Count LoadConst instructions after SSA
|
422
|
+
loadconst_after = []
|
423
|
+
for block in func.cfg.blocks.values():
|
424
|
+
loadconst_after.extend([inst for inst in block.instructions if isinstance(inst, LoadConst)])
|
425
|
+
|
426
|
+
# Verify LoadConst instructions are preserved
|
427
|
+
assert len(loadconst_before) == 2, "Should have 2 LoadConst instructions before SSA"
|
428
|
+
assert len(loadconst_after) == 2, "Should have 2 LoadConst instructions after SSA"
|
429
|
+
|
430
|
+
# Verify the constants are still correct
|
431
|
+
const_values = [inst.constant.value for inst in loadconst_after if hasattr(inst.constant, "value")]
|
432
|
+
assert 5 in const_values, "Constant 5 should be preserved"
|
433
|
+
assert 1 in const_values, "Constant 1 should be preserved"
|
@@ -0,0 +1,236 @@
|
|
1
|
+
"""Tests for tail call optimization."""
|
2
|
+
|
3
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
4
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
5
|
+
from machine_dialect.mir.mir_instructions import Call, Copy, Return
|
6
|
+
from machine_dialect.mir.mir_module import MIRModule
|
7
|
+
from machine_dialect.mir.mir_types import MIRType
|
8
|
+
from machine_dialect.mir.mir_values import Constant, Temp, Variable
|
9
|
+
from machine_dialect.mir.optimizations.tail_call import TailCallOptimization
|
10
|
+
|
11
|
+
|
12
|
+
def test_simple_tail_call() -> None:
|
13
|
+
"""Test detection of simple tail call pattern."""
|
14
|
+
# Create a function with a tail call
|
15
|
+
module = MIRModule("test")
|
16
|
+
func = MIRFunction("factorial", [Variable("n", MIRType.INT)])
|
17
|
+
module.add_function(func)
|
18
|
+
|
19
|
+
# Create basic block with tail call pattern
|
20
|
+
block = BasicBlock("entry")
|
21
|
+
|
22
|
+
# result = call factorial(n-1)
|
23
|
+
result = Temp(MIRType.INT, 0)
|
24
|
+
call_inst = Call(result, "factorial", [Variable("n", MIRType.INT)], (1, 1))
|
25
|
+
block.add_instruction(call_inst)
|
26
|
+
|
27
|
+
# return result
|
28
|
+
block.add_instruction(Return((1, 1), result))
|
29
|
+
|
30
|
+
func.cfg.add_block(block)
|
31
|
+
func.cfg.entry_block = block
|
32
|
+
|
33
|
+
# Run optimization
|
34
|
+
optimizer = TailCallOptimization()
|
35
|
+
modified = optimizer.run_on_module(module)
|
36
|
+
|
37
|
+
# Check that the call was marked as tail call
|
38
|
+
assert modified
|
39
|
+
assert call_inst.is_tail_call
|
40
|
+
assert optimizer.stats["tail_calls_found"] == 1
|
41
|
+
assert optimizer.stats["recursive_tail_calls"] == 1
|
42
|
+
|
43
|
+
|
44
|
+
def test_tail_call_with_copy() -> None:
|
45
|
+
"""Test detection of tail call with intermediate copy."""
|
46
|
+
module = MIRModule("test")
|
47
|
+
func = MIRFunction("process", [Variable("x", MIRType.INT)])
|
48
|
+
module.add_function(func)
|
49
|
+
|
50
|
+
block = BasicBlock("entry")
|
51
|
+
|
52
|
+
# temp = call helper(x)
|
53
|
+
temp = Temp(MIRType.INT, 0)
|
54
|
+
call_inst = Call(temp, "helper", [Variable("x", MIRType.INT)], (1, 1))
|
55
|
+
block.add_instruction(call_inst)
|
56
|
+
|
57
|
+
# result = temp
|
58
|
+
result = Variable("result", MIRType.INT)
|
59
|
+
block.add_instruction(Copy(result, temp, (1, 1)))
|
60
|
+
|
61
|
+
# return result
|
62
|
+
block.add_instruction(Return((1, 1), result))
|
63
|
+
|
64
|
+
func.cfg.add_block(block)
|
65
|
+
func.cfg.entry_block = block
|
66
|
+
|
67
|
+
# Run optimization
|
68
|
+
optimizer = TailCallOptimization()
|
69
|
+
modified = optimizer.run_on_module(module)
|
70
|
+
|
71
|
+
# Check that the call was marked as tail call
|
72
|
+
assert modified
|
73
|
+
assert call_inst.is_tail_call
|
74
|
+
assert optimizer.stats["tail_calls_found"] == 1
|
75
|
+
|
76
|
+
|
77
|
+
def test_void_tail_call() -> None:
|
78
|
+
"""Test detection of void tail call (no return value)."""
|
79
|
+
module = MIRModule("test")
|
80
|
+
func = MIRFunction("cleanup", [])
|
81
|
+
module.add_function(func)
|
82
|
+
|
83
|
+
block = BasicBlock("entry")
|
84
|
+
|
85
|
+
# call finalize()
|
86
|
+
call_inst = Call(None, "finalize", [], (1, 1))
|
87
|
+
block.add_instruction(call_inst)
|
88
|
+
|
89
|
+
# return
|
90
|
+
block.add_instruction(Return((1, 1), None))
|
91
|
+
|
92
|
+
func.cfg.add_block(block)
|
93
|
+
func.cfg.entry_block = block
|
94
|
+
|
95
|
+
# Run optimization
|
96
|
+
optimizer = TailCallOptimization()
|
97
|
+
modified = optimizer.run_on_module(module)
|
98
|
+
|
99
|
+
# Check that the call was marked as tail call
|
100
|
+
assert modified
|
101
|
+
assert call_inst.is_tail_call
|
102
|
+
assert optimizer.stats["tail_calls_found"] == 1
|
103
|
+
|
104
|
+
|
105
|
+
def test_non_tail_call() -> None:
|
106
|
+
"""Test that non-tail calls are not marked."""
|
107
|
+
module = MIRModule("test")
|
108
|
+
func = MIRFunction("compute", [Variable("x", MIRType.INT)])
|
109
|
+
module.add_function(func)
|
110
|
+
|
111
|
+
block = BasicBlock("entry")
|
112
|
+
|
113
|
+
# temp = call helper(x)
|
114
|
+
temp = Temp(MIRType.INT, 0)
|
115
|
+
call_inst = Call(temp, "helper", [Variable("x", MIRType.INT)], (1, 1))
|
116
|
+
block.add_instruction(call_inst)
|
117
|
+
|
118
|
+
# result = temp + 1 (additional computation after call)
|
119
|
+
# We would add a BinaryOp here in real code
|
120
|
+
|
121
|
+
# return something else
|
122
|
+
block.add_instruction(Return((1, 1), Constant(42, MIRType.INT)))
|
123
|
+
|
124
|
+
func.cfg.add_block(block)
|
125
|
+
func.cfg.entry_block = block
|
126
|
+
|
127
|
+
# Run optimization
|
128
|
+
optimizer = TailCallOptimization()
|
129
|
+
modified = optimizer.run_on_module(module)
|
130
|
+
|
131
|
+
# Check that the call was NOT marked as tail call
|
132
|
+
assert not modified
|
133
|
+
assert not call_inst.is_tail_call
|
134
|
+
assert optimizer.stats["tail_calls_found"] == 0
|
135
|
+
|
136
|
+
|
137
|
+
def test_multiple_tail_calls() -> None:
|
138
|
+
"""Test function with multiple tail calls in different blocks."""
|
139
|
+
module = MIRModule("test")
|
140
|
+
func = MIRFunction("fibonacci", [Variable("n", MIRType.INT)])
|
141
|
+
module.add_function(func)
|
142
|
+
|
143
|
+
# Block 1: tail call to fib(n-1)
|
144
|
+
block1 = BasicBlock("block1")
|
145
|
+
temp1 = Temp(MIRType.INT, 0)
|
146
|
+
call1 = Call(temp1, "fibonacci", [Variable("n", MIRType.INT)], (1, 1))
|
147
|
+
block1.add_instruction(call1)
|
148
|
+
block1.add_instruction(Return((1, 1), temp1))
|
149
|
+
|
150
|
+
# Block 2: tail call to fib(n-2)
|
151
|
+
block2 = BasicBlock("block2")
|
152
|
+
temp2 = Temp(MIRType.INT, 1)
|
153
|
+
call2 = Call(temp2, "fibonacci", [Variable("n", MIRType.INT)], (1, 1))
|
154
|
+
block2.add_instruction(call2)
|
155
|
+
block2.add_instruction(Return((1, 1), temp2))
|
156
|
+
|
157
|
+
func.cfg.add_block(block1)
|
158
|
+
func.cfg.add_block(block2)
|
159
|
+
func.cfg.entry_block = block1
|
160
|
+
|
161
|
+
# Run optimization
|
162
|
+
optimizer = TailCallOptimization()
|
163
|
+
modified = optimizer.run_on_module(module)
|
164
|
+
|
165
|
+
# Check that both calls were marked as tail calls
|
166
|
+
assert modified
|
167
|
+
assert call1.is_tail_call
|
168
|
+
assert call2.is_tail_call
|
169
|
+
assert optimizer.stats["tail_calls_found"] == 2
|
170
|
+
assert optimizer.stats["recursive_tail_calls"] == 2
|
171
|
+
|
172
|
+
|
173
|
+
def test_mutual_recursion() -> None:
|
174
|
+
"""Test tail calls in mutually recursive functions."""
|
175
|
+
module = MIRModule("test")
|
176
|
+
|
177
|
+
# Function even calls odd
|
178
|
+
even_func = MIRFunction("even", [Variable("n", MIRType.INT)])
|
179
|
+
even_block = BasicBlock("entry")
|
180
|
+
even_result = Temp(MIRType.BOOL, 0)
|
181
|
+
even_call = Call(even_result, "odd", [Variable("n", MIRType.INT)], (1, 1))
|
182
|
+
even_block.add_instruction(even_call)
|
183
|
+
even_block.add_instruction(Return((1, 1), even_result))
|
184
|
+
even_func.cfg.add_block(even_block)
|
185
|
+
even_func.cfg.entry_block = even_block
|
186
|
+
module.add_function(even_func)
|
187
|
+
|
188
|
+
# Function odd calls even
|
189
|
+
odd_func = MIRFunction("odd", [Variable("n", MIRType.INT)])
|
190
|
+
odd_block = BasicBlock("entry")
|
191
|
+
odd_result = Temp(MIRType.BOOL, 1)
|
192
|
+
odd_call = Call(odd_result, "even", [Variable("n", MIRType.INT)], (1, 1))
|
193
|
+
odd_block.add_instruction(odd_call)
|
194
|
+
odd_block.add_instruction(Return((1, 1), odd_result))
|
195
|
+
odd_func.cfg.add_block(odd_block)
|
196
|
+
odd_func.cfg.entry_block = odd_block
|
197
|
+
module.add_function(odd_func)
|
198
|
+
|
199
|
+
# Run optimization
|
200
|
+
optimizer = TailCallOptimization()
|
201
|
+
modified = optimizer.run_on_module(module)
|
202
|
+
|
203
|
+
# Check that both mutual recursive calls were marked as tail calls
|
204
|
+
assert modified
|
205
|
+
assert even_call.is_tail_call
|
206
|
+
assert odd_call.is_tail_call
|
207
|
+
assert optimizer.stats["tail_calls_found"] == 2
|
208
|
+
# These are not self-recursive, so recursive count should be 0
|
209
|
+
assert optimizer.stats["recursive_tail_calls"] == 0
|
210
|
+
|
211
|
+
|
212
|
+
def test_already_optimized() -> None:
|
213
|
+
"""Test that already marked tail calls are not counted again."""
|
214
|
+
module = MIRModule("test")
|
215
|
+
func = MIRFunction("test", [])
|
216
|
+
module.add_function(func)
|
217
|
+
|
218
|
+
block = BasicBlock("entry")
|
219
|
+
|
220
|
+
# Create a call already marked as tail call
|
221
|
+
result = Temp(MIRType.INT, 0)
|
222
|
+
call_inst = Call(result, "helper", [], (1, 1), is_tail_call=True)
|
223
|
+
block.add_instruction(call_inst)
|
224
|
+
block.add_instruction(Return((1, 1), result))
|
225
|
+
|
226
|
+
func.cfg.add_block(block)
|
227
|
+
func.cfg.entry_block = block
|
228
|
+
|
229
|
+
# Run optimization
|
230
|
+
optimizer = TailCallOptimization()
|
231
|
+
modified = optimizer.run_on_module(module)
|
232
|
+
|
233
|
+
# Check that no modifications were made
|
234
|
+
assert not modified
|
235
|
+
assert call_inst.is_tail_call # Still marked
|
236
|
+
assert optimizer.stats["tail_calls_found"] == 0 # Not counted as new
|