kumi 0.0.13 → 0.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rspec +0 -1
- data/BACKLOG.md +34 -0
- data/CLAUDE.md +4 -6
- data/README.md +0 -18
- data/config/functions.yaml +352 -0
- data/docs/dev/analyzer-debug.md +52 -0
- data/docs/dev/parse-command.md +64 -0
- data/docs/functions/analyzer_integration.md +199 -0
- data/docs/functions/signatures.md +171 -0
- data/examples/hash_objects_demo.rb +138 -0
- data/golden/array_operations/schema.kumi +17 -0
- data/golden/cascade_logic/schema.kumi +16 -0
- data/golden/mixed_nesting/schema.kumi +42 -0
- data/golden/simple_math/schema.kumi +10 -0
- data/lib/kumi/analyzer.rb +72 -21
- data/lib/kumi/core/analyzer/checkpoint.rb +72 -0
- data/lib/kumi/core/analyzer/debug.rb +167 -0
- data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +1 -3
- data/lib/kumi/core/analyzer/passes/function_signature_pass.rb +199 -0
- data/lib/kumi/core/analyzer/passes/load_input_cse.rb +120 -0
- data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +72 -157
- data/lib/kumi/core/analyzer/passes/toposorter.rb +37 -1
- data/lib/kumi/core/analyzer/state_serde.rb +64 -0
- data/lib/kumi/core/analyzer/structs/access_plan.rb +12 -10
- data/lib/kumi/core/compiler/access_planner.rb +3 -2
- data/lib/kumi/core/function_registry/collection_functions.rb +3 -1
- data/lib/kumi/core/functions/dimension.rb +98 -0
- data/lib/kumi/core/functions/dtypes.rb +20 -0
- data/lib/kumi/core/functions/errors.rb +11 -0
- data/lib/kumi/core/functions/kernel_adapter.rb +45 -0
- data/lib/kumi/core/functions/loader.rb +119 -0
- data/lib/kumi/core/functions/registry_v2.rb +68 -0
- data/lib/kumi/core/functions/shape.rb +70 -0
- data/lib/kumi/core/functions/signature.rb +122 -0
- data/lib/kumi/core/functions/signature_parser.rb +86 -0
- data/lib/kumi/core/functions/signature_resolver.rb +272 -0
- data/lib/kumi/core/ir/execution_engine/interpreter.rb +98 -7
- data/lib/kumi/core/ir/execution_engine/profiler.rb +202 -0
- data/lib/kumi/dev/ir.rb +75 -0
- data/lib/kumi/dev/parse.rb +105 -0
- data/lib/kumi/dev/runner.rb +83 -0
- data/lib/kumi/frontends/ruby.rb +28 -0
- data/lib/kumi/frontends/text.rb +46 -0
- data/lib/kumi/frontends.rb +29 -0
- data/lib/kumi/kernels/ruby/aggregate_core.rb +105 -0
- data/lib/kumi/kernels/ruby/datetime_scalar.rb +21 -0
- data/lib/kumi/kernels/ruby/mask_scalar.rb +15 -0
- data/lib/kumi/kernels/ruby/scalar_core.rb +63 -0
- data/lib/kumi/kernels/ruby/string_scalar.rb +19 -0
- data/lib/kumi/kernels/ruby/vector_struct.rb +39 -0
- data/lib/kumi/runtime/executable.rb +57 -26
- data/lib/kumi/schema.rb +4 -4
- data/lib/kumi/support/diff.rb +22 -0
- data/lib/kumi/support/ir_render.rb +61 -0
- data/lib/kumi/version.rb +1 -1
- data/lib/kumi.rb +2 -0
- data/performance_results.txt +63 -0
- data/scripts/test_mixed_nesting_performance.rb +206 -0
- metadata +45 -5
- data/docs/features/javascript-transpiler.md +0 -148
- data/lib/kumi/js.rb +0 -23
- data/lib/kumi/support/ir_dump.rb +0 -491
@@ -0,0 +1,98 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "errors"
|
4
|
+
|
5
|
+
module Kumi
|
6
|
+
module Core
|
7
|
+
module Functions
|
8
|
+
# Represents a single dimension in a signature with NEP 20 support.
|
9
|
+
#
|
10
|
+
# A dimension can be:
|
11
|
+
# - Named dimension (symbol): :i, :j, :n
|
12
|
+
# - Fixed-size dimension (integer): 2, 3, 10
|
13
|
+
# - With modifiers:
|
14
|
+
# - flexible (?): can be omitted if not present in all operands
|
15
|
+
# - broadcastable (|1): can broadcast against size-1 dimensions
|
16
|
+
#
|
17
|
+
# Examples:
|
18
|
+
# Dimension.new(:i) # named dimension 'i'
|
19
|
+
# Dimension.new(3) # fixed-size dimension of size 3
|
20
|
+
# Dimension.new(:n, flexible: true) # dimension 'n' that can be omitted
|
21
|
+
# Dimension.new(:i, broadcastable: true) # dimension 'i' that can broadcast
|
22
|
+
class Dimension
|
23
|
+
attr_reader :name, :flexible, :broadcastable
|
24
|
+
|
25
|
+
def initialize(name, flexible: false, broadcastable: false)
|
26
|
+
@name = name
|
27
|
+
@flexible = flexible
|
28
|
+
@broadcastable = broadcastable
|
29
|
+
|
30
|
+
validate!
|
31
|
+
freeze
|
32
|
+
end
|
33
|
+
|
34
|
+
def fixed_size?
|
35
|
+
@name.is_a?(Integer)
|
36
|
+
end
|
37
|
+
|
38
|
+
def named?
|
39
|
+
@name.is_a?(Symbol)
|
40
|
+
end
|
41
|
+
|
42
|
+
def flexible?
|
43
|
+
@flexible
|
44
|
+
end
|
45
|
+
|
46
|
+
def broadcastable?
|
47
|
+
@broadcastable
|
48
|
+
end
|
49
|
+
|
50
|
+
def size
|
51
|
+
fixed_size? ? @name : nil
|
52
|
+
end
|
53
|
+
|
54
|
+
def ==(other)
|
55
|
+
other.is_a?(Dimension) &&
|
56
|
+
name == other.name &&
|
57
|
+
flexible == other.flexible &&
|
58
|
+
broadcastable == other.broadcastable
|
59
|
+
end
|
60
|
+
|
61
|
+
def eql?(other)
|
62
|
+
self == other
|
63
|
+
end
|
64
|
+
|
65
|
+
def hash
|
66
|
+
[name, flexible, broadcastable].hash
|
67
|
+
end
|
68
|
+
|
69
|
+
def to_s
|
70
|
+
str = name.to_s
|
71
|
+
str += "?" if flexible?
|
72
|
+
str += "|1" if broadcastable?
|
73
|
+
str
|
74
|
+
end
|
75
|
+
|
76
|
+
def inspect
|
77
|
+
"#<Dimension #{self}>"
|
78
|
+
end
|
79
|
+
|
80
|
+
private
|
81
|
+
|
82
|
+
def validate!
|
83
|
+
unless name.is_a?(Symbol) || name.is_a?(Integer)
|
84
|
+
raise SignatureError, "dimension name must be a symbol or integer, got: #{name.inspect}"
|
85
|
+
end
|
86
|
+
|
87
|
+
raise SignatureError, "fixed-size dimension must be positive, got: #{name}" if name.is_a?(Integer) && name <= 0
|
88
|
+
|
89
|
+
raise SignatureError, "dimension cannot be both flexible and broadcastable" if flexible? && broadcastable?
|
90
|
+
|
91
|
+
return unless fixed_size? && flexible?
|
92
|
+
|
93
|
+
raise SignatureError, "fixed-size dimension cannot be flexible"
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
98
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Kumi::Core::Functions
|
2
|
+
DType = Struct.new(:name, keyword_init: true)
|
3
|
+
module DTypes
|
4
|
+
BOOL = DType.new(name: :bool)
|
5
|
+
INT = DType.new(name: :int)
|
6
|
+
FLOAT = DType.new(name: :float)
|
7
|
+
STRING = DType.new(name: :string)
|
8
|
+
DATETIME = DType.new(name: :datetime)
|
9
|
+
ANY = DType.new(name: :any)
|
10
|
+
end
|
11
|
+
|
12
|
+
module Promotion
|
13
|
+
# super simple table; extend later
|
14
|
+
TABLE = {
|
15
|
+
%i[int int] => :int, %i[int float] => :float, %i[float int] => :float, %i[float float] => :float,
|
16
|
+
%i[bool int] => :int, %i[bool float] => :float, %i[bool bool] => :bool
|
17
|
+
}
|
18
|
+
def self.promote(a, b) = TABLE[[a, b]] || TABLE[[b, a]] || :any
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Kumi
|
4
|
+
module Core
|
5
|
+
module Functions
|
6
|
+
KernelHandle = Struct.new(:kind, :callable, :null_policy, :options, keyword_init: true)
|
7
|
+
|
8
|
+
module KernelAdapter
|
9
|
+
module_function
|
10
|
+
|
11
|
+
def build_for(function, backend_entry)
|
12
|
+
impl = backend_entry.impl
|
13
|
+
mod = ruby_module_for(function)
|
14
|
+
|
15
|
+
raise Kumi::Core::Errors::CompilationError, "Missing Ruby kernel #{impl} for #{function.name}" unless mod.respond_to?(impl)
|
16
|
+
|
17
|
+
kind = function.class_sym
|
18
|
+
KernelHandle.new(
|
19
|
+
kind: kind,
|
20
|
+
callable: mod.method(impl),
|
21
|
+
null_policy: function.null_policy,
|
22
|
+
options: function.options || {}
|
23
|
+
)
|
24
|
+
end
|
25
|
+
|
26
|
+
def ruby_module_for(function)
|
27
|
+
case function.domain.to_sym
|
28
|
+
when :core
|
29
|
+
function.class_sym == :aggregate ? Kumi::Kernels::Ruby::AggregateCore : Kumi::Kernels::Ruby::ScalarCore
|
30
|
+
when :string
|
31
|
+
Kumi::Kernels::Ruby::StringScalar
|
32
|
+
when :datetime
|
33
|
+
Kumi::Kernels::Ruby::DatetimeScalar
|
34
|
+
when :struct
|
35
|
+
Kumi::Kernels::Ruby::VectorStruct
|
36
|
+
when :mask
|
37
|
+
Kumi::Kernels::Ruby::MaskScalar
|
38
|
+
else
|
39
|
+
raise Kumi::Core::Errors::CompilationError, "Unknown domain #{function.domain}"
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
@@ -0,0 +1,119 @@
|
|
1
|
+
require "yaml"
|
2
|
+
module Kumi::Core::Functions
|
3
|
+
Function = Data.define(:name, :domain, :opset, :class_sym, :signatures,
|
4
|
+
:type_vars, :dtypes, :null_policy, :algebra, :vectorization,
|
5
|
+
:options, :traits, :shape_fn, :kernels)
|
6
|
+
|
7
|
+
KernelEntry = Data.define(:backend, :impl, :priority, :when_)
|
8
|
+
|
9
|
+
class Loader
|
10
|
+
def self.load_file(path)
|
11
|
+
doc = YAML.load_file(path)
|
12
|
+
functions = doc.map { |h| build_function(h) }
|
13
|
+
validate!(functions)
|
14
|
+
functions.freeze
|
15
|
+
end
|
16
|
+
|
17
|
+
def self.build_function(h)
|
18
|
+
sigs = Array(h.fetch("signature")).map { |s| parse_signature(s) }
|
19
|
+
kernels = Array(h.fetch("kernels", [])).map do |k|
|
20
|
+
KernelEntry.new(backend: k.fetch("backend").to_sym,
|
21
|
+
impl: k.fetch("impl").to_sym,
|
22
|
+
priority: k.fetch("priority", 0).to_i,
|
23
|
+
when_: k["when"]&.transform_keys!(&:to_sym))
|
24
|
+
end
|
25
|
+
Function.new(
|
26
|
+
name: h.fetch("name"),
|
27
|
+
domain: h.fetch("domain"),
|
28
|
+
opset: h.fetch("opset").to_i,
|
29
|
+
class_sym: h.fetch("class").to_sym,
|
30
|
+
signatures: sigs,
|
31
|
+
type_vars: h["type_vars"] || {},
|
32
|
+
dtypes: h["dtypes"] || {},
|
33
|
+
null_policy: (h["null_policy"] || "propagate").to_sym,
|
34
|
+
algebra: (h["algebra"] || {}).transform_keys!(&:to_sym),
|
35
|
+
vectorization: (h["vectorization"] || {}).transform_keys!(&:to_sym),
|
36
|
+
options: h["options"] || {},
|
37
|
+
traits: (h["traits"] || {}).transform_keys!(&:to_sym),
|
38
|
+
shape_fn: h["shape_fn"],
|
39
|
+
kernels: kernels
|
40
|
+
)
|
41
|
+
end
|
42
|
+
|
43
|
+
# " (i),(j)->(i,j)@product " → Signature
|
44
|
+
def self.parse_signature(s)
|
45
|
+
lhs, rhs = s.split("->").map(&:strip)
|
46
|
+
out_axes, policy = rhs.split("@").map(&:strip)
|
47
|
+
|
48
|
+
# Parse input shapes by splitting on commas between parentheses
|
49
|
+
in_shapes = parse_input_shapes(lhs)
|
50
|
+
|
51
|
+
Signature.new(in_shapes: in_shapes,
|
52
|
+
out_shape: parse_axes(out_axes),
|
53
|
+
join_policy: policy&.to_sym)
|
54
|
+
end
|
55
|
+
|
56
|
+
# Parse "(i),(j)" or "(i,j),(k,l)" properly
|
57
|
+
def self.parse_input_shapes(lhs)
|
58
|
+
# Find all parenthesized groups
|
59
|
+
shapes = []
|
60
|
+
current_pos = 0
|
61
|
+
|
62
|
+
while current_pos < lhs.length
|
63
|
+
# Find the next opening parenthesis
|
64
|
+
start_paren = lhs.index('(', current_pos)
|
65
|
+
break unless start_paren
|
66
|
+
|
67
|
+
# Find the matching closing parenthesis
|
68
|
+
paren_count = 0
|
69
|
+
end_paren = start_paren
|
70
|
+
|
71
|
+
(start_paren..lhs.length-1).each do |i|
|
72
|
+
case lhs[i]
|
73
|
+
when '('
|
74
|
+
paren_count += 1
|
75
|
+
when ')'
|
76
|
+
paren_count -= 1
|
77
|
+
if paren_count == 0
|
78
|
+
end_paren = i
|
79
|
+
break
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
84
|
+
# Extract the shape
|
85
|
+
shape_str = lhs[start_paren..end_paren]
|
86
|
+
shapes << parse_axes(shape_str)
|
87
|
+
current_pos = end_paren + 1
|
88
|
+
end
|
89
|
+
|
90
|
+
shapes
|
91
|
+
end
|
92
|
+
|
93
|
+
def self.parse_axes(txt)
|
94
|
+
txt = txt.strip
|
95
|
+
return [] if txt == "()" || txt.empty?
|
96
|
+
|
97
|
+
inner = txt.sub("(", "").sub(")", "")
|
98
|
+
if inner.empty?
|
99
|
+
[]
|
100
|
+
else
|
101
|
+
inner.split(",").map do |a|
|
102
|
+
dim_name = a.strip.to_sym
|
103
|
+
Dimension.new(dim_name) # Convert to Dimension objects
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
|
108
|
+
def self.validate!(fns)
|
109
|
+
names = {}
|
110
|
+
fns.each do |f|
|
111
|
+
key = [f.domain, f.name, f.opset]
|
112
|
+
raise "duplicate function #{key}" if names[key]
|
113
|
+
|
114
|
+
names[key] = true
|
115
|
+
raise "no kernels for #{f.name}" if f.kernels.empty?
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
module Kumi::Core::Functions
|
2
|
+
class RegistryV2
|
3
|
+
def initialize(functions:)
|
4
|
+
@by_name = functions.group_by(&:name).transform_values { |v| v.sort_by(&:opset).freeze }.freeze
|
5
|
+
end
|
6
|
+
|
7
|
+
# Factory method to load from YAML configuration
|
8
|
+
def self.load_from_file(path = nil)
|
9
|
+
path ||= File.join(__dir__, "../../../..", "config", "functions.yaml")
|
10
|
+
functions = Loader.load_file(path)
|
11
|
+
new(functions: functions)
|
12
|
+
end
|
13
|
+
|
14
|
+
def fetch(name, opset: nil)
|
15
|
+
list = @by_name[name] or raise KeyError, "unknown function #{name}"
|
16
|
+
opset ? list.find { |f| f.opset == opset } : list.last
|
17
|
+
end
|
18
|
+
|
19
|
+
# Get function signatures for NEP-20 signature resolution
|
20
|
+
# This bridges RegistryV2 with our existing SignatureResolver
|
21
|
+
def get_function_signatures(name, opset: nil)
|
22
|
+
begin
|
23
|
+
fn = fetch(name, opset: opset)
|
24
|
+
# Convert Signature objects to string representations for NEP-20 parser
|
25
|
+
fn.signatures.map(&:to_signature_string)
|
26
|
+
rescue KeyError
|
27
|
+
[] # Function not found in RegistryV2 - fall back to legacy registry
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
# Enhanced signature resolution using NEP-20 resolver
|
32
|
+
def choose_signature(fn, arg_shapes)
|
33
|
+
# Use our NEP-20 SignatureResolver for proper dimension handling
|
34
|
+
sig_strings = fn.signatures.map(&:to_signature_string)
|
35
|
+
parsed_sigs = sig_strings.map { |s| SignatureParser.parse(s) }
|
36
|
+
|
37
|
+
plan = SignatureResolver.choose(signatures: parsed_sigs, arg_shapes: arg_shapes)
|
38
|
+
|
39
|
+
# Return both the original function signature and the resolution plan
|
40
|
+
{
|
41
|
+
function: fn,
|
42
|
+
signature: plan[:signature],
|
43
|
+
plan: plan
|
44
|
+
}
|
45
|
+
rescue SignatureMatchError => e
|
46
|
+
raise ArgumentError, "no matching signature for #{fn.name} with shapes #{arg_shapes.inspect}: #{e.message}"
|
47
|
+
end
|
48
|
+
|
49
|
+
def resolve_kernel(fn, backend:, conditions: {})
|
50
|
+
ks = fn.kernels.select { |k| k.backend == backend.to_sym }
|
51
|
+
ks = ks.select { |k| conditions.all? { |ck, cv| k.when_&.fetch(ck, cv) == cv } } unless conditions.empty?
|
52
|
+
ks.max_by(&:priority) or raise "no kernel for #{fn.name} backend=#{backend}"
|
53
|
+
end
|
54
|
+
|
55
|
+
# Introspection methods
|
56
|
+
def all_function_names
|
57
|
+
@by_name.keys
|
58
|
+
end
|
59
|
+
|
60
|
+
def function_exists?(name)
|
61
|
+
@by_name.key?(name)
|
62
|
+
end
|
63
|
+
|
64
|
+
def all_functions
|
65
|
+
@by_name.values.flatten
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "dimension"
|
4
|
+
|
5
|
+
module Kumi
|
6
|
+
module Core
|
7
|
+
module Functions
|
8
|
+
# Shape utilities with NEP 20 support. A "shape" is an Array<Dimension> of dimension objects.
|
9
|
+
# [] == scalar, [Dimension.new(:i)] == vector along :i, etc.
|
10
|
+
module Shape
|
11
|
+
module_function
|
12
|
+
|
13
|
+
def scalar?(shape) = shape.empty?
|
14
|
+
|
15
|
+
def equal?(a, b) = a.map(&:name) == b.map(&:name)
|
16
|
+
|
17
|
+
# NEP 20 broadcast rules:
|
18
|
+
# - scalar can broadcast to any expected shape
|
19
|
+
# - fixed-size dimensions must match exactly
|
20
|
+
# - broadcastable dimensions with |1 modifier can broadcast against size-1
|
21
|
+
# - flexible dimensions with ? can be omitted if not present in all operands
|
22
|
+
def broadcastable?(got:, expected:)
|
23
|
+
return true if scalar?(got)
|
24
|
+
return false if got.length != expected.length
|
25
|
+
|
26
|
+
got.zip(expected).all? do |got_dim, exp_dim|
|
27
|
+
broadcastable_dimension?(got: got_dim, expected: exp_dim)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def broadcastable_dimension?(got:, expected:)
|
32
|
+
# Same name and modifiers
|
33
|
+
return true if got == expected
|
34
|
+
|
35
|
+
# Same name, different modifiers - check compatibility
|
36
|
+
if got.name == expected.name
|
37
|
+
# Fixed-size dimensions must match exactly
|
38
|
+
if got.fixed_size? || expected.fixed_size?
|
39
|
+
return got.size == expected.size if got.fixed_size? && expected.fixed_size?
|
40
|
+
return false # one fixed, one not - incompatible
|
41
|
+
end
|
42
|
+
|
43
|
+
# Both named dimensions with same name are compatible
|
44
|
+
return true
|
45
|
+
end
|
46
|
+
|
47
|
+
# Different names - only broadcastable with |1 modifier
|
48
|
+
expected.broadcastable? && scalar?([got])
|
49
|
+
end
|
50
|
+
|
51
|
+
# Check if a dimension can be omitted (NEP 20 flexible dimensions)
|
52
|
+
def flexible?(dim)
|
53
|
+
dim.is_a?(Dimension) && dim.flexible?
|
54
|
+
end
|
55
|
+
|
56
|
+
# Check if a dimension can broadcast (NEP 20 broadcastable dimensions)
|
57
|
+
def broadcastable_dimension?(dim)
|
58
|
+
dim.is_a?(Dimension) && dim.broadcastable?
|
59
|
+
end
|
60
|
+
|
61
|
+
# Convenience: find dimensions in set a that are not in set b
|
62
|
+
def dimensions_minus(a, b)
|
63
|
+
a_names = a.map(&:name).to_set
|
64
|
+
b_names = b.map(&:name).to_set
|
65
|
+
(a_names - b_names).to_a
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Kumi
|
4
|
+
module Core
|
5
|
+
module Functions
|
6
|
+
# Signature is a small immutable value object describing a function's
|
7
|
+
# vectorization contract with NEP 20 support.
|
8
|
+
#
|
9
|
+
# in_shapes: Array<Array<Dimension>> e.g., [[Dimension.new(:i)], [Dimension.new(:i)]]
|
10
|
+
# out_shape: Array<Dimension> e.g., [Dimension.new(:i)], [Dimension.new(:i), Dimension.new(:j)], or []
|
11
|
+
# join_policy: nil | :zip | :product
|
12
|
+
# raw: original string, for diagnostics (optional)
|
13
|
+
class Signature
|
14
|
+
attr_reader :in_shapes, :out_shape, :join_policy, :raw
|
15
|
+
|
16
|
+
def initialize(in_shapes:, out_shape:, join_policy: nil, raw: nil)
|
17
|
+
@in_shapes = deep_dup(in_shapes).freeze
|
18
|
+
@out_shape = out_shape.dup.freeze
|
19
|
+
@join_policy = join_policy&.to_sym
|
20
|
+
@raw = raw
|
21
|
+
validate!
|
22
|
+
freeze
|
23
|
+
end
|
24
|
+
|
25
|
+
def arity = @in_shapes.length
|
26
|
+
|
27
|
+
# Dimensions that appear in any input but not in output (i.e., reduced/dropped).
|
28
|
+
def dropped_axes
|
29
|
+
input_names = @in_shapes.flatten.map(&:name)
|
30
|
+
output_names = @out_shape.map(&:name)
|
31
|
+
(input_names - output_names).uniq.freeze
|
32
|
+
end
|
33
|
+
|
34
|
+
# True if any axis from inputs is dropped (common in aggregates).
|
35
|
+
def reduction?
|
36
|
+
!dropped_axes.empty?
|
37
|
+
end
|
38
|
+
|
39
|
+
def to_h
|
40
|
+
{
|
41
|
+
in_shapes: in_shapes.map(&:dup),
|
42
|
+
out_shape: out_shape.dup,
|
43
|
+
join_policy: join_policy,
|
44
|
+
raw: raw
|
45
|
+
}
|
46
|
+
end
|
47
|
+
|
48
|
+
def inspect
|
49
|
+
"#<Signature #{format_signature}#{" @#{join_policy}" if join_policy}>"
|
50
|
+
end
|
51
|
+
|
52
|
+
def format_signature
|
53
|
+
lhs = in_shapes.map { |dims| "(#{dims.map(&:to_s).join(',')})" }.join(",")
|
54
|
+
rhs = "(#{out_shape.map(&:to_s).join(',')})"
|
55
|
+
"#{lhs}->#{rhs}"
|
56
|
+
end
|
57
|
+
|
58
|
+
# Convert back to string representation for NEP-20 parser compatibility
|
59
|
+
def to_signature_string
|
60
|
+
sig_str = format_signature
|
61
|
+
join_policy ? "#{sig_str}@#{join_policy}" : sig_str
|
62
|
+
end
|
63
|
+
|
64
|
+
private
|
65
|
+
|
66
|
+
def validate!
|
67
|
+
unless @in_shapes.is_a?(Array) && @in_shapes.all? { |s| s.is_a?(Array) }
|
68
|
+
raise SignatureError, "in_shapes must be an array of dimension arrays"
|
69
|
+
end
|
70
|
+
|
71
|
+
@in_shapes.each_with_index do |dims, idx|
|
72
|
+
validate_dimensions!(dims, where: "in_shapes[#{idx}]")
|
73
|
+
end
|
74
|
+
|
75
|
+
validate_dimensions!(@out_shape, where: "out_shape")
|
76
|
+
|
77
|
+
unless [nil, :zip, :product].include?(@join_policy)
|
78
|
+
raise SignatureError, "join_policy must be nil, :zip, or :product; got #{@join_policy.inspect}"
|
79
|
+
end
|
80
|
+
|
81
|
+
# Validate NEP 20 constraints
|
82
|
+
validate_nep20_constraints!
|
83
|
+
end
|
84
|
+
|
85
|
+
def validate_dimensions!(dims, where: "shape")
|
86
|
+
unless dims.is_a?(Array) && dims.all? { |d| d.is_a?(Dimension) }
|
87
|
+
raise SignatureError, "#{where}: must be an array of Dimension objects, got: #{dims.inspect}"
|
88
|
+
end
|
89
|
+
|
90
|
+
# Check for duplicate dimension names within a single argument
|
91
|
+
names = dims.map(&:name)
|
92
|
+
duplicates = names.group_by { |n| n }.select { |_, v| v.size > 1 }.keys
|
93
|
+
raise SignatureError, "#{where}: duplicate dimension names #{duplicates.inspect}" unless duplicates.empty?
|
94
|
+
|
95
|
+
true
|
96
|
+
end
|
97
|
+
|
98
|
+
def validate_nep20_constraints!
|
99
|
+
# Broadcastable dimensions should only appear in inputs, not outputs
|
100
|
+
@out_shape.each do |dim|
|
101
|
+
raise SignatureError, "output dimension #{dim} cannot be broadcastable" if dim.broadcastable?
|
102
|
+
end
|
103
|
+
|
104
|
+
# Fixed-size dimensions in outputs must match corresponding input dimensions
|
105
|
+
all_input_dims = @in_shapes.flatten
|
106
|
+
@out_shape.each do |out_dim|
|
107
|
+
next unless out_dim.fixed_size?
|
108
|
+
|
109
|
+
matching_inputs = all_input_dims.select { |in_dim| in_dim.name == out_dim.name }
|
110
|
+
matching_inputs.each do |in_dim|
|
111
|
+
if in_dim.fixed_size? && in_dim.size != out_dim.size
|
112
|
+
raise SignatureError, "fixed-size dimension #{out_dim.name} has inconsistent sizes: #{in_dim.size} vs #{out_dim.size}"
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
def deep_dup(arr) = arr.map { |x| x.is_a?(Array) ? x.dup : x }
|
119
|
+
end
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "dimension"
|
4
|
+
require_relative "signature"
|
5
|
+
|
6
|
+
module Kumi
|
7
|
+
module Core
|
8
|
+
module Functions
|
9
|
+
# Parses NEP 20 conformant signature strings like:
|
10
|
+
# "(),()->()" # scalar operations
|
11
|
+
# "(i),(i)->(i)" # vector operations
|
12
|
+
# "(i),(j)->(i,j)@product" # matrix operations with join policy
|
13
|
+
# "(i,j)->(i)" # reduction of :j
|
14
|
+
# "(3),(3)->(3)" # fixed-size 3-vectors (cross product)
|
15
|
+
# "(i?),(i?)->(i?)" # flexible dimensions
|
16
|
+
# "(i|1),(i|1)->()" # broadcastable dimensions
|
17
|
+
# "(m?,n),(n,p?)->(m?,p?)" # matmul signature
|
18
|
+
class SignatureParser
|
19
|
+
class << self
|
20
|
+
def parse(str)
|
21
|
+
raise SignatureParseError, "empty signature" if str.nil? || str.strip.empty?
|
22
|
+
lhs, rhs = str.split("->", 2)&.map!(&:strip)
|
23
|
+
raise SignatureParseError, "signature must contain '->': #{str.inspect}" unless rhs
|
24
|
+
|
25
|
+
out_spec, policy = rhs.split("@", 2)&.map!(&:strip)
|
26
|
+
in_shapes = parse_many(lhs)
|
27
|
+
out_shape = parse_axes(out_spec)
|
28
|
+
join_policy = policy&.to_sym
|
29
|
+
|
30
|
+
Signature.new(in_shapes: in_shapes, out_shape: out_shape, join_policy: join_policy, raw: str)
|
31
|
+
rescue SignatureError => e
|
32
|
+
raise
|
33
|
+
rescue StandardError => e
|
34
|
+
raise SignatureParseError, "invalid signature #{str.inspect}: #{e.message}"
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def parse_many(lhs)
|
40
|
+
# Handle zero arguments case
|
41
|
+
return [] if lhs.strip.empty?
|
42
|
+
|
43
|
+
# split by commas that are *between* groups, not inside them
|
44
|
+
# simpler approach: split by '),', re-add ')' where needed
|
45
|
+
tokens = lhs.split("),").map { |t| t.strip.end_with?(")") ? t.strip : "#{t.strip})" }
|
46
|
+
tokens.map { |tok| parse_axes(tok) }
|
47
|
+
end
|
48
|
+
|
49
|
+
def parse_axes(spec)
|
50
|
+
spec = spec.strip
|
51
|
+
raise SignatureParseError, "missing parentheses in #{spec.inspect}" unless spec.start_with?("(") && spec.end_with?(")")
|
52
|
+
|
53
|
+
inner = spec[1..-2].strip
|
54
|
+
return [] if inner.empty?
|
55
|
+
|
56
|
+
inner.split(",").map { |dim_str| parse_dimension(dim_str.strip) }
|
57
|
+
end
|
58
|
+
|
59
|
+
# Parse a single dimension with NEP 20 modifiers
|
60
|
+
# Examples: "i", "3", "n?", "i|1"
|
61
|
+
def parse_dimension(dim_str)
|
62
|
+
return Dimension.new(:empty) if dim_str.empty?
|
63
|
+
|
64
|
+
# Extract modifiers
|
65
|
+
flexible = dim_str.end_with?("?")
|
66
|
+
dim_str = dim_str.chomp("?") if flexible
|
67
|
+
|
68
|
+
broadcastable = dim_str.end_with?("|1")
|
69
|
+
dim_str = dim_str.chomp("|1") if broadcastable
|
70
|
+
|
71
|
+
# Parse name (symbol or integer)
|
72
|
+
name = if dim_str.match?(/^\d+$/)
|
73
|
+
dim_str.to_i
|
74
|
+
else
|
75
|
+
dim_str.to_sym
|
76
|
+
end
|
77
|
+
|
78
|
+
Dimension.new(name, flexible: flexible, broadcastable: broadcastable)
|
79
|
+
rescue StandardError => e
|
80
|
+
raise SignatureParseError, "invalid dimension #{dim_str.inspect}: #{e.message}"
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|