torch-rb 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +28 -0
  3. data/LICENSE.txt +46 -0
  4. data/README.md +426 -0
  5. data/ext/torch/ext.cpp +839 -0
  6. data/ext/torch/extconf.rb +25 -0
  7. data/lib/torch-rb.rb +1 -0
  8. data/lib/torch.rb +422 -0
  9. data/lib/torch/ext.bundle +0 -0
  10. data/lib/torch/inspector.rb +85 -0
  11. data/lib/torch/nn/alpha_dropout.rb +9 -0
  12. data/lib/torch/nn/conv2d.rb +37 -0
  13. data/lib/torch/nn/convnd.rb +41 -0
  14. data/lib/torch/nn/dropout.rb +9 -0
  15. data/lib/torch/nn/dropout2d.rb +9 -0
  16. data/lib/torch/nn/dropout3d.rb +9 -0
  17. data/lib/torch/nn/dropoutnd.rb +15 -0
  18. data/lib/torch/nn/embedding.rb +52 -0
  19. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  20. data/lib/torch/nn/functional.rb +100 -0
  21. data/lib/torch/nn/init.rb +30 -0
  22. data/lib/torch/nn/linear.rb +36 -0
  23. data/lib/torch/nn/module.rb +85 -0
  24. data/lib/torch/nn/mse_loss.rb +13 -0
  25. data/lib/torch/nn/parameter.rb +14 -0
  26. data/lib/torch/nn/relu.rb +13 -0
  27. data/lib/torch/nn/sequential.rb +29 -0
  28. data/lib/torch/optim/adadelta.rb +57 -0
  29. data/lib/torch/optim/adagrad.rb +71 -0
  30. data/lib/torch/optim/adam.rb +81 -0
  31. data/lib/torch/optim/adamax.rb +68 -0
  32. data/lib/torch/optim/adamw.rb +82 -0
  33. data/lib/torch/optim/asgd.rb +65 -0
  34. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  35. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  36. data/lib/torch/optim/optimizer.rb +62 -0
  37. data/lib/torch/optim/rmsprop.rb +76 -0
  38. data/lib/torch/optim/rprop.rb +68 -0
  39. data/lib/torch/optim/sgd.rb +60 -0
  40. data/lib/torch/tensor.rb +196 -0
  41. data/lib/torch/utils/data/data_loader.rb +27 -0
  42. data/lib/torch/utils/data/tensor_dataset.rb +22 -0
  43. data/lib/torch/version.rb +3 -0
  44. metadata +169 -0
