kumi 0.0.15 → 0.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/golden/cascade_logic/schema.kumi +3 -1
- data/lib/kumi/analyzer.rb +11 -9
- data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +0 -81
- data/lib/kumi/core/analyzer/passes/ir_dependency_pass.rb +18 -20
- data/lib/kumi/core/analyzer/passes/ir_execution_schedule_pass.rb +67 -0
- data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +0 -36
- data/lib/kumi/core/analyzer/passes/toposorter.rb +1 -39
- data/lib/kumi/core/analyzer/passes/unsat_detector.rb +8 -191
- data/lib/kumi/core/compiler/access_builder.rb +20 -10
- data/lib/kumi/core/compiler/access_codegen.rb +61 -0
- data/lib/kumi/core/compiler/access_emit/base.rb +173 -0
- data/lib/kumi/core/compiler/access_emit/each_indexed.rb +56 -0
- data/lib/kumi/core/compiler/access_emit/materialize.rb +45 -0
- data/lib/kumi/core/compiler/access_emit/ravel.rb +50 -0
- data/lib/kumi/core/compiler/access_emit/read.rb +32 -0
- data/lib/kumi/core/ir/execution_engine/interpreter.rb +36 -181
- data/lib/kumi/core/ir/execution_engine/values.rb +8 -8
- data/lib/kumi/core/ir/execution_engine.rb +3 -19
- data/lib/kumi/dev/parse.rb +12 -12
- data/lib/kumi/runtime/executable.rb +22 -175
- data/lib/kumi/runtime/run.rb +105 -0
- data/lib/kumi/schema.rb +8 -13
- data/lib/kumi/version.rb +1 -1
- data/lib/kumi.rb +3 -2
- metadata +10 -25
- data/BACKLOG.md +0 -34
- data/config/functions.yaml +0 -352
- data/docs/functions/analyzer_integration.md +0 -199
- data/docs/functions/signatures.md +0 -171
- data/examples/hash_objects_demo.rb +0 -138
- data/lib/kumi/core/analyzer/passes/function_signature_pass.rb +0 -199
- data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +0 -48
- data/lib/kumi/core/functions/dimension.rb +0 -98
- data/lib/kumi/core/functions/dtypes.rb +0 -20
- data/lib/kumi/core/functions/errors.rb +0 -11
- data/lib/kumi/core/functions/kernel_adapter.rb +0 -45
- data/lib/kumi/core/functions/loader.rb +0 -119
- data/lib/kumi/core/functions/registry_v2.rb +0 -68
- data/lib/kumi/core/functions/shape.rb +0 -70
- data/lib/kumi/core/functions/signature.rb +0 -122
- data/lib/kumi/core/functions/signature_parser.rb +0 -86
- data/lib/kumi/core/functions/signature_resolver.rb +0 -272
- data/lib/kumi/kernels/ruby/aggregate_core.rb +0 -105
- data/lib/kumi/kernels/ruby/datetime_scalar.rb +0 -21
- data/lib/kumi/kernels/ruby/mask_scalar.rb +0 -15
- data/lib/kumi/kernels/ruby/scalar_core.rb +0 -63
- data/lib/kumi/kernels/ruby/string_scalar.rb +0 -19
- data/lib/kumi/kernels/ruby/vector_struct.rb +0 -39
@@ -1,138 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
# Hash Objects Demo - Demonstrates Kumi's structured input syntax
|
4
|
-
# This example shows how to use hash objects for organizing related input fields
|
5
|
-
|
6
|
-
require_relative "../lib/kumi"
|
7
|
-
|
8
|
-
module HashObjectsDemo
|
9
|
-
extend Kumi::Schema
|
10
|
-
|
11
|
-
schema do
|
12
|
-
input do
|
13
|
-
# Employee information as hash object
|
14
|
-
hash :employee do
|
15
|
-
string :name
|
16
|
-
integer :age, domain: 18..65
|
17
|
-
float :base_salary, domain: 30_000.0..200_000.0
|
18
|
-
boolean :is_manager
|
19
|
-
integer :years_experience, domain: 0..40
|
20
|
-
end
|
21
|
-
|
22
|
-
# Company configuration as hash object
|
23
|
-
hash :company_config do
|
24
|
-
string :name
|
25
|
-
float :bonus_percentage, domain: 0.0..0.50
|
26
|
-
float :manager_multiplier, domain: 1.0..2.0
|
27
|
-
integer :current_year, domain: 2020..2030
|
28
|
-
end
|
29
|
-
|
30
|
-
# Benefits configuration as hash object
|
31
|
-
hash :benefits do
|
32
|
-
boolean :health_insurance
|
33
|
-
boolean :dental_coverage
|
34
|
-
float :retirement_match, domain: 0.0..0.10
|
35
|
-
integer :vacation_days, domain: 10..30
|
36
|
-
end
|
37
|
-
end
|
38
|
-
|
39
|
-
# Traits using hash object access
|
40
|
-
trait :is_senior, input.employee.years_experience >= 5
|
41
|
-
trait :eligible_for_bonus, is_senior & input.employee.is_manager
|
42
|
-
trait :has_full_benefits, input.benefits.health_insurance & input.benefits.dental_coverage
|
43
|
-
|
44
|
-
# Salary calculations using structured data
|
45
|
-
value :base_annual_salary, input.employee.base_salary
|
46
|
-
|
47
|
-
value :bonus_amount do
|
48
|
-
on eligible_for_bonus, base_annual_salary * input.company_config.bonus_percentage
|
49
|
-
base 0.0
|
50
|
-
end
|
51
|
-
|
52
|
-
trait :is_manager_trait, input.employee.is_manager == true
|
53
|
-
|
54
|
-
value :manager_adjustment do
|
55
|
-
on is_manager_trait, input.company_config.manager_multiplier
|
56
|
-
base 1.0
|
57
|
-
end
|
58
|
-
|
59
|
-
value :total_compensation, (base_annual_salary * manager_adjustment) + bonus_amount
|
60
|
-
|
61
|
-
# Benefits calculations
|
62
|
-
value :retirement_contribution, total_compensation * input.benefits.retirement_match
|
63
|
-
|
64
|
-
value :benefits_package_value do
|
65
|
-
on has_full_benefits, 5_000.0 + input.benefits.vacation_days * 150.0
|
66
|
-
base input.benefits.vacation_days * 100.0
|
67
|
-
end
|
68
|
-
|
69
|
-
# Final totals
|
70
|
-
value :total_package_value, total_compensation + retirement_contribution + benefits_package_value
|
71
|
-
|
72
|
-
# Summary calculations
|
73
|
-
value :years_to_retirement, 65 - input.employee.age
|
74
|
-
end
|
75
|
-
end
|
76
|
-
|
77
|
-
# Example usage
|
78
|
-
if __FILE__ == $0
|
79
|
-
puts "Hash Objects Demo - Employee Compensation Calculator"
|
80
|
-
puts "=" * 55
|
81
|
-
|
82
|
-
# Sample data demonstrating hash objects structure
|
83
|
-
employee_data = {
|
84
|
-
employee: {
|
85
|
-
name: "Alice Johnson",
|
86
|
-
age: 32,
|
87
|
-
base_salary: 85_000.0,
|
88
|
-
is_manager: true,
|
89
|
-
years_experience: 8
|
90
|
-
},
|
91
|
-
company_config: {
|
92
|
-
name: "Tech Solutions Inc",
|
93
|
-
bonus_percentage: 0.15,
|
94
|
-
manager_multiplier: 1.25,
|
95
|
-
current_year: 2024
|
96
|
-
},
|
97
|
-
benefits: {
|
98
|
-
health_insurance: true,
|
99
|
-
dental_coverage: true,
|
100
|
-
retirement_match: 0.06,
|
101
|
-
vacation_days: 25
|
102
|
-
}
|
103
|
-
}
|
104
|
-
|
105
|
-
# Calculate compensation
|
106
|
-
result = HashObjectsDemo.from(employee_data)
|
107
|
-
|
108
|
-
def format_currency(amount)
|
109
|
-
"$#{amount.round(0).to_s.gsub(/\B(?=(\d{3})+(?!\d))/, ',')}"
|
110
|
-
end
|
111
|
-
|
112
|
-
puts "\nEmployee Information:"
|
113
|
-
puts "- Name: #{employee_data[:employee][:name]}"
|
114
|
-
puts "- Age: #{employee_data[:employee][:age]}"
|
115
|
-
puts "- Experience: #{employee_data[:employee][:years_experience]} years"
|
116
|
-
puts "- Manager: #{employee_data[:employee][:is_manager] ? 'Yes' : 'No'}"
|
117
|
-
|
118
|
-
puts "\nCompany: #{employee_data[:company_config][:name]}"
|
119
|
-
puts "- Bonus Rate: #{(employee_data[:company_config][:bonus_percentage] * 100).round(1)}%"
|
120
|
-
puts "- Manager Multiplier: #{employee_data[:company_config][:manager_multiplier]}x"
|
121
|
-
|
122
|
-
puts "\nBenefits:"
|
123
|
-
puts "- Health Insurance: #{employee_data[:benefits][:health_insurance] ? 'Yes' : 'No'}"
|
124
|
-
puts "- Dental Coverage: #{employee_data[:benefits][:dental_coverage] ? 'Yes' : 'No'}"
|
125
|
-
puts "- Retirement Match: #{(employee_data[:benefits][:retirement_match] * 100).round(1)}%"
|
126
|
-
puts "- Vacation Days: #{employee_data[:benefits][:vacation_days]}"
|
127
|
-
|
128
|
-
puts "\nCompensation Breakdown:"
|
129
|
-
puts "- Base Salary: #{format_currency(result[:base_annual_salary])}"
|
130
|
-
puts "- Manager Adjustment: #{result[:manager_adjustment]}x"
|
131
|
-
puts "- Bonus: #{format_currency(result[:bonus_amount])}"
|
132
|
-
puts "- Total Compensation: #{format_currency(result[:total_compensation])}"
|
133
|
-
puts "- Retirement Contribution: #{format_currency(result[:retirement_contribution])}"
|
134
|
-
puts "- Benefits Package Value: #{format_currency(result[:benefits_package_value])}"
|
135
|
-
|
136
|
-
puts "\nTotal Package: #{format_currency(result[:total_package_value])}"
|
137
|
-
puts "Years to Retirement: #{result[:years_to_retirement]}"
|
138
|
-
end
|
@@ -1,199 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Kumi
|
4
|
-
module Core
|
5
|
-
module Analyzer
|
6
|
-
module Passes
|
7
|
-
# RESPONSIBILITY: Apply NEP-20 signature resolution to function calls
|
8
|
-
# DEPENDENCIES: :node_index from Toposorter, :broadcast_metadata (optional)
|
9
|
-
# PRODUCES: Signature metadata in node_index for CallExpression nodes
|
10
|
-
# INTERFACE: new(schema, state).run(errors)
|
11
|
-
class FunctionSignaturePass < PassBase
|
12
|
-
def run(errors)
|
13
|
-
node_index = get_state(:node_index, required: true)
|
14
|
-
|
15
|
-
# Process all CallExpression nodes in the index
|
16
|
-
node_index.each do |object_id, entry|
|
17
|
-
next unless entry[:type] == "CallExpression"
|
18
|
-
|
19
|
-
resolve_function_signature(entry, object_id, errors)
|
20
|
-
end
|
21
|
-
|
22
|
-
state # Node index is modified in-place
|
23
|
-
end
|
24
|
-
|
25
|
-
private
|
26
|
-
|
27
|
-
def resolve_function_signature(entry, object_id, errors)
|
28
|
-
node = entry[:node]
|
29
|
-
|
30
|
-
# 1) Gather candidate signatures from current registry
|
31
|
-
sig_strings = get_function_signatures(node)
|
32
|
-
return if sig_strings.empty?
|
33
|
-
|
34
|
-
begin
|
35
|
-
sigs = parse_signatures(sig_strings)
|
36
|
-
rescue Kumi::Core::Functions::SignatureError => e
|
37
|
-
report_error(errors, "Invalid signature for function `#{node.fn_name}`: #{e.message}",
|
38
|
-
location: node.loc, type: :type)
|
39
|
-
return
|
40
|
-
end
|
41
|
-
|
42
|
-
# 2) Build arg_shapes from current node context
|
43
|
-
arg_shapes = build_argument_shapes(node, object_id)
|
44
|
-
|
45
|
-
# 3) Resolve signature
|
46
|
-
begin
|
47
|
-
plan = Kumi::Core::Functions::SignatureResolver.choose(signatures: sigs, arg_shapes: arg_shapes)
|
48
|
-
rescue Kumi::Core::Functions::SignatureMatchError => e
|
49
|
-
report_error(errors,
|
50
|
-
"Signature mismatch for `#{node.fn_name}` with args #{format_shapes(arg_shapes)}. Candidates: #{format_sigs(sig_strings)}. #{e.message}",
|
51
|
-
location: node.loc, type: :type)
|
52
|
-
return
|
53
|
-
end
|
54
|
-
|
55
|
-
# 4) Attach metadata to node index entry
|
56
|
-
attach_signature_metadata(entry, plan)
|
57
|
-
end
|
58
|
-
|
59
|
-
def get_function_signatures(node)
|
60
|
-
# Use RegistryV2 if enabled, otherwise fall back to legacy registry
|
61
|
-
if registry_v2_enabled?
|
62
|
-
registry_v2_signatures(node)
|
63
|
-
else
|
64
|
-
legacy_registry_signatures(node)
|
65
|
-
end
|
66
|
-
end
|
67
|
-
|
68
|
-
def registry_v2_signatures(node)
|
69
|
-
registry_v2.get_function_signatures(node.fn_name)
|
70
|
-
rescue => e
|
71
|
-
# If RegistryV2 fails, fall back to legacy
|
72
|
-
legacy_registry_signatures(node)
|
73
|
-
end
|
74
|
-
|
75
|
-
def legacy_registry_signatures(node)
|
76
|
-
# Try to get signatures from the current registry
|
77
|
-
# For now, we'll create basic signatures from the current registry format
|
78
|
-
|
79
|
-
meta = Kumi::Registry.signature(node.fn_name)
|
80
|
-
|
81
|
-
# Check if the function already has NEP-20 signatures
|
82
|
-
return meta[:signatures] if meta[:signatures] && meta[:signatures].is_a?(Array)
|
83
|
-
|
84
|
-
# Otherwise, create a basic signature from arity
|
85
|
-
# This is a bridge until we have full NEP-20 signatures in the registry
|
86
|
-
create_basic_signature(meta[:arity])
|
87
|
-
rescue Kumi::Errors::UnknownFunction
|
88
|
-
# For now, return empty array - function existence will be caught by TypeChecker
|
89
|
-
[]
|
90
|
-
end
|
91
|
-
|
92
|
-
def create_basic_signature(arity)
|
93
|
-
return [] if arity.nil? || arity < 0 # Variable arity - skip for now
|
94
|
-
|
95
|
-
case arity
|
96
|
-
when 0
|
97
|
-
["()->()"] # Scalar function
|
98
|
-
when 1
|
99
|
-
["()->()", "(i)->(i)"] # Scalar or element-wise
|
100
|
-
when 2
|
101
|
-
["(),()->()", "(i),(i)->(i)"] # Scalar or element-wise binary
|
102
|
-
else
|
103
|
-
# For higher arity, just provide scalar signature
|
104
|
-
args = (["()"] * arity).join(",")
|
105
|
-
["#{args}->()"]
|
106
|
-
end
|
107
|
-
end
|
108
|
-
|
109
|
-
def build_argument_shapes(node, object_id)
|
110
|
-
# Build argument shapes from current analysis context
|
111
|
-
node.args.map do |arg|
|
112
|
-
axes = get_broadcast_metadata(arg.object_id)
|
113
|
-
normalize_shape(axes)
|
114
|
-
end
|
115
|
-
end
|
116
|
-
|
117
|
-
def normalize_shape(axes)
|
118
|
-
case axes
|
119
|
-
when nil
|
120
|
-
[] # scalar
|
121
|
-
when Array
|
122
|
-
axes.map { |d| d.is_a?(Integer) ? d : d.to_sym }
|
123
|
-
else
|
124
|
-
[] # defensive fallback
|
125
|
-
end
|
126
|
-
end
|
127
|
-
|
128
|
-
def get_broadcast_metadata(arg_object_id)
|
129
|
-
# Try to get broadcast metadata from existing analysis state
|
130
|
-
broadcast_meta = get_state(:broadcast_metadata, required: false)
|
131
|
-
return nil unless broadcast_meta
|
132
|
-
|
133
|
-
# Look up by node object_id
|
134
|
-
broadcast_meta[arg_object_id]&.dig(:axes)
|
135
|
-
end
|
136
|
-
|
137
|
-
def parse_signatures(sig_strings)
|
138
|
-
@sig_cache ||= {}
|
139
|
-
sig_strings.map do |s|
|
140
|
-
@sig_cache[s] ||= Kumi::Core::Functions::SignatureParser.parse(s)
|
141
|
-
end
|
142
|
-
end
|
143
|
-
|
144
|
-
def format_shapes(shapes)
|
145
|
-
shapes.map { |ax| "(#{ax.join(',')})" }.join(', ')
|
146
|
-
end
|
147
|
-
|
148
|
-
def format_sigs(sig_strings)
|
149
|
-
sig_strings.join(" | ")
|
150
|
-
end
|
151
|
-
|
152
|
-
def attach_signature_metadata(entry, plan)
|
153
|
-
# Attach signature resolution results to the node index entry
|
154
|
-
# This way other passes can access the metadata via the node index
|
155
|
-
metadata = entry[:metadata]
|
156
|
-
|
157
|
-
attach_core_signature_data(metadata, plan)
|
158
|
-
attach_shape_contract(metadata, plan)
|
159
|
-
end
|
160
|
-
|
161
|
-
def attach_core_signature_data(metadata, plan)
|
162
|
-
metadata[:signature] = plan[:signature]
|
163
|
-
metadata[:result_axes] = plan[:result_axes] # e.g., [:i, :j]
|
164
|
-
metadata[:join_policy] = plan[:join_policy] # nil | :zip | :product
|
165
|
-
metadata[:dropped_axes] = plan[:dropped_axes] # e.g., [:j] for reductions
|
166
|
-
metadata[:effective_signature] = plan[:effective_signature] # Normalized for lowering
|
167
|
-
metadata[:dim_env] = plan[:env] # Dimension bindings (for matmul)
|
168
|
-
metadata[:signature_score] = plan[:score] # Match quality
|
169
|
-
end
|
170
|
-
|
171
|
-
def attach_shape_contract(metadata, plan)
|
172
|
-
# Attach shape contract for lowering convenience
|
173
|
-
metadata[:shape_contract] = {
|
174
|
-
in: plan[:effective_signature][:in_shapes],
|
175
|
-
out: plan[:effective_signature][:out_shape],
|
176
|
-
join: plan[:effective_signature][:join_policy]
|
177
|
-
}
|
178
|
-
end
|
179
|
-
|
180
|
-
def registry_v2_enabled?
|
181
|
-
ENV["KUMI_FN_REGISTRY_V2"] == "1"
|
182
|
-
end
|
183
|
-
|
184
|
-
def registry_v2
|
185
|
-
@registry_v2 ||= Kumi::Core::Functions::RegistryV2.load_from_file
|
186
|
-
end
|
187
|
-
|
188
|
-
def nep20_flex_enabled?
|
189
|
-
ENV["KUMI_ENABLE_FLEX"] == "1"
|
190
|
-
end
|
191
|
-
|
192
|
-
def nep20_bcast1_enabled?
|
193
|
-
ENV["KUMI_ENABLE_BCAST1"] == "1"
|
194
|
-
end
|
195
|
-
end
|
196
|
-
end
|
197
|
-
end
|
198
|
-
end
|
199
|
-
end
|
@@ -1,48 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Kumi
|
4
|
-
module Core
|
5
|
-
module Analyzer
|
6
|
-
module Passes
|
7
|
-
# RESPONSIBILITY: Validate consistency between declared and inferred types
|
8
|
-
# DEPENDENCIES: :input_metadata from InputCollector, :inferred_types from TypeInferencerPass
|
9
|
-
# PRODUCES: None (validation only)
|
10
|
-
# INTERFACE: new(schema, state).run(errors)
|
11
|
-
class TypeConsistencyChecker < PassBase
|
12
|
-
def run(errors)
|
13
|
-
input_meta = get_state(:input_metadata, required: false) || {}
|
14
|
-
|
15
|
-
# First, validate that all declared types are valid
|
16
|
-
validate_declared_types(input_meta, errors)
|
17
|
-
|
18
|
-
# Then check basic consistency (placeholder for now)
|
19
|
-
# In a full implementation, this would do sophisticated usage analysis
|
20
|
-
state
|
21
|
-
end
|
22
|
-
|
23
|
-
private
|
24
|
-
|
25
|
-
def validate_declared_types(input_meta, errors)
|
26
|
-
input_meta.each do |field_name, meta|
|
27
|
-
declared_type = meta[:type]
|
28
|
-
next unless declared_type # Skip fields without declared types
|
29
|
-
next if Kumi::Core::Types.valid_type?(declared_type)
|
30
|
-
|
31
|
-
# Find the input field declaration for proper location information
|
32
|
-
field_decl = find_input_field_declaration(field_name)
|
33
|
-
location = field_decl&.loc
|
34
|
-
|
35
|
-
report_type_error(errors, "Invalid type declaration for field :#{field_name}: #{declared_type.inspect}", location: location)
|
36
|
-
end
|
37
|
-
end
|
38
|
-
|
39
|
-
def find_input_field_declaration(field_name)
|
40
|
-
return nil unless schema
|
41
|
-
|
42
|
-
schema.inputs.find { |input_decl| input_decl.name == field_name }
|
43
|
-
end
|
44
|
-
end
|
45
|
-
end
|
46
|
-
end
|
47
|
-
end
|
48
|
-
end
|
@@ -1,98 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require_relative "errors"
|
4
|
-
|
5
|
-
module Kumi
|
6
|
-
module Core
|
7
|
-
module Functions
|
8
|
-
# Represents a single dimension in a signature with NEP 20 support.
|
9
|
-
#
|
10
|
-
# A dimension can be:
|
11
|
-
# - Named dimension (symbol): :i, :j, :n
|
12
|
-
# - Fixed-size dimension (integer): 2, 3, 10
|
13
|
-
# - With modifiers:
|
14
|
-
# - flexible (?): can be omitted if not present in all operands
|
15
|
-
# - broadcastable (|1): can broadcast against size-1 dimensions
|
16
|
-
#
|
17
|
-
# Examples:
|
18
|
-
# Dimension.new(:i) # named dimension 'i'
|
19
|
-
# Dimension.new(3) # fixed-size dimension of size 3
|
20
|
-
# Dimension.new(:n, flexible: true) # dimension 'n' that can be omitted
|
21
|
-
# Dimension.new(:i, broadcastable: true) # dimension 'i' that can broadcast
|
22
|
-
class Dimension
|
23
|
-
attr_reader :name, :flexible, :broadcastable
|
24
|
-
|
25
|
-
def initialize(name, flexible: false, broadcastable: false)
|
26
|
-
@name = name
|
27
|
-
@flexible = flexible
|
28
|
-
@broadcastable = broadcastable
|
29
|
-
|
30
|
-
validate!
|
31
|
-
freeze
|
32
|
-
end
|
33
|
-
|
34
|
-
def fixed_size?
|
35
|
-
@name.is_a?(Integer)
|
36
|
-
end
|
37
|
-
|
38
|
-
def named?
|
39
|
-
@name.is_a?(Symbol)
|
40
|
-
end
|
41
|
-
|
42
|
-
def flexible?
|
43
|
-
@flexible
|
44
|
-
end
|
45
|
-
|
46
|
-
def broadcastable?
|
47
|
-
@broadcastable
|
48
|
-
end
|
49
|
-
|
50
|
-
def size
|
51
|
-
fixed_size? ? @name : nil
|
52
|
-
end
|
53
|
-
|
54
|
-
def ==(other)
|
55
|
-
other.is_a?(Dimension) &&
|
56
|
-
name == other.name &&
|
57
|
-
flexible == other.flexible &&
|
58
|
-
broadcastable == other.broadcastable
|
59
|
-
end
|
60
|
-
|
61
|
-
def eql?(other)
|
62
|
-
self == other
|
63
|
-
end
|
64
|
-
|
65
|
-
def hash
|
66
|
-
[name, flexible, broadcastable].hash
|
67
|
-
end
|
68
|
-
|
69
|
-
def to_s
|
70
|
-
str = name.to_s
|
71
|
-
str += "?" if flexible?
|
72
|
-
str += "|1" if broadcastable?
|
73
|
-
str
|
74
|
-
end
|
75
|
-
|
76
|
-
def inspect
|
77
|
-
"#<Dimension #{self}>"
|
78
|
-
end
|
79
|
-
|
80
|
-
private
|
81
|
-
|
82
|
-
def validate!
|
83
|
-
unless name.is_a?(Symbol) || name.is_a?(Integer)
|
84
|
-
raise SignatureError, "dimension name must be a symbol or integer, got: #{name.inspect}"
|
85
|
-
end
|
86
|
-
|
87
|
-
raise SignatureError, "fixed-size dimension must be positive, got: #{name}" if name.is_a?(Integer) && name <= 0
|
88
|
-
|
89
|
-
raise SignatureError, "dimension cannot be both flexible and broadcastable" if flexible? && broadcastable?
|
90
|
-
|
91
|
-
return unless fixed_size? && flexible?
|
92
|
-
|
93
|
-
raise SignatureError, "fixed-size dimension cannot be flexible"
|
94
|
-
end
|
95
|
-
end
|
96
|
-
end
|
97
|
-
end
|
98
|
-
end
|
@@ -1,20 +0,0 @@
|
|
1
|
-
module Kumi::Core::Functions
|
2
|
-
DType = Struct.new(:name, keyword_init: true)
|
3
|
-
module DTypes
|
4
|
-
BOOL = DType.new(name: :bool)
|
5
|
-
INT = DType.new(name: :int)
|
6
|
-
FLOAT = DType.new(name: :float)
|
7
|
-
STRING = DType.new(name: :string)
|
8
|
-
DATETIME = DType.new(name: :datetime)
|
9
|
-
ANY = DType.new(name: :any)
|
10
|
-
end
|
11
|
-
|
12
|
-
module Promotion
|
13
|
-
# super simple table; extend later
|
14
|
-
TABLE = {
|
15
|
-
%i[int int] => :int, %i[int float] => :float, %i[float int] => :float, %i[float float] => :float,
|
16
|
-
%i[bool int] => :int, %i[bool float] => :float, %i[bool bool] => :bool
|
17
|
-
}
|
18
|
-
def self.promote(a, b) = TABLE[[a, b]] || TABLE[[b, a]] || :any
|
19
|
-
end
|
20
|
-
end
|
@@ -1,45 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Kumi
|
4
|
-
module Core
|
5
|
-
module Functions
|
6
|
-
KernelHandle = Struct.new(:kind, :callable, :null_policy, :options, keyword_init: true)
|
7
|
-
|
8
|
-
module KernelAdapter
|
9
|
-
module_function
|
10
|
-
|
11
|
-
def build_for(function, backend_entry)
|
12
|
-
impl = backend_entry.impl
|
13
|
-
mod = ruby_module_for(function)
|
14
|
-
|
15
|
-
raise Kumi::Core::Errors::CompilationError, "Missing Ruby kernel #{impl} for #{function.name}" unless mod.respond_to?(impl)
|
16
|
-
|
17
|
-
kind = function.class_sym
|
18
|
-
KernelHandle.new(
|
19
|
-
kind: kind,
|
20
|
-
callable: mod.method(impl),
|
21
|
-
null_policy: function.null_policy,
|
22
|
-
options: function.options || {}
|
23
|
-
)
|
24
|
-
end
|
25
|
-
|
26
|
-
def ruby_module_for(function)
|
27
|
-
case function.domain.to_sym
|
28
|
-
when :core
|
29
|
-
function.class_sym == :aggregate ? Kumi::Kernels::Ruby::AggregateCore : Kumi::Kernels::Ruby::ScalarCore
|
30
|
-
when :string
|
31
|
-
Kumi::Kernels::Ruby::StringScalar
|
32
|
-
when :datetime
|
33
|
-
Kumi::Kernels::Ruby::DatetimeScalar
|
34
|
-
when :struct
|
35
|
-
Kumi::Kernels::Ruby::VectorStruct
|
36
|
-
when :mask
|
37
|
-
Kumi::Kernels::Ruby::MaskScalar
|
38
|
-
else
|
39
|
-
raise Kumi::Core::Errors::CompilationError, "Unknown domain #{function.domain}"
|
40
|
-
end
|
41
|
-
end
|
42
|
-
end
|
43
|
-
end
|
44
|
-
end
|
45
|
-
end
|
@@ -1,119 +0,0 @@
|
|
1
|
-
require "yaml"
|
2
|
-
module Kumi::Core::Functions
|
3
|
-
Function = Data.define(:name, :domain, :opset, :class_sym, :signatures,
|
4
|
-
:type_vars, :dtypes, :null_policy, :algebra, :vectorization,
|
5
|
-
:options, :traits, :shape_fn, :kernels)
|
6
|
-
|
7
|
-
KernelEntry = Data.define(:backend, :impl, :priority, :when_)
|
8
|
-
|
9
|
-
class Loader
|
10
|
-
def self.load_file(path)
|
11
|
-
doc = YAML.load_file(path)
|
12
|
-
functions = doc.map { |h| build_function(h) }
|
13
|
-
validate!(functions)
|
14
|
-
functions.freeze
|
15
|
-
end
|
16
|
-
|
17
|
-
def self.build_function(h)
|
18
|
-
sigs = Array(h.fetch("signature")).map { |s| parse_signature(s) }
|
19
|
-
kernels = Array(h.fetch("kernels", [])).map do |k|
|
20
|
-
KernelEntry.new(backend: k.fetch("backend").to_sym,
|
21
|
-
impl: k.fetch("impl").to_sym,
|
22
|
-
priority: k.fetch("priority", 0).to_i,
|
23
|
-
when_: k["when"]&.transform_keys!(&:to_sym))
|
24
|
-
end
|
25
|
-
Function.new(
|
26
|
-
name: h.fetch("name"),
|
27
|
-
domain: h.fetch("domain"),
|
28
|
-
opset: h.fetch("opset").to_i,
|
29
|
-
class_sym: h.fetch("class").to_sym,
|
30
|
-
signatures: sigs,
|
31
|
-
type_vars: h["type_vars"] || {},
|
32
|
-
dtypes: h["dtypes"] || {},
|
33
|
-
null_policy: (h["null_policy"] || "propagate").to_sym,
|
34
|
-
algebra: (h["algebra"] || {}).transform_keys!(&:to_sym),
|
35
|
-
vectorization: (h["vectorization"] || {}).transform_keys!(&:to_sym),
|
36
|
-
options: h["options"] || {},
|
37
|
-
traits: (h["traits"] || {}).transform_keys!(&:to_sym),
|
38
|
-
shape_fn: h["shape_fn"],
|
39
|
-
kernels: kernels
|
40
|
-
)
|
41
|
-
end
|
42
|
-
|
43
|
-
# " (i),(j)->(i,j)@product " → Signature
|
44
|
-
def self.parse_signature(s)
|
45
|
-
lhs, rhs = s.split("->").map(&:strip)
|
46
|
-
out_axes, policy = rhs.split("@").map(&:strip)
|
47
|
-
|
48
|
-
# Parse input shapes by splitting on commas between parentheses
|
49
|
-
in_shapes = parse_input_shapes(lhs)
|
50
|
-
|
51
|
-
Signature.new(in_shapes: in_shapes,
|
52
|
-
out_shape: parse_axes(out_axes),
|
53
|
-
join_policy: policy&.to_sym)
|
54
|
-
end
|
55
|
-
|
56
|
-
# Parse "(i),(j)" or "(i,j),(k,l)" properly
|
57
|
-
def self.parse_input_shapes(lhs)
|
58
|
-
# Find all parenthesized groups
|
59
|
-
shapes = []
|
60
|
-
current_pos = 0
|
61
|
-
|
62
|
-
while current_pos < lhs.length
|
63
|
-
# Find the next opening parenthesis
|
64
|
-
start_paren = lhs.index('(', current_pos)
|
65
|
-
break unless start_paren
|
66
|
-
|
67
|
-
# Find the matching closing parenthesis
|
68
|
-
paren_count = 0
|
69
|
-
end_paren = start_paren
|
70
|
-
|
71
|
-
(start_paren..lhs.length-1).each do |i|
|
72
|
-
case lhs[i]
|
73
|
-
when '('
|
74
|
-
paren_count += 1
|
75
|
-
when ')'
|
76
|
-
paren_count -= 1
|
77
|
-
if paren_count == 0
|
78
|
-
end_paren = i
|
79
|
-
break
|
80
|
-
end
|
81
|
-
end
|
82
|
-
end
|
83
|
-
|
84
|
-
# Extract the shape
|
85
|
-
shape_str = lhs[start_paren..end_paren]
|
86
|
-
shapes << parse_axes(shape_str)
|
87
|
-
current_pos = end_paren + 1
|
88
|
-
end
|
89
|
-
|
90
|
-
shapes
|
91
|
-
end
|
92
|
-
|
93
|
-
def self.parse_axes(txt)
|
94
|
-
txt = txt.strip
|
95
|
-
return [] if txt == "()" || txt.empty?
|
96
|
-
|
97
|
-
inner = txt.sub("(", "").sub(")", "")
|
98
|
-
if inner.empty?
|
99
|
-
[]
|
100
|
-
else
|
101
|
-
inner.split(",").map do |a|
|
102
|
-
dim_name = a.strip.to_sym
|
103
|
-
Dimension.new(dim_name) # Convert to Dimension objects
|
104
|
-
end
|
105
|
-
end
|
106
|
-
end
|
107
|
-
|
108
|
-
def self.validate!(fns)
|
109
|
-
names = {}
|
110
|
-
fns.each do |f|
|
111
|
-
key = [f.domain, f.name, f.opset]
|
112
|
-
raise "duplicate function #{key}" if names[key]
|
113
|
-
|
114
|
-
names[key] = true
|
115
|
-
raise "no kernels for #{f.name}" if f.kernels.empty?
|
116
|
-
end
|
117
|
-
end
|
118
|
-
end
|
119
|
-
end
|