kumi 0.0.9 → 0.0.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/CLAUDE.md +18 -258
  4. data/README.md +188 -121
  5. data/docs/AST.md +1 -1
  6. data/docs/FUNCTIONS.md +52 -8
  7. data/docs/VECTOR_SEMANTICS.md +286 -0
  8. data/docs/compiler_design_principles.md +86 -0
  9. data/docs/features/README.md +15 -2
  10. data/docs/features/hierarchical-broadcasting.md +349 -0
  11. data/docs/features/javascript-transpiler.md +148 -0
  12. data/docs/features/performance.md +1 -3
  13. data/docs/features/s-expression-printer.md +2 -2
  14. data/docs/schema_metadata.md +7 -7
  15. data/examples/deep_schema_compilation_and_evaluation_benchmark.rb +21 -15
  16. data/examples/game_of_life.rb +2 -4
  17. data/lib/kumi/analyzer.rb +34 -14
  18. data/lib/kumi/compiler.rb +4 -283
  19. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +717 -66
  20. data/lib/kumi/core/analyzer/passes/dependency_resolver.rb +1 -1
  21. data/lib/kumi/core/analyzer/passes/input_access_planner_pass.rb +47 -0
  22. data/lib/kumi/core/analyzer/passes/input_collector.rb +118 -99
  23. data/lib/kumi/core/analyzer/passes/join_reduce_planning_pass.rb +293 -0
  24. data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +993 -0
  25. data/lib/kumi/core/analyzer/passes/pass_base.rb +2 -2
  26. data/lib/kumi/core/analyzer/passes/scope_resolution_pass.rb +346 -0
  27. data/lib/kumi/core/analyzer/passes/semantic_constraint_validator.rb +28 -0
  28. data/lib/kumi/core/analyzer/passes/toposorter.rb +9 -3
  29. data/lib/kumi/core/analyzer/passes/type_checker.rb +9 -5
  30. data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +2 -2
  31. data/lib/kumi/core/analyzer/passes/{type_inferencer.rb → type_inferencer_pass.rb} +4 -4
  32. data/lib/kumi/core/analyzer/passes/unsat_detector.rb +92 -48
  33. data/lib/kumi/core/analyzer/plans.rb +52 -0
  34. data/lib/kumi/core/analyzer/structs/access_plan.rb +20 -0
  35. data/lib/kumi/core/analyzer/structs/input_meta.rb +29 -0
  36. data/lib/kumi/core/compiler/access_builder.rb +36 -0
  37. data/lib/kumi/core/compiler/access_planner.rb +219 -0
  38. data/lib/kumi/core/compiler/accessors/base.rb +69 -0
  39. data/lib/kumi/core/compiler/accessors/each_indexed_accessor.rb +84 -0
  40. data/lib/kumi/core/compiler/accessors/materialize_accessor.rb +55 -0
  41. data/lib/kumi/core/compiler/accessors/ravel_accessor.rb +73 -0
  42. data/lib/kumi/core/compiler/accessors/read_accessor.rb +41 -0
  43. data/lib/kumi/core/compiler_base.rb +137 -0
  44. data/lib/kumi/core/error_reporter.rb +6 -5
  45. data/lib/kumi/core/errors.rb +4 -0
  46. data/lib/kumi/core/explain.rb +157 -205
  47. data/lib/kumi/core/export/node_builders.rb +2 -2
  48. data/lib/kumi/core/export/node_serializers.rb +1 -1
  49. data/lib/kumi/core/function_registry/collection_functions.rb +100 -6
  50. data/lib/kumi/core/function_registry/conditional_functions.rb +14 -4
  51. data/lib/kumi/core/function_registry/function_builder.rb +142 -53
  52. data/lib/kumi/core/function_registry/logical_functions.rb +173 -3
  53. data/lib/kumi/core/function_registry/stat_functions.rb +156 -0
  54. data/lib/kumi/core/function_registry.rb +138 -98
  55. data/lib/kumi/core/ir/execution_engine/combinators.rb +117 -0
  56. data/lib/kumi/core/ir/execution_engine/interpreter.rb +336 -0
  57. data/lib/kumi/core/ir/execution_engine/values.rb +46 -0
  58. data/lib/kumi/core/ir/execution_engine.rb +50 -0
  59. data/lib/kumi/core/ir.rb +58 -0
  60. data/lib/kumi/core/ruby_parser/build_context.rb +2 -2
  61. data/lib/kumi/core/ruby_parser/declaration_reference_proxy.rb +0 -12
  62. data/lib/kumi/core/ruby_parser/dsl_cascade_builder.rb +37 -16
  63. data/lib/kumi/core/ruby_parser/input_builder.rb +61 -8
  64. data/lib/kumi/core/ruby_parser/parser.rb +1 -1
  65. data/lib/kumi/core/ruby_parser/schema_builder.rb +2 -2
  66. data/lib/kumi/core/ruby_parser/sugar.rb +7 -0
  67. data/lib/kumi/errors.rb +2 -0
  68. data/lib/kumi/js.rb +23 -0
  69. data/lib/kumi/registry.rb +17 -22
  70. data/lib/kumi/runtime/executable.rb +213 -0
  71. data/lib/kumi/schema.rb +15 -4
  72. data/lib/kumi/schema_metadata.rb +2 -2
  73. data/lib/kumi/support/ir_dump.rb +491 -0
  74. data/lib/kumi/support/s_expression_printer.rb +17 -16
  75. data/lib/kumi/syntax/array_expression.rb +6 -6
  76. data/lib/kumi/syntax/call_expression.rb +4 -4
  77. data/lib/kumi/syntax/cascade_expression.rb +4 -4
  78. data/lib/kumi/syntax/case_expression.rb +4 -4
  79. data/lib/kumi/syntax/declaration_reference.rb +4 -4
  80. data/lib/kumi/syntax/hash_expression.rb +4 -4
  81. data/lib/kumi/syntax/input_declaration.rb +6 -5
  82. data/lib/kumi/syntax/input_element_reference.rb +5 -5
  83. data/lib/kumi/syntax/input_reference.rb +5 -5
  84. data/lib/kumi/syntax/literal.rb +4 -4
  85. data/lib/kumi/syntax/location.rb +5 -0
  86. data/lib/kumi/syntax/node.rb +33 -34
  87. data/lib/kumi/syntax/root.rb +6 -6
  88. data/lib/kumi/syntax/trait_declaration.rb +4 -4
  89. data/lib/kumi/syntax/value_declaration.rb +4 -4
  90. data/lib/kumi/version.rb +1 -1
  91. data/lib/kumi.rb +6 -15
  92. data/scripts/analyze_broadcast_methods.rb +68 -0
  93. data/scripts/analyze_cascade_methods.rb +74 -0
  94. data/scripts/check_broadcasting_coverage.rb +51 -0
  95. data/scripts/find_dead_code.rb +114 -0
  96. metadata +36 -9
  97. data/docs/features/array-broadcasting.md +0 -170
  98. data/lib/kumi/cli.rb +0 -449
  99. data/lib/kumi/core/compiled_schema.rb +0 -43
  100. data/lib/kumi/core/evaluation_wrapper.rb +0 -40
  101. data/lib/kumi/core/schema_instance.rb +0 -111
  102. data/lib/kumi/core/vectorization_metadata.rb +0 -110
  103. data/migrate_to_core_iterative.rb +0 -938
