kumi 0.0.10 → 0.0.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/CLAUDE.md +7 -231
- data/README.md +1 -1
- data/docs/VECTOR_SEMANTICS.md +286 -0
- data/docs/features/hierarchical-broadcasting.md +1 -1
- data/docs/features/s-expression-printer.md +2 -2
- data/examples/deep_schema_compilation_and_evaluation_benchmark.rb +21 -15
- data/lib/kumi/analyzer.rb +34 -12
- data/lib/kumi/compiler.rb +2 -12
- data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +157 -64
- data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +1 -1
- data/lib/kumi/core/analyzer/passes/input_access_planner_pass.rb +47 -0
- data/lib/kumi/core/analyzer/passes/input_collector.rb +118 -101
- data/lib/kumi/core/analyzer/passes/join_reduce_planning_pass.rb +293 -0
- data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +993 -0
- data/lib/kumi/core/analyzer/passes/pass_base.rb +2 -2
- data/lib/kumi/core/analyzer/passes/scope_resolution_pass.rb +346 -0
- data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +2 -1
- data/lib/kumi/core/analyzer/passes/toposorter.rb +9 -3
- data/lib/kumi/core/analyzer/passes/type_checker.rb +3 -3
- data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +2 -2
- data/lib/kumi/core/analyzer/passes/{type_inferencer.rb → type_inferencer_pass.rb} +4 -4
- data/lib/kumi/core/analyzer/passes/unsat_detector.rb +2 -2
- data/lib/kumi/core/analyzer/plans.rb +52 -0
- data/lib/kumi/core/analyzer/structs/access_plan.rb +20 -0
- data/lib/kumi/core/analyzer/structs/input_meta.rb +29 -0
- data/lib/kumi/core/compiler/access_builder.rb +36 -0
- data/lib/kumi/core/compiler/access_planner.rb +219 -0
- data/lib/kumi/core/compiler/accessors/base.rb +69 -0
- data/lib/kumi/core/compiler/accessors/each_indexed_accessor.rb +84 -0
- data/lib/kumi/core/compiler/accessors/materialize_accessor.rb +55 -0
- data/lib/kumi/core/compiler/accessors/ravel_accessor.rb +73 -0
- data/lib/kumi/core/compiler/accessors/read_accessor.rb +41 -0
- data/lib/kumi/core/compiler_base.rb +2 -2
- data/lib/kumi/core/error_reporter.rb +6 -5
- data/lib/kumi/core/errors.rb +4 -0
- data/lib/kumi/core/explain.rb +157 -205
- data/lib/kumi/core/export/node_builders.rb +2 -2
- data/lib/kumi/core/export/node_serializers.rb +1 -1
- data/lib/kumi/core/function_registry/collection_functions.rb +21 -10
- data/lib/kumi/core/function_registry/conditional_functions.rb +14 -4
- data/lib/kumi/core/function_registry/function_builder.rb +142 -55
- data/lib/kumi/core/function_registry/logical_functions.rb +5 -5
- data/lib/kumi/core/function_registry/stat_functions.rb +2 -2
- data/lib/kumi/core/function_registry.rb +126 -108
- data/lib/kumi/core/ir/execution_engine/combinators.rb +117 -0
- data/lib/kumi/core/ir/execution_engine/interpreter.rb +336 -0
- data/lib/kumi/core/ir/execution_engine/values.rb +46 -0
- data/lib/kumi/core/ir/execution_engine.rb +50 -0
- data/lib/kumi/core/ir.rb +58 -0
- data/lib/kumi/core/ruby_parser/build_context.rb +2 -2
- data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +0 -12
- data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +36 -15
- data/lib/kumi/core/ruby_parser/input_builder.rb +5 -5
- data/lib/kumi/core/ruby_parser/parser.rb +1 -1
- data/lib/kumi/core/ruby_parser/schema_builder.rb +2 -2
- data/lib/kumi/core/ruby_parser/sugar.rb +7 -0
- data/lib/kumi/registry.rb +14 -79
- data/lib/kumi/runtime/executable.rb +213 -0
- data/lib/kumi/schema.rb +14 -3
- data/lib/kumi/schema_metadata.rb +2 -2
- data/lib/kumi/support/ir_dump.rb +491 -0
- data/lib/kumi/support/s_expression_printer.rb +1 -1
- data/lib/kumi/syntax/location.rb +5 -0
- data/lib/kumi/syntax/node.rb +0 -1
- data/lib/kumi/syntax/root.rb +2 -2
- data/lib/kumi/version.rb +1 -1
- data/lib/kumi.rb +6 -15
- metadata +26 -15
- data/lib/kumi/core/cascade_executor_builder.rb +0 -132
- data/lib/kumi/core/compiled_schema.rb +0 -43
- data/lib/kumi/core/compiler/expression_compiler.rb +0 -146
- data/lib/kumi/core/compiler/function_invoker.rb +0 -55
- data/lib/kumi/core/compiler/path_traversal_compiler.rb +0 -158
- data/lib/kumi/core/compiler/reference_compiler.rb +0 -46
- data/lib/kumi/core/evaluation_wrapper.rb +0 -40
- data/lib/kumi/core/nested_structure_utils.rb +0 -78
- data/lib/kumi/core/schema_instance.rb +0 -115
- data/lib/kumi/core/vectorized_function_builder.rb +0 -88
- data/lib/kumi/js/compiler.rb +0 -878
- data/lib/kumi/js/function_registry.rb +0 -333
- data/migrate_to_core_iterative.rb +0 -938
@@ -4,135 +4,152 @@ module Kumi
|
|
4
4
|
module Core
|
5
5
|
module Analyzer
|
6
6
|
module Passes
|
7
|
-
#
|
8
|
-
#
|
9
|
-
#
|
10
|
-
#
|
7
|
+
# Emits per-node metadata:
|
8
|
+
# :type, :domain
|
9
|
+
# :container => :scalar | :field | :array
|
10
|
+
# :access_mode => :field | :element # how THIS node is read once reached
|
11
|
+
# :enter_via => :hash | :array # how we HOP from parent to THIS node
|
12
|
+
# :consume_alias => true|false # inline array hop (alias is not a hash key)
|
13
|
+
# :children => { name => node_meta } # optional
|
14
|
+
#
|
15
|
+
# Invariants:
|
16
|
+
# - Any nested array (child depth ≥ 1) must declare its element (i.e., have children).
|
17
|
+
# - Depth-0 inputs always: enter_via :hash, consume_alias false, access_mode :field.
|
11
18
|
class InputCollector < PassBase
|
12
19
|
def run(errors)
|
13
20
|
input_meta = {}
|
14
21
|
|
15
|
-
schema.inputs.each do |
|
16
|
-
|
17
|
-
|
18
|
-
next
|
19
|
-
end
|
20
|
-
|
21
|
-
name = field_decl.name
|
22
|
-
existing = input_meta[name]
|
23
|
-
|
24
|
-
if existing
|
25
|
-
# Check for compatibility and merge
|
26
|
-
merged_meta = merge_field_metadata(existing, field_decl, errors)
|
27
|
-
input_meta[name] = merged_meta if merged_meta
|
28
|
-
else
|
29
|
-
# New field - collect its metadata
|
30
|
-
input_meta[name] = collect_field_metadata(field_decl, errors)
|
31
|
-
end
|
22
|
+
schema.inputs.each do |decl|
|
23
|
+
name = decl.name
|
24
|
+
input_meta[name] = collect_field_metadata(decl, errors, depth: 0, name: name)
|
32
25
|
end
|
33
26
|
|
34
|
-
|
27
|
+
input_meta.each_value(&:deep_freeze!)
|
28
|
+
state.with(:input_metadata, input_meta.freeze)
|
35
29
|
end
|
36
30
|
|
37
31
|
private
|
38
32
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
# Process children if present
|
49
|
-
if field_decl.children && !field_decl.children.empty?
|
50
|
-
children_meta = {}
|
51
|
-
field_decl.children.each do |child_decl|
|
52
|
-
unless child_decl.is_a?(Kumi::Syntax::InputDeclaration)
|
53
|
-
report_error(errors, "Expected InputDeclaration node in children, got #{child_decl.class}", location: child_decl.loc)
|
54
|
-
next
|
55
|
-
end
|
56
|
-
children_meta[child_decl.name] = collect_field_metadata(child_decl, errors)
|
33
|
+
# ---------- builders ----------
|
34
|
+
|
35
|
+
def collect_field_metadata(decl, errors, depth:, name:)
|
36
|
+
children = nil
|
37
|
+
if decl.children&.any?
|
38
|
+
children = {}
|
39
|
+
decl.children.each do |child|
|
40
|
+
children[child.name] = collect_field_metadata(child, errors, depth: depth + 1, name: child.name)
|
57
41
|
end
|
58
|
-
metadata[:children] = children_meta
|
59
42
|
end
|
60
43
|
|
61
|
-
|
44
|
+
access_mode = decl.access_mode || :field
|
45
|
+
|
46
|
+
meta = Structs::InputMeta.new(
|
47
|
+
type: decl.type,
|
48
|
+
domain: decl.domain,
|
49
|
+
container: kind_from_type(decl.type),
|
50
|
+
access_mode: access_mode,
|
51
|
+
enter_via: :hash,
|
52
|
+
consume_alias: false,
|
53
|
+
children: children
|
54
|
+
)
|
55
|
+
stamp_edges_from!(meta, errors, parent_depth: depth)
|
56
|
+
validate_access_modes!(meta, errors, parent_depth: depth)
|
57
|
+
meta
|
62
58
|
end
|
63
59
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
60
|
+
# ---------- edge stamping + validation ----------
|
61
|
+
#
|
62
|
+
# Sets child[:enter_via], child[:consume_alias], child[:access_mode] defaults,
|
63
|
+
# and validates nested arrays declare their element.
|
64
|
+
#
|
65
|
+
# Rules:
|
66
|
+
# - Common: any ARRAY child at child-depth ≥ 1 must have children (no bare nested array).
|
67
|
+
# - Parent :object → any child:
|
68
|
+
# child.enter_via = :hash; child.consume_alias = false; child.access_mode ||= :field
|
69
|
+
# - Parent :array:
|
70
|
+
# * If exactly one child:
|
71
|
+
# - child.container ∈ {:scalar, :array} → via :array, consume_alias true, access_mode :element
|
72
|
+
# - child.container == :field → via :hash, consume_alias false, access_mode :field
|
73
|
+
# * Else (element object): every child → via :hash, consume_alias false, access_mode :field
|
74
|
+
def stamp_edges_from!(parent_meta, errors, parent_depth:)
|
75
|
+
kids = parent_meta.children || {}
|
76
|
+
return if kids.empty?
|
77
|
+
|
78
|
+
# Validate nested arrays anywhere below root
|
79
|
+
kids.each do |kname, child|
|
80
|
+
next unless child.container == :array
|
81
|
+
|
82
|
+
if !child.children || child.children.empty?
|
83
|
+
report_error(errors, "Nested array at :#{kname} must declare its element", location: nil)
|
84
|
+
end
|
72
85
|
end
|
73
86
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
87
|
+
case parent_meta.container
|
88
|
+
when :object
|
89
|
+
kids.each_value do |child|
|
90
|
+
child.enter_via = :hash
|
91
|
+
child.consume_alias = false
|
92
|
+
child.access_mode = :field
|
93
|
+
end
|
80
94
|
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
existing_children = existing[:children] || {}
|
94
|
-
new_children = {}
|
95
|
-
|
96
|
-
field_decl.children.each do |child_decl|
|
97
|
-
unless child_decl.is_a?(Kumi::Syntax::InputDeclaration)
|
98
|
-
report_error(errors, "Expected InputDeclaration node in children, got #{child_decl.class}", location: child_decl.loc)
|
99
|
-
next
|
95
|
+
when :array
|
96
|
+
# Array parents MUST explicitly declare their access mode
|
97
|
+
access_mode = parent_meta.access_mode
|
98
|
+
raise "Array must explicitly declare access_mode (:field or :element)" unless access_mode
|
99
|
+
|
100
|
+
case access_mode
|
101
|
+
when :field
|
102
|
+
# Array of objects: all children are fields accessed via hash
|
103
|
+
kids.each_value do |child|
|
104
|
+
child.enter_via = :hash
|
105
|
+
child.consume_alias = false
|
106
|
+
child.access_mode = :field
|
100
107
|
end
|
101
108
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
end
|
108
|
-
end
|
109
|
+
when :element
|
110
|
+
_name, only = kids.first
|
111
|
+
only.enter_via = :array
|
112
|
+
only.consume_alias = true
|
113
|
+
only.access_mode = :element
|
109
114
|
|
110
|
-
|
111
|
-
|
112
|
-
|
115
|
+
else
|
116
|
+
raise "Invalid access_mode :#{access_mode} for array (must be :field or :element)"
|
117
|
+
end
|
113
118
|
end
|
114
|
-
|
115
|
-
merged
|
116
119
|
end
|
117
120
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
121
|
+
# Enforce access_mode semantics are only used in valid contexts.
|
122
|
+
def validate_access_modes!(parent_meta, errors, parent_depth:)
|
123
|
+
kids = parent_meta.children || {}
|
124
|
+
return if kids.empty?
|
125
|
+
|
126
|
+
kids.each do |kname, child|
|
127
|
+
mode = child.access_mode
|
128
|
+
next unless mode
|
124
129
|
|
125
|
-
|
126
|
-
|
127
|
-
|
130
|
+
unless %i[field element].include?(mode)
|
131
|
+
report_error(errors, "Invalid access_mode for :#{kname}: #{mode.inspect}", location: nil)
|
132
|
+
next
|
133
|
+
end
|
128
134
|
|
129
|
-
|
130
|
-
|
131
|
-
|
135
|
+
if mode == :element
|
136
|
+
if parent_meta.container == :array
|
137
|
+
single = (kids.size == 1)
|
138
|
+
unless single && %i[scalar array].include?(child.container)
|
139
|
+
report_error(errors, "access_mode :element only valid for single scalar/array element (at :#{kname})", location: nil)
|
140
|
+
end
|
141
|
+
else
|
142
|
+
report_error(errors, "access_mode :element only valid under array parent (at :#{kname})", location: nil)
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
132
146
|
end
|
133
147
|
|
134
|
-
def
|
135
|
-
|
148
|
+
def kind_from_type(t)
|
149
|
+
return :array if t == :array
|
150
|
+
return :field if t == :field
|
151
|
+
|
152
|
+
:scalar
|
136
153
|
end
|
137
154
|
end
|
138
155
|
end
|
@@ -0,0 +1,293 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Kumi
|
4
|
+
module Core
|
5
|
+
module Analyzer
|
6
|
+
module Passes
|
7
|
+
# Plans join and reduce operations for declarations.
|
8
|
+
# Determines reduction axes, flattening requirements, and join policies.
|
9
|
+
#
|
10
|
+
# DEPENDENCIES: :broadcasts, :scope_plans, :decl_shapes, :declarations, :input_metadata
|
11
|
+
# PRODUCES: :join_reduce_plans
|
12
|
+
class JoinReducePlanningPass < PassBase
|
13
|
+
include Kumi::Core::Analyzer::Plans
|
14
|
+
|
15
|
+
def run(_errors)
|
16
|
+
broadcasts = get_state(:broadcasts, required: false) || {}
|
17
|
+
scope_plans = get_state(:scope_plans, required: false) || {}
|
18
|
+
declarations = get_state(:declarations, required: true)
|
19
|
+
input_metadata = get_state(:input_metadata, required: true)
|
20
|
+
|
21
|
+
plans = {}
|
22
|
+
|
23
|
+
# Process reduction operations
|
24
|
+
process_reductions(broadcasts, scope_plans, declarations, input_metadata, plans)
|
25
|
+
|
26
|
+
# Process join operations (for non-reduction vectorized operations)
|
27
|
+
process_joins(broadcasts, scope_plans, declarations, plans)
|
28
|
+
|
29
|
+
# Return new state with join/reduce plans
|
30
|
+
state.with(:join_reduce_plans, plans.freeze)
|
31
|
+
end
|
32
|
+
|
33
|
+
private
|
34
|
+
|
35
|
+
def process_reductions(broadcasts, scope_plans, declarations, input_metadata, plans)
|
36
|
+
reduction_ops = broadcasts[:reduction_operations] || {}
|
37
|
+
|
38
|
+
reduction_ops.each do |name, info|
|
39
|
+
debug_reduction(name, info) if ENV["DEBUG_JOIN_REDUCE"]
|
40
|
+
|
41
|
+
# Get the source scope from scope_plans or infer from argument
|
42
|
+
source_scope = get_source_scope(name, info, scope_plans, declarations, input_metadata)
|
43
|
+
|
44
|
+
# Determine reduction axis and result scope
|
45
|
+
axis, result_scope = determine_reduction_axis(source_scope, info, scope_plans, name)
|
46
|
+
|
47
|
+
# Check for flattening requirements
|
48
|
+
flatten_indices = determine_flatten_indices(info)
|
49
|
+
|
50
|
+
plan = Reduce.new(
|
51
|
+
function: info[:function],
|
52
|
+
axis: axis,
|
53
|
+
source_scope: source_scope,
|
54
|
+
result_scope: result_scope,
|
55
|
+
flatten_args: flatten_indices
|
56
|
+
)
|
57
|
+
|
58
|
+
plans[name] = plan
|
59
|
+
|
60
|
+
debug_reduction_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
def process_joins(broadcasts, scope_plans, declarations, plans)
|
65
|
+
vectorized_ops = broadcasts[:vectorized_operations] || {}
|
66
|
+
|
67
|
+
# Process vectorized operations
|
68
|
+
vectorized_ops.each do |name, info|
|
69
|
+
# Skip if already processed as reduction
|
70
|
+
next if plans.key?(name)
|
71
|
+
|
72
|
+
debug_join(name, info) if ENV["DEBUG_JOIN_REDUCE"]
|
73
|
+
|
74
|
+
scope_plan = scope_plans[name]
|
75
|
+
next unless scope_plan
|
76
|
+
|
77
|
+
# Only need join planning if multiple arguments at different scopes
|
78
|
+
next unless requires_join?(declarations[name], scope_plan)
|
79
|
+
|
80
|
+
plan = Join.new(
|
81
|
+
policy: :zip, # Default to zip for array operations
|
82
|
+
target_scope: scope_plan.scope
|
83
|
+
)
|
84
|
+
|
85
|
+
plans[name] = plan
|
86
|
+
|
87
|
+
debug_join_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
|
88
|
+
end
|
89
|
+
|
90
|
+
# Process scalar declarations that need broadcasting to vectorized scopes
|
91
|
+
# (These are referenced by vectorized cascades but aren't vectorized themselves)
|
92
|
+
scope_plans.each do |name, scope_plan|
|
93
|
+
# Skip if already processed
|
94
|
+
next if plans.key?(name)
|
95
|
+
|
96
|
+
# Skip if no vectorized target scope
|
97
|
+
next unless scope_plan.scope && !scope_plan.scope.empty?
|
98
|
+
|
99
|
+
# Skip if already vectorized (handled above)
|
100
|
+
next if vectorized_ops.key?(name)
|
101
|
+
|
102
|
+
# Check if this scalar declaration needs broadcasting
|
103
|
+
if needs_scalar_to_vector_broadcast?(name, scope_plan, declarations, vectorized_ops)
|
104
|
+
debug_scalar_broadcast(name, scope_plan) if ENV["DEBUG_JOIN_REDUCE"]
|
105
|
+
|
106
|
+
plan = Join.new(
|
107
|
+
policy: :broadcast, # Use broadcast policy for scalar-to-vector
|
108
|
+
target_scope: scope_plan.scope
|
109
|
+
)
|
110
|
+
|
111
|
+
plans[name] = plan
|
112
|
+
|
113
|
+
debug_join_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
def get_source_scope(name, reduction_info, scope_plans, declarations, input_metadata)
|
119
|
+
# Always infer from the reduction argument - this is the full dimensional scope
|
120
|
+
infer_scope_from_argument(reduction_info[:argument], declarations, input_metadata)
|
121
|
+
end
|
122
|
+
|
123
|
+
def determine_reduction_axis(source_scope, reduction_info, scope_plans, name)
|
124
|
+
return [[], []] if source_scope.empty?
|
125
|
+
|
126
|
+
# Check if explicit axis is specified
|
127
|
+
if reduction_info[:axis]
|
128
|
+
axis = reduction_info[:axis]
|
129
|
+
result_scope = compute_result_scope(source_scope, axis)
|
130
|
+
return [axis, result_scope]
|
131
|
+
end
|
132
|
+
|
133
|
+
# Check if there's a scope plan that specifies what to preserve (result_scope)
|
134
|
+
scope_plan = scope_plans[name]
|
135
|
+
if scope_plan&.scope && !scope_plan.scope.empty?
|
136
|
+
desired_result_scope = scope_plan.scope
|
137
|
+
# Compute axis by removing the desired result dimensions
|
138
|
+
axis = source_scope - desired_result_scope
|
139
|
+
return [axis, desired_result_scope]
|
140
|
+
end
|
141
|
+
|
142
|
+
# Default: reduce innermost dimension (partial reduction)
|
143
|
+
axis = [source_scope.last]
|
144
|
+
result_scope = source_scope[0...-1]
|
145
|
+
|
146
|
+
[axis, result_scope]
|
147
|
+
end
|
148
|
+
|
149
|
+
def compute_result_scope(source_scope, axis)
|
150
|
+
# Remove specified axis dimensions from source scope
|
151
|
+
case axis
|
152
|
+
when :all
|
153
|
+
[]
|
154
|
+
when Array
|
155
|
+
source_scope - axis
|
156
|
+
when Integer
|
157
|
+
# Numeric axis: remove that many innermost dimensions
|
158
|
+
source_scope[0...-axis]
|
159
|
+
else
|
160
|
+
source_scope
|
161
|
+
end
|
162
|
+
end
|
163
|
+
|
164
|
+
def determine_flatten_indices(reduction_info)
|
165
|
+
# Check for explicit flatten requirements
|
166
|
+
flatten = reduction_info[:flatten_argument_indices] || []
|
167
|
+
Array(flatten)
|
168
|
+
end
|
169
|
+
|
170
|
+
def requires_join?(declaration, scope_plan)
|
171
|
+
return false unless declaration
|
172
|
+
return false unless scope_plan.scope && !scope_plan.scope.empty?
|
173
|
+
|
174
|
+
expr = declaration.expression
|
175
|
+
|
176
|
+
case expr
|
177
|
+
when Kumi::Syntax::CallExpression
|
178
|
+
# Multiple arguments suggest potential join requirement
|
179
|
+
expr.args.size > 1
|
180
|
+
when Kumi::Syntax::CascadeExpression
|
181
|
+
# Cascades with vectorized target scope need join planning
|
182
|
+
# to handle cross-scope conditions and results
|
183
|
+
true
|
184
|
+
else
|
185
|
+
false
|
186
|
+
end
|
187
|
+
end
|
188
|
+
|
189
|
+
def infer_scope_from_argument(arg, declarations, input_metadata)
|
190
|
+
return [] unless arg
|
191
|
+
|
192
|
+
case arg
|
193
|
+
when Kumi::Syntax::InputElementReference
|
194
|
+
dims_from_path(arg.path, input_metadata)
|
195
|
+
when Kumi::Syntax::DeclarationReference
|
196
|
+
# Look up the declaration's scope if available
|
197
|
+
decl = declarations[arg.name]
|
198
|
+
decl ? infer_scope_from_argument(decl.expression, declarations, input_metadata) : []
|
199
|
+
when Kumi::Syntax::CallExpression
|
200
|
+
# For calls, use the deepest scope from arguments
|
201
|
+
scopes = arg.args.map { |a| infer_scope_from_argument(a, declarations, input_metadata) }
|
202
|
+
scopes.max_by(&:size) || []
|
203
|
+
else
|
204
|
+
[]
|
205
|
+
end
|
206
|
+
end
|
207
|
+
|
208
|
+
def dims_from_path(path, input_metadata)
|
209
|
+
dims = []
|
210
|
+
meta = input_metadata
|
211
|
+
|
212
|
+
path.each do |seg|
|
213
|
+
field = meta[seg] || meta[seg.to_sym] || meta[seg.to_s]
|
214
|
+
break unless field
|
215
|
+
|
216
|
+
dims << seg.to_sym if field[:type] == :array
|
217
|
+
|
218
|
+
meta = field[:children] || {}
|
219
|
+
end
|
220
|
+
|
221
|
+
dims
|
222
|
+
end
|
223
|
+
|
224
|
+
def needs_scalar_to_vector_broadcast?(name, scope_plan, declarations, vectorized_ops)
|
225
|
+
# Check if this scalar declaration is referenced by any vectorized operation
|
226
|
+
# that requires it to be broadcast to a vectorized scope
|
227
|
+
|
228
|
+
# Look for vectorized operations that reference this declaration
|
229
|
+
vectorized_ops.each do |vec_name, vec_info|
|
230
|
+
vec_decl = declarations[vec_name]
|
231
|
+
next unless vec_decl
|
232
|
+
|
233
|
+
# Check if this vectorized operation references our scalar declaration
|
234
|
+
if declaration_references?(vec_decl.expression, name)
|
235
|
+
return true
|
236
|
+
end
|
237
|
+
end
|
238
|
+
|
239
|
+
false
|
240
|
+
end
|
241
|
+
|
242
|
+
def declaration_references?(expr, target_name)
|
243
|
+
case expr
|
244
|
+
when Kumi::Syntax::DeclarationReference
|
245
|
+
expr.name == target_name
|
246
|
+
when Kumi::Syntax::CallExpression
|
247
|
+
expr.args.any? { |arg| declaration_references?(arg, target_name) }
|
248
|
+
when Kumi::Syntax::CascadeExpression
|
249
|
+
expr.cases.any? do |case_expr|
|
250
|
+
declaration_references?(case_expr.condition, target_name) ||
|
251
|
+
declaration_references?(case_expr.result, target_name)
|
252
|
+
end
|
253
|
+
when Kumi::Syntax::ArrayExpression
|
254
|
+
expr.elements.any? { |elem| declaration_references?(elem, target_name) }
|
255
|
+
else
|
256
|
+
false
|
257
|
+
end
|
258
|
+
end
|
259
|
+
|
260
|
+
# Debug helpers
|
261
|
+
def debug_reduction(name, info)
|
262
|
+
puts "\n=== Processing reduction: #{name} ==="
|
263
|
+
puts "Function: #{info[:function]}"
|
264
|
+
puts "Argument: #{info[:argument].class}"
|
265
|
+
end
|
266
|
+
|
267
|
+
def debug_reduction_plan(name, plan)
|
268
|
+
puts "Reduction plan for #{name}:"
|
269
|
+
puts " Axis: #{plan.axis.inspect}"
|
270
|
+
puts " Source scope: #{plan.source_scope.inspect}"
|
271
|
+
puts " Result scope: #{plan.result_scope.inspect}"
|
272
|
+
end
|
273
|
+
|
274
|
+
def debug_join(name, info)
|
275
|
+
puts "\n=== Processing join: #{name} ==="
|
276
|
+
puts "Source: #{info[:source]}"
|
277
|
+
end
|
278
|
+
|
279
|
+
def debug_join_plan(name, plan)
|
280
|
+
puts "Join plan for #{name}:"
|
281
|
+
puts " Target scope: #{plan.target_scope.inspect}"
|
282
|
+
puts " Policy: #{plan.policy}"
|
283
|
+
end
|
284
|
+
|
285
|
+
def debug_scalar_broadcast(name, scope_plan)
|
286
|
+
puts "\n=== Processing scalar broadcast: #{name} ==="
|
287
|
+
puts "Target scope: #{scope_plan.scope.inspect}"
|
288
|
+
end
|
289
|
+
end
|
290
|
+
end
|
291
|
+
end
|
292
|
+
end
|
293
|
+
end
|