torch-rb 0.1.1 → 0.1.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +73 -9
  5. data/ext/torch/ext.cpp +148 -315
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +298 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +236 -112
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +52 -25
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -39
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +419 -16
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +191 -19
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +4 -0
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +62 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +60 -0
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +90 -30
  139. data/lib/torch/utils/data/data_loader.rb +15 -0
  140. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  141. data/lib/torch/version.rb +1 -1
  142. metadata +122 -3
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
@@ -1,32 +1,181 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
7
12
  require "torch/version"
8
13
 
9
- # nn
14
+ # optim
15
+ require "torch/optim/optimizer"
16
+ require "torch/optim/adadelta"
17
+ require "torch/optim/adagrad"
18
+ require "torch/optim/adam"
19
+ require "torch/optim/adamax"
20
+ require "torch/optim/adamw"
21
+ require "torch/optim/asgd"
22
+ require "torch/optim/rmsprop"
23
+ require "torch/optim/rprop"
24
+ require "torch/optim/sgd"
25
+
26
+ # optim lr_scheduler
27
+ require "torch/optim/lr_scheduler/lr_scheduler"
28
+ require "torch/optim/lr_scheduler/step_lr"
29
+
30
+ # nn parameters
31
+ require "torch/nn/parameter"
32
+ require "torch/nn/utils"
33
+
34
+ # nn containers
10
35
  require "torch/nn/module"
11
- require "torch/nn/init"
36
+ require "torch/nn/sequential"
37
+
38
+ # nn convolution layers
39
+ require "torch/nn/convnd"
40
+ require "torch/nn/conv1d"
12
41
  require "torch/nn/conv2d"
13
- require "torch/nn/functional"
42
+ require "torch/nn/conv3d"
43
+ require "torch/nn/unfold"
44
+ require "torch/nn/fold"
45
+
46
+ # nn pooling layers
47
+ require "torch/nn/max_poolnd"
48
+ require "torch/nn/max_pool1d"
49
+ require "torch/nn/max_pool2d"
50
+ require "torch/nn/max_pool3d"
51
+ require "torch/nn/max_unpoolnd"
52
+ require "torch/nn/max_unpool1d"
53
+ require "torch/nn/max_unpool2d"
54
+ require "torch/nn/max_unpool3d"
55
+ require "torch/nn/avg_poolnd"
56
+ require "torch/nn/avg_pool1d"
57
+ require "torch/nn/avg_pool2d"
58
+ require "torch/nn/avg_pool3d"
59
+ require "torch/nn/lp_poolnd"
60
+ require "torch/nn/lp_pool1d"
61
+ require "torch/nn/lp_pool2d"
62
+
63
+ # nn padding layers
64
+ require "torch/nn/reflection_padnd"
65
+ require "torch/nn/reflection_pad1d"
66
+ require "torch/nn/reflection_pad2d"
67
+ require "torch/nn/replication_padnd"
68
+ require "torch/nn/replication_pad1d"
69
+ require "torch/nn/replication_pad2d"
70
+ require "torch/nn/replication_pad3d"
71
+ require "torch/nn/constant_padnd"
72
+ require "torch/nn/constant_pad1d"
73
+ require "torch/nn/constant_pad2d"
74
+ require "torch/nn/constant_pad3d"
75
+ require "torch/nn/zero_pad2d"
76
+
77
+ # nn normalization layers
78
+ require "torch/nn/batch_norm"
79
+ require "torch/nn/batch_norm1d"
80
+ require "torch/nn/batch_norm2d"
81
+ require "torch/nn/batch_norm3d"
82
+ require "torch/nn/group_norm"
83
+ require "torch/nn/instance_norm"
84
+ require "torch/nn/instance_norm1d"
85
+ require "torch/nn/instance_norm2d"
86
+ require "torch/nn/instance_norm3d"
87
+ require "torch/nn/layer_norm"
88
+ require "torch/nn/local_response_norm"
89
+
90
+ # nn recurrent layers
91
+ require "torch/nn/rnn_base"
92
+ require "torch/nn/rnn"
93
+ require "torch/nn/lstm"
94
+ require "torch/nn/gru"
95
+
96
+ # nn linear layers
97
+ require "torch/nn/bilinear"
98
+ require "torch/nn/identity"
14
99
  require "torch/nn/linear"
