kumi 0.0.15 → 0.0.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +13 -0
  3. data/golden/cascade_logic/schema.kumi +3 -1
  4. data/lib/kumi/analyzer.rb +11 -9
  5. data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +0 -81
  6. data/lib/kumi/core/analyzer/passes/ir_dependency_pass.rb +18 -20
  7. data/lib/kumi/core/analyzer/passes/ir_execution_schedule_pass.rb +67 -0
  8. data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +0 -36
  9. data/lib/kumi/core/analyzer/passes/toposorter.rb +1 -39
  10. data/lib/kumi/core/analyzer/passes/unsat_detector.rb +8 -191
  11. data/lib/kumi/core/compiler/access_builder.rb +20 -10
  12. data/lib/kumi/core/compiler/access_codegen.rb +61 -0
  13. data/lib/kumi/core/compiler/access_emit/base.rb +173 -0
  14. data/lib/kumi/core/compiler/access_emit/each_indexed.rb +56 -0
  15. data/lib/kumi/core/compiler/access_emit/materialize.rb +45 -0
  16. data/lib/kumi/core/compiler/access_emit/ravel.rb +50 -0
  17. data/lib/kumi/core/compiler/access_emit/read.rb +32 -0
  18. data/lib/kumi/core/ir/execution_engine/interpreter.rb +36 -181
  19. data/lib/kumi/core/ir/execution_engine/values.rb +8 -8
  20. data/lib/kumi/core/ir/execution_engine.rb +3 -19
  21. data/lib/kumi/dev/parse.rb +12 -12
  22. data/lib/kumi/runtime/executable.rb +22 -175
  23. data/lib/kumi/runtime/run.rb +105 -0
  24. data/lib/kumi/schema.rb +8 -13
  25. data/lib/kumi/version.rb +1 -1
  26. data/lib/kumi.rb +3 -2
  27. metadata +10 -25
  28. data/BACKLOG.md +0 -34
  29. data/config/functions.yaml +0 -352
  30. data/docs/functions/analyzer_integration.md +0 -199
  31. data/docs/functions/signatures.md +0 -171
  32. data/examples/hash_objects_demo.rb +0 -138
  33. data/lib/kumi/core/analyzer/passes/function_signature_pass.rb +0 -199
  34. data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +0 -48
  35. data/lib/kumi/core/functions/dimension.rb +0 -98
  36. data/lib/kumi/core/functions/dtypes.rb +0 -20
  37. data/lib/kumi/core/functions/errors.rb +0 -11
  38. data/lib/kumi/core/functions/kernel_adapter.rb +0 -45
  39. data/lib/kumi/core/functions/loader.rb +0 -119
  40. data/lib/kumi/core/functions/registry_v2.rb +0 -68
  41. data/lib/kumi/core/functions/shape.rb +0 -70
  42. data/lib/kumi/core/functions/signature.rb +0 -122
  43. data/lib/kumi/core/functions/signature_parser.rb +0 -86
  44. data/lib/kumi/core/functions/signature_resolver.rb +0 -272
  45. data/lib/kumi/kernels/ruby/aggregate_core.rb +0 -105
  46. data/lib/kumi/kernels/ruby/datetime_scalar.rb +0 -21
  47. data/lib/kumi/kernels/ruby/mask_scalar.rb +0 -15
  48. data/lib/kumi/kernels/ruby/scalar_core.rb +0 -63
  49. data/lib/kumi/kernels/ruby/string_scalar.rb +0 -19
  50. data/lib/kumi/kernels/ruby/vector_struct.rb +0 -39
