kumi 0.0.16 → 0.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/golden/cascade_logic/schema.kumi +3 -1
- data/lib/kumi/analyzer.rb +8 -11
- data/lib/kumi/core/analyzer/passes/broadcast_detector.rb +0 -81
- data/lib/kumi/core/analyzer/passes/lower_to_ir_pass.rb +0 -36
- data/lib/kumi/core/analyzer/passes/toposorter.rb +1 -36
- data/lib/kumi/core/analyzer/passes/unsat_detector.rb +8 -191
- data/lib/kumi/core/compiler/access_builder.rb +5 -8
- data/lib/kumi/version.rb +1 -1
- metadata +2 -25
- data/BACKLOG.md +0 -34
- data/config/functions.yaml +0 -352
- data/docs/functions/analyzer_integration.md +0 -199
- data/docs/functions/signatures.md +0 -171
- data/examples/hash_objects_demo.rb +0 -138
- data/lib/kumi/core/analyzer/passes/function_signature_pass.rb +0 -199
- data/lib/kumi/core/analyzer/passes/type_consistency_checker.rb +0 -48
- data/lib/kumi/core/functions/dimension.rb +0 -98
- data/lib/kumi/core/functions/dtypes.rb +0 -20
- data/lib/kumi/core/functions/errors.rb +0 -11
- data/lib/kumi/core/functions/kernel_adapter.rb +0 -45
- data/lib/kumi/core/functions/loader.rb +0 -119
- data/lib/kumi/core/functions/registry_v2.rb +0 -68
- data/lib/kumi/core/functions/shape.rb +0 -70
- data/lib/kumi/core/functions/signature.rb +0 -122
- data/lib/kumi/core/functions/signature_parser.rb +0 -86
- data/lib/kumi/core/functions/signature_resolver.rb +0 -272
- data/lib/kumi/kernels/ruby/aggregate_core.rb +0 -105
- data/lib/kumi/kernels/ruby/datetime_scalar.rb +0 -21
- data/lib/kumi/kernels/ruby/mask_scalar.rb +0 -15
- data/lib/kumi/kernels/ruby/scalar_core.rb +0 -63
- data/lib/kumi/kernels/ruby/string_scalar.rb +0 -19
- data/lib/kumi/kernels/ruby/vector_struct.rb +0 -39
@@ -1,272 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require_relative "errors"
|
4
|
-
require_relative "shape"
|
5
|
-
require_relative "signature"
|
6
|
-
|
7
|
-
module Kumi
|
8
|
-
module Core
|
9
|
-
module Functions
|
10
|
-
# Given a set of signatures and actual argument shapes, pick the best match.
|
11
|
-
# Supports NEP 20 extensions: fixed-size, flexible, and broadcastable dimensions.
|
12
|
-
#
|
13
|
-
# Inputs:
|
14
|
-
# signatures : Array<Signature> (with Dimension objects)
|
15
|
-
# arg_shapes : Array<Array<Symbol|Integer>> e.g., [[:i], [:i]] or [[], [3]] or [[2, :i]]
|
16
|
-
#
|
17
|
-
# Returns:
|
18
|
-
# { signature:, result_axes:, join_policy:, dropped_axes:, effective_signature: }
|
19
|
-
#
|
20
|
-
# NEP 20 Matching rules:
|
21
|
-
# - Arity must match exactly (before flexible dimension resolution).
|
22
|
-
# - Fixed-size dimensions (integers) must match exactly.
|
23
|
-
# - Flexible dimensions (?) can be omitted if not present in all operands.
|
24
|
-
# - Broadcastable dimensions (|1) can match scalar or size-1 dimensions.
|
25
|
-
# - For each param position, shapes are checked according to NEP 20 rules.
|
26
|
-
# - We prefer exact matches, then flexible matches, then broadcast matches.
|
27
|
-
class SignatureResolver
|
28
|
-
class << self
|
29
|
-
def choose(signatures:, arg_shapes:)
|
30
|
-
# Handle empty arg_shapes for zero-arity functions
|
31
|
-
arg_shapes = [] if arg_shapes.nil?
|
32
|
-
sanity_check_args!(arg_shapes)
|
33
|
-
|
34
|
-
candidates = signatures.map do |sig|
|
35
|
-
score = match_score(sig, arg_shapes)
|
36
|
-
next if score.nil?
|
37
|
-
|
38
|
-
# Convert arg_shapes to normalized Dimension arrays for environment building
|
39
|
-
normalized_args = arg_shapes.map { |shape| normalize_shape(shape) }
|
40
|
-
env = build_dimension_environment(sig, normalized_args)
|
41
|
-
next if env.nil? # Skip candidates with dimension conflicts
|
42
|
-
|
43
|
-
{
|
44
|
-
signature: sig,
|
45
|
-
score: score,
|
46
|
-
result_axes: sig.out_shape.map(&:name), # Convert Dimension objects to names for backward compatibility
|
47
|
-
join_policy: sig.join_policy,
|
48
|
-
dropped_axes: sig.dropped_axes.map { |name| name.is_a?(Symbol) ? name : name.to_sym }, # Convert to symbols
|
49
|
-
env: env
|
50
|
-
}
|
51
|
-
end.compact
|
52
|
-
|
53
|
-
raise SignatureMatchError, mismatch_message(signatures, arg_shapes) if candidates.empty?
|
54
|
-
|
55
|
-
# Lower score is better: 0 = exact-everywhere, then number of broadcasts
|
56
|
-
best = candidates.min_by { |c| c[:score] }
|
57
|
-
|
58
|
-
# Add effective signature and environment for analyzer/lowering
|
59
|
-
best[:effective_signature] = {
|
60
|
-
in_shapes: best[:signature].in_shapes.map { |dims| dims.map(&:name) },
|
61
|
-
out_shape: best[:signature].out_shape.map(&:name),
|
62
|
-
join_policy: best[:signature].join_policy
|
63
|
-
}
|
64
|
-
# env is already included from candidate building
|
65
|
-
|
66
|
-
best
|
67
|
-
end
|
68
|
-
|
69
|
-
private
|
70
|
-
|
71
|
-
def sanity_check_args!(arg_shapes)
|
72
|
-
unless arg_shapes.is_a?(Array) &&
|
73
|
-
arg_shapes.all? { |s| s.is_a?(Array) && s.all? { |a| a.is_a?(Symbol) || a.is_a?(Integer) } }
|
74
|
-
raise SignatureMatchError, "arg_shapes must be an array of dimension arrays (symbols or integers), got: #{arg_shapes.inspect}"
|
75
|
-
end
|
76
|
-
end
|
77
|
-
|
78
|
-
# Returns an integer "broadcast cost" or nil if not matchable.
|
79
|
-
# Lower score = better match: 0 = exact, then increasing cost for broadcasts/flexibility
|
80
|
-
def match_score(sig, arg_shapes)
|
81
|
-
return nil unless sig.arity == arg_shapes.length
|
82
|
-
|
83
|
-
# Convert arg_shapes to normalized Dimension arrays for comparison
|
84
|
-
normalized_args = arg_shapes.map { |shape| normalize_shape(shape) }
|
85
|
-
|
86
|
-
# Try to match each argument against its expected signature shape
|
87
|
-
cost = 0
|
88
|
-
sig.in_shapes.each_with_index do |expected_dims, idx|
|
89
|
-
got_dims = normalized_args[idx]
|
90
|
-
arg_cost = match_argument_cost(got: got_dims, expected: expected_dims)
|
91
|
-
return nil if arg_cost.nil?
|
92
|
-
|
93
|
-
cost += arg_cost
|
94
|
-
end
|
95
|
-
|
96
|
-
# Additional checks for join_policy constraints
|
97
|
-
return nil unless valid_join_policy?(sig, normalized_args)
|
98
|
-
|
99
|
-
cost
|
100
|
-
end
|
101
|
-
|
102
|
-
private
|
103
|
-
|
104
|
-
# Convert a shape array (symbols/integers) to normalized Dimension array
|
105
|
-
def normalize_shape(shape)
|
106
|
-
shape.map do |dim|
|
107
|
-
case dim
|
108
|
-
when Symbol
|
109
|
-
Dimension.new(dim)
|
110
|
-
when Integer
|
111
|
-
Dimension.new(dim)
|
112
|
-
else
|
113
|
-
raise SignatureMatchError, "Invalid dimension type: #{dim.class}"
|
114
|
-
end
|
115
|
-
end
|
116
|
-
end
|
117
|
-
|
118
|
-
# Calculate cost of matching one argument against expected dimensions
|
119
|
-
def match_argument_cost(got:, expected:)
|
120
|
-
# Handle scalar first
|
121
|
-
if got.empty?
|
122
|
-
return expected.empty? ? 0 : (expected.any?(&:flexible?) ? 10 : 1) # scalar broadcast or flexible-tail
|
123
|
-
end
|
124
|
-
|
125
|
-
# Try strict matching first if no flexible dimensions
|
126
|
-
if !expected.any?(&:flexible?) && got.length == expected.length
|
127
|
-
total = 0
|
128
|
-
got.zip(expected).each do |g, e|
|
129
|
-
c = match_dimension_cost(got: g, expected: e)
|
130
|
-
return nil if c.nil?
|
131
|
-
total += c
|
132
|
-
end
|
133
|
-
return total
|
134
|
-
end
|
135
|
-
|
136
|
-
# Use right-aligned flexible matching
|
137
|
-
right_align_match(got: got, expected: expected)
|
138
|
-
end
|
139
|
-
|
140
|
-
# Right-aligned matching for flexible dimensions (NEP 20 ? modifier)
|
141
|
-
def right_align_match(got:, expected:)
|
142
|
-
gi = got.length - 1
|
143
|
-
ei = expected.length - 1
|
144
|
-
cost = 0
|
145
|
-
|
146
|
-
while ei >= 0
|
147
|
-
exp = expected[ei]
|
148
|
-
|
149
|
-
if exp.flexible? && gi < 0
|
150
|
-
# optional tail dimension that we don't have → ok, consume expected only
|
151
|
-
ei -= 1
|
152
|
-
cost += 10
|
153
|
-
next
|
154
|
-
end
|
155
|
-
|
156
|
-
return nil if gi < 0 # ran out of got dims and exp wasn't flexible
|
157
|
-
|
158
|
-
got_dim = got[gi]
|
159
|
-
dim_cost = match_dimension_cost(got: got_dim, expected: exp)
|
160
|
-
if dim_cost.nil?
|
161
|
-
# if exp is flexible, we can try to drop it
|
162
|
-
if exp.flexible?
|
163
|
-
ei -= 1
|
164
|
-
cost += 10
|
165
|
-
next
|
166
|
-
else
|
167
|
-
return nil
|
168
|
-
end
|
169
|
-
else
|
170
|
-
cost += dim_cost
|
171
|
-
gi -= 1
|
172
|
-
ei -= 1
|
173
|
-
end
|
174
|
-
end
|
175
|
-
|
176
|
-
# if we still have leftover got dims, argument is longer than expected → not a match
|
177
|
-
return nil if gi >= 0
|
178
|
-
|
179
|
-
cost
|
180
|
-
end
|
181
|
-
|
182
|
-
# Calculate cost of matching one dimension against another
|
183
|
-
def match_dimension_cost(got:, expected:)
|
184
|
-
return 0 if got == expected # Exact match
|
185
|
-
|
186
|
-
# Fixed-size equality
|
187
|
-
if got.fixed_size? && expected.fixed_size?
|
188
|
-
return got.size == expected.size ? 0 : nil
|
189
|
-
end
|
190
|
-
|
191
|
-
# Same symbolic name (ignoring modifiers) → ok unless one is fixed and the other isn't (penalize)
|
192
|
-
if got.named? && expected.named? && got.name == expected.name
|
193
|
-
return (got.fixed_size? || expected.fixed_size?) ? 2 : 0
|
194
|
-
end
|
195
|
-
|
196
|
-
# Broadcastable expected dim accepts scalar or size-1
|
197
|
-
if expected.broadcastable?
|
198
|
-
# scalar at argument level would have been handled in match_argument_cost
|
199
|
-
# so here we check for size-1 fixed dimensions
|
200
|
-
return 3 if got.fixed_size? && got.size == 1
|
201
|
-
# Named dimensions that could be size-1 at runtime also get broadcast cost
|
202
|
-
return 3 if got.named?
|
203
|
-
end
|
204
|
-
|
205
|
-
nil # No match possible
|
206
|
-
end
|
207
|
-
|
208
|
-
# Check if join_policy constraints are satisfied
|
209
|
-
def valid_join_policy?(sig, normalized_args)
|
210
|
-
return true if sig.join_policy # :zip or :product allows different axes
|
211
|
-
|
212
|
-
# nil join_policy: check if dimension names are consistent
|
213
|
-
non_scalar_args = normalized_args.reject { |a| Shape.scalar?(a) }
|
214
|
-
return true if non_scalar_args.empty?
|
215
|
-
|
216
|
-
# For nil join_policy, we allow different dimension names if:
|
217
|
-
# 1. All args have same dimension names (element-wise operations), OR
|
218
|
-
# 2. The constraint solver can validate cross-dimensional consistency (like matmul)
|
219
|
-
first_names = non_scalar_args.first.map(&:name)
|
220
|
-
same_names = non_scalar_args.all? { |arg| arg.map(&:name) == first_names }
|
221
|
-
|
222
|
-
return true if same_names
|
223
|
-
|
224
|
-
# If dimension names differ, check if constraint solver can handle it
|
225
|
-
# This allows operations like matmul where dimensions are linked across arguments
|
226
|
-
env = build_dimension_environment(sig, normalized_args)
|
227
|
-
!env.nil?
|
228
|
-
end
|
229
|
-
|
230
|
-
def mismatch_message(signatures, arg_shapes)
|
231
|
-
sigs = signatures.map(&:inspect).join(", ")
|
232
|
-
"no matching signature for shapes #{pp_shapes(arg_shapes)} among [#{sigs}]"
|
233
|
-
end
|
234
|
-
|
235
|
-
def pp_shapes(shapes)
|
236
|
-
shapes.map { |ax| "(#{ax.join(',')})" }.join(", ")
|
237
|
-
end
|
238
|
-
|
239
|
-
# Build dimension environment by checking consistency of named dimensions across arguments
|
240
|
-
def build_dimension_environment(sig, normalized_args)
|
241
|
-
env = {}
|
242
|
-
|
243
|
-
# Walk all expected dimensions across all arguments
|
244
|
-
sig.in_shapes.each_with_index do |expected_shape, arg_idx|
|
245
|
-
got_shape = normalized_args[arg_idx] || []
|
246
|
-
|
247
|
-
expected_shape.each_with_index do |exp_dim, dim_idx|
|
248
|
-
next unless exp_dim.named? && dim_idx < got_shape.length
|
249
|
-
|
250
|
-
got_dim = got_shape[dim_idx]
|
251
|
-
dim_name = exp_dim.name
|
252
|
-
|
253
|
-
# Check for consistency: same dimension name must map to same concrete value
|
254
|
-
if env.key?(dim_name)
|
255
|
-
# If we've seen this dimension name before, it must match
|
256
|
-
if env[dim_name] != got_dim
|
257
|
-
return nil # Inconsistent binding - signature doesn't match
|
258
|
-
end
|
259
|
-
else
|
260
|
-
# First time seeing this dimension name - record the binding
|
261
|
-
env[dim_name] = got_dim
|
262
|
-
end
|
263
|
-
end
|
264
|
-
end
|
265
|
-
|
266
|
-
env
|
267
|
-
end
|
268
|
-
end
|
269
|
-
end
|
270
|
-
end
|
271
|
-
end
|
272
|
-
end
|
@@ -1,105 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Kumi
|
4
|
-
module Kernels
|
5
|
-
module Ruby
|
6
|
-
module AggregateCore
|
7
|
-
module_function
|
8
|
-
|
9
|
-
def kumi_sum(enum, skip_nulls: true, min_count: 0)
|
10
|
-
total = 0
|
11
|
-
count = 0
|
12
|
-
enum.each do |x|
|
13
|
-
next if skip_nulls && x.nil?
|
14
|
-
|
15
|
-
total += x
|
16
|
-
count += 1
|
17
|
-
end
|
18
|
-
return nil if count < min_count
|
19
|
-
|
20
|
-
total
|
21
|
-
end
|
22
|
-
|
23
|
-
def kumi_min(enum, skip_nulls: true, min_count: 0)
|
24
|
-
best = nil
|
25
|
-
count = 0
|
26
|
-
enum.each do |x|
|
27
|
-
next if skip_nulls && x.nil?
|
28
|
-
|
29
|
-
best = x if best.nil? || x < best
|
30
|
-
count += 1
|
31
|
-
end
|
32
|
-
return nil if count < min_count
|
33
|
-
|
34
|
-
best
|
35
|
-
end
|
36
|
-
|
37
|
-
def kumi_max(enum, skip_nulls: true, min_count: 0)
|
38
|
-
best = nil
|
39
|
-
count = 0
|
40
|
-
enum.each do |x|
|
41
|
-
next if skip_nulls && x.nil?
|
42
|
-
|
43
|
-
best = x if best.nil? || x > best
|
44
|
-
count += 1
|
45
|
-
end
|
46
|
-
return nil if count < min_count
|
47
|
-
|
48
|
-
best
|
49
|
-
end
|
50
|
-
|
51
|
-
def kumi_mean(enum, skip_nulls: true, min_count: 0)
|
52
|
-
total = 0.0
|
53
|
-
count = 0
|
54
|
-
enum.each do |x|
|
55
|
-
next if skip_nulls && x.nil?
|
56
|
-
|
57
|
-
total += x
|
58
|
-
count += 1
|
59
|
-
end
|
60
|
-
return nil if count < [min_count, 1].max
|
61
|
-
|
62
|
-
total / count
|
63
|
-
end
|
64
|
-
|
65
|
-
def kumi_any(enum, skip_nulls: true, min_count: 0)
|
66
|
-
count = 0
|
67
|
-
enum.each do |x|
|
68
|
-
next if skip_nulls && x.nil?
|
69
|
-
|
70
|
-
return true if x
|
71
|
-
count += 1
|
72
|
-
end
|
73
|
-
return nil if count < min_count
|
74
|
-
|
75
|
-
false
|
76
|
-
end
|
77
|
-
|
78
|
-
def kumi_all(enum, skip_nulls: true, min_count: 0)
|
79
|
-
count = 0
|
80
|
-
enum.each do |x|
|
81
|
-
next if skip_nulls && x.nil?
|
82
|
-
|
83
|
-
return false unless x
|
84
|
-
count += 1
|
85
|
-
end
|
86
|
-
return nil if count < min_count
|
87
|
-
|
88
|
-
true
|
89
|
-
end
|
90
|
-
|
91
|
-
def kumi_count(enum, skip_nulls: true, min_count: 0)
|
92
|
-
count = 0
|
93
|
-
enum.each do |x|
|
94
|
-
next if skip_nulls && x.nil?
|
95
|
-
|
96
|
-
count += 1
|
97
|
-
end
|
98
|
-
return nil if count < min_count
|
99
|
-
|
100
|
-
count
|
101
|
-
end
|
102
|
-
end
|
103
|
-
end
|
104
|
-
end
|
105
|
-
end
|
@@ -1,21 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require "date"
|
4
|
-
|
5
|
-
module Kumi
|
6
|
-
module Kernels
|
7
|
-
module Ruby
|
8
|
-
module DatetimeScalar
|
9
|
-
module_function
|
10
|
-
|
11
|
-
def dt_add_days(d, n)
|
12
|
-
d + n
|
13
|
-
end
|
14
|
-
|
15
|
-
def dt_diff_days(d1, d2)
|
16
|
-
(d1 - d2).to_i
|
17
|
-
end
|
18
|
-
end
|
19
|
-
end
|
20
|
-
end
|
21
|
-
end
|
@@ -1,63 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Kumi
|
4
|
-
module Kernels
|
5
|
-
module Ruby
|
6
|
-
module ScalarCore
|
7
|
-
module_function
|
8
|
-
|
9
|
-
def kumi_add(a, b)
|
10
|
-
a + b
|
11
|
-
end
|
12
|
-
|
13
|
-
def kumi_sub(a, b)
|
14
|
-
a - b
|
15
|
-
end
|
16
|
-
|
17
|
-
def kumi_mul(a, b)
|
18
|
-
a * b
|
19
|
-
end
|
20
|
-
|
21
|
-
def kumi_div(a, b)
|
22
|
-
a / b.to_f
|
23
|
-
end
|
24
|
-
|
25
|
-
def kumi_eq(a, b)
|
26
|
-
a == b
|
27
|
-
end
|
28
|
-
|
29
|
-
def kumi_gt(a, b)
|
30
|
-
a > b
|
31
|
-
end
|
32
|
-
|
33
|
-
def kumi_gte(a, b)
|
34
|
-
a >= b
|
35
|
-
end
|
36
|
-
|
37
|
-
def kumi_lt(a, b)
|
38
|
-
a < b
|
39
|
-
end
|
40
|
-
|
41
|
-
def kumi_lte(a, b)
|
42
|
-
a <= b
|
43
|
-
end
|
44
|
-
|
45
|
-
def kumi_ne(a, b)
|
46
|
-
a != b
|
47
|
-
end
|
48
|
-
|
49
|
-
def kumi_and(a, b)
|
50
|
-
a && b
|
51
|
-
end
|
52
|
-
|
53
|
-
def kumi_or(a, b)
|
54
|
-
a || b
|
55
|
-
end
|
56
|
-
|
57
|
-
def kumi_not(a)
|
58
|
-
!a
|
59
|
-
end
|
60
|
-
end
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|
@@ -1,39 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Kumi
|
4
|
-
module Kernels
|
5
|
-
module Ruby
|
6
|
-
module VectorStruct
|
7
|
-
module_function
|
8
|
-
|
9
|
-
def size(vec)
|
10
|
-
vec&.size
|
11
|
-
end
|
12
|
-
|
13
|
-
def join_zip(left, right)
|
14
|
-
raise NotImplementedError, "join operations should be implemented in IR/VM"
|
15
|
-
end
|
16
|
-
|
17
|
-
def join_product(left, right)
|
18
|
-
raise NotImplementedError, "join operations should be implemented in IR/VM"
|
19
|
-
end
|
20
|
-
|
21
|
-
def align_to(vec, target_axes)
|
22
|
-
raise NotImplementedError, "align_to should be implemented in IR/VM"
|
23
|
-
end
|
24
|
-
|
25
|
-
def lift(vec, indices)
|
26
|
-
raise NotImplementedError, "lift should be implemented in IR/VM"
|
27
|
-
end
|
28
|
-
|
29
|
-
def flatten(*args)
|
30
|
-
raise NotImplementedError, "flatten should be implemented in IR/VM"
|
31
|
-
end
|
32
|
-
|
33
|
-
def take(values, indices)
|
34
|
-
raise NotImplementedError, "take should be implemented in IR/VM"
|
35
|
-
end
|
36
|
-
end
|
37
|
-
end
|
38
|
-
end
|
39
|
-
end
|