torch-rb 0.1.0 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
data/lib/torch.rb CHANGED
@@ -1,32 +1,130 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
7
12
  require "torch/version"
8
13
 
9
- # nn
14
+ # optim
15
+ require "torch/optim/optimizer"
16
+ require "torch/optim/adadelta"
17
+ require "torch/optim/adagrad"
18
+ require "torch/optim/adam"
19
+ require "torch/optim/adamax"
20
+ require "torch/optim/adamw"
21
+ require "torch/optim/asgd"
22
+ require "torch/optim/rmsprop"
23
+ require "torch/optim/rprop"
24
+ require "torch/optim/sgd"
25
+
26
+ # optim lr_scheduler
27
+ require "torch/optim/lr_scheduler/lr_scheduler"
28
+ require "torch/optim/lr_scheduler/step_lr"
29
+
30
+ # nn parameters
31
+ require "torch/nn/parameter"
32
+
33
+ # nn containers
10
34
  require "torch/nn/module"
11
- require "torch/nn/init"
35
+ require "torch/nn/sequential"
36
+
37
+ # nn convolution layers
38
+ require "torch/nn/convnd"
12
39
  require "torch/nn/conv2d"
13
- require "torch/nn/functional"
40
+
41
+ # nn pooling layers
42
+ require "torch/nn/max_poolnd"
43
+ require "torch/nn/max_pool2d"
44
+ require "torch/nn/avg_poolnd"
45
+ require "torch/nn/avg_pool2d"
46
+
47
+ # nn recurrent layers
48
+ require "torch/nn/rnn_base"
49
+ require "torch/nn/rnn"
50
+
51
+ # nn linear layers
52
+ require "torch/nn/bilinear"
53
+ require "torch/nn/identity"
14
54
  require "torch/nn/linear"
15
- require "torch/nn/parameter"
16
- require "torch/nn/sequential"
55
+
56
+ # nn dropout layers
57
+ require "torch/nn/dropoutnd"
58
+ require "torch/nn/alpha_dropout"
59
+ require "torch/nn/dropout"
60
+ require "torch/nn/dropout2d"
61
+ require "torch/nn/dropout3d"
62
+ require "torch/nn/feature_alpha_dropout"
63
+
64
+ # nn activations
65
+ require "torch/nn/leaky_relu"
66
+ require "torch/nn/prelu"
17
67
  require "torch/nn/relu"
68
+ require "torch/nn/sigmoid"
69
+ require "torch/nn/softplus"
70
+
71
+ # nn activations other
72
+ require "torch/nn/log_softmax"
73
+ require "torch/nn/softmax"
74
+ require "torch/nn/softmax2d"
75
+ require "torch/nn/softmin"
76
+
77
+ # nn sparse layers
78
+ require "torch/nn/embedding"
79
+ require "torch/nn/embedding_bag"
80
+
81
+ # nn distance functions
82
+ require "torch/nn/cosine_similarity"
83
+ require "torch/nn/pairwise_distance"
84
+
85
+ # nn loss functions
86
+ require "torch/nn/loss"
87
+ require "torch/nn/weighted_loss"
88
+ require "torch/nn/bce_loss"
89
+ require "torch/nn/bce_with_logits_loss"
90
+ require "torch/nn/cosine_embedding_loss"
91
+ require "torch/nn/cross_entropy_loss"
92
+ require "torch/nn/ctc_loss"
93
+ require "torch/nn/hinge_embedding_loss"
94
+ require "torch/nn/kl_div_loss"
95
+ require "torch/nn/l1_loss"
96
+ require "torch/nn/margin_ranking_loss"
18
97
  require "torch/nn/mse_loss"
