kumi 0.0.6 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. checksums.yaml +4 -4
  2. data/CLAUDE.md +34 -177
  3. data/README.md +41 -7
  4. data/docs/SYNTAX.md +2 -7
  5. data/docs/features/array-broadcasting.md +1 -1
  6. data/docs/schema_metadata/broadcasts.md +53 -0
  7. data/docs/schema_metadata/cascades.md +45 -0
  8. data/docs/schema_metadata/declarations.md +54 -0
  9. data/docs/schema_metadata/dependencies.md +57 -0
  10. data/docs/schema_metadata/evaluation_order.md +29 -0
  11. data/docs/schema_metadata/examples.md +95 -0
  12. data/docs/schema_metadata/inferred_types.md +46 -0
  13. data/docs/schema_metadata/inputs.md +86 -0
  14. data/docs/schema_metadata.md +108 -0
  15. data/examples/game_of_life.rb +1 -1
  16. data/examples/static_analysis_errors.rb +7 -7
  17. data/lib/kumi/analyzer.rb +20 -20
  18. data/lib/kumi/compiler.rb +44 -50
  19. data/lib/kumi/core/analyzer/analysis_state.rb +39 -0
  20. data/lib/kumi/core/analyzer/constant_evaluator.rb +59 -0
  21. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +248 -0
  22. data/lib/kumi/core/analyzer/passes/declaration_validator.rb +45 -0
  23. data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +153 -0
  24. data/lib/kumi/core/analyzer/passes/input_collector.rb +139 -0
  25. data/lib/kumi/core/analyzer/passes/name_indexer.rb +26 -0
  26. data/lib/kumi/core/analyzer/passes/pass_base.rb +52 -0
  27. data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +111 -0
  28. data/lib/kumi/core/analyzer/passes/toposorter.rb +110 -0
  29. data/lib/kumi/core/analyzer/passes/type_checker.rb +162 -0
  30. data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +48 -0
  31. data/lib/kumi/core/analyzer/passes/type_inferencer.rb +236 -0
  32. data/lib/kumi/core/analyzer/passes/unsat_detector.rb +406 -0
  33. data/lib/kumi/core/analyzer/passes/visitor_pass.rb +44 -0
  34. data/lib/kumi/core/atom_unsat_solver.rb +396 -0
  35. data/lib/kumi/core/compiled_schema.rb +43 -0
  36. data/lib/kumi/core/constraint_relationship_solver.rb +641 -0
  37. data/lib/kumi/core/domain/enum_analyzer.rb +55 -0
  38. data/lib/kumi/core/domain/range_analyzer.rb +85 -0
  39. data/lib/kumi/core/domain/validator.rb +82 -0
  40. data/lib/kumi/core/domain/violation_formatter.rb +42 -0
  41. data/lib/kumi/core/error_reporter.rb +166 -0
  42. data/lib/kumi/core/error_reporting.rb +97 -0
  43. data/lib/kumi/core/errors.rb +120 -0
  44. data/lib/kumi/core/evaluation_wrapper.rb +40 -0
  45. data/lib/kumi/core/explain.rb +295 -0
  46. data/lib/kumi/core/export/deserializer.rb +41 -0
  47. data/lib/kumi/core/export/errors.rb +14 -0
  48. data/lib/kumi/core/export/node_builders.rb +142 -0
  49. data/lib/kumi/core/export/node_registry.rb +54 -0
  50. data/lib/kumi/core/export/node_serializers.rb +158 -0
  51. data/lib/kumi/core/export/serializer.rb +25 -0
  52. data/lib/kumi/core/export.rb +35 -0
  53. data/lib/kumi/core/function_registry/collection_functions.rb +202 -0
  54. data/lib/kumi/core/function_registry/comparison_functions.rb +33 -0
  55. data/lib/kumi/core/function_registry/conditional_functions.rb +38 -0
  56. data/lib/kumi/core/function_registry/function_builder.rb +95 -0
  57. data/lib/kumi/core/function_registry/logical_functions.rb +44 -0
  58. data/lib/kumi/core/function_registry/math_functions.rb +74 -0
  59. data/lib/kumi/core/function_registry/string_functions.rb +57 -0
  60. data/lib/kumi/core/function_registry/type_functions.rb +53 -0
  61. data/lib/kumi/{function_registry.rb → core/function_registry.rb} +28 -36
  62. data/lib/kumi/core/input/type_matcher.rb +97 -0
  63. data/lib/kumi/core/input/validator.rb +51 -0
  64. data/lib/kumi/core/input/violation_creator.rb +52 -0
  65. data/lib/kumi/core/json_schema/generator.rb +65 -0
  66. data/lib/kumi/core/json_schema/validator.rb +27 -0
  67. data/lib/kumi/core/json_schema.rb +16 -0
  68. data/lib/kumi/core/ruby_parser/build_context.rb +27 -0
  69. data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +38 -0
  70. data/lib/kumi/core/ruby_parser/dsl.rb +14 -0
  71. data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +138 -0
  72. data/lib/kumi/core/ruby_parser/expression_converter.rb +128 -0
  73. data/lib/kumi/core/ruby_parser/guard_rails.rb +45 -0
  74. data/lib/kumi/core/ruby_parser/input_builder.rb +127 -0
  75. data/lib/kumi/core/ruby_parser/input_field_proxy.rb +48 -0
  76. data/lib/kumi/core/ruby_parser/input_proxy.rb +31 -0
  77. data/lib/kumi/core/ruby_parser/nested_input.rb +17 -0
  78. data/lib/kumi/core/ruby_parser/parser.rb +71 -0
  79. data/lib/kumi/core/ruby_parser/schema_builder.rb +175 -0
  80. data/lib/kumi/core/ruby_parser/sugar.rb +263 -0
  81. data/lib/kumi/core/ruby_parser.rb +12 -0
  82. data/lib/kumi/core/schema_instance.rb +111 -0
  83. data/lib/kumi/core/types/builder.rb +23 -0
  84. data/lib/kumi/core/types/compatibility.rb +96 -0
  85. data/lib/kumi/core/types/formatter.rb +26 -0
  86. data/lib/kumi/core/types/inference.rb +42 -0
  87. data/lib/kumi/core/types/normalizer.rb +72 -0
  88. data/lib/kumi/core/types/validator.rb +37 -0
  89. data/lib/kumi/core/types.rb +66 -0
  90. data/lib/kumi/core/vectorization_metadata.rb +110 -0
  91. data/lib/kumi/errors.rb +1 -112
  92. data/lib/kumi/registry.rb +37 -0
  93. data/lib/kumi/schema.rb +13 -7
  94. data/lib/kumi/schema_metadata.rb +524 -0
  95. data/lib/kumi/syntax/array_expression.rb +6 -6
  96. data/lib/kumi/syntax/call_expression.rb +4 -4
  97. data/lib/kumi/syntax/cascade_expression.rb +4 -4
  98. data/lib/kumi/syntax/case_expression.rb +4 -4
  99. data/lib/kumi/syntax/declaration_reference.rb +4 -4
  100. data/lib/kumi/syntax/hash_expression.rb +4 -4
  101. data/lib/kumi/syntax/input_declaration.rb +5 -5
  102. data/lib/kumi/syntax/input_element_reference.rb +5 -5
  103. data/lib/kumi/syntax/input_reference.rb +5 -5
  104. data/lib/kumi/syntax/literal.rb +4 -4
  105. data/lib/kumi/syntax/node.rb +34 -34
  106. data/lib/kumi/syntax/root.rb +6 -6
  107. data/lib/kumi/syntax/trait_declaration.rb +4 -4
  108. data/lib/kumi/syntax/value_declaration.rb +4 -4
  109. data/lib/kumi/version.rb +1 -1
  110. data/lib/kumi.rb +14 -0
  111. data/migrate_to_core_iterative.rb +938 -0
  112. data/scripts/generate_function_docs.rb +9 -9
  113. metadata +85 -69
  114. data/lib/generators/trait_engine/templates/schema_spec.rb.erb +0 -27
  115. data/lib/kumi/analyzer/analysis_state.rb +0 -37
  116. data/lib/kumi/analyzer/constant_evaluator.rb +0 -57
  117. data/lib/kumi/analyzer/passes/broadcast_detector.rb +0 -251
  118. data/lib/kumi/analyzer/passes/declaration_validator.rb +0 -43
  119. data/lib/kumi/analyzer/passes/dependency_resolver.rb +0 -151
  120. data/lib/kumi/analyzer/passes/input_collector.rb +0 -137
  121. data/lib/kumi/analyzer/passes/name_indexer.rb +0 -24
  122. data/lib/kumi/analyzer/passes/pass_base.rb +0 -50
  123. data/lib/kumi/analyzer/passes/semantic_constraint_validator.rb +0 -110
  124. data/lib/kumi/analyzer/passes/toposorter.rb +0 -108
  125. data/lib/kumi/analyzer/passes/type_checker.rb +0 -162
  126. data/lib/kumi/analyzer/passes/type_consistency_checker.rb +0 -46
  127. data/lib/kumi/analyzer/passes/type_inferencer.rb +0 -232
  128. data/lib/kumi/analyzer/passes/unsat_detector.rb +0 -406
  129. data/lib/kumi/analyzer/passes/visitor_pass.rb +0 -42
  130. data/lib/kumi/atom_unsat_solver.rb +0 -394
  131. data/lib/kumi/compiled_schema.rb +0 -41
  132. data/lib/kumi/constraint_relationship_solver.rb +0 -638
  133. data/lib/kumi/domain/enum_analyzer.rb +0 -53
  134. data/lib/kumi/domain/range_analyzer.rb +0 -83
  135. data/lib/kumi/domain/validator.rb +0 -80
  136. data/lib/kumi/domain/violation_formatter.rb +0 -40
  137. data/lib/kumi/error_reporter.rb +0 -164
  138. data/lib/kumi/error_reporting.rb +0 -95
  139. data/lib/kumi/evaluation_wrapper.rb +0 -38
  140. data/lib/kumi/explain.rb +0 -281
  141. data/lib/kumi/export/deserializer.rb +0 -39
  142. data/lib/kumi/export/errors.rb +0 -12
  143. data/lib/kumi/export/node_builders.rb +0 -140
  144. data/lib/kumi/export/node_registry.rb +0 -52
  145. data/lib/kumi/export/node_serializers.rb +0 -156
  146. data/lib/kumi/export/serializer.rb +0 -23
  147. data/lib/kumi/export.rb +0 -33
  148. data/lib/kumi/function_registry/collection_functions.rb +0 -200
  149. data/lib/kumi/function_registry/comparison_functions.rb +0 -31
  150. data/lib/kumi/function_registry/conditional_functions.rb +0 -36
  151. data/lib/kumi/function_registry/function_builder.rb +0 -93
  152. data/lib/kumi/function_registry/logical_functions.rb +0 -42
  153. data/lib/kumi/function_registry/math_functions.rb +0 -72
  154. data/lib/kumi/function_registry/string_functions.rb +0 -54
  155. data/lib/kumi/function_registry/type_functions.rb +0 -51
  156. data/lib/kumi/input/type_matcher.rb +0 -95
  157. data/lib/kumi/input/validator.rb +0 -49
  158. data/lib/kumi/input/violation_creator.rb +0 -50
  159. data/lib/kumi/parser/build_context.rb +0 -25
  160. data/lib/kumi/parser/declaration_reference_proxy.rb +0 -36
  161. data/lib/kumi/parser/dsl.rb +0 -12
  162. data/lib/kumi/parser/dsl_cascade_builder.rb +0 -136
  163. data/lib/kumi/parser/expression_converter.rb +0 -126
  164. data/lib/kumi/parser/guard_rails.rb +0 -43
  165. data/lib/kumi/parser/input_builder.rb +0 -125
  166. data/lib/kumi/parser/input_field_proxy.rb +0 -46
  167. data/lib/kumi/parser/input_proxy.rb +0 -29
  168. data/lib/kumi/parser/nested_input.rb +0 -15
  169. data/lib/kumi/parser/parser.rb +0 -68
  170. data/lib/kumi/parser/schema_builder.rb +0 -173
  171. data/lib/kumi/parser/sugar.rb +0 -261
  172. data/lib/kumi/schema_instance.rb +0 -109
  173. data/lib/kumi/types/builder.rb +0 -21
  174. data/lib/kumi/types/compatibility.rb +0 -94
  175. data/lib/kumi/types/formatter.rb +0 -24
  176. data/lib/kumi/types/inference.rb +0 -40
  177. data/lib/kumi/types/normalizer.rb +0 -70
  178. data/lib/kumi/types/validator.rb +0 -35
  179. data/lib/kumi/types.rb +0 -64
  180. data/lib/kumi/vectorization_metadata.rb +0 -108
