kumi 0.0.15 → 0.0.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +13 -0
  3. data/golden/cascade_logic/schema.kumi +3 -1
  4. data/lib/kumi/analyzer.rb +11 -9
  5. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +0 -81
  6. data/lib/kumi/core/analyzer/passes/ir_dependency_pass.rb +18 -20
  7. data/lib/kumi/core/analyzer/passes/ir_execution_schedule_pass.rb +67 -0
  8. data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +0 -36
  9. data/lib/kumi/core/analyzer/passes/toposorter.rb +1 -39
  10. data/lib/kumi/core/analyzer/passes/unsat_detector.rb +8 -191
  11. data/lib/kumi/core/compiler/access_builder.rb +20 -10
  12. data/lib/kumi/core/compiler/access_codegen.rb +61 -0
  13. data/lib/kumi/core/compiler/access_emit/base.rb +173 -0
  14. data/lib/kumi/core/compiler/access_emit/each_indexed.rb +56 -0
  15. data/lib/kumi/core/compiler/access_emit/materialize.rb +45 -0
  16. data/lib/kumi/core/compiler/access_emit/ravel.rb +50 -0
  17. data/lib/kumi/core/compiler/access_emit/read.rb +32 -0
  18. data/lib/kumi/core/ir/execution_engine/interpreter.rb +36 -181
  19. data/lib/kumi/core/ir/execution_engine/values.rb +8 -8
  20. data/lib/kumi/core/ir/execution_engine.rb +3 -19
  21. data/lib/kumi/dev/parse.rb +12 -12
  22. data/lib/kumi/runtime/executable.rb +22 -175
  23. data/lib/kumi/runtime/run.rb +105 -0
  24. data/lib/kumi/schema.rb +8 -13
  25. data/lib/kumi/version.rb +1 -1
  26. data/lib/kumi.rb +3 -2
  27. metadata +10 -25
  28. data/BACKLOG.md +0 -34
  29. data/config/functions.yaml +0 -352
  30. data/docs/functions/analyzer_integration.md +0 -199
  31. data/docs/functions/signatures.md +0 -171
  32. data/examples/hash_objects_demo.rb +0 -138
  33. data/lib/kumi/core/analyzer/passes/function_signature_pass.rb +0 -199
  34. data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +0 -48
  35. data/lib/kumi/core/functions/dimension.rb +0 -98
  36. data/lib/kumi/core/functions/dtypes.rb +0 -20
  37. data/lib/kumi/core/functions/errors.rb +0 -11
  38. data/lib/kumi/core/functions/kernel_adapter.rb +0 -45
  39. data/lib/kumi/core/functions/loader.rb +0 -119
  40. data/lib/kumi/core/functions/registry_v2.rb +0 -68
  41. data/lib/kumi/core/functions/shape.rb +0 -70
  42. data/lib/kumi/core/functions/signature.rb +0 -122
  43. data/lib/kumi/core/functions/signature_parser.rb +0 -86
  44. data/lib/kumi/core/functions/signature_resolver.rb +0 -272
  45. data/lib/kumi/kernels/ruby/aggregate_core.rb +0 -105
  46. data/lib/kumi/kernels/ruby/datetime_scalar.rb +0 -21
  47. data/lib/kumi/kernels/ruby/mask_scalar.rb +0 -15
  48. data/lib/kumi/kernels/ruby/scalar_core.rb +0 -63
  49. data/lib/kumi/kernels/ruby/string_scalar.rb +0 -19
  50. data/lib/kumi/kernels/ruby/vector_struct.rb +0 -39
