kumi 0.0.7 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CLAUDE.md +1 -1
  3. data/README.md +8 -5
  4. data/examples/game_of_life.rb +1 -1
  5. data/examples/static_analysis_errors.rb +7 -7
  6. data/lib/kumi/analyzer.rb +15 -15
  7. data/lib/kumi/compiler.rb +6 -6
  8. data/lib/kumi/core/analyzer/analysis_state.rb +39 -0
  9. data/lib/kumi/core/analyzer/constant_evaluator.rb +59 -0
  10. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +248 -0
  11. data/lib/kumi/core/analyzer/passes/declaration_validator.rb +45 -0
  12. data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +153 -0
  13. data/lib/kumi/core/analyzer/passes/input_collector.rb +139 -0
  14. data/lib/kumi/core/analyzer/passes/name_indexer.rb +26 -0
  15. data/lib/kumi/core/analyzer/passes/pass_base.rb +52 -0
  16. data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +111 -0
  17. data/lib/kumi/core/analyzer/passes/toposorter.rb +110 -0
  18. data/lib/kumi/core/analyzer/passes/type_checker.rb +162 -0
  19. data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +48 -0
  20. data/lib/kumi/core/analyzer/passes/type_inferencer.rb +236 -0
  21. data/lib/kumi/core/analyzer/passes/unsat_detector.rb +406 -0
  22. data/lib/kumi/core/analyzer/passes/visitor_pass.rb +44 -0
  23. data/lib/kumi/core/atom_unsat_solver.rb +396 -0
  24. data/lib/kumi/core/compiled_schema.rb +43 -0
  25. data/lib/kumi/core/constraint_relationship_solver.rb +641 -0
  26. data/lib/kumi/core/domain/enum_analyzer.rb +55 -0
  27. data/lib/kumi/core/domain/range_analyzer.rb +85 -0
  28. data/lib/kumi/core/domain/validator.rb +82 -0
  29. data/lib/kumi/core/domain/violation_formatter.rb +42 -0
  30. data/lib/kumi/core/error_reporter.rb +166 -0
  31. data/lib/kumi/core/error_reporting.rb +97 -0
  32. data/lib/kumi/core/errors.rb +120 -0
  33. data/lib/kumi/core/evaluation_wrapper.rb +40 -0
  34. data/lib/kumi/core/explain.rb +295 -0
  35. data/lib/kumi/core/export/deserializer.rb +41 -0
  36. data/lib/kumi/core/export/errors.rb +14 -0
  37. data/lib/kumi/core/export/node_builders.rb +142 -0
  38. data/lib/kumi/core/export/node_registry.rb +54 -0
  39. data/lib/kumi/core/export/node_serializers.rb +158 -0
  40. data/lib/kumi/core/export/serializer.rb +25 -0
  41. data/lib/kumi/core/export.rb +35 -0
  42. data/lib/kumi/core/function_registry/collection_functions.rb +202 -0
  43. data/lib/kumi/core/function_registry/comparison_functions.rb +33 -0
  44. data/lib/kumi/core/function_registry/conditional_functions.rb +38 -0
  45. data/lib/kumi/core/function_registry/function_builder.rb +95 -0
  46. data/lib/kumi/core/function_registry/logical_functions.rb +44 -0
  47. data/lib/kumi/core/function_registry/math_functions.rb +74 -0
  48. data/lib/kumi/core/function_registry/string_functions.rb +57 -0
  49. data/lib/kumi/core/function_registry/type_functions.rb +53 -0
  50. data/lib/kumi/{function_registry.rb → core/function_registry.rb} +28 -36
  51. data/lib/kumi/core/input/type_matcher.rb +97 -0
  52. data/lib/kumi/core/input/validator.rb +51 -0
  53. data/lib/kumi/core/input/violation_creator.rb +52 -0
  54. data/lib/kumi/core/json_schema/generator.rb +65 -0
  55. data/lib/kumi/core/json_schema/validator.rb +27 -0
  56. data/lib/kumi/core/json_schema.rb +16 -0
  57. data/lib/kumi/core/ruby_parser/build_context.rb +27 -0
  58. data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +38 -0
  59. data/lib/kumi/core/ruby_parser/dsl.rb +14 -0
  60. data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +138 -0
  61. data/lib/kumi/core/ruby_parser/expression_converter.rb +128 -0
  62. data/lib/kumi/core/ruby_parser/guard_rails.rb +45 -0
  63. data/lib/kumi/core/ruby_parser/input_builder.rb +127 -0
  64. data/lib/kumi/core/ruby_parser/input_field_proxy.rb +48 -0
  65. data/lib/kumi/core/ruby_parser/input_proxy.rb +31 -0
  66. data/lib/kumi/core/ruby_parser/nested_input.rb +17 -0
  67. data/lib/kumi/core/ruby_parser/parser.rb +71 -0
  68. data/lib/kumi/core/ruby_parser/schema_builder.rb +175 -0
  69. data/lib/kumi/core/ruby_parser/sugar.rb +263 -0
  70. data/lib/kumi/core/ruby_parser.rb +12 -0
  71. data/lib/kumi/core/schema_instance.rb +111 -0
  72. data/lib/kumi/core/types/builder.rb +23 -0
  73. data/lib/kumi/core/types/compatibility.rb +96 -0
  74. data/lib/kumi/core/types/formatter.rb +26 -0
  75. data/lib/kumi/core/types/inference.rb +42 -0
  76. data/lib/kumi/core/types/normalizer.rb +72 -0
  77. data/lib/kumi/core/types/validator.rb +37 -0
  78. data/lib/kumi/core/types.rb +66 -0
  79. data/lib/kumi/core/vectorization_metadata.rb +110 -0
  80. data/lib/kumi/errors.rb +1 -112
  81. data/lib/kumi/registry.rb +37 -0
  82. data/lib/kumi/schema.rb +5 -5
  83. data/lib/kumi/schema_metadata.rb +3 -3
  84. data/lib/kumi/syntax/array_expression.rb +6 -6
  85. data/lib/kumi/syntax/call_expression.rb +4 -4
  86. data/lib/kumi/syntax/cascade_expression.rb +4 -4
  87. data/lib/kumi/syntax/case_expression.rb +4 -4
  88. data/lib/kumi/syntax/declaration_reference.rb +4 -4
  89. data/lib/kumi/syntax/hash_expression.rb +4 -4
  90. data/lib/kumi/syntax/input_declaration.rb +5 -5
  91. data/lib/kumi/syntax/input_element_reference.rb +5 -5
  92. data/lib/kumi/syntax/input_reference.rb +5 -5
  93. data/lib/kumi/syntax/literal.rb +4 -4
  94. data/lib/kumi/syntax/node.rb +34 -34
  95. data/lib/kumi/syntax/root.rb +6 -6
  96. data/lib/kumi/syntax/trait_declaration.rb +4 -4
  97. data/lib/kumi/syntax/value_declaration.rb +4 -4
  98. data/lib/kumi/version.rb +1 -1
  99. data/migrate_to_core_iterative.rb +938 -0
  100. data/scripts/generate_function_docs.rb +9 -9
  101. metadata +75 -72
  102. data/lib/kumi/analyzer/analysis_state.rb +0 -37
  103. data/lib/kumi/analyzer/constant_evaluator.rb +0 -57
  104. data/lib/kumi/analyzer/passes/broadcast_detector.rb +0 -246
  105. data/lib/kumi/analyzer/passes/declaration_validator.rb +0 -43
  106. data/lib/kumi/analyzer/passes/dependency_resolver.rb +0 -151
  107. data/lib/kumi/analyzer/passes/input_collector.rb +0 -137
  108. data/lib/kumi/analyzer/passes/name_indexer.rb +0 -24
  109. data/lib/kumi/analyzer/passes/pass_base.rb +0 -50
  110. data/lib/kumi/analyzer/passes/semantic_constraint_validator.rb +0 -109
  111. data/lib/kumi/analyzer/passes/toposorter.rb +0 -108
  112. data/lib/kumi/analyzer/passes/type_checker.rb +0 -160
  113. data/lib/kumi/analyzer/passes/type_consistency_checker.rb +0 -46
  114. data/lib/kumi/analyzer/passes/type_inferencer.rb +0 -232
  115. data/lib/kumi/analyzer/passes/unsat_detector.rb +0 -404
  116. data/lib/kumi/analyzer/passes/visitor_pass.rb +0 -42
  117. data/lib/kumi/atom_unsat_solver.rb +0 -394
  118. data/lib/kumi/compiled_schema.rb +0 -41
  119. data/lib/kumi/constraint_relationship_solver.rb +0 -638
  120. data/lib/kumi/domain/enum_analyzer.rb +0 -53
  121. data/lib/kumi/domain/range_analyzer.rb +0 -83
  122. data/lib/kumi/domain/validator.rb +0 -80
  123. data/lib/kumi/domain/violation_formatter.rb +0 -40
  124. data/lib/kumi/error_reporter.rb +0 -164
  125. data/lib/kumi/error_reporting.rb +0 -95
  126. data/lib/kumi/evaluation_wrapper.rb +0 -38
  127. data/lib/kumi/explain.rb +0 -293
  128. data/lib/kumi/export/deserializer.rb +0 -39
  129. data/lib/kumi/export/errors.rb +0 -12
  130. data/lib/kumi/export/node_builders.rb +0 -140
  131. data/lib/kumi/export/node_registry.rb +0 -52
  132. data/lib/kumi/export/node_serializers.rb +0 -156
  133. data/lib/kumi/export/serializer.rb +0 -23
  134. data/lib/kumi/export.rb +0 -33
  135. data/lib/kumi/function_registry/collection_functions.rb +0 -200
  136. data/lib/kumi/function_registry/comparison_functions.rb +0 -31
  137. data/lib/kumi/function_registry/conditional_functions.rb +0 -36
  138. data/lib/kumi/function_registry/function_builder.rb +0 -93
  139. data/lib/kumi/function_registry/logical_functions.rb +0 -42
  140. data/lib/kumi/function_registry/math_functions.rb +0 -72
  141. data/lib/kumi/function_registry/string_functions.rb +0 -54
  142. data/lib/kumi/function_registry/type_functions.rb +0 -51
  143. data/lib/kumi/input/type_matcher.rb +0 -95
  144. data/lib/kumi/input/validator.rb +0 -49
  145. data/lib/kumi/input/violation_creator.rb +0 -50
  146. data/lib/kumi/json_schema/generator.rb +0 -63
  147. data/lib/kumi/json_schema/validator.rb +0 -25
  148. data/lib/kumi/json_schema.rb +0 -14
  149. data/lib/kumi/ruby_parser/build_context.rb +0 -25
  150. data/lib/kumi/ruby_parser/declaration_reference_proxy.rb +0 -36
  151. data/lib/kumi/ruby_parser/dsl.rb +0 -12
  152. data/lib/kumi/ruby_parser/dsl_cascade_builder.rb +0 -136
  153. data/lib/kumi/ruby_parser/expression_converter.rb +0 -126
  154. data/lib/kumi/ruby_parser/guard_rails.rb +0 -43
  155. data/lib/kumi/ruby_parser/input_builder.rb +0 -125
  156. data/lib/kumi/ruby_parser/input_field_proxy.rb +0 -46
  157. data/lib/kumi/ruby_parser/input_proxy.rb +0 -29
  158. data/lib/kumi/ruby_parser/nested_input.rb +0 -15
  159. data/lib/kumi/ruby_parser/parser.rb +0 -69
  160. data/lib/kumi/ruby_parser/schema_builder.rb +0 -173
  161. data/lib/kumi/ruby_parser/sugar.rb +0 -261
  162. data/lib/kumi/ruby_parser.rb +0 -10
  163. data/lib/kumi/schema_instance.rb +0 -109
  164. data/lib/kumi/types/builder.rb +0 -21
  165. data/lib/kumi/types/compatibility.rb +0 -94
  166. data/lib/kumi/types/formatter.rb +0 -24
  167. data/lib/kumi/types/inference.rb +0 -40
  168. data/lib/kumi/types/normalizer.rb +0 -70
  169. data/lib/kumi/types/validator.rb +0 -35
  170. data/lib/kumi/types.rb +0 -64
  171. data/lib/kumi/vectorization_metadata.rb +0 -108