data/lib/kumi/compiler.rb CHANGED
@@ -27,7 +27,6 @@ module Kumi
27
27
  end
28
28
  end
29
29
 
30
-
31
30
  def compile_binding_node(expr)
32
31
  name = expr.name
33
32
  # Handle forward references in cycles by deferring binding lookup to runtime
@@ -45,7 +44,7 @@ module Kumi
45
44
  def compile_call(expr)
46
45
  fn_name = expr.fn_name
47
46
  arg_fns = expr.args.map { |a| compile_expr(a) }
48
-
47
+
49
48
  # Check if this is a vectorized operation
50
49
  if vectorized_operation?(expr)
51
50
  ->(ctx) { invoke_vectorized_function(fn_name, arg_fns, ctx, expr.loc) }
@@ -56,40 +55,39 @@ module Kumi
56
55
 
57
56
  def compile_cascade(expr)
58
57
  # Check if current declaration is vectorized
59
- broadcast_meta = @analysis.state[:broadcast_metadata]
58
+ broadcast_meta = @analysis.state[:broadcasts]
60
59
  is_vectorized = @current_declaration && broadcast_meta&.dig(:vectorized_operations, @current_declaration)
61
-
62
-
60
+
63
61
  # For vectorized cascades, we need to transform conditions that use all?
