machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,398 @@
|
|
1
|
+
"""Common Subexpression Elimination (CSE) optimization pass.
|
2
|
+
|
3
|
+
This module implements CSE at the MIR level, eliminating redundant
|
4
|
+
computations by reusing previously computed values.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
11
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
12
|
+
from machine_dialect.mir.mir_instructions import (
|
13
|
+
BinaryOp,
|
14
|
+
Call,
|
15
|
+
Copy,
|
16
|
+
LoadConst,
|
17
|
+
MIRInstruction,
|
18
|
+
StoreVar,
|
19
|
+
UnaryOp,
|
20
|
+
)
|
21
|
+
from machine_dialect.mir.mir_transformer import MIRTransformer
|
22
|
+
from machine_dialect.mir.mir_values import Constant, MIRValue, Temp, Variable
|
23
|
+
from machine_dialect.mir.optimization_pass import (
|
24
|
+
OptimizationPass,
|
25
|
+
PassInfo,
|
26
|
+
PassType,
|
27
|
+
PreservationLevel,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass(frozen=True)
|
32
|
+
class Expression:
|
33
|
+
"""Represents an expression for CSE.
|
34
|
+
|
35
|
+
Attributes:
|
36
|
+
op: Operation type.
|
37
|
+
operands: Tuple of operands.
|
38
|
+
"""
|
39
|
+
|
40
|
+
op: str
|
41
|
+
operands: tuple[Any, ...]
|
42
|
+
|
43
|
+
def __hash__(self) -> int:
|
44
|
+
"""Hash the expression."""
|
45
|
+
return hash((self.op, self.operands))
|
46
|
+
|
47
|
+
|
48
|
+
class AvailableExpressions:
|
49
|
+
"""Tracks available expressions in a block."""
|
50
|
+
|
51
|
+
def __init__(self) -> None:
|
52
|
+
"""Initialize available expressions."""
|
53
|
+
# Map from expression to the value containing it
|
54
|
+
self.expressions: dict[Expression, MIRValue] = {}
|
55
|
+
# Map from value to expressions it defines
|
56
|
+
self.definitions: dict[MIRValue, set[Expression]] = {}
|
57
|
+
|
58
|
+
def add(self, expr: Expression, value: MIRValue) -> None:
|
59
|
+
"""Add an available expression.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
expr: The expression.
|
63
|
+
value: The value containing the expression.
|
64
|
+
"""
|
65
|
+
self.expressions[expr] = value
|
66
|
+
if value not in self.definitions:
|
67
|
+
self.definitions[value] = set()
|
68
|
+
self.definitions[value].add(expr)
|
69
|
+
|
70
|
+
def get(self, expr: Expression) -> MIRValue | None:
|
71
|
+
"""Get the value for an expression.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
expr: The expression to look up.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
The value containing the expression or None.
|
78
|
+
"""
|
79
|
+
return self.expressions.get(expr)
|
80
|
+
|
81
|
+
def invalidate(self, value: MIRValue) -> None:
|
82
|
+
"""Invalidate expressions involving a value.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
value: The value that changed.
|
86
|
+
"""
|
87
|
+
# Remove expressions that use this value
|
88
|
+
to_remove = []
|
89
|
+
for expr in self.expressions:
|
90
|
+
if value in expr.operands:
|
91
|
+
to_remove.append(expr)
|
92
|
+
|
93
|
+
for expr in to_remove:
|
94
|
+
del self.expressions[expr]
|
95
|
+
|
96
|
+
# Remove expressions defined by this value
|
97
|
+
if value in self.definitions:
|
98
|
+
for expr in self.definitions[value]:
|
99
|
+
if expr in self.expressions:
|
100
|
+
del self.expressions[expr]
|
101
|
+
del self.definitions[value]
|
102
|
+
|
103
|
+
def copy(self) -> "AvailableExpressions":
|
104
|
+
"""Create a copy of available expressions.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
A copy of this available expressions set.
|
108
|
+
"""
|
109
|
+
new = AvailableExpressions()
|
110
|
+
new.expressions = self.expressions.copy()
|
111
|
+
new.definitions = {k: v.copy() for k, v in self.definitions.items()}
|
112
|
+
return new
|
113
|
+
|
114
|
+
|
115
|
+
class CommonSubexpressionElimination(OptimizationPass):
|
116
|
+
"""Common subexpression elimination optimization pass."""
|
117
|
+
|
118
|
+
def get_info(self) -> PassInfo:
|
119
|
+
"""Get pass information.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
Pass information.
|
123
|
+
"""
|
124
|
+
return PassInfo(
|
125
|
+
name="cse",
|
126
|
+
description="Eliminate common subexpressions",
|
127
|
+
pass_type=PassType.OPTIMIZATION,
|
128
|
+
requires=[],
|
129
|
+
preserves=PreservationLevel.CFG,
|
130
|
+
)
|
131
|
+
|
132
|
+
def run_on_function(self, function: MIRFunction) -> bool:
|
133
|
+
"""Run CSE on a function.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
function: The function to optimize.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
True if the function was modified.
|
140
|
+
"""
|
141
|
+
transformer = MIRTransformer(function)
|
142
|
+
|
143
|
+
# Perform local CSE within each block
|
144
|
+
for block in function.cfg.blocks.values():
|
145
|
+
self._local_cse(block, transformer)
|
146
|
+
|
147
|
+
# Perform global CSE across blocks
|
148
|
+
self._global_cse(function, transformer)
|
149
|
+
|
150
|
+
return transformer.modified
|
151
|
+
|
152
|
+
def _local_cse(self, block: BasicBlock, transformer: MIRTransformer) -> None:
|
153
|
+
"""Perform local CSE within a block.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
block: The block to optimize.
|
157
|
+
transformer: MIR transformer.
|
158
|
+
"""
|
159
|
+
available = AvailableExpressions()
|
160
|
+
|
161
|
+
for inst in list(block.instructions):
|
162
|
+
# Check if this instruction computes an expression
|
163
|
+
expr = self._get_expression(inst)
|
164
|
+
|
165
|
+
if expr:
|
166
|
+
# Check if expression is already available
|
167
|
+
existing = available.get(expr)
|
168
|
+
if existing and existing != self._get_result(inst):
|
169
|
+
# Replace with copy of existing value
|
170
|
+
result = self._get_result(inst)
|
171
|
+
if result:
|
172
|
+
new_inst = Copy(result, existing, inst.source_location)
|
173
|
+
transformer.replace_instruction(block, inst, new_inst)
|
174
|
+
self.stats["local_cse_eliminated"] = self.stats.get("local_cse_eliminated", 0) + 1
|
175
|
+
else:
|
176
|
+
# Add expression to available set
|
177
|
+
result = self._get_result(inst)
|
178
|
+
if result:
|
179
|
+
available.add(expr, result)
|
180
|
+
|
181
|
+
# Update available expressions based on side effects
|
182
|
+
self._update_available(inst, available)
|
183
|
+
|
184
|
+
def _global_cse(self, function: MIRFunction, transformer: MIRTransformer) -> None:
|
185
|
+
"""Perform global CSE across blocks.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
function: The function to optimize.
|
189
|
+
transformer: MIR transformer.
|
190
|
+
"""
|
191
|
+
# Compute available expressions at entry of each block
|
192
|
+
block_available: dict[BasicBlock, AvailableExpressions] = {}
|
193
|
+
|
194
|
+
# Initialize with empty sets
|
195
|
+
for block in function.cfg.blocks.values():
|
196
|
+
block_available[block] = AvailableExpressions()
|
197
|
+
|
198
|
+
# Iterate until fixed point
|
199
|
+
changed = True
|
200
|
+
while changed:
|
201
|
+
changed = False
|
202
|
+
|
203
|
+
for block in function.cfg.blocks.values():
|
204
|
+
# Compute available at entry as intersection of predecessors
|
205
|
+
if block.predecessors:
|
206
|
+
# Start with copy of first predecessor
|
207
|
+
if block.predecessors[0] in block_available:
|
208
|
+
new_available = self._intersect_available(
|
209
|
+
[block_available.get(p, AvailableExpressions()) for p in block.predecessors]
|
210
|
+
)
|
211
|
+
else:
|
212
|
+
new_available = AvailableExpressions()
|
213
|
+
else:
|
214
|
+
new_available = AvailableExpressions()
|
215
|
+
|
216
|
+
# Check if changed
|
217
|
+
if self._available_changed(block_available[block], new_available):
|
218
|
+
block_available[block] = new_available
|
219
|
+
changed = True
|
220
|
+
|
221
|
+
# Compute available at exit
|
222
|
+
available = new_available.copy()
|
223
|
+
for inst in block.instructions:
|
224
|
+
expr = self._get_expression(inst)
|
225
|
+
if expr:
|
226
|
+
result = self._get_result(inst)
|
227
|
+
if result:
|
228
|
+
available.add(expr, result)
|
229
|
+
self._update_available(inst, available)
|
230
|
+
|
231
|
+
# Apply CSE based on available expressions
|
232
|
+
for block in function.cfg.blocks.values():
|
233
|
+
available = block_available[block].copy()
|
234
|
+
|
235
|
+
for inst in list(block.instructions):
|
236
|
+
expr = self._get_expression(inst)
|
237
|
+
|
238
|
+
if expr:
|
239
|
+
existing = available.get(expr)
|
240
|
+
if existing and existing != self._get_result(inst):
|
241
|
+
# Replace with copy
|
242
|
+
result = self._get_result(inst)
|
243
|
+
if result:
|
244
|
+
source_loc = inst.source_location if hasattr(inst, "source_location") else (0, 0)
|
245
|
+
new_inst = Copy(result, existing, source_loc)
|
246
|
+
transformer.replace_instruction(block, inst, new_inst)
|
247
|
+
self.stats["global_cse_eliminated"] = self.stats.get("global_cse_eliminated", 0) + 1
|
248
|
+
else:
|
249
|
+
result = self._get_result(inst)
|
250
|
+
if result:
|
251
|
+
available.add(expr, result)
|
252
|
+
|
253
|
+
self._update_available(inst, available)
|
254
|
+
|
255
|
+
def _get_expression(self, inst: MIRInstruction) -> Expression | None:
|
256
|
+
"""Extract expression from an instruction.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
inst: The instruction.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
The expression or None.
|
263
|
+
"""
|
264
|
+
if isinstance(inst, BinaryOp):
|
265
|
+
# Normalize commutative operations
|
266
|
+
if inst.op in ["+", "*", "==", "!=", "and", "or"]:
|
267
|
+
# Sort operands for commutative ops
|
268
|
+
operands = tuple(sorted([self._normalize_value(inst.left), self._normalize_value(inst.right)], key=str))
|
269
|
+
else:
|
270
|
+
operands = (self._normalize_value(inst.left), self._normalize_value(inst.right))
|
271
|
+
return Expression(f"binary_{inst.op}", operands)
|
272
|
+
|
273
|
+
elif isinstance(inst, UnaryOp):
|
274
|
+
return Expression(f"unary_{inst.op}", (self._normalize_value(inst.operand),))
|
275
|
+
|
276
|
+
elif isinstance(inst, LoadConst):
|
277
|
+
# Constants are their own expressions
|
278
|
+
return Expression("const", (inst.constant.value, inst.constant.type))
|
279
|
+
|
280
|
+
return None
|
281
|
+
|
282
|
+
def _normalize_value(self, value: MIRValue) -> Any:
|
283
|
+
"""Normalize a value for expression comparison.
|
284
|
+
|
285
|
+
Args:
|
286
|
+
value: The value to normalize.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
Normalized representation.
|
290
|
+
"""
|
291
|
+
if isinstance(value, Constant):
|
292
|
+
return ("const", value.value, value.type)
|
293
|
+
elif isinstance(value, Variable):
|
294
|
+
return ("var", value.name)
|
295
|
+
elif isinstance(value, Temp):
|
296
|
+
return ("temp", value.id)
|
297
|
+
else:
|
298
|
+
return str(value)
|
299
|
+
|
300
|
+
def _get_result(self, inst: MIRInstruction) -> MIRValue | None:
|
301
|
+
"""Get the result value of an instruction.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
inst: The instruction.
|
305
|
+
|
306
|
+
Returns:
|
307
|
+
The result value or None.
|
308
|
+
"""
|
309
|
+
defs = inst.get_defs()
|
310
|
+
if defs and len(defs) == 1:
|
311
|
+
return defs[0]
|
312
|
+
return None
|
313
|
+
|
314
|
+
def _update_available(
|
315
|
+
self,
|
316
|
+
inst: MIRInstruction,
|
317
|
+
available: AvailableExpressions,
|
318
|
+
) -> None:
|
319
|
+
"""Update available expressions after an instruction.
|
320
|
+
|
321
|
+
Args:
|
322
|
+
inst: The instruction.
|
323
|
+
available: Available expressions to update.
|
324
|
+
"""
|
325
|
+
# Invalidate expressions if instruction has side effects
|
326
|
+
if isinstance(inst, StoreVar):
|
327
|
+
# Invalidate expressions using this variable
|
328
|
+
available.invalidate(inst.var)
|
329
|
+
elif isinstance(inst, Call):
|
330
|
+
# Conservative: invalidate all expressions with variables
|
331
|
+
# (calls might modify globals or have other side effects)
|
332
|
+
for value in list(available.definitions.keys()):
|
333
|
+
if isinstance(value, Variable):
|
334
|
+
available.invalidate(value)
|
335
|
+
|
336
|
+
def _intersect_available(
|
337
|
+
self,
|
338
|
+
sets: list[AvailableExpressions],
|
339
|
+
) -> AvailableExpressions:
|
340
|
+
"""Compute intersection of available expression sets.
|
341
|
+
|
342
|
+
Args:
|
343
|
+
sets: List of available expression sets.
|
344
|
+
|
345
|
+
Returns:
|
346
|
+
The intersection.
|
347
|
+
"""
|
348
|
+
if not sets:
|
349
|
+
return AvailableExpressions()
|
350
|
+
|
351
|
+
# Start with first set
|
352
|
+
result = AvailableExpressions()
|
353
|
+
if not sets[0].expressions:
|
354
|
+
return result
|
355
|
+
|
356
|
+
# Find expressions available in all sets
|
357
|
+
for expr, value in sets[0].expressions.items():
|
358
|
+
available_in_all = True
|
359
|
+
for s in sets[1:]:
|
360
|
+
if expr not in s.expressions:
|
361
|
+
available_in_all = False
|
362
|
+
break
|
363
|
+
# Check if same value
|
364
|
+
if s.expressions[expr] != value:
|
365
|
+
available_in_all = False
|
366
|
+
break
|
367
|
+
|
368
|
+
if available_in_all:
|
369
|
+
result.add(expr, value)
|
370
|
+
|
371
|
+
return result
|
372
|
+
|
373
|
+
def _available_changed(
|
374
|
+
self,
|
375
|
+
old: AvailableExpressions,
|
376
|
+
new: AvailableExpressions,
|
377
|
+
) -> bool:
|
378
|
+
"""Check if available expressions changed.
|
379
|
+
|
380
|
+
Args:
|
381
|
+
old: Old available expressions.
|
382
|
+
new: New available expressions.
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
True if changed.
|
386
|
+
"""
|
387
|
+
if len(old.expressions) != len(new.expressions):
|
388
|
+
return True
|
389
|
+
|
390
|
+
for expr, value in old.expressions.items():
|
391
|
+
if expr not in new.expressions or new.expressions[expr] != value:
|
392
|
+
return True
|
393
|
+
|
394
|
+
return False
|
395
|
+
|
396
|
+
def finalize(self) -> None:
|
397
|
+
"""Finalize the pass."""
|
398
|
+
pass
|
@@ -0,0 +1,288 @@
|
|
1
|
+
"""Dead code elimination optimization pass.
|
2
|
+
|
3
|
+
This module implements dead code elimination (DCE) at the MIR level,
|
4
|
+
removing instructions and blocks that have no effect on program output.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from machine_dialect.mir.analyses.use_def_chains import UseDefChains
|
8
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
9
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
10
|
+
from machine_dialect.mir.mir_instructions import (
|
11
|
+
Assert,
|
12
|
+
Call,
|
13
|
+
ConditionalJump,
|
14
|
+
Jump,
|
15
|
+
MIRInstruction,
|
16
|
+
Print,
|
17
|
+
Return,
|
18
|
+
Scope,
|
19
|
+
StoreVar,
|
20
|
+
)
|
21
|
+
from machine_dialect.mir.mir_transformer import MIRTransformer
|
22
|
+
from machine_dialect.mir.mir_values import Temp, Variable
|
23
|
+
from machine_dialect.mir.optimization_pass import (
|
24
|
+
OptimizationPass,
|
25
|
+
PassInfo,
|
26
|
+
PassType,
|
27
|
+
PreservationLevel,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
class DeadCodeElimination(OptimizationPass):
|
32
|
+
"""Dead code elimination optimization pass."""
|
33
|
+
|
34
|
+
def get_info(self) -> PassInfo:
|
35
|
+
"""Get pass information.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
Pass information.
|
39
|
+
"""
|
40
|
+
return PassInfo(
|
41
|
+
name="dce",
|
42
|
+
description="Eliminate dead code and unreachable blocks",
|
43
|
+
pass_type=[PassType.OPTIMIZATION, PassType.CLEANUP],
|
44
|
+
requires=["use-def-chains"],
|
45
|
+
preserves=PreservationLevel.CFG,
|
46
|
+
)
|
47
|
+
|
48
|
+
def run_on_function(self, function: MIRFunction) -> bool:
|
49
|
+
"""Run dead code elimination on a function.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
function: The function to optimize.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
True if the function was modified.
|
56
|
+
"""
|
57
|
+
transformer = MIRTransformer(function)
|
58
|
+
|
59
|
+
# Get use-def chains
|
60
|
+
use_def_chains: UseDefChains = self.get_analysis("use-def-chains", function)
|
61
|
+
|
62
|
+
# Phase 1: Remove dead instructions
|
63
|
+
dead_instructions = self._find_dead_instructions(function, use_def_chains)
|
64
|
+
for block, inst in dead_instructions:
|
65
|
+
transformer.remove_instruction(block, inst)
|
66
|
+
self.stats["dead_instructions_removed"] = self.stats.get("dead_instructions_removed", 0) + 1
|
67
|
+
|
68
|
+
# Phase 2: Remove dead stores
|
69
|
+
dead_stores = self._find_dead_stores(function, use_def_chains)
|
70
|
+
for block, inst in dead_stores:
|
71
|
+
# When removing a StoreVar, replace uses of the variable with the source value
|
72
|
+
if isinstance(inst, StoreVar):
|
73
|
+
# Replace all uses of the variable with the source value
|
74
|
+
transformer.replace_uses(inst.var, inst.source)
|
75
|
+
transformer.remove_instruction(block, inst)
|
76
|
+
self.stats["dead_stores_removed"] = self.stats.get("dead_stores_removed", 0) + 1
|
77
|
+
|
78
|
+
# Phase 3: Remove unreachable blocks
|
79
|
+
num_unreachable = transformer.eliminate_unreachable_blocks()
|
80
|
+
self.stats["unreachable_blocks_removed"] = num_unreachable
|
81
|
+
|
82
|
+
# Phase 4: Simplify control flow
|
83
|
+
transformer.simplify_cfg()
|
84
|
+
|
85
|
+
return transformer.modified
|
86
|
+
|
87
|
+
def _find_dead_instructions(
|
88
|
+
self,
|
89
|
+
function: MIRFunction,
|
90
|
+
use_def_chains: UseDefChains,
|
91
|
+
) -> list[tuple[BasicBlock, MIRInstruction]]:
|
92
|
+
"""Find dead instructions that can be removed.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
function: The function to analyze.
|
96
|
+
use_def_chains: Use-def chain information.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
List of (block, instruction) pairs to remove.
|
100
|
+
"""
|
101
|
+
dead = []
|
102
|
+
|
103
|
+
for block in function.cfg.blocks.values():
|
104
|
+
for inst in block.instructions:
|
105
|
+
# Skip instructions with side effects
|
106
|
+
if self._has_side_effects(inst):
|
107
|
+
continue
|
108
|
+
|
109
|
+
# Check if all values defined by this instruction are dead
|
110
|
+
defs = inst.get_defs()
|
111
|
+
if not defs:
|
112
|
+
continue
|
113
|
+
|
114
|
+
all_dead = True
|
115
|
+
for value in defs:
|
116
|
+
if isinstance(value, Temp | Variable):
|
117
|
+
if not use_def_chains.is_dead(value):
|
118
|
+
all_dead = False
|
119
|
+
break
|
120
|
+
|
121
|
+
if all_dead:
|
122
|
+
dead.append((block, inst))
|
123
|
+
|
124
|
+
return dead
|
125
|
+
|
126
|
+
def _find_dead_stores(
|
127
|
+
self,
|
128
|
+
function: MIRFunction,
|
129
|
+
use_def_chains: UseDefChains,
|
130
|
+
) -> list[tuple[BasicBlock, MIRInstruction]]:
|
131
|
+
"""Find dead store instructions.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
function: The function to analyze.
|
135
|
+
use_def_chains: Use-def chain information.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
List of (block, instruction) pairs to remove.
|
139
|
+
"""
|
140
|
+
dead_stores = []
|
141
|
+
|
142
|
+
for block in function.cfg.blocks.values():
|
143
|
+
# Track last store to each variable in the block
|
144
|
+
last_stores: dict[Variable, MIRInstruction] = {}
|
145
|
+
|
146
|
+
for inst in block.instructions:
|
147
|
+
if isinstance(inst, StoreVar):
|
148
|
+
# Check if there's a previous store to the same variable
|
149
|
+
if inst.var in last_stores:
|
150
|
+
# Previous store might be dead if not used between stores
|
151
|
+
prev_store = last_stores[inst.var]
|
152
|
+
if self._is_dead_store(
|
153
|
+
prev_store,
|
154
|
+
inst,
|
155
|
+
block,
|
156
|
+
use_def_chains,
|
157
|
+
):
|
158
|
+
dead_stores.append((block, prev_store))
|
159
|
+
|
160
|
+
last_stores[inst.var] = inst
|
161
|
+
|
162
|
+
# Check if final stores are dead (no uses after block)
|
163
|
+
for _var, store in last_stores.items():
|
164
|
+
if self._is_store_dead_at_end(store, block, use_def_chains):
|
165
|
+
dead_stores.append((block, store))
|
166
|
+
|
167
|
+
return dead_stores
|
168
|
+
|
169
|
+
def _is_dead_store(
|
170
|
+
self,
|
171
|
+
store1: MIRInstruction,
|
172
|
+
store2: MIRInstruction,
|
173
|
+
block: BasicBlock,
|
174
|
+
use_def_chains: UseDefChains,
|
175
|
+
) -> bool:
|
176
|
+
"""Check if store1 is dead because of store2.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
store1: First store instruction.
|
180
|
+
store2: Second store instruction.
|
181
|
+
block: Containing block.
|
182
|
+
use_def_chains: Use-def chain information.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
True if store1 is dead.
|
186
|
+
"""
|
187
|
+
# Find instructions between the two stores
|
188
|
+
idx1 = block.instructions.index(store1)
|
189
|
+
idx2 = block.instructions.index(store2)
|
190
|
+
|
191
|
+
if idx1 >= idx2:
|
192
|
+
return False
|
193
|
+
|
194
|
+
# Check if variable is used between the stores
|
195
|
+
if isinstance(store1, StoreVar):
|
196
|
+
var = store1.var
|
197
|
+
for i in range(idx1 + 1, idx2):
|
198
|
+
inst = block.instructions[i]
|
199
|
+
if var in inst.get_uses():
|
200
|
+
return False
|
201
|
+
|
202
|
+
return True
|
203
|
+
|
204
|
+
def _is_store_dead_at_end(
|
205
|
+
self,
|
206
|
+
store: MIRInstruction,
|
207
|
+
block: BasicBlock,
|
208
|
+
use_def_chains: UseDefChains,
|
209
|
+
) -> bool:
|
210
|
+
"""Check if a store at the end of a block is dead.
|
211
|
+
|
212
|
+
Args:
|
213
|
+
store: Store instruction.
|
214
|
+
block: Containing block.
|
215
|
+
use_def_chains: Use-def chain information.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
True if the store is dead.
|
219
|
+
"""
|
220
|
+
if not isinstance(store, StoreVar):
|
221
|
+
return False
|
222
|
+
|
223
|
+
# Check if variable is used after this block
|
224
|
+
var = store.var
|
225
|
+
|
226
|
+
# Check uses in successor blocks
|
227
|
+
visited = set()
|
228
|
+
worklist = list(block.successors)
|
229
|
+
|
230
|
+
while worklist:
|
231
|
+
succ = worklist.pop()
|
232
|
+
if succ in visited:
|
233
|
+
continue
|
234
|
+
visited.add(succ)
|
235
|
+
|
236
|
+
# Check phi nodes
|
237
|
+
for phi in succ.phi_nodes:
|
238
|
+
for val, pred_label in phi.incoming:
|
239
|
+
if val == var and pred_label == block.label:
|
240
|
+
return False
|
241
|
+
|
242
|
+
# Check instructions
|
243
|
+
for inst in succ.instructions:
|
244
|
+
if var in inst.get_uses():
|
245
|
+
return False
|
246
|
+
# If we see another store to the same variable, we're done
|
247
|
+
if isinstance(inst, StoreVar) and inst.var == var:
|
248
|
+
break
|
249
|
+
else:
|
250
|
+
# No store found, check successors
|
251
|
+
worklist.extend(succ.successors)
|
252
|
+
|
253
|
+
return True
|
254
|
+
|
255
|
+
def _has_side_effects(self, inst: MIRInstruction) -> bool:
|
256
|
+
"""Check if an instruction has side effects.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
inst: Instruction to check.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
True if the instruction has side effects.
|
263
|
+
"""
|
264
|
+
# Control flow instructions
|
265
|
+
if isinstance(inst, Jump | ConditionalJump | Return):
|
266
|
+
return True
|
267
|
+
|
268
|
+
# I/O operations
|
269
|
+
if isinstance(inst, Print):
|
270
|
+
return True
|
271
|
+
|
272
|
+
# Function calls (conservative - assume all calls have side effects)
|
273
|
+
if isinstance(inst, Call):
|
274
|
+
return True
|
275
|
+
|
276
|
+
# Memory operations
|
277
|
+
if isinstance(inst, StoreVar):
|
278
|
+
return True
|
279
|
+
|
280
|
+
# Assertions and scopes
|
281
|
+
if isinstance(inst, Assert | Scope):
|
282
|
+
return True
|
283
|
+
|
284
|
+
return False
|
285
|
+
|
286
|
+
def finalize(self) -> None:
|
287
|
+
"""Finalize the pass."""
|
288
|
+
pass
|