torch-rb 0.3.1 → 0.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 06e94b492acbbdb71f9e6a11081fb043a03ae0d5c704cc79faa31dd96bde70ef
4
- data.tar.gz: 4f38fa52d30ef9bf121204423b4d675f21dbef806b6f137152f2cf9399ddf4bb
3
+ metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
4
+ data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
5
5
  SHA512:
6
- metadata.gz: 2fb2613ca629a70f55009b697b15830d59c0d8fc06c1c5102917b4870cb783427fb56ecc08889c09e15c342381385f258b2a33102dc5adddf2d463d41674994d
7
- data.tar.gz: f26a6ba91caa57a92b8b047217a35c39d1e9c4c361df77e2182053b4ab490f20792fc88dba169dae87d4a3d4ee4d69e2c779efb1fa6150b4d3f0d93e3762aec9
6
+ metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
7
+ data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
@@ -1,3 +1,31 @@
1
+ ## 0.3.6 (2020-09-17)
2
+
3
+ - Added `inplace` option for leaky ReLU
4
+ - Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
5
+ - Fixed error with buffers on GPU
6
+
7
+ ## 0.3.5 (2020-09-04)
8
+
9
+ - Fixed error with data loader (due to `dtype` of `randperm`)
10
+
11
+ ## 0.3.4 (2020-08-26)
12
+
13
+ - Added `Torch.clamp` method
14
+
15
+ ## 0.3.3 (2020-08-25)
16
+
17
+ - Added spectral ops
18
+ - Fixed tensor indexing
19
+
20
+ ## 0.3.2 (2020-08-24)
21
+
22
+ - Added `enable_grad` method
23
+ - Added `random_split` method
24
+ - Added `collate_fn` option to `DataLoader`
25
+ - Added `grad=` method to `Tensor`
26
+ - Fixed error with `grad` method when empty
27
+ - Fixed `EmbeddingBag`
28
+
1
29
  ## 0.3.1 (2020-08-17)
2
30
 
3
31
  - Added `create_graph` and `retain_graph` options to `backward` method
data/README.md CHANGED
@@ -2,7 +2,11 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
5
+ Check out:
6
+
7
+ - [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
8
+ - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
+ - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
6
10
 
7
11
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
8
12
 
@@ -411,7 +415,7 @@ Here’s the list of compatible versions.
411
415
 
412
416
  Torch.rb | LibTorch
413
417
  --- | ---
414
- 0.3.0-0.3.1 | 1.6.0
418
+ 0.3.0-0.3.4 | 1.6.0
415
419
  0.2.0-0.2.7 | 1.5.0-1.5.1
416
420
  0.1.8 | 1.4.0
417
421
  0.1.0-0.1.7 | 1.3.1
@@ -16,6 +16,7 @@
16
16
  #include "nn_functions.hpp"
17
17
 
18
18
  using namespace Rice;
19
+ using torch::indexing::TensorIndex;
19
20
 
20
21
  // need to make a distinction between parameters and tensors
21
22
  class Parameter: public torch::autograd::Variable {
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
28
29
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
30
  }
30
31
 
32
+ std::vector<TensorIndex> index_vector(Array a) {
33
+ auto indices = std::vector<TensorIndex>();
34
+ indices.reserve(a.size());
35
+ for (size_t i = 0; i < a.size(); i++) {
36
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
37
+ }
38
+ return indices;
39
+ }
40
+
31
41
  extern "C"
32
42
  void Init_ext()
33
43
  {
@@ -58,6 +68,13 @@ void Init_ext()
58
68
  return generator.seed();
59
69
  });
60
70
 
71
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
+
61
78
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
62
79
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
63
80
  .add_handler<torch::Error>(handle_error)
@@ -284,11 +301,6 @@ void Init_ext()
284
301
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
285
302
  return torch::pickle_load(v);
286
303
  })
