torch-rb 0.1.0 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
data/lib/torch.rb CHANGED
@@ -1,32 +1,130 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
7
12
  require "torch/version"
8
13
 
9
- # nn
14
+ # optim
15
+ require "torch/optim/optimizer"
16
+ require "torch/optim/adadelta"
17
+ require "torch/optim/adagrad"
18
+ require "torch/optim/adam"
19
+ require "torch/optim/adamax"
20
+ require "torch/optim/adamw"
21
+ require "torch/optim/asgd"
22
+ require "torch/optim/rmsprop"
23
+ require "torch/optim/rprop"
24
+ require "torch/optim/sgd"
25
+
26
+ # optim lr_scheduler
27
+ require "torch/optim/lr_scheduler/lr_scheduler"
28
+ require "torch/optim/lr_scheduler/step_lr"
29
+
30
+ # nn parameters
31
+ require "torch/nn/parameter"
32
+
33
+ # nn containers
10
34
  require "torch/nn/module"
11
- require "torch/nn/init"
35
+ require "torch/nn/sequential"
36
+
37
+ # nn convolution layers
38
+ require "torch/nn/convnd"
12
39
  require "torch/nn/conv2d"
13
- require "torch/nn/functional"
40
+
41
+ # nn pooling layers
42
+ require "torch/nn/max_poolnd"
43
+ require "torch/nn/max_pool2d"
44
+ require "torch/nn/avg_poolnd"
45
+ require "torch/nn/avg_pool2d"
46
+
47
+ # nn recurrent layers
48
+ require "torch/nn/rnn_base"
49
+ require "torch/nn/rnn"
50
+
51
+ # nn linear layers
52
+ require "torch/nn/bilinear"
53
+ require "torch/nn/identity"
14
54
  require "torch/nn/linear"
15
- require "torch/nn/parameter"
16
- require "torch/nn/sequential"
55
+
56
+ # nn dropout layers
57
+ require "torch/nn/dropoutnd"
58
+ require "torch/nn/alpha_dropout"
59
+ require "torch/nn/dropout"
60
+ require "torch/nn/dropout2d"
61
+ require "torch/nn/dropout3d"
62
+ require "torch/nn/feature_alpha_dropout"
63
+
64
+ # nn activations
65
+ require "torch/nn/leaky_relu"
66
+ require "torch/nn/prelu"
17
67
  require "torch/nn/relu"
68
+ require "torch/nn/sigmoid"
69
+ require "torch/nn/softplus"
70
+
71
+ # nn activations other
72
+ require "torch/nn/log_softmax"
73
+ require "torch/nn/softmax"
74
+ require "torch/nn/softmax2d"
75
+ require "torch/nn/softmin"
76
+
77
+ # nn sparse layers
78
+ require "torch/nn/embedding"
79
+ require "torch/nn/embedding_bag"
80
+
81
+ # nn distance functions
82
+ require "torch/nn/cosine_similarity"
83
+ require "torch/nn/pairwise_distance"
84
+
85
+ # nn loss functions
86
+ require "torch/nn/loss"
87
+ require "torch/nn/weighted_loss"
88
+ require "torch/nn/bce_loss"
89
+ require "torch/nn/bce_with_logits_loss"
90
+ require "torch/nn/cosine_embedding_loss"
91
+ require "torch/nn/cross_entropy_loss"
92
+ require "torch/nn/ctc_loss"
93
+ require "torch/nn/hinge_embedding_loss"
94
+ require "torch/nn/kl_div_loss"
95
+ require "torch/nn/l1_loss"
96
+ require "torch/nn/margin_ranking_loss"
18
97
  require "torch/nn/mse_loss"
