kumi 0.0.10 → 0.0.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +1 -1
- data/CHANGELOG.md +23 -0
- data/CLAUDE.md +7 -231
- data/README.md +5 -5
- data/docs/SYNTAX.md +66 -0
- data/docs/VECTOR_SEMANTICS.md +286 -0
- data/docs/features/hierarchical-broadcasting.md +67 -1
- data/docs/features/input-declaration-system.md +16 -0
- data/docs/features/s-expression-printer.md +2 -2
- data/lib/kumi/analyzer.rb +34 -12
- data/lib/kumi/compiler.rb +2 -12
- data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +157 -64
- data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +1 -1
- data/lib/kumi/core/analyzer/passes/input_access_planner_pass.rb +47 -0
- data/lib/kumi/core/analyzer/passes/input_collector.rb +123 -101
- data/lib/kumi/core/analyzer/passes/join_reduce_planning_pass.rb +293 -0
- data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +993 -0
- data/lib/kumi/core/analyzer/passes/pass_base.rb +2 -2
- data/lib/kumi/core/analyzer/passes/scope_resolution_pass.rb +346 -0
- data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +2 -1
- data/lib/kumi/core/analyzer/passes/toposorter.rb +9 -3
- data/lib/kumi/core/analyzer/passes/type_checker.rb +3 -3
- data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +2 -2
- data/lib/kumi/core/analyzer/passes/{type_inferencer.rb → type_inferencer_pass.rb} +4 -4
- data/lib/kumi/core/analyzer/passes/unsat_detector.rb +2 -2
- data/lib/kumi/core/analyzer/plans.rb +52 -0
- data/lib/kumi/core/analyzer/structs/access_plan.rb +20 -0
- data/lib/kumi/core/analyzer/structs/input_meta.rb +29 -0
- data/lib/kumi/core/compiler/access_builder.rb +36 -0
- data/lib/kumi/core/compiler/access_planner.rb +219 -0
- data/lib/kumi/core/compiler/accessors/base.rb +69 -0
- data/lib/kumi/core/compiler/accessors/each_indexed_accessor.rb +84 -0
- data/lib/kumi/core/compiler/accessors/materialize_accessor.rb +55 -0
- data/lib/kumi/core/compiler/accessors/ravel_accessor.rb +73 -0
- data/lib/kumi/core/compiler/accessors/read_accessor.rb +41 -0
- data/lib/kumi/core/compiler_base.rb +2 -2
- data/lib/kumi/core/error_reporter.rb +6 -5
- data/lib/kumi/core/errors.rb +4 -0
- data/lib/kumi/core/explain.rb +157 -205
- data/lib/kumi/core/export/node_builders.rb +2 -2
- data/lib/kumi/core/export/node_serializers.rb +1 -1
- data/lib/kumi/core/function_registry/collection_functions.rb +21 -10
- data/lib/kumi/core/function_registry/conditional_functions.rb +14 -4
- data/lib/kumi/core/function_registry/function_builder.rb +142 -55
- data/lib/kumi/core/function_registry/logical_functions.rb +5 -5
- data/lib/kumi/core/function_registry/stat_functions.rb +2 -2
- data/lib/kumi/core/function_registry.rb +126 -108
- data/lib/kumi/core/input/validator.rb +1 -1
- data/lib/kumi/core/ir/execution_engine/combinators.rb +117 -0
- data/lib/kumi/core/ir/execution_engine/interpreter.rb +336 -0
- data/lib/kumi/core/ir/execution_engine/values.rb +46 -0
- data/lib/kumi/core/ir/execution_engine.rb +50 -0
- data/lib/kumi/core/ir.rb +58 -0
- data/lib/kumi/core/ruby_parser/build_context.rb +2 -2
- data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +0 -12
- data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +36 -15
- data/lib/kumi/core/ruby_parser/input_builder.rb +30 -9
- data/lib/kumi/core/ruby_parser/parser.rb +1 -1
- data/lib/kumi/core/ruby_parser/schema_builder.rb +2 -2
- data/lib/kumi/core/ruby_parser/sugar.rb +7 -0
- data/lib/kumi/core/types/validator.rb +1 -1
- data/lib/kumi/registry.rb +14 -79
- data/lib/kumi/runtime/executable.rb +213 -0
- data/lib/kumi/schema.rb +14 -3
- data/lib/kumi/schema_metadata.rb +2 -2
- data/lib/kumi/support/ir_dump.rb +491 -0
- data/lib/kumi/support/s_expression_printer.rb +1 -1
- data/lib/kumi/syntax/location.rb +5 -0
- data/lib/kumi/syntax/node.rb +0 -1
- data/lib/kumi/syntax/root.rb +2 -2
- data/lib/kumi/version.rb +1 -1
- data/lib/kumi.rb +6 -15
- metadata +37 -19
- data/lib/kumi/core/cascade_executor_builder.rb +0 -132
- data/lib/kumi/core/compiled_schema.rb +0 -43
- data/lib/kumi/core/compiler/expression_compiler.rb +0 -146
- data/lib/kumi/core/compiler/function_invoker.rb +0 -55
- data/lib/kumi/core/compiler/path_traversal_compiler.rb +0 -158
- data/lib/kumi/core/compiler/reference_compiler.rb +0 -46
- data/lib/kumi/core/evaluation_wrapper.rb +0 -40
- data/lib/kumi/core/nested_structure_utils.rb +0 -78
- data/lib/kumi/core/schema_instance.rb +0 -115
- data/lib/kumi/core/vectorized_function_builder.rb +0 -88
- data/lib/kumi/js/compiler.rb +0 -878
- data/lib/kumi/js/function_registry.rb +0 -333
- data/migrate_to_core_iterative.rb +0 -938
@@ -0,0 +1,293 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Kumi
|
4
|
+
module Core
|
5
|
+
module Analyzer
|
6
|
+
module Passes
|
7
|
+
# Plans join and reduce operations for declarations.
|
8
|
+
# Determines reduction axes, flattening requirements, and join policies.
|
9
|
+
#
|
10
|
+
# DEPENDENCIES: :broadcasts, :scope_plans, :decl_shapes, :declarations, :input_metadata
|
11
|
+
# PRODUCES: :join_reduce_plans
|
12
|
+
class JoinReducePlanningPass < PassBase
|
13
|
+
include Kumi::Core::Analyzer::Plans
|
14
|
+
|
15
|
+
def run(_errors)
|
16
|
+
broadcasts = get_state(:broadcasts, required: false) || {}
|
17
|
+
scope_plans = get_state(:scope_plans, required: false) || {}
|
18
|
+
declarations = get_state(:declarations, required: true)
|
19
|
+
input_metadata = get_state(:input_metadata, required: true)
|
20
|
+
|
21
|
+
plans = {}
|
22
|
+
|
23
|
+
# Process reduction operations
|
24
|
+
process_reductions(broadcasts, scope_plans, declarations, input_metadata, plans)
|
25
|
+
|
26
|
+
# Process join operations (for non-reduction vectorized operations)
|
27
|
+
process_joins(broadcasts, scope_plans, declarations, plans)
|
28
|
+
|
29
|
+
# Return new state with join/reduce plans
|
30
|
+
state.with(:join_reduce_plans, plans.freeze)
|
31
|
+
end
|
32
|
+
|
33
|
+
private
|
34
|
+
|
35
|
+
def process_reductions(broadcasts, scope_plans, declarations, input_metadata, plans)
|
36
|
+
reduction_ops = broadcasts[:reduction_operations] || {}
|
37
|
+
|
38
|
+
reduction_ops.each do |name, info|
|
39
|
+
debug_reduction(name, info) if ENV["DEBUG_JOIN_REDUCE"]
|
40
|
+
|
41
|
+
# Get the source scope from scope_plans or infer from argument
|
42
|
+
source_scope = get_source_scope(name, info, scope_plans, declarations, input_metadata)
|
43
|
+
|
44
|
+
# Determine reduction axis and result scope
|
45
|
+
axis, result_scope = determine_reduction_axis(source_scope, info, scope_plans, name)
|
46
|
+
|
47
|
+
# Check for flattening requirements
|
48
|
+
flatten_indices = determine_flatten_indices(info)
|
49
|
+
|
50
|
+
plan = Reduce.new(
|
51
|
+
function: info[:function],
|
52
|
+
axis: axis,
|
53
|
+
source_scope: source_scope,
|
54
|
+
result_scope: result_scope,
|
55
|
+
flatten_args: flatten_indices
|
56
|
+
)
|
57
|
+
|
58
|
+
plans[name] = plan
|
59
|
+
|
60
|
+
debug_reduction_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
def process_joins(broadcasts, scope_plans, declarations, plans)
|
65
|
+
vectorized_ops = broadcasts[:vectorized_operations] || {}
|
66
|
+
|
67
|
+
# Process vectorized operations
|
68
|
+
vectorized_ops.each do |name, info|
|
69
|
+
# Skip if already processed as reduction
|
70
|
+
next if plans.key?(name)
|
71
|
+
|
72
|
+
debug_join(name, info) if ENV["DEBUG_JOIN_REDUCE"]
|
73
|
+
|
74
|
+
scope_plan = scope_plans[name]
|
75
|
+
next unless scope_plan
|
76
|
+
|
77
|
+
# Only need join planning if multiple arguments at different scopes
|
78
|
+
next unless requires_join?(declarations[name], scope_plan)
|
79
|
+
|
80
|
+
plan = Join.new(
|
81
|
+
policy: :zip, # Default to zip for array operations
|
82
|
+
target_scope: scope_plan.scope
|
83
|
+
)
|
84
|
+
|
85
|
+
plans[name] = plan
|
86
|
+
|
87
|
+
debug_join_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
|
88
|
+
end
|
89
|
+
|
90
|
+
# Process scalar declarations that need broadcasting to vectorized scopes
|
91
|
+
# (These are referenced by vectorized cascades but aren't vectorized themselves)
|
92
|
+
scope_plans.each do |name, scope_plan|
|
93
|
+
# Skip if already processed
|
94
|
+
next if plans.key?(name)
|
95
|
+
|
96
|
+
# Skip if no vectorized target scope
|
97
|
+
next unless scope_plan.scope && !scope_plan.scope.empty?
|
98
|
+
|
99
|
+
# Skip if already vectorized (handled above)
|
100
|
+
next if vectorized_ops.key?(name)
|
101
|
+
|
102
|
+
# Check if this scalar declaration needs broadcasting
|
103
|
+
if needs_scalar_to_vector_broadcast?(name, scope_plan, declarations, vectorized_ops)
|
104
|
+
debug_scalar_broadcast(name, scope_plan) if ENV["DEBUG_JOIN_REDUCE"]
|
105
|
+
|
106
|
+
plan = Join.new(
|
107
|
+
policy: :broadcast, # Use broadcast policy for scalar-to-vector
|
108
|
+
target_scope: scope_plan.scope
|
109
|
+
)
|
110
|
+
|
111
|
+
plans[name] = plan
|
112
|
+
|
113
|
+
debug_join_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
def get_source_scope(name, reduction_info, scope_plans, declarations, input_metadata)
|
119
|
+
# Always infer from the reduction argument - this is the full dimensional scope
|
120
|
+
infer_scope_from_argument(reduction_info[:argument], declarations, input_metadata)
|
121
|
+
end
|
122
|
+
|
123
|
+
def determine_reduction_axis(source_scope, reduction_info, scope_plans, name)
|
124
|
+
return [[], []] if source_scope.empty?
|
125
|
+
|
126
|
+
# Check if explicit axis is specified
|
127
|
+
if reduction_info[:axis]
|
128
|
+
axis = reduction_info[:axis]
|
129
|
+
result_scope = compute_result_scope(source_scope, axis)
|
130
|
+
return [axis, result_scope]
|
131
|
+
end
|
132
|
+
|
133
|
+
# Check if there's a scope plan that specifies what to preserve (result_scope)
|
134
|
+
scope_plan = scope_plans[name]
|
135
|
+
if scope_plan&.scope && !scope_plan.scope.empty?
|
136
|
+
desired_result_scope = scope_plan.scope
|
137
|
+
# Compute axis by removing the desired result dimensions
|
138
|
+
axis = source_scope - desired_result_scope
|
139
|
+
return [axis, desired_result_scope]
|
140
|
+
end
|
141
|
+
|
142
|
+
# Default: reduce innermost dimension (partial reduction)
|
143
|
+
axis = [source_scope.last]
|
144
|
+
result_scope = source_scope[0...-1]
|
145
|
+
|
146
|
+
[axis, result_scope]
|
147
|
+
end
|
148
|
+
|
149
|
+
def compute_result_scope(source_scope, axis)
|
150
|
+
# Remove specified axis dimensions from source scope
|
151
|
+
case axis
|
152
|
+
when :all
|
153
|
+
[]
|
154
|
+
when Array
|
155
|
+
source_scope - axis
|
156
|
+
when Integer
|
157
|
+
# Numeric axis: remove that many innermost dimensions
|
158
|
+
source_scope[0...-axis]
|
159
|
+
else
|
160
|
+
source_scope
|
161
|
+
end
|
162
|
+
end
|
163
|
+
|
164
|
+
def determine_flatten_indices(reduction_info)
|
165
|
+
# Check for explicit flatten requirements
|
166
|
+
flatten = reduction_info[:flatten_argument_indices] || []
|
167
|
+
Array(flatten)
|
168
|
+
end
|
169
|
+
|
170
|
+
def requires_join?(declaration, scope_plan)
|
171
|
+
return false unless declaration
|
172
|
+
return false unless scope_plan.scope && !scope_plan.scope.empty?
|
173
|
+
|
174
|
+
expr = declaration.expression
|
175
|
+
|
176
|
+
case expr
|
177
|
+
when Kumi::Syntax::CallExpression
|
178
|
+
# Multiple arguments suggest potential join requirement
|
179
|
+
expr.args.size > 1
|
180
|
+
when Kumi::Syntax::CascadeExpression
|
181
|
+
# Cascades with vectorized target scope need join planning
|
182
|
+
# to handle cross-scope conditions and results
|
183
|
+
true
|
184
|
+
else
|
185
|
+
false
|
186
|
+
end
|
187
|
+
end
|
188
|
+
|
189
|
+
def infer_scope_from_argument(arg, declarations, input_metadata)
|
190
|
+
return [] unless arg
|
191
|
+
|
192
|
+
case arg
|
193
|
+
when Kumi::Syntax::InputElementReference
|
194
|
+
dims_from_path(arg.path, input_metadata)
|
195
|
+
when Kumi::Syntax::DeclarationReference
|
196
|
+
# Look up the declaration's scope if available
|
197
|
+
decl = declarations[arg.name]
|
198
|
+
decl ? infer_scope_from_argument(decl.expression, declarations, input_metadata) : []
|
199
|
+
when Kumi::Syntax::CallExpression
|
200
|
+
# For calls, use the deepest scope from arguments
|
201
|
+
scopes = arg.args.map { |a| infer_scope_from_argument(a, declarations, input_metadata) }
|
202
|
+
scopes.max_by(&:size) || []
|
203
|
+
else
|
204
|
+
[]
|
205
|
+
end
|
206
|
+
end
|
207
|
+
|
208
|
+
def dims_from_path(path, input_metadata)
|
209
|
+
dims = []
|
210
|
+
meta = input_metadata
|
211
|
+
|
212
|
+
path.each do |seg|
|
213
|
+
field = meta[seg] || meta[seg.to_sym] || meta[seg.to_s]
|
214
|
+
break unless field
|
215
|
+
|
216
|
+
dims << seg.to_sym if field[:type] == :array
|
217
|
+
|
218
|
+
meta = field[:children] || {}
|
219
|
+
end
|
220
|
+
|
221
|
+
dims
|
222
|
+
end
|
223
|
+
|
224
|
+
def needs_scalar_to_vector_broadcast?(name, scope_plan, declarations, vectorized_ops)
|
225
|
+
# Check if this scalar declaration is referenced by any vectorized operation
|
226
|
+
# that requires it to be broadcast to a vectorized scope
|
227
|
+
|
228
|
+
# Look for vectorized operations that reference this declaration
|
229
|
+
vectorized_ops.each do |vec_name, vec_info|
|
230
|
+
vec_decl = declarations[vec_name]
|
231
|
+
next unless vec_decl
|
232
|
+
|
233
|
+
# Check if this vectorized operation references our scalar declaration
|
234
|
+
if declaration_references?(vec_decl.expression, name)
|
235
|
+
return true
|
236
|
+
end
|
237
|
+
end
|
238
|
+
|
239
|
+
false
|
240
|
+
end
|
241
|
+
|
242
|
+
def declaration_references?(expr, target_name)
|
243
|
+
case expr
|
244
|
+
when Kumi::Syntax::DeclarationReference
|
245
|
+
expr.name == target_name
|
246
|
+
when Kumi::Syntax::CallExpression
|
247
|
+
expr.args.any? { |arg| declaration_references?(arg, target_name) }
|
248
|
+
when Kumi::Syntax::CascadeExpression
|
249
|
+
expr.cases.any? do |case_expr|
|
250
|
+
declaration_references?(case_expr.condition, target_name) ||
|
251
|
+
declaration_references?(case_expr.result, target_name)
|
252
|
+
end
|
253
|
+
when Kumi::Syntax::ArrayExpression
|
254
|
+
expr.elements.any? { |elem| declaration_references?(elem, target_name) }
|
255
|
+
else
|
256
|
+
false
|
257
|
+
end
|
258
|
+
end
|
259
|
+
|
260
|
+
# Debug helpers
|
261
|
+
def debug_reduction(name, info)
|
262
|
+
puts "\n=== Processing reduction: #{name} ==="
|
263
|
+
puts "Function: #{info[:function]}"
|
264
|
+
puts "Argument: #{info[:argument].class}"
|
265
|
+
end
|
266
|
+
|
267
|
+
def debug_reduction_plan(name, plan)
|
268
|
+
puts "Reduction plan for #{name}:"
|
269
|
+
puts " Axis: #{plan.axis.inspect}"
|
270
|
+
puts " Source scope: #{plan.source_scope.inspect}"
|
271
|
+
puts " Result scope: #{plan.result_scope.inspect}"
|
272
|
+
end
|
273
|
+
|
274
|
+
def debug_join(name, info)
|
275
|
+
puts "\n=== Processing join: #{name} ==="
|
276
|
+
puts "Source: #{info[:source]}"
|
277
|
+
end
|
278
|
+
|
279
|
+
def debug_join_plan(name, plan)
|
280
|
+
puts "Join plan for #{name}:"
|
281
|
+
puts " Target scope: #{plan.target_scope.inspect}"
|
282
|
+
puts " Policy: #{plan.policy}"
|
283
|
+
end
|
284
|
+
|
285
|
+
def debug_scalar_broadcast(name, scope_plan)
|
286
|
+
puts "\n=== Processing scalar broadcast: #{name} ==="
|
287
|
+
puts "Target scope: #{scope_plan.scope.inspect}"
|
288
|
+
end
|
289
|
+
end
|
290
|
+
end
|
291
|
+
end
|
292
|
+
end
|
293
|
+
end
|