torch-rb 0.1.2 → 0.1.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
data/lib/torch.rb CHANGED
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
@@ -8,29 +13,169 @@ require "torch/version"
8
13
 
9
14
  # optim
10
15
  require "torch/optim/optimizer"
16
+ require "torch/optim/adadelta"
17
+ require "torch/optim/adagrad"
18
+ require "torch/optim/adam"
19
+ require "torch/optim/adamax"
20
+ require "torch/optim/adamw"
21
+ require "torch/optim/asgd"
22
+ require "torch/optim/rmsprop"
23
+ require "torch/optim/rprop"
11
24
  require "torch/optim/sgd"
12
25
 
13
- # nn
26
+ # optim lr_scheduler
27
+ require "torch/optim/lr_scheduler/lr_scheduler"
28
+ require "torch/optim/lr_scheduler/step_lr"
29
+
30
+ # nn parameters
31
+ require "torch/nn/parameter"
32
+ require "torch/nn/utils"
33
+
34
+ # nn containers
14
35
  require "torch/nn/module"
15
- require "torch/nn/init"
36
+ require "torch/nn/sequential"
37
+
38
+ # nn convolution layers
39
+ require "torch/nn/convnd"
40
+ require "torch/nn/conv1d"
16
41
  require "torch/nn/conv2d"
17
- require "torch/nn/functional"
42
+ require "torch/nn/conv3d"
43
+ require "torch/nn/unfold"
44
+ require "torch/nn/fold"
45
+
46
+ # nn pooling layers
47
+ require "torch/nn/max_poolnd"
48
+ require "torch/nn/max_pool1d"
49
+ require "torch/nn/max_pool2d"
50
+ require "torch/nn/max_pool3d"
51
+ require "torch/nn/max_unpoolnd"
52
+ require "torch/nn/max_unpool1d"
53
+ require "torch/nn/max_unpool2d"
54
+ require "torch/nn/max_unpool3d"
55
+ require "torch/nn/avg_poolnd"
56
+ require "torch/nn/avg_pool1d"
57
+ require "torch/nn/avg_pool2d"
58
+ require "torch/nn/avg_pool3d"
59
+ require "torch/nn/lp_poolnd"
60
+ require "torch/nn/lp_pool1d"
61
+ require "torch/nn/lp_pool2d"
62
+
63
+ # nn padding layers
64
+ require "torch/nn/reflection_padnd"
65
+ require "torch/nn/reflection_pad1d"
66
+ require "torch/nn/reflection_pad2d"
67
+ require "torch/nn/replication_padnd"
68
+ require "torch/nn/replication_pad1d"
69
+ require "torch/nn/replication_pad2d"
70
+ require "torch/nn/replication_pad3d"
71
+ require "torch/nn/constant_padnd"
72
+ require "torch/nn/constant_pad1d"
73
+ require "torch/nn/constant_pad2d"
74
+ require "torch/nn/constant_pad3d"
75
+ require "torch/nn/zero_pad2d"
76
+
77
+ # nn normalization layers
78
+ require "torch/nn/batch_norm"
79
+ require "torch/nn/batch_norm1d"
80
+ require "torch/nn/batch_norm2d"
81
+ require "torch/nn/batch_norm3d"
82
+ require "torch/nn/group_norm"
83
+ require "torch/nn/instance_norm"
84
+ require "torch/nn/instance_norm1d"
85
+ require "torch/nn/instance_norm2d"
86
+ require "torch/nn/instance_norm3d"
87
+ require "torch/nn/layer_norm"
88
+ require "torch/nn/local_response_norm"
89
+
90
+ # nn recurrent layers
91
+ require "torch/nn/rnn_base"
92
+ require "torch/nn/rnn"
93
+ require "torch/nn/lstm"
94
+ require "torch/nn/gru"
95
+
96
+ # nn linear layers
97
+ require "torch/nn/bilinear"
98
+ require "torch/nn/identity"
18
99
  require "torch/nn/linear"
