torch-rb 0.3.0 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33636e58063f25c2b9f122d29332e4136bb6a4de0fd227349f75d65a9db94931
4
- data.tar.gz: 9349dd0b050a4c9e0714d92bb451bdd916e55fb47a5c4d90a74720d53564a1d6
3
+ metadata.gz: 93271ffd62be6e35c6ea3a2219a7bc3dccbe8489d6f4aca1a1f00f99bab1a4bb
4
+ data.tar.gz: df2755ac3e6221502430d780d116a1145c461f97857e7b1b2b809095afaad9e5
5
5
  SHA512:
6
- metadata.gz: 692e8dc3531426377413fc9325c2b03dd7fcbbbce0c05cbd5d7c3182a08bbe733bb9a6b0aa62a56fb27917073e1f1f5859aa0dd3f6f40e74043bba242ede6267
7
- data.tar.gz: f57d0411c18c7c4753f5edc82e48b666ad3154c72bf189a8ec2c4dceb08cd1f37a9421b9a6eae3d20e10c4c9236a1136d76559a31989ce22868a9f44ef3e0e66
6
+ metadata.gz: 9e02de90a7a83e5d4421941a0ceea69c6367f42be2e97e4229812f35c27f83475fbace86f42285128da724343dcfd85050b7846d81d43fce100749be0072ad4c
7
+ data.tar.gz: 884873c3c965f16b0a833087909019ebc6f228511a9b0ecbf4c436cc546e28012e63899651aeab6befddbe7b057676b170cd3eb858b997f42289c103252a2834
@@ -1,3 +1,30 @@
1
+ ## 0.3.5 (2020-09-04)
2
+
3
+ - Fixed error with data loader (due to `dtype` of `randperm`)
4
+
5
+ ## 0.3.4 (2020-08-26)
6
+
7
+ - Added `Torch.clamp` method
8
+
9
+ ## 0.3.3 (2020-08-25)
10
+
11
+ - Added spectral ops
12
+ - Fixed tensor indexing
13
+
14
+ ## 0.3.2 (2020-08-24)
15
+
16
+ - Added `enable_grad` method
17
+ - Added `random_split` method
18
+ - Added `collate_fn` option to `DataLoader`
19
+ - Added `grad=` method to `Tensor`
20
+ - Fixed error with `grad` method when empty
21
+ - Fixed `EmbeddingBag`
22
+
23
+ ## 0.3.1 (2020-08-17)
24
+
25
+ - Added `create_graph` and `retain_graph` options to `backward` method
26
+ - Fixed error when `set` not required
27
+
1
28
  ## 0.3.0 (2020-07-29)
2
29
 
3
30
  - Updated LibTorch to 1.6.0
data/README.md CHANGED
@@ -2,7 +2,11 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
5
+ Check out:
6
+
7
+ - [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
8
+ - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
+ - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
6
10
 
7
11
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
8
12
 
@@ -411,7 +415,7 @@ Here’s the list of compatible versions.
411
415
 
412
416
  Torch.rb | LibTorch
413
417
  --- | ---
414
- 0.3.0 | 1.6.0
418
+ 0.3.0-0.3.4 | 1.6.0
415
419
  0.2.0-0.2.7 | 1.5.0-1.5.1
416
420
  0.1.8 | 1.4.0
417
421
  0.1.0-0.1.7 | 1.3.1
@@ -16,6 +16,7 @@
16
16
  #include "nn_functions.hpp"
17
17
 
18
18
  using namespace Rice;
19
+ using torch::indexing::TensorIndex;
19
20
 
20
21
  // need to make a distinction between parameters and tensors
21
22
  class Parameter: public torch::autograd::Variable {
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
28
29
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
30
  }
30
31
 
32
+ std::vector<TensorIndex> index_vector(Array a) {
33
+ auto indices = std::vector<TensorIndex>();
34
+ indices.reserve(a.size());
35
+ for (size_t i = 0; i < a.size(); i++) {
36
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
37
+ }
38
+ return indices;
39
+ }
40
+
31
41
  extern "C"
32
42
  void Init_ext()
33
43
  {
@@ -58,6 +68,13 @@ void Init_ext()
58
68
  return generator.seed();
59
69
  });
60
70
 
71
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
+
61
78
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
62
79
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
63
80
  .add_handler<torch::Error>(handle_error)
@@ -330,6 +347,18 @@ void Init_ext()
330
347
  .define_method("numel", &torch::Tensor::numel)
331
348
  .define_method("element_size", &torch::Tensor::element_size)
332
349
  .define_method("requires_grad", &torch::Tensor::requires_grad)