@@ -1,160 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Analyzer
5
- module Passes
6
- # RESPONSIBILITY: Validate function call arity and argument types against FunctionRegistry
7
- # DEPENDENCIES: :inferred_types from TypeInferencer
8
- # PRODUCES: None (validation only)
9
- # INTERFACE: new(schema, state).run(errors)
10
- class TypeChecker < VisitorPass
11
- def run(errors)
12
- visit_nodes_of_type(Kumi::Syntax::CallExpression, errors: errors) do |node, _decl, errs|
13
- validate_function_call(node, errs)
14
- end
15
- state
16
- end
17
-
18
- private
19
-
20
- def validate_function_call(node, errors)
21
- signature = get_function_signature(node, errors)
22
- return unless signature
23
-
24
- validate_arity(node, signature, errors)
25
- validate_argument_types(node, signature, errors)
26
- end
27
-
28
- def get_function_signature(node, errors)
29
- FunctionRegistry.signature(node.fn_name)
30
- rescue FunctionRegistry::UnknownFunction
31
- # Use old format for backward compatibility, but node.loc provides better location
32
- report_error(errors, "unsupported operator `#{node.fn_name}`", location: node.loc, type: :type)
33
- nil
34
- end
35
-
36
- def validate_arity(node, signature, errors)
37
- expected_arity = signature[:arity]
38
- actual_arity = node.args.size
39
-
40
- return if expected_arity.negative? || expected_arity == actual_arity
41
-
42
- report_error(errors, "operator `#{node.fn_name}` expects #{expected_arity} args, got #{actual_arity}", location: node.loc,
43
- type: :type)
44
- end
45
-
46
- def validate_argument_types(node, signature, errors)
47
- types = signature[:param_types]
48
- return if types.nil? || (signature[:arity].negative? && node.args.empty?)
49
-
50
- # Skip type checking for vectorized operations
51
- broadcast_meta = get_state(:broadcasts, required: false)
52
- return if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
53
-
54
- node.args.each_with_index do |arg, i|
55
- validate_argument_type(arg, i, types[i], node.fn_name, errors)
56
- end
57
- end
58
-
59
- def is_part_of_vectorized_operation?(node, broadcast_meta)
60
- # Check if this node is part of a vectorized or reduction operation
61
- # This is a simplified check - in a real implementation we'd need to track context
62
- node.args.any? do |arg|
63
- case arg
64
- when Kumi::Syntax::DeclarationReference
65
- broadcast_meta[:vectorized_operations]&.key?(arg.name) ||
66
- broadcast_meta[:reduction_operations]&.key?(arg.name)
67
- when Kumi::Syntax::InputElementReference
68
- broadcast_meta[:array_fields]&.key?(arg.path.first)
69
- else
70
- false
71
- end
72
- end
73
- end
74
-
75
- def validate_argument_type(arg, index, expected_type, fn_name, errors)
76
- return if expected_type.nil? || expected_type == Kumi::Types::ANY
77
-
78
- # Get the inferred type for this argument
79
- actual_type = get_expression_type(arg)
80
- return if Kumi::Types.compatible?(actual_type, expected_type)
81
-
82
- # Generate descriptive error message
83
- source_desc = describe_expression_type(arg, actual_type)
84
- report_error(errors, "argument #{index + 1} of `fn(:#{fn_name})` expects #{expected_type}, " \
85
- "got #{source_desc}", location: arg.loc, type: :type)
86
- end
87
-
88
- def get_expression_type(expr)
89
- case expr
90
- when Kumi::Syntax::Literal
91
- # Inferred type from literal value
92
- Kumi::Types.infer_from_value(expr.value)
93
-
94
- when Kumi::Syntax::InputReference
95
- # Declared type from input block (user-specified)
96
- get_declared_field_type(expr.name)
97
-
98
- when Kumi::Syntax::DeclarationReference
99
- # Inferred type from type inference results
100
- get_inferred_declaration_type(expr.name)
101
-
102
- else
103
- # For complex expressions, we should have type inference results
104
- # This is a simplified approach - in reality we'd need to track types for all expressions
105
- Kumi::Types::ANY
106
- end
107
- end
108
-
109
- def get_declared_field_type(field_name)
110
- # Get explicitly declared type from input metadata
111
- input_meta = get_state(:inputs, required: false) || {}
112
- field_meta = input_meta[field_name]
113
- field_meta&.dig(:type) || Kumi::Types::ANY
114
- end
115
-
116
- def get_inferred_declaration_type(decl_name)
117
- # Get inferred type from type inference results
118
- decl_types = get_state(:inferred_types, required: true)
119
- decl_types[decl_name] || Kumi::Types::ANY
120
- end
121
-
122
- def describe_expression_type(expr, type)
123
- case expr
124
- when Kumi::Syntax::Literal
125
- "`#{expr.value}` of type #{type} (literal value)"
126
-
127
- when Kumi::Syntax::InputReference
128
- input_meta = get_state(:inputs, required: false) || {}
129
- field_meta = input_meta[expr.name]
130
-
131
- if field_meta&.dig(:type)
132
- # Explicitly declared type
133
- domain_desc = field_meta[:domain] ? " (domain: #{field_meta[:domain]})" : ""
134
- "input field `#{expr.name}` of declared type #{type}#{domain_desc}"
135
- else
136
- # Undeclared field
137
- "undeclared input field `#{expr.name}` (inferred as #{type})"
138
- end
139
-
140
- when Kumi::Syntax::DeclarationReference
141
- # This type was inferred from the declaration's expression
142
- "reference to declaration `#{expr.name}` of inferred type #{type}"
143
-
144
- when Kumi::Syntax::CallExpression
145
- "result of function `#{expr.fn_name}` returning #{type}"
146
-
147
- when Kumi::Syntax::ArrayExpression
148
- "list expression of type #{type}"
149
-
150
- when Kumi::Syntax::CascadeExpression
151
- "cascade expression of type #{type}"
152
-
153
- else
154
- "expression of type #{type}"
155
- end
156
- end
157
- end
158
- end
159
- end
160
- end
@@ -1,46 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Analyzer
5
- module Passes
6
- # RESPONSIBILITY: Validate consistency between declared and inferred types
7
- # DEPENDENCIES: :inputs from InputCollector, :inferred_types from TypeInferencer
8
- # PRODUCES: None (validation only)
9
- # INTERFACE: new(schema, state).run(errors)
10
- class TypeConsistencyChecker < PassBase
11
- def run(errors)
12
- input_meta = get_state(:inputs, required: false) || {}
13
-
14
- # First, validate that all declared types are valid
15
- validate_declared_types(input_meta, errors)
16
-
17
- # Then check basic consistency (placeholder for now)
18
- # In a full implementation, this would do sophisticated usage analysis
19
- state
20
- end
21
-
22
- private
23
-
24
- def validate_declared_types(input_meta, errors)
25
- input_meta.each do |field_name, meta|
26
- declared_type = meta[:type]
27
- next unless declared_type # Skip fields without declared types
28
- next if Kumi::Types.valid_type?(declared_type)
29
-
30
- # Find the input field declaration for proper location information
31
- field_decl = find_input_field_declaration(field_name)
32
- location = field_decl&.loc
33
-
34
- report_type_error(errors, "Invalid type declaration for field :#{field_name}: #{declared_type.inspect}", location: location)
35
- end
36
- end
37
-
38
- def find_input_field_declaration(field_name)
39
- return nil unless schema
40
-
41
- schema.inputs.find { |input_decl| input_decl.name == field_name }
42
- end
43
- end
44
- end
45
- end
46
- end
@@ -1,232 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Analyzer
5
- module Passes
6
- # RESPONSIBILITY: Infer types for all declarations based on expression analysis
7
- # DEPENDENCIES: Toposorter (needs evaluation_order), DeclarationValidator (needs declarations)
8
- # PRODUCES: inferred_types hash mapping declaration names to inferred types
9
- # INTERFACE: new(schema, state).run(errors)
10
- class TypeInferencer < PassBase
11
- def run(errors)
12
- types = {}
13
- topo_order = get_state(:evaluation_order)
14
- definitions = get_state(:declarations)
15
-
16
- # Get broadcast metadata from broadcast detector
17
- broadcast_meta = get_state(:broadcasts, required: false) || {}
18
-
19
- # Process declarations in topological order to ensure dependencies are resolved
20
- topo_order.each do |name|
21
- decl = definitions[name]
22
- next unless decl
23
-
24
- begin
25
- # Check if this declaration is marked as vectorized
26
- if broadcast_meta[:vectorized_operations]&.key?(name)
27
- # Infer the element type and wrap in array
28
- element_type = infer_vectorized_element_type(decl.expression, types, broadcast_meta)
29
- types[name] = decl.is_a?(Kumi::Syntax::TraitDeclaration) ? { array: :boolean } : { array: element_type }
30
- else
31
- # Normal type inference
32
- inferred_type = infer_expression_type(decl.expression, types, broadcast_meta, name)
33
- types[name] = inferred_type
34
- end
35
- rescue StandardError => e
36
- report_type_error(errors, "Type inference failed: #{e.message}", location: decl&.loc)
37
- end
38
- end
39
-
40
- state.with(:inferred_types, types)
41
- end
42
-
43
- private
44
-
45
- def infer_expression_type(expr, type_context = {}, broadcast_metadata = {}, current_decl_name = nil)
46
- case expr
47
- when Literal
48
- Types.infer_from_value(expr.value)
49
- when InputReference
50
- # Look up type from field metadata
51
- input_meta = get_state(:inputs, required: false) || {}
52
- meta = input_meta[expr.name]
53
- meta&.dig(:type) || :any
54
- when DeclarationReference
55
- type_context[expr.name] || :any
56
- when CallExpression
57
- infer_call_type(expr, type_context, broadcast_metadata, current_decl_name)
58
- when ArrayExpression
59
- infer_list_type(expr, type_context, broadcast_metadata, current_decl_name)
60
- when CascadeExpression
61
- infer_cascade_type(expr, type_context, broadcast_metadata, current_decl_name)
62
- when InputElementReference
63
- # Element reference returns the field type
64
- infer_element_reference_type(expr)
65
- else
66
- :any
67
- end
68
- end
69
-
70
- def infer_call_type(call_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
71
- fn_name = call_expr.fn_name
72
- args = call_expr.args
73
-
74
- # Check broadcast metadata first
75
- if current_decl_name && broadcast_metadata[:vectorized_values]&.key?(current_decl_name)
76
- # This declaration is marked as vectorized, so it produces an array
77
- element_type = infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
78
- return { array: element_type }
79
- end
80
-
81
- if current_decl_name && broadcast_metadata[:reducer_values]&.key?(current_decl_name)
82
- # This declaration is marked as a reducer, get the result from the function
83
- return infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
84
- end
85
-
86
- # Check if function exists in registry
87
- unless FunctionRegistry.supported?(fn_name)
88
- # Don't push error here - let existing TypeChecker handle it
89
- return :any
90
- end
91
-
92
- signature = FunctionRegistry.signature(fn_name)
93
-
94
- # Validate arity if not variable
95
- if signature[:arity] >= 0 && args.size != signature[:arity]
96
- # Don't push error here - let existing TypeChecker handle it
97
- return :any
98
- end
99
-
100
- # Infer argument types
101
- arg_types = args.map { |arg| infer_expression_type(arg, type_context, broadcast_metadata, current_decl_name) }
102
-
103
- # Validate parameter types (warn but don't fail)
104
- param_types = signature[:param_types] || []
105
- if signature[:arity] >= 0 && param_types.size.positive?
106
- arg_types.each_with_index do |arg_type, i|
107
- expected_type = param_types[i] || param_types.last
108
- next if expected_type.nil?
109
-
110
- unless Types.compatible?(arg_type, expected_type)
111
- # Could add warning here in future, but for now just infer best type
112
- end
113
- end
114
- end
115
-
116
- signature[:return_type] || :any
117
- end
118
-
119
- def infer_vectorized_element_type(call_expr, _type_context, _broadcast_metadata)
120
- # For vectorized arithmetic operations, infer the element type
121
- # For now, assume arithmetic operations on floats produce floats
122
- case call_expr.fn_name
123
- when :multiply, :add, :subtract, :divide
124
- :float
125
- else
126
- :any
127
- end
128
- end
129
-
130
- def infer_function_return_type(fn_name, _args, _type_context, _broadcast_metadata)
131
- # Get the function signature
132
- return :any unless FunctionRegistry.supported?(fn_name)
133
-
134
- signature = FunctionRegistry.signature(fn_name)
135
- signature[:return_type] || :any
136
- end
137
-
138
- def infer_list_type(list_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
139
- return Types.array(:any) if list_expr.elements.empty?
140
-
141
- element_types = list_expr.elements.map { |elem| infer_expression_type(elem, type_context, broadcast_metadata, current_decl_name) }
142
-
143
- # Try to unify all element types
144
- unified_type = element_types.reduce { |acc, type| Types.unify(acc, type) }
145
- Types.array(unified_type)
146
- rescue StandardError
147
- # If unification fails, fall back to generic array
148
- Types.array(:any)
149
- end
150
-
151
- def infer_vectorized_element_type(expr, type_context, vectorization_meta)
152
- # For vectorized operations, we need to infer the element type
153
- case expr
154
- when InputElementReference
155
- # Get the field type from metadata
156
- input_meta = get_state(:inputs, required: false) || {}
157
- array_name = expr.path.first
158
- field_name = expr.path[1]
159
-
160
- array_meta = input_meta[array_name]
161
- return :any unless array_meta&.dig(:type) == :array
162
-
163
- array_meta.dig(:children, field_name, :type) || :any
164
-
165
- when CallExpression
166
- # For arithmetic operations, infer from operands
167
- if %i[add subtract multiply divide].include?(expr.fn_name)
168
- # Get types of operands
169
- arg_types = expr.args.map do |arg|
170
- if arg.is_a?(InputElementReference)
171
- infer_vectorized_element_type(arg, type_context, vectorization_meta)
172
- elsif arg.is_a?(DeclarationReference)
173
- # Get the element type if it's vectorized
174
- ref_type = type_context[arg.name]
175
- if ref_type.is_a?(Hash) && ref_type.key?(:array)
176
- ref_type[:array]
177
- else
178
- ref_type || :any
179
- end
180
- else
181
- infer_expression_type(arg, type_context, vectorization_meta)
182
- end
183
- end
184
-
185
- # Unify types for arithmetic
186
- Types.unify(*arg_types) || :float
187
- else
188
- :any
189
- end
190
-
191
- else
192
- :any
193
- end
194
- end
195
-
196
- def infer_element_reference_type(expr)
197
- # Get array field metadata
198
- input_meta = get_state(:inputs, required: false) || {}
199
-
200
- return :any unless expr.path.size >= 2
201
-
202
- array_name = expr.path.first
203
- field_name = expr.path[1]
204
-
205
- array_meta = input_meta[array_name]
206
- return :any unless array_meta&.dig(:type) == :array
207
-
208
- # Get the field type from children metadata
209
- field_type = array_meta.dig(:children, field_name, :type) || :any
210
-
211
- # Return array of field type (vectorized)
212
- { array: field_type }
213
- end
214
-
215
- def infer_cascade_type(cascade_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
216
- return :any if cascade_expr.cases.empty?
217
-
218
- result_types = cascade_expr.cases.map do |case_stmt|
219
- infer_expression_type(case_stmt.result, type_context, broadcast_metadata, current_decl_name)
220
- end
221
-
222
- # Reduce all possible types into a single unified type
223
- result_types.reduce { |unified, type| Types.unify(unified, type) } || :any
224
- rescue StandardError
225
- # Check if unification fails, fall back to base type
226
- # TODO: understand if this right to fallback or we should raise
227
- :any
228
- end
229
- end
230
- end
231
- end
232
- end