torch-rb 0.2.6 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -2
- data/README.md +9 -2
- data/ext/torch/ext.cpp +49 -7
- data/ext/torch/extconf.rb +3 -4
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +25 -5
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/tensor.rb +60 -46
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
data/lib/torch/native/parser.rb
CHANGED
@@ -83,6 +83,8 @@ module Torch
|
|
83
83
|
else
|
84
84
|
v.is_a?(Integer)
|
85
85
|
end
|
86
|
+
when "int?"
|
87
|
+
v.is_a?(Integer) || v.nil?
|
86
88
|
when "float"
|
87
89
|
v.is_a?(Numeric)
|
88
90
|
when /int\[.*\]/
|
@@ -126,9 +128,11 @@ module Torch
|
|
126
128
|
end
|
127
129
|
|
128
130
|
func = candidates.first
|
131
|
+
args = func.args.map { |a| final_values[a[:name]] }
|
132
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
133
|
{
|
130
134
|
name: func.cpp_name,
|
131
|
-
args:
|
135
|
+
args: args
|
132
136
|
}
|
133
137
|
end
|
134
138
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -18,7 +18,6 @@ module Torch
|
|
18
18
|
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
19
|
end
|
20
20
|
|
21
|
-
# TODO add more parameters
|
22
21
|
def extra_inspect
|
23
22
|
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
23
|
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -373,7 +373,8 @@ module Torch
|
|
373
373
|
end
|
374
374
|
|
375
375
|
# weight and input swapped
|
376
|
-
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
376
|
+
ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
377
|
+
ret
|
377
378
|
end
|
378
379
|
|
379
380
|
# distance functions
|
@@ -426,6 +427,9 @@ module Torch
|
|
426
427
|
end
|
427
428
|
|
428
429
|
def mse_loss(input, target, reduction: "mean")
|
430
|
+
if target.size != input.size
|
431
|
+
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
432
|
+
end
|
429
433
|
NN.mse_loss(input, target, reduction)
|
430
434
|
end
|
431
435
|
|
data/lib/torch/nn/module.rb
CHANGED
@@ -286,8 +286,11 @@ module Torch
|
|
286
286
|
str % vars
|
287
287
|
end
|
288
288
|
|
289
|
+
# used for format
|
290
|
+
# remove tensors for performance
|
291
|
+
# so we can skip call to inspect
|
289
292
|
def dict
|
290
|
-
instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
293
|
+
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
291
294
|
end
|
292
295
|
end
|
293
296
|
end
|
@@ -32,9 +32,11 @@ module Torch
|
|
32
32
|
end
|
33
33
|
|
34
34
|
def state_dict
|
35
|
+
raise NotImplementedYet
|
36
|
+
|
35
37
|
pack_group = lambda do |group|
|
36
|
-
packed = group.select { |k, _| k != :params }.to_h
|
37
|
-
packed[
|
38
|
+
packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
|
39
|
+
packed["params"] = group[:params].map { |p| p.object_id }
|
38
40
|
packed
|
39
41
|
end
|
40
42
|
|
@@ -42,8 +44,8 @@ module Torch
|
|
42
44
|
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
45
|
|
44
46
|
{
|
45
|
-
state
|
46
|
-
param_groups
|
47
|
+
"state" => packed_state,
|
48
|
+
"param_groups" => param_groups
|
47
49
|
}
|
48
50
|
end
|
49
51
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
module Torch
|
2
2
|
class Tensor
|
3
3
|
include Comparable
|
4
|
+
include Enumerable
|
4
5
|
include Inspector
|
5
6
|
|
6
7
|
alias_method :requires_grad?, :requires_grad
|
@@ -25,6 +26,14 @@ module Torch
|
|
25
26
|
inspect
|
26
27
|
end
|
27
28
|
|
29
|
+
def each
|
30
|
+
return enum_for(:each) unless block_given?
|
31
|
+
|
32
|
+
size(0).times do |i|
|
33
|
+
yield self[i]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
28
37
|
# TODO make more performant
|
29
38
|
def to_a
|
30
39
|
arr = _flat_data
|
@@ -38,10 +47,15 @@ module Torch
|
|
38
47
|
end
|
39
48
|
end
|
40
49
|
|
41
|
-
|
42
|
-
|
50
|
+
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
device ||= self.device
|
43
52
|
device = Device.new(device) if device.is_a?(String)
|
44
|
-
|
53
|
+
|
54
|
+
dtype ||= self.dtype
|
55
|
+
enum = DTYPE_TO_ENUM[dtype]
|
56
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
57
|
+
|
58
|
+
_to(device, enum, non_blocking, copy)
|
45
59
|
end
|
46
60
|
|
47
61
|
def cpu
|
@@ -89,8 +103,9 @@ module Torch
|
|
89
103
|
Torch.empty(0, dtype: dtype)
|
90
104
|
end
|
91
105
|
|
92
|
-
def backward(gradient = nil)
|
93
|
-
|
106
|
+
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
+
retain_graph = create_graph if retain_graph.nil?
|
108
|
+
_backward(gradient, retain_graph, create_graph)
|
94
109
|
end
|
95
110
|
|
96
111
|
# TODO read directly from memory
|
@@ -153,6 +168,18 @@ module Torch
|
|
153
168
|
neg
|
154
169
|
end
|
155
170
|
|
171
|
+
def &(other)
|
172
|
+
logical_and(other)
|
173
|
+
end
|
174
|
+
|
175
|
+
def |(other)
|
176
|
+
logical_or(other)
|
177
|
+
end
|
178
|
+
|
179
|
+
def ^(other)
|
180
|
+
logical_xor(other)
|
181
|
+
end
|
182
|
+
|
156
183
|
# TODO better compare?
|
157
184
|
def <=>(other)
|
158
185
|
item <=> other
|
@@ -161,49 +188,15 @@ module Torch
|
|
161
188
|
# based on python_variable_indexing.cpp and
|
162
189
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
163
190
|
def [](*indexes)
|
164
|
-
|
165
|
-
dim = 0
|
166
|
-
indexes.each do |index|
|
167
|
-
if index.is_a?(Numeric)
|
168
|
-
result = result._select_int(dim, index)
|
169
|
-
elsif index.is_a?(Range)
|
170
|
-
finish = index.end
|
171
|
-
finish += 1 unless index.exclude_end?
|
172
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
173
|
-
dim += 1
|
174
|
-
elsif index.is_a?(Tensor)
|
175
|
-
result = result.index([index])
|
176
|
-
elsif index.nil?
|
177
|
-
result = result.unsqueeze(dim)
|
178
|
-
dim += 1
|
179
|
-
elsif index == true
|
180
|
-
result = result.unsqueeze(dim)
|
181
|
-
# TODO handle false
|
182
|
-
else
|
183
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
184
|
-
end
|
185
|
-
end
|
186
|
-
result
|
191
|
+
_index(tensor_indexes(indexes))
|
187
192
|
end
|
188
193
|
|
189
194
|
# based on python_variable_indexing.cpp and
|
190
195
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
191
|
-
def []=(
|
196
|
+
def []=(*indexes, value)
|
192
197
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
193
|
-
|
194
198
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
195
|
-
|
196
|
-
if index.is_a?(Numeric)
|
197
|
-
copy_to(_select_int(0, index), value)
|
198
|
-
elsif index.is_a?(Range)
|
199
|
-
finish = index.end
|
200
|
-
finish += 1 unless index.exclude_end?
|
201
|
-
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
202
|
-
elsif index.is_a?(Tensor)
|
203
|
-
index_put!([index], value)
|
204
|
-
else
|
205
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
206
|
-
end
|
199
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
207
200
|
end
|
208
201
|
|
209
202
|
# native functions that need manually defined
|
@@ -217,13 +210,13 @@ module Torch
|
|
217
210
|
end
|
218
211
|
end
|
219
212
|
|
220
|
-
#
|
213
|
+
# parser can't handle overlap, so need to handle manually
|
221
214
|
def random!(*args)
|
222
215
|
case args.size
|
223
216
|
when 1
|
224
217
|
_random__to(*args)
|
225
218
|
when 2
|
226
|
-
|
219
|
+
_random__from(*args)
|
227
220
|
else
|
228
221
|
_random_(*args)
|
229
222
|
end
|
@@ -236,8 +229,29 @@ module Torch
|
|
236
229
|
|
237
230
|
private
|
238
231
|
|
239
|
-
def
|
240
|
-
|
232
|
+
def tensor_indexes(indexes)
|
233
|
+
indexes.map do |index|
|
234
|
+
case index
|
235
|
+
when Integer
|
236
|
+
TensorIndex.integer(index)
|
237
|
+
when Range
|
238
|
+
finish = index.end
|
239
|
+
if finish == -1 && !index.exclude_end?
|
240
|
+
finish = nil
|
241
|
+
else
|
242
|
+
finish += 1 unless index.exclude_end?
|
243
|
+
end
|
244
|
+
TensorIndex.slice(index.begin, finish)
|
245
|
+
when Tensor
|
246
|
+
TensorIndex.tensor(index)
|
247
|
+
when nil
|
248
|
+
TensorIndex.none
|
249
|
+
when true, false
|
250
|
+
TensorIndex.boolean(index)
|
251
|
+
else
|
252
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
253
|
+
end
|
254
|
+
end
|
241
255
|
end
|
242
256
|
end
|
243
257
|
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class << self
|
5
|
+
def random_split(dataset, lengths)
|
6
|
+
if lengths.sum != dataset.length
|
7
|
+
raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
|
8
|
+
end
|
9
|
+
|
10
|
+
indices = Torch.randperm(lengths.sum).to_a
|
11
|
+
_accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
|
12
|
+
end
|
13
|
+
|
14
|
+
private
|
15
|
+
|
16
|
+
def _accumulate(iterable)
|
17
|
+
sum = 0
|
18
|
+
iterable.map { |x| sum += x }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -6,10 +6,22 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1, shuffle: false)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
12
|
@shuffle = shuffle
|
13
|
+
|
14
|
+
@batch_sampler = nil
|
15
|
+
|
16
|
+
if collate_fn.nil?
|
17
|
+
if auto_collation?
|
18
|
+
collate_fn = method(:default_collate)
|
19
|
+
else
|
20
|
+
collate_fn = method(:default_convert)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
@collate_fn = collate_fn
|
13
25
|
end
|
14
26
|
|
15
27
|
def each
|
@@ -25,8 +37,8 @@ module Torch
|
|
25
37
|
end
|
26
38
|
|
27
39
|
indexes.each_slice(@batch_size) do |idx|
|
28
|
-
|
29
|
-
yield
|
40
|
+
# TODO improve performance
|
41
|
+
yield @collate_fn.call(idx.map { |i| @dataset[i] })
|
30
42
|
end
|
31
43
|
end
|
32
44
|
|
@@ -36,7 +48,7 @@ module Torch
|
|
36
48
|
|
37
49
|
private
|
38
50
|
|
39
|
-
def
|
51
|
+
def default_convert(batch)
|
40
52
|
elem = batch[0]
|
41
53
|
case elem
|
42
54
|
when Tensor
|
@@ -44,11 +56,15 @@ module Torch
|
|
44
56
|
when Integer
|
45
57
|
Torch.tensor(batch)
|
46
58
|
when Array
|
47
|
-
batch.transpose.map { |v|
|
59
|
+
batch.transpose.map { |v| default_convert(v) }
|
48
60
|
else
|
49
|
-
raise
|
61
|
+
raise NotImplementedYet
|
50
62
|
end
|
51
63
|
end
|
64
|
+
|
65
|
+
def auto_collation?
|
66
|
+
!@batch_sampler.nil?
|
67
|
+
end
|
52
68
|
end
|
53
69
|
end
|
54
70
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class Subset < Dataset
|
5
|
+
def initialize(dataset, indices)
|
6
|
+
@dataset = dataset
|
7
|
+
@indices = indices
|
8
|
+
end
|
9
|
+
|
10
|
+
def [](idx)
|
11
|
+
@dataset[@indices[idx]]
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@indices.length
|
16
|
+
end
|
17
|
+
alias_method :size, :length
|
18
|
+
|
19
|
+
def to_a
|
20
|
+
@indices.map { |i| @dataset[i] }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -259,8 +259,10 @@ files:
|
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
|
+
- lib/torch/utils/data.rb
|
262
263
|
- lib/torch/utils/data/data_loader.rb
|
263
264
|
- lib/torch/utils/data/dataset.rb
|
265
|
+
- lib/torch/utils/data/subset.rb
|
264
266
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
267
|
- lib/torch/version.rb
|
266
268
|
homepage: https://github.com/ankane/torch.rb
|