torch-rb 0.16.0 → 0.17.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +6 -6
- data/codegen/native_functions.yaml +269 -161
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/sparse_functions.h +6 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +1 -1
- data/lib/torch/nn/functional.rb +11 -1
- data/lib/torch/nn/functional_attention.rb +5 -5
- data/lib/torch/tensor.rb +3 -1
- data/lib/torch/version.rb +1 -1
- metadata +9 -2
data/ext/torch/utils.h
CHANGED
data/lib/torch/nn/functional.rb
CHANGED
@@ -134,7 +134,7 @@ module Torch
|
|
134
134
|
raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
|
135
135
|
|
136
136
|
if mode == "constant"
|
137
|
-
|
137
|
+
Torch.constant_pad_nd(input, pad, value)
|
138
138
|
else
|
139
139
|
raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
|
140
140
|
|
@@ -481,6 +481,16 @@ module Torch
|
|
481
481
|
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
|
482
482
|
end
|
483
483
|
|
484
|
+
def normalize(input, p: 2.0, dim: 1, eps: 1e-12, out: nil)
|
485
|
+
if out.nil?
|
486
|
+
denom = input.norm(p, dim, keepdim: true).clamp_min(eps).expand_as(input)
|
487
|
+
input / denom
|
488
|
+
else
|
489
|
+
denom = input.norm(p, dim, keepdim: true).clamp_min!(eps).expand_as(input)
|
490
|
+
Torch.div(input, denom, out: out)
|
491
|
+
end
|
492
|
+
end
|
493
|
+
|
484
494
|
# vision
|
485
495
|
|
486
496
|
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
@@ -5,10 +5,10 @@ module Torch
|
|
5
5
|
def in_projection_packed(q, k, v, w, b: nil)
|
6
6
|
e = q.size(-1)
|
7
7
|
|
8
|
-
if k.eql?
|
9
|
-
if q.eql?
|
8
|
+
if k.eql?(v)
|
9
|
+
if q.eql?(k)
|
10
10
|
# self-attention
|
11
|
-
|
11
|
+
linear(q, w, b).chunk(3, dim: -1)
|
12
12
|
else
|
13
13
|
# encoder-decoder attention
|
14
14
|
w_q, w_kv = w.split_with_sizes([e, e * 2])
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
b_q, b_kv = b.split_with_sizes([e, e * 2])
|
19
19
|
end
|
20
20
|
|
21
|
-
|
21
|
+
[linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
|
22
22
|
end
|
23
23
|
else
|
24
24
|
w_q, w_k, w_v = w.chunk(3)
|
@@ -28,7 +28,7 @@ module Torch
|
|
28
28
|
b_q, b_k, b_v = b.chunk(3)
|
29
29
|
end
|
30
30
|
|
31
|
-
|
31
|
+
[linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
32
32
|
end
|
33
33
|
end
|
34
34
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -57,7 +57,7 @@ module Torch
|
|
57
57
|
if shape.empty?
|
58
58
|
arr
|
59
59
|
else
|
60
|
-
shape[1..-1].
|
60
|
+
shape[1..-1].reverse_each do |dim|
|
61
61
|
arr = arr.each_slice(dim)
|
62
62
|
end
|
63
63
|
arr.to_a
|
@@ -160,6 +160,7 @@ module Torch
|
|
160
160
|
# based on python_variable_indexing.cpp and
|
161
161
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
162
162
|
def [](*indexes)
|
163
|
+
indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
|
163
164
|
_index(indexes)
|
164
165
|
end
|
165
166
|
|
@@ -167,6 +168,7 @@ module Torch
|
|
167
168
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
168
169
|
def []=(*indexes, value)
|
169
170
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
171
|
+
indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
|
170
172
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
171
173
|
_index_put_custom(indexes, value)
|
172
174
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.17.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-07-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -43,17 +43,24 @@ files:
|
|
43
43
|
- ext/torch/ext.cpp
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
|
+
- ext/torch/fft_functions.h
|
46
47
|
- ext/torch/generator.cpp
|
47
48
|
- ext/torch/ivalue.cpp
|
48
49
|
- ext/torch/linalg.cpp
|
50
|
+
- ext/torch/linalg_functions.h
|
49
51
|
- ext/torch/nn.cpp
|
52
|
+
- ext/torch/nn_functions.h
|
50
53
|
- ext/torch/random.cpp
|
51
54
|
- ext/torch/ruby_arg_parser.cpp
|
52
55
|
- ext/torch/ruby_arg_parser.h
|
56
|
+
- ext/torch/sparse_functions.h
|
53
57
|
- ext/torch/special.cpp
|
58
|
+
- ext/torch/special_functions.h
|
54
59
|
- ext/torch/templates.h
|
55
60
|
- ext/torch/tensor.cpp
|
61
|
+
- ext/torch/tensor_functions.h
|
56
62
|
- ext/torch/torch.cpp
|
63
|
+
- ext/torch/torch_functions.h
|
57
64
|
- ext/torch/utils.h
|
58
65
|
- ext/torch/wrap_outputs.h
|
59
66
|
- lib/torch-rb.rb
|