@@ -1,138 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- # Hash Objects Demo - Demonstrates Kumi's structured input syntax
4
- # This example shows how to use hash objects for organizing related input fields
5
-
6
- require_relative "../lib/kumi"
7
-
8
- module HashObjectsDemo
9
- extend Kumi::Schema
10
-
11
- schema do
12
- input do
13
- # Employee information as hash object
14
- hash :employee do
15
- string :name
16
- integer :age, domain: 18..65
17
- float :base_salary, domain: 30_000.0..200_000.0
18
- boolean :is_manager
19
- integer :years_experience, domain: 0..40
20
- end
21
-
22
- # Company configuration as hash object
23
- hash :company_config do
24
- string :name
25
- float :bonus_percentage, domain: 0.0..0.50
26
- float :manager_multiplier, domain: 1.0..2.0
27
- integer :current_year, domain: 2020..2030
28
- end
29
-
30
- # Benefits configuration as hash object
31
- hash :benefits do
32
- boolean :health_insurance
33
- boolean :dental_coverage
34
- float :retirement_match, domain: 0.0..0.10
35
- integer :vacation_days, domain: 10..30
36
- end
37
- end
38
-
39
- # Traits using hash object access
40
- trait :is_senior, input.employee.years_experience >= 5
41
- trait :eligible_for_bonus, is_senior & input.employee.is_manager
42
- trait :has_full_benefits, input.benefits.health_insurance & input.benefits.dental_coverage
43
-
44
- # Salary calculations using structured data
45
- value :base_annual_salary, input.employee.base_salary
46
-
47
- value :bonus_amount do
48
- on eligible_for_bonus, base_annual_salary * input.company_config.bonus_percentage
49
- base 0.0
50
- end
51
-
52
- trait :is_manager_trait, input.employee.is_manager == true
53
-
54
- value :manager_adjustment do
55
- on is_manager_trait, input.company_config.manager_multiplier
56
- base 1.0
57
- end
58
-
59
- value :total_compensation, (base_annual_salary * manager_adjustment) + bonus_amount
60
-
61
- # Benefits calculations
62
- value :retirement_contribution, total_compensation * input.benefits.retirement_match
63
-
64
- value :benefits_package_value do
65
- on has_full_benefits, 5_000.0 + input.benefits.vacation_days * 150.0
66
- base input.benefits.vacation_days * 100.0
67
- end
68
-
69
- # Final totals
70
- value :total_package_value, total_compensation + retirement_contribution + benefits_package_value
71
-
72
- # Summary calculations
73
- value :years_to_retirement, 65 - input.employee.age
74
- end
75
- end
76
-
77
- # Example usage
78
- if __FILE__ == $0
79
- puts "Hash Objects Demo - Employee Compensation Calculator"
80
- puts "=" * 55
81
-
82
- # Sample data demonstrating hash objects structure
83
- employee_data = {
84
- employee: {
85
- name: "Alice Johnson",
86
- age: 32,
87
- base_salary: 85_000.0,
88
- is_manager: true,
89
- years_experience: 8
90
- },
91
- company_config: {
92
- name: "Tech Solutions Inc",
93
- bonus_percentage: 0.15,
94
- manager_multiplier: 1.25,
95
- current_year: 2024
96
- },
97
- benefits: {
98
- health_insurance: true,
99
- dental_coverage: true,
100
- retirement_match: 0.06,
101
- vacation_days: 25
102
- }
103
- }
104
-
105
- # Calculate compensation
106
- result = HashObjectsDemo.from(employee_data)
107
-
108
- def format_currency(amount)
109
- "$#{amount.round(0).to_s.gsub(/\B(?=(\d{3})+(?!\d))/, ',')}"
110
- end
111
-
112
- puts "\nEmployee Information:"
113
- puts "- Name: #{employee_data[:employee][:name]}"
114
- puts "- Age: #{employee_data[:employee][:age]}"
115
- puts "- Experience: #{employee_data[:employee][:years_experience]} years"
116
- puts "- Manager: #{employee_data[:employee][:is_manager] ? 'Yes' : 'No'}"
117
-
118
- puts "\nCompany: #{employee_data[:company_config][:name]}"
119
- puts "- Bonus Rate: #{(employee_data[:company_config][:bonus_percentage] * 100).round(1)}%"
120
- puts "- Manager Multiplier: #{employee_data[:company_config][:manager_multiplier]}x"
121
-
122
- puts "\nBenefits:"
123
- puts "- Health Insurance: #{employee_data[:benefits][:health_insurance] ? 'Yes' : 'No'}"
124
- puts "- Dental Coverage: #{employee_data[:benefits][:dental_coverage] ? 'Yes' : 'No'}"
125
- puts "- Retirement Match: #{(employee_data[:benefits][:retirement_match] * 100).round(1)}%"
126
- puts "- Vacation Days: #{employee_data[:benefits][:vacation_days]}"
127
-
128
- puts "\nCompensation Breakdown:"
129
- puts "- Base Salary: #{format_currency(result[:base_annual_salary])}"
130
- puts "- Manager Adjustment: #{result[:manager_adjustment]}x"
131
- puts "- Bonus: #{format_currency(result[:bonus_amount])}"
132
- puts "- Total Compensation: #{format_currency(result[:total_compensation])}"
133
- puts "- Retirement Contribution: #{format_currency(result[:retirement_contribution])}"
134
- puts "- Benefits Package Value: #{format_currency(result[:benefits_package_value])}"
135
-
136
- puts "\nTotal Package: #{format_currency(result[:total_package_value])}"
137
- puts "Years to Retirement: #{result[:years_to_retirement]}"
138
- end
@@ -1,199 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Analyzer
6
- module Passes
7
- # RESPONSIBILITY: Apply NEP-20 signature resolution to function calls
8
- # DEPENDENCIES: :node_index from Toposorter, :broadcast_metadata (optional)
9
- # PRODUCES: Signature metadata in node_index for CallExpression nodes
10
- # INTERFACE: new(schema, state).run(errors)
11
- class FunctionSignaturePass < PassBase
12
- def run(errors)
13
- node_index = get_state(:node_index, required: true)
14
-
15
- # Process all CallExpression nodes in the index
16
- node_index.each do |object_id, entry|
17
- next unless entry[:type] == "CallExpression"
18
-
19
- resolve_function_signature(entry, object_id, errors)
20
- end
21
-
22
- state # Node index is modified in-place
23
- end
24
-
25
- private
26
-
27
- def resolve_function_signature(entry, object_id, errors)
28
- node = entry[:node]
29
-
30
- # 1) Gather candidate signatures from current registry
31
- sig_strings = get_function_signatures(node)
32
- return if sig_strings.empty?
33
-
34
- begin
35
- sigs = parse_signatures(sig_strings)
36
- rescue Kumi::Core::Functions::SignatureError => e
37
- report_error(errors, "Invalid signature for function `#{node.fn_name}`: #{e.message}",
38
- location: node.loc, type: :type)
39
- return
40
- end
41
-
42
- # 2) Build arg_shapes from current node context
43
- arg_shapes = build_argument_shapes(node, object_id)
44
-
45
- # 3) Resolve signature
46
- begin
47
- plan = Kumi::Core::Functions::SignatureResolver.choose(signatures: sigs, arg_shapes: arg_shapes)
48
- rescue Kumi::Core::Functions::SignatureMatchError => e
49
- report_error(errors,
50
- "Signature mismatch for `#{node.fn_name}` with args #{format_shapes(arg_shapes)}. Candidates: #{format_sigs(sig_strings)}. #{e.message}",
51
- location: node.loc, type: :type)
52
- return
53
- end
54
-
55
- # 4) Attach metadata to node index entry
56
- attach_signature_metadata(entry, plan)
57
- end
58
-
59
- def get_function_signatures(node)
60
- # Use RegistryV2 if enabled, otherwise fall back to legacy registry
61
- if registry_v2_enabled?
62
- registry_v2_signatures(node)
63
- else
64
- legacy_registry_signatures(node)
65
- end
66
- end
67
-
68
- def registry_v2_signatures(node)
69
- registry_v2.get_function_signatures(node.fn_name)
70
- rescue => e
71
- # If RegistryV2 fails, fall back to legacy
72
- legacy_registry_signatures(node)
73
- end
74
-
75
- def legacy_registry_signatures(node)
76
- # Try to get signatures from the current registry
77
- # For now, we'll create basic signatures from the current registry format
78
-
79
- meta = Kumi::Registry.signature(node.fn_name)
80
-
81
- # Check if the function already has NEP-20 signatures
82
- return meta[:signatures] if meta[:signatures] && meta[:signatures].is_a?(Array)
83
-
84
- # Otherwise, create a basic signature from arity
85
- # This is a bridge until we have full NEP-20 signatures in the registry
86
- create_basic_signature(meta[:arity])
87
- rescue Kumi::Errors::UnknownFunction
88
- # For now, return empty array - function existence will be caught by TypeChecker
89
- []
90
- end
91
-
92
- def create_basic_signature(arity)
93
- return [] if arity.nil? || arity < 0 # Variable arity - skip for now
94
-
95
- case arity
96
- when 0
97
- ["()->()"] # Scalar function
98
- when 1
99
- ["()->()", "(i)->(i)"] # Scalar or element-wise
100
- when 2
101
- ["(),()->()", "(i),(i)->(i)"] # Scalar or element-wise binary
102
- else
103
- # For higher arity, just provide scalar signature
104
- args = (["()"] * arity).join(",")
105
- ["#{args}->()"]
106
- end
107
- end
108
-
109
- def build_argument_shapes(node, object_id)
110
- # Build argument shapes from current analysis context
111
- node.args.map do |arg|
112
- axes = get_broadcast_metadata(arg.object_id)
113
- normalize_shape(axes)
114
- end
115
- end
116
-
117
- def normalize_shape(axes)
118
- case axes
119
- when nil
120
- [] # scalar
121
- when Array
122
- axes.map { |d| d.is_a?(Integer) ? d : d.to_sym }
123
- else
124
- [] # defensive fallback
125
- end
126
- end
127
-
128
- def get_broadcast_metadata(arg_object_id)
129
- # Try to get broadcast metadata from existing analysis state
130
- broadcast_meta = get_state(:broadcast_metadata, required: false)
131
- return nil unless broadcast_meta
132
-
133
- # Look up by node object_id
134
- broadcast_meta[arg_object_id]&.dig(:axes)
135
- end
136
-
137
- def parse_signatures(sig_strings)
138
- @sig_cache ||= {}
139
- sig_strings.map do |s|
140
- @sig_cache[s] ||= Kumi::Core::Functions::SignatureParser.parse(s)
141
- end
142
- end
143
-
144
- def format_shapes(shapes)
145
- shapes.map { |ax| "(#{ax.join(',')})" }.join(', ')
146
- end
147
-
148
- def format_sigs(sig_strings)
149
- sig_strings.join(" | ")
150
- end
151
-
152
- def attach_signature_metadata(entry, plan)
153
- # Attach signature resolution results to the node index entry
154
- # This way other passes can access the metadata via the node index
155
- metadata = entry[:metadata]
156
-
157
- attach_core_signature_data(metadata, plan)
158
- attach_shape_contract(metadata, plan)
159
- end
160
-
161
- def attach_core_signature_data(metadata, plan)
162
- metadata[:signature] = plan[:signature]
163
- metadata[:result_axes] = plan[:result_axes] # e.g., [:i, :j]
164
- metadata[:join_policy] = plan[:join_policy] # nil | :zip | :product
165
- metadata[:dropped_axes] = plan[:dropped_axes] # e.g., [:j] for reductions
166
- metadata[:effective_signature] = plan[:effective_signature] # Normalized for lowering
167
- metadata[:dim_env] = plan[:env] # Dimension bindings (for matmul)
168
- metadata[:signature_score] = plan[:score] # Match quality
169
- end
170
-
171
- def attach_shape_contract(metadata, plan)
172
- # Attach shape contract for lowering convenience
173
- metadata[:shape_contract] = {
174
- in: plan[:effective_signature][:in_shapes],
175
- out: plan[:effective_signature][:out_shape],
176
- join: plan[:effective_signature][:join_policy]
177
- }
178
- end
179
-
180
- def registry_v2_enabled?
181
- ENV["KUMI_FN_REGISTRY_V2"] == "1"
182
- end
183
-
184
- def registry_v2
185
- @registry_v2 ||= Kumi::Core::Functions::RegistryV2.load_from_file
186
- end
187
-
188
- def nep20_flex_enabled?
189
- ENV["KUMI_ENABLE_FLEX"] == "1"
190
- end
191
-
192
- def nep20_bcast1_enabled?
193
- ENV["KUMI_ENABLE_BCAST1"] == "1"
194
- end
195
- end
196
- end
197
- end
198
- end
199
- end
@@ -1,48 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Analyzer
6
- module Passes
7
- # RESPONSIBILITY: Validate consistency between declared and inferred types
8
- # DEPENDENCIES: :input_metadata from InputCollector, :inferred_types from TypeInferencerPass
9
- # PRODUCES: None (validation only)
10
- # INTERFACE: new(schema, state).run(errors)
11
- class TypeConsistencyChecker < PassBase
12
- def run(errors)
13
- input_meta = get_state(:input_metadata, required: false) || {}
14
-
15
- # First, validate that all declared types are valid
16
- validate_declared_types(input_meta, errors)
17
-
18
- # Then check basic consistency (placeholder for now)
19
- # In a full implementation, this would do sophisticated usage analysis
20
- state
21
- end
22
-
23
- private
24
-
25
- def validate_declared_types(input_meta, errors)
26
- input_meta.each do |field_name, meta|
27
- declared_type = meta[:type]
28
- next unless declared_type # Skip fields without declared types
29
- next if Kumi::Core::Types.valid_type?(declared_type)
30
-
31
- # Find the input field declaration for proper location information
32
- field_decl = find_input_field_declaration(field_name)
33
- location = field_decl&.loc
34
-
35
- report_type_error(errors, "Invalid type declaration for field :#{field_name}: #{declared_type.inspect}", location: location)
36
- end
37
- end
38
-
39
- def find_input_field_declaration(field_name)
40
- return nil unless schema
41
-
42
- schema.inputs.find { |input_decl| input_decl.name == field_name }
43
- end
44
- end
45
- end
46
- end
47
- end
48
- end
@@ -1,98 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require_relative "errors"
4
-
5
- module Kumi
6
- module Core
7
- module Functions
8
- # Represents a single dimension in a signature with NEP 20 support.
9
- #
10
- # A dimension can be:
11
- # - Named dimension (symbol): :i, :j, :n
12
- # - Fixed-size dimension (integer): 2, 3, 10
13
- # - With modifiers:
14
- # - flexible (?): can be omitted if not present in all operands
15
- # - broadcastable (|1): can broadcast against size-1 dimensions
16
- #
17
- # Examples:
18
- # Dimension.new(:i) # named dimension 'i'
19
- # Dimension.new(3) # fixed-size dimension of size 3
20
- # Dimension.new(:n, flexible: true) # dimension 'n' that can be omitted
21
- # Dimension.new(:i, broadcastable: true) # dimension 'i' that can broadcast
22
- class Dimension
23
- attr_reader :name, :flexible, :broadcastable
24
-
25
- def initialize(name, flexible: false, broadcastable: false)
26
- @name = name
27
- @flexible = flexible
28
- @broadcastable = broadcastable
29
-
30
- validate!
31
- freeze
32
- end
33
-
34
- def fixed_size?
35
- @name.is_a?(Integer)
36
- end
37
-
38
- def named?
39
- @name.is_a?(Symbol)
40
- end
41
-
42
- def flexible?
43
- @flexible
44
- end
45
-
46
- def broadcastable?
47
- @broadcastable
48
- end
49
-
50
- def size
51
- fixed_size? ? @name : nil
52
- end
53
-
54
- def ==(other)
55
- other.is_a?(Dimension) &&
56
- name == other.name &&
57
- flexible == other.flexible &&
58
- broadcastable == other.broadcastable
59
- end
60
-
61
- def eql?(other)
62
- self == other
63
- end
64
-
65
- def hash
66
- [name, flexible, broadcastable].hash
67
- end
68
-
69
- def to_s
70
- str = name.to_s
71
- str += "?" if flexible?
72
- str += "|1" if broadcastable?
73
- str
74
- end
75
-
76
- def inspect
77
- "#<Dimension #{self}>"
78
- end
79
-
80
- private
81
-
82
- def validate!
83
- unless name.is_a?(Symbol) || name.is_a?(Integer)
84
- raise SignatureError, "dimension name must be a symbol or integer, got: #{name.inspect}"
85
- end
86
-
87
- raise SignatureError, "fixed-size dimension must be positive, got: #{name}" if name.is_a?(Integer) && name <= 0
88
-
89
- raise SignatureError, "dimension cannot be both flexible and broadcastable" if flexible? && broadcastable?
90
-
91
- return unless fixed_size? && flexible?
92
-
93
- raise SignatureError, "fixed-size dimension cannot be flexible"
94
- end
95
- end
96
- end
97
- end
98
- end
@@ -1,20 +0,0 @@
1
- module Kumi::Core::Functions
2
- DType = Struct.new(:name, keyword_init: true)
3
- module DTypes
4
- BOOL = DType.new(name: :bool)
5
- INT = DType.new(name: :int)
6
- FLOAT = DType.new(name: :float)
7
- STRING = DType.new(name: :string)
8
- DATETIME = DType.new(name: :datetime)
9
- ANY = DType.new(name: :any)
10
- end
11
-
12
- module Promotion
13
- # super simple table; extend later
14
- TABLE = {
15
- %i[int int] => :int, %i[int float] => :float, %i[float int] => :float, %i[float float] => :float,
16
- %i[bool int] => :int, %i[bool float] => :float, %i[bool bool] => :bool
17
- }
18
- def self.promote(a, b) = TABLE[[a, b]] || TABLE[[b, a]] || :any
19
- end
20
- end
@@ -1,11 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Functions
6
- class SignatureError < StandardError; end
7
- class SignatureParseError < SignatureError; end
8
- class SignatureMatchError < SignatureError; end
9
- end
10
- end
11
- end
@@ -1,45 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Kumi
4
- module Core
5
- module Functions
6
- KernelHandle = Struct.new(:kind, :callable, :null_policy, :options, keyword_init: true)
7
-
8
- module KernelAdapter
9
- module_function
10
-
11
- def build_for(function, backend_entry)
12
- impl = backend_entry.impl
13
- mod = ruby_module_for(function)
14
-
15
- raise Kumi::Core::Errors::CompilationError, "Missing Ruby kernel #{impl} for #{function.name}" unless mod.respond_to?(impl)
16
-
17
- kind = function.class_sym
18
- KernelHandle.new(
19
- kind: kind,
20
- callable: mod.method(impl),
21
- null_policy: function.null_policy,
22
- options: function.options || {}
23
- )
24
- end
25
-
26
- def ruby_module_for(function)
27
- case function.domain.to_sym
28
- when :core
29
- function.class_sym == :aggregate ? Kumi::Kernels::Ruby::AggregateCore : Kumi::Kernels::Ruby::ScalarCore
30
- when :string
31
- Kumi::Kernels::Ruby::StringScalar
32
- when :datetime
33
- Kumi::Kernels::Ruby::DatetimeScalar
34
- when :struct
35
- Kumi::Kernels::Ruby::VectorStruct
36
- when :mask
37
- Kumi::Kernels::Ruby::MaskScalar
38
- else
39
- raise Kumi::Core::Errors::CompilationError, "Unknown domain #{function.domain}"
40
- end
41
- end
42
- end
43
- end
44
- end
45
- end
@@ -1,119 +0,0 @@
1
- require "yaml"
2
- module Kumi::Core::Functions
3
- Function = Data.define(:name, :domain, :opset, :class_sym, :signatures,
4
- :type_vars, :dtypes, :null_policy, :algebra, :vectorization,
5
- :options, :traits, :shape_fn, :kernels)
6
-
7
- KernelEntry = Data.define(:backend, :impl, :priority, :when_)
8
-
9
- class Loader
10
- def self.load_file(path)
11
- doc = YAML.load_file(path)
12
- functions = doc.map { |h| build_function(h) }
13
- validate!(functions)
14
- functions.freeze
15
- end
16
-
17
- def self.build_function(h)
18
- sigs = Array(h.fetch("signature")).map { |s| parse_signature(s) }
19
- kernels = Array(h.fetch("kernels", [])).map do |k|
20
- KernelEntry.new(backend: k.fetch("backend").to_sym,
21
- impl: k.fetch("impl").to_sym,
22
- priority: k.fetch("priority", 0).to_i,
23
- when_: k["when"]&.transform_keys!(&:to_sym))
24
- end
25
- Function.new(
26
- name: h.fetch("name"),
27
- domain: h.fetch("domain"),
28
- opset: h.fetch("opset").to_i,
29
- class_sym: h.fetch("class").to_sym,
30
- signatures: sigs,
31
- type_vars: h["type_vars"] || {},
32
- dtypes: h["dtypes"] || {},
33
- null_policy: (h["null_policy"] || "propagate").to_sym,
34
- algebra: (h["algebra"] || {}).transform_keys!(&:to_sym),
35
- vectorization: (h["vectorization"] || {}).transform_keys!(&:to_sym),
36
- options: h["options"] || {},
37
- traits: (h["traits"] || {}).transform_keys!(&:to_sym),
38
- shape_fn: h["shape_fn"],
39
- kernels: kernels
40
- )
41
- end
42
-
43
- # " (i),(j)->(i,j)@product " → Signature
44
- def self.parse_signature(s)
45
- lhs, rhs = s.split("->").map(&:strip)
46
- out_axes, policy = rhs.split("@").map(&:strip)
47
-
48
- # Parse input shapes by splitting on commas between parentheses
49
- in_shapes = parse_input_shapes(lhs)
50
-
51
- Signature.new(in_shapes: in_shapes,
52
- out_shape: parse_axes(out_axes),
53
- join_policy: policy&.to_sym)
54
- end
55
-
56
- # Parse "(i),(j)" or "(i,j),(k,l)" properly
57
- def self.parse_input_shapes(lhs)
58
- # Find all parenthesized groups
59
- shapes = []
60
- current_pos = 0
61
-
62
- while current_pos < lhs.length
63
- # Find the next opening parenthesis
64
- start_paren = lhs.index('(', current_pos)
65
- break unless start_paren
66
-
67
- # Find the matching closing parenthesis
68
- paren_count = 0
69
- end_paren = start_paren
70
-
71
- (start_paren..lhs.length-1).each do |i|
72
- case lhs[i]
73
- when '('
74
- paren_count += 1
75
- when ')'
76
- paren_count -= 1
77
- if paren_count == 0
78
- end_paren = i
79
- break
80
- end
81
- end
82
- end
83
-
84
- # Extract the shape
85
- shape_str = lhs[start_paren..end_paren]
86
- shapes << parse_axes(shape_str)
87
- current_pos = end_paren + 1
88
- end
89
-
90
- shapes
91
- end
92
-
93
- def self.parse_axes(txt)
94
- txt = txt.strip
95
- return [] if txt == "()" || txt.empty?
96
-
97
- inner = txt.sub("(", "").sub(")", "")
98
- if inner.empty?
99
- []
100
- else
101
- inner.split(",").map do |a|
102
- dim_name = a.strip.to_sym
103
- Dimension.new(dim_name) # Convert to Dimension objects
104
- end
105
- end
106
- end
107
-
108
- def self.validate!(fns)
109
- names = {}
110
- fns.each do |f|
111
- key = [f.domain, f.name, f.opset]
112
- raise "duplicate function #{key}" if names[key]
113
-
114
- names[key] = true
115
- raise "no kernels for #{f.name}" if f.kernels.empty?
116
- end
117
- end
118
- end
119
- end