torch-rb 0.2.2 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +19 -7
- data/ext/torch/ext.cpp +64 -19
- data/ext/torch/extconf.rb +21 -18
- data/lib/torch.rb +6 -3
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +26 -7
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +76 -30
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -6
- data/lib/torch/random.rb +0 -10
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
module Torch
|
2
2
|
class Tensor
|
3
3
|
include Comparable
|
4
|
+
include Enumerable
|
4
5
|
include Inspector
|
5
6
|
|
6
7
|
alias_method :requires_grad?, :requires_grad
|
@@ -25,8 +26,25 @@ module Torch
|
|
25
26
|
inspect
|
26
27
|
end
|
27
28
|
|
29
|
+
def each
|
30
|
+
return enum_for(:each) unless block_given?
|
31
|
+
|
32
|
+
size(0).times do |i|
|
33
|
+
yield self[i]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
# TODO make more performant
|
28
38
|
def to_a
|
29
|
-
|
39
|
+
arr = _flat_data
|
40
|
+
if shape.empty?
|
41
|
+
arr
|
42
|
+
else
|
43
|
+
shape[1..-1].reverse.each do |dim|
|
44
|
+
arr = arr.each_slice(dim)
|
45
|
+
end
|
46
|
+
arr.to_a
|
47
|
+
end
|
30
48
|
end
|
31
49
|
|
32
50
|
# TODO support dtype
|
@@ -64,7 +82,15 @@ module Torch
|
|
64
82
|
if numel != 1
|
65
83
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
66
84
|
end
|
67
|
-
|
85
|
+
to_a.first
|
86
|
+
end
|
87
|
+
|
88
|
+
def to_i
|
89
|
+
item.to_i
|
90
|
+
end
|
91
|
+
|
92
|
+
def to_f
|
93
|
+
item.to_f
|
68
94
|
end
|
69
95
|
|
70
96
|
# unsure if this is correct
|
@@ -80,7 +106,7 @@ module Torch
|
|
80
106
|
def numo
|
81
107
|
cls = Torch._dtype_to_numo[dtype]
|
82
108
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
83
|
-
cls.
|
109
|
+
cls.from_string(_data_str).reshape(*shape)
|
84
110
|
end
|
85
111
|
|
86
112
|
def new_ones(*size, **options)
|
@@ -108,15 +134,6 @@ module Torch
|
|
108
134
|
_view(size)
|
109
135
|
end
|
110
136
|
|
111
|
-
# value and other are swapped for some methods
|
112
|
-
def add!(value = 1, other)
|
113
|
-
if other.is_a?(Numeric)
|
114
|
-
_add__scalar(other, value)
|
115
|
-
else
|
116
|
-
_add__tensor(other, value)
|
117
|
-
end
|
118
|
-
end
|
119
|
-
|
120
137
|
def +(other)
|
121
138
|
add(other)
|
122
139
|
end
|
@@ -145,12 +162,25 @@ module Torch
|
|
145
162
|
neg
|
146
163
|
end
|
147
164
|
|
165
|
+
def &(other)
|
166
|
+
logical_and(other)
|
167
|
+
end
|
168
|
+
|
169
|
+
def |(other)
|
170
|
+
logical_or(other)
|
171
|
+
end
|
172
|
+
|
173
|
+
def ^(other)
|
174
|
+
logical_xor(other)
|
175
|
+
end
|
176
|
+
|
148
177
|
# TODO better compare?
|
149
178
|
def <=>(other)
|
150
179
|
item <=> other
|
151
180
|
end
|
152
181
|
|
153
|
-
# based on python_variable_indexing.cpp
|
182
|
+
# based on python_variable_indexing.cpp and
|
183
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
154
184
|
def [](*indexes)
|
155
185
|
result = self
|
156
186
|
dim = 0
|
@@ -162,6 +192,8 @@ module Torch
|
|
162
192
|
finish += 1 unless index.exclude_end?
|
163
193
|
result = result._slice_tensor(dim, index.begin, finish, 1)
|
164
194
|
dim += 1
|
195
|
+
elsif index.is_a?(Tensor)
|
196
|
+
result = result.index([index])
|
165
197
|
elsif index.nil?
|
166
198
|
result = result.unsqueeze(dim)
|
167
199
|
dim += 1
|
@@ -175,12 +207,12 @@ module Torch
|
|
175
207
|
result
|
176
208
|
end
|
177
209
|
|
178
|
-
#
|
179
|
-
#
|
210
|
+
# based on python_variable_indexing.cpp and
|
211
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
180
212
|
def []=(index, value)
|
181
213
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
182
214
|
|
183
|
-
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
215
|
+
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
184
216
|
|
185
217
|
if index.is_a?(Numeric)
|
186
218
|
copy_to(_select_int(0, index), value)
|
@@ -188,13 +220,39 @@ module Torch
|
|
188
220
|
finish = index.end
|
189
221
|
finish += 1 unless index.exclude_end?
|
190
222
|
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
223
|
+
elsif index.is_a?(Tensor)
|
224
|
+
index_put!([index], value)
|
191
225
|
else
|
192
226
|
raise Error, "Unsupported index type: #{index.class.name}"
|
193
227
|
end
|
194
228
|
end
|
195
229
|
|
196
|
-
|
197
|
-
|
230
|
+
# native functions that need manually defined
|
231
|
+
|
232
|
+
# value and other are swapped for some methods
|
233
|
+
def add!(value = 1, other)
|
234
|
+
if other.is_a?(Numeric)
|
235
|
+
_add__scalar(other, value)
|
236
|
+
else
|
237
|
+
_add__tensor(other, value)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
|
241
|
+
# native functions overlap, so need to handle manually
|
242
|
+
def random!(*args)
|
243
|
+
case args.size
|
244
|
+
when 1
|
245
|
+
_random__to(*args)
|
246
|
+
when 2
|
247
|
+
_random__from_to(*args)
|
248
|
+
else
|
249
|
+
_random_(*args)
|
250
|
+
end
|
251
|
+
end
|
252
|
+
|
253
|
+
def clamp!(min, max)
|
254
|
+
_clamp_min_(min)
|
255
|
+
_clamp_max_(max)
|
198
256
|
end
|
199
257
|
|
200
258
|
private
|
@@ -202,17 +260,5 @@ module Torch
|
|
202
260
|
def copy_to(dst, src)
|
203
261
|
dst.copy!(src)
|
204
262
|
end
|
205
|
-
|
206
|
-
def reshape_arr(arr, dims)
|
207
|
-
if dims.empty?
|
208
|
-
arr
|
209
|
-
else
|
210
|
-
arr = arr.flatten
|
211
|
-
dims[1..-1].reverse.each do |dim|
|
212
|
-
arr = arr.each_slice(dim)
|
213
|
-
end
|
214
|
-
arr.to_a
|
215
|
-
end
|
216
|
-
end
|
217
263
|
end
|
218
264
|
end
|
@@ -6,21 +6,49 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
|
+
@shuffle = shuffle
|
12
13
|
end
|
13
14
|
|
14
15
|
def each
|
15
|
-
|
16
|
-
|
17
|
-
|
16
|
+
# try to keep the random number generator in sync with Python
|
17
|
+
# this makes it easy to compare results
|
18
|
+
base_seed = Torch.empty([], dtype: :int64).random!.item
|
19
|
+
|
20
|
+
indexes =
|
21
|
+
if @shuffle
|
22
|
+
Torch.randperm(@dataset.size).to_a
|
23
|
+
else
|
24
|
+
@dataset.size.times
|
25
|
+
end
|
26
|
+
|
27
|
+
indexes.each_slice(@batch_size) do |idx|
|
28
|
+
batch = idx.map { |i| @dataset[i] }
|
29
|
+
yield collate(batch)
|
18
30
|
end
|
19
31
|
end
|
20
32
|
|
21
33
|
def size
|
22
34
|
(@dataset.size / @batch_size.to_f).ceil
|
23
35
|
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def collate(batch)
|
40
|
+
elem = batch[0]
|
41
|
+
case elem
|
42
|
+
when Tensor
|
43
|
+
Torch.stack(batch, 0)
|
44
|
+
when Integer
|
45
|
+
Torch.tensor(batch)
|
46
|
+
when Array
|
47
|
+
batch.transpose.map { |v| collate(v) }
|
48
|
+
else
|
49
|
+
raise NotImpelmentYet
|
50
|
+
end
|
51
|
+
end
|
24
52
|
end
|
25
53
|
end
|
26
54
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-06-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -95,19 +95,19 @@ dependencies:
|
|
95
95
|
- !ruby/object:Gem::Version
|
96
96
|
version: '0'
|
97
97
|
- !ruby/object:Gem::Dependency
|
98
|
-
name:
|
98
|
+
name: torchvision
|
99
99
|
requirement: !ruby/object:Gem::Requirement
|
100
100
|
requirements:
|
101
101
|
- - ">="
|
102
102
|
- !ruby/object:Gem::Version
|
103
|
-
version:
|
103
|
+
version: 0.1.1
|
104
104
|
type: :development
|
105
105
|
prerelease: false
|
106
106
|
version_requirements: !ruby/object:Gem::Requirement
|
107
107
|
requirements:
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
|
-
version:
|
110
|
+
version: 0.1.1
|
111
111
|
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
@@ -258,9 +258,9 @@ files:
|
|
258
258
|
- lib/torch/optim/rmsprop.rb
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
|
-
- lib/torch/random.rb
|
262
261
|
- lib/torch/tensor.rb
|
263
262
|
- lib/torch/utils/data/data_loader.rb
|
263
|
+
- lib/torch/utils/data/dataset.rb
|
264
264
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
265
|
- lib/torch/version.rb
|
266
266
|
homepage: https://github.com/ankane/torch.rb
|