torch-rb 0.1.1 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +73 -9
  5. data/ext/torch/ext.cpp +148 -315
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +298 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +236 -112
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +52 -25
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -39
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +419 -16
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +191 -19
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +4 -0
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +62 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +60 -0
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +90 -30
  139. data/lib/torch/utils/data/data_loader.rb +15 -0
  140. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  141. data/lib/torch/version.rb +1 -1
  142. metadata +122 -3
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
@@ -1,32 +1,181 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # native functions
5
+ require "torch/native/generator"
6
+ require "torch/native/parser"
7
+ require "torch/native/dispatcher"
8
+
4
9
  # modules
5
10
  require "torch/inspector"
6
11
  require "torch/tensor"
7
12
  require "torch/version"
8
13
 
9
- # nn
14
+ # optim
15
+ require "torch/optim/optimizer"
16
+ require "torch/optim/adadelta"
17
+ require "torch/optim/adagrad"
18
+ require "torch/optim/adam"
19
+ require "torch/optim/adamax"
20
+ require "torch/optim/adamw"
21
+ require "torch/optim/asgd"
22
+ require "torch/optim/rmsprop"
23
+ require "torch/optim/rprop"
24
+ require "torch/optim/sgd"
25
+
26
+ # optim lr_scheduler
27
+ require "torch/optim/lr_scheduler/lr_scheduler"
28
+ require "torch/optim/lr_scheduler/step_lr"
29
+
30
+ # nn parameters
31
+ require "torch/nn/parameter"
32
+ require "torch/nn/utils"
33
+
34
+ # nn containers
10
35
  require "torch/nn/module"
11
- require "torch/nn/init"
36
+ require "torch/nn/sequential"
37
+
38
+ # nn convolution layers
39
+ require "torch/nn/convnd"
40
+ require "torch/nn/conv1d"
12
41
  require "torch/nn/conv2d"
13
- require "torch/nn/functional"
42
+ require "torch/nn/conv3d"
43
+ require "torch/nn/unfold"
44
+ require "torch/nn/fold"
45
+
46
+ # nn pooling layers
47
+ require "torch/nn/max_poolnd"
48
+ require "torch/nn/max_pool1d"
49
+ require "torch/nn/max_pool2d"
50
+ require "torch/nn/max_pool3d"
51
+ require "torch/nn/max_unpoolnd"
52
+ require "torch/nn/max_unpool1d"
53
+ require "torch/nn/max_unpool2d"
54
+ require "torch/nn/max_unpool3d"
55
+ require "torch/nn/avg_poolnd"
56
+ require "torch/nn/avg_pool1d"
57
+ require "torch/nn/avg_pool2d"
58
+ require "torch/nn/avg_pool3d"
59
+ require "torch/nn/lp_poolnd"
60
+ require "torch/nn/lp_pool1d"
61
+ require "torch/nn/lp_pool2d"
62
+
63
+ # nn padding layers
64
+ require "torch/nn/reflection_padnd"
65
+ require "torch/nn/reflection_pad1d"
66
+ require "torch/nn/reflection_pad2d"
67
+ require "torch/nn/replication_padnd"
68
+ require "torch/nn/replication_pad1d"
69
+ require "torch/nn/replication_pad2d"
70
+ require "torch/nn/replication_pad3d"
71
+ require "torch/nn/constant_padnd"
72
+ require "torch/nn/constant_pad1d"
73
+ require "torch/nn/constant_pad2d"
74
+ require "torch/nn/constant_pad3d"
75
+ require "torch/nn/zero_pad2d"
76
+
77
+ # nn normalization layers
78
+ require "torch/nn/batch_norm"
79
+ require "torch/nn/batch_norm1d"
80
+ require "torch/nn/batch_norm2d"
81
+ require "torch/nn/batch_norm3d"
82
+ require "torch/nn/group_norm"
83
+ require "torch/nn/instance_norm"
84
+ require "torch/nn/instance_norm1d"
85
+ require "torch/nn/instance_norm2d"
86
+ require "torch/nn/instance_norm3d"
87
+ require "torch/nn/layer_norm"
88
+ require "torch/nn/local_response_norm"
89
+
90
+ # nn recurrent layers
91
+ require "torch/nn/rnn_base"
92
+ require "torch/nn/rnn"
93
+ require "torch/nn/lstm"
94
+ require "torch/nn/gru"
95
+
96
+ # nn linear layers
97
+ require "torch/nn/bilinear"
98
+ require "torch/nn/identity"
14
99
  require "torch/nn/linear"
