torch-rb 0.2.3 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -2
- data/README.md +22 -7
- data/ext/torch/ext.cpp +44 -22
- data/lib/torch.rb +7 -5
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -2
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +24 -5
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +74 -37
- data/lib/torch/utils/data/data_loader.rb +11 -6
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -5
data/lib/torch/native/parser.rb
CHANGED
data/lib/torch/nn/batch_norm.rb
CHANGED
@@ -70,6 +70,11 @@ module Torch
|
|
70
70
|
momentum: exponential_average_factor, eps: @eps
|
71
71
|
)
|
72
72
|
end
|
73
|
+
|
74
|
+
def extra_inspect
|
75
|
+
s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
|
76
|
+
format(s, **dict)
|
77
|
+
end
|
73
78
|
end
|
74
79
|
end
|
75
80
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -18,9 +18,15 @@ module Torch
|
|
18
18
|
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
19
|
end
|
20
20
|
|
21
|
-
# TODO add more parameters
|
22
21
|
def extra_inspect
|
23
|
-
|
22
|
+
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
23
|
+
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
24
|
+
s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
|
25
|
+
s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
|
26
|
+
s += ", groups: %{groups}" if @groups != 1
|
27
|
+
s += ", bias: false" unless @bias
|
28
|
+
s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
|
29
|
+
format(s, **dict)
|
24
30
|
end
|
25
31
|
end
|
26
32
|
end
|
data/lib/torch/nn/convnd.rb
CHANGED
data/lib/torch/nn/max_poolnd.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -145,7 +145,7 @@ module Torch
|
|
145
145
|
params = {}
|
146
146
|
if recurse
|
147
147
|
named_children.each do |name, mod|
|
148
|
-
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
148
|
+
params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
|
149
149
|
end
|
150
150
|
end
|
151
151
|
instance_variables.each do |name|
|
@@ -186,8 +186,22 @@ module Torch
|
|
186
186
|
named_modules.values
|
187
187
|
end
|
188
188
|
|
189
|
-
|
190
|
-
|
189
|
+
# TODO return enumerator?
|
190
|
+
def named_modules(memo: nil, prefix: "")
|
191
|
+
ret = {}
|
192
|
+
memo ||= Set.new
|
193
|
+
unless memo.include?(self)
|
194
|
+
memo << self
|
195
|
+
ret[prefix] = self
|
196
|
+
named_children.each do |name, mod|
|
197
|
+
next unless mod.is_a?(Module)
|
198
|
+
submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
|
199
|
+
mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
|
200
|
+
ret[m[0]] = m[1]
|
201
|
+
end
|
202
|
+
end
|
203
|
+
end
|
204
|
+
ret
|
191
205
|
end
|
192
206
|
|
193
207
|
def train(mode = true)
|
@@ -230,7 +244,9 @@ module Torch
|
|
230
244
|
str = String.new
|
231
245
|
str << "#{name}(\n"
|
232
246
|
named_children.each do |name, mod|
|
233
|
-
|
247
|
+
mod_str = mod.inspect
|
248
|
+
mod_str = mod_str.lines.join(" ")
|
249
|
+
str << " (#{name}): #{mod_str}\n"
|
234
250
|
end
|
235
251
|
str << ")"
|
236
252
|
end
|
@@ -270,8 +286,11 @@ module Torch
|
|
270
286
|
str % vars
|
271
287
|
end
|
272
288
|
|
289
|
+
# used for format
|
290
|
+
# remove tensors for performance
|
291
|
+
# so we can skip call to inspect
|
273
292
|
def dict
|
274
|
-
instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
293
|
+
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
275
294
|
end
|
276
295
|
end
|
277
296
|
end
|
@@ -32,9 +32,11 @@ module Torch
|
|
32
32
|
end
|
33
33
|
|
34
34
|
def state_dict
|
35
|
+
raise NotImplementedYet
|
36
|
+
|
35
37
|
pack_group = lambda do |group|
|
36
|
-
packed = group.select { |k, _| k != :params }.to_h
|
37
|
-
packed[
|
38
|
+
packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
|
39
|
+
packed["params"] = group[:params].map { |p| p.object_id }
|
38
40
|
packed
|
39
41
|
end
|
40
42
|
|
@@ -42,8 +44,8 @@ module Torch
|
|
42
44
|
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
45
|
|
44
46
|
{
|
45
|
-
state
|
46
|
-
param_groups
|
47
|
+
"state" => packed_state,
|
48
|
+
"param_groups" => param_groups
|
47
49
|
}
|
48
50
|
end
|
49
51
|
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
module Torch
|
2
2
|
class Tensor
|
3
3
|
include Comparable
|
4
|
+
include Enumerable
|
4
5
|
include Inspector
|
5
6
|
|
6
7
|
alias_method :requires_grad?, :requires_grad
|
@@ -25,14 +26,36 @@ module Torch
|
|
25
26
|
inspect
|
26
27
|
end
|
27
28
|
|
29
|
+
def each
|
30
|
+
return enum_for(:each) unless block_given?
|
31
|
+
|
32
|
+
size(0).times do |i|
|
33
|
+
yield self[i]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
# TODO make more performant
|
28
38
|
def to_a
|
29
|
-
|
39
|
+
arr = _flat_data
|
40
|
+
if shape.empty?
|
41
|
+
arr
|
42
|
+
else
|
43
|
+
shape[1..-1].reverse.each do |dim|
|
44
|
+
arr = arr.each_slice(dim)
|
45
|
+
end
|
46
|
+
arr.to_a
|
47
|
+
end
|
30
48
|
end
|
31
49
|
|
32
|
-
|
33
|
-
|
50
|
+
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
device ||= self.device
|
34
52
|
device = Device.new(device) if device.is_a?(String)
|
35
|
-
|
53
|
+
|
54
|
+
dtype ||= self.dtype
|
55
|
+
enum = DTYPE_TO_ENUM[dtype]
|
56
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
57
|
+
|
58
|
+
_to(device, enum, non_blocking, copy)
|
36
59
|
end
|
37
60
|
|
38
61
|
def cpu
|
@@ -64,7 +87,15 @@ module Torch
|
|
64
87
|
if numel != 1
|
65
88
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
66
89
|
end
|
67
|
-
|
90
|
+
to_a.first
|
91
|
+
end
|
92
|
+
|
93
|
+
def to_i
|
94
|
+
item.to_i
|
95
|
+
end
|
96
|
+
|
97
|
+
def to_f
|
98
|
+
item.to_f
|
68
99
|
end
|
69
100
|
|
70
101
|
# unsure if this is correct
|
@@ -80,7 +111,7 @@ module Torch
|
|
80
111
|
def numo
|
81
112
|
cls = Torch._dtype_to_numo[dtype]
|
82
113
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
83
|
-
cls.
|
114
|
+
cls.from_string(_data_str).reshape(*shape)
|
84
115
|
end
|
85
116
|
|
86
117
|
def new_ones(*size, **options)
|
@@ -108,15 +139,6 @@ module Torch
|
|
108
139
|
_view(size)
|
109
140
|
end
|
110
141
|
|
111
|
-
# value and other are swapped for some methods
|
112
|
-
def add!(value = 1, other)
|
113
|
-
if other.is_a?(Numeric)
|
114
|
-
_add__scalar(other, value)
|
115
|
-
else
|
116
|
-
_add__tensor(other, value)
|
117
|
-
end
|
118
|
-
end
|
119
|
-
|
120
142
|
def +(other)
|
121
143
|
add(other)
|
122
144
|
end
|
@@ -145,12 +167,25 @@ module Torch
|
|
145
167
|
neg
|
146
168
|
end
|
147
169
|
|
170
|
+
def &(other)
|
171
|
+
logical_and(other)
|
172
|
+
end
|
173
|
+
|
174
|
+
def |(other)
|
175
|
+
logical_or(other)
|
176
|
+
end
|
177
|
+
|
178
|
+
def ^(other)
|
179
|
+
logical_xor(other)
|
180
|
+
end
|
181
|
+
|
148
182
|
# TODO better compare?
|
149
183
|
def <=>(other)
|
150
184
|
item <=> other
|
151
185
|
end
|
152
186
|
|
153
|
-
# based on python_variable_indexing.cpp
|
187
|
+
# based on python_variable_indexing.cpp and
|
188
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
154
189
|
def [](*indexes)
|
155
190
|
result = self
|
156
191
|
dim = 0
|
@@ -162,6 +197,8 @@ module Torch
|
|
162
197
|
finish += 1 unless index.exclude_end?
|
163
198
|
result = result._slice_tensor(dim, index.begin, finish, 1)
|
164
199
|
dim += 1
|
200
|
+
elsif index.is_a?(Tensor)
|
201
|
+
result = result.index([index])
|
165
202
|
elsif index.nil?
|
166
203
|
result = result.unsqueeze(dim)
|
167
204
|
dim += 1
|
@@ -175,24 +212,37 @@ module Torch
|
|
175
212
|
result
|
176
213
|
end
|
177
214
|
|
178
|
-
#
|
179
|
-
#
|
215
|
+
# based on python_variable_indexing.cpp and
|
216
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
180
217
|
def []=(index, value)
|
181
218
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
182
219
|
|
183
|
-
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
220
|
+
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
184
221
|
|
185
222
|
if index.is_a?(Numeric)
|
186
|
-
|
223
|
+
index_put!([Torch.tensor(index)], value)
|
187
224
|
elsif index.is_a?(Range)
|
188
225
|
finish = index.end
|
189
226
|
finish += 1 unless index.exclude_end?
|
190
|
-
|
227
|
+
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
228
|
+
elsif index.is_a?(Tensor)
|
229
|
+
index_put!([index], value)
|
191
230
|
else
|
192
231
|
raise Error, "Unsupported index type: #{index.class.name}"
|
193
232
|
end
|
194
233
|
end
|
195
234
|
|
235
|
+
# native functions that need manually defined
|
236
|
+
|
237
|
+
# value and other are swapped for some methods
|
238
|
+
def add!(value = 1, other)
|
239
|
+
if other.is_a?(Numeric)
|
240
|
+
_add__scalar(other, value)
|
241
|
+
else
|
242
|
+
_add__tensor(other, value)
|
243
|
+
end
|
244
|
+
end
|
245
|
+
|
196
246
|
# native functions overlap, so need to handle manually
|
197
247
|
def random!(*args)
|
198
248
|
case args.size
|
@@ -205,22 +255,9 @@ module Torch
|
|
205
255
|
end
|
206
256
|
end
|
207
257
|
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
dst.copy!(src)
|
212
|
-
end
|
213
|
-
|
214
|
-
def reshape_arr(arr, dims)
|
215
|
-
if dims.empty?
|
216
|
-
arr
|
217
|
-
else
|
218
|
-
arr = arr.flatten
|
219
|
-
dims[1..-1].reverse.each do |dim|
|
220
|
-
arr = arr.each_slice(dim)
|
221
|
-
end
|
222
|
-
arr.to_a
|
223
|
-
end
|
258
|
+
def clamp!(min, max)
|
259
|
+
_clamp_min_(min)
|
260
|
+
_clamp_max_(max)
|
224
261
|
end
|
225
262
|
end
|
226
263
|
end
|
@@ -6,9 +6,10 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
|
+
@shuffle = shuffle
|
12
13
|
end
|
13
14
|
|
14
15
|
def each
|
@@ -16,11 +17,15 @@ module Torch
|
|
16
17
|
# this makes it easy to compare results
|
17
18
|
base_seed = Torch.empty([], dtype: :int64).random!.item
|
18
19
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
20
|
+
indexes =
|
21
|
+
if @shuffle
|
22
|
+
Torch.randperm(@dataset.size).to_a
|
23
|
+
else
|
24
|
+
@dataset.size.times
|
25
|
+
end
|
26
|
+
|
27
|
+
indexes.each_slice(@batch_size) do |idx|
|
28
|
+
batch = idx.map { |i| @dataset[i] }
|
24
29
|
yield collate(batch)
|
25
30
|
end
|
26
31
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-07-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -95,19 +95,19 @@ dependencies:
|
|
95
95
|
- !ruby/object:Gem::Version
|
96
96
|
version: '0'
|
97
97
|
- !ruby/object:Gem::Dependency
|
98
|
-
name:
|
98
|
+
name: torchvision
|
99
99
|
requirement: !ruby/object:Gem::Requirement
|
100
100
|
requirements:
|
101
101
|
- - ">="
|
102
102
|
- !ruby/object:Gem::Version
|
103
|
-
version:
|
103
|
+
version: 0.1.1
|
104
104
|
type: :development
|
105
105
|
prerelease: false
|
106
106
|
version_requirements: !ruby/object:Gem::Requirement
|
107
107
|
requirements:
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
|
-
version:
|
110
|
+
version: 0.1.1
|
111
111
|
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
@@ -260,6 +260,7 @@ files:
|
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
262
|
- lib/torch/utils/data/data_loader.rb
|
263
|
+
- lib/torch/utils/data/dataset.rb
|
263
264
|
- lib/torch/utils/data/tensor_dataset.rb
|
264
265
|
- lib/torch/version.rb
|
265
266
|
homepage: https://github.com/ankane/torch.rb
|