torch-rb 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +46 -0
- data/README.md +426 -0
- data/ext/torch/ext.cpp +839 -0
- data/ext/torch/extconf.rb +25 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +422 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +85 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +37 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +100 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +85 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +14 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/tensor.rb +196 -0
- data/lib/torch/utils/data/data_loader.rb +27 -0
- data/lib/torch/utils/data/tensor_dataset.rb +22 -0
- data/lib/torch/version.rb +3 -0
- metadata +169 -0
@@ -0,0 +1,25 @@
|
|
1
|
+
require "mkmf-rice"
|
2
|
+
|
3
|
+
abort "Missing stdc++" unless have_library("stdc++")
|
4
|
+
|
5
|
+
$CXXFLAGS << " -std=c++11"
|
6
|
+
|
7
|
+
# needed for Linux pre-cxx11 ABI version
|
8
|
+
# $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=0"
|
9
|
+
|
10
|
+
# silence ruby/intern.h warning
|
11
|
+
$CXXFLAGS << " -Wno-deprecated-register"
|
12
|
+
|
13
|
+
inc, lib = dir_config("torch")
|
14
|
+
|
15
|
+
inc ||= "/usr/local/include"
|
16
|
+
lib ||= "/usr/local/lib"
|
17
|
+
|
18
|
+
$INCFLAGS << " -I#{inc}"
|
19
|
+
$INCFLAGS << " -I#{inc}/torch/csrc/api/include"
|
20
|
+
|
21
|
+
$LDFLAGS << " -Wl,-rpath,#{lib}"
|
22
|
+
$LDFLAGS << " -L#{lib}"
|
23
|
+
$LDFLAGS << " -ltorch -lc10"
|
24
|
+
|
25
|
+
create_makefile("torch/ext")
|
data/lib/torch-rb.rb
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
require "torch"
|
data/lib/torch.rb
ADDED
@@ -0,0 +1,422 @@
|
|
1
|
+
# ext
|
2
|
+
require "torch/ext"
|
3
|
+
|
4
|
+
# modules
|
5
|
+
require "torch/inspector"
|
6
|
+
require "torch/tensor"
|
7
|
+
require "torch/version"
|
8
|
+
|
9
|
+
# optim
|
10
|
+
require "torch/optim/optimizer"
|
11
|
+
require "torch/optim/adadelta"
|
12
|
+
require "torch/optim/adagrad"
|
13
|
+
require "torch/optim/adam"
|
14
|
+
require "torch/optim/adamax"
|
15
|
+
require "torch/optim/adamw"
|
16
|
+
require "torch/optim/asgd"
|
17
|
+
require "torch/optim/rmsprop"
|
18
|
+
require "torch/optim/rprop"
|
19
|
+
require "torch/optim/sgd"
|
20
|
+
|
21
|
+
# optim lr_scheduler
|
22
|
+
require "torch/optim/lr_scheduler/lr_scheduler"
|
23
|
+
require "torch/optim/lr_scheduler/step_lr"
|
24
|
+
|
25
|
+
# nn base classes
|
26
|
+
require "torch/nn/module"
|
27
|
+
require "torch/nn/convnd"
|
28
|
+
require "torch/nn/dropoutnd"
|
29
|
+
|
30
|
+
# nn
|
31
|
+
require "torch/nn/alpha_dropout"
|
32
|
+
require "torch/nn/conv2d"
|
33
|
+
require "torch/nn/dropout"
|
34
|
+
require "torch/nn/dropout2d"
|
35
|
+
require "torch/nn/dropout3d"
|
36
|
+
require "torch/nn/embedding"
|
37
|
+
require "torch/nn/feature_alpha_dropout"
|
38
|
+
require "torch/nn/functional"
|
39
|
+
require "torch/nn/init"
|
40
|
+
require "torch/nn/linear"
|
41
|
+
require "torch/nn/mse_loss"
|
42
|
+
require "torch/nn/parameter"
|
43
|
+
require "torch/nn/relu"
|
44
|
+
require "torch/nn/sequential"
|
45
|
+
|
46
|
+
# utils
|
47
|
+
require "torch/utils/data/data_loader"
|
48
|
+
require "torch/utils/data/tensor_dataset"
|
49
|
+
|
50
|
+
module Torch
|
51
|
+
class Error < StandardError; end
|
52
|
+
class NotImplementedYet < StandardError
|
53
|
+
def message
|
54
|
+
"This feature has not been implemented yet. Consider submitting a PR."
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
59
|
+
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
60
|
+
# complex and quantized types not supported by PyTorch yet
|
61
|
+
DTYPE_TO_ENUM = {
|
62
|
+
uint8: 0,
|
63
|
+
int8: 1,
|
64
|
+
short: 2,
|
65
|
+
int16: 2,
|
66
|
+
int: 3,
|
67
|
+
int32: 3,
|
68
|
+
long: 4,
|
69
|
+
int64: 4,
|
70
|
+
half: 5,
|
71
|
+
float16: 5,
|
72
|
+
float: 6,
|
73
|
+
float32: 6,
|
74
|
+
double: 7,
|
75
|
+
float64: 7,
|
76
|
+
# complex_half: 8,
|
77
|
+
# complex_float: 9,
|
78
|
+
# complex_double: 10,
|
79
|
+
bool: 11,
|
80
|
+
# qint8: 12,
|
81
|
+
# quint8: 13,
|
82
|
+
# qint32: 14,
|
83
|
+
# bfloat16: 15
|
84
|
+
}
|
85
|
+
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
86
|
+
|
87
|
+
class << self
|
88
|
+
# Torch.float, Torch.long, etc
|
89
|
+
DTYPE_TO_ENUM.each_key do |dtype|
|
90
|
+
define_method(dtype) do
|
91
|
+
dtype
|
92
|
+
end
|
93
|
+
|
94
|
+
Tensor.define_method(dtype) do
|
95
|
+
type(dtype)
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
# https://pytorch.org/docs/stable/torch.html
|
100
|
+
|
101
|
+
def tensor?(obj)
|
102
|
+
obj.is_a?(Tensor)
|
103
|
+
end
|
104
|
+
|
105
|
+
def from_numo(ndarray)
|
106
|
+
dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
|
107
|
+
raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
|
108
|
+
options = tensor_options(device: "cpu", dtype: dtype[0])
|
109
|
+
# TODO pass pointer to array instead of creating string
|
110
|
+
str = ndarray.to_string
|
111
|
+
tensor = _from_blob(str, ndarray.shape, options)
|
112
|
+
# from_blob does not own the data, so we need to keep
|
113
|
+
# a reference to it for duration of tensor
|
114
|
+
# can remove when passing pointer directly
|
115
|
+
tensor.instance_variable_set("@_numo_str", str)
|
116
|
+
tensor
|
117
|
+
end
|
118
|
+
|
119
|
+
# private
|
120
|
+
# use method for cases when Numo not available
|
121
|
+
# or available after Torch loaded
|
122
|
+
def _dtype_to_numo
|
123
|
+
{
|
124
|
+
uint8: Numo::UInt8,
|
125
|
+
int8: Numo::Int8,
|
126
|
+
int16: Numo::Int16,
|
127
|
+
int32: Numo::Int32,
|
128
|
+
int64: Numo::Int64,
|
129
|
+
float32: Numo::SFloat,
|
130
|
+
float64: Numo::DFloat
|
131
|
+
}
|
132
|
+
end
|
133
|
+
|
134
|
+
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
135
|
+
|
136
|
+
def arange(start, finish = nil, step = 1, **options)
|
137
|
+
# ruby doesn't support start = 0, finish, step = 1, ...
|
138
|
+
if finish.nil?
|
139
|
+
finish = start
|
140
|
+
start = 0
|
141
|
+
end
|
142
|
+
_arange(start, finish, step, tensor_options(**options))
|
143
|
+
end
|
144
|
+
|
145
|
+
def empty(*size, **options)
|
146
|
+
_empty(tensor_size(size), tensor_options(**options))
|
147
|
+
end
|
148
|
+
|
149
|
+
def eye(n, m = nil, **options)
|
150
|
+
_eye(n, m || n, tensor_options(**options))
|
151
|
+
end
|
152
|
+
|
153
|
+
def full(size, fill_value, **options)
|
154
|
+
_full(size, fill_value, tensor_options(**options))
|
155
|
+
end
|
156
|
+
|
157
|
+
def linspace(start, finish, steps = 100, **options)
|
158
|
+
_linspace(start, finish, steps, tensor_options(**options))
|
159
|
+
end
|
160
|
+
|
161
|
+
def logspace(start, finish, steps = 100, base = 10.0, **options)
|
162
|
+
_logspace(start, finish, steps, base, tensor_options(**options))
|
163
|
+
end
|
164
|
+
|
165
|
+
def ones(*size, **options)
|
166
|
+
_ones(tensor_size(size), tensor_options(**options))
|
167
|
+
end
|
168
|
+
|
169
|
+
def rand(*size, **options)
|
170
|
+
_rand(tensor_size(size), tensor_options(**options))
|
171
|
+
end
|
172
|
+
|
173
|
+
def randint(low = 0, high, size, **options)
|
174
|
+
_randint(low, high, size, tensor_options(**options))
|
175
|
+
end
|
176
|
+
|
177
|
+
def randn(*size, **options)
|
178
|
+
_randn(tensor_size(size), tensor_options(**options))
|
179
|
+
end
|
180
|
+
|
181
|
+
def randperm(n, **options)
|
182
|
+
_randperm(n, tensor_options(**options))
|
183
|
+
end
|
184
|
+
|
185
|
+
def zeros(*size, **options)
|
186
|
+
_zeros(tensor_size(size), tensor_options(**options))
|
187
|
+
end
|
188
|
+
|
189
|
+
def tensor(data, **options)
|
190
|
+
size = []
|
191
|
+
if data.respond_to?(:to_a)
|
192
|
+
data = data.to_a
|
193
|
+
d = data
|
194
|
+
while d.is_a?(Array)
|
195
|
+
size << d.size
|
196
|
+
d = d.first
|
197
|
+
end
|
198
|
+
data = data.flatten
|
199
|
+
else
|
200
|
+
data = [data].compact
|
201
|
+
end
|
202
|
+
|
203
|
+
if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
|
204
|
+
options[:dtype] = :int64
|
205
|
+
end
|
206
|
+
|
207
|
+
_tensor(data, size, tensor_options(**options))
|
208
|
+
end
|
209
|
+
|
210
|
+
# --- begin like ---
|
211
|
+
|
212
|
+
def ones_like(input, **options)
|
213
|
+
ones(input.size, like_options(input, options))
|
214
|
+
end
|
215
|
+
|
216
|
+
def empty_like(input, **options)
|
217
|
+
empty(input.size, like_options(input, options))
|
218
|
+
end
|
219
|
+
|
220
|
+
def full_like(input, fill_value, **options)
|
221
|
+
full(input.size, fill_value, like_options(input, options))
|
222
|
+
end
|
223
|
+
|
224
|
+
def rand_like(input, **options)
|
225
|
+
rand(input.size, like_options(input, options))
|
226
|
+
end
|
227
|
+
|
228
|
+
def randint_like(input, low, high = nil, **options)
|
229
|
+
# ruby doesn't support input, low = 0, high, ...
|
230
|
+
if high.nil?
|
231
|
+
high = low
|
232
|
+
low = 0
|
233
|
+
end
|
234
|
+
randint(low, high, input.size, like_options(input, options))
|
235
|
+
end
|
236
|
+
|
237
|
+
def randn_like(input, **options)
|
238
|
+
randn(input.size, like_options(input, options))
|
239
|
+
end
|
240
|
+
|
241
|
+
def zeros_like(input, **options)
|
242
|
+
zeros(input.size, like_options(input, options))
|
243
|
+
end
|
244
|
+
|
245
|
+
# --- begin operations ---
|
246
|
+
|
247
|
+
%w(add sub mul div remainder).each do |op|
|
248
|
+
define_method(op) do |input, other, **options|
|
249
|
+
execute_op(op, input, other, **options)
|
250
|
+
end
|
251
|
+
end
|
252
|
+
|
253
|
+
def neg(input)
|
254
|
+
_neg(input)
|
255
|
+
end
|
256
|
+
|
257
|
+
def no_grad
|
258
|
+
previous_value = grad_enabled?
|
259
|
+
begin
|
260
|
+
_set_grad_enabled(false)
|
261
|
+
yield
|
262
|
+
ensure
|
263
|
+
_set_grad_enabled(previous_value)
|
264
|
+
end
|
265
|
+
end
|
266
|
+
|
267
|
+
# TODO support out
|
268
|
+
def mean(input, dim = nil, keepdim: false)
|
269
|
+
if dim
|
270
|
+
_mean_dim(input, dim, keepdim)
|
271
|
+
else
|
272
|
+
_mean(input)
|
273
|
+
end
|
274
|
+
end
|
275
|
+
|
276
|
+
# TODO support dtype
|
277
|
+
def sum(input, dim = nil, keepdim: false)
|
278
|
+
if dim
|
279
|
+
_sum_dim(input, dim, keepdim)
|
280
|
+
else
|
281
|
+
_sum(input)
|
282
|
+
end
|
283
|
+
end
|
284
|
+
|
285
|
+
def argmax(input, dim = nil, keepdim: false)
|
286
|
+
if dim
|
287
|
+
_argmax_dim(input, dim, keepdim)
|
288
|
+
else
|
289
|
+
_argmax(input)
|
290
|
+
end
|
291
|
+
end
|
292
|
+
|
293
|
+
def eq(input, other)
|
294
|
+
_eq(input, other)
|
295
|
+
end
|
296
|
+
|
297
|
+
def norm(input)
|
298
|
+
_norm(input)
|
299
|
+
end
|
300
|
+
|
301
|
+
def pow(input, exponent)
|
302
|
+
_pow(input, exponent)
|
303
|
+
end
|
304
|
+
|
305
|
+
def min(input)
|
306
|
+
_min(input)
|
307
|
+
end
|
308
|
+
|
309
|
+
def max(input, dim = nil, keepdim: false, out: nil)
|
310
|
+
if dim
|
311
|
+
raise NotImplementedYet unless out
|
312
|
+
_max_out(out[0], out[1], input, dim, keepdim)
|
313
|
+
else
|
314
|
+
_max(input)
|
315
|
+
end
|
316
|
+
end
|
317
|
+
|
318
|
+
def exp(input)
|
319
|
+
_exp(input)
|
320
|
+
end
|
321
|
+
|
322
|
+
def log(input)
|
323
|
+
_log(input)
|
324
|
+
end
|
325
|
+
|
326
|
+
def sign(input)
|
327
|
+
_sign(input)
|
328
|
+
end
|
329
|
+
|
330
|
+
def gt(input, other)
|
331
|
+
_gt(input, other)
|
332
|
+
end
|
333
|
+
|
334
|
+
def lt(input, other)
|
335
|
+
_lt(input, other)
|
336
|
+
end
|
337
|
+
|
338
|
+
def unsqueeze(input, dim)
|
339
|
+
_unsqueeze(input, dim)
|
340
|
+
end
|
341
|
+
|
342
|
+
def dot(input, tensor)
|
343
|
+
_dot(input, tensor)
|
344
|
+
end
|
345
|
+
|
346
|
+
def cat(tensors, dim = 0)
|
347
|
+
_cat(tensors, dim)
|
348
|
+
end
|
349
|
+
|
350
|
+
def matmul(input, other)
|
351
|
+
_matmul(input, other)
|
352
|
+
end
|
353
|
+
|
354
|
+
def reshape(input, shape)
|
355
|
+
_reshape(input, shape)
|
356
|
+
end
|
357
|
+
|
358
|
+
def flatten(input, start_dim: 0, end_dim: -1)
|
359
|
+
_flatten(input, start_dim, end_dim)
|
360
|
+
end
|
361
|
+
|
362
|
+
def sqrt(input)
|
363
|
+
_sqrt(input)
|
364
|
+
end
|
365
|
+
|
366
|
+
def abs(input)
|
367
|
+
_abs(input)
|
368
|
+
end
|
369
|
+
|
370
|
+
def device(str)
|
371
|
+
Device.new(str)
|
372
|
+
end
|
373
|
+
|
374
|
+
private
|
375
|
+
|
376
|
+
def execute_op(op, input, other, out: nil)
|
377
|
+
scalar = other.is_a?(Numeric)
|
378
|
+
if out
|
379
|
+
# TODO make work with scalars
|
380
|
+
raise Error, "out not supported with scalar yet" if scalar
|
381
|
+
send("_#{op}_out", out, input, other)
|
382
|
+
else
|
383
|
+
if scalar
|
384
|
+
send("_#{op}_scalar", input, other)
|
385
|
+
else
|
386
|
+
send("_#{op}", input, other)
|
387
|
+
end
|
388
|
+
end
|
389
|
+
end
|
390
|
+
|
391
|
+
def tensor_size(size)
|
392
|
+
size.flatten
|
393
|
+
end
|
394
|
+
|
395
|
+
def tensor_options(dtype: nil, layout: nil, device: nil, requires_grad: nil)
|
396
|
+
options = TensorOptions.new
|
397
|
+
unless dtype.nil?
|
398
|
+
type = DTYPE_TO_ENUM[dtype]
|
399
|
+
raise Error, "Unknown dtype: #{dtype.inspect}" unless type
|
400
|
+
options = options.dtype(type)
|
401
|
+
end
|
402
|
+
unless device.nil?
|
403
|
+
options = options.device(device.to_s)
|
404
|
+
end
|
405
|
+
unless layout.nil?
|
406
|
+
options = options.layout(layout.to_s)
|
407
|
+
end
|
408
|
+
unless requires_grad.nil?
|
409
|
+
options = options.requires_grad(requires_grad)
|
410
|
+
end
|
411
|
+
options
|
412
|
+
end
|
413
|
+
|
414
|
+
def like_options(input, options)
|
415
|
+
options = options.dup
|
416
|
+
options[:dtype] ||= input.dtype
|
417
|
+
options[:layout] ||= input.layout
|
418
|
+
options[:device] ||= input.device
|
419
|
+
options
|
420
|
+
end
|
421
|
+
end
|
422
|
+
end
|
Binary file
|