machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,634 @@
1
+ """Constant propagation and folding optimization pass.
2
+
3
+ This module implements constant propagation at the MIR level, replacing
4
+ variable uses with constants and folding constant expressions.
5
+ """
6
+
7
+ from typing import Any
8
+
9
+ from machine_dialect.mir.mir_function import MIRFunction
10
+ from machine_dialect.mir.mir_instructions import (
11
+ BinaryOp,
12
+ ConditionalJump,
13
+ Copy,
14
+ Jump,
15
+ LoadConst,
16
+ LoadVar,
17
+ MIRInstruction,
18
+ Phi,
19
+ StoreVar,
20
+ UnaryOp,
21
+ )
22
+ from machine_dialect.mir.mir_transformer import MIRTransformer
23
+ from machine_dialect.mir.mir_types import MIRType
24
+ from machine_dialect.mir.mir_values import Constant, MIRValue, Temp, Variable
25
+ from machine_dialect.mir.optimization_pass import (
26
+ OptimizationPass,
27
+ PassInfo,
28
+ PassType,
29
+ PreservationLevel,
30
+ )
31
+
32
+
33
+ class ConstantLattice:
34
+ """Lattice for constant propagation analysis.
35
+
36
+ Values can be:
37
+ - TOP: Unknown/uninitialized
38
+ - Constant value: Known constant
39
+ - BOTTOM: Not a constant (conflicting values)
40
+ """
41
+
42
+ TOP = object() # Unknown
43
+ BOTTOM = object() # Not constant
44
+
45
+ def __init__(self) -> None:
46
+ """Initialize the lattice."""
47
+ self.values: dict[MIRValue, Any] = {}
48
+
49
+ def get(self, value: MIRValue) -> Any:
50
+ """Get the lattice value.
51
+
52
+ Args:
53
+ value: MIR value to query.
54
+
55
+ Returns:
56
+ Lattice value (TOP, BOTTOM, or constant).
57
+ """
58
+ return self.values.get(value, self.TOP)
59
+
60
+ def set(self, value: MIRValue, lattice_val: Any) -> bool:
61
+ """Set the lattice value.
62
+
63
+ Args:
64
+ value: MIR value to set.
65
+ lattice_val: New lattice value.
66
+
67
+ Returns:
68
+ True if the value changed.
69
+ """
70
+ old = self.get(value)
71
+ if old == lattice_val:
72
+ return False
73
+
74
+ if old == self.BOTTOM:
75
+ return False # Can't change from BOTTOM
76
+
77
+ if lattice_val == self.TOP:
78
+ return False # Can't go back to TOP
79
+
80
+ if old == self.TOP:
81
+ self.values[value] = lattice_val
82
+ return True
83
+
84
+ # Both are constants - must be same or go to BOTTOM
85
+ if old != lattice_val:
86
+ self.values[value] = self.BOTTOM
87
+ return True
88
+
89
+ return False
90
+
91
+ def is_constant(self, value: MIRValue) -> bool:
92
+ """Check if a value is a known constant.
93
+
94
+ Args:
95
+ value: MIR value to check.
96
+
97
+ Returns:
98
+ True if the value is a known constant.
99
+ """
100
+ val = self.get(value)
101
+ return bool(val != self.TOP and val != self.BOTTOM)
102
+
103
+ def get_constant(self, value: MIRValue) -> Any:
104
+ """Get the constant value.
105
+
106
+ Args:
107
+ value: MIR value to query.
108
+
109
+ Returns:
110
+ The constant value or None.
111
+ """
112
+ val = self.get(value)
113
+ if val != self.TOP and val != self.BOTTOM:
114
+ return val
115
+ return None
116
+
117
+
118
+ class ConstantPropagation(OptimizationPass):
119
+ """Constant propagation and folding optimization pass."""
120
+
121
+ def get_info(self) -> PassInfo:
122
+ """Get pass information.
123
+
124
+ Returns:
125
+ Pass information.
126
+ """
127
+ return PassInfo(
128
+ name="constant-propagation",
129
+ description="Propagate constants and fold constant expressions",
130
+ pass_type=PassType.OPTIMIZATION,
131
+ requires=[],
132
+ preserves=PreservationLevel.CFG,
133
+ )
134
+
135
+ def run_on_function(self, function: MIRFunction) -> bool:
136
+ """Run constant propagation on a function.
137
+
138
+ Args:
139
+ function: The function to optimize.
140
+
141
+ Returns:
142
+ True if the function was modified.
143
+ """
144
+ # Perform constant propagation analysis
145
+ lattice = self._analyze_constants(function)
146
+
147
+ # Apply transformations based on analysis
148
+ transformer = MIRTransformer(function)
149
+
150
+ # Replace uses with constants
151
+ for value, const_val in lattice.values.items():
152
+ if const_val != ConstantLattice.TOP and const_val != ConstantLattice.BOTTOM:
153
+ if isinstance(value, Variable | Temp):
154
+ # Create constant
155
+ const = Constant(const_val, self._infer_type(const_val))
156
+ count = transformer.replace_uses(value, const)
157
+ self.stats["constants_propagated"] = self.stats.get("constants_propagated", 0) + count
158
+
159
+ # Fold constant expressions
160
+ self._fold_constant_expressions(function, transformer)
161
+
162
+ # Simplify control flow with known conditions
163
+ self._simplify_control_flow(function, transformer, lattice)
164
+
165
+ return transformer.modified
166
+
167
+ def _analyze_constants(self, function: MIRFunction) -> ConstantLattice:
168
+ """Analyze function to find constant values using iterative dataflow.
169
+
170
+ This implements a worklist algorithm that converges to a fixed point,
171
+ properly handling loops and cross-block propagation.
172
+
173
+ Args:
174
+ function: Function to analyze.
175
+
176
+ Returns:
177
+ Constant lattice with analysis results.
178
+ """
179
+ lattice = ConstantLattice()
180
+ worklist = set()
181
+ block_lattices: dict[Any, ConstantLattice] = {}
182
+
183
+ # Initialize all blocks' local lattices
184
+ for block in function.cfg.blocks.values():
185
+ block_lattices[block] = ConstantLattice()
186
+ worklist.add(block)
187
+
188
+ # Fixed-point iteration
189
+ iteration_count = 0
190
+ max_iterations = 100 # Prevent infinite loops
191
+
192
+ while worklist and iteration_count < max_iterations:
193
+ iteration_count += 1
194
+ block = worklist.pop()
195
+
196
+ # Merge lattice values from predecessors
197
+ changed = self._merge_predecessors(block, block_lattices, lattice)
198
+
199
+ # Process phi nodes with proper meet operation
200
+ for phi in block.phi_nodes:
201
+ if self._process_phi(phi, block, block_lattices, lattice):
202
+ changed = True
203
+
204
+ # Process instructions
205
+ for inst in block.instructions:
206
+ if self._process_instruction(inst, lattice):
207
+ changed = True
208
+
209
+ # If this block changed, add successors to worklist
210
+ if changed:
211
+ worklist.update(block.successors)
212
+
213
+ return lattice
214
+
215
+ def _merge_predecessors(
216
+ self, block: Any, block_lattices: dict[Any, ConstantLattice], lattice: ConstantLattice
217
+ ) -> bool:
218
+ """Merge lattice values from predecessor blocks.
219
+
220
+ Args:
221
+ block: Current block.
222
+ block_lattices: Per-block lattice states.
223
+ lattice: Global lattice.
224
+
225
+ Returns:
226
+ True if any values changed.
227
+ """
228
+ changed = False
229
+
230
+ # For each predecessor, merge its output values
231
+ for pred in block.predecessors:
232
+ pred_lattice = block_lattices.get(pred)
233
+ if pred_lattice:
234
+ for value, const_val in pred_lattice.values.items():
235
+ if lattice.set(value, const_val):
236
+ changed = True
237
+
238
+ return changed
239
+
240
+ def _process_phi(
241
+ self, phi: Phi, block: Any, block_lattices: dict[Any, ConstantLattice], lattice: ConstantLattice
242
+ ) -> bool:
243
+ """Process phi node with improved cross-block analysis.
244
+
245
+ Args:
246
+ phi: Phi node to process.
247
+ block: Current block.
248
+ block_lattices: Per-block lattice states.
249
+ lattice: Constant lattice.
250
+
251
+ Returns:
252
+ True if the phi's value changed.
253
+ """
254
+ # Collect all incoming values with proper lattice meet
255
+ result = ConstantLattice.TOP
256
+ all_same = True
257
+ first_val = None
258
+
259
+ for value, pred_block in phi.incoming:
260
+ if isinstance(value, Constant):
261
+ val = value.value
262
+ else:
263
+ # Look up value from predecessor's lattice
264
+ pred_lattice = block_lattices.get(pred_block, lattice)
265
+ val = pred_lattice.get(value) if pred_lattice else lattice.get(value)
266
+
267
+ if val == ConstantLattice.BOTTOM:
268
+ result = ConstantLattice.BOTTOM
269
+ break
270
+ elif val != ConstantLattice.TOP:
271
+ if first_val is None:
272
+ first_val = val
273
+ elif first_val != val:
274
+ all_same = False
275
+ result = ConstantLattice.BOTTOM
276
+ break
277
+
278
+ # If all values are the same constant, propagate it
279
+ if all_same and first_val is not None:
280
+ result = first_val
281
+
282
+ if result != ConstantLattice.TOP:
283
+ return lattice.set(phi.dest, result)
284
+ return False
285
+
286
+ def _process_instruction(self, inst: MIRInstruction, lattice: ConstantLattice) -> bool:
287
+ """Process instruction with change tracking.
288
+
289
+ Args:
290
+ inst: Instruction to process.
291
+ lattice: Constant lattice.
292
+
293
+ Returns:
294
+ True if any value changed.
295
+ """
296
+ changed = False
297
+
298
+ if isinstance(inst, LoadConst):
299
+ # LoadConst defines a constant
300
+ if lattice.set(inst.dest, inst.constant.value):
301
+ changed = True
302
+
303
+ elif isinstance(inst, Copy):
304
+ # Copy propagates constants
305
+ if isinstance(inst.source, Constant):
306
+ if lattice.set(inst.dest, inst.source.value):
307
+ changed = True
308
+ elif isinstance(inst.source, Variable | Temp):
309
+ val = lattice.get(inst.source)
310
+ if val != ConstantLattice.TOP:
311
+ if lattice.set(inst.dest, val):
312
+ changed = True
313
+
314
+ elif isinstance(inst, StoreVar):
315
+ # Store propagates constants to variables
316
+ if isinstance(inst.source, Constant):
317
+ if lattice.set(inst.var, inst.source.value):
318
+ changed = True
319
+ elif isinstance(inst.source, Variable | Temp):
320
+ val = lattice.get(inst.source)
321
+ if val != ConstantLattice.TOP:
322
+ if lattice.set(inst.var, val):
323
+ changed = True
324
+
325
+ elif isinstance(inst, LoadVar):
326
+ # Load from variable
327
+ val = lattice.get(inst.var)
328
+ if val != ConstantLattice.TOP:
329
+ if lattice.set(inst.dest, val):
330
+ changed = True
331
+
332
+ elif isinstance(inst, BinaryOp):
333
+ # Try to fold binary operations
334
+ left_val = self._get_value(inst.left, lattice)
335
+ right_val = self._get_value(inst.right, lattice)
336
+
337
+ if left_val is not None and right_val is not None:
338
+ result = self._fold_binary_op(inst.op, left_val, right_val)
339
+ if result is not None:
340
+ if lattice.set(inst.dest, result):
341
+ changed = True
342
+
343
+ elif isinstance(inst, UnaryOp):
344
+ # Try to fold unary operations
345
+ operand_val = self._get_value(inst.operand, lattice)
346
+
347
+ if operand_val is not None:
348
+ result = self._fold_unary_op(inst.op, operand_val)
349
+ if result is not None:
350
+ if lattice.set(inst.dest, result):
351
+ changed = True
352
+
353
+ return changed
354
+
355
+ def _process_binary_op(self, inst: BinaryOp, lattice: ConstantLattice) -> None:
356
+ """Process a binary operation for constant folding.
357
+
358
+ Args:
359
+ inst: Binary operation instruction.
360
+ lattice: Constant lattice.
361
+ """
362
+ # Get operand values
363
+ left_val = self._get_value(inst.left, lattice)
364
+ right_val = self._get_value(inst.right, lattice)
365
+
366
+ if left_val is None or right_val is None:
367
+ return
368
+
369
+ # Try to fold the operation
370
+ result = self._fold_binary_op(inst.op, left_val, right_val)
371
+ if result is not None:
372
+ lattice.set(inst.dest, result)
373
+
374
+ def _process_unary_op(self, inst: UnaryOp, lattice: ConstantLattice) -> None:
375
+ """Process a unary operation for constant folding.
376
+
377
+ Args:
378
+ inst: Unary operation instruction.
379
+ lattice: Constant lattice.
380
+ """
381
+ # Get operand value
382
+ operand_val = self._get_value(inst.operand, lattice)
383
+
384
+ if operand_val is None:
385
+ return
386
+
387
+ # Try to fold the operation
388
+ result = self._fold_unary_op(inst.op, operand_val)
389
+ if result is not None:
390
+ lattice.set(inst.dest, result)
391
+
392
+ def _get_value(self, value: MIRValue, lattice: ConstantLattice) -> Any:
393
+ """Get the constant value of a MIR value.
394
+
395
+ Args:
396
+ value: MIR value to evaluate.
397
+ lattice: Constant lattice.
398
+
399
+ Returns:
400
+ The constant value or None.
401
+ """
402
+ if isinstance(value, Constant):
403
+ return value.value
404
+ elif isinstance(value, Variable | Temp):
405
+ val = lattice.get(value)
406
+ if val != ConstantLattice.TOP and val != ConstantLattice.BOTTOM:
407
+ return val
408
+ return None
409
+
410
+ def _fold_binary_op(self, op: str, left: Any, right: Any) -> Any:
411
+ """Fold a binary operation with constant operands.
412
+
413
+ Args:
414
+ op: Operation string.
415
+ left: Left operand value.
416
+ right: Right operand value.
417
+
418
+ Returns:
419
+ The folded result or None.
420
+ """
421
+ try:
422
+ # Arithmetic operations
423
+ if op == "+":
424
+ # Handle string concatenation and numeric addition
425
+ if isinstance(left, str) or isinstance(right, str):
426
+ return str(left) + str(right)
427
+ return left + right
428
+ elif op == "-":
429
+ return left - right
430
+ elif op == "*":
431
+ return left * right
432
+ elif op == "/":
433
+ if right != 0:
434
+ # Integer division for integers
435
+ if isinstance(left, int) and isinstance(right, int):
436
+ return left // right
437
+ return left / right
438
+ elif op == "//":
439
+ if right != 0:
440
+ return left // right
441
+ elif op == "%":
442
+ if right != 0:
443
+ return left % right
444
+ elif op == "**":
445
+ return left**right
446
+
447
+ # Comparison operations
448
+ elif op == "<":
449
+ return left < right
450
+ elif op == "<=":
451
+ return left <= right
452
+ elif op == ">":
453
+ return left > right
454
+ elif op == ">=":
455
+ return left >= right
456
+ elif op == "==":
457
+ return left == right
458
+ elif op == "!=":
459
+ return left != right
460
+ elif op == "===": # Strict equality
461
+ return left is right
462
+ elif op == "!==": # Strict inequality
463
+ return left is not right
464
+
465
+ # Logical operations
466
+ elif op == "and":
467
+ return left and right
468
+ elif op == "or":
469
+ return left or right
470
+
471
+ # Bitwise operations
472
+ elif op == "&":
473
+ if isinstance(left, int) and isinstance(right, int):
474
+ return left & right
475
+ elif op == "|":
476
+ if isinstance(left, int) and isinstance(right, int):
477
+ return left | right
478
+ elif op == "^":
479
+ if isinstance(left, int) and isinstance(right, int):
480
+ return left ^ right
481
+ elif op == "<<":
482
+ if isinstance(left, int) and isinstance(right, int):
483
+ return left << right
484
+ elif op == ">>":
485
+ if isinstance(left, int) and isinstance(right, int):
486
+ return left >> right
487
+
488
+ # String operations
489
+ elif op == "in":
490
+ return left in right
491
+ elif op == "not in":
492
+ return left not in right
493
+
494
+ except (TypeError, ValueError, ZeroDivisionError, OverflowError):
495
+ pass
496
+ return None
497
+
498
+ def _fold_unary_op(self, op: str, operand: Any) -> Any:
499
+ """Fold a unary operation with constant operand.
500
+
501
+ Args:
502
+ op: Operation string.
503
+ operand: Operand value.
504
+
505
+ Returns:
506
+ The folded result or None.
507
+ """
508
+ try:
509
+ if op == "-":
510
+ return -operand
511
+ elif op == "not":
512
+ return not operand
513
+ elif op == "+":
514
+ return +operand
515
+ elif op == "~": # Bitwise NOT
516
+ if isinstance(operand, int):
517
+ return ~operand
518
+ elif op == "abs":
519
+ return abs(operand)
520
+ elif op == "len":
521
+ if hasattr(operand, "__len__"):
522
+ return len(operand)
523
+ except (TypeError, ValueError):
524
+ pass
525
+ return None
526
+
527
+ def _fold_constant_expressions(
528
+ self,
529
+ function: MIRFunction,
530
+ transformer: MIRTransformer,
531
+ ) -> None:
532
+ """Fold constant expressions in the function.
533
+
534
+ Args:
535
+ function: Function to optimize.
536
+ transformer: MIR transformer.
537
+ """
538
+ for block in function.cfg.blocks.values():
539
+ for inst in list(block.instructions):
540
+ if isinstance(inst, BinaryOp):
541
+ # Try to fold binary operation
542
+ left_val = self._get_constant_value(inst.left)
543
+ right_val = self._get_constant_value(inst.right)
544
+
545
+ if left_val is not None and right_val is not None:
546
+ result = self._fold_binary_op(inst.op, left_val, right_val)
547
+ if result is not None:
548
+ # Replace with LoadConst
549
+ const = Constant(result, self._infer_type(result))
550
+ new_inst = LoadConst(inst.dest, const, inst.source_location)
551
+ transformer.replace_instruction(block, inst, new_inst)
552
+ self.stats["expressions_folded"] = self.stats.get("expressions_folded", 0) + 1
553
+
554
+ elif isinstance(inst, UnaryOp):
555
+ # Try to fold unary operation
556
+ operand_val = self._get_constant_value(inst.operand)
557
+
558
+ if operand_val is not None:
559
+ result = self._fold_unary_op(inst.op, operand_val)
560
+ if result is not None:
561
+ # Replace with LoadConst
562
+ const = Constant(result, self._infer_type(result))
563
+ new_inst = LoadConst(inst.dest, const, inst.source_location)
564
+ transformer.replace_instruction(block, inst, new_inst)
565
+ self.stats["expressions_folded"] = self.stats.get("expressions_folded", 0) + 1
566
+
567
+ def _simplify_control_flow(
568
+ self,
569
+ function: MIRFunction,
570
+ transformer: MIRTransformer,
571
+ lattice: ConstantLattice,
572
+ ) -> None:
573
+ """Simplify control flow with known conditions.
574
+
575
+ Args:
576
+ function: Function to optimize.
577
+ transformer: MIR transformer.
578
+ lattice: Constant lattice.
579
+ """
580
+ for block in list(function.cfg.blocks.values()):
581
+ term = block.get_terminator()
582
+ if isinstance(term, ConditionalJump):
583
+ # Check if condition is constant
584
+ cond_val = self._get_value(term.condition, lattice)
585
+ if cond_val is not None:
586
+ # Replace with unconditional jump
587
+ if cond_val:
588
+ new_jump = Jump(term.true_label, term.source_location)
589
+ elif term.false_label is not None:
590
+ new_jump = Jump(term.false_label, term.source_location)
591
+ else:
592
+ continue # Can't simplify without false label
593
+
594
+ transformer.replace_instruction(block, term, new_jump)
595
+ self.stats["branches_simplified"] = self.stats.get("branches_simplified", 0) + 1
596
+
597
+ def _get_constant_value(self, value: MIRValue) -> Any:
598
+ """Get the constant value if it's a constant.
599
+
600
+ Args:
601
+ value: MIR value.
602
+
603
+ Returns:
604
+ Constant value or None.
605
+ """
606
+ if isinstance(value, Constant):
607
+ return value.value
608
+ return None
609
+
610
+ def _infer_type(self, value: Any) -> MIRType:
611
+ """Infer MIR type from a Python value.
612
+
613
+ Args:
614
+ value: Python value.
615
+
616
+ Returns:
617
+ Inferred MIR type.
618
+ """
619
+ if isinstance(value, bool):
620
+ return MIRType.BOOL
621
+ elif isinstance(value, int):
622
+ return MIRType.INT
623
+ elif isinstance(value, float):
624
+ return MIRType.FLOAT
625
+ elif isinstance(value, str):
626
+ return MIRType.STRING
627
+ elif value is None:
628
+ return MIRType.EMPTY
629
+ else:
630
+ return MIRType.UNKNOWN
631
+
632
+ def finalize(self) -> None:
633
+ """Finalize the pass."""
634
+ pass