19
- require "torch/nn/parameter"
20
- require "torch/nn/sequential"
100
+
101
+ # nn dropout layers
102
+ require "torch/nn/dropoutnd"
103
+ require "torch/nn/alpha_dropout"
104
+ require "torch/nn/dropout"
105
+ require "torch/nn/dropout2d"
106
+ require "torch/nn/dropout3d"
107
+ require "torch/nn/feature_alpha_dropout"
108
+
109
+ # nn activations
110
+ require "torch/nn/hardshrink"
111
+ require "torch/nn/leaky_relu"
112
+ require "torch/nn/log_sigmoid"
113
+ require "torch/nn/prelu"
21
114
  require "torch/nn/relu"
115
+ require "torch/nn/sigmoid"
116
+ require "torch/nn/softplus"
117
+ require "torch/nn/softshrink"
118
+ require "torch/nn/softsign"
119
+ require "torch/nn/tanh"
120
+ require "torch/nn/tanhshrink"
121
+
122
+ # nn activations other
123
+ require "torch/nn/log_softmax"
124
+ require "torch/nn/softmax"
125
+ require "torch/nn/softmax2d"
126
+ require "torch/nn/softmin"
127
+
128
+ # nn sparse layers
129
+ require "torch/nn/embedding"
130
+ require "torch/nn/embedding_bag"
131
+
132
+ # nn distance functions
133
+ require "torch/nn/cosine_similarity"
134
+ require "torch/nn/pairwise_distance"
135
+
136
+ # nn loss functions
137
+ require "torch/nn/loss"
138
+ require "torch/nn/weighted_loss"
139
+ require "torch/nn/bce_loss"
140
+ require "torch/nn/bce_with_logits_loss"
141
+ require "torch/nn/cosine_embedding_loss"
142
+ require "torch/nn/cross_entropy_loss"
143
+ require "torch/nn/ctc_loss"
144
+ require "torch/nn/hinge_embedding_loss"
145
+ require "torch/nn/kl_div_loss"
146
+ require "torch/nn/l1_loss"
147
+ require "torch/nn/margin_ranking_loss"
22
148
  require "torch/nn/mse_loss"
149
+ require "torch/nn/multi_label_margin_loss"
150
+ require "torch/nn/multi_label_soft_margin_loss"
151
+ require "torch/nn/multi_margin_loss"
152
+ require "torch/nn/nll_loss"
153
+ require "torch/nn/poisson_nll_loss"
154
+ require "torch/nn/smooth_l1_loss"
155
+ require "torch/nn/soft_margin_loss"
156
+ require "torch/nn/triplet_margin_loss"
157
+
158
+ # nn other
159
+ require "torch/nn/functional"
160
+ require "torch/nn/init"
23
161
 
24
162
  # utils
25
163
  require "torch/utils/data/data_loader"
26
164
  require "torch/utils/data/tensor_dataset"
27
165
 
166
+ # random
167
+ require "torch/random"
168
+
28
169
  module Torch
29
170
  class Error < StandardError; end
171
+ class NotImplementedYet < StandardError
172
+ def message
173
+ "This feature has not been implemented yet. Consider submitting a PR."
174
+ end
175
+ end
30
176
 
31
177
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
32
178
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
33
- # complex and quantized types not supported by PyTorch yet
34
179
  DTYPE_TO_ENUM = {
35
180
  uint8: 0,
36
181
  int8: 1,
@@ -46,17 +191,52 @@ module Torch
46
191
  float32: 6,
47
192
  double: 7,
48
193
  float64: 7,
49
- # complex_half: 8,
50
- # complex_float: 9,
51
- # complex_double: 10,
194
+ complex_half: 8,
195
+ complex_float: 9,
196
+ complex_double: 10,
52
197
  bool: 11,
53
- # qint8: 12,
54
- # quint8: 13,
55
- # qint32: 14,
56
- # bfloat16: 15
198
+ qint8: 12,
199
+ quint8: 13,
200
+ qint32: 14,
201
+ bfloat16: 15
57
202
  }
58
203
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
59
204
 