64
- if is_vectorized
65
- pairs = expr.cases.map do |c|
66
- condition_fn = transform_vectorized_condition(c.condition)
67
- result_fn = compile_expr(c.result)
68
- [condition_fn, result_fn]
69
- end
70
- else
71
- pairs = expr.cases.map { |c| [compile_expr(c.condition), compile_expr(c.result)] }
72
- end
73
-
62
+ pairs = if is_vectorized
63
+ expr.cases.map do |c|
64
+ condition_fn = transform_vectorized_condition(c.condition)
65
+ result_fn = compile_expr(c.result)
66
+ [condition_fn, result_fn]
67
+ end
68
+ else
69
+ expr.cases.map { |c| [compile_expr(c.condition), compile_expr(c.result)] }
70
+ end
71
+
74
72
  if is_vectorized
75
73
  lambda do |ctx|
76
74
  # This cascade can be vectorized - check if we actually need to at runtime
77
75
  # Evaluate all conditions and results to check for arrays
78
76
  cond_results = pairs.map { |cond, _res| cond.call(ctx) }
79
77
  res_results = pairs.map { |_cond, res| res.call(ctx) }
80
-
78
+
81
79
  # Check if any conditions or results are arrays (vectorized)
82
- has_vectorized_data = (cond_results + res_results).any? { |v| v.is_a?(Array) }
83
-
80
+ has_vectorized_data = (cond_results + res_results).any?(Array)
81
+
84
82
  if has_vectorized_data
