torch-rb 0.2.4 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +29 -2
- data/README.md +22 -7
- data/ext/torch/ext.cpp +46 -24
- data/ext/torch/extconf.rb +3 -4
- data/lib/torch.rb +7 -5
- data/lib/torch/hub.rb +48 -4
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/module.rb +5 -2
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +69 -39
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 06e94b492acbbdb71f9e6a11081fb043a03ae0d5c704cc79faa31dd96bde70ef
|
4
|
+
data.tar.gz: 4f38fa52d30ef9bf121204423b4d675f21dbef806b6f137152f2cf9399ddf4bb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 2fb2613ca629a70f55009b697b15830d59c0d8fc06c1c5102917b4870cb783427fb56ecc08889c09e15c342381385f258b2a33102dc5adddf2d463d41674994d
|
7
|
+
data.tar.gz: f26a6ba91caa57a92b8b047217a35c39d1e9c4c361df77e2182053b4ab490f20792fc88dba169dae87d4a3d4ee4d69e2c779efb1fa6150b4d3f0d93e3762aec9
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,30 @@
|
|
1
|
+
## 0.3.1 (2020-08-17)
|
2
|
+
|
3
|
+
- Added `create_graph` and `retain_graph` options to `backward` method
|
4
|
+
- Fixed error when `set` not required
|
5
|
+
|
6
|
+
## 0.3.0 (2020-07-29)
|
7
|
+
|
8
|
+
- Updated LibTorch to 1.6.0
|
9
|
+
- Removed `state_dict` method from optimizers until `load_state_dict` is implemented
|
10
|
+
|
11
|
+
## 0.2.7 (2020-06-29)
|
12
|
+
|
13
|
+
- Made tensors enumerable
|
14
|
+
- Improved performance of `inspect` method
|
15
|
+
|
16
|
+
## 0.2.6 (2020-06-29)
|
17
|
+
|
18
|
+
- Added support for indexing with tensors
|
19
|
+
- Added `contiguous` methods
|
20
|
+
- Fixed named parameters for nested parameters
|
21
|
+
|
22
|
+
## 0.2.5 (2020-06-07)
|
23
|
+
|
24
|
+
- Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
|
25
|
+
- Improved error messages
|
26
|
+
- Fixed tensor slicing
|
27
|
+
|
1
28
|
## 0.2.4 (2020-04-29)
|
2
29
|
|
3
30
|
- Added `to_i` and `to_f` to tensors
|
@@ -26,7 +53,7 @@
|
|
26
53
|
## 0.2.0 (2020-04-22)
|
27
54
|
|
28
55
|
- No longer experimental
|
29
|
-
- Updated
|
56
|
+
- Updated LibTorch to 1.5.0
|
30
57
|
- Added support for GPUs and OpenMP
|
31
58
|
- Added adaptive pooling layers
|
32
59
|
- Tensor `dtype` is now based on Numo type for `Torch.tensor`
|
@@ -35,7 +62,7 @@
|
|
35
62
|
|
36
63
|
## 0.1.8 (2020-01-17)
|
37
64
|
|
38
|
-
- Updated
|
65
|
+
- Updated LibTorch to 1.4.0
|
39
66
|
|
40
67
|
## 0.1.7 (2020-01-10)
|
41
68
|
|
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,12 +24,26 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
28
42
|
- Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
|
29
43
|
- Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
|
30
44
|
|
45
|
+
You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
|
46
|
+
|
31
47
|
## Tutorial
|
32
48
|
|
33
49
|
Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
|
@@ -192,7 +208,7 @@ end
|
|
192
208
|
Define a neural network
|
193
209
|
|
194
210
|
```ruby
|
195
|
-
class
|
211
|
+
class MyNet < Torch::NN::Module
|
196
212
|
def initialize
|
197
213
|
super
|
198
214
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +242,7 @@ end
|
|
226
242
|
Create an instance of it
|
227
243
|
|
228
244
|
```ruby
|
229
|
-
net =
|
245
|
+
net = MyNet.new
|
230
246
|
input = Torch.randn(1, 1, 32, 32)
|
231
247
|
net.call(input)
|
232
248
|
```
|
@@ -294,7 +310,7 @@ Torch.save(net.state_dict, "net.pth")
|
|
294
310
|
Load a model
|
295
311
|
|
296
312
|
```ruby
|
297
|
-
net =
|
313
|
+
net = MyNet.new
|
298
314
|
net.load_state_dict(Torch.load("net.pth"))
|
299
315
|
net.eval
|
300
316
|
```
|
@@ -395,7 +411,8 @@ Here’s the list of compatible versions.
|
|
395
411
|
|
396
412
|
Torch.rb | LibTorch
|
397
413
|
--- | ---
|
398
|
-
0.
|
414
|
+
0.3.0-0.3.1 | 1.6.0
|
415
|
+
0.2.0-0.2.7 | 1.5.0-1.5.1
|
399
416
|
0.1.8 | 1.4.0
|
400
417
|
0.1.0-0.1.7 | 1.3.1
|
401
418
|
|
@@ -413,9 +430,7 @@ Then install the gem (no need for `bundle config`).
|
|
413
430
|
|
414
431
|
### Linux
|
415
432
|
|
416
|
-
Deep learning is significantly faster on
|
417
|
-
|
418
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
433
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
419
434
|
|
420
435
|
Check if CUDA is available
|
421
436
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
|
|
23
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
24
24
|
};
|
25
25
|
|
26
|
-
void handle_error(
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
27
|
{
|
28
28
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
29
|
}
|
@@ -32,29 +32,35 @@ extern "C"
|
|
32
32
|
void Init_ext()
|
33
33
|
{
|
34
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
35
36
|
add_torch_functions(rb_mTorch);
|
36
37
|
|
37
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
38
40
|
add_tensor_functions(rb_cTensor);
|
39
41
|
|
40
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
41
44
|
add_nn_functions(rb_mNN);
|
42
45
|
|
43
46
|
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
44
48
|
.define_singleton_method(
|
45
49
|
"initial_seed",
|
46
50
|
*[]() {
|
47
|
-
return at::detail::getDefaultCPUGenerator()
|
51
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
48
52
|
})
|
49
53
|
.define_singleton_method(
|
50
54
|
"seed",
|
51
55
|
*[]() {
|
52
56
|
// TODO set for CUDA when available
|
53
|
-
|
57
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
58
|
+
return generator.seed();
|
54
59
|
});
|
55
60
|
|
56
61
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
57
62
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
63
|
+
.add_handler<torch::Error>(handle_error)
|
58
64
|
.define_constructor(Constructor<torch::IValue>())
|
59
65
|
.define_method("bool?", &torch::IValue::isBool)
|
60
66
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
@@ -317,7 +323,6 @@ void Init_ext()
|
|
317
323
|
});
|
318
324
|
|
319
325
|
rb_cTensor
|
320
|
-
.add_handler<c10::Error>(handle_error)
|
321
326
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
322
327
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
323
328
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -325,6 +330,11 @@ void Init_ext()
|
|
325
330
|
.define_method("numel", &torch::Tensor::numel)
|
326
331
|
.define_method("element_size", &torch::Tensor::element_size)
|
327
332
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
333
|
+
.define_method(
|
334
|
+
"contiguous?",
|
335
|
+
*[](Tensor& self) {
|
336
|
+
return self.is_contiguous();
|
337
|
+
})
|
328
338
|
.define_method(
|
329
339
|
"addcmul!",
|
330
340
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -342,8 +352,8 @@ void Init_ext()
|
|
342
352
|
})
|
343
353
|
.define_method(
|
344
354
|
"_backward",
|
345
|
-
*[](Tensor& self,
|
346
|
-
return
|
355
|
+
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
356
|
+
return self.backward(gradient, create_graph, retain_graph);
|
347
357
|
})
|
348
358
|
.define_method(
|
349
359
|
"grad",
|
@@ -374,6 +384,21 @@ void Init_ext()
|
|
374
384
|
s << self.device();
|
375
385
|
return s.str();
|
376
386
|
})
|
387
|
+
.define_method(
|
388
|
+
"_data_str",
|
389
|
+
*[](Tensor& self) {
|
390
|
+
Tensor tensor = self;
|
391
|
+
|
392
|
+
// move to CPU to get data
|
393
|
+
if (tensor.device().type() != torch::kCPU) {
|
394
|
+
torch::Device device("cpu");
|
395
|
+
tensor = tensor.to(device);
|
396
|
+
}
|
397
|
+
|
398
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
399
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
400
|
+
})
|
401
|
+
// TODO figure out a better way to do this
|
377
402
|
.define_method(
|
378
403
|
"_flat_data",
|
379
404
|
*[](Tensor& self) {
|
@@ -388,46 +413,40 @@ void Init_ext()
|
|
388
413
|
Array a;
|
389
414
|
auto dtype = tensor.dtype();
|
390
415
|
|
416
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
417
|
+
|
391
418
|
// TODO DRY if someone knows C++
|
392
419
|
if (dtype == torch::kByte) {
|
393
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
394
420
|
for (int i = 0; i < tensor.numel(); i++) {
|
395
|
-
a.push(
|
421
|
+
a.push(view[i].item().to<uint8_t>());
|
396
422
|
}
|
397
423
|
} else if (dtype == torch::kChar) {
|
398
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
399
424
|
for (int i = 0; i < tensor.numel(); i++) {
|
400
|
-
a.push(to_ruby<int>(
|
425
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
401
426
|
}
|
402
427
|
} else if (dtype == torch::kShort) {
|
403
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
404
428
|
for (int i = 0; i < tensor.numel(); i++) {
|
405
|
-
a.push(
|
429
|
+
a.push(view[i].item().to<int16_t>());
|
406
430
|
}
|
407
431
|
} else if (dtype == torch::kInt) {
|
408
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
409
432
|
for (int i = 0; i < tensor.numel(); i++) {
|
410
|
-
a.push(
|
433
|
+
a.push(view[i].item().to<int32_t>());
|
411
434
|
}
|
412
435
|
} else if (dtype == torch::kLong) {
|
413
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
414
436
|
for (int i = 0; i < tensor.numel(); i++) {
|
415
|
-
a.push(
|
437
|
+
a.push(view[i].item().to<int64_t>());
|
416
438
|
}
|
417
439
|
} else if (dtype == torch::kFloat) {
|
418
|
-
float* data = tensor.data_ptr<float>();
|
419
440
|
for (int i = 0; i < tensor.numel(); i++) {
|
420
|
-
a.push(
|
441
|
+
a.push(view[i].item().to<float>());
|
421
442
|
}
|
422
443
|
} else if (dtype == torch::kDouble) {
|
423
|
-
double* data = tensor.data_ptr<double>();
|
424
444
|
for (int i = 0; i < tensor.numel(); i++) {
|
425
|
-
a.push(
|
445
|
+
a.push(view[i].item().to<double>());
|
426
446
|
}
|
427
447
|
} else if (dtype == torch::kBool) {
|
428
|
-
bool* data = tensor.data_ptr<bool>();
|
429
448
|
for (int i = 0; i < tensor.numel(); i++) {
|
430
|
-
a.push(
|
449
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
431
450
|
}
|
432
451
|
} else {
|
433
452
|
throw std::runtime_error("Unsupported type");
|
@@ -442,14 +461,14 @@ void Init_ext()
|
|
442
461
|
.define_singleton_method(
|
443
462
|
"_make_subclass",
|
444
463
|
*[](Tensor& rd, bool requires_grad) {
|
445
|
-
auto data =
|
464
|
+
auto data = rd.detach();
|
446
465
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
447
466
|
auto var = data.set_requires_grad(requires_grad);
|
448
467
|
return Parameter(std::move(var));
|
449
468
|
});
|
450
469
|
|
451
470
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
452
|
-
.add_handler<
|
471
|
+
.add_handler<torch::Error>(handle_error)
|
453
472
|
.define_constructor(Constructor<torch::TensorOptions>())
|
454
473
|
.define_method(
|
455
474
|
"dtype",
|
@@ -555,6 +574,7 @@ void Init_ext()
|
|
555
574
|
});
|
556
575
|
|
557
576
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
577
|
+
.add_handler<torch::Error>(handle_error)
|
558
578
|
.define_method(
|
559
579
|
"grad",
|
560
580
|
*[](Parameter& self) {
|
@@ -564,6 +584,7 @@ void Init_ext()
|
|
564
584
|
|
565
585
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
566
586
|
.define_constructor(Constructor<torch::Device, std::string>())
|
587
|
+
.add_handler<torch::Error>(handle_error)
|
567
588
|
.define_method("index", &torch::Device::index)
|
568
589
|
.define_method("index?", &torch::Device::has_index)
|
569
590
|
.define_method(
|
@@ -575,6 +596,7 @@ void Init_ext()
|
|
575
596
|
});
|
576
597
|
|
577
598
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
599
|
+
.add_handler<torch::Error>(handle_error)
|
578
600
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
579
601
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
580
602
|
}
|
data/ext/torch/extconf.rb
CHANGED
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
8
|
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
|
-
|
11
|
-
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
10
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
12
11
|
|
13
12
|
# check omp first
|
14
13
|
if have_library("omp") || have_library("gomp")
|
15
14
|
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
-
$CXXFLAGS += " -Xclang" if
|
15
|
+
$CXXFLAGS += " -Xclang" if apple_clang
|
17
16
|
$CXXFLAGS += " -fopenmp"
|
18
17
|
end
|
19
18
|
|
20
|
-
if
|
19
|
+
if apple_clang
|
21
20
|
# silence ruby/intern.h warning
|
22
21
|
$CXXFLAGS += " -Wno-deprecated-register"
|
23
22
|
|
data/lib/torch.rb
CHANGED
@@ -1,6 +1,12 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# stdlib
|
5
|
+
require "fileutils"
|
6
|
+
require "net/http"
|
7
|
+
require "set"
|
8
|
+
require "tmpdir"
|
9
|
+
|
4
10
|
# native functions
|
5
11
|
require "torch/native/generator"
|
6
12
|
require "torch/native/parser"
|
@@ -465,11 +471,7 @@ module Torch
|
|
465
471
|
when nil
|
466
472
|
IValue.new
|
467
473
|
when Array
|
468
|
-
|
469
|
-
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
470
|
-
else
|
471
|
-
raise Error, "Unknown list type"
|
472
|
-
end
|
474
|
+
IValue.from_list(obj.map { |v| to_ivalue(v) })
|
473
475
|
else
|
474
476
|
raise Error, "Unknown type: #{obj.class.name}"
|
475
477
|
end
|
data/lib/torch/hub.rb
CHANGED
@@ -5,12 +5,56 @@ module Torch
|
|
5
5
|
raise NotImplementedYet
|
6
6
|
end
|
7
7
|
|
8
|
-
def download_url_to_file(url)
|
9
|
-
|
8
|
+
def download_url_to_file(url, dst)
|
9
|
+
uri = URI(url)
|
10
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
11
|
+
location = nil
|
12
|
+
|
13
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
|
+
request = Net::HTTP::Get.new(uri)
|
15
|
+
|
16
|
+
puts "Downloading #{url}..."
|
17
|
+
File.open(tmp, "wb") do |f|
|
18
|
+
http.request(request) do |response|
|
19
|
+
case response
|
20
|
+
when Net::HTTPRedirection
|
21
|
+
location = response["location"]
|
22
|
+
when Net::HTTPSuccess
|
23
|
+
response.read_body do |chunk|
|
24
|
+
f.write(chunk)
|
25
|
+
end
|
26
|
+
else
|
27
|
+
raise Error, "Bad response"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
if location
|
34
|
+
download_url_to_file(location, dst)
|
35
|
+
else
|
36
|
+
FileUtils.mv(tmp, dst)
|
37
|
+
nil
|
38
|
+
end
|
10
39
|
end
|
11
40
|
|
12
|
-
def load_state_dict_from_url(url)
|
13
|
-
|
41
|
+
def load_state_dict_from_url(url, model_dir: nil)
|
42
|
+
unless model_dir
|
43
|
+
torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
|
44
|
+
model_dir = File.join(torch_home, "checkpoints")
|
45
|
+
end
|
46
|
+
|
47
|
+
FileUtils.mkdir_p(model_dir)
|
48
|
+
|
49
|
+
parts = URI(url)
|
50
|
+
filename = File.basename(parts.path)
|
51
|
+
cached_file = File.join(model_dir, filename)
|
52
|
+
unless File.exist?(cached_file)
|
53
|
+
# TODO support hash_prefix
|
54
|
+
download_url_to_file(url, cached_file)
|
55
|
+
end
|
56
|
+
|
57
|
+
Torch.load(cached_file)
|
14
58
|
end
|
15
59
|
end
|
16
60
|
end
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,89 +1,264 @@
|
|
1
|
+
# mirrors _tensor_str.py
|
1
2
|
module Torch
|
2
3
|
module Inspector
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
4
|
+
PRINT_OPTS = {
|
5
|
+
precision: 4,
|
6
|
+
threshold: 1000,
|
7
|
+
edgeitems: 3,
|
8
|
+
linewidth: 80,
|
9
|
+
sci_mode: nil
|
10
|
+
}
|
11
|
+
|
12
|
+
class Formatter
|
13
|
+
def initialize(tensor)
|
14
|
+
@floating_dtype = tensor.floating_point?
|
15
|
+
@complex_dtype = tensor.complex?
|
16
|
+
@int_mode = true
|
17
|
+
@sci_mode = false
|
18
|
+
@max_width = 1
|
19
|
+
|
20
|
+
tensor_view = Torch.no_grad { tensor.reshape(-1) }
|
21
|
+
|
22
|
+
if !@floating_dtype
|
23
|
+
tensor_view.each do |value|
|
24
|
+
value_str = value.item.to_s
|
25
|
+
@max_width = [@max_width, value_str.length].max
|
26
|
+
end
|
11
27
|
else
|
12
|
-
|
28
|
+
nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
|
29
|
+
|
30
|
+
# no valid number, do nothing
|
31
|
+
return if nonzero_finite_vals.numel == 0
|
32
|
+
|
33
|
+
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
34
|
+
nonzero_finite_abs = nonzero_finite_vals.abs.double
|
35
|
+
nonzero_finite_min = nonzero_finite_abs.min.double
|
36
|
+
nonzero_finite_max = nonzero_finite_abs.max.double
|
37
|
+
|
38
|
+
nonzero_finite_vals.each do |value|
|
39
|
+
if value.item != value.item.ceil
|
40
|
+
@int_mode = false
|
41
|
+
break
|
42
|
+
end
|
43
|
+
end
|
13
44
|
|
14
|
-
if
|
15
|
-
|
45
|
+
if @int_mode
|
46
|
+
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
47
|
+
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
48
|
+
if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
|
49
|
+
@sci_mode = true
|
50
|
+
nonzero_finite_vals.each do |value|
|
51
|
+
value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
|
52
|
+
@max_width = [@max_width, value_str.length].max
|
53
|
+
end
|
54
|
+
else
|
55
|
+
nonzero_finite_vals.each do |value|
|
56
|
+
value_str = "%.0f" % value.item
|
57
|
+
@max_width = [@max_width, value_str.length + 1].max
|
58
|
+
end
|
59
|
+
end
|
16
60
|
else
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
61
|
+
# Check if scientific representation should be used.
|
62
|
+
if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
|
63
|
+
@sci_mode = true
|
64
|
+
nonzero_finite_vals.each do |value|
|
65
|
+
value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
|
66
|
+
@max_width = [@max_width, value_str.length].max
|
67
|
+
end
|
68
|
+
else
|
69
|
+
nonzero_finite_vals.each do |value|
|
70
|
+
value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
|
71
|
+
@max_width = [@max_width, value_str.length].max
|
72
|
+
end
|
25
73
|
end
|
74
|
+
end
|
75
|
+
end
|
26
76
|
|
27
|
-
|
28
|
-
|
77
|
+
@sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
|
78
|
+
end
|
29
79
|
|
30
|
-
|
31
|
-
|
80
|
+
def width
|
81
|
+
@max_width
|
82
|
+
end
|
32
83
|
|
33
|
-
|
84
|
+
def format(value)
|
85
|
+
value = value.item
|
34
86
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
fmt = "%#{total}d"
|
87
|
+
if @floating_dtype
|
88
|
+
if @sci_mode
|
89
|
+
ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
|
90
|
+
elsif @int_mode
|
91
|
+
ret = String.new("%.0f" % value)
|
92
|
+
unless value.infinite? || value.nan?
|
93
|
+
ret += "."
|
43
94
|
end
|
95
|
+
else
|
96
|
+
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
44
97
|
end
|
98
|
+
elsif @complex_dtype
|
99
|
+
p = PRINT_OPTS[:precision]
|
100
|
+
raise NotImplementedYet
|
101
|
+
else
|
102
|
+
ret = value.to_s
|
103
|
+
end
|
104
|
+
# Ruby throws error when negative, Python doesn't
|
105
|
+
" " * [@max_width - ret.size, 0].max + ret
|
106
|
+
end
|
107
|
+
end
|
108
|
+
|
109
|
+
def inspect
|
110
|
+
Torch.no_grad do
|
111
|
+
str_intern(self)
|
112
|
+
end
|
113
|
+
rescue => e
|
114
|
+
# prevent stack error
|
115
|
+
puts e.backtrace.join("\n")
|
116
|
+
"Error inspecting tensor: #{e.inspect}"
|
117
|
+
end
|
118
|
+
|
119
|
+
private
|
120
|
+
|
121
|
+
# TODO update
|
122
|
+
def str_intern(slf)
|
123
|
+
prefix = "tensor("
|
124
|
+
indent = prefix.length
|
125
|
+
suffixes = []
|
126
|
+
|
127
|
+
has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
|
128
|
+
|
129
|
+
if slf.numel == 0 && !slf.sparse?
|
130
|
+
# Explicitly print the shape if it is not (0,), to match NumPy behavior
|
131
|
+
if slf.dim != 1
|
132
|
+
suffixes << "size: #{shape.inspect}"
|
133
|
+
end
|
45
134
|
|
46
|
-
|
135
|
+
# In an empty tensor, there are no elements to infer if the dtype
|
136
|
+
# should be int64, so it must be shown explicitly.
|
137
|
+
if slf.dtype != :int64
|
138
|
+
suffixes << "dtype: #{slf.dtype.inspect}"
|
47
139
|
end
|
140
|
+
tensor_str = "[]"
|
141
|
+
else
|
142
|
+
if !has_default_dtype
|
143
|
+
suffixes << "dtype: #{slf.dtype.inspect}"
|
144
|
+
end
|
145
|
+
|
146
|
+
if slf.layout != :strided
|
147
|
+
tensor_str = tensor_str(slf.to_dense, indent)
|
148
|
+
else
|
149
|
+
tensor_str = tensor_str(slf, indent)
|
150
|
+
end
|
151
|
+
end
|
48
152
|
|
49
|
-
|
50
|
-
|
51
|
-
attributes << "requires_grad: true"
|
153
|
+
if slf.layout != :strided
|
154
|
+
suffixes << "layout: #{slf.layout.inspect}"
|
52
155
|
end
|
53
|
-
|
54
|
-
|
156
|
+
|
157
|
+
# TODO show grad_fn
|
158
|
+
if slf.requires_grad?
|
159
|
+
suffixes << "requires_grad: true"
|
55
160
|
end
|
56
161
|
|
57
|
-
|
162
|
+
add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
|
58
163
|
end
|
59
164
|
|
60
|
-
|
165
|
+
def add_suffixes(tensor_str, suffixes, indent, force_newline)
|
166
|
+
tensor_strs = [tensor_str]
|
167
|
+
# rfind in Python returns -1 when not found
|
168
|
+
last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
|
169
|
+
suffixes.each do |suffix|
|
170
|
+
suffix_len = suffix.length
|
171
|
+
if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
|
172
|
+
tensor_strs << ",\n" + " " * indent + suffix
|
173
|
+
last_line_len = indent + suffix_len
|
174
|
+
force_newline = false
|
175
|
+
else
|
176
|
+
tensor_strs.append(", " + suffix)
|
177
|
+
last_line_len += suffix_len + 2
|
178
|
+
end
|
179
|
+
end
|
180
|
+
tensor_strs.append(")")
|
181
|
+
tensor_strs.join("")
|
182
|
+
end
|
61
183
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
184
|
+
def tensor_str(slf, indent)
|
185
|
+
return "[]" if slf.numel == 0
|
186
|
+
|
187
|
+
summarize = slf.numel > PRINT_OPTS[:threshold]
|
188
|
+
|
189
|
+
if slf.dtype == :float16 || slf.dtype == :bfloat16
|
190
|
+
slf = slf.float
|
191
|
+
end
|
192
|
+
formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
|
193
|
+
tensor_str_with_formatter(slf, indent, formatter, summarize)
|
194
|
+
end
|
195
|
+
|
196
|
+
def summarized_data(slf)
|
197
|
+
edgeitems = PRINT_OPTS[:edgeitems]
|
73
198
|
|
74
|
-
|
199
|
+
dim = slf.dim
|
200
|
+
if dim == 0
|
201
|
+
slf
|
202
|
+
elsif dim == 1
|
203
|
+
if size(0) > 2 * edgeitems
|
204
|
+
Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
|
205
|
+
else
|
206
|
+
slf
|
207
|
+
end
|
208
|
+
elsif slf.size(0) > 2 * edgeitems
|
209
|
+
start = edgeitems.times.map { |i| slf[i] }
|
210
|
+
finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
|
211
|
+
Torch.stack((start + finish).map { |x| summarized_data(x) })
|
75
212
|
else
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
213
|
+
Torch.stack(slf.map { |x| summarized_data(x) })
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def tensor_str_with_formatter(slf, indent, formatter, summarize)
|
218
|
+
edgeitems = PRINT_OPTS[:edgeitems]
|
219
|
+
|
220
|
+
dim = slf.dim
|
84
221
|
|
85
|
-
|
222
|
+
return scalar_str(slf, formatter) if dim == 0
|
223
|
+
return vector_str(slf, indent, formatter, summarize) if dim == 1
|
224
|
+
|
225
|
+
if summarize && slf.size(0) > 2 * edgeitems
|
226
|
+
slices = (
|
227
|
+
[edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
|
228
|
+
["..."] +
|
229
|
+
[((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
|
230
|
+
)
|
231
|
+
else
|
232
|
+
slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
|
86
233
|
end
|
234
|
+
|
235
|
+
tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
|
236
|
+
"[" + tensor_str + "]"
|
237
|
+
end
|
238
|
+
|
239
|
+
def scalar_str(slf, formatter)
|
240
|
+
formatter.format(slf)
|
241
|
+
end
|
242
|
+
|
243
|
+
def vector_str(slf, indent, formatter, summarize)
|
244
|
+
# length includes spaces and comma between elements
|
245
|
+
element_length = formatter.width + 2
|
246
|
+
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
247
|
+
char_per_line = element_length * elements_per_line
|
248
|
+
|
249
|
+
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
250
|
+
data = (
|
251
|
+
[slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
|
252
|
+
[" ..."] +
|
253
|
+
[slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
|
254
|
+
)
|
255
|
+
else
|
256
|
+
data = slf.map { |val| formatter.format(val) }
|
257
|
+
end
|
258
|
+
|
259
|
+
data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
|
260
|
+
lines = data_lines.map { |line| line.join(", ") }
|
261
|
+
"[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
|
87
262
|
end
|
88
263
|
end
|
89
264
|
end
|