kumi 0.0.10 → 0.0.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. checksums.yaml +4 -4
  2. data/.rubocop.yml +1 -1
  3. data/CHANGELOG.md +23 -0
  4. data/CLAUDE.md +7 -231
  5. data/README.md +5 -5
  6. data/docs/SYNTAX.md +66 -0
  7. data/docs/VECTOR_SEMANTICS.md +286 -0
  8. data/docs/features/hierarchical-broadcasting.md +67 -1
  9. data/docs/features/input-declaration-system.md +16 -0
  10. data/docs/features/s-expression-printer.md +2 -2
  11. data/lib/kumi/analyzer.rb +34 -12
  12. data/lib/kumi/compiler.rb +2 -12
  13. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +157 -64
  14. data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +1 -1
  15. data/lib/kumi/core/analyzer/passes/input_access_planner_pass.rb +47 -0
  16. data/lib/kumi/core/analyzer/passes/input_collector.rb +123 -101
  17. data/lib/kumi/core/analyzer/passes/join_reduce_planning_pass.rb +293 -0
  18. data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +993 -0
  19. data/lib/kumi/core/analyzer/passes/pass_base.rb +2 -2
  20. data/lib/kumi/core/analyzer/passes/scope_resolution_pass.rb +346 -0
  21. data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +2 -1
  22. data/lib/kumi/core/analyzer/passes/toposorter.rb +9 -3
  23. data/lib/kumi/core/analyzer/passes/type_checker.rb +3 -3
  24. data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +2 -2
  25. data/lib/kumi/core/analyzer/passes/{type_inferencer.rb → type_inferencer_pass.rb} +4 -4
  26. data/lib/kumi/core/analyzer/passes/unsat_detector.rb +2 -2
  27. data/lib/kumi/core/analyzer/plans.rb +52 -0
  28. data/lib/kumi/core/analyzer/structs/access_plan.rb +20 -0
  29. data/lib/kumi/core/analyzer/structs/input_meta.rb +29 -0
  30. data/lib/kumi/core/compiler/access_builder.rb +36 -0
  31. data/lib/kumi/core/compiler/access_planner.rb +219 -0
  32. data/lib/kumi/core/compiler/accessors/base.rb +69 -0
  33. data/lib/kumi/core/compiler/accessors/each_indexed_accessor.rb +84 -0
  34. data/lib/kumi/core/compiler/accessors/materialize_accessor.rb +55 -0
  35. data/lib/kumi/core/compiler/accessors/ravel_accessor.rb +73 -0
  36. data/lib/kumi/core/compiler/accessors/read_accessor.rb +41 -0
  37. data/lib/kumi/core/compiler_base.rb +2 -2
  38. data/lib/kumi/core/error_reporter.rb +6 -5
  39. data/lib/kumi/core/errors.rb +4 -0
  40. data/lib/kumi/core/explain.rb +157 -205
  41. data/lib/kumi/core/export/node_builders.rb +2 -2
  42. data/lib/kumi/core/export/node_serializers.rb +1 -1
  43. data/lib/kumi/core/function_registry/collection_functions.rb +21 -10
  44. data/lib/kumi/core/function_registry/conditional_functions.rb +14 -4
  45. data/lib/kumi/core/function_registry/function_builder.rb +142 -55
  46. data/lib/kumi/core/function_registry/logical_functions.rb +5 -5
  47. data/lib/kumi/core/function_registry/stat_functions.rb +2 -2
  48. data/lib/kumi/core/function_registry.rb +126 -108
  49. data/lib/kumi/core/input/validator.rb +1 -1
  50. data/lib/kumi/core/ir/execution_engine/combinators.rb +117 -0
  51. data/lib/kumi/core/ir/execution_engine/interpreter.rb +336 -0
  52. data/lib/kumi/core/ir/execution_engine/values.rb +46 -0
  53. data/lib/kumi/core/ir/execution_engine.rb +50 -0
  54. data/lib/kumi/core/ir.rb +58 -0
  55. data/lib/kumi/core/ruby_parser/build_context.rb +2 -2
  56. data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +0 -12
  57. data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +36 -15
  58. data/lib/kumi/core/ruby_parser/input_builder.rb +30 -9
  59. data/lib/kumi/core/ruby_parser/parser.rb +1 -1
  60. data/lib/kumi/core/ruby_parser/schema_builder.rb +2 -2
  61. data/lib/kumi/core/ruby_parser/sugar.rb +7 -0
  62. data/lib/kumi/core/types/validator.rb +1 -1
  63. data/lib/kumi/registry.rb +14 -79
  64. data/lib/kumi/runtime/executable.rb +213 -0
  65. data/lib/kumi/schema.rb +14 -3
  66. data/lib/kumi/schema_metadata.rb +2 -2
  67. data/lib/kumi/support/ir_dump.rb +491 -0
  68. data/lib/kumi/support/s_expression_printer.rb +1 -1
  69. data/lib/kumi/syntax/location.rb +5 -0
  70. data/lib/kumi/syntax/node.rb +0 -1
  71. data/lib/kumi/syntax/root.rb +2 -2
  72. data/lib/kumi/version.rb +1 -1
  73. data/lib/kumi.rb +6 -15
  74. metadata +37 -19
  75. data/lib/kumi/core/cascade_executor_builder.rb +0 -132
  76. data/lib/kumi/core/compiled_schema.rb +0 -43
  77. data/lib/kumi/core/compiler/expression_compiler.rb +0 -146
  78. data/lib/kumi/core/compiler/function_invoker.rb +0 -55
  79. data/lib/kumi/core/compiler/path_traversal_compiler.rb +0 -158
  80. data/lib/kumi/core/compiler/reference_compiler.rb +0 -46
  81. data/lib/kumi/core/evaluation_wrapper.rb +0 -40
  82. data/lib/kumi/core/nested_structure_utils.rb +0 -78
  83. data/lib/kumi/core/schema_instance.rb +0 -115
  84. data/lib/kumi/core/vectorized_function_builder.rb +0 -88
  85. data/lib/kumi/js/compiler.rb +0 -878
  86. data/lib/kumi/js/function_registry.rb +0 -333
  87. data/migrate_to_core_iterative.rb +0 -938
