kumi 0.0.7 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CLAUDE.md +1 -1
  3. data/README.md +8 -5
  4. data/examples/game_of_life.rb +1 -1
  5. data/examples/static_analysis_errors.rb +7 -7
  6. data/lib/kumi/analyzer.rb +15 -15
  7. data/lib/kumi/compiler.rb +6 -6
  8. data/lib/kumi/core/analyzer/analysis_state.rb +39 -0
  9. data/lib/kumi/core/analyzer/constant_evaluator.rb +59 -0
  10. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +248 -0
  11. data/lib/kumi/core/analyzer/passes/declaration_validator.rb +45 -0
  12. data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +153 -0
  13. data/lib/kumi/core/analyzer/passes/input_collector.rb +139 -0
  14. data/lib/kumi/core/analyzer/passes/name_indexer.rb +26 -0
  15. data/lib/kumi/core/analyzer/passes/pass_base.rb +52 -0
  16. data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +111 -0
  17. data/lib/kumi/core/analyzer/passes/toposorter.rb +110 -0
  18. data/lib/kumi/core/analyzer/passes/type_checker.rb +162 -0
  19. data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +48 -0
  20. data/lib/kumi/core/analyzer/passes/type_inferencer.rb +236 -0
  21. data/lib/kumi/core/analyzer/passes/unsat_detector.rb +406 -0
  22. data/lib/kumi/core/analyzer/passes/visitor_pass.rb +44 -0
  23. data/lib/kumi/core/atom_unsat_solver.rb +396 -0
  24. data/lib/kumi/core/compiled_schema.rb +43 -0
  25. data/lib/kumi/core/constraint_relationship_solver.rb +641 -0
  26. data/lib/kumi/core/domain/enum_analyzer.rb +55 -0
  27. data/lib/kumi/core/domain/range_analyzer.rb +85 -0
  28. data/lib/kumi/core/domain/validator.rb +82 -0
  29. data/lib/kumi/core/domain/violation_formatter.rb +42 -0
  30. data/lib/kumi/core/error_reporter.rb +166 -0
  31. data/lib/kumi/core/error_reporting.rb +97 -0
  32. data/lib/kumi/core/errors.rb +120 -0
  33. data/lib/kumi/core/evaluation_wrapper.rb +40 -0
  34. data/lib/kumi/core/explain.rb +295 -0
  35. data/lib/kumi/core/export/deserializer.rb +41 -0
  36. data/lib/kumi/core/export/errors.rb +14 -0
  37. data/lib/kumi/core/export/node_builders.rb +142 -0
  38. data/lib/kumi/core/export/node_registry.rb +54 -0
  39. data/lib/kumi/core/export/node_serializers.rb +158 -0
  40. data/lib/kumi/core/export/serializer.rb +25 -0
  41. data/lib/kumi/core/export.rb +35 -0
  42. data/lib/kumi/core/function_registry/collection_functions.rb +202 -0
  43. data/lib/kumi/core/function_registry/comparison_functions.rb +33 -0
  44. data/lib/kumi/core/function_registry/conditional_functions.rb +38 -0
  45. data/lib/kumi/core/function_registry/function_builder.rb +95 -0
  46. data/lib/kumi/core/function_registry/logical_functions.rb +44 -0
  47. data/lib/kumi/core/function_registry/math_functions.rb +74 -0
  48. data/lib/kumi/core/function_registry/string_functions.rb +57 -0
  49. data/lib/kumi/core/function_registry/type_functions.rb +53 -0
  50. data/lib/kumi/{function_registry.rb → core/function_registry.rb} +28 -36
  51. data/lib/kumi/core/input/type_matcher.rb +97 -0
  52. data/lib/kumi/core/input/validator.rb +51 -0
  53. data/lib/kumi/core/input/violation_creator.rb +52 -0
  54. data/lib/kumi/core/json_schema/generator.rb +65 -0
  55. data/lib/kumi/core/json_schema/validator.rb +27 -0
  56. data/lib/kumi/core/json_schema.rb +16 -0
  57. data/lib/kumi/core/ruby_parser/build_context.rb +27 -0
  58. data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +38 -0
  59. data/lib/kumi/core/ruby_parser/dsl.rb +14 -0
  60. data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +138 -0
  61. data/lib/kumi/core/ruby_parser/expression_converter.rb +128 -0
  62. data/lib/kumi/core/ruby_parser/guard_rails.rb +45 -0
  63. data/lib/kumi/core/ruby_parser/input_builder.rb +127 -0
  64. data/lib/kumi/core/ruby_parser/input_field_proxy.rb +48 -0
  65. data/lib/kumi/core/ruby_parser/input_proxy.rb +31 -0
  66. data/lib/kumi/core/ruby_parser/nested_input.rb +17 -0
  67. data/lib/kumi/core/ruby_parser/parser.rb +71 -0
  68. data/lib/kumi/core/ruby_parser/schema_builder.rb +175 -0
  69. data/lib/kumi/core/ruby_parser/sugar.rb +263 -0
  70. data/lib/kumi/core/ruby_parser.rb +12 -0
  71. data/lib/kumi/core/schema_instance.rb +111 -0
  72. data/lib/kumi/core/types/builder.rb +23 -0
  73. data/lib/kumi/core/types/compatibility.rb +96 -0
  74. data/lib/kumi/core/types/formatter.rb +26 -0
  75. data/lib/kumi/core/types/inference.rb +42 -0
  76. data/lib/kumi/core/types/normalizer.rb +72 -0
  77. data/lib/kumi/core/types/validator.rb +37 -0
  78. data/lib/kumi/core/types.rb +66 -0
  79. data/lib/kumi/core/vectorization_metadata.rb +110 -0
  80. data/lib/kumi/errors.rb +1 -112
  81. data/lib/kumi/registry.rb +37 -0
  82. data/lib/kumi/schema.rb +5 -5
  83. data/lib/kumi/schema_metadata.rb +3 -3
  84. data/lib/kumi/syntax/array_expression.rb +6 -6
  85. data/lib/kumi/syntax/call_expression.rb +4 -4
  86. data/lib/kumi/syntax/cascade_expression.rb +4 -4
  87. data/lib/kumi/syntax/case_expression.rb +4 -4
  88. data/lib/kumi/syntax/declaration_reference.rb +4 -4
  89. data/lib/kumi/syntax/hash_expression.rb +4 -4
  90. data/lib/kumi/syntax/input_declaration.rb +5 -5
  91. data/lib/kumi/syntax/input_element_reference.rb +5 -5
  92. data/lib/kumi/syntax/input_reference.rb +5 -5
  93. data/lib/kumi/syntax/literal.rb +4 -4
  94. data/lib/kumi/syntax/node.rb +34 -34
  95. data/lib/kumi/syntax/root.rb +6 -6
  96. data/lib/kumi/syntax/trait_declaration.rb +4 -4
  97. data/lib/kumi/syntax/value_declaration.rb +4 -4
  98. data/lib/kumi/version.rb +1 -1
  99. data/migrate_to_core_iterative.rb +938 -0
  100. data/scripts/generate_function_docs.rb +9 -9
  101. metadata +75 -72
  102. data/lib/kumi/analyzer/analysis_state.rb +0 -37
  103. data/lib/kumi/analyzer/constant_evaluator.rb +0 -57
  104. data/lib/kumi/analyzer/passes/broadcast_detector.rb +0 -246
  105. data/lib/kumi/analyzer/passes/declaration_validator.rb +0 -43
  106. data/lib/kumi/analyzer/passes/dependency_resolver.rb +0 -151
  107. data/lib/kumi/analyzer/passes/input_collector.rb +0 -137
  108. data/lib/kumi/analyzer/passes/name_indexer.rb +0 -24
  109. data/lib/kumi/analyzer/passes/pass_base.rb +0 -50
  110. data/lib/kumi/analyzer/passes/semantic_constraint_validator.rb +0 -109
  111. data/lib/kumi/analyzer/passes/toposorter.rb +0 -108
  112. data/lib/kumi/analyzer/passes/type_checker.rb +0 -160
  113. data/lib/kumi/analyzer/passes/type_consistency_checker.rb +0 -46
  114. data/lib/kumi/analyzer/passes/type_inferencer.rb +0 -232
  115. data/lib/kumi/analyzer/passes/unsat_detector.rb +0 -404
  116. data/lib/kumi/analyzer/passes/visitor_pass.rb +0 -42
  117. data/lib/kumi/atom_unsat_solver.rb +0 -394
  118. data/lib/kumi/compiled_schema.rb +0 -41
  119. data/lib/kumi/constraint_relationship_solver.rb +0 -638
  120. data/lib/kumi/domain/enum_analyzer.rb +0 -53
  121. data/lib/kumi/domain/range_analyzer.rb +0 -83
  122. data/lib/kumi/domain/validator.rb +0 -80
  123. data/lib/kumi/domain/violation_formatter.rb +0 -40
  124. data/lib/kumi/error_reporter.rb +0 -164
  125. data/lib/kumi/error_reporting.rb +0 -95
  126. data/lib/kumi/evaluation_wrapper.rb +0 -38
  127. data/lib/kumi/explain.rb +0 -293
  128. data/lib/kumi/export/deserializer.rb +0 -39
  129. data/lib/kumi/export/errors.rb +0 -12
  130. data/lib/kumi/export/node_builders.rb +0 -140
  131. data/lib/kumi/export/node_registry.rb +0 -52
  132. data/lib/kumi/export/node_serializers.rb +0 -156
  133. data/lib/kumi/export/serializer.rb +0 -23
  134. data/lib/kumi/export.rb +0 -33
  135. data/lib/kumi/function_registry/collection_functions.rb +0 -200
  136. data/lib/kumi/function_registry/comparison_functions.rb +0 -31
  137. data/lib/kumi/function_registry/conditional_functions.rb +0 -36
  138. data/lib/kumi/function_registry/function_builder.rb +0 -93
  139. data/lib/kumi/function_registry/logical_functions.rb +0 -42
  140. data/lib/kumi/function_registry/math_functions.rb +0 -72
  141. data/lib/kumi/function_registry/string_functions.rb +0 -54
  142. data/lib/kumi/function_registry/type_functions.rb +0 -51
  143. data/lib/kumi/input/type_matcher.rb +0 -95
  144. data/lib/kumi/input/validator.rb +0 -49
  145. data/lib/kumi/input/violation_creator.rb +0 -50
  146. data/lib/kumi/json_schema/generator.rb +0 -63
  147. data/lib/kumi/json_schema/validator.rb +0 -25
  148. data/lib/kumi/json_schema.rb +0 -14
  149. data/lib/kumi/ruby_parser/build_context.rb +0 -25
  150. data/lib/kumi/ruby_parser/declaration_reference_proxy.rb +0 -36
  151. data/lib/kumi/ruby_parser/dsl.rb +0 -12
  152. data/lib/kumi/ruby_parser/dsl_cascade_builder.rb +0 -136
  153. data/lib/kumi/ruby_parser/expression_converter.rb +0 -126
  154. data/lib/kumi/ruby_parser/guard_rails.rb +0 -43
  155. data/lib/kumi/ruby_parser/input_builder.rb +0 -125
  156. data/lib/kumi/ruby_parser/input_field_proxy.rb +0 -46
  157. data/lib/kumi/ruby_parser/input_proxy.rb +0 -29
  158. data/lib/kumi/ruby_parser/nested_input.rb +0 -15
  159. data/lib/kumi/ruby_parser/parser.rb +0 -69
  160. data/lib/kumi/ruby_parser/schema_builder.rb +0 -173
  161. data/lib/kumi/ruby_parser/sugar.rb +0 -261
  162. data/lib/kumi/ruby_parser.rb +0 -10
  163. data/lib/kumi/schema_instance.rb +0 -109
  164. data/lib/kumi/types/builder.rb +0 -21
  165. data/lib/kumi/types/compatibility.rb +0 -94
  166. data/lib/kumi/types/formatter.rb +0 -24
  167. data/lib/kumi/types/inference.rb +0 -40
  168. data/lib/kumi/types/normalizer.rb +0 -70
  169. data/lib/kumi/types/validator.rb +0 -35
  170. data/lib/kumi/types.rb +0 -64
  171. data/lib/kumi/vectorization_metadata.rb +0 -108
