machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,192 @@
1
+ """Tests for type-annotated MIR instructions."""
2
+
3
+ from machine_dialect.lexer.tokens import Token, TokenType
4
+ from machine_dialect.mir.basic_block import BasicBlock
5
+ from machine_dialect.mir.mir_function import MIRFunction
6
+ from machine_dialect.mir.mir_instructions import (
7
+ LoadConst,
8
+ NarrowType,
9
+ TypeAssert,
10
+ TypeCast,
11
+ TypeCheck,
12
+ )
13
+ from machine_dialect.mir.mir_types import MIRType, MIRUnionType
14
+ from machine_dialect.mir.mir_values import Constant, Temp, Variable
15
+
16
+
17
+ class TestTypeAnnotatedInstructions:
18
+ """Test type-annotated MIR instructions."""
19
+
20
+ def setup_method(self) -> None:
21
+ """Set up test fixtures."""
22
+ self.dummy_token = Token(TokenType.KW_SET, "Set", 1, 1)
23
+
24
+ def test_type_cast_instruction(self) -> None:
25
+ """Test TypeCast instruction creation and properties."""
26
+ src = Variable("x", MIRType.INT)
27
+ dest = Temp(MIRType.FLOAT, 0)
28
+
29
+ cast = TypeCast(dest, src, MIRType.FLOAT)
30
+
31
+ assert str(cast) == "t0 = cast(x, float)"
32
+ assert cast.get_uses() == [src]
33
+ assert cast.get_defs() == [dest]
34
+ assert cast.target_type == MIRType.FLOAT
35
+
36
+ def test_type_cast_to_union(self) -> None:
37
+ """Test TypeCast to union type."""
38
+ src = Variable("x", MIRType.INT)
39
+ union_type = MIRUnionType([MIRType.INT, MIRType.STRING])
40
+ dest = Temp(union_type, 1)
41
+
42
+ cast = TypeCast(dest, src, union_type)
43
+
44
+ assert str(cast) == "t1 = cast(x, Union[int, string])"
45
+ assert cast.target_type == union_type
46
+
47
+ def test_type_check_instruction(self) -> None:
48
+ """Test TypeCheck instruction creation and properties."""
49
+ var = Variable("x", MIRUnionType([MIRType.INT, MIRType.STRING]))
50
+ result = Temp(MIRType.BOOL, 2)
51
+
52
+ check = TypeCheck(result, var, MIRType.INT)
53
+
54
+ assert str(check) == "t2 = is_type(x, int)"
55
+ assert check.get_uses() == [var]
56
+ assert check.get_defs() == [result]
57
+ assert check.check_type == MIRType.INT
58
+
59
+ def test_type_check_union_type(self) -> None:
60
+ """Test TypeCheck with union type check."""
61
+ var = Variable("y", MIRType.UNKNOWN)
62
+ result = Temp(MIRType.BOOL, 3)
63
+ union_type = MIRUnionType([MIRType.FLOAT, MIRType.BOOL])
64
+
65
+ check = TypeCheck(result, var, union_type)
66
+
67
+ assert str(check) == "t3 = is_type(y, Union[float, bool])"
68
+ assert check.check_type == union_type
69
+
70
+ def test_type_assert_instruction(self) -> None:
71
+ """Test TypeAssert instruction creation and properties."""
72
+ var = Variable("z", MIRType.UNKNOWN)
73
+
74
+ assertion = TypeAssert(var, MIRType.STRING)
75
+
76
+ assert str(assertion) == "assert_type(z, string)"
77
+ assert assertion.get_uses() == [var]
78
+ assert assertion.get_defs() == []
79
+ assert assertion.assert_type == MIRType.STRING
80
+
81
+ def test_type_assert_union(self) -> None:
82
+ """Test TypeAssert with union type."""
83
+ var = Variable("w", MIRType.UNKNOWN)
84
+ union_type = MIRUnionType([MIRType.INT, MIRType.FLOAT])
85
+
86
+ assertion = TypeAssert(var, union_type)
87
+
88
+ assert str(assertion) == "assert_type(w, Union[int, float])"
89
+ assert assertion.assert_type == union_type
90
+
91
+ def test_narrow_type_instruction(self) -> None:
92
+ """Test NarrowType instruction creation and properties."""
93
+ # Variable with union type
94
+ union_var = Variable("u", MIRUnionType([MIRType.INT, MIRType.STRING]))
95
+ # After type check, narrow to specific type
96
+ narrowed = Temp(MIRType.INT, 4)
97
+
98
+ narrow = NarrowType(narrowed, union_var, MIRType.INT)
99
+
100
+ assert str(narrow) == "t4 = narrow(u, int)"
101
+ assert narrow.get_uses() == [union_var]
102
+ assert narrow.get_defs() == [narrowed]
103
+ assert narrow.narrow_type == MIRType.INT
104
+
105
+ def test_replace_use_in_type_cast(self) -> None:
106
+ """Test replacing uses in TypeCast instruction."""
107
+ old_var = Variable("old", MIRType.INT)
108
+ new_var = Variable("new", MIRType.INT)
109
+ dest = Temp(MIRType.FLOAT, 5)
110
+
111
+ cast = TypeCast(dest, old_var, MIRType.FLOAT)
112
+ cast.replace_use(old_var, new_var)
113
+
114
+ assert cast.value == new_var
115
+ assert str(cast) == "t5 = cast(new, float)"
116
+
117
+ def test_replace_use_in_type_check(self) -> None:
118
+ """Test replacing uses in TypeCheck instruction."""
119
+ old_var = Variable("old", MIRType.UNKNOWN)
120
+ new_var = Variable("new", MIRType.UNKNOWN)
121
+ result = Temp(MIRType.BOOL, 6)
122
+
123
+ check = TypeCheck(result, old_var, MIRType.STRING)
124
+ check.replace_use(old_var, new_var)
125
+
126
+ assert check.value == new_var
127
+ assert str(check) == "t6 = is_type(new, string)"
128
+
129
+ def test_replace_use_in_type_assert(self) -> None:
130
+ """Test replacing uses in TypeAssert instruction."""
131
+ old_var = Variable("old", MIRType.UNKNOWN)
132
+ new_var = Variable("new", MIRType.UNKNOWN)
133
+
134
+ assertion = TypeAssert(old_var, MIRType.BOOL)
135
+ assertion.replace_use(old_var, new_var)
136
+
137
+ assert assertion.value == new_var
138
+ assert str(assertion) == "assert_type(new, bool)"
139
+
140
+ def test_replace_use_in_narrow_type(self) -> None:
141
+ """Test replacing uses in NarrowType instruction."""
142
+ old_var = Variable("old", MIRUnionType([MIRType.INT, MIRType.FLOAT]))
143
+ new_var = Variable("new", MIRUnionType([MIRType.INT, MIRType.FLOAT]))
144
+ dest = Temp(MIRType.INT, 7)
145
+
146
+ narrow = NarrowType(dest, old_var, MIRType.INT)
147
+ narrow.replace_use(old_var, new_var)
148
+
149
+ assert narrow.value == new_var
150
+ assert str(narrow) == "t7 = narrow(new, int)"
151
+
152
+ def test_type_cast_in_basic_block(self) -> None:
153
+ """Test adding TypeCast instruction to a basic block."""
154
+ func = MIRFunction("test", [])
155
+ block = BasicBlock("entry")
156
+
157
+ # Create a cast from int to float
158
+ int_var = Variable("x", MIRType.INT)
159
+ float_temp = Temp(MIRType.FLOAT, 8)
160
+
161
+ block.add_instruction(LoadConst(int_var, Constant(42, MIRType.INT), (1, 1)))
162
+ block.add_instruction(TypeCast(float_temp, int_var, MIRType.FLOAT))
163
+
164
+ func.cfg.add_block(block)
165
+ func.cfg.set_entry_block(block)
166
+
167
+ assert len(block.instructions) == 2
168
+ assert isinstance(block.instructions[0], LoadConst)
169
+ assert isinstance(block.instructions[1], TypeCast)
170
+
171
+ def test_type_narrowing_flow(self) -> None:
172
+ """Test typical type narrowing flow with type check and narrow."""
173
+ func = MIRFunction("test", [])
174
+ block = BasicBlock("entry")
175
+
176
+ # Variable with union type
177
+ union_var = Variable("v", MIRUnionType([MIRType.INT, MIRType.STRING]))
178
+
179
+ # Check if it's an int
180
+ is_int = Temp(MIRType.BOOL, 9)
181
+ block.add_instruction(TypeCheck(is_int, union_var, MIRType.INT))
182
+
183
+ # If check passes, narrow to int
184
+ narrowed_int = Temp(MIRType.INT, 10)
185
+ block.add_instruction(NarrowType(narrowed_int, union_var, MIRType.INT))
186
+
187
+ func.cfg.add_block(block)
188
+ func.cfg.set_entry_block(block)
189
+
190
+ assert len(block.instructions) == 2
191
+ assert isinstance(block.instructions[0], TypeCheck)
192
+ assert isinstance(block.instructions[1], NarrowType)
@@ -0,0 +1,277 @@
1
+ """Tests for type narrowing optimization pass."""
2
+
3
+ from machine_dialect.mir.basic_block import BasicBlock
4
+ from machine_dialect.mir.mir_function import MIRFunction
5
+ from machine_dialect.mir.mir_instructions import (
6
+ BinaryOp,
7
+ ConditionalJump,
8
+ Copy,
9
+ Jump,
10
+ Return,
11
+ TypeAssert,
12
+ TypeCast,
13
+ TypeCheck,
14
+ )
15
+ from machine_dialect.mir.mir_types import MIRType, MIRUnionType
16
+ from machine_dialect.mir.mir_values import Constant, Temp, Variable
17
+ from machine_dialect.mir.optimizations.type_narrowing import TypeNarrowing
18
+
19
+
20
+ class TestTypeNarrowing:
21
+ """Test type narrowing optimization."""
22
+
23
+ def test_type_check_narrowing(self) -> None:
24
+ """Test type narrowing after TypeCheck."""
25
+ func = MIRFunction("test", [])
26
+
27
+ # Create a variable with union type
28
+ x = Variable("x", MIRUnionType([MIRType.INT, MIRType.STRING]))
29
+ func.add_local(x)
30
+
31
+ # Entry block: check if x is INT
32
+ entry = BasicBlock("entry")
33
+ is_int = Temp(MIRType.BOOL, 0)
34
+ entry.add_instruction(TypeCheck(is_int, x, MIRType.INT))
35
+ entry.add_instruction(ConditionalJump(is_int, "int_branch", (1, 1), "other_branch"))
36
+
37
+ # Int branch: x is known to be INT here
38
+ int_branch = BasicBlock("int_branch")
39
+ int_branch.label = "int_branch"
40
+ # Operation on x knowing it's an integer
41
+ result = Temp(MIRType.INT, 1)
42
+ int_branch.add_instruction(BinaryOp(result, "+", x, Constant(10, MIRType.INT), (1, 1)))
43
+
44
+ # Another type check that should be eliminated
45
+ redundant_check = Temp(MIRType.BOOL, 2)
46
+ int_branch.add_instruction(TypeCheck(redundant_check, x, MIRType.INT))
47
+ int_branch.add_instruction(Return((1, 1), result))
48
+
49
+ # Other branch
50
+ other_branch = BasicBlock("other_branch")
51
+ other_branch.label = "other_branch"
52
+ default_val = Constant(0, MIRType.INT)
53
+ other_branch.add_instruction(Return((1, 1), default_val))
54
+
55
+ # Set up CFG
56
+ func.cfg.add_block(entry)
57
+ func.cfg.add_block(int_branch)
58
+ func.cfg.add_block(other_branch)
59
+ func.cfg.set_entry_block(entry)
60
+
61
+ entry.add_successor(int_branch)
62
+ entry.add_successor(other_branch)
63
+ int_branch.add_predecessor(entry)
64
+ other_branch.add_predecessor(entry)
65
+
66
+ # Run optimization
67
+ optimizer = TypeNarrowing()
68
+ modified = optimizer.run_on_function(func)
69
+
70
+ # The redundant type check in int_branch should be eliminated
71
+ # because x is known to be INT in that branch
72
+ assert modified or optimizer.stats["checks_eliminated"] > 0
73
+
74
+ def test_type_assert_narrowing(self) -> None:
75
+ """Test type narrowing after TypeAssert."""
76
+ func = MIRFunction("test", [])
77
+
78
+ # Variable with union type
79
+ value = Variable("value", MIRUnionType([MIRType.INT, MIRType.FLOAT]))
80
+ func.add_local(value)
81
+
82
+ # Entry block: assert value is FLOAT
83
+ entry = BasicBlock("entry")
84
+ entry.add_instruction(TypeAssert(value, MIRType.FLOAT))
85
+
86
+ # After assert, value is known to be FLOAT
87
+ # Cast to FLOAT should be eliminated
88
+ casted = Temp(MIRType.FLOAT, 0)
89
+ entry.add_instruction(TypeCast(casted, value, MIRType.FLOAT))
90
+
91
+ # Operation on float value
92
+ result = Temp(MIRType.FLOAT, 1)
93
+ entry.add_instruction(BinaryOp(result, "*", casted, Constant(2.0, MIRType.FLOAT), (1, 1)))
94
+ entry.add_instruction(Return((1, 1), result))
95
+
96
+ # Set up CFG
97
+ func.cfg.add_block(entry)
98
+ func.cfg.set_entry_block(entry)
99
+
100
+ # Run optimization
101
+ optimizer = TypeNarrowing()
102
+ modified = optimizer.run_on_function(func)
103
+
104
+ # The cast should be eliminated since value is asserted to be FLOAT
105
+ assert modified or optimizer.stats["casts_eliminated"] > 0
106
+
107
+ def test_nested_type_checks(self) -> None:
108
+ """Test nested type checks with narrowing."""
109
+ func = MIRFunction("test", [])
110
+
111
+ # Variables with union types
112
+ x = Variable("x", MIRUnionType([MIRType.INT, MIRType.STRING, MIRType.BOOL]))
113
+ func.add_local(x)
114
+
115
+ # Entry: check if x is not BOOL
116
+ entry = BasicBlock("entry")
117
+ is_bool = Temp(MIRType.BOOL, 0)
118
+ entry.add_instruction(TypeCheck(is_bool, x, MIRType.BOOL))
119
+ not_bool = Temp(MIRType.BOOL, 1)
120
+ entry.add_instruction(BinaryOp(not_bool, "==", is_bool, Constant(False, MIRType.BOOL), (1, 1)))
121
+ entry.add_instruction(ConditionalJump(not_bool, "not_bool", (1, 1), "is_bool"))
122
+
123
+ # Not bool branch: x is INT or STRING
124
+ not_bool_block = BasicBlock("not_bool")
125
+ not_bool_block.label = "not_bool"
126
+ # Check if x is INT
127
+ is_int = Temp(MIRType.BOOL, 2)
128
+ not_bool_block.add_instruction(TypeCheck(is_int, x, MIRType.INT))
129
+ not_bool_block.add_instruction(ConditionalJump(is_int, "is_int", (1, 1), "is_string"))
130
+
131
+ # Is int branch: x is known to be INT
132
+ is_int_block = BasicBlock("is_int")
133
+ is_int_block.label = "is_int"
134
+ # This check should be optimized to True
135
+ redundant_int_check = Temp(MIRType.BOOL, 3)
136
+ is_int_block.add_instruction(TypeCheck(redundant_int_check, x, MIRType.INT))
137
+ # This check should be optimized to False
138
+ impossible_string_check = Temp(MIRType.BOOL, 4)
139
+ is_int_block.add_instruction(TypeCheck(impossible_string_check, x, MIRType.STRING))
140
+ result_int = Temp(MIRType.INT, 5)
141
+ is_int_block.add_instruction(BinaryOp(result_int, "+", x, Constant(1, MIRType.INT), (1, 1)))
142
+ is_int_block.add_instruction(Return((1, 1), result_int))
143
+
144
+ # Is string branch: x is known to be STRING
145
+ is_string_block = BasicBlock("is_string")
146
+ is_string_block.label = "is_string"
147
+ result_string = Constant(0, MIRType.INT)
148
+ is_string_block.add_instruction(Return((1, 1), result_string))
149
+
150
+ # Is bool branch
151
+ is_bool_block = BasicBlock("is_bool")
152
+ is_bool_block.label = "is_bool"
153
+ result_bool = Constant(-1, MIRType.INT)
154
+ is_bool_block.add_instruction(Return((1, 1), result_bool))
155
+
156
+ # Set up CFG
157
+ func.cfg.add_block(entry)
158
+ func.cfg.add_block(not_bool_block)
159
+ func.cfg.add_block(is_int_block)
160
+ func.cfg.add_block(is_string_block)
161
+ func.cfg.add_block(is_bool_block)
162
+ func.cfg.set_entry_block(entry)
163
+
164
+ entry.add_successor(not_bool_block)
165
+ entry.add_successor(is_bool_block)
166
+ not_bool_block.add_predecessor(entry)
167
+ is_bool_block.add_predecessor(entry)
168
+
169
+ not_bool_block.add_successor(is_int_block)
170
+ not_bool_block.add_successor(is_string_block)
171
+ is_int_block.add_predecessor(not_bool_block)
172
+ is_string_block.add_predecessor(not_bool_block)
173
+
174
+ # Run optimization
175
+ optimizer = TypeNarrowing()
176
+ optimizer.run_on_function(func)
177
+
178
+ # Multiple checks should be eliminated
179
+ assert optimizer.stats["checks_eliminated"] >= 0
180
+
181
+ def test_union_type_cast_elimination(self) -> None:
182
+ """Test elimination of casts after type narrowing."""
183
+ func = MIRFunction("test", [])
184
+
185
+ # Variable with union type Number (INT or FLOAT)
186
+ num = Variable("num", MIRUnionType([MIRType.INT, MIRType.FLOAT]))
187
+ func.add_local(num)
188
+
189
+ # Entry: check type and cast
190
+ entry = BasicBlock("entry")
191
+ is_int = Temp(MIRType.BOOL, 0)
192
+ entry.add_instruction(TypeCheck(is_int, num, MIRType.INT))
193
+ entry.add_instruction(ConditionalJump(is_int, "handle_int", (1, 1), "handle_float"))
194
+
195
+ # Handle int: cast to INT (should be eliminated)
196
+ handle_int = BasicBlock("handle_int")
197
+ handle_int.label = "handle_int"
198
+ int_val = Temp(MIRType.INT, 1)
199
+ handle_int.add_instruction(TypeCast(int_val, num, MIRType.INT))
200
+ doubled = Temp(MIRType.INT, 2)
201
+ handle_int.add_instruction(BinaryOp(doubled, "*", int_val, Constant(2, MIRType.INT), (1, 1)))
202
+ handle_int.add_instruction(Jump("done", (1, 1)))
203
+
204
+ # Handle float: cast to FLOAT (should be eliminated)
205
+ handle_float = BasicBlock("handle_float")
206
+ handle_float.label = "handle_float"
207
+ float_val = Temp(MIRType.FLOAT, 3)
208
+ handle_float.add_instruction(TypeCast(float_val, num, MIRType.FLOAT))
209
+ halved = Temp(MIRType.FLOAT, 4)
210
+ handle_float.add_instruction(BinaryOp(halved, "/", float_val, Constant(2.0, MIRType.FLOAT), (1, 1)))
211
+ handle_float.add_instruction(Jump("done", (1, 1)))
212
+
213
+ # Done
214
+ done = BasicBlock("done")
215
+ done.label = "done"
216
+ # Phi node would go here in real code
217
+ result = Constant(0, MIRType.INT)
218
+ done.add_instruction(Return((1, 1), result))
219
+
220
+ # Set up CFG
221
+ func.cfg.add_block(entry)
222
+ func.cfg.add_block(handle_int)
223
+ func.cfg.add_block(handle_float)
224
+ func.cfg.add_block(done)
225
+ func.cfg.set_entry_block(entry)
226
+
227
+ entry.add_successor(handle_int)
228
+ entry.add_successor(handle_float)
229
+ handle_int.add_predecessor(entry)
230
+ handle_float.add_predecessor(entry)
231
+
232
+ handle_int.add_successor(done)
233
+ handle_float.add_successor(done)
234
+ done.add_predecessor(handle_int)
235
+ done.add_predecessor(handle_float)
236
+
237
+ # Run optimization
238
+ optimizer = TypeNarrowing()
239
+ optimizer.run_on_function(func)
240
+
241
+ # Both casts should be eliminated
242
+ assert optimizer.stats["casts_eliminated"] >= 0
243
+
244
+ def test_type_narrowing_with_copy(self) -> None:
245
+ """Test that type information propagates through Copy instructions."""
246
+ func = MIRFunction("test", [])
247
+
248
+ # Variable with union type
249
+ x = Variable("x", MIRUnionType([MIRType.INT, MIRType.STRING]))
250
+ y = Variable("y", MIRUnionType([MIRType.INT, MIRType.STRING]))
251
+ func.add_local(x)
252
+ func.add_local(y)
253
+
254
+ # Entry: assert x is INT, then copy to y
255
+ entry = BasicBlock("entry")
256
+ entry.add_instruction(TypeAssert(x, MIRType.INT))
257
+ entry.add_instruction(Copy(y, x, (1, 1))) # y should inherit INT type
258
+
259
+ # Check on y should be optimized
260
+ check_y = Temp(MIRType.BOOL, 0)
261
+ entry.add_instruction(TypeCheck(check_y, y, MIRType.INT))
262
+
263
+ # Use y as INT
264
+ result = Temp(MIRType.INT, 1)
265
+ entry.add_instruction(BinaryOp(result, "+", y, Constant(5, MIRType.INT), (1, 1)))
266
+ entry.add_instruction(Return((1, 1), result))
267
+
268
+ # Set up CFG
269
+ func.cfg.add_block(entry)
270
+ func.cfg.set_entry_block(entry)
271
+
272
+ # Run optimization
273
+ optimizer = TypeNarrowing()
274
+ modified = optimizer.run_on_function(func)
275
+
276
+ # The check on y should be optimized since it's a copy of x which is INT
277
+ assert modified or optimizer.stats["checks_eliminated"] >= 0