287
- .define_singleton_method(
288
- "_binary_cross_entropy_with_logits",
289
- *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
290
- return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
291
- })
292
304
  .define_singleton_method(
293
305
  "_from_blob",
294
306
  *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
@@ -330,6 +342,18 @@ void Init_ext()
330
342
  .define_method("numel", &torch::Tensor::numel)
331
343
  .define_method("element_size", &torch::Tensor::element_size)
332
344
  .define_method("requires_grad", &torch::Tensor::requires_grad)
345
+ .define_method(
346
+ "_index",
347
+ *[](Tensor& self, Array indices) {
348
+ auto vec = index_vector(indices);
349
+ return self.index(vec);
350
+ })
351
+ .define_method(
352
+ "_index_put_custom",
353
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
354
+ auto vec = index_vector(indices);
355
+ return self.index_put_(vec, value);
356
+ })
333
357
  .define_method(
334
358
  "contiguous?",
335
359
  *[](Tensor& self) {
@@ -350,15 +374,16 @@ void Init_ext()
350
374
  *[](Tensor& self, bool requires_grad) {
351
375
  return self.set_requires_grad(requires_grad);
352
376
  })
353
- .define_method(
354
- "_backward",
355
- *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
356
- return self.backward(gradient, create_graph, retain_graph);
357
- })
358
377
  .define_method(
359
378
  "grad",
360
379
  *[](Tensor& self) {
361
- return self.grad();
380
+ auto grad = self.grad();
381
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
382
+ })
383
+ .define_method(
384
+ "grad=",
385
+ *[](Tensor& self, torch::Tensor& grad) {
386
+ self.grad() = grad;
362
387
  })
363
388
  .define_method(
364
389
  "_dtype",
@@ -502,6 +527,7 @@ void Init_ext()
502
527
  });
503
528
 
504
529
  Module rb_mInit = define_module_under(rb_mNN, "Init")
530
+ .add_handler<torch::Error>(handle_error)
505
531
  .define_singleton_method(
506
532
  "_calculate_gain",
507
533
  *[](NonlinearityType nonlinearity, double param) {
@@ -580,11 +606,16 @@ void Init_ext()
580
606
  *[](Parameter& self) {
581
607
  auto grad = self.grad();
582
608
  return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
609
+ })
610
+ .define_method(
611
+ "grad=",
612
+ *[](Parameter& self, torch::Tensor& grad) {
613
+ self.grad() = grad;
583
614
  });
584
615
 
585
616
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
586
- .define_constructor(Constructor<torch::Device, std::string>())
587
617
  .add_handler<torch::Error>(handle_error)
618
+ .define_constructor(Constructor<torch::Device, std::string>())
588
619
  .define_method("index", &torch::Device::index)
589
620
  .define_method("index?", &torch::Device::has_index)
590
621
  .define_method(
@@ -53,3 +53,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
53
53
  a.push(to_ruby<int64_t>(std::get<3>(x)));
54
54
  return Object(a);
55
55
  }
56
+
57
+ Object wrap(std::vector<torch::Tensor> x) {
58
+ Array a;
59
+ for (auto& t : x) {
60
+ a.push(to_ruby<torch::Tensor>(t));
61
+ }
62
+ return Object(a);
63
+ }
@@ -9,6 +9,11 @@
9
9
 
10
10
  using namespace Rice;
11
11
 
12
+ using torch::Device;
13
+ using torch::Scalar;
14
+ using torch::ScalarType;
15
+ using torch::Tensor;
16
+
12
17
  // need to wrap torch::IntArrayRef() since
13
18
  // it doesn't own underlying data
14
19
  class IntArrayRef {
@@ -32,30 +37,6 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
32
37
  return IntArrayRef(x);
33
38
  }
34
39
 
35
- // for now
36
- class Scalar {
37
- torch::Scalar value;
38
- public:
39
- Scalar(Object o) {
40
- // TODO cast based on Ruby type
41
- if (o.rb_type() == T_FIXNUM) {
42
- value = torch::Scalar(from_ruby<int64_t>(o));
43
- } else {
44
- value = torch::Scalar(from_ruby<float>(o));
45
- }
46
- }
47
- operator torch::Scalar() {
48
- return value;
49
- }
50
- };
51
-
52
- template<>
53
- inline
54
- Scalar from_ruby<Scalar>(Object x)
55
- {
56
- return Scalar(x);
57
- }
58
-
59
40
  class TensorList {
60
41
  std::vector<torch::Tensor> vec;
61
42
  public:
@@ -174,8 +155,6 @@ MyReduction from_ruby<MyReduction>(Object x)
174
155
  return MyReduction(x);
175
156
  }
176
157
 
177
- typedef torch::Tensor Tensor;
178
-
179
158
  class OptionalTensor {
180
159
  Object value;
181
160
  public:
@@ -190,6 +169,17 @@ class OptionalTensor {
190
169
  }
191
170
  };
192
171
 
172
+ template<>
173
+ inline
174
+ Scalar from_ruby<Scalar>(Object x)
175
+ {
176
+ if (x.rb_type() == T_FIXNUM) {
177
+ return torch::Scalar(from_ruby<int64_t>(x));
178
+ } else {
179
+ return torch::Scalar(from_ruby<double>(x));
180
+ }
181
+ }
182
+
193
183
  template<>
194
184
  inline
195
185
  OptionalTensor from_ruby<OptionalTensor>(Object x)
@@ -197,46 +187,60 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
197
187
  return OptionalTensor(x);
198
188
  }
