torch-rb 0.2.6 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -2
- data/README.md +9 -2
- data/ext/torch/ext.cpp +49 -7
- data/ext/torch/extconf.rb +3 -4
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +25 -5
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/tensor.rb +60 -46
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f17982ebcf982b779b2c88dc0ef13a9c94b2e9a6a87007a0c3136ad5f2ef261b
|
4
|
+
data.tar.gz: d4ddddfc7cbb9baee0e117e36decfe8a757903f876ddcf510db4db53363f7adf
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 00ff39bb405350f0107974b1786bf14ac6ca558744f05653ee2f16445d46f66658f82dcfa069bd3a92b44c33ba642dd196fc29a21379f188111fd7e8648f5eab
|
7
|
+
data.tar.gz: b66d48682789f71032c1d928174829791df43a04e829ebc241cafab884bc076983323bdffbcb12f5785c6d4614c13d2fed32bb5f20106957a4d6326289b7a1ea
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,32 @@
|
|
1
|
+
## 0.3.3 (2020-08-25)
|
2
|
+
|
3
|
+
- Added spectral ops
|
4
|
+
- Fixed tensor indexing
|
5
|
+
|
6
|
+
## 0.3.2 (2020-08-24)
|
7
|
+
|
8
|
+
- Added `enable_grad` method
|
9
|
+
- Added `random_split` method
|
10
|
+
- Added `collate_fn` option to `DataLoader`
|
11
|
+
- Added `grad=` method to `Tensor`
|
12
|
+
- Fixed error with `grad` method when empty
|
13
|
+
- Fixed `EmbeddingBag`
|
14
|
+
|
15
|
+
## 0.3.1 (2020-08-17)
|
16
|
+
|
17
|
+
- Added `create_graph` and `retain_graph` options to `backward` method
|
18
|
+
- Fixed error when `set` not required
|
19
|
+
|
20
|
+
## 0.3.0 (2020-07-29)
|
21
|
+
|
22
|
+
- Updated LibTorch to 1.6.0
|
23
|
+
- Removed `state_dict` method from optimizers until `load_state_dict` is implemented
|
24
|
+
|
25
|
+
## 0.2.7 (2020-06-29)
|
26
|
+
|
27
|
+
- Made tensors enumerable
|
28
|
+
- Improved performance of `inspect` method
|
29
|
+
|
1
30
|
## 0.2.6 (2020-06-29)
|
2
31
|
|
3
32
|
- Added support for indexing with tensors
|
@@ -38,7 +67,7 @@
|
|
38
67
|
## 0.2.0 (2020-04-22)
|
39
68
|
|
40
69
|
- No longer experimental
|
41
|
-
- Updated
|
70
|
+
- Updated LibTorch to 1.5.0
|
42
71
|
- Added support for GPUs and OpenMP
|
43
72
|
- Added adaptive pooling layers
|
44
73
|
- Tensor `dtype` is now based on Numo type for `Torch.tensor`
|
@@ -47,7 +76,7 @@
|
|
47
76
|
|
48
77
|
## 0.1.8 (2020-01-17)
|
49
78
|
|
50
|
-
- Updated
|
79
|
+
- Updated LibTorch to 1.4.0
|
51
80
|
|
52
81
|
## 0.1.7 (2020-01-10)
|
53
82
|
|
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
@@ -42,6 +46,8 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
|
|
42
46
|
- Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
|
43
47
|
- Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
|
44
48
|
|
49
|
+
You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
|
50
|
+
|
45
51
|
## Tutorial
|
46
52
|
|
47
53
|
Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
|
@@ -409,7 +415,8 @@ Here’s the list of compatible versions.
|
|
409
415
|
|
410
416
|
Torch.rb | LibTorch
|
411
417
|
--- | ---
|
412
|
-
0.
|
418
|
+
0.3.0-0.3.1 | 1.6.0
|
419
|
+
0.2.0-0.2.7 | 1.5.0-1.5.1
|
413
420
|
0.1.8 | 1.4.0
|
414
421
|
0.1.0-0.1.7 | 1.3.1
|
415
422
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -48,15 +58,23 @@ void Init_ext()
|
|
48
58
|
.define_singleton_method(
|
49
59
|
"initial_seed",
|
50
60
|
*[]() {
|
51
|
-
return at::detail::getDefaultCPUGenerator()
|
61
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
52
62
|
})
|
53
63
|
.define_singleton_method(
|
54
64
|
"seed",
|
55
65
|
*[]() {
|
56
66
|
// TODO set for CUDA when available
|
57
|
-
|
67
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
68
|
+
return generator.seed();
|
58
69
|
});
|
59
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
60
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
61
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -329,6 +347,18 @@ void Init_ext()
|
|
329
347
|
.define_method("numel", &torch::Tensor::numel)
|
330
348
|
.define_method("element_size", &torch::Tensor::element_size)
|
331
349
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
350
|
+
.define_method(
|
351
|
+
"_index",
|
352
|
+
*[](Tensor& self, Array indices) {
|
353
|
+
auto vec = index_vector(indices);
|
354
|
+
return self.index(vec);
|
355
|
+
})
|
356
|
+
.define_method(
|
357
|
+
"_index_put_custom",
|
358
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
359
|
+
auto vec = index_vector(indices);
|
360
|
+
return self.index_put_(vec, value);
|
361
|
+
})
|
332
362
|
.define_method(
|
333
363
|
"contiguous?",
|
334
364
|
*[](Tensor& self) {
|
@@ -351,13 +381,19 @@ void Init_ext()
|
|
351
381
|
})
|
352
382
|
.define_method(
|
353
383
|
"_backward",
|
354
|
-
*[](Tensor& self,
|
355
|
-
return
|
384
|
+
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
385
|
+
return self.backward(gradient, create_graph, retain_graph);
|
356
386
|
})
|
357
387
|
.define_method(
|
358
388
|
"grad",
|
359
389
|
*[](Tensor& self) {
|
360
|
-
|
390
|
+
auto grad = self.grad();
|
391
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
392
|
+
})
|
393
|
+
.define_method(
|
394
|
+
"grad=",
|
395
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
396
|
+
self.grad() = grad;
|
361
397
|
})
|
362
398
|
.define_method(
|
363
399
|
"_dtype",
|
@@ -460,7 +496,7 @@ void Init_ext()
|
|
460
496
|
.define_singleton_method(
|
461
497
|
"_make_subclass",
|
462
498
|
*[](Tensor& rd, bool requires_grad) {
|
463
|
-
auto data =
|
499
|
+
auto data = rd.detach();
|
464
500
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
465
501
|
auto var = data.set_requires_grad(requires_grad);
|
466
502
|
return Parameter(std::move(var));
|
@@ -501,6 +537,7 @@ void Init_ext()
|
|
501
537
|
});
|
502
538
|
|
503
539
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
540
|
+
.add_handler<torch::Error>(handle_error)
|
504
541
|
.define_singleton_method(
|
505
542
|
"_calculate_gain",
|
506
543
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -579,11 +616,16 @@ void Init_ext()
|
|
579
616
|
*[](Parameter& self) {
|
580
617
|
auto grad = self.grad();
|
581
618
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
619
|
+
})
|
620
|
+
.define_method(
|
621
|
+
"grad=",
|
622
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
623
|
+
self.grad() = grad;
|
582
624
|
});
|
583
625
|
|
584
626
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
585
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
586
627
|
.add_handler<torch::Error>(handle_error)
|
628
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
587
629
|
.define_method("index", &torch::Device::index)
|
588
630
|
.define_method("index?", &torch::Device::has_index)
|
589
631
|
.define_method(
|
data/ext/torch/extconf.rb
CHANGED
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
8
|
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
|
-
|
11
|
-
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
10
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
12
11
|
|
13
12
|
# check omp first
|
14
13
|
if have_library("omp") || have_library("gomp")
|
15
14
|
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
-
$CXXFLAGS += " -Xclang" if
|
15
|
+
$CXXFLAGS += " -Xclang" if apple_clang
|
17
16
|
$CXXFLAGS += " -fopenmp"
|
18
17
|
end
|
19
18
|
|
20
|
-
if
|
19
|
+
if apple_clang
|
21
20
|
# silence ruby/intern.h warning
|
22
21
|
$CXXFLAGS += " -Wno-deprecated-register"
|
23
22
|
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,6 +9,10 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
+
using torch::Device;
|
13
|
+
using torch::ScalarType;
|
14
|
+
using torch::Tensor;
|
15
|
+
|
12
16
|
// need to wrap torch::IntArrayRef() since
|
13
17
|
// it doesn't own underlying data
|
14
18
|
class IntArrayRef {
|
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
178
|
return MyReduction(x);
|
175
179
|
}
|
176
180
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
181
|
class OptionalTensor {
|
180
182
|
Object value;
|
181
183
|
public:
|
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
199
|
return OptionalTensor(x);
|
198
200
|
}
|
199
201
|
|
200
|
-
class ScalarType {
|
201
|
-
Object value;
|
202
|
-
public:
|
203
|
-
ScalarType(Object o) {
|
204
|
-
value = o;
|
205
|
-
}
|
206
|
-
operator at::ScalarType() {
|
207
|
-
throw std::runtime_error("ScalarType arguments not implemented yet");
|
208
|
-
}
|
209
|
-
};
|
210
|
-
|
211
202
|
template<>
|
212
203
|
inline
|
213
|
-
ScalarType from_ruby<ScalarType
|
204
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
214
205
|
{
|
215
|
-
|
206
|
+
if (x.is_nil()) {
|
207
|
+
return torch::nullopt;
|
208
|
+
} else {
|
209
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
210
|
+
}
|
216
211
|
}
|
217
212
|
|
218
|
-
class OptionalScalarType {
|
219
|
-
Object value;
|
220
|
-
public:
|
221
|
-
OptionalScalarType(Object o) {
|
222
|
-
value = o;
|
223
|
-
}
|
224
|
-
operator c10::optional<at::ScalarType>() {
|
225
|
-
if (value.is_nil()) {
|
226
|
-
return c10::nullopt;
|
227
|
-
}
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
231
|
-
|
232
213
|
template<>
|
233
214
|
inline
|
234
|
-
|
215
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
235
216
|
{
|
236
|
-
|
217
|
+
if (x.is_nil()) {
|
218
|
+
return torch::nullopt;
|
219
|
+
} else {
|
220
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
221
|
+
}
|
237
222
|
}
|
238
223
|
|
239
|
-
typedef torch::Device Device;
|
240
|
-
|
241
224
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
225
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
226
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -4,6 +4,7 @@ require "torch/ext"
|
|
4
4
|
# stdlib
|
5
5
|
require "fileutils"
|
6
6
|
require "net/http"
|
7
|
+
require "set"
|
7
8
|
require "tmpdir"
|
8
9
|
|
9
10
|
# native functions
|
@@ -178,8 +179,10 @@ require "torch/nn/functional"
|
|
178
179
|
require "torch/nn/init"
|
179
180
|
|
180
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
181
183
|
require "torch/utils/data/data_loader"
|
182
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
183
186
|
require "torch/utils/data/tensor_dataset"
|
184
187
|
|
185
188
|
# hub
|
@@ -315,6 +318,16 @@ module Torch
|
|
315
318
|
end
|
316
319
|
end
|
317
320
|
|
321
|
+
def enable_grad
|
322
|
+
previous_value = grad_enabled?
|
323
|
+
begin
|
324
|
+
_set_grad_enabled(true)
|
325
|
+
yield
|
326
|
+
ensure
|
327
|
+
_set_grad_enabled(previous_value)
|
328
|
+
end
|
329
|
+
end
|
330
|
+
|
318
331
|
def device(str)
|
319
332
|
Device.new(str)
|
320
333
|
end
|
@@ -447,6 +460,17 @@ module Torch
|
|
447
460
|
zeros(input.size, **like_options(input, options))
|
448
461
|
end
|
449
462
|
|
463
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
464
|
+
if center
|
465
|
+
signal_dim = input.dim
|
466
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
467
|
+
pad = n_fft.div(2).to_i
|
468
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
469
|
+
input = input.view(input.shape[-signal_dim..-1])
|
470
|
+
end
|
471
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
472
|
+
end
|
473
|
+
|
450
474
|
private
|
451
475
|
|
452
476
|
def to_ivalue(obj)
|
@@ -470,11 +494,7 @@ module Torch
|
|
470
494
|
when nil
|
471
495
|
IValue.new
|
472
496
|
when Array
|
473
|
-
|
474
|
-
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
475
|
-
else
|
476
|
-
raise Error, "Unknown list type"
|
477
|
-
end
|
497
|
+
IValue.from_list(obj.map { |v| to_ivalue(v) })
|
478
498
|
else
|
479
499
|
raise Error, "Unknown type: #{obj.class.name}"
|
480
500
|
end
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,89 +1,264 @@
|
|
1
|
+
# mirrors _tensor_str.py
|
1
2
|
module Torch
|
2
3
|
module Inspector
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
4
|
+
PRINT_OPTS = {
|
5
|
+
precision: 4,
|
6
|
+
threshold: 1000,
|
7
|
+
edgeitems: 3,
|
8
|
+
linewidth: 80,
|
9
|
+
sci_mode: nil
|
10
|
+
}
|
11
|
+
|
12
|
+
class Formatter
|
13
|
+
def initialize(tensor)
|
14
|
+
@floating_dtype = tensor.floating_point?
|
15
|
+
@complex_dtype = tensor.complex?
|
16
|
+
@int_mode = true
|
17
|
+
@sci_mode = false
|
18
|
+
@max_width = 1
|
19
|
+
|
20
|
+
tensor_view = Torch.no_grad { tensor.reshape(-1) }
|
21
|
+
|
22
|
+
if !@floating_dtype
|
23
|
+
tensor_view.each do |value|
|
24
|
+
value_str = value.item.to_s
|
25
|
+
@max_width = [@max_width, value_str.length].max
|
26
|
+
end
|
11
27
|
else
|
12
|
-
|
28
|
+
nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
|
29
|
+
|
30
|
+
# no valid number, do nothing
|
31
|
+
return if nonzero_finite_vals.numel == 0
|
32
|
+
|
33
|
+
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
34
|
+
nonzero_finite_abs = nonzero_finite_vals.abs.double
|
35
|
+
nonzero_finite_min = nonzero_finite_abs.min.double
|
36
|
+
nonzero_finite_max = nonzero_finite_abs.max.double
|
37
|
+
|
38
|
+
nonzero_finite_vals.each do |value|
|
39
|
+
if value.item != value.item.ceil
|
40
|
+
@int_mode = false
|
41
|
+
break
|
42
|
+
end
|
43
|
+
end
|
13
44
|
|
14
|
-
if
|
15
|
-
|
45
|
+
if @int_mode
|
46
|
+
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
47
|
+
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
48
|
+
if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
|
49
|
+
@sci_mode = true
|
50
|
+
nonzero_finite_vals.each do |value|
|
51
|
+
value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
|
52
|
+
@max_width = [@max_width, value_str.length].max
|
53
|
+
end
|
54
|
+
else
|
55
|
+
nonzero_finite_vals.each do |value|
|
56
|
+
value_str = "%.0f" % value.item
|
57
|
+
@max_width = [@max_width, value_str.length + 1].max
|
58
|
+
end
|
59
|
+
end
|
16
60
|
else
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
61
|
+
# Check if scientific representation should be used.
|
62
|
+
if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
|
63
|
+
@sci_mode = true
|
64
|
+
nonzero_finite_vals.each do |value|
|
65
|
+
value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
|
66
|
+
@max_width = [@max_width, value_str.length].max
|
67
|
+
end
|
68
|
+
else
|
69
|
+
nonzero_finite_vals.each do |value|
|
70
|
+
value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
|
71
|
+
@max_width = [@max_width, value_str.length].max
|
72
|
+
end
|
25
73
|
end
|
74
|
+
end
|
75
|
+
end
|
26
76
|
|
27
|
-
|
28
|
-
|
77
|
+
@sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
|
78
|
+
end
|
29
79
|
|
30
|
-
|
31
|
-
|
80
|
+
def width
|
81
|
+
@max_width
|
82
|
+
end
|
32
83
|
|
33
|
-
|
84
|
+
def format(value)
|
85
|
+
value = value.item
|
34
86
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
fmt = "%#{total}d"
|
87
|
+
if @floating_dtype
|
88
|
+
if @sci_mode
|
89
|
+
ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
|
90
|
+
elsif @int_mode
|
91
|
+
ret = String.new("%.0f" % value)
|
92
|
+
unless value.infinite? || value.nan?
|
93
|
+
ret += "."
|
43
94
|
end
|
95
|
+
else
|
96
|
+
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
44
97
|
end
|
98
|
+
elsif @complex_dtype
|
99
|
+
p = PRINT_OPTS[:precision]
|
100
|
+
raise NotImplementedYet
|
101
|
+
else
|
102
|
+
ret = value.to_s
|
103
|
+
end
|
104
|
+
# Ruby throws error when negative, Python doesn't
|
105
|
+
" " * [@max_width - ret.size, 0].max + ret
|
106
|
+
end
|
107
|
+
end
|
108
|
+
|
109
|
+
def inspect
|
110
|
+
Torch.no_grad do
|
111
|
+
str_intern(self)
|
112
|
+
end
|
113
|
+
rescue => e
|
114
|
+
# prevent stack error
|
115
|
+
puts e.backtrace.join("\n")
|
116
|
+
"Error inspecting tensor: #{e.inspect}"
|
117
|
+
end
|
118
|
+
|
119
|
+
private
|
120
|
+
|
121
|
+
# TODO update
|
122
|
+
def str_intern(slf)
|
123
|
+
prefix = "tensor("
|
124
|
+
indent = prefix.length
|
125
|
+
suffixes = []
|
126
|
+
|
127
|
+
has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
|
128
|
+
|
129
|
+
if slf.numel == 0 && !slf.sparse?
|
130
|
+
# Explicitly print the shape if it is not (0,), to match NumPy behavior
|
131
|
+
if slf.dim != 1
|
132
|
+
suffixes << "size: #{shape.inspect}"
|
133
|
+
end
|
45
134
|
|
46
|
-
|
135
|
+
# In an empty tensor, there are no elements to infer if the dtype
|
136
|
+
# should be int64, so it must be shown explicitly.
|
137
|
+
if slf.dtype != :int64
|
138
|
+
suffixes << "dtype: #{slf.dtype.inspect}"
|
47
139
|
end
|
140
|
+
tensor_str = "[]"
|
141
|
+
else
|
142
|
+
if !has_default_dtype
|
143
|
+
suffixes << "dtype: #{slf.dtype.inspect}"
|
144
|
+
end
|
145
|
+
|
146
|
+
if slf.layout != :strided
|
147
|
+
tensor_str = tensor_str(slf.to_dense, indent)
|
148
|
+
else
|
149
|
+
tensor_str = tensor_str(slf, indent)
|
150
|
+
end
|
151
|
+
end
|
48
152
|
|
49
|
-
|
50
|
-
|
51
|
-
attributes << "requires_grad: true"
|
153
|
+
if slf.layout != :strided
|
154
|
+
suffixes << "layout: #{slf.layout.inspect}"
|
52
155
|
end
|
53
|
-
|
54
|
-
|
156
|
+
|
157
|
+
# TODO show grad_fn
|
158
|
+
if slf.requires_grad?
|
159
|
+
suffixes << "requires_grad: true"
|
55
160
|
end
|
56
161
|
|
57
|
-
|
162
|
+
add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
|
58
163
|
end
|
59
164
|
|
60
|
-
|
165
|
+
def add_suffixes(tensor_str, suffixes, indent, force_newline)
|
166
|
+
tensor_strs = [tensor_str]
|
167
|
+
# rfind in Python returns -1 when not found
|
168
|
+
last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
|
169
|
+
suffixes.each do |suffix|
|
170
|
+
suffix_len = suffix.length
|
171
|
+
if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
|
172
|
+
tensor_strs << ",\n" + " " * indent + suffix
|
173
|
+
last_line_len = indent + suffix_len
|
174
|
+
force_newline = false
|
175
|
+
else
|
176
|
+
tensor_strs.append(", " + suffix)
|
177
|
+
last_line_len += suffix_len + 2
|
178
|
+
end
|
179
|
+
end
|
180
|
+
tensor_strs.append(")")
|
181
|
+
tensor_strs.join("")
|
182
|
+
end
|
61
183
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
184
|
+
def tensor_str(slf, indent)
|
185
|
+
return "[]" if slf.numel == 0
|
186
|
+
|
187
|
+
summarize = slf.numel > PRINT_OPTS[:threshold]
|
188
|
+
|
189
|
+
if slf.dtype == :float16 || slf.dtype == :bfloat16
|
190
|
+
slf = slf.float
|
191
|
+
end
|
192
|
+
formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
|
193
|
+
tensor_str_with_formatter(slf, indent, formatter, summarize)
|
194
|
+
end
|
195
|
+
|
196
|
+
def summarized_data(slf)
|
197
|
+
edgeitems = PRINT_OPTS[:edgeitems]
|
73
198
|
|
74
|
-
|
199
|
+
dim = slf.dim
|
200
|
+
if dim == 0
|
201
|
+
slf
|
202
|
+
elsif dim == 1
|
203
|
+
if size(0) > 2 * edgeitems
|
204
|
+
Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
|
205
|
+
else
|
206
|
+
slf
|
207
|
+
end
|
208
|
+
elsif slf.size(0) > 2 * edgeitems
|
209
|
+
start = edgeitems.times.map { |i| slf[i] }
|
210
|
+
finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
|
211
|
+
Torch.stack((start + finish).map { |x| summarized_data(x) })
|
75
212
|
else
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
213
|
+
Torch.stack(slf.map { |x| summarized_data(x) })
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def tensor_str_with_formatter(slf, indent, formatter, summarize)
|
218
|
+
edgeitems = PRINT_OPTS[:edgeitems]
|
219
|
+
|
220
|
+
dim = slf.dim
|
84
221
|
|
85
|
-
|
222
|
+
return scalar_str(slf, formatter) if dim == 0
|
223
|
+
return vector_str(slf, indent, formatter, summarize) if dim == 1
|
224
|
+
|
225
|
+
if summarize && slf.size(0) > 2 * edgeitems
|
226
|
+
slices = (
|
227
|
+
[edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
|
228
|
+
["..."] +
|
229
|
+
[((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
|
230
|
+
)
|
231
|
+
else
|
232
|
+
slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
|
86
233
|
end
|
234
|
+
|
235
|
+
tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
|
236
|
+
"[" + tensor_str + "]"
|
237
|
+
end
|
238
|
+
|
239
|
+
def scalar_str(slf, formatter)
|
240
|
+
formatter.format(slf)
|
241
|
+
end
|
242
|
+
|
243
|
+
def vector_str(slf, indent, formatter, summarize)
|
244
|
+
# length includes spaces and comma between elements
|
245
|
+
element_length = formatter.width + 2
|
246
|
+
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
247
|
+
char_per_line = element_length * elements_per_line
|
248
|
+
|
249
|
+
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
250
|
+
data = (
|
251
|
+
[slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
|
252
|
+
[" ..."] +
|
253
|
+
[slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
|
254
|
+
)
|
255
|
+
else
|
256
|
+
data = slf.map { |val| formatter.format(val) }
|
257
|
+
end
|
258
|
+
|
259
|
+
data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
|
260
|
+
lines = data_lines.map { |line| line.join(", ") }
|
261
|
+
"[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
|
87
262
|
end
|
88
263
|
end
|
89
264
|
end
|