torch-rb 0.2.6 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
4
- data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
3
+ metadata.gz: f17982ebcf982b779b2c88dc0ef13a9c94b2e9a6a87007a0c3136ad5f2ef261b
4
+ data.tar.gz: d4ddddfc7cbb9baee0e117e36decfe8a757903f876ddcf510db4db53363f7adf
5
5
  SHA512:
6
- metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
7
- data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
6
+ metadata.gz: 00ff39bb405350f0107974b1786bf14ac6ca558744f05653ee2f16445d46f66658f82dcfa069bd3a92b44c33ba642dd196fc29a21379f188111fd7e8648f5eab
7
+ data.tar.gz: b66d48682789f71032c1d928174829791df43a04e829ebc241cafab884bc076983323bdffbcb12f5785c6d4614c13d2fed32bb5f20106957a4d6326289b7a1ea
@@ -1,3 +1,32 @@
1
+ ## 0.3.3 (2020-08-25)
2
+
3
+ - Added spectral ops
4
+ - Fixed tensor indexing
5
+
6
+ ## 0.3.2 (2020-08-24)
7
+
8
+ - Added `enable_grad` method
9
+ - Added `random_split` method
10
+ - Added `collate_fn` option to `DataLoader`
11
+ - Added `grad=` method to `Tensor`
12
+ - Fixed error with `grad` method when empty
13
+ - Fixed `EmbeddingBag`
14
+
15
+ ## 0.3.1 (2020-08-17)
16
+
17
+ - Added `create_graph` and `retain_graph` options to `backward` method
18
+ - Fixed error when `set` not required
19
+
20
+ ## 0.3.0 (2020-07-29)
21
+
22
+ - Updated LibTorch to 1.6.0
23
+ - Removed `state_dict` method from optimizers until `load_state_dict` is implemented
24
+
25
+ ## 0.2.7 (2020-06-29)
26
+
27
+ - Made tensors enumerable
28
+ - Improved performance of `inspect` method
29
+
1
30
  ## 0.2.6 (2020-06-29)
2
31
 
3
32
  - Added support for indexing with tensors
@@ -38,7 +67,7 @@
38
67
  ## 0.2.0 (2020-04-22)
39
68
 
40
69
  - No longer experimental
41
- - Updated libtorch to 1.5.0
70
+ - Updated LibTorch to 1.5.0
42
71
  - Added support for GPUs and OpenMP
43
72
  - Added adaptive pooling layers
44
73
  - Tensor `dtype` is now based on Numo type for `Torch.tensor`
@@ -47,7 +76,7 @@
47
76
 
48
77
  ## 0.1.8 (2020-01-17)
49
78
 
50
- - Updated libtorch to 1.4.0
79
+ - Updated LibTorch to 1.4.0
51
80
 
52
81
  ## 0.1.7 (2020-01-10)
53
82
 