199
189
 
200
- class ScalarType {
201
- Object value;
202
- public:
203
- ScalarType(Object o) {
204
- value = o;
205
- }
206
- operator at::ScalarType() {
207
- throw std::runtime_error("ScalarType arguments not implemented yet");
208
- }
209
- };
190
+ template<>
191
+ inline
192
+ torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
193
+ {
194
+ if (x.is_nil()) {
195
+ return torch::nullopt;
196
+ } else {
197
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
198
+ }
199
+ }
210
200
 
211
201
  template<>
212
202
  inline
213
- ScalarType from_ruby<ScalarType>(Object x)
203
+ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
214
204
  {
215
- return ScalarType(x);
205
+ if (x.is_nil()) {
206
+ return torch::nullopt;
207
+ } else {
208
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
209
+ }
216
210
  }
217
211
 
218
- class OptionalScalarType {
219
- Object value;
220
- public:
221
- OptionalScalarType(Object o) {
222
- value = o;
223
- }
224
- operator c10::optional<at::ScalarType>() {
225
- if (value.is_nil()) {
226
- return c10::nullopt;
227
- }
228
- return ScalarType(value);
229
- }
230
- };
212
+ template<>
213
+ inline
214
+ torch::optional<double> from_ruby<torch::optional<double>>(Object x)
215
+ {
216
+ if (x.is_nil()) {
217
+ return torch::nullopt;
218
+ } else {
219
+ return torch::optional<double>{from_ruby<double>(x)};
220
+ }
221
+ }
231
222
 
232
223
  template<>
233
224
  inline
234
- OptionalScalarType from_ruby<OptionalScalarType>(Object x)
225
+ torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
235
226
  {
236
- return OptionalScalarType(x);
227
+ if (x.is_nil()) {
228
+ return torch::nullopt;
229
+ } else {
230
+ return torch::optional<bool>{from_ruby<bool>(x)};
231
+ }
237
232
  }
238
233
 
239
- typedef torch::Device Device;
234
+ template<>
235
+ inline
236
+ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
237
+ {
238
+ if (x.is_nil()) {
239
+ return torch::nullopt;
240
+ } else {
241
+ return torch::optional<Scalar>{from_ruby<Scalar>(x)};
242
+ }
243
+ }
240
244
 
241
245
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
242
246
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -244,3 +248,4 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
244
248
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
245
249
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
246
250
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
251
+ Object wrap(std::vector<torch::Tensor> x);
@@ -179,8 +179,10 @@ require "torch/nn/functional"
179
179
  require "torch/nn/init"
