torch-rb 0.2.4 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +29 -2
- data/README.md +22 -7
- data/ext/torch/ext.cpp +46 -24
- data/ext/torch/extconf.rb +3 -4
- data/lib/torch.rb +7 -5
- data/lib/torch/hub.rb +48 -4
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/module.rb +5 -2
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +69 -39
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
data/lib/torch/native/parser.rb
CHANGED
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -18,7 +18,6 @@ module Torch
|
|
18
18
|
F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
|
19
19
|
end
|
20
20
|
|
21
|
-
# TODO add more parameters
|
22
21
|
def extra_inspect
|
23
22
|
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
23
|
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
data/lib/torch/nn/module.rb
CHANGED
@@ -145,7 +145,7 @@ module Torch
|
|
145
145
|
params = {}
|
146
146
|
if recurse
|
147
147
|
named_children.each do |name, mod|
|
148
|
-
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
148
|
+
params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
|
149
149
|
end
|
150
150
|
end
|
151
151
|
instance_variables.each do |name|
|
@@ -286,8 +286,11 @@ module Torch
|
|
286
286
|
str % vars
|
287
287
|
end
|
288
288
|
|
289
|
+
# used for format
|
290
|
+
# remove tensors for performance
|
291
|
+
# so we can skip call to inspect
|
289
292
|
def dict
|
290
|
-
instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
293
|
+
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
291
294
|
end
|
292
295
|
end
|
293
296
|
end
|
@@ -32,9 +32,11 @@ module Torch
|
|
32
32
|
end
|
33
33
|
|
34
34
|
def state_dict
|
35
|
+
raise NotImplementedYet
|
36
|
+
|
35
37
|
pack_group = lambda do |group|
|
36
|
-
packed = group.select { |k, _| k != :params }.to_h
|
37
|
-
packed[
|
38
|
+
packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
|
39
|
+
packed["params"] = group[:params].map { |p| p.object_id }
|
38
40
|
packed
|
39
41
|
end
|
40
42
|
|
@@ -42,8 +44,8 @@ module Torch
|
|
42
44
|
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
45
|
|
44
46
|
{
|
45
|
-
state
|
46
|
-
param_groups
|
47
|
+
"state" => packed_state,
|
48
|
+
"param_groups" => param_groups
|
47
49
|
}
|
48
50
|
end
|
49
51
|
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
module Torch
|
2
2
|
class Tensor
|
3
3
|
include Comparable
|
4
|
+
include Enumerable
|
4
5
|
include Inspector
|
5
6
|
|
6
7
|
alias_method :requires_grad?, :requires_grad
|
@@ -25,14 +26,36 @@ module Torch
|
|
25
26
|
inspect
|
26
27
|
end
|
27
28
|
|
29
|
+
def each
|
30
|
+
return enum_for(:each) unless block_given?
|
31
|
+
|
32
|
+
size(0).times do |i|
|
33
|
+
yield self[i]
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
# TODO make more performant
|
28
38
|
def to_a
|
29
|
-
|
39
|
+
arr = _flat_data
|
40
|
+
if shape.empty?
|
41
|
+
arr
|
42
|
+
else
|
43
|
+
shape[1..-1].reverse.each do |dim|
|
44
|
+
arr = arr.each_slice(dim)
|
45
|
+
end
|
46
|
+
arr.to_a
|
47
|
+
end
|
30
48
|
end
|
31
49
|
|
32
|
-
|
33
|
-
|
50
|
+
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
device ||= self.device
|
34
52
|
device = Device.new(device) if device.is_a?(String)
|
35
|
-
|
53
|
+
|
54
|
+
dtype ||= self.dtype
|
55
|
+
enum = DTYPE_TO_ENUM[dtype]
|
56
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
57
|
+
|
58
|
+
_to(device, enum, non_blocking, copy)
|
36
59
|
end
|
37
60
|
|
38
61
|
def cpu
|
@@ -64,7 +87,7 @@ module Torch
|
|
64
87
|
if numel != 1
|
65
88
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
66
89
|
end
|
67
|
-
|
90
|
+
to_a.first
|
68
91
|
end
|
69
92
|
|
70
93
|
def to_i
|
@@ -80,15 +103,16 @@ module Torch
|
|
80
103
|
Torch.empty(0, dtype: dtype)
|
81
104
|
end
|
82
105
|
|
83
|
-
def backward(gradient = nil)
|
84
|
-
|
106
|
+
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
+
retain_graph = create_graph if retain_graph.nil?
|
108
|
+
_backward(gradient, retain_graph, create_graph)
|
85
109
|
end
|
86
110
|
|
87
111
|
# TODO read directly from memory
|
88
112
|
def numo
|
89
113
|
cls = Torch._dtype_to_numo[dtype]
|
90
114
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
91
|
-
cls.
|
115
|
+
cls.from_string(_data_str).reshape(*shape)
|
92
116
|
end
|
93
117
|
|
94
118
|
def new_ones(*size, **options)
|
@@ -116,15 +140,6 @@ module Torch
|
|
116
140
|
_view(size)
|
117
141
|
end
|
118
142
|
|
119
|
-
# value and other are swapped for some methods
|
120
|
-
def add!(value = 1, other)
|
121
|
-
if other.is_a?(Numeric)
|
122
|
-
_add__scalar(other, value)
|
123
|
-
else
|
124
|
-
_add__tensor(other, value)
|
125
|
-
end
|
126
|
-
end
|
127
|
-
|
128
143
|
def +(other)
|
129
144
|
add(other)
|
130
145
|
end
|
@@ -153,12 +168,25 @@ module Torch
|
|
153
168
|
neg
|
154
169
|
end
|
155
170
|
|
171
|
+
def &(other)
|
172
|
+
logical_and(other)
|
173
|
+
end
|
174
|
+
|
175
|
+
def |(other)
|
176
|
+
logical_or(other)
|
177
|
+
end
|
178
|
+
|
179
|
+
def ^(other)
|
180
|
+
logical_xor(other)
|
181
|
+
end
|
182
|
+
|
156
183
|
# TODO better compare?
|
157
184
|
def <=>(other)
|
158
185
|
item <=> other
|
159
186
|
end
|
160
187
|
|
161
|
-
# based on python_variable_indexing.cpp
|
188
|
+
# based on python_variable_indexing.cpp and
|
189
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
162
190
|
def [](*indexes)
|
163
191
|
result = self
|
164
192
|
dim = 0
|
@@ -170,6 +198,8 @@ module Torch
|
|
170
198
|
finish += 1 unless index.exclude_end?
|
171
199
|
result = result._slice_tensor(dim, index.begin, finish, 1)
|
172
200
|
dim += 1
|
201
|
+
elsif index.is_a?(Tensor)
|
202
|
+
result = result.index([index])
|
173
203
|
elsif index.nil?
|
174
204
|
result = result.unsqueeze(dim)
|
175
205
|
dim += 1
|
@@ -183,24 +213,37 @@ module Torch
|
|
183
213
|
result
|
184
214
|
end
|
185
215
|
|
186
|
-
#
|
187
|
-
#
|
216
|
+
# based on python_variable_indexing.cpp and
|
217
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
188
218
|
def []=(index, value)
|
189
219
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
190
220
|
|
191
|
-
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
221
|
+
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
192
222
|
|
193
223
|
if index.is_a?(Numeric)
|
194
|
-
|
224
|
+
index_put!([Torch.tensor(index)], value)
|
195
225
|
elsif index.is_a?(Range)
|
196
226
|
finish = index.end
|
197
227
|
finish += 1 unless index.exclude_end?
|
198
|
-
|
228
|
+
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
229
|
+
elsif index.is_a?(Tensor)
|
230
|
+
index_put!([index], value)
|
199
231
|
else
|
200
232
|
raise Error, "Unsupported index type: #{index.class.name}"
|
201
233
|
end
|
202
234
|
end
|
203
235
|
|
236
|
+
# native functions that need manually defined
|
237
|
+
|
238
|
+
# value and other are swapped for some methods
|
239
|
+
def add!(value = 1, other)
|
240
|
+
if other.is_a?(Numeric)
|
241
|
+
_add__scalar(other, value)
|
242
|
+
else
|
243
|
+
_add__tensor(other, value)
|
244
|
+
end
|
245
|
+
end
|
246
|
+
|
204
247
|
# native functions overlap, so need to handle manually
|
205
248
|
def random!(*args)
|
206
249
|
case args.size
|
@@ -213,22 +256,9 @@ module Torch
|
|
213
256
|
end
|
214
257
|
end
|
215
258
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
dst.copy!(src)
|
220
|
-
end
|
221
|
-
|
222
|
-
def reshape_arr(arr, dims)
|
223
|
-
if dims.empty?
|
224
|
-
arr
|
225
|
-
else
|
226
|
-
arr = arr.flatten
|
227
|
-
dims[1..-1].reverse.each do |dim|
|
228
|
-
arr = arr.each_slice(dim)
|
229
|
-
end
|
230
|
-
arr.to_a
|
231
|
-
end
|
259
|
+
def clamp!(min, max)
|
260
|
+
_clamp_min_(min)
|
261
|
+
_clamp_max_(max)
|
232
262
|
end
|
233
263
|
end
|
234
264
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-08-17 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|