torch-rb 0.3.1 → 0.3.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 06e94b492acbbdb71f9e6a11081fb043a03ae0d5c704cc79faa31dd96bde70ef
4
- data.tar.gz: 4f38fa52d30ef9bf121204423b4d675f21dbef806b6f137152f2cf9399ddf4bb
3
+ metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
4
+ data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
5
5
  SHA512:
6
- metadata.gz: 2fb2613ca629a70f55009b697b15830d59c0d8fc06c1c5102917b4870cb783427fb56ecc08889c09e15c342381385f258b2a33102dc5adddf2d463d41674994d
7
- data.tar.gz: f26a6ba91caa57a92b8b047217a35c39d1e9c4c361df77e2182053b4ab490f20792fc88dba169dae87d4a3d4ee4d69e2c779efb1fa6150b4d3f0d93e3762aec9
6
+ metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
7
+ data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
@@ -1,3 +1,31 @@
1
+ ## 0.3.6 (2020-09-17)
2
+
3
+ - Added `inplace` option for leaky ReLU
4
+ - Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
5
+ - Fixed error with buffers on GPU
6
+
7
+ ## 0.3.5 (2020-09-04)
8
+
9
+ - Fixed error with data loader (due to `dtype` of `randperm`)
10
+
11
+ ## 0.3.4 (2020-08-26)
12
+
13
+ - Added `Torch.clamp` method
14
+
15
+ ## 0.3.3 (2020-08-25)
16
+
17
+ - Added spectral ops
18
+ - Fixed tensor indexing
19
+
20
+ ## 0.3.2 (2020-08-24)
21
+
22
+ - Added `enable_grad` method
23
+ - Added `random_split` method
24
+ - Added `collate_fn` option to `DataLoader`
25
+ - Added `grad=` method to `Tensor`
26
+ - Fixed error with `grad` method when empty
27
+ - Fixed `EmbeddingBag`
28
+
1
29
  ## 0.3.1 (2020-08-17)
2
30
 
3
31
  - Added `create_graph` and `retain_graph` options to `backward` method
data/README.md CHANGED
@@ -2,7 +2,11 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
5
+ Check out:
6
+
7
+ - [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
8
+ - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
+ - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
6
10
 
7
11
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
8
12
 
@@ -411,7 +415,7 @@ Here’s the list of compatible versions.
411
415
 
412
416
  Torch.rb | LibTorch
413
417
  --- | ---
414
- 0.3.0-0.3.1 | 1.6.0
418
+ 0.3.0-0.3.4 | 1.6.0
415
419
  0.2.0-0.2.7 | 1.5.0-1.5.1
416
420
  0.1.8 | 1.4.0
417
421
  0.1.0-0.1.7 | 1.3.1
@@ -16,6 +16,7 @@
16
16
  #include "nn_functions.hpp"
17
17
 
18
18
  using namespace Rice;
19
+ using torch::indexing::TensorIndex;
19
20
 
20
21
  // need to make a distinction between parameters and tensors
21
22
  class Parameter: public torch::autograd::Variable {
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
28
29
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
30
  }
30
31
 
32
+ std::vector<TensorIndex> index_vector(Array a) {
33
+ auto indices = std::vector<TensorIndex>();
34
+ indices.reserve(a.size());
35
+ for (size_t i = 0; i < a.size(); i++) {
36
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
37
+ }
38
+ return indices;
39
+ }
40
+
31
41
  extern "C"
32
42
  void Init_ext()
33
43
  {
@@ -58,6 +68,13 @@ void Init_ext()
58
68
  return generator.seed();
59
69
  });
60
70
 
71
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
+
61
78
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
62
79
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
63
80
  .add_handler<torch::Error>(handle_error)
@@ -284,11 +301,6 @@ void Init_ext()
284
301
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
285
302
  return torch::pickle_load(v);
286
303
  })
