machine-dialect 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. machine_dialect/__main__.py +667 -0
  2. machine_dialect/agent/__init__.py +5 -0
  3. machine_dialect/agent/agent.py +360 -0
  4. machine_dialect/ast/__init__.py +95 -0
  5. machine_dialect/ast/ast_node.py +35 -0
  6. machine_dialect/ast/call_expression.py +82 -0
  7. machine_dialect/ast/dict_extraction.py +60 -0
  8. machine_dialect/ast/expressions.py +439 -0
  9. machine_dialect/ast/literals.py +309 -0
  10. machine_dialect/ast/program.py +35 -0
  11. machine_dialect/ast/statements.py +1433 -0
  12. machine_dialect/ast/tests/test_ast_string_representation.py +62 -0
  13. machine_dialect/ast/tests/test_boolean_literal.py +29 -0
  14. machine_dialect/ast/tests/test_collection_hir.py +138 -0
  15. machine_dialect/ast/tests/test_define_statement.py +142 -0
  16. machine_dialect/ast/tests/test_desugar.py +541 -0
  17. machine_dialect/ast/tests/test_foreach_desugar.py +245 -0
  18. machine_dialect/cfg/__init__.py +6 -0
  19. machine_dialect/cfg/config.py +156 -0
  20. machine_dialect/cfg/examples.py +221 -0
  21. machine_dialect/cfg/generate_with_ai.py +187 -0
  22. machine_dialect/cfg/openai_generation.py +200 -0
  23. machine_dialect/cfg/parser.py +94 -0
  24. machine_dialect/cfg/tests/__init__.py +1 -0
  25. machine_dialect/cfg/tests/test_cfg_parser.py +252 -0
  26. machine_dialect/cfg/tests/test_config.py +188 -0
  27. machine_dialect/cfg/tests/test_examples.py +391 -0
  28. machine_dialect/cfg/tests/test_generate_with_ai.py +354 -0
  29. machine_dialect/cfg/tests/test_openai_generation.py +256 -0
  30. machine_dialect/codegen/__init__.py +5 -0
  31. machine_dialect/codegen/bytecode_module.py +89 -0
  32. machine_dialect/codegen/bytecode_serializer.py +300 -0
  33. machine_dialect/codegen/opcodes.py +101 -0
  34. machine_dialect/codegen/register_codegen.py +1996 -0
  35. machine_dialect/codegen/symtab.py +208 -0
  36. machine_dialect/codegen/tests/__init__.py +1 -0
  37. machine_dialect/codegen/tests/test_array_operations_codegen.py +295 -0
  38. machine_dialect/codegen/tests/test_bytecode_serializer.py +185 -0
  39. machine_dialect/codegen/tests/test_register_codegen_ssa.py +324 -0
  40. machine_dialect/codegen/tests/test_symtab.py +418 -0
  41. machine_dialect/codegen/vm_serializer.py +621 -0
  42. machine_dialect/compiler/__init__.py +18 -0
  43. machine_dialect/compiler/compiler.py +197 -0
  44. machine_dialect/compiler/config.py +149 -0
  45. machine_dialect/compiler/context.py +149 -0
  46. machine_dialect/compiler/phases/__init__.py +19 -0
  47. machine_dialect/compiler/phases/bytecode_optimization.py +90 -0
  48. machine_dialect/compiler/phases/codegen.py +40 -0
  49. machine_dialect/compiler/phases/hir_generation.py +39 -0
  50. machine_dialect/compiler/phases/mir_generation.py +86 -0
  51. machine_dialect/compiler/phases/optimization.py +110 -0
  52. machine_dialect/compiler/phases/parsing.py +39 -0
  53. machine_dialect/compiler/pipeline.py +143 -0
  54. machine_dialect/compiler/tests/__init__.py +1 -0
  55. machine_dialect/compiler/tests/test_compiler.py +568 -0
  56. machine_dialect/compiler/vm_runner.py +173 -0
  57. machine_dialect/errors/__init__.py +32 -0
  58. machine_dialect/errors/exceptions.py +369 -0
  59. machine_dialect/errors/messages.py +82 -0
  60. machine_dialect/errors/tests/__init__.py +0 -0
  61. machine_dialect/errors/tests/test_expected_token_errors.py +188 -0
  62. machine_dialect/errors/tests/test_name_errors.py +118 -0
  63. machine_dialect/helpers/__init__.py +0 -0
  64. machine_dialect/helpers/stopwords.py +225 -0
  65. machine_dialect/helpers/validators.py +30 -0
  66. machine_dialect/lexer/__init__.py +9 -0
  67. machine_dialect/lexer/constants.py +23 -0
  68. machine_dialect/lexer/lexer.py +907 -0
  69. machine_dialect/lexer/tests/__init__.py +0 -0
  70. machine_dialect/lexer/tests/helpers.py +86 -0
  71. machine_dialect/lexer/tests/test_apostrophe_identifiers.py +122 -0
  72. machine_dialect/lexer/tests/test_backtick_identifiers.py +140 -0
  73. machine_dialect/lexer/tests/test_boolean_literals.py +108 -0
  74. machine_dialect/lexer/tests/test_case_insensitive_keywords.py +188 -0
  75. machine_dialect/lexer/tests/test_comments.py +200 -0
  76. machine_dialect/lexer/tests/test_double_asterisk_keywords.py +127 -0
  77. machine_dialect/lexer/tests/test_lexer_position.py +113 -0
  78. machine_dialect/lexer/tests/test_list_tokens.py +282 -0
  79. machine_dialect/lexer/tests/test_stopwords.py +80 -0
  80. machine_dialect/lexer/tests/test_strict_equality.py +129 -0
  81. machine_dialect/lexer/tests/test_token.py +41 -0
  82. machine_dialect/lexer/tests/test_tokenization.py +294 -0
  83. machine_dialect/lexer/tests/test_underscore_literals.py +343 -0
  84. machine_dialect/lexer/tests/test_url_literals.py +169 -0
  85. machine_dialect/lexer/tokens.py +487 -0
  86. machine_dialect/linter/__init__.py +10 -0
  87. machine_dialect/linter/__main__.py +144 -0
  88. machine_dialect/linter/linter.py +154 -0
  89. machine_dialect/linter/rules/__init__.py +8 -0
  90. machine_dialect/linter/rules/base.py +112 -0
  91. machine_dialect/linter/rules/statement_termination.py +99 -0
  92. machine_dialect/linter/tests/__init__.py +1 -0
  93. machine_dialect/linter/tests/mdrules/__init__.py +0 -0
  94. machine_dialect/linter/tests/mdrules/test_md101_statement_termination.py +181 -0
  95. machine_dialect/linter/tests/test_linter.py +81 -0
  96. machine_dialect/linter/tests/test_rules.py +110 -0
  97. machine_dialect/linter/tests/test_violations.py +71 -0
  98. machine_dialect/linter/violations.py +51 -0
  99. machine_dialect/mir/__init__.py +69 -0
  100. machine_dialect/mir/analyses/__init__.py +20 -0
  101. machine_dialect/mir/analyses/alias_analysis.py +315 -0
  102. machine_dialect/mir/analyses/dominance_analysis.py +49 -0
  103. machine_dialect/mir/analyses/escape_analysis.py +286 -0
  104. machine_dialect/mir/analyses/loop_analysis.py +272 -0
  105. machine_dialect/mir/analyses/tests/test_type_analysis.py +736 -0
  106. machine_dialect/mir/analyses/type_analysis.py +448 -0
  107. machine_dialect/mir/analyses/use_def_chains.py +232 -0
  108. machine_dialect/mir/basic_block.py +385 -0
  109. machine_dialect/mir/dataflow.py +445 -0
  110. machine_dialect/mir/debug_info.py +208 -0
  111. machine_dialect/mir/hir_to_mir.py +1738 -0
  112. machine_dialect/mir/mir_dumper.py +366 -0
  113. machine_dialect/mir/mir_function.py +167 -0
  114. machine_dialect/mir/mir_instructions.py +1877 -0
  115. machine_dialect/mir/mir_interpreter.py +556 -0
  116. machine_dialect/mir/mir_module.py +225 -0
  117. machine_dialect/mir/mir_printer.py +480 -0
  118. machine_dialect/mir/mir_transformer.py +410 -0
  119. machine_dialect/mir/mir_types.py +367 -0
  120. machine_dialect/mir/mir_validation.py +455 -0
  121. machine_dialect/mir/mir_values.py +268 -0
  122. machine_dialect/mir/optimization_config.py +233 -0
  123. machine_dialect/mir/optimization_pass.py +251 -0
  124. machine_dialect/mir/optimization_pipeline.py +355 -0
  125. machine_dialect/mir/optimizations/__init__.py +84 -0
  126. machine_dialect/mir/optimizations/algebraic_simplification.py +733 -0
  127. machine_dialect/mir/optimizations/branch_prediction.py +372 -0
  128. machine_dialect/mir/optimizations/constant_propagation.py +634 -0
  129. machine_dialect/mir/optimizations/cse.py +398 -0
  130. machine_dialect/mir/optimizations/dce.py +288 -0
  131. machine_dialect/mir/optimizations/inlining.py +551 -0
  132. machine_dialect/mir/optimizations/jump_threading.py +487 -0
  133. machine_dialect/mir/optimizations/licm.py +405 -0
  134. machine_dialect/mir/optimizations/loop_unrolling.py +366 -0
  135. machine_dialect/mir/optimizations/strength_reduction.py +422 -0
  136. machine_dialect/mir/optimizations/tail_call.py +207 -0
  137. machine_dialect/mir/optimizations/tests/test_loop_unrolling.py +483 -0
  138. machine_dialect/mir/optimizations/type_narrowing.py +397 -0
  139. machine_dialect/mir/optimizations/type_specialization.py +447 -0
  140. machine_dialect/mir/optimizations/type_specific.py +906 -0
  141. machine_dialect/mir/optimize_mir.py +89 -0
  142. machine_dialect/mir/pass_manager.py +391 -0
  143. machine_dialect/mir/profiling/__init__.py +26 -0
  144. machine_dialect/mir/profiling/profile_collector.py +318 -0
  145. machine_dialect/mir/profiling/profile_data.py +372 -0
  146. machine_dialect/mir/profiling/profile_reader.py +272 -0
  147. machine_dialect/mir/profiling/profile_writer.py +226 -0
  148. machine_dialect/mir/register_allocation.py +302 -0
  149. machine_dialect/mir/reporting/__init__.py +17 -0
  150. machine_dialect/mir/reporting/optimization_reporter.py +314 -0
  151. machine_dialect/mir/reporting/report_formatter.py +289 -0
  152. machine_dialect/mir/ssa_construction.py +342 -0
  153. machine_dialect/mir/tests/__init__.py +1 -0
  154. machine_dialect/mir/tests/test_algebraic_associativity.py +204 -0
  155. machine_dialect/mir/tests/test_algebraic_complex_patterns.py +221 -0
  156. machine_dialect/mir/tests/test_algebraic_division.py +126 -0
  157. machine_dialect/mir/tests/test_algebraic_simplification.py +863 -0
  158. machine_dialect/mir/tests/test_basic_block.py +425 -0
  159. machine_dialect/mir/tests/test_branch_prediction.py +459 -0
  160. machine_dialect/mir/tests/test_call_lowering.py +168 -0
  161. machine_dialect/mir/tests/test_collection_lowering.py +604 -0
  162. machine_dialect/mir/tests/test_cross_block_constant_propagation.py +255 -0
  163. machine_dialect/mir/tests/test_custom_passes.py +166 -0
  164. machine_dialect/mir/tests/test_debug_info.py +285 -0
  165. machine_dialect/mir/tests/test_dict_extraction_lowering.py +192 -0
  166. machine_dialect/mir/tests/test_dictionary_lowering.py +299 -0
  167. machine_dialect/mir/tests/test_double_negation.py +231 -0
  168. machine_dialect/mir/tests/test_escape_analysis.py +233 -0
  169. machine_dialect/mir/tests/test_hir_to_mir.py +465 -0
  170. machine_dialect/mir/tests/test_hir_to_mir_complete.py +389 -0
  171. machine_dialect/mir/tests/test_hir_to_mir_simple.py +130 -0
  172. machine_dialect/mir/tests/test_inlining.py +435 -0
  173. machine_dialect/mir/tests/test_licm.py +472 -0
  174. machine_dialect/mir/tests/test_mir_dumper.py +313 -0
  175. machine_dialect/mir/tests/test_mir_instructions.py +445 -0
  176. machine_dialect/mir/tests/test_mir_module.py +860 -0
  177. machine_dialect/mir/tests/test_mir_printer.py +387 -0
  178. machine_dialect/mir/tests/test_mir_types.py +123 -0
  179. machine_dialect/mir/tests/test_mir_types_enhanced.py +132 -0
  180. machine_dialect/mir/tests/test_mir_validation.py +378 -0
  181. machine_dialect/mir/tests/test_mir_values.py +168 -0
  182. machine_dialect/mir/tests/test_one_based_indexing.py +202 -0
  183. machine_dialect/mir/tests/test_optimization_helpers.py +60 -0
  184. machine_dialect/mir/tests/test_optimization_pipeline.py +554 -0
  185. machine_dialect/mir/tests/test_optimization_reporter.py +318 -0
  186. machine_dialect/mir/tests/test_pass_manager.py +294 -0
  187. machine_dialect/mir/tests/test_pass_registration.py +64 -0
  188. machine_dialect/mir/tests/test_profiling.py +356 -0
  189. machine_dialect/mir/tests/test_register_allocation.py +307 -0
  190. machine_dialect/mir/tests/test_report_formatters.py +372 -0
  191. machine_dialect/mir/tests/test_ssa_construction.py +433 -0
  192. machine_dialect/mir/tests/test_tail_call.py +236 -0
  193. machine_dialect/mir/tests/test_type_annotated_instructions.py +192 -0
  194. machine_dialect/mir/tests/test_type_narrowing.py +277 -0
  195. machine_dialect/mir/tests/test_type_specialization.py +421 -0
  196. machine_dialect/mir/tests/test_type_specific_optimization.py +545 -0
  197. machine_dialect/mir/tests/test_type_specific_optimization_advanced.py +382 -0
  198. machine_dialect/mir/type_inference.py +368 -0
  199. machine_dialect/parser/__init__.py +12 -0
  200. machine_dialect/parser/enums.py +45 -0
  201. machine_dialect/parser/parser.py +3655 -0
  202. machine_dialect/parser/protocols.py +11 -0
  203. machine_dialect/parser/symbol_table.py +169 -0
  204. machine_dialect/parser/tests/__init__.py +0 -0
  205. machine_dialect/parser/tests/helper_functions.py +193 -0
  206. machine_dialect/parser/tests/test_action_statements.py +334 -0
  207. machine_dialect/parser/tests/test_boolean_literal_expressions.py +152 -0
  208. machine_dialect/parser/tests/test_call_statements.py +154 -0
  209. machine_dialect/parser/tests/test_call_statements_errors.py +187 -0
  210. machine_dialect/parser/tests/test_collection_mutations.py +264 -0
  211. machine_dialect/parser/tests/test_conditional_expressions.py +343 -0
  212. machine_dialect/parser/tests/test_define_integration.py +468 -0
  213. machine_dialect/parser/tests/test_define_statements.py +311 -0
  214. machine_dialect/parser/tests/test_dict_extraction.py +115 -0
  215. machine_dialect/parser/tests/test_empty_literal.py +155 -0
  216. machine_dialect/parser/tests/test_float_literal_expressions.py +163 -0
  217. machine_dialect/parser/tests/test_identifier_expressions.py +57 -0
  218. machine_dialect/parser/tests/test_if_empty_block.py +61 -0
  219. machine_dialect/parser/tests/test_if_statements.py +299 -0
  220. machine_dialect/parser/tests/test_illegal_tokens.py +86 -0
  221. machine_dialect/parser/tests/test_infix_expressions.py +680 -0
  222. machine_dialect/parser/tests/test_integer_literal_expressions.py +137 -0
  223. machine_dialect/parser/tests/test_interaction_statements.py +269 -0
  224. machine_dialect/parser/tests/test_list_literals.py +277 -0
  225. machine_dialect/parser/tests/test_no_none_in_ast.py +94 -0
  226. machine_dialect/parser/tests/test_panic_mode_recovery.py +171 -0
  227. machine_dialect/parser/tests/test_parse_errors.py +114 -0
  228. machine_dialect/parser/tests/test_possessive_syntax.py +182 -0
  229. machine_dialect/parser/tests/test_prefix_expressions.py +415 -0
  230. machine_dialect/parser/tests/test_program.py +13 -0
  231. machine_dialect/parser/tests/test_return_statements.py +89 -0
  232. machine_dialect/parser/tests/test_set_statements.py +152 -0
  233. machine_dialect/parser/tests/test_strict_equality.py +258 -0
  234. machine_dialect/parser/tests/test_symbol_table.py +217 -0
  235. machine_dialect/parser/tests/test_url_literal_expressions.py +209 -0
  236. machine_dialect/parser/tests/test_utility_statements.py +423 -0
  237. machine_dialect/parser/token_buffer.py +159 -0
  238. machine_dialect/repl/__init__.py +3 -0
  239. machine_dialect/repl/repl.py +426 -0
  240. machine_dialect/repl/tests/__init__.py +0 -0
  241. machine_dialect/repl/tests/test_repl.py +606 -0
  242. machine_dialect/semantic/__init__.py +12 -0
  243. machine_dialect/semantic/analyzer.py +906 -0
  244. machine_dialect/semantic/error_messages.py +189 -0
  245. machine_dialect/semantic/tests/__init__.py +1 -0
  246. machine_dialect/semantic/tests/test_analyzer.py +364 -0
  247. machine_dialect/semantic/tests/test_error_messages.py +104 -0
  248. machine_dialect/tests/edge_cases/__init__.py +10 -0
  249. machine_dialect/tests/edge_cases/test_boundary_access.py +256 -0
  250. machine_dialect/tests/edge_cases/test_empty_collections.py +166 -0
  251. machine_dialect/tests/edge_cases/test_invalid_operations.py +243 -0
  252. machine_dialect/tests/edge_cases/test_named_list_edge_cases.py +295 -0
  253. machine_dialect/tests/edge_cases/test_nested_structures.py +313 -0
  254. machine_dialect/tests/edge_cases/test_type_mixing.py +277 -0
  255. machine_dialect/tests/integration/test_array_operations_emulation.py +248 -0
  256. machine_dialect/tests/integration/test_list_compilation.py +395 -0
  257. machine_dialect/tests/integration/test_lists_and_dictionaries.py +322 -0
  258. machine_dialect/type_checking/__init__.py +21 -0
  259. machine_dialect/type_checking/tests/__init__.py +1 -0
  260. machine_dialect/type_checking/tests/test_type_system.py +230 -0
  261. machine_dialect/type_checking/type_system.py +270 -0
  262. machine_dialect-0.1.0a1.dist-info/METADATA +128 -0
  263. machine_dialect-0.1.0a1.dist-info/RECORD +268 -0
  264. machine_dialect-0.1.0a1.dist-info/WHEEL +5 -0
  265. machine_dialect-0.1.0a1.dist-info/entry_points.txt +3 -0
  266. machine_dialect-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  267. machine_dialect-0.1.0a1.dist-info/top_level.txt +2 -0
  268. machine_dialect_vm/__init__.pyi +15 -0