@@ -26,7 +26,7 @@ module Kumi
26
26
 
27
27
  def run(errors)
28
28
  definitions = get_state(:declarations)
29
- input_meta = get_state(:inputs)
29
+ input_meta = get_state(:input_metadata)
30
30
 
31
31
  dependency_graph = Hash.new { |h, k| h[k] = [] }
32
32
  reverse_dependencies = Hash.new { |h, k| h[k] = [] }
@@ -0,0 +1,47 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ module Passes
7
+ class InputAccessPlannerPass < PassBase
8
+ def run(errors)
9
+ input_metadata = get_state(:input_metadata)
10
+
11
+ options = {
12
+ on_missing: :error,
13
+ key_policy: :indifferent
14
+ }
15
+
16
+ # TODO : Allow by input definition on policies or at least general policy definition
17
+ plans = Kumi::Core::Compiler::AccessPlanner.plan(input_metadata, options)
18
+
19
+ # Quick validation
20
+ validate_plans!(plans, errors)
21
+
22
+ # Create new state with access plans
23
+ state.with(:access_plans, plans.freeze)
24
+ end
25
+
26
+ private
27
+
28
+ def validate_plans!(plans, errors)
29
+ plans.each do |path, plan_list|
30
+ add_error(errors, nil, "No access plans generated for path: #{path}") if plan_list.nil? || plan_list.empty?
31
+
32
+ plan_list&.each do |plan|
33
+ unless plan[:operations].is_a?(Array)
34
+ add_error(errors, nil, "Invalid operations for path #{path}: expected Array, got #{plan[:operations].class}")
35
+ end
36
+
37
+ unless plan[:mode].is_a?(Symbol)
38
+ add_error(errors, nil, "Invalid mode for path #{path}: expected Symbol, got #{plan[:mode].class}")
39
+ end
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
47
+ end
@@ -4,133 +4,152 @@ module Kumi
4
4
  module Core