85
83
  # Apply element-wise cascade evaluation
86
- array_length = cond_results.find { |v| v.is_a?(Array) }&.length ||
87
- res_results.find { |v| v.is_a?(Array) }&.length || 1
88
-
84
+ array_length = cond_results.find { |v| v.is_a?(Array) }&.length ||
85
+ res_results.find { |v| v.is_a?(Array) }&.length || 1
86
+
89
87
  (0...array_length).map do |i|
90
- pairs.each_with_index do |(cond, res), pair_idx|
88
+ pairs.each_with_index do |(_cond, _res), pair_idx|
91
89
  cond_val = cond_results[pair_idx].is_a?(Array) ? cond_results[pair_idx][i] : cond_results[pair_idx]
92
-
90
+
93
91
  if cond_val
94
92
  res_val = res_results[pair_idx].is_a?(Array) ? res_results[pair_idx][i] : res_results[pair_idx]
95
93
  break res_val
@@ -98,7 +96,7 @@ module Kumi
98
96
  end
99
97
  else
100
98
  # All data is scalar - use regular cascade evaluation
101
- pairs.each_with_index do |(cond, res), pair_idx|
99
+ pairs.each_with_index do |(_cond, _res), pair_idx|
102
100
  return res_results[pair_idx] if cond_results[pair_idx]
103
101
  end
104
102
  nil
@@ -114,17 +112,17 @@ module Kumi
114
112
 
115
113
  def transform_vectorized_condition(condition_expr)
116
114
  # If this is fn(:all?, [trait_ref]), extract the trait_ref for vectorized cascades
117
- if condition_expr.is_a?(Kumi::Syntax::CallExpression) &&
118
- condition_expr.fn_name == :all? &&
115
+ if condition_expr.is_a?(Kumi::Syntax::CallExpression) &&
116
+ condition_expr.fn_name == :all? &&
119
117
  condition_expr.args.length == 1
120
-
118
+
121
119
  arg = condition_expr.args.first
122
120
  if arg.is_a?(Kumi::Syntax::ArrayExpression) && arg.elements.length == 1
123
121
  trait_ref = arg.elements.first
124
122
  return compile_expr(trait_ref)
125
123
  end
126
124
  end
127
-
125
+
128
126
  # Otherwise compile normally
129
127
  compile_expr(condition_expr)
130
128
  end
@@ -160,7 +158,7 @@ module Kumi
160
158
  compile_declaration(decl)
161
159
  end
162
160
 
163
- CompiledSchema.new(@bindings.freeze)
161
+ Core::CompiledSchema.new(@bindings.freeze)
164
162
  end
