torch-rb 0.16.0 → 0.17.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +6 -6
- data/codegen/native_functions.yaml +269 -161
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/sparse_functions.h +6 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +1 -1
- data/lib/torch/nn/functional.rb +11 -1
- data/lib/torch/nn/functional_attention.rb +5 -5
- data/lib/torch/nn/module.rb +24 -4
- data/lib/torch/tensor.rb +10 -4
- data/lib/torch/version.rb +1 -1
- metadata +11 -4
data/ext/torch/utils.h
CHANGED
data/lib/torch/nn/functional.rb
CHANGED
@@ -134,7 +134,7 @@ module Torch
|
|
134
134
|
raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
|
135
135
|
|
136
136
|
if mode == "constant"
|
137
|
-
|
137
|
+
Torch.constant_pad_nd(input, pad, value)
|
138
138
|
else
|
139
139
|
raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
|
140
140
|
|
@@ -481,6 +481,16 @@ module Torch
|
|
481
481
|
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
|
482
482
|
end
|
483
483
|
|
484
|
+
def normalize(input, p: 2.0, dim: 1, eps: 1e-12, out: nil)
|
485
|
+
if out.nil?
|
486
|
+
denom = input.norm(p, dim, keepdim: true).clamp_min(eps).expand_as(input)
|
487
|
+
input / denom
|
488
|
+
else
|
489
|
+
denom = input.norm(p, dim, keepdim: true).clamp_min!(eps).expand_as(input)
|
490
|
+
Torch.div(input, denom, out: out)
|
491
|
+
end
|
492
|
+
end
|
493
|
+
|
484
494
|
# vision
|
485
495
|
|
486
496
|
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
@@ -5,10 +5,10 @@ module Torch
|
|
5
5
|
def in_projection_packed(q, k, v, w, b: nil)
|
6
6
|
e = q.size(-1)
|
7
7
|
|
8
|
-
if k.eql?
|
9
|
-
if q.eql?
|
8
|
+
if k.eql?(v)
|
9
|
+
if q.eql?(k)
|
10
10
|
# self-attention
|
11
|
-
|
11
|
+
linear(q, w, b).chunk(3, dim: -1)
|
12
12
|
else
|
13
13
|
# encoder-decoder attention
|
14
14
|
w_q, w_kv = w.split_with_sizes([e, e * 2])
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
b_q, b_kv = b.split_with_sizes([e, e * 2])
|
19
19
|
end
|
20
20
|
|
21
|
-
|
21
|
+
[linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
|
22
22
|
end
|
23
23
|
else
|
24
24
|
w_q, w_k, w_v = w.chunk(3)
|
@@ -28,7 +28,7 @@ module Torch
|
|
28
28
|
b_q, b_k, b_v = b.chunk(3)
|
29
29
|
end
|
30
30
|
|
31
|
-
|
31
|
+
[linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
32
32
|
end
|
33
33
|
end
|
34
34
|
|
data/lib/torch/nn/module.rb
CHANGED
@@ -10,16 +10,23 @@ module Torch
|
|
10
10
|
@parameters = {}
|
11
11
|
@buffers = {}
|
12
12
|
@modules = {}
|
13
|
+
@non_persistent_buffers_set = Set.new
|
13
14
|
end
|
14
15
|
|
15
16
|
def forward
|
16
17
|
raise NotImplementedError
|
17
18
|
end
|
18
19
|
|
19
|
-
def register_buffer(name, tensor)
|
20
|
+
def register_buffer(name, tensor, persistent: true)
|
20
21
|
# TODO add checks
|
21
22
|
@buffers[name] = tensor
|
22
23
|
instance_variable_set("@#{name}", tensor)
|
24
|
+
|
25
|
+
if persistent
|
26
|
+
@non_persistent_buffers_set.delete(name)
|
27
|
+
else
|
28
|
+
@non_persistent_buffers_set << name
|
29
|
+
end
|
23
30
|
end
|
24
31
|
|
25
32
|
def register_parameter(name, param)
|
@@ -190,8 +197,18 @@ module Torch
|
|
190
197
|
named_buffers.values
|
191
198
|
end
|
192
199
|
|
193
|
-
|
194
|
-
|
200
|
+
# TODO set recurse: true in 0.18.0
|
201
|
+
def named_buffers(prefix: "", recurse: false)
|
202
|
+
buffers = {}
|
203
|
+
if recurse
|
204
|
+
named_children.each do |name, mod|
|
205
|
+
buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
|
206
|
+
end
|
207
|
+
end
|
208
|
+
(@buffers || {}).each do |k, v|
|
209
|
+
buffers[[prefix, k].join] = v
|
210
|
+
end
|
211
|
+
buffers
|
195
212
|
end
|
196
213
|
|
197
214
|
def children
|
@@ -390,7 +407,10 @@ module Torch
|
|
390
407
|
destination[prefix + k] = v
|
391
408
|
end
|
392
409
|
named_buffers.each do |k, v|
|
393
|
-
|
410
|
+
# TODO exclude v.nil?
|
411
|
+
if !@non_persistent_buffers_set.include?(k)
|
412
|
+
destination[prefix + k] = v
|
413
|
+
end
|
394
414
|
end
|
395
415
|
end
|
396
416
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -57,7 +57,7 @@ module Torch
|
|
57
57
|
if shape.empty?
|
58
58
|
arr
|
59
59
|
else
|
60
|
-
shape[1..-1].
|
60
|
+
shape[1..-1].reverse_each do |dim|
|
61
61
|
arr = arr.each_slice(dim)
|
62
62
|
end
|
63
63
|
arr.to_a
|
@@ -132,9 +132,13 @@ module Torch
|
|
132
132
|
|
133
133
|
# TODO read directly from memory
|
134
134
|
def numo
|
135
|
-
|
136
|
-
|
137
|
-
|
135
|
+
if dtype == :bool
|
136
|
+
Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
|
137
|
+
else
|
138
|
+
cls = Torch._dtype_to_numo[dtype]
|
139
|
+
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
140
|
+
cls.from_string(_data_str).reshape(*shape)
|
141
|
+
end
|
138
142
|
end
|
139
143
|
|
140
144
|
def requires_grad=(requires_grad)
|
@@ -160,6 +164,7 @@ module Torch
|
|
160
164
|
# based on python_variable_indexing.cpp and
|
161
165
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
162
166
|
def [](*indexes)
|
167
|
+
indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
|
163
168
|
_index(indexes)
|
164
169
|
end
|
165
170
|
|
@@ -167,6 +172,7 @@ module Torch
|
|
167
172
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
168
173
|
def []=(*indexes, value)
|
169
174
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
175
|
+
indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
|
170
176
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
171
177
|
_index_put_custom(indexes, value)
|
172
178
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.17.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-08-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 4.1
|
19
|
+
version: '4.1'
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 4.1
|
26
|
+
version: '4.1'
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|
@@ -43,17 +43,24 @@ files:
|
|
43
43
|
- ext/torch/ext.cpp
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
|
+
- ext/torch/fft_functions.h
|
46
47
|
- ext/torch/generator.cpp
|
47
48
|
- ext/torch/ivalue.cpp
|
48
49
|
- ext/torch/linalg.cpp
|
50
|
+
- ext/torch/linalg_functions.h
|
49
51
|
- ext/torch/nn.cpp
|
52
|
+
- ext/torch/nn_functions.h
|
50
53
|
- ext/torch/random.cpp
|
51
54
|
- ext/torch/ruby_arg_parser.cpp
|
52
55
|
- ext/torch/ruby_arg_parser.h
|
56
|
+
- ext/torch/sparse_functions.h
|
53
57
|
- ext/torch/special.cpp
|
58
|
+
- ext/torch/special_functions.h
|
54
59
|
- ext/torch/templates.h
|
55
60
|
- ext/torch/tensor.cpp
|
61
|
+
- ext/torch/tensor_functions.h
|
56
62
|
- ext/torch/torch.cpp
|
63
|
+
- ext/torch/torch_functions.h
|
57
64
|
- ext/torch/utils.h
|
58
65
|
- ext/torch/wrap_outputs.h
|
59
66
|
- lib/torch-rb.rb
|