15
- require "torch/nn/parameter"
16
- require "torch/nn/sequential"
100
+
101
+ # nn dropout layers
102
+ require "torch/nn/dropoutnd"
103
+ require "torch/nn/alpha_dropout"
104
+ require "torch/nn/dropout"
105
+ require "torch/nn/dropout2d"
106
+ require "torch/nn/dropout3d"
107
+ require "torch/nn/feature_alpha_dropout"
108
+
109
+ # nn activations
110
+ require "torch/nn/hardshrink"
111
+ require "torch/nn/leaky_relu"
112
+ require "torch/nn/log_sigmoid"
113
+ require "torch/nn/prelu"
17
114
  require "torch/nn/relu"
115
+ require "torch/nn/sigmoid"
116
+ require "torch/nn/softplus"
117
+ require "torch/nn/softshrink"
118
+ require "torch/nn/softsign"
119
+ require "torch/nn/tanh"
120
+ require "torch/nn/tanhshrink"
121
+
122
+ # nn activations other
123
+ require "torch/nn/log_softmax"
124
+ require "torch/nn/softmax"
125
+ require "torch/nn/softmax2d"
126
+ require "torch/nn/softmin"
127
+
128
+ # nn sparse layers
129
+ require "torch/nn/embedding"
130
+ require "torch/nn/embedding_bag"
131
+
132
+ # nn distance functions
133
+ require "torch/nn/cosine_similarity"
134
+ require "torch/nn/pairwise_distance"
135
+
136
+ # nn loss functions
137
+ require "torch/nn/loss"
138
+ require "torch/nn/weighted_loss"
139
+ require "torch/nn/bce_loss"
140
+ require "torch/nn/bce_with_logits_loss"
141
+ require "torch/nn/cosine_embedding_loss"
142
+ require "torch/nn/cross_entropy_loss"
143
+ require "torch/nn/ctc_loss"
144
+ require "torch/nn/hinge_embedding_loss"
145
+ require "torch/nn/kl_div_loss"
146
+ require "torch/nn/l1_loss"
147
+ require "torch/nn/margin_ranking_loss"
18
148
  require "torch/nn/mse_loss"
149
+ require "torch/nn/multi_label_margin_loss"
150
+ require "torch/nn/multi_label_soft_margin_loss"
151
+ require "torch/nn/multi_margin_loss"
152
+ require "torch/nn/nll_loss"
153
+ require "torch/nn/poisson_nll_loss"
154
+ require "torch/nn/smooth_l1_loss"
155
+ require "torch/nn/soft_margin_loss"
156
+ require "torch/nn/triplet_margin_loss"
157
+
158
+ # nn other
159
+ require "torch/nn/functional"
160
+ require "torch/nn/init"
19
161
 
20
162
  # utils
21
163
  require "torch/utils/data/data_loader"
22
164
  require "torch/utils/data/tensor_dataset"
23
165
 
166
+ # random
167
+ require "torch/random"
168
+
24
169
  module Torch
25
170
  class Error < StandardError; end
171
+ class NotImplementedYet < StandardError
172
+ def message
173
+ "This feature has not been implemented yet. Consider submitting a PR."
174
+ end
175
+ end
26
176
 
27
177
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
28
178
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
29
- # complex and quantized types not supported by PyTorch yet
30
179
  DTYPE_TO_ENUM = {
31
180
  uint8: 0,
32
181
  int8: 1,
@@ -42,22 +191,61 @@ module Torch
42
191
  float32: 6,
43
192
  double: 7,
44
193
  float64: 7,
45
- # complex_half: 8,
46
- # complex_float: 9,
47
- # complex_double: 10,
194
+ complex_half: 8,
195
+ complex_float: 9,
196
+ complex_double: 10,
48
197
  bool: 11,
49
- # qint8: 12,
50
- # quint8: 13,
51
- # qint32: 14,
52
- # bfloat16: 15
198
+ qint8: 12,
199
+ quint8: 13,
200
+ qint32: 14,
201
+ bfloat16: 15
53
202
  }
54
203
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
55
204
 
205
+ def self._make_tensor_class(dtype, cuda = false)
206
+ cls = Class.new
207
+ device = cuda ? "cuda" : "cpu"
208
+ cls.define_singleton_method("new") do |*args|
209
+ if args.size == 1 && args.first.is_a?(Tensor)
210
+ args.first.send(dtype).to(device)
211
+ elsif args.size == 1 && args.first.is_a?(Array)
212
+ Torch.tensor(args.first, dtype: dtype, device: device)
213
+ else
214
+ Torch.empty(*args, dtype: dtype, device: device)
215
+ end
216
+ end
217
+ cls
218
+ end
219
+
220
+ FloatTensor = _make_tensor_class(:float32)
221
+ DoubleTensor = _make_tensor_class(:float64)
222
+ HalfTensor = _make_tensor_class(:float16)
223
+ ByteTensor = _make_tensor_class(:uint8)
224
+ CharTensor = _make_tensor_class(:int8)
225
+ ShortTensor = _make_tensor_class(:int16)
226
+ IntTensor = _make_tensor_class(:int32)
227
+ LongTensor = _make_tensor_class(:int64)
228
+ BoolTensor = _make_tensor_class(:bool)
229
+
230
+ CUDA::FloatTensor = _make_tensor_class(:float32, true)
231
+ CUDA::DoubleTensor = _make_tensor_class(:float64, true)
232
+ CUDA::HalfTensor = _make_tensor_class(:float16, true)
233
+ CUDA::ByteTensor = _make_tensor_class(:uint8, true)
234
+ CUDA::CharTensor = _make_tensor_class(:int8, true)
235
+ CUDA::ShortTensor = _make_tensor_class(:int16, true)
236
+ CUDA::IntTensor = _make_tensor_class(:int32, true)
237
+ CUDA::LongTensor = _make_tensor_class(:int64, true)
238
+ CUDA::BoolTensor = _make_tensor_class(:bool, true)
239
+
56
240
  class << self