98
+ require "torch/nn/multi_label_margin_loss"
99
+ require "torch/nn/multi_label_soft_margin_loss"
100
+ require "torch/nn/multi_margin_loss"
101
+ require "torch/nn/nll_loss"
102
+ require "torch/nn/poisson_nll_loss"
103
+ require "torch/nn/smooth_l1_loss"
104
+ require "torch/nn/soft_margin_loss"
105
+ require "torch/nn/triplet_margin_loss"
106
+
107
+ # nn other
108
+ require "torch/nn/functional"
109
+ require "torch/nn/init"
19
110
 
20
111
  # utils
21
112
  require "torch/utils/data/data_loader"
22
113
  require "torch/utils/data/tensor_dataset"
23
114
 
115
+ # random
116
+ require "torch/random"
117
+
24
118
  module Torch
25
119
  class Error < StandardError; end
120
+ class NotImplementedYet < StandardError
121
+ def message
122
+ "This feature has not been implemented yet. Consider submitting a PR."
123
+ end
124
+ end
26
125
 
27
126
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
28
127
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
29
- # complex and quantized types not supported by PyTorch yet
30
128
  DTYPE_TO_ENUM = {
31
129
  uint8: 0,
32
130
  int8: 1,
@@ -42,22 +140,61 @@ module Torch
42
140
  float32: 6,
43
141
  double: 7,
44
142
  float64: 7,
45
- # complex_half: 8,
46
- # complex_float: 9,
47
- # complex_double: 10,
143
+ complex_half: 8,
144
+ complex_float: 9,
145
+ complex_double: 10,
48
146
  bool: 11,
49
- # qint8: 12,
50
- # quint8: 13,
51
- # qint32: 14,
52
- # bfloat16: 15
147
+ qint8: 12,
148
+ quint8: 13,
149
+ qint32: 14,
150
+ bfloat16: 15
53
151
  }
54
152
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
55
153
 
154
+ def self._make_tensor_class(dtype, cuda = false)
155
+ cls = Class.new
156
+ device = cuda ? "cuda" : "cpu"
157
+ cls.define_singleton_method("new") do |*args|
158
+ if args.size == 1 && args.first.is_a?(Tensor)
159
+ args.first.send(dtype).to(device)
160
+ elsif args.size == 1 && args.first.is_a?(Array)
161
+ Torch.tensor(args.first, dtype: dtype, device: device)
162
+ else
163
+ Torch.empty(*args, dtype: dtype, device: device)
164
+ end
165
+ end
166
+ cls
167
+ end
168
+
169
+ FloatTensor = _make_tensor_class(:float32)
170
+ DoubleTensor = _make_tensor_class(:float64)
171
+ HalfTensor = _make_tensor_class(:float16)
172
+ ByteTensor = _make_tensor_class(:uint8)
173
+ CharTensor = _make_tensor_class(:int8)
174
+ ShortTensor = _make_tensor_class(:int16)
175
+ IntTensor = _make_tensor_class(:int32)
176
+ LongTensor = _make_tensor_class(:int64)
177
+ BoolTensor = _make_tensor_class(:bool)
178
+
179
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
180
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
181
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
182
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
183
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
184
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
185
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
186
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
187
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
188
+
56
189
  class << self
57
190
  # Torch.float, Torch.long, etc
58
- DTYPE_TO_ENUM.each_key do |type|
59
- define_method(type) do
60
- type
191
+ DTYPE_TO_ENUM.each_key do |dtype|
192
+ define_method(dtype) do
193
+ dtype
194
+ end
195
+
196
+ Tensor.define_method(dtype) do
197
+ type(dtype)
61
198
  end
62
199
  end
63
200
 
@@ -67,17 +204,26 @@ module Torch
67
204
  obj.is_a?(Tensor)
68
205
  end
69
206
 
70
- # TODO don't copy
71
207
  def from_numo(ndarray)
72
208
  dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
73
209
  raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
74
- tensor(ndarray.to_a, dtype: dtype[0])
210
+ options = tensor_options(device: "cpu", dtype: dtype[0])
211
+ # TODO pass pointer to array instead of creating string
212
+ str = ndarray.to_string
213
+ tensor = _from_blob(str, ndarray.shape, options)
214
+ # from_blob does not own the data, so we need to keep
215
+ # a reference to it for duration of tensor
216
+ # can remove when passing pointer directly
217
+ tensor.instance_variable_set("@_numo_str", str)
218
+ tensor
75
219
  end
