torch-rb 0.2.6 → 0.3.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
4
- data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
3
+ metadata.gz: f17982ebcf982b779b2c88dc0ef13a9c94b2e9a6a87007a0c3136ad5f2ef261b
4
+ data.tar.gz: d4ddddfc7cbb9baee0e117e36decfe8a757903f876ddcf510db4db53363f7adf
5
5
  SHA512:
6
- metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
7
- data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
6
+ metadata.gz: 00ff39bb405350f0107974b1786bf14ac6ca558744f05653ee2f16445d46f66658f82dcfa069bd3a92b44c33ba642dd196fc29a21379f188111fd7e8648f5eab
7
+ data.tar.gz: b66d48682789f71032c1d928174829791df43a04e829ebc241cafab884bc076983323bdffbcb12f5785c6d4614c13d2fed32bb5f20106957a4d6326289b7a1ea
@@ -1,3 +1,32 @@
1
+ ## 0.3.3 (2020-08-25)
2
+
3
+ - Added spectral ops
4
+ - Fixed tensor indexing
5
+
6
+ ## 0.3.2 (2020-08-24)
7
+
8
+ - Added `enable_grad` method
9
+ - Added `random_split` method
10
+ - Added `collate_fn` option to `DataLoader`
11
+ - Added `grad=` method to `Tensor`
12
+ - Fixed error with `grad` method when empty
13
+ - Fixed `EmbeddingBag`
14
+
15
+ ## 0.3.1 (2020-08-17)
16
+
17
+ - Added `create_graph` and `retain_graph` options to `backward` method
18
+ - Fixed error when `set` not required
19
+
20
+ ## 0.3.0 (2020-07-29)
21
+
22
+ - Updated LibTorch to 1.6.0
23
+ - Removed `state_dict` method from optimizers until `load_state_dict` is implemented
24
+
25
+ ## 0.2.7 (2020-06-29)
26
+
27
+ - Made tensors enumerable
28
+ - Improved performance of `inspect` method
29
+
1
30
  ## 0.2.6 (2020-06-29)
2
31
 
3
32
  - Added support for indexing with tensors
@@ -38,7 +67,7 @@
38
67
  ## 0.2.0 (2020-04-22)
39
68
 
40
69
  - No longer experimental
41
- - Updated libtorch to 1.5.0
70
+ - Updated LibTorch to 1.5.0
42
71
  - Added support for GPUs and OpenMP
43
72
  - Added adaptive pooling layers
44
73
  - Tensor `dtype` is now based on Numo type for `Torch.tensor`
@@ -47,7 +76,7 @@
47
76
 
48
77
  ## 0.1.8 (2020-01-17)
49
78
 
50
- - Updated libtorch to 1.4.0
79
+ - Updated LibTorch to 1.4.0
51
80
 
52
81
  ## 0.1.7 (2020-01-10)
53
82
 