165
163
 
166
164
  private
@@ -216,14 +214,12 @@ module Kumi
216
214
 
217
215
  def vectorized_operation?(expr)
218
216
  # Check if this operation uses vectorized inputs
219
- broadcast_meta = @analysis.state[:broadcast_metadata]
217
+ broadcast_meta = @analysis.state[:broadcasts]
220
218
  return false unless broadcast_meta
221
-
219
+
222
220
  # Reduction functions are NOT vectorized operations - they consume arrays
223
- if FunctionRegistry.reducer?(expr.fn_name)
224
- return false
225
- end
226
-
221
+ return false if Kumi::Registry.reducer?(expr.fn_name)
222
+
227
223
  expr.args.any? do |arg|
228
224
  case arg
229
225
  when Kumi::Syntax::InputElementReference
@@ -235,21 +231,20 @@ module Kumi
235
231
  end
236
232
  end
237
233
  end
238
-
239
-
234
+
240
235
  def invoke_vectorized_function(name, arg_fns, ctx, loc)
241
236
  # Evaluate arguments
242
237
  values = arg_fns.map { |fn| fn.call(ctx) }
243
-
238
+
244
239
  # Check if any argument is vectorized (array)
245
- has_vectorized_args = values.any? { |v| v.is_a?(Array) }
246
-
240
+ has_vectorized_args = values.any?(Array)
241
+
247
242
  if has_vectorized_args
248
243
  # Apply function with broadcasting to all vectorized arguments
249
244
  vectorized_function_call(name, values)
250
245
  else
251
246
  # All arguments are scalars - regular function call
252
- fn = FunctionRegistry.fetch(name)
247
+ fn = Kumi::Registry.fetch(name)
253
248
  fn.call(*values)
254
249
  end
255
250
  rescue StandardError => e
@@ -259,37 +254,36 @@ module Kumi
259
254
  runtime_error.define_singleton_method(:cause) { e }
260
255
  raise runtime_error
261
256
  end
262
-
257
+
263
258
  def vectorized_function_call(fn_name, values)
264
259
  # Get the function from registry
265
- fn = FunctionRegistry.fetch(fn_name)
266
-
260
+ fn = Kumi::Registry.fetch(fn_name)
261
+
267
262
  # Find array dimensions for broadcasting
268
263
  array_values = values.select { |v| v.is_a?(Array) }
269
264
  return fn.call(*values) if array_values.empty?
270
-
265
+
271
266
  # All arrays should have the same length (validation could be added)
272
267
  array_length = array_values.first.size
273
-
268
+
274
269
  # Broadcast and apply function element-wise
275
270
  (0...array_length).map do |i|
276
271
  element_args = values.map do |v|
277
- v.is_a?(Array) ? v[i] : v # Broadcast scalars
272
+ v.is_a?(Array) ? v[i] : v # Broadcast scalars
278
273
  end
279
274
  fn.call(*element_args)
280
275
  end
281
276
  end
282
-
283
277
 
284
278
  def invoke_function(name, arg_fns, ctx, loc)
285
- fn = FunctionRegistry.fetch(name)
279
+ fn = Kumi::Registry.fetch(name)
286
280
  values = arg_fns.map { |fn| fn.call(ctx) }
287
281
  fn.call(*values)
288
282
  rescue StandardError => e
289
283
  # Preserve original error class and backtrace while adding context
290
284
  enhanced_message = "Error calling fn(:#{name}) at #{loc}: #{e.message}"
291
285
 
292
- if e.is_a?(Kumi::Errors::Error)
286
+ if e.is_a?(Kumi::Core::Errors::Error)
293
287
  # Re-raise Kumi errors with enhanced message but preserve type
294
288
  e.define_singleton_method(:message) { enhanced_message }
295
289
  raise e