15
- require "torch/nn/parameter"
16
- require "torch/nn/sequential"
100
+
101
+ # nn dropout layers
102
+ require "torch/nn/dropoutnd"
103
+ require "torch/nn/alpha_dropout"
104
+ require "torch/nn/dropout"
105
+ require "torch/nn/dropout2d"
106
+ require "torch/nn/dropout3d"
107
+ require "torch/nn/feature_alpha_dropout"
108
+
109
+ # nn activations
110
+ require "torch/nn/hardshrink"
111
+ require "torch/nn/leaky_relu"
112
+ require "torch/nn/log_sigmoid"
113
+ require "torch/nn/prelu"
17
114
  require "torch/nn/relu"
115
+ require "torch/nn/sigmoid"
116
+ require "torch/nn/softplus"
117
+ require "torch/nn/softshrink"
118
+ require "torch/nn/softsign"
119
+ require "torch/nn/tanh"
120
+ require "torch/nn/tanhshrink"
121
+
122
+ # nn activations other
123
+ require "torch/nn/log_softmax"
124
+ require "torch/nn/softmax"
125
+ require "torch/nn/softmax2d"
126
+ require "torch/nn/softmin"
127
+
128
+ # nn sparse layers
129
+ require "torch/nn/embedding"
130
+ require "torch/nn/embedding_bag"
131
+
132
+ # nn distance functions
133
+ require "torch/nn/cosine_similarity"
134
+ require "torch/nn/pairwise_distance"
135
+
136
+ # nn loss functions
137
+ require "torch/nn/loss"
138
+ require "torch/nn/weighted_loss"
139
+ require "torch/nn/bce_loss"
140
+ require "torch/nn/bce_with_logits_loss"
141
+ require "torch/nn/cosine_embedding_loss"
142
+ require "torch/nn/cross_entropy_loss"
143
+ require "torch/nn/ctc_loss"
144
+ require "torch/nn/hinge_embedding_loss"
145
+ require "torch/nn/kl_div_loss"
146
+ require "torch/nn/l1_loss"
147
+ require "torch/nn/margin_ranking_loss"
18
148
  require "torch/nn/mse_loss"
149
+ require "torch/nn/multi_label_margin_loss"
150
+ require "torch/nn/multi_label_soft_margin_loss"
151
+ require "torch/nn/multi_margin_loss"
152
+ require "torch/nn/nll_loss"
153
+ require "torch/nn/poisson_nll_loss"
154
+ require "torch/nn/smooth_l1_loss"
155
+ require "torch/nn/soft_margin_loss"
156
+ require "torch/nn/triplet_margin_loss"
157
+
158
+ # nn other
159
+ require "torch/nn/functional"
160
+ require "torch/nn/init"
19
161
 
20
162
  # utils
21
163
  require "torch/utils/data/data_loader"
22
164
  require "torch/utils/data/tensor_dataset"
23
165
 
166
+ # random
167
+ require "torch/random"
168
+
24
169
  module Torch
25
170
  class Error < StandardError; end
171
+ class NotImplementedYet < StandardError
172
+ def message
173
+ "This feature has not been implemented yet. Consider submitting a PR."
174
+ end
175
+ end
26
176
 
27
177
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
28
178
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
29
- # complex and quantized types not supported by PyTorch yet
30
179
  DTYPE_TO_ENUM = {
31
180
  uint8: 0,
32
181
  int8: 1,
@@ -42,22 +191,61 @@ module Torch
42
191
  float32: 6,
43
192
  double: 7,
44
193
  float64: 7,
45
- # complex_half: 8,
46
- # complex_float: 9,
47
- # complex_double: 10,
194
+ complex_half: 8,
195
+ complex_float: 9,
196
+ complex_double: 10,
48
197
  bool: 11,
49
- # qint8: 12,
50
- # quint8: 13,
51
- # qint32: 14,
52
- # bfloat16: 15
198
+ qint8: 12,
199
+ quint8: 13,
200
+ qint32: 14,
201
+ bfloat16: 15
53
202
  }
54
203
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
55
204
 
205
+ def self._make_tensor_class(dtype, cuda = false)
206
+ cls = Class.new
207
+ device = cuda ? "cuda" : "cpu"
208
+ cls.define_singleton_method("new") do |*args|
209
+ if args.size == 1 && args.first.is_a?(Tensor)
210
+ args.first.send(dtype).to(device)
211
+ elsif args.size == 1 && args.first.is_a?(Array)
212
+ Torch.tensor(args.first, dtype: dtype, device: device)
213
+ else
214
+ Torch.empty(*args, dtype: dtype, device: device)
215
+ end
216
+ end
217
+ cls
218
+ end
219
+
220
+ FloatTensor = _make_tensor_class(:float32)
221
+ DoubleTensor = _make_tensor_class(:float64)
222
+ HalfTensor = _make_tensor_class(:float16)
223
+ ByteTensor = _make_tensor_class(:uint8)
224
+ CharTensor = _make_tensor_class(:int8)
225
+ ShortTensor = _make_tensor_class(:int16)
226
+ IntTensor = _make_tensor_class(:int32)
227
+ LongTensor = _make_tensor_class(:int64)
228
+ BoolTensor = _make_tensor_class(:bool)
229
+
230
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
231
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
232
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
233
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
234
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
235
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
236
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
237
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
238
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
239
+
56
240
  class << self