5
5
  module Analyzer
6
6
  module Passes
7
- # RESPONSIBILITY: Collect field metadata from input declarations and validate consistency
8
- # DEPENDENCIES: :definitions
9
- # PRODUCES: :inputs - Hash mapping field names to {type:, domain:} metadata
10
- # INTERFACE: new(schema, state).run(errors)
7
+ # Emits per-node metadata:
8
+ # :type, :domain
9
+ # :container => :scalar | :field | :array
10
+ # :access_mode => :field | :element # how THIS node is read once reached
11
+ # :enter_via => :hash | :array # how we HOP from parent to THIS node
12
+ # :consume_alias => true|false # inline array hop (alias is not a hash key)
13
+ # :children => { name => node_meta } # optional
14
+ #
15
+ # Invariants:
16
+ # - Any nested array (child depth ≥ 1) must declare its element (i.e., have children).
17
+ # - Depth-0 inputs always: enter_via :hash, consume_alias false, access_mode :field.
11
18
  class InputCollector < PassBase
12
19
  def run(errors)
13
20
  input_meta = {}
14
21
 
15
- schema.inputs.each do |field_decl|
16
- unless field_decl.is_a?(Kumi::Syntax::InputDeclaration)
17
- report_error(errors, "Expected InputDeclaration node, got #{field_decl.class}", location: field_decl.loc)
18
- next
19
- end
20
-
21
- name = field_decl.name
22
- existing = input_meta[name]
23
-
24
- if existing
25
- # Check for compatibility and merge
26
- merged_meta = merge_field_metadata(existing, field_decl, errors)
27
- input_meta[name] = merged_meta if merged_meta
28
- else
29
- # New field - collect its metadata
30
- input_meta[name] = collect_field_metadata(field_decl, errors)
31
- end
22
+ schema.inputs.each do |decl|
23
+ name = decl.name
24
+ input_meta[name] = collect_field_metadata(decl, errors, depth: 0, name: name)
32
25
  end
33
26
 
34
- state.with(:inputs, freeze_nested_hash(input_meta))
27
+ input_meta.each_value(&:deep_freeze!)
28
+ state.with(:input_metadata, input_meta.freeze)
35
29
  end
36
30
 
37
31
  private
38
32
 
39
- def collect_field_metadata(field_decl, errors)
40
- validate_domain_type(field_decl, errors) if field_decl.domain
41
-
42
- metadata = {
43
- type: field_decl.type,
44
- domain: field_decl.domain
45
- }
46
-
47
- # Process children if present
48
- if field_decl.children && !field_decl.children.empty?
49
- children_meta = {}
50
- field_decl.children.each do |child_decl|
51
- unless child_decl.is_a?(Kumi::Syntax::InputDeclaration)
52
- report_error(errors, "Expected InputDeclaration node in children, got #{child_decl.class}", location: child_decl.loc)
53
- next
54
- end
55
- children_meta[child_decl.name] = collect_field_metadata(child_decl, errors)
33
+ # ---------- builders ----------
34
+
35
+ def collect_field_metadata(decl, errors, depth:, name:)
36
+ children = nil
37
+ if decl.children&.any?
38
+ children = {}
39
+ decl.children.each do |child|
40
+ children[child.name] = collect_field_metadata(child, errors, depth: depth + 1, name: child.name)
56
41
  end
57
- metadata[:children] = children_meta
58
42
  end
59
43
 
60
- metadata
44
+ access_mode = decl.access_mode || :field
45
+
46
+ meta = Structs::InputMeta.new(
47
+ type: decl.type,
48
+ domain: decl.domain,
49
+ container: kind_from_type(decl.type),
50
+ access_mode: access_mode,
51
+ enter_via: :hash,
52
+ consume_alias: false,
53
+ children: children
54
+ )
55
+ stamp_edges_from!(meta, errors, parent_depth: depth)
56
+ validate_access_modes!(meta, errors, parent_depth: depth)
57
+ meta
61
58
  end
