machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,906 @@
|
|
1
|
+
"""Type-specific MIR optimization pass.
|
2
|
+
|
3
|
+
This module implements type-aware optimizations that leverage type information
|
4
|
+
from variable definitions to generate more efficient MIR code.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from machine_dialect.mir.analyses.dominance_analysis import DominanceAnalysis
|
8
|
+
from machine_dialect.mir.analyses.use_def_chains import UseDefChains, UseDefChainsAnalysis
|
9
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
10
|
+
from machine_dialect.mir.dataflow import DataFlowAnalysis, Range, TypeContext
|
11
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
12
|
+
from machine_dialect.mir.mir_instructions import (
|
13
|
+
BinaryOp,
|
14
|
+
ConditionalJump,
|
15
|
+
Copy,
|
16
|
+
LoadConst,
|
17
|
+
MaxOp,
|
18
|
+
MinOp,
|
19
|
+
MIRInstruction,
|
20
|
+
SaturatingAddOp,
|
21
|
+
ShiftOp,
|
22
|
+
UnaryOp,
|
23
|
+
)
|
24
|
+
from machine_dialect.mir.mir_types import MIRType, MIRUnionType
|
25
|
+
from machine_dialect.mir.mir_values import Constant, MIRValue, Temp, Variable
|
26
|
+
from machine_dialect.mir.optimization_pass import (
|
27
|
+
FunctionPass,
|
28
|
+
PassInfo,
|
29
|
+
PassType,
|
30
|
+
PreservationLevel,
|
31
|
+
)
|
32
|
+
from machine_dialect.mir.ssa_construction import DominanceInfo
|
33
|
+
|
34
|
+
|
35
|
+
class TypeInference(DataFlowAnalysis[dict[MIRValue, TypeContext]]):
|
36
|
+
"""Type inference using dataflow analysis framework."""
|
37
|
+
|
38
|
+
def initial_state(self) -> dict[MIRValue, TypeContext]:
|
39
|
+
"""Get initial type state."""
|
40
|
+
return {}
|
41
|
+
|
42
|
+
def transfer(self, inst: MIRInstruction, state: dict[MIRValue, TypeContext]) -> dict[MIRValue, TypeContext]:
|
43
|
+
"""Transfer function for type inference.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
inst: The instruction to process.
|
47
|
+
state: Input type state.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
Output type state after processing the instruction.
|
51
|
+
"""
|
52
|
+
new_state = state.copy()
|
53
|
+
|
54
|
+
# LoadConst establishes exact type and range
|
55
|
+
if isinstance(inst, LoadConst):
|
56
|
+
# Ensure we have a MIRType, not MIRUnionType
|
57
|
+
base_type = inst.constant.type if not isinstance(inst.constant.type, MIRUnionType) else MIRType.UNKNOWN
|
58
|
+
ctx = TypeContext(base_type=base_type, nullable=False, provenance="constant")
|
59
|
+
# For numeric constants, set exact range
|
60
|
+
if isinstance(inst.constant.value, int | float):
|
61
|
+
ctx.range = Range(inst.constant.value, inst.constant.value)
|
62
|
+
new_state[inst.dest] = ctx
|
63
|
+
|
64
|
+
# BinaryOp propagates type information
|
65
|
+
elif isinstance(inst, BinaryOp):
|
66
|
+
left_ctx = state.get(inst.left)
|
67
|
+
right_ctx = state.get(inst.right)
|
68
|
+
|
69
|
+
# Determine result type
|
70
|
+
if inst.op in ["==", "!=", "<", "<=", ">", ">=", "and", "or"]:
|
71
|
+
result_type = MIRType.BOOL
|
72
|
+
ctx = TypeContext(base_type=result_type, nullable=False)
|
73
|
+
elif left_ctx and right_ctx:
|
74
|
+
# Numeric operations
|
75
|
+
if left_ctx.base_type == MIRType.FLOAT or right_ctx.base_type == MIRType.FLOAT:
|
76
|
+
result_type = MIRType.FLOAT
|
77
|
+
else:
|
78
|
+
result_type = left_ctx.base_type
|
79
|
+
ctx = TypeContext(base_type=result_type)
|
80
|
+
|
81
|
+
# Compute range for arithmetic operations
|
82
|
+
if inst.op == "+" and left_ctx.range and right_ctx.range:
|
83
|
+
if left_ctx.range.min is not None and right_ctx.range.min is not None:
|
84
|
+
new_min = left_ctx.range.min + right_ctx.range.min
|
85
|
+
else:
|
86
|
+
new_min = None
|
87
|
+
if left_ctx.range.max is not None and right_ctx.range.max is not None:
|
88
|
+
new_max = left_ctx.range.max + right_ctx.range.max
|
89
|
+
else:
|
90
|
+
new_max = None
|
91
|
+
ctx.range = Range(new_min, new_max)
|
92
|
+
elif inst.op == "-" and left_ctx.range and right_ctx.range:
|
93
|
+
if left_ctx.range.min is not None and right_ctx.range.max is not None:
|
94
|
+
new_min = left_ctx.range.min - right_ctx.range.max
|
95
|
+
else:
|
96
|
+
new_min = None
|
97
|
+
if left_ctx.range.max is not None and right_ctx.range.min is not None:
|
98
|
+
new_max = left_ctx.range.max - right_ctx.range.min
|
99
|
+
else:
|
100
|
+
new_max = None
|
101
|
+
ctx.range = Range(new_min, new_max)
|
102
|
+
else:
|
103
|
+
ctx = TypeContext(base_type=MIRType.UNKNOWN)
|
104
|
+
|
105
|
+
new_state[inst.dest] = ctx
|
106
|
+
|
107
|
+
# Copy propagates type information
|
108
|
+
elif isinstance(inst, Copy):
|
109
|
+
if inst.source in state:
|
110
|
+
new_state[inst.dest] = state[inst.source]
|
111
|
+
|
112
|
+
# UnaryOp
|
113
|
+
elif isinstance(inst, UnaryOp):
|
114
|
+
operand_ctx = state.get(inst.operand)
|
115
|
+
if operand_ctx:
|
116
|
+
if inst.op == "not":
|
117
|
+
ctx = TypeContext(base_type=MIRType.BOOL, nullable=False)
|
118
|
+
else:
|
119
|
+
ctx = TypeContext(base_type=operand_ctx.base_type)
|
120
|
+
# Negation flips range
|
121
|
+
if inst.op == "-" and operand_ctx.range:
|
122
|
+
ctx.range = Range(
|
123
|
+
-operand_ctx.range.max if operand_ctx.range.max is not None else None,
|
124
|
+
-operand_ctx.range.min if operand_ctx.range.min is not None else None,
|
125
|
+
)
|
126
|
+
new_state[inst.dest] = ctx
|
127
|
+
|
128
|
+
# New specialized instructions
|
129
|
+
elif isinstance(inst, MinOp):
|
130
|
+
left_ctx = state.get(inst.left)
|
131
|
+
right_ctx = state.get(inst.right)
|
132
|
+
if left_ctx and right_ctx:
|
133
|
+
ctx = TypeContext(base_type=left_ctx.base_type)
|
134
|
+
# Result range is constrained by both inputs
|
135
|
+
if left_ctx.range and right_ctx.range:
|
136
|
+
new_min = (
|
137
|
+
min(left_ctx.range.min, right_ctx.range.min)
|
138
|
+
if left_ctx.range.min is not None and right_ctx.range.min is not None
|
139
|
+
else None
|
140
|
+
)
|
141
|
+
new_max = (
|
142
|
+
min(left_ctx.range.max, right_ctx.range.max)
|
143
|
+
if left_ctx.range.max is not None and right_ctx.range.max is not None
|
144
|
+
else None
|
145
|
+
)
|
146
|
+
ctx.range = Range(new_min, new_max)
|
147
|
+
new_state[inst.dest] = ctx
|
148
|
+
|
149
|
+
elif isinstance(inst, MaxOp):
|
150
|
+
left_ctx = state.get(inst.left)
|
151
|
+
right_ctx = state.get(inst.right)
|
152
|
+
if left_ctx and right_ctx:
|
153
|
+
ctx = TypeContext(base_type=left_ctx.base_type)
|
154
|
+
# Result range is constrained by both inputs
|
155
|
+
if left_ctx.range and right_ctx.range:
|
156
|
+
new_min = (
|
157
|
+
max(left_ctx.range.min, right_ctx.range.min)
|
158
|
+
if left_ctx.range.min is not None and right_ctx.range.min is not None
|
159
|
+
else None
|
160
|
+
)
|
161
|
+
new_max = (
|
162
|
+
max(left_ctx.range.max, right_ctx.range.max)
|
163
|
+
if left_ctx.range.max is not None and right_ctx.range.max is not None
|
164
|
+
else None
|
165
|
+
)
|
166
|
+
ctx.range = Range(new_min, new_max)
|
167
|
+
new_state[inst.dest] = ctx
|
168
|
+
|
169
|
+
return new_state
|
170
|
+
|
171
|
+
def meet(self, states: list[dict[MIRValue, TypeContext]]) -> dict[MIRValue, TypeContext]:
|
172
|
+
"""Meet operation for type states.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
states: Type states to join.
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
The joined type state.
|
179
|
+
"""
|
180
|
+
if not states:
|
181
|
+
return {}
|
182
|
+
|
183
|
+
result = states[0].copy()
|
184
|
+
for state in states[1:]:
|
185
|
+
for value, ctx in state.items():
|
186
|
+
if value in result:
|
187
|
+
# Merge type contexts - union of ranges
|
188
|
+
existing = result[value]
|
189
|
+
if existing.range and ctx.range:
|
190
|
+
result[value].range = existing.range.union(ctx.range)
|
191
|
+
else:
|
192
|
+
result[value] = ctx
|
193
|
+
|
194
|
+
return result
|
195
|
+
|
196
|
+
|
197
|
+
class TypeSpecificOptimization(FunctionPass):
|
198
|
+
"""Type-specific optimization pass using modern dataflow framework.
|
199
|
+
|
200
|
+
This pass performs optimizations based on known type information:
|
201
|
+
- Type-aware constant folding with rich metadata
|
202
|
+
- Range-based optimizations
|
203
|
+
- Pattern-based transformations
|
204
|
+
- Cross-block type propagation
|
205
|
+
- Advanced algebraic simplifications
|
206
|
+
"""
|
207
|
+
|
208
|
+
def __init__(self) -> None:
|
209
|
+
"""Initialize the type-specific optimization pass."""
|
210
|
+
super().__init__()
|
211
|
+
self.type_analysis = TypeInference()
|
212
|
+
self.type_contexts: dict[BasicBlock, dict[MIRValue, TypeContext]] = {}
|
213
|
+
self.use_def_chains: UseDefChains | None = None
|
214
|
+
self.dominance_info: DominanceInfo | None = None
|
215
|
+
self.stats = {
|
216
|
+
"constant_folded": 0,
|
217
|
+
"range_optimized": 0,
|
218
|
+
"patterns_matched": 0,
|
219
|
+
"cross_block_optimized": 0,
|
220
|
+
"specialized_instructions": 0,
|
221
|
+
"strength_reduced": 0,
|
222
|
+
"dead_code_eliminated": 0,
|
223
|
+
"boolean_optimized": 0,
|
224
|
+
"type_checks_eliminated": 0,
|
225
|
+
"integer_optimized": 0,
|
226
|
+
"float_optimized": 0,
|
227
|
+
"string_optimized": 0,
|
228
|
+
"instructions_removed": 0,
|
229
|
+
}
|
230
|
+
|
231
|
+
def get_info(self) -> PassInfo:
|
232
|
+
"""Get pass information.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
Pass information.
|
236
|
+
"""
|
237
|
+
return PassInfo(
|
238
|
+
name="type-specific-optimization",
|
239
|
+
description="Optimize MIR using type information and dataflow analysis",
|
240
|
+
pass_type=PassType.OPTIMIZATION,
|
241
|
+
requires=["use-def-chains", "dominance"],
|
242
|
+
preserves=PreservationLevel.CFG,
|
243
|
+
)
|
244
|
+
|
245
|
+
def finalize(self) -> None:
|
246
|
+
"""Finalize the pass after running."""
|
247
|
+
pass
|
248
|
+
|
249
|
+
def run_on_function(self, function: MIRFunction) -> bool:
|
250
|
+
"""Run type-specific optimizations on a function.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
function: The function to optimize.
|
254
|
+
|
255
|
+
Returns:
|
256
|
+
True if the function was modified.
|
257
|
+
"""
|
258
|
+
modified = False
|
259
|
+
|
260
|
+
# Run type inference using dataflow framework
|
261
|
+
block_type_contexts = self.type_analysis.analyze(function)
|
262
|
+
# Store for cross-block access
|
263
|
+
self.type_contexts = block_type_contexts
|
264
|
+
|
265
|
+
# Get use-def chains and dominance info
|
266
|
+
use_def_analysis = UseDefChainsAnalysis()
|
267
|
+
self.use_def_chains = use_def_analysis.run_on_function(function)
|
268
|
+
|
269
|
+
dom_analysis = DominanceAnalysis()
|
270
|
+
self.dominance_info = dom_analysis.run_on_function(function)
|
271
|
+
|
272
|
+
# Optimize each block
|
273
|
+
for block in function.cfg.blocks.values():
|
274
|
+
block_modified = self._optimize_block(block)
|
275
|
+
if block_modified:
|
276
|
+
modified = True
|
277
|
+
|
278
|
+
# Cross-block optimizations
|
279
|
+
if self._optimize_cross_block(function):
|
280
|
+
modified = True
|
281
|
+
|
282
|
+
return modified
|
283
|
+
|
284
|
+
def _optimize_block(self, block: BasicBlock) -> bool:
|
285
|
+
"""Optimize a basic block.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
block: The block to optimize.
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
True if the block was modified.
|
292
|
+
"""
|
293
|
+
modified = False
|
294
|
+
new_instructions = []
|
295
|
+
|
296
|
+
# Get block-specific type contexts
|
297
|
+
block_types = self.type_contexts.get(block, {})
|
298
|
+
|
299
|
+
for inst in block.instructions:
|
300
|
+
optimized = self._optimize_instruction(inst, block_types)
|
301
|
+
|
302
|
+
if optimized != inst:
|
303
|
+
modified = True
|
304
|
+
if optimized is not None:
|
305
|
+
# Update metadata on optimized instruction
|
306
|
+
if hasattr(inst, "dest") and inst.dest in block_types:
|
307
|
+
ctx = block_types[inst.dest]
|
308
|
+
optimized.result_type = ctx.base_type
|
309
|
+
if hasattr(inst.dest, "known_range"):
|
310
|
+
inst.dest.known_range = ctx.range
|
311
|
+
new_instructions.append(optimized)
|
312
|
+
else:
|
313
|
+
self.stats["dead_code_eliminated"] += 1
|
314
|
+
else:
|
315
|
+
new_instructions.append(inst)
|
316
|
+
|
317
|
+
if modified:
|
318
|
+
block.instructions = new_instructions
|
319
|
+
|
320
|
+
return modified
|
321
|
+
|
322
|
+
def _optimize_instruction(
|
323
|
+
self, inst: MIRInstruction, type_contexts: dict[MIRValue, TypeContext]
|
324
|
+
) -> MIRInstruction | None:
|
325
|
+
"""Optimize a single instruction.
|
326
|
+
|
327
|
+
Args:
|
328
|
+
inst: The instruction to optimize.
|
329
|
+
type_contexts: Type contexts for values.
|
330
|
+
|
331
|
+
Returns:
|
332
|
+
Optimized instruction or None if eliminated.
|
333
|
+
"""
|
334
|
+
# Constant folding with type awareness
|
335
|
+
if isinstance(inst, BinaryOp):
|
336
|
+
# Boolean short-circuit optimizations
|
337
|
+
if inst.op == "and":
|
338
|
+
# False and x => False
|
339
|
+
if isinstance(inst.left, Constant) and inst.left.value is False:
|
340
|
+
self.stats["boolean_optimized"] += 1
|
341
|
+
return LoadConst(inst.dest, Constant(False, MIRType.BOOL), inst.source_location)
|
342
|
+
# x and False => False
|
343
|
+
if isinstance(inst.right, Constant) and inst.right.value is False:
|
344
|
+
self.stats["boolean_optimized"] += 1
|
345
|
+
return LoadConst(inst.dest, Constant(False, MIRType.BOOL), inst.source_location)
|
346
|
+
# True and x => x
|
347
|
+
if isinstance(inst.left, Constant) and inst.left.value is True:
|
348
|
+
self.stats["boolean_optimized"] += 1
|
349
|
+
return Copy(inst.dest, inst.right, inst.source_location)
|
350
|
+
# x and True => x
|
351
|
+
if isinstance(inst.right, Constant) and inst.right.value is True:
|
352
|
+
self.stats["boolean_optimized"] += 1
|
353
|
+
return Copy(inst.dest, inst.left, inst.source_location)
|
354
|
+
# x and x => x (idempotent)
|
355
|
+
if inst.left == inst.right:
|
356
|
+
self.stats["boolean_optimized"] += 1
|
357
|
+
return Copy(inst.dest, inst.left, inst.source_location)
|
358
|
+
elif inst.op == "or":
|
359
|
+
# True or x => True
|
360
|
+
if isinstance(inst.left, Constant) and inst.left.value is True:
|
361
|
+
self.stats["boolean_optimized"] += 1
|
362
|
+
return LoadConst(inst.dest, Constant(True, MIRType.BOOL), inst.source_location)
|
363
|
+
# x or True => True
|
364
|
+
if isinstance(inst.right, Constant) and inst.right.value is True:
|
365
|
+
self.stats["boolean_optimized"] += 1
|
366
|
+
return LoadConst(inst.dest, Constant(True, MIRType.BOOL), inst.source_location)
|
367
|
+
# False or x => x
|
368
|
+
if isinstance(inst.left, Constant) and inst.left.value is False:
|
369
|
+
self.stats["boolean_optimized"] += 1
|
370
|
+
return Copy(inst.dest, inst.right, inst.source_location)
|
371
|
+
# x or False => x
|
372
|
+
if isinstance(inst.right, Constant) and inst.right.value is False:
|
373
|
+
self.stats["boolean_optimized"] += 1
|
374
|
+
return Copy(inst.dest, inst.left, inst.source_location)
|
375
|
+
|
376
|
+
# Try constant folding
|
377
|
+
if isinstance(inst.left, Constant) and isinstance(inst.right, Constant):
|
378
|
+
folded = self._fold_binary_constant(inst.op, inst.left, inst.right)
|
379
|
+
if folded:
|
380
|
+
self.stats["constant_folded"] += 1
|
381
|
+
return LoadConst(inst.dest, folded, inst.source_location)
|
382
|
+
|
383
|
+
# Pattern-based optimizations
|
384
|
+
pattern_opt = self._optimize_patterns(inst, type_contexts)
|
385
|
+
if pattern_opt and pattern_opt != inst:
|
386
|
+
self.stats["patterns_matched"] += 1
|
387
|
+
return pattern_opt
|
388
|
+
|
389
|
+
# Range-based optimizations
|
390
|
+
range_opt = self._optimize_with_ranges(inst, type_contexts)
|
391
|
+
if range_opt and range_opt != inst:
|
392
|
+
self.stats["range_optimized"] += 1
|
393
|
+
return range_opt
|
394
|
+
|
395
|
+
# Strength reduction
|
396
|
+
strength_opt = self._apply_strength_reduction(inst)
|
397
|
+
if strength_opt and strength_opt != inst:
|
398
|
+
self.stats["strength_reduced"] += 1
|
399
|
+
# Check if this is also an integer optimization
|
400
|
+
if isinstance(inst.left, Constant | Variable | Temp) or isinstance(
|
401
|
+
inst.right, Constant | Variable | Temp
|
402
|
+
):
|
403
|
+
left_ctx = type_contexts.get(inst.left)
|
404
|
+
right_ctx = type_contexts.get(inst.right)
|
405
|
+
if (left_ctx and left_ctx.base_type == MIRType.INT) or (
|
406
|
+
right_ctx and right_ctx.base_type == MIRType.INT
|
407
|
+
):
|
408
|
+
self.stats["integer_optimized"] += 1
|
409
|
+
elif isinstance(inst.left, Constant) and inst.left.type == MIRType.INT:
|
410
|
+
self.stats["integer_optimized"] += 1
|
411
|
+
elif isinstance(inst.right, Constant) and inst.right.type == MIRType.INT:
|
412
|
+
self.stats["integer_optimized"] += 1
|
413
|
+
return strength_opt
|
414
|
+
|
415
|
+
# Self-equality optimization (x == x => True)
|
416
|
+
if inst.op == "==" and inst.left == inst.right:
|
417
|
+
self.stats["boolean_optimized"] += 1
|
418
|
+
return LoadConst(inst.dest, Constant(True, MIRType.BOOL), inst.source_location)
|
419
|
+
|
420
|
+
elif isinstance(inst, UnaryOp):
|
421
|
+
# Double negation elimination
|
422
|
+
if inst.op == "not":
|
423
|
+
# Check if operand is result of another not operation
|
424
|
+
if self.use_def_chains:
|
425
|
+
def_inst = self.use_def_chains.get_definition(inst.operand)
|
426
|
+
if isinstance(def_inst, UnaryOp) and def_inst.op == "not":
|
427
|
+
# not(not(x)) -> x
|
428
|
+
self.stats["boolean_optimized"] += 1
|
429
|
+
return Copy(inst.dest, def_inst.operand, inst.source_location)
|
430
|
+
# Check for comparison inversion: not(x op y) -> x inv_op y
|
431
|
+
elif isinstance(def_inst, BinaryOp):
|
432
|
+
inverted_op = self._invert_comparison(def_inst.op)
|
433
|
+
if inverted_op:
|
434
|
+
self.stats["boolean_optimized"] += 1
|
435
|
+
return BinaryOp(inst.dest, inverted_op, def_inst.left, def_inst.right, inst.source_location)
|
436
|
+
elif inst.op == "-":
|
437
|
+
# Check for double negation: -(-x) -> x
|
438
|
+
if self.use_def_chains:
|
439
|
+
def_inst = self.use_def_chains.get_definition(inst.operand)
|
440
|
+
if isinstance(def_inst, UnaryOp) and def_inst.op == "-":
|
441
|
+
# -(-x) -> x
|
442
|
+
self.stats["integer_optimized"] += 1
|
443
|
+
return Copy(inst.dest, def_inst.operand, inst.source_location)
|
444
|
+
|
445
|
+
# Constant folding
|
446
|
+
if isinstance(inst.operand, Constant):
|
447
|
+
folded = self._fold_unary_constant(inst.op, inst.operand)
|
448
|
+
if folded:
|
449
|
+
self.stats["constant_folded"] += 1
|
450
|
+
return LoadConst(inst.dest, folded, inst.source_location)
|
451
|
+
|
452
|
+
elif isinstance(inst, ConditionalJump):
|
453
|
+
# Optimize conditional jumps with known conditions
|
454
|
+
if isinstance(inst.condition, Constant):
|
455
|
+
if inst.condition.value:
|
456
|
+
# Always true - convert to unconditional jump
|
457
|
+
from machine_dialect.mir.mir_instructions import Jump
|
458
|
+
|
459
|
+
self.stats["constant_folded"] += 1
|
460
|
+
if inst.true_label:
|
461
|
+
# TODO: Verify if using inst.source_location is correct for optimization-generated instructions
|
462
|
+
return Jump(inst.true_label, inst.source_location)
|
463
|
+
else:
|
464
|
+
# Always false - convert to jump to false label
|
465
|
+
from machine_dialect.mir.mir_instructions import Jump
|
466
|
+
|
467
|
+
self.stats["constant_folded"] += 1
|
468
|
+
if inst.false_label:
|
469
|
+
# TODO: Verify if using inst.source_location is correct for optimization-generated instructions
|
470
|
+
return Jump(inst.false_label, inst.source_location)
|
471
|
+
|
472
|
+
return inst
|
473
|
+
|
474
|
+
def _optimize_patterns(self, inst: BinaryOp, type_contexts: dict[MIRValue, TypeContext]) -> MIRInstruction | None:
|
475
|
+
"""Apply pattern-based optimizations.
|
476
|
+
|
477
|
+
Args:
|
478
|
+
inst: The binary operation.
|
479
|
+
type_contexts: Type contexts.
|
480
|
+
|
481
|
+
Returns:
|
482
|
+
Optimized instruction or None.
|
483
|
+
"""
|
484
|
+
# Bit manipulation patterns
|
485
|
+
if inst.op == "&" and self.use_def_chains:
|
486
|
+
# Pattern: x & (x - 1) - clears the lowest set bit
|
487
|
+
# Check if right operand is x - 1
|
488
|
+
right_def = self.use_def_chains.get_definition(inst.right)
|
489
|
+
if isinstance(right_def, BinaryOp) and right_def.op == "-":
|
490
|
+
if right_def.left == inst.left and isinstance(right_def.right, Constant) and right_def.right.value == 1:
|
491
|
+
# Found x & (x - 1) pattern
|
492
|
+
# This could be replaced with a PopCountOp or specialized instruction
|
493
|
+
# For now, just mark that we found it
|
494
|
+
self.stats["patterns_matched"] += 1
|
495
|
+
# Could optimize to a specialized instruction here
|
496
|
+
return inst
|
497
|
+
# Check if left operand is x - 1 (commutative)
|
498
|
+
left_def = self.use_def_chains.get_definition(inst.left)
|
499
|
+
if isinstance(left_def, BinaryOp) and left_def.op == "-":
|
500
|
+
if left_def.left == inst.right and isinstance(left_def.right, Constant) and left_def.right.value == 1:
|
501
|
+
# Found (x - 1) & x pattern
|
502
|
+
self.stats["patterns_matched"] += 1
|
503
|
+
return inst
|
504
|
+
|
505
|
+
# Min/max pattern detection
|
506
|
+
# Pattern: (a < b) ? a : b => min(a, b)
|
507
|
+
# Pattern: (a > b) ? a : b => max(a, b)
|
508
|
+
|
509
|
+
# For now, convert comparisons that will be used in select patterns
|
510
|
+
if inst.op in ["<", ">", "<=", ">="] and self.use_def_chains:
|
511
|
+
# Check if this comparison is used in a conditional
|
512
|
+
uses = self.use_def_chains.get_uses(inst.dest)
|
513
|
+
for use in uses:
|
514
|
+
if isinstance(use, ConditionalJump):
|
515
|
+
# Could potentially be converted to min/max
|
516
|
+
# This would require more complex pattern matching
|
517
|
+
pass
|
518
|
+
|
519
|
+
# Saturating arithmetic patterns
|
520
|
+
# Pattern: min(a + b, MAX_INT) => saturating_add(a, b)
|
521
|
+
if inst.op == "+" and self._is_saturating_pattern(inst, type_contexts):
|
522
|
+
self.stats["specialized_instructions"] += 1
|
523
|
+
# SaturatingAddOp doesn't take source_location - it's an optimization-generated instruction
|
524
|
+
return SaturatingAddOp(inst.dest, inst.left, inst.right)
|
525
|
+
|
526
|
+
# Identity operations
|
527
|
+
if self._is_identity_operation(inst):
|
528
|
+
return Copy(inst.dest, inst.left if inst.op in ["+", "-", "*", "/"] else inst.right, inst.source_location)
|
529
|
+
|
530
|
+
return inst
|
531
|
+
|
532
|
+
def _optimize_with_ranges(
|
533
|
+
self, inst: BinaryOp, type_contexts: dict[MIRValue, TypeContext]
|
534
|
+
) -> MIRInstruction | None:
|
535
|
+
"""Optimize using range information.
|
536
|
+
|
537
|
+
Args:
|
538
|
+
inst: The binary operation.
|
539
|
+
type_contexts: Type contexts with ranges.
|
540
|
+
|
541
|
+
Returns:
|
542
|
+
Optimized instruction or None.
|
543
|
+
"""
|
544
|
+
left_ctx = type_contexts.get(inst.left)
|
545
|
+
right_ctx = type_contexts.get(inst.right)
|
546
|
+
|
547
|
+
if not left_ctx or not right_ctx:
|
548
|
+
return inst
|
549
|
+
|
550
|
+
# Range-based comparison optimization
|
551
|
+
if inst.op in ["<", "<=", ">", ">="]:
|
552
|
+
if left_ctx.range and right_ctx.range:
|
553
|
+
# Check if comparison result is statically known
|
554
|
+
if inst.op == "<" and left_ctx.range.max is not None and right_ctx.range.min is not None:
|
555
|
+
if left_ctx.range.max < right_ctx.range.min:
|
556
|
+
# Always true
|
557
|
+
return LoadConst(inst.dest, Constant(True, MIRType.BOOL), inst.source_location)
|
558
|
+
elif left_ctx.range.min is not None and right_ctx.range.max is not None:
|
559
|
+
if left_ctx.range.min >= right_ctx.range.max:
|
560
|
+
# Always false
|
561
|
+
return LoadConst(inst.dest, Constant(False, MIRType.BOOL), inst.source_location)
|
562
|
+
|
563
|
+
# Division by power of 2 optimization
|
564
|
+
if inst.op == "/" and right_ctx.range and right_ctx.range.is_constant():
|
565
|
+
val = right_ctx.range.min
|
566
|
+
if isinstance(val, int) and val > 0 and (val & (val - 1)) == 0:
|
567
|
+
# Power of 2 - use shift
|
568
|
+
shift_amount = val.bit_length() - 1
|
569
|
+
return ShiftOp(inst.dest, inst.left, Constant(shift_amount, MIRType.INT), ">>", inst.source_location)
|
570
|
+
|
571
|
+
return inst
|
572
|
+
|
573
|
+
def _apply_strength_reduction(self, inst: BinaryOp) -> MIRInstruction | None:
|
574
|
+
"""Apply strength reduction optimizations.
|
575
|
+
|
576
|
+
Args:
|
577
|
+
inst: The binary operation.
|
578
|
+
|
579
|
+
Returns:
|
580
|
+
Reduced instruction or original.
|
581
|
+
"""
|
582
|
+
# Check for power-of-2 optimizations with right constant
|
583
|
+
if isinstance(inst.right, Constant):
|
584
|
+
val = inst.right.value
|
585
|
+
if isinstance(val, int):
|
586
|
+
# Multiplication optimizations
|
587
|
+
if inst.op == "*":
|
588
|
+
if val == 0:
|
589
|
+
return LoadConst(inst.dest, Constant(0, MIRType.INT), inst.source_location)
|
590
|
+
elif val == 1:
|
591
|
+
return Copy(inst.dest, inst.left, inst.source_location)
|
592
|
+
elif val == 2:
|
593
|
+
return BinaryOp(inst.dest, "+", inst.left, inst.left, inst.source_location)
|
594
|
+
elif val == -1:
|
595
|
+
return UnaryOp(inst.dest, "-", inst.left, inst.source_location)
|
596
|
+
elif val > 2 and (val & (val - 1)) == 0:
|
597
|
+
# Power of 2 - use shift
|
598
|
+
shift_amount = val.bit_length() - 1
|
599
|
+
return ShiftOp(
|
600
|
+
inst.dest, inst.left, Constant(shift_amount, MIRType.INT), "<<", inst.source_location
|
601
|
+
)
|
602
|
+
|
603
|
+
# Division by power of 2
|
604
|
+
elif inst.op == "/" and val > 0 and (val & (val - 1)) == 0:
|
605
|
+
shift_amount = val.bit_length() - 1
|
606
|
+
return ShiftOp(
|
607
|
+
inst.dest, inst.left, Constant(shift_amount, MIRType.INT), ">>", inst.source_location
|
608
|
+
)
|
609
|
+
|
610
|
+
# Modulo by power of 2
|
611
|
+
elif inst.op == "%" and val > 0 and (val & (val - 1)) == 0:
|
612
|
+
mask_val = val - 1
|
613
|
+
return BinaryOp(inst.dest, "&", inst.left, Constant(mask_val, MIRType.INT), inst.source_location)
|
614
|
+
|
615
|
+
# Power optimizations
|
616
|
+
elif inst.op == "**":
|
617
|
+
if val == 0:
|
618
|
+
# x ** 0 -> 1
|
619
|
+
return LoadConst(inst.dest, Constant(1, MIRType.INT), inst.source_location)
|
620
|
+
elif val == 1:
|
621
|
+
# x ** 1 -> x
|
622
|
+
return Copy(inst.dest, inst.left, inst.source_location)
|
623
|
+
elif val == 2:
|
624
|
+
# x ** 2 -> x * x
|
625
|
+
return BinaryOp(inst.dest, "*", inst.left, inst.left, inst.source_location)
|
626
|
+
|
627
|
+
# Self operations
|
628
|
+
if inst.left == inst.right:
|
629
|
+
if inst.op == "-":
|
630
|
+
return LoadConst(inst.dest, Constant(0, MIRType.INT), inst.source_location)
|
631
|
+
elif inst.op == "/" and inst.left != Constant(0, MIRType.INT):
|
632
|
+
return LoadConst(inst.dest, Constant(1, MIRType.INT), inst.source_location)
|
633
|
+
elif inst.op == "^": # XOR
|
634
|
+
return LoadConst(inst.dest, Constant(0, MIRType.INT), inst.source_location)
|
635
|
+
elif inst.op == "%": # x % x => 0
|
636
|
+
return LoadConst(inst.dest, Constant(0, MIRType.INT), inst.source_location)
|
637
|
+
|
638
|
+
return inst
|
639
|
+
|
640
|
+
def _optimize_cross_block(self, function: MIRFunction) -> bool:
|
641
|
+
"""Perform cross-block optimizations.
|
642
|
+
|
643
|
+
Args:
|
644
|
+
function: The function to optimize.
|
645
|
+
|
646
|
+
Returns:
|
647
|
+
True if modified.
|
648
|
+
"""
|
649
|
+
modified = False
|
650
|
+
|
651
|
+
# Use dominance information for more aggressive optimizations
|
652
|
+
if not self.dominance_info:
|
653
|
+
return False
|
654
|
+
|
655
|
+
for block in function.cfg.blocks.values():
|
656
|
+
# Find values that are constant along all paths to this block
|
657
|
+
for inst in block.instructions:
|
658
|
+
if isinstance(inst, BinaryOp):
|
659
|
+
# Check if operands have consistent values from dominators
|
660
|
+
if self._has_consistent_value_from_dominators(inst.left, block):
|
661
|
+
# Can optimize based on dominator information
|
662
|
+
self.stats["cross_block_optimized"] += 1
|
663
|
+
modified = True
|
664
|
+
|
665
|
+
return modified
|
666
|
+
|
667
|
+
def _has_consistent_value_from_dominators(self, value: MIRValue, block: BasicBlock) -> bool:
|
668
|
+
"""Check if a value has consistent type/range from dominators.
|
669
|
+
|
670
|
+
Args:
|
671
|
+
value: The value to check.
|
672
|
+
block: The current block.
|
673
|
+
|
674
|
+
Returns:
|
675
|
+
True if value is consistent.
|
676
|
+
"""
|
677
|
+
if not self.dominance_info or not self.use_def_chains:
|
678
|
+
return False
|
679
|
+
|
680
|
+
# Get definition of value
|
681
|
+
def_inst = self.use_def_chains.get_definition(value)
|
682
|
+
if not def_inst:
|
683
|
+
return False
|
684
|
+
|
685
|
+
# Check if definition dominates this block
|
686
|
+
def_block = self._find_block_for_instruction(def_inst)
|
687
|
+
if def_block and self.dominance_info.dominates(def_block, block):
|
688
|
+
return True
|
689
|
+
|
690
|
+
return False
|
691
|
+
|
692
|
+
def _find_block_for_instruction(self, inst: MIRInstruction) -> BasicBlock | None:
|
693
|
+
"""Find which block contains an instruction.
|
694
|
+
|
695
|
+
Args:
|
696
|
+
inst: The instruction to find.
|
697
|
+
|
698
|
+
Returns:
|
699
|
+
The containing block or None.
|
700
|
+
"""
|
701
|
+
# This would need access to the function's CFG
|
702
|
+
# For now, return None
|
703
|
+
return None
|
704
|
+
|
705
|
+
def _is_saturating_pattern(self, inst: BinaryOp, type_contexts: dict[MIRValue, TypeContext]) -> bool:
|
706
|
+
"""Check if this is a saturating arithmetic pattern.
|
707
|
+
|
708
|
+
Args:
|
709
|
+
inst: The instruction.
|
710
|
+
type_contexts: Type contexts.
|
711
|
+
|
712
|
+
Returns:
|
713
|
+
True if saturating pattern.
|
714
|
+
"""
|
715
|
+
# Simple heuristic for now
|
716
|
+
if inst.op != "+":
|
717
|
+
return False
|
718
|
+
|
719
|
+
# Check if result is used in a min operation with a constant
|
720
|
+
if self.use_def_chains:
|
721
|
+
uses = self.use_def_chains.get_uses(inst.dest)
|
722
|
+
for use in uses:
|
723
|
+
if isinstance(use, MinOp):
|
724
|
+
return True
|
725
|
+
|
726
|
+
return False
|
727
|
+
|
728
|
+
def _is_identity_operation(self, inst: BinaryOp) -> bool:
|
729
|
+
"""Check if this is an identity operation.
|
730
|
+
|
731
|
+
Args:
|
732
|
+
inst: The binary operation.
|
733
|
+
|
734
|
+
Returns:
|
735
|
+
True if identity operation.
|
736
|
+
"""
|
737
|
+
if isinstance(inst.right, Constant):
|
738
|
+
val = inst.right.value
|
739
|
+
if inst.op == "+" and val == 0:
|
740
|
+
return True
|
741
|
+
elif inst.op == "-" and val == 0:
|
742
|
+
return True
|
743
|
+
elif inst.op == "*" and val == 1:
|
744
|
+
return True
|
745
|
+
elif inst.op == "/" and val == 1:
|
746
|
+
return True
|
747
|
+
|
748
|
+
if isinstance(inst.left, Constant):
|
749
|
+
val = inst.left.value
|
750
|
+
if inst.op == "+" and val == 0:
|
751
|
+
return True
|
752
|
+
elif inst.op == "*" and val == 1:
|
753
|
+
return True
|
754
|
+
|
755
|
+
return False
|
756
|
+
|
757
|
+
def _fold_binary_constant(self, op: str, left: Constant, right: Constant) -> Constant | None:
|
758
|
+
"""Fold binary operation on constants.
|
759
|
+
|
760
|
+
Args:
|
761
|
+
op: The operator.
|
762
|
+
left: Left constant.
|
763
|
+
right: Right constant.
|
764
|
+
|
765
|
+
Returns:
|
766
|
+
Folded constant or None.
|
767
|
+
"""
|
768
|
+
try:
|
769
|
+
left_val = left.value
|
770
|
+
right_val = right.value
|
771
|
+
|
772
|
+
# Integer operations
|
773
|
+
if left.type == MIRType.INT and right.type == MIRType.INT:
|
774
|
+
if op == "+":
|
775
|
+
return Constant(left_val + right_val, MIRType.INT)
|
776
|
+
elif op == "-":
|
777
|
+
return Constant(left_val - right_val, MIRType.INT)
|
778
|
+
elif op == "*":
|
779
|
+
return Constant(left_val * right_val, MIRType.INT)
|
780
|
+
elif op == "/" and right_val != 0:
|
781
|
+
return Constant(left_val // right_val, MIRType.INT)
|
782
|
+
elif op == "%" and right_val != 0:
|
783
|
+
return Constant(left_val % right_val, MIRType.INT)
|
784
|
+
elif op == "**":
|
785
|
+
return Constant(left_val**right_val, MIRType.INT)
|
786
|
+
elif op == "&":
|
787
|
+
return Constant(left_val & right_val, MIRType.INT)
|
788
|
+
elif op == "|":
|
789
|
+
return Constant(left_val | right_val, MIRType.INT)
|
790
|
+
elif op == "^":
|
791
|
+
return Constant(left_val ^ right_val, MIRType.INT)
|
792
|
+
elif op == "<<":
|
793
|
+
return Constant(left_val << right_val, MIRType.INT)
|
794
|
+
elif op == ">>":
|
795
|
+
return Constant(left_val >> right_val, MIRType.INT)
|
796
|
+
# Comparisons
|
797
|
+
elif op == "==":
|
798
|
+
return Constant(left_val == right_val, MIRType.BOOL)
|
799
|
+
elif op == "!=":
|
800
|
+
return Constant(left_val != right_val, MIRType.BOOL)
|
801
|
+
elif op == "<":
|
802
|
+
return Constant(left_val < right_val, MIRType.BOOL)
|
803
|
+
elif op == "<=":
|
804
|
+
return Constant(left_val <= right_val, MIRType.BOOL)
|
805
|
+
elif op == ">":
|
806
|
+
return Constant(left_val > right_val, MIRType.BOOL)
|
807
|
+
elif op == ">=":
|
808
|
+
return Constant(left_val >= right_val, MIRType.BOOL)
|
809
|
+
|
810
|
+
# Float operations
|
811
|
+
elif left.type == MIRType.FLOAT or right.type == MIRType.FLOAT:
|
812
|
+
left_val = float(left_val)
|
813
|
+
right_val = float(right_val)
|
814
|
+
|
815
|
+
if op == "+":
|
816
|
+
return Constant(left_val + right_val, MIRType.FLOAT)
|
817
|
+
elif op == "-":
|
818
|
+
return Constant(left_val - right_val, MIRType.FLOAT)
|
819
|
+
elif op == "*":
|
820
|
+
return Constant(left_val * right_val, MIRType.FLOAT)
|
821
|
+
elif op == "/" and right_val != 0.0:
|
822
|
+
return Constant(left_val / right_val, MIRType.FLOAT)
|
823
|
+
elif op == "**":
|
824
|
+
return Constant(left_val**right_val, MIRType.FLOAT)
|
825
|
+
# Comparisons
|
826
|
+
elif op == "==":
|
827
|
+
return Constant(left_val == right_val, MIRType.BOOL)
|
828
|
+
elif op == "!=":
|
829
|
+
return Constant(left_val != right_val, MIRType.BOOL)
|
830
|
+
elif op == "<":
|
831
|
+
return Constant(left_val < right_val, MIRType.BOOL)
|
832
|
+
elif op == "<=":
|
833
|
+
return Constant(left_val <= right_val, MIRType.BOOL)
|
834
|
+
elif op == ">":
|
835
|
+
return Constant(left_val > right_val, MIRType.BOOL)
|
836
|
+
elif op == ">=":
|
837
|
+
return Constant(left_val >= right_val, MIRType.BOOL)
|
838
|
+
|
839
|
+
# Boolean operations
|
840
|
+
elif left.type == MIRType.BOOL and right.type == MIRType.BOOL:
|
841
|
+
if op == "and":
|
842
|
+
return Constant(left_val and right_val, MIRType.BOOL)
|
843
|
+
elif op == "or":
|
844
|
+
return Constant(left_val or right_val, MIRType.BOOL)
|
845
|
+
elif op == "==":
|
846
|
+
return Constant(left_val == right_val, MIRType.BOOL)
|
847
|
+
elif op == "!=":
|
848
|
+
return Constant(left_val != right_val, MIRType.BOOL)
|
849
|
+
|
850
|
+
# String operations
|
851
|
+
elif left.type == MIRType.STRING and right.type == MIRType.STRING:
|
852
|
+
if op == "+":
|
853
|
+
return Constant(str(left_val) + str(right_val), MIRType.STRING)
|
854
|
+
elif op == "==":
|
855
|
+
return Constant(left_val == right_val, MIRType.BOOL)
|
856
|
+
elif op == "!=":
|
857
|
+
return Constant(left_val != right_val, MIRType.BOOL)
|
858
|
+
|
859
|
+
except (ValueError, TypeError, ZeroDivisionError):
|
860
|
+
pass
|
861
|
+
|
862
|
+
return None
|
863
|
+
|
864
|
+
def _invert_comparison(self, op: str) -> str | None:
|
865
|
+
"""Get the inverted comparison operator.
|
866
|
+
|
867
|
+
Args:
|
868
|
+
op: The comparison operator.
|
869
|
+
|
870
|
+
Returns:
|
871
|
+
The inverted operator or None if not a comparison.
|
872
|
+
"""
|
873
|
+
inversions = {
|
874
|
+
"==": "!=",
|
875
|
+
"!=": "==",
|
876
|
+
"<": ">=",
|
877
|
+
"<=": ">",
|
878
|
+
">": "<=",
|
879
|
+
">=": "<",
|
880
|
+
}
|
881
|
+
return inversions.get(op)
|
882
|
+
|
883
|
+
def _fold_unary_constant(self, op: str, operand: Constant) -> Constant | None:
|
884
|
+
"""Fold unary operation on constant.
|
885
|
+
|
886
|
+
Args:
|
887
|
+
op: The operator.
|
888
|
+
operand: The operand.
|
889
|
+
|
890
|
+
Returns:
|
891
|
+
Folded constant or None.
|
892
|
+
"""
|
893
|
+
try:
|
894
|
+
if op == "-":
|
895
|
+
if operand.type == MIRType.INT:
|
896
|
+
return Constant(-operand.value, MIRType.INT)
|
897
|
+
elif operand.type == MIRType.FLOAT:
|
898
|
+
return Constant(-operand.value, MIRType.FLOAT)
|
899
|
+
elif op == "not":
|
900
|
+
return Constant(not operand.value, MIRType.BOOL)
|
901
|
+
elif op == "~" and operand.type == MIRType.INT:
|
902
|
+
return Constant(~operand.value, MIRType.INT)
|
903
|
+
except (ValueError, TypeError):
|
904
|
+
pass
|
905
|
+
|
906
|
+
return None
|