57
241
  # Torch.float, Torch.long, etc
58
- DTYPE_TO_ENUM.each_key do |type|
59
- define_method(type) do
60
- type
242
+ DTYPE_TO_ENUM.each_key do |dtype|
243
+ define_method(dtype) do
244
+ dtype
245
+ end
246
+
247
+ Tensor.define_method(dtype) do
248
+ type(dtype)
61
249
  end
62
250
  end
63
251
 
@@ -67,17 +255,26 @@ module Torch
67
255
  obj.is_a?(Tensor)
68
256
  end
69
257
 
70
- # TODO don't copy
71
258
  def from_numo(ndarray)
72
259
  dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
73
260
  raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
74
- tensor(ndarray.to_a, dtype: dtype[0])
261
+ options = tensor_options(device: "cpu", dtype: dtype[0])
262
+ # TODO pass pointer to array instead of creating string
263
+ str = ndarray.to_string
264
+ tensor = _from_blob(str, ndarray.shape, options)
265
+ # from_blob does not own the data, so we need to keep
266
+ # a reference to it for duration of tensor
267
+ # can remove when passing pointer directly
268
+ tensor.instance_variable_set("@_numo_str", str)
269
+ tensor
75
270
  end
76
271
 
77
272
  # private
78
273
  # use method for cases when Numo not available
79
274
  # or available after Torch loaded
80
275
  def _dtype_to_numo
276
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
277
+
81
278
  {
82
279
  uint8: Numo::UInt8,
83
280
  int8: Numo::Int8,
@@ -89,6 +286,20 @@ module Torch
89
286
  }
90
287
  end
91
288
 
289
+ def no_grad
290
+ previous_value = grad_enabled?
291
+ begin
292
+ _set_grad_enabled(false)
293
+ yield
294
+ ensure
295
+ _set_grad_enabled(previous_value)
296
+ end
297
+ end
298
+
299
+ def device(str)
300
+ Device.new(str)
301
+ end
302
+
92
303
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
93
304
 
94
305
  def arange(start, finish = nil, step = 1, **options)
@@ -158,8 +369,12 @@ module Torch
158
369
  data = [data].compact
159
370
  end
160
371
 
161
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
162
- options[:dtype] = :int64
372
+ if options[:dtype].nil?
373
+ if data.all? { |v| v.is_a?(Integer) }
374
+ options[:dtype] = :int64
375
+ elsif data.all? { |v| v == true || v == false }
376
+ options[:dtype] = :bool
377
+ end
163
378
  end
164
379
 
165
380
  _tensor(data, size, tensor_options(**options))
@@ -189,7 +404,7 @@ module Torch
189
404
  high = low
190
405
  low = 0
191
406
  end
192
- rand(input.size, like_options(input, options))
407
+ randint(low, high, input.size, like_options(input, options))
193
408
  end
194
409
 
195
410
  def randn_like(input, **options)
@@ -200,99 +415,8 @@ module Torch
200
415
  zeros(input.size, like_options(input, options))
201
416
  end
202
417
 
203
- # --- begin operations ---
204
-
205
- %w(add sub mul div remainder).each do |op|
206
- define_method(op) do |input, other, **options|
207
- execute_op(op, input, other, **options)
208
- end
209
- end
210
-
211
- def neg(input)
212
- _neg(input)
213
- end
214
-
215
- def no_grad
216
- previous_value = grad_enabled?
217
- begin
218
- _set_grad_enabled(false)
219
- yield
220
- ensure
221
- _set_grad_enabled(previous_value)
222
- end
223
- end
224
-
225
- # TODO support out
226
- def mean(input, dim = nil, keepdim: false)
227
- if dim
228
- _mean_dim(input, dim, keepdim)
229
- else
230
- _mean(input)
231
- end
232
- end
233
-
234
- # TODO support dtype
235
- def sum(input, dim = nil, keepdim: false)
236
- if dim
237
- _sum_dim(input, dim, keepdim)
238
- else
239
- _sum(input)
240
- end
241
- end
242
-
243
- def norm(input)
244
- _norm(input)
245
- end
246
-
247
- def pow(input, exponent)
248
- _pow(input, exponent)
249
- end
250
-
251
- def min(input)
252
- _min(input)
253
- end
254
-
255
- def max(input)
256
- _max(input)
257
- end
258
-
259
- def exp(input)
260
- _exp(input)
261
- end
262
-
263
- def log(input)
264
- _log(input)
265
- end
266
-
267
- def unsqueeze(input, dim)
268
- _unsqueeze(input, dim)
269
- end
270
-
271
- def dot(input, tensor)
272
- _dot(input, tensor)
273
- end
274
-
275
- def matmul(input, other)
276
- _matmul(input, other)
277
- end
278
-
279
418
  private