@@ -1,68 +0,0 @@
1
- module Kumi::Core::Functions
2
- class RegistryV2
3
- def initialize(functions:)
4
- @by_name = functions.group_by(&:name).transform_values { |v| v.sort_by(&:opset).freeze }.freeze
5
- end
6
-
7
- # Factory method to load from YAML configuration
8
- def self.load_from_file(path = nil)
9
- path ||= File.join(__dir__, "../../../..", "config", "functions.yaml")
10
- functions = Loader.load_file(path)
11
- new(functions: functions)
12
- end
13
-
14
- def fetch(name, opset: nil)
15
- list = @by_name[name] or raise KeyError, "unknown function #{name}"
16
- opset ? list.find { |f| f.opset == opset } : list.last
17
- end
18
-
19
- # Get function signatures for NEP-20 signature resolution
20
- # This bridges RegistryV2 with our existing SignatureResolver
21
- def get_function_signatures(name, opset: nil)
22
- begin
23
- fn = fetch(name, opset: opset)
24
- # Convert Signature objects to string representations for NEP-20 parser
25
- fn.signatures.map(&:to_signature_string)
26
- rescue KeyError
27
- [] # Function not found in RegistryV2 - fall back to legacy registry
28
- end
29
- end
30
-
31
- # Enhanced signature resolution using NEP-20 resolver
32
- def choose_signature(fn, arg_shapes)
33
- # Use our NEP-20 SignatureResolver for proper dimension handling
34
- sig_strings = fn.signatures.map(&:to_signature_string)
35
- parsed_sigs = sig_strings.map { |s| SignatureParser.parse(s) }
36
-
37
- plan = SignatureResolver.choose(signatures: parsed_sigs, arg_shapes: arg_shapes)
38
-
39
- # Return both the original function signature and the resolution plan
40
- {
41
- function: fn,
42
- signature: plan[:signature],
43
- plan: plan
44
- }
45
- rescue SignatureMatchError => e
46
- raise ArgumentError, "no matching signature for #{fn.name} with shapes #{arg_shapes.inspect}: #{e.message}"
47
- end
48
-
49
- def resolve_kernel(fn, backend:, conditions: {})
50
- ks = fn.kernels.select { |k| k.backend == backend.to_sym }
51
- ks = ks.select { |k| conditions.all? { |ck, cv| k.when_&.fetch(ck, cv) == cv } } unless conditions.empty?
52
- ks.max_by(&:priority) or raise "no kernel for #{fn.name} backend=#{backend}"
53
- end
54
-
55
- # Introspection methods
56
- def all_function_names
57
- @by_name.keys
58
- end
59
-
60
- def function_exists?(name)
61
- @by_name.key?(name)
62
- end
63
-
64
- def all_functions
65
- @by_name.values.flatten
66
- end
67
- end
68
- end
@@ -1,70 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require_relative "dimension"
4
-
5
- module Kumi
6
- module Core
7
- module Functions
8
- # Shape utilities with NEP 20 support. A "shape" is an Array<Dimension> of dimension objects.
9
- # [] == scalar, [Dimension.new(:i)] == vector along :i, etc.
10
- module Shape
11
- module_function
12
-
13
- def scalar?(shape) = shape.empty?
14
-
15
- def equal?(a, b) = a.map(&:name) == b.map(&:name)
16
-
17
- # NEP 20 broadcast rules:
18
- # - scalar can broadcast to any expected shape
19
- # - fixed-size dimensions must match exactly
20
- # - broadcastable dimensions with |1 modifier can broadcast against size-1
21
- # - flexible dimensions with ? can be omitted if not present in all operands
22
- def broadcastable?(got:, expected:)
23
- return true if scalar?(got)
24
- return false if got.length != expected.length
25
-
26
- got.zip(expected).all? do |got_dim, exp_dim|
27
- broadcastable_dimension?(got: got_dim, expected: exp_dim)
28
- end
29
- end
30
-
31
- def broadcastable_dimension?(got:, expected:)
32
- # Same name and modifiers
33
- return true if got == expected
34
-
35
- # Same name, different modifiers - check compatibility
36
- if got.name == expected.name
37
- # Fixed-size dimensions must match exactly
38
- if got.fixed_size? || expected.fixed_size?
39
- return got.size == expected.size if got.fixed_size? && expected.fixed_size?
40
- return false # one fixed, one not - incompatible
41
- end
42
-
43
- # Both named dimensions with same name are compatible
44
- return true
45
- end
46
-
47
- # Different names - only broadcastable with |1 modifier
48
- expected.broadcastable? && scalar?([got])
49
- end
50
-
51
- # Check if a dimension can be omitted (NEP 20 flexible dimensions)
52
- def flexible?(dim)
53
- dim.is_a?(Dimension) && dim.flexible?
54
- end
55
-
56
- # Check if a dimension can broadcast (NEP 20 broadcastable dimensions)
57
- def broadcastable_dimension?(dim)
58
- dim.is_a?(Dimension) && dim.broadcastable?
59
- end
60
-
61
- # Convenience: find dimensions in set a that are not in set b
62
- def dimensions_minus(a, b)
63
- a_names = a.map(&:name).to_set
64
- b_names = b.map(&:name).to_set
65
- (a_names - b_names).to_a
66
- end
67
- end
68
- end
69
- end
70
- end
@@ -1,122 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Functions
6
- # Signature is a small immutable value object describing a function's
7
- # vectorization contract with NEP 20 support.
8
- #
9
- # in_shapes: Array<Array<Dimension>> e.g., [[Dimension.new(:i)], [Dimension.new(:i)]]
10
- # out_shape: Array<Dimension> e.g., [Dimension.new(:i)], [Dimension.new(:i), Dimension.new(:j)], or []
11
- # join_policy: nil | :zip | :product
12
- # raw: original string, for diagnostics (optional)
13
- class Signature
14
- attr_reader :in_shapes, :out_shape, :join_policy, :raw
15
-
16
- def initialize(in_shapes:, out_shape:, join_policy: nil, raw: nil)
17
- @in_shapes = deep_dup(in_shapes).freeze
18
- @out_shape = out_shape.dup.freeze
19
- @join_policy = join_policy&.to_sym
20
- @raw = raw
21
- validate!
22
- freeze
23
- end
24
-
25
- def arity = @in_shapes.length
26
-
27
- # Dimensions that appear in any input but not in output (i.e., reduced/dropped).
28
- def dropped_axes
29
- input_names = @in_shapes.flatten.map(&:name)
30
- output_names = @out_shape.map(&:name)
31
- (input_names - output_names).uniq.freeze
32
- end
33
-
34
- # True if any axis from inputs is dropped (common in aggregates).
35
- def reduction?
36
- !dropped_axes.empty?
37
- end
38
-
39
- def to_h
40
- {
41
- in_shapes: in_shapes.map(&:dup),
42
- out_shape: out_shape.dup,
43
- join_policy: join_policy,
44
- raw: raw
45
- }
46
- end
47
-
48
- def inspect
49
- "#<Signature #{format_signature}#{" @#{join_policy}" if join_policy}>"
50
- end
51
-
52
- def format_signature
53
- lhs = in_shapes.map { |dims| "(#{dims.map(&:to_s).join(',')})" }.join(",")
54
- rhs = "(#{out_shape.map(&:to_s).join(',')})"
55
- "#{lhs}->#{rhs}"
56
- end
57
-
58
- # Convert back to string representation for NEP-20 parser compatibility
59
- def to_signature_string
60
- sig_str = format_signature
61
- join_policy ? "#{sig_str}@#{join_policy}" : sig_str
62
- end
63
-
64
- private
65
-
66
- def validate!
67
- unless @in_shapes.is_a?(Array) && @in_shapes.all? { |s| s.is_a?(Array) }
68
- raise SignatureError, "in_shapes must be an array of dimension arrays"
69
- end
70
-
71
- @in_shapes.each_with_index do |dims, idx|
72
- validate_dimensions!(dims, where: "in_shapes[#{idx}]")
73
- end
74
-
75
- validate_dimensions!(@out_shape, where: "out_shape")
76
-
77
- unless [nil, :zip, :product].include?(@join_policy)
78
- raise SignatureError, "join_policy must be nil, :zip, or :product; got #{@join_policy.inspect}"
79
- end
80
-
81
- # Validate NEP 20 constraints
82
- validate_nep20_constraints!
83
- end
84
-
85
- def validate_dimensions!(dims, where: "shape")
86
- unless dims.is_a?(Array) && dims.all? { |d| d.is_a?(Dimension) }
87
- raise SignatureError, "#{where}: must be an array of Dimension objects, got: #{dims.inspect}"
88
- end
89
-
90
- # Check for duplicate dimension names within a single argument
91
- names = dims.map(&:name)
92
- duplicates = names.group_by { |n| n }.select { |_, v| v.size > 1 }.keys
93
- raise SignatureError, "#{where}: duplicate dimension names #{duplicates.inspect}" unless duplicates.empty?
94
-
95
- true
96
- end
97
-
98
- def validate_nep20_constraints!
99
- # Broadcastable dimensions should only appear in inputs, not outputs
100
- @out_shape.each do |dim|
101
- raise SignatureError, "output dimension #{dim} cannot be broadcastable" if dim.broadcastable?
102
- end
103
-
104
- # Fixed-size dimensions in outputs must match corresponding input dimensions
105
- all_input_dims = @in_shapes.flatten
106
- @out_shape.each do |out_dim|
107
- next unless out_dim.fixed_size?
108
-
109
- matching_inputs = all_input_dims.select { |in_dim| in_dim.name == out_dim.name }
110
- matching_inputs.each do |in_dim|
111
- if in_dim.fixed_size? && in_dim.size != out_dim.size
112
- raise SignatureError, "fixed-size dimension #{out_dim.name} has inconsistent sizes: #{in_dim.size} vs #{out_dim.size}"
113
- end
114
- end
115
- end
116
- end
117
-
118
- def deep_dup(arr) = arr.map { |x| x.is_a?(Array) ? x.dup : x }
119
- end
120
- end
121
- end
122
- end
@@ -1,86 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require_relative "dimension"
4
- require_relative "signature"
5
-
6
- module Kumi
7
- module Core
8
- module Functions
9
- # Parses NEP 20 conformant signature strings like:
10
- # "(),()->()" # scalar operations
11
- # "(i),(i)->(i)" # vector operations
12
- # "(i),(j)->(i,j)@product" # matrix operations with join policy
13
- # "(i,j)->(i)" # reduction of :j
14
- # "(3),(3)->(3)" # fixed-size 3-vectors (cross product)
15
- # "(i?),(i?)->(i?)" # flexible dimensions
16
- # "(i|1),(i|1)->()" # broadcastable dimensions
17
- # "(m?,n),(n,p?)->(m?,p?)" # matmul signature
18
- class SignatureParser
19
- class << self
20
- def parse(str)
21
- raise SignatureParseError, "empty signature" if str.nil? || str.strip.empty?
22
- lhs, rhs = str.split("->", 2)&.map!(&:strip)
23
- raise SignatureParseError, "signature must contain '->': #{str.inspect}" unless rhs
24
-
25
- out_spec, policy = rhs.split("@", 2)&.map!(&:strip)
26
- in_shapes = parse_many(lhs)
27
- out_shape = parse_axes(out_spec)
28
- join_policy = policy&.to_sym
29
-
30
- Signature.new(in_shapes: in_shapes, out_shape: out_shape, join_policy: join_policy, raw: str)
31
- rescue SignatureError => e
32
- raise
33
- rescue StandardError => e
34
- raise SignatureParseError, "invalid signature #{str.inspect}: #{e.message}"
35
- end
36
-
37
- private
38
-
39
- def parse_many(lhs)
40
- # Handle zero arguments case
41
- return [] if lhs.strip.empty?
42
-
43
- # split by commas that are *between* groups, not inside them
44
- # simpler approach: split by '),', re-add ')' where needed
45
- tokens = lhs.split("),").map { |t| t.strip.end_with?(")") ? t.strip : "#{t.strip})" }
46
- tokens.map { |tok| parse_axes(tok) }
47
- end
48
-
49
- def parse_axes(spec)
50
- spec = spec.strip
51
- raise SignatureParseError, "missing parentheses in #{spec.inspect}" unless spec.start_with?("(") && spec.end_with?(")")
52
-
53
- inner = spec[1..-2].strip
54
- return [] if inner.empty?
55
-
56
- inner.split(",").map { |dim_str| parse_dimension(dim_str.strip) }
57
- end
58
-
59
- # Parse a single dimension with NEP 20 modifiers
60
- # Examples: "i", "3", "n?", "i|1"
61
- def parse_dimension(dim_str)
62
- return Dimension.new(:empty) if dim_str.empty?
63
-
64
- # Extract modifiers
65
- flexible = dim_str.end_with?("?")
66
- dim_str = dim_str.chomp("?") if flexible
67
-
68
- broadcastable = dim_str.end_with?("|1")
69
- dim_str = dim_str.chomp("|1") if broadcastable
70
-
71
- # Parse name (symbol or integer)
72
- name = if dim_str.match?(/^\d+$/)
73
- dim_str.to_i
74
- else
75
- dim_str.to_sym
76
- end
77
-
78
- Dimension.new(name, flexible: flexible, broadcastable: broadcastable)
79
- rescue StandardError => e
80
- raise SignatureParseError, "invalid dimension #{dim_str.inspect}: #{e.message}"
81
- end
82
- end
83
- end
84
- end
85
- end
86
- end
@@ -1,272 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require_relative "errors"
4
- require_relative "shape"
5
- require_relative "signature"
6
-
7
- module Kumi
8
- module Core
9
- module Functions
10
- # Given a set of signatures and actual argument shapes, pick the best match.
11
- # Supports NEP 20 extensions: fixed-size, flexible, and broadcastable dimensions.
12
- #
13
- # Inputs:
14
- # signatures : Array<Signature> (with Dimension objects)
15
- # arg_shapes : Array<Array<Symbol|Integer>> e.g., [[:i], [:i]] or [[], [3]] or [[2, :i]]
16
- #
17
- # Returns:
18
- # { signature:, result_axes:, join_policy:, dropped_axes:, effective_signature: }
19
- #
20
- # NEP 20 Matching rules:
21
- # - Arity must match exactly (before flexible dimension resolution).
22
- # - Fixed-size dimensions (integers) must match exactly.
23
- # - Flexible dimensions (?) can be omitted if not present in all operands.
24
- # - Broadcastable dimensions (|1) can match scalar or size-1 dimensions.
25
- # - For each param position, shapes are checked according to NEP 20 rules.
26
- # - We prefer exact matches, then flexible matches, then broadcast matches.
27
- class SignatureResolver
28
- class << self
29
- def choose(signatures:, arg_shapes:)
30
- # Handle empty arg_shapes for zero-arity functions
31
- arg_shapes = [] if arg_shapes.nil?
32
- sanity_check_args!(arg_shapes)
33
-
34
- candidates = signatures.map do |sig|
35
- score = match_score(sig, arg_shapes)
36
- next if score.nil?
37
-
38
- # Convert arg_shapes to normalized Dimension arrays for environment building
39
- normalized_args = arg_shapes.map { |shape| normalize_shape(shape) }
40
- env = build_dimension_environment(sig, normalized_args)
41
- next if env.nil? # Skip candidates with dimension conflicts
42
-
43
- {
44
- signature: sig,
45
- score: score,
46
- result_axes: sig.out_shape.map(&:name), # Convert Dimension objects to names for backward compatibility
47
- join_policy: sig.join_policy,
48
- dropped_axes: sig.dropped_axes.map { |name| name.is_a?(Symbol) ? name : name.to_sym }, # Convert to symbols
49
- env: env
50
- }
51
- end.compact
52
-
53
- raise SignatureMatchError, mismatch_message(signatures, arg_shapes) if candidates.empty?
54
-
55
- # Lower score is better: 0 = exact-everywhere, then number of broadcasts
56
- best = candidates.min_by { |c| c[:score] }
57
-
58
- # Add effective signature and environment for analyzer/lowering
59
- best[:effective_signature] = {
60
- in_shapes: best[:signature].in_shapes.map { |dims| dims.map(&:name) },
61
- out_shape: best[:signature].out_shape.map(&:name),
62
- join_policy: best[:signature].join_policy
63
- }
64
- # env is already included from candidate building
65
-
66
- best
67
- end
68
-
69
- private
70
-
71
- def sanity_check_args!(arg_shapes)
72
- unless arg_shapes.is_a?(Array) &&
73
- arg_shapes.all? { |s| s.is_a?(Array) && s.all? { |a| a.is_a?(Symbol) || a.is_a?(Integer) } }
74
- raise SignatureMatchError, "arg_shapes must be an array of dimension arrays (symbols or integers), got: #{arg_shapes.inspect}"
75
- end
76
- end
77
-
78
- # Returns an integer "broadcast cost" or nil if not matchable.
79
- # Lower score = better match: 0 = exact, then increasing cost for broadcasts/flexibility
80
- def match_score(sig, arg_shapes)
81
- return nil unless sig.arity == arg_shapes.length
82
-
83
- # Convert arg_shapes to normalized Dimension arrays for comparison
84
- normalized_args = arg_shapes.map { |shape| normalize_shape(shape) }
85
-
86
- # Try to match each argument against its expected signature shape
87
- cost = 0
88
- sig.in_shapes.each_with_index do |expected_dims, idx|
89
- got_dims = normalized_args[idx]
90
- arg_cost = match_argument_cost(got: got_dims, expected: expected_dims)
91
- return nil if arg_cost.nil?
92
-
93
- cost += arg_cost
94
- end
95
-
96
- # Additional checks for join_policy constraints
97
- return nil unless valid_join_policy?(sig, normalized_args)
98
-
99
- cost
100
- end
101
-
102
- private
103
-
104
- # Convert a shape array (symbols/integers) to normalized Dimension array
105
- def normalize_shape(shape)
106
- shape.map do |dim|
107
- case dim
108
- when Symbol
109
- Dimension.new(dim)
110
- when Integer
111
- Dimension.new(dim)
112
- else
113
- raise SignatureMatchError, "Invalid dimension type: #{dim.class}"
114
- end
115
- end
116
- end
117
-
118
- # Calculate cost of matching one argument against expected dimensions
119
- def match_argument_cost(got:, expected:)
120
- # Handle scalar first
121
- if got.empty?
122
- return expected.empty? ? 0 : (expected.any?(&:flexible?) ? 10 : 1) # scalar broadcast or flexible-tail
123
- end
124
-
125
- # Try strict matching first if no flexible dimensions
126
- if !expected.any?(&:flexible?) && got.length == expected.length
127
- total = 0
128
- got.zip(expected).each do |g, e|
129
- c = match_dimension_cost(got: g, expected: e)
130
- return nil if c.nil?
131
- total += c
132
- end
133
- return total
134
- end
135
-
136
- # Use right-aligned flexible matching
137
- right_align_match(got: got, expected: expected)
138
- end
139
-
140
- # Right-aligned matching for flexible dimensions (NEP 20 ? modifier)
141
- def right_align_match(got:, expected:)
142
- gi = got.length - 1
143
- ei = expected.length - 1
144
- cost = 0
145
-
146
- while ei >= 0
147
- exp = expected[ei]
148
-
149
- if exp.flexible? && gi < 0
150
- # optional tail dimension that we don't have → ok, consume expected only
151
- ei -= 1
152
- cost += 10
153
- next
154
- end
155
-
156
- return nil if gi < 0 # ran out of got dims and exp wasn't flexible
157
-
158
- got_dim = got[gi]
159
- dim_cost = match_dimension_cost(got: got_dim, expected: exp)
160
- if dim_cost.nil?
161
- # if exp is flexible, we can try to drop it
162
- if exp.flexible?
163
- ei -= 1
164
- cost += 10
165
- next
166
- else
167
- return nil
168
- end
169
- else
170
- cost += dim_cost
171
- gi -= 1
172
- ei -= 1
173
- end
174
- end
175
-
176
- # if we still have leftover got dims, argument is longer than expected → not a match
177
- return nil if gi >= 0
178
-
179
- cost
180
- end
181
-
182
- # Calculate cost of matching one dimension against another
183
- def match_dimension_cost(got:, expected:)
184
- return 0 if got == expected # Exact match
185
-
186
- # Fixed-size equality
187
- if got.fixed_size? && expected.fixed_size?
188
- return got.size == expected.size ? 0 : nil
189
- end
190
-
191
- # Same symbolic name (ignoring modifiers) → ok unless one is fixed and the other isn't (penalize)
192
- if got.named? && expected.named? && got.name == expected.name
193
- return (got.fixed_size? || expected.fixed_size?) ? 2 : 0
194
- end
195
-
196
- # Broadcastable expected dim accepts scalar or size-1
197
- if expected.broadcastable?
198
- # scalar at argument level would have been handled in match_argument_cost
199
- # so here we check for size-1 fixed dimensions
200
- return 3 if got.fixed_size? && got.size == 1
201
- # Named dimensions that could be size-1 at runtime also get broadcast cost
202
- return 3 if got.named?
203
- end
204
-
205
- nil # No match possible
206
- end
207
-
208
- # Check if join_policy constraints are satisfied
209
- def valid_join_policy?(sig, normalized_args)
210
- return true if sig.join_policy # :zip or :product allows different axes
211
-
212
- # nil join_policy: check if dimension names are consistent
213
- non_scalar_args = normalized_args.reject { |a| Shape.scalar?(a) }
214
- return true if non_scalar_args.empty?
215
-
216
- # For nil join_policy, we allow different dimension names if:
217
- # 1. All args have same dimension names (element-wise operations), OR
218
- # 2. The constraint solver can validate cross-dimensional consistency (like matmul)
219
- first_names = non_scalar_args.first.map(&:name)
220
- same_names = non_scalar_args.all? { |arg| arg.map(&:name) == first_names }
221
-
222
- return true if same_names
223
-
224
- # If dimension names differ, check if constraint solver can handle it
225
- # This allows operations like matmul where dimensions are linked across arguments
226
- env = build_dimension_environment(sig, normalized_args)
227
- !env.nil?
228
- end
229
-
230
- def mismatch_message(signatures, arg_shapes)
231
- sigs = signatures.map(&:inspect).join(", ")
232
- "no matching signature for shapes #{pp_shapes(arg_shapes)} among [#{sigs}]"
233
- end
234
-
235
- def pp_shapes(shapes)
236
- shapes.map { |ax| "(#{ax.join(',')})" }.join(", ")
237
- end
238
-
239
- # Build dimension environment by checking consistency of named dimensions across arguments
240
- def build_dimension_environment(sig, normalized_args)
241
- env = {}
242
-
243
- # Walk all expected dimensions across all arguments
244
- sig.in_shapes.each_with_index do |expected_shape, arg_idx|
245
- got_shape = normalized_args[arg_idx] || []
246
-
247
- expected_shape.each_with_index do |exp_dim, dim_idx|
248
- next unless exp_dim.named? && dim_idx < got_shape.length
249
-
250
- got_dim = got_shape[dim_idx]
251
- dim_name = exp_dim.name
252
-
253
- # Check for consistency: same dimension name must map to same concrete value
254
- if env.key?(dim_name)
255
- # If we've seen this dimension name before, it must match
256
- if env[dim_name] != got_dim
257
- return nil # Inconsistent binding - signature doesn't match
258
- end
259
- else
260
- # First time seeing this dimension name - record the binding
261
- env[dim_name] = got_dim
262
- end
263
- end
264
- end
265
-
266
- env
267
- end
268
- end
269
- end
270
- end
271
- end
272
- end