287
- .define_singleton_method(
288
- "_binary_cross_entropy_with_logits",
289
- *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
290
- return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
291
- })
292
304
  .define_singleton_method(
293
305
  "_from_blob",
294
306
  *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
@@ -330,6 +342,18 @@ void Init_ext()
330
342
  .define_method("numel", &torch::Tensor::numel)
331
343
  .define_method("element_size", &torch::Tensor::element_size)
332
344
  .define_method("requires_grad", &torch::Tensor::requires_grad)
345
+ .define_method(
346
+ "_index",
347
+ *[](Tensor& self, Array indices) {
348
+ auto vec = index_vector(indices);
349
+ return self.index(vec);
350
+ })
351
+ .define_method(
352
+ "_index_put_custom",
353
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
354
+ auto vec = index_vector(indices);
355
+ return self.index_put_(vec, value);
356
+ })
333
357
  .define_method(
334
358
  "contiguous?",
335
359
  *[](Tensor& self) {
@@ -350,15 +374,16 @@ void Init_ext()
350
374
  *[](Tensor& self, bool requires_grad) {
351
375
  return self.set_requires_grad(requires_grad);
352
376
  })
353
- .define_method(
354
- "_backward",
355
- *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
356
- return self.backward(gradient, create_graph, retain_graph);
357
- })
358
377
  .define_method(
359
378
  "grad",
360
379
  *[](Tensor& self) {
361
- return self.grad();
380
+ auto grad = self.grad();
381
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
382
+ })
383
+ .define_method(
384
+ "grad=",
385
+ *[](Tensor& self, torch::Tensor& grad) {
386
+ self.grad() = grad;
362
387
  })
363
388
  .define_method(
364
389
  "_dtype",
@@ -502,6 +527,7 @@ void Init_ext()
502
527
  });
503
528
 
504
529
  Module rb_mInit = define_module_under(rb_mNN, "Init")
530
+ .add_handler<torch::Error>(handle_error)
505
531
  .define_singleton_method(
506
532
  "_calculate_gain",
507
533
  *[](NonlinearityType nonlinearity, double param) {
@@ -580,11 +606,16 @@ void Init_ext()
580
606
  *[](Parameter& self) {
581
607
  auto grad = self.grad();
582
608
  return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
609
+ })
610
+ .define_method(
611
+ "grad=",
612
+ *[](Parameter& self, torch::Tensor& grad) {
613
+ self.grad() = grad;
583
614
  });
584
615
 
585
616
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
586
- .define_constructor(Constructor<torch::Device, std::string>())
587
617
  .add_handler<torch::Error>(handle_error)
618
+ .define_constructor(Constructor<torch::Device, std::string>())
588
619
  .define_method("index", &torch::Device::index)
589
620
  .define_method("index?", &torch::Device::has_index)
590
621
  .define_method(
@@ -53,3 +53,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
53
53
  a.push(to_ruby<int64_t>(std::get<3>(x)));
54
54
  return Object(a);
55
55
  }
56
+
57
+ Object wrap(std::vector<torch::Tensor> x) {
58
+ Array a;
59
+ for (auto& t : x) {
60
+ a.push(to_ruby<torch::Tensor>(t));
61
+ }
62
+ return Object(a);
63
+ }
@@ -9,6 +9,11 @@
9
9
 
10
10
  using namespace Rice;
11
11
 
12
+ using torch::Device;
13
+ using torch::Scalar;
14
+ using torch::ScalarType;
15
+ using torch::Tensor;
16
+
12
17
  // need to wrap torch::IntArrayRef() since
13
18
  // it doesn't own underlying data
14
19
  class IntArrayRef {
@@ -32,30 +37,6 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
32
37
  return IntArrayRef(x);
33
38
  }
34
39
 
35
- // for now
36
- class Scalar {
37
- torch::Scalar value;
38
- public:
39
- Scalar(Object o) {
40
- // TODO cast based on Ruby type
41
- if (o.rb_type() == T_FIXNUM) {
42
- value = torch::Scalar(from_ruby<int64_t>(o));
43
- } else {
44
- value = torch::Scalar(from_ruby<float>(o));
45
- }
46
- }
47
- operator torch::Scalar() {
48
- return value;
49
- }
50
- };
51
-
52
- template<>
53
- inline
54
- Scalar from_ruby<Scalar>(Object x)
55
- {
56
- return Scalar(x);
57
- }
58
-
59
40
  class TensorList {
60
41
  std::vector<torch::Tensor> vec;
61
42
  public:
@@ -174,8 +155,6 @@ MyReduction from_ruby<MyReduction>(Object x)
174
155
  return MyReduction(x);
175
156
  }
176
157
 
177
- typedef torch::Tensor Tensor;
178
-
179
158
  class OptionalTensor {
180
159
  Object value;
181
160
  public:
@@ -190,6 +169,17 @@ class OptionalTensor {
190
169
  }
191
170
  };
192
171
 
172
+ template<>
173
+ inline
174
+ Scalar from_ruby<Scalar>(Object x)
175
+ {
176
+ if (x.rb_type() == T_FIXNUM) {
177
+ return torch::Scalar(from_ruby<int64_t>(x));
178
+ } else {
179
+ return torch::Scalar(from_ruby<double>(x));
180
+ }
181
+ }
182
+
193
183
  template<>
194
184
  inline
195
185
  OptionalTensor from_ruby<OptionalTensor>(Object x)
@@ -197,46 +187,60 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
197
187
  return OptionalTensor(x);
198
188
  }
