machine-dialect 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machine_dialect/__main__.py +667 -0
- machine_dialect/agent/__init__.py +5 -0
- machine_dialect/agent/agent.py +360 -0
- machine_dialect/ast/__init__.py +95 -0
- machine_dialect/ast/ast_node.py +35 -0
- machine_dialect/ast/call_expression.py +82 -0
- machine_dialect/ast/dict_extraction.py +60 -0
- machine_dialect/ast/expressions.py +439 -0
- machine_dialect/ast/literals.py +309 -0
- machine_dialect/ast/program.py +35 -0
- machine_dialect/ast/statements.py +1433 -0
- machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
- machine_dialect/ast/tests/test_boolean_literal.py +29 -0
- machine_dialect/ast/tests/test_collection_hir.py +138 -0
- machine_dialect/ast/tests/test_define_statement.py +142 -0
- machine_dialect/ast/tests/test_desugar.py +541 -0
- machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
- machine_dialect/cfg/__init__.py +6 -0
- machine_dialect/cfg/config.py +156 -0
- machine_dialect/cfg/examples.py +221 -0
- machine_dialect/cfg/generate_with_ai.py +187 -0
- machine_dialect/cfg/openai_generation.py +200 -0
- machine_dialect/cfg/parser.py +94 -0
- machine_dialect/cfg/tests/__init__.py +1 -0
- machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
- machine_dialect/cfg/tests/test_config.py +188 -0
- machine_dialect/cfg/tests/test_examples.py +391 -0
- machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
- machine_dialect/cfg/tests/test_openai_generation.py +256 -0
- machine_dialect/codegen/__init__.py +5 -0
- machine_dialect/codegen/bytecode_module.py +89 -0
- machine_dialect/codegen/bytecode_serializer.py +300 -0
- machine_dialect/codegen/opcodes.py +101 -0
- machine_dialect/codegen/register_codegen.py +1996 -0
- machine_dialect/codegen/symtab.py +208 -0
- machine_dialect/codegen/tests/__init__.py +1 -0
- machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
- machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
- machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
- machine_dialect/codegen/tests/test_symtab.py +418 -0
- machine_dialect/codegen/vm_serializer.py +621 -0
- machine_dialect/compiler/__init__.py +18 -0
- machine_dialect/compiler/compiler.py +197 -0
- machine_dialect/compiler/config.py +149 -0
- machine_dialect/compiler/context.py +149 -0
- machine_dialect/compiler/phases/__init__.py +19 -0
- machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
- machine_dialect/compiler/phases/codegen.py +40 -0
- machine_dialect/compiler/phases/hir_generation.py +39 -0
- machine_dialect/compiler/phases/mir_generation.py +86 -0
- machine_dialect/compiler/phases/optimization.py +110 -0
- machine_dialect/compiler/phases/parsing.py +39 -0
- machine_dialect/compiler/pipeline.py +143 -0
- machine_dialect/compiler/tests/__init__.py +1 -0
- machine_dialect/compiler/tests/test_compiler.py +568 -0
- machine_dialect/compiler/vm_runner.py +173 -0
- machine_dialect/errors/__init__.py +32 -0
- machine_dialect/errors/exceptions.py +369 -0
- machine_dialect/errors/messages.py +82 -0
- machine_dialect/errors/tests/__init__.py +0 -0
- machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
- machine_dialect/errors/tests/test_name_errors.py +118 -0
- machine_dialect/helpers/__init__.py +0 -0
- machine_dialect/helpers/stopwords.py +225 -0
- machine_dialect/helpers/validators.py +30 -0
- machine_dialect/lexer/__init__.py +9 -0
- machine_dialect/lexer/constants.py +23 -0
- machine_dialect/lexer/lexer.py +907 -0
- machine_dialect/lexer/tests/__init__.py +0 -0
- machine_dialect/lexer/tests/helpers.py +86 -0
- machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
- machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
- machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
- machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
- machine_dialect/lexer/tests/test_comments.py +200 -0
- machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
- machine_dialect/lexer/tests/test_lexer_position.py +113 -0
- machine_dialect/lexer/tests/test_list_tokens.py +282 -0
- machine_dialect/lexer/tests/test_stopwords.py +80 -0
- machine_dialect/lexer/tests/test_strict_equality.py +129 -0
- machine_dialect/lexer/tests/test_token.py +41 -0
- machine_dialect/lexer/tests/test_tokenization.py +294 -0
- machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
- machine_dialect/lexer/tests/test_url_literals.py +169 -0
- machine_dialect/lexer/tokens.py +487 -0
- machine_dialect/linter/__init__.py +10 -0
- machine_dialect/linter/__main__.py +144 -0
- machine_dialect/linter/linter.py +154 -0
- machine_dialect/linter/rules/__init__.py +8 -0
- machine_dialect/linter/rules/base.py +112 -0
- machine_dialect/linter/rules/statement_termination.py +99 -0
- machine_dialect/linter/tests/__init__.py +1 -0
- machine_dialect/linter/tests/mdrules/__init__.py +0 -0
- machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
- machine_dialect/linter/tests/test_linter.py +81 -0
- machine_dialect/linter/tests/test_rules.py +110 -0
- machine_dialect/linter/tests/test_violations.py +71 -0
- machine_dialect/linter/violations.py +51 -0
- machine_dialect/mir/__init__.py +69 -0
- machine_dialect/mir/analyses/__init__.py +20 -0
- machine_dialect/mir/analyses/alias_analysis.py +315 -0
- machine_dialect/mir/analyses/dominance_analysis.py +49 -0
- machine_dialect/mir/analyses/escape_analysis.py +286 -0
- machine_dialect/mir/analyses/loop_analysis.py +272 -0
- machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
- machine_dialect/mir/analyses/type_analysis.py +448 -0
- machine_dialect/mir/analyses/use_def_chains.py +232 -0
- machine_dialect/mir/basic_block.py +385 -0
- machine_dialect/mir/dataflow.py +445 -0
- machine_dialect/mir/debug_info.py +208 -0
- machine_dialect/mir/hir_to_mir.py +1738 -0
- machine_dialect/mir/mir_dumper.py +366 -0
- machine_dialect/mir/mir_function.py +167 -0
- machine_dialect/mir/mir_instructions.py +1877 -0
- machine_dialect/mir/mir_interpreter.py +556 -0
- machine_dialect/mir/mir_module.py +225 -0
- machine_dialect/mir/mir_printer.py +480 -0
- machine_dialect/mir/mir_transformer.py +410 -0
- machine_dialect/mir/mir_types.py +367 -0
- machine_dialect/mir/mir_validation.py +455 -0
- machine_dialect/mir/mir_values.py +268 -0
- machine_dialect/mir/optimization_config.py +233 -0
- machine_dialect/mir/optimization_pass.py +251 -0
- machine_dialect/mir/optimization_pipeline.py +355 -0
- machine_dialect/mir/optimizations/__init__.py +84 -0
- machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
- machine_dialect/mir/optimizations/branch_prediction.py +372 -0
- machine_dialect/mir/optimizations/constant_propagation.py +634 -0
- machine_dialect/mir/optimizations/cse.py +398 -0
- machine_dialect/mir/optimizations/dce.py +288 -0
- machine_dialect/mir/optimizations/inlining.py +551 -0
- machine_dialect/mir/optimizations/jump_threading.py +487 -0
- machine_dialect/mir/optimizations/licm.py +405 -0
- machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
- machine_dialect/mir/optimizations/strength_reduction.py +422 -0
- machine_dialect/mir/optimizations/tail_call.py +207 -0
- machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
- machine_dialect/mir/optimizations/type_narrowing.py +397 -0
- machine_dialect/mir/optimizations/type_specialization.py +447 -0
- machine_dialect/mir/optimizations/type_specific.py +906 -0
- machine_dialect/mir/optimize_mir.py +89 -0
- machine_dialect/mir/pass_manager.py +391 -0
- machine_dialect/mir/profiling/__init__.py +26 -0
- machine_dialect/mir/profiling/profile_collector.py +318 -0
- machine_dialect/mir/profiling/profile_data.py +372 -0
- machine_dialect/mir/profiling/profile_reader.py +272 -0
- machine_dialect/mir/profiling/profile_writer.py +226 -0
- machine_dialect/mir/register_allocation.py +302 -0
- machine_dialect/mir/reporting/__init__.py +17 -0
- machine_dialect/mir/reporting/optimization_reporter.py +314 -0
- machine_dialect/mir/reporting/report_formatter.py +289 -0
- machine_dialect/mir/ssa_construction.py +342 -0
- machine_dialect/mir/tests/__init__.py +1 -0
- machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
- machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
- machine_dialect/mir/tests/test_algebraic_division.py +126 -0
- machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
- machine_dialect/mir/tests/test_basic_block.py +425 -0
- machine_dialect/mir/tests/test_branch_prediction.py +459 -0
- machine_dialect/mir/tests/test_call_lowering.py +168 -0
- machine_dialect/mir/tests/test_collection_lowering.py +604 -0
- machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
- machine_dialect/mir/tests/test_custom_passes.py +166 -0
- machine_dialect/mir/tests/test_debug_info.py +285 -0
- machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
- machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
- machine_dialect/mir/tests/test_double_negation.py +231 -0
- machine_dialect/mir/tests/test_escape_analysis.py +233 -0
- machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
- machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
- machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
- machine_dialect/mir/tests/test_inlining.py +435 -0
- machine_dialect/mir/tests/test_licm.py +472 -0
- machine_dialect/mir/tests/test_mir_dumper.py +313 -0
- machine_dialect/mir/tests/test_mir_instructions.py +445 -0
- machine_dialect/mir/tests/test_mir_module.py +860 -0
- machine_dialect/mir/tests/test_mir_printer.py +387 -0
- machine_dialect/mir/tests/test_mir_types.py +123 -0
- machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
- machine_dialect/mir/tests/test_mir_validation.py +378 -0
- machine_dialect/mir/tests/test_mir_values.py +168 -0
- machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
- machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
- machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
- machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
- machine_dialect/mir/tests/test_pass_manager.py +294 -0
- machine_dialect/mir/tests/test_pass_registration.py +64 -0
- machine_dialect/mir/tests/test_profiling.py +356 -0
- machine_dialect/mir/tests/test_register_allocation.py +307 -0
- machine_dialect/mir/tests/test_report_formatters.py +372 -0
- machine_dialect/mir/tests/test_ssa_construction.py +433 -0
- machine_dialect/mir/tests/test_tail_call.py +236 -0
- machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
- machine_dialect/mir/tests/test_type_narrowing.py +277 -0
- machine_dialect/mir/tests/test_type_specialization.py +421 -0
- machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
- machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
- machine_dialect/mir/type_inference.py +368 -0
- machine_dialect/parser/__init__.py +12 -0
- machine_dialect/parser/enums.py +45 -0
- machine_dialect/parser/parser.py +3655 -0
- machine_dialect/parser/protocols.py +11 -0
- machine_dialect/parser/symbol_table.py +169 -0
- machine_dialect/parser/tests/__init__.py +0 -0
- machine_dialect/parser/tests/helper_functions.py +193 -0
- machine_dialect/parser/tests/test_action_statements.py +334 -0
- machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
- machine_dialect/parser/tests/test_call_statements.py +154 -0
- machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
- machine_dialect/parser/tests/test_collection_mutations.py +264 -0
- machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
- machine_dialect/parser/tests/test_define_integration.py +468 -0
- machine_dialect/parser/tests/test_define_statements.py +311 -0
- machine_dialect/parser/tests/test_dict_extraction.py +115 -0
- machine_dialect/parser/tests/test_empty_literal.py +155 -0
- machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
- machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
- machine_dialect/parser/tests/test_if_empty_block.py +61 -0
- machine_dialect/parser/tests/test_if_statements.py +299 -0
- machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
- machine_dialect/parser/tests/test_infix_expressions.py +680 -0
- machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
- machine_dialect/parser/tests/test_interaction_statements.py +269 -0
- machine_dialect/parser/tests/test_list_literals.py +277 -0
- machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
- machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
- machine_dialect/parser/tests/test_parse_errors.py +114 -0
- machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
- machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
- machine_dialect/parser/tests/test_program.py +13 -0
- machine_dialect/parser/tests/test_return_statements.py +89 -0
- machine_dialect/parser/tests/test_set_statements.py +152 -0
- machine_dialect/parser/tests/test_strict_equality.py +258 -0
- machine_dialect/parser/tests/test_symbol_table.py +217 -0
- machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
- machine_dialect/parser/tests/test_utility_statements.py +423 -0
- machine_dialect/parser/token_buffer.py +159 -0
- machine_dialect/repl/__init__.py +3 -0
- machine_dialect/repl/repl.py +426 -0
- machine_dialect/repl/tests/__init__.py +0 -0
- machine_dialect/repl/tests/test_repl.py +606 -0
- machine_dialect/semantic/__init__.py +12 -0
- machine_dialect/semantic/analyzer.py +906 -0
- machine_dialect/semantic/error_messages.py +189 -0
- machine_dialect/semantic/tests/__init__.py +1 -0
- machine_dialect/semantic/tests/test_analyzer.py +364 -0
- machine_dialect/semantic/tests/test_error_messages.py +104 -0
- machine_dialect/tests/edge_cases/__init__.py +10 -0
- machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
- machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
- machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
- machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
- machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
- machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
- machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
- machine_dialect/tests/integration/test_list_compilation.py +395 -0
- machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
- machine_dialect/type_checking/__init__.py +21 -0
- machine_dialect/type_checking/tests/__init__.py +1 -0
- machine_dialect/type_checking/tests/test_type_system.py +230 -0
- machine_dialect/type_checking/type_system.py +270 -0
- machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
- machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
- machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
- machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
- machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
- machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
- machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,422 @@
|
|
1
|
+
"""Strength reduction optimization pass.
|
2
|
+
|
3
|
+
This module implements strength reduction at the MIR level, replacing
|
4
|
+
expensive operations with cheaper equivalents.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
8
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
9
|
+
from machine_dialect.mir.mir_instructions import (
|
10
|
+
BinaryOp,
|
11
|
+
Copy,
|
12
|
+
LoadConst,
|
13
|
+
MIRInstruction,
|
14
|
+
UnaryOp,
|
15
|
+
)
|
16
|
+
from machine_dialect.mir.mir_transformer import MIRTransformer
|
17
|
+
from machine_dialect.mir.mir_types import MIRType
|
18
|
+
from machine_dialect.mir.mir_values import Constant, MIRValue
|
19
|
+
from machine_dialect.mir.optimization_pass import (
|
20
|
+
OptimizationPass,
|
21
|
+
PassInfo,
|
22
|
+
PassType,
|
23
|
+
PreservationLevel,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class StrengthReduction(OptimizationPass):
|
28
|
+
"""Strength reduction optimization pass."""
|
29
|
+
|
30
|
+
def get_info(self) -> PassInfo:
|
31
|
+
"""Get pass information.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
Pass information.
|
35
|
+
"""
|
36
|
+
return PassInfo(
|
37
|
+
name="strength-reduction",
|
38
|
+
description="Replace expensive operations with cheaper equivalents",
|
39
|
+
pass_type=PassType.OPTIMIZATION,
|
40
|
+
requires=[],
|
41
|
+
preserves=PreservationLevel.CFG,
|
42
|
+
)
|
43
|
+
|
44
|
+
def run_on_function(self, function: MIRFunction) -> bool:
|
45
|
+
"""Run strength reduction on a function.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
function: The function to optimize.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
True if the function was modified.
|
52
|
+
"""
|
53
|
+
transformer = MIRTransformer(function)
|
54
|
+
|
55
|
+
for block in function.cfg.blocks.values():
|
56
|
+
self._reduce_block(block, transformer)
|
57
|
+
|
58
|
+
return transformer.modified
|
59
|
+
|
60
|
+
def _reduce_block(self, block: BasicBlock, transformer: MIRTransformer) -> None:
|
61
|
+
"""Apply strength reduction to a block.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
block: The block to optimize.
|
65
|
+
transformer: MIR transformer.
|
66
|
+
"""
|
67
|
+
for inst in list(block.instructions):
|
68
|
+
if isinstance(inst, BinaryOp):
|
69
|
+
self._reduce_binary_op(inst, block, transformer)
|
70
|
+
elif isinstance(inst, UnaryOp):
|
71
|
+
self._reduce_unary_op(inst, block, transformer)
|
72
|
+
|
73
|
+
def _reduce_binary_op(
|
74
|
+
self,
|
75
|
+
inst: BinaryOp,
|
76
|
+
block: BasicBlock,
|
77
|
+
transformer: MIRTransformer,
|
78
|
+
) -> None:
|
79
|
+
"""Apply strength reduction to a binary operation.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
inst: Binary operation instruction.
|
83
|
+
block: Containing block.
|
84
|
+
transformer: MIR transformer.
|
85
|
+
"""
|
86
|
+
# Check for multiplication by power of 2
|
87
|
+
if inst.op == "*":
|
88
|
+
# Check for x * 1 or 1 * x first (special cases)
|
89
|
+
if self._is_one(inst.right) or self._is_one(inst.left):
|
90
|
+
# Don't convert to shift, let algebraic simplifications handle it
|
91
|
+
pass
|
92
|
+
elif self._is_power_of_two_constant(inst.right):
|
93
|
+
shift = self._get_power_of_two(inst.right)
|
94
|
+
if shift is not None and shift > 0: # Only optimize for shift > 0
|
95
|
+
# Replace multiplication with left shift
|
96
|
+
shift_const = Constant(shift, MIRType.INT)
|
97
|
+
new_inst = BinaryOp(inst.dest, "<<", inst.left, shift_const, inst.source_location)
|
98
|
+
transformer.replace_instruction(block, inst, new_inst)
|
99
|
+
self.stats["multiply_to_shift"] = self.stats.get("multiply_to_shift", 0) + 1
|
100
|
+
return
|
101
|
+
elif self._is_power_of_two_constant(inst.left):
|
102
|
+
shift = self._get_power_of_two(inst.left)
|
103
|
+
if shift is not None and shift > 0: # Only optimize for shift > 0
|
104
|
+
# Replace multiplication with left shift (commutative)
|
105
|
+
shift_const = Constant(shift, MIRType.INT)
|
106
|
+
new_inst = BinaryOp(inst.dest, "<<", inst.right, shift_const, inst.source_location)
|
107
|
+
transformer.replace_instruction(block, inst, new_inst)
|
108
|
+
self.stats["multiply_to_shift"] = self.stats.get("multiply_to_shift", 0) + 1
|
109
|
+
return
|
110
|
+
|
111
|
+
# Check for division by power of 2
|
112
|
+
elif inst.op in ["/", "//"]:
|
113
|
+
# Check for x / 1 first (special case)
|
114
|
+
if self._is_one(inst.right):
|
115
|
+
# Don't convert to shift, let algebraic simplifications handle it
|
116
|
+
pass
|
117
|
+
elif self._is_power_of_two_constant(inst.right):
|
118
|
+
shift = self._get_power_of_two(inst.right)
|
119
|
+
if shift is not None and shift > 0: # Only optimize for shift > 0
|
120
|
+
# Replace division with right shift (for integers)
|
121
|
+
shift_const = Constant(shift, MIRType.INT)
|
122
|
+
new_inst = BinaryOp(inst.dest, ">>", inst.left, shift_const, inst.source_location)
|
123
|
+
transformer.replace_instruction(block, inst, new_inst)
|
124
|
+
self.stats["divide_to_shift"] = self.stats.get("divide_to_shift", 0) + 1
|
125
|
+
return
|
126
|
+
|
127
|
+
# Check for modulo by power of 2
|
128
|
+
elif inst.op == "%":
|
129
|
+
if self._is_power_of_two_constant(inst.right):
|
130
|
+
power = self._get_constant_value(inst.right)
|
131
|
+
if power is not None and power > 0:
|
132
|
+
# Replace modulo with bitwise AND (n % power = n & (power - 1))
|
133
|
+
mask = Constant(power - 1, MIRType.INT)
|
134
|
+
new_inst = BinaryOp(inst.dest, "&", inst.left, mask, inst.source_location)
|
135
|
+
transformer.replace_instruction(block, inst, new_inst)
|
136
|
+
self.stats["modulo_to_and"] = self.stats.get("modulo_to_and", 0) + 1
|
137
|
+
return
|
138
|
+
|
139
|
+
# Algebraic simplifications
|
140
|
+
self._apply_algebraic_simplifications(inst, block, transformer)
|
141
|
+
|
142
|
+
def _reduce_unary_op(
|
143
|
+
self,
|
144
|
+
inst: UnaryOp,
|
145
|
+
block: BasicBlock,
|
146
|
+
transformer: MIRTransformer,
|
147
|
+
) -> None:
|
148
|
+
"""Apply strength reduction to a unary operation.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
inst: Unary operation instruction.
|
152
|
+
block: Containing block.
|
153
|
+
transformer: MIR transformer.
|
154
|
+
"""
|
155
|
+
# Double negation elimination
|
156
|
+
if inst.op == "-":
|
157
|
+
# Check if operand is result of another negation
|
158
|
+
# This would require tracking def-use chains
|
159
|
+
pass
|
160
|
+
|
161
|
+
# Boolean not simplification
|
162
|
+
elif inst.op == "not":
|
163
|
+
# Check for not(not(x)) pattern
|
164
|
+
pass
|
165
|
+
|
166
|
+
def _apply_algebraic_simplifications(
|
167
|
+
self,
|
168
|
+
inst: BinaryOp,
|
169
|
+
block: BasicBlock,
|
170
|
+
transformer: MIRTransformer,
|
171
|
+
) -> None:
|
172
|
+
"""Apply algebraic simplifications to binary operations.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
inst: Binary operation instruction.
|
176
|
+
block: Containing block.
|
177
|
+
transformer: MIR transformer.
|
178
|
+
"""
|
179
|
+
new_inst: MIRInstruction
|
180
|
+
# Identity operations
|
181
|
+
if inst.op == "+":
|
182
|
+
# x + 0 = x
|
183
|
+
if self._is_zero(inst.right):
|
184
|
+
new_inst = Copy(inst.dest, inst.left, inst.source_location)
|
185
|
+
transformer.replace_instruction(block, inst, new_inst)
|
186
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
187
|
+
return
|
188
|
+
elif self._is_zero(inst.left):
|
189
|
+
new_inst = Copy(inst.dest, inst.right, inst.source_location)
|
190
|
+
transformer.replace_instruction(block, inst, new_inst)
|
191
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
192
|
+
return
|
193
|
+
|
194
|
+
elif inst.op == "-":
|
195
|
+
# x - 0 = x
|
196
|
+
if self._is_zero(inst.right):
|
197
|
+
new_inst = Copy(inst.dest, inst.left, inst.source_location)
|
198
|
+
transformer.replace_instruction(block, inst, new_inst)
|
199
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
200
|
+
return
|
201
|
+
# x - x = 0
|
202
|
+
elif self._values_equal(inst.left, inst.right):
|
203
|
+
zero = Constant(0, MIRType.INT)
|
204
|
+
new_inst = LoadConst(inst.dest, zero, inst.source_location)
|
205
|
+
transformer.replace_instruction(block, inst, new_inst)
|
206
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
207
|
+
return
|
208
|
+
|
209
|
+
elif inst.op == "*":
|
210
|
+
# x * 0 = 0
|
211
|
+
if self._is_zero(inst.right) or self._is_zero(inst.left):
|
212
|
+
zero = Constant(0, MIRType.INT)
|
213
|
+
new_inst = LoadConst(inst.dest, zero, inst.source_location)
|
214
|
+
transformer.replace_instruction(block, inst, new_inst)
|
215
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
216
|
+
return
|
217
|
+
# x * 1 = x
|
218
|
+
elif self._is_one(inst.right):
|
219
|
+
new_inst = Copy(inst.dest, inst.left, inst.source_location)
|
220
|
+
transformer.replace_instruction(block, inst, new_inst)
|
221
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
222
|
+
return
|
223
|
+
elif self._is_one(inst.left):
|
224
|
+
new_inst = Copy(inst.dest, inst.right, inst.source_location)
|
225
|
+
transformer.replace_instruction(block, inst, new_inst)
|
226
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
227
|
+
return
|
228
|
+
# x * -1 = -x
|
229
|
+
elif self._is_negative_one(inst.right):
|
230
|
+
new_inst = UnaryOp(inst.dest, "-", inst.left, inst.source_location)
|
231
|
+
transformer.replace_instruction(block, inst, new_inst)
|
232
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
233
|
+
return
|
234
|
+
elif self._is_negative_one(inst.left):
|
235
|
+
new_inst = UnaryOp(inst.dest, "-", inst.right, inst.source_location)
|
236
|
+
transformer.replace_instruction(block, inst, new_inst)
|
237
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
238
|
+
return
|
239
|
+
|
240
|
+
elif inst.op in ["/", "//"]:
|
241
|
+
# x / 1 = x
|
242
|
+
if self._is_one(inst.right):
|
243
|
+
new_inst = Copy(inst.dest, inst.left, inst.source_location)
|
244
|
+
transformer.replace_instruction(block, inst, new_inst)
|
245
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
246
|
+
return
|
247
|
+
# x / x = 1 (if x != 0)
|
248
|
+
elif self._values_equal(inst.left, inst.right):
|
249
|
+
one = Constant(1, MIRType.INT)
|
250
|
+
new_inst = LoadConst(inst.dest, one, inst.source_location)
|
251
|
+
transformer.replace_instruction(block, inst, new_inst)
|
252
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
253
|
+
return
|
254
|
+
|
255
|
+
# Boolean operations
|
256
|
+
elif inst.op == "and":
|
257
|
+
# x and False = False (check this first as it's stronger)
|
258
|
+
if self._is_false(inst.right) or self._is_false(inst.left):
|
259
|
+
false = Constant(False, MIRType.BOOL)
|
260
|
+
new_inst = LoadConst(inst.dest, false, inst.source_location)
|
261
|
+
transformer.replace_instruction(block, inst, new_inst)
|
262
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
263
|
+
return
|
264
|
+
# x and True = x
|
265
|
+
elif self._is_true(inst.right):
|
266
|
+
new_inst = Copy(inst.dest, inst.left, inst.source_location)
|
267
|
+
transformer.replace_instruction(block, inst, new_inst)
|
268
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
269
|
+
return
|
270
|
+
elif self._is_true(inst.left):
|
271
|
+
new_inst = Copy(inst.dest, inst.right, inst.source_location)
|
272
|
+
transformer.replace_instruction(block, inst, new_inst)
|
273
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
274
|
+
return
|
275
|
+
|
276
|
+
elif inst.op == "or":
|
277
|
+
# x or True = True (check this first as it's stronger)
|
278
|
+
if self._is_true(inst.right) or self._is_true(inst.left):
|
279
|
+
true = Constant(True, MIRType.BOOL)
|
280
|
+
new_inst = LoadConst(inst.dest, true, inst.source_location)
|
281
|
+
transformer.replace_instruction(block, inst, new_inst)
|
282
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
283
|
+
return
|
284
|
+
# x or False = x
|
285
|
+
elif self._is_false(inst.right):
|
286
|
+
new_inst = Copy(inst.dest, inst.left, inst.source_location)
|
287
|
+
transformer.replace_instruction(block, inst, new_inst)
|
288
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
289
|
+
return
|
290
|
+
elif self._is_false(inst.left):
|
291
|
+
new_inst = Copy(inst.dest, inst.right, inst.source_location)
|
292
|
+
transformer.replace_instruction(block, inst, new_inst)
|
293
|
+
self.stats["algebraic_simplified"] = self.stats.get("algebraic_simplified", 0) + 1
|
294
|
+
return
|
295
|
+
|
296
|
+
def _is_power_of_two_constant(self, value: MIRValue) -> bool:
|
297
|
+
"""Check if a value is a constant power of 2.
|
298
|
+
|
299
|
+
Args:
|
300
|
+
value: Value to check.
|
301
|
+
|
302
|
+
Returns:
|
303
|
+
True if the value is a power of 2 constant.
|
304
|
+
"""
|
305
|
+
if isinstance(value, Constant):
|
306
|
+
val = value.value
|
307
|
+
if isinstance(val, int) and val > 0:
|
308
|
+
# Check if only one bit is set
|
309
|
+
return (val & (val - 1)) == 0
|
310
|
+
return False
|
311
|
+
|
312
|
+
def _get_power_of_two(self, value: MIRValue) -> int | None:
|
313
|
+
"""Get the power of 2 exponent.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
value: Power of 2 constant.
|
317
|
+
|
318
|
+
Returns:
|
319
|
+
The exponent or None.
|
320
|
+
"""
|
321
|
+
if isinstance(value, Constant):
|
322
|
+
val = value.value
|
323
|
+
if isinstance(val, int) and val > 0 and (val & (val - 1)) == 0:
|
324
|
+
# Count trailing zeros to get exponent
|
325
|
+
exp = 0
|
326
|
+
while val > 1:
|
327
|
+
val >>= 1
|
328
|
+
exp += 1
|
329
|
+
return exp
|
330
|
+
return None
|
331
|
+
|
332
|
+
def _get_constant_value(self, value: MIRValue) -> int | float | bool | None:
|
333
|
+
"""Get constant value if it's a constant.
|
334
|
+
|
335
|
+
Args:
|
336
|
+
value: MIR value.
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
The constant value or None.
|
340
|
+
"""
|
341
|
+
if isinstance(value, Constant):
|
342
|
+
val = value.value
|
343
|
+
if isinstance(val, int | float | bool):
|
344
|
+
return val
|
345
|
+
return None
|
346
|
+
|
347
|
+
def _is_zero(self, value: MIRValue) -> bool:
|
348
|
+
"""Check if value is zero.
|
349
|
+
|
350
|
+
Args:
|
351
|
+
value: Value to check.
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
True if value is zero.
|
355
|
+
"""
|
356
|
+
val = self._get_constant_value(value)
|
357
|
+
return val == 0
|
358
|
+
|
359
|
+
def _is_one(self, value: MIRValue) -> bool:
|
360
|
+
"""Check if value is one.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
value: Value to check.
|
364
|
+
|
365
|
+
Returns:
|
366
|
+
True if value is one.
|
367
|
+
"""
|
368
|
+
val = self._get_constant_value(value)
|
369
|
+
return val == 1
|
370
|
+
|
371
|
+
def _is_negative_one(self, value: MIRValue) -> bool:
|
372
|
+
"""Check if value is negative one.
|
373
|
+
|
374
|
+
Args:
|
375
|
+
value: Value to check.
|
376
|
+
|
377
|
+
Returns:
|
378
|
+
True if value is -1.
|
379
|
+
"""
|
380
|
+
val = self._get_constant_value(value)
|
381
|
+
return val == -1
|
382
|
+
|
383
|
+
def _is_true(self, value: MIRValue) -> bool:
|
384
|
+
"""Check if value is boolean true.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
value: Value to check.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
True if value is boolean true.
|
391
|
+
"""
|
392
|
+
val = self._get_constant_value(value)
|
393
|
+
return val is True
|
394
|
+
|
395
|
+
def _is_false(self, value: MIRValue) -> bool:
|
396
|
+
"""Check if value is boolean false.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
value: Value to check.
|
400
|
+
|
401
|
+
Returns:
|
402
|
+
True if value is boolean false.
|
403
|
+
"""
|
404
|
+
val = self._get_constant_value(value)
|
405
|
+
return val is False
|
406
|
+
|
407
|
+
def _values_equal(self, v1: MIRValue, v2: MIRValue) -> bool:
|
408
|
+
"""Check if two values are equal.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
v1: First value.
|
412
|
+
v2: Second value.
|
413
|
+
|
414
|
+
Returns:
|
415
|
+
True if values are equal.
|
416
|
+
"""
|
417
|
+
# Simple equality check - could be enhanced
|
418
|
+
return v1 == v2
|
419
|
+
|
420
|
+
def finalize(self) -> None:
|
421
|
+
"""Finalize the pass."""
|
422
|
+
pass
|
@@ -0,0 +1,207 @@
|
|
1
|
+
"""Tail call optimization pass.
|
2
|
+
|
3
|
+
This module implements tail call optimization to transform recursive calls
|
4
|
+
in tail position into jumps, eliminating stack growth for tail-recursive functions.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from machine_dialect.mir.basic_block import BasicBlock
|
8
|
+
from machine_dialect.mir.mir_function import MIRFunction
|
9
|
+
from machine_dialect.mir.mir_instructions import Call, Copy, Return
|
10
|
+
from machine_dialect.mir.mir_module import MIRModule
|
11
|
+
from machine_dialect.mir.optimization_pass import (
|
12
|
+
ModulePass,
|
13
|
+
PassInfo,
|
14
|
+
PassType,
|
15
|
+
PreservationLevel,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
class TailCallOptimization(ModulePass):
|
20
|
+
"""Tail call optimization pass.
|
21
|
+
|
22
|
+
This pass identifies function calls in tail position and marks them
|
23
|
+
for optimization. A call is in tail position if:
|
24
|
+
1. It's immediately followed by a return of its result
|
25
|
+
2. Or it's the last instruction before a return (for void calls)
|
26
|
+
|
27
|
+
The actual transformation to jumps happens during bytecode generation.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self) -> None:
|
31
|
+
"""Initialize the tail call optimization pass."""
|
32
|
+
super().__init__()
|
33
|
+
self.stats = {
|
34
|
+
"tail_calls_found": 0,
|
35
|
+
"functions_optimized": 0,
|
36
|
+
"recursive_tail_calls": 0,
|
37
|
+
}
|
38
|
+
|
39
|
+
def get_info(self) -> PassInfo:
|
40
|
+
"""Get pass information.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Pass information.
|
44
|
+
"""
|
45
|
+
return PassInfo(
|
46
|
+
name="tail-call",
|
47
|
+
description="Optimize tail calls into jumps",
|
48
|
+
pass_type=PassType.OPTIMIZATION,
|
49
|
+
requires=[],
|
50
|
+
preserves=PreservationLevel.CFG,
|
51
|
+
)
|
52
|
+
|
53
|
+
def finalize(self) -> None:
|
54
|
+
"""Finalize the pass after running.
|
55
|
+
|
56
|
+
Override from base class - no special finalization needed.
|
57
|
+
"""
|
58
|
+
pass
|
59
|
+
|
60
|
+
def run_on_module(self, module: MIRModule) -> bool:
|
61
|
+
"""Run tail call optimization on a module.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
module: The module to optimize.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
True if the module was modified.
|
68
|
+
"""
|
69
|
+
modified = False
|
70
|
+
|
71
|
+
# Process each function
|
72
|
+
for func_name, function in module.functions.items():
|
73
|
+
if self._optimize_tail_calls_in_function(function, func_name):
|
74
|
+
modified = True
|
75
|
+
self.stats["functions_optimized"] += 1
|
76
|
+
|
77
|
+
return modified
|
78
|
+
|
79
|
+
def _optimize_tail_calls_in_function(self, function: MIRFunction, func_name: str) -> bool:
|
80
|
+
"""Optimize tail calls in a single function.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
function: The function to optimize.
|
84
|
+
func_name: Name of the function (for recursive call detection).
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
True if the function was modified.
|
88
|
+
"""
|
89
|
+
modified = False
|
90
|
+
|
91
|
+
# Process each basic block
|
92
|
+
for block in function.cfg.blocks.values():
|
93
|
+
if self._optimize_tail_calls_in_block(block, func_name):
|
94
|
+
modified = True
|
95
|
+
|
96
|
+
return modified
|
97
|
+
|
98
|
+
def _optimize_tail_calls_in_block(self, block: BasicBlock, func_name: str) -> bool:
|
99
|
+
"""Optimize tail calls in a basic block.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
block: The basic block to process.
|
103
|
+
func_name: Name of the containing function.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
True if the block was modified.
|
107
|
+
"""
|
108
|
+
modified = False
|
109
|
+
instructions = block.instructions
|
110
|
+
|
111
|
+
# Look for tail call patterns
|
112
|
+
i = 0
|
113
|
+
while i < len(instructions):
|
114
|
+
inst = instructions[i]
|
115
|
+
|
116
|
+
# Pattern 1: Call followed by Return of its result
|
117
|
+
if isinstance(inst, Call) and not inst.is_tail_call:
|
118
|
+
if i + 1 < len(instructions):
|
119
|
+
next_inst = instructions[i + 1]
|
120
|
+
|
121
|
+
# Direct return of call result
|
122
|
+
if isinstance(next_inst, Return) and next_inst.value == inst.dest:
|
123
|
+
inst.is_tail_call = True
|
124
|
+
self.stats["tail_calls_found"] += 1
|
125
|
+
modified = True
|
126
|
+
|
127
|
+
# Check if it's a recursive call
|
128
|
+
if hasattr(inst.func, "name") and inst.func.name == func_name:
|
129
|
+
self.stats["recursive_tail_calls"] += 1
|
130
|
+
|
131
|
+
# Call result copied to variable, then returned
|
132
|
+
elif i + 2 < len(instructions) and isinstance(next_inst, Copy):
|
133
|
+
third_inst = instructions[i + 2]
|
134
|
+
if (
|
135
|
+
isinstance(third_inst, Return)
|
136
|
+
and next_inst.source == inst.dest
|
137
|
+
and third_inst.value == next_inst.dest
|
138
|
+
):
|
139
|
+
inst.is_tail_call = True
|
140
|
+
self.stats["tail_calls_found"] += 1
|
141
|
+
modified = True
|
142
|
+
|
143
|
+
# Check if it's a recursive call
|
144
|
+
if hasattr(inst.func, "name") and inst.func.name == func_name:
|
145
|
+
self.stats["recursive_tail_calls"] += 1
|
146
|
+
|
147
|
+
# Pattern 2: Void call followed by return
|
148
|
+
elif isinstance(inst, Call) and inst.dest is None and not inst.is_tail_call:
|
149
|
+
if i + 1 < len(instructions):
|
150
|
+
next_inst = instructions[i + 1]
|
151
|
+
if isinstance(next_inst, Return) and next_inst.value is None:
|
152
|
+
inst.is_tail_call = True
|
153
|
+
self.stats["tail_calls_found"] += 1
|
154
|
+
modified = True
|
155
|
+
|
156
|
+
# Check if it's a recursive call
|
157
|
+
if hasattr(inst.func, "name") and inst.func.name == func_name:
|
158
|
+
self.stats["recursive_tail_calls"] += 1
|
159
|
+
|
160
|
+
i += 1
|
161
|
+
|
162
|
+
return modified
|
163
|
+
|
164
|
+
def _is_tail_position(self, block: BasicBlock, instruction_index: int) -> bool:
|
165
|
+
"""Check if an instruction is in tail position.
|
166
|
+
|
167
|
+
An instruction is in tail position if all paths from it lead
|
168
|
+
directly to a return without any other side effects.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
block: The basic block containing the instruction.
|
172
|
+
instruction_index: Index of the instruction in the block.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
True if the instruction is in tail position.
|
176
|
+
"""
|
177
|
+
# Simple check: instruction is followed only by a return
|
178
|
+
instructions = block.instructions
|
179
|
+
|
180
|
+
# Check remaining instructions after this one
|
181
|
+
for i in range(instruction_index + 1, len(instructions)):
|
182
|
+
inst = instructions[i]
|
183
|
+
|
184
|
+
# Return is ok
|
185
|
+
if isinstance(inst, Return):
|
186
|
+
return True
|
187
|
+
|
188
|
+
# Copy is ok if it's just moving the result
|
189
|
+
if isinstance(inst, Copy):
|
190
|
+
continue
|
191
|
+
|
192
|
+
# Any other instruction means not in tail position
|
193
|
+
return False
|
194
|
+
|
195
|
+
# If we reach end of block without return, check if block
|
196
|
+
# has a single successor that starts with return
|
197
|
+
# (This would require more complex CFG analysis)
|
198
|
+
|
199
|
+
return False
|
200
|
+
|
201
|
+
def get_statistics(self) -> dict[str, int]:
|
202
|
+
"""Get optimization statistics.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
Dictionary of statistics.
|
206
|
+
"""
|
207
|
+
return self.stats
|