98
+ require "torch/nn/multi_label_margin_loss"
99
+ require "torch/nn/multi_label_soft_margin_loss"
100
+ require "torch/nn/multi_margin_loss"
101
+ require "torch/nn/nll_loss"
102
+ require "torch/nn/poisson_nll_loss"
103
+ require "torch/nn/smooth_l1_loss"
104
+ require "torch/nn/soft_margin_loss"
105
+ require "torch/nn/triplet_margin_loss"
106
+
107
+ # nn other
108
+ require "torch/nn/functional"
109
+ require "torch/nn/init"
19
110
 
20
111
  # utils
21
112
  require "torch/utils/data/data_loader"
22
113
  require "torch/utils/data/tensor_dataset"
23
114
 
115
+ # random
116
+ require "torch/random"
117
+
24
118
  module Torch
25
119
  class Error < StandardError; end
120
+ class NotImplementedYet < StandardError
121
+ def message
122
+ "This feature has not been implemented yet. Consider submitting a PR."
123
+ end
124
+ end
26
125
 
27
126
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
28
127
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
29
- # complex and quantized types not supported by PyTorch yet
30
128
  DTYPE_TO_ENUM = {
31
129
  uint8: 0,
32
130
  int8: 1,
@@ -42,22 +140,61 @@ module Torch
42
140
  float32: 6,
43
141
  double: 7,
44
142
  float64: 7,
45
- # complex_half: 8,
46
- # complex_float: 9,
47
- # complex_double: 10,
143
+ complex_half: 8,
144
+ complex_float: 9,
145
+ complex_double: 10,
48
146
  bool: 11,
49
- # qint8: 12,
50
- # quint8: 13,
51
- # qint32: 14,
52
- # bfloat16: 15
147
+ qint8: 12,
148
+ quint8: 13,
149
+ qint32: 14,
150
+ bfloat16: 15
53
151
  }
54
152
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
55
153
 
154
+ def self._make_tensor_class(dtype, cuda = false)
155
+ cls = Class.new
156
+ device = cuda ? "cuda" : "cpu"
157
+ cls.define_singleton_method("new") do |*args|
158
+ if args.size == 1 && args.first.is_a?(Tensor)
159
+ args.first.send(dtype).to(device)
160
+ elsif args.size == 1 && args.first.is_a?(Array)
161
+ Torch.tensor(args.first, dtype: dtype, device: device)
162
+ else
163
+ Torch.empty(*args, dtype: dtype, device: device)
164
+ end
165
+ end
166
+ cls
167
+ end
168
+
169
+ FloatTensor = _make_tensor_class(:float32)
170
+ DoubleTensor = _make_tensor_class(:float64)
171
+ HalfTensor = _make_tensor_class(:float16)
172
+ ByteTensor = _make_tensor_class(:uint8)
173
+ CharTensor = _make_tensor_class(:int8)
174
+ ShortTensor = _make_tensor_class(:int16)
175
+ IntTensor = _make_tensor_class(:int32)
176
+ LongTensor = _make_tensor_class(:int64)
177
+ BoolTensor = _make_tensor_class(:bool)
178
+
179
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
180
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
181
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
182
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
183
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
184
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
185
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
186
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
187
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
188
+
56
189
  class << self
57
190
  # Torch.float, Torch.long, etc
58
- DTYPE_TO_ENUM.each_key do |type|
59
- define_method(type) do
60
- type
191
+ DTYPE_TO_ENUM.each_key do |dtype|
192
+ define_method(dtype) do
193
+ dtype
194
+ end
195
+
196
+ Tensor.define_method(dtype) do
197
+ type(dtype)
61
198
  end
62
199
  end
63
200
 
@@ -67,17 +204,26 @@ module Torch
67
204
  obj.is_a?(Tensor)
68
205
  end
69
206
 
70
- # TODO don't copy
71
207
  def from_numo(ndarray)
72
208
  dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
73
209
  raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
74
- tensor(ndarray.to_a, dtype: dtype[0])
210
+ options = tensor_options(device: "cpu", dtype: dtype[0])
211
+ # TODO pass pointer to array instead of creating string
212
+ str = ndarray.to_string
213
+ tensor = _from_blob(str, ndarray.shape, options)
214
+ # from_blob does not own the data, so we need to keep
215
+ # a reference to it for duration of tensor
216
+ # can remove when passing pointer directly
217
+ tensor.instance_variable_set("@_numo_str", str)
218
+ tensor
75
219
  end