76
220
 
77
221
  # private
78
222
  # use method for cases when Numo not available
79
223
  # or available after Torch loaded
80
224
  def _dtype_to_numo
225
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
226
+
81
227
  {
82
228
  uint8: Numo::UInt8,
83
229
  int8: Numo::Int8,
@@ -89,6 +235,20 @@ module Torch
89
235
  }
90
236
  end
91
237
 
238
+ def no_grad
239
+ previous_value = grad_enabled?
240
+ begin
241
+ _set_grad_enabled(false)
242
+ yield
243
+ ensure
244
+ _set_grad_enabled(previous_value)
245
+ end
246
+ end
247
+
248
+ def device(str)
249
+ Device.new(str)
250
+ end
251
+
92
252
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
93
253
 
94
254
  def arange(start, finish = nil, step = 1, **options)
@@ -158,8 +318,12 @@ module Torch
158
318
  data = [data].compact
159
319
  end
160
320
 
161
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
162
- options[:dtype] = :int64
321
+ if options[:dtype].nil?
322
+ if data.all? { |v| v.is_a?(Integer) }
323
+ options[:dtype] = :int64
324
+ elsif data.all? { |v| v == true || v == false }
325
+ options[:dtype] = :bool
326
+ end
163
327
  end
164
328
 
165
329
  _tensor(data, size, tensor_options(**options))
@@ -189,7 +353,7 @@ module Torch
189
353
  high = low
190
354
  low = 0
191
355
  end
192
- rand(input.size, like_options(input, options))
356
+ randint(low, high, input.size, like_options(input, options))
193
357
  end
194
358
 
195
359
  def randn_like(input, **options)
@@ -202,26 +366,6 @@ module Torch
202
366
 
203
367
  # --- begin operations ---
204
368
 
205
- %w(add sub mul div remainder).each do |op|
206
- define_method(op) do |input, other, **options|
207
- execute_op(op, input, other, **options)
208
- end
209
- end
210
-
211
- def neg(input)
212
- _neg(input)
213
- end
214
-
215
- def no_grad
216
- previous_value = grad_enabled?
217
- begin
218
- _set_grad_enabled(false)
219
- yield
220
- ensure
221
- _set_grad_enabled(previous_value)
222
- end
223
- end
224
-
225
369
  # TODO support out
226
370
  def mean(input, dim = nil, keepdim: false)
227
371
  if dim
@@ -240,59 +384,30 @@ module Torch
240
384
  end
241
385
  end
242
386
 
243
- def norm(input)
244
- _norm(input)
387
+ def topk(input, k)
388
+ _topk(input, k)
245
389
  end
246
390
 
247
- def pow(input, exponent)
248
- _pow(input, exponent)
249
- end
250
-
251
- def min(input)
252
- _min(input)
253
- end
254
-
255
- def max(input)
256
- _max(input)
257
- end
258
-
259
- def exp(input)
260
- _exp(input)
261
- end
262
-
263
- def log(input)
264
- _log(input)
265
- end
266
-
267
- def unsqueeze(input, dim)
268
- _unsqueeze(input, dim)
391
+ def max(input, dim = nil, keepdim: false, out: nil)
392
+ if dim
393
+ raise NotImplementedYet unless out
394
+ _max_out(out[0], out[1], input, dim, keepdim)
395
+ else
396
+ _max(input)
397
+ end
269
398
  end
270
399
 
271
- def dot(input, tensor)
272
- _dot(input, tensor)
400
+ # TODO make dim keyword argument
401
+ def log_softmax(input, dim)
402
+ _log_softmax(input, dim)
273
403
  end
274
404
 