62
59
 
63
- def merge_field_metadata(existing, field_decl, errors)
64
- name = field_decl.name
65
-
66
- # Check for type compatibility
67
- if existing[:type] != field_decl.type && field_decl.type && existing[:type]
68
- report_error(errors,
69
- "Field :#{name} declared with conflicting types: #{existing[:type]} vs #{field_decl.type}",
70
- location: field_decl.loc)
71
- end
72
-
73
- # Check for domain compatibility
74
- if existing[:domain] != field_decl.domain && field_decl.domain && existing[:domain]
75
- report_error(errors,
76
- "Field :#{name} declared with conflicting domains: #{existing[:domain].inspect} vs #{field_decl.domain.inspect}",
77
- location: field_decl.loc)
60
+ # ---------- edge stamping + validation ----------
61
+ #
62
+ # Sets child[:enter_via], child[:consume_alias], child[:access_mode] defaults,
63
+ # and validates nested arrays declare their element.
64
+ #
65
+ # Rules:
66
+ # - Common: any ARRAY child at child-depth ≥ 1 must have children (no bare nested array).
67
+ # - Parent :object → any child:
68
+ # child.enter_via = :hash; child.consume_alias = false; child.access_mode ||= :field
69
+ # - Parent :array:
70
+ # * If exactly one child:
71
+ # - child.container ∈ {:scalar, :array} via :array, consume_alias true, access_mode :element
72
+ # - child.container == :field → via :hash, consume_alias false, access_mode :field
73
+ # * Else (element object): every child → via :hash, consume_alias false, access_mode :field
74
+ def stamp_edges_from!(parent_meta, errors, parent_depth:)
75
+ kids = parent_meta.children || {}
76
+ return if kids.empty?
77
+
78
+ # Validate nested arrays anywhere below root
79
+ kids.each do |kname, child|
80
+ next unless child.container == :array
81
+
82
+ if !child.children || child.children.empty?
83
+ report_error(errors, "Nested array at :#{kname} must declare its element", location: nil)
84
+ end
78
85
  end
79
86
 
80
- # Validate domain type if provided
81
- validate_domain_type(field_decl, errors) if field_decl.domain
82
-
83
- # Merge metadata (later declarations override nil values)
84
- merged = {
85
- type: field_decl.type || existing[:type],
86
- domain: field_decl.domain || existing[:domain]
87
- }
88
-
89
- # Merge children if present
90
- if field_decl.children && !field_decl.children.empty?
91
- existing_children = existing[:children] || {}
92
- new_children = {}
87
+ case parent_meta.container
88
+ when :object
89
+ kids.each_value do |child|
90
+ child.enter_via = :hash
91
+ child.consume_alias = false
92
+ child.access_mode = :field
93
+ end
93
94
 
94
- field_decl.children.each do |child_decl|
95
- unless child_decl.is_a?(Kumi::Syntax::InputDeclaration)
96
- report_error(errors, "Expected InputDeclaration node in children, got #{child_decl.class}", location: child_decl.loc)
97
- next
95
+ when :array
96
+ # Array parents MUST explicitly declare their access mode
97
+ access_mode = parent_meta.access_mode
98
+ raise "Array must explicitly declare access_mode (:field or :element)" unless access_mode
99
+
100
+ case access_mode
101
+ when :field
102
+ # Array of objects: all children are fields accessed via hash
103
+ kids.each_value do |child|
104
+ child.enter_via = :hash
105
+ child.consume_alias = false
106
+ child.access_mode = :field
98
107
  end
99
108
 
100
- child_name = child_decl.name
101
- new_children[child_name] = if existing_children[child_name]
102
- merge_field_metadata(existing_children[child_name], child_decl, errors)
103
- else
104
- collect_field_metadata(child_decl, errors)
105
- end
106
- end
109
+ when :element
110
+ _name, only = kids.first
111
+ only.enter_via = :array
112
+ only.consume_alias = true
113
+ only.access_mode = :element
107
114
 
108
- merged[:children] = new_children
109
- elsif existing[:children]
110
- merged[:children] = existing[:children]
115
+ else
116
+ raise "Invalid access_mode :#{access_mode} for array (must be :field or :element)"
117
+ end
111
118
  end
