kumi 0.0.7 → 0.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CLAUDE.md +1 -1
- data/README.md +8 -5
- data/examples/game_of_life.rb +1 -1
- data/examples/static_analysis_errors.rb +7 -7
- data/lib/kumi/analyzer.rb +15 -15
- data/lib/kumi/compiler.rb +6 -6
- data/lib/kumi/core/analyzer/analysis_state.rb +39 -0
- data/lib/kumi/core/analyzer/constant_evaluator.rb +59 -0
- data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +248 -0
- data/lib/kumi/core/analyzer/passes/declaration_validator.rb +45 -0
- data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +153 -0
- data/lib/kumi/core/analyzer/passes/input_collector.rb +139 -0
- data/lib/kumi/core/analyzer/passes/name_indexer.rb +26 -0
- data/lib/kumi/core/analyzer/passes/pass_base.rb +52 -0
- data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +111 -0
- data/lib/kumi/core/analyzer/passes/toposorter.rb +110 -0
- data/lib/kumi/core/analyzer/passes/type_checker.rb +162 -0
- data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +48 -0
- data/lib/kumi/core/analyzer/passes/type_inferencer.rb +236 -0
- data/lib/kumi/core/analyzer/passes/unsat_detector.rb +406 -0
- data/lib/kumi/core/analyzer/passes/visitor_pass.rb +44 -0
- data/lib/kumi/core/atom_unsat_solver.rb +396 -0
- data/lib/kumi/core/compiled_schema.rb +43 -0
- data/lib/kumi/core/constraint_relationship_solver.rb +641 -0
- data/lib/kumi/core/domain/enum_analyzer.rb +55 -0
- data/lib/kumi/core/domain/range_analyzer.rb +85 -0
- data/lib/kumi/core/domain/validator.rb +82 -0
- data/lib/kumi/core/domain/violation_formatter.rb +42 -0
- data/lib/kumi/core/error_reporter.rb +166 -0
- data/lib/kumi/core/error_reporting.rb +97 -0
- data/lib/kumi/core/errors.rb +120 -0
- data/lib/kumi/core/evaluation_wrapper.rb +40 -0
- data/lib/kumi/core/explain.rb +295 -0
- data/lib/kumi/core/export/deserializer.rb +41 -0
- data/lib/kumi/core/export/errors.rb +14 -0
- data/lib/kumi/core/export/node_builders.rb +142 -0
- data/lib/kumi/core/export/node_registry.rb +54 -0
- data/lib/kumi/core/export/node_serializers.rb +158 -0
- data/lib/kumi/core/export/serializer.rb +25 -0
- data/lib/kumi/core/export.rb +35 -0
- data/lib/kumi/core/function_registry/collection_functions.rb +202 -0
- data/lib/kumi/core/function_registry/comparison_functions.rb +33 -0
- data/lib/kumi/core/function_registry/conditional_functions.rb +38 -0
- data/lib/kumi/core/function_registry/function_builder.rb +95 -0
- data/lib/kumi/core/function_registry/logical_functions.rb +44 -0
- data/lib/kumi/core/function_registry/math_functions.rb +74 -0
- data/lib/kumi/core/function_registry/string_functions.rb +57 -0
- data/lib/kumi/core/function_registry/type_functions.rb +53 -0
- data/lib/kumi/{function_registry.rb → core/function_registry.rb} +28 -36
- data/lib/kumi/core/input/type_matcher.rb +97 -0
- data/lib/kumi/core/input/validator.rb +51 -0
- data/lib/kumi/core/input/violation_creator.rb +52 -0
- data/lib/kumi/core/json_schema/generator.rb +65 -0
- data/lib/kumi/core/json_schema/validator.rb +27 -0
- data/lib/kumi/core/json_schema.rb +16 -0
- data/lib/kumi/core/ruby_parser/build_context.rb +27 -0
- data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +38 -0
- data/lib/kumi/core/ruby_parser/dsl.rb +14 -0
- data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +138 -0
- data/lib/kumi/core/ruby_parser/expression_converter.rb +128 -0
- data/lib/kumi/core/ruby_parser/guard_rails.rb +45 -0
- data/lib/kumi/core/ruby_parser/input_builder.rb +127 -0
- data/lib/kumi/core/ruby_parser/input_field_proxy.rb +48 -0
- data/lib/kumi/core/ruby_parser/input_proxy.rb +31 -0
- data/lib/kumi/core/ruby_parser/nested_input.rb +17 -0
- data/lib/kumi/core/ruby_parser/parser.rb +71 -0
- data/lib/kumi/core/ruby_parser/schema_builder.rb +175 -0
- data/lib/kumi/core/ruby_parser/sugar.rb +263 -0
- data/lib/kumi/core/ruby_parser.rb +12 -0
- data/lib/kumi/core/schema_instance.rb +111 -0
- data/lib/kumi/core/types/builder.rb +23 -0
- data/lib/kumi/core/types/compatibility.rb +96 -0
- data/lib/kumi/core/types/formatter.rb +26 -0
- data/lib/kumi/core/types/inference.rb +42 -0
- data/lib/kumi/core/types/normalizer.rb +72 -0
- data/lib/kumi/core/types/validator.rb +37 -0
- data/lib/kumi/core/types.rb +66 -0
- data/lib/kumi/core/vectorization_metadata.rb +110 -0
- data/lib/kumi/errors.rb +1 -112
- data/lib/kumi/registry.rb +37 -0
- data/lib/kumi/schema.rb +5 -5
- data/lib/kumi/schema_metadata.rb +3 -3
- data/lib/kumi/syntax/array_expression.rb +6 -6
- data/lib/kumi/syntax/call_expression.rb +4 -4
- data/lib/kumi/syntax/cascade_expression.rb +4 -4
- data/lib/kumi/syntax/case_expression.rb +4 -4
- data/lib/kumi/syntax/declaration_reference.rb +4 -4
- data/lib/kumi/syntax/hash_expression.rb +4 -4
- data/lib/kumi/syntax/input_declaration.rb +5 -5
- data/lib/kumi/syntax/input_element_reference.rb +5 -5
- data/lib/kumi/syntax/input_reference.rb +5 -5
- data/lib/kumi/syntax/literal.rb +4 -4
- data/lib/kumi/syntax/node.rb +34 -34
- data/lib/kumi/syntax/root.rb +6 -6
- data/lib/kumi/syntax/trait_declaration.rb +4 -4
- data/lib/kumi/syntax/value_declaration.rb +4 -4
- data/lib/kumi/version.rb +1 -1
- data/migrate_to_core_iterative.rb +938 -0
- data/scripts/generate_function_docs.rb +9 -9
- metadata +75 -72
- data/lib/kumi/analyzer/analysis_state.rb +0 -37
- data/lib/kumi/analyzer/constant_evaluator.rb +0 -57
- data/lib/kumi/analyzer/passes/broadcast_detector.rb +0 -246
- data/lib/kumi/analyzer/passes/declaration_validator.rb +0 -43
- data/lib/kumi/analyzer/passes/dependency_resolver.rb +0 -151
- data/lib/kumi/analyzer/passes/input_collector.rb +0 -137
- data/lib/kumi/analyzer/passes/name_indexer.rb +0 -24
- data/lib/kumi/analyzer/passes/pass_base.rb +0 -50
- data/lib/kumi/analyzer/passes/semantic_constraint_validator.rb +0 -109
- data/lib/kumi/analyzer/passes/toposorter.rb +0 -108
- data/lib/kumi/analyzer/passes/type_checker.rb +0 -160
- data/lib/kumi/analyzer/passes/type_consistency_checker.rb +0 -46
- data/lib/kumi/analyzer/passes/type_inferencer.rb +0 -232
- data/lib/kumi/analyzer/passes/unsat_detector.rb +0 -404
- data/lib/kumi/analyzer/passes/visitor_pass.rb +0 -42
- data/lib/kumi/atom_unsat_solver.rb +0 -394
- data/lib/kumi/compiled_schema.rb +0 -41
- data/lib/kumi/constraint_relationship_solver.rb +0 -638
- data/lib/kumi/domain/enum_analyzer.rb +0 -53
- data/lib/kumi/domain/range_analyzer.rb +0 -83
- data/lib/kumi/domain/validator.rb +0 -80
- data/lib/kumi/domain/violation_formatter.rb +0 -40
- data/lib/kumi/error_reporter.rb +0 -164
- data/lib/kumi/error_reporting.rb +0 -95
- data/lib/kumi/evaluation_wrapper.rb +0 -38
- data/lib/kumi/explain.rb +0 -293
- data/lib/kumi/export/deserializer.rb +0 -39
- data/lib/kumi/export/errors.rb +0 -12
- data/lib/kumi/export/node_builders.rb +0 -140
- data/lib/kumi/export/node_registry.rb +0 -52
- data/lib/kumi/export/node_serializers.rb +0 -156
- data/lib/kumi/export/serializer.rb +0 -23
- data/lib/kumi/export.rb +0 -33
- data/lib/kumi/function_registry/collection_functions.rb +0 -200
- data/lib/kumi/function_registry/comparison_functions.rb +0 -31
- data/lib/kumi/function_registry/conditional_functions.rb +0 -36
- data/lib/kumi/function_registry/function_builder.rb +0 -93
- data/lib/kumi/function_registry/logical_functions.rb +0 -42
- data/lib/kumi/function_registry/math_functions.rb +0 -72
- data/lib/kumi/function_registry/string_functions.rb +0 -54
- data/lib/kumi/function_registry/type_functions.rb +0 -51
- data/lib/kumi/input/type_matcher.rb +0 -95
- data/lib/kumi/input/validator.rb +0 -49
- data/lib/kumi/input/violation_creator.rb +0 -50
- data/lib/kumi/json_schema/generator.rb +0 -63
- data/lib/kumi/json_schema/validator.rb +0 -25
- data/lib/kumi/json_schema.rb +0 -14
- data/lib/kumi/ruby_parser/build_context.rb +0 -25
- data/lib/kumi/ruby_parser/declaration_reference_proxy.rb +0 -36
- data/lib/kumi/ruby_parser/dsl.rb +0 -12
- data/lib/kumi/ruby_parser/dsl_cascade_builder.rb +0 -136
- data/lib/kumi/ruby_parser/expression_converter.rb +0 -126
- data/lib/kumi/ruby_parser/guard_rails.rb +0 -43
- data/lib/kumi/ruby_parser/input_builder.rb +0 -125
- data/lib/kumi/ruby_parser/input_field_proxy.rb +0 -46
- data/lib/kumi/ruby_parser/input_proxy.rb +0 -29
- data/lib/kumi/ruby_parser/nested_input.rb +0 -15
- data/lib/kumi/ruby_parser/parser.rb +0 -69
- data/lib/kumi/ruby_parser/schema_builder.rb +0 -173
- data/lib/kumi/ruby_parser/sugar.rb +0 -261
- data/lib/kumi/ruby_parser.rb +0 -10
- data/lib/kumi/schema_instance.rb +0 -109
- data/lib/kumi/types/builder.rb +0 -21
- data/lib/kumi/types/compatibility.rb +0 -94
- data/lib/kumi/types/formatter.rb +0 -24
- data/lib/kumi/types/inference.rb +0 -40
- data/lib/kumi/types/normalizer.rb +0 -70
- data/lib/kumi/types/validator.rb +0 -35
- data/lib/kumi/types.rb +0 -64
- data/lib/kumi/vectorization_metadata.rb +0 -108
@@ -0,0 +1,162 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Kumi
|
4
|
+
module Core
|
5
|
+
module Analyzer
|
6
|
+
module Passes
|
7
|
+
# RESPONSIBILITY: Validate function call arity and argument types against FunctionRegistry
|
8
|
+
# DEPENDENCIES: :inferred_types from TypeInferencer
|
9
|
+
# PRODUCES: None (validation only)
|
10
|
+
# INTERFACE: new(schema, state).run(errors)
|
11
|
+
class TypeChecker < VisitorPass
|
12
|
+
def run(errors)
|
13
|
+
visit_nodes_of_type(Kumi::Syntax::CallExpression, errors: errors) do |node, _decl, errs|
|
14
|
+
validate_function_call(node, errs)
|
15
|
+
end
|
16
|
+
state
|
17
|
+
end
|
18
|
+
|
19
|
+
private
|
20
|
+
|
21
|
+
def validate_function_call(node, errors)
|
22
|
+
signature = get_function_signature(node, errors)
|
23
|
+
return unless signature
|
24
|
+
|
25
|
+
validate_arity(node, signature, errors)
|
26
|
+
validate_argument_types(node, signature, errors)
|
27
|
+
end
|
28
|
+
|
29
|
+
def get_function_signature(node, errors)
|
30
|
+
Kumi::Registry.signature(node.fn_name)
|
31
|
+
rescue Kumi::Errors::UnknownFunction
|
32
|
+
# Use old format for backward compatibility, but node.loc provides better location
|
33
|
+
report_error(errors, "unsupported operator `#{node.fn_name}`", location: node.loc, type: :type)
|
34
|
+
nil
|
35
|
+
end
|
36
|
+
|
37
|
+
def validate_arity(node, signature, errors)
|
38
|
+
expected_arity = signature[:arity]
|
39
|
+
actual_arity = node.args.size
|
40
|
+
|
41
|
+
return if expected_arity.negative? || expected_arity == actual_arity
|
42
|
+
|
43
|
+
report_error(errors, "operator `#{node.fn_name}` expects #{expected_arity} args, got #{actual_arity}", location: node.loc,
|
44
|
+
type: :type)
|
45
|
+
end
|
46
|
+
|
47
|
+
def validate_argument_types(node, signature, errors)
|
48
|
+
types = signature[:param_types]
|
49
|
+
return if types.nil? || (signature[:arity].negative? && node.args.empty?)
|
50
|
+
|
51
|
+
# Skip type checking for vectorized operations
|
52
|
+
broadcast_meta = get_state(:broadcasts, required: false)
|
53
|
+
return if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
|
54
|
+
|
55
|
+
node.args.each_with_index do |arg, i|
|
56
|
+
validate_argument_type(arg, i, types[i], node.fn_name, errors)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
def is_part_of_vectorized_operation?(node, broadcast_meta)
|
61
|
+
# Check if this node is part of a vectorized or reduction operation
|
62
|
+
# This is a simplified check - in a real implementation we'd need to track context
|
63
|
+
node.args.any? do |arg|
|
64
|
+
case arg
|
65
|
+
when Kumi::Syntax::DeclarationReference
|
66
|
+
broadcast_meta[:vectorized_operations]&.key?(arg.name) ||
|
67
|
+
broadcast_meta[:reduction_operations]&.key?(arg.name)
|
68
|
+
when Kumi::Syntax::InputElementReference
|
69
|
+
broadcast_meta[:array_fields]&.key?(arg.path.first)
|
70
|
+
else
|
71
|
+
false
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
|
76
|
+
def validate_argument_type(arg, index, expected_type, fn_name, errors)
|
77
|
+
return if expected_type.nil? || expected_type == Kumi::Core::Types::ANY
|
78
|
+
|
79
|
+
# Get the inferred type for this argument
|
80
|
+
actual_type = get_expression_type(arg)
|
81
|
+
return if Kumi::Core::Types.compatible?(actual_type, expected_type)
|
82
|
+
|
83
|
+
# Generate descriptive error message
|
84
|
+
source_desc = describe_expression_type(arg, actual_type)
|
85
|
+
report_error(errors, "argument #{index + 1} of `fn(:#{fn_name})` expects #{expected_type}, " \
|
86
|
+
"got #{source_desc}", location: arg.loc, type: :type)
|
87
|
+
end
|
88
|
+
|
89
|
+
def get_expression_type(expr)
|
90
|
+
case expr
|
91
|
+
when Kumi::Syntax::Literal
|
92
|
+
# Inferred type from literal value
|
93
|
+
Kumi::Core::Types.infer_from_value(expr.value)
|
94
|
+
|
95
|
+
when Kumi::Syntax::InputReference
|
96
|
+
# Declared type from input block (user-specified)
|
97
|
+
get_declared_field_type(expr.name)
|
98
|
+
|
99
|
+
when Kumi::Syntax::DeclarationReference
|
100
|
+
# Inferred type from type inference results
|
101
|
+
get_inferred_declaration_type(expr.name)
|
102
|
+
|
103
|
+
else
|
104
|
+
# For complex expressions, we should have type inference results
|
105
|
+
# This is a simplified approach - in reality we'd need to track types for all expressions
|
106
|
+
Kumi::Core::Types::ANY
|
107
|
+
end
|
108
|
+
end
|
109
|
+
|
110
|
+
def get_declared_field_type(field_name)
|
111
|
+
# Get explicitly declared type from input metadata
|
112
|
+
input_meta = get_state(:inputs, required: false) || {}
|
113
|
+
field_meta = input_meta[field_name]
|
114
|
+
field_meta&.dig(:type) || Kumi::Core::Types::ANY
|
115
|
+
end
|
116
|
+
|
117
|
+
def get_inferred_declaration_type(decl_name)
|
118
|
+
# Get inferred type from type inference results
|
119
|
+
decl_types = get_state(:inferred_types, required: true)
|
120
|
+
decl_types[decl_name] || Kumi::Core::Types::ANY
|
121
|
+
end
|
122
|
+
|
123
|
+
def describe_expression_type(expr, type)
|
124
|
+
case expr
|
125
|
+
when Kumi::Syntax::Literal
|
126
|
+
"`#{expr.value}` of type #{type} (literal value)"
|
127
|
+
|
128
|
+
when Kumi::Syntax::InputReference
|
129
|
+
input_meta = get_state(:inputs, required: false) || {}
|
130
|
+
field_meta = input_meta[expr.name]
|
131
|
+
|
132
|
+
if field_meta&.dig(:type)
|
133
|
+
# Explicitly declared type
|
134
|
+
domain_desc = field_meta[:domain] ? " (domain: #{field_meta[:domain]})" : ""
|
135
|
+
"input field `#{expr.name}` of declared type #{type}#{domain_desc}"
|
136
|
+
else
|
137
|
+
# Undeclared field
|
138
|
+
"undeclared input field `#{expr.name}` (inferred as #{type})"
|
139
|
+
end
|
140
|
+
|
141
|
+
when Kumi::Syntax::DeclarationReference
|
142
|
+
# This type was inferred from the declaration's expression
|
143
|
+
"reference to declaration `#{expr.name}` of inferred type #{type}"
|
144
|
+
|
145
|
+
when Kumi::Syntax::CallExpression
|
146
|
+
"result of function `#{expr.fn_name}` returning #{type}"
|
147
|
+
|
148
|
+
when Kumi::Syntax::ArrayExpression
|
149
|
+
"list expression of type #{type}"
|
150
|
+
|
151
|
+
when Kumi::Syntax::CascadeExpression
|
152
|
+
"cascade expression of type #{type}"
|
153
|
+
|
154
|
+
else
|
155
|
+
"expression of type #{type}"
|
156
|
+
end
|
157
|
+
end
|
158
|
+
end
|
159
|
+
end
|
160
|
+
end
|
161
|
+
end
|
162
|
+
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Kumi
|
4
|
+
module Core
|
5
|
+
module Analyzer
|
6
|
+
module Passes
|
7
|
+
# RESPONSIBILITY: Validate consistency between declared and inferred types
|
8
|
+
# DEPENDENCIES: :inputs from InputCollector, :inferred_types from TypeInferencer
|
9
|
+
# PRODUCES: None (validation only)
|
10
|
+
# INTERFACE: new(schema, state).run(errors)
|
11
|
+
class TypeConsistencyChecker < PassBase
|
12
|
+
def run(errors)
|
13
|
+
input_meta = get_state(:inputs, required: false) || {}
|
14
|
+
|
15
|
+
# First, validate that all declared types are valid
|
16
|
+
validate_declared_types(input_meta, errors)
|
17
|
+
|
18
|
+
# Then check basic consistency (placeholder for now)
|
19
|
+
# In a full implementation, this would do sophisticated usage analysis
|
20
|
+
state
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def validate_declared_types(input_meta, errors)
|
26
|
+
input_meta.each do |field_name, meta|
|
27
|
+
declared_type = meta[:type]
|
28
|
+
next unless declared_type # Skip fields without declared types
|
29
|
+
next if Kumi::Core::Types.valid_type?(declared_type)
|
30
|
+
|
31
|
+
# Find the input field declaration for proper location information
|
32
|
+
field_decl = find_input_field_declaration(field_name)
|
33
|
+
location = field_decl&.loc
|
34
|
+
|
35
|
+
report_type_error(errors, "Invalid type declaration for field :#{field_name}: #{declared_type.inspect}", location: location)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def find_input_field_declaration(field_name)
|
40
|
+
return nil unless schema
|
41
|
+
|
42
|
+
schema.inputs.find { |input_decl| input_decl.name == field_name }
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
@@ -0,0 +1,236 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Kumi
|
4
|
+
module Core
|
5
|
+
module Analyzer
|
6
|
+
module Passes
|
7
|
+
# RESPONSIBILITY: Infer types for all declarations based on expression analysis
|
8
|
+
# DEPENDENCIES: Toposorter (needs evaluation_order), DeclarationValidator (needs declarations)
|
9
|
+
# PRODUCES: inferred_types hash mapping declaration names to inferred types
|
10
|
+
# INTERFACE: new(schema, state).run(errors)
|
11
|
+
class TypeInferencer < PassBase
|
12
|
+
def run(errors)
|
13
|
+
types = {}
|
14
|
+
topo_order = get_state(:evaluation_order)
|
15
|
+
definitions = get_state(:declarations)
|
16
|
+
|
17
|
+
# Get broadcast metadata from broadcast detector
|
18
|
+
broadcast_meta = get_state(:broadcasts, required: false) || {}
|
19
|
+
|
20
|
+
# Process declarations in topological order to ensure dependencies are resolved
|
21
|
+
topo_order.each do |name|
|
22
|
+
decl = definitions[name]
|
23
|
+
next unless decl
|
24
|
+
|
25
|
+
begin
|
26
|
+
# Check if this declaration is marked as vectorized
|
27
|
+
if broadcast_meta[:vectorized_operations]&.key?(name)
|
28
|
+
# Infer the element type and wrap in array
|
29
|
+
element_type = infer_vectorized_element_type(decl.expression, types, broadcast_meta)
|
30
|
+
types[name] = decl.is_a?(Kumi::Syntax::TraitDeclaration) ? { array: :boolean } : { array: element_type }
|
31
|
+
else
|
32
|
+
# Normal type inference
|
33
|
+
inferred_type = infer_expression_type(decl.expression, types, broadcast_meta, name)
|
34
|
+
types[name] = inferred_type
|
35
|
+
end
|
36
|
+
rescue StandardError => e
|
37
|
+
report_type_error(errors, "Type inference failed: #{e.message}", location: decl&.loc)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
state.with(:inferred_types, types)
|
42
|
+
end
|
43
|
+
|
44
|
+
private
|
45
|
+
|
46
|
+
def infer_expression_type(expr, type_context = {}, broadcast_metadata = {}, current_decl_name = nil)
|
47
|
+
case expr
|
48
|
+
when Literal
|
49
|
+
Types.infer_from_value(expr.value)
|
50
|
+
when InputReference
|
51
|
+
# Look up type from field metadata
|
52
|
+
input_meta = get_state(:inputs, required: false) || {}
|
53
|
+
meta = input_meta[expr.name]
|
54
|
+
meta&.dig(:type) || :any
|
55
|
+
when DeclarationReference
|
56
|
+
type_context[expr.name] || :any
|
57
|
+
when CallExpression
|
58
|
+
infer_call_type(expr, type_context, broadcast_metadata, current_decl_name)
|
59
|
+
when ArrayExpression
|
60
|
+
infer_list_type(expr, type_context, broadcast_metadata, current_decl_name)
|
61
|
+
when CascadeExpression
|
62
|
+
infer_cascade_type(expr, type_context, broadcast_metadata, current_decl_name)
|
63
|
+
when InputElementReference
|
64
|
+
# Element reference returns the field type
|
65
|
+
infer_element_reference_type(expr)
|
66
|
+
else
|
67
|
+
:any
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
def infer_call_type(call_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
|
72
|
+
fn_name = call_expr.fn_name
|
73
|
+
args = call_expr.args
|
74
|
+
|
75
|
+
# Check broadcast metadata first
|
76
|
+
if current_decl_name && broadcast_metadata[:vectorized_values]&.key?(current_decl_name)
|
77
|
+
# This declaration is marked as vectorized, so it produces an array
|
78
|
+
element_type = infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
|
79
|
+
return { array: element_type }
|
80
|
+
end
|
81
|
+
|
82
|
+
if current_decl_name && broadcast_metadata[:reducer_values]&.key?(current_decl_name)
|
83
|
+
# This declaration is marked as a reducer, get the result from the function
|
84
|
+
return infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
|
85
|
+
end
|
86
|
+
|
87
|
+
# Check if function exists in registry
|
88
|
+
unless Kumi::Registry.supported?(fn_name)
|
89
|
+
# Don't push error here - let existing TypeChecker handle it
|
90
|
+
return :any
|
91
|
+
end
|
92
|
+
|
93
|
+
signature = Kumi::Registry.signature(fn_name)
|
94
|
+
|
95
|
+
# Validate arity if not variable
|
96
|
+
if signature[:arity] >= 0 && args.size != signature[:arity]
|
97
|
+
# Don't push error here - let existing TypeChecker handle it
|
98
|
+
return :any
|
99
|
+
end
|
100
|
+
|
101
|
+
# Infer argument types
|
102
|
+
arg_types = args.map { |arg| infer_expression_type(arg, type_context, broadcast_metadata, current_decl_name) }
|
103
|
+
|
104
|
+
# Validate parameter types (warn but don't fail)
|
105
|
+
param_types = signature[:param_types] || []
|
106
|
+
if signature[:arity] >= 0 && param_types.size.positive?
|
107
|
+
arg_types.each_with_index do |arg_type, i|
|
108
|
+
expected_type = param_types[i] || param_types.last
|
109
|
+
next if expected_type.nil?
|
110
|
+
|
111
|
+
unless Types.compatible?(arg_type, expected_type)
|
112
|
+
# Could add warning here in future, but for now just infer best type
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
|
117
|
+
signature[:return_type] || :any
|
118
|
+
end
|
119
|
+
|
120
|
+
def infer_vectorized_element_type(call_expr, _type_context, _broadcast_metadata)
|
121
|
+
# For vectorized arithmetic operations, infer the element type
|
122
|
+
# For now, assume arithmetic operations on floats produce floats
|
123
|
+
case call_expr.fn_name
|
124
|
+
when :multiply, :add, :subtract, :divide
|
125
|
+
:float
|
126
|
+
else
|
127
|
+
:any
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
def infer_function_return_type(fn_name, _args, _type_context, _broadcast_metadata)
|
132
|
+
# Get the function signature
|
133
|
+
return :any unless Kumi::Registry.supported?(fn_name)
|
134
|
+
|
135
|
+
signature = Kumi::Registry.signature(fn_name)
|
136
|
+
signature[:return_type] || :any
|
137
|
+
end
|
138
|
+
|
139
|
+
def infer_list_type(list_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
|
140
|
+
return Types.array(:any) if list_expr.elements.empty?
|
141
|
+
|
142
|
+
element_types = list_expr.elements.map do |elem|
|
143
|
+
infer_expression_type(elem, type_context, broadcast_metadata, current_decl_name)
|
144
|
+
end
|
145
|
+
|
146
|
+
# Try to unify all element types
|
147
|
+
unified_type = element_types.reduce { |acc, type| Types.unify(acc, type) }
|
148
|
+
Types.array(unified_type)
|
149
|
+
rescue StandardError
|
150
|
+
# If unification fails, fall back to generic array
|
151
|
+
Types.array(:any)
|
152
|
+
end
|
153
|
+
|
154
|
+
def infer_vectorized_element_type(expr, type_context, vectorization_meta)
|
155
|
+
# For vectorized operations, we need to infer the element type
|
156
|
+
case expr
|
157
|
+
when InputElementReference
|
158
|
+
# Get the field type from metadata
|
159
|
+
input_meta = get_state(:inputs, required: false) || {}
|
160
|
+
array_name = expr.path.first
|
161
|
+
field_name = expr.path[1]
|
162
|
+
|
163
|
+
array_meta = input_meta[array_name]
|
164
|
+
return :any unless array_meta&.dig(:type) == :array
|
165
|
+
|
166
|
+
array_meta.dig(:children, field_name, :type) || :any
|
167
|
+
|
168
|
+
when CallExpression
|
169
|
+
# For arithmetic operations, infer from operands
|
170
|
+
if %i[add subtract multiply divide].include?(expr.fn_name)
|
171
|
+
# Get types of operands
|
172
|
+
arg_types = expr.args.map do |arg|
|
173
|
+
if arg.is_a?(InputElementReference)
|
174
|
+
infer_vectorized_element_type(arg, type_context, vectorization_meta)
|
175
|
+
elsif arg.is_a?(DeclarationReference)
|
176
|
+
# Get the element type if it's vectorized
|
177
|
+
ref_type = type_context[arg.name]
|
178
|
+
if ref_type.is_a?(Hash) && ref_type.key?(:array)
|
179
|
+
ref_type[:array]
|
180
|
+
else
|
181
|
+
ref_type || :any
|
182
|
+
end
|
183
|
+
else
|
184
|
+
infer_expression_type(arg, type_context, vectorization_meta)
|
185
|
+
end
|
186
|
+
end
|
187
|
+
|
188
|
+
# Unify types for arithmetic
|
189
|
+
Types.unify(*arg_types) || :float
|
190
|
+
else
|
191
|
+
:any
|
192
|
+
end
|
193
|
+
|
194
|
+
else
|
195
|
+
:any
|
196
|
+
end
|
197
|
+
end
|
198
|
+
|
199
|
+
def infer_element_reference_type(expr)
|
200
|
+
# Get array field metadata
|
201
|
+
input_meta = get_state(:inputs, required: false) || {}
|
202
|
+
|
203
|
+
return :any unless expr.path.size >= 2
|
204
|
+
|
205
|
+
array_name = expr.path.first
|
206
|
+
field_name = expr.path[1]
|
207
|
+
|
208
|
+
array_meta = input_meta[array_name]
|
209
|
+
return :any unless array_meta&.dig(:type) == :array
|
210
|
+
|
211
|
+
# Get the field type from children metadata
|
212
|
+
field_type = array_meta.dig(:children, field_name, :type) || :any
|
213
|
+
|
214
|
+
# Return array of field type (vectorized)
|
215
|
+
{ array: field_type }
|
216
|
+
end
|
217
|
+
|
218
|
+
def infer_cascade_type(cascade_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
|
219
|
+
return :any if cascade_expr.cases.empty?
|
220
|
+
|
221
|
+
result_types = cascade_expr.cases.map do |case_stmt|
|
222
|
+
infer_expression_type(case_stmt.result, type_context, broadcast_metadata, current_decl_name)
|
223
|
+
end
|
224
|
+
|
225
|
+
# Reduce all possible types into a single unified type
|
226
|
+
result_types.reduce { |unified, type| Types.unify(unified, type) } || :any
|
227
|
+
rescue StandardError
|
228
|
+
# Check if unification fails, fall back to base type
|
229
|
+
# TODO: understand if this right to fallback or we should raise
|
230
|
+
:any
|
231
|
+
end
|
232
|
+
end
|
233
|
+
end
|
234
|
+
end
|
235
|
+
end
|
236
|
+
end
|