180
180
 
181
181
  # utils
182
+ require "torch/utils/data"
182
183
  require "torch/utils/data/data_loader"
183
184
  require "torch/utils/data/dataset"
185
+ require "torch/utils/data/subset"
184
186
  require "torch/utils/data/tensor_dataset"
185
187
 
186
188
  # hub
@@ -237,25 +239,22 @@ module Torch
237
239
  cls
238
240
  end
239
241
 
240
- FloatTensor = _make_tensor_class(:float32)
241
- DoubleTensor = _make_tensor_class(:float64)
242
- HalfTensor = _make_tensor_class(:float16)
243
- ByteTensor = _make_tensor_class(:uint8)
244
- CharTensor = _make_tensor_class(:int8)
245
- ShortTensor = _make_tensor_class(:int16)
246
- IntTensor = _make_tensor_class(:int32)
247
- LongTensor = _make_tensor_class(:int64)
248
- BoolTensor = _make_tensor_class(:bool)
249
-
250
- CUDA::FloatTensor = _make_tensor_class(:float32, true)
251
- CUDA::DoubleTensor = _make_tensor_class(:float64, true)
252
- CUDA::HalfTensor = _make_tensor_class(:float16, true)
253
- CUDA::ByteTensor = _make_tensor_class(:uint8, true)
254
- CUDA::CharTensor = _make_tensor_class(:int8, true)
255
- CUDA::ShortTensor = _make_tensor_class(:int16, true)
256
- CUDA::IntTensor = _make_tensor_class(:int32, true)
257
- CUDA::LongTensor = _make_tensor_class(:int64, true)
258
- CUDA::BoolTensor = _make_tensor_class(:bool, true)
242
+ DTYPE_TO_CLASS = {
243
+ float32: "FloatTensor",
244
+ float64: "DoubleTensor",
245
+ float16: "HalfTensor",
246
+ uint8: "ByteTensor",
247
+ int8: "CharTensor",
248
+ int16: "ShortTensor",
249
+ int32: "IntTensor",
250
+ int64: "LongTensor",
251
+ bool: "BoolTensor"
252
+ }
253
+
254
+ DTYPE_TO_CLASS.each do |dtype, class_name|
255
+ const_set(class_name, _make_tensor_class(dtype))
256
+ CUDA.const_set(class_name, _make_tensor_class(dtype, true))
257
+ end
259
258
 
260
259
  class << self
261
260
  # Torch.float, Torch.long, etc
@@ -316,6 +315,16 @@ module Torch
316
315
  end
317
316
  end
318
317
 
318
+ def enable_grad
319
+ previous_value = grad_enabled?
320
+ begin
321
+ _set_grad_enabled(true)
322
+ yield
323
+ ensure
324
+ _set_grad_enabled(previous_value)
325
+ end
326
+ end
327
+
319
328
  def device(str)
320
329
  Device.new(str)
321
330
  end
@@ -376,6 +385,10 @@ module Torch
376
385
  end
377
386
 
378
387
  def randperm(n, **options)
388
+ # dtype hack in Python
389
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
390
+ options[:dtype] ||= :int64
391
+
379
392
  _randperm(n, tensor_options(**options))
380
393
  end
381
394
 
@@ -448,6 +461,22 @@ module Torch
448
461
  zeros(input.size, **like_options(input, options))
449
462
  end
450
463
 
464
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
465
+ if center
466
+ signal_dim = input.dim
467
+ extended_shape = [1] * (3 - signal_dim) + input.size
468
+ pad = n_fft.div(2).to_i
469
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
470
+ input = input.view(input.shape[-signal_dim..-1])
471
+ end
472
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
473
+ end
474
+
475
+ def clamp(tensor, min, max)
476
+ tensor = _clamp_min(tensor, min)
477
+ _clamp_max(tensor, max)
478
+ end
479
+
451
480
  private
452
481
 
