machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,398 @@
1
+ """Common Subexpression Elimination (CSE) optimization pass.
2
+
3
+ This module implements CSE at the MIR level, eliminating redundant
4
+ computations by reusing previously computed values.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any
9
+
10
+ from machine_dialect.mir.basic_block import BasicBlock
11
+ from machine_dialect.mir.mir_function import MIRFunction
12
+ from machine_dialect.mir.mir_instructions import (
13
+ BinaryOp,
14
+ Call,
15
+ Copy,
16
+ LoadConst,
17
+ MIRInstruction,
18
+ StoreVar,
19
+ UnaryOp,
20
+ )
21
+ from machine_dialect.mir.mir_transformer import MIRTransformer
22
+ from machine_dialect.mir.mir_values import Constant, MIRValue, Temp, Variable
23
+ from machine_dialect.mir.optimization_pass import (
24
+ OptimizationPass,
25
+ PassInfo,
26
+ PassType,
27
+ PreservationLevel,
28
+ )
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class Expression:
33
+ """Represents an expression for CSE.
34
+
35
+ Attributes:
36
+ op: Operation type.
37
+ operands: Tuple of operands.
38
+ """
39
+
40
+ op: str
41
+ operands: tuple[Any, ...]
42
+
43
+ def __hash__(self) -> int:
44
+ """Hash the expression."""
45
+ return hash((self.op, self.operands))
46
+
47
+
48
+ class AvailableExpressions:
49
+ """Tracks available expressions in a block."""
50
+
51
+ def __init__(self) -> None:
52
+ """Initialize available expressions."""
53
+ # Map from expression to the value containing it
54
+ self.expressions: dict[Expression, MIRValue] = {}
55
+ # Map from value to expressions it defines
56
+ self.definitions: dict[MIRValue, set[Expression]] = {}
57
+
58
+ def add(self, expr: Expression, value: MIRValue) -> None:
59
+ """Add an available expression.
60
+
61
+ Args:
62
+ expr: The expression.
63
+ value: The value containing the expression.
64
+ """
65
+ self.expressions[expr] = value
66
+ if value not in self.definitions:
67
+ self.definitions[value] = set()
68
+ self.definitions[value].add(expr)
69
+
70
+ def get(self, expr: Expression) -> MIRValue | None:
71
+ """Get the value for an expression.
72
+
73
+ Args:
74
+ expr: The expression to look up.
75
+
76
+ Returns:
77
+ The value containing the expression or None.
78
+ """
79
+ return self.expressions.get(expr)
80
+
81
+ def invalidate(self, value: MIRValue) -> None:
82
+ """Invalidate expressions involving a value.
83
+
84
+ Args:
85
+ value: The value that changed.
86
+ """
87
+ # Remove expressions that use this value
88
+ to_remove = []
89
+ for expr in self.expressions:
90
+ if value in expr.operands:
91
+ to_remove.append(expr)
92
+
93
+ for expr in to_remove:
94
+ del self.expressions[expr]
95
+
96
+ # Remove expressions defined by this value
97
+ if value in self.definitions:
98
+ for expr in self.definitions[value]:
99
+ if expr in self.expressions:
100
+ del self.expressions[expr]
101
+ del self.definitions[value]
102
+
103
+ def copy(self) -> "AvailableExpressions":
104
+ """Create a copy of available expressions.
105
+
106
+ Returns:
107
+ A copy of this available expressions set.
108
+ """
109
+ new = AvailableExpressions()
110
+ new.expressions = self.expressions.copy()
111
+ new.definitions = {k: v.copy() for k, v in self.definitions.items()}
112
+ return new
113
+
114
+
115
+ class CommonSubexpressionElimination(OptimizationPass):
116
+ """Common subexpression elimination optimization pass."""
117
+
118
+ def get_info(self) -> PassInfo:
119
+ """Get pass information.
120
+
121
+ Returns:
122
+ Pass information.
123
+ """
124
+ return PassInfo(
125
+ name="cse",
126
+ description="Eliminate common subexpressions",
127
+ pass_type=PassType.OPTIMIZATION,
128
+ requires=[],
129
+ preserves=PreservationLevel.CFG,
130
+ )
131
+
132
+ def run_on_function(self, function: MIRFunction) -> bool:
133
+ """Run CSE on a function.
134
+
135
+ Args:
136
+ function: The function to optimize.
137
+
138
+ Returns:
139
+ True if the function was modified.
140
+ """
141
+ transformer = MIRTransformer(function)
142
+
143
+ # Perform local CSE within each block
144
+ for block in function.cfg.blocks.values():
145
+ self._local_cse(block, transformer)
146
+
147
+ # Perform global CSE across blocks
148
+ self._global_cse(function, transformer)
149
+
150
+ return transformer.modified
151
+
152
+ def _local_cse(self, block: BasicBlock, transformer: MIRTransformer) -> None:
153
+ """Perform local CSE within a block.
154
+
155
+ Args:
156
+ block: The block to optimize.
157
+ transformer: MIR transformer.
158
+ """
159
+ available = AvailableExpressions()
160
+
161
+ for inst in list(block.instructions):
162
+ # Check if this instruction computes an expression
163
+ expr = self._get_expression(inst)
164
+
165
+ if expr:
166
+ # Check if expression is already available
167
+ existing = available.get(expr)
168
+ if existing and existing != self._get_result(inst):
169
+ # Replace with copy of existing value
170
+ result = self._get_result(inst)
171
+ if result:
172
+ new_inst = Copy(result, existing, inst.source_location)
173
+ transformer.replace_instruction(block, inst, new_inst)
174
+ self.stats["local_cse_eliminated"] = self.stats.get("local_cse_eliminated", 0) + 1
175
+ else:
176
+ # Add expression to available set
177
+ result = self._get_result(inst)
178
+ if result:
179
+ available.add(expr, result)
180
+
181
+ # Update available expressions based on side effects
182
+ self._update_available(inst, available)
183
+
184
+ def _global_cse(self, function: MIRFunction, transformer: MIRTransformer) -> None:
185
+ """Perform global CSE across blocks.
186
+
187
+ Args:
188
+ function: The function to optimize.
189
+ transformer: MIR transformer.
190
+ """
191
+ # Compute available expressions at entry of each block
192
+ block_available: dict[BasicBlock, AvailableExpressions] = {}
193
+
194
+ # Initialize with empty sets
195
+ for block in function.cfg.blocks.values():
196
+ block_available[block] = AvailableExpressions()
197
+
198
+ # Iterate until fixed point
199
+ changed = True
200
+ while changed:
201
+ changed = False
202
+
203
+ for block in function.cfg.blocks.values():
204
+ # Compute available at entry as intersection of predecessors
205
+ if block.predecessors:
206
+ # Start with copy of first predecessor
207
+ if block.predecessors[0] in block_available:
208
+ new_available = self._intersect_available(
209
+ [block_available.get(p, AvailableExpressions()) for p in block.predecessors]
210
+ )
211
+ else:
212
+ new_available = AvailableExpressions()
213
+ else:
214
+ new_available = AvailableExpressions()
215
+
216
+ # Check if changed
217
+ if self._available_changed(block_available[block], new_available):
218
+ block_available[block] = new_available
219
+ changed = True
220
+
221
+ # Compute available at exit
222
+ available = new_available.copy()
223
+ for inst in block.instructions:
224
+ expr = self._get_expression(inst)
225
+ if expr:
226
+ result = self._get_result(inst)
227
+ if result:
228
+ available.add(expr, result)
229
+ self._update_available(inst, available)
230
+
231
+ # Apply CSE based on available expressions
232
+ for block in function.cfg.blocks.values():
233
+ available = block_available[block].copy()
234
+
235
+ for inst in list(block.instructions):
236
+ expr = self._get_expression(inst)
237
+
238
+ if expr:
239
+ existing = available.get(expr)
240
+ if existing and existing != self._get_result(inst):
241
+ # Replace with copy
242
+ result = self._get_result(inst)
243
+ if result:
244
+ source_loc = inst.source_location if hasattr(inst, "source_location") else (0, 0)
245
+ new_inst = Copy(result, existing, source_loc)
246
+ transformer.replace_instruction(block, inst, new_inst)
247
+ self.stats["global_cse_eliminated"] = self.stats.get("global_cse_eliminated", 0) + 1
248
+ else:
249
+ result = self._get_result(inst)
250
+ if result:
251
+ available.add(expr, result)
252
+
253
+ self._update_available(inst, available)
254
+
255
+ def _get_expression(self, inst: MIRInstruction) -> Expression | None:
256
+ """Extract expression from an instruction.
257
+
258
+ Args:
259
+ inst: The instruction.
260
+
261
+ Returns:
262
+ The expression or None.
263
+ """
264
+ if isinstance(inst, BinaryOp):
265
+ # Normalize commutative operations
266
+ if inst.op in ["+", "*", "==", "!=", "and", "or"]:
267
+ # Sort operands for commutative ops
268
+ operands = tuple(sorted([self._normalize_value(inst.left), self._normalize_value(inst.right)], key=str))
269
+ else:
270
+ operands = (self._normalize_value(inst.left), self._normalize_value(inst.right))
271
+ return Expression(f"binary_{inst.op}", operands)
272
+
273
+ elif isinstance(inst, UnaryOp):
274
+ return Expression(f"unary_{inst.op}", (self._normalize_value(inst.operand),))
275
+
276
+ elif isinstance(inst, LoadConst):
277
+ # Constants are their own expressions
278
+ return Expression("const", (inst.constant.value, inst.constant.type))
279
+
280
+ return None
281
+
282
+ def _normalize_value(self, value: MIRValue) -> Any:
283
+ """Normalize a value for expression comparison.
284
+
285
+ Args:
286
+ value: The value to normalize.
287
+
288
+ Returns:
289
+ Normalized representation.
290
+ """
291
+ if isinstance(value, Constant):
292
+ return ("const", value.value, value.type)
293
+ elif isinstance(value, Variable):
294
+ return ("var", value.name)
295
+ elif isinstance(value, Temp):
296
+ return ("temp", value.id)
297
+ else:
298
+ return str(value)
299
+
300
+ def _get_result(self, inst: MIRInstruction) -> MIRValue | None:
301
+ """Get the result value of an instruction.
302
+
303
+ Args:
304
+ inst: The instruction.
305
+
306
+ Returns:
307
+ The result value or None.
308
+ """
309
+ defs = inst.get_defs()
310
+ if defs and len(defs) == 1:
311
+ return defs[0]
312
+ return None
313
+
314
+ def _update_available(
315
+ self,
316
+ inst: MIRInstruction,
317
+ available: AvailableExpressions,
318
+ ) -> None:
319
+ """Update available expressions after an instruction.
320
+
321
+ Args:
322
+ inst: The instruction.
323
+ available: Available expressions to update.
324
+ """
325
+ # Invalidate expressions if instruction has side effects
326
+ if isinstance(inst, StoreVar):
327
+ # Invalidate expressions using this variable
328
+ available.invalidate(inst.var)
329
+ elif isinstance(inst, Call):
330
+ # Conservative: invalidate all expressions with variables
331
+ # (calls might modify globals or have other side effects)
332
+ for value in list(available.definitions.keys()):
333
+ if isinstance(value, Variable):
334
+ available.invalidate(value)
335
+
336
+ def _intersect_available(
337
+ self,
338
+ sets: list[AvailableExpressions],
339
+ ) -> AvailableExpressions:
340
+ """Compute intersection of available expression sets.
341
+
342
+ Args:
343
+ sets: List of available expression sets.
344
+
345
+ Returns:
346
+ The intersection.
347
+ """
348
+ if not sets:
349
+ return AvailableExpressions()
350
+
351
+ # Start with first set
352
+ result = AvailableExpressions()
353
+ if not sets[0].expressions:
354
+ return result
355
+
356
+ # Find expressions available in all sets
357
+ for expr, value in sets[0].expressions.items():
358
+ available_in_all = True
359
+ for s in sets[1:]:
360
+ if expr not in s.expressions:
361
+ available_in_all = False
362
+ break
363
+ # Check if same value
364
+ if s.expressions[expr] != value:
365
+ available_in_all = False
366
+ break
367
+
368
+ if available_in_all:
369
+ result.add(expr, value)
370
+
371
+ return result
372
+
373
+ def _available_changed(
374
+ self,
375
+ old: AvailableExpressions,
376
+ new: AvailableExpressions,
377
+ ) -> bool:
378
+ """Check if available expressions changed.
379
+
380
+ Args:
381
+ old: Old available expressions.
382
+ new: New available expressions.
383
+
384
+ Returns:
385
+ True if changed.
386
+ """
387
+ if len(old.expressions) != len(new.expressions):
388
+ return True
389
+
390
+ for expr, value in old.expressions.items():
391
+ if expr not in new.expressions or new.expressions[expr] != value:
392
+ return True
393
+
394
+ return False
395
+
396
+ def finalize(self) -> None:
397
+ """Finalize the pass."""
398
+ pass
@@ -0,0 +1,288 @@
1
+ """Dead code elimination optimization pass.
2
+
3
+ This module implements dead code elimination (DCE) at the MIR level,
4
+ removing instructions and blocks that have no effect on program output.
5
+ """
6
+
7
+ from machine_dialect.mir.analyses.use_def_chains import UseDefChains
8
+ from machine_dialect.mir.basic_block import BasicBlock
9
+ from machine_dialect.mir.mir_function import MIRFunction
10
+ from machine_dialect.mir.mir_instructions import (
11
+ Assert,
12
+ Call,
13
+ ConditionalJump,
14
+ Jump,
15
+ MIRInstruction,
16
+ Print,
17
+ Return,
18
+ Scope,
19
+ StoreVar,
20
+ )
21
+ from machine_dialect.mir.mir_transformer import MIRTransformer
22
+ from machine_dialect.mir.mir_values import Temp, Variable
23
+ from machine_dialect.mir.optimization_pass import (
24
+ OptimizationPass,
25
+ PassInfo,
26
+ PassType,
27
+ PreservationLevel,
28
+ )
29
+
30
+
31
+ class DeadCodeElimination(OptimizationPass):
32
+ """Dead code elimination optimization pass."""
33
+
34
+ def get_info(self) -> PassInfo:
35
+ """Get pass information.
36
+
37
+ Returns:
38
+ Pass information.
39
+ """
40
+ return PassInfo(
41
+ name="dce",
42
+ description="Eliminate dead code and unreachable blocks",
43
+ pass_type=[PassType.OPTIMIZATION, PassType.CLEANUP],
44
+ requires=["use-def-chains"],
45
+ preserves=PreservationLevel.CFG,
46
+ )
47
+
48
+ def run_on_function(self, function: MIRFunction) -> bool:
49
+ """Run dead code elimination on a function.
50
+
51
+ Args:
52
+ function: The function to optimize.
53
+
54
+ Returns:
55
+ True if the function was modified.
56
+ """
57
+ transformer = MIRTransformer(function)
58
+
59
+ # Get use-def chains
60
+ use_def_chains: UseDefChains = self.get_analysis("use-def-chains", function)
61
+
62
+ # Phase 1: Remove dead instructions
63
+ dead_instructions = self._find_dead_instructions(function, use_def_chains)
64
+ for block, inst in dead_instructions:
65
+ transformer.remove_instruction(block, inst)
66
+ self.stats["dead_instructions_removed"] = self.stats.get("dead_instructions_removed", 0) + 1
67
+
68
+ # Phase 2: Remove dead stores
69
+ dead_stores = self._find_dead_stores(function, use_def_chains)
70
+ for block, inst in dead_stores:
71
+ # When removing a StoreVar, replace uses of the variable with the source value
72
+ if isinstance(inst, StoreVar):
73
+ # Replace all uses of the variable with the source value
74
+ transformer.replace_uses(inst.var, inst.source)
75
+ transformer.remove_instruction(block, inst)
76
+ self.stats["dead_stores_removed"] = self.stats.get("dead_stores_removed", 0) + 1
77
+
78
+ # Phase 3: Remove unreachable blocks
79
+ num_unreachable = transformer.eliminate_unreachable_blocks()
80
+ self.stats["unreachable_blocks_removed"] = num_unreachable
81
+
82
+ # Phase 4: Simplify control flow
83
+ transformer.simplify_cfg()
84
+
85
+ return transformer.modified
86
+
87
+ def _find_dead_instructions(
88
+ self,
89
+ function: MIRFunction,
90
+ use_def_chains: UseDefChains,
91
+ ) -> list[tuple[BasicBlock, MIRInstruction]]:
92
+ """Find dead instructions that can be removed.
93
+
94
+ Args:
95
+ function: The function to analyze.
96
+ use_def_chains: Use-def chain information.
97
+
98
+ Returns:
99
+ List of (block, instruction) pairs to remove.
100
+ """
101
+ dead = []
102
+
103
+ for block in function.cfg.blocks.values():
104
+ for inst in block.instructions:
105
+ # Skip instructions with side effects
106
+ if self._has_side_effects(inst):
107
+ continue
108
+
109
+ # Check if all values defined by this instruction are dead
110
+ defs = inst.get_defs()
111
+ if not defs:
112
+ continue
113
+
114
+ all_dead = True
115
+ for value in defs:
116
+ if isinstance(value, Temp | Variable):
117
+ if not use_def_chains.is_dead(value):
118
+ all_dead = False
119
+ break
120
+
121
+ if all_dead:
122
+ dead.append((block, inst))
123
+
124
+ return dead
125
+
126
+ def _find_dead_stores(
127
+ self,
128
+ function: MIRFunction,
129
+ use_def_chains: UseDefChains,
130
+ ) -> list[tuple[BasicBlock, MIRInstruction]]:
131
+ """Find dead store instructions.
132
+
133
+ Args:
134
+ function: The function to analyze.
135
+ use_def_chains: Use-def chain information.
136
+
137
+ Returns:
138
+ List of (block, instruction) pairs to remove.
139
+ """
140
+ dead_stores = []
141
+
142
+ for block in function.cfg.blocks.values():
143
+ # Track last store to each variable in the block
144
+ last_stores: dict[Variable, MIRInstruction] = {}
145
+
146
+ for inst in block.instructions:
147
+ if isinstance(inst, StoreVar):
148
+ # Check if there's a previous store to the same variable
149
+ if inst.var in last_stores:
150
+ # Previous store might be dead if not used between stores
151
+ prev_store = last_stores[inst.var]
152
+ if self._is_dead_store(
153
+ prev_store,
154
+ inst,
155
+ block,
156
+ use_def_chains,
157
+ ):
158
+ dead_stores.append((block, prev_store))
159
+
160
+ last_stores[inst.var] = inst
161
+
162
+ # Check if final stores are dead (no uses after block)
163
+ for _var, store in last_stores.items():
164
+ if self._is_store_dead_at_end(store, block, use_def_chains):
165
+ dead_stores.append((block, store))
166
+
167
+ return dead_stores
168
+
169
+ def _is_dead_store(
170
+ self,
171
+ store1: MIRInstruction,
172
+ store2: MIRInstruction,
173
+ block: BasicBlock,
174
+ use_def_chains: UseDefChains,
175
+ ) -> bool:
176
+ """Check if store1 is dead because of store2.
177
+
178
+ Args:
179
+ store1: First store instruction.
180
+ store2: Second store instruction.
181
+ block: Containing block.
182
+ use_def_chains: Use-def chain information.
183
+
184
+ Returns:
185
+ True if store1 is dead.
186
+ """
187
+ # Find instructions between the two stores
188
+ idx1 = block.instructions.index(store1)
189
+ idx2 = block.instructions.index(store2)
190
+
191
+ if idx1 >= idx2:
192
+ return False
193
+
194
+ # Check if variable is used between the stores
195
+ if isinstance(store1, StoreVar):
196
+ var = store1.var
197
+ for i in range(idx1 + 1, idx2):
198
+ inst = block.instructions[i]
199
+ if var in inst.get_uses():
200
+ return False
201
+
202
+ return True
203
+
204
+ def _is_store_dead_at_end(
205
+ self,
206
+ store: MIRInstruction,
207
+ block: BasicBlock,
208
+ use_def_chains: UseDefChains,
209
+ ) -> bool:
210
+ """Check if a store at the end of a block is dead.
211
+
212
+ Args:
213
+ store: Store instruction.
214
+ block: Containing block.
215
+ use_def_chains: Use-def chain information.
216
+
217
+ Returns:
218
+ True if the store is dead.
219
+ """
220
+ if not isinstance(store, StoreVar):
221
+ return False
222
+
223
+ # Check if variable is used after this block
224
+ var = store.var
225
+
226
+ # Check uses in successor blocks
227
+ visited = set()
228
+ worklist = list(block.successors)
229
+
230
+ while worklist:
231
+ succ = worklist.pop()
232
+ if succ in visited:
233
+ continue
234
+ visited.add(succ)
235
+
236
+ # Check phi nodes
237
+ for phi in succ.phi_nodes:
238
+ for val, pred_label in phi.incoming:
239
+ if val == var and pred_label == block.label:
240
+ return False
241
+
242
+ # Check instructions
243
+ for inst in succ.instructions:
244
+ if var in inst.get_uses():
245
+ return False
246
+ # If we see another store to the same variable, we're done
247
+ if isinstance(inst, StoreVar) and inst.var == var:
248
+ break
249
+ else:
250
+ # No store found, check successors
251
+ worklist.extend(succ.successors)
252
+
253
+ return True
254
+
255
+ def _has_side_effects(self, inst: MIRInstruction) -> bool:
256
+ """Check if an instruction has side effects.
257
+
258
+ Args:
259
+ inst: Instruction to check.
260
+
261
+ Returns:
262
+ True if the instruction has side effects.
263
+ """
264
+ # Control flow instructions
265
+ if isinstance(inst, Jump | ConditionalJump | Return):
266
+ return True
267
+
268
+ # I/O operations
269
+ if isinstance(inst, Print):
270
+ return True
271
+
272
+ # Function calls (conservative - assume all calls have side effects)
273
+ if isinstance(inst, Call):
274
+ return True
275
+
276
+ # Memory operations
277
+ if isinstance(inst, StoreVar):
278
+ return True
279
+
280
+ # Assertions and scopes
281
+ if isinstance(inst, Assert | Scope):
282
+ return True
283
+
284
+ return False
285
+
286
+ def finalize(self) -> None:
287
+ """Finalize the pass."""
288
+ pass