data/README.md CHANGED
@@ -2,7 +2,11 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
5
+ Check out:
6
+
7
+ - [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
8
+ - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
+ - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
6
10
 
7
11
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
8
12
 
@@ -42,6 +46,8 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
42
46
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
43
47
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
44
48
 
49
+ You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
50
+
45
51
  ## Tutorial
46
52
 
47
53
  Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
@@ -409,7 +415,8 @@ Here’s the list of compatible versions.
409
415
 
410
416
  Torch.rb | LibTorch
411
417
  --- | ---
412
- 0.2.0+ | 1.5.0+
418
+ 0.3.0-0.3.1 | 1.6.0
419
+ 0.2.0-0.2.7 | 1.5.0-1.5.1
413
420
  0.1.8 | 1.4.0
414
421
  0.1.0-0.1.7 | 1.3.1
415
422
 
@@ -16,6 +16,7 @@
16
16
  #include "nn_functions.hpp"
17
17
 
18
18
  using namespace Rice;
19
+ using torch::indexing::TensorIndex;
19
20
 
20
21
  // need to make a distinction between parameters and tensors
21
22
  class Parameter: public torch::autograd::Variable {
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
28
29
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
30
  }
30
31
 
32
+ std::vector<TensorIndex> index_vector(Array a) {
33
+ auto indices = std::vector<TensorIndex>();
34
+ indices.reserve(a.size());
35
+ for (size_t i = 0; i < a.size(); i++) {
36
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
37
+ }
38
+ return indices;
39
+ }
40
+
31
41
  extern "C"
32
42
  void Init_ext()
33
43
  {
@@ -48,15 +58,23 @@ void Init_ext()
48
58
  .define_singleton_method(
49
59
  "initial_seed",
50
60
  *[]() {
51
- return at::detail::getDefaultCPUGenerator()->current_seed();
61
+ return at::detail::getDefaultCPUGenerator().current_seed();
52
62
  })
53
63
  .define_singleton_method(
54
64
  "seed",
55
65
  *[]() {
56
66
  // TODO set for CUDA when available
57
- return at::detail::getDefaultCPUGenerator()->seed();
67
+ auto generator = at::detail::getDefaultCPUGenerator();
68
+ return generator.seed();
58
69
  });
59
70
 
71
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
+
60
78
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
61
79
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
62
80
  .add_handler<torch::Error>(handle_error)
@@ -329,6 +347,18 @@ void Init_ext()
329
347
  .define_method("numel", &torch::Tensor::numel)
330
348
  .define_method("element_size", &torch::Tensor::element_size)
331
349
  .define_method("requires_grad", &torch::Tensor::requires_grad)
350
+ .define_method(
351
+ "_index",
352
+ *[](Tensor& self, Array indices) {
353
+ auto vec = index_vector(indices);
354
+ return self.index(vec);
355
+ })
356
+ .define_method(
357
+ "_index_put_custom",
358
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
359
+ auto vec = index_vector(indices);
360
+ return self.index_put_(vec, value);
361
+ })
332
362
  .define_method(
333
363
  "contiguous?",
334
364
  *[](Tensor& self) {
@@ -351,13 +381,19 @@ void Init_ext()
351
381
  })
352
382
  .define_method(
353
383
  "_backward",
354
- *[](Tensor& self, Object gradient) {
355
- return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
384
+ *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
385
+ return self.backward(gradient, create_graph, retain_graph);
356
386
  })
357
387
  .define_method(
358
388
  "grad",
359
389
  *[](Tensor& self) {
360
- return self.grad();
390
+ auto grad = self.grad();
391
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
392
+ })
393
+ .define_method(
394
+ "grad=",
395
+ *[](Tensor& self, torch::Tensor& grad) {
396
+ self.grad() = grad;
361
397
  })
362
398
  .define_method(
363
399
  "_dtype",
@@ -460,7 +496,7 @@ void Init_ext()
460
496
  .define_singleton_method(
461
497
  "_make_subclass",
462
498
  *[](Tensor& rd, bool requires_grad) {
463
- auto data = torch::autograd::as_variable_ref(rd).detach();
499
+ auto data = rd.detach();
464
500
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
465
501
  auto var = data.set_requires_grad(requires_grad);
466
502
  return Parameter(std::move(var));
@@ -501,6 +537,7 @@ void Init_ext()
501
537
  });
502
538
 
503
539
  Module rb_mInit = define_module_under(rb_mNN, "Init")
540
+ .add_handler<torch::Error>(handle_error)
504
541
  .define_singleton_method(
505
542
  "_calculate_gain",
506
543
  *[](NonlinearityType nonlinearity, double param) {
@@ -579,11 +616,16 @@ void Init_ext()
579
616
  *[](Parameter& self) {
580
617
  auto grad = self.grad();
581
618
  return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
619
+ })
620
+ .define_method(
621
+ "grad=",
622
+ *[](Parameter& self, torch::Tensor& grad) {
623
+ self.grad() = grad;
582
624
  });
583
625
 
584
626
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
585
- .define_constructor(Constructor<torch::Device, std::string>())
586
627
  .add_handler<torch::Error>(handle_error)
628
+ .define_constructor(Constructor<torch::Device, std::string>())
587
629
  .define_method("index", &torch::Device::index)
588
630
  .define_method("index?", &torch::Device::has_index)
589
631
  .define_method(
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
8
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
- # TODO check compiler name
11
- clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
10
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
12
11
 
13
12
  # check omp first
14
13
  if have_library("omp") || have_library("gomp")
15
14
  $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS += " -Xclang" if clang
15
+ $CXXFLAGS += " -Xclang" if apple_clang
17
16
  $CXXFLAGS += " -fopenmp"
18
17
  end
19
18
 
20
- if clang
19
+ if apple_clang
21
20
  # silence ruby/intern.h warning
22
21
  $CXXFLAGS += " -Wno-deprecated-register"
23
22
 
@@ -9,6 +9,10 @@
9
9
 
10
10
  using namespace Rice;
11
11
 
12
+ using torch::Device;
13
+ using torch::ScalarType;
14
+ using torch::Tensor;
15
+
12
16
  // need to wrap torch::IntArrayRef() since
13
17
  // it doesn't own underlying data
14
18
  class IntArrayRef {
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
174
178
  return MyReduction(x);
175
179
  }
176
180
 
177
- typedef torch::Tensor Tensor;
178
-
179
181
  class OptionalTensor {
180
182
  Object value;
181
183
  public:
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
197
199
  return OptionalTensor(x);
198
200
  }
199
201
 
200
- class ScalarType {
201
- Object value;
202
- public:
203
- ScalarType(Object o) {
204
- value = o;
205
- }
206
- operator at::ScalarType() {
207
- throw std::runtime_error("ScalarType arguments not implemented yet");
208
- }
209
- };
210
-
211
202
  template<>
212
203
  inline
213
- ScalarType from_ruby<ScalarType>(Object x)
204
+ torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
214
205
  {
215
- return ScalarType(x);
206
+ if (x.is_nil()) {
207
+ return torch::nullopt;
208
+ } else {
209
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
210
+ }
216
211
  }
217
212
 
218
- class OptionalScalarType {
219
- Object value;
220
- public:
221
- OptionalScalarType(Object o) {
222
- value = o;
223
- }
224
- operator c10::optional<at::ScalarType>() {
225
- if (value.is_nil()) {
226
- return c10::nullopt;
227
- }
228
- return ScalarType(value);
229
- }
230
- };
231
-
232
213
  template<>
233
214
  inline
234
- OptionalScalarType from_ruby<OptionalScalarType>(Object x)
215
+ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
235
216
  {
236
- return OptionalScalarType(x);
217
+ if (x.is_nil()) {
218
+ return torch::nullopt;
219
+ } else {
220
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
221
+ }
237
222
  }
238
223
 
239
- typedef torch::Device Device;
240
-
241
224
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
242
225
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
243
226
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -4,6 +4,7 @@ require "torch/ext"
4
4
  # stdlib
5
5
  require "fileutils"
6
6
  require "net/http"
7
+ require "set"
7
8
  require "tmpdir"
8
9
 
9
10
  # native functions
@@ -178,8 +179,10 @@ require "torch/nn/functional"
178
179
  require "torch/nn/init"
179
180
 
180
181
  # utils
182
+ require "torch/utils/data"
181
183
  require "torch/utils/data/data_loader"
182
184
  require "torch/utils/data/dataset"
185
+ require "torch/utils/data/subset"
183
186
  require "torch/utils/data/tensor_dataset"
184
187
 
185
188
  # hub
@@ -315,6 +318,16 @@ module Torch
315
318
  end
316
319
  end
317
320
 
321
+ def enable_grad
322
+ previous_value = grad_enabled?
323
+ begin
324
+ _set_grad_enabled(true)
325
+ yield
326
+ ensure
327
+ _set_grad_enabled(previous_value)
328
+ end
329
+ end
330
+
318
331
  def device(str)
319
332
  Device.new(str)
320
333
  end
@@ -447,6 +460,17 @@ module Torch
447
460
  zeros(input.size, **like_options(input, options))
448
461
  end
449
462
 
463
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
464
+ if center
465
+ signal_dim = input.dim
466
+ extended_shape = [1] * (3 - signal_dim) + input.size
467
+ pad = n_fft.div(2).to_i
468
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
469
+ input = input.view(input.shape[-signal_dim..-1])
470
+ end
471
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
472
+ end
473
+
450
474
  private
451
475
 
452
476
  def to_ivalue(obj)
@@ -470,11 +494,7 @@ module Torch
470
494
  when nil
471
495
  IValue.new
472
496
  when Array
473
- if obj.all? { |v| v.is_a?(Tensor) }
474
- IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
475
- else
476
- raise Error, "Unknown list type"
477
- end
497
+ IValue.from_list(obj.map { |v| to_ivalue(v) })
478
498
  else
479
499
  raise Error, "Unknown type: #{obj.class.name}"
480
500
  end
@@ -7,25 +7,26 @@ module Torch
7
7
 
8
8
  def download_url_to_file(url, dst)
9
9
  uri = URI(url)
10
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
10
+ tmp = nil
11
11
  location = nil
12
12
 
13
+ puts "Downloading #{url}..."
13
14
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
15
  request = Net::HTTP::Get.new(uri)
15
16
 
16
- puts "Downloading #{url}..."
17
- File.open(tmp, "wb") do |f|
18
- http.request(request) do |response|
19
- case response
20
- when Net::HTTPRedirection
21
- location = response["location"]
22
- when Net::HTTPSuccess
17
+ http.request(request) do |response|
18
+ case response
19
+ when Net::HTTPRedirection
20
+ location = response["location"]
21
+ when Net::HTTPSuccess
22
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
+ File.open(tmp, "wb") do |f|
23
24
  response.read_body do |chunk|
24
25
  f.write(chunk)
25
26
  end
26
- else
27
- raise Error, "Bad response"
28
27
  end
28
+ else
29
+ raise Error, "Bad response"
29
30
  end
30
31
  end
31
32
  end
@@ -1,89 +1,264 @@
1
+ # mirrors _tensor_str.py
1
2
  module Torch
2
3
  module Inspector
3
- # TODO make more performant, especially when summarizing
4
- # how? only read data that will be displayed
5
- def inspect
6
- data =
7
- if numel == 0
8
- "[]"
9
- elsif dim == 0
10
- item
4
+ PRINT_OPTS = {
5
+ precision: 4,
6
+ threshold: 1000,
7
+ edgeitems: 3,
8
+ linewidth: 80,
9
+ sci_mode: nil
10
+ }
11
+
12
+ class Formatter
13
+ def initialize(tensor)
14
+ @floating_dtype = tensor.floating_point?
15
+ @complex_dtype = tensor.complex?
16
+ @int_mode = true
17
+ @sci_mode = false
18
+ @max_width = 1
19
+
20
+ tensor_view = Torch.no_grad { tensor.reshape(-1) }
21
+
22
+ if !@floating_dtype
23
+ tensor_view.each do |value|
24
+ value_str = value.item.to_s
25
+ @max_width = [@max_width, value_str.length].max
26
+ end
11
27
  else
12
- summarize = numel > 1000
28
+ nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
29
+
30
+ # no valid number, do nothing
31
+ return if nonzero_finite_vals.numel == 0
32
+
33
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
34
+ nonzero_finite_abs = nonzero_finite_vals.abs.double
35
+ nonzero_finite_min = nonzero_finite_abs.min.double
36
+ nonzero_finite_max = nonzero_finite_abs.max.double
37
+
38
+ nonzero_finite_vals.each do |value|
39
+ if value.item != value.item.ceil
40
+ @int_mode = false
41
+ break
42
+ end
43
+ end
13
44
 
14
- if dtype == :bool
15
- fmt = "%s"
45
+ if @int_mode
46
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
47
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
48
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
49
+ @sci_mode = true
50
+ nonzero_finite_vals.each do |value|
51
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
52
+ @max_width = [@max_width, value_str.length].max
53
+ end
54
+ else
55
+ nonzero_finite_vals.each do |value|
56
+ value_str = "%.0f" % value.item
57
+ @max_width = [@max_width, value_str.length + 1].max
58
+ end
59
+ end
16
60
  else
17
- values = _flat_data
18
- abs = values.select { |v| v != 0 }.map(&:abs)
19
- max = abs.max || 1
20
- min = abs.min || 1
21
-
22
- total = 0
23
- if values.any? { |v| v < 0 }
24
- total += 1
61
+ # Check if scientific representation should be used.
62
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
63
+ @sci_mode = true
64
+ nonzero_finite_vals.each do |value|
65
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
66
+ @max_width = [@max_width, value_str.length].max
67
+ end
68
+ else
69
+ nonzero_finite_vals.each do |value|
70
+ value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
71
+ @max_width = [@max_width, value_str.length].max
72
+ end
25
73
  end
74
+ end
75
+ end
26
76
 
27
- if floating_point?
28
- sci = max > 1e8 || max < 1e-4
77
+ @sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
78
+ end
29
79
 
30
- all_int = values.all? { |v| v.finite? && v == v.to_i }
31
- decimal = all_int ? 1 : 4
80
+ def width
81
+ @max_width
82
+ end
32
83
 
33
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
84
+ def format(value)
85
+ value = value.item
34
86
 
35
- if sci
36
- fmt = "%#{total}.4e"
37
- else
38
- fmt = "%#{total}.#{decimal}f"
39
- end
40
- else
41
- total += max.to_s.size
42
- fmt = "%#{total}d"
87
+ if @floating_dtype
88
+ if @sci_mode
89
+ ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
90
+ elsif @int_mode
91
+ ret = String.new("%.0f" % value)
92
+ unless value.infinite? || value.nan?
93
+ ret += "."
43
94
  end
95
+ else
96
+ ret = "%.#{PRINT_OPTS[:precision]}f" % value
44
97
  end
98
+ elsif @complex_dtype
99
+ p = PRINT_OPTS[:precision]
100
+ raise NotImplementedYet
101
+ else
102
+ ret = value.to_s
103
+ end
104
+ # Ruby throws error when negative, Python doesn't
105
+ " " * [@max_width - ret.size, 0].max + ret
106
+ end
107
+ end
108
+
109
+ def inspect
110
+ Torch.no_grad do
111
+ str_intern(self)
112
+ end
113
+ rescue => e
114
+ # prevent stack error
115
+ puts e.backtrace.join("\n")
116
+ "Error inspecting tensor: #{e.inspect}"
117
+ end
118
+
119
+ private
120
+
121
+ # TODO update
122
+ def str_intern(slf)
123
+ prefix = "tensor("
124
+ indent = prefix.length
125
+ suffixes = []
126
+
127
+ has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
128
+
129
+ if slf.numel == 0 && !slf.sparse?
130
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
131
+ if slf.dim != 1
132
+ suffixes << "size: #{shape.inspect}"
133
+ end
45
134
 
46
- inspect_level(to_a, fmt, dim - 1, 0, summarize)
135
+ # In an empty tensor, there are no elements to infer if the dtype
136
+ # should be int64, so it must be shown explicitly.
137
+ if slf.dtype != :int64
138
+ suffixes << "dtype: #{slf.dtype.inspect}"
47
139
  end
140
+ tensor_str = "[]"
141
+ else
142
+ if !has_default_dtype
143
+ suffixes << "dtype: #{slf.dtype.inspect}"
144
+ end
145
+
146
+ if slf.layout != :strided
147
+ tensor_str = tensor_str(slf.to_dense, indent)
148
+ else
149
+ tensor_str = tensor_str(slf, indent)
150
+ end
151
+ end
48
152
 
49
- attributes = []
50
- if requires_grad
51
- attributes << "requires_grad: true"
153
+ if slf.layout != :strided
154
+ suffixes << "layout: #{slf.layout.inspect}"
52
155
  end
53
- if ![:float32, :int64, :bool].include?(dtype)
54
- attributes << "dtype: #{dtype.inspect}"
156
+
157
+ # TODO show grad_fn
158
+ if slf.requires_grad?
159
+ suffixes << "requires_grad: true"
55
160
  end
56
161
 
57
- "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
162
+ add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
58
163
  end
59
164
 
60
- private
165
+ def add_suffixes(tensor_str, suffixes, indent, force_newline)
166
+ tensor_strs = [tensor_str]
167
+ # rfind in Python returns -1 when not found
168
+ last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
169
+ suffixes.each do |suffix|
170
+ suffix_len = suffix.length
171
+ if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
172
+ tensor_strs << ",\n" + " " * indent + suffix
173
+ last_line_len = indent + suffix_len
174
+ force_newline = false
175
+ else
176
+ tensor_strs.append(", " + suffix)
177
+ last_line_len += suffix_len + 2
178
+ end
179
+ end
180
+ tensor_strs.append(")")
181
+ tensor_strs.join("")
182
+ end
61
183
 
62
- # TODO DRY code
63
- def inspect_level(arr, fmt, total, level, summarize)
64
- if level == total
65
- cols =
66
- if summarize && arr.size > 7
67
- arr[0..2].map { |v| fmt % v } +
68
- ["..."] +
69
- arr[-3..-1].map { |v| fmt % v }
70
- else
71
- arr.map { |v| fmt % v }
72
- end
184
+ def tensor_str(slf, indent)
185
+ return "[]" if slf.numel == 0
186
+
187
+ summarize = slf.numel > PRINT_OPTS[:threshold]
188
+
189
+ if slf.dtype == :float16 || slf.dtype == :bfloat16
190
+ slf = slf.float
191
+ end
192
+ formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
193
+ tensor_str_with_formatter(slf, indent, formatter, summarize)
194
+ end
195
+
196
+ def summarized_data(slf)
197
+ edgeitems = PRINT_OPTS[:edgeitems]
73
198
 
74
- "[#{cols.join(", ")}]"
199
+ dim = slf.dim
200
+ if dim == 0
201
+ slf
202
+ elsif dim == 1
203
+ if size(0) > 2 * edgeitems
204
+ Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
205
+ else
206
+ slf
207
+ end
208
+ elsif slf.size(0) > 2 * edgeitems
209
+ start = edgeitems.times.map { |i| slf[i] }
210
+ finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
211
+ Torch.stack((start + finish).map { |x| summarized_data(x) })
75
212
  else
76
- rows =
77
- if summarize && arr.size > 7
78
- arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
79
- ["..."] +
80
- arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
81
- else
82
- arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
83
- end
213
+ Torch.stack(slf.map { |x| summarized_data(x) })
214
+ end
215
+ end
216
+
217
+ def tensor_str_with_formatter(slf, indent, formatter, summarize)
218
+ edgeitems = PRINT_OPTS[:edgeitems]
219
+
220
+ dim = slf.dim
84
221
 
85
- "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
222
+ return scalar_str(slf, formatter) if dim == 0
223
+ return vector_str(slf, indent, formatter, summarize) if dim == 1
224
+
225
+ if summarize && slf.size(0) > 2 * edgeitems
226
+ slices = (
227
+ [edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
228
+ ["..."] +
229
+ [((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
230
+ )
231
+ else
232
+ slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
86
233
  end
234
+
235
+ tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
236
+ "[" + tensor_str + "]"
237
+ end
238
+
239
+ def scalar_str(slf, formatter)
240
+ formatter.format(slf)
241
+ end
242
+
243
+ def vector_str(slf, indent, formatter, summarize)
244
+ # length includes spaces and comma between elements
245
+ element_length = formatter.width + 2
246
+ elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
247
+ char_per_line = element_length * elements_per_line
248
+
249
+ if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
250
+ data = (
251
+ [slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
252
+ [" ..."] +
253
+ [slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
254
+ )
255
+ else
256
+ data = slf.map { |val| formatter.format(val) }
257
+ end
258
+
259
+ data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
260
+ lines = data_lines.map { |line| line.join(", ") }
261
+ "[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
87
262
  end
88
263
  end
89
264
  end