453
482
  def to_ivalue(obj)
@@ -7,25 +7,26 @@ module Torch
7
7
 
8
8
  def download_url_to_file(url, dst)
9
9
  uri = URI(url)
10
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
10
+ tmp = nil
11
11
  location = nil
12
12
 
13
+ puts "Downloading #{url}..."
13
14
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
15
  request = Net::HTTP::Get.new(uri)
15
16
 
16
- puts "Downloading #{url}..."
17
- File.open(tmp, "wb") do |f|
18
- http.request(request) do |response|
19
- case response
20
- when Net::HTTPRedirection
21
- location = response["location"]
22
- when Net::HTTPSuccess
17
+ http.request(request) do |response|
18
+ case response
19
+ when Net::HTTPRedirection
20
+ location = response["location"]
21
+ when Net::HTTPSuccess
22
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
+ File.open(tmp, "wb") do |f|
23
24
  response.read_body do |chunk|
24
25
  f.write(chunk)
25
26
  end
26
- else
27
- raise Error, "Bad response"
28
27
  end
28
+ else
29
+ raise Error, "Bad response"
29
30
  end
30
31
  end
31
32
  end
@@ -1,10 +1,14 @@
1
1
  module Torch
2
2
  module Native
3
3
  class Function
4
- attr_reader :function
4
+ attr_reader :function, :tensor_options
5
5
 
6
6
  def initialize(function)
7
7
  @function = function
8
+
9
+ tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
+ @tensor_options = @function["func"].include?(tensor_options_str)
11
+ @function["func"].sub!(tensor_options_str, ")")
8
12
  end
9
13
 
10
14
  def func
@@ -82,6 +86,10 @@ module Torch
82
86
  @ret_size ||= func.split("->").last.split(", ").size
83
87
  end
84
88
 
89
+ def ret_array?
90
+ @ret_array ||= func.split("->").last.include?('[]')
91
+ end
92
+
85
93
  def out?
86
94
  out_size > 0 && base_name[-1] != "_"
87
95
  end
@@ -18,12 +18,12 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
25
25
  f.ruby_name.start_with?("_") ||
26
- f.ruby_name.end_with?("_backward") ||
26
+ f.ruby_name.include?("_backward") ||
27
27
  f.args.any? { |a| a[:type].include?("Dimname") }
28
28
  end
29
29
 
@@ -31,32 +31,15 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
34
  skip_args.any? { |sa| a[:type].include?(sa) } ||
35
+ # call to 'range' is ambiguous
36
+ f.cpp_name == "_range" ||
36
37
  # native_functions.yaml is missing size argument for normal
37
38
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
39
  (f.base_name == "normal" && !f.out?)
39
40
  end
40
41
  end
41
42
 
42
- # generate additional functions for optional arguments
43
- # there may be a better way to do this
44
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
45
- optional_functions.each do |f|
46
- next if f.ruby_name == "cross"
47
- next if f.ruby_name.start_with?("avg_pool") && f.out?
48
-
49
- opt_args = f.args.select { |a| a[:type] == "int?" }
50
- if opt_args.size == 1
51
- sep = f.name.include?(".") ? "_" : "."
52
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
53
- # TODO only remove some arguments
54
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
55
- functions << f1
56
- functions << f2
57
- end
58
- end
59
-
60
43
  # todo_functions.each do |f|
61
44
  # puts f.func
62
45
  # puts
