torch-rb 0.2.5 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +32 -2
- data/README.md +4 -1
- data/ext/torch/ext.cpp +23 -6
- data/ext/torch/extconf.rb +3 -4
- data/lib/torch.rb +14 -5
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/nn/module.rb +5 -2
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +46 -15
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 97908e85a67729120f763bb4140323505b77a831d5648e9d2d0961259e3d300c
|
4
|
+
data.tar.gz: f366548f9880dac7dffce6305e192f75a7467526ae55ae13af05d355918375ba
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cc32ddbc43131175452a8b62df5d1eac6bc8450eea174018affd5cd073f81e2c9825d613014c62c5f8137cf5dddd1ab6ab6de60a4b3a67a757387446dbc1efad
|
7
|
+
data.tar.gz: c322e0b7ec7f03f12311d737034dad45037d2ad7710974e24250e11f4a0db14e221e870ddc52c0f2b723476be6f41fca8a8719068b0d0b7d8974d2080e9be6dc
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,33 @@
|
|
1
|
+
## 0.3.2 (2020-08-24)
|
2
|
+
|
3
|
+
- Added `enable_grad` method
|
4
|
+
- Added `random_split` method
|
5
|
+
- Added `collate_fn` option to `DataLoader`
|
6
|
+
- Added `grad=` method to `Tensor`
|
7
|
+
- Fixed error with `grad` method when empty
|
8
|
+
- Fixed `EmbeddingBag`
|
9
|
+
|
10
|
+
## 0.3.1 (2020-08-17)
|
11
|
+
|
12
|
+
- Added `create_graph` and `retain_graph` options to `backward` method
|
13
|
+
- Fixed error when `set` not required
|
14
|
+
|
15
|
+
## 0.3.0 (2020-07-29)
|
16
|
+
|
17
|
+
- Updated LibTorch to 1.6.0
|
18
|
+
- Removed `state_dict` method from optimizers until `load_state_dict` is implemented
|
19
|
+
|
20
|
+
## 0.2.7 (2020-06-29)
|
21
|
+
|
22
|
+
- Made tensors enumerable
|
23
|
+
- Improved performance of `inspect` method
|
24
|
+
|
25
|
+
## 0.2.6 (2020-06-29)
|
26
|
+
|
27
|
+
- Added support for indexing with tensors
|
28
|
+
- Added `contiguous` methods
|
29
|
+
- Fixed named parameters for nested parameters
|
30
|
+
|
1
31
|
## 0.2.5 (2020-06-07)
|
2
32
|
|
3
33
|
- Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
|
@@ -32,7 +62,7 @@
|
|
32
62
|
## 0.2.0 (2020-04-22)
|
33
63
|
|
34
64
|
- No longer experimental
|
35
|
-
- Updated
|
65
|
+
- Updated LibTorch to 1.5.0
|
36
66
|
- Added support for GPUs and OpenMP
|
37
67
|
- Added adaptive pooling layers
|
38
68
|
- Tensor `dtype` is now based on Numo type for `Torch.tensor`
|
@@ -41,7 +71,7 @@
|
|
41
71
|
|
42
72
|
## 0.1.8 (2020-01-17)
|
43
73
|
|
44
|
-
- Updated
|
74
|
+
- Updated LibTorch to 1.4.0
|
45
75
|
|
46
76
|
## 0.1.7 (2020-01-10)
|
47
77
|
|
data/README.md
CHANGED
@@ -42,6 +42,8 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
|
|
42
42
|
- Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
|
43
43
|
- Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
|
44
44
|
|
45
|
+
You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
|
46
|
+
|
45
47
|
## Tutorial
|
46
48
|
|
47
49
|
Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
|
@@ -409,7 +411,8 @@ Here’s the list of compatible versions.
|
|
409
411
|
|
410
412
|
Torch.rb | LibTorch
|
411
413
|
--- | ---
|
412
|
-
0.
|
414
|
+
0.3.0-0.3.1 | 1.6.0
|
415
|
+
0.2.0-0.2.7 | 1.5.0-1.5.1
|
413
416
|
0.1.8 | 1.4.0
|
414
417
|
0.1.0-0.1.7 | 1.3.1
|
415
418
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -48,13 +48,14 @@ void Init_ext()
|
|
48
48
|
.define_singleton_method(
|
49
49
|
"initial_seed",
|
50
50
|
*[]() {
|
51
|
-
return at::detail::getDefaultCPUGenerator()
|
51
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
52
52
|
})
|
53
53
|
.define_singleton_method(
|
54
54
|
"seed",
|
55
55
|
*[]() {
|
56
56
|
// TODO set for CUDA when available
|
57
|
-
|
57
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
58
|
+
return generator.seed();
|
58
59
|
});
|
59
60
|
|
60
61
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
@@ -329,6 +330,11 @@ void Init_ext()
|
|
329
330
|
.define_method("numel", &torch::Tensor::numel)
|
330
331
|
.define_method("element_size", &torch::Tensor::element_size)
|
331
332
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
333
|
+
.define_method(
|
334
|
+
"contiguous?",
|
335
|
+
*[](Tensor& self) {
|
336
|
+
return self.is_contiguous();
|
337
|
+
})
|
332
338
|
.define_method(
|
333
339
|
"addcmul!",
|
334
340
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -346,13 +352,19 @@ void Init_ext()
|
|
346
352
|
})
|
347
353
|
.define_method(
|
348
354
|
"_backward",
|
349
|
-
*[](Tensor& self,
|
350
|
-
return
|
355
|
+
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
356
|
+
return self.backward(gradient, create_graph, retain_graph);
|
351
357
|
})
|
352
358
|
.define_method(
|
353
359
|
"grad",
|
354
360
|
*[](Tensor& self) {
|
355
|
-
|
361
|
+
auto grad = self.grad();
|
362
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
363
|
+
})
|
364
|
+
.define_method(
|
365
|
+
"grad=",
|
366
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
367
|
+
self.grad() = grad;
|
356
368
|
})
|
357
369
|
.define_method(
|
358
370
|
"_dtype",
|
@@ -455,7 +467,7 @@ void Init_ext()
|
|
455
467
|
.define_singleton_method(
|
456
468
|
"_make_subclass",
|
457
469
|
*[](Tensor& rd, bool requires_grad) {
|
458
|
-
auto data =
|
470
|
+
auto data = rd.detach();
|
459
471
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
460
472
|
auto var = data.set_requires_grad(requires_grad);
|
461
473
|
return Parameter(std::move(var));
|
@@ -574,6 +586,11 @@ void Init_ext()
|
|
574
586
|
*[](Parameter& self) {
|
575
587
|
auto grad = self.grad();
|
576
588
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
589
|
+
})
|
590
|
+
.define_method(
|
591
|
+
"grad=",
|
592
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
593
|
+
self.grad() = grad;
|
577
594
|
});
|
578
595
|
|
579
596
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
data/ext/torch/extconf.rb
CHANGED
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
8
|
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
|
-
|
11
|
-
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
10
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
12
11
|
|
13
12
|
# check omp first
|
14
13
|
if have_library("omp") || have_library("gomp")
|
15
14
|
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
-
$CXXFLAGS += " -Xclang" if
|
15
|
+
$CXXFLAGS += " -Xclang" if apple_clang
|
17
16
|
$CXXFLAGS += " -fopenmp"
|
18
17
|
end
|
19
18
|
|
20
|
-
if
|
19
|
+
if apple_clang
|
21
20
|
# silence ruby/intern.h warning
|
22
21
|
$CXXFLAGS += " -Wno-deprecated-register"
|
23
22
|
|
data/lib/torch.rb
CHANGED
@@ -4,6 +4,7 @@ require "torch/ext"
|
|
4
4
|
# stdlib
|
5
5
|
require "fileutils"
|
6
6
|
require "net/http"
|
7
|
+
require "set"
|
7
8
|
require "tmpdir"
|
8
9
|
|
9
10
|
# native functions
|
@@ -178,8 +179,10 @@ require "torch/nn/functional"
|
|
178
179
|
require "torch/nn/init"
|
179
180
|
|
180
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
181
183
|
require "torch/utils/data/data_loader"
|
182
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
183
186
|
require "torch/utils/data/tensor_dataset"
|
184
187
|
|
185
188
|
# hub
|
@@ -315,6 +318,16 @@ module Torch
|
|
315
318
|
end
|
316
319
|
end
|
317
320
|
|
321
|
+
def enable_grad
|
322
|
+
previous_value = grad_enabled?
|
323
|
+
begin
|
324
|
+
_set_grad_enabled(true)
|
325
|
+
yield
|
326
|
+
ensure
|
327
|
+
_set_grad_enabled(previous_value)
|
328
|
+
end
|
329
|
+
end
|
330
|
+
|
318
331
|
def device(str)
|
319
332
|
Device.new(str)
|
320
333
|
end
|
@@ -470,11 +483,7 @@ module Torch
|
|
470
483
|
when nil
|
471
484
|
IValue.new
|
472
485
|
when Array
|
473
|
-
|
474
|
-
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
475
|
-
else
|
476
|
-
raise Error, "Unknown list type"
|
477
|
-
end
|
486
|
+
IValue.from_list(obj.map { |v| to_ivalue(v) })
|
478
487
|
else
|
479
488
|
raise Error, "Unknown type: #{obj.class.name}"
|
480
489
|
end
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,89 +1,264 @@
|
|
1
|
+
# mirrors _tensor_str.py
|
1
2
|
module Torch
|
2
3
|
module Inspector
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
4
|
+
PRINT_OPTS = {
|
5
|
+
precision: 4,
|
6
|
+
threshold: 1000,
|
7
|
+
edgeitems: 3,
|
8
|
+
linewidth: 80,
|
9
|
+
sci_mode: nil
|
10
|
+
}
|
11
|
+
|
12
|
+
class Formatter
|
13
|
+
def initialize(tensor)
|
14
|
+
@floating_dtype = tensor.floating_point?
|
15
|
+
@complex_dtype = tensor.complex?
|
16
|
+
@int_mode = true
|
17
|
+
@sci_mode = false
|
18
|
+
@max_width = 1
|
19
|
+
|
20
|
+
tensor_view = Torch.no_grad { tensor.reshape(-1) }
|
21
|
+
|
22
|
+
if !@floating_dtype
|
23
|
+
tensor_view.each do |value|
|
24
|
+
value_str = value.item.to_s
|
25
|
+
@max_width = [@max_width, value_str.length].max
|
26
|
+
end
|
11
27
|
else
|
12
|
-
|
28
|
+
nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
|
29
|
+
|
30
|
+
# no valid number, do nothing
|
31
|
+
return if nonzero_finite_vals.numel == 0
|
32
|
+
|
33
|
+
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
34
|
+
nonzero_finite_abs = nonzero_finite_vals.abs.double
|
35
|
+
nonzero_finite_min = nonzero_finite_abs.min.double
|
36
|
+
nonzero_finite_max = nonzero_finite_abs.max.double
|
37
|
+
|
38
|
+
nonzero_finite_vals.each do |value|
|
39
|
+
if value.item != value.item.ceil
|
40
|
+
@int_mode = false
|
41
|
+
break
|
42
|
+
end
|
43
|
+
end
|
13
44
|
|
14
|
-
if
|
15
|
-
|
45
|
+
if @int_mode
|
46
|
+
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
47
|
+
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
48
|
+
if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
|
49
|
+
@sci_mode = true
|
50
|
+
nonzero_finite_vals.each do |value|
|
51
|
+
value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
|
52
|
+
@max_width = [@max_width, value_str.length].max
|
53
|
+
end
|
54
|
+
else
|
55
|
+
nonzero_finite_vals.each do |value|
|
56
|
+
value_str = "%.0f" % value.item
|
57
|
+
@max_width = [@max_width, value_str.length + 1].max
|
58
|
+
end
|
59
|
+
end
|
16
60
|
else
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
61
|
+
# Check if scientific representation should be used.
|
62
|
+
if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
|
63
|
+
@sci_mode = true
|
64
|
+
nonzero_finite_vals.each do |value|
|
65
|
+
value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
|
66
|
+
@max_width = [@max_width, value_str.length].max
|
67
|
+
end
|
68
|
+
else
|
69
|
+
nonzero_finite_vals.each do |value|
|
70
|
+
value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
|
71
|
+
@max_width = [@max_width, value_str.length].max
|
72
|
+
end
|
25
73
|
end
|
74
|
+
end
|
75
|
+
end
|
26
76
|
|
27
|
-
|
28
|
-
|
77
|
+
@sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
|
78
|
+
end
|
29
79
|
|
30
|
-
|
31
|
-
|
80
|
+
def width
|
81
|
+
@max_width
|
82
|
+
end
|
32
83
|
|
33
|
-
|
84
|
+
def format(value)
|
85
|
+
value = value.item
|
34
86
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
fmt = "%#{total}d"
|
87
|
+
if @floating_dtype
|
88
|
+
if @sci_mode
|
89
|
+
ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
|
90
|
+
elsif @int_mode
|
91
|
+
ret = String.new("%.0f" % value)
|
92
|
+
unless value.infinite? || value.nan?
|
93
|
+
ret += "."
|
43
94
|
end
|
95
|
+
else
|
96
|
+
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
44
97
|
end
|
98
|
+
elsif @complex_dtype
|
99
|
+
p = PRINT_OPTS[:precision]
|
100
|
+
raise NotImplementedYet
|
101
|
+
else
|
102
|
+
ret = value.to_s
|
103
|
+
end
|
104
|
+
# Ruby throws error when negative, Python doesn't
|
105
|
+
" " * [@max_width - ret.size, 0].max + ret
|
106
|
+
end
|
107
|
+
end
|
108
|
+
|
109
|
+
def inspect
|
110
|
+
Torch.no_grad do
|
111
|
+
str_intern(self)
|
112
|
+
end
|
113
|
+
rescue => e
|
114
|
+
# prevent stack error
|
115
|
+
puts e.backtrace.join("\n")
|
116
|
+
"Error inspecting tensor: #{e.inspect}"
|
117
|
+
end
|
118
|
+
|
119
|
+
private
|
120
|
+
|
121
|
+
# TODO update
|
122
|
+
def str_intern(slf)
|
123
|
+
prefix = "tensor("
|
124
|
+
indent = prefix.length
|
125
|
+
suffixes = []
|
126
|
+
|
127
|
+
has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
|
128
|
+
|
129
|
+
if slf.numel == 0 && !slf.sparse?
|
130
|
+
# Explicitly print the shape if it is not (0,), to match NumPy behavior
|
131
|
+
if slf.dim != 1
|
132
|
+
suffixes << "size: #{shape.inspect}"
|
133
|
+
end
|
45
134
|
|
46
|
-
|
135
|
+
# In an empty tensor, there are no elements to infer if the dtype
|
136
|
+
# should be int64, so it must be shown explicitly.
|
137
|
+
if slf.dtype != :int64
|
138
|
+
suffixes << "dtype: #{slf.dtype.inspect}"
|
47
139
|
end
|
140
|
+
tensor_str = "[]"
|
141
|
+
else
|
142
|
+
if !has_default_dtype
|
143
|
+
suffixes << "dtype: #{slf.dtype.inspect}"
|
144
|
+
end
|
145
|
+
|
146
|
+
if slf.layout != :strided
|
147
|
+
tensor_str = tensor_str(slf.to_dense, indent)
|
148
|
+
else
|
149
|
+
tensor_str = tensor_str(slf, indent)
|
150
|
+
end
|
151
|
+
end
|
48
152
|
|
49
|
-
|
50
|
-
|
51
|
-
attributes << "requires_grad: true"
|
153
|
+
if slf.layout != :strided
|
154
|
+
suffixes << "layout: #{slf.layout.inspect}"
|
52
155
|
end
|
53
|
-
|
54
|
-
|
156
|
+
|
157
|
+
# TODO show grad_fn
|
158
|
+
if slf.requires_grad?
|
159
|
+
suffixes << "requires_grad: true"
|
55
160
|
end
|
56
161
|
|
57
|
-
|
162
|
+
add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
|
58
163
|
end
|
59
164
|
|
60
|
-
|
165
|
+
def add_suffixes(tensor_str, suffixes, indent, force_newline)
|
166
|
+
tensor_strs = [tensor_str]
|
167
|
+
# rfind in Python returns -1 when not found
|
168
|
+
last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
|
169
|
+
suffixes.each do |suffix|
|
170
|
+
suffix_len = suffix.length
|
171
|
+
if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
|
172
|
+
tensor_strs << ",\n" + " " * indent + suffix
|
173
|
+
last_line_len = indent + suffix_len
|
174
|
+
force_newline = false
|
175
|
+
else
|
176
|
+
tensor_strs.append(", " + suffix)
|
177
|
+
last_line_len += suffix_len + 2
|
178
|
+
end
|
179
|
+
end
|
180
|
+
tensor_strs.append(")")
|
181
|
+
tensor_strs.join("")
|
182
|
+
end
|
61
183
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
184
|
+
def tensor_str(slf, indent)
|
185
|
+
return "[]" if slf.numel == 0
|
186
|
+
|
187
|
+
summarize = slf.numel > PRINT_OPTS[:threshold]
|
188
|
+
|
189
|
+
if slf.dtype == :float16 || slf.dtype == :bfloat16
|
190
|
+
slf = slf.float
|
191
|
+
end
|
192
|
+
formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
|
193
|
+
tensor_str_with_formatter(slf, indent, formatter, summarize)
|
194
|
+
end
|
195
|
+
|
196
|
+
def summarized_data(slf)
|
197
|
+
edgeitems = PRINT_OPTS[:edgeitems]
|
73
198
|
|
74
|
-
|
199
|
+
dim = slf.dim
|
200
|
+
if dim == 0
|
201
|
+
slf
|
202
|
+
elsif dim == 1
|
203
|
+
if size(0) > 2 * edgeitems
|
204
|
+
Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
|
205
|
+
else
|
206
|
+
slf
|
207
|
+
end
|
208
|
+
elsif slf.size(0) > 2 * edgeitems
|
209
|
+
start = edgeitems.times.map { |i| slf[i] }
|
210
|
+
finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
|
211
|
+
Torch.stack((start + finish).map { |x| summarized_data(x) })
|
75
212
|
else
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
213
|
+
Torch.stack(slf.map { |x| summarized_data(x) })
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def tensor_str_with_formatter(slf, indent, formatter, summarize)
|
218
|
+
edgeitems = PRINT_OPTS[:edgeitems]
|
219
|
+
|
220
|
+
dim = slf.dim
|
84
221
|
|
85
|
-
|
222
|
+
return scalar_str(slf, formatter) if dim == 0
|
223
|
+
return vector_str(slf, indent, formatter, summarize) if dim == 1
|
224
|
+
|
225
|
+
if summarize && slf.size(0) > 2 * edgeitems
|
226
|
+
slices = (
|
227
|
+
[edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
|
228
|
+
["..."] +
|
229
|
+
[((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
|
230
|
+
)
|
231
|
+
else
|
232
|
+
slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
|
86
233
|
end
|
234
|
+
|
235
|
+
tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
|
236
|
+
"[" + tensor_str + "]"
|
237
|
+
end
|
238
|
+
|
239
|
+
def scalar_str(slf, formatter)
|
240
|
+
formatter.format(slf)
|
241
|
+
end
|
242
|
+
|
243
|
+
def vector_str(slf, indent, formatter, summarize)
|
244
|
+
# length includes spaces and comma between elements
|
245
|
+
element_length = formatter.width + 2
|
246
|
+
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
247
|
+
char_per_line = element_length * elements_per_line
|
248
|
+
|
249
|
+
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
250
|
+
data = (
|
251
|
+
[slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
|
252
|
+
[" ..."] +
|
253
|
+
[slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
|
254
|
+
)
|
255
|
+
else
|
256
|
+
data = slf.map { |val| formatter.format(val) }
|
257
|
+
end
|
258
|
+
|
259
|
+
data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
|
260
|
+
lines = data_lines.map { |line| line.join(", ") }
|
261
|
+
"[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
|
87
262
|
end
|
88
263
|
end
|
89
264
|
end
|