205
+ def self._make_tensor_class(dtype, cuda = false)
206
+ cls = Class.new
207
+ device = cuda ? "cuda" : "cpu"
208
+ cls.define_singleton_method("new") do |*args|
209
+ if args.size == 1 && args.first.is_a?(Tensor)
210
+ args.first.send(dtype).to(device)
211
+ elsif args.size == 1 && args.first.is_a?(Array)
212
+ Torch.tensor(args.first, dtype: dtype, device: device)
213
+ else
214
+ Torch.empty(*args, dtype: dtype, device: device)
215
+ end
216
+ end
217
+ cls
218
+ end
219
+
220
+ FloatTensor = _make_tensor_class(:float32)
221
+ DoubleTensor = _make_tensor_class(:float64)
222
+ HalfTensor = _make_tensor_class(:float16)
223
+ ByteTensor = _make_tensor_class(:uint8)
224
+ CharTensor = _make_tensor_class(:int8)
225
+ ShortTensor = _make_tensor_class(:int16)
226
+ IntTensor = _make_tensor_class(:int32)
227
+ LongTensor = _make_tensor_class(:int64)
228
+ BoolTensor = _make_tensor_class(:bool)
229
+
230
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
231
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
232
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
233
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
234
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
235
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
236
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
237
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
238
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
239
+
60
240
  class << self
61
241
  # Torch.float, Torch.long, etc
62
242
  DTYPE_TO_ENUM.each_key do |dtype|
@@ -75,17 +255,26 @@ module Torch
75
255
  obj.is_a?(Tensor)
76
256
  end
77
257
 
78
- # TODO don't copy
79
258
  def from_numo(ndarray)
80
259
  dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
81
260
  raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
82
- tensor(ndarray.to_a, dtype: dtype[0])
261
+ options = tensor_options(device: "cpu", dtype: dtype[0])
262
+ # TODO pass pointer to array instead of creating string
263
+ str = ndarray.to_string
264
+ tensor = _from_blob(str, ndarray.shape, options)
265
+ # from_blob does not own the data, so we need to keep
266
+ # a reference to it for duration of tensor
267
+ # can remove when passing pointer directly
268
+ tensor.instance_variable_set("@_numo_str", str)
269
+ tensor
83
270
  end
84
271
 
85
272
  # private
86
273
  # use method for cases when Numo not available
87
274
  # or available after Torch loaded
88
275
  def _dtype_to_numo
276
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
277
+
89
278
  {
90
279
  uint8: Numo::UInt8,
91
280
  int8: Numo::Int8,
@@ -97,6 +286,29 @@ module Torch
97
286
  }
98
287
  end
99
288
 
289
+ def no_grad
290
+ previous_value = grad_enabled?
291
+ begin
292
+ _set_grad_enabled(false)
293
+ yield
294
+ ensure
295
+ _set_grad_enabled(previous_value)
296
+ end
297
+ end
298
+
299
+ def device(str)
300
+ Device.new(str)
301
+ end
302
+
303
+ def save(obj, f)
304
+ raise NotImplementedYet unless obj.is_a?(Tensor)
305
+ File.binwrite(f, _save(obj))
306
+ end
307
+
308
+ def load(f)
309
+ raise NotImplementedYet
310
+ end
311
+
100
312
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
101
313
 
102
314
  def arange(start, finish = nil, step = 1, **options)
@@ -166,8 +378,12 @@ module Torch
166
378
  data = [data].compact
167
379
  end
168
380
 
169
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
170
- options[:dtype] = :int64
381
+ if options[:dtype].nil?
382
+ if data.all? { |v| v.is_a?(Integer) }
383
+ options[:dtype] = :int64
384
+ elsif data.all? { |v| v == true || v == false }
385
+ options[:dtype] = :bool
386
+ end
171
387
  end
172
388
 
173
389
  _tensor(data, size, tensor_options(**options))
@@ -176,19 +392,19 @@ module Torch
176
392
  # --- begin like ---
177
393
 
178
394
  def ones_like(input, **options)
