machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,459 @@
1
+ """Comprehensive tests for branch prediction optimization pass."""
2
+
3
+ from unittest.mock import MagicMock
4
+
5
+ from machine_dialect.mir.basic_block import BasicBlock
6
+ from machine_dialect.mir.mir_function import MIRFunction
7
+ from machine_dialect.mir.mir_instructions import (
8
+ BinaryOp,
9
+ ConditionalJump,
10
+ Jump,
11
+ LoadConst,
12
+ Return,
13
+ )
14
+ from machine_dialect.mir.mir_module import MIRModule
15
+ from machine_dialect.mir.mir_types import MIRType
16
+ from machine_dialect.mir.mir_values import Constant, Temp
17
+ from machine_dialect.mir.optimization_pass import PassType, PreservationLevel
18
+ from machine_dialect.mir.optimizations.branch_prediction import (
19
+ BranchInfo,
20
+ BranchPredictionOptimization,
21
+ )
22
+ from machine_dialect.mir.profiling.profile_data import ProfileData
23
+
24
+
25
+ class TestBranchInfo:
26
+ """Test BranchInfo dataclass."""
27
+
28
+ def test_branch_info_creation(self) -> None:
29
+ """Test creating BranchInfo."""
30
+ block = BasicBlock("test_block")
31
+ temp = Temp(MIRType.BOOL)
32
+ jump = ConditionalJump(temp, "then_block", (1, 1), "else_block")
33
+
34
+ info = BranchInfo(
35
+ block=block,
36
+ instruction=jump,
37
+ taken_probability=0.98,
38
+ is_predictable=True,
39
+ is_loop_header=False,
40
+ )
41
+
42
+ assert info.block == block
43
+ assert info.instruction == jump
44
+ assert info.taken_probability == 0.98
45
+ assert info.is_predictable
46
+ assert not info.is_loop_header
47
+
48
+ def test_branch_info_defaults(self) -> None:
49
+ """Test BranchInfo default values."""
50
+ block = BasicBlock("test_block")
51
+ temp = Temp(MIRType.BOOL)
52
+ jump = ConditionalJump(temp, "then_block", (1, 1), "else_block")
53
+
54
+ info = BranchInfo(
55
+ block=block,
56
+ instruction=jump,
57
+ taken_probability=0.5,
58
+ is_predictable=False,
59
+ )
60
+
61
+ assert not info.is_loop_header
62
+
63
+
64
+ class TestBranchPredictionOptimization:
65
+ """Test BranchPredictionOptimization pass."""
66
+
67
+ def setup_method(self) -> None:
68
+ """Set up test fixtures."""
69
+ self.module = MIRModule("test")
70
+
71
+ # Create a function with conditional branch
72
+ self.func = MIRFunction("test_func", [], MIRType.INT)
73
+
74
+ # Create basic blocks
75
+ self.entry_block = BasicBlock("entry")
76
+ self.then_block = BasicBlock("then")
77
+ self.else_block = BasicBlock("else")
78
+ self.merge_block = BasicBlock("merge")
79
+
80
+ # Build control flow: if (x > 5) then ... else ...
81
+ x = Temp(MIRType.INT)
82
+ cond = Temp(MIRType.BOOL)
83
+ result = Temp(MIRType.INT)
84
+
85
+ # Entry block
86
+ self.entry_block.add_instruction(LoadConst(x, Constant(10, MIRType.INT), (1, 1)))
87
+ self.entry_block.add_instruction(BinaryOp(cond, ">", x, Constant(5, MIRType.INT), (1, 1)))
88
+ self.entry_block.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
89
+
90
+ # Then block
91
+ self.then_block.add_instruction(LoadConst(result, Constant(1, MIRType.INT), (1, 1)))
92
+ self.then_block.add_instruction(Jump("merge", (1, 1)))
93
+
94
+ # Else block
95
+ self.else_block.add_instruction(LoadConst(result, Constant(0, MIRType.INT), (1, 1)))
96
+ self.else_block.add_instruction(Jump("merge", (1, 1)))
97
+
98
+ # Merge block
99
+ self.merge_block.add_instruction(Return((1, 1), result))
100
+
101
+ # Add blocks to CFG
102
+ self.func.cfg.add_block(self.entry_block)
103
+ self.func.cfg.add_block(self.then_block)
104
+ self.func.cfg.add_block(self.else_block)
105
+ self.func.cfg.add_block(self.merge_block)
106
+ self.func.cfg.entry_block = self.entry_block
107
+
108
+ # Set up control flow edges via successors/predecessors
109
+ self.entry_block.add_successor(self.then_block)
110
+ self.entry_block.add_successor(self.else_block)
111
+ self.then_block.add_predecessor(self.entry_block)
112
+ self.else_block.add_predecessor(self.entry_block)
113
+
114
+ self.then_block.add_successor(self.merge_block)
115
+ self.else_block.add_successor(self.merge_block)
116
+ self.merge_block.add_predecessor(self.then_block)
117
+ self.merge_block.add_predecessor(self.else_block)
118
+
119
+ self.module.add_function(self.func)
120
+
121
+ def test_pass_initialization(self) -> None:
122
+ """Test initialization of branch prediction pass."""
123
+ opt = BranchPredictionOptimization(predictability_threshold=0.9)
124
+ assert opt.profile_data is None
125
+ assert opt.predictability_threshold == 0.9
126
+ assert opt.stats["branches_analyzed"] == 0
127
+
128
+ def test_pass_info(self) -> None:
129
+ """Test pass information."""
130
+ opt = BranchPredictionOptimization()
131
+ info = opt.get_info()
132
+ assert info.name == "branch-prediction"
133
+ assert info.pass_type == PassType.OPTIMIZATION
134
+ # Branch prediction might preserve CFG structure
135
+ assert info.preserves in [PreservationLevel.NONE, PreservationLevel.CFG]
136
+
137
+ def test_collect_branch_info(self) -> None:
138
+ """Test collecting branch information."""
139
+ opt = BranchPredictionOptimization()
140
+
141
+ # Mock profile data with proper method
142
+ profile = MagicMock(spec=ProfileData)
143
+ # Add branches dict with proper branch key and profile
144
+ from unittest.mock import Mock
145
+
146
+ branch_profile = Mock()
147
+ branch_profile.taken_probability = 0.98
148
+ branch_profile.predictable = True
149
+ profile.branches = {"test_func:entry": branch_profile}
150
+ opt.profile_data = profile
151
+
152
+ branches = opt._collect_branch_info(self.func)
153
+
154
+ # Should find one branch (in entry block)
155
+ assert len(branches) == 1
156
+ branch_info = branches[0]
157
+ assert branch_info.block.label == "entry"
158
+ assert isinstance(branch_info.instruction, ConditionalJump)
159
+ assert branch_info.taken_probability == 0.98
160
+ assert branch_info.is_predictable
161
+
162
+ def test_collect_branch_info_no_profile(self) -> None:
163
+ """Test collecting branch information without profile data."""
164
+ opt = BranchPredictionOptimization()
165
+
166
+ branches = opt._collect_branch_info(self.func)
167
+
168
+ # Should find one branch with default probability
169
+ assert len(branches) == 1
170
+ branch_info = branches[0]
171
+ assert branch_info.taken_probability == 0.5 # Default
172
+ assert not branch_info.is_predictable
173
+
174
+ def test_reorder_blocks_for_fallthrough(self) -> None:
175
+ """Test reordering blocks for better fallthrough."""
176
+ opt = BranchPredictionOptimization()
177
+
178
+ # Mock profile data - then branch is highly likely
179
+ profile = MagicMock(spec=ProfileData)
180
+ from unittest.mock import Mock
181
+
182
+ branch_profile = Mock()
183
+ branch_profile.taken_probability = 0.99
184
+ branch_profile.predictable = True
185
+ profile.branches = {"test_func:entry": branch_profile}
186
+ opt.profile_data = profile
187
+
188
+ # Ensure blocks are in non-optimal order initially
189
+ # Put else block before then block to force reordering
190
+ original_order = list(self.func.cfg.blocks.keys())
191
+ if "else" in original_order and "then" in original_order:
192
+ # Reorder to put else before then (non-optimal for 0.99 taken probability)
193
+ new_blocks = {}
194
+ new_blocks["entry"] = self.func.cfg.blocks["entry"]
195
+ new_blocks["else"] = self.func.cfg.blocks["else"]
196
+ new_blocks["then"] = self.func.cfg.blocks["then"]
197
+ new_blocks["merge"] = self.func.cfg.blocks["merge"]
198
+ self.func.cfg.blocks = new_blocks
199
+
200
+ # Collect branch info
201
+ branches = opt._collect_branch_info(self.func)
202
+
203
+ # Reorder blocks
204
+ reordered = opt._reorder_blocks(self.func, branches)
205
+
206
+ # Should reorder to put likely path (then) after entry
207
+ assert reordered
208
+
209
+ # Check new order has then block after entry
210
+ new_order = list(self.func.cfg.blocks.keys())
211
+ entry_idx = new_order.index("entry")
212
+ then_idx = new_order.index("then")
213
+ # Then should come right after entry for better fallthrough
214
+ assert then_idx == entry_idx + 1
215
+
216
+ def test_add_branch_hints(self) -> None:
217
+ """Test adding branch hints."""
218
+ opt = BranchPredictionOptimization()
219
+
220
+ # Create predictable branch info
221
+ last_inst = list(self.entry_block.instructions)[-1]
222
+ assert isinstance(last_inst, ConditionalJump)
223
+ branch_info = BranchInfo(
224
+ block=self.entry_block,
225
+ instruction=last_inst,
226
+ taken_probability=0.98,
227
+ is_predictable=True,
228
+ )
229
+
230
+ # Add branch hint (method takes single BranchInfo, not list)
231
+ hint_added = opt._add_branch_hint(branch_info)
232
+
233
+ # Check if hint was added (returns bool)
234
+ assert isinstance(hint_added, bool)
235
+
236
+ def test_convert_to_select(self) -> None:
237
+ """Test converting predictable branches to select instructions."""
238
+ opt = BranchPredictionOptimization()
239
+
240
+ # Create simple if-then-else that can be converted
241
+ func = MIRFunction("simple", [], MIRType.INT)
242
+ entry = BasicBlock("entry")
243
+ then_block = BasicBlock("then")
244
+ else_block = BasicBlock("else")
245
+ merge = BasicBlock("merge")
246
+
247
+ cond = Temp(MIRType.BOOL)
248
+ result = Temp(MIRType.INT)
249
+
250
+ # Simple pattern: result = cond ? 1 : 0
251
+ entry.add_instruction(LoadConst(cond, Constant(True, MIRType.BOOL), (1, 1)))
252
+ entry.add_instruction(ConditionalJump(cond, "then", (1, 1), "else"))
253
+
254
+ then_block.add_instruction(LoadConst(result, Constant(1, MIRType.INT), (1, 1)))
255
+ then_block.add_instruction(Jump("merge", (1, 1)))
256
+
257
+ else_block.add_instruction(LoadConst(result, Constant(0, MIRType.INT), (1, 1)))
258
+ else_block.add_instruction(Jump("merge", (1, 1)))
259
+
260
+ merge.add_instruction(Return((1, 1), result))
261
+
262
+ func.cfg.add_block(entry)
263
+ func.cfg.add_block(then_block)
264
+ func.cfg.add_block(else_block)
265
+ func.cfg.add_block(merge)
266
+ func.cfg.entry_block = entry
267
+
268
+ # Create highly predictable branch
269
+ last_inst = list(entry.instructions)[-1]
270
+ assert isinstance(last_inst, ConditionalJump)
271
+ branch_info = BranchInfo(
272
+ block=entry,
273
+ instruction=last_inst,
274
+ taken_probability=0.99,
275
+ is_predictable=True,
276
+ )
277
+
278
+ # Try to convert (takes single BranchInfo)
279
+ converted = opt._convert_to_select(branch_info)
280
+
281
+ # Note: Actual conversion depends on implementation
282
+ # This test verifies the method exists and runs
283
+ assert isinstance(converted, bool)
284
+
285
+ def test_detect_loop_headers(self) -> None:
286
+ """Test detecting loop header branches."""
287
+ # Create a function with a loop
288
+ func = MIRFunction("loop_func", [], MIRType.INT)
289
+
290
+ # Create loop structure
291
+ entry = BasicBlock("entry")
292
+ loop_header = BasicBlock("loop_header")
293
+ loop_body = BasicBlock("loop_body")
294
+ loop_exit = BasicBlock("loop_exit")
295
+
296
+ i = Temp(MIRType.INT)
297
+ cond = Temp(MIRType.BOOL)
298
+
299
+ # Entry
300
+ entry.add_instruction(LoadConst(i, Constant(0, MIRType.INT), (1, 1)))
301
+ entry.add_instruction(Jump("loop_header", (1, 1)))
302
+
303
+ # Loop header (condition check)
304
+ loop_header.add_instruction(BinaryOp(cond, "<", i, Constant(10, MIRType.INT), (1, 1)))
305
+ loop_header.add_instruction(ConditionalJump(cond, "loop_body", (1, 1), "loop_exit"))
306
+
307
+ # Loop body
308
+ loop_body.add_instruction(BinaryOp(i, "+", i, Constant(1, MIRType.INT), (1, 1)))
309
+ loop_body.add_instruction(Jump("loop_header", (1, 1)))
310
+
311
+ # Loop exit
312
+ loop_exit.add_instruction(Return((1, 1), i))
313
+
314
+ func.cfg.add_block(entry)
315
+ func.cfg.add_block(loop_header)
316
+ func.cfg.add_block(loop_body)
317
+ func.cfg.add_block(loop_exit)
318
+ func.cfg.entry_block = entry
319
+
320
+ # Set up control flow edges
321
+ entry.add_successor(loop_header)
322
+ loop_header.add_predecessor(entry)
323
+
324
+ loop_header.add_successor(loop_body)
325
+ loop_header.add_successor(loop_exit)
326
+ loop_body.add_predecessor(loop_header)
327
+ loop_exit.add_predecessor(loop_header)
328
+
329
+ loop_body.add_successor(loop_header) # Back edge
330
+ loop_header.add_predecessor(loop_body)
331
+
332
+ opt = BranchPredictionOptimization()
333
+ branches = opt._collect_branch_info(func)
334
+
335
+ # Find loop header branch
336
+ loop_branches = [b for b in branches if b.block.label == "loop_header"]
337
+ assert len(loop_branches) == 1
338
+
339
+ # Check if loop header is detected during collection
340
+ # (loop detection happens internally during branch collection)
341
+ assert len(loop_branches) == 1
342
+
343
+ def test_run_on_module_with_profile(self) -> None:
344
+ """Test running optimization with profile data."""
345
+ # Create mock profile data
346
+ profile = MagicMock(spec=ProfileData)
347
+ from unittest.mock import Mock
348
+
349
+ branch_profile = Mock()
350
+ branch_profile.taken_probability = 0.97
351
+ branch_profile.predictable = True
352
+ profile.branches = {"test_func:entry": branch_profile}
353
+ profile.get_function_metrics = MagicMock(
354
+ return_value={
355
+ "call_count": 1000,
356
+ "branches": {
357
+ ("entry", "then"): 970,
358
+ ("entry", "else"): 30,
359
+ },
360
+ }
361
+ )
362
+
363
+ opt = BranchPredictionOptimization(profile_data=profile)
364
+
365
+ # Run optimization
366
+ opt.run_on_module(self.module)
367
+
368
+ # Should analyze branches
369
+ assert opt.stats["branches_analyzed"] > 0
370
+
371
+ def test_run_on_module_without_profile(self) -> None:
372
+ """Test running optimization without profile data."""
373
+ opt = BranchPredictionOptimization()
374
+
375
+ # Run optimization
376
+ opt.run_on_module(self.module)
377
+
378
+ # Should still analyze branches with defaults
379
+ assert opt.stats["branches_analyzed"] > 0
380
+
381
+ def test_highly_biased_branch_optimization(self) -> None:
382
+ """Test optimization of highly biased branches."""
383
+ # Create mock profile with highly biased branch
384
+ profile = MagicMock(spec=ProfileData)
385
+ from unittest.mock import Mock
386
+
387
+ branch_profile = Mock()
388
+ branch_profile.taken_probability = 0.999 # 99.9% taken
389
+ branch_profile.predictable = True
390
+ profile.branches = {"test_func:entry": branch_profile}
391
+
392
+ opt = BranchPredictionOptimization(profile_data=profile, predictability_threshold=0.99)
393
+
394
+ branches = opt._collect_branch_info(self.func)
395
+
396
+ # Branch should be marked as highly predictable (0.999 > 0.99)
397
+ assert branches[0].is_predictable
398
+
399
+ # Run full optimization on the module
400
+ opt.run_on_module(self.module)
401
+
402
+ # With highly predictable branch, should do some optimization
403
+ # Branch hints should be added for predictable branches
404
+ assert opt.stats["branches_analyzed"] > 0
405
+ # Check if any optimization was done
406
+ total_optimizations = (
407
+ opt.stats.get("branch_hints_added", 0)
408
+ + opt.stats.get("blocks_reordered", 0)
409
+ + opt.stats.get("branches_converted_to_select", 0)
410
+ )
411
+ assert total_optimizations > 0
412
+
413
+ def test_multiple_branches(self) -> None:
414
+ """Test handling multiple branches in a function."""
415
+ # Create function with multiple branches
416
+ func = MIRFunction("multi_branch", [], MIRType.INT)
417
+
418
+ blocks = []
419
+ for i in range(5):
420
+ block = BasicBlock(f"block_{i}")
421
+ if i < 4:
422
+ cond = Temp(MIRType.BOOL)
423
+ block.add_instruction(LoadConst(cond, Constant(True, MIRType.BOOL), (1, 1)))
424
+ block.add_instruction(ConditionalJump(cond, f"block_{i + 1}", (1, 1), "exit"))
425
+ else:
426
+ block.add_instruction(Return((1, 1), Constant(0, MIRType.INT)))
427
+ blocks.append(block)
428
+ func.cfg.add_block(block)
429
+
430
+ exit_block = BasicBlock("exit")
431
+ exit_block.add_instruction(Return((1, 1), Constant(1, MIRType.INT)))
432
+ func.cfg.add_block(exit_block)
433
+
434
+ func.cfg.entry_block = blocks[0]
435
+
436
+ opt = BranchPredictionOptimization()
437
+ branches = opt._collect_branch_info(func)
438
+
439
+ # Should find 4 branches
440
+ assert len(branches) == 4
441
+
442
+ def test_edge_cases(self) -> None:
443
+ """Test edge cases."""
444
+ opt = BranchPredictionOptimization()
445
+
446
+ # Empty function
447
+ empty_func = MIRFunction("empty", [], MIRType.EMPTY)
448
+ branches = opt._collect_branch_info(empty_func)
449
+ assert len(branches) == 0
450
+
451
+ # Function with no branches
452
+ no_branch_func = MIRFunction("no_branch", [], MIRType.INT)
453
+ block = BasicBlock("entry")
454
+ block.add_instruction(Return((1, 1), Constant(42, MIRType.INT)))
455
+ no_branch_func.cfg.add_block(block)
456
+ no_branch_func.cfg.entry_block = block
457
+
458
+ branches = opt._collect_branch_info(no_branch_func)
459
+ assert len(branches) == 0
@@ -0,0 +1,168 @@
1
+ """Tests for improved call statement lowering."""
2
+
3
+ from machine_dialect.ast import (
4
+ Arguments,
5
+ CallStatement,
6
+ Identifier,
7
+ Program,
8
+ StringLiteral,
9
+ WholeNumberLiteral,
10
+ YesNoLiteral,
11
+ )
12
+ from machine_dialect.lexer.tokens import Token, TokenType
13
+ from machine_dialect.mir.hir_to_mir import lower_to_mir
14
+
15
+
16
+ class TestCallStatementLowering:
17
+ """Test improved call statement handling."""
18
+
19
+ def test_call_with_positional_arguments(self) -> None:
20
+ """Test call statement with positional arguments."""
21
+ # Call `print` with "Hello", 42.
22
+ args = Arguments(Token(TokenType.DELIM_LPAREN, "(", 0, 0))
23
+ args.positional = [
24
+ StringLiteral(Token(TokenType.LIT_TEXT, '"Hello"', 0, 0), '"Hello"'),
25
+ WholeNumberLiteral(Token(TokenType.LIT_WHOLE_NUMBER, "42", 0, 0), 42),
26
+ ]
27
+
28
+ call = CallStatement(
29
+ token=Token(TokenType.KW_USE, "use", 0, 0),
30
+ function_name=StringLiteral(Token(TokenType.LIT_TEXT, '"print"', 0, 0), '"print"'),
31
+ arguments=args,
32
+ )
33
+
34
+ program = Program(statements=[call])
35
+ mir_module = lower_to_mir(program)
36
+
37
+ # Check that main function was created
38
+ main_func = mir_module.get_function("__main__")
39
+ assert main_func is not None
40
+
41
+ # Check that call instruction was generated
42
+ found_call = False
43
+ for block in main_func.cfg.blocks.values():
44
+ for inst in block.instructions:
45
+ if inst.__class__.__name__ == "Call":
46
+ found_call = True
47
+ # Check that we have 2 arguments
48
+ assert hasattr(inst, "args") and len(inst.args) == 2
49
+ break
50
+
51
+ assert found_call, "Call instruction not found"
52
+
53
+ def test_call_with_named_arguments(self) -> None:
54
+ """Test call statement with named arguments."""
55
+ # Call `format` with name: "Alice", age: 30.
56
+ args = Arguments(Token(TokenType.DELIM_LPAREN, "(", 0, 0))
57
+ args.named = [
58
+ (
59
+ Identifier(Token(TokenType.MISC_IDENT, "name", 0, 0), "name"),
60
+ StringLiteral(Token(TokenType.LIT_TEXT, '"Alice"', 0, 0), '"Alice"'),
61
+ ),
62
+ (
63
+ Identifier(Token(TokenType.MISC_IDENT, "age", 0, 0), "age"),
64
+ WholeNumberLiteral(Token(TokenType.LIT_WHOLE_NUMBER, "30", 0, 0), 30),
65
+ ),
66
+ ]
67
+
68
+ call = CallStatement(
69
+ token=Token(TokenType.KW_USE, "use", 0, 0),
70
+ function_name=Identifier(Token(TokenType.MISC_IDENT, "format", 0, 0), "format"),
71
+ arguments=args,
72
+ )
73
+
74
+ program = Program(statements=[call])
75
+ mir_module = lower_to_mir(program)
76
+
77
+ main_func = mir_module.get_function("__main__")
78
+ assert main_func is not None
79
+
80
+ # Check that arguments were processed
81
+ found_call = False
82
+ for block in main_func.cfg.blocks.values():
83
+ for inst in block.instructions:
84
+ if inst.__class__.__name__ == "Call":
85
+ found_call = True
86
+ assert hasattr(inst, "args") and len(inst.args) == 2 # Named args converted to positional
87
+ break
88
+
89
+ assert found_call
90
+
91
+ def test_call_with_mixed_arguments(self) -> None:
92
+ """Test call with both positional and named arguments."""
93
+ args = Arguments(Token(TokenType.DELIM_LPAREN, "(", 0, 0))
94
+ args.positional = [StringLiteral(Token(TokenType.LIT_TEXT, '"test"', 0, 0), '"test"')]
95
+ args.named = [
96
+ (
97
+ Identifier(Token(TokenType.MISC_IDENT, "verbose", 0, 0), "verbose"),
98
+ YesNoLiteral(Token(TokenType.LIT_YES, "true", 0, 0), True),
99
+ )
100
+ ]
101
+
102
+ call = CallStatement(
103
+ token=Token(TokenType.KW_USE, "use", 0, 0),
104
+ function_name=StringLiteral(Token(TokenType.LIT_TEXT, '"run"', 0, 0), '"run"'),
105
+ arguments=args,
106
+ )
107
+
108
+ program = Program(statements=[call])
109
+ mir_module = lower_to_mir(program)
110
+
111
+ main_func = mir_module.get_function("__main__")
112
+ assert main_func is not None
113
+
114
+ # Both arguments should be present
115
+ for block in main_func.cfg.blocks.values():
116
+ for inst in block.instructions:
117
+ if inst.__class__.__name__ == "Call":
118
+ assert hasattr(inst, "args") and len(inst.args) == 2
119
+
120
+ def test_call_with_single_argument(self) -> None:
121
+ """Test call with single argument not wrapped in Arguments."""
122
+ call = CallStatement(
123
+ token=Token(TokenType.KW_USE, "use", 0, 0),
124
+ function_name=StringLiteral(Token(TokenType.LIT_TEXT, '"print"', 0, 0), '"print"'),
125
+ arguments=StringLiteral(Token(TokenType.LIT_TEXT, '"Hello World"', 0, 0), '"Hello World"'),
126
+ )
127
+
128
+ program = Program(statements=[call])
129
+ mir_module = lower_to_mir(program)
130
+
131
+ main_func = mir_module.get_function("__main__")
132
+ assert main_func is not None
133
+
134
+ # Should handle single argument
135
+ found_call = False
136
+ for block in main_func.cfg.blocks.values():
137
+ for inst in block.instructions:
138
+ if inst.__class__.__name__ == "Call":
139
+ found_call = True
140
+ assert hasattr(inst, "args") and len(inst.args) == 1
141
+ break
142
+
143
+ assert found_call
144
+
145
+ def test_call_without_arguments(self) -> None:
146
+ """Test call without any arguments."""
147
+ call = CallStatement(
148
+ token=Token(TokenType.KW_USE, "use", 0, 0),
149
+ function_name=Identifier(Token(TokenType.MISC_IDENT, "exit", 0, 0), "exit"),
150
+ arguments=None,
151
+ )
152
+
153
+ program = Program(statements=[call])
154
+ mir_module = lower_to_mir(program)
155
+
156
+ main_func = mir_module.get_function("__main__")
157
+ assert main_func is not None
158
+
159
+ # Should handle no arguments
160
+ found_call = False
161
+ for block in main_func.cfg.blocks.values():
162
+ for inst in block.instructions:
163
+ if inst.__class__.__name__ == "Call":
164
+ found_call = True
165
+ assert hasattr(inst, "args") and len(inst.args) == 0
166
+ break
167
+
168
+ assert found_call