@@ -97,7 +80,8 @@ void add_%{type}_functions(Module m) {
97
80
 
98
81
  cpp_defs = []
99
82
  functions.sort_by(&:cpp_name).each do |func|
100
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
83
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
84
+ fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
101
85
 
102
86
  cpp_args = []
103
87
  fargs.each do |a|
@@ -109,7 +93,7 @@ void add_%{type}_functions(Module m) {
109
93
  # TODO better signature
110
94
  "OptionalTensor"
111
95
  when "ScalarType?"
112
- "OptionalScalarType"
96
+ "torch::optional<ScalarType>"
113
97
  when "Tensor[]"
114
98
  "TensorList"
115
99
  when "Tensor?[]"
@@ -117,6 +101,14 @@ void add_%{type}_functions(Module m) {
117
101
  "TensorList"
118
102
  when "int"
119
103
  "int64_t"
104
+ when "int?"
105
+ "torch::optional<int64_t>"
106
+ when "float?"
107
+ "torch::optional<double>"
108
+ when "bool?"
109
+ "torch::optional<bool>"
110
+ when "Scalar?"
111
+ "torch::optional<torch::Scalar>"
120
112
  when "float"
121
113
  "double"
122
114
  when /\Aint\[/
@@ -125,6 +117,8 @@ void add_%{type}_functions(Module m) {
125
117
  "Tensor &"
126
118
  when "str"
127
119
  "std::string"
120
+ when "TensorOptions"
121
+ "const torch::TensorOptions &"
128
122
  else
129
123
  a[:type]
130
124
  end
@@ -141,8 +135,8 @@ void add_%{type}_functions(Module m) {
141
135
  prefix = def_method == :define_method ? "self." : "torch::"
142
136
 
143
137
  body = "#{prefix}#{dispatch}(#{args.join(", ")})"
144
- # TODO check type as well
145
- if func.ret_size > 1
138
+
139
+ if func.ret_size > 1 || func.ret_array?
146
140
  body = "wrap(#{body})"
147
141
  end
148
142
 
@@ -83,6 +83,12 @@ module Torch
83
83
  else
84
84
  v.is_a?(Integer)
85
85
  end
86
+ when "int?"
87
+ v.is_a?(Integer) || v.nil?
88
+ when "float?"
89
+ v.is_a?(Numeric) || v.nil?
90
+ when "bool?"
91
+ v == true || v == false || v.nil?
86
92
  when "float"
87
93
  v.is_a?(Numeric)
88
94
  when /int\[.*\]/
@@ -95,6 +101,10 @@ module Torch
95
101
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
96
102
  when "Scalar"
97
103
  v.is_a?(Numeric)
104
+ when "Scalar?"
105
+ v.is_a?(Numeric) || v.nil?
106
+ when "ScalarType"
107
+ false # not supported yet
98
108
  when "ScalarType?"
99
109
  v.nil?
100
110
  when "bool"
@@ -126,9 +136,11 @@ module Torch
126
136
  end
127
137
 
128
138
  func = candidates.first
139
+ args = func.args.map { |a| final_values[a[:name]] }
140
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
141
  {
130
142
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
143
+ args: args
132
144
  }
133
145
  end
134
146
  end
@@ -178,8 +178,12 @@ module Torch
178
178
  Torch.hardshrink(input, lambd)
179
179
  end
180
180
 
181
- def leaky_relu(input, negative_slope = 0.01)
182
- NN.leaky_relu(input, negative_slope)
181
+ def leaky_relu(input, negative_slope = 0.01, inplace: false)
182
+ if inplace
183
+ NN.leaky_relu!(input, negative_slope)
184
+ else
185
+ NN.leaky_relu(input, negative_slope)
186
+ end
183
187
  end
184
188
 
185
189
  def log_sigmoid(input)
@@ -373,7 +377,8 @@ module Torch
373
377
  end
374
378
 
375
379
  # weight and input swapped
376
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
380
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
381
+ ret
377
382
  end
378
383
 
379
384
  # distance functions
@@ -426,6 +431,9 @@ module Torch
426
431
  end
427
432
 
428
433
  def mse_loss(input, target, reduction: "mean")
434
+ if target.size != input.size
435
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
436
+ end
429
437
  NN.mse_loss(input, target, reduction)
430
438
  end
431
439
 
@@ -1,14 +1,14 @@
1
1
  module Torch
2
2
  module NN
3
3
  class LeakyReLU < Module
4
- def initialize(negative_slope: 1e-2) #, inplace: false)
4
+ def initialize(negative_slope: 1e-2, inplace: false)
5
5
  super()
6
6
  @negative_slope = negative_slope
7
- # @inplace = inplace
7
+ @inplace = inplace
8
8
  end
9
9
 
10
10
  def forward(input)
11
- F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
11
+ F.leaky_relu(input, @negative_slope, inplace: @inplace)
12
12
  end
13
13
 
14
14
  def extra_inspect
@@ -55,7 +55,12 @@ module Torch
55
55
  end
56
56
  end
57
57
  end
58
- # TODO apply to more objects
58
+
59
+ @buffers.each_key do |k|
60
+ buf = @buffers[k]
61
+ @buffers[k] = fn.call(buf) unless buf.nil?
62
+ end
63
+
59
64
  self
60
65
  end
61
66
 
@@ -103,11 +103,6 @@ module Torch
103
103
  Torch.empty(0, dtype: dtype)
104
104
  end
105
105
 
106
- def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
- retain_graph = create_graph if retain_graph.nil?
108
- _backward(gradient, retain_graph, create_graph)
109
- end
110
-
111
106
  # TODO read directly from memory
112
107
  def numo
113
108
  cls = Torch._dtype_to_numo[dtype]
@@ -188,49 +183,15 @@ module Torch
188
183
  # based on python_variable_indexing.cpp and
189
184
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
190
185
  def [](*indexes)
191
- result = self
192
- dim = 0
193
- indexes.each do |index|
194
- if index.is_a?(Numeric)
195
- result = result._select_int(dim, index)
196
- elsif index.is_a?(Range)
197
- finish = index.end
198
- finish += 1 unless index.exclude_end?
199
- result = result._slice_tensor(dim, index.begin, finish, 1)
200
- dim += 1
201
- elsif index.is_a?(Tensor)
202
- result = result.index([index])
203
- elsif index.nil?
204
- result = result.unsqueeze(dim)
205
- dim += 1
206
- elsif index == true
207
- result = result.unsqueeze(dim)
208
- # TODO handle false
209
- else
210
- raise Error, "Unsupported index type: #{index.class.name}"
211
- end
212
- end
213
- result
186
+ _index(tensor_indexes(indexes))
214
187
  end
215
188
 
216
189
  # based on python_variable_indexing.cpp and
217
190
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
218
- def []=(index, value)
191
+ def []=(*indexes, value)
219
192
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
220
-
221
193
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
222
-
223
- if index.is_a?(Numeric)
224
- index_put!([Torch.tensor(index)], value)
225
- elsif index.is_a?(Range)
226
- finish = index.end
227
- finish += 1 unless index.exclude_end?
228
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
- elsif index.is_a?(Tensor)
230
- index_put!([index], value)
231
- else
232
- raise Error, "Unsupported index type: #{index.class.name}"
233
- end
194
+ _index_put_custom(tensor_indexes(indexes), value)
234
195
  end
235
196
 
236
197
  # native functions that need manually defined
@@ -244,13 +205,13 @@ module Torch
244
205
  end
245
206
  end
246
207
 
247
- # native functions overlap, so need to handle manually
208
+ # parser can't handle overlap, so need to handle manually
248
209
  def random!(*args)
249
210
  case args.size
250
211
  when 1
251
212
  _random__to(*args)
252
213
  when 2
253
- _random__from_to(*args)
214
+ _random__from(*args)
254
215
  else
255
216
  _random_(*args)
256
217
  end
@@ -260,5 +221,32 @@ module Torch
260
221
  _clamp_min_(min)
261
222
  _clamp_max_(max)
262
223
  end
224
+
225
+ private
226
+
227
+ def tensor_indexes(indexes)
228
+ indexes.map do |index|
229
+ case index
230
+ when Integer
231
+ TensorIndex.integer(index)
232
+ when Range
233
+ finish = index.end || -1
234
+ if finish == -1 && !index.exclude_end?
235
+ finish = nil
236
+ else
237
+ finish += 1 unless index.exclude_end?
238
+ end
239
+ TensorIndex.slice(index.begin, finish)
240
+ when Tensor
241
+ TensorIndex.tensor(index)
242
+ when nil
243
+ TensorIndex.none
244
+ when true, false
245
+ TensorIndex.boolean(index)
246
+ else
247
+ raise Error, "Unsupported index type: #{index.class.name}"
248
+ end
249
+ end
250
+ end
263
251
  end
264
252
  end
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class << self
5
+ def random_split(dataset, lengths)
6
+ if lengths.sum != dataset.length
7
+ raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
8
+ end
9
+
10
+ indices = Torch.randperm(lengths.sum).to_a
11
+ _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
12
+ end
13
+
14
+ private
15
+
16
+ def _accumulate(iterable)
17
+ sum = 0
18
+ iterable.map { |x| sum += x }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -6,10 +6,22 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1, shuffle: false)
9
+ def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
12
  @shuffle = shuffle
13
+
14
+ @batch_sampler = nil
15
+
16
+ if collate_fn.nil?
17
+ if auto_collation?
18
+ collate_fn = method(:default_collate)
19
+ else
20
+ collate_fn = method(:default_convert)
21
+ end
22
+ end
23
+
24
+ @collate_fn = collate_fn
13
25
  end
14
26
 
15
27
  def each
@@ -25,18 +37,20 @@ module Torch
25
37
  end
26
38
 
27
39
  indexes.each_slice(@batch_size) do |idx|
28
- batch = idx.map { |i| @dataset[i] }
29
- yield collate(batch)
40
+ # TODO improve performance
41
+ yield @collate_fn.call(idx.map { |i| @dataset[i] })
30
42
  end
31
43
  end
32
44
 
33
45
  def size
34
46
  (@dataset.size / @batch_size.to_f).ceil
35
47
  end
48
+ alias_method :length, :size
49
+ alias_method :count, :size
36
50
 
37
51
  private
38
52
 
39
- def collate(batch)
53
+ def default_convert(batch)
40
54
  elem = batch[0]
41
55
  case elem
42
56
  when Tensor
@@ -44,11 +58,15 @@ module Torch
44
58
  when Integer
45
59
  Torch.tensor(batch)
46
60
  when Array
47
- batch.transpose.map { |v| collate(v) }
61
+ batch.transpose.map { |v| default_convert(v) }
48
62
  else
49
- raise NotImpelmentYet
63
+ raise NotImplementedYet
50
64
  end
51
65
  end
66
+
67
+ def auto_collation?
68
+ !@batch_sampler.nil?
69
+ end
52
70
  end
53
71
  end
54
72
  end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Subset < Dataset
5
+ def initialize(dataset, indices)
6
+ @dataset = dataset
7
+ @indices = indices
8
+ end
9
+
10
+ def [](idx)
11
+ @dataset[@indices[idx]]
12
+ end
13
+
14
+ def length
15
+ @indices.length
16
+ end
17
+ alias_method :size, :length
18
+
19
+ def to_a
20
+ @indices.map { |i| @dataset[i] }
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -16,6 +16,8 @@ module Torch
16
16
  def size
17
17
  @tensors[0].size(0)
18
18
  end
19
+ alias_method :length, :size
20
+ alias_method :count, :size
19
21
  end
20
22
  end
21
23
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.1"
2
+ VERSION = "0.3.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.1
4
+ version: 0.3.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-17 00:00:00.000000000 Z
11
+ date: 2020-09-18 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -259,8 +259,10 @@ files:
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
+ - lib/torch/utils/data.rb
262
263
  - lib/torch/utils/data/data_loader.rb
263
264
  - lib/torch/utils/data/dataset.rb
265
+ - lib/torch/utils/data/subset.rb
264
266
  - lib/torch/utils/data/tensor_dataset.rb
265
267
  - lib/torch/version.rb
266
268
  homepage: https://github.com/ankane/torch.rb