torch-rb 0.2.4 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 67c5a0cf556399dc32d73e8793e3aa794c181150f0f42dfa810c4b98a5acf6f2
4
- data.tar.gz: 0a23f6a42595fb9d599962e88438b964180583ead5b9cce934cc447951b4a389
3
+ metadata.gz: 06e94b492acbbdb71f9e6a11081fb043a03ae0d5c704cc79faa31dd96bde70ef
4
+ data.tar.gz: 4f38fa52d30ef9bf121204423b4d675f21dbef806b6f137152f2cf9399ddf4bb
5
5
  SHA512:
6
- metadata.gz: c0f8e9e3395d196d7ea6fa4b40d128284d768033e02f4ed7d2dc9adc985015fd0a80d601601dd97438b803b6a3bd7b81f5dbda353bb5dee4247503a24cd755d7
7
- data.tar.gz: c32a22ebbe1b4dfd77324f62a72d6a128639aac0a99d4c5255b16c606e6f961ae2c8b0dbab5012a9b21faa7409511b79a50676bc8314f181c85f90433433fa8b
6
+ metadata.gz: 2fb2613ca629a70f55009b697b15830d59c0d8fc06c1c5102917b4870cb783427fb56ecc08889c09e15c342381385f258b2a33102dc5adddf2d463d41674994d
7
+ data.tar.gz: f26a6ba91caa57a92b8b047217a35c39d1e9c4c361df77e2182053b4ab490f20792fc88dba169dae87d4a3d4ee4d69e2c779efb1fa6150b4d3f0d93e3762aec9
@@ -1,3 +1,30 @@
1
+ ## 0.3.1 (2020-08-17)
2
+
3
+ - Added `create_graph` and `retain_graph` options to `backward` method
4
+ - Fixed error when `set` not required
5
+
6
+ ## 0.3.0 (2020-07-29)
7
+
8
+ - Updated LibTorch to 1.6.0
9
+ - Removed `state_dict` method from optimizers until `load_state_dict` is implemented
10
+
11
+ ## 0.2.7 (2020-06-29)
12
+
13
+ - Made tensors enumerable
14
+ - Improved performance of `inspect` method
15
+
16
+ ## 0.2.6 (2020-06-29)
17
+
18
+ - Added support for indexing with tensors
19
+ - Added `contiguous` methods
20
+ - Fixed named parameters for nested parameters
21
+
22
+ ## 0.2.5 (2020-06-07)
23
+
24
+ - Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
25
+ - Improved error messages
26
+ - Fixed tensor slicing
27
+
1
28
  ## 0.2.4 (2020-04-29)
2
29
 
3
30
  - Added `to_i` and `to_f` to tensors
@@ -26,7 +53,7 @@
26
53
  ## 0.2.0 (2020-04-22)
27
54
 
28
55
  - No longer experimental
29
- - Updated libtorch to 1.5.0
56
+ - Updated LibTorch to 1.5.0
30
57
  - Added support for GPUs and OpenMP
31
58
  - Added adaptive pooling layers
32
59
  - Tensor `dtype` is now based on Numo type for `Torch.tensor`
@@ -35,7 +62,7 @@
35
62
 
36
63
  ## 0.1.8 (2020-01-17)
37
64
 
38
- - Updated libtorch to 1.4.0
65
+ - Updated LibTorch to 1.4.0
39
66
 
40
67
  ## 0.1.7 (2020-01-10)
41
68
 
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
+ For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
6
+
5
7
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
6
8
 
7
9
  ## Installation
@@ -22,12 +24,26 @@ It can take a few minutes to compile the extension.
22
24
 
23
25
  ## Getting Started
24
26
 
27
+ Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
28
+
29
+ We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
30
+
31
+ ```text
32
+ ankane/ml-stack:torch-gpu
33
+ ```
34
+
35
+ And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
36
+
37
+ ## API
38
+
25
39
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
26
40
 
27
41
  - Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
28
42
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
43
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
44
 