@@ -0,0 +1,162 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ module Passes
7
+ # RESPONSIBILITY: Validate function call arity and argument types against FunctionRegistry
8
+ # DEPENDENCIES: :inferred_types from TypeInferencer
9
+ # PRODUCES: None (validation only)
10
+ # INTERFACE: new(schema, state).run(errors)
11
+ class TypeChecker < VisitorPass
12
+ def run(errors)
13
+ visit_nodes_of_type(Kumi::Syntax::CallExpression, errors: errors) do |node, _decl, errs|
14
+ validate_function_call(node, errs)
15
+ end
16
+ state
17
+ end
18
+
19
+ private
20
+
21
+ def validate_function_call(node, errors)
22
+ signature = get_function_signature(node, errors)
23
+ return unless signature
24
+
25
+ validate_arity(node, signature, errors)
26
+ validate_argument_types(node, signature, errors)
27
+ end
28
+
29
+ def get_function_signature(node, errors)
30
+ Kumi::Registry.signature(node.fn_name)
31
+ rescue Kumi::Errors::UnknownFunction
32
+ # Use old format for backward compatibility, but node.loc provides better location
33
+ report_error(errors, "unsupported operator `#{node.fn_name}`", location: node.loc, type: :type)
34
+ nil
35
+ end
36
+
37
+ def validate_arity(node, signature, errors)
38
+ expected_arity = signature[:arity]
39
+ actual_arity = node.args.size
40
+
41
+ return if expected_arity.negative? || expected_arity == actual_arity
42
+
43
+ report_error(errors, "operator `#{node.fn_name}` expects #{expected_arity} args, got #{actual_arity}", location: node.loc,
44
+ type: :type)
45
+ end
46
+
47
+ def validate_argument_types(node, signature, errors)
48
+ types = signature[:param_types]
49
+ return if types.nil? || (signature[:arity].negative? && node.args.empty?)
50
+
51
+ # Skip type checking for vectorized operations
52
+ broadcast_meta = get_state(:broadcasts, required: false)
53
+ return if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
54
+
55
+ node.args.each_with_index do |arg, i|
56
+ validate_argument_type(arg, i, types[i], node.fn_name, errors)
57
+ end
58
+ end
59
+
60
+ def is_part_of_vectorized_operation?(node, broadcast_meta)
61
+ # Check if this node is part of a vectorized or reduction operation
62
+ # This is a simplified check - in a real implementation we'd need to track context
63
+ node.args.any? do |arg|
64
+ case arg
65
+ when Kumi::Syntax::DeclarationReference
66
+ broadcast_meta[:vectorized_operations]&.key?(arg.name) ||
67
+ broadcast_meta[:reduction_operations]&.key?(arg.name)
68
+ when Kumi::Syntax::InputElementReference
69
+ broadcast_meta[:array_fields]&.key?(arg.path.first)
70
+ else
71
+ false
72
+ end
73
+ end
74
+ end
75
+
76
+ def validate_argument_type(arg, index, expected_type, fn_name, errors)
77
+ return if expected_type.nil? || expected_type == Kumi::Core::Types::ANY
78
+
79
+ # Get the inferred type for this argument
80
+ actual_type = get_expression_type(arg)
81
+ return if Kumi::Core::Types.compatible?(actual_type, expected_type)
82
+
83
+ # Generate descriptive error message
84
+ source_desc = describe_expression_type(arg, actual_type)
85
+ report_error(errors, "argument #{index + 1} of `fn(:#{fn_name})` expects #{expected_type}, " \
86
+ "got #{source_desc}", location: arg.loc, type: :type)
87
+ end
88
+
89
+ def get_expression_type(expr)
90
+ case expr
91
+ when Kumi::Syntax::Literal
92
+ # Inferred type from literal value
93
+ Kumi::Core::Types.infer_from_value(expr.value)
94
+
95
+ when Kumi::Syntax::InputReference
96
+ # Declared type from input block (user-specified)
97
+ get_declared_field_type(expr.name)
98
+
99
+ when Kumi::Syntax::DeclarationReference
100
+ # Inferred type from type inference results
101
+ get_inferred_declaration_type(expr.name)
102
+
103
+ else
104
+ # For complex expressions, we should have type inference results
105
+ # This is a simplified approach - in reality we'd need to track types for all expressions
106
+ Kumi::Core::Types::ANY
107
+ end
108
+ end
109
+
110
+ def get_declared_field_type(field_name)
111
+ # Get explicitly declared type from input metadata
112
+ input_meta = get_state(:inputs, required: false) || {}
113
+ field_meta = input_meta[field_name]
114
+ field_meta&.dig(:type) || Kumi::Core::Types::ANY
115
+ end
116
+
117
+ def get_inferred_declaration_type(decl_name)
118
+ # Get inferred type from type inference results
119
+ decl_types = get_state(:inferred_types, required: true)
120
+ decl_types[decl_name] || Kumi::Core::Types::ANY
121
+ end
122
+
123
+ def describe_expression_type(expr, type)
124
+ case expr
125
+ when Kumi::Syntax::Literal
126
+ "`#{expr.value}` of type #{type} (literal value)"
127
+
128
+ when Kumi::Syntax::InputReference
129
+ input_meta = get_state(:inputs, required: false) || {}
130
+ field_meta = input_meta[expr.name]
131
+
132
+ if field_meta&.dig(:type)
133
+ # Explicitly declared type
134
+ domain_desc = field_meta[:domain] ? " (domain: #{field_meta[:domain]})" : ""
135
+ "input field `#{expr.name}` of declared type #{type}#{domain_desc}"
136
+ else
137
+ # Undeclared field
138
+ "undeclared input field `#{expr.name}` (inferred as #{type})"
139
+ end
140
+
141
+ when Kumi::Syntax::DeclarationReference
142
+ # This type was inferred from the declaration's expression
143
+ "reference to declaration `#{expr.name}` of inferred type #{type}"
144
+
145
+ when Kumi::Syntax::CallExpression
146
+ "result of function `#{expr.fn_name}` returning #{type}"
147
+
148
+ when Kumi::Syntax::ArrayExpression
149
+ "list expression of type #{type}"
150
+
151
+ when Kumi::Syntax::CascadeExpression
152
+ "cascade expression of type #{type}"
153
+
154
+ else
155
+ "expression of type #{type}"
156
+ end
157
+ end
158
+ end
159
+ end
160
+ end
161
+ end
162
+ end
@@ -0,0 +1,48 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ module Passes
7
+ # RESPONSIBILITY: Validate consistency between declared and inferred types
8
+ # DEPENDENCIES: :inputs from InputCollector, :inferred_types from TypeInferencer
9
+ # PRODUCES: None (validation only)
10
+ # INTERFACE: new(schema, state).run(errors)
11
+ class TypeConsistencyChecker < PassBase
12
+ def run(errors)
13
+ input_meta = get_state(:inputs, required: false) || {}
14
+
15
+ # First, validate that all declared types are valid
16
+ validate_declared_types(input_meta, errors)
17
+
18
+ # Then check basic consistency (placeholder for now)
19
+ # In a full implementation, this would do sophisticated usage analysis
20
+ state
21
+ end
22
+
23
+ private
24
+
25
+ def validate_declared_types(input_meta, errors)
26
+ input_meta.each do |field_name, meta|
27
+ declared_type = meta[:type]
28
+ next unless declared_type # Skip fields without declared types
29
+ next if Kumi::Core::Types.valid_type?(declared_type)
30
+
31
+ # Find the input field declaration for proper location information
32
+ field_decl = find_input_field_declaration(field_name)
33
+ location = field_decl&.loc
34
+
35
+ report_type_error(errors, "Invalid type declaration for field :#{field_name}: #{declared_type.inspect}", location: location)
36
+ end
37
+ end
38
+
39
+ def find_input_field_declaration(field_name)
40
+ return nil unless schema
41
+
42
+ schema.inputs.find { |input_decl| input_decl.name == field_name }
43
+ end
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,236 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ module Passes
7
+ # RESPONSIBILITY: Infer types for all declarations based on expression analysis
8
+ # DEPENDENCIES: Toposorter (needs evaluation_order), DeclarationValidator (needs declarations)
9
+ # PRODUCES: inferred_types hash mapping declaration names to inferred types
10
+ # INTERFACE: new(schema, state).run(errors)
11
+ class TypeInferencer < PassBase
12
+ def run(errors)
13
+ types = {}
14
+ topo_order = get_state(:evaluation_order)
15
+ definitions = get_state(:declarations)
16
+
17
+ # Get broadcast metadata from broadcast detector
18
+ broadcast_meta = get_state(:broadcasts, required: false) || {}
19
+
20
+ # Process declarations in topological order to ensure dependencies are resolved
21
+ topo_order.each do |name|
22
+ decl = definitions[name]
23
+ next unless decl
24
+
25
+ begin
26
+ # Check if this declaration is marked as vectorized
27
+ if broadcast_meta[:vectorized_operations]&.key?(name)
28
+ # Infer the element type and wrap in array
29
+ element_type = infer_vectorized_element_type(decl.expression, types, broadcast_meta)
30
+ types[name] = decl.is_a?(Kumi::Syntax::TraitDeclaration) ? { array: :boolean } : { array: element_type }
31
+ else
32
+ # Normal type inference
33
+ inferred_type = infer_expression_type(decl.expression, types, broadcast_meta, name)
34
+ types[name] = inferred_type
35
+ end
36
+ rescue StandardError => e
37
+ report_type_error(errors, "Type inference failed: #{e.message}", location: decl&.loc)
38
+ end
39
+ end
40
+
41
+ state.with(:inferred_types, types)
42
+ end
43
+
44
+ private
45
+
46
+ def infer_expression_type(expr, type_context = {}, broadcast_metadata = {}, current_decl_name = nil)
47
+ case expr
48
+ when Literal
49
+ Types.infer_from_value(expr.value)
50
+ when InputReference
51
+ # Look up type from field metadata
52
+ input_meta = get_state(:inputs, required: false) || {}
53
+ meta = input_meta[expr.name]
54
+ meta&.dig(:type) || :any
55
+ when DeclarationReference
56
+ type_context[expr.name] || :any
57
+ when CallExpression
58
+ infer_call_type(expr, type_context, broadcast_metadata, current_decl_name)
59
+ when ArrayExpression
60
+ infer_list_type(expr, type_context, broadcast_metadata, current_decl_name)
61
+ when CascadeExpression
62
+ infer_cascade_type(expr, type_context, broadcast_metadata, current_decl_name)
63
+ when InputElementReference
64
+ # Element reference returns the field type
65
+ infer_element_reference_type(expr)
66
+ else
67
+ :any
68
+ end
69
+ end
70
+
71
+ def infer_call_type(call_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
72
+ fn_name = call_expr.fn_name
73
+ args = call_expr.args
74
+
75
+ # Check broadcast metadata first
76
+ if current_decl_name && broadcast_metadata[:vectorized_values]&.key?(current_decl_name)
77
+ # This declaration is marked as vectorized, so it produces an array
78
+ element_type = infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
79
+ return { array: element_type }
80
+ end
81
+
82
+ if current_decl_name && broadcast_metadata[:reducer_values]&.key?(current_decl_name)
83
+ # This declaration is marked as a reducer, get the result from the function
84
+ return infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
85
+ end
86
+
87
+ # Check if function exists in registry
88
+ unless Kumi::Registry.supported?(fn_name)
89
+ # Don't push error here - let existing TypeChecker handle it
90
+ return :any
91
+ end
92
+
93
+ signature = Kumi::Registry.signature(fn_name)
94
+
95
+ # Validate arity if not variable
96
+ if signature[:arity] >= 0 && args.size != signature[:arity]
97
+ # Don't push error here - let existing TypeChecker handle it
98
+ return :any
99
+ end
100
+
101
+ # Infer argument types
102
+ arg_types = args.map { |arg| infer_expression_type(arg, type_context, broadcast_metadata, current_decl_name) }
103
+
104
+ # Validate parameter types (warn but don't fail)
105
+ param_types = signature[:param_types] || []
106
+ if signature[:arity] >= 0 && param_types.size.positive?
107
+ arg_types.each_with_index do |arg_type, i|
108
+ expected_type = param_types[i] || param_types.last
109
+ next if expected_type.nil?
110
+
111
+ unless Types.compatible?(arg_type, expected_type)
112
+ # Could add warning here in future, but for now just infer best type
113
+ end
114
+ end
115
+ end
116
+
117
+ signature[:return_type] || :any
118
+ end
119
+
120
+ def infer_vectorized_element_type(call_expr, _type_context, _broadcast_metadata)
121
+ # For vectorized arithmetic operations, infer the element type
122
+ # For now, assume arithmetic operations on floats produce floats
123
+ case call_expr.fn_name
124
+ when :multiply, :add, :subtract, :divide
125
+ :float
126
+ else
127
+ :any
128
+ end
129
+ end
130
+
131
+ def infer_function_return_type(fn_name, _args, _type_context, _broadcast_metadata)
132
+ # Get the function signature
133
+ return :any unless Kumi::Registry.supported?(fn_name)
134
+
135
+ signature = Kumi::Registry.signature(fn_name)
136
+ signature[:return_type] || :any
137
+ end
138
+
139
+ def infer_list_type(list_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
140
+ return Types.array(:any) if list_expr.elements.empty?
141
+
142
+ element_types = list_expr.elements.map do |elem|
143
+ infer_expression_type(elem, type_context, broadcast_metadata, current_decl_name)
144
+ end
145
+
146
+ # Try to unify all element types
147
+ unified_type = element_types.reduce { |acc, type| Types.unify(acc, type) }
148
+ Types.array(unified_type)
149
+ rescue StandardError
150
+ # If unification fails, fall back to generic array
151
+ Types.array(:any)
152
+ end
153
+
154
+ def infer_vectorized_element_type(expr, type_context, vectorization_meta)
155
+ # For vectorized operations, we need to infer the element type
156
+ case expr
157
+ when InputElementReference
158
+ # Get the field type from metadata
159
+ input_meta = get_state(:inputs, required: false) || {}
160
+ array_name = expr.path.first
161
+ field_name = expr.path[1]
162
+
163
+ array_meta = input_meta[array_name]
164
+ return :any unless array_meta&.dig(:type) == :array
165
+
166
+ array_meta.dig(:children, field_name, :type) || :any
167
+
168
+ when CallExpression
169
+ # For arithmetic operations, infer from operands
170
+ if %i[add subtract multiply divide].include?(expr.fn_name)
171
+ # Get types of operands
172
+ arg_types = expr.args.map do |arg|
173
+ if arg.is_a?(InputElementReference)
174
+ infer_vectorized_element_type(arg, type_context, vectorization_meta)
175
+ elsif arg.is_a?(DeclarationReference)
176
+ # Get the element type if it's vectorized
177
+ ref_type = type_context[arg.name]
178
+ if ref_type.is_a?(Hash) && ref_type.key?(:array)
179
+ ref_type[:array]
180
+ else
181
+ ref_type || :any
182
+ end
183
+ else
184
+ infer_expression_type(arg, type_context, vectorization_meta)
185
+ end
186
+ end
187
+
188
+ # Unify types for arithmetic
189
+ Types.unify(*arg_types) || :float
190
+ else
191
+ :any
192
+ end
193
+
194
+ else
195
+ :any
196
+ end
197
+ end
198
+
199
+ def infer_element_reference_type(expr)
200
+ # Get array field metadata
201
+ input_meta = get_state(:inputs, required: false) || {}
202
+
203
+ return :any unless expr.path.size >= 2
204
+
205
+ array_name = expr.path.first
206
+ field_name = expr.path[1]
207
+
208
+ array_meta = input_meta[array_name]
209
+ return :any unless array_meta&.dig(:type) == :array
210
+
211
+ # Get the field type from children metadata
212
+ field_type = array_meta.dig(:children, field_name, :type) || :any
213
+
214
+ # Return array of field type (vectorized)
215
+ { array: field_type }
216
+ end
217
+
218
+ def infer_cascade_type(cascade_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
219
+ return :any if cascade_expr.cases.empty?
220
+
221
+ result_types = cascade_expr.cases.map do |case_stmt|
222
+ infer_expression_type(case_stmt.result, type_context, broadcast_metadata, current_decl_name)
223
+ end
224
+
225
+ # Reduce all possible types into a single unified type
226
+ result_types.reduce { |unified, type| Types.unify(unified, type) } || :any
227
+ rescue StandardError
228
+ # Check if unification fails, fall back to base type
229
+ # TODO: understand if this right to fallback or we should raise
230
+ :any
231
+ end
232
+ end
233
+ end
234
+ end
235
+ end
236
+ end