torch-rb 0.1.2 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
data/lib/torch.rb CHANGED
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
@@ -8,29 +13,169 @@ require "torch/version"
8
13
 
9
14
  # optim
10
15
  require "torch/optim/optimizer"
16
+ require "torch/optim/adadelta"
17
+ require "torch/optim/adagrad"
18
+ require "torch/optim/adam"
19
+ require "torch/optim/adamax"
20
+ require "torch/optim/adamw"
21
+ require "torch/optim/asgd"
22
+ require "torch/optim/rmsprop"
23
+ require "torch/optim/rprop"
11
24
  require "torch/optim/sgd"
12
25
 
13
- # nn
26
+ # optim lr_scheduler
27
+ require "torch/optim/lr_scheduler/lr_scheduler"
28
+ require "torch/optim/lr_scheduler/step_lr"
29
+
30
+ # nn parameters
31
+ require "torch/nn/parameter"
32
+ require "torch/nn/utils"
33
+
34
+ # nn containers
14
35
  require "torch/nn/module"
15
- require "torch/nn/init"
36
+ require "torch/nn/sequential"
37
+
38
+ # nn convolution layers
39
+ require "torch/nn/convnd"
40
+ require "torch/nn/conv1d"
16
41
  require "torch/nn/conv2d"
17
- require "torch/nn/functional"
42
+ require "torch/nn/conv3d"
43
+ require "torch/nn/unfold"
44
+ require "torch/nn/fold"
45
+
46
+ # nn pooling layers
47
+ require "torch/nn/max_poolnd"
48
+ require "torch/nn/max_pool1d"
49
+ require "torch/nn/max_pool2d"
50
+ require "torch/nn/max_pool3d"
51
+ require "torch/nn/max_unpoolnd"
52
+ require "torch/nn/max_unpool1d"
53
+ require "torch/nn/max_unpool2d"
54
+ require "torch/nn/max_unpool3d"
55
+ require "torch/nn/avg_poolnd"
56
+ require "torch/nn/avg_pool1d"
57
+ require "torch/nn/avg_pool2d"
58
+ require "torch/nn/avg_pool3d"
59
+ require "torch/nn/lp_poolnd"
60
+ require "torch/nn/lp_pool1d"
61
+ require "torch/nn/lp_pool2d"
62
+
63
+ # nn padding layers
64
+ require "torch/nn/reflection_padnd"
65
+ require "torch/nn/reflection_pad1d"
66
+ require "torch/nn/reflection_pad2d"
67
+ require "torch/nn/replication_padnd"
68
+ require "torch/nn/replication_pad1d"
69
+ require "torch/nn/replication_pad2d"
70
+ require "torch/nn/replication_pad3d"
71
+ require "torch/nn/constant_padnd"
72
+ require "torch/nn/constant_pad1d"
73
+ require "torch/nn/constant_pad2d"
74
+ require "torch/nn/constant_pad3d"
75
+ require "torch/nn/zero_pad2d"
76
+
77
+ # nn normalization layers
78
+ require "torch/nn/batch_norm"
79
+ require "torch/nn/batch_norm1d"
80
+ require "torch/nn/batch_norm2d"
81
+ require "torch/nn/batch_norm3d"
82
+ require "torch/nn/group_norm"
83
+ require "torch/nn/instance_norm"
84
+ require "torch/nn/instance_norm1d"
85
+ require "torch/nn/instance_norm2d"
86
+ require "torch/nn/instance_norm3d"
87
+ require "torch/nn/layer_norm"
88
+ require "torch/nn/local_response_norm"
89
+
90
+ # nn recurrent layers
91
+ require "torch/nn/rnn_base"
92
+ require "torch/nn/rnn"
93
+ require "torch/nn/lstm"
94
+ require "torch/nn/gru"
95
+
96
+ # nn linear layers
97
+ require "torch/nn/bilinear"
98
+ require "torch/nn/identity"
18
99
  require "torch/nn/linear"