350
+ .define_method(
351
+ "_index",
352
+ *[](Tensor& self, Array indices) {
353
+ auto vec = index_vector(indices);
354
+ return self.index(vec);
355
+ })
356
+ .define_method(
357
+ "_index_put_custom",
358
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
359
+ auto vec = index_vector(indices);
360
+ return self.index_put_(vec, value);
361
+ })
333
362
  .define_method(
334
363
  "contiguous?",
335
364
  *[](Tensor& self) {
@@ -352,13 +381,19 @@ void Init_ext()
352
381
  })
353
382
  .define_method(
354
383
  "_backward",
355
- *[](Tensor& self, Object gradient) {
356
- return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
384
+ *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
385
+ return self.backward(gradient, create_graph, retain_graph);
357
386
  })
358
387
  .define_method(
359
388
  "grad",
360
389
  *[](Tensor& self) {
361
- return self.grad();
390
+ auto grad = self.grad();
391
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
392
+ })
393
+ .define_method(
394
+ "grad=",
395
+ *[](Tensor& self, torch::Tensor& grad) {
396
+ self.grad() = grad;
362
397
  })
363
398
  .define_method(
364
399
  "_dtype",
@@ -502,6 +537,7 @@ void Init_ext()
502
537
  });
503
538
 
504
539
  Module rb_mInit = define_module_under(rb_mNN, "Init")