112
-
113
- merged
114
119
  end
115
120
 
116
- def freeze_nested_hash(hash)
117
- hash.each_value do |value|
118
- freeze_nested_hash(value) if value.is_a?(Hash)
119
- end
120
- hash.freeze
121
- end
121
+ # Enforce access_mode semantics are only used in valid contexts.
122
+ def validate_access_modes!(parent_meta, errors, parent_depth:)
123
+ kids = parent_meta.children || {}
124
+ return if kids.empty?
122
125
 
123
- def validate_domain_type(field_decl, errors)
124
- domain = field_decl.domain
125
- return if valid_domain_type?(domain)
126
+ kids.each do |kname, child|
127
+ mode = child.access_mode
128
+ next unless mode
126
129
 
127
- report_error(errors,
128
- "Field :#{field_decl.name} has invalid domain constraint: #{domain.inspect}. Domain must be a Range, Array, or Proc",
129
- location: field_decl.loc)
130
+ unless %i[field element].include?(mode)
131
+ report_error(errors, "Invalid access_mode for :#{kname}: #{mode.inspect}", location: nil)
132
+ next
133
+ end
134
+
135
+ if mode == :element
136
+ if parent_meta.container == :array
137
+ single = (kids.size == 1)
138
+ unless single && %i[scalar array].include?(child.container)
139
+ report_error(errors, "access_mode :element only valid for single scalar/array element (at :#{kname})", location: nil)
140
+ end
141
+ else
142
+ report_error(errors, "access_mode :element only valid under array parent (at :#{kname})", location: nil)
143
+ end
144
+ end
145
+ end
130
146
  end
131
147
 
132
- def valid_domain_type?(domain)
133
- domain.is_a?(Range) || domain.is_a?(Array) || domain.is_a?(Proc)
148
+ def kind_from_type(t)
149
+ return :array if t == :array
150
+ return :field if t == :field
151
+
152
+ :scalar
134
153
  end
135
154
  end
136
155
  end
