machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,302 @@
1
+ """Virtual register allocation for MIR.
2
+
3
+ This module implements register allocation using linear scan algorithm.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+
8
+ from machine_dialect.mir.mir_function import MIRFunction
9
+ from machine_dialect.mir.mir_instructions import MIRInstruction
10
+ from machine_dialect.mir.mir_values import MIRValue, Temp, Variable
11
+
12
+
13
+ @dataclass
14
+ class LiveInterval:
15
+ """Represents the live interval of a value.
16
+
17
+ Attributes:
18
+ value: The MIR value.
19
+ start: Start position of the interval.
20
+ end: End position of the interval.
21
+ register: Allocated register/slot number.
22
+ """
23
+
24
+ value: MIRValue
25
+ start: int
26
+ end: int
27
+ register: int | None = None
28
+
29
+
30
+ @dataclass
31
+ class RegisterAllocation:
32
+ """Result of register allocation.
33
+
34
+ Attributes:
35
+ allocations: Mapping from MIR values to register numbers.
36
+ spilled_values: Set of values that need to be spilled to memory.
37
+ max_registers: Maximum number of registers used.
38
+ """
39
+
40
+ allocations: dict[MIRValue, int]
41
+ spilled_values: set[MIRValue]
42
+ max_registers: int
43
+
44
+
45
+ class RegisterAllocator:
46
+ """Allocates virtual registers for MIR values using linear scan."""
47
+
48
+ def __init__(self, function: MIRFunction, max_registers: int = 256) -> None:
49
+ """Initialize the register allocator.
50
+
51
+ Args:
52
+ function: The MIR function to allocate registers for.
53
+ max_registers: Maximum number of available registers.
54
+ """
55
+ self.function = function
56
+ self.max_registers = max_registers
57
+ self.live_intervals: list[LiveInterval] = []
58
+ self.active_intervals: list[LiveInterval] = []
59
+ self.free_registers: list[int] = list(range(max_registers))
60
+ self.instruction_positions: dict[MIRInstruction, int] = {}
61
+ self.spilled_values: set[MIRValue] = set()
62
+ self.next_spill_slot = 0 # Track spill slot allocation
63
+
64
+ def allocate(self) -> RegisterAllocation:
65
+ """Perform register allocation.
66
+
67
+ Returns:
68
+ The register allocation result.
69
+ """
70
+ # Build instruction positions
71
+ self._build_instruction_positions()
72
+
73
+ # Compute live intervals
74
+ self._compute_live_intervals()
75
+
76
+ # Sort intervals by start position
77
+ self.live_intervals.sort(key=lambda x: x.start)
78
+
79
+ # Perform linear scan allocation
80
+ allocations = self._linear_scan()
81
+
82
+ # Calculate the actual number of registers used
83
+ max_reg_used = 0
84
+ for reg in allocations.values():
85
+ if reg >= 0: # Only count actual registers, not spill slots
86
+ max_reg_used = max(max_reg_used, reg + 1)
87
+
88
+ return RegisterAllocation(
89
+ allocations=allocations, spilled_values=self.spilled_values, max_registers=max_reg_used
90
+ )
91
+
92
+ def _build_instruction_positions(self) -> None:
93
+ """Build a mapping from instructions to positions."""
94
+ position = 0
95
+ for block in self.function.cfg.blocks.values():
96
+ for inst in block.instructions:
97
+ self.instruction_positions[inst] = position
98
+ position += 1
99
+
100
+ def _compute_live_intervals(self) -> None:
101
+ """Compute live intervals for all values."""
102
+ # Track first definition and last use for each value
103
+ first_def: dict[MIRValue, int] = {}
104
+ last_use: dict[MIRValue, int] = {}
105
+
106
+ for block in self.function.cfg.blocks.values():
107
+ for inst in block.instructions:
108
+ position = self.instruction_positions[inst]
109
+
110
+ # Process definitions
111
+ for def_val in inst.get_defs():
112
+ if self._should_allocate(def_val):
113
+ if def_val not in first_def:
114
+ first_def[def_val] = position
115
+ last_use[def_val] = position # Def is also a use
116
+
117
+ # Process uses
118
+ for use_val in inst.get_uses():
119
+ if self._should_allocate(use_val):
120
+ last_use[use_val] = position
121
+ if use_val not in first_def:
122
+ # Value used before defined (parameter or external)
123
+ first_def[use_val] = 0
124
+
125
+ # Create intervals
126
+ for value in first_def:
127
+ interval = LiveInterval(value=value, start=first_def[value], end=last_use.get(value, first_def[value]))
128
+ self.live_intervals.append(interval)
129
+
130
+ def _should_allocate(self, value: MIRValue) -> bool:
131
+ """Check if a value needs register allocation.
132
+
133
+ Args:
134
+ value: The value to check.
135
+
136
+ Returns:
137
+ True if the value needs a register.
138
+ """
139
+ # Allocate registers for temps and variables
140
+ return isinstance(value, Temp | Variable)
141
+
142
+ def _linear_scan(self) -> dict[MIRValue, int]:
143
+ """Perform linear scan register allocation.
144
+
145
+ Returns:
146
+ Mapping from values to register numbers.
147
+ """
148
+ allocations: dict[MIRValue, int] = {}
149
+
150
+ for interval in self.live_intervals:
151
+ # Expire old intervals
152
+ self._expire_old_intervals(interval.start)
153
+
154
+ # Try to allocate a register
155
+ if self.free_registers:
156
+ # Allocate from free registers
157
+ register = self.free_registers.pop(0)
158
+ interval.register = register
159
+ allocations[interval.value] = register
160
+ self.active_intervals.append(interval)
161
+ # Sort active intervals by end position
162
+ self.active_intervals.sort(key=lambda x: x.end)
163
+ else:
164
+ # Need to spill - all registers are in use
165
+ self._spill_at_interval(interval)
166
+ if interval.register is not None:
167
+ # Got a register through spilling
168
+ allocations[interval.value] = interval.register
169
+ self.active_intervals.append(interval)
170
+ self.active_intervals.sort(key=lambda x: x.end)
171
+ else:
172
+ # This interval was spilled to memory
173
+ self.spilled_values.add(interval.value)
174
+ # Assign a spill slot (using negative numbers for spill slots)
175
+ self.next_spill_slot += 1
176
+ allocations[interval.value] = -(self.max_registers + self.next_spill_slot)
177
+
178
+ return allocations
179
+
180
+ def _expire_old_intervals(self, current_position: int) -> None:
181
+ """Expire intervals that are no longer live.
182
+
183
+ Args:
184
+ current_position: The current position in the program.
185
+ """
186
+ expired = []
187
+ for interval in self.active_intervals:
188
+ if interval.end >= current_position:
189
+ break # Sorted by end, so we can stop
190
+ expired.append(interval)
191
+
192
+ for interval in expired:
193
+ self.active_intervals.remove(interval)
194
+ if interval.register is not None and interval.register >= 0:
195
+ self.free_registers.append(interval.register)
196
+ self.free_registers.sort()
197
+
198
+ def _spill_at_interval(self, interval: LiveInterval) -> None:
199
+ """Spill a value when no registers are available.
200
+
201
+ Args:
202
+ interval: The interval that needs a register.
203
+ """
204
+ if not self.active_intervals:
205
+ # No active intervals, must spill current
206
+ self.spilled_values.add(interval.value)
207
+ interval.register = None
208
+ return
209
+
210
+ # Find the interval with the furthest end point
211
+ # (this is the last one since active_intervals is sorted by end)
212
+ spill_candidate = self.active_intervals[-1]
213
+
214
+ if spill_candidate.end > interval.end:
215
+ # Spill the furthest interval and give its register to current
216
+ self.active_intervals.remove(spill_candidate)
217
+ interval.register = spill_candidate.register
218
+ self.spilled_values.add(spill_candidate.value)
219
+ spill_candidate.register = None
220
+ else:
221
+ # Current interval ends later, spill it instead
222
+ self.spilled_values.add(interval.value)
223
+ interval.register = None
224
+
225
+
226
+ class LifetimeAnalyzer:
227
+ """Analyzes the lifetime of temporaries for optimization."""
228
+
229
+ def __init__(self, function: MIRFunction) -> None:
230
+ """Initialize the lifetime analyzer.
231
+
232
+ Args:
233
+ function: The function to analyze.
234
+ """
235
+ self.function = function
236
+ self.lifetimes: dict[MIRValue, tuple[int, int]] = {}
237
+
238
+ def analyze(self) -> dict[MIRValue, tuple[int, int]]:
239
+ """Analyze lifetimes of all values.
240
+
241
+ Returns:
242
+ Mapping from values to (first_use, last_use) positions.
243
+ """
244
+ position = 0
245
+
246
+ for block in self.function.cfg.blocks.values():
247
+ for inst in block.instructions:
248
+ # Track definitions
249
+ for def_val in inst.get_defs():
250
+ if isinstance(def_val, Temp | Variable):
251
+ if def_val not in self.lifetimes:
252
+ self.lifetimes[def_val] = (position, position)
253
+ else:
254
+ start, _ = self.lifetimes[def_val]
255
+ self.lifetimes[def_val] = (start, position)
256
+
257
+ # Track uses
258
+ for use_val in inst.get_uses():
259
+ if isinstance(use_val, Temp | Variable):
260
+ if use_val not in self.lifetimes:
261
+ self.lifetimes[use_val] = (position, position)
262
+ else:
263
+ start, _ = self.lifetimes[use_val]
264
+ self.lifetimes[use_val] = (start, position)
265
+
266
+ position += 1
267
+
268
+ return self.lifetimes
269
+
270
+ def find_reusable_slots(self) -> list[set[MIRValue]]:
271
+ """Find sets of values that can share the same stack slot.
272
+
273
+ Returns:
274
+ List of sets where each set contains values that can share a slot.
275
+ """
276
+ reusable_groups: list[set[MIRValue]] = []
277
+
278
+ # Sort values by start of lifetime
279
+ sorted_values = sorted(self.lifetimes.items(), key=lambda x: x[1][0])
280
+
281
+ for value, (start, end) in sorted_values:
282
+ # Find a group where this value doesn't overlap with any member
283
+ placed = False
284
+ for group in reusable_groups:
285
+ can_share = True
286
+ for other in group:
287
+ other_start, other_end = self.lifetimes[other]
288
+ # Check for overlap
289
+ if not (end < other_start or start > other_end):
290
+ can_share = False
291
+ break
292
+
293
+ if can_share:
294
+ group.add(value)
295
+ placed = True
296
+ break
297
+
298
+ if not placed:
299
+ # Create a new group
300
+ reusable_groups.append({value})
301
+
302
+ return reusable_groups
@@ -0,0 +1,17 @@
1
+ """MIR optimization reporting infrastructure."""
2
+
3
+ from machine_dialect.mir.reporting.optimization_reporter import OptimizationReporter
4
+ from machine_dialect.mir.reporting.report_formatter import (
5
+ HTMLReportFormatter,
6
+ JSONReportFormatter,
7
+ ReportFormatter,
8
+ TextReportFormatter,
9
+ )
10
+
11
+ __all__ = [
12
+ "HTMLReportFormatter",
13
+ "JSONReportFormatter",
14
+ "OptimizationReporter",
15
+ "ReportFormatter",
16
+ "TextReportFormatter",
17
+ ]
@@ -0,0 +1,314 @@
1
+ """Optimization reporter for collecting and aggregating pass statistics.
2
+
3
+ This module provides infrastructure for collecting optimization statistics
4
+ from various passes and generating comprehensive reports.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from enum import Enum
9
+ from typing import Any
10
+
11
+
12
+ class MetricType(Enum):
13
+ """Types of metrics collected."""
14
+
15
+ COUNT = "count" # Simple count (e.g., instructions removed)
16
+ PERCENTAGE = "percentage" # Percentage value
17
+ SIZE = "size" # Size in bytes
18
+ TIME = "time" # Time in milliseconds
19
+ RATIO = "ratio" # Ratio between two values
20
+
21
+
22
+ @dataclass
23
+ class PassMetrics:
24
+ """Metrics collected from a single optimization pass.
25
+
26
+ Attributes:
27
+ pass_name: Name of the optimization pass.
28
+ phase: Optimization phase (early, middle, late).
29
+ metrics: Dictionary of metric name to value.
30
+ before_stats: Statistics before the pass.
31
+ after_stats: Statistics after the pass.
32
+ time_ms: Time taken to run the pass in milliseconds.
33
+ """
34
+
35
+ pass_name: str
36
+ phase: str = "main"
37
+ metrics: dict[str, int] = field(default_factory=dict)
38
+ before_stats: dict[str, int] = field(default_factory=dict)
39
+ after_stats: dict[str, int] = field(default_factory=dict)
40
+ time_ms: float = 0.0
41
+
42
+ def get_improvement(self, metric: str) -> float:
43
+ """Calculate improvement percentage for a metric.
44
+
45
+ Args:
46
+ metric: Metric name.
47
+
48
+ Returns:
49
+ Improvement percentage (positive means reduction).
50
+ """
51
+ before = self.before_stats.get(metric, 0)
52
+ after = self.after_stats.get(metric, 0)
53
+ if before == 0:
54
+ return 0.0
55
+ return ((before - after) / before) * 100
56
+
57
+
58
+ @dataclass
59
+ class ModuleMetrics:
60
+ """Metrics for an entire module.
61
+
62
+ Attributes:
63
+ module_name: Name of the module.
64
+ function_metrics: Metrics for each function.
65
+ pass_metrics: Metrics from each pass.
66
+ total_time_ms: Total optimization time.
67
+ optimization_level: Optimization level used.
68
+ """
69
+
70
+ module_name: str
71
+ function_metrics: dict[str, dict[str, Any]] = field(default_factory=dict)
72
+ pass_metrics: list[PassMetrics] = field(default_factory=list)
73
+ total_time_ms: float = 0.0
74
+ optimization_level: int = 0
75
+
76
+ def add_pass_metrics(self, metrics: PassMetrics) -> None:
77
+ """Add metrics from a pass.
78
+
79
+ Args:
80
+ metrics: Pass metrics to add.
81
+ """
82
+ self.pass_metrics.append(metrics)
83
+ self.total_time_ms += metrics.time_ms
84
+
85
+ def get_summary(self) -> dict[str, Any]:
86
+ """Get summary statistics.
87
+
88
+ Returns:
89
+ Dictionary of summary statistics.
90
+ """
91
+ summary = {
92
+ "module_name": self.module_name,
93
+ "optimization_level": self.optimization_level,
94
+ "total_passes": len(self.pass_metrics),
95
+ "total_time_ms": self.total_time_ms,
96
+ "passes_applied": [m.pass_name for m in self.pass_metrics],
97
+ }
98
+
99
+ # Aggregate improvements
100
+ total_improvements = {}
101
+ for metrics in self.pass_metrics:
102
+ for key in metrics.before_stats:
103
+ if key in metrics.after_stats:
104
+ improvement = metrics.get_improvement(key)
105
+ if key not in total_improvements:
106
+ total_improvements[key] = 0.0
107
+ total_improvements[key] += improvement
108
+
109
+ summary["improvements"] = total_improvements
110
+
111
+ # Calculate total metrics
112
+ total_metrics = {}
113
+ for metrics in self.pass_metrics:
114
+ for key, value in metrics.metrics.items():
115
+ if key not in total_metrics:
116
+ total_metrics[key] = 0
117
+ total_metrics[key] += value
118
+
119
+ summary["total_metrics"] = total_metrics
120
+
121
+ return summary
122
+
123
+
124
+ class OptimizationReporter:
125
+ """Collects and reports optimization statistics.
126
+
127
+ This class aggregates statistics from multiple optimization passes
128
+ and generates comprehensive reports about the optimization process.
129
+ """
130
+
131
+ def __init__(self, module_name: str = "unknown") -> None:
132
+ """Initialize the reporter.
133
+
134
+ Args:
135
+ module_name: Name of the module being optimized.
136
+ """
137
+ self.module_metrics = ModuleMetrics(module_name=module_name)
138
+ self.current_pass: PassMetrics | None = None
139
+
140
+ def start_pass(
141
+ self,
142
+ pass_name: str,
143
+ phase: str = "main",
144
+ before_stats: dict[str, int] | None = None,
145
+ ) -> None:
146
+ """Start tracking a new pass.
147
+
148
+ Args:
149
+ pass_name: Name of the pass.
150
+ phase: Optimization phase.
151
+ before_stats: Statistics before the pass.
152
+ """
153
+ self.current_pass = PassMetrics(
154
+ pass_name=pass_name,
155
+ phase=phase,
156
+ before_stats=before_stats or {},
157
+ )
158
+
159
+ def end_pass(
160
+ self,
161
+ metrics: dict[str, int] | None = None,
162
+ after_stats: dict[str, int] | None = None,
163
+ time_ms: float = 0.0,
164
+ ) -> None:
165
+ """End tracking the current pass.
166
+
167
+ Args:
168
+ metrics: Pass-specific metrics.
169
+ after_stats: Statistics after the pass.
170
+ time_ms: Time taken by the pass.
171
+ """
172
+ if self.current_pass:
173
+ self.current_pass.metrics = metrics or {}
174
+ self.current_pass.after_stats = after_stats or {}
175
+ self.current_pass.time_ms = time_ms
176
+ self.module_metrics.add_pass_metrics(self.current_pass)
177
+ self.current_pass = None
178
+
179
+ def add_function_metrics(self, func_name: str, metrics: dict[str, Any]) -> None:
180
+ """Add metrics for a specific function.
181
+
182
+ Args:
183
+ func_name: Function name.
184
+ metrics: Function metrics.
185
+ """
186
+ self.module_metrics.function_metrics[func_name] = metrics
187
+
188
+ def add_custom_stats(self, pass_name: str, stats: dict[str, int]) -> None:
189
+ """Add custom statistics for a pass.
190
+
191
+ Args:
192
+ pass_name: Name of the pass.
193
+ stats: Statistics to add.
194
+ """
195
+ # Create a pass metrics entry for custom stats
196
+ metrics = PassMetrics(pass_name=pass_name, phase="bytecode", metrics=stats)
197
+ self.module_metrics.add_pass_metrics(metrics)
198
+
199
+ def set_optimization_level(self, level: int) -> None:
200
+ """Set the optimization level.
201
+
202
+ Args:
203
+ level: Optimization level (0-3).
204
+ """
205
+ self.module_metrics.optimization_level = level
206
+
207
+ def get_report_data(self) -> ModuleMetrics:
208
+ """Get the collected metrics.
209
+
210
+ Returns:
211
+ Module metrics.
212
+ """
213
+ return self.module_metrics
214
+
215
+ def generate_summary(self) -> str:
216
+ """Generate a text summary of optimizations.
217
+
218
+ Returns:
219
+ Text summary.
220
+ """
221
+ summary = self.module_metrics.get_summary()
222
+ lines = []
223
+
224
+ lines.append(f"Module: {summary['module_name']}")
225
+ lines.append(f"Optimization Level: {summary['optimization_level']}")
226
+ lines.append(f"Total Passes: {summary['total_passes']}")
227
+ lines.append(f"Total Time: {summary['total_time_ms']:.2f}ms")
228
+ lines.append("")
229
+
230
+ if summary["passes_applied"]:
231
+ lines.append("Passes Applied:")
232
+ for pass_name in summary["passes_applied"]:
233
+ lines.append(f" - {pass_name}")
234
+ lines.append("")
235
+
236
+ if summary["improvements"]:
237
+ lines.append("Improvements:")
238
+ for metric, improvement in summary["improvements"].items():
239
+ if improvement > 0:
240
+ lines.append(f" {metric}: {improvement:.1f}% reduction")
241
+ lines.append("")
242
+
243
+ if summary["total_metrics"]:
244
+ lines.append("Total Changes:")
245
+ for metric, value in summary["total_metrics"].items():
246
+ if value > 0:
247
+ lines.append(f" {metric}: {value}")
248
+
249
+ return "\n".join(lines)
250
+
251
+ def generate_detailed_report(self) -> str:
252
+ """Generate a detailed report with per-pass statistics.
253
+
254
+ Returns:
255
+ Detailed text report.
256
+ """
257
+ lines = []
258
+ lines.append("=" * 60)
259
+ lines.append("OPTIMIZATION REPORT")
260
+ lines.append("=" * 60)
261
+ lines.append("")
262
+
263
+ # Summary
264
+ lines.append(self.generate_summary())
265
+ lines.append("")
266
+ lines.append("=" * 60)
267
+ lines.append("DETAILED PASS STATISTICS")
268
+ lines.append("=" * 60)
269
+
270
+ # Per-pass details
271
+ for metrics in self.module_metrics.pass_metrics:
272
+ lines.append("")
273
+ lines.append(f"Pass: {metrics.pass_name}")
274
+ lines.append(f"Phase: {metrics.phase}")
275
+ lines.append(f"Time: {metrics.time_ms:.2f}ms")
276
+
277
+ if metrics.metrics:
278
+ lines.append("Metrics:")
279
+ for key, value in metrics.metrics.items():
280
+ if value > 0:
281
+ lines.append(f" {key}: {value}")
282
+
283
+ # Show improvements
284
+ improvements = []
285
+ for key in metrics.before_stats:
286
+ if key in metrics.after_stats:
287
+ improvement = metrics.get_improvement(key)
288
+ if improvement > 0:
289
+ improvements.append(
290
+ f" {key}: {metrics.before_stats[key]} → "
291
+ f"{metrics.after_stats[key]} "
292
+ f"({improvement:.1f}% reduction)"
293
+ )
294
+
295
+ if improvements:
296
+ lines.append("Improvements:")
297
+ lines.extend(improvements)
298
+
299
+ lines.append("-" * 40)
300
+
301
+ # Function-specific metrics if available
302
+ if self.module_metrics.function_metrics:
303
+ lines.append("")
304
+ lines.append("=" * 60)
305
+ lines.append("FUNCTION METRICS")
306
+ lines.append("=" * 60)
307
+
308
+ for func_name, func_metrics in self.module_metrics.function_metrics.items():
309
+ lines.append("")
310
+ lines.append(f"Function: {func_name}")
311
+ for key, value in func_metrics.items():
312
+ lines.append(f" {key}: {value}")
313
+
314
+ return "\n".join(lines)