275
- def matmul(input, other)
276
- _matmul(input, other)
405
+ def softmax(input, dim: nil)
406
+ _softmax(input, dim)
277
407
  end
278
408
 
279
409
  private
280
410
 
281
- def execute_op(op, input, other, out: nil)
282
- scalar = other.is_a?(Numeric)
283
- if out
284
- # TODO make work with scalars
285
- raise Error, "out not supported with scalar yet" if scalar
286
- send("_#{op}_out", out, input, other)
287
- else
288
- if scalar
289
- send("_#{op}_scalar", input, other)
290
- else
291
- send("_#{op}", input, other)
292
- end
293
- end
294
- end
295
-
296
411
  def tensor_size(size)
297
412
  size.flatten
298
413
  end
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -1,41 +1,49 @@
1
1
  module Torch
2
2
  module Inspector
3
+ # TODO make more performance, especially when summarizing
4
+ # how? only read data that will be displayed
3
5
  def inspect
4
6
  data =
5
7
  if numel == 0
6
8
  "[]"
7
9
  elsif dim == 0
8
- to_a.first
10
+ item
9
11
  else
10
- values = to_a.flatten
11
- abs = values.select { |v| v != 0 }.map(&:abs)
12
- max = abs.max || 1
13
- min = abs.min || 1
14
-
15
- total = 0
16
- if values.any? { |v| v < 0 }
17
- total += 1
18
- end
12
+ summarize = numel > 1000
13
+
14
+ if dtype == :bool
15
+ fmt = "%s"
16
+ else
17
+ values = to_a.flatten
18
+ abs = values.select { |v| v != 0 }.map(&:abs)
19
+ max = abs.max || 1
20
+ min = abs.min || 1
19
21
 
20
- if floating_point?
21
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
+ total = 0
23
+ if values.any? { |v| v < 0 }
24
+ total += 1
25
+ end
26
+
27
+ if floating_point?
28
+ sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
29
 
23
- all_int = values.all? { |v| v == v.to_i }
24
- decimal = all_int ? 1 : 4
30
+ all_int = values.all? { |v| v.finite? && v == v.to_i }
31
+ decimal = all_int ? 1 : 4
25
32
 
26
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
33
+ total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
27
34
 
28
- if sci
29
- fmt = "%#{total}.4e"
35
+ if sci
36
+ fmt = "%#{total}.4e"
37
+ else
38
+ fmt = "%#{total}.#{decimal}f"
39
+ end
30
40
  else
31
- fmt = "%#{total}.#{decimal}f"
41
+ total += max.to_s.size
42
+ fmt = "%#{total}d"
32
43
  end
33
- else
34
- total += max.to_s.size
35
- fmt = "%#{total}d"
36
44
  end
37
45
 
38
- inspect_level(to_a, fmt, dim - 1)
46
+ inspect_level(to_a, fmt, dim - 1, 0, summarize)
39
47
  end
40
48
 
41
49
  attributes = []
@@ -51,11 +59,30 @@ module Torch
51
59
 
52
60
  private
53
61
 
54
- def inspect_level(arr, fmt, total, level = 0)
62
+ # TODO DRY code
63
+ def inspect_level(arr, fmt, total, level, summarize)
55
64
  if level == total
56
- "[#{arr.map { |v| fmt % v }.join(", ")}]"
65
+ cols =
66
+ if summarize && arr.size > 7
67
+ arr[0..2].map { |v| fmt % v } +
68
+ ["..."] +
69
+ arr[-3..-1].map { |v| fmt % v }
70
+ else
71
+ arr.map { |v| fmt % v }
72
+ end
73
+
74
+ "[#{cols.join(", ")}]"
57
75
  else
58
- "[#{arr.map { |row| inspect_level(row, fmt, total, level + 1) }.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
76
+ rows =
77
+ if summarize && arr.size > 7
78
+ arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
+ ["..."] +
80
+ arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
+ else
82
+ arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
+ end
84
+
85
+ "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
59
86
  end
60
87
  end
61
88
  end