45
+ You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
46
+
31
47
  ## Tutorial
32
48
 
33
49
  Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
@@ -192,7 +208,7 @@ end
192
208
  Define a neural network
193
209
 
194
210
  ```ruby
195
- class Net < Torch::NN::Module
211
+ class MyNet < Torch::NN::Module
196
212
  def initialize
197
213
  super
198
214
  @conv1 = Torch::NN::Conv2d.new(1, 6, 3)
@@ -226,7 +242,7 @@ end
226
242
  Create an instance of it
227
243
 
228
244
  ```ruby
229
- net = Net.new
245
+ net = MyNet.new
230
246
  input = Torch.randn(1, 1, 32, 32)
231
247
  net.call(input)
232
248
  ```
@@ -294,7 +310,7 @@ Torch.save(net.state_dict, "net.pth")
294
310
  Load a model
295
311
 
296
312
  ```ruby
297
- net = Net.new
313
+ net = MyNet.new
298
314
  net.load_state_dict(Torch.load("net.pth"))
299
315
  net.eval
300
316
  ```
@@ -395,7 +411,8 @@ Here’s the list of compatible versions.
395
411
 
396
412
  Torch.rb | LibTorch
397
413
  --- | ---
398
- 0.2.0 | 1.5.0
414
+ 0.3.0-0.3.1 | 1.6.0
415
+ 0.2.0-0.2.7 | 1.5.0-1.5.1
399
416
  0.1.8 | 1.4.0
400
417
  0.1.0-0.1.7 | 1.3.1
401
418
 
@@ -413,9 +430,7 @@ Then install the gem (no need for `bundle config`).
413
430
 
414
431
  ### Linux
415
432
 
416
- Deep learning is significantly faster on GPUs.
417
-
418
- Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
433
+ Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
419
434
 
420
435
  Check if CUDA is available
421
436
 
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
23
23
  Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
24
24
  };
25
25
 
26
- void handle_error(c10::Error const & ex)
26
+ void handle_error(torch::Error const & ex)
27
27
  {
28
28
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
29
  }
@@ -32,29 +32,35 @@ extern "C"
32
32
  void Init_ext()
33
33
  {
34
34
  Module rb_mTorch = define_module("Torch");
35
+ rb_mTorch.add_handler<torch::Error>(handle_error);
35
36
  add_torch_functions(rb_mTorch);
36
37
 
37
38
  Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
39
+ rb_cTensor.add_handler<torch::Error>(handle_error);
38
40
  add_tensor_functions(rb_cTensor);
39
41
 
40
42
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
43
+ rb_mNN.add_handler<torch::Error>(handle_error);
41
44
  add_nn_functions(rb_mNN);
42
45
 
43
46
  Module rb_mRandom = define_module_under(rb_mTorch, "Random")
47
+ .add_handler<torch::Error>(handle_error)
44
48
  .define_singleton_method(
45
49
  "initial_seed",
46
50
  *[]() {
47
- return at::detail::getDefaultCPUGenerator()->current_seed();
51
+ return at::detail::getDefaultCPUGenerator().current_seed();
48
52
  })
49
53
  .define_singleton_method(
50
54
  "seed",
51
55
  *[]() {
52
56
  // TODO set for CUDA when available
53
- return at::detail::getDefaultCPUGenerator()->seed();
57
+ auto generator = at::detail::getDefaultCPUGenerator();
58
+ return generator.seed();
54
59
  });
55
60
 
56
61
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
57
62
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
63
+ .add_handler<torch::Error>(handle_error)
58
64
  .define_constructor(Constructor<torch::IValue>())
59
65
  .define_method("bool?", &torch::IValue::isBool)
60
66
  .define_method("bool_list?", &torch::IValue::isBoolList)
@@ -317,7 +323,6 @@ void Init_ext()
317
323
  });
318
324
 
319
325
  rb_cTensor
320
- .add_handler<c10::Error>(handle_error)
321
326
  .define_method("cuda?", &torch::Tensor::is_cuda)
322
327
  .define_method("sparse?", &torch::Tensor::is_sparse)
323
328
  .define_method("quantized?", &torch::Tensor::is_quantized)
@@ -325,6 +330,11 @@ void Init_ext()
325
330
  .define_method("numel", &torch::Tensor::numel)
326
331
  .define_method("element_size", &torch::Tensor::element_size)
327
332
  .define_method("requires_grad", &torch::Tensor::requires_grad)
333
+ .define_method(
334
+ "contiguous?",
335
+ *[](Tensor& self) {
336
+ return self.is_contiguous();
337
+ })
328
338
  .define_method(
329
339
  "addcmul!",
330
340
  *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
@@ -342,8 +352,8 @@ void Init_ext()
342
352
  })
343
353
  .define_method(
344
354
  "_backward",
345
- *[](Tensor& self, Object gradient) {
346
- return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
355
+ *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
356
+ return self.backward(gradient, create_graph, retain_graph);
347
357
  })
348
358
  .define_method(
349
359
  "grad",
@@ -374,6 +384,21 @@ void Init_ext()
374
384
  s << self.device();
375
385
  return s.str();
376
386
  })
387
+ .define_method(
388
+ "_data_str",
389
+ *[](Tensor& self) {
390
+ Tensor tensor = self;
391
+
392
+ // move to CPU to get data
393
+ if (tensor.device().type() != torch::kCPU) {
394
+ torch::Device device("cpu");
395
+ tensor = tensor.to(device);
396
+ }
397
+
398
+ auto data_ptr = (const char *) tensor.data_ptr();
399
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
400
+ })
401
+ // TODO figure out a better way to do this
377
402
  .define_method(
378
403
  "_flat_data",
379
404
  *[](Tensor& self) {
@@ -388,46 +413,40 @@ void Init_ext()
388
413
  Array a;
389
414
  auto dtype = tensor.dtype();
390
415
 
416
+ Tensor view = tensor.reshape({tensor.numel()});
417
+
391
418
  // TODO DRY if someone knows C++
392
419
  if (dtype == torch::kByte) {
393
- uint8_t* data = tensor.data_ptr<uint8_t>();
394
420
  for (int i = 0; i < tensor.numel(); i++) {
395
- a.push(data[i]);
421
+ a.push(view[i].item().to<uint8_t>());
396
422
  }
397
423
  } else if (dtype == torch::kChar) {
398
- int8_t* data = tensor.data_ptr<int8_t>();
399
424
  for (int i = 0; i < tensor.numel(); i++) {
400
- a.push(to_ruby<int>(data[i]));
425
+ a.push(to_ruby<int>(view[i].item().to<int8_t>()));
401
426
  }
402
427
  } else if (dtype == torch::kShort) {
403
- int16_t* data = tensor.data_ptr<int16_t>();
404
428
  for (int i = 0; i < tensor.numel(); i++) {
405
- a.push(data[i]);
429
+ a.push(view[i].item().to<int16_t>());
406
430
  }
407
431
  } else if (dtype == torch::kInt) {
408
- int32_t* data = tensor.data_ptr<int32_t>();
409
432
  for (int i = 0; i < tensor.numel(); i++) {
410
- a.push(data[i]);
433
+ a.push(view[i].item().to<int32_t>());
411
434
  }
412
435
  } else if (dtype == torch::kLong) {
413
- int64_t* data = tensor.data_ptr<int64_t>();
414
436
  for (int i = 0; i < tensor.numel(); i++) {
415
- a.push(data[i]);
437
+ a.push(view[i].item().to<int64_t>());
416
438
  }
417
439
  } else if (dtype == torch::kFloat) {
418
- float* data = tensor.data_ptr<float>();
419
440
  for (int i = 0; i < tensor.numel(); i++) {
420
- a.push(data[i]);
441
+ a.push(view[i].item().to<float>());
421
442
  }
422
443
  } else if (dtype == torch::kDouble) {
423
- double* data = tensor.data_ptr<double>();
424
444
  for (int i = 0; i < tensor.numel(); i++) {
425
- a.push(data[i]);
445
+ a.push(view[i].item().to<double>());
426
446
  }
427
447
  } else if (dtype == torch::kBool) {
428
- bool* data = tensor.data_ptr<bool>();
429
448
  for (int i = 0; i < tensor.numel(); i++) {
430
- a.push(data[i] ? True : False);
449
+ a.push(view[i].item().to<bool>() ? True : False);
431
450
  }
432
451
  } else {
433
452
  throw std::runtime_error("Unsupported type");
@@ -442,14 +461,14 @@ void Init_ext()
442
461
  .define_singleton_method(
443
462
  "_make_subclass",
444
463
  *[](Tensor& rd, bool requires_grad) {
445
- auto data = torch::autograd::as_variable_ref(rd).detach();
464
+ auto data = rd.detach();
446
465
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
447
466
  auto var = data.set_requires_grad(requires_grad);
448
467
  return Parameter(std::move(var));
449
468
  });
450
469
 
451
470
  Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
452
- .add_handler<c10::Error>(handle_error)
471
+ .add_handler<torch::Error>(handle_error)
453
472
  .define_constructor(Constructor<torch::TensorOptions>())
454
473
  .define_method(
455
474
  "dtype",
@@ -555,6 +574,7 @@ void Init_ext()
555
574
  });
556
575
 
557
576
  Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
577
+ .add_handler<torch::Error>(handle_error)
558
578
  .define_method(
559
579
  "grad",
560
580
  *[](Parameter& self) {
@@ -564,6 +584,7 @@ void Init_ext()
564
584
 
565
585
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
566
586
  .define_constructor(Constructor<torch::Device, std::string>())
587
+ .add_handler<torch::Error>(handle_error)
567
588
  .define_method("index", &torch::Device::index)
568
589
  .define_method("index?", &torch::Device::has_index)
569
590
  .define_method(
@@ -575,6 +596,7 @@ void Init_ext()
575
596
  });
576
597
 
577
598
  Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
599
+ .add_handler<torch::Error>(handle_error)
578
600
  .define_singleton_method("available?", &torch::cuda::is_available)
579
601
  .define_singleton_method("device_count", &torch::cuda::device_count);
580
602
  }
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
8
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
- # TODO check compiler name
11
- clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
10
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
12
11
 
13
12
  # check omp first
14
13
  if have_library("omp") || have_library("gomp")
15
14
  $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS += " -Xclang" if clang
15
+ $CXXFLAGS += " -Xclang" if apple_clang
17
16
  $CXXFLAGS += " -fopenmp"
18
17
  end
19
18
 
20
- if clang
19
+ if apple_clang
21
20
  # silence ruby/intern.h warning
22
21
  $CXXFLAGS += " -Wno-deprecated-register"
23
22
 
@@ -1,6 +1,12 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # stdlib
5
+ require "fileutils"
6
+ require "net/http"
7
+ require "set"
8
+ require "tmpdir"
9
+
4
10
  # native functions
5
11
  require "torch/native/generator"
6
12
  require "torch/native/parser"
@@ -465,11 +471,7 @@ module Torch
465
471
  when nil
466
472
  IValue.new
467
473
  when Array
468
- if obj.all? { |v| v.is_a?(Tensor) }
469
- IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
470
- else
471
- raise Error, "Unknown list type"
472
- end
474
+ IValue.from_list(obj.map { |v| to_ivalue(v) })
473
475
  else
474
476
  raise Error, "Unknown type: #{obj.class.name}"
475
477
  end
@@ -5,12 +5,56 @@ module Torch
5
5
  raise NotImplementedYet
6
6
  end
7
7
 
8
- def download_url_to_file(url)
9
- raise NotImplementedYet
8
+ def download_url_to_file(url, dst)
9
+ uri = URI(url)
10
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
11
+ location = nil
12
+
13
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
+ request = Net::HTTP::Get.new(uri)
15
+
16
+ puts "Downloading #{url}..."
17
+ File.open(tmp, "wb") do |f|
18
+ http.request(request) do |response|
19
+ case response
20
+ when Net::HTTPRedirection
21
+ location = response["location"]
22
+ when Net::HTTPSuccess
23
+ response.read_body do |chunk|
24
+ f.write(chunk)
25
+ end
26
+ else
27
+ raise Error, "Bad response"
28
+ end
29
+ end
30
+ end
31
+ end
32
+
33
+ if location
34
+ download_url_to_file(location, dst)
35
+ else
36
+ FileUtils.mv(tmp, dst)
37
+ nil
38
+ end
10
39
  end
11
40
 
12
- def load_state_dict_from_url(url)
13
- raise NotImplementedYet
41
+ def load_state_dict_from_url(url, model_dir: nil)
42
+ unless model_dir
43
+ torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
44
+ model_dir = File.join(torch_home, "checkpoints")
45
+ end
46
+
47
+ FileUtils.mkdir_p(model_dir)
48
+
49
+ parts = URI(url)
50
+ filename = File.basename(parts.path)
51
+ cached_file = File.join(model_dir, filename)
52
+ unless File.exist?(cached_file)
53
+ # TODO support hash_prefix
54
+ download_url_to_file(url, cached_file)
55
+ end
56
+
57
+ Torch.load(cached_file)
14
58
  end
15
59
  end
16
60
  end
@@ -1,89 +1,264 @@
1
+ # mirrors _tensor_str.py
1
2
  module Torch
2
3
  module Inspector
3
- # TODO make more performance, especially when summarizing
4
- # how? only read data that will be displayed
5
- def inspect
6
- data =
7
- if numel == 0
8
- "[]"
9
- elsif dim == 0
10
- item
4
+ PRINT_OPTS = {
5
+ precision: 4,
6
+ threshold: 1000,
7
+ edgeitems: 3,
8
+ linewidth: 80,
9
+ sci_mode: nil
10
+ }
11
+
12
+ class Formatter
13
+ def initialize(tensor)
14
+ @floating_dtype = tensor.floating_point?
15
+ @complex_dtype = tensor.complex?
16
+ @int_mode = true
17
+ @sci_mode = false
18
+ @max_width = 1
19
+
20
+ tensor_view = Torch.no_grad { tensor.reshape(-1) }
21
+
22
+ if !@floating_dtype
23
+ tensor_view.each do |value|
24
+ value_str = value.item.to_s
25
+ @max_width = [@max_width, value_str.length].max
26
+ end
11
27
  else
12
- summarize = numel > 1000
28
+ nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
29
+
30
+ # no valid number, do nothing
31
+ return if nonzero_finite_vals.numel == 0
32
+
33
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
34
+ nonzero_finite_abs = nonzero_finite_vals.abs.double
35
+ nonzero_finite_min = nonzero_finite_abs.min.double
36
+ nonzero_finite_max = nonzero_finite_abs.max.double
37
+
38
+ nonzero_finite_vals.each do |value|
39
+ if value.item != value.item.ceil
40
+ @int_mode = false
41
+ break
42
+ end
43
+ end
13
44
 
14
- if dtype == :bool
15
- fmt = "%s"
45
+ if @int_mode
46
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
47
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
48
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
49
+ @sci_mode = true
50
+ nonzero_finite_vals.each do |value|
51
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
52
+ @max_width = [@max_width, value_str.length].max
53
+ end
54
+ else
55
+ nonzero_finite_vals.each do |value|
56
+ value_str = "%.0f" % value.item
57
+ @max_width = [@max_width, value_str.length + 1].max
58
+ end
59
+ end
16
60
  else
17
- values = to_a.flatten
18
- abs = values.select { |v| v != 0 }.map(&:abs)
19
- max = abs.max || 1
20
- min = abs.min || 1
21
-
22
- total = 0
23
- if values.any? { |v| v < 0 }
24
- total += 1
61
+ # Check if scientific representation should be used.
62
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
63
+ @sci_mode = true
64
+ nonzero_finite_vals.each do |value|
65
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
66
+ @max_width = [@max_width, value_str.length].max
67
+ end
68
+ else
69
+ nonzero_finite_vals.each do |value|
70
+ value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
71
+ @max_width = [@max_width, value_str.length].max
72
+ end
25
73
  end
74
+ end
75
+ end
26
76
 
27
- if floating_point?
28
- sci = max > 1e8 || max < 1e-4
77
+ @sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
78
+ end
29
79
 
30
- all_int = values.all? { |v| v.finite? && v == v.to_i }
31
- decimal = all_int ? 1 : 4
80
+ def width
81
+ @max_width
82
+ end
32
83
 
33
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
84
+ def format(value)
85
+ value = value.item
34
86
 
35
- if sci
36
- fmt = "%#{total}.4e"
37
- else
38
- fmt = "%#{total}.#{decimal}f"
39
- end
40
- else
41
- total += max.to_s.size
42
- fmt = "%#{total}d"
87
+ if @floating_dtype
88
+ if @sci_mode
89
+ ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
90
+ elsif @int_mode
91
+ ret = String.new("%.0f" % value)
92
+ unless value.infinite? || value.nan?
93
+ ret += "."
43
94
  end
95
+ else
96
+ ret = "%.#{PRINT_OPTS[:precision]}f" % value
44
97
  end
98
+ elsif @complex_dtype
99
+ p = PRINT_OPTS[:precision]
100
+ raise NotImplementedYet
101
+ else
102
+ ret = value.to_s
103
+ end
104
+ # Ruby throws error when negative, Python doesn't
105
+ " " * [@max_width - ret.size, 0].max + ret
106
+ end
107
+ end
108
+
109
+ def inspect
110
+ Torch.no_grad do
111
+ str_intern(self)
112
+ end
113
+ rescue => e
114
+ # prevent stack error
115
+ puts e.backtrace.join("\n")
116
+ "Error inspecting tensor: #{e.inspect}"
117
+ end
118
+
119
+ private
120
+
121
+ # TODO update
122
+ def str_intern(slf)
123
+ prefix = "tensor("
124
+ indent = prefix.length
125
+ suffixes = []
126
+
127
+ has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
128
+
129
+ if slf.numel == 0 && !slf.sparse?
130
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
131
+ if slf.dim != 1
132
+ suffixes << "size: #{shape.inspect}"
133
+ end
45
134
 
46
- inspect_level(to_a, fmt, dim - 1, 0, summarize)
135
+ # In an empty tensor, there are no elements to infer if the dtype
136
+ # should be int64, so it must be shown explicitly.
137
+ if slf.dtype != :int64
138
+ suffixes << "dtype: #{slf.dtype.inspect}"
47
139
  end
140
+ tensor_str = "[]"
141
+ else
142
+ if !has_default_dtype
143
+ suffixes << "dtype: #{slf.dtype.inspect}"
144
+ end
145
+
146
+ if slf.layout != :strided
147
+ tensor_str = tensor_str(slf.to_dense, indent)
148
+ else
149
+ tensor_str = tensor_str(slf, indent)
150
+ end
151
+ end
48
152
 
49
- attributes = []
50
- if requires_grad
51
- attributes << "requires_grad: true"
153
+ if slf.layout != :strided
154
+ suffixes << "layout: #{slf.layout.inspect}"
52
155
  end
53
- if ![:float32, :int64, :bool].include?(dtype)
54
- attributes << "dtype: #{dtype.inspect}"
156
+
157
+ # TODO show grad_fn
158
+ if slf.requires_grad?
159
+ suffixes << "requires_grad: true"
55
160
  end
56
161
 
57
- "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
162
+ add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
58
163
  end
59
164
 
60
- private
165
+ def add_suffixes(tensor_str, suffixes, indent, force_newline)
166
+ tensor_strs = [tensor_str]
167
+ # rfind in Python returns -1 when not found
168
+ last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
169
+ suffixes.each do |suffix|
170
+ suffix_len = suffix.length
171
+ if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
172
+ tensor_strs << ",\n" + " " * indent + suffix
173
+ last_line_len = indent + suffix_len
174
+ force_newline = false
175
+ else
176
+ tensor_strs.append(", " + suffix)
177
+ last_line_len += suffix_len + 2
178
+ end
179
+ end
180
+ tensor_strs.append(")")
181
+ tensor_strs.join("")
182
+ end
61
183
 
62
- # TODO DRY code
63
- def inspect_level(arr, fmt, total, level, summarize)
64
- if level == total
65
- cols =
66
- if summarize && arr.size > 7
67
- arr[0..2].map { |v| fmt % v } +
68
- ["..."] +
69
- arr[-3..-1].map { |v| fmt % v }
70
- else
71
- arr.map { |v| fmt % v }
72
- end
184
+ def tensor_str(slf, indent)
185
+ return "[]" if slf.numel == 0
186
+
187
+ summarize = slf.numel > PRINT_OPTS[:threshold]
188
+
189
+ if slf.dtype == :float16 || slf.dtype == :bfloat16
190
+ slf = slf.float
191
+ end
192
+ formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
193
+ tensor_str_with_formatter(slf, indent, formatter, summarize)
194
+ end
195
+
196
+ def summarized_data(slf)
197
+ edgeitems = PRINT_OPTS[:edgeitems]
73
198
 
74
- "[#{cols.join(", ")}]"
199
+ dim = slf.dim
200
+ if dim == 0
201
+ slf
202
+ elsif dim == 1
203
+ if size(0) > 2 * edgeitems
204
+ Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
205
+ else
206
+ slf
207
+ end
208
+ elsif slf.size(0) > 2 * edgeitems
209
+ start = edgeitems.times.map { |i| slf[i] }
210
+ finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
211
+ Torch.stack((start + finish).map { |x| summarized_data(x) })
75
212
  else
76
- rows =
77
- if summarize && arr.size > 7
78
- arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
- ["..."] +
80
- arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
- else
82
- arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
- end
213
+ Torch.stack(slf.map { |x| summarized_data(x) })
214
+ end
215
+ end
216
+
217
+ def tensor_str_with_formatter(slf, indent, formatter, summarize)
218
+ edgeitems = PRINT_OPTS[:edgeitems]
219
+
220
+ dim = slf.dim
84
221
 
85
- "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
222
+ return scalar_str(slf, formatter) if dim == 0
223
+ return vector_str(slf, indent, formatter, summarize) if dim == 1
224
+
225
+ if summarize && slf.size(0) > 2 * edgeitems
226
+ slices = (
227
+ [edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
228
+ ["..."] +
229
+ [((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
230
+ )
231
+ else
232
+ slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
86
233
  end
234
+
235
+ tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
236
+ "[" + tensor_str + "]"
237
+ end
238
+
239
+ def scalar_str(slf, formatter)
240
+ formatter.format(slf)
241
+ end
242
+
243
+ def vector_str(slf, indent, formatter, summarize)
244
+ # length includes spaces and comma between elements
245
+ element_length = formatter.width + 2
246
+ elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
247
+ char_per_line = element_length * elements_per_line
248
+
249
+ if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
250
+ data = (
251
+ [slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
252
+ [" ..."] +
253
+ [slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
254
+ )
255
+ else
256
+ data = slf.map { |val| formatter.format(val) }
257
+ end
258
+
259
+ data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
260
+ lines = data_lines.map { |line| line.join(", ") }
261
+ "[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
87
262
  end
88
263
  end
89
264
  end