torch-rb 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +28 -0
  3. data/LICENSE.txt +46 -0
  4. data/README.md +426 -0
  5. data/ext/torch/ext.cpp +839 -0
  6. data/ext/torch/extconf.rb +25 -0
  7. data/lib/torch-rb.rb +1 -0
  8. data/lib/torch.rb +422 -0
  9. data/lib/torch/ext.bundle +0 -0
  10. data/lib/torch/inspector.rb +85 -0
  11. data/lib/torch/nn/alpha_dropout.rb +9 -0
  12. data/lib/torch/nn/conv2d.rb +37 -0
  13. data/lib/torch/nn/convnd.rb +41 -0
  14. data/lib/torch/nn/dropout.rb +9 -0
  15. data/lib/torch/nn/dropout2d.rb +9 -0
  16. data/lib/torch/nn/dropout3d.rb +9 -0
  17. data/lib/torch/nn/dropoutnd.rb +15 -0
  18. data/lib/torch/nn/embedding.rb +52 -0
  19. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  20. data/lib/torch/nn/functional.rb +100 -0
  21. data/lib/torch/nn/init.rb +30 -0
  22. data/lib/torch/nn/linear.rb +36 -0
  23. data/lib/torch/nn/module.rb +85 -0
  24. data/lib/torch/nn/mse_loss.rb +13 -0
  25. data/lib/torch/nn/parameter.rb +14 -0
  26. data/lib/torch/nn/relu.rb +13 -0
  27. data/lib/torch/nn/sequential.rb +29 -0
  28. data/lib/torch/optim/adadelta.rb +57 -0
  29. data/lib/torch/optim/adagrad.rb +71 -0
  30. data/lib/torch/optim/adam.rb +81 -0
  31. data/lib/torch/optim/adamax.rb +68 -0
  32. data/lib/torch/optim/adamw.rb +82 -0
  33. data/lib/torch/optim/asgd.rb +65 -0
  34. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  35. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  36. data/lib/torch/optim/optimizer.rb +62 -0
  37. data/lib/torch/optim/rmsprop.rb +76 -0
  38. data/lib/torch/optim/rprop.rb +68 -0
  39. data/lib/torch/optim/sgd.rb +60 -0
  40. data/lib/torch/tensor.rb +196 -0
  41. data/lib/torch/utils/data/data_loader.rb +27 -0
  42. data/lib/torch/utils/data/tensor_dataset.rb +22 -0
  43. data/lib/torch/version.rb +3 -0
  44. metadata +169 -0
