torch-rb 0.2.7 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -2
- data/README.md +9 -2
- data/ext/torch/ext.cpp +49 -7
- data/ext/torch/extconf.rb +3 -4
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +30 -5
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/tensor.rb +39 -46
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
data/lib/torch/native/parser.rb
CHANGED
@@ -83,6 +83,8 @@ module Torch
|
|
83
83
|
else
|
84
84
|
v.is_a?(Integer)
|
85
85
|
end
|
86
|
+
when "int?"
|
87
|
+
v.is_a?(Integer) || v.nil?
|
86
88
|
when "float"
|
87
89
|
v.is_a?(Numeric)
|
88
90
|
when /int\[.*\]/
|
@@ -126,9 +128,11 @@ module Torch
|
|
126
128
|
end
|
127
129
|
|
128
130
|
func = candidates.first
|
131
|
+
args = func.args.map { |a| final_values[a[:name]] }
|
132
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
133
|
{
|
130
134
|
name: func.cpp_name,
|
131
|
-
args:
|
135
|
+
args: args
|
132
136
|
}
|
133
137
|
end
|
134
138
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -18,7 +18,6 @@ module Torch
|
|
18
18
|
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
19
|
end
|
20
20
|
|
21
|
-
# TODO add more parameters
|
22
21
|
def extra_inspect
|
23
22
|
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
23
|
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -373,7 +373,8 @@ module Torch
|
|
373
373
|
end
|
374
374
|
|
375
375
|
# weight and input swapped
|
376
|
-
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
376
|
+
ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
377
|
+
ret
|
377
378
|
end
|
378
379
|
|
379
380
|
# distance functions
|
@@ -426,6 +427,9 @@ module Torch
|
|
426
427
|
end
|
427
428
|
|
428
429
|
def mse_loss(input, target, reduction: "mean")
|
430
|
+
if target.size != input.size
|
431
|
+
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
432
|
+
end
|
429
433
|
NN.mse_loss(input, target, reduction)
|
430
434
|
end
|
431
435
|
|
@@ -32,9 +32,11 @@ module Torch
|
|
32
32
|
end
|
33
33
|
|
34
34
|
def state_dict
|
35
|
+
raise NotImplementedYet
|
36
|
+
|
35
37
|
pack_group = lambda do |group|
|
36
|
-
packed = group.select { |k, _| k != :params }.to_h
|
37
|
-
packed[
|
38
|
+
packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
|
39
|
+
packed["params"] = group[:params].map { |p| p.object_id }
|
38
40
|
packed
|
39
41
|
end
|
40
42
|
|
@@ -42,8 +44,8 @@ module Torch
|
|
42
44
|
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
45
|
|
44
46
|
{
|
45
|
-
state
|
46
|
-
param_groups
|
47
|
+
"state" => packed_state,
|
48
|
+
"param_groups" => param_groups
|
47
49
|
}
|
48
50
|
end
|
49
51
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -47,10 +47,15 @@ module Torch
|
|
47
47
|
end
|
48
48
|
end
|
49
49
|
|
50
|
-
|
51
|
-
|
50
|
+
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
device ||= self.device
|
52
52
|
device = Device.new(device) if device.is_a?(String)
|
53
|
-
|
53
|
+
|
54
|
+
dtype ||= self.dtype
|
55
|
+
enum = DTYPE_TO_ENUM[dtype]
|
56
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
57
|
+
|
58
|
+
_to(device, enum, non_blocking, copy)
|
54
59
|
end
|
55
60
|
|
56
61
|
def cpu
|
@@ -98,8 +103,9 @@ module Torch
|
|
98
103
|
Torch.empty(0, dtype: dtype)
|
99
104
|
end
|
100
105
|
|
101
|
-
def backward(gradient = nil)
|
102
|
-
|
106
|
+
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
+
retain_graph = create_graph if retain_graph.nil?
|
108
|
+
_backward(gradient, retain_graph, create_graph)
|
103
109
|
end
|
104
110
|
|
105
111
|
# TODO read directly from memory
|
@@ -182,49 +188,15 @@ module Torch
|
|
182
188
|
# based on python_variable_indexing.cpp and
|
183
189
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
184
190
|
def [](*indexes)
|
185
|
-
|
186
|
-
dim = 0
|
187
|
-
indexes.each do |index|
|
188
|
-
if index.is_a?(Numeric)
|
189
|
-
result = result._select_int(dim, index)
|
190
|
-
elsif index.is_a?(Range)
|
191
|
-
finish = index.end
|
192
|
-
finish += 1 unless index.exclude_end?
|
193
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
194
|
-
dim += 1
|
195
|
-
elsif index.is_a?(Tensor)
|
196
|
-
result = result.index([index])
|
197
|
-
elsif index.nil?
|
198
|
-
result = result.unsqueeze(dim)
|
199
|
-
dim += 1
|
200
|
-
elsif index == true
|
201
|
-
result = result.unsqueeze(dim)
|
202
|
-
# TODO handle false
|
203
|
-
else
|
204
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
205
|
-
end
|
206
|
-
end
|
207
|
-
result
|
191
|
+
_index(tensor_indexes(indexes))
|
208
192
|
end
|
209
193
|
|
210
194
|
# based on python_variable_indexing.cpp and
|
211
195
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
212
|
-
def []=(
|
196
|
+
def []=(*indexes, value)
|
213
197
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
214
|
-
|
215
198
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
216
|
-
|
217
|
-
if index.is_a?(Numeric)
|
218
|
-
copy_to(_select_int(0, index), value)
|
219
|
-
elsif index.is_a?(Range)
|
220
|
-
finish = index.end
|
221
|
-
finish += 1 unless index.exclude_end?
|
222
|
-
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
223
|
-
elsif index.is_a?(Tensor)
|
224
|
-
index_put!([index], value)
|
225
|
-
else
|
226
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
227
|
-
end
|
199
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
228
200
|
end
|
229
201
|
|
230
202
|
# native functions that need manually defined
|
@@ -238,13 +210,13 @@ module Torch
|
|
238
210
|
end
|
239
211
|
end
|
240
212
|
|
241
|
-
#
|
213
|
+
# parser can't handle overlap, so need to handle manually
|
242
214
|
def random!(*args)
|
243
215
|
case args.size
|
244
216
|
when 1
|
245
217
|
_random__to(*args)
|
246
218
|
when 2
|
247
|
-
|
219
|
+
_random__from(*args)
|
248
220
|
else
|
249
221
|
_random_(*args)
|
250
222
|
end
|
@@ -257,8 +229,29 @@ module Torch
|
|
257
229
|
|
258
230
|
private
|
259
231
|
|
260
|
-
def
|
261
|
-
|
232
|
+
def tensor_indexes(indexes)
|
233
|
+
indexes.map do |index|
|
234
|
+
case index
|
235
|
+
when Integer
|
236
|
+
TensorIndex.integer(index)
|
237
|
+
when Range
|
238
|
+
finish = index.end
|
239
|
+
if finish == -1 && !index.exclude_end?
|
240
|
+
finish = nil
|
241
|
+
else
|
242
|
+
finish += 1 unless index.exclude_end?
|
243
|
+
end
|
244
|
+
TensorIndex.slice(index.begin, finish)
|
245
|
+
when Tensor
|
246
|
+
TensorIndex.tensor(index)
|
247
|
+
when nil
|
248
|
+
TensorIndex.none
|
249
|
+
when true, false
|
250
|
+
TensorIndex.boolean(index)
|
251
|
+
else
|
252
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
253
|
+
end
|
254
|
+
end
|
262
255
|
end
|
263
256
|
end
|
264
257
|
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class << self
|
5
|
+
def random_split(dataset, lengths)
|
6
|
+
if lengths.sum != dataset.length
|
7
|
+
raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
|
8
|
+
end
|
9
|
+
|
10
|
+
indices = Torch.randperm(lengths.sum).to_a
|
11
|
+
_accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
|
12
|
+
end
|
13
|
+
|
14
|
+
private
|
15
|
+
|
16
|
+
def _accumulate(iterable)
|
17
|
+
sum = 0
|
18
|
+
iterable.map { |x| sum += x }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -6,10 +6,22 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1, shuffle: false)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
12
|
@shuffle = shuffle
|
13
|
+
|
14
|
+
@batch_sampler = nil
|
15
|
+
|
16
|
+
if collate_fn.nil?
|
17
|
+
if auto_collation?
|
18
|
+
collate_fn = method(:default_collate)
|
19
|
+
else
|
20
|
+
collate_fn = method(:default_convert)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
@collate_fn = collate_fn
|
13
25
|
end
|
14
26
|
|
15
27
|
def each
|
@@ -25,8 +37,8 @@ module Torch
|
|
25
37
|
end
|
26
38
|
|
27
39
|
indexes.each_slice(@batch_size) do |idx|
|
28
|
-
|
29
|
-
yield
|
40
|
+
# TODO improve performance
|
41
|
+
yield @collate_fn.call(idx.map { |i| @dataset[i] })
|
30
42
|
end
|
31
43
|
end
|
32
44
|
|
@@ -36,7 +48,7 @@ module Torch
|
|
36
48
|
|
37
49
|
private
|
38
50
|
|
39
|
-
def
|
51
|
+
def default_convert(batch)
|
40
52
|
elem = batch[0]
|
41
53
|
case elem
|
42
54
|
when Tensor
|
@@ -44,11 +56,15 @@ module Torch
|
|
44
56
|
when Integer
|
45
57
|
Torch.tensor(batch)
|
46
58
|
when Array
|
47
|
-
batch.transpose.map { |v|
|
59
|
+
batch.transpose.map { |v| default_convert(v) }
|
48
60
|
else
|
49
|
-
raise
|
61
|
+
raise NotImplementedYet
|
50
62
|
end
|
51
63
|
end
|
64
|
+
|
65
|
+
def auto_collation?
|
66
|
+
!@batch_sampler.nil?
|
67
|
+
end
|
52
68
|
end
|
53
69
|
end
|
54
70
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class Subset < Dataset
|
5
|
+
def initialize(dataset, indices)
|
6
|
+
@dataset = dataset
|
7
|
+
@indices = indices
|
8
|
+
end
|
9
|
+
|
10
|
+
def [](idx)
|
11
|
+
@dataset[@indices[idx]]
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@indices.length
|
16
|
+
end
|
17
|
+
alias_method :size, :length
|
18
|
+
|
19
|
+
def to_a
|
20
|
+
@indices.map { |i| @dataset[i] }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -259,8 +259,10 @@ files:
|
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
|
+
- lib/torch/utils/data.rb
|
262
263
|
- lib/torch/utils/data/data_loader.rb
|
263
264
|
- lib/torch/utils/data/dataset.rb
|
265
|
+
- lib/torch/utils/data/subset.rb
|
264
266
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
267
|
- lib/torch/version.rb
|
266
268
|
homepage: https://github.com/ankane/torch.rb
|