machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,634 @@
|
|
1
|
+
"""Constant propagation and folding optimization pass.
|
2
|
+
|
3
|
+
This module implements constant propagation at the MIR level, replacing
|
4
|
+
variable uses with constants and folding constant expressions.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
10
|
+
from machine_dialect.mir.mir_instructions import (
|
11
|
+
BinaryOp,
|
12
|
+
ConditionalJump,
|
13
|
+
Copy,
|
14
|
+
Jump,
|
15
|
+
LoadConst,
|
16
|
+
LoadVar,
|
17
|
+
MIRInstruction,
|
18
|
+
Phi,
|
19
|
+
StoreVar,
|
20
|
+
UnaryOp,
|
21
|
+
)
|
22
|
+
from machine_dialect.mir.mir_transformer import MIRTransformer
|
23
|
+
from machine_dialect.mir.mir_types import MIRType
|
24
|
+
from machine_dialect.mir.mir_values import Constant, MIRValue, Temp, Variable
|
25
|
+
from machine_dialect.mir.optimization_pass import (
|
26
|
+
OptimizationPass,
|
27
|
+
PassInfo,
|
28
|
+
PassType,
|
29
|
+
PreservationLevel,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class ConstantLattice:
|
34
|
+
"""Lattice for constant propagation analysis.
|
35
|
+
|
36
|
+
Values can be:
|
37
|
+
- TOP: Unknown/uninitialized
|
38
|
+
- Constant value: Known constant
|
39
|
+
- BOTTOM: Not a constant (conflicting values)
|
40
|
+
"""
|
41
|
+
|
42
|
+
TOP = object() # Unknown
|
43
|
+
BOTTOM = object() # Not constant
|
44
|
+
|
45
|
+
def __init__(self) -> None:
|
46
|
+
"""Initialize the lattice."""
|
47
|
+
self.values: dict[MIRValue, Any] = {}
|
48
|
+
|
49
|
+
def get(self, value: MIRValue) -> Any:
|
50
|
+
"""Get the lattice value.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
value: MIR value to query.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
Lattice value (TOP, BOTTOM, or constant).
|
57
|
+
"""
|
58
|
+
return self.values.get(value, self.TOP)
|
59
|
+
|
60
|
+
def set(self, value: MIRValue, lattice_val: Any) -> bool:
|
61
|
+
"""Set the lattice value.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
value: MIR value to set.
|
65
|
+
lattice_val: New lattice value.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
True if the value changed.
|
69
|
+
"""
|
70
|
+
old = self.get(value)
|
71
|
+
if old == lattice_val:
|
72
|
+
return False
|
73
|
+
|
74
|
+
if old == self.BOTTOM:
|
75
|
+
return False # Can't change from BOTTOM
|
76
|
+
|
77
|
+
if lattice_val == self.TOP:
|
78
|
+
return False # Can't go back to TOP
|
79
|
+
|
80
|
+
if old == self.TOP:
|
81
|
+
self.values[value] = lattice_val
|
82
|
+
return True
|
83
|
+
|
84
|
+
# Both are constants - must be same or go to BOTTOM
|
85
|
+
if old != lattice_val:
|
86
|
+
self.values[value] = self.BOTTOM
|
87
|
+
return True
|
88
|
+
|
89
|
+
return False
|
90
|
+
|
91
|
+
def is_constant(self, value: MIRValue) -> bool:
|
92
|
+
"""Check if a value is a known constant.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
value: MIR value to check.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
True if the value is a known constant.
|
99
|
+
"""
|
100
|
+
val = self.get(value)
|
101
|
+
return bool(val != self.TOP and val != self.BOTTOM)
|
102
|
+
|
103
|
+
def get_constant(self, value: MIRValue) -> Any:
|
104
|
+
"""Get the constant value.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
value: MIR value to query.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
The constant value or None.
|
111
|
+
"""
|
112
|
+
val = self.get(value)
|
113
|
+
if val != self.TOP and val != self.BOTTOM:
|
114
|
+
return val
|
115
|
+
return None
|
116
|
+
|
117
|
+
|
118
|
+
class ConstantPropagation(OptimizationPass):
|
119
|
+
"""Constant propagation and folding optimization pass."""
|
120
|
+
|
121
|
+
def get_info(self) -> PassInfo:
|
122
|
+
"""Get pass information.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
Pass information.
|
126
|
+
"""
|
127
|
+
return PassInfo(
|
128
|
+
name="constant-propagation",
|
129
|
+
description="Propagate constants and fold constant expressions",
|
130
|
+
pass_type=PassType.OPTIMIZATION,
|
131
|
+
requires=[],
|
132
|
+
preserves=PreservationLevel.CFG,
|
133
|
+
)
|
134
|
+
|
135
|
+
def run_on_function(self, function: MIRFunction) -> bool:
|
136
|
+
"""Run constant propagation on a function.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
function: The function to optimize.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
True if the function was modified.
|
143
|
+
"""
|
144
|
+
# Perform constant propagation analysis
|
145
|
+
lattice = self._analyze_constants(function)
|
146
|
+
|
147
|
+
# Apply transformations based on analysis
|
148
|
+
transformer = MIRTransformer(function)
|
149
|
+
|
150
|
+
# Replace uses with constants
|
151
|
+
for value, const_val in lattice.values.items():
|
152
|
+
if const_val != ConstantLattice.TOP and const_val != ConstantLattice.BOTTOM:
|
153
|
+
if isinstance(value, Variable | Temp):
|
154
|
+
# Create constant
|
155
|
+
const = Constant(const_val, self._infer_type(const_val))
|
156
|
+
count = transformer.replace_uses(value, const)
|
157
|
+
self.stats["constants_propagated"] = self.stats.get("constants_propagated", 0) + count
|
158
|
+
|
159
|
+
# Fold constant expressions
|
160
|
+
self._fold_constant_expressions(function, transformer)
|
161
|
+
|
162
|
+
# Simplify control flow with known conditions
|
163
|
+
self._simplify_control_flow(function, transformer, lattice)
|
164
|
+
|
165
|
+
return transformer.modified
|
166
|
+
|
167
|
+
def _analyze_constants(self, function: MIRFunction) -> ConstantLattice:
|
168
|
+
"""Analyze function to find constant values using iterative dataflow.
|
169
|
+
|
170
|
+
This implements a worklist algorithm that converges to a fixed point,
|
171
|
+
properly handling loops and cross-block propagation.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
function: Function to analyze.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
Constant lattice with analysis results.
|
178
|
+
"""
|
179
|
+
lattice = ConstantLattice()
|
180
|
+
worklist = set()
|
181
|
+
block_lattices: dict[Any, ConstantLattice] = {}
|
182
|
+
|
183
|
+
# Initialize all blocks' local lattices
|
184
|
+
for block in function.cfg.blocks.values():
|
185
|
+
block_lattices[block] = ConstantLattice()
|
186
|
+
worklist.add(block)
|
187
|
+
|
188
|
+
# Fixed-point iteration
|
189
|
+
iteration_count = 0
|
190
|
+
max_iterations = 100 # Prevent infinite loops
|
191
|
+
|
192
|
+
while worklist and iteration_count < max_iterations:
|
193
|
+
iteration_count += 1
|
194
|
+
block = worklist.pop()
|
195
|
+
|
196
|
+
# Merge lattice values from predecessors
|
197
|
+
changed = self._merge_predecessors(block, block_lattices, lattice)
|
198
|
+
|
199
|
+
# Process phi nodes with proper meet operation
|
200
|
+
for phi in block.phi_nodes:
|
201
|
+
if self._process_phi(phi, block, block_lattices, lattice):
|
202
|
+
changed = True
|
203
|
+
|
204
|
+
# Process instructions
|
205
|
+
for inst in block.instructions:
|
206
|
+
if self._process_instruction(inst, lattice):
|
207
|
+
changed = True
|
208
|
+
|
209
|
+
# If this block changed, add successors to worklist
|
210
|
+
if changed:
|
211
|
+
worklist.update(block.successors)
|
212
|
+
|
213
|
+
return lattice
|
214
|
+
|
215
|
+
def _merge_predecessors(
|
216
|
+
self, block: Any, block_lattices: dict[Any, ConstantLattice], lattice: ConstantLattice
|
217
|
+
) -> bool:
|
218
|
+
"""Merge lattice values from predecessor blocks.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
block: Current block.
|
222
|
+
block_lattices: Per-block lattice states.
|
223
|
+
lattice: Global lattice.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
True if any values changed.
|
227
|
+
"""
|
228
|
+
changed = False
|
229
|
+
|
230
|
+
# For each predecessor, merge its output values
|
231
|
+
for pred in block.predecessors:
|
232
|
+
pred_lattice = block_lattices.get(pred)
|
233
|
+
if pred_lattice:
|
234
|
+
for value, const_val in pred_lattice.values.items():
|
235
|
+
if lattice.set(value, const_val):
|
236
|
+
changed = True
|
237
|
+
|
238
|
+
return changed
|
239
|
+
|
240
|
+
def _process_phi(
|
241
|
+
self, phi: Phi, block: Any, block_lattices: dict[Any, ConstantLattice], lattice: ConstantLattice
|
242
|
+
) -> bool:
|
243
|
+
"""Process phi node with improved cross-block analysis.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
phi: Phi node to process.
|
247
|
+
block: Current block.
|
248
|
+
block_lattices: Per-block lattice states.
|
249
|
+
lattice: Constant lattice.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
True if the phi's value changed.
|
253
|
+
"""
|
254
|
+
# Collect all incoming values with proper lattice meet
|
255
|
+
result = ConstantLattice.TOP
|
256
|
+
all_same = True
|
257
|
+
first_val = None
|
258
|
+
|
259
|
+
for value, pred_block in phi.incoming:
|
260
|
+
if isinstance(value, Constant):
|
261
|
+
val = value.value
|
262
|
+
else:
|
263
|
+
# Look up value from predecessor's lattice
|
264
|
+
pred_lattice = block_lattices.get(pred_block, lattice)
|
265
|
+
val = pred_lattice.get(value) if pred_lattice else lattice.get(value)
|
266
|
+
|
267
|
+
if val == ConstantLattice.BOTTOM:
|
268
|
+
result = ConstantLattice.BOTTOM
|
269
|
+
break
|
270
|
+
elif val != ConstantLattice.TOP:
|
271
|
+
if first_val is None:
|
272
|
+
first_val = val
|
273
|
+
elif first_val != val:
|
274
|
+
all_same = False
|
275
|
+
result = ConstantLattice.BOTTOM
|
276
|
+
break
|
277
|
+
|
278
|
+
# If all values are the same constant, propagate it
|
279
|
+
if all_same and first_val is not None:
|
280
|
+
result = first_val
|
281
|
+
|
282
|
+
if result != ConstantLattice.TOP:
|
283
|
+
return lattice.set(phi.dest, result)
|
284
|
+
return False
|
285
|
+
|
286
|
+
def _process_instruction(self, inst: MIRInstruction, lattice: ConstantLattice) -> bool:
|
287
|
+
"""Process instruction with change tracking.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
inst: Instruction to process.
|
291
|
+
lattice: Constant lattice.
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
True if any value changed.
|
295
|
+
"""
|
296
|
+
changed = False
|
297
|
+
|
298
|
+
if isinstance(inst, LoadConst):
|
299
|
+
# LoadConst defines a constant
|
300
|
+
if lattice.set(inst.dest, inst.constant.value):
|
301
|
+
changed = True
|
302
|
+
|
303
|
+
elif isinstance(inst, Copy):
|
304
|
+
# Copy propagates constants
|
305
|
+
if isinstance(inst.source, Constant):
|
306
|
+
if lattice.set(inst.dest, inst.source.value):
|
307
|
+
changed = True
|
308
|
+
elif isinstance(inst.source, Variable | Temp):
|
309
|
+
val = lattice.get(inst.source)
|
310
|
+
if val != ConstantLattice.TOP:
|
311
|
+
if lattice.set(inst.dest, val):
|
312
|
+
changed = True
|
313
|
+
|
314
|
+
elif isinstance(inst, StoreVar):
|
315
|
+
# Store propagates constants to variables
|
316
|
+
if isinstance(inst.source, Constant):
|
317
|
+
if lattice.set(inst.var, inst.source.value):
|
318
|
+
changed = True
|
319
|
+
elif isinstance(inst.source, Variable | Temp):
|
320
|
+
val = lattice.get(inst.source)
|
321
|
+
if val != ConstantLattice.TOP:
|
322
|
+
if lattice.set(inst.var, val):
|
323
|
+
changed = True
|
324
|
+
|
325
|
+
elif isinstance(inst, LoadVar):
|
326
|
+
# Load from variable
|
327
|
+
val = lattice.get(inst.var)
|
328
|
+
if val != ConstantLattice.TOP:
|
329
|
+
if lattice.set(inst.dest, val):
|
330
|
+
changed = True
|
331
|
+
|
332
|
+
elif isinstance(inst, BinaryOp):
|
333
|
+
# Try to fold binary operations
|
334
|
+
left_val = self._get_value(inst.left, lattice)
|
335
|
+
right_val = self._get_value(inst.right, lattice)
|
336
|
+
|
337
|
+
if left_val is not None and right_val is not None:
|
338
|
+
result = self._fold_binary_op(inst.op, left_val, right_val)
|
339
|
+
if result is not None:
|
340
|
+
if lattice.set(inst.dest, result):
|
341
|
+
changed = True
|
342
|
+
|
343
|
+
elif isinstance(inst, UnaryOp):
|
344
|
+
# Try to fold unary operations
|
345
|
+
operand_val = self._get_value(inst.operand, lattice)
|
346
|
+
|
347
|
+
if operand_val is not None:
|
348
|
+
result = self._fold_unary_op(inst.op, operand_val)
|
349
|
+
if result is not None:
|
350
|
+
if lattice.set(inst.dest, result):
|
351
|
+
changed = True
|
352
|
+
|
353
|
+
return changed
|
354
|
+
|
355
|
+
def _process_binary_op(self, inst: BinaryOp, lattice: ConstantLattice) -> None:
|
356
|
+
"""Process a binary operation for constant folding.
|
357
|
+
|
358
|
+
Args:
|
359
|
+
inst: Binary operation instruction.
|
360
|
+
lattice: Constant lattice.
|
361
|
+
"""
|
362
|
+
# Get operand values
|
363
|
+
left_val = self._get_value(inst.left, lattice)
|
364
|
+
right_val = self._get_value(inst.right, lattice)
|
365
|
+
|
366
|
+
if left_val is None or right_val is None:
|
367
|
+
return
|
368
|
+
|
369
|
+
# Try to fold the operation
|
370
|
+
result = self._fold_binary_op(inst.op, left_val, right_val)
|
371
|
+
if result is not None:
|
372
|
+
lattice.set(inst.dest, result)
|
373
|
+
|
374
|
+
def _process_unary_op(self, inst: UnaryOp, lattice: ConstantLattice) -> None:
|
375
|
+
"""Process a unary operation for constant folding.
|
376
|
+
|
377
|
+
Args:
|
378
|
+
inst: Unary operation instruction.
|
379
|
+
lattice: Constant lattice.
|
380
|
+
"""
|
381
|
+
# Get operand value
|
382
|
+
operand_val = self._get_value(inst.operand, lattice)
|
383
|
+
|
384
|
+
if operand_val is None:
|
385
|
+
return
|
386
|
+
|
387
|
+
# Try to fold the operation
|
388
|
+
result = self._fold_unary_op(inst.op, operand_val)
|
389
|
+
if result is not None:
|
390
|
+
lattice.set(inst.dest, result)
|
391
|
+
|
392
|
+
def _get_value(self, value: MIRValue, lattice: ConstantLattice) -> Any:
|
393
|
+
"""Get the constant value of a MIR value.
|
394
|
+
|
395
|
+
Args:
|
396
|
+
value: MIR value to evaluate.
|
397
|
+
lattice: Constant lattice.
|
398
|
+
|
399
|
+
Returns:
|
400
|
+
The constant value or None.
|
401
|
+
"""
|
402
|
+
if isinstance(value, Constant):
|
403
|
+
return value.value
|
404
|
+
elif isinstance(value, Variable | Temp):
|
405
|
+
val = lattice.get(value)
|
406
|
+
if val != ConstantLattice.TOP and val != ConstantLattice.BOTTOM:
|
407
|
+
return val
|
408
|
+
return None
|
409
|
+
|
410
|
+
def _fold_binary_op(self, op: str, left: Any, right: Any) -> Any:
|
411
|
+
"""Fold a binary operation with constant operands.
|
412
|
+
|
413
|
+
Args:
|
414
|
+
op: Operation string.
|
415
|
+
left: Left operand value.
|
416
|
+
right: Right operand value.
|
417
|
+
|
418
|
+
Returns:
|
419
|
+
The folded result or None.
|
420
|
+
"""
|
421
|
+
try:
|
422
|
+
# Arithmetic operations
|
423
|
+
if op == "+":
|
424
|
+
# Handle string concatenation and numeric addition
|
425
|
+
if isinstance(left, str) or isinstance(right, str):
|
426
|
+
return str(left) + str(right)
|
427
|
+
return left + right
|
428
|
+
elif op == "-":
|
429
|
+
return left - right
|
430
|
+
elif op == "*":
|
431
|
+
return left * right
|
432
|
+
elif op == "/":
|
433
|
+
if right != 0:
|
434
|
+
# Integer division for integers
|
435
|
+
if isinstance(left, int) and isinstance(right, int):
|
436
|
+
return left // right
|
437
|
+
return left / right
|
438
|
+
elif op == "//":
|
439
|
+
if right != 0:
|
440
|
+
return left // right
|
441
|
+
elif op == "%":
|
442
|
+
if right != 0:
|
443
|
+
return left % right
|
444
|
+
elif op == "**":
|
445
|
+
return left**right
|
446
|
+
|
447
|
+
# Comparison operations
|
448
|
+
elif op == "<":
|
449
|
+
return left < right
|
450
|
+
elif op == "<=":
|
451
|
+
return left <= right
|
452
|
+
elif op == ">":
|
453
|
+
return left > right
|
454
|
+
elif op == ">=":
|
455
|
+
return left >= right
|
456
|
+
elif op == "==":
|
457
|
+
return left == right
|
458
|
+
elif op == "!=":
|
459
|
+
return left != right
|
460
|
+
elif op == "===": # Strict equality
|
461
|
+
return left is right
|
462
|
+
elif op == "!==": # Strict inequality
|
463
|
+
return left is not right
|
464
|
+
|
465
|
+
# Logical operations
|
466
|
+
elif op == "and":
|
467
|
+
return left and right
|
468
|
+
elif op == "or":
|
469
|
+
return left or right
|
470
|
+
|
471
|
+
# Bitwise operations
|
472
|
+
elif op == "&":
|
473
|
+
if isinstance(left, int) and isinstance(right, int):
|
474
|
+
return left & right
|
475
|
+
elif op == "|":
|
476
|
+
if isinstance(left, int) and isinstance(right, int):
|
477
|
+
return left | right
|
478
|
+
elif op == "^":
|
479
|
+
if isinstance(left, int) and isinstance(right, int):
|
480
|
+
return left ^ right
|
481
|
+
elif op == "<<":
|
482
|
+
if isinstance(left, int) and isinstance(right, int):
|
483
|
+
return left << right
|
484
|
+
elif op == ">>":
|
485
|
+
if isinstance(left, int) and isinstance(right, int):
|
486
|
+
return left >> right
|
487
|
+
|
488
|
+
# String operations
|
489
|
+
elif op == "in":
|
490
|
+
return left in right
|
491
|
+
elif op == "not in":
|
492
|
+
return left not in right
|
493
|
+
|
494
|
+
except (TypeError, ValueError, ZeroDivisionError, OverflowError):
|
495
|
+
pass
|
496
|
+
return None
|
497
|
+
|
498
|
+
def _fold_unary_op(self, op: str, operand: Any) -> Any:
|
499
|
+
"""Fold a unary operation with constant operand.
|
500
|
+
|
501
|
+
Args:
|
502
|
+
op: Operation string.
|
503
|
+
operand: Operand value.
|
504
|
+
|
505
|
+
Returns:
|
506
|
+
The folded result or None.
|
507
|
+
"""
|
508
|
+
try:
|
509
|
+
if op == "-":
|
510
|
+
return -operand
|
511
|
+
elif op == "not":
|
512
|
+
return not operand
|
513
|
+
elif op == "+":
|
514
|
+
return +operand
|
515
|
+
elif op == "~": # Bitwise NOT
|
516
|
+
if isinstance(operand, int):
|
517
|
+
return ~operand
|
518
|
+
elif op == "abs":
|
519
|
+
return abs(operand)
|
520
|
+
elif op == "len":
|
521
|
+
if hasattr(operand, "__len__"):
|
522
|
+
return len(operand)
|
523
|
+
except (TypeError, ValueError):
|
524
|
+
pass
|
525
|
+
return None
|
526
|
+
|
527
|
+
def _fold_constant_expressions(
|
528
|
+
self,
|
529
|
+
function: MIRFunction,
|
530
|
+
transformer: MIRTransformer,
|
531
|
+
) -> None:
|
532
|
+
"""Fold constant expressions in the function.
|
533
|
+
|
534
|
+
Args:
|
535
|
+
function: Function to optimize.
|
536
|
+
transformer: MIR transformer.
|
537
|
+
"""
|
538
|
+
for block in function.cfg.blocks.values():
|
539
|
+
for inst in list(block.instructions):
|
540
|
+
if isinstance(inst, BinaryOp):
|
541
|
+
# Try to fold binary operation
|
542
|
+
left_val = self._get_constant_value(inst.left)
|
543
|
+
right_val = self._get_constant_value(inst.right)
|
544
|
+
|
545
|
+
if left_val is not None and right_val is not None:
|
546
|
+
result = self._fold_binary_op(inst.op, left_val, right_val)
|
547
|
+
if result is not None:
|
548
|
+
# Replace with LoadConst
|
549
|
+
const = Constant(result, self._infer_type(result))
|
550
|
+
new_inst = LoadConst(inst.dest, const, inst.source_location)
|
551
|
+
transformer.replace_instruction(block, inst, new_inst)
|
552
|
+
self.stats["expressions_folded"] = self.stats.get("expressions_folded", 0) + 1
|
553
|
+
|
554
|
+
elif isinstance(inst, UnaryOp):
|
555
|
+
# Try to fold unary operation
|
556
|
+
operand_val = self._get_constant_value(inst.operand)
|
557
|
+
|
558
|
+
if operand_val is not None:
|
559
|
+
result = self._fold_unary_op(inst.op, operand_val)
|
560
|
+
if result is not None:
|
561
|
+
# Replace with LoadConst
|
562
|
+
const = Constant(result, self._infer_type(result))
|
563
|
+
new_inst = LoadConst(inst.dest, const, inst.source_location)
|
564
|
+
transformer.replace_instruction(block, inst, new_inst)
|
565
|
+
self.stats["expressions_folded"] = self.stats.get("expressions_folded", 0) + 1
|
566
|
+
|
567
|
+
def _simplify_control_flow(
|
568
|
+
self,
|
569
|
+
function: MIRFunction,
|
570
|
+
transformer: MIRTransformer,
|
571
|
+
lattice: ConstantLattice,
|
572
|
+
) -> None:
|
573
|
+
"""Simplify control flow with known conditions.
|
574
|
+
|
575
|
+
Args:
|
576
|
+
function: Function to optimize.
|
577
|
+
transformer: MIR transformer.
|
578
|
+
lattice: Constant lattice.
|
579
|
+
"""
|
580
|
+
for block in list(function.cfg.blocks.values()):
|
581
|
+
term = block.get_terminator()
|
582
|
+
if isinstance(term, ConditionalJump):
|
583
|
+
# Check if condition is constant
|
584
|
+
cond_val = self._get_value(term.condition, lattice)
|
585
|
+
if cond_val is not None:
|
586
|
+
# Replace with unconditional jump
|
587
|
+
if cond_val:
|
588
|
+
new_jump = Jump(term.true_label, term.source_location)
|
589
|
+
elif term.false_label is not None:
|
590
|
+
new_jump = Jump(term.false_label, term.source_location)
|
591
|
+
else:
|
592
|
+
continue # Can't simplify without false label
|
593
|
+
|
594
|
+
transformer.replace_instruction(block, term, new_jump)
|
595
|
+
self.stats["branches_simplified"] = self.stats.get("branches_simplified", 0) + 1
|
596
|
+
|
597
|
+
def _get_constant_value(self, value: MIRValue) -> Any:
|
598
|
+
"""Get the constant value if it's a constant.
|
599
|
+
|
600
|
+
Args:
|
601
|
+
value: MIR value.
|
602
|
+
|
603
|
+
Returns:
|
604
|
+
Constant value or None.
|
605
|
+
"""
|
606
|
+
if isinstance(value, Constant):
|
607
|
+
return value.value
|
608
|
+
return None
|
609
|
+
|
610
|
+
def _infer_type(self, value: Any) -> MIRType:
|
611
|
+
"""Infer MIR type from a Python value.
|
612
|
+
|
613
|
+
Args:
|
614
|
+
value: Python value.
|
615
|
+
|
616
|
+
Returns:
|
617
|
+
Inferred MIR type.
|
618
|
+
"""
|
619
|
+
if isinstance(value, bool):
|
620
|
+
return MIRType.BOOL
|
621
|
+
elif isinstance(value, int):
|
622
|
+
return MIRType.INT
|
623
|
+
elif isinstance(value, float):
|
624
|
+
return MIRType.FLOAT
|
625
|
+
elif isinstance(value, str):
|
626
|
+
return MIRType.STRING
|
627
|
+
elif value is None:
|
628
|
+
return MIRType.EMPTY
|
629
|
+
else:
|
630
|
+
return MIRType.UNKNOWN
|
631
|
+
|
632
|
+
def finalize(self) -> None:
|
633
|
+
"""Finalize the pass."""
|
634
|
+
pass
|