@@ -0,0 +1,421 @@
1
+ """Comprehensive tests for type specialization optimization pass."""
2
+
3
+ from unittest.mock import MagicMock
4
+
5
+ import pytest
6
+
7
+ from machine_dialect.mir.basic_block import BasicBlock
8
+ from machine_dialect.mir.mir_function import MIRFunction
9
+ from machine_dialect.mir.mir_instructions import (
10
+ BinaryOp,
11
+ Call,
12
+ Return,
13
+ )
14
+ from machine_dialect.mir.mir_module import MIRModule
15
+ from machine_dialect.mir.mir_types import MIRType
16
+ from machine_dialect.mir.mir_values import Constant, Temp, Variable
17
+ from machine_dialect.mir.optimization_pass import PassType, PreservationLevel
18
+ from machine_dialect.mir.optimizations.type_specialization import (
19
+ SpecializationCandidate,
20
+ TypeSignature,
21
+ TypeSpecialization,
22
+ )
23
+ from machine_dialect.mir.profiling.profile_data import ProfileData
24
+
25
+
26
+ class TestTypeSignature:
27
+ """Test TypeSignature dataclass."""
28
+
29
+ def test_signature_creation(self) -> None:
30
+ """Test creating a type signature."""
31
+ sig = TypeSignature(
32
+ param_types=(MIRType.INT, MIRType.FLOAT),
33
+ return_type=MIRType.INT,
34
+ )
35
+ assert sig.param_types == (MIRType.INT, MIRType.FLOAT)
36
+ assert sig.return_type == MIRType.INT
37
+
38
+ def test_signature_hash(self) -> None:
39
+ """Test that type signatures can be hashed."""
40
+ sig1 = TypeSignature(
41
+ param_types=(MIRType.INT, MIRType.INT),
42
+ return_type=MIRType.INT,
43
+ )
44
+ sig2 = TypeSignature(
45
+ param_types=(MIRType.INT, MIRType.INT),
46
+ return_type=MIRType.INT,
47
+ )
48
+ sig3 = TypeSignature(
49
+ param_types=(MIRType.FLOAT, MIRType.INT),
50
+ return_type=MIRType.INT,
51
+ )
52
+
53
+ # Same signatures should have same hash
54
+ assert hash(sig1) == hash(sig2)
55
+ # Different signatures should have different hash
56
+ assert hash(sig1) != hash(sig3)
57
+
58
+ def test_signature_string_representation(self) -> None:
59
+ """Test string representation of type signature."""
60
+ sig = TypeSignature(
61
+ param_types=(MIRType.INT, MIRType.BOOL),
62
+ return_type=MIRType.FLOAT,
63
+ )
64
+ assert str(sig) == "(int, bool) -> float"
65
+
66
+
67
+ class TestSpecializationCandidate:
68
+ """Test SpecializationCandidate dataclass."""
69
+
70
+ def test_candidate_creation(self) -> None:
71
+ """Test creating a specialization candidate."""
72
+ sig = TypeSignature(
73
+ param_types=(MIRType.INT, MIRType.INT),
74
+ return_type=MIRType.INT,
75
+ )
76
+ candidate = SpecializationCandidate(
77
+ function_name="add",
78
+ signature=sig,
79
+ call_count=500,
80
+ benefit=0.85,
81
+ )
82
+ assert candidate.function_name == "add"
83
+ assert candidate.signature == sig
84
+ assert candidate.call_count == 500
85
+ assert candidate.benefit == 0.85
86
+
87
+ def test_specialized_name_generation(self) -> None:
88
+ """Test generation of specialized function names."""
89
+ sig = TypeSignature(
90
+ param_types=(MIRType.INT, MIRType.FLOAT),
91
+ return_type=MIRType.FLOAT,
92
+ )
93
+ candidate = SpecializationCandidate(
94
+ function_name="multiply",
95
+ signature=sig,
96
+ call_count=200,
97
+ benefit=0.5,
98
+ )
99
+ assert candidate.specialized_name() == "multiply__int_float"
100
+
101
+ def test_specialized_name_no_params(self) -> None:
102
+ """Test specialized name for function with no parameters."""
103
+ sig = TypeSignature(
104
+ param_types=(),
105
+ return_type=MIRType.INT,
106
+ )
107
+ candidate = SpecializationCandidate(
108
+ function_name="get_value",
109
+ signature=sig,
110
+ call_count=100,
111
+ benefit=0.3,
112
+ )
113
+ assert candidate.specialized_name() == "get_value__"
114
+
115
+
116
+ class TestTypeSpecialization:
117
+ """Test TypeSpecialization optimization pass."""
118
+
119
+ @pytest.fixture
120
+ def module(self) -> MIRModule:
121
+ """Test fixture providing a MIRModule with a simple add function."""
122
+ module = MIRModule("test")
123
+
124
+ # Create a simple function to specialize
125
+ func = MIRFunction(
126
+ "add",
127
+ [Variable("a", MIRType.UNKNOWN), Variable("b", MIRType.UNKNOWN)],
128
+ MIRType.UNKNOWN,
129
+ )
130
+ block = BasicBlock("entry")
131
+
132
+ # Add simple addition: result = a + b; return result
133
+ a = Variable("a", MIRType.UNKNOWN)
134
+ b = Variable("b", MIRType.UNKNOWN)
135
+ result = Temp(MIRType.UNKNOWN)
136
+ block.add_instruction(BinaryOp(result, "+", a, b, (1, 1)))
137
+ block.add_instruction(Return((1, 1), result))
138
+
139
+ func.cfg.add_block(block)
140
+ func.cfg.entry_block = block
141
+ module.add_function(func)
142
+
143
+ return module
144
+
145
+ def test_pass_initialization(self) -> None:
146
+ """Test initialization of type specialization pass."""
147
+ opt = TypeSpecialization(threshold=50)
148
+ assert opt.profile_data is None
149
+ assert opt.threshold == 50
150
+ assert opt.stats["functions_analyzed"] == 0
151
+ assert opt.stats["functions_specialized"] == 0
152
+
153
+ def test_pass_info(self) -> None:
154
+ """Test pass information."""
155
+ opt = TypeSpecialization()
156
+ info = opt.get_info()
157
+ assert info.name == "type-specialization"
158
+ assert info.pass_type == PassType.OPTIMIZATION
159
+ assert info.preserves == PreservationLevel.NONE
160
+
161
+ def test_collect_type_signatures(self, module: MIRModule) -> None:
162
+ """Test collecting type signatures from call sites."""
163
+ opt = TypeSpecialization()
164
+
165
+ # Create a caller function with typed calls
166
+ caller = MIRFunction("caller", [], MIRType.EMPTY)
167
+ block = BasicBlock("entry")
168
+
169
+ # Call add(1, 2) - both int
170
+ t1 = Temp(MIRType.INT)
171
+ block.add_instruction(Call(t1, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
172
+
173
+ # Call add(1.0, 2.0) - both float
174
+ t2 = Temp(MIRType.FLOAT)
175
+ block.add_instruction(Call(t2, "add", [Constant(1.0, MIRType.FLOAT), Constant(2.0, MIRType.FLOAT)], (1, 1)))
176
+
177
+ # Call add(1, 2) again - int
178
+ t3 = Temp(MIRType.INT)
179
+ block.add_instruction(Call(t3, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
180
+
181
+ caller.cfg.add_block(block)
182
+ caller.cfg.entry_block = block
183
+ module.add_function(caller)
184
+
185
+ # Collect signatures
186
+ opt._collect_type_signatures(module)
187
+
188
+ # Check collected signatures
189
+ assert "add" in opt.type_signatures
190
+ signatures = opt.type_signatures["add"]
191
+
192
+ # Should have two different signatures
193
+ assert len(signatures) == 2
194
+
195
+ # Check int signature (called twice)
196
+ int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
197
+ assert int_sig in signatures
198
+ assert signatures[int_sig] == 2
199
+
200
+ # Check float signature (called once)
201
+ float_sig = TypeSignature((MIRType.FLOAT, MIRType.FLOAT), MIRType.FLOAT)
202
+ assert float_sig in signatures
203
+ assert signatures[float_sig] == 1
204
+
205
+ def test_identify_candidates(self, module: MIRModule) -> None:
206
+ """Test identifying specialization candidates."""
207
+ opt = TypeSpecialization(threshold=2)
208
+
209
+ # Set up type signatures
210
+ int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
211
+ float_sig = TypeSignature((MIRType.FLOAT, MIRType.FLOAT), MIRType.FLOAT)
212
+
213
+ opt.type_signatures["add"][int_sig] = 10 # Above threshold
214
+ opt.type_signatures["add"][float_sig] = 1 # Below threshold
215
+
216
+ candidates = opt._identify_candidates(module)
217
+
218
+ # Should only have one candidate (int signature)
219
+ assert len(candidates) == 1
220
+ candidate = candidates[0]
221
+ assert candidate.function_name == "add"
222
+ assert candidate.signature == int_sig
223
+ assert candidate.call_count == 10
224
+
225
+ def test_calculate_benefit(self, module: MIRModule) -> None:
226
+ """Test benefit calculation for specialization."""
227
+ opt = TypeSpecialization()
228
+
229
+ # Test with specific type signature (high benefit)
230
+ int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
231
+ func = module.functions["add"]
232
+ benefit = opt._calculate_benefit(int_sig, 100, func)
233
+ assert benefit > 0
234
+
235
+ # Test with UNKNOWN types (lower benefit)
236
+ any_sig = TypeSignature((MIRType.UNKNOWN, MIRType.UNKNOWN), MIRType.UNKNOWN)
237
+ benefit_any = opt._calculate_benefit(any_sig, 100, func)
238
+ assert benefit_any <= benefit
239
+
240
+ def test_create_specialized_function(self, module: MIRModule) -> None:
241
+ """Test creating a specialized function."""
242
+ opt = TypeSpecialization()
243
+
244
+ int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
245
+ candidate = SpecializationCandidate(
246
+ function_name="add",
247
+ signature=int_sig,
248
+ call_count=100,
249
+ benefit=0.8,
250
+ )
251
+
252
+ # Create specialization (returns True/False)
253
+ created = opt._create_specialization(module, candidate)
254
+ assert created
255
+
256
+ # Check that specialized function was added to module
257
+ specialized_name = candidate.specialized_name()
258
+ assert specialized_name in module.functions
259
+
260
+ specialized = module.functions[specialized_name]
261
+ assert specialized.name == "add__int_int"
262
+ assert len(specialized.params) == 2
263
+ assert specialized.params[0].type == MIRType.INT
264
+ assert specialized.params[1].type == MIRType.INT
265
+ # Note: return_type might be set differently during specialization
266
+ # Check that function exists instead
267
+ assert specialized.return_type is not None
268
+
269
+ # Check that blocks were copied
270
+ assert len(specialized.cfg.blocks) == 1
271
+
272
+ def test_update_call_sites(self, module: MIRModule) -> None:
273
+ """Test updating call sites to use specialized functions."""
274
+ opt = TypeSpecialization()
275
+
276
+ # Create specialized function mapping
277
+ int_sig = TypeSignature((MIRType.INT, MIRType.INT), MIRType.INT)
278
+ opt.specializations["add"][int_sig] = "add__int_int"
279
+
280
+ # Create a caller with matching call
281
+ caller = MIRFunction("caller", [], MIRType.EMPTY)
282
+ block = BasicBlock("entry")
283
+
284
+ t1 = Temp(MIRType.INT)
285
+ call_inst = Call(t1, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1))
286
+ block.add_instruction(call_inst)
287
+
288
+ caller.cfg.add_block(block)
289
+ caller.cfg.entry_block = block
290
+ module.add_function(caller)
291
+
292
+ # Update call sites
293
+ opt._update_call_sites(module)
294
+
295
+ # Check that call was updated
296
+ updated_call = next(iter(block.instructions))
297
+ assert isinstance(updated_call, Call)
298
+ # Call has 'func' attribute which is a FunctionRef (with @ prefix)
299
+ assert isinstance(updated_call, Call)
300
+ assert str(updated_call.func) == "@add__int_int"
301
+
302
+ def test_run_on_module_with_profile(self, module: MIRModule) -> None:
303
+ """Test running type specialization with profile data."""
304
+ # Create mock profile data
305
+ profile = MagicMock(spec=ProfileData)
306
+ profile.get_function_metrics = MagicMock(
307
+ return_value={
308
+ "call_count": 1000,
309
+ "type_signatures": {
310
+ ((MIRType.INT, MIRType.INT), MIRType.INT): 800,
311
+ ((MIRType.FLOAT, MIRType.FLOAT), MIRType.FLOAT): 200,
312
+ },
313
+ }
314
+ )
315
+
316
+ opt = TypeSpecialization(profile_data=profile, threshold=100)
317
+
318
+ # Run optimization
319
+ changed = opt.run_on_module(module)
320
+
321
+ # Should have analyzed functions (might not change if threshold not met)
322
+ assert opt.stats["functions_analyzed"] > 0
323
+ # Changed flag depends on whether specialization was created
324
+ if changed:
325
+ assert opt.stats["specializations_created"] > 0
326
+
327
+ def test_run_on_module_without_profile(self, module: MIRModule) -> None:
328
+ """Test running type specialization without profile data."""
329
+ opt = TypeSpecialization(threshold=1)
330
+
331
+ # Add a caller to create type signatures
332
+ caller = MIRFunction("main", [], MIRType.EMPTY)
333
+ block = BasicBlock("entry")
334
+
335
+ # Multiple calls with int types
336
+ for _ in range(5):
337
+ t = Temp(MIRType.INT)
338
+ block.add_instruction(Call(t, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
339
+
340
+ caller.cfg.add_block(block)
341
+ caller.cfg.entry_block = block
342
+ module.add_function(caller)
343
+
344
+ # Run optimization
345
+ changed = opt.run_on_module(module)
346
+
347
+ # Should have analyzed functions
348
+ assert opt.stats["functions_analyzed"] > 0
349
+
350
+ # Check if specialization was created (depends on threshold)
351
+ if changed:
352
+ assert opt.stats["specializations_created"] > 0
353
+
354
+ def test_no_specialization_below_threshold(self, module: MIRModule) -> None:
355
+ """Test that no specialization occurs below threshold."""
356
+ opt = TypeSpecialization(threshold=1000) # Very high threshold
357
+
358
+ # Add a caller with few calls
359
+ caller = MIRFunction("main", [], MIRType.EMPTY)
360
+ block = BasicBlock("entry")
361
+
362
+ t = Temp(MIRType.INT)
363
+ block.add_instruction(Call(t, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
364
+
365
+ caller.cfg.add_block(block)
366
+ caller.cfg.entry_block = block
367
+ module.add_function(caller)
368
+
369
+ # Run optimization
370
+ changed = opt.run_on_module(module)
371
+
372
+ # Should not have made changes
373
+ assert not changed
374
+ assert opt.stats["specializations_created"] == 0
375
+
376
+ def test_multiple_function_specialization(self, module: MIRModule) -> None:
377
+ """Test specializing multiple functions."""
378
+ opt = TypeSpecialization(threshold=2)
379
+
380
+ # Add another function to specialize
381
+ mul_func = MIRFunction(
382
+ "multiply",
383
+ [Variable("x", MIRType.UNKNOWN), Variable("y", MIRType.UNKNOWN)],
384
+ MIRType.UNKNOWN,
385
+ )
386
+ block = BasicBlock("entry")
387
+ x = Variable("x", MIRType.UNKNOWN)
388
+ y = Variable("y", MIRType.UNKNOWN)
389
+ result = Temp(MIRType.UNKNOWN)
390
+ block.add_instruction(BinaryOp(result, "*", x, y, (1, 1)))
391
+ block.add_instruction(Return((1, 1), result))
392
+ mul_func.cfg.add_block(block)
393
+ mul_func.cfg.entry_block = block
394
+ module.add_function(mul_func)
395
+
396
+ # Add caller with calls to both functions
397
+ caller = MIRFunction("main", [], MIRType.EMPTY)
398
+ block = BasicBlock("entry")
399
+
400
+ # Call add multiple times
401
+ for _ in range(3):
402
+ t = Temp(MIRType.INT)
403
+ block.add_instruction(Call(t, "add", [Constant(1, MIRType.INT), Constant(2, MIRType.INT)], (1, 1)))
404
+
405
+ # Call multiply multiple times
406
+ for _ in range(3):
407
+ t = Temp(MIRType.FLOAT)
408
+ block.add_instruction(
409
+ Call(t, "multiply", [Constant(1.0, MIRType.FLOAT), Constant(2.0, MIRType.FLOAT)], (1, 1))
410
+ )
411
+
412
+ caller.cfg.add_block(block)
413
+ caller.cfg.entry_block = block
414
+ module.add_function(caller)
415
+
416
+ # Run optimization
417
+ changed = opt.run_on_module(module)
418
+
419
+ # Should have specialized both functions
420
+ assert changed
421
+ assert opt.stats["functions_specialized"] >= 2