@@ -0,0 +1,293 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ module Passes
7
+ # Plans join and reduce operations for declarations.
8
+ # Determines reduction axes, flattening requirements, and join policies.
9
+ #
10
+ # DEPENDENCIES: :broadcasts, :scope_plans, :decl_shapes, :declarations, :input_metadata
11
+ # PRODUCES: :join_reduce_plans
12
+ class JoinReducePlanningPass < PassBase
13
+ include Kumi::Core::Analyzer::Plans
14
+
15
+ def run(_errors)
16
+ broadcasts = get_state(:broadcasts, required: false) || {}
17
+ scope_plans = get_state(:scope_plans, required: false) || {}
18
+ declarations = get_state(:declarations, required: true)
19
+ input_metadata = get_state(:input_metadata, required: true)
20
+
21
+ plans = {}
22
+
23
+ # Process reduction operations
24
+ process_reductions(broadcasts, scope_plans, declarations, input_metadata, plans)
25
+
26
+ # Process join operations (for non-reduction vectorized operations)
27
+ process_joins(broadcasts, scope_plans, declarations, plans)
28
+
29
+ # Return new state with join/reduce plans
30
+ state.with(:join_reduce_plans, plans.freeze)
31
+ end
32
+
33
+ private
34
+
35
+ def process_reductions(broadcasts, scope_plans, declarations, input_metadata, plans)
36
+ reduction_ops = broadcasts[:reduction_operations] || {}
37
+
38
+ reduction_ops.each do |name, info|
39
+ debug_reduction(name, info) if ENV["DEBUG_JOIN_REDUCE"]
40
+
41
+ # Get the source scope from scope_plans or infer from argument
42
+ source_scope = get_source_scope(name, info, scope_plans, declarations, input_metadata)
43
+
44
+ # Determine reduction axis and result scope
45
+ axis, result_scope = determine_reduction_axis(source_scope, info, scope_plans, name)
46
+
47
+ # Check for flattening requirements
48
+ flatten_indices = determine_flatten_indices(info)
49
+
50
+ plan = Reduce.new(
51
+ function: info[:function],
52
+ axis: axis,
53
+ source_scope: source_scope,
54
+ result_scope: result_scope,
55
+ flatten_args: flatten_indices
56
+ )
57
+
58
+ plans[name] = plan
59
+
60
+ debug_reduction_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
61
+ end
62
+ end
63
+
64
+ def process_joins(broadcasts, scope_plans, declarations, plans)
65
+ vectorized_ops = broadcasts[:vectorized_operations] || {}
66
+
67
+ # Process vectorized operations
68
+ vectorized_ops.each do |name, info|
69
+ # Skip if already processed as reduction
70
+ next if plans.key?(name)
71
+
72
+ debug_join(name, info) if ENV["DEBUG_JOIN_REDUCE"]
73
+
74
+ scope_plan = scope_plans[name]
75
+ next unless scope_plan
76
+
77
+ # Only need join planning if multiple arguments at different scopes
78
+ next unless requires_join?(declarations[name], scope_plan)
79
+
80
+ plan = Join.new(
81
+ policy: :zip, # Default to zip for array operations
82
+ target_scope: scope_plan.scope
83
+ )
84
+
85
+ plans[name] = plan
86
+
87
+ debug_join_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
88
+ end
89
+
90
+ # Process scalar declarations that need broadcasting to vectorized scopes
91
+ # (These are referenced by vectorized cascades but aren't vectorized themselves)
92
+ scope_plans.each do |name, scope_plan|
93
+ # Skip if already processed
94
+ next if plans.key?(name)
95
+
96
+ # Skip if no vectorized target scope
97
+ next unless scope_plan.scope && !scope_plan.scope.empty?
98
+
99
+ # Skip if already vectorized (handled above)
100
+ next if vectorized_ops.key?(name)
101
+
102
+ # Check if this scalar declaration needs broadcasting
103
+ if needs_scalar_to_vector_broadcast?(name, scope_plan, declarations, vectorized_ops)
104
+ debug_scalar_broadcast(name, scope_plan) if ENV["DEBUG_JOIN_REDUCE"]
105
+
106
+ plan = Join.new(
107
+ policy: :broadcast, # Use broadcast policy for scalar-to-vector
108
+ target_scope: scope_plan.scope
109
+ )
110
+
111
+ plans[name] = plan
112
+
113
+ debug_join_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
114
+ end
115
+ end
116
+ end
117
+
118
+ def get_source_scope(name, reduction_info, scope_plans, declarations, input_metadata)
119
+ # Always infer from the reduction argument - this is the full dimensional scope
120
+ infer_scope_from_argument(reduction_info[:argument], declarations, input_metadata)
121
+ end
122
+
123
+ def determine_reduction_axis(source_scope, reduction_info, scope_plans, name)
124
+ return [[], []] if source_scope.empty?
125
+
126
+ # Check if explicit axis is specified
127
+ if reduction_info[:axis]
128
+ axis = reduction_info[:axis]
129
+ result_scope = compute_result_scope(source_scope, axis)
130
+ return [axis, result_scope]
131
+ end
132
+
133
+ # Check if there's a scope plan that specifies what to preserve (result_scope)
134
+ scope_plan = scope_plans[name]
135
+ if scope_plan&.scope && !scope_plan.scope.empty?
136
+ desired_result_scope = scope_plan.scope
137
+ # Compute axis by removing the desired result dimensions
138
+ axis = source_scope - desired_result_scope
139
+ return [axis, desired_result_scope]
140
+ end
141
+
142
+ # Default: reduce innermost dimension (partial reduction)
143
+ axis = [source_scope.last]
144
+ result_scope = source_scope[0...-1]
145
+
146
+ [axis, result_scope]
147
+ end
148
+
149
+ def compute_result_scope(source_scope, axis)
150
+ # Remove specified axis dimensions from source scope
151
+ case axis
152
+ when :all
153
+ []
154
+ when Array
155
+ source_scope - axis
156
+ when Integer
157
+ # Numeric axis: remove that many innermost dimensions
158
+ source_scope[0...-axis]
159
+ else
160
+ source_scope
161
+ end
162
+ end
163
+
164
+ def determine_flatten_indices(reduction_info)
165
+ # Check for explicit flatten requirements
166
+ flatten = reduction_info[:flatten_argument_indices] || []
167
+ Array(flatten)
168
+ end
169
+
170
+ def requires_join?(declaration, scope_plan)
171
+ return false unless declaration
172
+ return false unless scope_plan.scope && !scope_plan.scope.empty?
173
+
174
+ expr = declaration.expression
175
+
176
+ case expr
177
+ when Kumi::Syntax::CallExpression
178
+ # Multiple arguments suggest potential join requirement
179
+ expr.args.size > 1
180
+ when Kumi::Syntax::CascadeExpression
181
+ # Cascades with vectorized target scope need join planning
182
+ # to handle cross-scope conditions and results
183
+ true
184
+ else
185
+ false
186
+ end
187
+ end
188
+
189
+ def infer_scope_from_argument(arg, declarations, input_metadata)
190
+ return [] unless arg
191
+
192
+ case arg
193
+ when Kumi::Syntax::InputElementReference
194
+ dims_from_path(arg.path, input_metadata)
195
+ when Kumi::Syntax::DeclarationReference
196
+ # Look up the declaration's scope if available
197
+ decl = declarations[arg.name]
198
+ decl ? infer_scope_from_argument(decl.expression, declarations, input_metadata) : []
199
+ when Kumi::Syntax::CallExpression
200
+ # For calls, use the deepest scope from arguments
201
+ scopes = arg.args.map { |a| infer_scope_from_argument(a, declarations, input_metadata) }
202
+ scopes.max_by(&:size) || []
203
+ else
204
+ []
205
+ end
206
+ end
207
+
208
+ def dims_from_path(path, input_metadata)
209
+ dims = []
210
+ meta = input_metadata
211
+
212
+ path.each do |seg|
213
+ field = meta[seg] || meta[seg.to_sym] || meta[seg.to_s]
214
+ break unless field
215
+
216
+ dims << seg.to_sym if field[:type] == :array
217
+
218
+ meta = field[:children] || {}
219
+ end
220
+
221
+ dims
222
+ end
223
+
224
+ def needs_scalar_to_vector_broadcast?(name, scope_plan, declarations, vectorized_ops)
225
+ # Check if this scalar declaration is referenced by any vectorized operation
226
+ # that requires it to be broadcast to a vectorized scope
227
+
228
+ # Look for vectorized operations that reference this declaration
229
+ vectorized_ops.each do |vec_name, vec_info|
230
+ vec_decl = declarations[vec_name]
231
+ next unless vec_decl
232
+
233
+ # Check if this vectorized operation references our scalar declaration
234
+ if declaration_references?(vec_decl.expression, name)
235
+ return true
236
+ end
237
+ end
238
+
239
+ false
240
+ end
241
+
242
+ def declaration_references?(expr, target_name)
243
+ case expr
244
+ when Kumi::Syntax::DeclarationReference
245
+ expr.name == target_name
246
+ when Kumi::Syntax::CallExpression
247
+ expr.args.any? { |arg| declaration_references?(arg, target_name) }
248
+ when Kumi::Syntax::CascadeExpression
249
+ expr.cases.any? do |case_expr|
250
+ declaration_references?(case_expr.condition, target_name) ||
251
+ declaration_references?(case_expr.result, target_name)
252
+ end
253
+ when Kumi::Syntax::ArrayExpression
254
+ expr.elements.any? { |elem| declaration_references?(elem, target_name) }
255
+ else
256
+ false
257
+ end
258
+ end
259
+
260
+ # Debug helpers
261
+ def debug_reduction(name, info)
262
+ puts "\n=== Processing reduction: #{name} ==="
263
+ puts "Function: #{info[:function]}"
264
+ puts "Argument: #{info[:argument].class}"
265
+ end
266
+
267
+ def debug_reduction_plan(name, plan)
268
+ puts "Reduction plan for #{name}:"
269
+ puts " Axis: #{plan.axis.inspect}"
270
+ puts " Source scope: #{plan.source_scope.inspect}"
271
+ puts " Result scope: #{plan.result_scope.inspect}"
272
+ end
273
+
274
+ def debug_join(name, info)
275
+ puts "\n=== Processing join: #{name} ==="
276
+ puts "Source: #{info[:source]}"
277
+ end
278
+
279
+ def debug_join_plan(name, plan)
280
+ puts "Join plan for #{name}:"
281
+ puts " Target scope: #{plan.target_scope.inspect}"
282
+ puts " Policy: #{plan.policy}"
283
+ end
284
+
285
+ def debug_scalar_broadcast(name, scope_plan)
286
+ puts "\n=== Processing scalar broadcast: #{name} ==="
287
+ puts "Target scope: #{scope_plan.scope.inspect}"
288
+ end
289
+ end
290
+ end
291
+ end
292
+ end
293
+ end