@@ -0,0 +1,293 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Analyzer
6
+ module Passes
7
+ # Plans join and reduce operations for declarations.
8
+ # Determines reduction axes, flattening requirements, and join policies.
9
+ #
10
+ # DEPENDENCIES: :broadcasts, :scope_plans, :decl_shapes, :declarations, :input_metadata
11
+ # PRODUCES: :join_reduce_plans
12
+ class JoinReducePlanningPass < PassBase
13
+ include Kumi::Core::Analyzer::Plans
14
+
15
+ def run(_errors)
16
+ broadcasts = get_state(:broadcasts, required: false) || {}
17
+ scope_plans = get_state(:scope_plans, required: false) || {}
18
+ declarations = get_state(:declarations, required: true)
19
+ input_metadata = get_state(:input_metadata, required: true)
20
+
21
+ plans = {}
22
+
23
+ # Process reduction operations
24
+ process_reductions(broadcasts, scope_plans, declarations, input_metadata, plans)
25
+
26
+ # Process join operations (for non-reduction vectorized operations)
27
+ process_joins(broadcasts, scope_plans, declarations, plans)
28
+
29
+ # Return new state with join/reduce plans
30
+ state.with(:join_reduce_plans, plans.freeze)
31
+ end
32
+
33
+ private
34
+
35
+ def process_reductions(broadcasts, scope_plans, declarations, input_metadata, plans)
36
+ reduction_ops = broadcasts[:reduction_operations] || {}
37
+
38
+ reduction_ops.each do |name, info|
39
+ debug_reduction(name, info) if ENV["DEBUG_JOIN_REDUCE"]
40
+
41
+ # Get the source scope from scope_plans or infer from argument
42
+ source_scope = get_source_scope(name, info, scope_plans, declarations, input_metadata)
43
+
44
+ # Determine reduction axis and result scope
45
+ axis, result_scope = determine_reduction_axis(source_scope, info, scope_plans, name)
46
+
47
+ # Check for flattening requirements
48
+ flatten_indices = determine_flatten_indices(info)
49
+
50
+ plan = Reduce.new(
51
+ function: info[:function],
52
+ axis: axis,
53
+ source_scope: source_scope,
54
+ result_scope: result_scope,
55
+ flatten_args: flatten_indices
56
+ )
57
+
58
+ plans[name] = plan
59
+
60
+ debug_reduction_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
61
+ end
62
+ end
63
+
64
+ def process_joins(broadcasts, scope_plans, declarations, plans)
65
+ vectorized_ops = broadcasts[:vectorized_operations] || {}
66
+
67
+ # Process vectorized operations
68
+ vectorized_ops.each do |name, info|
69
+ # Skip if already processed as reduction
70
+ next if plans.key?(name)
71
+
72
+ debug_join(name, info) if ENV["DEBUG_JOIN_REDUCE"]
73
+
74
+ scope_plan = scope_plans[name]
75
+ next unless scope_plan
76
+
77
+ # Only need join planning if multiple arguments at different scopes
78
+ next unless requires_join?(declarations[name], scope_plan)
79
+
80
+ plan = Join.new(
81
+ policy: :zip, # Default to zip for array operations
82
+ target_scope: scope_plan.scope
83
+ )
84
+
85
+ plans[name] = plan
86
+
87
+ debug_join_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
88
+ end
89
+
90
+ # Process scalar declarations that need broadcasting to vectorized scopes
91
+ # (These are referenced by vectorized cascades but aren't vectorized themselves)
92
+ scope_plans.each do |name, scope_plan|
93
+ # Skip if already processed
94
+ next if plans.key?(name)
95
+
96
+ # Skip if no vectorized target scope
97
+ next unless scope_plan.scope && !scope_plan.scope.empty?
98
+
99
+ # Skip if already vectorized (handled above)
100
+ next if vectorized_ops.key?(name)
101
+
102
+ # Check if this scalar declaration needs broadcasting
103
+ if needs_scalar_to_vector_broadcast?(name, scope_plan, declarations, vectorized_ops)
104
+ debug_scalar_broadcast(name, scope_plan) if ENV["DEBUG_JOIN_REDUCE"]
105
+
106
+ plan = Join.new(
107
+ policy: :broadcast, # Use broadcast policy for scalar-to-vector
108
+ target_scope: scope_plan.scope
109
+ )
110
+
111
+ plans[name] = plan
112
+
113
+ debug_join_plan(name, plan) if ENV["DEBUG_JOIN_REDUCE"]
114
+ end
115
+ end
116
+ end
117
+
118
+ def get_source_scope(name, reduction_info, scope_plans, declarations, input_metadata)
119
+ # Always infer from the reduction argument - this is the full dimensional scope
120
+ infer_scope_from_argument(reduction_info[:argument], declarations, input_metadata)
121
+ end
122
+
123
+ def determine_reduction_axis(source_scope, reduction_info, scope_plans, name)
124
+ return [[], []] if source_scope.empty?
125
+
126
+ # Check if explicit axis is specified
127
+ if reduction_info[:axis]
128
+ axis = reduction_info[:axis]
129
+ result_scope = compute_result_scope(source_scope, axis)
130
+ return [axis, result_scope]
131
+ end
132
+
133
+ # Check if there's a scope plan that specifies what to preserve (result_scope)
134
+ scope_plan = scope_plans[name]
135
+ if scope_plan&.scope && !scope_plan.scope.empty?
136
+ desired_result_scope = scope_plan.scope
137
+ # Compute axis by removing the desired result dimensions
138
+ axis = source_scope - desired_result_scope
139
+ return [axis, desired_result_scope]
140
+ end
141
+
142
+ # Default: reduce innermost dimension (partial reduction)
143
+ axis = [source_scope.last]
144
+ result_scope = source_scope[0...-1]
145
+
146
+ [axis, result_scope]
147
+ end
148
+
149
+ def compute_result_scope(source_scope, axis)
150
+ # Remove specified axis dimensions from source scope
151
+ case axis
152
+ when :all
153
+ []
154
+ when Array
155
+ source_scope - axis
156
+ when Integer
157
+ # Numeric axis: remove that many innermost dimensions
158
+ source_scope[0...-axis]
159
+ else
160
+ source_scope
161
+ end
162
+ end
163
+
164
+ def determine_flatten_indices(reduction_info)
165
+ # Check for explicit flatten requirements
166
+ flatten = reduction_info[:flatten_argument_indices] || []
167
+ Array(flatten)
168
+ end
169
+
170
+ def requires_join?(declaration, scope_plan)
171
+ return false unless declaration
172
+ return false unless scope_plan.scope && !scope_plan.scope.empty?
173
+
174
+ expr = declaration.expression
175
+
176
+ case expr
177
+ when Kumi::Syntax::CallExpression
178
+ # Multiple arguments suggest potential join requirement
179
+ expr.args.size > 1
180
+ when Kumi::Syntax::CascadeExpression
181
+ # Cascades with vectorized target scope need join planning
182
+ # to handle cross-scope conditions and results
183
+ true
184
+ else
185
+ false
186
+ end
187
+ end
188
+
189
+ def infer_scope_from_argument(arg, declarations, input_metadata)
190
+ return [] unless arg
191
+
192
+ case arg
193
+ when Kumi::Syntax::InputElementReference
194
+ dims_from_path(arg.path, input_metadata)
195
+ when Kumi::Syntax::DeclarationReference
196
+ # Look up the declaration's scope if available
197
+ decl = declarations[arg.name]
198
+ decl ? infer_scope_from_argument(decl.expression, declarations, input_metadata) : []
199
+ when Kumi::Syntax::CallExpression
200
+ # For calls, use the deepest scope from arguments
201
+ scopes = arg.args.map { |a| infer_scope_from_argument(a, declarations, input_metadata) }
202
+ scopes.max_by(&:size) || []
203
+ else
204
+ []
205
+ end
206
+ end
207
+
208
+ def dims_from_path(path, input_metadata)
209
+ dims = []
210
+ meta = input_metadata
211
+
212
+ path.each do |seg|
213
+ field = meta[seg] || meta[seg.to_sym] || meta[seg.to_s]
214
+ break unless field
215
+
216
+ dims << seg.to_sym if field[:type] == :array
217
+
218
+ meta = field[:children] || {}
219
+ end
220
+
221
+ dims
222
+ end
223
+
224
+ def needs_scalar_to_vector_broadcast?(name, scope_plan, declarations, vectorized_ops)
225
+ # Check if this scalar declaration is referenced by any vectorized operation
226
+ # that requires it to be broadcast to a vectorized scope
227
+
228
+ # Look for vectorized operations that reference this declaration
229
+ vectorized_ops.each do |vec_name, vec_info|
230
+ vec_decl = declarations[vec_name]
231
+ next unless vec_decl
232
+
233
+ # Check if this vectorized operation references our scalar declaration
234
+ if declaration_references?(vec_decl.expression, name)
235
+ return true
236
+ end
237
+ end
238
+
239
+ false
240
+ end
241
+
242
+ def declaration_references?(expr, target_name)
243
+ case expr
244
+ when Kumi::Syntax::DeclarationReference
245
+ expr.name == target_name
246
+ when Kumi::Syntax::CallExpression
247
+ expr.args.any? { |arg| declaration_references?(arg, target_name) }
248
+ when Kumi::Syntax::CascadeExpression
249
+ expr.cases.any? do |case_expr|
250
+ declaration_references?(case_expr.condition, target_name) ||
251
+ declaration_references?(case_expr.result, target_name)
252
+ end
253
+ when Kumi::Syntax::ArrayExpression
254
+ expr.elements.any? { |elem| declaration_references?(elem, target_name) }
255
+ else
256
+ false
257
+ end
258
+ end
259
+
260
+ # Debug helpers
261
+ def debug_reduction(name, info)
262
+ puts "\n=== Processing reduction: #{name} ==="
263
+ puts "Function: #{info[:function]}"
264
+ puts "Argument: #{info[:argument].class}"
265
+ end
266
+
267
+ def debug_reduction_plan(name, plan)
268
+ puts "Reduction plan for #{name}:"
269
+ puts " Axis: #{plan.axis.inspect}"
270
+ puts " Source scope: #{plan.source_scope.inspect}"
271
+ puts " Result scope: #{plan.result_scope.inspect}"
272
+ end
273
+
274
+ def debug_join(name, info)
275
+ puts "\n=== Processing join: #{name} ==="
276
+ puts "Source: #{info[:source]}"
277
+ end
278
+
279
+ def debug_join_plan(name, plan)
280
+ puts "Join plan for #{name}:"
281
+ puts " Target scope: #{plan.target_scope.inspect}"
282
+ puts " Policy: #{plan.policy}"
283
+ end
284
+
285
+ def debug_scalar_broadcast(name, scope_plan)
286
+ puts "\n=== Processing scalar broadcast: #{name} ==="
287
+ puts "Target scope: #{scope_plan.scope.inspect}"
288
+ end
289
+ end
290
+ end
291
+ end
292
+ end
293
+ end