machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,302 @@
|
|
1
|
+
"""Virtual register allocation for MIR.
|
2
|
+
|
3
|
+
This module implements register allocation using linear scan algorithm.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from dataclasses import dataclass
|
7
|
+
|
8
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
9
|
+
from machine_dialect.mir.mir_instructions import MIRInstruction
|
10
|
+
from machine_dialect.mir.mir_values import MIRValue, Temp, Variable
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class LiveInterval:
|
15
|
+
"""Represents the live interval of a value.
|
16
|
+
|
17
|
+
Attributes:
|
18
|
+
value: The MIR value.
|
19
|
+
start: Start position of the interval.
|
20
|
+
end: End position of the interval.
|
21
|
+
register: Allocated register/slot number.
|
22
|
+
"""
|
23
|
+
|
24
|
+
value: MIRValue
|
25
|
+
start: int
|
26
|
+
end: int
|
27
|
+
register: int | None = None
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class RegisterAllocation:
|
32
|
+
"""Result of register allocation.
|
33
|
+
|
34
|
+
Attributes:
|
35
|
+
allocations: Mapping from MIR values to register numbers.
|
36
|
+
spilled_values: Set of values that need to be spilled to memory.
|
37
|
+
max_registers: Maximum number of registers used.
|
38
|
+
"""
|
39
|
+
|
40
|
+
allocations: dict[MIRValue, int]
|
41
|
+
spilled_values: set[MIRValue]
|
42
|
+
max_registers: int
|
43
|
+
|
44
|
+
|
45
|
+
class RegisterAllocator:
|
46
|
+
"""Allocates virtual registers for MIR values using linear scan."""
|
47
|
+
|
48
|
+
def __init__(self, function: MIRFunction, max_registers: int = 256) -> None:
|
49
|
+
"""Initialize the register allocator.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
function: The MIR function to allocate registers for.
|
53
|
+
max_registers: Maximum number of available registers.
|
54
|
+
"""
|
55
|
+
self.function = function
|
56
|
+
self.max_registers = max_registers
|
57
|
+
self.live_intervals: list[LiveInterval] = []
|
58
|
+
self.active_intervals: list[LiveInterval] = []
|
59
|
+
self.free_registers: list[int] = list(range(max_registers))
|
60
|
+
self.instruction_positions: dict[MIRInstruction, int] = {}
|
61
|
+
self.spilled_values: set[MIRValue] = set()
|
62
|
+
self.next_spill_slot = 0 # Track spill slot allocation
|
63
|
+
|
64
|
+
def allocate(self) -> RegisterAllocation:
|
65
|
+
"""Perform register allocation.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
The register allocation result.
|
69
|
+
"""
|
70
|
+
# Build instruction positions
|
71
|
+
self._build_instruction_positions()
|
72
|
+
|
73
|
+
# Compute live intervals
|
74
|
+
self._compute_live_intervals()
|
75
|
+
|
76
|
+
# Sort intervals by start position
|
77
|
+
self.live_intervals.sort(key=lambda x: x.start)
|
78
|
+
|
79
|
+
# Perform linear scan allocation
|
80
|
+
allocations = self._linear_scan()
|
81
|
+
|
82
|
+
# Calculate the actual number of registers used
|
83
|
+
max_reg_used = 0
|
84
|
+
for reg in allocations.values():
|
85
|
+
if reg >= 0: # Only count actual registers, not spill slots
|
86
|
+
max_reg_used = max(max_reg_used, reg + 1)
|
87
|
+
|
88
|
+
return RegisterAllocation(
|
89
|
+
allocations=allocations, spilled_values=self.spilled_values, max_registers=max_reg_used
|
90
|
+
)
|
91
|
+
|
92
|
+
def _build_instruction_positions(self) -> None:
|
93
|
+
"""Build a mapping from instructions to positions."""
|
94
|
+
position = 0
|
95
|
+
for block in self.function.cfg.blocks.values():
|
96
|
+
for inst in block.instructions:
|
97
|
+
self.instruction_positions[inst] = position
|
98
|
+
position += 1
|
99
|
+
|
100
|
+
def _compute_live_intervals(self) -> None:
|
101
|
+
"""Compute live intervals for all values."""
|
102
|
+
# Track first definition and last use for each value
|
103
|
+
first_def: dict[MIRValue, int] = {}
|
104
|
+
last_use: dict[MIRValue, int] = {}
|
105
|
+
|
106
|
+
for block in self.function.cfg.blocks.values():
|
107
|
+
for inst in block.instructions:
|
108
|
+
position = self.instruction_positions[inst]
|
109
|
+
|
110
|
+
# Process definitions
|
111
|
+
for def_val in inst.get_defs():
|
112
|
+
if self._should_allocate(def_val):
|
113
|
+
if def_val not in first_def:
|
114
|
+
first_def[def_val] = position
|
115
|
+
last_use[def_val] = position # Def is also a use
|
116
|
+
|
117
|
+
# Process uses
|
118
|
+
for use_val in inst.get_uses():
|
119
|
+
if self._should_allocate(use_val):
|
120
|
+
last_use[use_val] = position
|
121
|
+
if use_val not in first_def:
|
122
|
+
# Value used before defined (parameter or external)
|
123
|
+
first_def[use_val] = 0
|
124
|
+
|
125
|
+
# Create intervals
|
126
|
+
for value in first_def:
|
127
|
+
interval = LiveInterval(value=value, start=first_def[value], end=last_use.get(value, first_def[value]))
|
128
|
+
self.live_intervals.append(interval)
|
129
|
+
|
130
|
+
def _should_allocate(self, value: MIRValue) -> bool:
|
131
|
+
"""Check if a value needs register allocation.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
value: The value to check.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
True if the value needs a register.
|
138
|
+
"""
|
139
|
+
# Allocate registers for temps and variables
|
140
|
+
return isinstance(value, Temp | Variable)
|
141
|
+
|
142
|
+
def _linear_scan(self) -> dict[MIRValue, int]:
|
143
|
+
"""Perform linear scan register allocation.
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
Mapping from values to register numbers.
|
147
|
+
"""
|
148
|
+
allocations: dict[MIRValue, int] = {}
|
149
|
+
|
150
|
+
for interval in self.live_intervals:
|
151
|
+
# Expire old intervals
|
152
|
+
self._expire_old_intervals(interval.start)
|
153
|
+
|
154
|
+
# Try to allocate a register
|
155
|
+
if self.free_registers:
|
156
|
+
# Allocate from free registers
|
157
|
+
register = self.free_registers.pop(0)
|
158
|
+
interval.register = register
|
159
|
+
allocations[interval.value] = register
|
160
|
+
self.active_intervals.append(interval)
|
161
|
+
# Sort active intervals by end position
|
162
|
+
self.active_intervals.sort(key=lambda x: x.end)
|
163
|
+
else:
|
164
|
+
# Need to spill - all registers are in use
|
165
|
+
self._spill_at_interval(interval)
|
166
|
+
if interval.register is not None:
|
167
|
+
# Got a register through spilling
|
168
|
+
allocations[interval.value] = interval.register
|
169
|
+
self.active_intervals.append(interval)
|
170
|
+
self.active_intervals.sort(key=lambda x: x.end)
|
171
|
+
else:
|
172
|
+
# This interval was spilled to memory
|
173
|
+
self.spilled_values.add(interval.value)
|
174
|
+
# Assign a spill slot (using negative numbers for spill slots)
|
175
|
+
self.next_spill_slot += 1
|
176
|
+
allocations[interval.value] = -(self.max_registers + self.next_spill_slot)
|
177
|
+
|
178
|
+
return allocations
|
179
|
+
|
180
|
+
def _expire_old_intervals(self, current_position: int) -> None:
|
181
|
+
"""Expire intervals that are no longer live.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
current_position: The current position in the program.
|
185
|
+
"""
|
186
|
+
expired = []
|
187
|
+
for interval in self.active_intervals:
|
188
|
+
if interval.end >= current_position:
|
189
|
+
break # Sorted by end, so we can stop
|
190
|
+
expired.append(interval)
|
191
|
+
|
192
|
+
for interval in expired:
|
193
|
+
self.active_intervals.remove(interval)
|
194
|
+
if interval.register is not None and interval.register >= 0:
|
195
|
+
self.free_registers.append(interval.register)
|
196
|
+
self.free_registers.sort()
|
197
|
+
|
198
|
+
def _spill_at_interval(self, interval: LiveInterval) -> None:
|
199
|
+
"""Spill a value when no registers are available.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
interval: The interval that needs a register.
|
203
|
+
"""
|
204
|
+
if not self.active_intervals:
|
205
|
+
# No active intervals, must spill current
|
206
|
+
self.spilled_values.add(interval.value)
|
207
|
+
interval.register = None
|
208
|
+
return
|
209
|
+
|
210
|
+
# Find the interval with the furthest end point
|
211
|
+
# (this is the last one since active_intervals is sorted by end)
|
212
|
+
spill_candidate = self.active_intervals[-1]
|
213
|
+
|
214
|
+
if spill_candidate.end > interval.end:
|
215
|
+
# Spill the furthest interval and give its register to current
|
216
|
+
self.active_intervals.remove(spill_candidate)
|
217
|
+
interval.register = spill_candidate.register
|
218
|
+
self.spilled_values.add(spill_candidate.value)
|
219
|
+
spill_candidate.register = None
|
220
|
+
else:
|
221
|
+
# Current interval ends later, spill it instead
|
222
|
+
self.spilled_values.add(interval.value)
|
223
|
+
interval.register = None
|
224
|
+
|
225
|
+
|
226
|
+
class LifetimeAnalyzer:
|
227
|
+
"""Analyzes the lifetime of temporaries for optimization."""
|
228
|
+
|
229
|
+
def __init__(self, function: MIRFunction) -> None:
|
230
|
+
"""Initialize the lifetime analyzer.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
function: The function to analyze.
|
234
|
+
"""
|
235
|
+
self.function = function
|
236
|
+
self.lifetimes: dict[MIRValue, tuple[int, int]] = {}
|
237
|
+
|
238
|
+
def analyze(self) -> dict[MIRValue, tuple[int, int]]:
|
239
|
+
"""Analyze lifetimes of all values.
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
Mapping from values to (first_use, last_use) positions.
|
243
|
+
"""
|
244
|
+
position = 0
|
245
|
+
|
246
|
+
for block in self.function.cfg.blocks.values():
|
247
|
+
for inst in block.instructions:
|
248
|
+
# Track definitions
|
249
|
+
for def_val in inst.get_defs():
|
250
|
+
if isinstance(def_val, Temp | Variable):
|
251
|
+
if def_val not in self.lifetimes:
|
252
|
+
self.lifetimes[def_val] = (position, position)
|
253
|
+
else:
|
254
|
+
start, _ = self.lifetimes[def_val]
|
255
|
+
self.lifetimes[def_val] = (start, position)
|
256
|
+
|
257
|
+
# Track uses
|
258
|
+
for use_val in inst.get_uses():
|
259
|
+
if isinstance(use_val, Temp | Variable):
|
260
|
+
if use_val not in self.lifetimes:
|
261
|
+
self.lifetimes[use_val] = (position, position)
|
262
|
+
else:
|
263
|
+
start, _ = self.lifetimes[use_val]
|
264
|
+
self.lifetimes[use_val] = (start, position)
|
265
|
+
|
266
|
+
position += 1
|
267
|
+
|
268
|
+
return self.lifetimes
|
269
|
+
|
270
|
+
def find_reusable_slots(self) -> list[set[MIRValue]]:
|
271
|
+
"""Find sets of values that can share the same stack slot.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
List of sets where each set contains values that can share a slot.
|
275
|
+
"""
|
276
|
+
reusable_groups: list[set[MIRValue]] = []
|
277
|
+
|
278
|
+
# Sort values by start of lifetime
|
279
|
+
sorted_values = sorted(self.lifetimes.items(), key=lambda x: x[1][0])
|
280
|
+
|
281
|
+
for value, (start, end) in sorted_values:
|
282
|
+
# Find a group where this value doesn't overlap with any member
|
283
|
+
placed = False
|
284
|
+
for group in reusable_groups:
|
285
|
+
can_share = True
|
286
|
+
for other in group:
|
287
|
+
other_start, other_end = self.lifetimes[other]
|
288
|
+
# Check for overlap
|
289
|
+
if not (end < other_start or start > other_end):
|
290
|
+
can_share = False
|
291
|
+
break
|
292
|
+
|
293
|
+
if can_share:
|
294
|
+
group.add(value)
|
295
|
+
placed = True
|
296
|
+
break
|
297
|
+
|
298
|
+
if not placed:
|
299
|
+
# Create a new group
|
300
|
+
reusable_groups.append({value})
|
301
|
+
|
302
|
+
return reusable_groups
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""MIR optimization reporting infrastructure."""
|
2
|
+
|
3
|
+
from machine_dialect.mir.reporting.optimization_reporter import OptimizationReporter
|
4
|
+
from machine_dialect.mir.reporting.report_formatter import (
|
5
|
+
HTMLReportFormatter,
|
6
|
+
JSONReportFormatter,
|
7
|
+
ReportFormatter,
|
8
|
+
TextReportFormatter,
|
9
|
+
)
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"HTMLReportFormatter",
|
13
|
+
"JSONReportFormatter",
|
14
|
+
"OptimizationReporter",
|
15
|
+
"ReportFormatter",
|
16
|
+
"TextReportFormatter",
|
17
|
+
]
|
@@ -0,0 +1,314 @@
|
|
1
|
+
"""Optimization reporter for collecting and aggregating pass statistics.
|
2
|
+
|
3
|
+
This module provides infrastructure for collecting optimization statistics
|
4
|
+
from various passes and generating comprehensive reports.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
from enum import Enum
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
|
12
|
+
class MetricType(Enum):
|
13
|
+
"""Types of metrics collected."""
|
14
|
+
|
15
|
+
COUNT = "count" # Simple count (e.g., instructions removed)
|
16
|
+
PERCENTAGE = "percentage" # Percentage value
|
17
|
+
SIZE = "size" # Size in bytes
|
18
|
+
TIME = "time" # Time in milliseconds
|
19
|
+
RATIO = "ratio" # Ratio between two values
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class PassMetrics:
|
24
|
+
"""Metrics collected from a single optimization pass.
|
25
|
+
|
26
|
+
Attributes:
|
27
|
+
pass_name: Name of the optimization pass.
|
28
|
+
phase: Optimization phase (early, middle, late).
|
29
|
+
metrics: Dictionary of metric name to value.
|
30
|
+
before_stats: Statistics before the pass.
|
31
|
+
after_stats: Statistics after the pass.
|
32
|
+
time_ms: Time taken to run the pass in milliseconds.
|
33
|
+
"""
|
34
|
+
|
35
|
+
pass_name: str
|
36
|
+
phase: str = "main"
|
37
|
+
metrics: dict[str, int] = field(default_factory=dict)
|
38
|
+
before_stats: dict[str, int] = field(default_factory=dict)
|
39
|
+
after_stats: dict[str, int] = field(default_factory=dict)
|
40
|
+
time_ms: float = 0.0
|
41
|
+
|
42
|
+
def get_improvement(self, metric: str) -> float:
|
43
|
+
"""Calculate improvement percentage for a metric.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
metric: Metric name.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
Improvement percentage (positive means reduction).
|
50
|
+
"""
|
51
|
+
before = self.before_stats.get(metric, 0)
|
52
|
+
after = self.after_stats.get(metric, 0)
|
53
|
+
if before == 0:
|
54
|
+
return 0.0
|
55
|
+
return ((before - after) / before) * 100
|
56
|
+
|
57
|
+
|
58
|
+
@dataclass
|
59
|
+
class ModuleMetrics:
|
60
|
+
"""Metrics for an entire module.
|
61
|
+
|
62
|
+
Attributes:
|
63
|
+
module_name: Name of the module.
|
64
|
+
function_metrics: Metrics for each function.
|
65
|
+
pass_metrics: Metrics from each pass.
|
66
|
+
total_time_ms: Total optimization time.
|
67
|
+
optimization_level: Optimization level used.
|
68
|
+
"""
|
69
|
+
|
70
|
+
module_name: str
|
71
|
+
function_metrics: dict[str, dict[str, Any]] = field(default_factory=dict)
|
72
|
+
pass_metrics: list[PassMetrics] = field(default_factory=list)
|
73
|
+
total_time_ms: float = 0.0
|
74
|
+
optimization_level: int = 0
|
75
|
+
|
76
|
+
def add_pass_metrics(self, metrics: PassMetrics) -> None:
|
77
|
+
"""Add metrics from a pass.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
metrics: Pass metrics to add.
|
81
|
+
"""
|
82
|
+
self.pass_metrics.append(metrics)
|
83
|
+
self.total_time_ms += metrics.time_ms
|
84
|
+
|
85
|
+
def get_summary(self) -> dict[str, Any]:
|
86
|
+
"""Get summary statistics.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
Dictionary of summary statistics.
|
90
|
+
"""
|
91
|
+
summary = {
|
92
|
+
"module_name": self.module_name,
|
93
|
+
"optimization_level": self.optimization_level,
|
94
|
+
"total_passes": len(self.pass_metrics),
|
95
|
+
"total_time_ms": self.total_time_ms,
|
96
|
+
"passes_applied": [m.pass_name for m in self.pass_metrics],
|
97
|
+
}
|
98
|
+
|
99
|
+
# Aggregate improvements
|
100
|
+
total_improvements = {}
|
101
|
+
for metrics in self.pass_metrics:
|
102
|
+
for key in metrics.before_stats:
|
103
|
+
if key in metrics.after_stats:
|
104
|
+
improvement = metrics.get_improvement(key)
|
105
|
+
if key not in total_improvements:
|
106
|
+
total_improvements[key] = 0.0
|
107
|
+
total_improvements[key] += improvement
|
108
|
+
|
109
|
+
summary["improvements"] = total_improvements
|
110
|
+
|
111
|
+
# Calculate total metrics
|
112
|
+
total_metrics = {}
|
113
|
+
for metrics in self.pass_metrics:
|
114
|
+
for key, value in metrics.metrics.items():
|
115
|
+
if key not in total_metrics:
|
116
|
+
total_metrics[key] = 0
|
117
|
+
total_metrics[key] += value
|
118
|
+
|
119
|
+
summary["total_metrics"] = total_metrics
|
120
|
+
|
121
|
+
return summary
|
122
|
+
|
123
|
+
|
124
|
+
class OptimizationReporter:
|
125
|
+
"""Collects and reports optimization statistics.
|
126
|
+
|
127
|
+
This class aggregates statistics from multiple optimization passes
|
128
|
+
and generates comprehensive reports about the optimization process.
|
129
|
+
"""
|
130
|
+
|
131
|
+
def __init__(self, module_name: str = "unknown") -> None:
|
132
|
+
"""Initialize the reporter.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
module_name: Name of the module being optimized.
|
136
|
+
"""
|
137
|
+
self.module_metrics = ModuleMetrics(module_name=module_name)
|
138
|
+
self.current_pass: PassMetrics | None = None
|
139
|
+
|
140
|
+
def start_pass(
|
141
|
+
self,
|
142
|
+
pass_name: str,
|
143
|
+
phase: str = "main",
|
144
|
+
before_stats: dict[str, int] | None = None,
|
145
|
+
) -> None:
|
146
|
+
"""Start tracking a new pass.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
pass_name: Name of the pass.
|
150
|
+
phase: Optimization phase.
|
151
|
+
before_stats: Statistics before the pass.
|
152
|
+
"""
|
153
|
+
self.current_pass = PassMetrics(
|
154
|
+
pass_name=pass_name,
|
155
|
+
phase=phase,
|
156
|
+
before_stats=before_stats or {},
|
157
|
+
)
|
158
|
+
|
159
|
+
def end_pass(
|
160
|
+
self,
|
161
|
+
metrics: dict[str, int] | None = None,
|
162
|
+
after_stats: dict[str, int] | None = None,
|
163
|
+
time_ms: float = 0.0,
|
164
|
+
) -> None:
|
165
|
+
"""End tracking the current pass.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
metrics: Pass-specific metrics.
|
169
|
+
after_stats: Statistics after the pass.
|
170
|
+
time_ms: Time taken by the pass.
|
171
|
+
"""
|
172
|
+
if self.current_pass:
|
173
|
+
self.current_pass.metrics = metrics or {}
|
174
|
+
self.current_pass.after_stats = after_stats or {}
|
175
|
+
self.current_pass.time_ms = time_ms
|
176
|
+
self.module_metrics.add_pass_metrics(self.current_pass)
|
177
|
+
self.current_pass = None
|
178
|
+
|
179
|
+
def add_function_metrics(self, func_name: str, metrics: dict[str, Any]) -> None:
|
180
|
+
"""Add metrics for a specific function.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
func_name: Function name.
|
184
|
+
metrics: Function metrics.
|
185
|
+
"""
|
186
|
+
self.module_metrics.function_metrics[func_name] = metrics
|
187
|
+
|
188
|
+
def add_custom_stats(self, pass_name: str, stats: dict[str, int]) -> None:
|
189
|
+
"""Add custom statistics for a pass.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
pass_name: Name of the pass.
|
193
|
+
stats: Statistics to add.
|
194
|
+
"""
|
195
|
+
# Create a pass metrics entry for custom stats
|
196
|
+
metrics = PassMetrics(pass_name=pass_name, phase="bytecode", metrics=stats)
|
197
|
+
self.module_metrics.add_pass_metrics(metrics)
|
198
|
+
|
199
|
+
def set_optimization_level(self, level: int) -> None:
|
200
|
+
"""Set the optimization level.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
level: Optimization level (0-3).
|
204
|
+
"""
|
205
|
+
self.module_metrics.optimization_level = level
|
206
|
+
|
207
|
+
def get_report_data(self) -> ModuleMetrics:
|
208
|
+
"""Get the collected metrics.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
Module metrics.
|
212
|
+
"""
|
213
|
+
return self.module_metrics
|
214
|
+
|
215
|
+
def generate_summary(self) -> str:
|
216
|
+
"""Generate a text summary of optimizations.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
Text summary.
|
220
|
+
"""
|
221
|
+
summary = self.module_metrics.get_summary()
|
222
|
+
lines = []
|
223
|
+
|
224
|
+
lines.append(f"Module: {summary['module_name']}")
|
225
|
+
lines.append(f"Optimization Level: {summary['optimization_level']}")
|
226
|
+
lines.append(f"Total Passes: {summary['total_passes']}")
|
227
|
+
lines.append(f"Total Time: {summary['total_time_ms']:.2f}ms")
|
228
|
+
lines.append("")
|
229
|
+
|
230
|
+
if summary["passes_applied"]:
|
231
|
+
lines.append("Passes Applied:")
|
232
|
+
for pass_name in summary["passes_applied"]:
|
233
|
+
lines.append(f" - {pass_name}")
|
234
|
+
lines.append("")
|
235
|
+
|
236
|
+
if summary["improvements"]:
|
237
|
+
lines.append("Improvements:")
|
238
|
+
for metric, improvement in summary["improvements"].items():
|
239
|
+
if improvement > 0:
|
240
|
+
lines.append(f" {metric}: {improvement:.1f}% reduction")
|
241
|
+
lines.append("")
|
242
|
+
|
243
|
+
if summary["total_metrics"]:
|
244
|
+
lines.append("Total Changes:")
|
245
|
+
for metric, value in summary["total_metrics"].items():
|
246
|
+
if value > 0:
|
247
|
+
lines.append(f" {metric}: {value}")
|
248
|
+
|
249
|
+
return "\n".join(lines)
|
250
|
+
|
251
|
+
def generate_detailed_report(self) -> str:
|
252
|
+
"""Generate a detailed report with per-pass statistics.
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
Detailed text report.
|
256
|
+
"""
|
257
|
+
lines = []
|
258
|
+
lines.append("=" * 60)
|
259
|
+
lines.append("OPTIMIZATION REPORT")
|
260
|
+
lines.append("=" * 60)
|
261
|
+
lines.append("")
|
262
|
+
|
263
|
+
# Summary
|
264
|
+
lines.append(self.generate_summary())
|
265
|
+
lines.append("")
|
266
|
+
lines.append("=" * 60)
|
267
|
+
lines.append("DETAILED PASS STATISTICS")
|
268
|
+
lines.append("=" * 60)
|
269
|
+
|
270
|
+
# Per-pass details
|
271
|
+
for metrics in self.module_metrics.pass_metrics:
|
272
|
+
lines.append("")
|
273
|
+
lines.append(f"Pass: {metrics.pass_name}")
|
274
|
+
lines.append(f"Phase: {metrics.phase}")
|
275
|
+
lines.append(f"Time: {metrics.time_ms:.2f}ms")
|
276
|
+
|
277
|
+
if metrics.metrics:
|
278
|
+
lines.append("Metrics:")
|
279
|
+
for key, value in metrics.metrics.items():
|
280
|
+
if value > 0:
|
281
|
+
lines.append(f" {key}: {value}")
|
282
|
+
|
283
|
+
# Show improvements
|
284
|
+
improvements = []
|
285
|
+
for key in metrics.before_stats:
|
286
|
+
if key in metrics.after_stats:
|
287
|
+
improvement = metrics.get_improvement(key)
|
288
|
+
if improvement > 0:
|
289
|
+
improvements.append(
|
290
|
+
f" {key}: {metrics.before_stats[key]} → "
|
291
|
+
f"{metrics.after_stats[key]} "
|
292
|
+
f"({improvement:.1f}% reduction)"
|
293
|
+
)
|
294
|
+
|
295
|
+
if improvements:
|
296
|
+
lines.append("Improvements:")
|
297
|
+
lines.extend(improvements)
|
298
|
+
|
299
|
+
lines.append("-" * 40)
|
300
|
+
|
301
|
+
# Function-specific metrics if available
|
302
|
+
if self.module_metrics.function_metrics:
|
303
|
+
lines.append("")
|
304
|
+
lines.append("=" * 60)
|
305
|
+
lines.append("FUNCTION METRICS")
|
306
|
+
lines.append("=" * 60)
|
307
|
+
|
308
|
+
for func_name, func_metrics in self.module_metrics.function_metrics.items():
|
309
|
+
lines.append("")
|
310
|
+
lines.append(f"Function: {func_name}")
|
311
|
+
for key, value in func_metrics.items():
|
312
|
+
lines.append(f" {key}: {value}")
|
313
|
+
|
314
|
+
return "\n".join(lines)
|