@@ -0,0 +1,25 @@
1
+ require "mkmf-rice"
2
+
3
+ abort "Missing stdc++" unless have_library("stdc++")
4
+
5
+ $CXXFLAGS << " -std=c++11"
6
+
7
+ # needed for Linux pre-cxx11 ABI version
8
+ # $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=0"
9
+
10
+ # silence ruby/intern.h warning
11
+ $CXXFLAGS << " -Wno-deprecated-register"
12
+
13
+ inc, lib = dir_config("torch")
14
+
15
+ inc ||= "/usr/local/include"
16
+ lib ||= "/usr/local/lib"
17
+
18
+ $INCFLAGS << " -I#{inc}"
19
+ $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
20
+
21
+ $LDFLAGS << " -Wl,-rpath,#{lib}"
22
+ $LDFLAGS << " -L#{lib}"
23
+ $LDFLAGS << " -ltorch -lc10"
24
+
25
+ create_makefile("torch/ext")
@@ -0,0 +1 @@
1
+ require "torch"
@@ -0,0 +1,422 @@
1
+ # ext
2
+ require "torch/ext"
3
+
4
+ # modules
5
+ require "torch/inspector"
6
+ require "torch/tensor"
7
+ require "torch/version"
8
+
9
+ # optim
10
+ require "torch/optim/optimizer"
11
+ require "torch/optim/adadelta"
12
+ require "torch/optim/adagrad"
13
+ require "torch/optim/adam"
14
+ require "torch/optim/adamax"
15
+ require "torch/optim/adamw"
16
+ require "torch/optim/asgd"
17
+ require "torch/optim/rmsprop"
18
+ require "torch/optim/rprop"
19
+ require "torch/optim/sgd"
20
+
21
+ # optim lr_scheduler
22
+ require "torch/optim/lr_scheduler/lr_scheduler"
23
+ require "torch/optim/lr_scheduler/step_lr"
24
+
25
+ # nn base classes
26
+ require "torch/nn/module"
27
+ require "torch/nn/convnd"
28
+ require "torch/nn/dropoutnd"
29
+
30
+ # nn
31
+ require "torch/nn/alpha_dropout"
32
+ require "torch/nn/conv2d"
33
+ require "torch/nn/dropout"
34
+ require "torch/nn/dropout2d"
35
+ require "torch/nn/dropout3d"
36
+ require "torch/nn/embedding"
37
+ require "torch/nn/feature_alpha_dropout"
38
+ require "torch/nn/functional"
39
+ require "torch/nn/init"
40
+ require "torch/nn/linear"
41
+ require "torch/nn/mse_loss"
42
+ require "torch/nn/parameter"
43
+ require "torch/nn/relu"
44
+ require "torch/nn/sequential"
45
+
46
+ # utils
47
+ require "torch/utils/data/data_loader"
48
+ require "torch/utils/data/tensor_dataset"
49
+
50
+ module Torch
51
+ class Error < StandardError; end
52
+ class NotImplementedYet < StandardError
53
+ def message
54
+ "This feature has not been implemented yet. Consider submitting a PR."
55
+ end
56
+ end
57
+
58
+ # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
59
+ # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
60
+ # complex and quantized types not supported by PyTorch yet
61
+ DTYPE_TO_ENUM = {
62
+ uint8: 0,
63
+ int8: 1,
64
+ short: 2,
65
+ int16: 2,
66
+ int: 3,
67
+ int32: 3,
68
+ long: 4,
69
+ int64: 4,
70
+ half: 5,
71
+ float16: 5,
72
+ float: 6,
73
+ float32: 6,
74
+ double: 7,
75
+ float64: 7,
76
+ # complex_half: 8,
77
+ # complex_float: 9,
78
+ # complex_double: 10,
79
+ bool: 11,
80
+ # qint8: 12,
81
+ # quint8: 13,
82
+ # qint32: 14,
83
+ # bfloat16: 15
84
+ }
85
+ ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
86
+
87
+ class << self
88
+ # Torch.float, Torch.long, etc
89
+ DTYPE_TO_ENUM.each_key do |dtype|
90
+ define_method(dtype) do
91
+ dtype
92
+ end
93
+
94
+ Tensor.define_method(dtype) do
95
+ type(dtype)
96
+ end
97
+ end
98
+
99
+ # https://pytorch.org/docs/stable/torch.html
100
+
101
+ def tensor?(obj)
102
+ obj.is_a?(Tensor)
103
+ end
104
+
105
+ def from_numo(ndarray)
106
+ dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
107
+ raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
108
+ options = tensor_options(device: "cpu", dtype: dtype[0])
109
+ # TODO pass pointer to array instead of creating string
110
+ str = ndarray.to_string
111
+ tensor = _from_blob(str, ndarray.shape, options)
112
+ # from_blob does not own the data, so we need to keep
113
+ # a reference to it for duration of tensor
114
+ # can remove when passing pointer directly
115
+ tensor.instance_variable_set("@_numo_str", str)
116
+ tensor
117
+ end
118
+
119
+ # private
120
+ # use method for cases when Numo not available
121
+ # or available after Torch loaded
122
+ def _dtype_to_numo
123
+ {
124
+ uint8: Numo::UInt8,
125
+ int8: Numo::Int8,
126
+ int16: Numo::Int16,
127
+ int32: Numo::Int32,
128
+ int64: Numo::Int64,
129
+ float32: Numo::SFloat,
130
+ float64: Numo::DFloat
131
+ }
132
+ end
133
+
134
+ # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
135
+
136
+ def arange(start, finish = nil, step = 1, **options)
137
+ # ruby doesn't support start = 0, finish, step = 1, ...
138
+ if finish.nil?
139
+ finish = start
140
+ start = 0
141
+ end
142
+ _arange(start, finish, step, tensor_options(**options))
143
+ end
144
+
145
+ def empty(*size, **options)
146
+ _empty(tensor_size(size), tensor_options(**options))
147
+ end
148
+
149
+ def eye(n, m = nil, **options)
150
+ _eye(n, m || n, tensor_options(**options))
151
+ end
152
+
153
+ def full(size, fill_value, **options)
154
+ _full(size, fill_value, tensor_options(**options))
155
+ end
156
+
157
+ def linspace(start, finish, steps = 100, **options)
158
+ _linspace(start, finish, steps, tensor_options(**options))
159
+ end
160
+
161
+ def logspace(start, finish, steps = 100, base = 10.0, **options)
162
+ _logspace(start, finish, steps, base, tensor_options(**options))
163
+ end
164
+
165
+ def ones(*size, **options)
166
+ _ones(tensor_size(size), tensor_options(**options))
167
+ end
168
+
169
+ def rand(*size, **options)
170
+ _rand(tensor_size(size), tensor_options(**options))
171
+ end
172
+
173
+ def randint(low = 0, high, size, **options)
174
+ _randint(low, high, size, tensor_options(**options))
175
+ end
176
+
177
+ def randn(*size, **options)
178
+ _randn(tensor_size(size), tensor_options(**options))
179
+ end
180
+
181
+ def randperm(n, **options)
182
+ _randperm(n, tensor_options(**options))
183
+ end
184
+
185
+ def zeros(*size, **options)
186
+ _zeros(tensor_size(size), tensor_options(**options))
187
+ end
188
+
189
+ def tensor(data, **options)
190
+ size = []
191
+ if data.respond_to?(:to_a)
192
+ data = data.to_a
193
+ d = data
194
+ while d.is_a?(Array)
195
+ size << d.size
196
+ d = d.first
197
+ end
198
+ data = data.flatten
199
+ else
200
+ data = [data].compact
201
+ end
202
+
203
+ if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
204
+ options[:dtype] = :int64
205
+ end
206
+
207
+ _tensor(data, size, tensor_options(**options))
208
+ end
209
+
210
+ # --- begin like ---
211
+
212
+ def ones_like(input, **options)
213
+ ones(input.size, like_options(input, options))
214
+ end
215
+
216
+ def empty_like(input, **options)
217
+ empty(input.size, like_options(input, options))
218
+ end
219
+
220
+ def full_like(input, fill_value, **options)
221
+ full(input.size, fill_value, like_options(input, options))
222
+ end
223
+
224
+ def rand_like(input, **options)
225
+ rand(input.size, like_options(input, options))
226
+ end
227
+
228
+ def randint_like(input, low, high = nil, **options)
229
+ # ruby doesn't support input, low = 0, high, ...
230
+ if high.nil?
231
+ high = low
232
+ low = 0
233
+ end
234
+ randint(low, high, input.size, like_options(input, options))
235
+ end
236
+
237
+ def randn_like(input, **options)
238
+ randn(input.size, like_options(input, options))
239
+ end
240
+
241
+ def zeros_like(input, **options)
242
+ zeros(input.size, like_options(input, options))
243
+ end
244
+
245
+ # --- begin operations ---
246
+
247
+ %w(add sub mul div remainder).each do |op|
248
+ define_method(op) do |input, other, **options|
249
+ execute_op(op, input, other, **options)
250
+ end
251
+ end
252
+
253
+ def neg(input)
254
+ _neg(input)
255
+ end
256
+
257
+ def no_grad
258
+ previous_value = grad_enabled?
259
+ begin
260
+ _set_grad_enabled(false)
261
+ yield
262
+ ensure
263
+ _set_grad_enabled(previous_value)
264
+ end
265
+ end
266
+
267
+ # TODO support out
268
+ def mean(input, dim = nil, keepdim: false)
269
+ if dim
270
+ _mean_dim(input, dim, keepdim)
271
+ else
272
+ _mean(input)
273
+ end
274
+ end
275
+
276
+ # TODO support dtype
277
+ def sum(input, dim = nil, keepdim: false)
278
+ if dim
279
+ _sum_dim(input, dim, keepdim)
280
+ else
281
+ _sum(input)
282
+ end
283
+ end
284
+
285
+ def argmax(input, dim = nil, keepdim: false)
286
+ if dim
287
+ _argmax_dim(input, dim, keepdim)
288
+ else
289
+ _argmax(input)
290
+ end
291
+ end
292
+
293
+ def eq(input, other)
294
+ _eq(input, other)
295
+ end
296
+
297
+ def norm(input)
298
+ _norm(input)
299
+ end
300
+
301
+ def pow(input, exponent)
302
+ _pow(input, exponent)
303
+ end
304
+
305
+ def min(input)
306
+ _min(input)
307
+ end
308
+
309
+ def max(input, dim = nil, keepdim: false, out: nil)
310
+ if dim
311
+ raise NotImplementedYet unless out
312
+ _max_out(out[0], out[1], input, dim, keepdim)
313
+ else
314
+ _max(input)
315
+ end
316
+ end
317
+
318
+ def exp(input)
319
+ _exp(input)
320
+ end
321
+
322
+ def log(input)
323
+ _log(input)
324
+ end
325
+
326
+ def sign(input)
327
+ _sign(input)
328
+ end
329
+
330
+ def gt(input, other)
331
+ _gt(input, other)
332
+ end
333
+
334
+ def lt(input, other)
335
+ _lt(input, other)
336
+ end
337
+
338
+ def unsqueeze(input, dim)
339
+ _unsqueeze(input, dim)
340
+ end
341
+
342
+ def dot(input, tensor)
343
+ _dot(input, tensor)
344
+ end
345
+
346
+ def cat(tensors, dim = 0)
347
+ _cat(tensors, dim)
348
+ end
349
+
350
+ def matmul(input, other)
351
+ _matmul(input, other)
352
+ end
353
+
354
+ def reshape(input, shape)
355
+ _reshape(input, shape)
356
+ end
357
+
358
+ def flatten(input, start_dim: 0, end_dim: -1)
359
+ _flatten(input, start_dim, end_dim)
360
+ end
361
+
362
+ def sqrt(input)
363
+ _sqrt(input)
364
+ end
365
+
366
+ def abs(input)
367
+ _abs(input)
368
+ end
369
+
370
+ def device(str)
371
+ Device.new(str)
372
+ end
373
+
374
+ private
375
+
376
+ def execute_op(op, input, other, out: nil)
377
+ scalar = other.is_a?(Numeric)
378
+ if out
379
+ # TODO make work with scalars
380
+ raise Error, "out not supported with scalar yet" if scalar
381
+ send("_#{op}_out", out, input, other)
382
+ else
383
+ if scalar
384
+ send("_#{op}_scalar", input, other)
385
+ else
386
+ send("_#{op}", input, other)
387
+ end
388
+ end
389
+ end
390
+
391
+ def tensor_size(size)
392
+ size.flatten
393
+ end
394
+
395
+ def tensor_options(dtype: nil, layout: nil, device: nil, requires_grad: nil)
396
+ options = TensorOptions.new
397
+ unless dtype.nil?
398
+ type = DTYPE_TO_ENUM[dtype]
399
+ raise Error, "Unknown dtype: #{dtype.inspect}" unless type
400
+ options = options.dtype(type)
401
+ end
402
+ unless device.nil?
403
+ options = options.device(device.to_s)
404
+ end
405
+ unless layout.nil?
406
+ options = options.layout(layout.to_s)
407
+ end
408
+ unless requires_grad.nil?
409
+ options = options.requires_grad(requires_grad)
410
+ end
411
+ options
412
+ end
413
+
414
+ def like_options(input, options)
415
+ options = options.dup
416
+ options[:dtype] ||= input.dtype
417
+ options[:layout] ||= input.layout
418
+ options[:device] ||= input.device
419
+ options
420
+ end
421
+ end
422
+ end
Binary file