kumi 0.0.6 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. checksums.yaml +4 -4
  2. data/CLAUDE.md +33 -176
  3. data/README.md +33 -2
  4. data/docs/SYNTAX.md +2 -7
  5. data/docs/features/array-broadcasting.md +1 -1
  6. data/docs/schema_metadata/broadcasts.md +53 -0
  7. data/docs/schema_metadata/cascades.md +45 -0
  8. data/docs/schema_metadata/declarations.md +54 -0
  9. data/docs/schema_metadata/dependencies.md +57 -0
  10. data/docs/schema_metadata/evaluation_order.md +29 -0
  11. data/docs/schema_metadata/examples.md +95 -0
  12. data/docs/schema_metadata/inferred_types.md +46 -0
  13. data/docs/schema_metadata/inputs.md +86 -0
  14. data/docs/schema_metadata.md +108 -0
  15. data/lib/kumi/analyzer/passes/broadcast_detector.rb +52 -57
  16. data/lib/kumi/analyzer/passes/dependency_resolver.rb +8 -8
  17. data/lib/kumi/analyzer/passes/input_collector.rb +2 -2
  18. data/lib/kumi/analyzer/passes/name_indexer.rb +2 -2
  19. data/lib/kumi/analyzer/passes/semantic_constraint_validator.rb +15 -16
  20. data/lib/kumi/analyzer/passes/toposorter.rb +23 -23
  21. data/lib/kumi/analyzer/passes/type_checker.rb +7 -9
  22. data/lib/kumi/analyzer/passes/type_consistency_checker.rb +2 -2
  23. data/lib/kumi/analyzer/passes/type_inferencer.rb +24 -24
  24. data/lib/kumi/analyzer/passes/unsat_detector.rb +11 -13
  25. data/lib/kumi/analyzer.rb +5 -5
  26. data/lib/kumi/compiler.rb +39 -45
  27. data/lib/kumi/error_reporting.rb +1 -1
  28. data/lib/kumi/explain.rb +12 -0
  29. data/lib/kumi/export/node_registry.rb +2 -2
  30. data/lib/kumi/json_schema/generator.rb +63 -0
  31. data/lib/kumi/json_schema/validator.rb +25 -0
  32. data/lib/kumi/json_schema.rb +14 -0
  33. data/lib/kumi/{parser → ruby_parser}/build_context.rb +1 -1
  34. data/lib/kumi/{parser → ruby_parser}/declaration_reference_proxy.rb +3 -3
  35. data/lib/kumi/{parser → ruby_parser}/dsl.rb +1 -1
  36. data/lib/kumi/{parser → ruby_parser}/dsl_cascade_builder.rb +2 -2
  37. data/lib/kumi/{parser → ruby_parser}/expression_converter.rb +14 -14
  38. data/lib/kumi/{parser → ruby_parser}/guard_rails.rb +1 -1
  39. data/lib/kumi/{parser → ruby_parser}/input_builder.rb +1 -1
  40. data/lib/kumi/{parser → ruby_parser}/input_field_proxy.rb +4 -4
  41. data/lib/kumi/{parser → ruby_parser}/input_proxy.rb +1 -1
  42. data/lib/kumi/{parser → ruby_parser}/nested_input.rb +1 -1
  43. data/lib/kumi/{parser → ruby_parser}/parser.rb +11 -10
  44. data/lib/kumi/{parser → ruby_parser}/schema_builder.rb +1 -1
  45. data/lib/kumi/{parser → ruby_parser}/sugar.rb +1 -1
  46. data/lib/kumi/ruby_parser.rb +10 -0
  47. data/lib/kumi/schema.rb +10 -4
  48. data/lib/kumi/schema_instance.rb +6 -6
  49. data/lib/kumi/schema_metadata.rb +524 -0
  50. data/lib/kumi/vectorization_metadata.rb +4 -4
  51. data/lib/kumi/version.rb +1 -1
  52. data/lib/kumi.rb +14 -0
  53. metadata +28 -15
  54. data/lib/generators/trait_engine/templates/schema_spec.rb.erb +0 -27
@@ -24,7 +24,7 @@ module Kumi
24
24
 