199
189
 
200
- class ScalarType {
201
- Object value;
202
- public:
203
- ScalarType(Object o) {
204
- value = o;
205
- }
206
- operator at::ScalarType() {
207
- throw std::runtime_error("ScalarType arguments not implemented yet");
208
- }
209
- };
190
+ template<>
191
+ inline
192
+ torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
193
+ {
194
+ if (x.is_nil()) {
195
+ return torch::nullopt;
196
+ } else {
197
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
198
+ }
199
+ }
210
200
 
211
201
  template<>
212
202
  inline
213
- ScalarType from_ruby<ScalarType>(Object x)
203
+ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
214
204
  {
215
- return ScalarType(x);
205
+ if (x.is_nil()) {
206
+ return torch::nullopt;
207
+ } else {
208
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
209
+ }
216
210
  }
217
211
 
218
- class OptionalScalarType {
219
- Object value;
220
- public:
221
- OptionalScalarType(Object o) {
222
- value = o;
223
- }
224
- operator c10::optional<at::ScalarType>() {
225
- if (value.is_nil()) {
226
- return c10::nullopt;
227
- }
228
- return ScalarType(value);
229
- }
230
- };
212
+ template<>
213
+ inline
214
+ torch::optional<double> from_ruby<torch::optional<double>>(Object x)
215
+ {
216
+ if (x.is_nil()) {
217
+ return torch::nullopt;
218
+ } else {
219
+ return torch::optional<double>{from_ruby<double>(x)};
220
+ }
221
+ }
231
222
 
232
223
  template<>
233
224
  inline
234
- OptionalScalarType from_ruby<OptionalScalarType>(Object x)
225
+ torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
235
226
  {
236
- return OptionalScalarType(x);
227
+ if (x.is_nil()) {
228
+ return torch::nullopt;
229
+ } else {
230
+ return torch::optional<bool>{from_ruby<bool>(x)};
231
+ }
237
232
  }
238
233
 
239
- typedef torch::Device Device;
234
+ template<>
235
+ inline
236
+ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
237
+ {
238
+ if (x.is_nil()) {
239
+ return torch::nullopt;
240
+ } else {
241
+ return torch::optional<Scalar>{from_ruby<Scalar>(x)};
242
+ }
243
+ }
240
244
 
241
245
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
242
246
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -244,3 +248,4 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
244
248
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
245
249
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
246
250
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
251
+ Object wrap(std::vector<torch::Tensor> x);
@@ -179,8 +179,10 @@ require "torch/nn/functional"
179
179
  require "torch/nn/init"
