torch-rb 0.16.0 → 0.17.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +6 -6
- data/codegen/native_functions.yaml +269 -161
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/sparse_functions.h +6 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +1 -1
- data/lib/torch/nn/functional.rb +11 -1
- data/lib/torch/nn/functional_attention.rb +5 -5
- data/lib/torch/nn/module.rb +24 -4
- data/lib/torch/tensor.rb +10 -4
- data/lib/torch/version.rb +1 -1
- metadata +11 -4
data/ext/torch/utils.h
CHANGED
data/lib/torch/nn/functional.rb
CHANGED
@@ -134,7 +134,7 @@ module Torch
|
|
134
134
|
raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
|
135
135
|
|
136
136
|
if mode == "constant"
|
137
|
-
|
137
|
+
Torch.constant_pad_nd(input, pad, value)
|
138
138
|
else
|
139
139
|
raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0
|
140
140
|
|
@@ -481,6 +481,16 @@ module Torch
|
|
481
481
|
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
|
482
482
|
end
|
483
483
|
|
484
|
+
def normalize(input, p: 2.0, dim: 1, eps: 1e-12, out: nil)
|
485
|
+
if out.nil?
|
486
|
+
denom = input.norm(p, dim, keepdim: true).clamp_min(eps).expand_as(input)
|
487
|
+
input / denom
|
488
|
+
else
|
489
|
+
denom = input.norm(p, dim, keepdim: true).clamp_min!(eps).expand_as(input)
|
490
|
+
Torch.div(input, denom, out: out)
|
491
|
+
end
|
492
|
+
end
|
493
|
+
|
484
494
|
# vision
|
485
495
|
|
486
496
|
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
@@ -5,10 +5,10 @@ module Torch
|
|
5
5
|
def in_projection_packed(q, k, v, w, b: nil)
|
6
6
|
e = q.size(-1)
|
7
7
|
|
8
|
-
if k.eql?
|
9
|
-
if q.eql?
|
8
|
+
if k.eql?(v)
|
9
|
+
if q.eql?(k)
|
10
10
|
# self-attention
|
11
|
-
|
11
|
+
linear(q, w, b).chunk(3, dim: -1)
|
12
12
|
else
|
13
13
|
# encoder-decoder attention
|
14
14
|
w_q, w_kv = w.split_with_sizes([e, e * 2])
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
b_q, b_kv = b.split_with_sizes([e, e * 2])
|
19
19
|
end
|
20
20
|
|
21
|
-
|
21
|
+
[linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
|
22
22
|
end
|
23
23
|
else
|
24
24
|
w_q, w_k, w_v = w.chunk(3)
|
@@ -28,7 +28,7 @@ module Torch
|
|
28
28
|
b_q, b_k, b_v = b.chunk(3)
|
29
29
|
end
|
30
30
|
|
31
|
-
|
31
|
+
[linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
32
32
|
end
|
33
33
|
end
|
34
34
|
|
data/lib/torch/nn/module.rb
CHANGED
@@ -10,16 +10,23 @@ module Torch
|
|
10
10
|
@parameters = {}
|
11
11
|
@buffers = {}
|
12
12
|
@modules = {}
|
13
|
+
@non_persistent_buffers_set = Set.new
|
13
14
|
end
|
14
15
|
|
15
16
|
def forward
|
16
17
|
raise NotImplementedError
|
17
18
|
end
|
18
19
|
|
19
|
-
def register_buffer(name, tensor)
|
20
|
+
def register_buffer(name, tensor, persistent: true)
|
20
21
|
# TODO add checks
|
21
22
|
@buffers[name] = tensor
|
22
23
|
instance_variable_set("@#{name}", tensor)
|
24
|
+
|
25
|
+
if persistent
|
26
|
+
@non_persistent_buffers_set.delete(name)
|
27
|
+
else
|
28
|
+
@non_persistent_buffers_set << name
|
29
|
+
end
|
23
30
|
end
|
24
31
|
|
25
32
|
def register_parameter(name, param)
|
@@ -190,8 +197,18 @@ module Torch
|
|
190
197
|
named_buffers.values
|
191
198
|
end
|
192
199
|
|
193
|
-
|
194
|
-
|
200
|
+
# TODO set recurse: true in 0.18.0
|
201
|
+
def named_buffers(prefix: "", recurse: false)
|
202
|
+
buffers = {}
|
203
|
+
if recurse
|
204
|
+
named_children.each do |name, mod|
|
205
|
+
buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
|
206
|
+
end
|
207
|
+
end
|
208
|
+
(@buffers || {}).each do |k, v|
|
209
|
+
buffers[[prefix, k].join] = v
|
210
|
+
end
|
211
|
+
buffers
|
195
212
|
end
|
196
213
|
|
197
214
|
def children
|
@@ -390,7 +407,10 @@ module Torch
|
|
390
407
|
destination[prefix + k] = v
|
391
408
|
end
|
392
409
|
named_buffers.each do |k, v|
|
393
|
-
|
410
|
+
# TODO exclude v.nil?
|
411
|
+
if !@non_persistent_buffers_set.include?(k)
|
412
|
+
destination[prefix + k] = v
|
413
|
+
end
|
394
414
|
end
|
395
415
|
end
|
396
416
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -57,7 +57,7 @@ module Torch
|
|
57
57
|
if shape.empty?
|
58
58
|
arr
|
59
59
|
else
|
60
|
-
shape[1..-1].
|
60
|
+
shape[1..-1].reverse_each do |dim|
|
61
61
|
arr = arr.each_slice(dim)
|
62
62
|
end
|
63
63
|
arr.to_a
|
@@ -132,9 +132,13 @@ module Torch
|
|
132
132
|
|
133
133
|
# TODO read directly from memory
|
134
134
|
def numo
|
135
|
-
|
136
|
-
|
137
|
-
|
135
|
+
if dtype == :bool
|
136
|
+
Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
|
137
|
+
else
|
138
|
+
cls = Torch._dtype_to_numo[dtype]
|
139
|
+
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
140
|
+
cls.from_string(_data_str).reshape(*shape)
|
141
|
+
end
|
138
142
|
end
|
139
143
|
|
140
144
|
def requires_grad=(requires_grad)
|
@@ -160,6 +164,7 @@ module Torch
|
|
160
164
|
# based on python_variable_indexing.cpp and
|
161
165
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
162
166
|
def [](*indexes)
|
167
|
+
indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
|
163
168
|
_index(indexes)
|
164
169
|
end
|
165
170
|
|
@@ -167,6 +172,7 @@ module Torch
|
|
167
172
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
168
173
|
def []=(*indexes, value)
|
169
174
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
175
|
+
indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
|
170
176
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
171
177
|
_index_put_custom(indexes, value)
|
172
178
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.17.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-08-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 4.1
|
19
|
+
version: '4.1'
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 4.1
|
26
|
+
version: '4.1'
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|
@@ -43,17 +43,24 @@ files:
|
|
43
43
|
- ext/torch/ext.cpp
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
|
+
- ext/torch/fft_functions.h
|
46
47
|
- ext/torch/generator.cpp
|
47
48
|
- ext/torch/ivalue.cpp
|
48
49
|
- ext/torch/linalg.cpp
|
50
|
+
- ext/torch/linalg_functions.h
|
49
51
|
- ext/torch/nn.cpp
|
52
|
+
- ext/torch/nn_functions.h
|
50
53
|
- ext/torch/random.cpp
|
51
54
|
- ext/torch/ruby_arg_parser.cpp
|
52
55
|
- ext/torch/ruby_arg_parser.h
|
56
|
+
- ext/torch/sparse_functions.h
|
53
57
|
- ext/torch/special.cpp
|
58
|
+
- ext/torch/special_functions.h
|
54
59
|
- ext/torch/templates.h
|
55
60
|
- ext/torch/tensor.cpp
|
61
|
+
- ext/torch/tensor_functions.h
|
56
62
|
- ext/torch/torch.cpp
|
63
|
+
- ext/torch/torch_functions.h
|
57
64
|
- ext/torch/utils.h
|
58
65
|
- ext/torch/wrap_outputs.h
|
59
66
|
- lib/torch-rb.rb
|