76
220
 
77
221
  # private
78
222
  # use method for cases when Numo not available
79
223
  # or available after Torch loaded
80
224
  def _dtype_to_numo
225
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
226
+
81
227
  {
82
228
  uint8: Numo::UInt8,
83
229
  int8: Numo::Int8,
@@ -89,6 +235,20 @@ module Torch
89
235
  }
90
236
  end
91
237
 
238
+ def no_grad
239
+ previous_value = grad_enabled?
240
+ begin
241
+ _set_grad_enabled(false)
242
+ yield
243
+ ensure
244
+ _set_grad_enabled(previous_value)
245
+ end
246
+ end
247
+
248
+ def device(str)
249
+ Device.new(str)
250
+ end
251
+
92
252
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
93
253
 
94
254
  def arange(start, finish = nil, step = 1, **options)
@@ -158,8 +318,12 @@ module Torch
158
318
  data = [data].compact
159
319
  end
160
320
 
161
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
162
- options[:dtype] = :int64
321
+ if options[:dtype].nil?
322
+ if data.all? { |v| v.is_a?(Integer) }
323
+ options[:dtype] = :int64
324
+ elsif data.all? { |v| v == true || v == false }
325
+ options[:dtype] = :bool
326
+ end
163
327
  end
164
328
 
165
329
  _tensor(data, size, tensor_options(**options))
@@ -189,7 +353,7 @@ module Torch
189
353
  high = low
190
354
  low = 0
191
355
  end
192
- rand(input.size, like_options(input, options))
356
+ randint(low, high, input.size, like_options(input, options))
193
357
  end
194
358
 
195
359
  def randn_like(input, **options)
@@ -202,26 +366,6 @@ module Torch
202
366
 
203
367
  # --- begin operations ---
204
368
 
205
- %w(add sub mul div remainder).each do |op|
206
- define_method(op) do |input, other, **options|
207
- execute_op(op, input, other, **options)
208
- end
209
- end
210
-
211
- def neg(input)
212
- _neg(input)
213
- end
214
-
215
- def no_grad
216
- previous_value = grad_enabled?
217
- begin
218
- _set_grad_enabled(false)
219
- yield
220
- ensure
221
- _set_grad_enabled(previous_value)
222
- end
223
- end
224
-
225
369
  # TODO support out
226
370
  def mean(input, dim = nil, keepdim: false)
227
371
  if dim
@@ -240,59 +384,30 @@ module Torch
240
384
  end
241
385
  end
242
386
 
243
- def norm(input)
244
- _norm(input)
387
+ def topk(input, k)
388
+ _topk(input, k)
245
389
  end
246
390
 
247
- def pow(input, exponent)
248
- _pow(input, exponent)
249
- end
250
-
251
- def min(input)
252
- _min(input)
253
- end
254
-
255
- def max(input)
256
- _max(input)
257
- end
258
-
259
- def exp(input)
260
- _exp(input)
261
- end
262
-
263
- def log(input)
264
- _log(input)
265
- end
266
-
267
- def unsqueeze(input, dim)
268
- _unsqueeze(input, dim)
391
+ def max(input, dim = nil, keepdim: false, out: nil)
392
+ if dim
393
+ raise NotImplementedYet unless out
394
+ _max_out(out[0], out[1], input, dim, keepdim)
395
+ else
396
+ _max(input)
397
+ end
269
398
  end
270
399
 
271
- def dot(input, tensor)
272
- _dot(input, tensor)
400
+ # TODO make dim keyword argument
401
+ def log_softmax(input, dim)
402
+ _log_softmax(input, dim)
273
403
  end
274
404
 
