kumi 0.0.15 → 0.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/golden/cascade_logic/schema.kumi +3 -1
- data/lib/kumi/analyzer.rb +11 -9
- data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +0 -81
- data/lib/kumi/core/analyzer/passes/ir_dependency_pass.rb +18 -20
- data/lib/kumi/core/analyzer/passes/ir_execution_schedule_pass.rb +67 -0
- data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +0 -36
- data/lib/kumi/core/analyzer/passes/toposorter.rb +1 -39
- data/lib/kumi/core/analyzer/passes/unsat_detector.rb +8 -191
- data/lib/kumi/core/compiler/access_builder.rb +20 -10
- data/lib/kumi/core/compiler/access_codegen.rb +61 -0
- data/lib/kumi/core/compiler/access_emit/base.rb +173 -0
- data/lib/kumi/core/compiler/access_emit/each_indexed.rb +56 -0
- data/lib/kumi/core/compiler/access_emit/materialize.rb +45 -0
- data/lib/kumi/core/compiler/access_emit/ravel.rb +50 -0
- data/lib/kumi/core/compiler/access_emit/read.rb +32 -0
- data/lib/kumi/core/ir/execution_engine/interpreter.rb +36 -181
- data/lib/kumi/core/ir/execution_engine/values.rb +8 -8
- data/lib/kumi/core/ir/execution_engine.rb +3 -19
- data/lib/kumi/dev/parse.rb +12 -12
- data/lib/kumi/runtime/executable.rb +22 -175
- data/lib/kumi/runtime/run.rb +105 -0
- data/lib/kumi/schema.rb +8 -13
- data/lib/kumi/version.rb +1 -1
- data/lib/kumi.rb +3 -2
- metadata +10 -25
- data/BACKLOG.md +0 -34
- data/config/functions.yaml +0 -352
- data/docs/functions/analyzer_integration.md +0 -199
- data/docs/functions/signatures.md +0 -171
- data/examples/hash_objects_demo.rb +0 -138
- data/lib/kumi/core/analyzer/passes/function_signature_pass.rb +0 -199
- data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +0 -48
- data/lib/kumi/core/functions/dimension.rb +0 -98
- data/lib/kumi/core/functions/dtypes.rb +0 -20
- data/lib/kumi/core/functions/errors.rb +0 -11
- data/lib/kumi/core/functions/kernel_adapter.rb +0 -45
- data/lib/kumi/core/functions/loader.rb +0 -119
- data/lib/kumi/core/functions/registry_v2.rb +0 -68
- data/lib/kumi/core/functions/shape.rb +0 -70
- data/lib/kumi/core/functions/signature.rb +0 -122
- data/lib/kumi/core/functions/signature_parser.rb +0 -86
- data/lib/kumi/core/functions/signature_resolver.rb +0 -272
- data/lib/kumi/kernels/ruby/aggregate_core.rb +0 -105
- data/lib/kumi/kernels/ruby/datetime_scalar.rb +0 -21
- data/lib/kumi/kernels/ruby/mask_scalar.rb +0 -15
- data/lib/kumi/kernels/ruby/scalar_core.rb +0 -63
- data/lib/kumi/kernels/ruby/string_scalar.rb +0 -19
- data/lib/kumi/kernels/ruby/vector_struct.rb +0 -39
@@ -1,68 +0,0 @@
|
|
1
|
-
module Kumi::Core::Functions
|
2
|
-
class RegistryV2
|
3
|
-
def initialize(functions:)
|
4
|
-
@by_name = functions.group_by(&:name).transform_values { |v| v.sort_by(&:opset).freeze }.freeze
|
5
|
-
end
|
6
|
-
|
7
|
-
# Factory method to load from YAML configuration
|
8
|
-
def self.load_from_file(path = nil)
|
9
|
-
path ||= File.join(__dir__, "../../../..", "config", "functions.yaml")
|
10
|
-
functions = Loader.load_file(path)
|
11
|
-
new(functions: functions)
|
12
|
-
end
|
13
|
-
|
14
|
-
def fetch(name, opset: nil)
|
15
|
-
list = @by_name[name] or raise KeyError, "unknown function #{name}"
|
16
|
-
opset ? list.find { |f| f.opset == opset } : list.last
|
17
|
-
end
|
18
|
-
|
19
|
-
# Get function signatures for NEP-20 signature resolution
|
20
|
-
# This bridges RegistryV2 with our existing SignatureResolver
|
21
|
-
def get_function_signatures(name, opset: nil)
|
22
|
-
begin
|
23
|
-
fn = fetch(name, opset: opset)
|
24
|
-
# Convert Signature objects to string representations for NEP-20 parser
|
25
|
-
fn.signatures.map(&:to_signature_string)
|
26
|
-
rescue KeyError
|
27
|
-
[] # Function not found in RegistryV2 - fall back to legacy registry
|
28
|
-
end
|
29
|
-
end
|
30
|
-
|
31
|
-
# Enhanced signature resolution using NEP-20 resolver
|
32
|
-
def choose_signature(fn, arg_shapes)
|
33
|
-
# Use our NEP-20 SignatureResolver for proper dimension handling
|
34
|
-
sig_strings = fn.signatures.map(&:to_signature_string)
|
35
|
-
parsed_sigs = sig_strings.map { |s| SignatureParser.parse(s) }
|
36
|
-
|
37
|
-
plan = SignatureResolver.choose(signatures: parsed_sigs, arg_shapes: arg_shapes)
|
38
|
-
|
39
|
-
# Return both the original function signature and the resolution plan
|
40
|
-
{
|
41
|
-
function: fn,
|
42
|
-
signature: plan[:signature],
|
43
|
-
plan: plan
|
44
|
-
}
|
45
|
-
rescue SignatureMatchError => e
|
46
|
-
raise ArgumentError, "no matching signature for #{fn.name} with shapes #{arg_shapes.inspect}: #{e.message}"
|
47
|
-
end
|
48
|
-
|
49
|
-
def resolve_kernel(fn, backend:, conditions: {})
|
50
|
-
ks = fn.kernels.select { |k| k.backend == backend.to_sym }
|
51
|
-
ks = ks.select { |k| conditions.all? { |ck, cv| k.when_&.fetch(ck, cv) == cv } } unless conditions.empty?
|
52
|
-
ks.max_by(&:priority) or raise "no kernel for #{fn.name} backend=#{backend}"
|
53
|
-
end
|
54
|
-
|
55
|
-
# Introspection methods
|
56
|
-
def all_function_names
|
57
|
-
@by_name.keys
|
58
|
-
end
|
59
|
-
|
60
|
-
def function_exists?(name)
|
61
|
-
@by_name.key?(name)
|
62
|
-
end
|
63
|
-
|
64
|
-
def all_functions
|
65
|
-
@by_name.values.flatten
|
66
|
-
end
|
67
|
-
end
|
68
|
-
end
|
@@ -1,70 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require_relative "dimension"
|
4
|
-
|
5
|
-
module Kumi
|
6
|
-
module Core
|
7
|
-
module Functions
|
8
|
-
# Shape utilities with NEP 20 support. A "shape" is an Array<Dimension> of dimension objects.
|
9
|
-
# [] == scalar, [Dimension.new(:i)] == vector along :i, etc.
|
10
|
-
module Shape
|
11
|
-
module_function
|
12
|
-
|
13
|
-
def scalar?(shape) = shape.empty?
|
14
|
-
|
15
|
-
def equal?(a, b) = a.map(&:name) == b.map(&:name)
|
16
|
-
|
17
|
-
# NEP 20 broadcast rules:
|
18
|
-
# - scalar can broadcast to any expected shape
|
19
|
-
# - fixed-size dimensions must match exactly
|
20
|
-
# - broadcastable dimensions with |1 modifier can broadcast against size-1
|
21
|
-
# - flexible dimensions with ? can be omitted if not present in all operands
|
22
|
-
def broadcastable?(got:, expected:)
|
23
|
-
return true if scalar?(got)
|
24
|
-
return false if got.length != expected.length
|
25
|
-
|
26
|
-
got.zip(expected).all? do |got_dim, exp_dim|
|
27
|
-
broadcastable_dimension?(got: got_dim, expected: exp_dim)
|
28
|
-
end
|
29
|
-
end
|
30
|
-
|
31
|
-
def broadcastable_dimension?(got:, expected:)
|
32
|
-
# Same name and modifiers
|
33
|
-
return true if got == expected
|
34
|
-
|
35
|
-
# Same name, different modifiers - check compatibility
|
36
|
-
if got.name == expected.name
|
37
|
-
# Fixed-size dimensions must match exactly
|
38
|
-
if got.fixed_size? || expected.fixed_size?
|
39
|
-
return got.size == expected.size if got.fixed_size? && expected.fixed_size?
|
40
|
-
return false # one fixed, one not - incompatible
|
41
|
-
end
|
42
|
-
|
43
|
-
# Both named dimensions with same name are compatible
|
44
|
-
return true
|
45
|
-
end
|
46
|
-
|
47
|
-
# Different names - only broadcastable with |1 modifier
|
48
|
-
expected.broadcastable? && scalar?([got])
|
49
|
-
end
|
50
|
-
|
51
|
-
# Check if a dimension can be omitted (NEP 20 flexible dimensions)
|
52
|
-
def flexible?(dim)
|
53
|
-
dim.is_a?(Dimension) && dim.flexible?
|
54
|
-
end
|
55
|
-
|
56
|
-
# Check if a dimension can broadcast (NEP 20 broadcastable dimensions)
|
57
|
-
def broadcastable_dimension?(dim)
|
58
|
-
dim.is_a?(Dimension) && dim.broadcastable?
|
59
|
-
end
|
60
|
-
|
61
|
-
# Convenience: find dimensions in set a that are not in set b
|
62
|
-
def dimensions_minus(a, b)
|
63
|
-
a_names = a.map(&:name).to_set
|
64
|
-
b_names = b.map(&:name).to_set
|
65
|
-
(a_names - b_names).to_a
|
66
|
-
end
|
67
|
-
end
|
68
|
-
end
|
69
|
-
end
|
70
|
-
end
|
@@ -1,122 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Kumi
|
4
|
-
module Core
|
5
|
-
module Functions
|
6
|
-
# Signature is a small immutable value object describing a function's
|
7
|
-
# vectorization contract with NEP 20 support.
|
8
|
-
#
|
9
|
-
# in_shapes: Array<Array<Dimension>> e.g., [[Dimension.new(:i)], [Dimension.new(:i)]]
|
10
|
-
# out_shape: Array<Dimension> e.g., [Dimension.new(:i)], [Dimension.new(:i), Dimension.new(:j)], or []
|
11
|
-
# join_policy: nil | :zip | :product
|
12
|
-
# raw: original string, for diagnostics (optional)
|
13
|
-
class Signature
|
14
|
-
attr_reader :in_shapes, :out_shape, :join_policy, :raw
|
15
|
-
|
16
|
-
def initialize(in_shapes:, out_shape:, join_policy: nil, raw: nil)
|
17
|
-
@in_shapes = deep_dup(in_shapes).freeze
|
18
|
-
@out_shape = out_shape.dup.freeze
|
19
|
-
@join_policy = join_policy&.to_sym
|
20
|
-
@raw = raw
|
21
|
-
validate!
|
22
|
-
freeze
|
23
|
-
end
|
24
|
-
|
25
|
-
def arity = @in_shapes.length
|
26
|
-
|
27
|
-
# Dimensions that appear in any input but not in output (i.e., reduced/dropped).
|
28
|
-
def dropped_axes
|
29
|
-
input_names = @in_shapes.flatten.map(&:name)
|
30
|
-
output_names = @out_shape.map(&:name)
|
31
|
-
(input_names - output_names).uniq.freeze
|
32
|
-
end
|
33
|
-
|
34
|
-
# True if any axis from inputs is dropped (common in aggregates).
|
35
|
-
def reduction?
|
36
|
-
!dropped_axes.empty?
|
37
|
-
end
|
38
|
-
|
39
|
-
def to_h
|
40
|
-
{
|
41
|
-
in_shapes: in_shapes.map(&:dup),
|
42
|
-
out_shape: out_shape.dup,
|
43
|
-
join_policy: join_policy,
|
44
|
-
raw: raw
|
45
|
-
}
|
46
|
-
end
|
47
|
-
|
48
|
-
def inspect
|
49
|
-
"#<Signature #{format_signature}#{" @#{join_policy}" if join_policy}>"
|
50
|
-
end
|
51
|
-
|
52
|
-
def format_signature
|
53
|
-
lhs = in_shapes.map { |dims| "(#{dims.map(&:to_s).join(',')})" }.join(",")
|
54
|
-
rhs = "(#{out_shape.map(&:to_s).join(',')})"
|
55
|
-
"#{lhs}->#{rhs}"
|
56
|
-
end
|
57
|
-
|
58
|
-
# Convert back to string representation for NEP-20 parser compatibility
|
59
|
-
def to_signature_string
|
60
|
-
sig_str = format_signature
|
61
|
-
join_policy ? "#{sig_str}@#{join_policy}" : sig_str
|
62
|
-
end
|
63
|
-
|
64
|
-
private
|
65
|
-
|
66
|
-
def validate!
|
67
|
-
unless @in_shapes.is_a?(Array) && @in_shapes.all? { |s| s.is_a?(Array) }
|
68
|
-
raise SignatureError, "in_shapes must be an array of dimension arrays"
|
69
|
-
end
|
70
|
-
|
71
|
-
@in_shapes.each_with_index do |dims, idx|
|
72
|
-
validate_dimensions!(dims, where: "in_shapes[#{idx}]")
|
73
|
-
end
|
74
|
-
|
75
|
-
validate_dimensions!(@out_shape, where: "out_shape")
|
76
|
-
|
77
|
-
unless [nil, :zip, :product].include?(@join_policy)
|
78
|
-
raise SignatureError, "join_policy must be nil, :zip, or :product; got #{@join_policy.inspect}"
|
79
|
-
end
|
80
|
-
|
81
|
-
# Validate NEP 20 constraints
|
82
|
-
validate_nep20_constraints!
|
83
|
-
end
|
84
|
-
|
85
|
-
def validate_dimensions!(dims, where: "shape")
|
86
|
-
unless dims.is_a?(Array) && dims.all? { |d| d.is_a?(Dimension) }
|
87
|
-
raise SignatureError, "#{where}: must be an array of Dimension objects, got: #{dims.inspect}"
|
88
|
-
end
|
89
|
-
|
90
|
-
# Check for duplicate dimension names within a single argument
|
91
|
-
names = dims.map(&:name)
|
92
|
-
duplicates = names.group_by { |n| n }.select { |_, v| v.size > 1 }.keys
|
93
|
-
raise SignatureError, "#{where}: duplicate dimension names #{duplicates.inspect}" unless duplicates.empty?
|
94
|
-
|
95
|
-
true
|
96
|
-
end
|
97
|
-
|
98
|
-
def validate_nep20_constraints!
|
99
|
-
# Broadcastable dimensions should only appear in inputs, not outputs
|
100
|
-
@out_shape.each do |dim|
|
101
|
-
raise SignatureError, "output dimension #{dim} cannot be broadcastable" if dim.broadcastable?
|
102
|
-
end
|
103
|
-
|
104
|
-
# Fixed-size dimensions in outputs must match corresponding input dimensions
|
105
|
-
all_input_dims = @in_shapes.flatten
|
106
|
-
@out_shape.each do |out_dim|
|
107
|
-
next unless out_dim.fixed_size?
|
108
|
-
|
109
|
-
matching_inputs = all_input_dims.select { |in_dim| in_dim.name == out_dim.name }
|
110
|
-
matching_inputs.each do |in_dim|
|
111
|
-
if in_dim.fixed_size? && in_dim.size != out_dim.size
|
112
|
-
raise SignatureError, "fixed-size dimension #{out_dim.name} has inconsistent sizes: #{in_dim.size} vs #{out_dim.size}"
|
113
|
-
end
|
114
|
-
end
|
115
|
-
end
|
116
|
-
end
|
117
|
-
|
118
|
-
def deep_dup(arr) = arr.map { |x| x.is_a?(Array) ? x.dup : x }
|
119
|
-
end
|
120
|
-
end
|
121
|
-
end
|
122
|
-
end
|
@@ -1,86 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require_relative "dimension"
|
4
|
-
require_relative "signature"
|
5
|
-
|
6
|
-
module Kumi
|
7
|
-
module Core
|
8
|
-
module Functions
|
9
|
-
# Parses NEP 20 conformant signature strings like:
|
10
|
-
# "(),()->()" # scalar operations
|
11
|
-
# "(i),(i)->(i)" # vector operations
|
12
|
-
# "(i),(j)->(i,j)@product" # matrix operations with join policy
|
13
|
-
# "(i,j)->(i)" # reduction of :j
|
14
|
-
# "(3),(3)->(3)" # fixed-size 3-vectors (cross product)
|
15
|
-
# "(i?),(i?)->(i?)" # flexible dimensions
|
16
|
-
# "(i|1),(i|1)->()" # broadcastable dimensions
|
17
|
-
# "(m?,n),(n,p?)->(m?,p?)" # matmul signature
|
18
|
-
class SignatureParser
|
19
|
-
class << self
|
20
|
-
def parse(str)
|
21
|
-
raise SignatureParseError, "empty signature" if str.nil? || str.strip.empty?
|
22
|
-
lhs, rhs = str.split("->", 2)&.map!(&:strip)
|
23
|
-
raise SignatureParseError, "signature must contain '->': #{str.inspect}" unless rhs
|
24
|
-
|
25
|
-
out_spec, policy = rhs.split("@", 2)&.map!(&:strip)
|
26
|
-
in_shapes = parse_many(lhs)
|
27
|
-
out_shape = parse_axes(out_spec)
|
28
|
-
join_policy = policy&.to_sym
|
29
|
-
|
30
|
-
Signature.new(in_shapes: in_shapes, out_shape: out_shape, join_policy: join_policy, raw: str)
|
31
|
-
rescue SignatureError => e
|
32
|
-
raise
|
33
|
-
rescue StandardError => e
|
34
|
-
raise SignatureParseError, "invalid signature #{str.inspect}: #{e.message}"
|
35
|
-
end
|
36
|
-
|
37
|
-
private
|
38
|
-
|
39
|
-
def parse_many(lhs)
|
40
|
-
# Handle zero arguments case
|
41
|
-
return [] if lhs.strip.empty?
|
42
|
-
|
43
|
-
# split by commas that are *between* groups, not inside them
|
44
|
-
# simpler approach: split by '),', re-add ')' where needed
|
45
|
-
tokens = lhs.split("),").map { |t| t.strip.end_with?(")") ? t.strip : "#{t.strip})" }
|
46
|
-
tokens.map { |tok| parse_axes(tok) }
|
47
|
-
end
|
48
|
-
|
49
|
-
def parse_axes(spec)
|
50
|
-
spec = spec.strip
|
51
|
-
raise SignatureParseError, "missing parentheses in #{spec.inspect}" unless spec.start_with?("(") && spec.end_with?(")")
|
52
|
-
|
53
|
-
inner = spec[1..-2].strip
|
54
|
-
return [] if inner.empty?
|
55
|
-
|
56
|
-
inner.split(",").map { |dim_str| parse_dimension(dim_str.strip) }
|
57
|
-
end
|
58
|
-
|
59
|
-
# Parse a single dimension with NEP 20 modifiers
|
60
|
-
# Examples: "i", "3", "n?", "i|1"
|
61
|
-
def parse_dimension(dim_str)
|
62
|
-
return Dimension.new(:empty) if dim_str.empty?
|
63
|
-
|
64
|
-
# Extract modifiers
|
65
|
-
flexible = dim_str.end_with?("?")
|
66
|
-
dim_str = dim_str.chomp("?") if flexible
|
67
|
-
|
68
|
-
broadcastable = dim_str.end_with?("|1")
|
69
|
-
dim_str = dim_str.chomp("|1") if broadcastable
|
70
|
-
|
71
|
-
# Parse name (symbol or integer)
|
72
|
-
name = if dim_str.match?(/^\d+$/)
|
73
|
-
dim_str.to_i
|
74
|
-
else
|
75
|
-
dim_str.to_sym
|
76
|
-
end
|
77
|
-
|
78
|
-
Dimension.new(name, flexible: flexible, broadcastable: broadcastable)
|
79
|
-
rescue StandardError => e
|
80
|
-
raise SignatureParseError, "invalid dimension #{dim_str.inspect}: #{e.message}"
|
81
|
-
end
|
82
|
-
end
|
83
|
-
end
|
84
|
-
end
|
85
|
-
end
|
86
|
-
end
|
@@ -1,272 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require_relative "errors"
|
4
|
-
require_relative "shape"
|
5
|
-
require_relative "signature"
|
6
|
-
|
7
|
-
module Kumi
|
8
|
-
module Core
|
9
|
-
module Functions
|
10
|
-
# Given a set of signatures and actual argument shapes, pick the best match.
|
11
|
-
# Supports NEP 20 extensions: fixed-size, flexible, and broadcastable dimensions.
|
12
|
-
#
|
13
|
-
# Inputs:
|
14
|
-
# signatures : Array<Signature> (with Dimension objects)
|
15
|
-
# arg_shapes : Array<Array<Symbol|Integer>> e.g., [[:i], [:i]] or [[], [3]] or [[2, :i]]
|
16
|
-
#
|
17
|
-
# Returns:
|
18
|
-
# { signature:, result_axes:, join_policy:, dropped_axes:, effective_signature: }
|
19
|
-
#
|
20
|
-
# NEP 20 Matching rules:
|
21
|
-
# - Arity must match exactly (before flexible dimension resolution).
|
22
|
-
# - Fixed-size dimensions (integers) must match exactly.
|
23
|
-
# - Flexible dimensions (?) can be omitted if not present in all operands.
|
24
|
-
# - Broadcastable dimensions (|1) can match scalar or size-1 dimensions.
|
25
|
-
# - For each param position, shapes are checked according to NEP 20 rules.
|
26
|
-
# - We prefer exact matches, then flexible matches, then broadcast matches.
|
27
|
-
class SignatureResolver
|
28
|
-
class << self
|
29
|
-
def choose(signatures:, arg_shapes:)
|
30
|
-
# Handle empty arg_shapes for zero-arity functions
|
31
|
-
arg_shapes = [] if arg_shapes.nil?
|
32
|
-
sanity_check_args!(arg_shapes)
|
33
|
-
|
34
|
-
candidates = signatures.map do |sig|
|
35
|
-
score = match_score(sig, arg_shapes)
|
36
|
-
next if score.nil?
|
37
|
-
|
38
|
-
# Convert arg_shapes to normalized Dimension arrays for environment building
|
39
|
-
normalized_args = arg_shapes.map { |shape| normalize_shape(shape) }
|
40
|
-
env = build_dimension_environment(sig, normalized_args)
|
41
|
-
next if env.nil? # Skip candidates with dimension conflicts
|
42
|
-
|
43
|
-
{
|
44
|
-
signature: sig,
|
45
|
-
score: score,
|
46
|
-
result_axes: sig.out_shape.map(&:name), # Convert Dimension objects to names for backward compatibility
|
47
|
-
join_policy: sig.join_policy,
|
48
|
-
dropped_axes: sig.dropped_axes.map { |name| name.is_a?(Symbol) ? name : name.to_sym }, # Convert to symbols
|
49
|
-
env: env
|
50
|
-
}
|
51
|
-
end.compact
|
52
|
-
|
53
|
-
raise SignatureMatchError, mismatch_message(signatures, arg_shapes) if candidates.empty?
|
54
|
-
|
55
|
-
# Lower score is better: 0 = exact-everywhere, then number of broadcasts
|
56
|
-
best = candidates.min_by { |c| c[:score] }
|
57
|
-
|
58
|
-
# Add effective signature and environment for analyzer/lowering
|
59
|
-
best[:effective_signature] = {
|
60
|
-
in_shapes: best[:signature].in_shapes.map { |dims| dims.map(&:name) },
|
61
|
-
out_shape: best[:signature].out_shape.map(&:name),
|
62
|
-
join_policy: best[:signature].join_policy
|
63
|
-
}
|
64
|
-
# env is already included from candidate building
|
65
|
-
|
66
|
-
best
|
67
|
-
end
|
68
|
-
|
69
|
-
private
|
70
|
-
|
71
|
-
def sanity_check_args!(arg_shapes)
|
72
|
-
unless arg_shapes.is_a?(Array) &&
|
73
|
-
arg_shapes.all? { |s| s.is_a?(Array) && s.all? { |a| a.is_a?(Symbol) || a.is_a?(Integer) } }
|
74
|
-
raise SignatureMatchError, "arg_shapes must be an array of dimension arrays (symbols or integers), got: #{arg_shapes.inspect}"
|
75
|
-
end
|
76
|
-
end
|
77
|
-
|
78
|
-
# Returns an integer "broadcast cost" or nil if not matchable.
|
79
|
-
# Lower score = better match: 0 = exact, then increasing cost for broadcasts/flexibility
|
80
|
-
def match_score(sig, arg_shapes)
|
81
|
-
return nil unless sig.arity == arg_shapes.length
|
82
|
-
|
83
|
-
# Convert arg_shapes to normalized Dimension arrays for comparison
|
84
|
-
normalized_args = arg_shapes.map { |shape| normalize_shape(shape) }
|
85
|
-
|
86
|
-
# Try to match each argument against its expected signature shape
|
87
|
-
cost = 0
|
88
|
-
sig.in_shapes.each_with_index do |expected_dims, idx|
|
89
|
-
got_dims = normalized_args[idx]
|
90
|
-
arg_cost = match_argument_cost(got: got_dims, expected: expected_dims)
|
91
|
-
return nil if arg_cost.nil?
|
92
|
-
|
93
|
-
cost += arg_cost
|
94
|
-
end
|
95
|
-
|
96
|
-
# Additional checks for join_policy constraints
|
97
|
-
return nil unless valid_join_policy?(sig, normalized_args)
|
98
|
-
|
99
|
-
cost
|
100
|
-
end
|
101
|
-
|
102
|
-
private
|
103
|
-
|
104
|
-
# Convert a shape array (symbols/integers) to normalized Dimension array
|
105
|
-
def normalize_shape(shape)
|
106
|
-
shape.map do |dim|
|
107
|
-
case dim
|
108
|
-
when Symbol
|
109
|
-
Dimension.new(dim)
|
110
|
-
when Integer
|
111
|
-
Dimension.new(dim)
|
112
|
-
else
|
113
|
-
raise SignatureMatchError, "Invalid dimension type: #{dim.class}"
|
114
|
-
end
|
115
|
-
end
|
116
|
-
end
|
117
|
-
|
118
|
-
# Calculate cost of matching one argument against expected dimensions
|
119
|
-
def match_argument_cost(got:, expected:)
|
120
|
-
# Handle scalar first
|
121
|
-
if got.empty?
|
122
|
-
return expected.empty? ? 0 : (expected.any?(&:flexible?) ? 10 : 1) # scalar broadcast or flexible-tail
|
123
|
-
end
|
124
|
-
|
125
|
-
# Try strict matching first if no flexible dimensions
|
126
|
-
if !expected.any?(&:flexible?) && got.length == expected.length
|
127
|
-
total = 0
|
128
|
-
got.zip(expected).each do |g, e|
|
129
|
-
c = match_dimension_cost(got: g, expected: e)
|
130
|
-
return nil if c.nil?
|
131
|
-
total += c
|
132
|
-
end
|
133
|
-
return total
|
134
|
-
end
|
135
|
-
|
136
|
-
# Use right-aligned flexible matching
|
137
|
-
right_align_match(got: got, expected: expected)
|
138
|
-
end
|
139
|
-
|
140
|
-
# Right-aligned matching for flexible dimensions (NEP 20 ? modifier)
|
141
|
-
def right_align_match(got:, expected:)
|
142
|
-
gi = got.length - 1
|
143
|
-
ei = expected.length - 1
|
144
|
-
cost = 0
|
145
|
-
|
146
|
-
while ei >= 0
|
147
|
-
exp = expected[ei]
|
148
|
-
|
149
|
-
if exp.flexible? && gi < 0
|
150
|
-
# optional tail dimension that we don't have → ok, consume expected only
|
151
|
-
ei -= 1
|
152
|
-
cost += 10
|
153
|
-
next
|
154
|
-
end
|
155
|
-
|
156
|
-
return nil if gi < 0 # ran out of got dims and exp wasn't flexible
|
157
|
-
|
158
|
-
got_dim = got[gi]
|
159
|
-
dim_cost = match_dimension_cost(got: got_dim, expected: exp)
|
160
|
-
if dim_cost.nil?
|
161
|
-
# if exp is flexible, we can try to drop it
|
162
|
-
if exp.flexible?
|
163
|
-
ei -= 1
|
164
|
-
cost += 10
|
165
|
-
next
|
166
|
-
else
|
167
|
-
return nil
|
168
|
-
end
|
169
|
-
else
|
170
|
-
cost += dim_cost
|
171
|
-
gi -= 1
|
172
|
-
ei -= 1
|
173
|
-
end
|
174
|
-
end
|
175
|
-
|
176
|
-
# if we still have leftover got dims, argument is longer than expected → not a match
|
177
|
-
return nil if gi >= 0
|
178
|
-
|
179
|
-
cost
|
180
|
-
end
|
181
|
-
|
182
|
-
# Calculate cost of matching one dimension against another
|
183
|
-
def match_dimension_cost(got:, expected:)
|
184
|
-
return 0 if got == expected # Exact match
|
185
|
-
|
186
|
-
# Fixed-size equality
|
187
|
-
if got.fixed_size? && expected.fixed_size?
|
188
|
-
return got.size == expected.size ? 0 : nil
|
189
|
-
end
|
190
|
-
|
191
|
-
# Same symbolic name (ignoring modifiers) → ok unless one is fixed and the other isn't (penalize)
|
192
|
-
if got.named? && expected.named? && got.name == expected.name
|
193
|
-
return (got.fixed_size? || expected.fixed_size?) ? 2 : 0
|
194
|
-
end
|
195
|
-
|
196
|
-
# Broadcastable expected dim accepts scalar or size-1
|
197
|
-
if expected.broadcastable?
|
198
|
-
# scalar at argument level would have been handled in match_argument_cost
|
199
|
-
# so here we check for size-1 fixed dimensions
|
200
|
-
return 3 if got.fixed_size? && got.size == 1
|
201
|
-
# Named dimensions that could be size-1 at runtime also get broadcast cost
|
202
|
-
return 3 if got.named?
|
203
|
-
end
|
204
|
-
|
205
|
-
nil # No match possible
|
206
|
-
end
|
207
|
-
|
208
|
-
# Check if join_policy constraints are satisfied
|
209
|
-
def valid_join_policy?(sig, normalized_args)
|
210
|
-
return true if sig.join_policy # :zip or :product allows different axes
|
211
|
-
|
212
|
-
# nil join_policy: check if dimension names are consistent
|
213
|
-
non_scalar_args = normalized_args.reject { |a| Shape.scalar?(a) }
|
214
|
-
return true if non_scalar_args.empty?
|
215
|
-
|
216
|
-
# For nil join_policy, we allow different dimension names if:
|
217
|
-
# 1. All args have same dimension names (element-wise operations), OR
|
218
|
-
# 2. The constraint solver can validate cross-dimensional consistency (like matmul)
|
219
|
-
first_names = non_scalar_args.first.map(&:name)
|
220
|
-
same_names = non_scalar_args.all? { |arg| arg.map(&:name) == first_names }
|
221
|
-
|
222
|
-
return true if same_names
|
223
|
-
|
224
|
-
# If dimension names differ, check if constraint solver can handle it
|
225
|
-
# This allows operations like matmul where dimensions are linked across arguments
|
226
|
-
env = build_dimension_environment(sig, normalized_args)
|
227
|
-
!env.nil?
|
228
|
-
end
|
229
|
-
|
230
|
-
def mismatch_message(signatures, arg_shapes)
|
231
|
-
sigs = signatures.map(&:inspect).join(", ")
|
232
|
-
"no matching signature for shapes #{pp_shapes(arg_shapes)} among [#{sigs}]"
|
233
|
-
end
|
234
|
-
|
235
|
-
def pp_shapes(shapes)
|
236
|
-
shapes.map { |ax| "(#{ax.join(',')})" }.join(", ")
|
237
|
-
end
|
238
|
-
|
239
|
-
# Build dimension environment by checking consistency of named dimensions across arguments
|
240
|
-
def build_dimension_environment(sig, normalized_args)
|
241
|
-
env = {}
|
242
|
-
|
243
|
-
# Walk all expected dimensions across all arguments
|
244
|
-
sig.in_shapes.each_with_index do |expected_shape, arg_idx|
|
245
|
-
got_shape = normalized_args[arg_idx] || []
|
246
|
-
|
247
|
-
expected_shape.each_with_index do |exp_dim, dim_idx|
|
248
|
-
next unless exp_dim.named? && dim_idx < got_shape.length
|
249
|
-
|
250
|
-
got_dim = got_shape[dim_idx]
|
251
|
-
dim_name = exp_dim.name
|
252
|
-
|
253
|
-
# Check for consistency: same dimension name must map to same concrete value
|
254
|
-
if env.key?(dim_name)
|
255
|
-
# If we've seen this dimension name before, it must match
|
256
|
-
if env[dim_name] != got_dim
|
257
|
-
return nil # Inconsistent binding - signature doesn't match
|
258
|
-
end
|
259
|
-
else
|
260
|
-
# First time seeing this dimension name - record the binding
|
261
|
-
env[dim_name] = got_dim
|
262
|
-
end
|
263
|
-
end
|
264
|
-
end
|
265
|
-
|
266
|
-
env
|
267
|
-
end
|
268
|
-
end
|
269
|
-
end
|
270
|
-
end
|
271
|
-
end
|
272
|
-
end
|