torch-rb 0.2.2 → 0.2.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +19 -7
- data/ext/torch/ext.cpp +64 -19
- data/ext/torch/extconf.rb +21 -18
- data/lib/torch.rb +6 -3
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +26 -7
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +76 -30
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -6
- data/lib/torch/random.rb +0 -10
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
module Torch
|
2
2
|
class Tensor
|
3
3
|
include Comparable
|
4
|
+
include Enumerable
|
4
5
|
include Inspector
|
5
6
|
|
6
7
|
alias_method :requires_grad?, :requires_grad
|
@@ -25,8 +26,25 @@ module Torch
|
|
25
26
|
inspect
|
26
27
|
end
|
27
28
|
|
29
|
+
def each
|
30
|
+
return enum_for(:each) unless block_given?
|
31
|
+
|
32
|
+
size(0).times do |i|
|
33
|
+
yield self[i]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
# TODO make more performant
|
28
38
|
def to_a
|
29
|
-
|
39
|
+
arr = _flat_data
|
40
|
+
if shape.empty?
|
41
|
+
arr
|
42
|
+
else
|
43
|
+
shape[1..-1].reverse.each do |dim|
|
44
|
+
arr = arr.each_slice(dim)
|
45
|
+
end
|
46
|
+
arr.to_a
|
47
|
+
end
|
30
48
|
end
|
31
49
|
|
32
50
|
# TODO support dtype
|
@@ -64,7 +82,15 @@ module Torch
|
|
64
82
|
if numel != 1
|
65
83
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
66
84
|
end
|
67
|
-
|
85
|
+
to_a.first
|
86
|
+
end
|
87
|
+
|
88
|
+
def to_i
|
89
|
+
item.to_i
|
90
|
+
end
|
91
|
+
|
92
|
+
def to_f
|
93
|
+
item.to_f
|
68
94
|
end
|
69
95
|
|
70
96
|
# unsure if this is correct
|
@@ -80,7 +106,7 @@ module Torch
|
|
80
106
|
def numo
|
81
107
|
cls = Torch._dtype_to_numo[dtype]
|
82
108
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
83
|
-
cls.
|
109
|
+
cls.from_string(_data_str).reshape(*shape)
|
84
110
|
end
|
85
111
|
|
86
112
|
def new_ones(*size, **options)
|
@@ -108,15 +134,6 @@ module Torch
|
|
108
134
|
_view(size)
|
109
135
|
end
|
110
136
|
|
111
|
-
# value and other are swapped for some methods
|
112
|
-
def add!(value = 1, other)
|
113
|
-
if other.is_a?(Numeric)
|
114
|
-
_add__scalar(other, value)
|
115
|
-
else
|
116
|
-
_add__tensor(other, value)
|
117
|
-
end
|
118
|
-
end
|
119
|
-
|
120
137
|
def +(other)
|
121
138
|
add(other)
|
122
139
|
end
|
@@ -145,12 +162,25 @@ module Torch
|
|
145
162
|
neg
|
146
163
|
end
|
147
164
|
|
165
|
+
def &(other)
|
166
|
+
logical_and(other)
|
167
|
+
end
|
168
|
+
|
169
|
+
def |(other)
|
170
|
+
logical_or(other)
|
171
|
+
end
|
172
|
+
|
173
|
+
def ^(other)
|
174
|
+
logical_xor(other)
|
175
|
+
end
|
176
|
+
|
148
177
|
# TODO better compare?
|
149
178
|
def <=>(other)
|
150
179
|
item <=> other
|
151
180
|
end
|
152
181
|
|
153
|
-
# based on python_variable_indexing.cpp
|
182
|
+
# based on python_variable_indexing.cpp and
|
183
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
154
184
|
def [](*indexes)
|
155
185
|
result = self
|
156
186
|
dim = 0
|
@@ -162,6 +192,8 @@ module Torch
|
|
162
192
|
finish += 1 unless index.exclude_end?
|
163
193
|
result = result._slice_tensor(dim, index.begin, finish, 1)
|
164
194
|
dim += 1
|
195
|
+
elsif index.is_a?(Tensor)
|
196
|
+
result = result.index([index])
|
165
197
|
elsif index.nil?
|
166
198
|
result = result.unsqueeze(dim)
|
167
199
|
dim += 1
|
@@ -175,12 +207,12 @@ module Torch
|
|
175
207
|
result
|
176
208
|
end
|
177
209
|
|
178
|
-
#
|
179
|
-
#
|
210
|
+
# based on python_variable_indexing.cpp and
|
211
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
180
212
|
def []=(index, value)
|
181
213
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
182
214
|
|
183
|
-
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
215
|
+
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
184
216
|
|
185
217
|
if index.is_a?(Numeric)
|
186
218
|
copy_to(_select_int(0, index), value)
|
@@ -188,13 +220,39 @@ module Torch
|
|
188
220
|
finish = index.end
|
189
221
|
finish += 1 unless index.exclude_end?
|
190
222
|
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
223
|
+
elsif index.is_a?(Tensor)
|
224
|
+
index_put!([index], value)
|
191
225
|
else
|
192
226
|
raise Error, "Unsupported index type: #{index.class.name}"
|
193
227
|
end
|
194
228
|
end
|
195
229
|
|
196
|
-
|
197
|
-
|
230
|
+
# native functions that need manually defined
|
231
|
+
|
232
|
+
# value and other are swapped for some methods
|
233
|
+
def add!(value = 1, other)
|
234
|
+
if other.is_a?(Numeric)
|
235
|
+
_add__scalar(other, value)
|
236
|
+
else
|
237
|
+
_add__tensor(other, value)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
|
241
|
+
# native functions overlap, so need to handle manually
|
242
|
+
def random!(*args)
|
243
|
+
case args.size
|
244
|
+
when 1
|
245
|
+
_random__to(*args)
|
246
|
+
when 2
|
247
|
+
_random__from_to(*args)
|
248
|
+
else
|
249
|
+
_random_(*args)
|
250
|
+
end
|
251
|
+
end
|
252
|
+
|
253
|
+
def clamp!(min, max)
|
254
|
+
_clamp_min_(min)
|
255
|
+
_clamp_max_(max)
|
198
256
|
end
|
199
257
|
|
200
258
|
private
|
@@ -202,17 +260,5 @@ module Torch
|
|
202
260
|
def copy_to(dst, src)
|
203
261
|
dst.copy!(src)
|
204
262
|
end
|
205
|
-
|
206
|
-
def reshape_arr(arr, dims)
|
207
|
-
if dims.empty?
|
208
|
-
arr
|
209
|
-
else
|
210
|
-
arr = arr.flatten
|
211
|
-
dims[1..-1].reverse.each do |dim|
|
212
|
-
arr = arr.each_slice(dim)
|
213
|
-
end
|
214
|
-
arr.to_a
|
215
|
-
end
|
216
|
-
end
|
217
263
|
end
|
218
264
|
end
|
@@ -6,21 +6,49 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
|
+
@shuffle = shuffle
|
12
13
|
end
|
13
14
|
|
14
15
|
def each
|
15
|
-
|
16
|
-
|
17
|
-
|
16
|
+
# try to keep the random number generator in sync with Python
|
17
|
+
# this makes it easy to compare results
|
18
|
+
base_seed = Torch.empty([], dtype: :int64).random!.item
|
19
|
+
|
20
|
+
indexes =
|
21
|
+
if @shuffle
|
22
|
+
Torch.randperm(@dataset.size).to_a
|
23
|
+
else
|
24
|
+
@dataset.size.times
|
25
|
+
end
|
26
|
+
|
27
|
+
indexes.each_slice(@batch_size) do |idx|
|
28
|
+
batch = idx.map { |i| @dataset[i] }
|
29
|
+
yield collate(batch)
|
18
30
|
end
|
19
31
|
end
|
20
32
|
|
21
33
|
def size
|
22
34
|
(@dataset.size / @batch_size.to_f).ceil
|
23
35
|
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def collate(batch)
|
40
|
+
elem = batch[0]
|
41
|
+
case elem
|
42
|
+
when Tensor
|
43
|
+
Torch.stack(batch, 0)
|
44
|
+
when Integer
|
45
|
+
Torch.tensor(batch)
|
46
|
+
when Array
|
47
|
+
batch.transpose.map { |v| collate(v) }
|
48
|
+
else
|
49
|
+
raise NotImpelmentYet
|
50
|
+
end
|
51
|
+
end
|
24
52
|
end
|
25
53
|
end
|
26
54
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-06-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -95,19 +95,19 @@ dependencies:
|
|
95
95
|
- !ruby/object:Gem::Version
|
96
96
|
version: '0'
|
97
97
|
- !ruby/object:Gem::Dependency
|
98
|
-
name:
|
98
|
+
name: torchvision
|
99
99
|
requirement: !ruby/object:Gem::Requirement
|
100
100
|
requirements:
|
101
101
|
- - ">="
|
102
102
|
- !ruby/object:Gem::Version
|
103
|
-
version:
|
103
|
+
version: 0.1.1
|
104
104
|
type: :development
|
105
105
|
prerelease: false
|
106
106
|
version_requirements: !ruby/object:Gem::Requirement
|
107
107
|
requirements:
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
|
-
version:
|
110
|
+
version: 0.1.1
|
111
111
|
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
@@ -258,9 +258,9 @@ files:
|
|
258
258
|
- lib/torch/optim/rmsprop.rb
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
|
-
- lib/torch/random.rb
|
262
261
|
- lib/torch/tensor.rb
|
263
262
|
- lib/torch/utils/data/data_loader.rb
|
263
|
+
- lib/torch/utils/data/dataset.rb
|
264
264
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
265
|
- lib/torch/version.rb
|
266
266
|
homepage: https://github.com/ankane/torch.rb
|