kumi 0.0.6 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CLAUDE.md +33 -176
- data/README.md +33 -2
- data/docs/SYNTAX.md +2 -7
- data/docs/features/array-broadcasting.md +1 -1
- data/docs/schema_metadata/broadcasts.md +53 -0
- data/docs/schema_metadata/cascades.md +45 -0
- data/docs/schema_metadata/declarations.md +54 -0
- data/docs/schema_metadata/dependencies.md +57 -0
- data/docs/schema_metadata/evaluation_order.md +29 -0
- data/docs/schema_metadata/examples.md +95 -0
- data/docs/schema_metadata/inferred_types.md +46 -0
- data/docs/schema_metadata/inputs.md +86 -0
- data/docs/schema_metadata.md +108 -0
- data/lib/kumi/analyzer/passes/broadcast_detector.rb +52 -57
- data/lib/kumi/analyzer/passes/dependency_resolver.rb +8 -8
- data/lib/kumi/analyzer/passes/input_collector.rb +2 -2
- data/lib/kumi/analyzer/passes/name_indexer.rb +2 -2
- data/lib/kumi/analyzer/passes/semantic_constraint_validator.rb +15 -16
- data/lib/kumi/analyzer/passes/toposorter.rb +23 -23
- data/lib/kumi/analyzer/passes/type_checker.rb +7 -9
- data/lib/kumi/analyzer/passes/type_consistency_checker.rb +2 -2
- data/lib/kumi/analyzer/passes/type_inferencer.rb +24 -24
- data/lib/kumi/analyzer/passes/unsat_detector.rb +11 -13
- data/lib/kumi/analyzer.rb +5 -5
- data/lib/kumi/compiler.rb +39 -45
- data/lib/kumi/error_reporting.rb +1 -1
- data/lib/kumi/explain.rb +12 -0
- data/lib/kumi/export/node_registry.rb +2 -2
- data/lib/kumi/json_schema/generator.rb +63 -0
- data/lib/kumi/json_schema/validator.rb +25 -0
- data/lib/kumi/json_schema.rb +14 -0
- data/lib/kumi/{parser → ruby_parser}/build_context.rb +1 -1
- data/lib/kumi/{parser → ruby_parser}/declaration_reference_proxy.rb +3 -3
- data/lib/kumi/{parser → ruby_parser}/dsl.rb +1 -1
- data/lib/kumi/{parser → ruby_parser}/dsl_cascade_builder.rb +2 -2
- data/lib/kumi/{parser → ruby_parser}/expression_converter.rb +14 -14
- data/lib/kumi/{parser → ruby_parser}/guard_rails.rb +1 -1
- data/lib/kumi/{parser → ruby_parser}/input_builder.rb +1 -1
- data/lib/kumi/{parser → ruby_parser}/input_field_proxy.rb +4 -4
- data/lib/kumi/{parser → ruby_parser}/input_proxy.rb +1 -1
- data/lib/kumi/{parser → ruby_parser}/nested_input.rb +1 -1
- data/lib/kumi/{parser → ruby_parser}/parser.rb +11 -10
- data/lib/kumi/{parser → ruby_parser}/schema_builder.rb +1 -1
- data/lib/kumi/{parser → ruby_parser}/sugar.rb +1 -1
- data/lib/kumi/ruby_parser.rb +10 -0
- data/lib/kumi/schema.rb +10 -4
- data/lib/kumi/schema_instance.rb +6 -6
- data/lib/kumi/schema_metadata.rb +524 -0
- data/lib/kumi/vectorization_metadata.rb +4 -4
- data/lib/kumi/version.rb +1 -1
- data/lib/kumi.rb +14 -0
- metadata +28 -15
- data/lib/generators/trait_engine/templates/schema_spec.rb.erb +0 -27
@@ -24,7 +24,7 @@ module Kumi
|
|
24
24
|
|
25
25
|
private
|
26
26
|
|
27
|
-
def validate_semantic_constraints(node,
|
27
|
+
def validate_semantic_constraints(node, _decl, errors)
|
28
28
|
case node
|
29
29
|
when Kumi::Syntax::TraitDeclaration
|
30
30
|
validate_trait_expression(node, errors)
|
@@ -48,20 +48,20 @@ module Kumi
|
|
48
48
|
|
49
49
|
def validate_cascade_condition(when_case, errors)
|
50
50
|
condition = when_case.condition
|
51
|
-
|
51
|
+
|
52
52
|
case condition
|
53
53
|
when Kumi::Syntax::DeclarationReference
|
54
54
|
# Valid: trait reference
|
55
|
-
|
55
|
+
nil
|
56
56
|
when Kumi::Syntax::CallExpression
|
57
57
|
# Valid if it's a boolean composition of traits (all?, any?, none?)
|
58
58
|
return if boolean_trait_composition?(condition)
|
59
|
-
|
59
|
+
|
60
60
|
# For now, allow other CallExpressions - they'll be validated by other passes
|
61
|
-
|
61
|
+
nil
|
62
62
|
when Kumi::Syntax::Literal
|
63
63
|
# Allow literal conditions (like true/false) - they might be valid
|
64
|
-
|
64
|
+
nil
|
65
65
|
else
|
66
66
|
# Only reject truly invalid conditions like InputReference or complex expressions
|
67
67
|
report_error(
|
@@ -75,10 +75,10 @@ module Kumi
|
|
75
75
|
|
76
76
|
def validate_function_call(call_expr, errors)
|
77
77
|
fn_name = call_expr.fn_name
|
78
|
-
|
78
|
+
|
79
79
|
# Skip validation if FunctionRegistry is being mocked for testing
|
80
80
|
return if function_registry_mocked?
|
81
|
-
|
81
|
+
|
82
82
|
return if FunctionRegistry.supported?(fn_name)
|
83
83
|
|
84
84
|
report_error(
|
@@ -96,15 +96,14 @@ module Kumi
|
|
96
96
|
|
97
97
|
def function_registry_mocked?
|
98
98
|
# Check if FunctionRegistry is being mocked (for tests)
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
end
|
99
|
+
|
100
|
+
# Try to access a method that doesn't exist in the real registry
|
101
|
+
# If it's mocked, this won't raise an error
|
102
|
+
FunctionRegistry.respond_to?(:confirm_support!)
|
103
|
+
rescue StandardError
|
104
|
+
false
|
106
105
|
end
|
107
106
|
end
|
108
107
|
end
|
109
108
|
end
|
110
|
-
end
|
109
|
+
end
|
@@ -4,16 +4,16 @@ module Kumi
|
|
4
4
|
module Analyzer
|
5
5
|
module Passes
|
6
6
|
# RESPONSIBILITY: Compute topological ordering of declarations, allowing safe conditional cycles
|
7
|
-
# DEPENDENCIES: :
|
8
|
-
# PRODUCES: :
|
7
|
+
# DEPENDENCIES: :dependencies from DependencyResolver, :declarations from NameIndexer, :cascades from UnsatDetector
|
8
|
+
# PRODUCES: :evaluation_order - Array of declaration names in evaluation order
|
9
9
|
# INTERFACE: new(schema, state).run(errors)
|
10
10
|
class Toposorter < PassBase
|
11
11
|
def run(errors)
|
12
|
-
dependency_graph = get_state(:
|
13
|
-
definitions = get_state(:
|
12
|
+
dependency_graph = get_state(:dependencies, required: false) || {}
|
13
|
+
definitions = get_state(:declarations, required: false) || {}
|
14
14
|
|
15
15
|
order = compute_topological_order(dependency_graph, definitions, errors)
|
16
|
-
state.with(:
|
16
|
+
state.with(:evaluation_order, order)
|
17
17
|
end
|
18
18
|
|
19
19
|
private
|
@@ -22,7 +22,7 @@ module Kumi
|
|
22
22
|
temp_marks = Set.new
|
23
23
|
perm_marks = Set.new
|
24
24
|
order = []
|
25
|
-
|
25
|
+
cascades = get_state(:cascades) || {}
|
26
26
|
|
27
27
|
visit_node = lambda do |node, path = []|
|
28
28
|
return if perm_marks.include?(node)
|
@@ -30,13 +30,13 @@ module Kumi
|
|
30
30
|
if temp_marks.include?(node)
|
31
31
|
# Check if this is a safe conditional cycle
|
32
32
|
cycle_path = path + [node]
|
33
|
-
if safe_conditional_cycle?(cycle_path, graph,
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
33
|
+
return if safe_conditional_cycle?(cycle_path, graph, cascades)
|
34
|
+
|
35
|
+
# Allow this cycle - it's safe due to cascade mutual exclusion
|
36
|
+
|
37
|
+
report_unexpected_cycle(temp_marks, node, errors)
|
38
|
+
|
39
|
+
return
|
40
40
|
end
|
41
41
|
|
42
42
|
temp_marks << node
|
@@ -59,30 +59,30 @@ module Kumi
|
|
59
59
|
order.freeze
|
60
60
|
end
|
61
61
|
|
62
|
-
def safe_conditional_cycle?(cycle_path, graph,
|
62
|
+
def safe_conditional_cycle?(cycle_path, graph, cascades)
|
63
63
|
return false if cycle_path.nil? || cycle_path.size < 2
|
64
|
-
|
64
|
+
|
65
65
|
# Find where the cycle starts - look for the first occurrence of the repeated node
|
66
66
|
last_node = cycle_path.last
|
67
67
|
return false if last_node.nil?
|
68
|
-
|
68
|
+
|
69
69
|
cycle_start = cycle_path.index(last_node)
|
70
70
|
return false unless cycle_start && cycle_start < cycle_path.size - 1
|
71
|
-
|
72
|
-
cycle_nodes = cycle_path[cycle_start
|
73
|
-
|
71
|
+
|
72
|
+
cycle_nodes = cycle_path[cycle_start..]
|
73
|
+
|
74
74
|
# Check if all edges in the cycle are conditional
|
75
75
|
cycle_nodes.each_cons(2) do |from, to|
|
76
76
|
edges = graph[from] || []
|
77
77
|
edge = edges.find { |e| e.to == to }
|
78
|
-
|
78
|
+
|
79
79
|
return false unless edge&.conditional
|
80
|
-
|
80
|
+
|
81
81
|
# Check if the cascade has mutually exclusive conditions
|
82
|
-
cascade_meta =
|
82
|
+
cascade_meta = cascades[edge.cascade_owner]
|
83
83
|
return false unless cascade_meta&.dig(:all_mutually_exclusive)
|
84
84
|
end
|
85
|
-
|
85
|
+
|
86
86
|
true
|
87
87
|
end
|
88
88
|
|
@@ -4,7 +4,7 @@ module Kumi
|
|
4
4
|
module Analyzer
|
5
5
|
module Passes
|
6
6
|
# RESPONSIBILITY: Validate function call arity and argument types against FunctionRegistry
|
7
|
-
# DEPENDENCIES: :
|
7
|
+
# DEPENDENCIES: :inferred_types from TypeInferencer
|
8
8
|
# PRODUCES: None (validation only)
|
9
9
|
# INTERFACE: new(schema, state).run(errors)
|
10
10
|
class TypeChecker < VisitorPass
|
@@ -48,10 +48,8 @@ module Kumi
|
|
48
48
|
return if types.nil? || (signature[:arity].negative? && node.args.empty?)
|
49
49
|
|
50
50
|
# Skip type checking for vectorized operations
|
51
|
-
broadcast_meta = get_state(:
|
52
|
-
if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
|
53
|
-
return
|
54
|
-
end
|
51
|
+
broadcast_meta = get_state(:broadcasts, required: false)
|
52
|
+
return if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
|
55
53
|
|
56
54
|
node.args.each_with_index do |arg, i|
|
57
55
|
validate_argument_type(arg, i, types[i], node.fn_name, errors)
|
@@ -65,7 +63,7 @@ module Kumi
|
|
65
63
|
case arg
|
66
64
|
when Kumi::Syntax::DeclarationReference
|
67
65
|
broadcast_meta[:vectorized_operations]&.key?(arg.name) ||
|
68
|
-
|
66
|
+
broadcast_meta[:reduction_operations]&.key?(arg.name)
|
69
67
|
when Kumi::Syntax::InputElementReference
|
70
68
|
broadcast_meta[:array_fields]&.key?(arg.path.first)
|
71
69
|
else
|
@@ -110,14 +108,14 @@ module Kumi
|
|
110
108
|
|
111
109
|
def get_declared_field_type(field_name)
|
112
110
|
# Get explicitly declared type from input metadata
|
113
|
-
input_meta = get_state(:
|
111
|
+
input_meta = get_state(:inputs, required: false) || {}
|
114
112
|
field_meta = input_meta[field_name]
|
115
113
|
field_meta&.dig(:type) || Kumi::Types::ANY
|
116
114
|
end
|
117
115
|
|
118
116
|
def get_inferred_declaration_type(decl_name)
|
119
117
|
# Get inferred type from type inference results
|
120
|
-
decl_types = get_state(:
|
118
|
+
decl_types = get_state(:inferred_types, required: true)
|
121
119
|
decl_types[decl_name] || Kumi::Types::ANY
|
122
120
|
end
|
123
121
|
|
@@ -127,7 +125,7 @@ module Kumi
|
|
127
125
|
"`#{expr.value}` of type #{type} (literal value)"
|
128
126
|
|
129
127
|
when Kumi::Syntax::InputReference
|
130
|
-
input_meta = get_state(:
|
128
|
+
input_meta = get_state(:inputs, required: false) || {}
|
131
129
|
field_meta = input_meta[expr.name]
|
132
130
|
|
133
131
|
if field_meta&.dig(:type)
|
@@ -4,12 +4,12 @@ module Kumi
|
|
4
4
|
module Analyzer
|
5
5
|
module Passes
|
6
6
|
# RESPONSIBILITY: Validate consistency between declared and inferred types
|
7
|
-
# DEPENDENCIES: :
|
7
|
+
# DEPENDENCIES: :inputs from InputCollector, :inferred_types from TypeInferencer
|
8
8
|
# PRODUCES: None (validation only)
|
9
9
|
# INTERFACE: new(schema, state).run(errors)
|
10
10
|
class TypeConsistencyChecker < PassBase
|
11
11
|
def run(errors)
|
12
|
-
input_meta = get_state(:
|
12
|
+
input_meta = get_state(:inputs, required: false) || {}
|
13
13
|
|
14
14
|
# First, validate that all declared types are valid
|
15
15
|
validate_declared_types(input_meta, errors)
|
@@ -4,17 +4,17 @@ module Kumi
|
|
4
4
|
module Analyzer
|
5
5
|
module Passes
|
6
6
|
# RESPONSIBILITY: Infer types for all declarations based on expression analysis
|
7
|
-
# DEPENDENCIES: Toposorter (needs
|
8
|
-
# PRODUCES:
|
7
|
+
# DEPENDENCIES: Toposorter (needs evaluation_order), DeclarationValidator (needs declarations)
|
8
|
+
# PRODUCES: inferred_types hash mapping declaration names to inferred types
|
9
9
|
# INTERFACE: new(schema, state).run(errors)
|
10
10
|
class TypeInferencer < PassBase
|
11
11
|
def run(errors)
|
12
12
|
types = {}
|
13
|
-
topo_order = get_state(:
|
14
|
-
definitions = get_state(:
|
15
|
-
|
13
|
+
topo_order = get_state(:evaluation_order)
|
14
|
+
definitions = get_state(:declarations)
|
15
|
+
|
16
16
|
# Get broadcast metadata from broadcast detector
|
17
|
-
broadcast_meta = get_state(:
|
17
|
+
broadcast_meta = get_state(:broadcasts, required: false) || {}
|
18
18
|
|
19
19
|
# Process declarations in topological order to ensure dependencies are resolved
|
20
20
|
topo_order.each do |name|
|
@@ -37,7 +37,7 @@ module Kumi
|
|
37
37
|
end
|
38
38
|
end
|
39
39
|
|
40
|
-
state.with(:
|
40
|
+
state.with(:inferred_types, types)
|
41
41
|
end
|
42
42
|
|
43
43
|
private
|
@@ -48,7 +48,7 @@ module Kumi
|
|
48
48
|
Types.infer_from_value(expr.value)
|
49
49
|
when InputReference
|
50
50
|
# Look up type from field metadata
|
51
|
-
input_meta = get_state(:
|
51
|
+
input_meta = get_state(:inputs, required: false) || {}
|
52
52
|
meta = input_meta[expr.name]
|
53
53
|
meta&.dig(:type) || :any
|
54
54
|
when DeclarationReference
|
@@ -68,7 +68,7 @@ module Kumi
|
|
68
68
|
end
|
69
69
|
|
70
70
|
def infer_call_type(call_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
|
71
|
-
fn_name = call_expr.fn_name
|
71
|
+
fn_name = call_expr.fn_name
|
72
72
|
args = call_expr.args
|
73
73
|
|
74
74
|
# Check broadcast metadata first
|
@@ -116,7 +116,7 @@ module Kumi
|
|
116
116
|
signature[:return_type] || :any
|
117
117
|
end
|
118
118
|
|
119
|
-
def infer_vectorized_element_type(call_expr,
|
119
|
+
def infer_vectorized_element_type(call_expr, _type_context, _broadcast_metadata)
|
120
120
|
# For vectorized arithmetic operations, infer the element type
|
121
121
|
# For now, assume arithmetic operations on floats produce floats
|
122
122
|
case call_expr.fn_name
|
@@ -127,10 +127,10 @@ module Kumi
|
|
127
127
|
end
|
128
128
|
end
|
129
129
|
|
130
|
-
def infer_function_return_type(fn_name,
|
130
|
+
def infer_function_return_type(fn_name, _args, _type_context, _broadcast_metadata)
|
131
131
|
# Get the function signature
|
132
132
|
return :any unless FunctionRegistry.supported?(fn_name)
|
133
|
-
|
133
|
+
|
134
134
|
signature = FunctionRegistry.signature(fn_name)
|
135
135
|
signature[:return_type] || :any
|
136
136
|
end
|
@@ -153,15 +153,15 @@ module Kumi
|
|
153
153
|
case expr
|
154
154
|
when InputElementReference
|
155
155
|
# Get the field type from metadata
|
156
|
-
input_meta = get_state(:
|
156
|
+
input_meta = get_state(:inputs, required: false) || {}
|
157
157
|
array_name = expr.path.first
|
158
158
|
field_name = expr.path[1]
|
159
|
-
|
159
|
+
|
160
160
|
array_meta = input_meta[array_name]
|
161
161
|
return :any unless array_meta&.dig(:type) == :array
|
162
|
-
|
162
|
+
|
163
163
|
array_meta.dig(:children, field_name, :type) || :any
|
164
|
-
|
164
|
+
|
165
165
|
when CallExpression
|
166
166
|
# For arithmetic operations, infer from operands
|
167
167
|
if %i[add subtract multiply divide].include?(expr.fn_name)
|
@@ -181,13 +181,13 @@ module Kumi
|
|
181
181
|
infer_expression_type(arg, type_context, vectorization_meta)
|
182
182
|
end
|
183
183
|
end
|
184
|
-
|
184
|
+
|
185
185
|
# Unify types for arithmetic
|
186
186
|
Types.unify(*arg_types) || :float
|
187
187
|
else
|
188
188
|
:any
|
189
189
|
end
|
190
|
-
|
190
|
+
|
191
191
|
else
|
192
192
|
:any
|
193
193
|
end
|
@@ -195,19 +195,19 @@ module Kumi
|
|
195
195
|
|
196
196
|
def infer_element_reference_type(expr)
|
197
197
|
# Get array field metadata
|
198
|
-
input_meta = get_state(:
|
199
|
-
|
198
|
+
input_meta = get_state(:inputs, required: false) || {}
|
199
|
+
|
200
200
|
return :any unless expr.path.size >= 2
|
201
|
-
|
201
|
+
|
202
202
|
array_name = expr.path.first
|
203
203
|
field_name = expr.path[1]
|
204
|
-
|
204
|
+
|
205
205
|
array_meta = input_meta[array_name]
|
206
206
|
return :any unless array_meta&.dig(:type) == :array
|
207
|
-
|
207
|
+
|
208
208
|
# Get the field type from children metadata
|
209
209
|
field_type = array_meta.dig(:children, field_name, :type) || :any
|
210
|
-
|
210
|
+
|
211
211
|
# Return array of field type (vectorized)
|
212
212
|
{ array: field_type }
|
213
213
|
end
|
@@ -4,8 +4,8 @@ module Kumi
|
|
4
4
|
module Analyzer
|
5
5
|
module Passes
|
6
6
|
# RESPONSIBILITY: Detect unsatisfiable constraints and analyze cascade mutual exclusion
|
7
|
-
# DEPENDENCIES: :
|
8
|
-
# PRODUCES: :
|
7
|
+
# DEPENDENCIES: :declarations from NameIndexer, :inputs from InputCollector
|
8
|
+
# PRODUCES: :cascades - Hash of cascade mutual exclusion analysis results
|
9
9
|
# INTERFACE: new(schema, state).run(errors)
|
10
10
|
class UnsatDetector < VisitorPass
|
11
11
|
include Syntax
|
@@ -14,21 +14,19 @@ module Kumi
|
|
14
14
|
Atom = Kumi::AtomUnsatSolver::Atom
|
15
15
|
|
16
16
|
def run(errors)
|
17
|
-
definitions = get_state(:
|
18
|
-
@input_meta = get_state(:
|
17
|
+
definitions = get_state(:declarations)
|
18
|
+
@input_meta = get_state(:inputs) || {}
|
19
19
|
@definitions = definitions
|
20
20
|
@evaluator = ConstantEvaluator.new(definitions)
|
21
21
|
|
22
22
|
# First pass: analyze cascade conditions for mutual exclusion
|
23
|
-
|
23
|
+
cascades = {}
|
24
24
|
each_decl do |decl|
|
25
|
-
|
26
|
-
end
|
25
|
+
cascades[decl.name] = analyze_cascade_mutual_exclusion(decl, definitions) if decl.expression.is_a?(CascadeExpression)
|
27
26
|
|
28
|
-
|
27
|
+
# Store cascade metadata for later passes
|
29
28
|
|
30
|
-
|
31
|
-
each_decl do |decl|
|
29
|
+
# Second pass: check for unsatisfiable constraints
|
32
30
|
if decl.expression.is_a?(CascadeExpression)
|
33
31
|
# Special handling for cascade expressions
|
34
32
|
check_cascade_expression(decl, definitions, errors)
|
@@ -51,7 +49,7 @@ module Kumi
|
|
51
49
|
report_error(errors, "conjunction `#{decl.name}` is impossible", location: decl.loc) if impossible
|
52
50
|
end
|
53
51
|
end
|
54
|
-
state.with(:
|
52
|
+
state.with(:cascades, cascades)
|
55
53
|
end
|
56
54
|
|
57
55
|
private
|
@@ -93,7 +91,7 @@ module Kumi
|
|
93
91
|
end
|
94
92
|
end
|
95
93
|
|
96
|
-
all_mutually_exclusive =
|
94
|
+
all_mutually_exclusive = total_pairs.positive? && (exclusive_pairs == total_pairs)
|
97
95
|
|
98
96
|
{
|
99
97
|
condition_traits: condition_traits,
|
@@ -130,7 +128,7 @@ module Kumi
|
|
130
128
|
val1.value != val2.value
|
131
129
|
end
|
132
130
|
|
133
|
-
def check_or_expression(or_expr, definitions,
|
131
|
+
def check_or_expression(or_expr, definitions, _errors)
|
134
132
|
# For OR expressions: A | B is impossible only if BOTH A AND B are impossible
|
135
133
|
# If either side is satisfiable, the OR is satisfiable
|
136
134
|
left_side, right_side = or_expr.args
|
data/lib/kumi/analyzer.rb
CHANGED
@@ -52,11 +52,11 @@ module Kumi
|
|
52
52
|
|
53
53
|
def self.create_analysis_result(state)
|
54
54
|
Result.new(
|
55
|
-
definitions: state[:
|
56
|
-
dependency_graph: state[:
|
57
|
-
leaf_map: state[:
|
58
|
-
topo_order: state[:
|
59
|
-
decl_types: state[:
|
55
|
+
definitions: state[:declarations],
|
56
|
+
dependency_graph: state[:dependencies],
|
57
|
+
leaf_map: state[:leaves],
|
58
|
+
topo_order: state[:evaluation_order],
|
59
|
+
decl_types: state[:inferred_types],
|
60
60
|
state: state.to_h
|
61
61
|
)
|
62
62
|
end
|
data/lib/kumi/compiler.rb
CHANGED
@@ -27,7 +27,6 @@ module Kumi
|
|
27
27
|
end
|
28
28
|
end
|
29
29
|
|
30
|
-
|
31
30
|
def compile_binding_node(expr)
|
32
31
|
name = expr.name
|
33
32
|
# Handle forward references in cycles by deferring binding lookup to runtime
|
@@ -45,7 +44,7 @@ module Kumi
|
|
45
44
|
def compile_call(expr)
|
46
45
|
fn_name = expr.fn_name
|
47
46
|
arg_fns = expr.args.map { |a| compile_expr(a) }
|
48
|
-
|
47
|
+
|
49
48
|
# Check if this is a vectorized operation
|
50
49
|
if vectorized_operation?(expr)
|
51
50
|
->(ctx) { invoke_vectorized_function(fn_name, arg_fns, ctx, expr.loc) }
|
@@ -56,40 +55,39 @@ module Kumi
|
|
56
55
|
|
57
56
|
def compile_cascade(expr)
|
58
57
|
# Check if current declaration is vectorized
|
59
|
-
broadcast_meta = @analysis.state[:
|
58
|
+
broadcast_meta = @analysis.state[:broadcasts]
|
60
59
|
is_vectorized = @current_declaration && broadcast_meta&.dig(:vectorized_operations, @current_declaration)
|
61
|
-
|
62
|
-
|
60
|
+
|
63
61
|
# For vectorized cascades, we need to transform conditions that use all?
|
64
|
-
if is_vectorized
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
62
|
+
pairs = if is_vectorized
|
63
|
+
expr.cases.map do |c|
|
64
|
+
condition_fn = transform_vectorized_condition(c.condition)
|
65
|
+
result_fn = compile_expr(c.result)
|
66
|
+
[condition_fn, result_fn]
|
67
|
+
end
|
68
|
+
else
|
69
|
+
expr.cases.map { |c| [compile_expr(c.condition), compile_expr(c.result)] }
|
70
|
+
end
|
71
|
+
|
74
72
|
if is_vectorized
|
75
73
|
lambda do |ctx|
|
76
74
|
# This cascade can be vectorized - check if we actually need to at runtime
|
77
75
|
# Evaluate all conditions and results to check for arrays
|
78
76
|
cond_results = pairs.map { |cond, _res| cond.call(ctx) }
|
79
77
|
res_results = pairs.map { |_cond, res| res.call(ctx) }
|
80
|
-
|
78
|
+
|
81
79
|
# Check if any conditions or results are arrays (vectorized)
|
82
|
-
has_vectorized_data = (cond_results + res_results).any?
|
83
|
-
|
80
|
+
has_vectorized_data = (cond_results + res_results).any?(Array)
|
81
|
+
|
84
82
|
if has_vectorized_data
|
85
83
|
# Apply element-wise cascade evaluation
|
86
|
-
array_length = cond_results.find { |v| v.is_a?(Array) }&.length ||
|
87
|
-
|
88
|
-
|
84
|
+
array_length = cond_results.find { |v| v.is_a?(Array) }&.length ||
|
85
|
+
res_results.find { |v| v.is_a?(Array) }&.length || 1
|
86
|
+
|
89
87
|
(0...array_length).map do |i|
|
90
|
-
pairs.each_with_index do |(
|
88
|
+
pairs.each_with_index do |(_cond, _res), pair_idx|
|
91
89
|
cond_val = cond_results[pair_idx].is_a?(Array) ? cond_results[pair_idx][i] : cond_results[pair_idx]
|
92
|
-
|
90
|
+
|
93
91
|
if cond_val
|
94
92
|
res_val = res_results[pair_idx].is_a?(Array) ? res_results[pair_idx][i] : res_results[pair_idx]
|
95
93
|
break res_val
|
@@ -98,7 +96,7 @@ module Kumi
|
|
98
96
|
end
|
99
97
|
else
|
100
98
|
# All data is scalar - use regular cascade evaluation
|
101
|
-
pairs.each_with_index do |(
|
99
|
+
pairs.each_with_index do |(_cond, _res), pair_idx|
|
102
100
|
return res_results[pair_idx] if cond_results[pair_idx]
|
103
101
|
end
|
104
102
|
nil
|
@@ -114,17 +112,17 @@ module Kumi
|
|
114
112
|
|
115
113
|
def transform_vectorized_condition(condition_expr)
|
116
114
|
# If this is fn(:all?, [trait_ref]), extract the trait_ref for vectorized cascades
|
117
|
-
if condition_expr.is_a?(Kumi::Syntax::CallExpression) &&
|
118
|
-
condition_expr.fn_name == :all? &&
|
115
|
+
if condition_expr.is_a?(Kumi::Syntax::CallExpression) &&
|
116
|
+
condition_expr.fn_name == :all? &&
|
119
117
|
condition_expr.args.length == 1
|
120
|
-
|
118
|
+
|
121
119
|
arg = condition_expr.args.first
|
122
120
|
if arg.is_a?(Kumi::Syntax::ArrayExpression) && arg.elements.length == 1
|
123
121
|
trait_ref = arg.elements.first
|
124
122
|
return compile_expr(trait_ref)
|
125
123
|
end
|
126
124
|
end
|
127
|
-
|
125
|
+
|
128
126
|
# Otherwise compile normally
|
129
127
|
compile_expr(condition_expr)
|
130
128
|
end
|
@@ -216,14 +214,12 @@ module Kumi
|
|
216
214
|
|
217
215
|
def vectorized_operation?(expr)
|
218
216
|
# Check if this operation uses vectorized inputs
|
219
|
-
broadcast_meta = @analysis.state[:
|
217
|
+
broadcast_meta = @analysis.state[:broadcasts]
|
220
218
|
return false unless broadcast_meta
|
221
|
-
|
219
|
+
|
222
220
|
# Reduction functions are NOT vectorized operations - they consume arrays
|
223
|
-
if FunctionRegistry.reducer?(expr.fn_name)
|
224
|
-
|
225
|
-
end
|
226
|
-
|
221
|
+
return false if FunctionRegistry.reducer?(expr.fn_name)
|
222
|
+
|
227
223
|
expr.args.any? do |arg|
|
228
224
|
case arg
|
229
225
|
when Kumi::Syntax::InputElementReference
|
@@ -235,15 +231,14 @@ module Kumi
|
|
235
231
|
end
|
236
232
|
end
|
237
233
|
end
|
238
|
-
|
239
|
-
|
234
|
+
|
240
235
|
def invoke_vectorized_function(name, arg_fns, ctx, loc)
|
241
236
|
# Evaluate arguments
|
242
237
|
values = arg_fns.map { |fn| fn.call(ctx) }
|
243
|
-
|
238
|
+
|
244
239
|
# Check if any argument is vectorized (array)
|
245
|
-
has_vectorized_args = values.any?
|
246
|
-
|
240
|
+
has_vectorized_args = values.any?(Array)
|
241
|
+
|
247
242
|
if has_vectorized_args
|
248
243
|
# Apply function with broadcasting to all vectorized arguments
|
249
244
|
vectorized_function_call(name, values)
|
@@ -259,27 +254,26 @@ module Kumi
|
|
259
254
|
runtime_error.define_singleton_method(:cause) { e }
|
260
255
|
raise runtime_error
|
261
256
|
end
|
262
|
-
|
257
|
+
|
263
258
|
def vectorized_function_call(fn_name, values)
|
264
259
|
# Get the function from registry
|
265
260
|
fn = FunctionRegistry.fetch(fn_name)
|
266
|
-
|
261
|
+
|
267
262
|
# Find array dimensions for broadcasting
|
268
263
|
array_values = values.select { |v| v.is_a?(Array) }
|
269
264
|
return fn.call(*values) if array_values.empty?
|
270
|
-
|
265
|
+
|
271
266
|
# All arrays should have the same length (validation could be added)
|
272
267
|
array_length = array_values.first.size
|
273
|
-
|
268
|
+
|
274
269
|
# Broadcast and apply function element-wise
|
275
270
|
(0...array_length).map do |i|
|
276
271
|
element_args = values.map do |v|
|
277
|
-
v.is_a?(Array) ? v[i] : v
|
272
|
+
v.is_a?(Array) ? v[i] : v # Broadcast scalars
|
278
273
|
end
|
279
274
|
fn.call(*element_args)
|
280
275
|
end
|
281
276
|
end
|
282
|
-
|
283
277
|
|
284
278
|
def invoke_function(name, arg_fns, ctx, loc)
|
285
279
|
fn = FunctionRegistry.fetch(name)
|
data/lib/kumi/error_reporting.rb
CHANGED
@@ -54,7 +54,7 @@ module Kumi
|
|
54
54
|
|
55
55
|
# Immediately raise a syntax error
|
56
56
|
def raise_syntax_error(message, location: nil, context: {})
|
57
|
-
raise_localized_error(message, location: location, error_class: Errors::SyntaxError, type: :syntax, context: context)
|
57
|
+
raise_localized_error(message, location: location, error_class: Kumi::Errors::SyntaxError, type: :syntax, context: context)
|
58
58
|
end
|
59
59
|
|
60
60
|
# Immediately raise a type error
|
data/lib/kumi/explain.rb
CHANGED
@@ -274,6 +274,18 @@ module Kumi
|
|
274
274
|
|
275
275
|
raise ArgumentError, "Schema not found or not compiled" unless syntax_tree && analyzer_result
|
276
276
|
|
277
|
+
metadata = analyzer_result.state
|
278
|
+
|
279
|
+
# Create a minimal analyzer result structure for compatibility
|
280
|
+
analyzer_result = OpenStruct.new(
|
281
|
+
definitions: metadata[:declarations] || {},
|
282
|
+
dependency_graph: metadata[:dependencies] || {},
|
283
|
+
leaf_map: metadata[:leaves] || {},
|
284
|
+
topo_order: metadata[:evaluation_order] || [],
|
285
|
+
decl_types: metadata[:inferred_types] || {},
|
286
|
+
state: metadata
|
287
|
+
)
|
288
|
+
|
277
289
|
generator = ExplanationGenerator.new(syntax_tree, analyzer_result, inputs)
|
278
290
|
generator.explain(target_name)
|
279
291
|
end
|