180
180
 
181
181
  # utils
182
+ require "torch/utils/data"
182
183
  require "torch/utils/data/data_loader"
183
184
  require "torch/utils/data/dataset"
185
+ require "torch/utils/data/subset"
184
186
  require "torch/utils/data/tensor_dataset"
185
187
 
186
188
  # hub
@@ -237,25 +239,22 @@ module Torch
237
239
  cls
238
240
  end
239
241
 
240
- FloatTensor = _make_tensor_class(:float32)
241
- DoubleTensor = _make_tensor_class(:float64)
242
- HalfTensor = _make_tensor_class(:float16)
243
- ByteTensor = _make_tensor_class(:uint8)
244
- CharTensor = _make_tensor_class(:int8)
245
- ShortTensor = _make_tensor_class(:int16)
246
- IntTensor = _make_tensor_class(:int32)
247
- LongTensor = _make_tensor_class(:int64)
248
- BoolTensor = _make_tensor_class(:bool)
249
-
250
- CUDA::FloatTensor = _make_tensor_class(:float32, true)
251
- CUDA::DoubleTensor = _make_tensor_class(:float64, true)
252
- CUDA::HalfTensor = _make_tensor_class(:float16, true)
253
- CUDA::ByteTensor = _make_tensor_class(:uint8, true)
254
- CUDA::CharTensor = _make_tensor_class(:int8, true)
255
- CUDA::ShortTensor = _make_tensor_class(:int16, true)
256
- CUDA::IntTensor = _make_tensor_class(:int32, true)
257
- CUDA::LongTensor = _make_tensor_class(:int64, true)
258
- CUDA::BoolTensor = _make_tensor_class(:bool, true)
242
+ DTYPE_TO_CLASS = {
243
+ float32: "FloatTensor",
244
+ float64: "DoubleTensor",
245
+ float16: "HalfTensor",
246
+ uint8: "ByteTensor",
247
+ int8: "CharTensor",
248
+ int16: "ShortTensor",
249
+ int32: "IntTensor",
250
+ int64: "LongTensor",
251
+ bool: "BoolTensor"
252
+ }
253
+
254
+ DTYPE_TO_CLASS.each do |dtype, class_name|
255
+ const_set(class_name, _make_tensor_class(dtype))
256
+ CUDA.const_set(class_name, _make_tensor_class(dtype, true))
257
+ end
259
258
 
260
259
  class << self
261
260
  # Torch.float, Torch.long, etc
@@ -316,6 +315,16 @@ module Torch
316
315
  end
317
316
  end
318
317
 
318
+ def enable_grad
319
+ previous_value = grad_enabled?
320
+ begin
321
+ _set_grad_enabled(true)
322
+ yield
323
+ ensure
324
+ _set_grad_enabled(previous_value)
325
+ end
326
+ end
327
+
319
328
  def device(str)
320
329
  Device.new(str)
321
330
  end
@@ -376,6 +385,10 @@ module Torch
376
385
  end
377
386
 
378
387
  def randperm(n, **options)
388
+ # dtype hack in Python
389
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
390
+ options[:dtype] ||= :int64
391
+
379
392
  _randperm(n, tensor_options(**options))
380
393
  end
381
394
 
@@ -448,6 +461,22 @@ module Torch
448
461
  zeros(input.size, **like_options(input, options))
449
462
  end
450
463
 
464
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
465
+ if center
466
+ signal_dim = input.dim
467
+ extended_shape = [1] * (3 - signal_dim) + input.size
468
+ pad = n_fft.div(2).to_i
469
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
470
+ input = input.view(input.shape[-signal_dim..-1])
471
+ end
472
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
473
+ end
474
+
475
+ def clamp(tensor, min, max)
476
+ tensor = _clamp_min(tensor, min)
477
+ _clamp_max(tensor, max)
478
+ end
479
+
451
480
  private
452
481
 