data/README.md CHANGED
@@ -2,7 +2,11 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
5
+ Check out:
6
+
7
+ - [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
8
+ - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
+ - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
6
10
 
7
11
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
8
12
 
@@ -42,6 +46,8 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
42
46
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
43
47
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
44
48
 
49
+ You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
50
+
45
51
  ## Tutorial
46
52
 
47
53
  Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
@@ -409,7 +415,8 @@ Here’s the list of compatible versions.
409
415
 
410
416
  Torch.rb | LibTorch
411
417
  --- | ---
412
- 0.2.0+ | 1.5.0+
418
+ 0.3.0-0.3.1 | 1.6.0
419
+ 0.2.0-0.2.7 | 1.5.0-1.5.1
413
420
  0.1.8 | 1.4.0
414
421
  0.1.0-0.1.7 | 1.3.1
415
422
 
@@ -16,6 +16,7 @@
16
16
  #include "nn_functions.hpp"
17
17
 
18
18
  using namespace Rice;
19
+ using torch::indexing::TensorIndex;
19
20
 
20
21
  // need to make a distinction between parameters and tensors
21
22
  class Parameter: public torch::autograd::Variable {
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
28
29
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
30
  }
30
31
 
32
+ std::vector<TensorIndex> index_vector(Array a) {
33
+ auto indices = std::vector<TensorIndex>();
34
+ indices.reserve(a.size());
35
+ for (size_t i = 0; i < a.size(); i++) {
36
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
37
+ }
38
+ return indices;
39
+ }
40
+
31
41
  extern "C"
32
42
  void Init_ext()
33
43
  {
@@ -48,15 +58,23 @@ void Init_ext()
48
58
  .define_singleton_method(
49
59
  "initial_seed",
50
60
  *[]() {
51
- return at::detail::getDefaultCPUGenerator()->current_seed();
61
+ return at::detail::getDefaultCPUGenerator().current_seed();
52
62
  })
53
63
  .define_singleton_method(
54
64
  "seed",
55
65
  *[]() {
56
66
  // TODO set for CUDA when available
57
- return at::detail::getDefaultCPUGenerator()->seed();
67
+ auto generator = at::detail::getDefaultCPUGenerator();
68
+ return generator.seed();
58
69
  });
59
70
 
71
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
+
60
78
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
61
79
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
62
80
  .add_handler<torch::Error>(handle_error)
@@ -329,6 +347,18 @@ void Init_ext()
329
347
  .define_method("numel", &torch::Tensor::numel)
330
348
  .define_method("element_size", &torch::Tensor::element_size)
331
349
  .define_method("requires_grad", &torch::Tensor::requires_grad)
350
+ .define_method(
351
+ "_index",
352
+ *[](Tensor& self, Array indices) {
353
+ auto vec = index_vector(indices);
354
+ return self.index(vec);
355
+ })
356
+ .define_method(
357
+ "_index_put_custom",
358
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
359
+ auto vec = index_vector(indices);
360
+ return self.index_put_(vec, value);
361
+ })
332
362
  .define_method(
333
363
  "contiguous?",
334
364
  *[](Tensor& self) {
@@ -351,13 +381,19 @@ void Init_ext()
351
381
  })
352
382
  .define_method(
353
383
  "_backward",
354
- *[](Tensor& self, Object gradient) {
355
- return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
384
+ *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
385
+ return self.backward(gradient, create_graph, retain_graph);
356
386
  })
357
387
  .define_method(
358
388
  "grad",
359
389
  *[](Tensor& self) {
360
- return self.grad();
390
+ auto grad = self.grad();
391
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
392
+ })
393
+ .define_method(
394
+ "grad=",
395
+ *[](Tensor& self, torch::Tensor& grad) {
396
+ self.grad() = grad;
361
397
  })
362
398
  .define_method(
363
399
  "_dtype",
@@ -460,7 +496,7 @@ void Init_ext()
460
496
  .define_singleton_method(
461
497
  "_make_subclass",
462
498
  *[](Tensor& rd, bool requires_grad) {
463
- auto data = torch::autograd::as_variable_ref(rd).detach();
499
+ auto data = rd.detach();
464
500
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
465
501
  auto var = data.set_requires_grad(requires_grad);
466
502
  return Parameter(std::move(var));
@@ -501,6 +537,7 @@ void Init_ext()
501
537
  });
502
538
 
503
539
  Module rb_mInit = define_module_under(rb_mNN, "Init")
540
+ .add_handler<torch::Error>(handle_error)
504
541
  .define_singleton_method(
505
542
  "_calculate_gain",
506
543
  *[](NonlinearityType nonlinearity, double param) {
@@ -579,11 +616,16 @@ void Init_ext()
579
616
  *[](Parameter& self) {
580
617
  auto grad = self.grad();
581
618
  return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
619
+ })
620
+ .define_method(
621
+ "grad=",
622
+ *[](Parameter& self, torch::Tensor& grad) {
623
+ self.grad() = grad;
582
624
  });
583
625
 
584
626
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
585
- .define_constructor(Constructor<torch::Device, std::string>())
586
627
  .add_handler<torch::Error>(handle_error)
628
+ .define_constructor(Constructor<torch::Device, std::string>())
587
629
  .define_method("index", &torch::Device::index)
588
630
  .define_method("index?", &torch::Device::has_index)
589
631
  .define_method(
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
8
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
- # TODO check compiler name
11
- clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
10
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
12
11
 
13
12
  # check omp first
14
13
  if have_library("omp") || have_library("gomp")
15
14
  $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS += " -Xclang" if clang
15
+ $CXXFLAGS += " -Xclang" if apple_clang
17
16
  $CXXFLAGS += " -fopenmp"
18
17
  end
19
18
 
20
- if clang
19
+ if apple_clang
21
20
  # silence ruby/intern.h warning
22
21
  $CXXFLAGS += " -Wno-deprecated-register"
23
22
 
@@ -9,6 +9,10 @@
9
9
 
10
10
  using namespace Rice;
11
11
 
12
+ using torch::Device;
13
+ using torch::ScalarType;
14
+ using torch::Tensor;
15
+
12
16
  // need to wrap torch::IntArrayRef() since
13
17
  // it doesn't own underlying data
14
18
  class IntArrayRef {
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
174
178
  return MyReduction(x);
175
179
  }
176
180
 
177
- typedef torch::Tensor Tensor;
178
-
179
181
  class OptionalTensor {
180
182
  Object value;
181
183
  public:
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
197
199
  return OptionalTensor(x);
198
200
  }
199
201
 
200
- class ScalarType {
201
- Object value;
202
- public:
203
- ScalarType(Object o) {
204
- value = o;
205
- }
206
- operator at::ScalarType() {
207
- throw std::runtime_error("ScalarType arguments not implemented yet");
208
- }
209
- };
210
-
211
202
  template<>
212
203
  inline
213
- ScalarType from_ruby<ScalarType>(Object x)
204
+ torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
214
205
  {
215
- return ScalarType(x);
206
+ if (x.is_nil()) {
207
+ return torch::nullopt;
208
+ } else {
209
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
210
+ }
216
211
  }
217
212
 
218
- class OptionalScalarType {
219
- Object value;
220
- public:
221
- OptionalScalarType(Object o) {
222
- value = o;
223
- }
224
- operator c10::optional<at::ScalarType>() {
225
- if (value.is_nil()) {
226
- return c10::nullopt;
227
- }
228
- return ScalarType(value);
229
- }
230
- };
231
-
232
213
  template<>
233
214
  inline
234
- OptionalScalarType from_ruby<OptionalScalarType>(Object x)
215
+ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
235
216
  {
236
- return OptionalScalarType(x);
217
+ if (x.is_nil()) {
218
+ return torch::nullopt;
219
+ } else {
220
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
221
+ }
237
222
  }
238
223
 
239
- typedef torch::Device Device;
240
-
241
224
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
242
225
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
243
226
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -4,6 +4,7 @@ require "torch/ext"
4
4
  # stdlib
5
5
  require "fileutils"
6
6
  require "net/http"
7
+ require "set"
7
8
  require "tmpdir"
8
9
 
9
10
  # native functions
@@ -178,8 +179,10 @@ require "torch/nn/functional"
178
179
  require "torch/nn/init"
179
180
 
180
181
  # utils
182
+ require "torch/utils/data"
181
183
  require "torch/utils/data/data_loader"
182
184
  require "torch/utils/data/dataset"
185
+ require "torch/utils/data/subset"
183
186
  require "torch/utils/data/tensor_dataset"
184
187
 
185
188
  # hub
@@ -315,6 +318,16 @@ module Torch
315
318
  end
316
319
  end
317
320
 
321
+ def enable_grad
322
+ previous_value = grad_enabled?
323
+ begin
324
+ _set_grad_enabled(true)
325
+ yield
326
+ ensure
327
+ _set_grad_enabled(previous_value)
328
+ end
329
+ end
330
+
318
331
  def device(str)
319
332
  Device.new(str)
320
333
  end
@@ -447,6 +460,17 @@ module Torch
447
460
  zeros(input.size, **like_options(input, options))
448
461
  end
449
462
 
463
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
464
+ if center
465
+ signal_dim = input.dim
466
+ extended_shape = [1] * (3 - signal_dim) + input.size
467
+ pad = n_fft.div(2).to_i
468
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
469
+ input = input.view(input.shape[-signal_dim..-1])
470
+ end
471
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
472
+ end
473
+
450
474
  private
451
475
 
452
476
  def to_ivalue(obj)
@@ -470,11 +494,7 @@ module Torch
470
494
  when nil
471
495
  IValue.new
472
496
  when Array
473
- if obj.all? { |v| v.is_a?(Tensor) }
474
- IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
475
- else
476
- raise Error, "Unknown list type"
477
- end
497
+ IValue.from_list(obj.map { |v| to_ivalue(v) })
478
498
  else
479
499
  raise Error, "Unknown type: #{obj.class.name}"
480
500
  end
@@ -7,25 +7,26 @@ module Torch
7
7
 
8
8
  def download_url_to_file(url, dst)
9
9
  uri = URI(url)
10
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
10
+ tmp = nil
11
11
  location = nil
12
12
 
13
+ puts "Downloading #{url}..."
13
14
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
15
  request = Net::HTTP::Get.new(uri)
15
16
 
16
- puts "Downloading #{url}..."
17
- File.open(tmp, "wb") do |f|
18
- http.request(request) do |response|
19
- case response
20
- when Net::HTTPRedirection
21
- location = response["location"]
22
- when Net::HTTPSuccess
17
+ http.request(request) do |response|
18
+ case response
19
+ when Net::HTTPRedirection
20
+ location = response["location"]
21
+ when Net::HTTPSuccess
22
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
+ File.open(tmp, "wb") do |f|
23
24
  response.read_body do |chunk|
24
25
  f.write(chunk)
25
26
  end
26
- else
27
- raise Error, "Bad response"
28
27
  end
28
+ else
29
+ raise Error, "Bad response"
29
30
  end
30
31
  end
31
32
  end
@@ -1,89 +1,264 @@
1
+ # mirrors _tensor_str.py
1
2
  module Torch
2
3
  module Inspector
3
- # TODO make more performant, especially when summarizing
4
- # how? only read data that will be displayed
5
- def inspect
6
- data =
7
- if numel == 0
8
- "[]"
9
- elsif dim == 0
10
- item
4
+ PRINT_OPTS = {
5
+ precision: 4,
6
+ threshold: 1000,
7
+ edgeitems: 3,
8
+ linewidth: 80,
9
+ sci_mode: nil
10
+ }
11
+
12
+ class Formatter
13
+ def initialize(tensor)
14
+ @floating_dtype = tensor.floating_point?
15
+ @complex_dtype = tensor.complex?
16
+ @int_mode = true
17
+ @sci_mode = false
18
+ @max_width = 1
19
+
20
+ tensor_view = Torch.no_grad { tensor.reshape(-1) }
21
+
22
+ if !@floating_dtype
23
+ tensor_view.each do |value|
24
+ value_str = value.item.to_s
25
+ @max_width = [@max_width, value_str.length].max
26
+ end
11
27
  else
12
- summarize = numel > 1000
28
+ nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
29
+
30
+ # no valid number, do nothing
31
+ return if nonzero_finite_vals.numel == 0
32
+
33
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
34
+ nonzero_finite_abs = nonzero_finite_vals.abs.double
35
+ nonzero_finite_min = nonzero_finite_abs.min.double
36
+ nonzero_finite_max = nonzero_finite_abs.max.double
37
+
38
+ nonzero_finite_vals.each do |value|
39
+ if value.item != value.item.ceil
40
+ @int_mode = false
41
+ break
42
+ end
43
+ end
13
44
 
14
- if dtype == :bool
15
- fmt = "%s"
45
+ if @int_mode
46
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
47
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
48
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
49
+ @sci_mode = true
50
+ nonzero_finite_vals.each do |value|
51
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
52
+ @max_width = [@max_width, value_str.length].max
53
+ end
54
+ else
55
+ nonzero_finite_vals.each do |value|
56
+ value_str = "%.0f" % value.item
57
+ @max_width = [@max_width, value_str.length + 1].max
58
+ end
59
+ end
16
60
  else
17
- values = _flat_data
18
- abs = values.select { |v| v != 0 }.map(&:abs)
19
- max = abs.max || 1
20
- min = abs.min || 1
21
-
22
- total = 0
23
- if values.any? { |v| v < 0 }
24
- total += 1
61
+ # Check if scientific representation should be used.
62
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
63
+ @sci_mode = true
64
+ nonzero_finite_vals.each do |value|
65
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
66
+ @max_width = [@max_width, value_str.length].max
67
+ end
68
+ else
69
+ nonzero_finite_vals.each do |value|
70
+ value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
71
+ @max_width = [@max_width, value_str.length].max
72
+ end
25
73
  end
74
+ end
75
+ end
26
76
 
27
- if floating_point?
28
- sci = max > 1e8 || max < 1e-4
77
+ @sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
78
+ end
29
79
 
30
- all_int = values.all? { |v| v.finite? && v == v.to_i }
31
- decimal = all_int ? 1 : 4
80
+ def width
81
+ @max_width
82
+ end
32
83
 
33
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
84
+ def format(value)
85
+ value = value.item
34
86
 
35
- if sci
36
- fmt = "%#{total}.4e"
37
- else
38
- fmt = "%#{total}.#{decimal}f"
39
- end
40
- else
41
- total += max.to_s.size
42
- fmt = "%#{total}d"
87
+ if @floating_dtype
88
+ if @sci_mode
89
+ ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
90
+ elsif @int_mode
91
+ ret = String.new("%.0f" % value)
92
+ unless value.infinite? || value.nan?
93
+ ret += "."
43
94
  end
95
+ else
96
+ ret = "%.#{PRINT_OPTS[:precision]}f" % value
44
97
  end
98
+ elsif @complex_dtype
99
+ p = PRINT_OPTS[:precision]
100
+ raise NotImplementedYet
101
+ else
102
+ ret = value.to_s
103
+ end
104
+ # Ruby throws error when negative, Python doesn't
105
+ " " * [@max_width - ret.size, 0].max + ret
106
+ end
107
+ end
108
+
109
+ def inspect
110
+ Torch.no_grad do
111
+ str_intern(self)
112
+ end
113
+ rescue => e
114
+ # prevent stack error
115
+ puts e.backtrace.join("\n")
116
+ "Error inspecting tensor: #{e.inspect}"
117
+ end
118
+
119
+ private
120
+
121
+ # TODO update
122
+ def str_intern(slf)
123
+ prefix = "tensor("
124
+ indent = prefix.length
125
+ suffixes = []
126
+
127
+ has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
128
+
129
+ if slf.numel == 0 && !slf.sparse?
130
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
131
+ if slf.dim != 1
132
+ suffixes << "size: #{shape.inspect}"
133
+ end
45
134
 
46
- inspect_level(to_a, fmt, dim - 1, 0, summarize)
135
+ # In an empty tensor, there are no elements to infer if the dtype
136
+ # should be int64, so it must be shown explicitly.
137
+ if slf.dtype != :int64
138
+ suffixes << "dtype: #{slf.dtype.inspect}"
47
139
  end
140
+ tensor_str = "[]"
141
+ else
142
+ if !has_default_dtype
143
+ suffixes << "dtype: #{slf.dtype.inspect}"
144
+ end
145
+
146
+ if slf.layout != :strided
147
+ tensor_str = tensor_str(slf.to_dense, indent)
148
+ else
149
+ tensor_str = tensor_str(slf, indent)
150
+ end
151
+ end
48
152
 
49
- attributes = []
50
- if requires_grad
51
- attributes << "requires_grad: true"
153
+ if slf.layout != :strided
154
+ suffixes << "layout: #{slf.layout.inspect}"
52
155
  end
53
- if ![:float32, :int64, :bool].include?(dtype)
54
- attributes << "dtype: #{dtype.inspect}"
156
+
157
+ # TODO show grad_fn
158
+ if slf.requires_grad?
159
+ suffixes << "requires_grad: true"
55
160
  end
56
161
 
57
- "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
162
+ add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
58
163
  end
59
164
 
60
- private
165
+ def add_suffixes(tensor_str, suffixes, indent, force_newline)
166
+ tensor_strs = [tensor_str]
167
+ # rfind in Python returns -1 when not found
168
+ last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
169
+ suffixes.each do |suffix|
170
+ suffix_len = suffix.length
171
+ if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
172
+ tensor_strs << ",\n" + " " * indent + suffix
173
+ last_line_len = indent + suffix_len
174
+ force_newline = false
175
+ else
176
+ tensor_strs.append(", " + suffix)
177
+ last_line_len += suffix_len + 2
178
+ end
179
+ end
180
+ tensor_strs.append(")")
181
+ tensor_strs.join("")
182
+ end
61
183
 
62
- # TODO DRY code
63
- def inspect_level(arr, fmt, total, level, summarize)
64
- if level == total
65
- cols =
66
- if summarize && arr.size > 7
67
- arr[0..2].map { |v| fmt % v } +
68
- ["..."] +
69
- arr[-3..-1].map { |v| fmt % v }
70
- else
71
- arr.map { |v| fmt % v }
72
- end
184
+ def tensor_str(slf, indent)
185
+ return "[]" if slf.numel == 0
186
+
187
+ summarize = slf.numel > PRINT_OPTS[:threshold]
188
+
189
+ if slf.dtype == :float16 || slf.dtype == :bfloat16
190
+ slf = slf.float
191
+ end
192
+ formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
193
+ tensor_str_with_formatter(slf, indent, formatter, summarize)
194
+ end
195
+
196
+ def summarized_data(slf)
197
+ edgeitems = PRINT_OPTS[:edgeitems]
73
198
 
74
- "[#{cols.join(", ")}]"
199
+ dim = slf.dim
200
+ if dim == 0
201
+ slf
202
+ elsif dim == 1
203
+ if size(0) > 2 * edgeitems
204
+ Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
205
+ else
206
+ slf
207
+ end
208
+ elsif slf.size(0) > 2 * edgeitems
209
+ start = edgeitems.times.map { |i| slf[i] }
210
+ finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
211
+ Torch.stack((start + finish).map { |x| summarized_data(x) })
75
212
  else
76
- rows =
77
- if summarize && arr.size > 7
78
- arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
- ["..."] +
80
- arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
- else
82
- arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
- end
213
+ Torch.stack(slf.map { |x| summarized_data(x) })
214
+ end
215
+ end
216
+
217
+ def tensor_str_with_formatter(slf, indent, formatter, summarize)
218
+ edgeitems = PRINT_OPTS[:edgeitems]
219
+
220
+ dim = slf.dim
84
221
 
85
- "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
222
+ return scalar_str(slf, formatter) if dim == 0
223
+ return vector_str(slf, indent, formatter, summarize) if dim == 1
224
+
225
+ if summarize && slf.size(0) > 2 * edgeitems
226
+ slices = (
227
+ [edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
228
+ ["..."] +
229
+ [((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
230
+ )
231
+ else
232
+ slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
86
233
  end
234
+
235
+ tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
236
+ "[" + tensor_str + "]"
237
+ end
238
+
239
+ def scalar_str(slf, formatter)
240
+ formatter.format(slf)
241
+ end
242
+
243
+ def vector_str(slf, indent, formatter, summarize)
244
+ # length includes spaces and comma between elements
245
+ element_length = formatter.width + 2
246
+ elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
247
+ char_per_line = element_length * elements_per_line
248
+
249
+ if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
250
+ data = (
251
+ [slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
252
+ [" ..."] +
253
+ [slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
254
+ )
255
+ else
256
+ data = slf.map { |val| formatter.format(val) }
257
+ end
258
+
259
+ data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
260
+ lines = data_lines.map { |line| line.join(", ") }
261
+ "[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
87
262
  end
88
263
  end
89
264
  end