torch-rb 0.3.0 → 0.3.5

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33636e58063f25c2b9f122d29332e4136bb6a4de0fd227349f75d65a9db94931
4
- data.tar.gz: 9349dd0b050a4c9e0714d92bb451bdd916e55fb47a5c4d90a74720d53564a1d6
3
+ metadata.gz: 93271ffd62be6e35c6ea3a2219a7bc3dccbe8489d6f4aca1a1f00f99bab1a4bb
4
+ data.tar.gz: df2755ac3e6221502430d780d116a1145c461f97857e7b1b2b809095afaad9e5
5
5
  SHA512:
6
- metadata.gz: 692e8dc3531426377413fc9325c2b03dd7fcbbbce0c05cbd5d7c3182a08bbe733bb9a6b0aa62a56fb27917073e1f1f5859aa0dd3f6f40e74043bba242ede6267
7
- data.tar.gz: f57d0411c18c7c4753f5edc82e48b666ad3154c72bf189a8ec2c4dceb08cd1f37a9421b9a6eae3d20e10c4c9236a1136d76559a31989ce22868a9f44ef3e0e66
6
+ metadata.gz: 9e02de90a7a83e5d4421941a0ceea69c6367f42be2e97e4229812f35c27f83475fbace86f42285128da724343dcfd85050b7846d81d43fce100749be0072ad4c
7
+ data.tar.gz: 884873c3c965f16b0a833087909019ebc6f228511a9b0ecbf4c436cc546e28012e63899651aeab6befddbe7b057676b170cd3eb858b997f42289c103252a2834
@@ -1,3 +1,30 @@
1
+ ## 0.3.5 (2020-09-04)
2
+
3
+ - Fixed error with data loader (due to `dtype` of `randperm`)
4
+
5
+ ## 0.3.4 (2020-08-26)
6
+
7
+ - Added `Torch.clamp` method
8
+
9
+ ## 0.3.3 (2020-08-25)
10
+
11
+ - Added spectral ops
12
+ - Fixed tensor indexing
13
+
14
+ ## 0.3.2 (2020-08-24)
15
+
16
+ - Added `enable_grad` method
17
+ - Added `random_split` method
18
+ - Added `collate_fn` option to `DataLoader`
19
+ - Added `grad=` method to `Tensor`
20
+ - Fixed error with `grad` method when empty
21
+ - Fixed `EmbeddingBag`
22
+
23
+ ## 0.3.1 (2020-08-17)
24
+
25
+ - Added `create_graph` and `retain_graph` options to `backward` method
26
+ - Fixed error when `set` not required
27
+
1
28
  ## 0.3.0 (2020-07-29)
2
29
 
3
30
  - Updated LibTorch to 1.6.0
data/README.md CHANGED
@@ -2,7 +2,11 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
5
+ Check out:
6
+
7
+ - [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
8
+ - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
+ - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
6
10
 
7
11
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
8
12
 
@@ -411,7 +415,7 @@ Here’s the list of compatible versions.
411
415
 
412
416
  Torch.rb | LibTorch
413
417
  --- | ---
414
- 0.3.0 | 1.6.0
418
+ 0.3.0-0.3.4 | 1.6.0
415
419
  0.2.0-0.2.7 | 1.5.0-1.5.1
416
420
  0.1.8 | 1.4.0
417
421
  0.1.0-0.1.7 | 1.3.1
@@ -16,6 +16,7 @@
16
16
  #include "nn_functions.hpp"
17
17
 
18
18
  using namespace Rice;
19
+ using torch::indexing::TensorIndex;
19
20
 
20
21
  // need to make a distinction between parameters and tensors
21
22
  class Parameter: public torch::autograd::Variable {
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
28
29
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
30
  }
30
31
 
32
+ std::vector<TensorIndex> index_vector(Array a) {
33
+ auto indices = std::vector<TensorIndex>();
34
+ indices.reserve(a.size());
35
+ for (size_t i = 0; i < a.size(); i++) {
36
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
37
+ }
38
+ return indices;
39
+ }
40
+
31
41
  extern "C"
32
42
  void Init_ext()
33
43
  {
@@ -58,6 +68,13 @@ void Init_ext()
58
68
  return generator.seed();
59
69
  });
60
70
 
71
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
+
61
78
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
62
79
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
63
80
  .add_handler<torch::Error>(handle_error)
@@ -330,6 +347,18 @@ void Init_ext()
330
347
  .define_method("numel", &torch::Tensor::numel)
331
348
  .define_method("element_size", &torch::Tensor::element_size)
332
349
  .define_method("requires_grad", &torch::Tensor::requires_grad)
