kumi 0.0.21 → 0.0.22

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,349 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Analyzer
6
- module Passes
7
- # Plans per-declaration execution scope and join/lift needs.
8
- # Determines the dimensional scope (array nesting level) for each declaration
9
- # based on vectorization metadata and input paths.
10
- #
11
- # DEPENDENCIES: :declarations, :input_metadata, :broadcasts
12
- # PRODUCES: :scope_plans, :decl_shapes
13
- class ScopeResolutionPass < PassBase
14
- include Kumi::Core::Analyzer::Plans
15
-
16
- def run(_errors)
17
- declarations = get_state(:declarations, required: true)
18
- input_metadata = get_state(:input_metadata, required: true)
19
- broadcasts = get_state(:broadcasts) || {}
20
- dependencies = get_state(:dependencies) || {}
21
-
22
- puts "Available dependencies: #{dependencies.keys.inspect}" if ENV["DEBUG_SCOPE_RESOLUTION"]
23
-
24
- scope_plans = {}
25
- decl_shapes = {}
26
-
27
- initial_scopes = {}
28
- declarations.each do |name, decl|
29
- debug_output(name, decl) if ENV["DEBUG_SCOPE_RESOLUTION"]
30
- target_scope = infer_target_scope(name, decl, broadcasts, input_metadata)
31
- result_kind = determine_result_kind(name, target_scope, broadcasts)
32
- initial_scopes[name] = target_scope
33
- debug_result(target_scope, result_kind) if ENV["DEBUG_SCOPE_RESOLUTION"]
34
- end
35
-
36
- final_scopes = propagate_scope_constraints(initial_scopes, declarations, input_metadata)
37
- final_scopes.each do |name, target_scope|
38
- result_kind = determine_result_kind(name, target_scope, broadcasts)
39
- plan = build_scope_plan(target_scope)
40
- scope_plans[name] = plan
41
- decl_shapes[name] = { scope: target_scope, result: result_kind }.freeze
42
- end
43
-
44
- # Return new state with scope information
45
- state.with(:scope_plans, scope_plans.freeze)
46
- .with(:decl_shapes, decl_shapes.freeze)
47
- end
48
-
49
- private
50
-
51
- def propagate_scope_constraints(initial_scopes, declarations, input_metadata)
52
- scopes = initial_scopes.dup
53
- puts "\n=== Propagating scope constraints ===" if ENV["DEBUG_SCOPE_RESOLUTION"]
54
-
55
- declarations.each do |name, decl|
56
- case decl.expression
57
- when Kumi::Syntax::ArrayExpression
58
- propagate_from_array_expression(name, decl.expression, scopes, declarations, input_metadata)
59
- when Kumi::Syntax::CascadeExpression
60
- propagate_from_cascade_expression(name, decl.expression, scopes, declarations, input_metadata)
61
- end
62
- end
63
-
64
- puts "Final propagated scopes: #{scopes.inspect}" if ENV["DEBUG_SCOPE_RESOLUTION"]
65
- scopes
66
- end
67
-
68
- def propagate_from_array_expression(name, array_expr, scopes, declarations, input_metadata)
69
- puts "Analyzing array expression in #{name}: #{array_expr.elements.map(&:class)}" if ENV["DEBUG_SCOPE_RESOLUTION"]
70
-
71
- anchor_scope = nil
72
- declaration_refs = []
73
-
74
- array_expr.elements.each do |element|
75
- case element
76
- when Kumi::Syntax::InputElementReference
77
- path_scope = dims_from_path(element.path, input_metadata)
78
- puts "Found input anchor: #{element.path} -> scope #{path_scope}" if ENV["DEBUG_SCOPE_RESOLUTION"]
79
- anchor_scope = path_scope if path_scope.length > (anchor_scope&.length || 0)
80
- when Kumi::Syntax::DeclarationReference
81
- declaration_refs << element.name
82
- end
83
- end
84
-
85
- return unless anchor_scope && !anchor_scope.empty?
86
-
87
- declaration_refs.each do |ref_name|
88
- current_scope = scopes[ref_name] || []
89
- next unless anchor_scope.length > current_scope.length
90
-
91
- puts "Propagating scope #{anchor_scope} to #{ref_name} (was #{current_scope})" if ENV["DEBUG_SCOPE_RESOLUTION"]
92
- scopes[ref_name] = anchor_scope
93
- propagate_to_dependencies(ref_name, anchor_scope, scopes, declarations, input_metadata)
94
- end
95
- end
96
-
97
- def propagate_from_cascade_expression(name, cascade_expr, scopes, declarations, input_metadata)
98
- puts "Analyzing cascade expression in #{name}" if ENV["DEBUG_SCOPE_RESOLUTION"]
99
-
100
- # Cascade should propagate its own scope to condition dependencies
101
- cascade_scope = scopes[name] || []
102
- return if cascade_scope.empty?
103
-
104
- puts "Propagating cascade scope #{cascade_scope} to condition dependencies" if ENV["DEBUG_SCOPE_RESOLUTION"]
105
-
106
- cascade_expr.cases.each do |case_expr|
107
- find_declaration_references(case_expr.condition).each do |ref_name|
108
- current_scope = scopes[ref_name] || []
109
- next unless cascade_scope.length > current_scope.length
110
-
111
- if ENV["DEBUG_SCOPE_RESOLUTION"]
112
- puts "Propagating scope #{cascade_scope} to cascade condition #{ref_name} (was #{current_scope})"
113
- end
114
- scopes[ref_name] = cascade_scope
115
- propagate_to_dependencies(ref_name, cascade_scope, scopes, declarations, input_metadata)
116
- end
117
- end
118
- end
119
-
120
- def propagate_to_dependencies(decl_name, required_scope, scopes, declarations, input_metadata)
121
- return unless declarations[decl_name]
122
-
123
- decl = declarations[decl_name]
124
- puts "Propagating #{required_scope} into dependencies of #{decl_name}" if ENV["DEBUG_SCOPE_RESOLUTION"]
125
-
126
- case decl.expression
127
- when Kumi::Syntax::CascadeExpression
128
- decl.expression.cases.each do |case_expr|
129
- find_declaration_references(case_expr.condition).each do |ref_name|
130
- current_scope = scopes[ref_name] || []
131
- next unless required_scope.length > current_scope.length
132
-
133
- puts "Propagating scope #{required_scope} to trait dependency #{ref_name}" if ENV["DEBUG_SCOPE_RESOLUTION"]
134
- scopes[ref_name] = required_scope
135
- update_reduction_scope_if_needed(ref_name, required_scope, declarations, input_metadata)
136
- end
137
- end
138
- end
139
- end
140
-
141
- def find_declaration_references(expr)
142
- refs = []
143
- case expr
144
- when Kumi::Syntax::DeclarationReference
145
- refs << expr.name
146
- when Kumi::Syntax::CallExpression
147
- expr.args.each { |arg| refs.concat(find_declaration_references(arg)) }
148
- when Kumi::Syntax::ArrayExpression
149
- expr.elements.each { |elem| refs.concat(find_declaration_references(elem)) }
150
- end
151
- refs
152
- end
153
-
154
- def update_reduction_scope_if_needed(decl_name, required_scope, declarations, _input_metadata)
155
- decl = declarations[decl_name]
156
- return unless decl
157
-
158
- puts "Checking if #{decl_name} needs reduction scope update for #{required_scope}" if ENV["DEBUG_SCOPE_RESOLUTION"]
159
- end
160
-
161
- def debug_output(name, decl)
162
- puts "\n=== Resolving scope for #{name} ==="
163
- puts "Declaration: #{decl.inspect}"
164
- end
165
-
166
- def debug_result(target_scope, result_kind)
167
- puts "Target scope: #{target_scope.inspect}"
168
- puts "Result kind: #{result_kind.inspect}"
169
- end
170
-
171
- def build_scope_plan(target_scope)
172
- Scope.new(
173
- scope: target_scope,
174
- lifts: [], # Will be computed during IR lowering per call-site
175
- join_hint: nil, # Will be set to :zip when multiple vectorized args exist
176
- arg_shapes: {} # Optional: filled during lowering
177
- )
178
- end
179
-
180
- def determine_result_kind(name, target_scope, broadcasts)
181
- return :scalar if broadcasts.dig(:reduction_operations, name)
182
- return :scalar if target_scope.empty?
183
-
184
- { array: :dense }
185
- end
186
-
187
- # Derive scope from vectorization metadata or from deepest input path
188
- def infer_target_scope(name, decl, broadcasts, input_metadata)
189
- # First check vectorized operations
190
- vec = broadcasts.dig(:vectorized_operations, name)
191
- if vec
192
- puts "Vectorization info: #{vec.inspect}" if ENV["DEBUG_SCOPE_RESOLUTION"]
193
-
194
- case vec[:source]
195
- when :nested_array_access, :array_field_access
196
- path = vec[:path] || []
197
- return dims_from_path(path, input_metadata)
198
- when :cascade_with_vectorized_conditions_or_results,
199
- :cascade_condition_with_vectorized_trait
200
- # Fallback: derive from first input path seen in expression
201
- path = find_first_input_path(decl.expression) || []
202
- return dims_from_path(path, input_metadata)
203
- else
204
- return []
205
- end
206
- end
207
-
208
- # Check if this is a reduction operation that should preserve some scope
209
- red = broadcasts.dig(:reduction_operations, name)
210
- if red
211
- puts "Reduction info: #{red.inspect}" if ENV["DEBUG_SCOPE_RESOLUTION"]
212
-
213
- # Infer the natural scope for this reduction
214
- # For expressions like fn(:any?, input.players.score_matrices.session.points > 1000)
215
- # we want to reduce over session dimension but preserve the players dimension
216
- scope = infer_reduction_target_scope(decl.expression, input_metadata)
217
- puts "Inferred reduction target scope: #{scope.inspect}" if ENV["DEBUG_SCOPE_RESOLUTION"]
218
- return scope
219
- end
220
-
221
- []
222
- end
223
-
224
- def infer_reduction_target_scope(expr, input_metadata)
225
- # For reduction expressions, we need to analyze the argument to the reducer
226
- # and determine which dimensions should be preserved vs reduced
227
- case expr
228
- when Kumi::Syntax::CallExpression
229
- if reducer_function?(expr.fn_name)
230
- # Find the argument being reduced
231
- arg = expr.args.first
232
- if arg
233
- # Get the full scope from the argument
234
- full_scope = infer_scope_from_argument(arg, input_metadata)
235
-
236
- # For array reductions, we typically want to preserve
237
- # the outermost dimension (e.g., keep :players, reduce :score_matrices/:session)
238
- if full_scope.length > 1
239
- return full_scope[0..0] # Keep only the first dimension
240
- end
241
- end
242
- else
243
- # Recursively check if any argument contains a reducer
244
- # This handles cases like (fn(:sum, ...) >= 3500)
245
- expr.args.each do |arg|
246
- nested_scope = infer_reduction_target_scope(arg, input_metadata)
247
- return nested_scope unless nested_scope.empty?
248
- end
249
- end
250
- end
251
- []
252
- end
253
-
254
- def reducer_function?(fn_name)
255
- entry = Kumi::Registry.entry(fn_name)
256
- entry&.reducer == true
257
- end
258
-
259
- def infer_scope_from_argument(arg, input_metadata)
260
- case arg
261
- when Kumi::Syntax::InputElementReference
262
- dims_from_path(arg.path, input_metadata)
263
- when Kumi::Syntax::InputReference
264
- dims_from_path([arg.name], input_metadata)
265
- when Kumi::Syntax::CallExpression
266
- # For expressions like (input.players.score_matrices.session.points > 1000),
267
- # we need to find the deepest input path
268
- deepest_path = find_deepest_input_path(arg)
269
- deepest_path ? dims_from_path(deepest_path, input_metadata) : []
270
- else
271
- []
272
- end
273
- end
274
-
275
- def find_deepest_input_path(expr)
276
- paths = collect_input_paths(expr)
277
- paths.max_by(&:length)
278
- end
279
-
280
- def collect_input_paths(expr)
281
- paths = []
282
- case expr
283
- when Kumi::Syntax::InputElementReference
284
- paths << expr.path
285
- when Kumi::Syntax::InputReference
286
- paths << [expr.name]
287
- when Kumi::Syntax::CallExpression
288
- expr.args.each { |arg| paths.concat(collect_input_paths(arg)) }
289
- when Kumi::Syntax::ArrayExpression
290
- expr.elements.each { |elem| paths.concat(collect_input_paths(elem)) }
291
- end
292
- paths
293
- end
294
-
295
- def find_first_input_path(expr)
296
- return nil unless expr
297
-
298
- # Handle InputElementReference directly
299
- return expr.path if expr.is_a?(Kumi::Syntax::InputElementReference)
300
-
301
- # Handle InputReference (convert to path array)
302
- return [expr.name] if expr.is_a?(Kumi::Syntax::InputReference)
303
-
304
- # Recursively search in CallExpression arguments
305
- if expr.is_a?(Kumi::Syntax::CallExpression) && expr.args
306
- expr.args.each do |arg|
307
- path = find_first_input_path(arg)
308
- return path if path
309
- end
310
- end
311
-
312
- # Search in CascadeExpression cases
313
- if expr.is_a?(Kumi::Syntax::CascadeExpression) && expr.cases
314
- expr.cases.each do |case_item|
315
- path = find_first_input_path(case_item.condition)
316
- return path if path
317
-
318
- path = find_first_input_path(case_item.result)
319
- return path if path
320
- end
321
- end
322
-
323
- # Search in expression field if present
324
- return find_first_input_path(expr.expression) if expr.respond_to?(:expression)
325
-
326
- nil
327
- end
328
-
329
- # Map an input path like [:regions, :offices, :salary] to container dims [:regions, :offices]
330
- def dims_from_path(path, input_metadata)
331
- dims = []
332
- meta = input_metadata
333
-
334
- path.each do |seg|
335
- field = meta[seg] || meta[seg.to_sym] || meta[seg.to_s]
336
- break unless field
337
-
338
- dims << seg.to_sym if field[:type] == :array
339
-
340
- meta = field[:children] || {}
341
- end
342
-
343
- dims
344
- end
345
- end
346
- end
347
- end
348
- end
349
- end
@@ -1,179 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Analyzer
6
- module Passes
7
- # RESPONSIBILITY: Validate function call arity and argument types against FunctionRegistry
8
- # DEPENDENCIES: :inferred_types from TypeInferencerPass
9
- # PRODUCES: :functions_required - Set of function names used in the schema
10
- # INTERFACE: new(schema, state).run(errors)
11
- class TypeChecker < VisitorPass
12
- def run(errors)
13
- functions_required = Set.new
14
- @registry = state[:registry]
15
-
16
- visit_nodes_of_type(Kumi::Syntax::CallExpression, errors: errors) do |node, _decl, errs|
17
- validate_function_call(node, errs)
18
- functions_required.add(node.fn_name)
19
- end
20
-
21
- state.with(:functions_required, functions_required)
22
- end
23
-
24
- private
25
-
26
- def validate_function_call(node, errors)
27
- signature = get_function_signature(node, errors)
28
- return unless signature
29
-
30
- validate_arity(node, signature, errors)
31
- validate_argument_types(node, signature, errors)
32
- end
33
-
34
- def get_function_signature(node, errors)
35
- func = begin
36
- @registry.function(node.fn_name)
37
- rescue StandardError
38
- nil
39
- end
40
- if func
41
- return {
42
- arity: func.params.size,
43
- param_types: func.params.map { |p| p["type"] ? Kumi::Core::Types.parse(p["type"]) : nil }
44
- }
45
- end
46
-
47
- Kumi::Registry.signature(node.fn_name)
48
- rescue Kumi::Errors::UnknownFunction
49
- # Use old format for backward compatibility, but node.loc provides better location
50
- report_error(errors, "unsupported operator `#{node.fn_name}`", location: node.loc, type: :type)
51
- nil
52
- end
53
-
54
- def validate_arity(node, signature, errors)
55
- expected_arity = signature[:arity]
56
- actual_arity = node.args.size
57
-
58
- return if expected_arity.negative? || expected_arity == actual_arity
59
-
60
- report_error(errors, "operator `#{node.fn_name}` expects #{expected_arity} args, got #{actual_arity}", location: node.loc,
61
- type: :type)
62
- end
63
-
64
- def validate_argument_types(node, signature, errors)
65
- types = signature[:param_types]
66
- return if types.nil? || (signature[:arity].negative? && node.args.empty?)
67
-
68
- # Skip type checking for vectorized operations
69
- broadcast_meta = get_state(:broadcasts, required: false)
70
- return if broadcast_meta && is_part_of_vectorized_operation?(node, broadcast_meta)
71
-
72
- node.args.each_with_index do |arg, i|
73
- validate_argument_type(arg, i, types[i], node.fn_name, errors)
74
- end
75
- end
76
-
77
- def is_part_of_vectorized_operation?(node, broadcast_meta)
78
- # Check if this node is part of a vectorized or reduction operation
79
- # This is a simplified check - in a real implementation we'd need to track context
80
- node.args.any? do |arg|
81
- case arg
82
- when Kumi::Syntax::DeclarationReference
83
- broadcast_meta[:vectorized_operations]&.key?(arg.name) ||
84
- broadcast_meta[:reduction_operations]&.key?(arg.name)
85
- when Kumi::Syntax::InputElementReference
86
- broadcast_meta[:array_fields]&.key?(arg.path.first)
87
- else
88
- false
89
- end
90
- end
91
- end
92
-
93
- def validate_argument_type(arg, index, expected_type, fn_name, errors)
94
- return if expected_type.nil? || expected_type == Kumi::Core::Types::ANY
95
-
96
- # Get the inferred type for this argument
97
- actual_type = get_expression_type(arg)
98
- return if Kumi::Core::Types.compatible?(actual_type, expected_type)
99
-
100
- # Generate descriptive error message
101
- source_desc = describe_expression_type(arg, actual_type)
102
- report_error(errors, "argument #{index + 1} of `fn(:#{fn_name})` expects #{expected_type}, " \
103
- "got #{source_desc}", location: arg.loc, type: :type)
104
- end
105
-
106
- def get_expression_type(expr)
107
- case expr
108
- when Kumi::Syntax::Literal
109
- # Inferred type from literal value
110
- Kumi::Core::Types.infer_from_value(expr.value)
111
-
112
- when Kumi::Syntax::InputReference
113
- # Declared type from input block (user-specified)
114
- get_declared_field_type(expr.name)
115
-
116
- when Kumi::Syntax::DeclarationReference
117
- # Inferred type from type inference results
118
- get_inferred_declaration_type(expr.name)
119
-
120
- else
121
- # For complex expressions, we should have type inference results
122
- # This is a simplified approach - in reality we'd need to track types for all expressions
123
- Kumi::Core::Types::ANY
124
- end
125
- end
126
-
127
- def get_declared_field_type(field_name)
128
- # Get explicitly declared type from input metadata
129
- input_meta = get_state(:input_metadata, required: false) || {}
130
- field_meta = input_meta[field_name]
131
- field_meta&.dig(:type) || Kumi::Core::Types::ANY
132
- end
133
-
134
- def get_inferred_declaration_type(decl_name)
135
- # Get inferred type from type inference results
136
- decl_types = get_state(:inferred_types, required: true)
137
- decl_types[decl_name] || Kumi::Core::Types::ANY
138
- end
139
-
140
- def describe_expression_type(expr, type)
141
- case expr
142
- when Kumi::Syntax::Literal
143
- "`#{expr.value}` of type #{type} (literal value)"
144
-
145
- when Kumi::Syntax::InputReference
146
- input_meta = get_state(:input_metadata, required: false) || {}
147
- field_meta = input_meta[expr.name]
148
-
149
- if field_meta&.dig(:type)
150
- # Explicitly declared type
151
- domain_desc = field_meta[:domain] ? " (domain: #{field_meta[:domain]})" : ""
152
- "input field `#{expr.name}` of declared type #{type}#{domain_desc}"
153
- else
154
- # Undeclared field
155
- "undeclared input field `#{expr.name}` (inferred as #{type})"
156
- end
157
-
158
- when Kumi::Syntax::DeclarationReference
159
- # This type was inferred from the declaration's expression
160
- "reference to declaration `#{expr.name}` of inferred type #{type}"
161
-
162
- when Kumi::Syntax::CallExpression
163
- "result of function `#{expr.fn_name}` returning #{type}"
164
-
165
- when Kumi::Syntax::ArrayExpression
166
- "list expression of type #{type}"
167
-
168
- when Kumi::Syntax::CascadeExpression
169
- "cascade expression of type #{type}"
170
-
171
- else
172
- "expression of type #{type}"
173
- end
174
- end
175
- end
176
- end
177
- end
178
- end
179
- end