@@ -0,0 +1,39 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ # Simple immutable state wrapper to prevent accidental mutations between passes
7
+ class AnalysisState
8
+ def initialize(data = {})
9
+ @data = data.dup.freeze
10
+ end
11
+
12
+ # Get a value (same as hash access)
13
+ def [](key)
14
+ @data[key]
15
+ end
16
+
17
+ # Check if key exists (same as hash)
18
+ def key?(key)
19
+ @data.key?(key)
20
+ end
21
+
22
+ # Get all keys (same as hash)
23
+ def keys
24
+ @data.keys
25
+ end
26
+
27
+ # Create new state with additional data (simple and clean)
28
+ def with(key, value)
29
+ AnalysisState.new(@data.merge(key => value))
30
+ end
31
+
32
+ # Convert back to hash for final result
33
+ def to_h
34
+ @data.dup
35
+ end
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,59 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ class ConstantEvaluator
7
+ include Syntax
8
+
9
+ def initialize(definitions)
10
+ @definitions = definitions
11
+ @memo = {}
12
+ end
13
+
14
+ OPERATORS = {
15
+ add: :+,
16
+ subtract: :-,
17
+ multiply: :*,
18
+ divide: :/
19
+ }.freeze
20
+
21
+ def evaluate(node, visited = Set.new)
22
+ return :unknown unless node
23
+ return @memo[node] if @memo.key?(node)
24
+ return node.value if node.is_a?(Literal)
25
+
26
+ result = case node
27
+ when DeclarationReference then evaluate_binding(node, visited)
28
+ when CallExpression then evaluate_call_expression(node, visited)
29
+ else :unknown
30
+ end
31
+
32
+ @memo[node] = result unless result == :unknown
33
+ result
34
+ end
35
+
36
+ private
37
+
38
+ def evaluate_binding(node, visited)
39
+ return :unknown if visited.include?(node.name)
40
+
41
+ visited << node.name
42
+ definition = @definitions[node.name]
43
+ return :unknown unless definition
44
+
45
+ evaluate(definition.expression, visited)
46
+ end
47
+
48
+ def evaluate_call_expression(node, visited)
49
+ return :unknown unless OPERATORS.key?(node.fn_name)
50
+
51
+ args = node.args.map { |arg| evaluate(arg, visited) }
52
+ return :unknown if args.any?(:unknown)
53
+
54
+ args.reduce(OPERATORS[node.fn_name])
55
+ end
56
+ end
57
+ end
58
+ end
59
+ end
@@ -0,0 +1,248 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ module Passes
7
+ # Detects which operations should be broadcast over arrays
8
+ # DEPENDENCIES: :inputs, :declarations
9
+ # PRODUCES: :broadcasts
10
+ class BroadcastDetector < PassBase
11
+ def run(errors)
12
+ input_meta = get_state(:inputs) || {}
13
+ definitions = get_state(:declarations) || {}
14
+
15
+ # Find array fields with their element types
16
+ array_fields = find_array_fields(input_meta)
17
+
18
+ # Build compiler metadata
19
+ compiler_metadata = {
20
+ array_fields: array_fields,
21
+ vectorized_operations: {},
22
+ reduction_operations: {}
23
+ }
24
+
25
+ # Track which values are vectorized for type inference
26
+ vectorized_values = {}
27
+
28
+ # Analyze traits first, then values (to handle dependencies)
29
+ traits = definitions.select { |_name, decl| decl.is_a?(Kumi::Syntax::TraitDeclaration) }
30
+ values = definitions.select { |_name, decl| decl.is_a?(Kumi::Syntax::ValueDeclaration) }
31
+
32
+ (traits.to_a + values.to_a).each do |name, decl|
33
+ result = analyze_value_vectorization(name, decl.expression, array_fields, vectorized_values, errors)
34
+
35
+ case result[:type]
36
+ when :vectorized
37
+ compiler_metadata[:vectorized_operations][name] = result[:info]
38
+ # Store array source information for dimension checking
39
+ array_source = extract_array_source(result[:info], array_fields)
40
+ vectorized_values[name] = { vectorized: true, array_source: array_source }
41
+ when :reduction
42
+ compiler_metadata[:reduction_operations][name] = result[:info]
43
+ # Reduction produces scalar, not vectorized
44
+ vectorized_values[name] = { vectorized: false }
45
+ end
46
+ end
47
+
48
+ state.with(:broadcasts, compiler_metadata.freeze)
49
+ end
50
+
51
+ private
52
+
53
+ def find_array_fields(input_meta)
54
+ result = {}
55
+ input_meta.each do |name, meta|
56
+ next unless meta[:type] == :array && meta[:children]
57
+
58
+ result[name] = {
59
+ element_fields: meta[:children].keys,
60
+ element_types: meta[:children].transform_values { |v| v[:type] || :any }
61
+ }
62
+ end
63
+ result
64
+ end
65
+
66
+ def analyze_value_vectorization(name, expr, array_fields, vectorized_values, errors)
67
+ case expr
68
+ when Kumi::Syntax::InputElementReference
69
+ if array_fields.key?(expr.path.first)
70
+ { type: :vectorized, info: { source: :array_field_access, path: expr.path } }
71
+ else
72
+ { type: :scalar }
73
+ end
74
+
75
+ when Kumi::Syntax::DeclarationReference
76
+ # Check if this references a vectorized value
77
+ vector_info = vectorized_values[expr.name]
78
+ if vector_info && vector_info[:vectorized]
79
+ { type: :vectorized, info: { source: :vectorized_declaration, name: expr.name } }
80
+ else
81
+ { type: :scalar }
82
+ end
83
+
84
+ when Kumi::Syntax::CallExpression
85
+ analyze_call_vectorization(name, expr, array_fields, vectorized_values, errors)
86
+
87
+ when Kumi::Syntax::CascadeExpression
88
+ analyze_cascade_vectorization(name, expr, array_fields, vectorized_values, errors)
89
+
90
+ else
91
+ { type: :scalar }
92
+ end
93
+ end
94
+
95
+ def analyze_call_vectorization(_name, expr, array_fields, vectorized_values, errors)
96
+ # Check if this is a reduction function using function registry metadata
97
+ if Kumi::Registry.reducer?(expr.fn_name)
98
+ # Only treat as reduction if the argument is actually vectorized
99
+ arg_info = analyze_argument_vectorization(expr.args.first, array_fields, vectorized_values)
100
+ if arg_info[:vectorized]
101
+ { type: :reduction, info: { function: expr.fn_name, source: arg_info[:source] } }
102
+ else
103
+ # Not a vectorized reduction - just a regular function call
104
+ { type: :scalar }
105
+ end
106
+
107
+ else
108
+ # Special case: all?, any?, none? functions with vectorized trait arguments should be treated as vectorized
109
+ # for cascade condition purposes (they get transformed during compilation)
110
+ if %i[all? any? none?].include?(expr.fn_name) && expr.args.length == 1
111
+ arg = expr.args.first
112
+ if arg.is_a?(Kumi::Syntax::ArrayExpression) && arg.elements.length == 1
113
+ trait_ref = arg.elements.first
114
+ if trait_ref.is_a?(Kumi::Syntax::DeclarationReference) && vectorized_values[trait_ref.name]&.[](:vectorized)
115
+ return { type: :vectorized, info: { source: :cascade_condition_with_vectorized_trait, trait: trait_ref.name } }
116
+ end
117
+ end
118
+ end
119
+
120
+ # ANY function with vectorized arguments becomes vectorized (with broadcasting)
121
+ arg_infos = expr.args.map { |arg| analyze_argument_vectorization(arg, array_fields, vectorized_values) }
122
+
123
+ if arg_infos.any? { |info| info[:vectorized] }
124
+ # Check for dimension mismatches when multiple arguments are vectorized
125
+ vectorized_sources = arg_infos.select { |info| info[:vectorized] }.filter_map { |info| info[:array_source] }.uniq
126
+
127
+ if vectorized_sources.length > 1
128
+ # Multiple different array sources - this is a dimension mismatch
129
+ # Generate enhanced error message with type information
130
+ enhanced_message = build_dimension_mismatch_error(expr, arg_infos, array_fields, vectorized_sources)
131
+
132
+ report_error(errors, enhanced_message, location: expr.loc, type: :semantic)
133
+ return { type: :scalar } # Treat as scalar to prevent further errors
134
+ end
135
+
136
+ # This is a vectorized operation - ANY function supports broadcasting
137
+ { type: :vectorized, info: {
138
+ operation: expr.fn_name,
139
+ vectorized_args: arg_infos.map.with_index { |info, i| [i, info[:vectorized]] }.to_h
140
+ } }
141
+ else
142
+ { type: :scalar }
143
+ end
144
+ end
145
+ end
146
+
147
+ def analyze_argument_vectorization(arg, array_fields, vectorized_values)
148
+ case arg
149
+ when Kumi::Syntax::InputElementReference
150
+ if array_fields.key?(arg.path.first)
151
+ { vectorized: true, source: :array_field, array_source: arg.path.first }
152
+ else
153
+ { vectorized: false }
154
+ end
155
+
156
+ when Kumi::Syntax::DeclarationReference
157
+ # Check if this references a vectorized value
158
+ vector_info = vectorized_values[arg.name]
159
+ if vector_info && vector_info[:vectorized]
160
+ array_source = vector_info[:array_source]
161
+ { vectorized: true, source: :vectorized_value, array_source: array_source }
162
+ else
163
+ { vectorized: false }
164
+ end
165
+
166
+ when Kumi::Syntax::CallExpression
167
+ # Recursively check
168
+ result = analyze_value_vectorization(nil, arg, array_fields, vectorized_values, [])
169
+ { vectorized: result[:type] == :vectorized, source: :expression }
170
+
171
+ else
172
+ { vectorized: false }
173
+ end
174
+ end
175
+
176
+ def extract_array_source(info, _array_fields)
177
+ case info[:source]
178
+ when :array_field_access
179
+ info[:path]&.first
180
+ when :cascade_condition_with_vectorized_trait
181
+ # For cascades, we'd need to trace back to the original source
182
+ nil # TODO: Could be enhanced to trace through trait dependencies
183
+ end
184
+ end
185
+
186
+ def analyze_cascade_vectorization(_name, expr, array_fields, vectorized_values, errors)
187
+ # A cascade is vectorized if:
188
+ # 1. Any of its result expressions are vectorized, OR
189
+ # 2. Any of its conditions reference vectorized values (traits or arrays)
190
+ vectorized_results = []
191
+ vectorized_conditions = []
192
+
193
+ expr.cases.each do |case_expr|
194
+ # Check if result is vectorized
195
+ result_info = analyze_value_vectorization(nil, case_expr.result, array_fields, vectorized_values, errors)
196
+ vectorized_results << (result_info[:type] == :vectorized)
197
+
198
+ # Check if condition is vectorized
199
+ condition_info = analyze_value_vectorization(nil, case_expr.condition, array_fields, vectorized_values, errors)
200
+ vectorized_conditions << (condition_info[:type] == :vectorized)
201
+ end
202
+
203
+ if vectorized_results.any? || vectorized_conditions.any?
204
+ { type: :vectorized, info: { source: :cascade_with_vectorized_conditions_or_results } }
205
+ else
206
+ { type: :scalar }
207
+ end
208
+ end
209
+
210
+ def build_dimension_mismatch_error(_expr, arg_infos, array_fields, vectorized_sources)
211
+ # Build detailed error message with type information
212
+ summary = "Cannot broadcast operation across arrays from different sources: #{vectorized_sources.join(', ')}. "
213
+
214
+ problem_desc = "Problem: Multiple operands are arrays from different sources:\n"
215
+
216
+ vectorized_args = arg_infos.select { |info| info[:vectorized] }
217
+ vectorized_args.each_with_index do |arg_info, index|
218
+ array_source = arg_info[:array_source]
219
+ next unless array_source && array_fields[array_source]
220
+
221
+ # Determine the type based on array field metadata
222
+ type_desc = determine_array_type(array_source, array_fields)
223
+ problem_desc += " - Operand #{index + 1} resolves to #{type_desc} from array '#{array_source}'\n"
224
+ end
225
+
226
+ explanation = "Direct operations on arrays from different sources is ambiguous and not supported. " \
227
+ "Vectorized operations can only work on fields from the same array input."
228
+
229
+ "#{summary}#{problem_desc}#{explanation}"
230
+ end
231
+
232
+ def determine_array_type(array_source, array_fields)
233
+ field_info = array_fields[array_source]
234
+ return "array(any)" unless field_info[:element_types]
235
+
236
+ # For nested arrays (like items.name where items is an array), this represents array(element_type)
237
+ element_types = field_info[:element_types].values.uniq
238
+ if element_types.length == 1
239
+ "array(#{element_types.first})"
240
+ else
241
+ "array(mixed)"
242
+ end
243
+ end
244
+ end
245
+ end
246
+ end
247
+ end
248
+ end
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ module Passes
7
+ # RESPONSIBILITY: Perform local structural validation on each declaration
8
+ # DEPENDENCIES: :definitions
9
+ # PRODUCES: None (validation only)
10
+ # INTERFACE: new(schema, state).run(errors)
11
+ class DeclarationValidator < VisitorPass
12
+ def run(errors)
13
+ each_decl do |decl|
14
+ visit(decl) { |node| validate_node(node, errors) }
15
+ end
16
+ state
17
+ end
18
+
19
+ private
20
+
21
+ def validate_node(node, errors)
22
+ case node
23
+ when Kumi::Syntax::ValueDeclaration
24
+ validate_attribute(node, errors)
25
+ when Kumi::Syntax::TraitDeclaration
26
+ validate_trait(node, errors)
27
+ end
28
+ end
29
+
30
+ def validate_attribute(node, errors)
31
+ return unless node.expression.nil?
32
+
33
+ report_error(errors, "attribute `#{node.name}` requires an expression", location: node.loc)
34
+ end
35
+
36
+ def validate_trait(node, errors)
37
+ return if node.expression.is_a?(Kumi::Syntax::CallExpression)
38
+
39
+ report_error(errors, "trait `#{node.name}` must wrap a CallExpression", location: node.loc)
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
45
+ end