torch-rb 0.1.0 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
data/lib/torch.rb
CHANGED
@@ -1,32 +1,130 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# native functions
|
5
|
+
require "torch/native/generator"
|
6
|
+
require "torch/native/parser"
|
7
|
+
require "torch/native/dispatcher"
|
8
|
+
|
4
9
|
# modules
|
5
10
|
require "torch/inspector"
|
6
11
|
require "torch/tensor"
|
7
12
|
require "torch/version"
|
8
13
|
|
9
|
-
#
|
14
|
+
# optim
|
15
|
+
require "torch/optim/optimizer"
|
16
|
+
require "torch/optim/adadelta"
|
17
|
+
require "torch/optim/adagrad"
|
18
|
+
require "torch/optim/adam"
|
19
|
+
require "torch/optim/adamax"
|
20
|
+
require "torch/optim/adamw"
|
21
|
+
require "torch/optim/asgd"
|
22
|
+
require "torch/optim/rmsprop"
|
23
|
+
require "torch/optim/rprop"
|
24
|
+
require "torch/optim/sgd"
|
25
|
+
|
26
|
+
# optim lr_scheduler
|
27
|
+
require "torch/optim/lr_scheduler/lr_scheduler"
|
28
|
+
require "torch/optim/lr_scheduler/step_lr"
|
29
|
+
|
30
|
+
# nn parameters
|
31
|
+
require "torch/nn/parameter"
|
32
|
+
|
33
|
+
# nn containers
|
10
34
|
require "torch/nn/module"
|
11
|
-
require "torch/nn/
|
35
|
+
require "torch/nn/sequential"
|
36
|
+
|
37
|
+
# nn convolution layers
|
38
|
+
require "torch/nn/convnd"
|
12
39
|
require "torch/nn/conv2d"
|
13
|
-
|
40
|
+
|
41
|
+
# nn pooling layers
|
42
|
+
require "torch/nn/max_poolnd"
|
43
|
+
require "torch/nn/max_pool2d"
|
44
|
+
require "torch/nn/avg_poolnd"
|
45
|
+
require "torch/nn/avg_pool2d"
|
46
|
+
|
47
|
+
# nn recurrent layers
|
48
|
+
require "torch/nn/rnn_base"
|
49
|
+
require "torch/nn/rnn"
|
50
|
+
|
51
|
+
# nn linear layers
|
52
|
+
require "torch/nn/bilinear"
|
53
|
+
require "torch/nn/identity"
|
14
54
|
require "torch/nn/linear"
|
15
|
-
|
16
|
-
|
55
|
+
|
56
|
+
# nn dropout layers
|
57
|
+
require "torch/nn/dropoutnd"
|
58
|
+
require "torch/nn/alpha_dropout"
|
59
|
+
require "torch/nn/dropout"
|
60
|
+
require "torch/nn/dropout2d"
|
61
|
+
require "torch/nn/dropout3d"
|
62
|
+
require "torch/nn/feature_alpha_dropout"
|
63
|
+
|
64
|
+
# nn activations
|
65
|
+
require "torch/nn/leaky_relu"
|
66
|
+
require "torch/nn/prelu"
|
17
67
|
require "torch/nn/relu"
|
68
|
+
require "torch/nn/sigmoid"
|
69
|
+
require "torch/nn/softplus"
|
70
|
+
|
71
|
+
# nn activations other
|
72
|
+
require "torch/nn/log_softmax"
|
73
|
+
require "torch/nn/softmax"
|
74
|
+
require "torch/nn/softmax2d"
|
75
|
+
require "torch/nn/softmin"
|
76
|
+
|
77
|
+
# nn sparse layers
|
78
|
+
require "torch/nn/embedding"
|
79
|
+
require "torch/nn/embedding_bag"
|
80
|
+
|
81
|
+
# nn distance functions
|
82
|
+
require "torch/nn/cosine_similarity"
|
83
|
+
require "torch/nn/pairwise_distance"
|
84
|
+
|
85
|
+
# nn loss functions
|
86
|
+
require "torch/nn/loss"
|
87
|
+
require "torch/nn/weighted_loss"
|
88
|
+
require "torch/nn/bce_loss"
|
89
|
+
require "torch/nn/bce_with_logits_loss"
|
90
|
+
require "torch/nn/cosine_embedding_loss"
|
91
|
+
require "torch/nn/cross_entropy_loss"
|
92
|
+
require "torch/nn/ctc_loss"
|
93
|
+
require "torch/nn/hinge_embedding_loss"
|
94
|
+
require "torch/nn/kl_div_loss"
|
95
|
+
require "torch/nn/l1_loss"
|
96
|
+
require "torch/nn/margin_ranking_loss"
|
18
97
|
require "torch/nn/mse_loss"
|
98
|
+
require "torch/nn/multi_label_margin_loss"
|
99
|
+
require "torch/nn/multi_label_soft_margin_loss"
|
100
|
+
require "torch/nn/multi_margin_loss"
|
101
|
+
require "torch/nn/nll_loss"
|
102
|
+
require "torch/nn/poisson_nll_loss"
|
103
|
+
require "torch/nn/smooth_l1_loss"
|
104
|
+
require "torch/nn/soft_margin_loss"
|
105
|
+
require "torch/nn/triplet_margin_loss"
|
106
|
+
|
107
|
+
# nn other
|
108
|
+
require "torch/nn/functional"
|
109
|
+
require "torch/nn/init"
|
19
110
|
|
20
111
|
# utils
|
21
112
|
require "torch/utils/data/data_loader"
|
22
113
|
require "torch/utils/data/tensor_dataset"
|
23
114
|
|
115
|
+
# random
|
116
|
+
require "torch/random"
|
117
|
+
|
24
118
|
module Torch
|
25
119
|
class Error < StandardError; end
|
120
|
+
class NotImplementedYet < StandardError
|
121
|
+
def message
|
122
|
+
"This feature has not been implemented yet. Consider submitting a PR."
|
123
|
+
end
|
124
|
+
end
|
26
125
|
|
27
126
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
28
127
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
29
|
-
# complex and quantized types not supported by PyTorch yet
|
30
128
|
DTYPE_TO_ENUM = {
|
31
129
|
uint8: 0,
|
32
130
|
int8: 1,
|
@@ -42,22 +140,61 @@ module Torch
|
|
42
140
|
float32: 6,
|
43
141
|
double: 7,
|
44
142
|
float64: 7,
|
45
|
-
|
46
|
-
|
47
|
-
|
143
|
+
complex_half: 8,
|
144
|
+
complex_float: 9,
|
145
|
+
complex_double: 10,
|
48
146
|
bool: 11,
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
147
|
+
qint8: 12,
|
148
|
+
quint8: 13,
|
149
|
+
qint32: 14,
|
150
|
+
bfloat16: 15
|
53
151
|
}
|
54
152
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
55
153
|
|
154
|
+
def self._make_tensor_class(dtype, cuda = false)
|
155
|
+
cls = Class.new
|
156
|
+
device = cuda ? "cuda" : "cpu"
|
157
|
+
cls.define_singleton_method("new") do |*args|
|
158
|
+
if args.size == 1 && args.first.is_a?(Tensor)
|
159
|
+
args.first.send(dtype).to(device)
|
160
|
+
elsif args.size == 1 && args.first.is_a?(Array)
|
161
|
+
Torch.tensor(args.first, dtype: dtype, device: device)
|
162
|
+
else
|
163
|
+
Torch.empty(*args, dtype: dtype, device: device)
|
164
|
+
end
|
165
|
+
end
|
166
|
+
cls
|
167
|
+
end
|
168
|
+
|
169
|
+
FloatTensor = _make_tensor_class(:float32)
|
170
|
+
DoubleTensor = _make_tensor_class(:float64)
|
171
|
+
HalfTensor = _make_tensor_class(:float16)
|
172
|
+
ByteTensor = _make_tensor_class(:uint8)
|
173
|
+
CharTensor = _make_tensor_class(:int8)
|
174
|
+
ShortTensor = _make_tensor_class(:int16)
|
175
|
+
IntTensor = _make_tensor_class(:int32)
|
176
|
+
LongTensor = _make_tensor_class(:int64)
|
177
|
+
BoolTensor = _make_tensor_class(:bool)
|
178
|
+
|
179
|
+
CUDA::FloatTensor = _make_tensor_class(:float32, true)
|
180
|
+
CUDA::DoubleTensor = _make_tensor_class(:float64, true)
|
181
|
+
CUDA::HalfTensor = _make_tensor_class(:float16, true)
|
182
|
+
CUDA::ByteTensor = _make_tensor_class(:uint8, true)
|
183
|
+
CUDA::CharTensor = _make_tensor_class(:int8, true)
|
184
|
+
CUDA::ShortTensor = _make_tensor_class(:int16, true)
|
185
|
+
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
186
|
+
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
187
|
+
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
188
|
+
|
56
189
|
class << self
|
57
190
|
# Torch.float, Torch.long, etc
|
58
|
-
DTYPE_TO_ENUM.each_key do |
|
59
|
-
define_method(
|
60
|
-
|
191
|
+
DTYPE_TO_ENUM.each_key do |dtype|
|
192
|
+
define_method(dtype) do
|
193
|
+
dtype
|
194
|
+
end
|
195
|
+
|
196
|
+
Tensor.define_method(dtype) do
|
197
|
+
type(dtype)
|
61
198
|
end
|
62
199
|
end
|
63
200
|
|
@@ -67,17 +204,26 @@ module Torch
|
|
67
204
|
obj.is_a?(Tensor)
|
68
205
|
end
|
69
206
|
|
70
|
-
# TODO don't copy
|
71
207
|
def from_numo(ndarray)
|
72
208
|
dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
|
73
209
|
raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
|
74
|
-
|
210
|
+
options = tensor_options(device: "cpu", dtype: dtype[0])
|
211
|
+
# TODO pass pointer to array instead of creating string
|
212
|
+
str = ndarray.to_string
|
213
|
+
tensor = _from_blob(str, ndarray.shape, options)
|
214
|
+
# from_blob does not own the data, so we need to keep
|
215
|
+
# a reference to it for duration of tensor
|
216
|
+
# can remove when passing pointer directly
|
217
|
+
tensor.instance_variable_set("@_numo_str", str)
|
218
|
+
tensor
|
75
219
|
end
|
76
220
|
|
77
221
|
# private
|
78
222
|
# use method for cases when Numo not available
|
79
223
|
# or available after Torch loaded
|
80
224
|
def _dtype_to_numo
|
225
|
+
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
226
|
+
|
81
227
|
{
|
82
228
|
uint8: Numo::UInt8,
|
83
229
|
int8: Numo::Int8,
|
@@ -89,6 +235,20 @@ module Torch
|
|
89
235
|
}
|
90
236
|
end
|
91
237
|
|
238
|
+
def no_grad
|
239
|
+
previous_value = grad_enabled?
|
240
|
+
begin
|
241
|
+
_set_grad_enabled(false)
|
242
|
+
yield
|
243
|
+
ensure
|
244
|
+
_set_grad_enabled(previous_value)
|
245
|
+
end
|
246
|
+
end
|
247
|
+
|
248
|
+
def device(str)
|
249
|
+
Device.new(str)
|
250
|
+
end
|
251
|
+
|
92
252
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
93
253
|
|
94
254
|
def arange(start, finish = nil, step = 1, **options)
|
@@ -158,8 +318,12 @@ module Torch
|
|
158
318
|
data = [data].compact
|
159
319
|
end
|
160
320
|
|
161
|
-
if options[:dtype].nil?
|
162
|
-
|
321
|
+
if options[:dtype].nil?
|
322
|
+
if data.all? { |v| v.is_a?(Integer) }
|
323
|
+
options[:dtype] = :int64
|
324
|
+
elsif data.all? { |v| v == true || v == false }
|
325
|
+
options[:dtype] = :bool
|
326
|
+
end
|
163
327
|
end
|
164
328
|
|
165
329
|
_tensor(data, size, tensor_options(**options))
|
@@ -189,7 +353,7 @@ module Torch
|
|
189
353
|
high = low
|
190
354
|
low = 0
|
191
355
|
end
|
192
|
-
|
356
|
+
randint(low, high, input.size, like_options(input, options))
|
193
357
|
end
|
194
358
|
|
195
359
|
def randn_like(input, **options)
|
@@ -202,26 +366,6 @@ module Torch
|
|
202
366
|
|
203
367
|
# --- begin operations ---
|
204
368
|
|
205
|
-
%w(add sub mul div remainder).each do |op|
|
206
|
-
define_method(op) do |input, other, **options|
|
207
|
-
execute_op(op, input, other, **options)
|
208
|
-
end
|
209
|
-
end
|
210
|
-
|
211
|
-
def neg(input)
|
212
|
-
_neg(input)
|
213
|
-
end
|
214
|
-
|
215
|
-
def no_grad
|
216
|
-
previous_value = grad_enabled?
|
217
|
-
begin
|
218
|
-
_set_grad_enabled(false)
|
219
|
-
yield
|
220
|
-
ensure
|
221
|
-
_set_grad_enabled(previous_value)
|
222
|
-
end
|
223
|
-
end
|
224
|
-
|
225
369
|
# TODO support out
|
226
370
|
def mean(input, dim = nil, keepdim: false)
|
227
371
|
if dim
|
@@ -240,59 +384,30 @@ module Torch
|
|
240
384
|
end
|
241
385
|
end
|
242
386
|
|
243
|
-
def
|
244
|
-
|
387
|
+
def topk(input, k)
|
388
|
+
_topk(input, k)
|
245
389
|
end
|
246
390
|
|
247
|
-
def
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
def max(input)
|
256
|
-
_max(input)
|
257
|
-
end
|
258
|
-
|
259
|
-
def exp(input)
|
260
|
-
_exp(input)
|
261
|
-
end
|
262
|
-
|
263
|
-
def log(input)
|
264
|
-
_log(input)
|
265
|
-
end
|
266
|
-
|
267
|
-
def unsqueeze(input, dim)
|
268
|
-
_unsqueeze(input, dim)
|
391
|
+
def max(input, dim = nil, keepdim: false, out: nil)
|
392
|
+
if dim
|
393
|
+
raise NotImplementedYet unless out
|
394
|
+
_max_out(out[0], out[1], input, dim, keepdim)
|
395
|
+
else
|
396
|
+
_max(input)
|
397
|
+
end
|
269
398
|
end
|
270
399
|
|
271
|
-
|
272
|
-
|
400
|
+
# TODO make dim keyword argument
|
401
|
+
def log_softmax(input, dim)
|
402
|
+
_log_softmax(input, dim)
|
273
403
|
end
|
274
404
|
|
275
|
-
def
|
276
|
-
|
405
|
+
def softmax(input, dim: nil)
|
406
|
+
_softmax(input, dim)
|
277
407
|
end
|
278
408
|
|
279
409
|
private
|
280
410
|
|
281
|
-
def execute_op(op, input, other, out: nil)
|
282
|
-
scalar = other.is_a?(Numeric)
|
283
|
-
if out
|
284
|
-
# TODO make work with scalars
|
285
|
-
raise Error, "out not supported with scalar yet" if scalar
|
286
|
-
send("_#{op}_out", out, input, other)
|
287
|
-
else
|
288
|
-
if scalar
|
289
|
-
send("_#{op}_scalar", input, other)
|
290
|
-
else
|
291
|
-
send("_#{op}", input, other)
|
292
|
-
end
|
293
|
-
end
|
294
|
-
end
|
295
|
-
|
296
411
|
def tensor_size(size)
|
297
412
|
size.flatten
|
298
413
|
end
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,41 +1,49 @@
|
|
1
1
|
module Torch
|
2
2
|
module Inspector
|
3
|
+
# TODO make more performance, especially when summarizing
|
4
|
+
# how? only read data that will be displayed
|
3
5
|
def inspect
|
4
6
|
data =
|
5
7
|
if numel == 0
|
6
8
|
"[]"
|
7
9
|
elsif dim == 0
|
8
|
-
|
10
|
+
item
|
9
11
|
else
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
12
|
+
summarize = numel > 1000
|
13
|
+
|
14
|
+
if dtype == :bool
|
15
|
+
fmt = "%s"
|
16
|
+
else
|
17
|
+
values = to_a.flatten
|
18
|
+
abs = values.select { |v| v != 0 }.map(&:abs)
|
19
|
+
max = abs.max || 1
|
20
|
+
min = abs.min || 1
|
19
21
|
|
20
|
-
|
21
|
-
|
22
|
+
total = 0
|
23
|
+
if values.any? { |v| v < 0 }
|
24
|
+
total += 1
|
25
|
+
end
|
26
|
+
|
27
|
+
if floating_point?
|
28
|
+
sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
|
22
29
|
|
23
|
-
|
24
|
-
|
30
|
+
all_int = values.all? { |v| v.finite? && v == v.to_i }
|
31
|
+
decimal = all_int ? 1 : 4
|
25
32
|
|
26
|
-
|
33
|
+
total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
|
27
34
|
|
28
|
-
|
29
|
-
|
35
|
+
if sci
|
36
|
+
fmt = "%#{total}.4e"
|
37
|
+
else
|
38
|
+
fmt = "%#{total}.#{decimal}f"
|
39
|
+
end
|
30
40
|
else
|
31
|
-
|
41
|
+
total += max.to_s.size
|
42
|
+
fmt = "%#{total}d"
|
32
43
|
end
|
33
|
-
else
|
34
|
-
total += max.to_s.size
|
35
|
-
fmt = "%#{total}d"
|
36
44
|
end
|
37
45
|
|
38
|
-
inspect_level(to_a, fmt, dim - 1)
|
46
|
+
inspect_level(to_a, fmt, dim - 1, 0, summarize)
|
39
47
|
end
|
40
48
|
|
41
49
|
attributes = []
|
@@ -51,11 +59,30 @@ module Torch
|
|
51
59
|
|
52
60
|
private
|
53
61
|
|
54
|
-
|
62
|
+
# TODO DRY code
|
63
|
+
def inspect_level(arr, fmt, total, level, summarize)
|
55
64
|
if level == total
|
56
|
-
|
65
|
+
cols =
|
66
|
+
if summarize && arr.size > 7
|
67
|
+
arr[0..2].map { |v| fmt % v } +
|
68
|
+
["..."] +
|
69
|
+
arr[-3..-1].map { |v| fmt % v }
|
70
|
+
else
|
71
|
+
arr.map { |v| fmt % v }
|
72
|
+
end
|
73
|
+
|
74
|
+
"[#{cols.join(", ")}]"
|
57
75
|
else
|
58
|
-
|
76
|
+
rows =
|
77
|
+
if summarize && arr.size > 7
|
78
|
+
arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
|
79
|
+
["..."] +
|
80
|
+
arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
|
81
|
+
else
|
82
|
+
arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
|
83
|
+
end
|
84
|
+
|
85
|
+
"[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
|
59
86
|
end
|
60
87
|
end
|
61
88
|
end
|