179
- ones(input.size, like_options(input, options))
395
+ ones(input.size, **like_options(input, options))
180
396
  end
181
397
 
182
398
  def empty_like(input, **options)
183
- empty(input.size, like_options(input, options))
399
+ empty(input.size, **like_options(input, options))
184
400
  end
185
401
 
186
402
  def full_like(input, fill_value, **options)
187
- full(input.size, fill_value, like_options(input, options))
403
+ full(input.size, fill_value, **like_options(input, options))
188
404
  end
189
405
 
190
406
  def rand_like(input, **options)
191
- rand(input.size, like_options(input, options))
407
+ rand(input.size, **like_options(input, options))
192
408
  end
193
409
 
194
410
  def randint_like(input, low, high = nil, **options)
@@ -197,126 +413,19 @@ module Torch
197
413
  high = low
198
414
  low = 0
199
415
  end
200
- rand(input.size, like_options(input, options))
416
+ randint(low, high, input.size, **like_options(input, options))
201
417
  end
202
418
 
203
419
  def randn_like(input, **options)
204
- randn(input.size, like_options(input, options))
420
+ randn(input.size, **like_options(input, options))
205
421
  end
206
422
 
207
423
  def zeros_like(input, **options)
208
- zeros(input.size, like_options(input, options))
209
- end
210
-
211
- # --- begin operations ---
212
-
213
- %w(add sub mul div remainder).each do |op|
214
- define_method(op) do |input, other, **options|
215
- execute_op(op, input, other, **options)
216
- end
217
- end
218
-
219
- def neg(input)
220
- _neg(input)
221
- end
222
-
223
- def no_grad
224
- previous_value = grad_enabled?
225
- begin
226
- _set_grad_enabled(false)
227
- yield
228
- ensure
229
- _set_grad_enabled(previous_value)
230
- end
231
- end
232
-
233
- # TODO support out
234
- def mean(input, dim = nil, keepdim: false)
235
- if dim
236
- _mean_dim(input, dim, keepdim)
237
- else
238
- _mean(input)
239
- end
240
- end
241
-
242
- # TODO support dtype
243
- def sum(input, dim = nil, keepdim: false)
244
- if dim
245
- _sum_dim(input, dim, keepdim)
246
- else
247
- _sum(input)
248
- end
249
- end
250
-
251
- def argmax(input, dim = nil, keepdim: false)
252
- if dim
253
- _argmax_dim(input, dim, keepdim)
254
- else
255
- _argmax(input)
256
- end
257
- end
258
-
259
- def eq(input, other)
260
- _eq(input, other)
261
- end
262
-
263
- def norm(input)
264
- _norm(input)
265
- end
266
-
267
- def pow(input, exponent)
268
- _pow(input, exponent)
269
- end
270
-
271
- def min(input)
272
- _min(input)
273
- end
274
-
275
- def max(input)
276
- _max(input)
277
- end
278
-
279
- def exp(input)
280
- _exp(input)
281
- end
282
-
283
- def log(input)
284
- _log(input)
285
- end
286
-
287
- def unsqueeze(input, dim)
288
- _unsqueeze(input, dim)
289
- end
290
-
291
- def dot(input, tensor)
292
- _dot(input, tensor)
293
- end
294
-
295
- def matmul(input, other)
296
- _matmul(input, other)
297
- end
298
-
299
- def reshape(input, shape)
300
- _reshape(input, shape)
424
+ zeros(input.size, **like_options(input, options))
301
425
  end
302
426
 
303
427
  private
304
428
 
305
- def execute_op(op, input, other, out: nil)
306
- scalar = other.is_a?(Numeric)
307
- if out
308
- # TODO make work with scalars
309
- raise Error, "out not supported with scalar yet" if scalar
310
- send("_#{op}_out", out, input, other)
311
- else
312
- if scalar
313
- send("_#{op}_scalar", input, other)
314
- else
315
- send("_#{op}", input, other)
316
- end
317
- end
318
- end
319
-
320
429
  def tensor_size(size)
321
430
  size.flatten
322
431
  end