torch-rb 0.2.3 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: cff805041122544d87342923649010da6981fbdb6d47c73da8cc623ba3856af5
4
- data.tar.gz: 257c16cbdbc915fe30e7cba5fc0d24ce0337cf88dd420dfaa3a13b8437e04164
3
+ metadata.gz: 33636e58063f25c2b9f122d29332e4136bb6a4de0fd227349f75d65a9db94931
4
+ data.tar.gz: 9349dd0b050a4c9e0714d92bb451bdd916e55fb47a5c4d90a74720d53564a1d6
5
5
  SHA512:
6
- metadata.gz: 36e2e671f3400fdaa513cfa2dd9d07b839b120cd848dc0e28bf8723570c554e6a96e1d4f29f33a1e995e6eb57e6042299321b8903f657006d2c04c10cecc59c2
7
- data.tar.gz: ed1b17bf30ba5b4342350cf41cc5aa19b64c83cef44fe4f0122874805eabe9e91f31e0d3bc046812fa7ef07031de81f442fe18c231be2324d4da6f677325a54a
6
+ metadata.gz: 692e8dc3531426377413fc9325c2b03dd7fcbbbce0c05cbd5d7c3182a08bbe733bb9a6b0aa62a56fb27917073e1f1f5859aa0dd3f6f40e74043bba242ede6267
7
+ data.tar.gz: f57d0411c18c7c4753f5edc82e48b666ad3154c72bf189a8ec2c4dceb08cd1f37a9421b9a6eae3d20e10c4c9236a1136d76559a31989ce22868a9f44ef3e0e66
@@ -1,3 +1,31 @@
1
+ ## 0.3.0 (2020-07-29)
2
+
3
+ - Updated LibTorch to 1.6.0
4
+ - Removed `state_dict` method from optimizers until `load_state_dict` is implemented
5
+
6
+ ## 0.2.7 (2020-06-29)
7
+
8
+ - Made tensors enumerable
9
+ - Improved performance of `inspect` method
10
+
11
+ ## 0.2.6 (2020-06-29)
12
+
13
+ - Added support for indexing with tensors
14
+ - Added `contiguous` methods
15
+ - Fixed named parameters for nested parameters
16
+
17
+ ## 0.2.5 (2020-06-07)
18
+
19
+ - Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
20
+ - Improved error messages
21
+ - Fixed tensor slicing
22
+
23
+ ## 0.2.4 (2020-04-29)
24
+
25
+ - Added `to_i` and `to_f` to tensors
26
+ - Added `shuffle` option to data loader
27
+ - Fixed `modules` and `named_modules` for nested modules
28
+
1
29
  ## 0.2.3 (2020-04-28)
2
30
 
3
31
  - Added `show_config` and `parallel_info` methods
@@ -20,7 +48,7 @@
20
48
  ## 0.2.0 (2020-04-22)
21
49
 
22
50
  - No longer experimental
23
- - Updated libtorch to 1.5.0
51
+ - Updated LibTorch to 1.5.0
24
52
  - Added support for GPUs and OpenMP
25
53
  - Added adaptive pooling layers
26
54
  - Tensor `dtype` is now based on Numo type for `Torch.tensor`
@@ -29,7 +57,7 @@
29
57
 
30
58
  ## 0.1.8 (2020-01-17)
31
59
 
32
- - Updated libtorch to 1.4.0
60
+ - Updated LibTorch to 1.4.0
33
61
 
34
62
  ## 0.1.7 (2020-01-10)
35
63
 
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
+ For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
6
+
5
7
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
6
8
 
7
9
  ## Installation
@@ -22,12 +24,26 @@ It can take a few minutes to compile the extension.
22
24
 
23
25
  ## Getting Started
24
26
 
27
+ Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
28
+
29
+ We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
30
+
31
+ ```text
32
+ ankane/ml-stack:torch-gpu
33
+ ```
34
+
35
+ And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
36
+
37
+ ## API
38
+
25
39
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
26
40
 
27
41
  - Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
28
42
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
43
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
44
 
45
+ You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
46
+
31
47
  ## Tutorial
32
48
 
33
49
  Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
@@ -192,7 +208,7 @@ end
192
208
  Define a neural network
193
209
 
194
210
  ```ruby
195
- class Net < Torch::NN::Module
211
+ class MyNet < Torch::NN::Module
196
212
  def initialize
197
213
  super
198
214
  @conv1 = Torch::NN::Conv2d.new(1, 6, 3)
@@ -226,7 +242,7 @@ end
226
242
  Create an instance of it
227
243
 
228
244
  ```ruby
