machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,192 @@
|
|
1
|
+
"""Tests for type-annotated MIR instructions."""
|
2
|
+
|
3
|
+
from machine_dialect.lexer.tokens import Token, TokenType
|
4
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
5
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
6
|
+
from machine_dialect.mir.mir_instructions import (
|
7
|
+
LoadConst,
|
8
|
+
NarrowType,
|
9
|
+
TypeAssert,
|
10
|
+
TypeCast,
|
11
|
+
TypeCheck,
|
12
|
+
)
|
13
|
+
from machine_dialect.mir.mir_types import MIRType, MIRUnionType
|
14
|
+
from machine_dialect.mir.mir_values import Constant, Temp, Variable
|
15
|
+
|
16
|
+
|
17
|
+
class TestTypeAnnotatedInstructions:
|
18
|
+
"""Test type-annotated MIR instructions."""
|
19
|
+
|
20
|
+
def setup_method(self) -> None:
|
21
|
+
"""Set up test fixtures."""
|
22
|
+
self.dummy_token = Token(TokenType.KW_SET, "Set", 1, 1)
|
23
|
+
|
24
|
+
def test_type_cast_instruction(self) -> None:
|
25
|
+
"""Test TypeCast instruction creation and properties."""
|
26
|
+
src = Variable("x", MIRType.INT)
|
27
|
+
dest = Temp(MIRType.FLOAT, 0)
|
28
|
+
|
29
|
+
cast = TypeCast(dest, src, MIRType.FLOAT)
|
30
|
+
|
31
|
+
assert str(cast) == "t0 = cast(x, float)"
|
32
|
+
assert cast.get_uses() == [src]
|
33
|
+
assert cast.get_defs() == [dest]
|
34
|
+
assert cast.target_type == MIRType.FLOAT
|
35
|
+
|
36
|
+
def test_type_cast_to_union(self) -> None:
|
37
|
+
"""Test TypeCast to union type."""
|
38
|
+
src = Variable("x", MIRType.INT)
|
39
|
+
union_type = MIRUnionType([MIRType.INT, MIRType.STRING])
|
40
|
+
dest = Temp(union_type, 1)
|
41
|
+
|
42
|
+
cast = TypeCast(dest, src, union_type)
|
43
|
+
|
44
|
+
assert str(cast) == "t1 = cast(x, Union[int, string])"
|
45
|
+
assert cast.target_type == union_type
|
46
|
+
|
47
|
+
def test_type_check_instruction(self) -> None:
|
48
|
+
"""Test TypeCheck instruction creation and properties."""
|
49
|
+
var = Variable("x", MIRUnionType([MIRType.INT, MIRType.STRING]))
|
50
|
+
result = Temp(MIRType.BOOL, 2)
|
51
|
+
|
52
|
+
check = TypeCheck(result, var, MIRType.INT)
|
53
|
+
|
54
|
+
assert str(check) == "t2 = is_type(x, int)"
|
55
|
+
assert check.get_uses() == [var]
|
56
|
+
assert check.get_defs() == [result]
|
57
|
+
assert check.check_type == MIRType.INT
|
58
|
+
|
59
|
+
def test_type_check_union_type(self) -> None:
|
60
|
+
"""Test TypeCheck with union type check."""
|
61
|
+
var = Variable("y", MIRType.UNKNOWN)
|
62
|
+
result = Temp(MIRType.BOOL, 3)
|
63
|
+
union_type = MIRUnionType([MIRType.FLOAT, MIRType.BOOL])
|
64
|
+
|
65
|
+
check = TypeCheck(result, var, union_type)
|
66
|
+
|
67
|
+
assert str(check) == "t3 = is_type(y, Union[float, bool])"
|
68
|
+
assert check.check_type == union_type
|
69
|
+
|
70
|
+
def test_type_assert_instruction(self) -> None:
|
71
|
+
"""Test TypeAssert instruction creation and properties."""
|
72
|
+
var = Variable("z", MIRType.UNKNOWN)
|
73
|
+
|
74
|
+
assertion = TypeAssert(var, MIRType.STRING)
|
75
|
+
|
76
|
+
assert str(assertion) == "assert_type(z, string)"
|
77
|
+
assert assertion.get_uses() == [var]
|
78
|
+
assert assertion.get_defs() == []
|
79
|
+
assert assertion.assert_type == MIRType.STRING
|
80
|
+
|
81
|
+
def test_type_assert_union(self) -> None:
|
82
|
+
"""Test TypeAssert with union type."""
|
83
|
+
var = Variable("w", MIRType.UNKNOWN)
|
84
|
+
union_type = MIRUnionType([MIRType.INT, MIRType.FLOAT])
|
85
|
+
|
86
|
+
assertion = TypeAssert(var, union_type)
|
87
|
+
|
88
|
+
assert str(assertion) == "assert_type(w, Union[int, float])"
|
89
|
+
assert assertion.assert_type == union_type
|
90
|
+
|
91
|
+
def test_narrow_type_instruction(self) -> None:
|
92
|
+
"""Test NarrowType instruction creation and properties."""
|
93
|
+
# Variable with union type
|
94
|
+
union_var = Variable("u", MIRUnionType([MIRType.INT, MIRType.STRING]))
|
95
|
+
# After type check, narrow to specific type
|
96
|
+
narrowed = Temp(MIRType.INT, 4)
|
97
|
+
|
98
|
+
narrow = NarrowType(narrowed, union_var, MIRType.INT)
|
99
|
+
|
100
|
+
assert str(narrow) == "t4 = narrow(u, int)"
|
101
|
+
assert narrow.get_uses() == [union_var]
|
102
|
+
assert narrow.get_defs() == [narrowed]
|
103
|
+
assert narrow.narrow_type == MIRType.INT
|
104
|
+
|
105
|
+
def test_replace_use_in_type_cast(self) -> None:
|
106
|
+
"""Test replacing uses in TypeCast instruction."""
|
107
|
+
old_var = Variable("old", MIRType.INT)
|
108
|
+
new_var = Variable("new", MIRType.INT)
|
109
|
+
dest = Temp(MIRType.FLOAT, 5)
|
110
|
+
|
111
|
+
cast = TypeCast(dest, old_var, MIRType.FLOAT)
|
112
|
+
cast.replace_use(old_var, new_var)
|
113
|
+
|
114
|
+
assert cast.value == new_var
|
115
|
+
assert str(cast) == "t5 = cast(new, float)"
|
116
|
+
|
117
|
+
def test_replace_use_in_type_check(self) -> None:
|
118
|
+
"""Test replacing uses in TypeCheck instruction."""
|
119
|
+
old_var = Variable("old", MIRType.UNKNOWN)
|
120
|
+
new_var = Variable("new", MIRType.UNKNOWN)
|
121
|
+
result = Temp(MIRType.BOOL, 6)
|
122
|
+
|
123
|
+
check = TypeCheck(result, old_var, MIRType.STRING)
|
124
|
+
check.replace_use(old_var, new_var)
|
125
|
+
|
126
|
+
assert check.value == new_var
|
127
|
+
assert str(check) == "t6 = is_type(new, string)"
|
128
|
+
|
129
|
+
def test_replace_use_in_type_assert(self) -> None:
|
130
|
+
"""Test replacing uses in TypeAssert instruction."""
|
131
|
+
old_var = Variable("old", MIRType.UNKNOWN)
|
132
|
+
new_var = Variable("new", MIRType.UNKNOWN)
|
133
|
+
|
134
|
+
assertion = TypeAssert(old_var, MIRType.BOOL)
|
135
|
+
assertion.replace_use(old_var, new_var)
|
136
|
+
|
137
|
+
assert assertion.value == new_var
|
138
|
+
assert str(assertion) == "assert_type(new, bool)"
|
139
|
+
|
140
|
+
def test_replace_use_in_narrow_type(self) -> None:
|
141
|
+
"""Test replacing uses in NarrowType instruction."""
|
142
|
+
old_var = Variable("old", MIRUnionType([MIRType.INT, MIRType.FLOAT]))
|
143
|
+
new_var = Variable("new", MIRUnionType([MIRType.INT, MIRType.FLOAT]))
|
144
|
+
dest = Temp(MIRType.INT, 7)
|
145
|
+
|
146
|
+
narrow = NarrowType(dest, old_var, MIRType.INT)
|
147
|
+
narrow.replace_use(old_var, new_var)
|
148
|
+
|
149
|
+
assert narrow.value == new_var
|
150
|
+
assert str(narrow) == "t7 = narrow(new, int)"
|
151
|
+
|
152
|
+
def test_type_cast_in_basic_block(self) -> None:
|
153
|
+
"""Test adding TypeCast instruction to a basic block."""
|
154
|
+
func = MIRFunction("test", [])
|
155
|
+
block = BasicBlock("entry")
|
156
|
+
|
157
|
+
# Create a cast from int to float
|
158
|
+
int_var = Variable("x", MIRType.INT)
|
159
|
+
float_temp = Temp(MIRType.FLOAT, 8)
|
160
|
+
|
161
|
+
block.add_instruction(LoadConst(int_var, Constant(42, MIRType.INT), (1, 1)))
|
162
|
+
block.add_instruction(TypeCast(float_temp, int_var, MIRType.FLOAT))
|
163
|
+
|
164
|
+
func.cfg.add_block(block)
|
165
|
+
func.cfg.set_entry_block(block)
|
166
|
+
|
167
|
+
assert len(block.instructions) == 2
|
168
|
+
assert isinstance(block.instructions[0], LoadConst)
|
169
|
+
assert isinstance(block.instructions[1], TypeCast)
|
170
|
+
|
171
|
+
def test_type_narrowing_flow(self) -> None:
|
172
|
+
"""Test typical type narrowing flow with type check and narrow."""
|
173
|
+
func = MIRFunction("test", [])
|
174
|
+
block = BasicBlock("entry")
|
175
|
+
|
176
|
+
# Variable with union type
|
177
|
+
union_var = Variable("v", MIRUnionType([MIRType.INT, MIRType.STRING]))
|
178
|
+
|
179
|
+
# Check if it's an int
|
180
|
+
is_int = Temp(MIRType.BOOL, 9)
|
181
|
+
block.add_instruction(TypeCheck(is_int, union_var, MIRType.INT))
|
182
|
+
|
183
|
+
# If check passes, narrow to int
|
184
|
+
narrowed_int = Temp(MIRType.INT, 10)
|
185
|
+
block.add_instruction(NarrowType(narrowed_int, union_var, MIRType.INT))
|
186
|
+
|
187
|
+
func.cfg.add_block(block)
|
188
|
+
func.cfg.set_entry_block(block)
|
189
|
+
|
190
|
+
assert len(block.instructions) == 2
|
191
|
+
assert isinstance(block.instructions[0], TypeCheck)
|
192
|
+
assert isinstance(block.instructions[1], NarrowType)
|
@@ -0,0 +1,277 @@
|
|
1
|
+
"""Tests for type narrowing optimization pass."""
|
2
|
+
|
3
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
4
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
5
|
+
from machine_dialect.mir.mir_instructions import (
|
6
|
+
BinaryOp,
|
7
|
+
ConditionalJump,
|
8
|
+
Copy,
|
9
|
+
Jump,
|
10
|
+
Return,
|
11
|
+
TypeAssert,
|
12
|
+
TypeCast,
|
13
|
+
TypeCheck,
|
14
|
+
)
|
15
|
+
from machine_dialect.mir.mir_types import MIRType, MIRUnionType
|
16
|
+
from machine_dialect.mir.mir_values import Constant, Temp, Variable
|
17
|
+
from machine_dialect.mir.optimizations.type_narrowing import TypeNarrowing
|
18
|
+
|
19
|
+
|
20
|
+
class TestTypeNarrowing:
|
21
|
+
"""Test type narrowing optimization."""
|
22
|
+
|
23
|
+
def test_type_check_narrowing(self) -> None:
|
24
|
+
"""Test type narrowing after TypeCheck."""
|
25
|
+
func = MIRFunction("test", [])
|
26
|
+
|
27
|
+
# Create a variable with union type
|
28
|
+
x = Variable("x", MIRUnionType([MIRType.INT, MIRType.STRING]))
|
29
|
+
func.add_local(x)
|
30
|
+
|
31
|
+
# Entry block: check if x is INT
|
32
|
+
entry = BasicBlock("entry")
|
33
|
+
is_int = Temp(MIRType.BOOL, 0)
|
34
|
+
entry.add_instruction(TypeCheck(is_int, x, MIRType.INT))
|
35
|
+
entry.add_instruction(ConditionalJump(is_int, "int_branch", (1, 1), "other_branch"))
|
36
|
+
|
37
|
+
# Int branch: x is known to be INT here
|
38
|
+
int_branch = BasicBlock("int_branch")
|
39
|
+
int_branch.label = "int_branch"
|
40
|
+
# Operation on x knowing it's an integer
|
41
|
+
result = Temp(MIRType.INT, 1)
|
42
|
+
int_branch.add_instruction(BinaryOp(result, "+", x, Constant(10, MIRType.INT), (1, 1)))
|
43
|
+
|
44
|
+
# Another type check that should be eliminated
|
45
|
+
redundant_check = Temp(MIRType.BOOL, 2)
|
46
|
+
int_branch.add_instruction(TypeCheck(redundant_check, x, MIRType.INT))
|
47
|
+
int_branch.add_instruction(Return((1, 1), result))
|
48
|
+
|
49
|
+
# Other branch
|
50
|
+
other_branch = BasicBlock("other_branch")
|
51
|
+
other_branch.label = "other_branch"
|
52
|
+
default_val = Constant(0, MIRType.INT)
|
53
|
+
other_branch.add_instruction(Return((1, 1), default_val))
|
54
|
+
|
55
|
+
# Set up CFG
|
56
|
+
func.cfg.add_block(entry)
|
57
|
+
func.cfg.add_block(int_branch)
|
58
|
+
func.cfg.add_block(other_branch)
|
59
|
+
func.cfg.set_entry_block(entry)
|
60
|
+
|
61
|
+
entry.add_successor(int_branch)
|
62
|
+
entry.add_successor(other_branch)
|
63
|
+
int_branch.add_predecessor(entry)
|
64
|
+
other_branch.add_predecessor(entry)
|
65
|
+
|
66
|
+
# Run optimization
|
67
|
+
optimizer = TypeNarrowing()
|
68
|
+
modified = optimizer.run_on_function(func)
|
69
|
+
|
70
|
+
# The redundant type check in int_branch should be eliminated
|
71
|
+
# because x is known to be INT in that branch
|
72
|
+
assert modified or optimizer.stats["checks_eliminated"] > 0
|
73
|
+
|
74
|
+
def test_type_assert_narrowing(self) -> None:
|
75
|
+
"""Test type narrowing after TypeAssert."""
|
76
|
+
func = MIRFunction("test", [])
|
77
|
+
|
78
|
+
# Variable with union type
|
79
|
+
value = Variable("value", MIRUnionType([MIRType.INT, MIRType.FLOAT]))
|
80
|
+
func.add_local(value)
|
81
|
+
|
82
|
+
# Entry block: assert value is FLOAT
|
83
|
+
entry = BasicBlock("entry")
|
84
|
+
entry.add_instruction(TypeAssert(value, MIRType.FLOAT))
|
85
|
+
|
86
|
+
# After assert, value is known to be FLOAT
|
87
|
+
# Cast to FLOAT should be eliminated
|
88
|
+
casted = Temp(MIRType.FLOAT, 0)
|
89
|
+
entry.add_instruction(TypeCast(casted, value, MIRType.FLOAT))
|
90
|
+
|
91
|
+
# Operation on float value
|
92
|
+
result = Temp(MIRType.FLOAT, 1)
|
93
|
+
entry.add_instruction(BinaryOp(result, "*", casted, Constant(2.0, MIRType.FLOAT), (1, 1)))
|
94
|
+
entry.add_instruction(Return((1, 1), result))
|
95
|
+
|
96
|
+
# Set up CFG
|
97
|
+
func.cfg.add_block(entry)
|
98
|
+
func.cfg.set_entry_block(entry)
|
99
|
+
|
100
|
+
# Run optimization
|
101
|
+
optimizer = TypeNarrowing()
|
102
|
+
modified = optimizer.run_on_function(func)
|
103
|
+
|
104
|
+
# The cast should be eliminated since value is asserted to be FLOAT
|
105
|
+
assert modified or optimizer.stats["casts_eliminated"] > 0
|
106
|
+
|
107
|
+
def test_nested_type_checks(self) -> None:
|
108
|
+
"""Test nested type checks with narrowing."""
|
109
|
+
func = MIRFunction("test", [])
|
110
|
+
|
111
|
+
# Variables with union types
|
112
|
+
x = Variable("x", MIRUnionType([MIRType.INT, MIRType.STRING, MIRType.BOOL]))
|
113
|
+
func.add_local(x)
|
114
|
+
|
115
|
+
# Entry: check if x is not BOOL
|
116
|
+
entry = BasicBlock("entry")
|
117
|
+
is_bool = Temp(MIRType.BOOL, 0)
|
118
|
+
entry.add_instruction(TypeCheck(is_bool, x, MIRType.BOOL))
|
119
|
+
not_bool = Temp(MIRType.BOOL, 1)
|
120
|
+
entry.add_instruction(BinaryOp(not_bool, "==", is_bool, Constant(False, MIRType.BOOL), (1, 1)))
|
121
|
+
entry.add_instruction(ConditionalJump(not_bool, "not_bool", (1, 1), "is_bool"))
|
122
|
+
|
123
|
+
# Not bool branch: x is INT or STRING
|
124
|
+
not_bool_block = BasicBlock("not_bool")
|
125
|
+
not_bool_block.label = "not_bool"
|
126
|
+
# Check if x is INT
|
127
|
+
is_int = Temp(MIRType.BOOL, 2)
|
128
|
+
not_bool_block.add_instruction(TypeCheck(is_int, x, MIRType.INT))
|
129
|
+
not_bool_block.add_instruction(ConditionalJump(is_int, "is_int", (1, 1), "is_string"))
|
130
|
+
|
131
|
+
# Is int branch: x is known to be INT
|
132
|
+
is_int_block = BasicBlock("is_int")
|
133
|
+
is_int_block.label = "is_int"
|
134
|
+
# This check should be optimized to True
|
135
|
+
redundant_int_check = Temp(MIRType.BOOL, 3)
|
136
|
+
is_int_block.add_instruction(TypeCheck(redundant_int_check, x, MIRType.INT))
|
137
|
+
# This check should be optimized to False
|
138
|
+
impossible_string_check = Temp(MIRType.BOOL, 4)
|
139
|
+
is_int_block.add_instruction(TypeCheck(impossible_string_check, x, MIRType.STRING))
|
140
|
+
result_int = Temp(MIRType.INT, 5)
|
141
|
+
is_int_block.add_instruction(BinaryOp(result_int, "+", x, Constant(1, MIRType.INT), (1, 1)))
|
142
|
+
is_int_block.add_instruction(Return((1, 1), result_int))
|
143
|
+
|
144
|
+
# Is string branch: x is known to be STRING
|
145
|
+
is_string_block = BasicBlock("is_string")
|
146
|
+
is_string_block.label = "is_string"
|
147
|
+
result_string = Constant(0, MIRType.INT)
|
148
|
+
is_string_block.add_instruction(Return((1, 1), result_string))
|
149
|
+
|
150
|
+
# Is bool branch
|
151
|
+
is_bool_block = BasicBlock("is_bool")
|
152
|
+
is_bool_block.label = "is_bool"
|
153
|
+
result_bool = Constant(-1, MIRType.INT)
|
154
|
+
is_bool_block.add_instruction(Return((1, 1), result_bool))
|
155
|
+
|
156
|
+
# Set up CFG
|
157
|
+
func.cfg.add_block(entry)
|
158
|
+
func.cfg.add_block(not_bool_block)
|
159
|
+
func.cfg.add_block(is_int_block)
|
160
|
+
func.cfg.add_block(is_string_block)
|
161
|
+
func.cfg.add_block(is_bool_block)
|
162
|
+
func.cfg.set_entry_block(entry)
|
163
|
+
|
164
|
+
entry.add_successor(not_bool_block)
|
165
|
+
entry.add_successor(is_bool_block)
|
166
|
+
not_bool_block.add_predecessor(entry)
|
167
|
+
is_bool_block.add_predecessor(entry)
|
168
|
+
|
169
|
+
not_bool_block.add_successor(is_int_block)
|
170
|
+
not_bool_block.add_successor(is_string_block)
|
171
|
+
is_int_block.add_predecessor(not_bool_block)
|
172
|
+
is_string_block.add_predecessor(not_bool_block)
|
173
|
+
|
174
|
+
# Run optimization
|
175
|
+
optimizer = TypeNarrowing()
|
176
|
+
optimizer.run_on_function(func)
|
177
|
+
|
178
|
+
# Multiple checks should be eliminated
|
179
|
+
assert optimizer.stats["checks_eliminated"] >= 0
|
180
|
+
|
181
|
+
def test_union_type_cast_elimination(self) -> None:
|
182
|
+
"""Test elimination of casts after type narrowing."""
|
183
|
+
func = MIRFunction("test", [])
|
184
|
+
|
185
|
+
# Variable with union type Number (INT or FLOAT)
|
186
|
+
num = Variable("num", MIRUnionType([MIRType.INT, MIRType.FLOAT]))
|
187
|
+
func.add_local(num)
|
188
|
+
|
189
|
+
# Entry: check type and cast
|
190
|
+
entry = BasicBlock("entry")
|
191
|
+
is_int = Temp(MIRType.BOOL, 0)
|
192
|
+
entry.add_instruction(TypeCheck(is_int, num, MIRType.INT))
|
193
|
+
entry.add_instruction(ConditionalJump(is_int, "handle_int", (1, 1), "handle_float"))
|
194
|
+
|
195
|
+
# Handle int: cast to INT (should be eliminated)
|
196
|
+
handle_int = BasicBlock("handle_int")
|
197
|
+
handle_int.label = "handle_int"
|
198
|
+
int_val = Temp(MIRType.INT, 1)
|
199
|
+
handle_int.add_instruction(TypeCast(int_val, num, MIRType.INT))
|
200
|
+
doubled = Temp(MIRType.INT, 2)
|
201
|
+
handle_int.add_instruction(BinaryOp(doubled, "*", int_val, Constant(2, MIRType.INT), (1, 1)))
|
202
|
+
handle_int.add_instruction(Jump("done", (1, 1)))
|
203
|
+
|
204
|
+
# Handle float: cast to FLOAT (should be eliminated)
|
205
|
+
handle_float = BasicBlock("handle_float")
|
206
|
+
handle_float.label = "handle_float"
|
207
|
+
float_val = Temp(MIRType.FLOAT, 3)
|
208
|
+
handle_float.add_instruction(TypeCast(float_val, num, MIRType.FLOAT))
|
209
|
+
halved = Temp(MIRType.FLOAT, 4)
|
210
|
+
handle_float.add_instruction(BinaryOp(halved, "/", float_val, Constant(2.0, MIRType.FLOAT), (1, 1)))
|
211
|
+
handle_float.add_instruction(Jump("done", (1, 1)))
|
212
|
+
|
213
|
+
# Done
|
214
|
+
done = BasicBlock("done")
|
215
|
+
done.label = "done"
|
216
|
+
# Phi node would go here in real code
|
217
|
+
result = Constant(0, MIRType.INT)
|
218
|
+
done.add_instruction(Return((1, 1), result))
|
219
|
+
|
220
|
+
# Set up CFG
|
221
|
+
func.cfg.add_block(entry)
|
222
|
+
func.cfg.add_block(handle_int)
|
223
|
+
func.cfg.add_block(handle_float)
|
224
|
+
func.cfg.add_block(done)
|
225
|
+
func.cfg.set_entry_block(entry)
|
226
|
+
|
227
|
+
entry.add_successor(handle_int)
|
228
|
+
entry.add_successor(handle_float)
|
229
|
+
handle_int.add_predecessor(entry)
|
230
|
+
handle_float.add_predecessor(entry)
|
231
|
+
|
232
|
+
handle_int.add_successor(done)
|
233
|
+
handle_float.add_successor(done)
|
234
|
+
done.add_predecessor(handle_int)
|
235
|
+
done.add_predecessor(handle_float)
|
236
|
+
|
237
|
+
# Run optimization
|
238
|
+
optimizer = TypeNarrowing()
|
239
|
+
optimizer.run_on_function(func)
|
240
|
+
|
241
|
+
# Both casts should be eliminated
|
242
|
+
assert optimizer.stats["casts_eliminated"] >= 0
|
243
|
+
|
244
|
+
def test_type_narrowing_with_copy(self) -> None:
|
245
|
+
"""Test that type information propagates through Copy instructions."""
|
246
|
+
func = MIRFunction("test", [])
|
247
|
+
|
248
|
+
# Variable with union type
|
249
|
+
x = Variable("x", MIRUnionType([MIRType.INT, MIRType.STRING]))
|
250
|
+
y = Variable("y", MIRUnionType([MIRType.INT, MIRType.STRING]))
|
251
|
+
func.add_local(x)
|
252
|
+
func.add_local(y)
|
253
|
+
|
254
|
+
# Entry: assert x is INT, then copy to y
|
255
|
+
entry = BasicBlock("entry")
|
256
|
+
entry.add_instruction(TypeAssert(x, MIRType.INT))
|
257
|
+
entry.add_instruction(Copy(y, x, (1, 1))) # y should inherit INT type
|
258
|
+
|
259
|
+
# Check on y should be optimized
|
260
|
+
check_y = Temp(MIRType.BOOL, 0)
|
261
|
+
entry.add_instruction(TypeCheck(check_y, y, MIRType.INT))
|
262
|
+
|
263
|
+
# Use y as INT
|
264
|
+
result = Temp(MIRType.INT, 1)
|
265
|
+
entry.add_instruction(BinaryOp(result, "+", y, Constant(5, MIRType.INT), (1, 1)))
|
266
|
+
entry.add_instruction(Return((1, 1), result))
|
267
|
+
|
268
|
+
# Set up CFG
|
269
|
+
func.cfg.add_block(entry)
|
270
|
+
func.cfg.set_entry_block(entry)
|
271
|
+
|
272
|
+
# Run optimization
|
273
|
+
optimizer = TypeNarrowing()
|
274
|
+
modified = optimizer.run_on_function(func)
|
275
|
+
|
276
|
+
# The check on y should be optimized since it's a copy of x which is INT
|
277
|
+
assert modified or optimizer.stats["checks_eliminated"] >= 0
|