453
482
  def to_ivalue(obj)
@@ -7,25 +7,26 @@ module Torch
7
7
 
8
8
  def download_url_to_file(url, dst)
9
9
  uri = URI(url)
10
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
10
+ tmp = nil
11
11
  location = nil
12
12
 
13
+ puts "Downloading #{url}..."
13
14
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
15
  request = Net::HTTP::Get.new(uri)
15
16
 
16
- puts "Downloading #{url}..."
17
- File.open(tmp, "wb") do |f|
18
- http.request(request) do |response|
19
- case response
20
- when Net::HTTPRedirection
21
- location = response["location"]
22
- when Net::HTTPSuccess
17
+ http.request(request) do |response|
18
+ case response
19
+ when Net::HTTPRedirection
20
+ location = response["location"]
21
+ when Net::HTTPSuccess
22
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
+ File.open(tmp, "wb") do |f|
23
24
  response.read_body do |chunk|
24
25
  f.write(chunk)
25
26
  end
26
- else
27
- raise Error, "Bad response"
28
27
  end
28
+ else
29
+ raise Error, "Bad response"
29
30
  end
30
31
  end
31
32
  end
@@ -1,10 +1,14 @@
1
1
  module Torch
2
2
  module Native
3
3
  class Function
4
- attr_reader :function
4
+ attr_reader :function, :tensor_options
5
5
 
6
6
  def initialize(function)
7
7
  @function = function
8
+
9
+ tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
+ @tensor_options = @function["func"].include?(tensor_options_str)
11
+ @function["func"].sub!(tensor_options_str, ")")
8
12
  end
9
13
 
10
14
  def func
@@ -82,6 +86,10 @@ module Torch
82
86
  @ret_size ||= func.split("->").last.split(", ").size
83
87
  end
84
88
 
89
+ def ret_array?
90
+ @ret_array ||= func.split("->").last.include?('[]')
91
+ end
92
+
85
93
  def out?
86
94
  out_size > 0 && base_name[-1] != "_"
87
95
  end
@@ -18,12 +18,12 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
25
25
  f.ruby_name.start_with?("_") ||
26
- f.ruby_name.end_with?("_backward") ||
26
+ f.ruby_name.include?("_backward") ||
27
27
  f.args.any? { |a| a[:type].include?("Dimname") }
28
28
  end
29
29
 
@@ -31,32 +31,15 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
34
  skip_args.any? { |sa| a[:type].include?(sa) } ||
35
+ # call to 'range' is ambiguous
36
+ f.cpp_name == "_range" ||
36
37
  # native_functions.yaml is missing size argument for normal
37
38
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
39
  (f.base_name == "normal" && !f.out?)
39
40
  end
40
41
  end
41
42
 
42
- # generate additional functions for optional arguments
43
- # there may be a better way to do this
44
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
45
- optional_functions.each do |f|
46
- next if f.ruby_name == "cross"
47
- next if f.ruby_name.start_with?("avg_pool") && f.out?
48
-
49
- opt_args = f.args.select { |a| a[:type] == "int?" }
50
- if opt_args.size == 1
51
- sep = f.name.include?(".") ? "_" : "."
52
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
53
- # TODO only remove some arguments
54
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
55
- functions << f1
56
- functions << f2
57
- end
58
- end
59
-
60
43
  # todo_functions.each do |f|
61
44
  # puts f.func
62
45
  # puts