57
241
  # Torch.float, Torch.long, etc
58
- DTYPE_TO_ENUM.each_key do |type|
59
- define_method(type) do
60
- type
242
+ DTYPE_TO_ENUM.each_key do |dtype|
243
+ define_method(dtype) do
244
+ dtype
245
+ end
246
+
247
+ Tensor.define_method(dtype) do
248
+ type(dtype)
61
249
  end
62
250
  end
63
251
 
@@ -67,17 +255,26 @@ module Torch
67
255
  obj.is_a?(Tensor)
68
256
  end
69
257
 
70
- # TODO don't copy
71
258
  def from_numo(ndarray)
72
259
  dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
73
260
  raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
74
- tensor(ndarray.to_a, dtype: dtype[0])
261
+ options = tensor_options(device: "cpu", dtype: dtype[0])
262
+ # TODO pass pointer to array instead of creating string
263
+ str = ndarray.to_string
264
+ tensor = _from_blob(str, ndarray.shape, options)
265
+ # from_blob does not own the data, so we need to keep
266
+ # a reference to it for duration of tensor
267
+ # can remove when passing pointer directly
268
+ tensor.instance_variable_set("@_numo_str", str)
269
+ tensor
75
270
  end
76
271
 
77
272
  # private
78
273
  # use method for cases when Numo not available
79
274
  # or available after Torch loaded
80
275
  def _dtype_to_numo
276
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
277
+
81
278
  {
82
279
  uint8: Numo::UInt8,
83
280
  int8: Numo::Int8,
@@ -89,6 +286,20 @@ module Torch
89
286
  }
90
287
  end
91
288
 
289
+ def no_grad
290
+ previous_value = grad_enabled?
291
+ begin
292
+ _set_grad_enabled(false)
293
+ yield
294
+ ensure
295
+ _set_grad_enabled(previous_value)
296
+ end
297
+ end
298
+
299
+ def device(str)
300
+ Device.new(str)
301
+ end
302
+
92
303
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
93
304
 
94
305
  def arange(start, finish = nil, step = 1, **options)
@@ -158,8 +369,12 @@ module Torch
158
369
  data = [data].compact
159
370
  end
160
371
 
161
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
162
- options[:dtype] = :int64
372
+ if options[:dtype].nil?
373
+ if data.all? { |v| v.is_a?(Integer) }
374
+ options[:dtype] = :int64
375
+ elsif data.all? { |v| v == true || v == false }
376
+ options[:dtype] = :bool
377
+ end
163
378
  end
164
379
 
165
380
  _tensor(data, size, tensor_options(**options))
@@ -189,7 +404,7 @@ module Torch
189
404
  high = low
190
405
  low = 0
191
406
  end
192
- rand(input.size, like_options(input, options))
407
+ randint(low, high, input.size, like_options(input, options))
193
408
  end
194
409
 
195
410
  def randn_like(input, **options)
@@ -200,99 +415,8 @@ module Torch
200
415
  zeros(input.size, like_options(input, options))
201
416
  end
202
417
 
