machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,255 @@
|
|
1
|
+
"""Tests for cross-block constant propagation."""
|
2
|
+
|
3
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
4
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
5
|
+
from machine_dialect.mir.mir_instructions import (
|
6
|
+
BinaryOp,
|
7
|
+
ConditionalJump,
|
8
|
+
Copy,
|
9
|
+
Jump,
|
10
|
+
LoadConst,
|
11
|
+
Phi,
|
12
|
+
Return,
|
13
|
+
)
|
14
|
+
from machine_dialect.mir.mir_types import MIRType
|
15
|
+
from machine_dialect.mir.mir_values import Constant, Temp, Variable
|
16
|
+
from machine_dialect.mir.optimizations.constant_propagation import ConstantPropagation
|
17
|
+
|
18
|
+
|
19
|
+
class TestCrossBlockConstantPropagation:
|
20
|
+
"""Test cross-block constant propagation."""
|
21
|
+
|
22
|
+
def test_constant_through_multiple_blocks(self) -> None:
|
23
|
+
"""Test propagation of constants through multiple blocks."""
|
24
|
+
# Create function
|
25
|
+
func = MIRFunction("test", [])
|
26
|
+
|
27
|
+
# Block 1: x = 10
|
28
|
+
block1 = BasicBlock("entry")
|
29
|
+
x = Variable("x", MIRType.INT)
|
30
|
+
func.add_local(x)
|
31
|
+
block1.add_instruction(LoadConst(x, Constant(10, MIRType.INT), (1, 1)))
|
32
|
+
block1.add_instruction(Jump("block2", (1, 1)))
|
33
|
+
|
34
|
+
# Block 2: y = x + 5
|
35
|
+
block2 = BasicBlock("block2")
|
36
|
+
y = Variable("y", MIRType.INT)
|
37
|
+
func.add_local(y)
|
38
|
+
t1 = Temp(MIRType.INT, 0)
|
39
|
+
block2.add_instruction(BinaryOp(t1, "+", x, Constant(5, MIRType.INT), (1, 1)))
|
40
|
+
block2.add_instruction(Copy(y, t1, (1, 1)))
|
41
|
+
block2.add_instruction(Jump("block3", (1, 1)))
|
42
|
+
|
43
|
+
# Block 3: z = y * 2
|
44
|
+
block3 = BasicBlock("block3")
|
45
|
+
z = Variable("z", MIRType.INT)
|
46
|
+
func.add_local(z)
|
47
|
+
t2 = Temp(MIRType.INT, 1)
|
48
|
+
block3.add_instruction(BinaryOp(t2, "*", y, Constant(2, MIRType.INT), (1, 1)))
|
49
|
+
block3.add_instruction(Copy(z, t2, (1, 1)))
|
50
|
+
block3.add_instruction(Return((1, 1), z))
|
51
|
+
|
52
|
+
# Set up CFG
|
53
|
+
func.cfg.add_block(block1)
|
54
|
+
func.cfg.add_block(block2)
|
55
|
+
func.cfg.add_block(block3)
|
56
|
+
func.cfg.set_entry_block(block1)
|
57
|
+
|
58
|
+
block1.add_successor(block2)
|
59
|
+
block2.add_predecessor(block1)
|
60
|
+
block2.add_successor(block3)
|
61
|
+
block3.add_predecessor(block2)
|
62
|
+
|
63
|
+
# Run optimization
|
64
|
+
optimizer = ConstantPropagation()
|
65
|
+
modified = optimizer.run_on_function(func)
|
66
|
+
|
67
|
+
assert modified
|
68
|
+
# After optimization, operations should be folded
|
69
|
+
# x = 10 -> y = 15 -> z = 30
|
70
|
+
# Check that final return is a constant
|
71
|
+
final_inst = block3.instructions[-1]
|
72
|
+
assert isinstance(final_inst, Return)
|
73
|
+
# The value might be replaced with a constant
|
74
|
+
|
75
|
+
def test_phi_node_constant_propagation(self) -> None:
|
76
|
+
"""Test constant propagation through phi nodes."""
|
77
|
+
func = MIRFunction("test", [])
|
78
|
+
|
79
|
+
# Entry block
|
80
|
+
entry = BasicBlock("entry")
|
81
|
+
cond = Variable("cond", MIRType.BOOL)
|
82
|
+
func.add_local(cond)
|
83
|
+
entry.add_instruction(LoadConst(cond, Constant(True, MIRType.BOOL), (1, 1)))
|
84
|
+
entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
|
85
|
+
|
86
|
+
# Then block: x = 10
|
87
|
+
then_block = BasicBlock("then")
|
88
|
+
x_then = Temp(MIRType.INT, 0)
|
89
|
+
then_block.add_instruction(LoadConst(x_then, Constant(10, MIRType.INT), (1, 1)))
|
90
|
+
then_block.add_instruction(Jump("merge", (1, 1)))
|
91
|
+
|
92
|
+
# Else block: x = 10 (same value)
|
93
|
+
else_block = BasicBlock("else")
|
94
|
+
x_else = Temp(MIRType.INT, 1)
|
95
|
+
else_block.add_instruction(LoadConst(x_else, Constant(10, MIRType.INT), (1, 1)))
|
96
|
+
else_block.add_instruction(Jump("merge", (1, 1)))
|
97
|
+
|
98
|
+
# Merge block with phi node
|
99
|
+
merge = BasicBlock("merge")
|
100
|
+
x = Variable("x", MIRType.INT)
|
101
|
+
func.add_local(x)
|
102
|
+
phi = Phi(x, [(x_then, "then"), (x_else, "else")], (1, 1))
|
103
|
+
merge.phi_nodes.append(phi)
|
104
|
+
|
105
|
+
# Use x in computation
|
106
|
+
result = Temp(MIRType.INT, 2)
|
107
|
+
merge.add_instruction(BinaryOp(result, "+", x, Constant(5, MIRType.INT), (1, 1)))
|
108
|
+
merge.add_instruction(Return((1, 1), result))
|
109
|
+
|
110
|
+
# Set up CFG
|
111
|
+
func.cfg.add_block(entry)
|
112
|
+
func.cfg.add_block(then_block)
|
113
|
+
func.cfg.add_block(else_block)
|
114
|
+
func.cfg.add_block(merge)
|
115
|
+
func.cfg.set_entry_block(entry)
|
116
|
+
|
117
|
+
entry.add_successor(then_block)
|
118
|
+
entry.add_successor(else_block)
|
119
|
+
then_block.add_predecessor(entry)
|
120
|
+
else_block.add_predecessor(entry)
|
121
|
+
then_block.add_successor(merge)
|
122
|
+
else_block.add_successor(merge)
|
123
|
+
merge.add_predecessor(then_block)
|
124
|
+
merge.add_predecessor(else_block)
|
125
|
+
|
126
|
+
# Run optimization
|
127
|
+
optimizer = ConstantPropagation()
|
128
|
+
modified = optimizer.run_on_function(func)
|
129
|
+
|
130
|
+
assert modified
|
131
|
+
# Since both branches assign the same constant (10) to x,
|
132
|
+
# the phi should resolve to 10 and x + 5 should fold to 15
|
133
|
+
|
134
|
+
def test_loop_constant_propagation(self) -> None:
|
135
|
+
"""Test constant propagation in loops."""
|
136
|
+
func = MIRFunction("test", [])
|
137
|
+
|
138
|
+
# Entry block: i = 0, sum = 0
|
139
|
+
entry = BasicBlock("entry")
|
140
|
+
i = Variable("i", MIRType.INT)
|
141
|
+
sum_var = Variable("sum", MIRType.INT)
|
142
|
+
func.add_local(i)
|
143
|
+
func.add_local(sum_var)
|
144
|
+
|
145
|
+
entry.add_instruction(LoadConst(i, Constant(0, MIRType.INT), (1, 1)))
|
146
|
+
entry.add_instruction(LoadConst(sum_var, Constant(0, MIRType.INT), (1, 1)))
|
147
|
+
entry.add_instruction(Jump("loop", (1, 1)))
|
148
|
+
|
149
|
+
# Loop block
|
150
|
+
loop = BasicBlock("loop")
|
151
|
+
# Phi nodes for loop variables
|
152
|
+
i_phi = Phi(i, [(i, "entry")], (1, 1)) # Will have back-edge added
|
153
|
+
sum_phi = Phi(sum_var, [(sum_var, "entry")], (1, 1))
|
154
|
+
loop.phi_nodes.append(i_phi)
|
155
|
+
loop.phi_nodes.append(sum_phi)
|
156
|
+
|
157
|
+
# Check condition: i < 10
|
158
|
+
t_cond = Temp(MIRType.BOOL, 0)
|
159
|
+
loop.add_instruction(BinaryOp(t_cond, "<", i, Constant(10, MIRType.INT), (1, 1)))
|
160
|
+
loop.add_instruction(ConditionalJump(t_cond, "body", (1, 1), "exit"))
|
161
|
+
|
162
|
+
# Loop body
|
163
|
+
body = BasicBlock("body")
|
164
|
+
# sum = sum + i
|
165
|
+
t_sum = Temp(MIRType.INT, 1)
|
166
|
+
body.add_instruction(BinaryOp(t_sum, "+", sum_var, i, (1, 1)))
|
167
|
+
body.add_instruction(Copy(sum_var, t_sum, (1, 1)))
|
168
|
+
|
169
|
+
# i = i + 1
|
170
|
+
t_i = Temp(MIRType.INT, 2)
|
171
|
+
body.add_instruction(BinaryOp(t_i, "+", i, Constant(1, MIRType.INT), (1, 1)))
|
172
|
+
body.add_instruction(Copy(i, t_i, (1, 1)))
|
173
|
+
body.add_instruction(Jump("loop", (1, 1)))
|
174
|
+
|
175
|
+
# Exit block
|
176
|
+
exit_block = BasicBlock("exit")
|
177
|
+
exit_block.add_instruction(Return((1, 1), sum_var))
|
178
|
+
|
179
|
+
# Set up CFG
|
180
|
+
func.cfg.add_block(entry)
|
181
|
+
func.cfg.add_block(loop)
|
182
|
+
func.cfg.add_block(body)
|
183
|
+
func.cfg.add_block(exit_block)
|
184
|
+
func.cfg.set_entry_block(entry)
|
185
|
+
|
186
|
+
entry.add_successor(loop)
|
187
|
+
loop.add_predecessor(entry)
|
188
|
+
loop.add_successor(body)
|
189
|
+
loop.add_successor(exit_block)
|
190
|
+
body.add_predecessor(loop)
|
191
|
+
body.add_successor(loop) # Back-edge
|
192
|
+
loop.add_predecessor(body) # Back-edge
|
193
|
+
exit_block.add_predecessor(loop)
|
194
|
+
|
195
|
+
# Add back-edge to phi nodes
|
196
|
+
i_phi.incoming.append((i, "body"))
|
197
|
+
sum_phi.incoming.append((sum_var, "body"))
|
198
|
+
|
199
|
+
# Run optimization
|
200
|
+
optimizer = ConstantPropagation()
|
201
|
+
optimizer.run_on_function(func)
|
202
|
+
|
203
|
+
# In loops, constant propagation is limited but should still
|
204
|
+
# propagate initial values and fold operations where possible
|
205
|
+
assert optimizer.stats.get("constants_propagated", 0) >= 0
|
206
|
+
|
207
|
+
def test_conditional_constant_propagation(self) -> None:
|
208
|
+
"""Test constant propagation with conditional branches."""
|
209
|
+
func = MIRFunction("test", [])
|
210
|
+
|
211
|
+
# Entry: x = 5, y = 10
|
212
|
+
entry = BasicBlock("entry")
|
213
|
+
x = Variable("x", MIRType.INT)
|
214
|
+
y = Variable("y", MIRType.INT)
|
215
|
+
func.add_local(x)
|
216
|
+
func.add_local(y)
|
217
|
+
|
218
|
+
entry.add_instruction(LoadConst(x, Constant(5, MIRType.INT), (1, 1)))
|
219
|
+
entry.add_instruction(LoadConst(y, Constant(10, MIRType.INT), (1, 1)))
|
220
|
+
|
221
|
+
# Compute condition: x < y (should be constant True)
|
222
|
+
cond = Temp(MIRType.BOOL, 0)
|
223
|
+
entry.add_instruction(BinaryOp(cond, "<", x, y, (1, 1)))
|
224
|
+
entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
|
225
|
+
|
226
|
+
# Then block (should be taken)
|
227
|
+
then_block = BasicBlock("then")
|
228
|
+
result_then = Temp(MIRType.INT, 1)
|
229
|
+
then_block.add_instruction(BinaryOp(result_then, "+", x, y, (1, 1)))
|
230
|
+
then_block.add_instruction(Return((1, 1), result_then))
|
231
|
+
|
232
|
+
# Else block (dead code)
|
233
|
+
else_block = BasicBlock("else")
|
234
|
+
result_else = Temp(MIRType.INT, 2)
|
235
|
+
else_block.add_instruction(BinaryOp(result_else, "-", y, x, (1, 1)))
|
236
|
+
else_block.add_instruction(Return((1, 1), result_else))
|
237
|
+
|
238
|
+
# Set up CFG
|
239
|
+
func.cfg.add_block(entry)
|
240
|
+
func.cfg.add_block(then_block)
|
241
|
+
func.cfg.add_block(else_block)
|
242
|
+
func.cfg.set_entry_block(entry)
|
243
|
+
|
244
|
+
entry.add_successor(then_block)
|
245
|
+
entry.add_successor(else_block)
|
246
|
+
then_block.add_predecessor(entry)
|
247
|
+
else_block.add_predecessor(entry)
|
248
|
+
|
249
|
+
# Run optimization
|
250
|
+
optimizer = ConstantPropagation()
|
251
|
+
modified = optimizer.run_on_function(func)
|
252
|
+
|
253
|
+
assert modified
|
254
|
+
# The condition x < y should be folded to True
|
255
|
+
# and potentially the branch should be simplified
|
@@ -0,0 +1,166 @@
|
|
1
|
+
"""Test custom optimization passes functionality."""
|
2
|
+
|
3
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
4
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
5
|
+
from machine_dialect.mir.mir_instructions import LoadConst, Return
|
6
|
+
from machine_dialect.mir.mir_module import MIRModule
|
7
|
+
from machine_dialect.mir.mir_types import MIRType
|
8
|
+
from machine_dialect.mir.mir_values import Temp
|
9
|
+
from machine_dialect.mir.optimization_config import OptimizationConfig
|
10
|
+
from machine_dialect.mir.optimize_mir import optimize_mir
|
11
|
+
|
12
|
+
|
13
|
+
def create_simple_module() -> MIRModule:
|
14
|
+
"""Create a simple test module.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
A MIR module with a simple main function.
|
18
|
+
"""
|
19
|
+
module = MIRModule("test")
|
20
|
+
func = MIRFunction("main", [], MIRType.INT)
|
21
|
+
entry = BasicBlock("entry")
|
22
|
+
entry.add_instruction(LoadConst(Temp(MIRType.INT), 42, (1, 1)))
|
23
|
+
entry.add_instruction(Return((1, 1)))
|
24
|
+
func.cfg.add_block(entry)
|
25
|
+
func.cfg.entry = entry # type: ignore[attr-defined]
|
26
|
+
module.add_function(func)
|
27
|
+
return module
|
28
|
+
|
29
|
+
|
30
|
+
def test_custom_passes_override_default() -> None:
|
31
|
+
"""Test that custom passes override the default optimization pipeline."""
|
32
|
+
module = create_simple_module()
|
33
|
+
|
34
|
+
# Run with custom passes at optimization level 2
|
35
|
+
# Level 2 would normally run many passes, but custom passes should override
|
36
|
+
custom_passes = ["constant-propagation", "dce"]
|
37
|
+
_optimized, stats = optimize_mir(module, optimization_level=2, custom_passes=custom_passes)
|
38
|
+
|
39
|
+
# Verify that only our custom passes ran (plus their dependencies)
|
40
|
+
# constant-propagation depends on use-def-chains analysis
|
41
|
+
assert "constant-propagation" in stats
|
42
|
+
assert "dce" in stats
|
43
|
+
|
44
|
+
# These passes would normally run at level 2 but shouldn't with custom passes
|
45
|
+
assert "cse" not in stats # Common subexpression elimination
|
46
|
+
assert "strength-reduction" not in stats
|
47
|
+
|
48
|
+
|
49
|
+
def test_no_custom_passes_uses_default() -> None:
|
50
|
+
"""Test that without custom passes, the default pipeline is used."""
|
51
|
+
module = create_simple_module()
|
52
|
+
|
53
|
+
# Run at level 2 without custom passes
|
54
|
+
_optimized, stats = optimize_mir(module, optimization_level=2)
|
55
|
+
|
56
|
+
# At level 2, we should see standard optimizations
|
57
|
+
assert "constant-propagation" in stats or "constant-folding" in stats
|
58
|
+
assert "dce" in stats
|
59
|
+
# CSE is enabled at level 2
|
60
|
+
assert "cse" in stats
|
61
|
+
|
62
|
+
|
63
|
+
def test_custom_passes_empty_list() -> None:
|
64
|
+
"""Test that an empty custom passes list runs no optimization passes."""
|
65
|
+
module = create_simple_module()
|
66
|
+
|
67
|
+
# Run with empty custom passes list
|
68
|
+
_optimized, stats = optimize_mir(module, optimization_level=2, custom_passes=[])
|
69
|
+
|
70
|
+
# No optimization passes should have run
|
71
|
+
assert len(stats) == 0
|
72
|
+
|
73
|
+
|
74
|
+
def test_custom_passes_at_level_0() -> None:
|
75
|
+
"""Test that custom passes work even at optimization level 0."""
|
76
|
+
module = create_simple_module()
|
77
|
+
|
78
|
+
# Level 0 normally runs no optimizations, but custom passes should still run
|
79
|
+
custom_passes = ["dce"]
|
80
|
+
_optimized, stats = optimize_mir(module, optimization_level=0, custom_passes=custom_passes)
|
81
|
+
|
82
|
+
# DCE should have run despite level 0
|
83
|
+
assert "dce" in stats
|
84
|
+
|
85
|
+
|
86
|
+
def test_custom_passes_with_dependencies() -> None:
|
87
|
+
"""Test that custom passes include their required analysis passes."""
|
88
|
+
module = create_simple_module()
|
89
|
+
|
90
|
+
# constant-propagation requires use-def-chains analysis
|
91
|
+
custom_passes = ["constant-propagation"]
|
92
|
+
_optimized, stats = optimize_mir(module, optimization_level=1, custom_passes=custom_passes)
|
93
|
+
|
94
|
+
# Both the optimization and its required analysis should be in stats
|
95
|
+
assert "constant-propagation" in stats
|
96
|
+
# Note: dependencies are handled internally by the pass manager
|
97
|
+
# The stats might not always include analysis passes
|
98
|
+
|
99
|
+
|
100
|
+
def test_custom_passes_preserve_module() -> None:
|
101
|
+
"""Test that optimization with custom passes preserves module structure."""
|
102
|
+
module = create_simple_module()
|
103
|
+
original_func_count = len(module.functions)
|
104
|
+
original_func_name = next(iter(module.functions.values())).name if module.functions else None
|
105
|
+
|
106
|
+
# Run with custom passes
|
107
|
+
custom_passes = ["dce"]
|
108
|
+
optimized, _stats = optimize_mir(module, optimization_level=1, custom_passes=custom_passes)
|
109
|
+
|
110
|
+
# Module structure should be preserved
|
111
|
+
assert len(optimized.functions) == original_func_count
|
112
|
+
if original_func_name:
|
113
|
+
assert next(iter(optimized.functions.values())).name == original_func_name
|
114
|
+
assert optimized.name == module.name
|
115
|
+
|
116
|
+
|
117
|
+
def test_invalid_custom_pass_name() -> None:
|
118
|
+
"""Test that invalid pass names are handled gracefully."""
|
119
|
+
module = create_simple_module()
|
120
|
+
|
121
|
+
# Try to run with an invalid pass name
|
122
|
+
custom_passes = ["invalid-pass-name", "dce"]
|
123
|
+
|
124
|
+
# This should either skip the invalid pass or raise an error
|
125
|
+
# The actual behavior depends on the pass manager implementation
|
126
|
+
# For now, we just verify it doesn't crash
|
127
|
+
try:
|
128
|
+
_optimized, stats = optimize_mir(module, optimization_level=1, custom_passes=custom_passes)
|
129
|
+
# If it succeeds, the valid pass should still run
|
130
|
+
assert "dce" in stats or len(stats) >= 0
|
131
|
+
except (KeyError, ValueError):
|
132
|
+
# It's also acceptable to raise an error for invalid passes
|
133
|
+
pass
|
134
|
+
|
135
|
+
|
136
|
+
def test_custom_passes_order_preserved() -> None:
|
137
|
+
"""Test that custom passes run in the specified order."""
|
138
|
+
module = create_simple_module()
|
139
|
+
|
140
|
+
# Specify passes in a specific order
|
141
|
+
custom_passes = ["dce", "constant-propagation", "dce"]
|
142
|
+
_optimized, stats = optimize_mir(module, optimization_level=1, custom_passes=custom_passes)
|
143
|
+
|
144
|
+
# Both passes should have run
|
145
|
+
# Note: stats might aggregate multiple runs of the same pass
|
146
|
+
assert "dce" in stats
|
147
|
+
assert "constant-propagation" in stats
|
148
|
+
|
149
|
+
|
150
|
+
def test_custom_passes_with_config() -> None:
|
151
|
+
"""Test that custom passes work with a custom OptimizationConfig."""
|
152
|
+
module = create_simple_module()
|
153
|
+
|
154
|
+
# Create a custom config
|
155
|
+
config = OptimizationConfig.from_level(2)
|
156
|
+
config.debug_passes = True
|
157
|
+
config.pass_statistics = True
|
158
|
+
|
159
|
+
# Run with custom passes and custom config
|
160
|
+
custom_passes = ["constant-propagation"]
|
161
|
+
_optimized, stats = optimize_mir(module, optimization_level=2, config=config, custom_passes=custom_passes)
|
162
|
+
|
163
|
+
# Custom passes should override the config's default pipeline
|
164
|
+
assert "constant-propagation" in stats
|
165
|
+
# CSE would normally be in level 2 but shouldn't run with custom passes
|
166
|
+
assert "cse" not in stats
|