@@ -97,7 +80,8 @@ void add_%{type}_functions(Module m) {
97
80
 
98
81
  cpp_defs = []
99
82
  functions.sort_by(&:cpp_name).each do |func|
100
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
83
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
84
+ fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
101
85
 
102
86
  cpp_args = []
103
87
  fargs.each do |a|
@@ -109,7 +93,7 @@ void add_%{type}_functions(Module m) {
109
93
  # TODO better signature
110
94
  "OptionalTensor"
111
95
  when "ScalarType?"
112
- "OptionalScalarType"
96
+ "torch::optional<ScalarType>"
113
97
  when "Tensor[]"
114
98
  "TensorList"
115
99
  when "Tensor?[]"
@@ -117,6 +101,14 @@ void add_%{type}_functions(Module m) {
117
101
  "TensorList"
118
102
  when "int"
119
103
  "int64_t"
104
+ when "int?"
105
+ "torch::optional<int64_t>"
106
+ when "float?"
107
+ "torch::optional<double>"
108
+ when "bool?"
109
+ "torch::optional<bool>"
110
+ when "Scalar?"
111
+ "torch::optional<torch::Scalar>"
120
112
  when "float"
121
113
  "double"
122
114
  when /\Aint\[/
@@ -125,6 +117,8 @@ void add_%{type}_functions(Module m) {
125
117
  "Tensor &"
126
118
  when "str"
127
119
  "std::string"
120
+ when "TensorOptions"
121
+ "const torch::TensorOptions &"
128
122
  else
129
123
  a[:type]
130
124
  end
@@ -141,8 +135,8 @@ void add_%{type}_functions(Module m) {
141
135
  prefix = def_method == :define_method ? "self." : "torch::"
142
136
 
143
137
  body = "#{prefix}#{dispatch}(#{args.join(", ")})"
144
- # TODO check type as well
145
- if func.ret_size > 1
138
+
139
+ if func.ret_size > 1 || func.ret_array?
146
140
  body = "wrap(#{body})"
147
141
  end
148
142
 
@@ -83,6 +83,12 @@ module Torch
83
83
  else
84
84
  v.is_a?(Integer)
85
85
  end
86
+ when "int?"
87
+ v.is_a?(Integer) || v.nil?
88
+ when "float?"
89
+ v.is_a?(Numeric) || v.nil?
90
+ when "bool?"
91
+ v == true || v == false || v.nil?
86
92
  when "float"
87
93
  v.is_a?(Numeric)
88
94
  when /int\[.*\]/
@@ -95,6 +101,10 @@ module Torch
95
101
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
96
102
  when "Scalar"
97
103
  v.is_a?(Numeric)
104
+ when "Scalar?"
105
+ v.is_a?(Numeric) || v.nil?
106
+ when "ScalarType"
107
+ false # not supported yet
98
108
  when "ScalarType?"
99
109
  v.nil?
100
110
  when "bool"
@@ -126,9 +136,11 @@ module Torch
126
136
  end
127
137
 
128
138
  func = candidates.first
139
+ args = func.args.map { |a| final_values[a[:name]] }
140
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
141
  {
130
142
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
143
+ args: args
132
144
  }
133
145
  end
134
146
  end
@@ -178,8 +178,12 @@ module Torch
178
178
  Torch.hardshrink(input, lambd)
179
179
  end
180
180
 
181
- def leaky_relu(input, negative_slope = 0.01)
182
- NN.leaky_relu(input, negative_slope)
181
+ def leaky_relu(input, negative_slope = 0.01, inplace: false)
182
+ if inplace
183
+ NN.leaky_relu!(input, negative_slope)
184
+ else
185
+ NN.leaky_relu(input, negative_slope)
186
+ end
183
187
  end
184
188
 
185
189
  def log_sigmoid(input)
@@ -373,7 +377,8 @@ module Torch
373
377
  end
374
378
 
375
379
  # weight and input swapped
376
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
380
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
381
+ ret
377
382
  end
378
383
 
379
384
  # distance functions
@@ -426,6 +431,9 @@ module Torch
426
431
  end
427
432
 
428
433
  def mse_loss(input, target, reduction: "mean")
434
+ if target.size != input.size
435
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
436
+ end
429
437
  NN.mse_loss(input, target, reduction)
430
438
  end
431
439
 
@@ -1,14 +1,14 @@
1
1
  module Torch
2
2
  module NN
3
3
  class LeakyReLU < Module
4
- def initialize(negative_slope: 1e-2) #, inplace: false)
4
+ def initialize(negative_slope: 1e-2, inplace: false)
5
5
  super()
6
6
  @negative_slope = negative_slope
7
- # @inplace = inplace
7
+ @inplace = inplace
8
8
  end
9
9
 
10
10
  def forward(input)
11
- F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
11
+ F.leaky_relu(input, @negative_slope, inplace: @inplace)
12
12
  end
13
13
 
14
14
  def extra_inspect
@@ -55,7 +55,12 @@ module Torch
55
55
  end
56
56
  end
57
57
  end
58
- # TODO apply to more objects
58
+
59
+ @buffers.each_key do |k|
60
+ buf = @buffers[k]
61
+ @buffers[k] = fn.call(buf) unless buf.nil?
62
+ end
63
+
59
64
  self
60
65
  end
61
66
 
@@ -103,11 +103,6 @@ module Torch
103
103
  Torch.empty(0, dtype: dtype)
104
104
  end
105
105
 
106
- def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
- retain_graph = create_graph if retain_graph.nil?
108
- _backward(gradient, retain_graph, create_graph)
109
- end
110
-
111
106
  # TODO read directly from memory
112
107
  def numo
113
108
  cls = Torch._dtype_to_numo[dtype]
@@ -188,49 +183,15 @@ module Torch
188
183
  # based on python_variable_indexing.cpp and
189
184
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
190
185
  def [](*indexes)
191
- result = self
192
- dim = 0
193
- indexes.each do |index|
194
- if index.is_a?(Numeric)
195
- result = result._select_int(dim, index)
196
- elsif index.is_a?(Range)
197
- finish = index.end
198
- finish += 1 unless index.exclude_end?
199
- result = result._slice_tensor(dim, index.begin, finish, 1)
200
- dim += 1
201
- elsif index.is_a?(Tensor)
202
- result = result.index([index])
203
- elsif index.nil?
204
- result = result.unsqueeze(dim)
205
- dim += 1
206
- elsif index == true
207
- result = result.unsqueeze(dim)
208
- # TODO handle false
209
- else
210
- raise Error, "Unsupported index type: #{index.class.name}"
211
- end
212
- end
213
- result
186
+ _index(tensor_indexes(indexes))
214
187
  end
215
188
 
216
189
  # based on python_variable_indexing.cpp and
217
190
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
218
- def []=(index, value)
191
+ def []=(*indexes, value)
219
192
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
220
-
221
193
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
222
-
223
- if index.is_a?(Numeric)
224
- index_put!([Torch.tensor(index)], value)
225
- elsif index.is_a?(Range)
226
- finish = index.end
227
- finish += 1 unless index.exclude_end?
228
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
- elsif index.is_a?(Tensor)
230
- index_put!([index], value)
231
- else
232
- raise Error, "Unsupported index type: #{index.class.name}"
233
- end
194
+ _index_put_custom(tensor_indexes(indexes), value)
234
195
  end
235
196
 
236
197
  # native functions that need manually defined
@@ -244,13 +205,13 @@ module Torch
244
205
  end
245
206
  end
246
207
 
247
- # native functions overlap, so need to handle manually
208
+ # parser can't handle overlap, so need to handle manually
248
209
  def random!(*args)
249
210
  case args.size
250
211
  when 1
251
212
  _random__to(*args)
252
213
  when 2
253
- _random__from_to(*args)
214
+ _random__from(*args)
254
215
  else
255
216
  _random_(*args)
256
217
  end
@@ -260,5 +221,32 @@ module Torch
260
221
  _clamp_min_(min)
261
222
  _clamp_max_(max)
262
223
  end
224
+
225
+ private
226
+
227
+ def tensor_indexes(indexes)
228
+ indexes.map do |index|
229
+ case index
230
+ when Integer
231
+ TensorIndex.integer(index)
232
+ when Range
233
+ finish = index.end || -1
234
+ if finish == -1 && !index.exclude_end?
235
+ finish = nil
236
+ else
237
+ finish += 1 unless index.exclude_end?
238
+ end
239
+ TensorIndex.slice(index.begin, finish)
240
+ when Tensor
241
+ TensorIndex.tensor(index)
242
+ when nil
243
+ TensorIndex.none
244
+ when true, false
245
+ TensorIndex.boolean(index)
246
+ else
247
+ raise Error, "Unsupported index type: #{index.class.name}"
248
+ end
249
+ end
250
+ end
263
251
  end
264
252
  end
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class << self
5
+ def random_split(dataset, lengths)
6
+ if lengths.sum != dataset.length
7
+ raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
8
+ end
9
+
10
+ indices = Torch.randperm(lengths.sum).to_a
11
+ _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
12
+ end
13
+
14
+ private
15
+
16
+ def _accumulate(iterable)
17
+ sum = 0
18
+ iterable.map { |x| sum += x }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -6,10 +6,22 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1, shuffle: false)
9
+ def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
12
  @shuffle = shuffle
13
+
14
+ @batch_sampler = nil
15
+
16
+ if collate_fn.nil?
17
+ if auto_collation?
18
+ collate_fn = method(:default_collate)
19
+ else
20
+ collate_fn = method(:default_convert)
21
+ end
22
+ end
23
+
24
+ @collate_fn = collate_fn
13
25
  end
14
26
 
15
27
  def each
@@ -25,18 +37,20 @@ module Torch
25
37
  end
26
38
 
27
39
  indexes.each_slice(@batch_size) do |idx|
28
- batch = idx.map { |i| @dataset[i] }
29
- yield collate(batch)
40
+ # TODO improve performance
41
+ yield @collate_fn.call(idx.map { |i| @dataset[i] })
30
42
  end
31
43
  end
32
44
 
33
45
  def size
34
46
  (@dataset.size / @batch_size.to_f).ceil
35
47
  end
48
+ alias_method :length, :size
49
+ alias_method :count, :size
36
50
 
37
51
  private
38
52
 
39
- def collate(batch)
53
+ def default_convert(batch)
40
54
  elem = batch[0]
41
55
  case elem
42
56
  when Tensor
@@ -44,11 +58,15 @@ module Torch
44
58
  when Integer
45
59
  Torch.tensor(batch)
46
60
  when Array
47
- batch.transpose.map { |v| collate(v) }
61
+ batch.transpose.map { |v| default_convert(v) }
48
62
  else
49
- raise NotImpelmentYet
63
+ raise NotImplementedYet
50
64
  end
51
65
  end
66
+
67
+ def auto_collation?
68
+ !@batch_sampler.nil?
69
+ end
52
70
  end
53
71
  end
54
72
  end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Subset < Dataset
5
+ def initialize(dataset, indices)
6
+ @dataset = dataset
7
+ @indices = indices
8
+ end
9
+
10
+ def [](idx)
11
+ @dataset[@indices[idx]]
12
+ end
13
+
14
+ def length
15
+ @indices.length
16
+ end
17
+ alias_method :size, :length
18
+
19
+ def to_a
20
+ @indices.map { |i| @dataset[i] }
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -16,6 +16,8 @@ module Torch
16
16
  def size
17
17
  @tensors[0].size(0)
18
18
  end
19
+ alias_method :length, :size
20
+ alias_method :count, :size
19
21
  end
20
22
  end
21
23
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.1"
2
+ VERSION = "0.3.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.1
4
+ version: 0.3.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-17 00:00:00.000000000 Z
11
+ date: 2020-09-18 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -259,8 +259,10 @@ files:
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
+ - lib/torch/utils/data.rb
262
263
  - lib/torch/utils/data/data_loader.rb
263
264
  - lib/torch/utils/data/dataset.rb
265
+ - lib/torch/utils/data/subset.rb
264
266
  - lib/torch/utils/data/tensor_dataset.rb
265
267
  - lib/torch/version.rb
266
268
  homepage: https://github.com/ankane/torch.rb