kumi 0.0.6 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. checksums.yaml +4 -4
  2. data/CLAUDE.md +34 -177
  3. data/README.md +41 -7
  4. data/docs/SYNTAX.md +2 -7
  5. data/docs/features/array-broadcasting.md +1 -1
  6. data/docs/schema_metadata/broadcasts.md +53 -0
  7. data/docs/schema_metadata/cascades.md +45 -0
  8. data/docs/schema_metadata/declarations.md +54 -0
  9. data/docs/schema_metadata/dependencies.md +57 -0
  10. data/docs/schema_metadata/evaluation_order.md +29 -0
  11. data/docs/schema_metadata/examples.md +95 -0
  12. data/docs/schema_metadata/inferred_types.md +46 -0
  13. data/docs/schema_metadata/inputs.md +86 -0
  14. data/docs/schema_metadata.md +108 -0
  15. data/examples/game_of_life.rb +1 -1
  16. data/examples/static_analysis_errors.rb +7 -7
  17. data/lib/kumi/analyzer.rb +20 -20
  18. data/lib/kumi/compiler.rb +44 -50
  19. data/lib/kumi/core/analyzer/analysis_state.rb +39 -0
  20. data/lib/kumi/core/analyzer/constant_evaluator.rb +59 -0
  21. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +248 -0
  22. data/lib/kumi/core/analyzer/passes/declaration_validator.rb +45 -0
  23. data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +153 -0
  24. data/lib/kumi/core/analyzer/passes/input_collector.rb +139 -0
  25. data/lib/kumi/core/analyzer/passes/name_indexer.rb +26 -0
  26. data/lib/kumi/core/analyzer/passes/pass_base.rb +52 -0
  27. data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +111 -0
  28. data/lib/kumi/core/analyzer/passes/toposorter.rb +110 -0
  29. data/lib/kumi/core/analyzer/passes/type_checker.rb +162 -0
  30. data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +48 -0
  31. data/lib/kumi/core/analyzer/passes/type_inferencer.rb +236 -0
  32. data/lib/kumi/core/analyzer/passes/unsat_detector.rb +406 -0
  33. data/lib/kumi/core/analyzer/passes/visitor_pass.rb +44 -0
  34. data/lib/kumi/core/atom_unsat_solver.rb +396 -0
  35. data/lib/kumi/core/compiled_schema.rb +43 -0
  36. data/lib/kumi/core/constraint_relationship_solver.rb +641 -0
  37. data/lib/kumi/core/domain/enum_analyzer.rb +55 -0
  38. data/lib/kumi/core/domain/range_analyzer.rb +85 -0
  39. data/lib/kumi/core/domain/validator.rb +82 -0
  40. data/lib/kumi/core/domain/violation_formatter.rb +42 -0
  41. data/lib/kumi/core/error_reporter.rb +166 -0
  42. data/lib/kumi/core/error_reporting.rb +97 -0
  43. data/lib/kumi/core/errors.rb +120 -0
  44. data/lib/kumi/core/evaluation_wrapper.rb +40 -0
  45. data/lib/kumi/core/explain.rb +295 -0
  46. data/lib/kumi/core/export/deserializer.rb +41 -0
  47. data/lib/kumi/core/export/errors.rb +14 -0
  48. data/lib/kumi/core/export/node_builders.rb +142 -0
  49. data/lib/kumi/core/export/node_registry.rb +54 -0
  50. data/lib/kumi/core/export/node_serializers.rb +158 -0
  51. data/lib/kumi/core/export/serializer.rb +25 -0
  52. data/lib/kumi/core/export.rb +35 -0
  53. data/lib/kumi/core/function_registry/collection_functions.rb +202 -0
  54. data/lib/kumi/core/function_registry/comparison_functions.rb +33 -0
  55. data/lib/kumi/core/function_registry/conditional_functions.rb +38 -0
  56. data/lib/kumi/core/function_registry/function_builder.rb +95 -0
  57. data/lib/kumi/core/function_registry/logical_functions.rb +44 -0
  58. data/lib/kumi/core/function_registry/math_functions.rb +74 -0
  59. data/lib/kumi/core/function_registry/string_functions.rb +57 -0
  60. data/lib/kumi/core/function_registry/type_functions.rb +53 -0
  61. data/lib/kumi/{function_registry.rb → core/function_registry.rb} +28 -36
  62. data/lib/kumi/core/input/type_matcher.rb +97 -0
  63. data/lib/kumi/core/input/validator.rb +51 -0
  64. data/lib/kumi/core/input/violation_creator.rb +52 -0
  65. data/lib/kumi/core/json_schema/generator.rb +65 -0
  66. data/lib/kumi/core/json_schema/validator.rb +27 -0
  67. data/lib/kumi/core/json_schema.rb +16 -0
  68. data/lib/kumi/core/ruby_parser/build_context.rb +27 -0
  69. data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +38 -0
  70. data/lib/kumi/core/ruby_parser/dsl.rb +14 -0
  71. data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +138 -0
  72. data/lib/kumi/core/ruby_parser/expression_converter.rb +128 -0
  73. data/lib/kumi/core/ruby_parser/guard_rails.rb +45 -0
  74. data/lib/kumi/core/ruby_parser/input_builder.rb +127 -0
  75. data/lib/kumi/core/ruby_parser/input_field_proxy.rb +48 -0
  76. data/lib/kumi/core/ruby_parser/input_proxy.rb +31 -0
  77. data/lib/kumi/core/ruby_parser/nested_input.rb +17 -0
  78. data/lib/kumi/core/ruby_parser/parser.rb +71 -0
  79. data/lib/kumi/core/ruby_parser/schema_builder.rb +175 -0
  80. data/lib/kumi/core/ruby_parser/sugar.rb +263 -0
  81. data/lib/kumi/core/ruby_parser.rb +12 -0
  82. data/lib/kumi/core/schema_instance.rb +111 -0
  83. data/lib/kumi/core/types/builder.rb +23 -0
  84. data/lib/kumi/core/types/compatibility.rb +96 -0
  85. data/lib/kumi/core/types/formatter.rb +26 -0
  86. data/lib/kumi/core/types/inference.rb +42 -0
  87. data/lib/kumi/core/types/normalizer.rb +72 -0
  88. data/lib/kumi/core/types/validator.rb +37 -0
  89. data/lib/kumi/core/types.rb +66 -0
  90. data/lib/kumi/core/vectorization_metadata.rb +110 -0
  91. data/lib/kumi/errors.rb +1 -112
  92. data/lib/kumi/registry.rb +37 -0
  93. data/lib/kumi/schema.rb +13 -7
  94. data/lib/kumi/schema_metadata.rb +524 -0
  95. data/lib/kumi/syntax/array_expression.rb +6 -6
  96. data/lib/kumi/syntax/call_expression.rb +4 -4
  97. data/lib/kumi/syntax/cascade_expression.rb +4 -4
  98. data/lib/kumi/syntax/case_expression.rb +4 -4
  99. data/lib/kumi/syntax/declaration_reference.rb +4 -4
  100. data/lib/kumi/syntax/hash_expression.rb +4 -4
  101. data/lib/kumi/syntax/input_declaration.rb +5 -5
  102. data/lib/kumi/syntax/input_element_reference.rb +5 -5
  103. data/lib/kumi/syntax/input_reference.rb +5 -5
  104. data/lib/kumi/syntax/literal.rb +4 -4
  105. data/lib/kumi/syntax/node.rb +34 -34
  106. data/lib/kumi/syntax/root.rb +6 -6
  107. data/lib/kumi/syntax/trait_declaration.rb +4 -4
  108. data/lib/kumi/syntax/value_declaration.rb +4 -4
  109. data/lib/kumi/version.rb +1 -1
  110. data/lib/kumi.rb +14 -0
  111. data/migrate_to_core_iterative.rb +938 -0
  112. data/scripts/generate_function_docs.rb +9 -9
  113. metadata +85 -69
  114. data/lib/generators/trait_engine/templates/schema_spec.rb.erb +0 -27
  115. data/lib/kumi/analyzer/analysis_state.rb +0 -37
  116. data/lib/kumi/analyzer/constant_evaluator.rb +0 -57
  117. data/lib/kumi/analyzer/passes/broadcast_detector.rb +0 -251
  118. data/lib/kumi/analyzer/passes/declaration_validator.rb +0 -43
  119. data/lib/kumi/analyzer/passes/dependency_resolver.rb +0 -151
  120. data/lib/kumi/analyzer/passes/input_collector.rb +0 -137
  121. data/lib/kumi/analyzer/passes/name_indexer.rb +0 -24
  122. data/lib/kumi/analyzer/passes/pass_base.rb +0 -50
  123. data/lib/kumi/analyzer/passes/semantic_constraint_validator.rb +0 -110
  124. data/lib/kumi/analyzer/passes/toposorter.rb +0 -108
  125. data/lib/kumi/analyzer/passes/type_checker.rb +0 -162
  126. data/lib/kumi/analyzer/passes/type_consistency_checker.rb +0 -46
  127. data/lib/kumi/analyzer/passes/type_inferencer.rb +0 -232
  128. data/lib/kumi/analyzer/passes/unsat_detector.rb +0 -406
  129. data/lib/kumi/analyzer/passes/visitor_pass.rb +0 -42
  130. data/lib/kumi/atom_unsat_solver.rb +0 -394
  131. data/lib/kumi/compiled_schema.rb +0 -41
  132. data/lib/kumi/constraint_relationship_solver.rb +0 -638
  133. data/lib/kumi/domain/enum_analyzer.rb +0 -53
  134. data/lib/kumi/domain/range_analyzer.rb +0 -83
  135. data/lib/kumi/domain/validator.rb +0 -80
  136. data/lib/kumi/domain/violation_formatter.rb +0 -40
  137. data/lib/kumi/error_reporter.rb +0 -164
  138. data/lib/kumi/error_reporting.rb +0 -95
  139. data/lib/kumi/evaluation_wrapper.rb +0 -38
  140. data/lib/kumi/explain.rb +0 -281
  141. data/lib/kumi/export/deserializer.rb +0 -39
  142. data/lib/kumi/export/errors.rb +0 -12
  143. data/lib/kumi/export/node_builders.rb +0 -140
  144. data/lib/kumi/export/node_registry.rb +0 -52
  145. data/lib/kumi/export/node_serializers.rb +0 -156
  146. data/lib/kumi/export/serializer.rb +0 -23
  147. data/lib/kumi/export.rb +0 -33
  148. data/lib/kumi/function_registry/collection_functions.rb +0 -200
  149. data/lib/kumi/function_registry/comparison_functions.rb +0 -31
  150. data/lib/kumi/function_registry/conditional_functions.rb +0 -36
  151. data/lib/kumi/function_registry/function_builder.rb +0 -93
  152. data/lib/kumi/function_registry/logical_functions.rb +0 -42
  153. data/lib/kumi/function_registry/math_functions.rb +0 -72
  154. data/lib/kumi/function_registry/string_functions.rb +0 -54
  155. data/lib/kumi/function_registry/type_functions.rb +0 -51
  156. data/lib/kumi/input/type_matcher.rb +0 -95
  157. data/lib/kumi/input/validator.rb +0 -49
  158. data/lib/kumi/input/violation_creator.rb +0 -50
  159. data/lib/kumi/parser/build_context.rb +0 -25
  160. data/lib/kumi/parser/declaration_reference_proxy.rb +0 -36
  161. data/lib/kumi/parser/dsl.rb +0 -12
  162. data/lib/kumi/parser/dsl_cascade_builder.rb +0 -136
  163. data/lib/kumi/parser/expression_converter.rb +0 -126
  164. data/lib/kumi/parser/guard_rails.rb +0 -43
  165. data/lib/kumi/parser/input_builder.rb +0 -125
  166. data/lib/kumi/parser/input_field_proxy.rb +0 -46
  167. data/lib/kumi/parser/input_proxy.rb +0 -29
  168. data/lib/kumi/parser/nested_input.rb +0 -15
  169. data/lib/kumi/parser/parser.rb +0 -68
  170. data/lib/kumi/parser/schema_builder.rb +0 -173
  171. data/lib/kumi/parser/sugar.rb +0 -261
  172. data/lib/kumi/schema_instance.rb +0 -109
  173. data/lib/kumi/types/builder.rb +0 -21
  174. data/lib/kumi/types/compatibility.rb +0 -94
  175. data/lib/kumi/types/formatter.rb +0 -24
  176. data/lib/kumi/types/inference.rb +0 -40
  177. data/lib/kumi/types/normalizer.rb +0 -70
  178. data/lib/kumi/types/validator.rb +0 -35
  179. data/lib/kumi/types.rb +0 -64
  180. data/lib/kumi/vectorization_metadata.rb +0 -108