540
+ .add_handler<torch::Error>(handle_error)
505
541
  .define_singleton_method(
506
542
  "_calculate_gain",
507
543
  *[](NonlinearityType nonlinearity, double param) {
@@ -580,11 +616,16 @@ void Init_ext()
580
616
  *[](Parameter& self) {
581
617
  auto grad = self.grad();
582
618
  return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
619
+ })
620
+ .define_method(
621
+ "grad=",
622
+ *[](Parameter& self, torch::Tensor& grad) {
623
+ self.grad() = grad;
583
624
  });
584
625
 
585
626
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
586
- .define_constructor(Constructor<torch::Device, std::string>())
587
627
  .add_handler<torch::Error>(handle_error)
628
+ .define_constructor(Constructor<torch::Device, std::string>())
588
629
  .define_method("index", &torch::Device::index)
589
630
  .define_method("index?", &torch::Device::has_index)
590
631
  .define_method(
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
8
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
- # TODO check compiler name
11
- clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
10
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
12
11
 
13
12
  # check omp first
14
13
  if have_library("omp") || have_library("gomp")
15
14
  $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS += " -Xclang" if clang
15
+ $CXXFLAGS += " -Xclang" if apple_clang
17
16
  $CXXFLAGS += " -fopenmp"
18
17
  end
19
18
 
20
- if clang
19
+ if apple_clang
21
20
  # silence ruby/intern.h warning
22
21
  $CXXFLAGS += " -Wno-deprecated-register"
23
22
 
@@ -9,6 +9,10 @@
9
9
 
10
10
  using namespace Rice;
11
11
 
12
+ using torch::Device;
13
+ using torch::ScalarType;
14
+ using torch::Tensor;
15
+
12
16
  // need to wrap torch::IntArrayRef() since
13
17
  // it doesn't own underlying data
14
18
  class IntArrayRef {
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
174
178
  return MyReduction(x);
175
179
  }
176
180
 
177
- typedef torch::Tensor Tensor;
178
-
179
181
  class OptionalTensor {
180
182
  Object value;
181
183
  public:
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
197
199
  return OptionalTensor(x);
198
200
  }
199
201
 
200
- class ScalarType {
201
- Object value;
202
- public:
203
- ScalarType(Object o) {
204
- value = o;
205
- }
206
- operator at::ScalarType() {
207
- throw std::runtime_error("ScalarType arguments not implemented yet");
208
- }
209
- };
210
-
211
202
  template<>
212
203
  inline
213
- ScalarType from_ruby<ScalarType>(Object x)
204
+ torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
214
205
  {
215
- return ScalarType(x);
206
+ if (x.is_nil()) {
207
+ return torch::nullopt;
208
+ } else {
209
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
210
+ }
216
211
  }
217
212
 
218
- class OptionalScalarType {
219
- Object value;
220
- public:
221
- OptionalScalarType(Object o) {
222
- value = o;
223
- }
224
- operator c10::optional<at::ScalarType>() {
225
- if (value.is_nil()) {
226
- return c10::nullopt;
227
- }
228
- return ScalarType(value);
229
- }
230
- };
231
-
232
213
  template<>
233
214
  inline
234
- OptionalScalarType from_ruby<OptionalScalarType>(Object x)
215
+ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
235
216
  {
236
- return OptionalScalarType(x);
217
+ if (x.is_nil()) {
218
+ return torch::nullopt;
219
+ } else {
220
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
221
+ }
237
222
  }
238
223
 
239
- typedef torch::Device Device;
240
-
241
224
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
242
225
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
243
226
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -4,6 +4,7 @@ require "torch/ext"
4
4
  # stdlib
5
5
  require "fileutils"
6
6
  require "net/http"
7
+ require "set"
7
8
  require "tmpdir"
8
9
 
9
10
  # native functions
@@ -178,8 +179,10 @@ require "torch/nn/functional"
178
179
  require "torch/nn/init"
179
180
 
180
181
  # utils
182
+ require "torch/utils/data"
181
183
  require "torch/utils/data/data_loader"
182
184
  require "torch/utils/data/dataset"
185
+ require "torch/utils/data/subset"
183
186
  require "torch/utils/data/tensor_dataset"
184
187
 
185
188
  # hub
@@ -315,6 +318,16 @@ module Torch
315
318
  end
316
319
  end
317
320
 
321
+ def enable_grad
322
+ previous_value = grad_enabled?
323
+ begin
324
+ _set_grad_enabled(true)
325
+ yield
326
+ ensure
327
+ _set_grad_enabled(previous_value)
328
+ end
329
+ end
330
+
318
331
  def device(str)
319
332
  Device.new(str)
320
333
  end
@@ -375,6 +388,10 @@ module Torch
375
388
  end
376
389
 
377
390
  def randperm(n, **options)
391
+ # dtype hack in Python
392
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
393
+ options[:dtype] ||= :int64
394
+
378
395
  _randperm(n, tensor_options(**options))
379
396
  end
380
397
 
@@ -447,6 +464,22 @@ module Torch
447
464
  zeros(input.size, **like_options(input, options))
448
465
  end
449
466
 
467
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
468
+ if center
469
+ signal_dim = input.dim
470
+ extended_shape = [1] * (3 - signal_dim) + input.size
471
+ pad = n_fft.div(2).to_i
472
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
473
+ input = input.view(input.shape[-signal_dim..-1])
474
+ end
475
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
476
+ end
477
+
478
+ def clamp(tensor, min, max)
479
+ tensor = _clamp_min(tensor, min)
480
+ _clamp_max(tensor, max)
481
+ end
482
+
450
483
  private
451
484
 
452
485
  def to_ivalue(obj)
@@ -7,25 +7,26 @@ module Torch
7
7
 
8
8
  def download_url_to_file(url, dst)
9
9
  uri = URI(url)
10
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
10
+ tmp = nil
11
11
  location = nil
12
12
 
13
+ puts "Downloading #{url}..."
13
14
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
15
  request = Net::HTTP::Get.new(uri)
15
16
 
16
- puts "Downloading #{url}..."
17
- File.open(tmp, "wb") do |f|
18
- http.request(request) do |response|
19
- case response
20
- when Net::HTTPRedirection
21
- location = response["location"]
22
- when Net::HTTPSuccess
17
+ http.request(request) do |response|
18
+ case response
19
+ when Net::HTTPRedirection
20
+ location = response["location"]
21
+ when Net::HTTPSuccess
22
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
+ File.open(tmp, "wb") do |f|
23
24
  response.read_body do |chunk|
24
25
  f.write(chunk)
25
26
  end
26
- else
27
- raise Error, "Bad response"
28
27
  end
28
+ else
29
+ raise Error, "Bad response"
29
30
  end
30
31
  end
31
32
  end
@@ -1,10 +1,14 @@
1
1
  module Torch
2
2
  module Native
3
3
  class Function
4
- attr_reader :function
4
+ attr_reader :function, :tensor_options
5
5
 
6
6
  def initialize(function)
7
7
  @function = function
8
+
9
+ tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
+ @tensor_options = @function["func"].include?(tensor_options_str)
11
+ @function["func"].sub!(tensor_options_str, ")")
8
12
  end
9
13
 
10
14
  def func
@@ -33,30 +33,14 @@ module Torch
33
33
  f.args.any? do |a|
34
34
  a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
35
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
+ # call to 'range' is ambiguous
37
+ f.cpp_name == "_range" ||
36
38
  # native_functions.yaml is missing size argument for normal
37
39
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
40
  (f.base_name == "normal" && !f.out?)
39
41
  end
40
42
  end
41
43
 
42
- # generate additional functions for optional arguments
43
- # there may be a better way to do this
44
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
45
- optional_functions.each do |f|
46
- next if f.ruby_name == "cross"
47
- next if f.ruby_name.start_with?("avg_pool") && f.out?
48
-
49
- opt_args = f.args.select { |a| a[:type] == "int?" }
50
- if opt_args.size == 1
51
- sep = f.name.include?(".") ? "_" : "."
52
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
53
- # TODO only remove some arguments
54
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
55
- functions << f1
56
- functions << f2
57
- end
58
- end
59
-
60
44
  # todo_functions.each do |f|
61
45
  # puts f.func
62
46
  # puts
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
97
81
 
98
82
  cpp_defs = []
99
83
  functions.sort_by(&:cpp_name).each do |func|
100
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
84
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
85
+ fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
101
86
 
102
87
  cpp_args = []
103
88
  fargs.each do |a|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
109
94
  # TODO better signature
110
95
  "OptionalTensor"
111
96
  when "ScalarType?"
112
- "OptionalScalarType"
97
+ "torch::optional<ScalarType>"
113
98
  when "Tensor[]"
114
99
  "TensorList"
115
100
  when "Tensor?[]"
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
117
102
  "TensorList"
118
103
  when "int"
119
104
  "int64_t"
105
+ when "int?"
106
+ "torch::optional<int64_t>"
120
107
  when "float"
121
108
  "double"
122
109
  when /\Aint\[/
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
125
112
  "Tensor &"
126
113
  when "str"
127
114
  "std::string"
115
+ when "TensorOptions"
116
+ "const torch::TensorOptions &"
128
117
  else
129
118
  a[:type]
130
119
  end
@@ -83,6 +83,8 @@ module Torch
83
83
  else
84
84
  v.is_a?(Integer)
85
85
  end
86
+ when "int?"
87
+ v.is_a?(Integer) || v.nil?
86
88
  when "float"
87
89
  v.is_a?(Numeric)
88
90
  when /int\[.*\]/
@@ -126,9 +128,11 @@ module Torch
126
128
  end
127
129
 
128
130
  func = candidates.first
131
+ args = func.args.map { |a| final_values[a[:name]] }
132
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
133
  {
130
134
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
135
+ args: args
132
136
  }
133
137
  end
134
138
  end
@@ -373,7 +373,8 @@ module Torch
373
373
  end
374
374
 
375
375
  # weight and input swapped
376
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
376
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
377
+ ret
377
378
  end
378
379
 
379
380
  # distance functions
@@ -426,6 +427,9 @@ module Torch
426
427
  end
427
428
 
428
429
  def mse_loss(input, target, reduction: "mean")
430
+ if target.size != input.size
431
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
432
+ end
429
433
  NN.mse_loss(input, target, reduction)
430
434
  end
431
435
 
@@ -103,8 +103,9 @@ module Torch
103
103
  Torch.empty(0, dtype: dtype)
104
104
  end
105
105
 
106
- def backward(gradient = nil)
107
- _backward(gradient)
106
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
+ retain_graph = create_graph if retain_graph.nil?
108
+ _backward(gradient, retain_graph, create_graph)
108
109
  end
109
110
 
110
111
  # TODO read directly from memory
@@ -187,49 +188,15 @@ module Torch
187
188
  # based on python_variable_indexing.cpp and
188
189
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
189
190
  def [](*indexes)
190
- result = self
191
- dim = 0
192
- indexes.each do |index|
193
- if index.is_a?(Numeric)
194
- result = result._select_int(dim, index)
195
- elsif index.is_a?(Range)
196
- finish = index.end
197
- finish += 1 unless index.exclude_end?
198
- result = result._slice_tensor(dim, index.begin, finish, 1)
199
- dim += 1
200
- elsif index.is_a?(Tensor)
201
- result = result.index([index])
202
- elsif index.nil?
203
- result = result.unsqueeze(dim)
204
- dim += 1
205
- elsif index == true
206
- result = result.unsqueeze(dim)
207
- # TODO handle false
208
- else
209
- raise Error, "Unsupported index type: #{index.class.name}"
210
- end
211
- end
212
- result
191
+ _index(tensor_indexes(indexes))
213
192
  end
214
193
 
215
194
  # based on python_variable_indexing.cpp and
216
195
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
217
- def []=(index, value)
196
+ def []=(*indexes, value)
218
197
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
219
-
220
198
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
221
-
222
- if index.is_a?(Numeric)
223
- index_put!([Torch.tensor(index)], value)
224
- elsif index.is_a?(Range)
225
- finish = index.end
226
- finish += 1 unless index.exclude_end?
227
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
228
- elsif index.is_a?(Tensor)
229
- index_put!([index], value)
230
- else
231
- raise Error, "Unsupported index type: #{index.class.name}"
232
- end
199
+ _index_put_custom(tensor_indexes(indexes), value)
233
200
  end
234
201
 
235
202
  # native functions that need manually defined
@@ -243,13 +210,13 @@ module Torch
243
210
  end
244
211
  end
245
212
 
246
- # native functions overlap, so need to handle manually
213
+ # parser can't handle overlap, so need to handle manually
247
214
  def random!(*args)
248
215
  case args.size
249
216
  when 1
250
217
  _random__to(*args)
251
218
  when 2
252
- _random__from_to(*args)
219
+ _random__from(*args)
253
220
  else
254
221
  _random_(*args)
255
222
  end
@@ -259,5 +226,32 @@ module Torch
259
226
  _clamp_min_(min)
260
227
  _clamp_max_(max)
261
228
  end
229
+
230
+ private
231
+
232
+ def tensor_indexes(indexes)
233
+ indexes.map do |index|
234
+ case index
235
+ when Integer
236
+ TensorIndex.integer(index)
237
+ when Range
238
+ finish = index.end
239
+ if finish == -1 && !index.exclude_end?
240
+ finish = nil
241
+ else
242
+ finish += 1 unless index.exclude_end?
243
+ end
244
+ TensorIndex.slice(index.begin, finish)
245
+ when Tensor
246
+ TensorIndex.tensor(index)
247
+ when nil
248
+ TensorIndex.none
249
+ when true, false
250
+ TensorIndex.boolean(index)
251
+ else
252
+ raise Error, "Unsupported index type: #{index.class.name}"
253
+ end
254
+ end
255
+ end
262
256
  end
263
257
  end
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class << self
5
+ def random_split(dataset, lengths)
6
+ if lengths.sum != dataset.length
7
+ raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
8
+ end
9
+
10
+ indices = Torch.randperm(lengths.sum).to_a
11
+ _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
12
+ end
13
+
14
+ private
15
+
16
+ def _accumulate(iterable)
17
+ sum = 0
18
+ iterable.map { |x| sum += x }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -6,10 +6,22 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1, shuffle: false)
9
+ def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
12
  @shuffle = shuffle
13
+
14
+ @batch_sampler = nil
15
+
16
+ if collate_fn.nil?
17
+ if auto_collation?
18
+ collate_fn = method(:default_collate)
19
+ else
20
+ collate_fn = method(:default_convert)
21
+ end
22
+ end
23
+
24
+ @collate_fn = collate_fn
13
25
  end
14
26
 
15
27
  def each
@@ -25,8 +37,8 @@ module Torch
25
37
  end
26
38
 
27
39
  indexes.each_slice(@batch_size) do |idx|
28
- batch = idx.map { |i| @dataset[i] }
29
- yield collate(batch)
40
+ # TODO improve performance
41
+ yield @collate_fn.call(idx.map { |i| @dataset[i] })
30
42
  end
31
43
  end
32
44
 
@@ -36,7 +48,7 @@ module Torch
36
48
 
37
49
  private
38
50
 
39
- def collate(batch)
51
+ def default_convert(batch)
40
52
  elem = batch[0]
41
53
  case elem
42
54
  when Tensor
@@ -44,11 +56,15 @@ module Torch
44
56
  when Integer
45
57
  Torch.tensor(batch)
46
58
  when Array
47
- batch.transpose.map { |v| collate(v) }
59
+ batch.transpose.map { |v| default_convert(v) }
48
60
  else
49
- raise NotImpelmentYet
61
+ raise NotImplementedYet
50
62
  end
51
63
  end
64
+
65
+ def auto_collation?
66
+ !@batch_sampler.nil?
67
+ end
52
68
  end
53
69
  end
54
70
  end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Subset < Dataset
5
+ def initialize(dataset, indices)
6
+ @dataset = dataset
7
+ @indices = indices
8
+ end
9
+
10
+ def [](idx)
11
+ @dataset[@indices[idx]]
12
+ end
13
+
14
+ def length
15
+ @indices.length
16
+ end
17
+ alias_method :size, :length
18
+
19
+ def to_a
20
+ @indices.map { |i| @dataset[i] }
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.0"
2
+ VERSION = "0.3.5"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-07-29 00:00:00.000000000 Z
11
+ date: 2020-09-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -259,8 +259,10 @@ files:
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
+ - lib/torch/utils/data.rb
262
263
  - lib/torch/utils/data/data_loader.rb
263
264
  - lib/torch/utils/data/dataset.rb
265
+ - lib/torch/utils/data/subset.rb
264
266
  - lib/torch/utils/data/tensor_dataset.rb
265
267
  - lib/torch/version.rb
266
268
  homepage: https://github.com/ankane/torch.rb