@@ -0,0 +1,25 @@
1
+ require "mkmf-rice"
2
+
3
+ abort "Missing stdc++" unless have_library("stdc++")
4
+
5
+ $CXXFLAGS << " -std=c++11"
6
+
7
+ # needed for Linux pre-cxx11 ABI version
8
+ # $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=0"
9
+
10
+ # silence ruby/intern.h warning
11
+ $CXXFLAGS << " -Wno-deprecated-register"
12
+
13
+ inc, lib = dir_config("torch")
14
+
15
+ inc ||= "/usr/local/include"
16
+ lib ||= "/usr/local/lib"
17
+
18
+ $INCFLAGS << " -I#{inc}"
19
+ $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
20
+
21
+ $LDFLAGS << " -Wl,-rpath,#{lib}"
22
+ $LDFLAGS << " -L#{lib}"
23
+ $LDFLAGS << " -ltorch -lc10"
24
+
25
+ create_makefile("torch/ext")
@@ -0,0 +1 @@
1
+ require "torch"
@@ -0,0 +1,422 @@
1
+ # ext
2
+ require "torch/ext"
3
+
4
+ # modules
5
+ require "torch/inspector"
6
+ require "torch/tensor"
7
+ require "torch/version"
8
+
9
+ # optim
10
+ require "torch/optim/optimizer"
11
+ require "torch/optim/adadelta"
12
+ require "torch/optim/adagrad"
13
+ require "torch/optim/adam"
14
+ require "torch/optim/adamax"
15
+ require "torch/optim/adamw"
16
+ require "torch/optim/asgd"
17
+ require "torch/optim/rmsprop"
18
+ require "torch/optim/rprop"
19
+ require "torch/optim/sgd"
20
+
21
+ # optim lr_scheduler
22
+ require "torch/optim/lr_scheduler/lr_scheduler"
23
+ require "torch/optim/lr_scheduler/step_lr"
24
+
25
+ # nn base classes
26
+ require "torch/nn/module"
27
+ require "torch/nn/convnd"
28
+ require "torch/nn/dropoutnd"
29
+
30
+ # nn
31
+ require "torch/nn/alpha_dropout"
32
+ require "torch/nn/conv2d"
33
+ require "torch/nn/dropout"
34
+ require "torch/nn/dropout2d"
35
+ require "torch/nn/dropout3d"
36
+ require "torch/nn/embedding"
37
+ require "torch/nn/feature_alpha_dropout"
38
+ require "torch/nn/functional"
39
+ require "torch/nn/init"
40
+ require "torch/nn/linear"
41
+ require "torch/nn/mse_loss"
42
+ require "torch/nn/parameter"
43
+ require "torch/nn/relu"
44
+ require "torch/nn/sequential"
45
+
46
+ # utils
47
+ require "torch/utils/data/data_loader"
48
+ require "torch/utils/data/tensor_dataset"
49
+
50
+ module Torch
51
+ class Error < StandardError; end
52
+ class NotImplementedYet < StandardError
53
+ def message
54
+ "This feature has not been implemented yet. Consider submitting a PR."
55
+ end
56
+ end
57
+
58
+ # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
59
+ # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
60
+ # complex and quantized types not supported by PyTorch yet
61
+ DTYPE_TO_ENUM = {
62
+ uint8: 0,
63
+ int8: 1,
64
+ short: 2,
65
+ int16: 2,
66
+ int: 3,
67
+ int32: 3,
68
+ long: 4,
69
+ int64: 4,
70
+ half: 5,
71
+ float16: 5,
72
+ float: 6,
73
+ float32: 6,
74
+ double: 7,
75
+ float64: 7,
76
+ # complex_half: 8,
77
+ # complex_float: 9,
78
+ # complex_double: 10,
79
+ bool: 11,
80
+ # qint8: 12,
81
+ # quint8: 13,
82
+ # qint32: 14,
83
+ # bfloat16: 15
84
+ }
85
+ ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
86
+
87
+ class << self
88
+ # Torch.float, Torch.long, etc
89
+ DTYPE_TO_ENUM.each_key do |dtype|
90
+ define_method(dtype) do
91
+ dtype
92
+ end
93
+
94
+ Tensor.define_method(dtype) do
95
+ type(dtype)
96
+ end
97
+ end
98
+
99
+ # https://pytorch.org/docs/stable/torch.html
100
+
101
+ def tensor?(obj)
102
+ obj.is_a?(Tensor)
103
+ end
104
+
105
+ def from_numo(ndarray)
106
+ dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
107
+ raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
108
+ options = tensor_options(device: "cpu", dtype: dtype[0])
109
+ # TODO pass pointer to array instead of creating string
110
+ str = ndarray.to_string
111
+ tensor = _from_blob(str, ndarray.shape, options)
112
+ # from_blob does not own the data, so we need to keep
113
+ # a reference to it for duration of tensor
114
+ # can remove when passing pointer directly
115
+ tensor.instance_variable_set("@_numo_str", str)
116
+ tensor
117
+ end
118
+
119
+ # private
120
+ # use method for cases when Numo not available
121
+ # or available after Torch loaded
122
+ def _dtype_to_numo
123
+ {
124
+ uint8: Numo::UInt8,
125
+ int8: Numo::Int8,
126
+ int16: Numo::Int16,
127
+ int32: Numo::Int32,
128
+ int64: Numo::Int64,
129
+ float32: Numo::SFloat,
130
+ float64: Numo::DFloat
131
+ }
132
+ end
133
+
134
+ # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
135
+
136
+ def arange(start, finish = nil, step = 1, **options)
137
+ # ruby doesn't support start = 0, finish, step = 1, ...
138
+ if finish.nil?
139
+ finish = start
140
+ start = 0
141
+ end
142
+ _arange(start, finish, step, tensor_options(**options))
143
+ end
144
+
145
+ def empty(*size, **options)
146
+ _empty(tensor_size(size), tensor_options(**options))
147
+ end
148
+
149
+ def eye(n, m = nil, **options)
150
+ _eye(n, m || n, tensor_options(**options))
151
+ end
152
+
153
+ def full(size, fill_value, **options)
154
+ _full(size, fill_value, tensor_options(**options))
155
+ end
156
+
157
+ def linspace(start, finish, steps = 100, **options)
158
+ _linspace(start, finish, steps, tensor_options(**options))
159
+ end
160
+
161
+ def logspace(start, finish, steps = 100, base = 10.0, **options)
162
+ _logspace(start, finish, steps, base, tensor_options(**options))
163
+ end
164
+
165
+ def ones(*size, **options)
166
+ _ones(tensor_size(size), tensor_options(**options))
167
+ end
168
+
169
+ def rand(*size, **options)
170
+ _rand(tensor_size(size), tensor_options(**options))
171
+ end
172
+
173
+ def randint(low = 0, high, size, **options)
174
+ _randint(low, high, size, tensor_options(**options))
175
+ end
176
+
177
+ def randn(*size, **options)
178
+ _randn(tensor_size(size), tensor_options(**options))
179
+ end
180
+
181
+ def randperm(n, **options)
182
+ _randperm(n, tensor_options(**options))
183
+ end
184
+
185
+ def zeros(*size, **options)
186
+ _zeros(tensor_size(size), tensor_options(**options))
187
+ end
188
+
189
+ def tensor(data, **options)
190
+ size = []
191
+ if data.respond_to?(:to_a)
192
+ data = data.to_a
193
+ d = data
194
+ while d.is_a?(Array)
195
+ size << d.size
196
+ d = d.first
197
+ end
198
+ data = data.flatten
199
+ else
200
+ data = [data].compact
201
+ end
202
+
203
+ if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
204
+ options[:dtype] = :int64
205
+ end
206
+
207
+ _tensor(data, size, tensor_options(**options))
208
+ end
209
+
210
+ # --- begin like ---
211
+
212
+ def ones_like(input, **options)
213
+ ones(input.size, like_options(input, options))
214
+ end
215
+
216
+ def empty_like(input, **options)
217
+ empty(input.size, like_options(input, options))
218
+ end
219
+
220
+ def full_like(input, fill_value, **options)
221
+ full(input.size, fill_value, like_options(input, options))
222
+ end
223
+
224
+ def rand_like(input, **options)
225
+ rand(input.size, like_options(input, options))
226
+ end
227
+
228
+ def randint_like(input, low, high = nil, **options)
229
+ # ruby doesn't support input, low = 0, high, ...
230
+ if high.nil?
231
+ high = low
232
+ low = 0
233
+ end
234
+ randint(low, high, input.size, like_options(input, options))
235
+ end
236
+
237
+ def randn_like(input, **options)
238
+ randn(input.size, like_options(input, options))
239
+ end
240
+
241
+ def zeros_like(input, **options)
242
+ zeros(input.size, like_options(input, options))
243
+ end
244
+
245
+ # --- begin operations ---
246
+
247
+ %w(add sub mul div remainder).each do |op|
248
+ define_method(op) do |input, other, **options|
249
+ execute_op(op, input, other, **options)
250
+ end
251
+ end
252
+
253
+ def neg(input)
254
+ _neg(input)
255
+ end
256
+
257
+ def no_grad
258
+ previous_value = grad_enabled?
259
+ begin
260
+ _set_grad_enabled(false)
261
+ yield
262
+ ensure
263
+ _set_grad_enabled(previous_value)
264
+ end
265
+ end
266
+
267
+ # TODO support out
268
+ def mean(input, dim = nil, keepdim: false)
269
+ if dim
270
+ _mean_dim(input, dim, keepdim)
271
+ else
272
+ _mean(input)
273
+ end
274
+ end
275
+
276
+ # TODO support dtype
277
+ def sum(input, dim = nil, keepdim: false)
278
+ if dim
279
+ _sum_dim(input, dim, keepdim)
280
+ else
281
+ _sum(input)
282
+ end
283
+ end
284
+
285
+ def argmax(input, dim = nil, keepdim: false)
286
+ if dim
287
+ _argmax_dim(input, dim, keepdim)
288
+ else
289
+ _argmax(input)
290
+ end
291
+ end
292
+
293
+ def eq(input, other)
294
+ _eq(input, other)
295
+ end
296
+
297
+ def norm(input)
298
+ _norm(input)
299
+ end
300
+
301
+ def pow(input, exponent)
302
+ _pow(input, exponent)
303
+ end
304
+
305
+ def min(input)
306
+ _min(input)
307
+ end
308
+
309
+ def max(input, dim = nil, keepdim: false, out: nil)
310
+ if dim
311
+ raise NotImplementedYet unless out
312
+ _max_out(out[0], out[1], input, dim, keepdim)
313
+ else
314
+ _max(input)
315
+ end
316
+ end
317
+
318
+ def exp(input)
319
+ _exp(input)
320
+ end
321
+
322
+ def log(input)
323
+ _log(input)
324
+ end
325
+
326
+ def sign(input)
327
+ _sign(input)
328
+ end
329
+
330
+ def gt(input, other)
331
+ _gt(input, other)
332
+ end
333
+
334
+ def lt(input, other)
335
+ _lt(input, other)
336
+ end
337
+
338
+ def unsqueeze(input, dim)
339
+ _unsqueeze(input, dim)
340
+ end
341
+
342
+ def dot(input, tensor)
343
+ _dot(input, tensor)
344
+ end
345
+
346
+ def cat(tensors, dim = 0)
347
+ _cat(tensors, dim)
348
+ end
349
+
350
+ def matmul(input, other)
351
+ _matmul(input, other)
352
+ end
353
+
354
+ def reshape(input, shape)
355
+ _reshape(input, shape)
356
+ end
357
+
358
+ def flatten(input, start_dim: 0, end_dim: -1)
359
+ _flatten(input, start_dim, end_dim)
360
+ end
361
+
362
+ def sqrt(input)
363
+ _sqrt(input)
364
+ end
365
+
366
+ def abs(input)
367
+ _abs(input)
368
+ end
369
+
370
+ def device(str)
371
+ Device.new(str)
372
+ end
373
+
374
+ private
375
+
376
+ def execute_op(op, input, other, out: nil)
377
+ scalar = other.is_a?(Numeric)
378
+ if out
379
+ # TODO make work with scalars
380
+ raise Error, "out not supported with scalar yet" if scalar
381
+ send("_#{op}_out", out, input, other)
382
+ else
383
+ if scalar
384
+ send("_#{op}_scalar", input, other)
385
+ else
386
+ send("_#{op}", input, other)
387
+ end
388
+ end
389
+ end
390
+
391
+ def tensor_size(size)
392
+ size.flatten
393
+ end
394
+
395
+ def tensor_options(dtype: nil, layout: nil, device: nil, requires_grad: nil)
396
+ options = TensorOptions.new
397
+ unless dtype.nil?
398
+ type = DTYPE_TO_ENUM[dtype]
399
+ raise Error, "Unknown dtype: #{dtype.inspect}" unless type
400
+ options = options.dtype(type)
401
+ end
402
+ unless device.nil?
403
+ options = options.device(device.to_s)
404
+ end
405
+ unless layout.nil?
406
+ options = options.layout(layout.to_s)
407
+ end
408
+ unless requires_grad.nil?
409
+ options = options.requires_grad(requires_grad)
410
+ end
411
+ options
412
+ end
413
+
414
+ def like_options(input, options)
415
+ options = options.dup
416
+ options[:dtype] ||= input.dtype
417
+ options[:layout] ||= input.layout
418
+ options[:device] ||= input.device
419
+ options
420
+ end
421
+ end
422
+ end
Binary file