203
- # --- begin operations ---
204
-
205
- %w(add sub mul div remainder).each do |op|
206
- define_method(op) do |input, other, **options|
207
- execute_op(op, input, other, **options)
208
- end
209
- end
210
-
211
- def neg(input)
212
- _neg(input)
213
- end
214
-
215
- def no_grad
216
- previous_value = grad_enabled?
217
- begin
218
- _set_grad_enabled(false)
219
- yield
220
- ensure
221
- _set_grad_enabled(previous_value)
222
- end
223
- end
224
-
225
- # TODO support out
226
- def mean(input, dim = nil, keepdim: false)
227
- if dim
228
- _mean_dim(input, dim, keepdim)
229
- else
230
- _mean(input)
231
- end
232
- end
233
-
234
- # TODO support dtype
235
- def sum(input, dim = nil, keepdim: false)
236
- if dim
237
- _sum_dim(input, dim, keepdim)
238
- else
239
- _sum(input)
240
- end
241
- end
242
-
243
- def norm(input)
244
- _norm(input)
245
- end
246
-
247
- def pow(input, exponent)
248
- _pow(input, exponent)
249
- end
250
-
251
- def min(input)
252
- _min(input)
253
- end
254
-
255
- def max(input)
256
- _max(input)
257
- end
258
-
259
- def exp(input)
260
- _exp(input)
261
- end
262
-
263
- def log(input)
264
- _log(input)
265
- end
266
-
267
- def unsqueeze(input, dim)
268
- _unsqueeze(input, dim)
269
- end
270
-
271
- def dot(input, tensor)
272
- _dot(input, tensor)
273
- end
274
-
275
- def matmul(input, other)
276
- _matmul(input, other)
277
- end
278
-
279
418
  private
280
419
 
281
- def execute_op(op, input, other, out: nil)
282
- scalar = other.is_a?(Numeric)
283
- if out
284
- # TODO make work with scalars
285
- raise Error, "out not supported with scalar yet" if scalar
286
- send("_#{op}_out", out, input, other)
287
- else
288
- if scalar
289
- send("_#{op}_scalar", input, other)
290
- else
291
- send("_#{op}", input, other)
292
- end
293
- end
294
- end
295
-
296
420
  def tensor_size(size)
297
421
  size.flatten
298
422
  end
Binary file
@@ -1,41 +1,49 @@
1
1
  module Torch
2
2
  module Inspector
3
+ # TODO make more performance, especially when summarizing
4
+ # how? only read data that will be displayed
3
5
  def inspect
4
6
  data =
5
7
  if numel == 0
6
8
  "[]"
7
9
  elsif dim == 0
8
- to_a.first
10
+ item
9
11
  else
10
- values = to_a.flatten
11
- abs = values.select { |v| v != 0 }.map(&:abs)
12
- max = abs.max || 1
13
- min = abs.min || 1
14
-
15
- total = 0
16
- if values.any? { |v| v < 0 }
17
- total += 1
18
- end
12
+ summarize = numel > 1000
13
+
14
+ if dtype == :bool
15
+ fmt = "%s"
16
+ else
17
+ values = to_a.flatten
18
+ abs = values.select { |v| v != 0 }.map(&:abs)
19
+ max = abs.max || 1
20
+ min = abs.min || 1
19
21
 
20
- if floating_point?
21
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
+ total = 0
23
+ if values.any? { |v| v < 0 }
24
+ total += 1
25
+ end
26
+
27
+ if floating_point?
28
+ sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
22
29
 
23
- all_int = values.all? { |v| v == v.to_i }
24
- decimal = all_int ? 1 : 4
30
+ all_int = values.all? { |v| v.finite? && v == v.to_i }
31
+ decimal = all_int ? 1 : 4
25
32
 
26
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
33
+ total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
27
34
 
28
- if sci
29
- fmt = "%#{total}.4e"
35
+ if sci
36
+ fmt = "%#{total}.4e"
37
+ else
38
+ fmt = "%#{total}.#{decimal}f"
39
+ end
30
40
  else
31
- fmt = "%#{total}.#{decimal}f"
41
+ total += max.to_s.size
42
+ fmt = "%#{total}d"
32
43
  end
33
- else
34
- total += max.to_s.size
35
- fmt = "%#{total}d"
36
44
  end
37
45
 
38
- inspect_level(to_a, fmt, dim - 1)
46
+ inspect_level(to_a, fmt, dim - 1, 0, summarize)
39
47
  end
40
48
 
41
49
  attributes = []
@@ -51,11 +59,30 @@ module Torch
51
59
 
52
60
  private
53
61
 
54
- def inspect_level(arr, fmt, total, level = 0)
62
+ # TODO DRY code
63
+ def inspect_level(arr, fmt, total, level, summarize)
55
64
  if level == total
56
- "[#{arr.map { |v| fmt % v }.join(", ")}]"
65
+ cols =
66
+ if summarize && arr.size > 7
67
+ arr[0..2].map { |v| fmt % v } +
68
+ ["..."] +
69
+ arr[-3..-1].map { |v| fmt % v }
70
+ else
71
+ arr.map { |v| fmt % v }
72
+ end
73
+
74
+ "[#{cols.join(", ")}]"
57
75
  else
58
- "[#{arr.map { |row| inspect_level(row, fmt, total, level + 1) }.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
76
+ rows =
77
+ if summarize && arr.size > 7
78
+ arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
+ ["..."] +
80
+ arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
+ else
82
+ arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
+ end
84
+
85
+ "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
59
86
  end
60
87
  end
61
88
  end