25
25
  private
26
26
 
27
- def validate_semantic_constraints(node, decl, errors)
27
+ def validate_semantic_constraints(node, _decl, errors)
28
28
  case node
29
29
  when Kumi::Syntax::TraitDeclaration
30
30
  validate_trait_expression(node, errors)
@@ -48,20 +48,20 @@ module Kumi
48
48
 
49
49
  def validate_cascade_condition(when_case, errors)
50
50
  condition = when_case.condition
51
-
51
+
52
52
  case condition
53
53
  when Kumi::Syntax::DeclarationReference
54
54
  # Valid: trait reference
55
- return
55
+ nil
56
56
  when Kumi::Syntax::CallExpression
57
57
  # Valid if it's a boolean composition of traits (all?, any?, none?)
58
58
  return if boolean_trait_composition?(condition)
59
-
59
+
60
60
  # For now, allow other CallExpressions - they'll be validated by other passes
61
- return
61
+ nil
62
62
  when Kumi::Syntax::Literal
63
63
  # Allow literal conditions (like true/false) - they might be valid
64
- return
64
+ nil
65
65
  else
66
66
  # Only reject truly invalid conditions like InputReference or complex expressions
67
67
  report_error(
@@ -75,10 +75,10 @@ module Kumi
75
75
 
76
76
  def validate_function_call(call_expr, errors)
77
77
  fn_name = call_expr.fn_name
78
-
78
+
79
79
  # Skip validation if FunctionRegistry is being mocked for testing
80
80
  return if function_registry_mocked?
81
-
81
+
82
82
  return if FunctionRegistry.supported?(fn_name)
83
83
 
84
84
  report_error(
@@ -96,15 +96,14 @@ module Kumi
96
96
 
97
97
  def function_registry_mocked?
98
98
  # Check if FunctionRegistry is being mocked (for tests)
99
- begin
100
- # Try to access a method that doesn't exist in the real registry
101
- # If it's mocked, this won't raise an error
102
- FunctionRegistry.respond_to?(:confirm_support!)
103
- rescue
104
- false
105
- end
99
+
100
+ # Try to access a method that doesn't exist in the real registry
101
+ # If it's mocked, this won't raise an error
102
+ FunctionRegistry.respond_to?(:confirm_support!)
103
+ rescue StandardError
104
+ false
106
105
  end
107
106
  end
108
107
  end
109
108
  end
110
- end
109
+ end
@@ -4,16 +4,16 @@ module Kumi
4
4
  module Analyzer
5
5
  module Passes
6
6
  # RESPONSIBILITY: Compute topological ordering of declarations, allowing safe conditional cycles
7
- # DEPENDENCIES: :dependency_graph from DependencyResolver, :definitions from NameIndexer, :cascade_metadata from UnsatDetector
8
- # PRODUCES: :topo_order - Array of declaration names in evaluation order
7
+ # DEPENDENCIES: :dependencies from DependencyResolver, :declarations from NameIndexer, :cascades from UnsatDetector
8
+ # PRODUCES: :evaluation_order - Array of declaration names in evaluation order
9
9
  # INTERFACE: new(schema, state).run(errors)
10
10
  class Toposorter < PassBase
11
11
  def run(errors)
12
- dependency_graph = get_state(:dependency_graph, required: false) || {}
13
- definitions = get_state(:definitions, required: false) || {}
12
+ dependency_graph = get_state(:dependencies, required: false) || {}
13
+ definitions = get_state(:declarations, required: false) || {}
14
14
 
15
15
  order = compute_topological_order(dependency_graph, definitions, errors)
16
- state.with(:topo_order, order)
16
+ state.with(:evaluation_order, order)
17
17
  end
18
18
 
19
19
  private
@@ -22,7 +22,7 @@ module Kumi
22
22
  temp_marks = Set.new
23
23
  perm_marks = Set.new
24
24
  order = []
25
- cascade_metadata = get_state(:cascade_metadata) || {}
25
+ cascades = get_state(:cascades) || {}
26
26
 
27
27
  visit_node = lambda do |node, path = []|
28
28
  return if perm_marks.include?(node)
@@ -30,13 +30,13 @@ module Kumi
30
30
  if temp_marks.include?(node)
31
31
  # Check if this is a safe conditional cycle
32
32
  cycle_path = path + [node]
33
- if safe_conditional_cycle?(cycle_path, graph, cascade_metadata)
34
- # Allow this cycle - it's safe due to cascade mutual exclusion
35
- return
36
- else
37
- report_unexpected_cycle(temp_marks, node, errors)
38
- return
39
- end
33
+ return if safe_conditional_cycle?(cycle_path, graph, cascades)
34
+
35
+ # Allow this cycle - it's safe due to cascade mutual exclusion
36
+
37
+ report_unexpected_cycle(temp_marks, node, errors)
38
+
39
+ return
40
40
  end
41
41
 
42
42
  temp_marks << node
@@ -59,30 +59,30 @@ module Kumi
59
59
  order.freeze
60
60
  end
61
61
 
62
- def safe_conditional_cycle?(cycle_path, graph, cascade_metadata)
62
+ def safe_conditional_cycle?(cycle_path, graph, cascades)
63
63
  return false if cycle_path.nil? || cycle_path.size < 2
64
-
64
+
65
65
  # Find where the cycle starts - look for the first occurrence of the repeated node
66
66
  last_node = cycle_path.last
67
67
  return false if last_node.nil?
68
-
68
+
69
69
  cycle_start = cycle_path.index(last_node)
70
70
  return false unless cycle_start && cycle_start < cycle_path.size - 1
71
-
72
- cycle_nodes = cycle_path[cycle_start..-1]
73
-
71
+
72
+ cycle_nodes = cycle_path[cycle_start..]
73
+
74
74
  # Check if all edges in the cycle are conditional
75
75
  cycle_nodes.each_cons(2) do |from, to|
76
76
  edges = graph[from] || []
77
77
  edge = edges.find { |e| e.to == to }
78
-
78
+
79
79
  return false unless edge&.conditional
80
-
80
+
81
81
  # Check if the cascade has mutually exclusive conditions
82
- cascade_meta = cascade_metadata[edge.cascade_owner]
82
+ cascade_meta = cascades[edge.cascade_owner]
83
83
  return false unless cascade_meta&.dig(:all_mutually_exclusive)
84
84
  end
85
-
85
+
86
86
  true
87
87
  end
88
88
 
@@ -4,7 +4,7 @@ module Kumi
4
4
  module Analyzer
5
5
  module Passes
6
6
  # RESPONSIBILITY: Validate function call arity and argument types against FunctionRegistry
7
- # DEPENDENCIES: :decl_types from TypeInferencer
7
+ # DEPENDENCIES: :inferred_types from TypeInferencer
8
8
  # PRODUCES: None (validation only)
9
9
  # INTERFACE: new(schema, state).run(errors)
10
10
  class TypeChecker < VisitorPass
@@ -48,10 +48,8 @@ module Kumi
48
48
  return if types.nil? || (signature[:arity].negative? && node.args.empty?)
49
49
 
50
50
  # Skip type checking for vectorized operations
51
- broadcast_meta = get_state(:broadcast_metadata, required: false)
52
- if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
53
- return
54
- end
51
+ broadcast_meta = get_state(:broadcasts, required: false)
52
+ return if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
55
53
 
56
54
  node.args.each_with_index do |arg, i|
57
55
  validate_argument_type(arg, i, types[i], node.fn_name, errors)
@@ -65,7 +63,7 @@ module Kumi
65
63
  case arg
66
64
  when Kumi::Syntax::DeclarationReference
67
65
  broadcast_meta[:vectorized_operations]&.key?(arg.name) ||
68
- broadcast_meta[:reduction_operations]&.key?(arg.name)
66
+ broadcast_meta[:reduction_operations]&.key?(arg.name)
69
67
  when Kumi::Syntax::InputElementReference
70
68
  broadcast_meta[:array_fields]&.key?(arg.path.first)
71
69
  else
@@ -110,14 +108,14 @@ module Kumi
110
108
 
111
109
  def get_declared_field_type(field_name)
112
110
  # Get explicitly declared type from input metadata
113
- input_meta = get_state(:input_meta, required: false) || {}
111
+ input_meta = get_state(:inputs, required: false) || {}
114
112
  field_meta = input_meta[field_name]
115
113
  field_meta&.dig(:type) || Kumi::Types::ANY
116
114
  end
117
115
 
118
116
  def get_inferred_declaration_type(decl_name)
119
117
  # Get inferred type from type inference results
120
- decl_types = get_state(:decl_types, required: true)
118
+ decl_types = get_state(:inferred_types, required: true)
121
119
  decl_types[decl_name] || Kumi::Types::ANY
122
120
  end
123
121
 
@@ -127,7 +125,7 @@ module Kumi
127
125
  "`#{expr.value}` of type #{type} (literal value)"
128
126
 
129
127
  when Kumi::Syntax::InputReference
130
- input_meta = get_state(:input_meta, required: false) || {}
128
+ input_meta = get_state(:inputs, required: false) || {}
131
129
  field_meta = input_meta[expr.name]
132
130
 
133
131
  if field_meta&.dig(:type)
@@ -4,12 +4,12 @@ module Kumi
4
4
  module Analyzer
5
5
  module Passes
6
6
  # RESPONSIBILITY: Validate consistency between declared and inferred types
7
- # DEPENDENCIES: :input_meta from InputCollector, :decl_types from TypeInferencer
7
+ # DEPENDENCIES: :inputs from InputCollector, :inferred_types from TypeInferencer
8
8
  # PRODUCES: None (validation only)
9
9
  # INTERFACE: new(schema, state).run(errors)
10
10
  class TypeConsistencyChecker < PassBase
11
11
  def run(errors)
12
- input_meta = get_state(:input_meta, required: false) || {}
12
+ input_meta = get_state(:inputs, required: false) || {}
13
13
 
14
14
  # First, validate that all declared types are valid
15
15
  validate_declared_types(input_meta, errors)
@@ -4,17 +4,17 @@ module Kumi
4
4
  module Analyzer
5
5
  module Passes
6
6
  # RESPONSIBILITY: Infer types for all declarations based on expression analysis
7
- # DEPENDENCIES: Toposorter (needs topo_order), DeclarationValidator (needs definitions)
8
- # PRODUCES: decl_types hash mapping declaration names to inferred types
7
+ # DEPENDENCIES: Toposorter (needs evaluation_order), DeclarationValidator (needs declarations)
8
+ # PRODUCES: inferred_types hash mapping declaration names to inferred types
9
9
  # INTERFACE: new(schema, state).run(errors)
10
10
  class TypeInferencer < PassBase
11
11
  def run(errors)
12
12
  types = {}
13
- topo_order = get_state(:topo_order)
14
- definitions = get_state(:definitions)
15
-
13
+ topo_order = get_state(:evaluation_order)
14
+ definitions = get_state(:declarations)
15
+
16
16
  # Get broadcast metadata from broadcast detector
17
- broadcast_meta = get_state(:broadcast_metadata, required: false) || {}
17
+ broadcast_meta = get_state(:broadcasts, required: false) || {}
18
18
 
19
19
  # Process declarations in topological order to ensure dependencies are resolved
20
20
  topo_order.each do |name|
@@ -37,7 +37,7 @@ module Kumi
37
37
  end
38
38
  end
39
39
 
40
- state.with(:decl_types, types)
40
+ state.with(:inferred_types, types)
41
41
  end
42
42
 
43
43
  private
@@ -48,7 +48,7 @@ module Kumi
48
48
  Types.infer_from_value(expr.value)
49
49
  when InputReference
50
50
  # Look up type from field metadata
51
- input_meta = get_state(:input_meta, required: false) || {}
51
+ input_meta = get_state(:inputs, required: false) || {}
52
52
  meta = input_meta[expr.name]
53
53
  meta&.dig(:type) || :any
54
54
  when DeclarationReference
@@ -68,7 +68,7 @@ module Kumi
68
68
  end
69
69
 
70
70
  def infer_call_type(call_expr, type_context, broadcast_metadata = {}, current_decl_name = nil)
71
- fn_name = call_expr.fn_name
71
+ fn_name = call_expr.fn_name
72
72
  args = call_expr.args
73
73
 
74
74
  # Check broadcast metadata first
@@ -116,7 +116,7 @@ module Kumi
116
116
  signature[:return_type] || :any
117
117
  end
118
118
 
119
- def infer_vectorized_element_type(call_expr, type_context, broadcast_metadata)
119
+ def infer_vectorized_element_type(call_expr, _type_context, _broadcast_metadata)
120
120
  # For vectorized arithmetic operations, infer the element type
121
121
  # For now, assume arithmetic operations on floats produce floats
122
122
  case call_expr.fn_name
@@ -127,10 +127,10 @@ module Kumi
127
127
  end
128
128
  end
129
129
 
130
- def infer_function_return_type(fn_name, args, type_context, broadcast_metadata)
130
+ def infer_function_return_type(fn_name, _args, _type_context, _broadcast_metadata)
131
131
  # Get the function signature
132
132
  return :any unless FunctionRegistry.supported?(fn_name)
133
-
133
+
134
134
  signature = FunctionRegistry.signature(fn_name)
135
135
  signature[:return_type] || :any
136
136
  end
@@ -153,15 +153,15 @@ module Kumi
153
153
  case expr
154
154
  when InputElementReference
155
155
  # Get the field type from metadata
156
- input_meta = get_state(:input_meta, required: false) || {}
156
+ input_meta = get_state(:inputs, required: false) || {}
157
157
  array_name = expr.path.first
158
158
  field_name = expr.path[1]
159
-
159
+
160
160
  array_meta = input_meta[array_name]
161
161
  return :any unless array_meta&.dig(:type) == :array
162
-
162
+
163
163
  array_meta.dig(:children, field_name, :type) || :any
164
-
164
+
165
165
  when CallExpression
166
166
  # For arithmetic operations, infer from operands
167
167
  if %i[add subtract multiply divide].include?(expr.fn_name)
@@ -181,13 +181,13 @@ module Kumi
181
181
  infer_expression_type(arg, type_context, vectorization_meta)
182
182
  end
183
183
  end
184
-
184
+
185
185
  # Unify types for arithmetic
186
186
  Types.unify(*arg_types) || :float
187
187
  else
188
188
  :any
189
189
  end
190
-
190
+
191
191
  else
192
192
  :any
193
193
  end
@@ -195,19 +195,19 @@ module Kumi
195
195
 
196
196
  def infer_element_reference_type(expr)
197
197
  # Get array field metadata
198
- input_meta = get_state(:input_meta, required: false) || {}
199
-
198
+ input_meta = get_state(:inputs, required: false) || {}
199
+
200
200
  return :any unless expr.path.size >= 2
201
-
201
+
202
202
  array_name = expr.path.first
203
203
  field_name = expr.path[1]
204
-
204
+
205
205
  array_meta = input_meta[array_name]
206
206
  return :any unless array_meta&.dig(:type) == :array
207
-
207
+
208
208
  # Get the field type from children metadata
209
209
  field_type = array_meta.dig(:children, field_name, :type) || :any
210
-
210
+
211
211
  # Return array of field type (vectorized)
212
212
  { array: field_type }
213
213
  end
@@ -4,8 +4,8 @@ module Kumi
4
4
  module Analyzer
5
5
  module Passes
6
6
  # RESPONSIBILITY: Detect unsatisfiable constraints and analyze cascade mutual exclusion
7
- # DEPENDENCIES: :definitions from NameIndexer, :input_meta from InputCollector
8
- # PRODUCES: :cascade_metadata - Hash of cascade mutual exclusion analysis results
7
+ # DEPENDENCIES: :declarations from NameIndexer, :inputs from InputCollector
8
+ # PRODUCES: :cascades - Hash of cascade mutual exclusion analysis results
9
9
  # INTERFACE: new(schema, state).run(errors)
10
10
  class UnsatDetector < VisitorPass
11
11
  include Syntax
@@ -14,21 +14,19 @@ module Kumi
14
14
  Atom = Kumi::AtomUnsatSolver::Atom
15
15
 
16
16
  def run(errors)
17
- definitions = get_state(:definitions)
18
- @input_meta = get_state(:input_meta) || {}
17
+ definitions = get_state(:declarations)
18
+ @input_meta = get_state(:inputs) || {}
19
19
  @definitions = definitions
20
20
  @evaluator = ConstantEvaluator.new(definitions)
21
21
 
22
22
  # First pass: analyze cascade conditions for mutual exclusion
23
- cascade_metadata = {}
23
+ cascades = {}
24
24
  each_decl do |decl|
25
- cascade_metadata[decl.name] = analyze_cascade_mutual_exclusion(decl, definitions) if decl.expression.is_a?(CascadeExpression)
26
- end
25
+ cascades[decl.name] = analyze_cascade_mutual_exclusion(decl, definitions) if decl.expression.is_a?(CascadeExpression)
27
26
 
28
- # Store cascade metadata for later passes
27
+ # Store cascade metadata for later passes
29
28
 
30
- # Second pass: check for unsatisfiable constraints
31
- each_decl do |decl|
29
+ # Second pass: check for unsatisfiable constraints
32
30
  if decl.expression.is_a?(CascadeExpression)
33
31
  # Special handling for cascade expressions
34
32
  check_cascade_expression(decl, definitions, errors)
@@ -51,7 +49,7 @@ module Kumi
51
49
  report_error(errors, "conjunction `#{decl.name}` is impossible", location: decl.loc) if impossible
52
50
  end
53
51
  end
54
- state.with(:cascade_metadata, cascade_metadata)
52
+ state.with(:cascades, cascades)
55
53
  end
56
54
 
57
55
  private
@@ -93,7 +91,7 @@ module Kumi
93
91
  end
94
92
  end
95
93
 
96
- all_mutually_exclusive = (total_pairs > 0) && (exclusive_pairs == total_pairs)
94
+ all_mutually_exclusive = total_pairs.positive? && (exclusive_pairs == total_pairs)
97
95
 
98
96
  {
99
97
  condition_traits: condition_traits,
@@ -130,7 +128,7 @@ module Kumi
130
128
  val1.value != val2.value
131
129
  end
132
130
 
133
- def check_or_expression(or_expr, definitions, errors)
131
+ def check_or_expression(or_expr, definitions, _errors)
134
132
  # For OR expressions: A | B is impossible only if BOTH A AND B are impossible
135
133
  # If either side is satisfiable, the OR is satisfiable
136
134
  left_side, right_side = or_expr.args
data/lib/kumi/analyzer.rb CHANGED
@@ -52,11 +52,11 @@ module Kumi
52
52
 
53
53
  def self.create_analysis_result(state)
54
54
  Result.new(
55
- definitions: state[:definitions],
56
- dependency_graph: state[:dependency_graph],
57
- leaf_map: state[:leaf_map],
58
- topo_order: state[:topo_order],
59
- decl_types: state[:decl_types],
55
+ definitions: state[:declarations],
56
+ dependency_graph: state[:dependencies],
57
+ leaf_map: state[:leaves],
58
+ topo_order: state[:evaluation_order],
59
+ decl_types: state[:inferred_types],
60
60
  state: state.to_h
61
61
  )
62
62
  end
data/lib/kumi/compiler.rb CHANGED
@@ -27,7 +27,6 @@ module Kumi
27
27
  end
28
28
  end
29
29
 
30
-
31
30
  def compile_binding_node(expr)
32
31
  name = expr.name
33
32
  # Handle forward references in cycles by deferring binding lookup to runtime
@@ -45,7 +44,7 @@ module Kumi
45
44
  def compile_call(expr)
46
45
  fn_name = expr.fn_name
47
46
  arg_fns = expr.args.map { |a| compile_expr(a) }
48
-
47
+
49
48
  # Check if this is a vectorized operation
50
49
  if vectorized_operation?(expr)
51
50
  ->(ctx) { invoke_vectorized_function(fn_name, arg_fns, ctx, expr.loc) }
@@ -56,40 +55,39 @@ module Kumi
56
55
 
57
56
  def compile_cascade(expr)
58
57
  # Check if current declaration is vectorized
59
- broadcast_meta = @analysis.state[:broadcast_metadata]
58
+ broadcast_meta = @analysis.state[:broadcasts]
60
59
  is_vectorized = @current_declaration && broadcast_meta&.dig(:vectorized_operations, @current_declaration)
61
-
62
-
60
+
63
61
  # For vectorized cascades, we need to transform conditions that use all?
64
- if is_vectorized
65
- pairs = expr.cases.map do |c|
66
- condition_fn = transform_vectorized_condition(c.condition)
67
- result_fn = compile_expr(c.result)
68
- [condition_fn, result_fn]
69
- end
70
- else
71
- pairs = expr.cases.map { |c| [compile_expr(c.condition), compile_expr(c.result)] }
72
- end
73
-
62
+ pairs = if is_vectorized
63
+ expr.cases.map do |c|
64
+ condition_fn = transform_vectorized_condition(c.condition)
65
+ result_fn = compile_expr(c.result)
66
+ [condition_fn, result_fn]
67
+ end
68
+ else
69
+ expr.cases.map { |c| [compile_expr(c.condition), compile_expr(c.result)] }
70
+ end
71
+
74
72
  if is_vectorized
75
73
  lambda do |ctx|
76
74
  # This cascade can be vectorized - check if we actually need to at runtime
77
75
  # Evaluate all conditions and results to check for arrays
78
76
  cond_results = pairs.map { |cond, _res| cond.call(ctx) }
79
77
  res_results = pairs.map { |_cond, res| res.call(ctx) }
80
-
78
+
81
79
  # Check if any conditions or results are arrays (vectorized)
82
- has_vectorized_data = (cond_results + res_results).any? { |v| v.is_a?(Array) }
83
-
80
+ has_vectorized_data = (cond_results + res_results).any?(Array)
81
+
84
82
  if has_vectorized_data
85
83
  # Apply element-wise cascade evaluation
86
- array_length = cond_results.find { |v| v.is_a?(Array) }&.length ||
87
- res_results.find { |v| v.is_a?(Array) }&.length || 1
88
-
84
+ array_length = cond_results.find { |v| v.is_a?(Array) }&.length ||
85
+ res_results.find { |v| v.is_a?(Array) }&.length || 1
86
+
89
87
  (0...array_length).map do |i|
90
- pairs.each_with_index do |(cond, res), pair_idx|
88
+ pairs.each_with_index do |(_cond, _res), pair_idx|
91
89
  cond_val = cond_results[pair_idx].is_a?(Array) ? cond_results[pair_idx][i] : cond_results[pair_idx]
92
-
90
+
93
91
  if cond_val
94
92
  res_val = res_results[pair_idx].is_a?(Array) ? res_results[pair_idx][i] : res_results[pair_idx]
95
93
  break res_val
@@ -98,7 +96,7 @@ module Kumi
98
96
  end
99
97
  else
100
98
  # All data is scalar - use regular cascade evaluation
101
- pairs.each_with_index do |(cond, res), pair_idx|
99
+ pairs.each_with_index do |(_cond, _res), pair_idx|
102
100
  return res_results[pair_idx] if cond_results[pair_idx]
103
101
  end
104
102
  nil
@@ -114,17 +112,17 @@ module Kumi
114
112
 
115
113
  def transform_vectorized_condition(condition_expr)
116
114
  # If this is fn(:all?, [trait_ref]), extract the trait_ref for vectorized cascades
117
- if condition_expr.is_a?(Kumi::Syntax::CallExpression) &&
118
- condition_expr.fn_name == :all? &&
115
+ if condition_expr.is_a?(Kumi::Syntax::CallExpression) &&
116
+ condition_expr.fn_name == :all? &&
119
117
  condition_expr.args.length == 1
120
-
118
+
121
119
  arg = condition_expr.args.first
122
120
  if arg.is_a?(Kumi::Syntax::ArrayExpression) && arg.elements.length == 1
123
121
  trait_ref = arg.elements.first
124
122
  return compile_expr(trait_ref)
125
123
  end
126
124
  end
127
-
125
+
128
126
  # Otherwise compile normally
129
127
  compile_expr(condition_expr)
130
128
  end
@@ -216,14 +214,12 @@ module Kumi
216
214
 
217
215
  def vectorized_operation?(expr)
218
216
  # Check if this operation uses vectorized inputs
219
- broadcast_meta = @analysis.state[:broadcast_metadata]
217
+ broadcast_meta = @analysis.state[:broadcasts]
220
218
  return false unless broadcast_meta
221
-
219
+
222
220
  # Reduction functions are NOT vectorized operations - they consume arrays
223
- if FunctionRegistry.reducer?(expr.fn_name)
224
- return false
225
- end
226
-
221
+ return false if FunctionRegistry.reducer?(expr.fn_name)
222
+
227
223
  expr.args.any? do |arg|
228
224
  case arg
229
225
  when Kumi::Syntax::InputElementReference
@@ -235,15 +231,14 @@ module Kumi
235
231
  end
236
232
  end
237
233
  end
238
-
239
-
234
+
240
235
  def invoke_vectorized_function(name, arg_fns, ctx, loc)
241
236
  # Evaluate arguments
242
237
  values = arg_fns.map { |fn| fn.call(ctx) }
243
-
238
+
244
239
  # Check if any argument is vectorized (array)
245
- has_vectorized_args = values.any? { |v| v.is_a?(Array) }
246
-
240
+ has_vectorized_args = values.any?(Array)
241
+
247
242
  if has_vectorized_args
248
243
  # Apply function with broadcasting to all vectorized arguments
249
244
  vectorized_function_call(name, values)
@@ -259,27 +254,26 @@ module Kumi
259
254
  runtime_error.define_singleton_method(:cause) { e }
260
255
  raise runtime_error
261
256
  end
262
-
257
+
263
258
  def vectorized_function_call(fn_name, values)
264
259
  # Get the function from registry
265
260
  fn = FunctionRegistry.fetch(fn_name)
266
-
261
+
267
262
  # Find array dimensions for broadcasting
268
263
  array_values = values.select { |v| v.is_a?(Array) }
269
264
  return fn.call(*values) if array_values.empty?
270
-
265
+
271
266
  # All arrays should have the same length (validation could be added)
272
267
  array_length = array_values.first.size
273
-
268
+
274
269
  # Broadcast and apply function element-wise
275
270
  (0...array_length).map do |i|
276
271
  element_args = values.map do |v|
277
- v.is_a?(Array) ? v[i] : v # Broadcast scalars
272
+ v.is_a?(Array) ? v[i] : v # Broadcast scalars
278
273
  end
279
274
  fn.call(*element_args)
280
275
  end
281
276
  end
282
-
283
277
 
284
278
  def invoke_function(name, arg_fns, ctx, loc)
285
279
  fn = FunctionRegistry.fetch(name)
@@ -54,7 +54,7 @@ module Kumi
54
54
 
55
55
  # Immediately raise a syntax error
56
56
  def raise_syntax_error(message, location: nil, context: {})
57
- raise_localized_error(message, location: location, error_class: Errors::SyntaxError, type: :syntax, context: context)
57
+ raise_localized_error(message, location: location, error_class: Kumi::Errors::SyntaxError, type: :syntax, context: context)
58
58
  end
59
59
 
60
60
  # Immediately raise a type error
data/lib/kumi/explain.rb CHANGED
@@ -274,6 +274,18 @@ module Kumi
274
274
 
275
275
  raise ArgumentError, "Schema not found or not compiled" unless syntax_tree && analyzer_result
276
276
 
277
+ metadata = analyzer_result.state
278
+
279
+ # Create a minimal analyzer result structure for compatibility
280
+ analyzer_result = OpenStruct.new(
281
+ definitions: metadata[:declarations] || {},
282
+ dependency_graph: metadata[:dependencies] || {},
283
+ leaf_map: metadata[:leaves] || {},
284
+ topo_order: metadata[:evaluation_order] || [],
285
+ decl_types: metadata[:inferred_types] || {},
286
+ state: metadata
287
+ )
288
+
277
289
  generator = ExplanationGenerator.new(syntax_tree, analyzer_result, inputs)
278
290
  generator.explain(target_name)
279
291
  end