19
- require "torch/nn/parameter"
20
- require "torch/nn/sequential"
100
+
101
+ # nn dropout layers
102
+ require "torch/nn/dropoutnd"
103
+ require "torch/nn/alpha_dropout"
104
+ require "torch/nn/dropout"
105
+ require "torch/nn/dropout2d"
106
+ require "torch/nn/dropout3d"
107
+ require "torch/nn/feature_alpha_dropout"
108
+
109
+ # nn activations
110
+ require "torch/nn/hardshrink"
111
+ require "torch/nn/leaky_relu"
112
+ require "torch/nn/log_sigmoid"
113
+ require "torch/nn/prelu"
21
114
  require "torch/nn/relu"
115
+ require "torch/nn/sigmoid"
116
+ require "torch/nn/softplus"
117
+ require "torch/nn/softshrink"
118
+ require "torch/nn/softsign"
119
+ require "torch/nn/tanh"
120
+ require "torch/nn/tanhshrink"
121
+
122
+ # nn activations other
123
+ require "torch/nn/log_softmax"
124
+ require "torch/nn/softmax"
125
+ require "torch/nn/softmax2d"
126
+ require "torch/nn/softmin"
127
+
128
+ # nn sparse layers
129
+ require "torch/nn/embedding"
130
+ require "torch/nn/embedding_bag"
131
+
132
+ # nn distance functions
133
+ require "torch/nn/cosine_similarity"
134
+ require "torch/nn/pairwise_distance"
135
+
136
+ # nn loss functions
137
+ require "torch/nn/loss"
138
+ require "torch/nn/weighted_loss"
139
+ require "torch/nn/bce_loss"
140
+ require "torch/nn/bce_with_logits_loss"
141
+ require "torch/nn/cosine_embedding_loss"
142
+ require "torch/nn/cross_entropy_loss"
143
+ require "torch/nn/ctc_loss"
144
+ require "torch/nn/hinge_embedding_loss"
145
+ require "torch/nn/kl_div_loss"
146
+ require "torch/nn/l1_loss"
147
+ require "torch/nn/margin_ranking_loss"
22
148
  require "torch/nn/mse_loss"
149
+ require "torch/nn/multi_label_margin_loss"
150
+ require "torch/nn/multi_label_soft_margin_loss"
151
+ require "torch/nn/multi_margin_loss"
152
+ require "torch/nn/nll_loss"
153
+ require "torch/nn/poisson_nll_loss"
154
+ require "torch/nn/smooth_l1_loss"
155
+ require "torch/nn/soft_margin_loss"
156
+ require "torch/nn/triplet_margin_loss"
157
+
158
+ # nn other
159
+ require "torch/nn/functional"
160
+ require "torch/nn/init"
23
161
 
24
162
  # utils
25
163
  require "torch/utils/data/data_loader"
26
164
  require "torch/utils/data/tensor_dataset"
27
165
 
166
+ # random
167
+ require "torch/random"
168
+
28
169
  module Torch
29
170
  class Error < StandardError; end
171
+ class NotImplementedYet < StandardError
172
+ def message
173
+ "This feature has not been implemented yet. Consider submitting a PR."
174
+ end
175
+ end
30
176
 
31
177
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
32
178
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
33
- # complex and quantized types not supported by PyTorch yet
34
179
  DTYPE_TO_ENUM = {
35
180
  uint8: 0,
36
181
  int8: 1,
@@ -46,17 +191,52 @@ module Torch
46
191
  float32: 6,
47
192
  double: 7,
48
193
  float64: 7,
49
- # complex_half: 8,
50
- # complex_float: 9,
51
- # complex_double: 10,
194
+ complex_half: 8,
195
+ complex_float: 9,
196
+ complex_double: 10,
52
197
  bool: 11,
53
- # qint8: 12,
54
- # quint8: 13,
55
- # qint32: 14,
56
- # bfloat16: 15
198
+ qint8: 12,
199
+ quint8: 13,
200
+ qint32: 14,
201
+ bfloat16: 15
57
202
  }
58
203
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
59
204
 
