machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,318 @@
1
+ """Tests for optimization reporter."""
2
+
3
+ from machine_dialect.mir.reporting.optimization_reporter import (
4
+ ModuleMetrics,
5
+ OptimizationReporter,
6
+ PassMetrics,
7
+ )
8
+
9
+
10
+ class TestPassMetrics:
11
+ """Test PassMetrics functionality."""
12
+
13
+ def test_pass_metrics_creation(self) -> None:
14
+ """Test creating pass metrics."""
15
+ metrics = PassMetrics(
16
+ pass_name="constant-propagation",
17
+ phase="early",
18
+ metrics={"constants_propagated": 10},
19
+ before_stats={"instructions": 100},
20
+ after_stats={"instructions": 90},
21
+ time_ms=5.5,
22
+ )
23
+
24
+ assert metrics.pass_name == "constant-propagation"
25
+ assert metrics.phase == "early"
26
+ assert metrics.metrics["constants_propagated"] == 10
27
+ assert metrics.time_ms == 5.5
28
+
29
+ def test_get_improvement(self) -> None:
30
+ """Test calculating improvement percentage."""
31
+ metrics = PassMetrics(
32
+ pass_name="dce",
33
+ before_stats={"instructions": 100, "blocks": 10},
34
+ after_stats={"instructions": 80, "blocks": 8},
35
+ )
36
+
37
+ # 20% reduction in instructions
38
+ assert metrics.get_improvement("instructions") == 20.0
39
+
40
+ # 20% reduction in blocks
41
+ assert metrics.get_improvement("blocks") == 20.0
42
+
43
+ # No change for missing metric
44
+ assert metrics.get_improvement("missing") == 0.0
45
+
46
+ def test_get_improvement_zero_before(self) -> None:
47
+ """Test improvement calculation with zero before value."""
48
+ metrics = PassMetrics(
49
+ pass_name="test",
50
+ before_stats={"count": 0},
51
+ after_stats={"count": 10},
52
+ )
53
+
54
+ assert metrics.get_improvement("count") == 0.0
55
+
56
+
57
+ class TestModuleMetrics:
58
+ """Test ModuleMetrics functionality."""
59
+
60
+ def test_module_metrics_creation(self) -> None:
61
+ """Test creating module metrics."""
62
+ metrics = ModuleMetrics(
63
+ module_name="test_module",
64
+ optimization_level=2,
65
+ )
66
+
67
+ assert metrics.module_name == "test_module"
68
+ assert metrics.optimization_level == 2
69
+ assert metrics.total_time_ms == 0.0
70
+ assert len(metrics.pass_metrics) == 0
71
+
72
+ def test_add_pass_metrics(self) -> None:
73
+ """Test adding pass metrics."""
74
+ module_metrics = ModuleMetrics("test")
75
+
76
+ pass1 = PassMetrics("pass1", time_ms=10.0)
77
+ pass2 = PassMetrics("pass2", time_ms=15.0)
78
+
79
+ module_metrics.add_pass_metrics(pass1)
80
+ module_metrics.add_pass_metrics(pass2)
81
+
82
+ assert len(module_metrics.pass_metrics) == 2
83
+ assert module_metrics.total_time_ms == 25.0
84
+
85
+ def test_get_summary(self) -> None:
86
+ """Test getting module summary."""
87
+ module_metrics = ModuleMetrics(
88
+ module_name="test",
89
+ optimization_level=2,
90
+ )
91
+
92
+ # Add some pass metrics
93
+ pass1 = PassMetrics(
94
+ pass_name="constant-propagation",
95
+ metrics={"constants_propagated": 5},
96
+ before_stats={"instructions": 100},
97
+ after_stats={"instructions": 95},
98
+ time_ms=2.0,
99
+ )
100
+
101
+ pass2 = PassMetrics(
102
+ pass_name="dce",
103
+ metrics={"dead_removed": 10},
104
+ before_stats={"instructions": 95},
105
+ after_stats={"instructions": 85},
106
+ time_ms=3.0,
107
+ )
108
+
109
+ module_metrics.add_pass_metrics(pass1)
110
+ module_metrics.add_pass_metrics(pass2)
111
+
112
+ summary = module_metrics.get_summary()
113
+
114
+ assert summary["module_name"] == "test"
115
+ assert summary["optimization_level"] == 2
116
+ assert summary["total_passes"] == 2
117
+ assert summary["total_time_ms"] == 5.0
118
+ assert "constant-propagation" in summary["passes_applied"]
119
+ assert "dce" in summary["passes_applied"]
120
+
121
+ # Check aggregated metrics
122
+ assert summary["total_metrics"]["constants_propagated"] == 5
123
+ assert summary["total_metrics"]["dead_removed"] == 10
124
+
125
+ # Check improvements
126
+ assert "instructions" in summary["improvements"]
127
+
128
+
129
+ class TestOptimizationReporter:
130
+ """Test OptimizationReporter functionality."""
131
+
132
+ def test_reporter_creation(self) -> None:
133
+ """Test creating optimization reporter."""
134
+ reporter = OptimizationReporter("my_module")
135
+
136
+ assert reporter.module_metrics.module_name == "my_module"
137
+ assert reporter.current_pass is None
138
+
139
+ def test_pass_tracking(self) -> None:
140
+ """Test tracking optimization passes."""
141
+ reporter = OptimizationReporter("test")
142
+
143
+ # Start a pass
144
+ reporter.start_pass(
145
+ "constant-propagation",
146
+ phase="early",
147
+ before_stats={"instructions": 100},
148
+ )
149
+
150
+ assert reporter.current_pass is not None
151
+ assert reporter.current_pass.pass_name == "constant-propagation"
152
+
153
+ # End the pass
154
+ reporter.end_pass(
155
+ metrics={"constants_propagated": 5},
156
+ after_stats={"instructions": 95},
157
+ time_ms=2.5,
158
+ )
159
+
160
+ assert reporter.current_pass is None
161
+ assert len(reporter.module_metrics.pass_metrics) == 1
162
+
163
+ pass_metrics = reporter.module_metrics.pass_metrics[0]
164
+ assert pass_metrics.pass_name == "constant-propagation"
165
+ assert pass_metrics.metrics["constants_propagated"] == 5
166
+ assert pass_metrics.time_ms == 2.5
167
+
168
+ def test_multiple_passes(self) -> None:
169
+ """Test tracking multiple optimization passes."""
170
+ reporter = OptimizationReporter("test")
171
+
172
+ # First pass
173
+ reporter.start_pass("pass1", before_stats={"size": 1000})
174
+ reporter.end_pass(
175
+ metrics={"changes": 10},
176
+ after_stats={"size": 900},
177
+ time_ms=5.0,
178
+ )
179
+
180
+ # Second pass
181
+ reporter.start_pass("pass2", before_stats={"size": 900})
182
+ reporter.end_pass(
183
+ metrics={"changes": 5},
184
+ after_stats={"size": 850},
185
+ time_ms=3.0,
186
+ )
187
+
188
+ assert len(reporter.module_metrics.pass_metrics) == 2
189
+ assert reporter.module_metrics.total_time_ms == 8.0
190
+
191
+ def test_function_metrics(self) -> None:
192
+ """Test adding function-specific metrics."""
193
+ reporter = OptimizationReporter("test")
194
+
195
+ reporter.add_function_metrics(
196
+ "main",
197
+ {
198
+ "instructions": 50,
199
+ "blocks": 5,
200
+ "loops": 2,
201
+ },
202
+ )
203
+
204
+ reporter.add_function_metrics(
205
+ "helper",
206
+ {
207
+ "instructions": 20,
208
+ "blocks": 2,
209
+ "loops": 0,
210
+ },
211
+ )
212
+
213
+ assert len(reporter.module_metrics.function_metrics) == 2
214
+ assert reporter.module_metrics.function_metrics["main"]["loops"] == 2
215
+ assert reporter.module_metrics.function_metrics["helper"]["instructions"] == 20
216
+
217
+ def test_optimization_level(self) -> None:
218
+ """Test setting optimization level."""
219
+ reporter = OptimizationReporter("test")
220
+ reporter.set_optimization_level(3)
221
+
222
+ assert reporter.module_metrics.optimization_level == 3
223
+
224
+ def test_generate_summary(self) -> None:
225
+ """Test generating text summary."""
226
+ reporter = OptimizationReporter("test_module")
227
+ reporter.set_optimization_level(2)
228
+
229
+ # Add a pass
230
+ reporter.start_pass(
231
+ "constant-propagation",
232
+ before_stats={"instructions": 100},
233
+ )
234
+ reporter.end_pass(
235
+ metrics={"constants_propagated": 10},
236
+ after_stats={"instructions": 90},
237
+ time_ms=5.0,
238
+ )
239
+
240
+ summary = reporter.generate_summary()
241
+
242
+ assert "Module: test_module" in summary
243
+ assert "Optimization Level: 2" in summary
244
+ assert "Total Passes: 1" in summary
245
+ assert "Total Time: 5.00ms" in summary
246
+ assert "constant-propagation" in summary
247
+ assert "instructions: 10.0% reduction" in summary
248
+ assert "constants_propagated: 10" in summary
249
+
250
+ def test_generate_detailed_report(self) -> None:
251
+ """Test generating detailed report."""
252
+ reporter = OptimizationReporter("test_module")
253
+
254
+ # Add multiple passes
255
+ reporter.start_pass("pass1", before_stats={"size": 1000})
256
+ reporter.end_pass(
257
+ metrics={"optimized": 5},
258
+ after_stats={"size": 950},
259
+ time_ms=2.0,
260
+ )
261
+
262
+ reporter.start_pass("pass2", before_stats={"size": 950})
263
+ reporter.end_pass(
264
+ metrics={"optimized": 3},
265
+ after_stats={"size": 920},
266
+ time_ms=1.5,
267
+ )
268
+
269
+ # Add function metrics
270
+ reporter.add_function_metrics(
271
+ "main",
272
+ {"complexity": 10, "lines": 50},
273
+ )
274
+
275
+ report = reporter.generate_detailed_report()
276
+
277
+ assert "OPTIMIZATION REPORT" in report
278
+ assert "DETAILED PASS STATISTICS" in report
279
+ assert "Pass: pass1" in report
280
+ assert "Pass: pass2" in report
281
+ assert "Time: 2.00ms" in report
282
+ assert "Time: 1.50ms" in report
283
+ assert "FUNCTION METRICS" in report
284
+ assert "Function: main" in report
285
+ assert "complexity: 10" in report
286
+
287
+ def test_get_report_data(self) -> None:
288
+ """Test getting raw report data."""
289
+ reporter = OptimizationReporter("test")
290
+
291
+ reporter.start_pass("test_pass")
292
+ reporter.end_pass(metrics={"changes": 5})
293
+
294
+ data = reporter.get_report_data()
295
+
296
+ assert isinstance(data, ModuleMetrics)
297
+ assert data.module_name == "test"
298
+ assert len(data.pass_metrics) == 1
299
+
300
+ def test_empty_reporter(self) -> None:
301
+ """Test reporter with no passes."""
302
+ reporter = OptimizationReporter("empty")
303
+
304
+ summary = reporter.generate_summary()
305
+
306
+ assert "Module: empty" in summary
307
+ assert "Total Passes: 0" in summary
308
+ assert "Total Time: 0.00ms" in summary
309
+
310
+ def test_end_pass_without_start(self) -> None:
311
+ """Test ending a pass without starting one."""
312
+ reporter = OptimizationReporter("test")
313
+
314
+ # This should not crash
315
+ reporter.end_pass(metrics={"test": 1})
316
+
317
+ # No pass should be recorded
318
+ assert len(reporter.module_metrics.pass_metrics) == 0
@@ -0,0 +1,294 @@
1
+ """Tests for the pass manager and optimization framework."""
2
+
3
+ from machine_dialect.mir.basic_block import BasicBlock
4
+ from machine_dialect.mir.mir_function import MIRFunction
5
+ from machine_dialect.mir.mir_instructions import (
6
+ BinaryOp,
7
+ LoadConst,
8
+ Return,
9
+ )
10
+ from machine_dialect.mir.mir_module import MIRModule
11
+ from machine_dialect.mir.mir_types import MIRType
12
+ from machine_dialect.mir.mir_values import Constant, Temp
13
+ from machine_dialect.mir.optimization_config import (
14
+ OptimizationConfig,
15
+ OptimizationPipeline,
16
+ )
17
+ from machine_dialect.mir.optimizations import register_all_passes
18
+ from machine_dialect.mir.pass_manager import PassManager
19
+
20
+
21
+ def create_test_module() -> MIRModule:
22
+ """Create a test module with optimization opportunities.
23
+
24
+ Returns:
25
+ A test MIR module.
26
+ """
27
+ module = MIRModule("test")
28
+
29
+ # Create main function with constant folding opportunities
30
+ main_func = MIRFunction("main")
31
+
32
+ # Create basic blocks
33
+ entry = BasicBlock("entry")
34
+ main_func.cfg.add_block(entry)
35
+ main_func.cfg.set_entry_block(entry)
36
+
37
+ # Add instructions with optimization opportunities
38
+ t0 = Temp(MIRType.INT, 0)
39
+ t1 = Temp(MIRType.INT, 1)
40
+ t2 = Temp(MIRType.INT, 2)
41
+ t3 = Temp(MIRType.INT, 3)
42
+ t4 = Temp(MIRType.INT, 4)
43
+
44
+ # Constant folding opportunity: 2 + 3 = 5
45
+ entry.add_instruction(LoadConst(t0, Constant(2, MIRType.INT), (1, 1)))
46
+ entry.add_instruction(LoadConst(t1, Constant(3, MIRType.INT), (1, 1)))
47
+ entry.add_instruction(BinaryOp(t2, "+", t0, t1, (1, 1)))
48
+
49
+ # Strength reduction opportunity: x * 4 -> x << 2
50
+ entry.add_instruction(LoadConst(t3, Constant(4, MIRType.INT), (1, 1)))
51
+ entry.add_instruction(BinaryOp(t4, "*", t2, t3, (1, 1)))
52
+
53
+ # Return result
54
+ entry.add_instruction(Return((1, 1), t4))
55
+
56
+ module.add_function(main_func)
57
+ return module
58
+
59
+
60
+ def test_pass_manager_creation() -> None:
61
+ """Test creating a pass manager."""
62
+ pm = PassManager()
63
+ assert pm is not None
64
+ assert pm.registry is not None
65
+ assert pm.analysis_manager is not None
66
+ assert pm.scheduler is not None
67
+
68
+
69
+ def test_register_passes() -> None:
70
+ """Test registering optimization passes."""
71
+ pm = PassManager()
72
+ register_all_passes(pm)
73
+
74
+ # Check that passes are registered
75
+ passes = pm.registry.list_passes()
76
+ assert "dce" in passes
77
+ assert "constant-propagation" in passes
78
+ assert "cse" in passes
79
+ assert "strength-reduction" in passes
80
+ assert "use-def-chains" in passes
81
+ assert "loop-analysis" in passes
82
+
83
+
84
+ def test_optimization_pipeline() -> None:
85
+ """Test getting optimization pipeline for different levels."""
86
+ config0 = OptimizationConfig.from_level(0)
87
+ passes0 = OptimizationPipeline.get_passes(config0)
88
+ assert len(passes0) == 0 # No optimizations at -O0
89
+
90
+ config1 = OptimizationConfig.from_level(1)
91
+ passes1 = OptimizationPipeline.get_passes(config1)
92
+ assert "constant-propagation" in passes1
93
+ assert "dce" in passes1
94
+
95
+ config2 = OptimizationConfig.from_level(2)
96
+ passes2 = OptimizationPipeline.get_passes(config2)
97
+ assert "cse" in passes2
98
+ assert len(passes2) > len(passes1)
99
+
100
+ config3 = OptimizationConfig.from_level(3)
101
+ passes3 = OptimizationPipeline.get_passes(config3)
102
+ assert len(passes3) >= len(passes2)
103
+
104
+
105
+ def test_run_optimizations() -> None:
106
+ """Test running optimizations on a module."""
107
+ module = create_test_module()
108
+ pm = PassManager()
109
+ register_all_passes(pm)
110
+
111
+ # Run basic optimizations
112
+ config = OptimizationConfig.from_level(1)
113
+ passes = OptimizationPipeline.get_passes(config)
114
+
115
+ modified = pm.run_passes(module, passes, config.level)
116
+ # Module should be modified by optimizations
117
+ assert modified or len(passes) == 0
118
+
119
+ # Check statistics
120
+ stats = pm.get_statistics()
121
+ assert isinstance(stats, dict)
122
+
123
+
124
+ def test_constant_propagation() -> None:
125
+ """Test constant propagation pass."""
126
+ from machine_dialect.mir.optimizations.constant_propagation import (
127
+ ConstantPropagation,
128
+ )
129
+
130
+ module = create_test_module()
131
+ func = module.get_function("main")
132
+ assert func is not None
133
+
134
+ # Run constant propagation
135
+ cp_pass = ConstantPropagation()
136
+ cp_pass.initialize()
137
+ cp_pass.run_on_function(func)
138
+ cp_pass.finalize()
139
+
140
+ # Check that constants were propagated
141
+ stats = cp_pass.get_stats()
142
+ # May or may not have propagated depending on exact implementation
143
+ assert stats is not None
144
+
145
+
146
+ def test_dead_code_elimination() -> None:
147
+ """Test dead code elimination pass."""
148
+ from machine_dialect.mir.analyses.use_def_chains import UseDefChainsAnalysis
149
+ from machine_dialect.mir.optimizations.dce import DeadCodeElimination
150
+
151
+ # Create function with dead code
152
+ func = MIRFunction("test")
153
+ entry = BasicBlock("entry")
154
+ func.cfg.add_block(entry)
155
+ func.cfg.set_entry_block(entry)
156
+
157
+ # Dead instruction: result not used
158
+ t0 = Temp(MIRType.INT, 0)
159
+ t1 = Temp(MIRType.INT, 1)
160
+ entry.add_instruction(LoadConst(t0, Constant(42, MIRType.INT), (1, 1)))
161
+ entry.add_instruction(BinaryOp(t1, "+", t0, t0, (1, 1))) # Dead if t1 not used
162
+ entry.add_instruction(Return((1, 1), t0)) # Only t0 is used
163
+
164
+ # First build use-def chains
165
+ use_def = UseDefChainsAnalysis()
166
+ chains = use_def.run_on_function(func)
167
+
168
+ # Run DCE
169
+ dce_pass = DeadCodeElimination()
170
+ dce_pass.analysis_manager = type("", (), {"get_analysis": lambda _, __, ___: chains})()
171
+ dce_pass.initialize()
172
+ modified = dce_pass.run_on_function(func)
173
+ dce_pass.finalize()
174
+
175
+ # Should have removed dead instruction
176
+ assert modified or len(entry.instructions) == 2
177
+
178
+
179
+ def test_strength_reduction() -> None:
180
+ """Test strength reduction pass."""
181
+ from machine_dialect.mir.optimizations.strength_reduction import StrengthReduction
182
+
183
+ func = MIRFunction("test")
184
+ entry = BasicBlock("entry")
185
+ func.cfg.add_block(entry)
186
+ func.cfg.set_entry_block(entry)
187
+
188
+ # Multiplication by power of 2
189
+ t0 = Temp(MIRType.INT, 0)
190
+ t1 = Temp(MIRType.INT, 1)
191
+ entry.add_instruction(LoadConst(t0, Constant(10, MIRType.INT), (1, 1)))
192
+ entry.add_instruction(BinaryOp(t1, "*", t0, Constant(8, MIRType.INT), (1, 1)))
193
+ entry.add_instruction(Return((1, 1), t1))
194
+
195
+ # Run strength reduction
196
+ sr_pass = StrengthReduction()
197
+ sr_pass.initialize()
198
+ sr_pass.run_on_function(func)
199
+ sr_pass.finalize()
200
+
201
+ # Should have converted multiply to shift
202
+ stats = sr_pass.get_stats()
203
+ assert stats.get("multiply_to_shift", 0) >= 0
204
+
205
+
206
+ def test_cse() -> None:
207
+ """Test common subexpression elimination."""
208
+ from machine_dialect.mir.optimizations.cse import CommonSubexpressionElimination
209
+
210
+ func = MIRFunction("test")
211
+ entry = BasicBlock("entry")
212
+ func.cfg.add_block(entry)
213
+ func.cfg.set_entry_block(entry)
214
+
215
+ # Common subexpression: t0 + t1 computed twice
216
+ t0 = Temp(MIRType.INT, 0)
217
+ t1 = Temp(MIRType.INT, 1)
218
+ t2 = Temp(MIRType.INT, 2)
219
+ t3 = Temp(MIRType.INT, 3)
220
+
221
+ entry.add_instruction(LoadConst(t0, Constant(5, MIRType.INT), (1, 1)))
222
+ entry.add_instruction(LoadConst(t1, Constant(7, MIRType.INT), (1, 1)))
223
+ entry.add_instruction(BinaryOp(t2, "+", t0, t1, (1, 1))) # First computation
224
+ entry.add_instruction(BinaryOp(t3, "+", t0, t1, (1, 1))) # Same computation
225
+ entry.add_instruction(Return((1, 1), t3))
226
+
227
+ # Run CSE
228
+ cse_pass = CommonSubexpressionElimination()
229
+ cse_pass.initialize()
230
+ cse_pass.run_on_function(func)
231
+ cse_pass.finalize()
232
+
233
+ # Should have eliminated common subexpression
234
+ stats = cse_pass.get_stats()
235
+ assert stats.get("local_cse_eliminated", 0) >= 0
236
+
237
+
238
+ def test_analysis_caching() -> None:
239
+ """Test that analyses are cached properly."""
240
+ from machine_dialect.mir.analyses.use_def_chains import UseDefChainsAnalysis
241
+
242
+ pm = PassManager()
243
+ register_all_passes(pm)
244
+
245
+ module = create_test_module()
246
+ func = module.get_function("main")
247
+ assert func is not None # Type narrowing for MyPy
248
+
249
+ # Register and run analysis
250
+ analysis = UseDefChainsAnalysis()
251
+ pm.analysis_manager.register_analysis("use-def-chains", analysis)
252
+
253
+ # First call should compute
254
+ result1 = pm.analysis_manager.get_analysis("use-def-chains", func)
255
+ assert result1 is not None
256
+
257
+ # Second call should use cache
258
+ result2 = pm.analysis_manager.get_analysis("use-def-chains", func)
259
+ assert result2 is result1 # Same object
260
+
261
+ # Invalidate and recompute
262
+ pm.analysis_manager.invalidate(["use-def-chains"])
263
+ result3 = pm.analysis_manager.get_analysis("use-def-chains", func)
264
+ # May or may not be same object depending on implementation
265
+ assert result3 is not None
266
+
267
+
268
+ def test_full_optimization_pipeline() -> None:
269
+ """Test complete optimization pipeline."""
270
+ module = create_test_module()
271
+ pm = PassManager()
272
+ register_all_passes(pm)
273
+
274
+ # Get initial instruction count
275
+ initial_count = sum(
276
+ len(block.instructions) for func in module.functions.values() for block in func.cfg.blocks.values()
277
+ )
278
+
279
+ # Run O2 optimizations
280
+ config = OptimizationConfig.from_level(2)
281
+ passes = OptimizationPipeline.get_passes(config)
282
+ pm.run_passes(module, passes, config.level)
283
+
284
+ # Get final instruction count
285
+ final_count = sum(
286
+ len(block.instructions) for func in module.functions.values() for block in func.cfg.blocks.values()
287
+ )
288
+
289
+ # Should have same or fewer instructions after optimization
290
+ assert final_count <= initial_count
291
+
292
+ # Check that we collected statistics
293
+ stats = pm.get_statistics()
294
+ assert len(stats) > 0 if passes else True
@@ -0,0 +1,64 @@
1
+ """Test that all passes are properly registered."""
2
+
3
+ import importlib
4
+ import inspect
5
+ import pkgutil
6
+
7
+ from machine_dialect.mir.optimization_pass import AnalysisPass, OptimizationPass, Pass
8
+ from machine_dialect.mir.optimizations import register_all_passes
9
+ from machine_dialect.mir.pass_manager import PassManager
10
+
11
+
12
+ class TestPassRegistration:
13
+ """Test that all pass classes are registered."""
14
+
15
+ def test_all_passes_registered(self) -> None:
16
+ """Verify that all Pass subclasses are registered."""
17
+ # Get all registered passes
18
+ pm = PassManager()
19
+ register_all_passes(pm)
20
+ registered_names = set(pm.registry._passes.keys())
21
+
22
+ # Find all Pass subclasses
23
+ all_pass_classes = set()
24
+
25
+ # Import all modules in mir.analyses and mir.optimizations
26
+ import machine_dialect.mir.analyses as analyses_pkg
27
+ import machine_dialect.mir.optimizations as opt_pkg
28
+
29
+ for pkg in [analyses_pkg, opt_pkg]:
30
+ for _importer, modname, ispkg in pkgutil.iter_modules(pkg.__path__, prefix=pkg.__name__ + "."):
31
+ if not ispkg:
32
+ try:
33
+ module = importlib.import_module(modname)
34
+ for name, obj in inspect.getmembers(module, inspect.isclass):
35
+ # Check if it's a concrete Pass subclass
36
+ if (
37
+ issubclass(obj, Pass)
38
+ and obj not in [Pass, OptimizationPass, AnalysisPass]
39
+ and not inspect.isabstract(obj)
40
+ # Exclude base classes that contain "Pass" in name but are abstract
41
+ and not (name.endswith("Pass") and name != obj.__name__)
42
+ ):
43
+ # Try to instantiate to get the pass name
44
+ try:
45
+ instance = obj()
46
+ info = instance.get_info()
47
+ all_pass_classes.add(info.name)
48
+ except (TypeError, AttributeError):
49
+ # Skip abstract classes or classes with required args
50
+ pass
51
+ except ImportError:
52
+ # Some modules might have additional dependencies
53
+ pass
54
+
55
+ # Check that all found passes are registered
56
+ unregistered = all_pass_classes - registered_names
57
+
58
+ assert unregistered == set(), (
59
+ f"Found unregistered passes: {sorted(unregistered)}. "
60
+ f"Add them to register_all_passes() in mir/optimizations/__init__.py"
61
+ )
62
+
63
+ # Also verify we have a reasonable number of passes
64
+ assert len(registered_names) > 10, f"Expected at least 10 registered passes, found {len(registered_names)}"