350
+ .define_method(
351
+ "_index",
352
+ *[](Tensor& self, Array indices) {
353
+ auto vec = index_vector(indices);
354
+ return self.index(vec);
355
+ })
356
+ .define_method(
357
+ "_index_put_custom",
358
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
359
+ auto vec = index_vector(indices);
360
+ return self.index_put_(vec, value);
361
+ })
333
362
  .define_method(
334
363
  "contiguous?",
335
364
  *[](Tensor& self) {
@@ -352,13 +381,19 @@ void Init_ext()
352
381
  })
353
382
  .define_method(
354
383
  "_backward",
355
- *[](Tensor& self, Object gradient) {
356
- return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
384
+ *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
385
+ return self.backward(gradient, create_graph, retain_graph);
357
386
  })
358
387
  .define_method(
359
388
  "grad",
360
389
  *[](Tensor& self) {
361
- return self.grad();
390
+ auto grad = self.grad();
391
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
392
+ })
393
+ .define_method(
394
+ "grad=",
395
+ *[](Tensor& self, torch::Tensor& grad) {
396
+ self.grad() = grad;
362
397
  })
363
398
  .define_method(
364
399
  "_dtype",
@@ -502,6 +537,7 @@ void Init_ext()
502
537
  });
503
538
 
504
539
  Module rb_mInit = define_module_under(rb_mNN, "Init")