@@ -1,162 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Analyzer
5
- module Passes
6
- # RESPONSIBILITY: Validate function call arity and argument types against FunctionRegistry
7
- # DEPENDENCIES: :decl_types from TypeInferencer
8
- # PRODUCES: None (validation only)
9
- # INTERFACE: new(schema, state).run(errors)
10
- class TypeChecker < VisitorPass
11
- def run(errors)
12
- visit_nodes_of_type(Kumi::Syntax::CallExpression, errors: errors) do |node, _decl, errs|
13
- validate_function_call(node, errs)
14
- end
15
- state
16
- end
17
-
18
- private
19
-
20
- def validate_function_call(node, errors)
21
- signature = get_function_signature(node, errors)
22
- return unless signature
23
-
24
- validate_arity(node, signature, errors)
25
- validate_argument_types(node, signature, errors)
26
- end
27
-
28
- def get_function_signature(node, errors)
29
- FunctionRegistry.signature(node.fn_name)
30
- rescue FunctionRegistry::UnknownFunction
31
- # Use old format for backward compatibility, but node.loc provides better location
32
- report_error(errors, "unsupported operator `#{node.fn_name}`", location: node.loc, type: :type)
33
- nil
34
- end
35
-
36
- def validate_arity(node, signature, errors)
37
- expected_arity = signature[:arity]
38
- actual_arity = node.args.size
39
-
40
- return if expected_arity.negative? || expected_arity == actual_arity
41
-
42
- report_error(errors, "operator `#{node.fn_name}` expects #{expected_arity} args, got #{actual_arity}", location: node.loc,
43
- type: :type)
44
- end
45
-
46
- def validate_argument_types(node, signature, errors)
47
- types = signature[:param_types]
48
- return if types.nil? || (signature[:arity].negative? && node.args.empty?)
49
-
50
- # Skip type checking for vectorized operations
51
- broadcast_meta = get_state(:broadcast_metadata, required: false)
52
- if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
53
- return
54
- end
55
-
56
- node.args.each_with_index do |arg, i|
57
- validate_argument_type(arg, i, types[i], node.fn_name, errors)
58
- end
59
- end
60
-
61
- def is_part_of_vectorized_operation?(node, broadcast_meta)
62
- # Check if this node is part of a vectorized or reduction operation
63
- # This is a simplified check - in a real implementation we'd need to track context
64
- node.args.any? do |arg|
65
- case arg
66
- when Kumi::Syntax::DeclarationReference
67
- broadcast_meta[:vectorized_operations]&.key?(arg.name) ||
68
- broadcast_meta[:reduction_operations]&.key?(arg.name)
69
- when Kumi::Syntax::InputElementReference
70
- broadcast_meta[:array_fields]&.key?(arg.path.first)
71
- else
72
- false
73
- end
74
- end
75
- end
76
-
77
- def validate_argument_type(arg, index, expected_type, fn_name, errors)
78
- return if expected_type.nil? || expected_type == Kumi::Types::ANY
79
-
80
- # Get the inferred type for this argument
81
- actual_type = get_expression_type(arg)
82
- return if Kumi::Types.compatible?(actual_type, expected_type)
83
-
84
- # Generate descriptive error message
85
- source_desc = describe_expression_type(arg, actual_type)
86
- report_error(errors, "argument #{index + 1} of `fn(:#{fn_name})` expects #{expected_type}, " \
87
- "got #{source_desc}", location: arg.loc, type: :type)
88
- end
89
-
90
- def get_expression_type(expr)
91
- case expr
92
- when Kumi::Syntax::Literal
93
- # Inferred type from literal value
94
- Kumi::Types.infer_from_value(expr.value)
95
-
96
- when Kumi::Syntax::InputReference
97
- # Declared type from input block (user-specified)
98
- get_declared_field_type(expr.name)
99
-
100
- when Kumi::Syntax::DeclarationReference
101
- # Inferred type from type inference results
102
- get_inferred_declaration_type(expr.name)
103
-
104
- else
105
- # For complex expressions, we should have type inference results
106
- # This is a simplified approach - in reality we'd need to track types for all expressions
107
- Kumi::Types::ANY
108
- end
109
- end
110
-
111
- def get_declared_field_type(field_name)
112
- # Get explicitly declared type from input metadata
113
- input_meta = get_state(:input_meta, required: false) || {}
114
- field_meta = input_meta[field_name]
115
- field_meta&.dig(:type) || Kumi::Types::ANY
116
- end
117
-
118
- def get_inferred_declaration_type(decl_name)
119
- # Get inferred type from type inference results
120
- decl_types = get_state(:decl_types, required: true)
121
- decl_types[decl_name] || Kumi::Types::ANY
122
- end
123
-
124
- def describe_expression_type(expr, type)
125
- case expr
126
- when Kumi::Syntax::Literal
127
- "`#{expr.value}` of type #{type} (literal value)"
128
-
129
- when Kumi::Syntax::InputReference
130
- input_meta = get_state(:input_meta, required: false) || {}
131
- field_meta = input_meta[expr.name]
132
-
133
- if field_meta&.dig(:type)
134
- # Explicitly declared type
135
- domain_desc = field_meta[:domain] ? " (domain: #{field_meta[:domain]})" : ""
136
- "input field `#{expr.name}` of declared type #{type}#{domain_desc}"
137
- else
138
- # Undeclared field
139
- "undeclared input field `#{expr.name}` (inferred as #{type})"
140
- end
141
-
142
- when Kumi::Syntax::DeclarationReference
143
- # This type was inferred from the declaration's expression
144
- "reference to declaration `#{expr.name}` of inferred type #{type}"
145
-
146
- when Kumi::Syntax::CallExpression
147
- "result of function `#{expr.fn_name}` returning #{type}"
148
-
149
- when Kumi::Syntax::ArrayExpression
150
- "list expression of type #{type}"
151
-
152
- when Kumi::Syntax::CascadeExpression
153
- "cascade expression of type #{type}"
154
-
155
- else
156
- "expression of type #{type}"
157
- end
158
- end
159
- end
160
- end
161
- end
162
- end
@@ -1,46 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Analyzer
5
- module Passes
6
- # RESPONSIBILITY: Validate consistency between declared and inferred types
7
- # DEPENDENCIES: :input_meta from InputCollector, :decl_types from TypeInferencer
8
- # PRODUCES: None (validation only)
9
- # INTERFACE: new(schema, state).run(errors)
10
- class TypeConsistencyChecker < PassBase
11
- def run(errors)
12
- input_meta = get_state(:input_meta, required: false) || {}
13
-
14
- # First, validate that all declared types are valid
15
- validate_declared_types(input_meta, errors)
16
-
17
- # Then check basic consistency (placeholder for now)
18
- # In a full implementation, this would do sophisticated usage analysis
19
- state
20
- end
21
-
22
- private
23
-
24
- def validate_declared_types(input_meta, errors)
25
- input_meta.each do |field_name, meta|
26
- declared_type = meta[:type]
27
- next unless declared_type # Skip fields without declared types
28
- next if Kumi::Types.valid_type?(declared_type)
29
-
30
- # Find the input field declaration for proper location information
31
- field_decl = find_input_field_declaration(field_name)
32
- location = field_decl&.loc
33
-
34
- report_type_error(errors, "Invalid type declaration for field :#{field_name}: #{declared_type.inspect}", location: location)
35
- end
36
- end
37
-
38
- def find_input_field_declaration(field_name)
39
- return nil unless schema
40
-
41
- schema.inputs.find { |input_decl| input_decl.name == field_name }
42
- end
43
- end
44
- end
45
- end
46
- end
@@ -1,232 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Analyzer
5
- module Passes
6
- # RESPONSIBILITY: Infer types for all declarations based on expression analysis
7
- # DEPENDENCIES: Toposorter (needs topo_order), DeclarationValidator (needs definitions)
8
- # PRODUCES: decl_types hash mapping declaration names to inferred types
9
- # INTERFACE: new(schema, state).run(errors)
10
- class TypeInferencer < PassBase
11
- def run(errors)
12
- types = {}
13
- topo_order = get_state(:topo_order)
14
- definitions = get_state(:definitions)
15
-
16
- # Get broadcast metadata from broadcast detector
17
- broadcast_meta = get_state(:broadcast_metadata, required: false) || {}
18
-
19
- # Process declarations in topological order to ensure dependencies are resolved
20
- topo_order.each do |name|
21
- decl = definitions[name]
22
- next unless decl
23
-
24
- begin
25
- # Check if this declaration is marked as vectorized
26
- if broadcast_meta[:vectorized_operations]&.key?(name)
27
- # Infer the element type and wrap in array
28
- element_type = infer_vectorized_element_type(decl.expression, types, broadcast_meta)
29
- types[name] = decl.is_a?(Kumi::Syntax::TraitDeclaration) ? { array: :boolean } : { array: element_type }
30
- else
31
- # Normal type inference
32
- inferred_type = infer_expression_type(decl.expression, types, broadcast_meta, name)
33
- types[name] = inferred_type
34
- end
35
- rescue StandardError => e
36
- report_type_error(errors, "Type inference failed: #{e.message}", location: decl&.loc)
37
- end
38
- end
39
-
40
- state.with(:decl_types, types)
41
- end
42
-
43
- private
44
-
45
- def infer_expression_type(expr, type_context = {}, broadcast_metadata = {}, current_decl_name = nil)
46
- case expr
47
- when Literal
48
- Types.infer_from_value(expr.value)
49
- when InputReference
50
- # Look up type from field metadata
51
- input_meta = get_state(:input_meta, required: false) || {}
52
- meta = input_meta[expr.name]
53
- meta&.dig(:type) || :any
54
- when DeclarationReference
55
- type_context[expr.name] || :any
56
- when CallExpression
57
- infer_call_type(expr, type_context, broadcast_metadata, current_decl_name)
58
- when ArrayExpression
59
- infer_list_type(expr, type_context, broadcast_metadata, current_decl_name)
60
- when CascadeExpression
61
- infer_cascade_type(expr, type_context, broadcast_metadata, current_decl_name)
62
- when InputElementReference
63
- # Element reference returns the field type
64
- infer_element_reference_type(expr)
65
- else
66
- :any
67
- end
68
- end
69
-
70
- def infer_call_type(call_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
71
- fn_name = call_expr.fn_name
72
- args = call_expr.args
73
-
74
- # Check broadcast metadata first
75
- if current_decl_name && broadcast_metadata[:vectorized_values]&.key?(current_decl_name)
76
- # This declaration is marked as vectorized, so it produces an array
77
- element_type = infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
78
- return { array: element_type }
79
- end
80
-
81
- if current_decl_name && broadcast_metadata[:reducer_values]&.key?(current_decl_name)
82
- # This declaration is marked as a reducer, get the result from the function
83
- return infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
84
- end
85
-
86
- # Check if function exists in registry
87
- unless FunctionRegistry.supported?(fn_name)
88
- # Don't push error here - let existing TypeChecker handle it
89
- return :any
90
- end
91
-
92
- signature = FunctionRegistry.signature(fn_name)
93
-
94
- # Validate arity if not variable
95
- if signature[:arity] >= 0 && args.size != signature[:arity]
96
- # Don't push error here - let existing TypeChecker handle it
97
- return :any
98
- end
99
-
100
- # Infer argument types
101
- arg_types = args.map { |arg| infer_expression_type(arg, type_context, broadcast_metadata, current_decl_name) }
102
-
103
- # Validate parameter types (warn but don't fail)
104
- param_types = signature[:param_types] || []
105
- if signature[:arity] >= 0 && param_types.size.positive?
106
- arg_types.each_with_index do |arg_type, i|
107
- expected_type = param_types[i] || param_types.last
108
- next if expected_type.nil?
109
-
110
- unless Types.compatible?(arg_type, expected_type)
111
- # Could add warning here in future, but for now just infer best type
112
- end
113
- end
114
- end
115
-
116
- signature[:return_type] || :any
117
- end
118
-
119
- def infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
120
- # For vectorized arithmetic operations, infer the element type
121
- # For now, assume arithmetic operations on floats produce floats
122
- case call_expr.fn_name
123
- when :multiply, :add, :subtract, :divide
124
- :float
125
- else
126
- :any
127
- end
128
- end
129
-
130
- def infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
131
- # Get the function signature
132
- return :any unless FunctionRegistry.supported?(fn_name)
133
-
134
- signature = FunctionRegistry.signature(fn_name)
135
- signature[:return_type] || :any
136
- end
137
-
138
- def infer_list_type(list_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
139
- return Types.array(:any) if list_expr.elements.empty?
140
-
141
- element_types = list_expr.elements.map { |elem| infer_expression_type(elem, type_context, broadcast_metadata, current_decl_name) }
142
-
143
- # Try to unify all element types
144
- unified_type = element_types.reduce { |acc, type| Types.unify(acc, type) }
145
- Types.array(unified_type)
146
- rescue StandardError
147
- # If unification fails, fall back to generic array
148
- Types.array(:any)
149
- end
150
-
151
- def infer_vectorized_element_type(expr, type_context, vectorization_meta)
152
- # For vectorized operations, we need to infer the element type
153
- case expr
154
- when InputElementReference
155
- # Get the field type from metadata
156
- input_meta = get_state(:input_meta, required: false) || {}
157
- array_name = expr.path.first
158
- field_name = expr.path[1]
159
-
160
- array_meta = input_meta[array_name]
161
- return :any unless array_meta&.dig(:type) == :array
162
-
163
- array_meta.dig(:children, field_name, :type) || :any
164
-
165
- when CallExpression
166
- # For arithmetic operations, infer from operands
167
- if %i[add subtract multiply divide].include?(expr.fn_name)
168
- # Get types of operands
169
- arg_types = expr.args.map do |arg|
170
- if arg.is_a?(InputElementReference)
171
- infer_vectorized_element_type(arg, type_context, vectorization_meta)
172
- elsif arg.is_a?(DeclarationReference)
173
- # Get the element type if it's vectorized
174
- ref_type = type_context[arg.name]
175
- if ref_type.is_a?(Hash) && ref_type.key?(:array)
176
- ref_type[:array]
177
- else
178
- ref_type || :any
179
- end
180
- else
181
- infer_expression_type(arg, type_context, vectorization_meta)
182
- end
183
- end
184
-
185
- # Unify types for arithmetic
186
- Types.unify(*arg_types) || :float
187
- else
188
- :any
189
- end
190
-
191
- else
192
- :any
193
- end
194
- end
195
-
196
- def infer_element_reference_type(expr)
197
- # Get array field metadata
198
- input_meta = get_state(:input_meta, required: false) || {}
199
-
200
- return :any unless expr.path.size >= 2
201
-
202
- array_name = expr.path.first
203
- field_name = expr.path[1]
204
-
205
- array_meta = input_meta[array_name]
206
- return :any unless array_meta&.dig(:type) == :array
207
-
208
- # Get the field type from children metadata
209
- field_type = array_meta.dig(:children, field_name, :type) || :any
210
-
211
- # Return array of field type (vectorized)
212
- { array: field_type }
213
- end
214
-
215
- def infer_cascade_type(cascade_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
216
- return :any if cascade_expr.cases.empty?
217
-
218
- result_types = cascade_expr.cases.map do |case_stmt|
219
- infer_expression_type(case_stmt.result, type_context, broadcast_metadata, current_decl_name)
220
- end
221
-
222
- # Reduce all possible types into a single unified type
223
- result_types.reduce { |unified, type| Types.unify(unified, type) } || :any
224
- rescue StandardError
225
- # Check if unification fails, fall back to base type
226
- # TODO: understand if this right to fallback or we should raise
227
- :any
228
- end
229
- end
230
- end
231
- end
232
- end