280
419
 
281
- def execute_op(op, input, other, out: nil)
282
- scalar = other.is_a?(Numeric)
283
- if out
284
- # TODO make work with scalars
285
- raise Error, "out not supported with scalar yet" if scalar
286
- send("_#{op}_out", out, input, other)
287
- else
288
- if scalar
289
- send("_#{op}_scalar", input, other)
290
- else
291
- send("_#{op}", input, other)
292
- end
293
- end
294
- end
295
-
296
420
  def tensor_size(size)
297
421
  size.flatten
298
422
  end
Binary file
@@ -1,41 +1,49 @@
1
1
  module Torch
2
2
  module Inspector
3
+ # TODO make more performance, especially when summarizing
4
+ # how? only read data that will be displayed
3
5
  def inspect
4
6
  data =
5
7
  if numel == 0
6
8
  "[]"
7
9
  elsif dim == 0
8
- to_a.first
10
+ item
9
11
  else
10
- values = to_a.flatten
11
- abs = values.select { |v| v != 0 }.map(&:abs)
12
- max = abs.max || 1
13
- min = abs.min || 1
14
-
15
- total = 0
16
- if values.any? { |v| v < 0 }
17
- total += 1
18
- end
12
+ summarize = numel > 1000
13
+
14
+ if dtype == :bool
15
+ fmt = "%s"
16
+ else
17
+ values = to_a.flatten
18
+ abs = values.select { |v| v != 0 }.map(&:abs)
19
+ max = abs.max || 1
20
+ min = abs.min || 1
19
21
 
20
- if floating_point?
21
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
+ total = 0
23
+ if values.any? { |v| v < 0 }
24
+ total += 1
25
+ end
26
+
27
+ if floating_point?
28
+ sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
29
 
23
- all_int = values.all? { |v| v == v.to_i }
24
- decimal = all_int ? 1 : 4
30
+ all_int = values.all? { |v| v.finite? && v == v.to_i }
31
+ decimal = all_int ? 1 : 4
25
32
 
26
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
33
+ total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
27
34
 
28
- if sci
29
- fmt = "%#{total}.4e"
35
+ if sci
36
+ fmt = "%#{total}.4e"
37
+ else
38
+ fmt = "%#{total}.#{decimal}f"
39
+ end
30
40
  else
31
- fmt = "%#{total}.#{decimal}f"
41
+ total += max.to_s.size
42
+ fmt = "%#{total}d"
32
43
  end
33
- else
34
- total += max.to_s.size
35
- fmt = "%#{total}d"
36
44
  end
37
45
 
38
- inspect_level(to_a, fmt, dim - 1)
46
+ inspect_level(to_a, fmt, dim - 1, 0, summarize)
39
47
  end
40
48
 
41
49
  attributes = []
@@ -51,11 +59,30 @@ module Torch
51
59
 
52
60
  private
53
61
 
54
- def inspect_level(arr, fmt, total, level = 0)
62
+ # TODO DRY code
63
+ def inspect_level(arr, fmt, total, level, summarize)
55
64
  if level == total
56
- "[#{arr.map { |v| fmt % v }.join(", ")}]"
65
+ cols =
66
+ if summarize && arr.size > 7
67
+ arr[0..2].map { |v| fmt % v } +
68
+ ["..."] +
69
+ arr[-3..-1].map { |v| fmt % v }
70
+ else
71
+ arr.map { |v| fmt % v }
72
+ end
73
+
74
+ "[#{cols.join(", ")}]"
57
75
  else
58
- "[#{arr.map { |row| inspect_level(row, fmt, total, level + 1) }.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
76
+ rows =
77
+ if summarize && arr.size > 7
78
+ arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
+ ["..."] +
80
+ arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
+ else
82
+ arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
+ end
84
+
85
+ "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
59
86
  end
60
87
  end
61
88
  end