205
+ def self._make_tensor_class(dtype, cuda = false)
206
+ cls = Class.new
207
+ device = cuda ? "cuda" : "cpu"
208
+ cls.define_singleton_method("new") do |*args|
209
+ if args.size == 1 && args.first.is_a?(Tensor)
210
+ args.first.send(dtype).to(device)
211
+ elsif args.size == 1 && args.first.is_a?(Array)
212
+ Torch.tensor(args.first, dtype: dtype, device: device)
213
+ else
214
+ Torch.empty(*args, dtype: dtype, device: device)
215
+ end
216
+ end
217
+ cls
218
+ end
219
+
220
+ FloatTensor = _make_tensor_class(:float32)
221
+ DoubleTensor = _make_tensor_class(:float64)
222
+ HalfTensor = _make_tensor_class(:float16)
223
+ ByteTensor = _make_tensor_class(:uint8)
224
+ CharTensor = _make_tensor_class(:int8)
225
+ ShortTensor = _make_tensor_class(:int16)
226
+ IntTensor = _make_tensor_class(:int32)
227
+ LongTensor = _make_tensor_class(:int64)
228
+ BoolTensor = _make_tensor_class(:bool)
229
+
230
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
231
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
232
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
233
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
234
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
235
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
236
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
237
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
238
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
239
+
60
240
  class << self
61
241
  # Torch.float, Torch.long, etc
62
242
  DTYPE_TO_ENUM.each_key do |dtype|
@@ -75,17 +255,26 @@ module Torch
75
255
  obj.is_a?(Tensor)
76
256
  end
77
257
 
78
- # TODO don't copy
79
258
  def from_numo(ndarray)
80
259
  dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
81
260
  raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
82
- tensor(ndarray.to_a, dtype: dtype[0])
261
+ options = tensor_options(device: "cpu", dtype: dtype[0])
262
+ # TODO pass pointer to array instead of creating string
263
+ str = ndarray.to_string
264
+ tensor = _from_blob(str, ndarray.shape, options)
265
+ # from_blob does not own the data, so we need to keep
266
+ # a reference to it for duration of tensor
267
+ # can remove when passing pointer directly
268
+ tensor.instance_variable_set("@_numo_str", str)
269
+ tensor
83
270
  end
84
271
 
85
272
  # private
86
273
  # use method for cases when Numo not available
87
274
  # or available after Torch loaded
88
275
  def _dtype_to_numo
276
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
277
+
89
278
  {
90
279
  uint8: Numo::UInt8,
91
280
  int8: Numo::Int8,
@@ -97,6 +286,29 @@ module Torch
97
286
  }
98
287
  end
99
288
 
289
+ def no_grad
290
+ previous_value = grad_enabled?
291
+ begin
292
+ _set_grad_enabled(false)
293
+ yield
294
+ ensure
295
+ _set_grad_enabled(previous_value)
296
+ end
297
+ end
298
+
299
+ def device(str)
300
+ Device.new(str)
301
+ end
302
+
303
+ def save(obj, f)
304
+ raise NotImplementedYet unless obj.is_a?(Tensor)
305
+ File.binwrite(f, _save(obj))
306
+ end
307
+
308
+ def load(f)
309
+ raise NotImplementedYet
310
+ end
311
+
100
312
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
101
313
 
102
314
  def arange(start, finish = nil, step = 1, **options)
@@ -166,8 +378,12 @@ module Torch
166
378
  data = [data].compact
167
379
  end
168
380
 
169
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
170
- options[:dtype] = :int64
381
+ if options[:dtype].nil?
382
+ if data.all? { |v| v.is_a?(Integer) }
383
+ options[:dtype] = :int64
384
+ elsif data.all? { |v| v == true || v == false }
385
+ options[:dtype] = :bool
386
+ end
171
387
  end
172
388
 
173
389
  _tensor(data, size, tensor_options(**options))
@@ -176,19 +392,19 @@ module Torch
176
392
  # --- begin like ---
177
393
 
178
394
  def ones_like(input, **options)