275
- def matmul(input, other)
276
- _matmul(input, other)
405
+ def softmax(input, dim: nil)
406
+ _softmax(input, dim)
277
407
  end
278
408
 
279
409
  private
280
410
 
281
- def execute_op(op, input, other, out: nil)
282
- scalar = other.is_a?(Numeric)
283
- if out
284
- # TODO make work with scalars
285
- raise Error, "out not supported with scalar yet" if scalar
286
- send("_#{op}_out", out, input, other)
287
- else
288
- if scalar
289
- send("_#{op}_scalar", input, other)
290
- else
291
- send("_#{op}", input, other)
292
- end
293
- end
294
- end
295
-
296
411
  def tensor_size(size)
297
412
  size.flatten
298
413
  end
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -1,41 +1,49 @@
1
1
  module Torch
2
2
  module Inspector
3
+ # TODO make more performance, especially when summarizing
4
+ # how? only read data that will be displayed
3
5
  def inspect
4
6
  data =
5
7
  if numel == 0
6
8
  "[]"
7
9
  elsif dim == 0
8
- to_a.first
10
+ item
9
11
  else
10
- values = to_a.flatten
11
- abs = values.select { |v| v != 0 }.map(&:abs)
12
- max = abs.max || 1
13
- min = abs.min || 1
14
-
15
- total = 0
16
- if values.any? { |v| v < 0 }
17
- total += 1
18
- end
12
+ summarize = numel > 1000
13
+
14
+ if dtype == :bool
15
+ fmt = "%s"
16
+ else
17
+ values = to_a.flatten
18
+ abs = values.select { |v| v != 0 }.map(&:abs)
19
+ max = abs.max || 1
20
+ min = abs.min || 1
19
21
 
20
- if floating_point?
21
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
+ total = 0
23
+ if values.any? { |v| v < 0 }
24
+ total += 1
25
+ end
26
+
27
+ if floating_point?
28
+ sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
29
 
23
- all_int = values.all? { |v| v == v.to_i }
24
- decimal = all_int ? 1 : 4
30
+ all_int = values.all? { |v| v.finite? && v == v.to_i }
31
+ decimal = all_int ? 1 : 4
25
32
 
26
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
33
+ total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
27
34
 
28
- if sci
29
- fmt = "%#{total}.4e"
35
+ if sci
36
+ fmt = "%#{total}.4e"
37
+ else
38
+ fmt = "%#{total}.#{decimal}f"
39
+ end
30
40
  else
31
- fmt = "%#{total}.#{decimal}f"
41
+ total += max.to_s.size
42
+ fmt = "%#{total}d"
32
43
  end
33
- else
34
- total += max.to_s.size
35
- fmt = "%#{total}d"
36
44
  end
37
45
 
38
- inspect_level(to_a, fmt, dim - 1)
46
+ inspect_level(to_a, fmt, dim - 1, 0, summarize)
39
47
  end
40
48
 
41
49
  attributes = []
@@ -51,11 +59,30 @@ module Torch
51
59
 
52
60
  private
53
61
 
54
- def inspect_level(arr, fmt, total, level = 0)
62
+ # TODO DRY code
63
+ def inspect_level(arr, fmt, total, level, summarize)
55
64
  if level == total
56
- "[#{arr.map { |v| fmt % v }.join(", ")}]"
65
+ cols =
66
+ if summarize && arr.size > 7
67
+ arr[0..2].map { |v| fmt % v } +
68
+ ["..."] +
69
+ arr[-3..-1].map { |v| fmt % v }
70
+ else
71
+ arr.map { |v| fmt % v }
72
+ end
73
+
74
+ "[#{cols.join(", ")}]"
57
75
  else
58
- "[#{arr.map { |row| inspect_level(row, fmt, total, level + 1) }.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
76
+ rows =
77
+ if summarize && arr.size > 7
78
+ arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
+ ["..."] +
80
+ arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
+ else
82
+ arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
+ end
84
+
85
+ "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
59
86
  end
60
87
  end
61
88
  end