machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,356 @@
1
+ """Tests for profiling infrastructure."""
2
+
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+ from machine_dialect.mir.profiling import (
7
+ BranchProfile,
8
+ FunctionProfile,
9
+ LoopProfile,
10
+ ProfileCollector,
11
+ ProfileData,
12
+ ProfileReader,
13
+ ProfileWriter,
14
+ )
15
+
16
+
17
+ class TestProfileData:
18
+ """Test profile data structures."""
19
+
20
+ def test_function_profile(self) -> None:
21
+ """Test function profile creation and updates."""
22
+ profile = FunctionProfile(name="test_func")
23
+ assert profile.name == "test_func"
24
+ assert profile.call_count == 0
25
+ assert not profile.hot
26
+
27
+ # Update with calls
28
+ profile.call_count = 150
29
+ profile.total_cycles = 15000
30
+ profile.update_stats()
31
+
32
+ assert profile.hot
33
+ assert profile.avg_cycles == 100.0
34
+ assert profile.inline_benefit > 0
35
+
36
+ def test_branch_profile(self) -> None:
37
+ """Test branch profile creation and updates."""
38
+ profile = BranchProfile(location="func:block:1")
39
+ assert profile.location == "func:block:1"
40
+ assert not profile.predictable
41
+
42
+ # Update with biased branch
43
+ profile.taken_count = 95
44
+ profile.not_taken_count = 5
45
+ profile.update_stats()
46
+
47
+ assert profile.predictable
48
+ assert profile.taken_probability == 0.95
49
+
50
+ def test_loop_profile(self) -> None:
51
+ """Test loop profile creation and iteration tracking."""
52
+ profile = LoopProfile(location="func:loop1")
53
+ assert profile.location == "func:loop1"
54
+ assert not profile.hot
55
+
56
+ # Record iterations
57
+ profile.record_iteration(10)
58
+ profile.record_iteration(12)
59
+ profile.record_iteration(8)
60
+
61
+ assert profile.entry_count == 3
62
+ assert profile.total_iterations == 30
63
+ assert profile.avg_iterations == 10.0
64
+ assert profile.max_iterations == 12
65
+ assert profile.min_iterations == 8
66
+
67
+ def test_profile_data_merge(self) -> None:
68
+ """Test merging profile data."""
69
+ profile1 = ProfileData(module_name="test")
70
+ profile1.functions["func1"] = FunctionProfile(name="func1", call_count=10)
71
+ profile1.branches["branch1"] = BranchProfile(location="branch1", taken_count=5, not_taken_count=3)
72
+
73
+ profile2 = ProfileData(module_name="test")
74
+ profile2.functions["func1"] = FunctionProfile(name="func1", call_count=5)
75
+ profile2.functions["func2"] = FunctionProfile(name="func2", call_count=3)
76
+ profile2.branches["branch1"] = BranchProfile(location="branch1", taken_count=2, not_taken_count=1)
77
+
78
+ profile1.merge(profile2)
79
+
80
+ assert profile1.functions["func1"].call_count == 15
81
+ assert "func2" in profile1.functions
82
+ assert profile1.branches["branch1"].taken_count == 7
83
+ assert profile1.branches["branch1"].not_taken_count == 4
84
+
85
+ def test_hot_function_detection(self) -> None:
86
+ """Test detection of hot functions."""
87
+ profile = ProfileData(module_name="test")
88
+ profile.functions["cold"] = FunctionProfile(name="cold", call_count=10)
89
+ profile.functions["hot"] = FunctionProfile(name="hot", call_count=200)
90
+ profile.functions["hot"].update_stats()
91
+
92
+ hot_funcs = profile.get_hot_functions(threshold=100)
93
+ assert "hot" in hot_funcs
94
+ assert "cold" not in hot_funcs
95
+
96
+
97
+ class TestProfileCollector:
98
+ """Test profile collection functionality."""
99
+
100
+ def test_collector_initialization(self) -> None:
101
+ """Test collector initialization."""
102
+ collector = ProfileCollector("test_module")
103
+ assert collector.profile_data.module_name == "test_module"
104
+ assert not collector.enabled
105
+ assert collector.sampling_rate == 1
106
+
107
+ def test_function_profiling(self) -> None:
108
+ """Test function entry/exit profiling."""
109
+ collector = ProfileCollector()
110
+ collector.enable()
111
+
112
+ # Enter and exit function
113
+ collector.enter_function("test_func", "main:10")
114
+ collector.exit_function("test_func")
115
+
116
+ profile_data = collector.get_profile_data()
117
+ assert "test_func" in profile_data.functions
118
+ func_profile = profile_data.functions["test_func"]
119
+ assert func_profile.call_count == 1
120
+ assert "main:10" in func_profile.call_sites
121
+
122
+ def test_branch_profiling(self) -> None:
123
+ """Test branch profiling."""
124
+ collector = ProfileCollector()
125
+ collector.enable()
126
+
127
+ # Record branch executions
128
+ collector.record_branch("func:block:1", taken=True)
129
+ collector.record_branch("func:block:1", taken=True)
130
+ collector.record_branch("func:block:1", taken=False)
131
+
132
+ profile_data = collector.get_profile_data()
133
+ assert "func:block:1" in profile_data.branches
134
+ branch_profile = profile_data.branches["func:block:1"]
135
+ assert branch_profile.taken_count == 2
136
+ assert branch_profile.not_taken_count == 1
137
+
138
+ def test_loop_profiling(self) -> None:
139
+ """Test loop profiling."""
140
+ collector = ProfileCollector()
141
+ collector.enable()
142
+
143
+ # Profile a loop
144
+ collector.enter_loop("func:loop1")
145
+ for _ in range(5):
146
+ collector.record_loop_iteration()
147
+ collector.exit_loop("func:loop1")
148
+
149
+ profile_data = collector.get_profile_data()
150
+ assert "func:loop1" in profile_data.loops
151
+ loop_profile = profile_data.loops["func:loop1"]
152
+ assert loop_profile.entry_count == 1
153
+ assert loop_profile.total_iterations == 5
154
+
155
+ def test_sampling(self) -> None:
156
+ """Test sampling rate functionality."""
157
+ collector = ProfileCollector()
158
+ collector.enable(sampling_rate=2) # Sample every 2nd event
159
+
160
+ # Only every 2nd function call should be recorded
161
+ collector.enter_function("func1")
162
+ collector.exit_function("func1")
163
+ collector.enter_function("func2")
164
+ collector.exit_function("func2")
165
+
166
+ profile_data = collector.get_profile_data()
167
+ # Due to sampling, only one function should be recorded
168
+ assert len(profile_data.functions) == 1
169
+
170
+ def test_hot_path_hints(self) -> None:
171
+ """Test generation of optimization hints."""
172
+ collector = ProfileCollector()
173
+ collector.enable()
174
+
175
+ # Create hot function
176
+ for _ in range(200):
177
+ collector.enter_function("hot_func")
178
+ collector.exit_function("hot_func")
179
+
180
+ # Create predictable branch
181
+ for _ in range(100):
182
+ collector.record_branch("predictable", taken=True)
183
+ for _ in range(5):
184
+ collector.record_branch("predictable", taken=False)
185
+
186
+ hints = collector.get_hot_path_hints()
187
+ assert "hot_func" in hints["hot_functions"]
188
+ assert "predictable" in hints["predictable_branches"]
189
+
190
+
191
+ class TestProfilePersistence:
192
+ """Test profile reading and writing."""
193
+
194
+ def test_json_roundtrip(self) -> None:
195
+ """Test JSON serialization and deserialization."""
196
+ # Create profile data
197
+ profile = ProfileData(module_name="test")
198
+ profile.functions["func1"] = FunctionProfile(name="func1", call_count=100, total_cycles=1000)
199
+ profile.branches["branch1"] = BranchProfile(location="branch1", taken_count=75, not_taken_count=25)
200
+ profile.loops["loop1"] = LoopProfile(location="loop1", entry_count=10)
201
+
202
+ # Write and read
203
+ with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
204
+ filepath = Path(f.name)
205
+
206
+ try:
207
+ writer = ProfileWriter()
208
+ writer.write_json(profile, filepath)
209
+
210
+ reader = ProfileReader()
211
+ loaded = reader.read_json(filepath)
212
+
213
+ # Verify data
214
+ assert loaded.module_name == "test"
215
+ assert "func1" in loaded.functions
216
+ assert loaded.functions["func1"].call_count == 100
217
+ assert "branch1" in loaded.branches
218
+ assert loaded.branches["branch1"].taken_count == 75
219
+ assert "loop1" in loaded.loops
220
+ assert loaded.loops["loop1"].entry_count == 10
221
+ finally:
222
+ filepath.unlink()
223
+
224
+ def test_binary_roundtrip(self) -> None:
225
+ """Test binary serialization and deserialization."""
226
+ # Create profile data
227
+ profile = ProfileData(module_name="test")
228
+ profile.functions["func1"] = FunctionProfile(name="func1", call_count=50)
229
+ profile.total_samples = 1000
230
+
231
+ # Write and read
232
+ with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f:
233
+ filepath = Path(f.name)
234
+
235
+ try:
236
+ writer = ProfileWriter()
237
+ writer.write_binary(profile, filepath)
238
+
239
+ reader = ProfileReader()
240
+ loaded = reader.read_binary(filepath)
241
+
242
+ # Verify data
243
+ assert loaded.module_name == "test"
244
+ assert loaded.total_samples == 1000
245
+ assert loaded.functions["func1"].call_count == 50
246
+ finally:
247
+ filepath.unlink()
248
+
249
+ def test_auto_format_detection(self) -> None:
250
+ """Test automatic format detection."""
251
+ profile = ProfileData(module_name="test")
252
+ profile.functions["func1"] = FunctionProfile(name="func1")
253
+
254
+ reader = ProfileReader()
255
+ writer = ProfileWriter()
256
+
257
+ # Test JSON detection
258
+ with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
259
+ json_path = Path(f.name)
260
+
261
+ try:
262
+ writer.write_json(profile, json_path)
263
+ loaded = reader.read_auto(json_path)
264
+ assert loaded.module_name == "test"
265
+ finally:
266
+ json_path.unlink()
267
+
268
+ # Test binary detection
269
+ with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f:
270
+ pkl_path = Path(f.name)
271
+
272
+ try:
273
+ writer.write_binary(profile, pkl_path)
274
+ loaded = reader.read_auto(pkl_path)
275
+ assert loaded.module_name == "test"
276
+ finally:
277
+ pkl_path.unlink()
278
+
279
+ def test_profile_merging(self) -> None:
280
+ """Test merging multiple profile files."""
281
+ # Create profiles
282
+ profile1 = ProfileData(module_name="test")
283
+ profile1.functions["func1"] = FunctionProfile(name="func1", call_count=10)
284
+
285
+ profile2 = ProfileData(module_name="test")
286
+ profile2.functions["func1"] = FunctionProfile(name="func1", call_count=5)
287
+ profile2.functions["func2"] = FunctionProfile(name="func2", call_count=3)
288
+
289
+ # Write profiles
290
+ with tempfile.TemporaryDirectory() as tmpdir:
291
+ path1 = Path(tmpdir) / "profile1.json"
292
+ path2 = Path(tmpdir) / "profile2.json"
293
+
294
+ writer = ProfileWriter()
295
+ writer.write_json(profile1, path1)
296
+ writer.write_json(profile2, path2)
297
+
298
+ # Merge profiles
299
+ reader = ProfileReader()
300
+ merged = reader.merge_profiles([path1, path2])
301
+
302
+ assert merged.functions["func1"].call_count == 15
303
+ assert merged.functions["func2"].call_count == 3
304
+
305
+ def test_summary_generation(self) -> None:
306
+ """Test human-readable summary generation."""
307
+ profile = ProfileData(module_name="test")
308
+
309
+ # Add hot function
310
+ hot_func = FunctionProfile(name="hot_func", call_count=1000, total_cycles=100000)
311
+ hot_func.update_stats()
312
+ profile.functions["hot_func"] = hot_func
313
+
314
+ # Add predictable branch
315
+ branch = BranchProfile(location="branch1", taken_count=95, not_taken_count=5)
316
+ branch.update_stats()
317
+ profile.branches["branch1"] = branch
318
+
319
+ # Write summary
320
+ with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
321
+ filepath = Path(f.name)
322
+
323
+ try:
324
+ writer = ProfileWriter()
325
+ writer.write_summary(profile, filepath)
326
+
327
+ # Verify summary content
328
+ content = filepath.read_text()
329
+ assert "test" in content
330
+ assert "hot_func" in content
331
+ assert "branch1" in content
332
+ assert "95.0%" in content or "Taken: 0.95" in content
333
+ finally:
334
+ filepath.unlink()
335
+
336
+ def test_validation(self) -> None:
337
+ """Test profile data validation."""
338
+ reader = ProfileReader()
339
+
340
+ # Valid profile
341
+ valid_profile = ProfileData(module_name="test")
342
+ valid_profile.functions["func1"] = FunctionProfile(name="func1", call_count=10, total_cycles=100)
343
+ valid_profile.functions["func1"].update_stats()
344
+ valid_profile.total_samples = 10 # Add samples to make it valid
345
+
346
+ warnings = reader.validate_profile(valid_profile)
347
+ assert len(warnings) == 0
348
+
349
+ # Invalid profile - cycles without calls
350
+ invalid_profile = ProfileData(module_name="test")
351
+ invalid_profile.functions["func1"] = FunctionProfile(name="func1", call_count=0, total_cycles=100)
352
+ invalid_profile.total_samples = 1 # Add samples but still invalid
353
+
354
+ warnings = reader.validate_profile(invalid_profile)
355
+ assert len(warnings) > 0
356
+ assert "cycles but no calls" in warnings[0]
@@ -0,0 +1,307 @@
1
+ """Tests for virtual register allocation."""
2
+
3
+ from machine_dialect.ast import (
4
+ Expression,
5
+ Identifier,
6
+ InfixExpression,
7
+ Program,
8
+ ReturnStatement,
9
+ SetStatement,
10
+ WholeNumberLiteral,
11
+ )
12
+ from machine_dialect.lexer.tokens import Token, TokenType
13
+ from machine_dialect.mir.hir_to_mir import lower_to_mir
14
+ from machine_dialect.mir.register_allocation import (
15
+ LifetimeAnalyzer,
16
+ RegisterAllocator,
17
+ )
18
+
19
+
20
+ class TestRegisterAllocation:
21
+ """Test virtual register allocation."""
22
+
23
+ def _create_infix(self, left: Expression, op: str, right: Expression) -> InfixExpression:
24
+ """Helper to create InfixExpression properly."""
25
+ token = Token(TokenType.OP_PLUS if op == "+" else TokenType.OP_STAR, op, 0, 0)
26
+ expr = InfixExpression(token, op, left)
27
+ expr.right = right
28
+ return expr
29
+
30
+ def _token(self, token_type: TokenType, value: str = "") -> Token:
31
+ """Create a token for testing."""
32
+ return Token(token_type, value, 0, 0)
33
+
34
+ def test_basic_register_allocation(self) -> None:
35
+ """Test basic register allocation for simple function."""
36
+ # Create a function with some variables
37
+ program = Program(
38
+ statements=[
39
+ SetStatement(
40
+ self._token(TokenType.KW_SET, "set"),
41
+ Identifier(self._token(TokenType.MISC_IDENT, "x"), "x"),
42
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "10"), 10),
43
+ ),
44
+ SetStatement(
45
+ self._token(TokenType.KW_SET, "set"),
46
+ Identifier(self._token(TokenType.MISC_IDENT, "y"), "y"),
47
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "20"), 20),
48
+ ),
49
+ SetStatement(
50
+ self._token(TokenType.KW_SET, "set"),
51
+ Identifier(self._token(TokenType.MISC_IDENT, "z"), "z"),
52
+ self._create_infix(
53
+ Identifier(self._token(TokenType.MISC_IDENT, "x"), "x"),
54
+ "+",
55
+ Identifier(self._token(TokenType.MISC_IDENT, "y"), "y"),
56
+ ),
57
+ ),
58
+ ReturnStatement(
59
+ self._token(TokenType.KW_RETURN, "return"),
60
+ return_value=Identifier(self._token(TokenType.MISC_IDENT, "z"), "z"),
61
+ ),
62
+ ]
63
+ )
64
+
65
+ mir_module = lower_to_mir(program)
66
+ main_func = mir_module.get_function("__main__")
67
+ assert main_func is not None
68
+
69
+ # Perform register allocation
70
+ allocator = RegisterAllocator(main_func)
71
+ allocation = allocator.allocate()
72
+
73
+ # Check that values got allocated
74
+ assert len(allocation.allocations) > 0
75
+ # Check that we didn't use too many registers
76
+ assert allocation.max_registers < 256
77
+ # No spilling needed for simple function
78
+ assert len(allocation.spilled_values) == 0
79
+
80
+ def test_register_allocation_with_spilling(self) -> None:
81
+ """Test register allocation with spilling when registers are limited."""
82
+ # Create a function with many variables that have overlapping lifetimes
83
+ statements = []
84
+
85
+ # Create 15 variables
86
+ num_vars = 15
87
+ for i in range(num_vars):
88
+ statements.append(
89
+ SetStatement(
90
+ self._token(TokenType.KW_SET, "set"),
91
+ Identifier(self._token(TokenType.MISC_IDENT, f"var_{i}"), f"var_{i}"),
92
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, str(i)), i),
93
+ )
94
+ )
95
+
96
+ # Now create an expression that uses all variables at once
97
+ # This ensures their lifetimes overlap
98
+ # Build: var_0 + var_1 + var_2 + ... + var_14
99
+ result_expr: Expression = Identifier(self._token(TokenType.MISC_IDENT, "var_0"), "var_0")
100
+ for i in range(1, num_vars):
101
+ result_expr = self._create_infix(
102
+ result_expr,
103
+ "+",
104
+ Identifier(self._token(TokenType.MISC_IDENT, f"var_{i}"), f"var_{i}"),
105
+ )
106
+
107
+ statements.append(
108
+ SetStatement(
109
+ self._token(TokenType.KW_SET, "set"),
110
+ Identifier(self._token(TokenType.MISC_IDENT, "result"), "result"),
111
+ result_expr,
112
+ )
113
+ )
114
+
115
+ program = Program(statements=statements) # type: ignore[arg-type]
116
+ mir_module = lower_to_mir(program)
117
+ main_func = mir_module.get_function("__main__")
118
+ assert main_func is not None
119
+
120
+ # Allocate with very limited registers (less than the number of overlapping values)
121
+ allocator = RegisterAllocator(main_func, max_registers=8)
122
+ allocation = allocator.allocate()
123
+
124
+ # Check that some values were spilled (we have 15+ values but only 8 registers)
125
+ assert len(allocation.spilled_values) > 0, (
126
+ f"Expected spilling but got none. Allocated: {allocation.max_registers} registers"
127
+ )
128
+ # Should use most of the available registers
129
+ assert allocation.max_registers <= 8 and allocation.max_registers >= 5
130
+
131
+ def test_lifetime_analysis(self) -> None:
132
+ """Test lifetime analysis for temporaries."""
133
+ program = Program(
134
+ statements=[
135
+ SetStatement(
136
+ self._token(TokenType.KW_SET, "set"),
137
+ Identifier(self._token(TokenType.MISC_IDENT, "a"), "a"),
138
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "1"), 1),
139
+ ),
140
+ SetStatement(
141
+ self._token(TokenType.KW_SET, "set"),
142
+ Identifier(self._token(TokenType.MISC_IDENT, "b"), "b"),
143
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "2"), 2),
144
+ ),
145
+ SetStatement(
146
+ self._token(TokenType.KW_SET, "set"),
147
+ Identifier(self._token(TokenType.MISC_IDENT, "c"), "c"),
148
+ self._create_infix(
149
+ Identifier(self._token(TokenType.MISC_IDENT, "a"), "a"),
150
+ "+",
151
+ Identifier(self._token(TokenType.MISC_IDENT, "b"), "b"),
152
+ ),
153
+ ),
154
+ # 'a' and 'b' not used after this point
155
+ SetStatement(
156
+ self._token(TokenType.KW_SET, "set"),
157
+ Identifier(self._token(TokenType.MISC_IDENT, "d"), "d"),
158
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "3"), 3),
159
+ ),
160
+ SetStatement(
161
+ self._token(TokenType.KW_SET, "set"),
162
+ Identifier(self._token(TokenType.MISC_IDENT, "e"), "e"),
163
+ self._create_infix(
164
+ Identifier(self._token(TokenType.MISC_IDENT, "c"), "c"),
165
+ "*",
166
+ Identifier(self._token(TokenType.MISC_IDENT, "d"), "d"),
167
+ ),
168
+ ),
169
+ ReturnStatement(
170
+ self._token(TokenType.KW_RETURN, "return"),
171
+ return_value=Identifier(self._token(TokenType.MISC_IDENT, "e"), "e"),
172
+ ),
173
+ ]
174
+ )
175
+
176
+ mir_module = lower_to_mir(program)
177
+ main_func = mir_module.get_function("__main__")
178
+ assert main_func is not None
179
+
180
+ # Analyze lifetimes
181
+ analyzer = LifetimeAnalyzer(main_func)
182
+ lifetimes = analyzer.analyze()
183
+
184
+ # Check that we have lifetime info for all variables
185
+ assert len(lifetimes) > 0
186
+
187
+ # Variables should have different lifetimes
188
+ # 'a' and 'b' should have shorter lifetimes than 'e'
189
+ for _, (start, end) in lifetimes.items():
190
+ assert start <= end # Valid lifetime range
191
+
192
+ def test_reusable_slot_detection(self) -> None:
193
+ """Test detection of reusable stack slots."""
194
+ # Create a program where variables don't overlap in lifetime
195
+ program = Program(
196
+ statements=[
197
+ # First set of variables
198
+ SetStatement(
199
+ self._token(TokenType.KW_SET, "set"),
200
+ Identifier(self._token(TokenType.MISC_IDENT, "temp1"), "temp1"),
201
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "1"), 1),
202
+ ),
203
+ SetStatement(
204
+ self._token(TokenType.KW_SET, "set"),
205
+ Identifier(self._token(TokenType.MISC_IDENT, "temp2"), "temp2"),
206
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "2"), 2),
207
+ ),
208
+ SetStatement(
209
+ self._token(TokenType.KW_SET, "set"),
210
+ Identifier(self._token(TokenType.MISC_IDENT, "result1"), "result1"),
211
+ self._create_infix(
212
+ Identifier(self._token(TokenType.MISC_IDENT, "temp1"), "temp1"),
213
+ "+",
214
+ Identifier(self._token(TokenType.MISC_IDENT, "temp2"), "temp2"),
215
+ ),
216
+ ),
217
+ # temp1 and temp2 dead after this
218
+ # Second set of variables (can reuse slots)
219
+ SetStatement(
220
+ self._token(TokenType.KW_SET, "set"),
221
+ Identifier(self._token(TokenType.MISC_IDENT, "temp3"), "temp3"),
222
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "3"), 3),
223
+ ),
224
+ SetStatement(
225
+ self._token(TokenType.KW_SET, "set"),
226
+ Identifier(self._token(TokenType.MISC_IDENT, "temp4"), "temp4"),
227
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "4"), 4),
228
+ ),
229
+ SetStatement(
230
+ self._token(TokenType.KW_SET, "set"),
231
+ Identifier(self._token(TokenType.MISC_IDENT, "result2"), "result2"),
232
+ self._create_infix(
233
+ Identifier(self._token(TokenType.MISC_IDENT, "temp3"), "temp3"),
234
+ "+",
235
+ Identifier(self._token(TokenType.MISC_IDENT, "temp4"), "temp4"),
236
+ ),
237
+ ),
238
+ ReturnStatement(
239
+ self._token(TokenType.KW_RETURN, "return"),
240
+ return_value=self._create_infix(
241
+ Identifier(self._token(TokenType.MISC_IDENT, "result1"), "result1"),
242
+ "+",
243
+ Identifier(self._token(TokenType.MISC_IDENT, "result2"), "result2"),
244
+ ),
245
+ ),
246
+ ]
247
+ )
248
+
249
+ mir_module = lower_to_mir(program)
250
+ main_func = mir_module.get_function("__main__")
251
+ assert main_func is not None
252
+
253
+ # Analyze lifetimes
254
+ analyzer = LifetimeAnalyzer(main_func)
255
+ lifetimes = analyzer.analyze()
256
+
257
+ # Find reusable slots
258
+ reusable_groups = analyzer.find_reusable_slots()
259
+
260
+ # Should identify some values that can share slots
261
+ assert len(reusable_groups) > 0
262
+
263
+ # Count total slots needed vs total variables
264
+ total_slots_needed = len(reusable_groups)
265
+ total_variables = len(lifetimes)
266
+
267
+ # Should need fewer slots than variables due to reuse
268
+ assert total_slots_needed <= total_variables
269
+
270
+ def test_linear_scan_ordering(self) -> None:
271
+ """Test that linear scan processes intervals in correct order."""
272
+ program = Program(
273
+ statements=[
274
+ SetStatement(
275
+ self._token(TokenType.KW_SET, "set"),
276
+ Identifier(self._token(TokenType.MISC_IDENT, "first"), "first"),
277
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "1"), 1),
278
+ ),
279
+ SetStatement(
280
+ self._token(TokenType.KW_SET, "set"),
281
+ Identifier(self._token(TokenType.MISC_IDENT, "second"), "second"),
282
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "2"), 2),
283
+ ),
284
+ SetStatement(
285
+ self._token(TokenType.KW_SET, "set"),
286
+ Identifier(self._token(TokenType.MISC_IDENT, "third"), "third"),
287
+ WholeNumberLiteral(self._token(TokenType.LIT_WHOLE_NUMBER, "3"), 3),
288
+ ),
289
+ ]
290
+ )
291
+
292
+ mir_module = lower_to_mir(program)
293
+ main_func = mir_module.get_function("__main__")
294
+ assert main_func is not None
295
+
296
+ allocator = RegisterAllocator(main_func)
297
+
298
+ # Build intervals
299
+ allocator._build_instruction_positions()
300
+ allocator._compute_live_intervals()
301
+
302
+ # Check intervals are created
303
+ assert len(allocator.live_intervals) > 0
304
+
305
+ # Check intervals are sorted by start position
306
+ for i in range(1, len(allocator.live_intervals)):
307
+ assert allocator.live_intervals[i - 1].start <= allocator.live_intervals[i].start