machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,483 @@
|
|
1
|
+
"""Tests for loop unrolling optimization pass."""
|
2
|
+
|
3
|
+
from unittest.mock import Mock, patch
|
4
|
+
|
5
|
+
from machine_dialect.mir.analyses.loop_analysis import Loop
|
6
|
+
from machine_dialect.mir.basic_block import CFG, BasicBlock
|
7
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
8
|
+
from machine_dialect.mir.mir_instructions import (
|
9
|
+
BinaryOp,
|
10
|
+
ConditionalJump,
|
11
|
+
Jump,
|
12
|
+
LoadConst,
|
13
|
+
)
|
14
|
+
from machine_dialect.mir.mir_types import MIRType
|
15
|
+
from machine_dialect.mir.mir_values import Constant, Temp
|
16
|
+
from machine_dialect.mir.optimization_pass import PassInfo, PassType, PreservationLevel
|
17
|
+
from machine_dialect.mir.optimizations.loop_unrolling import LoopUnrolling
|
18
|
+
|
19
|
+
|
20
|
+
class TestLoopUnrolling:
|
21
|
+
"""Test loop unrolling optimization pass."""
|
22
|
+
|
23
|
+
def test_initialization(self) -> None:
|
24
|
+
"""Test loop unrolling pass initialization."""
|
25
|
+
pass_instance = LoopUnrolling()
|
26
|
+
|
27
|
+
assert pass_instance.unroll_threshold == 4
|
28
|
+
assert pass_instance.max_body_size == 20
|
29
|
+
assert pass_instance.stats == {"unrolled": 0, "loops_processed": 0}
|
30
|
+
|
31
|
+
def test_get_info(self) -> None:
|
32
|
+
"""Test getting pass information."""
|
33
|
+
pass_instance = LoopUnrolling()
|
34
|
+
info = pass_instance.get_info()
|
35
|
+
|
36
|
+
assert isinstance(info, PassInfo)
|
37
|
+
assert info.name == "loop-unrolling"
|
38
|
+
assert info.description == "Unroll small loops to reduce overhead"
|
39
|
+
assert info.pass_type == PassType.OPTIMIZATION
|
40
|
+
assert info.requires == ["loop-analysis", "dominance"]
|
41
|
+
assert info.preserves == PreservationLevel.CFG
|
42
|
+
|
43
|
+
def test_get_loops_in_order(self) -> None:
|
44
|
+
"""Test ordering loops from innermost to outermost."""
|
45
|
+
pass_instance = LoopUnrolling()
|
46
|
+
|
47
|
+
# Create mock loops with different depths
|
48
|
+
loop1 = Mock(spec=Loop)
|
49
|
+
loop1.depth = 1
|
50
|
+
|
51
|
+
loop2 = Mock(spec=Loop)
|
52
|
+
loop2.depth = 3
|
53
|
+
|
54
|
+
loop3 = Mock(spec=Loop)
|
55
|
+
loop3.depth = 2
|
56
|
+
|
57
|
+
from typing import cast
|
58
|
+
|
59
|
+
from machine_dialect.mir.analyses.loop_analysis import Loop as RealLoop
|
60
|
+
|
61
|
+
loops = cast(list[RealLoop], [loop1, loop2, loop3])
|
62
|
+
ordered = pass_instance._get_loops_in_order(loops)
|
63
|
+
|
64
|
+
# Should be ordered by depth descending (3, 2, 1)
|
65
|
+
assert ordered[0] == loop2
|
66
|
+
assert ordered[1] == loop3
|
67
|
+
assert ordered[2] == loop1
|
68
|
+
|
69
|
+
def test_find_defining_instruction(self) -> None:
|
70
|
+
"""Test finding the instruction that defines a value."""
|
71
|
+
pass_instance = LoopUnrolling()
|
72
|
+
|
73
|
+
# Create a basic block with instructions
|
74
|
+
block = BasicBlock("test")
|
75
|
+
t0 = Temp(MIRType.INT, temp_id=0)
|
76
|
+
t1 = Temp(MIRType.INT, temp_id=1)
|
77
|
+
|
78
|
+
# Add instruction that defines t0
|
79
|
+
inst1 = LoadConst(t0, 42, (1, 1))
|
80
|
+
block.add_instruction(inst1)
|
81
|
+
|
82
|
+
# Add instruction that defines t1
|
83
|
+
inst2 = BinaryOp(t1, "+", t0, Constant(1), (2, 1))
|
84
|
+
block.add_instruction(inst2)
|
85
|
+
|
86
|
+
# Find defining instruction for t0
|
87
|
+
result = pass_instance._find_defining_instruction(t0, block)
|
88
|
+
assert result == inst1
|
89
|
+
|
90
|
+
# Find defining instruction for t1
|
91
|
+
result = pass_instance._find_defining_instruction(t1, block)
|
92
|
+
assert result == inst2
|
93
|
+
|
94
|
+
# Non-existent value
|
95
|
+
t2 = Temp(MIRType.INT, temp_id=2)
|
96
|
+
result = pass_instance._find_defining_instruction(t2, block)
|
97
|
+
assert result is None
|
98
|
+
|
99
|
+
def test_should_unroll_too_large(self) -> None:
|
100
|
+
"""Test that large loops are not unrolled."""
|
101
|
+
pass_instance = LoopUnrolling()
|
102
|
+
|
103
|
+
# Create a loop with too many instructions
|
104
|
+
loop = Mock(spec=Loop)
|
105
|
+
block1 = BasicBlock("loop_body")
|
106
|
+
|
107
|
+
# Add many instructions to exceed threshold
|
108
|
+
for i in range(25): # More than max_body_size (20)
|
109
|
+
t = Temp(MIRType.INT, temp_id=i)
|
110
|
+
block1.add_instruction(LoadConst(t, i, (i, 1)))
|
111
|
+
|
112
|
+
loop.blocks = [block1]
|
113
|
+
|
114
|
+
function = Mock(spec=MIRFunction)
|
115
|
+
result = pass_instance._should_unroll(loop, function)
|
116
|
+
|
117
|
+
assert result is False
|
118
|
+
|
119
|
+
def test_should_unroll_unknown_iteration_count(self) -> None:
|
120
|
+
"""Test that loops with unknown iteration count are not unrolled."""
|
121
|
+
pass_instance = LoopUnrolling()
|
122
|
+
|
123
|
+
# Create a simple loop
|
124
|
+
loop = Mock(spec=Loop)
|
125
|
+
block = BasicBlock("loop_body")
|
126
|
+
t0 = Temp(MIRType.INT, temp_id=0)
|
127
|
+
block.add_instruction(LoadConst(t0, 1, (1, 1)))
|
128
|
+
loop.blocks = [block]
|
129
|
+
loop.header = block
|
130
|
+
|
131
|
+
function = Mock(spec=MIRFunction)
|
132
|
+
|
133
|
+
# Mock _get_iteration_count to return None
|
134
|
+
with patch.object(pass_instance, "_get_iteration_count", return_value=None):
|
135
|
+
result = pass_instance._should_unroll(loop, function)
|
136
|
+
|
137
|
+
assert result is False
|
138
|
+
|
139
|
+
def test_should_unroll_valid_loop(self) -> None:
|
140
|
+
"""Test that valid loops are marked for unrolling."""
|
141
|
+
pass_instance = LoopUnrolling()
|
142
|
+
|
143
|
+
# Create a simple loop
|
144
|
+
loop = Mock(spec=Loop)
|
145
|
+
block = BasicBlock("loop_body")
|
146
|
+
t0 = Temp(MIRType.INT, temp_id=0)
|
147
|
+
block.add_instruction(LoadConst(t0, 1, (1, 1)))
|
148
|
+
loop.blocks = [block]
|
149
|
+
|
150
|
+
function = Mock(spec=MIRFunction)
|
151
|
+
|
152
|
+
# Mock _get_iteration_count to return a valid count
|
153
|
+
with patch.object(pass_instance, "_get_iteration_count", return_value=4):
|
154
|
+
result = pass_instance._should_unroll(loop, function)
|
155
|
+
|
156
|
+
assert result is True
|
157
|
+
|
158
|
+
def test_get_iteration_count_constant_bound(self) -> None:
|
159
|
+
"""Test determining iteration count with constant bounds."""
|
160
|
+
pass_instance = LoopUnrolling()
|
161
|
+
|
162
|
+
# Create a loop header with condition: i < 10
|
163
|
+
header = BasicBlock("header")
|
164
|
+
i = Temp(MIRType.INT, temp_id=0)
|
165
|
+
cond = Temp(MIRType.BOOL, temp_id=1)
|
166
|
+
|
167
|
+
# Add comparison: cond = i < 10
|
168
|
+
cmp_inst = BinaryOp(cond, "<", i, Constant(10), (1, 1))
|
169
|
+
header.add_instruction(cmp_inst)
|
170
|
+
|
171
|
+
# Add conditional jump based on condition
|
172
|
+
jump_inst = ConditionalJump(cond, "body", (2, 1), "exit")
|
173
|
+
header.add_instruction(jump_inst)
|
174
|
+
|
175
|
+
loop = Mock(spec=Loop)
|
176
|
+
loop.header = header
|
177
|
+
|
178
|
+
function = Mock(spec=MIRFunction)
|
179
|
+
|
180
|
+
# Mock _find_defining_instruction to return the comparison
|
181
|
+
with patch.object(pass_instance, "_find_defining_instruction", return_value=cmp_inst):
|
182
|
+
result = pass_instance._get_iteration_count(loop, function)
|
183
|
+
assert result == 10
|
184
|
+
|
185
|
+
def test_get_iteration_count_less_equal(self) -> None:
|
186
|
+
"""Test iteration count with <= comparison."""
|
187
|
+
pass_instance = LoopUnrolling()
|
188
|
+
|
189
|
+
# Create a loop header with condition: i <= 5
|
190
|
+
header = BasicBlock("header")
|
191
|
+
i = Temp(MIRType.INT, temp_id=0)
|
192
|
+
cond = Temp(MIRType.BOOL, temp_id=1)
|
193
|
+
|
194
|
+
# Add comparison: cond = i <= 5
|
195
|
+
cmp_inst = BinaryOp(cond, "<=", i, Constant(5), (1, 1))
|
196
|
+
header.add_instruction(cmp_inst)
|
197
|
+
|
198
|
+
# Add conditional jump
|
199
|
+
jump_inst = ConditionalJump(cond, "body", (2, 1), "exit")
|
200
|
+
header.add_instruction(jump_inst)
|
201
|
+
|
202
|
+
loop = Mock(spec=Loop)
|
203
|
+
loop.header = header
|
204
|
+
|
205
|
+
function = Mock(spec=MIRFunction)
|
206
|
+
|
207
|
+
# Mock _find_defining_instruction to return the comparison
|
208
|
+
with patch.object(pass_instance, "_find_defining_instruction", return_value=cmp_inst):
|
209
|
+
result = pass_instance._get_iteration_count(loop, function)
|
210
|
+
assert result == 6 # 0 to 5 inclusive
|
211
|
+
|
212
|
+
def test_get_iteration_count_non_constant(self) -> None:
|
213
|
+
"""Test that non-constant bounds return None."""
|
214
|
+
pass_instance = LoopUnrolling()
|
215
|
+
|
216
|
+
# Create a loop header with variable bound: i < n
|
217
|
+
header = BasicBlock("header")
|
218
|
+
i = Temp(MIRType.INT, temp_id=0)
|
219
|
+
n = Temp(MIRType.INT, temp_id=1) # Variable, not constant
|
220
|
+
cond = Temp(MIRType.BOOL, temp_id=2)
|
221
|
+
|
222
|
+
# Add comparison: cond = i < n
|
223
|
+
cmp_inst = BinaryOp(cond, "<", i, n, (1, 1))
|
224
|
+
header.add_instruction(cmp_inst)
|
225
|
+
|
226
|
+
# Add conditional jump
|
227
|
+
jump_inst = ConditionalJump(cond, "body", (2, 1), "exit")
|
228
|
+
header.add_instruction(jump_inst)
|
229
|
+
|
230
|
+
loop = Mock(spec=Loop)
|
231
|
+
loop.header = header
|
232
|
+
|
233
|
+
function = Mock(spec=MIRFunction)
|
234
|
+
|
235
|
+
result = pass_instance._get_iteration_count(loop, function)
|
236
|
+
assert result is None
|
237
|
+
|
238
|
+
def test_clone_block(self) -> None:
|
239
|
+
"""Test cloning a basic block."""
|
240
|
+
pass_instance = LoopUnrolling()
|
241
|
+
|
242
|
+
# Create original block
|
243
|
+
original = BasicBlock("original")
|
244
|
+
t0 = Temp(MIRType.INT, temp_id=0)
|
245
|
+
t1 = Temp(MIRType.INT, temp_id=1)
|
246
|
+
|
247
|
+
original.add_instruction(LoadConst(t0, 42, (1, 1)))
|
248
|
+
original.add_instruction(BinaryOp(t1, "+", t0, Constant(1), (2, 1)))
|
249
|
+
original.add_instruction(Jump("next", (3, 1)))
|
250
|
+
|
251
|
+
# Clone the block
|
252
|
+
cloned = pass_instance._clone_block(original, "_unroll_1")
|
253
|
+
|
254
|
+
assert cloned.label == "original_unroll_1"
|
255
|
+
assert len(cloned.instructions) == 3
|
256
|
+
assert cloned.instructions != original.instructions # Different objects
|
257
|
+
|
258
|
+
# Check that instructions are cloned
|
259
|
+
assert isinstance(cloned.instructions[0], LoadConst)
|
260
|
+
assert isinstance(cloned.instructions[1], BinaryOp)
|
261
|
+
assert isinstance(cloned.instructions[2], Jump)
|
262
|
+
|
263
|
+
def test_clone_instruction(self) -> None:
|
264
|
+
"""Test cloning individual instructions."""
|
265
|
+
pass_instance = LoopUnrolling()
|
266
|
+
|
267
|
+
# Test cloning different instruction types
|
268
|
+
t0 = Temp(MIRType.INT, temp_id=0)
|
269
|
+
|
270
|
+
# Clone LoadConst
|
271
|
+
inst1 = LoadConst(t0, 42, (1, 1))
|
272
|
+
cloned1 = pass_instance._clone_instruction(inst1, "_suffix")
|
273
|
+
assert isinstance(cloned1, LoadConst)
|
274
|
+
assert cloned1 != inst1
|
275
|
+
# Deep copy creates new objects, but the values should be equivalent
|
276
|
+
assert cloned1.constant == inst1.constant
|
277
|
+
|
278
|
+
# Clone Jump
|
279
|
+
inst2 = Jump("target", (2, 1))
|
280
|
+
cloned2 = pass_instance._clone_instruction(inst2, "_suffix")
|
281
|
+
assert isinstance(cloned2, Jump)
|
282
|
+
assert cloned2.label == "target" # Currently preserves original target
|
283
|
+
|
284
|
+
# Clone ConditionalJump
|
285
|
+
cond = Temp(MIRType.BOOL, temp_id=1)
|
286
|
+
inst3 = ConditionalJump(cond, "true_branch", (3, 1), "false_branch")
|
287
|
+
cloned3 = pass_instance._clone_instruction(inst3, "_suffix")
|
288
|
+
assert isinstance(cloned3, ConditionalJump)
|
289
|
+
assert cloned3.true_label == "true_branch"
|
290
|
+
assert cloned3.false_label == "false_branch"
|
291
|
+
|
292
|
+
def test_update_loop_increment(self) -> None:
|
293
|
+
"""Test updating loop increment for unrolling."""
|
294
|
+
pass_instance = LoopUnrolling()
|
295
|
+
|
296
|
+
# Create a loop with increment
|
297
|
+
block = BasicBlock("increment")
|
298
|
+
i = Temp(MIRType.INT, temp_id=0)
|
299
|
+
|
300
|
+
# Add increment: i = i + 1
|
301
|
+
inc_inst = BinaryOp(i, "+", i, Constant(1), (1, 1))
|
302
|
+
block.add_instruction(inc_inst)
|
303
|
+
|
304
|
+
loop = Mock(spec=Loop)
|
305
|
+
loop.blocks = [block]
|
306
|
+
|
307
|
+
# Update increment for unroll factor of 4
|
308
|
+
pass_instance._update_loop_increment(loop, 4)
|
309
|
+
|
310
|
+
# The implementation updates the right operand to a new Constant
|
311
|
+
# Check that increment value is now 4
|
312
|
+
assert hasattr(inc_inst.right, "value")
|
313
|
+
assert inc_inst.right.value == 4
|
314
|
+
|
315
|
+
def test_connect_unrolled_blocks(self) -> None:
|
316
|
+
"""Test connecting unrolled blocks."""
|
317
|
+
pass_instance = LoopUnrolling()
|
318
|
+
|
319
|
+
# Create original loop body
|
320
|
+
body1 = BasicBlock("body1")
|
321
|
+
body1.add_instruction(Jump("header", (1, 1)))
|
322
|
+
|
323
|
+
# Create unrolled copies
|
324
|
+
unroll1 = BasicBlock("body1_unroll_1")
|
325
|
+
unroll1.add_instruction(Jump("header", (2, 1)))
|
326
|
+
|
327
|
+
unroll2 = BasicBlock("body1_unroll_2")
|
328
|
+
unroll2.add_instruction(Jump("header", (3, 1)))
|
329
|
+
|
330
|
+
loop = Mock(spec=Loop)
|
331
|
+
|
332
|
+
pass_instance._connect_unrolled_blocks(loop, [unroll1, unroll2], [body1], 3)
|
333
|
+
|
334
|
+
# Check that original body jumps to first unrolled copy
|
335
|
+
assert hasattr(body1.instructions[-1], "label")
|
336
|
+
assert body1.instructions[-1].label == "body1_unroll_1"
|
337
|
+
|
338
|
+
# Check that first unrolled copy jumps to second
|
339
|
+
assert hasattr(unroll1.instructions[-1], "label")
|
340
|
+
assert unroll1.instructions[-1].label == "body1_unroll_2"
|
341
|
+
|
342
|
+
def test_run_on_function_no_analyses(self) -> None:
|
343
|
+
"""Test that function returns False when analyses are missing."""
|
344
|
+
pass_instance = LoopUnrolling()
|
345
|
+
|
346
|
+
function = Mock(spec=MIRFunction)
|
347
|
+
|
348
|
+
# Mock get_analysis to return None
|
349
|
+
with patch.object(pass_instance, "get_analysis", return_value=None):
|
350
|
+
result = pass_instance.run_on_function(function)
|
351
|
+
|
352
|
+
assert result is False
|
353
|
+
|
354
|
+
def test_run_on_function_with_config(self) -> None:
|
355
|
+
"""Test that config is used when available."""
|
356
|
+
pass_instance = LoopUnrolling()
|
357
|
+
|
358
|
+
# Set threshold directly
|
359
|
+
pass_instance.unroll_threshold = 8
|
360
|
+
|
361
|
+
function = Mock(spec=MIRFunction)
|
362
|
+
|
363
|
+
# Mock analyses
|
364
|
+
loop_info = Mock()
|
365
|
+
loop_info.loops = []
|
366
|
+
dominance = Mock()
|
367
|
+
|
368
|
+
with patch.object(pass_instance, "get_analysis") as mock_get:
|
369
|
+
mock_get.side_effect = [loop_info, dominance]
|
370
|
+
result = pass_instance.run_on_function(function)
|
371
|
+
|
372
|
+
# Check that config threshold was applied
|
373
|
+
assert pass_instance.unroll_threshold == 8
|
374
|
+
assert result is False # No loops to process
|
375
|
+
|
376
|
+
def test_run_on_function_process_loops(self) -> None:
|
377
|
+
"""Test processing loops in a function."""
|
378
|
+
pass_instance = LoopUnrolling()
|
379
|
+
|
380
|
+
# Create a function with CFG
|
381
|
+
function = MIRFunction("test_func", [], MIRType.EMPTY)
|
382
|
+
function.cfg = CFG()
|
383
|
+
|
384
|
+
# Create a simple loop
|
385
|
+
header = BasicBlock("header")
|
386
|
+
body = BasicBlock("body")
|
387
|
+
|
388
|
+
loop = Mock(spec=Loop)
|
389
|
+
loop.depth = 1
|
390
|
+
loop.header = header
|
391
|
+
loop.blocks = [header, body]
|
392
|
+
|
393
|
+
# Mock analyses
|
394
|
+
loop_info = Mock()
|
395
|
+
# Use hasattr check to avoid AttributeError
|
396
|
+
loop_info.loops = [loop]
|
397
|
+
dominance = Mock()
|
398
|
+
|
399
|
+
with patch.object(pass_instance, "get_analysis") as mock_get:
|
400
|
+
mock_get.side_effect = [loop_info, dominance]
|
401
|
+
|
402
|
+
# Mock _get_loops_in_order to return our loop
|
403
|
+
with patch.object(pass_instance, "_get_loops_in_order", return_value=[loop]):
|
404
|
+
# Mock should_unroll to return False (don't actually unroll)
|
405
|
+
with patch.object(pass_instance, "_should_unroll", return_value=False):
|
406
|
+
result = pass_instance.run_on_function(function)
|
407
|
+
|
408
|
+
assert pass_instance.stats["loops_processed"] == 1
|
409
|
+
assert pass_instance.stats["unrolled"] == 0
|
410
|
+
assert result is False
|
411
|
+
|
412
|
+
def test_finalize(self) -> None:
|
413
|
+
"""Test finalize method."""
|
414
|
+
pass_instance = LoopUnrolling()
|
415
|
+
# Should not raise any exception
|
416
|
+
pass_instance.finalize()
|
417
|
+
|
418
|
+
def test_get_statistics(self) -> None:
|
419
|
+
"""Test getting optimization statistics."""
|
420
|
+
pass_instance = LoopUnrolling()
|
421
|
+
|
422
|
+
# Initial stats
|
423
|
+
stats = pass_instance.get_statistics()
|
424
|
+
assert stats == {"unrolled": 0, "loops_processed": 0}
|
425
|
+
|
426
|
+
# Modify stats
|
427
|
+
pass_instance.stats["unrolled"] = 3
|
428
|
+
pass_instance.stats["loops_processed"] = 5
|
429
|
+
|
430
|
+
stats = pass_instance.get_statistics()
|
431
|
+
assert stats == {"unrolled": 3, "loops_processed": 5}
|
432
|
+
|
433
|
+
def test_unroll_loop_empty_body(self) -> None:
|
434
|
+
"""Test that loops with empty body are not unrolled."""
|
435
|
+
pass_instance = LoopUnrolling()
|
436
|
+
|
437
|
+
# Create a loop with only header (no body blocks)
|
438
|
+
header = BasicBlock("header")
|
439
|
+
loop = Mock(spec=Loop)
|
440
|
+
loop.header = header
|
441
|
+
loop.blocks = [header]
|
442
|
+
|
443
|
+
function = MIRFunction("test", [], MIRType.EMPTY)
|
444
|
+
function.cfg = CFG()
|
445
|
+
|
446
|
+
transformer = Mock()
|
447
|
+
|
448
|
+
# The implementation checks for loop_body_blocks excluding the header
|
449
|
+
# Since all blocks are just the header, loop_body_blocks will be empty
|
450
|
+
result = pass_instance._unroll_loop(loop, function, transformer)
|
451
|
+
assert result is False
|
452
|
+
|
453
|
+
def test_unroll_loop_success(self) -> None:
|
454
|
+
"""Test successful loop unrolling."""
|
455
|
+
pass_instance = LoopUnrolling()
|
456
|
+
pass_instance.unroll_threshold = 2
|
457
|
+
|
458
|
+
# Create a loop with header and body
|
459
|
+
header = BasicBlock("header")
|
460
|
+
body = BasicBlock("body")
|
461
|
+
body.add_instruction(Jump("header", (1, 1)))
|
462
|
+
|
463
|
+
loop = Mock(spec=Loop)
|
464
|
+
loop.header = header
|
465
|
+
loop.blocks = [header, body]
|
466
|
+
|
467
|
+
function = MIRFunction("test", [], MIRType.EMPTY)
|
468
|
+
function.cfg = CFG()
|
469
|
+
|
470
|
+
transformer = Mock()
|
471
|
+
transformer.modified = False
|
472
|
+
|
473
|
+
# Mock the helper methods to ensure they're called
|
474
|
+
with patch.object(pass_instance, "_connect_unrolled_blocks"):
|
475
|
+
with patch.object(pass_instance, "_update_loop_increment"):
|
476
|
+
result = pass_instance._unroll_loop(loop, function, transformer)
|
477
|
+
|
478
|
+
assert result is True
|
479
|
+
assert transformer.modified is True
|
480
|
+
|
481
|
+
# Check that new blocks were added to CFG (unroll_factor - 1 new blocks per body block)
|
482
|
+
# With unroll_factor=2 and 1 body block, we should have 1 new block
|
483
|
+
assert len(function.cfg.blocks) == 1
|