229
- net = Net.new
245
+ net = MyNet.new
230
246
  input = Torch.randn(1, 1, 32, 32)
231
247
  net.call(input)
232
248
  ```
@@ -294,7 +310,7 @@ Torch.save(net.state_dict, "net.pth")
294
310
  Load a model
295
311
 
296
312
  ```ruby
297
- net = Net.new
313
+ net = MyNet.new
298
314
  net.load_state_dict(Torch.load("net.pth"))
299
315
  net.eval
300
316
  ```
@@ -395,7 +411,8 @@ Here’s the list of compatible versions.
395
411
 
396
412
  Torch.rb | LibTorch
397
413
  --- | ---
398
- 0.2.0 | 1.5.0
414
+ 0.3.0 | 1.6.0
415
+ 0.2.0-0.2.7 | 1.5.0-1.5.1
399
416
  0.1.8 | 1.4.0
400
417
  0.1.0-0.1.7 | 1.3.1
401
418
 
@@ -413,9 +430,7 @@ Then install the gem (no need for `bundle config`).
413
430
 
414
431
  ### Linux
415
432
 
416
- Deep learning is significantly faster on GPUs.
417
-
418
- Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
433
+ Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
419
434
 
420
435
  Check if CUDA is available
421
436
 
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
23
23
  Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
24
24
  };
25
25
 
26
- void handle_error(c10::Error const & ex)
26
+ void handle_error(torch::Error const & ex)
27
27
  {
28
28
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
29
  }
@@ -32,29 +32,35 @@ extern "C"
32
32
  void Init_ext()
33
33
  {
34
34
  Module rb_mTorch = define_module("Torch");
35
+ rb_mTorch.add_handler<torch::Error>(handle_error);
35
36
  add_torch_functions(rb_mTorch);
36
37
 
37
38
  Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
39
+ rb_cTensor.add_handler<torch::Error>(handle_error);
38
40
  add_tensor_functions(rb_cTensor);
39
41
 
40
42
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
43
+ rb_mNN.add_handler<torch::Error>(handle_error);
41
44
  add_nn_functions(rb_mNN);
42
45
 
43
46
  Module rb_mRandom = define_module_under(rb_mTorch, "Random")
47
+ .add_handler<torch::Error>(handle_error)
44
48
  .define_singleton_method(
45
49
  "initial_seed",
46
50
  *[]() {
47
- return at::detail::getDefaultCPUGenerator()->current_seed();
51
+ return at::detail::getDefaultCPUGenerator().current_seed();
48
52
  })
49
53
  .define_singleton_method(
50
54
  "seed",
51
55
  *[]() {
52
56
  // TODO set for CUDA when available
53
- return at::detail::getDefaultCPUGenerator()->seed();
57
+ auto generator = at::detail::getDefaultCPUGenerator();
58
+ return generator.seed();
54
59
  });
55
60
 
56
61
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
57
62
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
63
+ .add_handler<torch::Error>(handle_error)
58
64
  .define_constructor(Constructor<torch::IValue>())
59
65
  .define_method("bool?", &torch::IValue::isBool)
60
66
  .define_method("bool_list?", &torch::IValue::isBoolList)
@@ -317,7 +323,6 @@ void Init_ext()
317
323
  });
318
324
 
319
325
  rb_cTensor
320
- .add_handler<c10::Error>(handle_error)
321
326
  .define_method("cuda?", &torch::Tensor::is_cuda)
322
327
  .define_method("sparse?", &torch::Tensor::is_sparse)
323
328
  .define_method("quantized?", &torch::Tensor::is_quantized)
@@ -325,6 +330,11 @@ void Init_ext()
325
330
  .define_method("numel", &torch::Tensor::numel)
326
331
  .define_method("element_size", &torch::Tensor::element_size)
327
332
  .define_method("requires_grad", &torch::Tensor::requires_grad)
333
+ .define_method(
334
+ "contiguous?",
335
+ *[](Tensor& self) {
336
+ return self.is_contiguous();
337
+ })
328
338
  .define_method(
329
339
  "addcmul!",
330
340
  *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
@@ -374,6 +384,21 @@ void Init_ext()
374
384
  s << self.device();
375
385
  return s.str();
376
386
  })
387
+ .define_method(
388
+ "_data_str",
389
+ *[](Tensor& self) {
390
+ Tensor tensor = self;
391
+
392
+ // move to CPU to get data
393
+ if (tensor.device().type() != torch::kCPU) {
394
+ torch::Device device("cpu");
395
+ tensor = tensor.to(device);
396
+ }
397
+
398
+ auto data_ptr = (const char *) tensor.data_ptr();
399
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
400
+ })
401
+ // TODO figure out a better way to do this
377
402
  .define_method(
378
403
  "_flat_data",
379
404
  *[](Tensor& self) {
@@ -388,46 +413,40 @@ void Init_ext()
388
413
  Array a;
389
414
  auto dtype = tensor.dtype();
390
415
 
416
+ Tensor view = tensor.reshape({tensor.numel()});
417
+
391
418
  // TODO DRY if someone knows C++
392
419
  if (dtype == torch::kByte) {
393
- uint8_t* data = tensor.data_ptr<uint8_t>();
394
420
  for (int i = 0; i < tensor.numel(); i++) {
395
- a.push(data[i]);
421
+ a.push(view[i].item().to<uint8_t>());
396
422
  }
397
423
  } else if (dtype == torch::kChar) {
398
- int8_t* data = tensor.data_ptr<int8_t>();
399
424
  for (int i = 0; i < tensor.numel(); i++) {
400
- a.push(to_ruby<int>(data[i]));
425
+ a.push(to_ruby<int>(view[i].item().to<int8_t>()));
401
426
  }
402
427
  } else if (dtype == torch::kShort) {
403
- int16_t* data = tensor.data_ptr<int16_t>();
404
428
  for (int i = 0; i < tensor.numel(); i++) {
405
- a.push(data[i]);
429
+ a.push(view[i].item().to<int16_t>());
406
430
  }
407
431
  } else if (dtype == torch::kInt) {
408
- int32_t* data = tensor.data_ptr<int32_t>();
409
432
  for (int i = 0; i < tensor.numel(); i++) {
410
- a.push(data[i]);
433
+ a.push(view[i].item().to<int32_t>());
411
434
  }
412
435
  } else if (dtype == torch::kLong) {
413
- int64_t* data = tensor.data_ptr<int64_t>();
414
436
  for (int i = 0; i < tensor.numel(); i++) {
415
- a.push(data[i]);
437
+ a.push(view[i].item().to<int64_t>());
416
438
  }
417
439
  } else if (dtype == torch::kFloat) {
418
- float* data = tensor.data_ptr<float>();
419
440
  for (int i = 0; i < tensor.numel(); i++) {
420
- a.push(data[i]);
441
+ a.push(view[i].item().to<float>());
421
442
  }
422
443
  } else if (dtype == torch::kDouble) {
423
- double* data = tensor.data_ptr<double>();
424
444
  for (int i = 0; i < tensor.numel(); i++) {
425
- a.push(data[i]);
445
+ a.push(view[i].item().to<double>());
426
446
  }
427
447
  } else if (dtype == torch::kBool) {
428
- bool* data = tensor.data_ptr<bool>();
429
448
  for (int i = 0; i < tensor.numel(); i++) {
430
- a.push(data[i] ? True : False);
449
+ a.push(view[i].item().to<bool>() ? True : False);
431
450
  }
432
451
  } else {
433
452
  throw std::runtime_error("Unsupported type");
@@ -442,14 +461,14 @@ void Init_ext()
442
461
  .define_singleton_method(
443
462
  "_make_subclass",
444
463
  *[](Tensor& rd, bool requires_grad) {
445
- auto data = torch::autograd::as_variable_ref(rd).detach();
464
+ auto data = rd.detach();
446
465
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
447
466
  auto var = data.set_requires_grad(requires_grad);
448
467
  return Parameter(std::move(var));
449
468
  });
450
469
 
451
470
  Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
452
- .add_handler<c10::Error>(handle_error)
471
+ .add_handler<torch::Error>(handle_error)
453
472
  .define_constructor(Constructor<torch::TensorOptions>())
454
473
  .define_method(
455
474
  "dtype",
@@ -555,6 +574,7 @@ void Init_ext()
555
574
  });
556
575
 
557
576
  Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
577
+ .add_handler<torch::Error>(handle_error)
558
578
  .define_method(
559
579
  "grad",
560
580
  *[](Parameter& self) {
@@ -564,6 +584,7 @@ void Init_ext()
564
584
 
565
585
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
566
586
  .define_constructor(Constructor<torch::Device, std::string>())
587
+ .add_handler<torch::Error>(handle_error)
567
588
  .define_method("index", &torch::Device::index)
568
589
  .define_method("index?", &torch::Device::has_index)
569
590
  .define_method(
@@ -575,6 +596,7 @@ void Init_ext()
575
596
  });
576
597
 
577
598
  Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
599
+ .add_handler<torch::Error>(handle_error)
578
600
  .define_singleton_method("available?", &torch::cuda::is_available)
579
601
  .define_singleton_method("device_count", &torch::cuda::device_count);
580
602
  }
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # stdlib
5
+ require "fileutils"
6
+ require "net/http"
7
+ require "tmpdir"
8
+
4
9
  # native functions
5
10
  require "torch/native/generator"
6
11
  require "torch/native/parser"
@@ -174,6 +179,7 @@ require "torch/nn/init"
174
179
 
175
180
  # utils
176
181
  require "torch/utils/data/data_loader"
182
+ require "torch/utils/data/dataset"
177
183
  require "torch/utils/data/tensor_dataset"
178
184
 
179
185
  # hub
@@ -464,11 +470,7 @@ module Torch
464
470
  when nil
465
471
  IValue.new
466
472
  when Array
467
- if obj.all? { |v| v.is_a?(Tensor) }
468
- IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
469
- else
470
- raise Error, "Unknown list type"
471
- end
473
+ IValue.from_list(obj.map { |v| to_ivalue(v) })
472
474
  else
473
475
  raise Error, "Unknown type: #{obj.class.name}"
474
476
  end
@@ -4,6 +4,58 @@ module Torch
4
4
  def list(github, force_reload: false)
5
5
  raise NotImplementedYet
6
6
  end
7
+
8
+ def download_url_to_file(url, dst)
9
+ uri = URI(url)
10
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
11
+ location = nil
12
+
13
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
+ request = Net::HTTP::Get.new(uri)
15
+
16
+ puts "Downloading #{url}..."
17
+ File.open(tmp, "wb") do |f|
18
+ http.request(request) do |response|
19
+ case response
20
+ when Net::HTTPRedirection
21
+ location = response["location"]
22
+ when Net::HTTPSuccess
23
+ response.read_body do |chunk|
24
+ f.write(chunk)
25
+ end
26
+ else
27
+ raise Error, "Bad response"
28
+ end
29
+ end
30
+ end
31
+ end
32
+
33
+ if location
34
+ download_url_to_file(location, dst)
35
+ else
36
+ FileUtils.mv(tmp, dst)
37
+ nil
38
+ end
39
+ end
40
+
41
+ def load_state_dict_from_url(url, model_dir: nil)
42
+ unless model_dir
43
+ torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
44
+ model_dir = File.join(torch_home, "checkpoints")
45
+ end
46
+
47
+ FileUtils.mkdir_p(model_dir)
48
+
49
+ parts = URI(url)
50
+ filename = File.basename(parts.path)
51
+ cached_file = File.join(model_dir, filename)
52
+ unless File.exist?(cached_file)
53
+ # TODO support hash_prefix
54
+ download_url_to_file(url, cached_file)
55
+ end
56
+
57
+ Torch.load(cached_file)
58
+ end
7
59
  end
8
60
  end
9
61
  end
@@ -1,89 +1,264 @@
1
+ # mirrors _tensor_str.py
1
2
  module Torch
2
3
  module Inspector
3
- # TODO make more performance, especially when summarizing
4
- # how? only read data that will be displayed
5
- def inspect
6
- data =
7
- if numel == 0
8
- "[]"
9
- elsif dim == 0
10
- item
4
+ PRINT_OPTS = {
5
+ precision: 4,
6
+ threshold: 1000,
7
+ edgeitems: 3,
8
+ linewidth: 80,
9
+ sci_mode: nil
10
+ }
11
+
12
+ class Formatter
13
+ def initialize(tensor)
14
+ @floating_dtype = tensor.floating_point?
15
+ @complex_dtype = tensor.complex?
16
+ @int_mode = true
17
+ @sci_mode = false
18
+ @max_width = 1
19
+
20
+ tensor_view = Torch.no_grad { tensor.reshape(-1) }
21
+
22
+ if !@floating_dtype
23
+ tensor_view.each do |value|
24
+ value_str = value.item.to_s
25
+ @max_width = [@max_width, value_str.length].max
26
+ end
11
27
  else
12
- summarize = numel > 1000
28
+ nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
29
+
30
+ # no valid number, do nothing
31
+ return if nonzero_finite_vals.numel == 0
32
+
33
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
34
+ nonzero_finite_abs = nonzero_finite_vals.abs.double
35
+ nonzero_finite_min = nonzero_finite_abs.min.double
36
+ nonzero_finite_max = nonzero_finite_abs.max.double
37
+
38
+ nonzero_finite_vals.each do |value|
39
+ if value.item != value.item.ceil
40
+ @int_mode = false
41
+ break
42
+ end
43
+ end
13
44
 
14
- if dtype == :bool
15
- fmt = "%s"
45
+ if @int_mode
46
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
47
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
48
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
49
+ @sci_mode = true
50
+ nonzero_finite_vals.each do |value|
51
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
52
+ @max_width = [@max_width, value_str.length].max
53
+ end
54
+ else
55
+ nonzero_finite_vals.each do |value|
56
+ value_str = "%.0f" % value.item
57
+ @max_width = [@max_width, value_str.length + 1].max
58
+ end
59
+ end
16
60
  else
17
- values = to_a.flatten
18
- abs = values.select { |v| v != 0 }.map(&:abs)
19
- max = abs.max || 1
20
- min = abs.min || 1
21
-
22
- total = 0
23
- if values.any? { |v| v < 0 }
24
- total += 1
61
+ # Check if scientific representation should be used.
62
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
63
+ @sci_mode = true
64
+ nonzero_finite_vals.each do |value|
65
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
66
+ @max_width = [@max_width, value_str.length].max
67
+ end
68
+ else
69
+ nonzero_finite_vals.each do |value|
70
+ value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
71
+ @max_width = [@max_width, value_str.length].max
72
+ end
25
73
  end
74
+ end
75
+ end
26
76
 
27
- if floating_point?
28
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
77
+ @sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
78
+ end
29
79
 
30
- all_int = values.all? { |v| v.finite? && v == v.to_i }
31
- decimal = all_int ? 1 : 4
80
+ def width
81
+ @max_width
82
+ end
32
83
 
33
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
84
+ def format(value)
85
+ value = value.item
34
86
 
35
- if sci
36
- fmt = "%#{total}.4e"
37
- else
38
- fmt = "%#{total}.#{decimal}f"
39
- end
40
- else
41
- total += max.to_s.size
42
- fmt = "%#{total}d"
87
+ if @floating_dtype
88
+ if @sci_mode
89
+ ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
90
+ elsif @int_mode
91
+ ret = String.new("%.0f" % value)
92
+ unless value.infinite? || value.nan?
93
+ ret += "."
43
94
  end
95
+ else
96
+ ret = "%.#{PRINT_OPTS[:precision]}f" % value
44
97
  end
98
+ elsif @complex_dtype
99
+ p = PRINT_OPTS[:precision]
100
+ raise NotImplementedYet
101
+ else
102
+ ret = value.to_s
103
+ end
104
+ # Ruby throws error when negative, Python doesn't
105
+ " " * [@max_width - ret.size, 0].max + ret
106
+ end
107
+ end
108
+
109
+ def inspect
110
+ Torch.no_grad do
111
+ str_intern(self)
112
+ end
113
+ rescue => e
114
+ # prevent stack error
115
+ puts e.backtrace.join("\n")
116
+ "Error inspecting tensor: #{e.inspect}"
117
+ end
118
+
119
+ private
120
+
121
+ # TODO update
122
+ def str_intern(slf)
123
+ prefix = "tensor("
124
+ indent = prefix.length
125
+ suffixes = []
126
+
127
+ has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
128
+
129
+ if slf.numel == 0 && !slf.sparse?
130
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
131
+ if slf.dim != 1
132
+ suffixes << "size: #{shape.inspect}"
133
+ end
45
134
 
46
- inspect_level(to_a, fmt, dim - 1, 0, summarize)
135
+ # In an empty tensor, there are no elements to infer if the dtype
136
+ # should be int64, so it must be shown explicitly.
137
+ if slf.dtype != :int64
138
+ suffixes << "dtype: #{slf.dtype.inspect}"
47
139
  end
140
+ tensor_str = "[]"
141
+ else
142
+ if !has_default_dtype
143
+ suffixes << "dtype: #{slf.dtype.inspect}"
144
+ end
145
+
146
+ if slf.layout != :strided
147
+ tensor_str = tensor_str(slf.to_dense, indent)
148
+ else
149
+ tensor_str = tensor_str(slf, indent)
150
+ end
151
+ end
48
152
 
49
- attributes = []
50
- if requires_grad
51
- attributes << "requires_grad: true"
153
+ if slf.layout != :strided
154
+ suffixes << "layout: #{slf.layout.inspect}"
52
155
  end
53
- if ![:float32, :int64, :bool].include?(dtype)
54
- attributes << "dtype: #{dtype.inspect}"
156
+
157
+ # TODO show grad_fn
158
+ if slf.requires_grad?
159
+ suffixes << "requires_grad: true"
55
160
  end
56
161
 
57
- "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
162
+ add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
58
163
  end
59
164
 
60
- private
165
+ def add_suffixes(tensor_str, suffixes, indent, force_newline)
166
+ tensor_strs = [tensor_str]
167
+ # rfind in Python returns -1 when not found
168
+ last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
169
+ suffixes.each do |suffix|
170
+ suffix_len = suffix.length
171
+ if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
172
+ tensor_strs << ",\n" + " " * indent + suffix
173
+ last_line_len = indent + suffix_len
174
+ force_newline = false
175
+ else
176
+ tensor_strs.append(", " + suffix)
177
+ last_line_len += suffix_len + 2
178
+ end
179
+ end
180
+ tensor_strs.append(")")
181
+ tensor_strs.join("")
182
+ end
61
183
 
62
- # TODO DRY code
63
- def inspect_level(arr, fmt, total, level, summarize)
64
- if level == total
65
- cols =
66
- if summarize && arr.size > 7
67
- arr[0..2].map { |v| fmt % v } +
68
- ["..."] +
69
- arr[-3..-1].map { |v| fmt % v }
70
- else
71
- arr.map { |v| fmt % v }
72
- end
184
+ def tensor_str(slf, indent)
185
+ return "[]" if slf.numel == 0
186
+
187
+ summarize = slf.numel > PRINT_OPTS[:threshold]
188
+
189
+ if slf.dtype == :float16 || slf.dtype == :bfloat16
190
+ slf = slf.float
191
+ end
192
+ formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
193
+ tensor_str_with_formatter(slf, indent, formatter, summarize)
194
+ end
195
+
196
+ def summarized_data(slf)
197
+ edgeitems = PRINT_OPTS[:edgeitems]
73
198
 
74
- "[#{cols.join(", ")}]"
199
+ dim = slf.dim
200
+ if dim == 0
201
+ slf
202
+ elsif dim == 1
203
+ if size(0) > 2 * edgeitems
204
+ Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
205
+ else
206
+ slf
207
+ end
208
+ elsif slf.size(0) > 2 * edgeitems
209
+ start = edgeitems.times.map { |i| slf[i] }
210
+ finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
211
+ Torch.stack((start + finish).map { |x| summarized_data(x) })
75
212
  else
76
- rows =
77
- if summarize && arr.size > 7
78
- arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
- ["..."] +
80
- arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
- else
82
- arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
- end
213
+ Torch.stack(slf.map { |x| summarized_data(x) })
214
+ end
215
+ end
216
+
217
+ def tensor_str_with_formatter(slf, indent, formatter, summarize)
218
+ edgeitems = PRINT_OPTS[:edgeitems]
219
+
220
+ dim = slf.dim
84
221
 
85
- "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
222
+ return scalar_str(slf, formatter) if dim == 0
223
+ return vector_str(slf, indent, formatter, summarize) if dim == 1
224
+
225
+ if summarize && slf.size(0) > 2 * edgeitems
226
+ slices = (
227
+ [edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
228
+ ["..."] +
229
+ [((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
230
+ )
231
+ else
232
+ slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
86
233
  end
234
+
235
+ tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
236
+ "[" + tensor_str + "]"
237
+ end
238
+
239
+ def scalar_str(slf, formatter)
240
+ formatter.format(slf)
241
+ end
242
+
243
+ def vector_str(slf, indent, formatter, summarize)
244
+ # length includes spaces and comma between elements
245
+ element_length = formatter.width + 2
246
+ elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
247
+ char_per_line = element_length * elements_per_line
248
+
249
+ if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
250
+ data = (
251
+ [slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
252
+ [" ..."] +
253
+ [slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
254
+ )
255
+ else
256
+ data = slf.map { |val| formatter.format(val) }
257
+ end
258
+
259
+ data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
260
+ lines = data_lines.map { |line| line.join(", ") }
261
+ "[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
87
262
  end
88
263
  end
89
264
  end