kumi 0.0.12 → 0.0.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +4 -4
  2. data/.rspec +0 -1
  3. data/BACKLOG.md +34 -0
  4. data/CHANGELOG.md +15 -0
  5. data/CLAUDE.md +4 -6
  6. data/README.md +0 -18
  7. data/config/functions.yaml +352 -0
  8. data/docs/dev/analyzer-debug.md +52 -0
  9. data/docs/dev/parse-command.md +64 -0
  10. data/docs/functions/analyzer_integration.md +199 -0
  11. data/docs/functions/signatures.md +171 -0
  12. data/examples/hash_objects_demo.rb +138 -0
  13. data/golden/array_operations/schema.kumi +17 -0
  14. data/golden/cascade_logic/schema.kumi +16 -0
  15. data/golden/mixed_nesting/schema.kumi +42 -0
  16. data/golden/simple_math/schema.kumi +10 -0
  17. data/lib/kumi/analyzer.rb +72 -21
  18. data/lib/kumi/core/analyzer/checkpoint.rb +72 -0
  19. data/lib/kumi/core/analyzer/debug.rb +167 -0
  20. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +1 -3
  21. data/lib/kumi/core/analyzer/passes/function_signature_pass.rb +199 -0
  22. data/lib/kumi/core/analyzer/passes/load_input_cse.rb +120 -0
  23. data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +99 -151
  24. data/lib/kumi/core/analyzer/passes/toposorter.rb +37 -1
  25. data/lib/kumi/core/analyzer/state_serde.rb +64 -0
  26. data/lib/kumi/core/analyzer/structs/access_plan.rb +12 -10
  27. data/lib/kumi/core/compiler/access_planner.rb +3 -2
  28. data/lib/kumi/core/function_registry/collection_functions.rb +3 -1
  29. data/lib/kumi/core/functions/dimension.rb +98 -0
  30. data/lib/kumi/core/functions/dtypes.rb +20 -0
  31. data/lib/kumi/core/functions/errors.rb +11 -0
  32. data/lib/kumi/core/functions/kernel_adapter.rb +45 -0
  33. data/lib/kumi/core/functions/loader.rb +119 -0
  34. data/lib/kumi/core/functions/registry_v2.rb +68 -0
  35. data/lib/kumi/core/functions/shape.rb +70 -0
  36. data/lib/kumi/core/functions/signature.rb +122 -0
  37. data/lib/kumi/core/functions/signature_parser.rb +86 -0
  38. data/lib/kumi/core/functions/signature_resolver.rb +272 -0
  39. data/lib/kumi/core/ir/execution_engine/interpreter.rb +98 -7
  40. data/lib/kumi/core/ir/execution_engine/profiler.rb +202 -0
  41. data/lib/kumi/core/ir/execution_engine.rb +30 -1
  42. data/lib/kumi/dev/ir.rb +75 -0
  43. data/lib/kumi/dev/parse.rb +105 -0
  44. data/lib/kumi/dev/runner.rb +83 -0
  45. data/lib/kumi/frontends/ruby.rb +28 -0
  46. data/lib/kumi/frontends/text.rb +46 -0
  47. data/lib/kumi/frontends.rb +29 -0
  48. data/lib/kumi/kernels/ruby/aggregate_core.rb +105 -0
  49. data/lib/kumi/kernels/ruby/datetime_scalar.rb +21 -0
  50. data/lib/kumi/kernels/ruby/mask_scalar.rb +15 -0
  51. data/lib/kumi/kernels/ruby/scalar_core.rb +63 -0
  52. data/lib/kumi/kernels/ruby/string_scalar.rb +19 -0
  53. data/lib/kumi/kernels/ruby/vector_struct.rb +39 -0
  54. data/lib/kumi/runtime/executable.rb +63 -20
  55. data/lib/kumi/schema.rb +4 -4
  56. data/lib/kumi/support/diff.rb +22 -0
  57. data/lib/kumi/support/ir_render.rb +61 -0
  58. data/lib/kumi/version.rb +1 -1
  59. data/lib/kumi.rb +2 -0
  60. data/performance_results.txt +63 -0
  61. data/scripts/test_mixed_nesting_performance.rb +206 -0
  62. metadata +45 -5
  63. data/docs/features/javascript-transpiler.md +0 -148
  64. data/lib/kumi/js.rb +0 -23
  65. data/lib/kumi/support/ir_dump.rb +0 -491