179
- ones(input.size, like_options(input, options))
395
+ ones(input.size, **like_options(input, options))
180
396
  end
181
397
 
182
398
  def empty_like(input, **options)
183
- empty(input.size, like_options(input, options))
399
+ empty(input.size, **like_options(input, options))
184
400
  end
185
401
 
186
402
  def full_like(input, fill_value, **options)
187
- full(input.size, fill_value, like_options(input, options))
403
+ full(input.size, fill_value, **like_options(input, options))
188
404
  end
189
405
 
190
406
  def rand_like(input, **options)
191
- rand(input.size, like_options(input, options))
407
+ rand(input.size, **like_options(input, options))
192
408
  end
193
409
 
194
410
  def randint_like(input, low, high = nil, **options)
@@ -197,126 +413,19 @@ module Torch
197
413
  high = low
198
414
  low = 0
199
415
  end
200
- rand(input.size, like_options(input, options))
416
+ randint(low, high, input.size, **like_options(input, options))
201
417
  end
202
418
 
203
419
  def randn_like(input, **options)
204
- randn(input.size, like_options(input, options))
420
+ randn(input.size, **like_options(input, options))
205
421
  end
206
422
 
207
423
  def zeros_like(input, **options)
208
- zeros(input.size, like_options(input, options))
209
- end
210
-
211
- # --- begin operations ---
212
-
213
- %w(add sub mul div remainder).each do |op|
214
- define_method(op) do |input, other, **options|
215
- execute_op(op, input, other, **options)
216
- end
217
- end
218
-
219
- def neg(input)
220
- _neg(input)
221
- end
222
-
223
- def no_grad
224
- previous_value = grad_enabled?
225
- begin
226
- _set_grad_enabled(false)
227
- yield
228
- ensure
229
- _set_grad_enabled(previous_value)
230
- end
231
- end
232
-
233
- # TODO support out
234
- def mean(input, dim = nil, keepdim: false)
235
- if dim
236
- _mean_dim(input, dim, keepdim)
237
- else
238
- _mean(input)
239
- end
240
- end
241
-
242
- # TODO support dtype
243
- def sum(input, dim = nil, keepdim: false)
244
- if dim
245
- _sum_dim(input, dim, keepdim)
246
- else
247
- _sum(input)
248
- end
249
- end
250
-
251
- def argmax(input, dim = nil, keepdim: false)
252
- if dim
253
- _argmax_dim(input, dim, keepdim)
254
- else
255
- _argmax(input)
256
- end
257
- end
258
-
259
- def eq(input, other)
260
- _eq(input, other)
261
- end
262
-
263
- def norm(input)
264
- _norm(input)
265
- end
266
-
267
- def pow(input, exponent)
268
- _pow(input, exponent)
269
- end
270
-
271
- def min(input)
272
- _min(input)
273
- end
274
-
275
- def max(input)
276
- _max(input)
277
- end
278
-
279
- def exp(input)
280
- _exp(input)
281
- end
282
-
283
- def log(input)
284
- _log(input)
285
- end
286
-
287
- def unsqueeze(input, dim)
288
- _unsqueeze(input, dim)
289
- end
290
-
291
- def dot(input, tensor)
292
- _dot(input, tensor)
293
- end
294
-
295
- def matmul(input, other)
296
- _matmul(input, other)
297
- end
298
-
299
- def reshape(input, shape)
300
- _reshape(input, shape)
424
+ zeros(input.size, **like_options(input, options))
301
425
  end
302
426
 
303
427
  private
304
428
 
305
- def execute_op(op, input, other, out: nil)
306
- scalar = other.is_a?(Numeric)
307
- if out
308
- # TODO make work with scalars
309
- raise Error, "out not supported with scalar yet" if scalar
310
- send("_#{op}_out", out, input, other)
311
- else
312
- if scalar
313
- send("_#{op}_scalar", input, other)
314
- else
315
- send("_#{op}", input, other)
316
- end
317
- end
318
- end
319
-
320
429
  def tensor_size(size)
321
430
  size.flatten
322
431
  end