torch-rb 0.2.6 → 0.3.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -2
- data/README.md +9 -2
- data/ext/torch/ext.cpp +49 -7
- data/ext/torch/extconf.rb +3 -4
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +25 -5
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/tensor.rb +60 -46
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
data/lib/torch/native/parser.rb
CHANGED
@@ -83,6 +83,8 @@ module Torch
|
|
83
83
|
else
|
84
84
|
v.is_a?(Integer)
|
85
85
|
end
|
86
|
+
when "int?"
|
87
|
+
v.is_a?(Integer) || v.nil?
|
86
88
|
when "float"
|
87
89
|
v.is_a?(Numeric)
|
88
90
|
when /int\[.*\]/
|
@@ -126,9 +128,11 @@ module Torch
|
|
126
128
|
end
|
127
129
|
|
128
130
|
func = candidates.first
|
131
|
+
args = func.args.map { |a| final_values[a[:name]] }
|
132
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
133
|
{
|
130
134
|
name: func.cpp_name,
|
131
|
-
args:
|
135
|
+
args: args
|
132
136
|
}
|
133
137
|
end
|
134
138
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -18,7 +18,6 @@ module Torch
|
|
18
18
|
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
19
|
end
|
20
20
|
|
21
|
-
# TODO add more parameters
|
22
21
|
def extra_inspect
|
23
22
|
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
23
|
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -373,7 +373,8 @@ module Torch
|
|
373
373
|
end
|
374
374
|
|
375
375
|
# weight and input swapped
|
376
|
-
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
376
|
+
ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
377
|
+
ret
|
377
378
|
end
|
378
379
|
|
379
380
|
# distance functions
|
@@ -426,6 +427,9 @@ module Torch
|
|
426
427
|
end
|
427
428
|
|
428
429
|
def mse_loss(input, target, reduction: "mean")
|
430
|
+
if target.size != input.size
|
431
|
+
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
432
|
+
end
|
429
433
|
NN.mse_loss(input, target, reduction)
|
430
434
|
end
|
431
435
|
|
data/lib/torch/nn/module.rb
CHANGED
@@ -286,8 +286,11 @@ module Torch
|
|
286
286
|
str % vars
|
287
287
|
end
|
288
288
|
|
289
|
+
# used for format
|
290
|
+
# remove tensors for performance
|
291
|
+
# so we can skip call to inspect
|
289
292
|
def dict
|
290
|
-
instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
293
|
+
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
291
294
|
end
|
292
295
|
end
|
293
296
|
end
|
@@ -32,9 +32,11 @@ module Torch
|
|
32
32
|
end
|
33
33
|
|
34
34
|
def state_dict
|
35
|
+
raise NotImplementedYet
|
36
|
+
|
35
37
|
pack_group = lambda do |group|
|
36
|
-
packed = group.select { |k, _| k != :params }.to_h
|
37
|
-
packed[
|
38
|
+
packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
|
39
|
+
packed["params"] = group[:params].map { |p| p.object_id }
|
38
40
|
packed
|
39
41
|
end
|
40
42
|
|
@@ -42,8 +44,8 @@ module Torch
|
|
42
44
|
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
45
|
|
44
46
|
{
|
45
|
-
state
|
46
|
-
param_groups
|
47
|
+
"state" => packed_state,
|
48
|
+
"param_groups" => param_groups
|
47
49
|
}
|
48
50
|
end
|
49
51
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
module Torch
|
2
2
|
class Tensor
|
3
3
|
include Comparable
|
4
|
+
include Enumerable
|
4
5
|
include Inspector
|
5
6
|
|
6
7
|
alias_method :requires_grad?, :requires_grad
|
@@ -25,6 +26,14 @@ module Torch
|
|
25
26
|
inspect
|
26
27
|
end
|
27
28
|
|
29
|
+
def each
|
30
|
+
return enum_for(:each) unless block_given?
|
31
|
+
|
32
|
+
size(0).times do |i|
|
33
|
+
yield self[i]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
28
37
|
# TODO make more performant
|
29
38
|
def to_a
|
30
39
|
arr = _flat_data
|
@@ -38,10 +47,15 @@ module Torch
|
|
38
47
|
end
|
39
48
|
end
|
40
49
|
|
41
|
-
|
42
|
-
|
50
|
+
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
device ||= self.device
|
43
52
|
device = Device.new(device) if device.is_a?(String)
|
44
|
-
|
53
|
+
|
54
|
+
dtype ||= self.dtype
|
55
|
+
enum = DTYPE_TO_ENUM[dtype]
|
56
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
57
|
+
|
58
|
+
_to(device, enum, non_blocking, copy)
|
45
59
|
end
|
46
60
|
|
47
61
|
def cpu
|
@@ -89,8 +103,9 @@ module Torch
|
|
89
103
|
Torch.empty(0, dtype: dtype)
|
90
104
|
end
|
91
105
|
|
92
|
-
def backward(gradient = nil)
|
93
|
-
|
106
|
+
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
+
retain_graph = create_graph if retain_graph.nil?
|
108
|
+
_backward(gradient, retain_graph, create_graph)
|
94
109
|
end
|
95
110
|
|
96
111
|
# TODO read directly from memory
|
@@ -153,6 +168,18 @@ module Torch
|
|
153
168
|
neg
|
154
169
|
end
|
155
170
|
|
171
|
+
def &(other)
|
172
|
+
logical_and(other)
|
173
|
+
end
|
174
|
+
|
175
|
+
def |(other)
|
176
|
+
logical_or(other)
|
177
|
+
end
|
178
|
+
|
179
|
+
def ^(other)
|
180
|
+
logical_xor(other)
|
181
|
+
end
|
182
|
+
|
156
183
|
# TODO better compare?
|
157
184
|
def <=>(other)
|
158
185
|
item <=> other
|
@@ -161,49 +188,15 @@ module Torch
|
|
161
188
|
# based on python_variable_indexing.cpp and
|
162
189
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
163
190
|
def [](*indexes)
|
164
|
-
|
165
|
-
dim = 0
|
166
|
-
indexes.each do |index|
|
167
|
-
if index.is_a?(Numeric)
|
168
|
-
result = result._select_int(dim, index)
|
169
|
-
elsif index.is_a?(Range)
|
170
|
-
finish = index.end
|
171
|
-
finish += 1 unless index.exclude_end?
|
172
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
173
|
-
dim += 1
|
174
|
-
elsif index.is_a?(Tensor)
|
175
|
-
result = result.index([index])
|
176
|
-
elsif index.nil?
|
177
|
-
result = result.unsqueeze(dim)
|
178
|
-
dim += 1
|
179
|
-
elsif index == true
|
180
|
-
result = result.unsqueeze(dim)
|
181
|
-
# TODO handle false
|
182
|
-
else
|
183
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
184
|
-
end
|
185
|
-
end
|
186
|
-
result
|
191
|
+
_index(tensor_indexes(indexes))
|
187
192
|
end
|
188
193
|
|
189
194
|
# based on python_variable_indexing.cpp and
|
190
195
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
191
|
-
def []=(
|
196
|
+
def []=(*indexes, value)
|
192
197
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
193
|
-
|
194
198
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
195
|
-
|
196
|
-
if index.is_a?(Numeric)
|
197
|
-
copy_to(_select_int(0, index), value)
|
198
|
-
elsif index.is_a?(Range)
|
199
|
-
finish = index.end
|
200
|
-
finish += 1 unless index.exclude_end?
|
201
|
-
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
202
|
-
elsif index.is_a?(Tensor)
|
203
|
-
index_put!([index], value)
|
204
|
-
else
|
205
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
206
|
-
end
|
199
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
207
200
|
end
|
208
201
|
|
209
202
|
# native functions that need manually defined
|
@@ -217,13 +210,13 @@ module Torch
|
|
217
210
|
end
|
218
211
|
end
|
219
212
|
|
220
|
-
#
|
213
|
+
# parser can't handle overlap, so need to handle manually
|
221
214
|
def random!(*args)
|
222
215
|
case args.size
|
223
216
|
when 1
|
224
217
|
_random__to(*args)
|
225
218
|
when 2
|
226
|
-
|
219
|
+
_random__from(*args)
|
227
220
|
else
|
228
221
|
_random_(*args)
|
229
222
|
end
|
@@ -236,8 +229,29 @@ module Torch
|
|
236
229
|
|
237
230
|
private
|
238
231
|
|
239
|
-
def
|
240
|
-
|
232
|
+
def tensor_indexes(indexes)
|
233
|
+
indexes.map do |index|
|
234
|
+
case index
|
235
|
+
when Integer
|
236
|
+
TensorIndex.integer(index)
|
237
|
+
when Range
|
238
|
+
finish = index.end
|
239
|
+
if finish == -1 && !index.exclude_end?
|
240
|
+
finish = nil
|
241
|
+
else
|
242
|
+
finish += 1 unless index.exclude_end?
|
243
|
+
end
|
244
|
+
TensorIndex.slice(index.begin, finish)
|
245
|
+
when Tensor
|
246
|
+
TensorIndex.tensor(index)
|
247
|
+
when nil
|
248
|
+
TensorIndex.none
|
249
|
+
when true, false
|
250
|
+
TensorIndex.boolean(index)
|
251
|
+
else
|
252
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
253
|
+
end
|
254
|
+
end
|
241
255
|
end
|
242
256
|
end
|
243
257
|
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class << self
|
5
|
+
def random_split(dataset, lengths)
|
6
|
+
if lengths.sum != dataset.length
|
7
|
+
raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
|
8
|
+
end
|
9
|
+
|
10
|
+
indices = Torch.randperm(lengths.sum).to_a
|
11
|
+
_accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
|
12
|
+
end
|
13
|
+
|
14
|
+
private
|
15
|
+
|
16
|
+
def _accumulate(iterable)
|
17
|
+
sum = 0
|
18
|
+
iterable.map { |x| sum += x }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -6,10 +6,22 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1, shuffle: false)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
12
|
@shuffle = shuffle
|
13
|
+
|
14
|
+
@batch_sampler = nil
|
15
|
+
|
16
|
+
if collate_fn.nil?
|
17
|
+
if auto_collation?
|
18
|
+
collate_fn = method(:default_collate)
|
19
|
+
else
|
20
|
+
collate_fn = method(:default_convert)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
@collate_fn = collate_fn
|
13
25
|
end
|
14
26
|
|
15
27
|
def each
|
@@ -25,8 +37,8 @@ module Torch
|
|
25
37
|
end
|
26
38
|
|
27
39
|
indexes.each_slice(@batch_size) do |idx|
|
28
|
-
|
29
|
-
yield
|
40
|
+
# TODO improve performance
|
41
|
+
yield @collate_fn.call(idx.map { |i| @dataset[i] })
|
30
42
|
end
|
31
43
|
end
|
32
44
|
|
@@ -36,7 +48,7 @@ module Torch
|
|
36
48
|
|
37
49
|
private
|
38
50
|
|
39
|
-
def
|
51
|
+
def default_convert(batch)
|
40
52
|
elem = batch[0]
|
41
53
|
case elem
|
42
54
|
when Tensor
|
@@ -44,11 +56,15 @@ module Torch
|
|
44
56
|
when Integer
|
45
57
|
Torch.tensor(batch)
|
46
58
|
when Array
|
47
|
-
batch.transpose.map { |v|
|
59
|
+
batch.transpose.map { |v| default_convert(v) }
|
48
60
|
else
|
49
|
-
raise
|
61
|
+
raise NotImplementedYet
|
50
62
|
end
|
51
63
|
end
|
64
|
+
|
65
|
+
def auto_collation?
|
66
|
+
!@batch_sampler.nil?
|
67
|
+
end
|
52
68
|
end
|
53
69
|
end
|
54
70
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class Subset < Dataset
|
5
|
+
def initialize(dataset, indices)
|
6
|
+
@dataset = dataset
|
7
|
+
@indices = indices
|
8
|
+
end
|
9
|
+
|
10
|
+
def [](idx)
|
11
|
+
@dataset[@indices[idx]]
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@indices.length
|
16
|
+
end
|
17
|
+
alias_method :size, :length
|
18
|
+
|
19
|
+
def to_a
|
20
|
+
@indices.map { |i| @dataset[i] }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -259,8 +259,10 @@ files:
|
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
|
+
- lib/torch/utils/data.rb
|
262
263
|
- lib/torch/utils/data/data_loader.rb
|
263
264
|
- lib/torch/utils/data/dataset.rb
|
265
|
+
- lib/torch/utils/data/subset.rb
|
264
266
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
267
|
- lib/torch/version.rb
|
266
268
|
homepage: https://github.com/ankane/torch.rb
|