@@ -0,0 +1,98 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "errors"
4
+
5
+ module Kumi
6
+ module Core
7
+ module Functions
8
+ # Represents a single dimension in a signature with NEP 20 support.
9
+ #
10
+ # A dimension can be:
11
+ # - Named dimension (symbol): :i, :j, :n
12
+ # - Fixed-size dimension (integer): 2, 3, 10
13
+ # - With modifiers:
14
+ # - flexible (?): can be omitted if not present in all operands
15
+ # - broadcastable (|1): can broadcast against size-1 dimensions
16
+ #
17
+ # Examples:
18
+ # Dimension.new(:i) # named dimension 'i'
19
+ # Dimension.new(3) # fixed-size dimension of size 3
20
+ # Dimension.new(:n, flexible: true) # dimension 'n' that can be omitted
21
+ # Dimension.new(:i, broadcastable: true) # dimension 'i' that can broadcast
22
+ class Dimension
23
+ attr_reader :name, :flexible, :broadcastable
24
+
25
+ def initialize(name, flexible: false, broadcastable: false)
26
+ @name = name
27
+ @flexible = flexible
28
+ @broadcastable = broadcastable
29
+
30
+ validate!
31
+ freeze
32
+ end
33
+
34
+ def fixed_size?
35
+ @name.is_a?(Integer)
36
+ end
37
+
38
+ def named?
39
+ @name.is_a?(Symbol)
40
+ end
41
+
42
+ def flexible?
43
+ @flexible
44
+ end
45
+
46
+ def broadcastable?
47
+ @broadcastable
48
+ end
49
+
50
+ def size
51
+ fixed_size? ? @name : nil
52
+ end
53
+
54
+ def ==(other)
55
+ other.is_a?(Dimension) &&
56
+ name == other.name &&
57
+ flexible == other.flexible &&
58
+ broadcastable == other.broadcastable
59
+ end
60
+
61
+ def eql?(other)
62
+ self == other
63
+ end
64
+
65
+ def hash
66
+ [name, flexible, broadcastable].hash
67
+ end
68
+
69
+ def to_s
70
+ str = name.to_s
71
+ str += "?" if flexible?
72
+ str += "|1" if broadcastable?
73
+ str
74
+ end
75
+
76
+ def inspect
77
+ "#<Dimension #{self}>"
78
+ end
79
+
80
+ private
81
+
82
+ def validate!
83
+ unless name.is_a?(Symbol) || name.is_a?(Integer)
84
+ raise SignatureError, "dimension name must be a symbol or integer, got: #{name.inspect}"
85
+ end
86
+
87
+ raise SignatureError, "fixed-size dimension must be positive, got: #{name}" if name.is_a?(Integer) && name <= 0
88
+
89
+ raise SignatureError, "dimension cannot be both flexible and broadcastable" if flexible? && broadcastable?
90
+
91
+ return unless fixed_size? && flexible?
92
+
93
+ raise SignatureError, "fixed-size dimension cannot be flexible"
94
+ end
95
+ end
96
+ end
97
+ end
98
+ end
@@ -0,0 +1,20 @@
1
+ module Kumi::Core::Functions
2
+ DType = Struct.new(:name, keyword_init: true)
3
+ module DTypes
4
+ BOOL = DType.new(name: :bool)
5
+ INT = DType.new(name: :int)
6
+ FLOAT = DType.new(name: :float)
7
+ STRING = DType.new(name: :string)
8
+ DATETIME = DType.new(name: :datetime)
9
+ ANY = DType.new(name: :any)
10
+ end
11
+
12
+ module Promotion
13
+ # super simple table; extend later
14
+ TABLE = {
15
+ %i[int int] => :int, %i[int float] => :float, %i[float int] => :float, %i[float float] => :float,
16
+ %i[bool int] => :int, %i[bool float] => :float, %i[bool bool] => :bool
17
+ }
18
+ def self.promote(a, b) = TABLE[[a, b]] || TABLE[[b, a]] || :any
19
+ end
20
+ end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Functions
6
+ class SignatureError < StandardError; end
7
+ class SignatureParseError < SignatureError; end
8
+ class SignatureMatchError < SignatureError; end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Functions
6
+ KernelHandle = Struct.new(:kind, :callable, :null_policy, :options, keyword_init: true)
7
+
8
+ module KernelAdapter
9
+ module_function
10
+
11
+ def build_for(function, backend_entry)
12
+ impl = backend_entry.impl
13
+ mod = ruby_module_for(function)
14
+
15
+ raise Kumi::Core::Errors::CompilationError, "Missing Ruby kernel #{impl} for #{function.name}" unless mod.respond_to?(impl)
16
+
17
+ kind = function.class_sym
18
+ KernelHandle.new(
19
+ kind: kind,
20
+ callable: mod.method(impl),
21
+ null_policy: function.null_policy,
22
+ options: function.options || {}
23
+ )
24
+ end
25
+
26
+ def ruby_module_for(function)
27
+ case function.domain.to_sym
28
+ when :core
29
+ function.class_sym == :aggregate ? Kumi::Kernels::Ruby::AggregateCore : Kumi::Kernels::Ruby::ScalarCore
30
+ when :string
31
+ Kumi::Kernels::Ruby::StringScalar
32
+ when :datetime
33
+ Kumi::Kernels::Ruby::DatetimeScalar
34
+ when :struct
35
+ Kumi::Kernels::Ruby::VectorStruct
36
+ when :mask
37
+ Kumi::Kernels::Ruby::MaskScalar
38
+ else
39
+ raise Kumi::Core::Errors::CompilationError, "Unknown domain #{function.domain}"
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
45
+ end
@@ -0,0 +1,119 @@
1
+ require "yaml"
2
+ module Kumi::Core::Functions
3
+ Function = Data.define(:name, :domain, :opset, :class_sym, :signatures,
4
+ :type_vars, :dtypes, :null_policy, :algebra, :vectorization,
5
+ :options, :traits, :shape_fn, :kernels)
6
+
7
+ KernelEntry = Data.define(:backend, :impl, :priority, :when_)
8
+
9
+ class Loader
10
+ def self.load_file(path)
11
+ doc = YAML.load_file(path)
12
+ functions = doc.map { |h| build_function(h) }
13
+ validate!(functions)
14
+ functions.freeze
15
+ end
16
+
17
+ def self.build_function(h)
18
+ sigs = Array(h.fetch("signature")).map { |s| parse_signature(s) }
19
+ kernels = Array(h.fetch("kernels", [])).map do |k|
20
+ KernelEntry.new(backend: k.fetch("backend").to_sym,
21
+ impl: k.fetch("impl").to_sym,
22
+ priority: k.fetch("priority", 0).to_i,
23
+ when_: k["when"]&.transform_keys!(&:to_sym))
24
+ end
25
+ Function.new(
26
+ name: h.fetch("name"),
27
+ domain: h.fetch("domain"),
28
+ opset: h.fetch("opset").to_i,
29
+ class_sym: h.fetch("class").to_sym,
30
+ signatures: sigs,
31
+ type_vars: h["type_vars"] || {},
32
+ dtypes: h["dtypes"] || {},
33
+ null_policy: (h["null_policy"] || "propagate").to_sym,
34
+ algebra: (h["algebra"] || {}).transform_keys!(&:to_sym),
35
+ vectorization: (h["vectorization"] || {}).transform_keys!(&:to_sym),
36
+ options: h["options"] || {},
37
+ traits: (h["traits"] || {}).transform_keys!(&:to_sym),
38
+ shape_fn: h["shape_fn"],
39
+ kernels: kernels
40
+ )
41
+ end
42
+
43
+ # " (i),(j)->(i,j)@product " → Signature
44
+ def self.parse_signature(s)
45
+ lhs, rhs = s.split("->").map(&:strip)
46
+ out_axes, policy = rhs.split("@").map(&:strip)
47
+
48
+ # Parse input shapes by splitting on commas between parentheses
49
+ in_shapes = parse_input_shapes(lhs)
50
+
51
+ Signature.new(in_shapes: in_shapes,
52
+ out_shape: parse_axes(out_axes),
53
+ join_policy: policy&.to_sym)
54
+ end
55
+
56
+ # Parse "(i),(j)" or "(i,j),(k,l)" properly
57
+ def self.parse_input_shapes(lhs)
58
+ # Find all parenthesized groups
59
+ shapes = []
60
+ current_pos = 0
61
+
62
+ while current_pos < lhs.length
63
+ # Find the next opening parenthesis
64
+ start_paren = lhs.index('(', current_pos)
65
+ break unless start_paren
66
+
67
+ # Find the matching closing parenthesis
68
+ paren_count = 0
69
+ end_paren = start_paren
70
+
71
+ (start_paren..lhs.length-1).each do |i|
72
+ case lhs[i]
73
+ when '('
74
+ paren_count += 1
75
+ when ')'
76
+ paren_count -= 1
77
+ if paren_count == 0
78
+ end_paren = i
79
+ break
80
+ end
81
+ end
82
+ end
83
+
84
+ # Extract the shape
85
+ shape_str = lhs[start_paren..end_paren]
86
+ shapes << parse_axes(shape_str)
87
+ current_pos = end_paren + 1
88
+ end
89
+
90
+ shapes
91
+ end
92
+
93
+ def self.parse_axes(txt)
94
+ txt = txt.strip
95
+ return [] if txt == "()" || txt.empty?
96
+
97
+ inner = txt.sub("(", "").sub(")", "")
98
+ if inner.empty?
99
+ []
100
+ else
101
+ inner.split(",").map do |a|
102
+ dim_name = a.strip.to_sym
103
+ Dimension.new(dim_name) # Convert to Dimension objects
104
+ end
105
+ end
106
+ end
107
+
108
+ def self.validate!(fns)
109
+ names = {}
110
+ fns.each do |f|
111
+ key = [f.domain, f.name, f.opset]
112
+ raise "duplicate function #{key}" if names[key]
113
+
114
+ names[key] = true
115
+ raise "no kernels for #{f.name}" if f.kernels.empty?
116
+ end
117
+ end
118
+ end
119
+ end
@@ -0,0 +1,68 @@
1
+ module Kumi::Core::Functions
2
+ class RegistryV2
3
+ def initialize(functions:)
4
+ @by_name = functions.group_by(&:name).transform_values { |v| v.sort_by(&:opset).freeze }.freeze
5
+ end
6
+
7
+ # Factory method to load from YAML configuration
8
+ def self.load_from_file(path = nil)
9
+ path ||= File.join(__dir__, "../../../..", "config", "functions.yaml")
10
+ functions = Loader.load_file(path)
11
+ new(functions: functions)
12
+ end
13
+
14
+ def fetch(name, opset: nil)
15
+ list = @by_name[name] or raise KeyError, "unknown function #{name}"
16
+ opset ? list.find { |f| f.opset == opset } : list.last
17
+ end
18
+
19
+ # Get function signatures for NEP-20 signature resolution
20
+ # This bridges RegistryV2 with our existing SignatureResolver
21
+ def get_function_signatures(name, opset: nil)
22
+ begin
23
+ fn = fetch(name, opset: opset)
24
+ # Convert Signature objects to string representations for NEP-20 parser
25
+ fn.signatures.map(&:to_signature_string)
26
+ rescue KeyError
27
+ [] # Function not found in RegistryV2 - fall back to legacy registry
28
+ end
29
+ end
30
+
31
+ # Enhanced signature resolution using NEP-20 resolver
32
+ def choose_signature(fn, arg_shapes)
33
+ # Use our NEP-20 SignatureResolver for proper dimension handling
34
+ sig_strings = fn.signatures.map(&:to_signature_string)
35
+ parsed_sigs = sig_strings.map { |s| SignatureParser.parse(s) }
36
+
37
+ plan = SignatureResolver.choose(signatures: parsed_sigs, arg_shapes: arg_shapes)
38
+
39
+ # Return both the original function signature and the resolution plan
40
+ {
41
+ function: fn,
42
+ signature: plan[:signature],
43
+ plan: plan
44
+ }
45
+ rescue SignatureMatchError => e
46
+ raise ArgumentError, "no matching signature for #{fn.name} with shapes #{arg_shapes.inspect}: #{e.message}"
47
+ end
48
+
49
+ def resolve_kernel(fn, backend:, conditions: {})
50
+ ks = fn.kernels.select { |k| k.backend == backend.to_sym }
51
+ ks = ks.select { |k| conditions.all? { |ck, cv| k.when_&.fetch(ck, cv) == cv } } unless conditions.empty?
52
+ ks.max_by(&:priority) or raise "no kernel for #{fn.name} backend=#{backend}"
53
+ end
54
+
55
+ # Introspection methods
56
+ def all_function_names
57
+ @by_name.keys
58
+ end
59
+
60
+ def function_exists?(name)
61
+ @by_name.key?(name)
62
+ end
63
+
64
+ def all_functions
65
+ @by_name.values.flatten
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,70 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "dimension"
4
+
5
+ module Kumi
6
+ module Core
7
+ module Functions
8
+ # Shape utilities with NEP 20 support. A "shape" is an Array<Dimension> of dimension objects.
9
+ # [] == scalar, [Dimension.new(:i)] == vector along :i, etc.
10
+ module Shape
11
+ module_function
12
+
13
+ def scalar?(shape) = shape.empty?
14
+
15
+ def equal?(a, b) = a.map(&:name) == b.map(&:name)
16
+
17
+ # NEP 20 broadcast rules:
18
+ # - scalar can broadcast to any expected shape
19
+ # - fixed-size dimensions must match exactly
20
+ # - broadcastable dimensions with |1 modifier can broadcast against size-1
21
+ # - flexible dimensions with ? can be omitted if not present in all operands
22
+ def broadcastable?(got:, expected:)
23
+ return true if scalar?(got)
24
+ return false if got.length != expected.length
25
+
26
+ got.zip(expected).all? do |got_dim, exp_dim|
27
+ broadcastable_dimension?(got: got_dim, expected: exp_dim)
28
+ end
29
+ end
30
+
31
+ def broadcastable_dimension?(got:, expected:)
32
+ # Same name and modifiers
33
+ return true if got == expected
34
+
35
+ # Same name, different modifiers - check compatibility
36
+ if got.name == expected.name
37
+ # Fixed-size dimensions must match exactly
38
+ if got.fixed_size? || expected.fixed_size?
39
+ return got.size == expected.size if got.fixed_size? && expected.fixed_size?
40
+ return false # one fixed, one not - incompatible
41
+ end
42
+
43
+ # Both named dimensions with same name are compatible
44
+ return true
45
+ end
46
+
47
+ # Different names - only broadcastable with |1 modifier
48
+ expected.broadcastable? && scalar?([got])
49
+ end
50
+
51
+ # Check if a dimension can be omitted (NEP 20 flexible dimensions)
52
+ def flexible?(dim)
53
+ dim.is_a?(Dimension) && dim.flexible?
54
+ end
55
+
56
+ # Check if a dimension can broadcast (NEP 20 broadcastable dimensions)
57
+ def broadcastable_dimension?(dim)
58
+ dim.is_a?(Dimension) && dim.broadcastable?
59
+ end
60
+
61
+ # Convenience: find dimensions in set a that are not in set b
62
+ def dimensions_minus(a, b)
63
+ a_names = a.map(&:name).to_set
64
+ b_names = b.map(&:name).to_set
65
+ (a_names - b_names).to_a
66
+ end
67
+ end
68
+ end
69
+ end
70
+ end
@@ -0,0 +1,122 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Kumi
4
+ module Core
5
+ module Functions
6
+ # Signature is a small immutable value object describing a function's
7
+ # vectorization contract with NEP 20 support.
8
+ #
9
+ # in_shapes: Array<Array<Dimension>> e.g., [[Dimension.new(:i)], [Dimension.new(:i)]]
10
+ # out_shape: Array<Dimension> e.g., [Dimension.new(:i)], [Dimension.new(:i), Dimension.new(:j)], or []
11
+ # join_policy: nil | :zip | :product
12
+ # raw: original string, for diagnostics (optional)
13
+ class Signature
14
+ attr_reader :in_shapes, :out_shape, :join_policy, :raw
15
+
16
+ def initialize(in_shapes:, out_shape:, join_policy: nil, raw: nil)
17
+ @in_shapes = deep_dup(in_shapes).freeze
18
+ @out_shape = out_shape.dup.freeze
19
+ @join_policy = join_policy&.to_sym
20
+ @raw = raw
21
+ validate!
22
+ freeze
23
+ end
24
+
25
+ def arity = @in_shapes.length
26
+
27
+ # Dimensions that appear in any input but not in output (i.e., reduced/dropped).
28
+ def dropped_axes
29
+ input_names = @in_shapes.flatten.map(&:name)
30
+ output_names = @out_shape.map(&:name)
31
+ (input_names - output_names).uniq.freeze
32
+ end
33
+
34
+ # True if any axis from inputs is dropped (common in aggregates).
35
+ def reduction?
36
+ !dropped_axes.empty?
37
+ end
38
+
39
+ def to_h
40
+ {
41
+ in_shapes: in_shapes.map(&:dup),
42
+ out_shape: out_shape.dup,
43
+ join_policy: join_policy,
44
+ raw: raw
45
+ }
46
+ end
47
+
48
+ def inspect
49
+ "#<Signature #{format_signature}#{" @#{join_policy}" if join_policy}>"
50
+ end
51
+
52
+ def format_signature
53
+ lhs = in_shapes.map { |dims| "(#{dims.map(&:to_s).join(',')})" }.join(",")
54
+ rhs = "(#{out_shape.map(&:to_s).join(',')})"
55
+ "#{lhs}->#{rhs}"
56
+ end
57
+
58
+ # Convert back to string representation for NEP-20 parser compatibility
59
+ def to_signature_string
60
+ sig_str = format_signature
61
+ join_policy ? "#{sig_str}@#{join_policy}" : sig_str
62
+ end
63
+
64
+ private
65
+
66
+ def validate!
67
+ unless @in_shapes.is_a?(Array) && @in_shapes.all? { |s| s.is_a?(Array) }
68
+ raise SignatureError, "in_shapes must be an array of dimension arrays"
69
+ end
70
+
71
+ @in_shapes.each_with_index do |dims, idx|
72
+ validate_dimensions!(dims, where: "in_shapes[#{idx}]")
73
+ end
74
+
75
+ validate_dimensions!(@out_shape, where: "out_shape")
76
+
77
+ unless [nil, :zip, :product].include?(@join_policy)
78
+ raise SignatureError, "join_policy must be nil, :zip, or :product; got #{@join_policy.inspect}"
79
+ end
80
+
81
+ # Validate NEP 20 constraints
82
+ validate_nep20_constraints!
83
+ end
84
+
85
+ def validate_dimensions!(dims, where: "shape")
86
+ unless dims.is_a?(Array) && dims.all? { |d| d.is_a?(Dimension) }
87
+ raise SignatureError, "#{where}: must be an array of Dimension objects, got: #{dims.inspect}"
88
+ end
89
+
90
+ # Check for duplicate dimension names within a single argument
91
+ names = dims.map(&:name)
92
+ duplicates = names.group_by { |n| n }.select { |_, v| v.size > 1 }.keys
93
+ raise SignatureError, "#{where}: duplicate dimension names #{duplicates.inspect}" unless duplicates.empty?
94
+
95
+ true
96
+ end
97
+
98
+ def validate_nep20_constraints!
99
+ # Broadcastable dimensions should only appear in inputs, not outputs
100
+ @out_shape.each do |dim|
101
+ raise SignatureError, "output dimension #{dim} cannot be broadcastable" if dim.broadcastable?
102
+ end
103
+
104
+ # Fixed-size dimensions in outputs must match corresponding input dimensions
105
+ all_input_dims = @in_shapes.flatten
106
+ @out_shape.each do |out_dim|
107
+ next unless out_dim.fixed_size?
108
+
109
+ matching_inputs = all_input_dims.select { |in_dim| in_dim.name == out_dim.name }
110
+ matching_inputs.each do |in_dim|
111
+ if in_dim.fixed_size? && in_dim.size != out_dim.size
112
+ raise SignatureError, "fixed-size dimension #{out_dim.name} has inconsistent sizes: #{in_dim.size} vs #{out_dim.size}"
113
+ end
114
+ end
115
+ end
116
+ end
117
+
118
+ def deep_dup(arr) = arr.map { |x| x.is_a?(Array) ? x.dup : x }
119
+ end
120
+ end
121
+ end
122
+ end
@@ -0,0 +1,86 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "dimension"
4
+ require_relative "signature"
5
+
6
+ module Kumi
7
+ module Core
8
+ module Functions
9
+ # Parses NEP 20 conformant signature strings like:
10
+ # "(),()->()" # scalar operations
11
+ # "(i),(i)->(i)" # vector operations
12
+ # "(i),(j)->(i,j)@product" # matrix operations with join policy
13
+ # "(i,j)->(i)" # reduction of :j
14
+ # "(3),(3)->(3)" # fixed-size 3-vectors (cross product)
15
+ # "(i?),(i?)->(i?)" # flexible dimensions
16
+ # "(i|1),(i|1)->()" # broadcastable dimensions
17
+ # "(m?,n),(n,p?)->(m?,p?)" # matmul signature
18
+ class SignatureParser
19
+ class << self
20
+ def parse(str)
21
+ raise SignatureParseError, "empty signature" if str.nil? || str.strip.empty?
22
+ lhs, rhs = str.split("->", 2)&.map!(&:strip)
23
+ raise SignatureParseError, "signature must contain '->': #{str.inspect}" unless rhs
24
+
25
+ out_spec, policy = rhs.split("@", 2)&.map!(&:strip)
26
+ in_shapes = parse_many(lhs)
27
+ out_shape = parse_axes(out_spec)
28
+ join_policy = policy&.to_sym
29
+
30
+ Signature.new(in_shapes: in_shapes, out_shape: out_shape, join_policy: join_policy, raw: str)
31
+ rescue SignatureError => e
32
+ raise
33
+ rescue StandardError => e
34
+ raise SignatureParseError, "invalid signature #{str.inspect}: #{e.message}"
35
+ end
36
+
37
+ private
38
+
39
+ def parse_many(lhs)
40
+ # Handle zero arguments case
41
+ return [] if lhs.strip.empty?
42
+
43
+ # split by commas that are *between* groups, not inside them
44
+ # simpler approach: split by '),', re-add ')' where needed
45
+ tokens = lhs.split("),").map { |t| t.strip.end_with?(")") ? t.strip : "#{t.strip})" }
46
+ tokens.map { |tok| parse_axes(tok) }
47
+ end
48
+
49
+ def parse_axes(spec)
50
+ spec = spec.strip
51
+ raise SignatureParseError, "missing parentheses in #{spec.inspect}" unless spec.start_with?("(") && spec.end_with?(")")
52
+
53
+ inner = spec[1..-2].strip
54
+ return [] if inner.empty?
55
+
56
+ inner.split(",").map { |dim_str| parse_dimension(dim_str.strip) }
57
+ end
58
+
59
+ # Parse a single dimension with NEP 20 modifiers
60
+ # Examples: "i", "3", "n?", "i|1"
61
+ def parse_dimension(dim_str)
62
+ return Dimension.new(:empty) if dim_str.empty?
63
+
64
+ # Extract modifiers
65
+ flexible = dim_str.end_with?("?")
66
+ dim_str = dim_str.chomp("?") if flexible
67
+
68
+ broadcastable = dim_str.end_with?("|1")
69
+ dim_str = dim_str.chomp("|1") if broadcastable
70
+
71
+ # Parse name (symbol or integer)
72
+ name = if dim_str.match?(/^\d+$/)
73
+ dim_str.to_i
74
+ else
75
+ dim_str.to_sym
76
+ end
77
+
78
+ Dimension.new(name, flexible: flexible, broadcastable: broadcastable)
79
+ rescue StandardError => e
80
+ raise SignatureParseError, "invalid dimension #{dim_str.inspect}: #{e.message}"
81
+ end
82
+ end
83
+ end
84
+ end
85
+ end
86
+ end