machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,447 @@
1
+ """Type specialization optimization pass.
2
+
3
+ This module implements type specialization to generate optimized versions
4
+ of functions for specific type combinations based on profiling data.
5
+ """
6
+
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass
9
+
10
+ from machine_dialect.mir.basic_block import BasicBlock
11
+ from machine_dialect.mir.mir_function import MIRFunction
12
+ from machine_dialect.mir.mir_instructions import (
13
+ BinaryOp,
14
+ Call,
15
+ )
16
+ from machine_dialect.mir.mir_module import MIRModule
17
+ from machine_dialect.mir.mir_types import MIRType, MIRUnionType
18
+ from machine_dialect.mir.mir_values import Constant, MIRValue, Variable
19
+ from machine_dialect.mir.optimization_pass import (
20
+ ModulePass,
21
+ PassInfo,
22
+ PassType,
23
+ PreservationLevel,
24
+ )
25
+ from machine_dialect.mir.profiling.profile_data import ProfileData
26
+
27
+
28
+ @dataclass
29
+ class TypeSignature:
30
+ """Type signature for function specialization.
31
+
32
+ Attributes:
33
+ param_types: Types of function parameters.
34
+ return_type: Return type of the function.
35
+ """
36
+
37
+ param_types: tuple[MIRType | MIRUnionType, ...]
38
+ return_type: MIRType | MIRUnionType
39
+
40
+ def __hash__(self) -> int:
41
+ """Hash for use in dictionaries."""
42
+ return hash((self.param_types, self.return_type))
43
+
44
+ def __str__(self) -> str:
45
+ """String representation."""
46
+ params = ", ".join(str(t) for t in self.param_types)
47
+ return f"({params}) -> {self.return_type}"
48
+
49
+
50
+ @dataclass
51
+ class SpecializationCandidate:
52
+ """Candidate function for type specialization.
53
+
54
+ Attributes:
55
+ function_name: Name of the function.
56
+ signature: Type signature to specialize for.
57
+ call_count: Number of calls with this signature.
58
+ benefit: Estimated benefit of specialization.
59
+ """
60
+
61
+ function_name: str
62
+ signature: TypeSignature
63
+ call_count: int
64
+ benefit: float
65
+
66
+ def specialized_name(self) -> str:
67
+ """Generate specialized function name."""
68
+ type_names = []
69
+ for t in self.signature.param_types:
70
+ if isinstance(t, MIRUnionType):
71
+ # Format union types as "union_type1_type2"
72
+ union_name = "union_" + "_".join(ut.name.lower() for ut in t.types)
73
+ type_names.append(union_name)
74
+ else:
75
+ type_names.append(t.name.lower())
76
+ type_suffix = "_".join(type_names)
77
+ return f"{self.function_name}__{type_suffix}"
78
+
79
+
80
+ class TypeSpecialization(ModulePass):
81
+ """Type specialization optimization pass.
82
+
83
+ This pass creates specialized versions of functions for frequently-used
84
+ type combinations, enabling better optimization and reducing type checks.
85
+ """
86
+
87
+ def __init__(self, profile_data: ProfileData | None = None, threshold: int = 100) -> None:
88
+ """Initialize type specialization pass.
89
+
90
+ Args:
91
+ profile_data: Optional profiling data for hot type combinations.
92
+ threshold: Minimum call count to consider specialization.
93
+ """
94
+ super().__init__()
95
+ self.profile_data = profile_data
96
+ self.threshold = threshold
97
+ self.stats = {
98
+ "functions_analyzed": 0,
99
+ "functions_specialized": 0,
100
+ "specializations_created": 0,
101
+ "type_checks_eliminated": 0,
102
+ }
103
+
104
+ # Track type signatures seen for each function
105
+ self.type_signatures: dict[str, dict[TypeSignature, int]] = defaultdict(lambda: defaultdict(int))
106
+
107
+ # Map of original to specialized functions
108
+ self.specializations: dict[str, dict[TypeSignature, str]] = defaultdict(dict)
109
+
110
+ def get_info(self) -> PassInfo:
111
+ """Get pass information.
112
+
113
+ Returns:
114
+ Pass information.
115
+ """
116
+ return PassInfo(
117
+ name="type-specialization",
118
+ description="Create type-specialized function versions",
119
+ pass_type=PassType.OPTIMIZATION,
120
+ requires=[],
121
+ preserves=PreservationLevel.NONE,
122
+ )
123
+
124
+ def finalize(self) -> None:
125
+ """Finalize the pass after running."""
126
+ pass
127
+
128
+ def run_on_module(self, module: MIRModule) -> bool:
129
+ """Run type specialization on a module.
130
+
131
+ Args:
132
+ module: The module to optimize.
133
+
134
+ Returns:
135
+ True if the module was modified.
136
+ """
137
+ modified = False
138
+
139
+ # Phase 1: Collect type signatures from call sites
140
+ self._collect_type_signatures(module)
141
+
142
+ # Phase 2: Identify specialization candidates
143
+ candidates = self._identify_candidates(module)
144
+
145
+ # Phase 3: Create specialized functions
146
+ for candidate in candidates:
147
+ if self._create_specialization(module, candidate):
148
+ modified = True
149
+ self.stats["specializations_created"] += 1
150
+
151
+ # Phase 4: Update call sites to use specialized versions
152
+ if modified:
153
+ self._update_call_sites(module)
154
+
155
+ return modified
156
+
157
+ def _collect_type_signatures(self, module: MIRModule) -> None:
158
+ """Collect type signatures from all call sites.
159
+
160
+ Args:
161
+ module: The module to analyze.
162
+ """
163
+ for function in module.functions.values():
164
+ self.stats["functions_analyzed"] += 1
165
+
166
+ for block in function.cfg.blocks.values():
167
+ for inst in block.instructions:
168
+ if isinstance(inst, Call):
169
+ # Infer types of arguments
170
+ arg_types = self._infer_arg_types(inst.args)
171
+ if arg_types and hasattr(inst.func, "name"):
172
+ # Record this type signature
173
+ func_name = inst.func.name
174
+ return_type = self._infer_return_type(inst)
175
+ signature = TypeSignature(arg_types, return_type)
176
+
177
+ # Use profile data if available
178
+ if self.profile_data and func_name in self.profile_data.functions:
179
+ profile = self.profile_data.functions[func_name]
180
+ self.type_signatures[func_name][signature] += profile.call_count
181
+ else:
182
+ self.type_signatures[func_name][signature] += 1
183
+
184
+ def _infer_arg_types(self, args: list[MIRValue]) -> tuple[MIRType | MIRUnionType, ...] | None:
185
+ """Infer types of function arguments.
186
+
187
+ Args:
188
+ args: List of argument values.
189
+
190
+ Returns:
191
+ Tuple of types or None if unable to infer.
192
+ """
193
+ types = []
194
+ for arg in args:
195
+ if isinstance(arg, Constant):
196
+ types.append(arg.type)
197
+ elif isinstance(arg, Variable):
198
+ if arg.type != MIRType.UNKNOWN:
199
+ types.append(arg.type)
200
+ else:
201
+ return None # Can't infer all types
202
+ else:
203
+ return None # Unknown value type
204
+
205
+ return tuple(types)
206
+
207
+ def _infer_return_type(self, call: Call) -> MIRType | MIRUnionType:
208
+ """Infer return type of a call.
209
+
210
+ Args:
211
+ call: The call instruction.
212
+
213
+ Returns:
214
+ The inferred return type.
215
+ """
216
+ if call.dest:
217
+ if hasattr(call.dest, "type"):
218
+ return call.dest.type
219
+ return MIRType.UNKNOWN
220
+
221
+ def _identify_candidates(self, module: MIRModule) -> list[SpecializationCandidate]:
222
+ """Identify functions worth specializing.
223
+
224
+ Args:
225
+ module: The module to analyze.
226
+
227
+ Returns:
228
+ List of specialization candidates.
229
+ """
230
+ candidates = []
231
+
232
+ for func_name, signatures in self.type_signatures.items():
233
+ # Skip if function doesn't exist in module
234
+ if func_name not in module.functions:
235
+ continue
236
+
237
+ function = module.functions[func_name]
238
+
239
+ # Skip if function is too large (avoid code bloat)
240
+ if self._count_instructions(function) > 100:
241
+ continue
242
+
243
+ # Find hot type signatures
244
+ for signature, count in signatures.items():
245
+ if count >= self.threshold:
246
+ # Calculate benefit based on:
247
+ # 1. Call frequency
248
+ # 2. Potential for optimization (numeric types benefit more)
249
+ # 3. Type check elimination
250
+ benefit = self._calculate_benefit(signature, count, function)
251
+
252
+ if benefit > 0:
253
+ candidates.append(
254
+ SpecializationCandidate(
255
+ function_name=func_name, signature=signature, call_count=count, benefit=benefit
256
+ )
257
+ )
258
+
259
+ # Sort by benefit (highest first)
260
+ candidates.sort(key=lambda c: c.benefit, reverse=True)
261
+
262
+ # Limit number of specializations to avoid code bloat
263
+ return candidates[:10]
264
+
265
+ def _count_instructions(self, function: MIRFunction) -> int:
266
+ """Count instructions in a function.
267
+
268
+ Args:
269
+ function: The function to count.
270
+
271
+ Returns:
272
+ Total instruction count.
273
+ """
274
+ count = 0
275
+ for block in function.cfg.blocks.values():
276
+ count += len(block.instructions)
277
+ return count
278
+
279
+ def _calculate_benefit(self, signature: TypeSignature, call_count: int, function: MIRFunction) -> float:
280
+ """Calculate specialization benefit.
281
+
282
+ Args:
283
+ signature: Type signature to specialize for.
284
+ call_count: Number of calls with this signature.
285
+ function: The function to specialize.
286
+
287
+ Returns:
288
+ Estimated benefit score.
289
+ """
290
+ benefit = 0.0
291
+
292
+ # Benefit from call frequency
293
+ benefit += call_count * 0.1
294
+
295
+ # Benefit from numeric types (can use specialized instructions)
296
+ for param_type in signature.param_types:
297
+ if param_type in (MIRType.INT, MIRType.FLOAT):
298
+ benefit += 20.0
299
+ elif param_type == MIRType.BOOL:
300
+ benefit += 10.0
301
+
302
+ # Benefit from eliminating type checks
303
+ type_check_count = self._count_type_checks(function)
304
+ benefit += type_check_count * 5.0
305
+
306
+ # Penalty for code size
307
+ inst_count = self._count_instructions(function)
308
+ benefit -= inst_count * 0.5
309
+
310
+ return max(0.0, benefit)
311
+
312
+ def _count_type_checks(self, function: MIRFunction) -> int:
313
+ """Count potential type checks in a function.
314
+
315
+ Args:
316
+ function: The function to analyze.
317
+
318
+ Returns:
319
+ Number of potential type checks.
320
+ """
321
+ # Simple heuristic: count operations that might need type checking
322
+ count = 0
323
+ for block in function.cfg.blocks.values():
324
+ for inst in block.instructions:
325
+ if isinstance(inst, BinaryOp):
326
+ # Binary ops often need type checking
327
+ count += 1
328
+ return count
329
+
330
+ def _create_specialization(self, module: MIRModule, candidate: SpecializationCandidate) -> bool:
331
+ """Create a specialized version of a function.
332
+
333
+ Args:
334
+ module: The module containing the function.
335
+ candidate: The specialization candidate.
336
+
337
+ Returns:
338
+ True if specialization was created.
339
+ """
340
+ original_func = module.functions.get(candidate.function_name)
341
+ if not original_func:
342
+ return False
343
+
344
+ # Clone the function
345
+ specialized_name = candidate.specialized_name()
346
+ specialized_func = self._clone_function(original_func, specialized_name)
347
+
348
+ # Apply type information to parameters
349
+ for i, param in enumerate(specialized_func.params):
350
+ if i < len(candidate.signature.param_types):
351
+ param.type = candidate.signature.param_types[i]
352
+
353
+ # Optimize the specialized function
354
+ self._optimize_specialized_function(specialized_func, candidate.signature)
355
+
356
+ # Add to module
357
+ module.add_function(specialized_func)
358
+
359
+ # Track the specialization
360
+ self.specializations[candidate.function_name][candidate.signature] = specialized_name
361
+ self.stats["functions_specialized"] += 1
362
+
363
+ return True
364
+
365
+ def _clone_function(self, original: MIRFunction, new_name: str) -> MIRFunction:
366
+ """Clone a function with a new name.
367
+
368
+ Args:
369
+ original: The function to clone.
370
+ new_name: Name for the cloned function.
371
+
372
+ Returns:
373
+ The cloned function.
374
+ """
375
+ # Create new function with same parameters
376
+ cloned = MIRFunction(new_name, [Variable(p.name, p.type) for p in original.params])
377
+
378
+ # Clone all blocks
379
+ block_mapping: dict[str, str] = {}
380
+ for block_name, block in original.cfg.blocks.items():
381
+ new_block = BasicBlock(block_name)
382
+
383
+ # Clone instructions
384
+ for inst in block.instructions:
385
+ # Deep copy the instruction
386
+ # (In a real implementation, we'd need proper deep copying)
387
+ new_block.add_instruction(inst)
388
+
389
+ cloned.cfg.add_block(new_block)
390
+ block_mapping[block_name] = block_name
391
+
392
+ # Set entry block
393
+ if original.cfg.entry_block:
394
+ cloned.cfg.entry_block = original.cfg.entry_block
395
+
396
+ return cloned
397
+
398
+ def _optimize_specialized_function(self, function: MIRFunction, signature: TypeSignature) -> None:
399
+ """Apply type-specific optimizations to specialized function.
400
+
401
+ Args:
402
+ function: The specialized function.
403
+ signature: The type signature it's specialized for.
404
+ """
405
+ # With known types, we can:
406
+ # 1. Eliminate type checks
407
+ # 2. Use specialized instructions
408
+ # 3. Constant fold more aggressively
409
+
410
+ for block in function.cfg.blocks.values():
411
+ new_instructions = []
412
+
413
+ for inst in block.instructions:
414
+ # Example: optimize integer operations
415
+ if isinstance(inst, BinaryOp):
416
+ # If we know types are integers, can use specialized ops
417
+ if all(t == MIRType.INT for t in signature.param_types):
418
+ # Could replace with specialized integer instruction
419
+ self.stats["type_checks_eliminated"] += 1
420
+
421
+ new_instructions.append(inst)
422
+
423
+ block.instructions = new_instructions
424
+
425
+ def _update_call_sites(self, module: MIRModule) -> None:
426
+ """Update call sites to use specialized versions.
427
+
428
+ Args:
429
+ module: The module to update.
430
+ """
431
+ for function in module.functions.values():
432
+ for block in function.cfg.blocks.values():
433
+ for inst in block.instructions:
434
+ if isinstance(inst, Call) and hasattr(inst.func, "name"):
435
+ func_name = inst.func.name
436
+
437
+ # Check if we have a specialization for this call
438
+ if func_name in self.specializations:
439
+ arg_types = self._infer_arg_types(inst.args)
440
+ if arg_types:
441
+ return_type = self._infer_return_type(inst)
442
+ signature = TypeSignature(arg_types, return_type)
443
+
444
+ if signature in self.specializations[func_name]:
445
+ # Update to use specialized version
446
+ specialized_name = self.specializations[func_name][signature]
447
+ inst.func.name = specialized_name