540
+ .add_handler<torch::Error>(handle_error)
505
541
  .define_singleton_method(
506
542
  "_calculate_gain",
507
543
  *[](NonlinearityType nonlinearity, double param) {
@@ -580,11 +616,16 @@ void Init_ext()
580
616
  *[](Parameter& self) {
581
617
  auto grad = self.grad();
582
618
  return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
619
+ })
620
+ .define_method(
621
+ "grad=",
622
+ *[](Parameter& self, torch::Tensor& grad) {
623
+ self.grad() = grad;
583
624
  });
584
625
 
585
626
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
586
- .define_constructor(Constructor<torch::Device, std::string>())
587
627
  .add_handler<torch::Error>(handle_error)
628
+ .define_constructor(Constructor<torch::Device, std::string>())
588
629
  .define_method("index", &torch::Device::index)
589
630
  .define_method("index?", &torch::Device::has_index)
590
631
  .define_method(
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
8
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
- # TODO check compiler name
11
- clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
10
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
12
11
 
13
12
  # check omp first
14
13
  if have_library("omp") || have_library("gomp")
15
14
  $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS += " -Xclang" if clang
15
+ $CXXFLAGS += " -Xclang" if apple_clang
17
16
  $CXXFLAGS += " -fopenmp"
18
17
  end
19
18
 
20
- if clang
19
+ if apple_clang
21
20
  # silence ruby/intern.h warning
22
21
  $CXXFLAGS += " -Wno-deprecated-register"
23
22
 
@@ -9,6 +9,10 @@
9
9
 
10
10
  using namespace Rice;
11
11
 
12
+ using torch::Device;
13
+ using torch::ScalarType;
14
+ using torch::Tensor;
15
+
12
16
  // need to wrap torch::IntArrayRef() since
13
17
  // it doesn't own underlying data
14
18
  class IntArrayRef {
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
174
178
  return MyReduction(x);
175
179
  }
176
180
 
177
- typedef torch::Tensor Tensor;
178
-
179
181
  class OptionalTensor {
180
182
  Object value;
181
183
  public:
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
197
199
  return OptionalTensor(x);
198
200
  }
199
201
 
200
- class ScalarType {
201
- Object value;
202
- public:
203
- ScalarType(Object o) {
204
- value = o;
205
- }
206
- operator at::ScalarType() {
207
- throw std::runtime_error("ScalarType arguments not implemented yet");
208
- }
209
- };
210
-
211
202
  template<>
212
203
  inline
213
- ScalarType from_ruby<ScalarType>(Object x)
204
+ torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
214
205
  {
215
- return ScalarType(x);
206
+ if (x.is_nil()) {
207
+ return torch::nullopt;
208
+ } else {
209
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
210
+ }
216
211
  }
217
212
 
218
- class OptionalScalarType {
219
- Object value;
220
- public:
221
- OptionalScalarType(Object o) {
222
- value = o;
223
- }
224
- operator c10::optional<at::ScalarType>() {
225
- if (value.is_nil()) {
226
- return c10::nullopt;
227
- }
228
- return ScalarType(value);
229
- }
230
- };
231
-
232
213
  template<>
233
214
  inline
234
- OptionalScalarType from_ruby<OptionalScalarType>(Object x)
215
+ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
235
216
  {
236
- return OptionalScalarType(x);
217
+ if (x.is_nil()) {
218
+ return torch::nullopt;
219
+ } else {
220
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
221
+ }
237
222
  }
238
223
 
239
- typedef torch::Device Device;
240
-
241
224
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
242
225
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
243
226
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -4,6 +4,7 @@ require "torch/ext"
4
4
  # stdlib
5
5
  require "fileutils"
6
6
  require "net/http"
7
+ require "set"
7
8
  require "tmpdir"
8
9
 
9
10
  # native functions
@@ -178,8 +179,10 @@ require "torch/nn/functional"
178
179
  require "torch/nn/init"
179
180
 
180
181
  # utils
182
+ require "torch/utils/data"
181
183
  require "torch/utils/data/data_loader"
182
184
  require "torch/utils/data/dataset"
185
+ require "torch/utils/data/subset"
183
186
  require "torch/utils/data/tensor_dataset"
184
187
 
185
188
  # hub
@@ -315,6 +318,16 @@ module Torch
315
318
  end
316
319
  end
317
320
 
321
+ def enable_grad
322
+ previous_value = grad_enabled?
323
+ begin
324
+ _set_grad_enabled(true)
325
+ yield
326
+ ensure
327
+ _set_grad_enabled(previous_value)
328
+ end
329
+ end
330
+
318
331
  def device(str)
319
332
  Device.new(str)
320
333
  end
@@ -375,6 +388,10 @@ module Torch
375
388
  end
376
389
 
377
390
  def randperm(n, **options)
391
+ # dtype hack in Python
392
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
393
+ options[:dtype] ||= :int64
394
+
378
395
  _randperm(n, tensor_options(**options))
379
396
  end
380
397
 
@@ -447,6 +464,22 @@ module Torch
447
464
  zeros(input.size, **like_options(input, options))
448
465
  end
449
466
 
467
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
468
+ if center
469
+ signal_dim = input.dim
470
+ extended_shape = [1] * (3 - signal_dim) + input.size
471
+ pad = n_fft.div(2).to_i
472
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
473
+ input = input.view(input.shape[-signal_dim..-1])
474
+ end
475
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
476
+ end
477
+
478
+ def clamp(tensor, min, max)
479
+ tensor = _clamp_min(tensor, min)
480
+ _clamp_max(tensor, max)
481
+ end
482
+
450
483
  private
451
484
 
452
485
  def to_ivalue(obj)
@@ -7,25 +7,26 @@ module Torch
7
7
 
8
8
  def download_url_to_file(url, dst)
9
9
  uri = URI(url)
10
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
10
+ tmp = nil
11
11
  location = nil
12
12
 
13
+ puts "Downloading #{url}..."
13
14
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
15
  request = Net::HTTP::Get.new(uri)
15
16
 
16
- puts "Downloading #{url}..."
17
- File.open(tmp, "wb") do |f|
18
- http.request(request) do |response|
19
- case response
20
- when Net::HTTPRedirection
21
- location = response["location"]
22
- when Net::HTTPSuccess
17
+ http.request(request) do |response|
18
+ case response
19
+ when Net::HTTPRedirection
20
+ location = response["location"]
21
+ when Net::HTTPSuccess
22
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
+ File.open(tmp, "wb") do |f|
23
24
  response.read_body do |chunk|
24
25
  f.write(chunk)
25
26
  end
26
- else
27
- raise Error, "Bad response"
28
27
  end
28
+ else
29
+ raise Error, "Bad response"
29
30
  end
30
31
  end
31
32
  end
@@ -1,10 +1,14 @@
1
1
  module Torch
2
2
  module Native
3
3
  class Function
4
- attr_reader :function
4
+ attr_reader :function, :tensor_options
5
5
 
6
6
  def initialize(function)
7
7
  @function = function
8
+
9
+ tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
+ @tensor_options = @function["func"].include?(tensor_options_str)
11
+ @function["func"].sub!(tensor_options_str, ")")
8
12
  end
9
13
 
10
14
  def func
@@ -33,30 +33,14 @@ module Torch
33
33
  f.args.any? do |a|
34
34
  a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
35
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
+ # call to 'range' is ambiguous
37
+ f.cpp_name == "_range" ||
36
38
  # native_functions.yaml is missing size argument for normal
37
39
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
40
  (f.base_name == "normal" && !f.out?)
39
41
  end
40
42
  end
41
43
 
42
- # generate additional functions for optional arguments
43
- # there may be a better way to do this
44
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
45
- optional_functions.each do |f|
46
- next if f.ruby_name == "cross"
47
- next if f.ruby_name.start_with?("avg_pool") && f.out?
48
-
49
- opt_args = f.args.select { |a| a[:type] == "int?" }
50
- if opt_args.size == 1
51
- sep = f.name.include?(".") ? "_" : "."
52
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
53
- # TODO only remove some arguments
54
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
55
- functions << f1
56
- functions << f2
57
- end
58
- end
59
-
60
44
  # todo_functions.each do |f|
61
45
  # puts f.func
62
46
  # puts
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
97
81
 
98
82
  cpp_defs = []
99
83
  functions.sort_by(&:cpp_name).each do |func|
100
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
84
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
85
+ fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
101
86
 
102
87
  cpp_args = []
103
88
  fargs.each do |a|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
109
94
  # TODO better signature
110
95
  "OptionalTensor"
111
96
  when "ScalarType?"
112
- "OptionalScalarType"
97
+ "torch::optional<ScalarType>"
113
98
  when "Tensor[]"
114
99
  "TensorList"
115
100
  when "Tensor?[]"
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
117
102
  "TensorList"
118
103
  when "int"
119
104
  "int64_t"
105
+ when "int?"
106
+ "torch::optional<int64_t>"
120
107
  when "float"
121
108
  "double"
122
109
  when /\Aint\[/
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
125
112
  "Tensor &"
126
113
  when "str"
127
114
  "std::string"
115
+ when "TensorOptions"
116
+ "const torch::TensorOptions &"
128
117
  else
129
118
  a[:type]
130
119
  end
@@ -83,6 +83,8 @@ module Torch
83
83
  else
84
84
  v.is_a?(Integer)
85
85
  end
86
+ when "int?"
87
+ v.is_a?(Integer) || v.nil?
86
88
  when "float"
87
89
  v.is_a?(Numeric)
88
90
  when /int\[.*\]/
@@ -126,9 +128,11 @@ module Torch
126
128
  end
127
129
 
128
130
  func = candidates.first
131
+ args = func.args.map { |a| final_values[a[:name]] }
132
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
133
  {
130
134
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
135
+ args: args
132
136
  }
133
137
  end
134
138
  end
@@ -373,7 +373,8 @@ module Torch
373
373
  end
374
374
 
375
375
  # weight and input swapped
376
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
376
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
377
+ ret
377
378
  end
378
379
 
379
380
  # distance functions
@@ -426,6 +427,9 @@ module Torch
426
427
  end
427
428
 
428
429
  def mse_loss(input, target, reduction: "mean")
430
+ if target.size != input.size
431
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
432
+ end
429
433
  NN.mse_loss(input, target, reduction)
430
434
  end
431
435
 
@@ -103,8 +103,9 @@ module Torch
103
103
  Torch.empty(0, dtype: dtype)
104
104
  end
105
105
 
106
- def backward(gradient = nil)
107
- _backward(gradient)
106
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
+ retain_graph = create_graph if retain_graph.nil?
108
+ _backward(gradient, retain_graph, create_graph)
108
109
  end
109
110
 
110
111
  # TODO read directly from memory
@@ -187,49 +188,15 @@ module Torch
187
188
  # based on python_variable_indexing.cpp and
188
189
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
189
190
  def [](*indexes)
190
- result = self
191
- dim = 0
192
- indexes.each do |index|
193
- if index.is_a?(Numeric)
194
- result = result._select_int(dim, index)
195
- elsif index.is_a?(Range)
196
- finish = index.end
197
- finish += 1 unless index.exclude_end?
198
- result = result._slice_tensor(dim, index.begin, finish, 1)
199
- dim += 1
200
- elsif index.is_a?(Tensor)
201
- result = result.index([index])
202
- elsif index.nil?
203
- result = result.unsqueeze(dim)
204
- dim += 1
205
- elsif index == true
206
- result = result.unsqueeze(dim)
207
- # TODO handle false
208
- else
209
- raise Error, "Unsupported index type: #{index.class.name}"
210
- end
211
- end
212
- result
191
+ _index(tensor_indexes(indexes))
213
192
  end
214
193
 
215
194
  # based on python_variable_indexing.cpp and
216
195
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
217
- def []=(index, value)
196
+ def []=(*indexes, value)
218
197
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
219
-
220
198
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
221
-
222
- if index.is_a?(Numeric)
223
- index_put!([Torch.tensor(index)], value)
224
- elsif index.is_a?(Range)
225
- finish = index.end
226
- finish += 1 unless index.exclude_end?
227
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
228
- elsif index.is_a?(Tensor)
229
- index_put!([index], value)
230
- else
231
- raise Error, "Unsupported index type: #{index.class.name}"
232
- end
199
+ _index_put_custom(tensor_indexes(indexes), value)
233
200
  end
234
201
 
235
202
  # native functions that need manually defined
@@ -243,13 +210,13 @@ module Torch
243
210
  end
244
211
  end
245
212
 
246
- # native functions overlap, so need to handle manually
213
+ # parser can't handle overlap, so need to handle manually
247
214
  def random!(*args)
248
215
  case args.size
249
216
  when 1
250
217
  _random__to(*args)
251
218
  when 2
252
- _random__from_to(*args)
219
+ _random__from(*args)
253
220
  else
254
221
  _random_(*args)
255
222
  end
@@ -259,5 +226,32 @@ module Torch
259
226
  _clamp_min_(min)
260
227
  _clamp_max_(max)
261
228
  end
229
+
230
+ private
231
+
232
+ def tensor_indexes(indexes)
233
+ indexes.map do |index|
234
+ case index
235
+ when Integer
236
+ TensorIndex.integer(index)
237
+ when Range
238
+ finish = index.end
239
+ if finish == -1 && !index.exclude_end?
240
+ finish = nil
241
+ else
242
+ finish += 1 unless index.exclude_end?
243
+ end
244
+ TensorIndex.slice(index.begin, finish)
245
+ when Tensor
246
+ TensorIndex.tensor(index)
247
+ when nil
248
+ TensorIndex.none
249
+ when true, false
250
+ TensorIndex.boolean(index)
251
+ else
252
+ raise Error, "Unsupported index type: #{index.class.name}"
253
+ end
254
+ end
255
+ end
262
256
  end
263
257
  end
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class << self
5
+ def random_split(dataset, lengths)
6
+ if lengths.sum != dataset.length
7
+ raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
8
+ end
9
+
10
+ indices = Torch.randperm(lengths.sum).to_a
11
+ _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
12
+ end
13
+
14
+ private
15
+
16
+ def _accumulate(iterable)
17
+ sum = 0
18
+ iterable.map { |x| sum += x }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -6,10 +6,22 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1, shuffle: false)
9
+ def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
12
  @shuffle = shuffle
13
+
14
+ @batch_sampler = nil
15
+
16
+ if collate_fn.nil?
17
+ if auto_collation?
18
+ collate_fn = method(:default_collate)
19
+ else
20
+ collate_fn = method(:default_convert)
21
+ end
22
+ end
23
+
24
+ @collate_fn = collate_fn
13
25
  end
14
26
 
15
27
  def each
@@ -25,8 +37,8 @@ module Torch
25
37
  end
26
38
 
27
39
  indexes.each_slice(@batch_size) do |idx|
28
- batch = idx.map { |i| @dataset[i] }
29
- yield collate(batch)
40
+ # TODO improve performance
41
+ yield @collate_fn.call(idx.map { |i| @dataset[i] })
30
42
  end
31
43
  end
32
44
 
@@ -36,7 +48,7 @@ module Torch
36
48
 
37
49
  private
38
50
 
39
- def collate(batch)
51
+ def default_convert(batch)
40
52
  elem = batch[0]
41
53
  case elem
42
54
  when Tensor
@@ -44,11 +56,15 @@ module Torch
44
56
  when Integer
45
57
  Torch.tensor(batch)
46
58
  when Array
47
- batch.transpose.map { |v| collate(v) }
59
+ batch.transpose.map { |v| default_convert(v) }
48
60
  else
49
- raise NotImpelmentYet
61
+ raise NotImplementedYet
50
62
  end
51
63
  end
64
+
65
+ def auto_collation?
66
+ !@batch_sampler.nil?
67
+ end
52
68
  end
53
69
  end
54
70
  end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Subset < Dataset
5
+ def initialize(dataset, indices)
6
+ @dataset = dataset
7
+ @indices = indices
8
+ end
9
+
10
+ def [](idx)
11
+ @dataset[@indices[idx]]
12
+ end
13
+
14
+ def length
15
+ @indices.length
16
+ end
17
+ alias_method :size, :length
18
+
19
+ def to_a
20
+ @indices.map { |i| @dataset[i] }
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.0"
2
+ VERSION = "0.3.5"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-07-29 00:00:00.000000000 Z
11
+ date: 2020-09-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -259,8 +259,10 @@ files:
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
+ - lib/torch/utils/data.rb
262
263
  - lib/torch/utils/data/data_loader.rb
263
264
  - lib/torch/utils/data/dataset.rb
265
+ - lib/torch/utils/data/subset.rb
264
266
  - lib/torch/utils/data/tensor_dataset.rb
265
267
  - lib/torch/version.rb
266
268
  homepage: https://github.com/ankane/torch.rb