torch-rb 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 97908e85a67729120f763bb4140323505b77a831d5648e9d2d0961259e3d300c
4
- data.tar.gz: f366548f9880dac7dffce6305e192f75a7467526ae55ae13af05d355918375ba
3
+ metadata.gz: f17982ebcf982b779b2c88dc0ef13a9c94b2e9a6a87007a0c3136ad5f2ef261b
4
+ data.tar.gz: d4ddddfc7cbb9baee0e117e36decfe8a757903f876ddcf510db4db53363f7adf
5
5
  SHA512:
6
- metadata.gz: cc32ddbc43131175452a8b62df5d1eac6bc8450eea174018affd5cd073f81e2c9825d613014c62c5f8137cf5dddd1ab6ab6de60a4b3a67a757387446dbc1efad
7
- data.tar.gz: c322e0b7ec7f03f12311d737034dad45037d2ad7710974e24250e11f4a0db14e221e870ddc52c0f2b723476be6f41fca8a8719068b0d0b7d8974d2080e9be6dc
6
+ metadata.gz: 00ff39bb405350f0107974b1786bf14ac6ca558744f05653ee2f16445d46f66658f82dcfa069bd3a92b44c33ba642dd196fc29a21379f188111fd7e8648f5eab
7
+ data.tar.gz: b66d48682789f71032c1d928174829791df43a04e829ebc241cafab884bc076983323bdffbcb12f5785c6d4614c13d2fed32bb5f20106957a4d6326289b7a1ea
@@ -1,3 +1,8 @@
1
+ ## 0.3.3 (2020-08-25)
2
+
3
+ - Added spectral ops
4
+ - Fixed tensor indexing
5
+
1
6
  ## 0.3.2 (2020-08-24)
2
7
 
3
8
  - Added `enable_grad` method
data/README.md CHANGED
@@ -2,7 +2,11 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
5
+ Check out:
6
+
7
+ - [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
8
+ - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
+ - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
6
10
 
7
11
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
8
12
 
@@ -16,6 +16,7 @@
16
16
  #include "nn_functions.hpp"
17
17
 
18
18
  using namespace Rice;
19
+ using torch::indexing::TensorIndex;
19
20
 
20
21
  // need to make a distinction between parameters and tensors
21
22
  class Parameter: public torch::autograd::Variable {
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
28
29
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
30
  }
30
31
 
32
+ std::vector<TensorIndex> index_vector(Array a) {
33
+ auto indices = std::vector<TensorIndex>();
34
+ indices.reserve(a.size());
35
+ for (size_t i = 0; i < a.size(); i++) {
36
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
37
+ }
38
+ return indices;
39
+ }
40
+
31
41
  extern "C"
32
42
  void Init_ext()
33
43
  {
@@ -58,6 +68,13 @@ void Init_ext()
58
68
  return generator.seed();
59
69
  });
60
70
 
71
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
+
61
78
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
62
79
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
63
80
  .add_handler<torch::Error>(handle_error)
@@ -330,6 +347,18 @@ void Init_ext()
330
347
  .define_method("numel", &torch::Tensor::numel)
331
348
  .define_method("element_size", &torch::Tensor::element_size)
332
349
  .define_method("requires_grad", &torch::Tensor::requires_grad)
350
+ .define_method(
351
+ "_index",
352
+ *[](Tensor& self, Array indices) {
353
+ auto vec = index_vector(indices);
354
+ return self.index(vec);
355
+ })
356
+ .define_method(
357
+ "_index_put_custom",
358
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
359
+ auto vec = index_vector(indices);
360
+ return self.index_put_(vec, value);
361
+ })
333
362
  .define_method(
334
363
  "contiguous?",
335
364
  *[](Tensor& self) {
@@ -508,6 +537,7 @@ void Init_ext()
508
537
  });
509
538
 
510
539
  Module rb_mInit = define_module_under(rb_mNN, "Init")
540
+ .add_handler<torch::Error>(handle_error)
511
541
  .define_singleton_method(
512
542
  "_calculate_gain",
513
543
  *[](NonlinearityType nonlinearity, double param) {
@@ -594,8 +624,8 @@ void Init_ext()
594
624
  });
595
625
 
596
626
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
597
- .define_constructor(Constructor<torch::Device, std::string>())
598
627
  .add_handler<torch::Error>(handle_error)
628
+ .define_constructor(Constructor<torch::Device, std::string>())
599
629
  .define_method("index", &torch::Device::index)
600
630
  .define_method("index?", &torch::Device::has_index)
601
631
  .define_method(
@@ -9,6 +9,10 @@
9
9
 
10
10
  using namespace Rice;
11
11
 
12
+ using torch::Device;
13
+ using torch::ScalarType;
14
+ using torch::Tensor;
15
+
12
16
  // need to wrap torch::IntArrayRef() since
13
17
  // it doesn't own underlying data
14
18
  class IntArrayRef {
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
174
178
  return MyReduction(x);
175
179
  }
176
180
 
177
- typedef torch::Tensor Tensor;
178
-
179
181
  class OptionalTensor {
180
182
  Object value;
181
183
  public:
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
197
199
  return OptionalTensor(x);
198
200
  }
199
201
 
200
- class ScalarType {
201
- Object value;
202
- public:
203
- ScalarType(Object o) {
204
- value = o;
205
- }
206
- operator at::ScalarType() {
207
- throw std::runtime_error("ScalarType arguments not implemented yet");
208
- }
209
- };
210
-
211
202
  template<>
212
203
  inline
213
- ScalarType from_ruby<ScalarType>(Object x)
204
+ torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
214
205
  {
215
- return ScalarType(x);
206
+ if (x.is_nil()) {
207
+ return torch::nullopt;
208
+ } else {
209
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
210
+ }
216
211
  }
217
212
 
218
- class OptionalScalarType {
219
- Object value;
220
- public:
221
- OptionalScalarType(Object o) {
222
- value = o;
223
- }
224
- operator c10::optional<at::ScalarType>() {
225
- if (value.is_nil()) {
226
- return c10::nullopt;
227
- }
228
- return ScalarType(value);
229
- }
230
- };
231
-
232
213
  template<>
233
214
  inline
234
- OptionalScalarType from_ruby<OptionalScalarType>(Object x)
215
+ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
235
216
  {
236
- return OptionalScalarType(x);
217
+ if (x.is_nil()) {
218
+ return torch::nullopt;
219
+ } else {
220
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
221
+ }
237
222
  }
238
223
 
239
- typedef torch::Device Device;
240
-
241
224
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
242
225
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
243
226
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -460,6 +460,17 @@ module Torch
460
460
  zeros(input.size, **like_options(input, options))
461
461
  end
462
462
 
463
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
464
+ if center
465
+ signal_dim = input.dim
466
+ extended_shape = [1] * (3 - signal_dim) + input.size
467
+ pad = n_fft.div(2).to_i
468
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
469
+ input = input.view(input.shape[-signal_dim..-1])
470
+ end
471
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
472
+ end
473
+
463
474
  private
464
475
 
465
476
  def to_ivalue(obj)
@@ -1,10 +1,14 @@
1
1
  module Torch
2
2
  module Native
3
3
  class Function
4
- attr_reader :function
4
+ attr_reader :function, :tensor_options
5
5
 
6
6
  def initialize(function)
7
7
  @function = function
8
+
9
+ tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
+ @tensor_options = @function["func"].include?(tensor_options_str)
11
+ @function["func"].sub!(tensor_options_str, ")")
8
12
  end
9
13
 
10
14
  def func
@@ -33,30 +33,14 @@ module Torch
33
33
  f.args.any? do |a|
34
34
  a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
35
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
+ # call to 'range' is ambiguous
37
+ f.cpp_name == "_range" ||
36
38
  # native_functions.yaml is missing size argument for normal
37
39
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
40
  (f.base_name == "normal" && !f.out?)
39
41
  end
40
42
  end
41
43
 
42
- # generate additional functions for optional arguments
43
- # there may be a better way to do this
44
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
45
- optional_functions.each do |f|
46
- next if f.ruby_name == "cross"
47
- next if f.ruby_name.start_with?("avg_pool") && f.out?
48
-
49
- opt_args = f.args.select { |a| a[:type] == "int?" }
50
- if opt_args.size == 1
51
- sep = f.name.include?(".") ? "_" : "."
52
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
53
- # TODO only remove some arguments
54
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
55
- functions << f1
56
- functions << f2
57
- end
58
- end
59
-
60
44
  # todo_functions.each do |f|
61
45
  # puts f.func
62
46
  # puts
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
97
81
 
98
82
  cpp_defs = []
99
83
  functions.sort_by(&:cpp_name).each do |func|
100
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
84
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
85
+ fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
101
86
 
102
87
  cpp_args = []
103
88
  fargs.each do |a|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
109
94
  # TODO better signature
110
95
  "OptionalTensor"
111
96
  when "ScalarType?"
112
- "OptionalScalarType"
97
+ "torch::optional<ScalarType>"
113
98
  when "Tensor[]"
114
99
  "TensorList"
115
100
  when "Tensor?[]"
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
117
102
  "TensorList"
118
103
  when "int"
119
104
  "int64_t"
105
+ when "int?"
106
+ "torch::optional<int64_t>"
120
107
  when "float"
121
108
  "double"
122
109
  when /\Aint\[/
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
125
112
  "Tensor &"
126
113
  when "str"
127
114
  "std::string"
115
+ when "TensorOptions"
116
+ "const torch::TensorOptions &"
128
117
  else
129
118
  a[:type]
130
119
  end
@@ -83,6 +83,8 @@ module Torch
83
83
  else
84
84
  v.is_a?(Integer)
85
85
  end
86
+ when "int?"
87
+ v.is_a?(Integer) || v.nil?
86
88
  when "float"
87
89
  v.is_a?(Numeric)
88
90
  when /int\[.*\]/
@@ -126,9 +128,11 @@ module Torch
126
128
  end
127
129
 
128
130
  func = candidates.first
131
+ args = func.args.map { |a| final_values[a[:name]] }
132
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
133
  {
130
134
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
135
+ args: args
132
136
  }
133
137
  end
134
138
  end
@@ -188,49 +188,15 @@ module Torch
188
188
  # based on python_variable_indexing.cpp and
189
189
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
190
190
  def [](*indexes)
191
- result = self
192
- dim = 0
193
- indexes.each do |index|
194
- if index.is_a?(Numeric)
195
- result = result._select_int(dim, index)
196
- elsif index.is_a?(Range)
197
- finish = index.end
198
- finish += 1 unless index.exclude_end?
199
- result = result._slice_tensor(dim, index.begin, finish, 1)
200
- dim += 1
201
- elsif index.is_a?(Tensor)
202
- result = result.index([index])
203
- elsif index.nil?
204
- result = result.unsqueeze(dim)
205
- dim += 1
206
- elsif index == true
207
- result = result.unsqueeze(dim)
208
- # TODO handle false
209
- else
210
- raise Error, "Unsupported index type: #{index.class.name}"
211
- end
212
- end
213
- result
191
+ _index(tensor_indexes(indexes))
214
192
  end
215
193
 
216
194
  # based on python_variable_indexing.cpp and
217
195
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
218
- def []=(index, value)
196
+ def []=(*indexes, value)
219
197
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
220
-
221
198
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
222
-
223
- if index.is_a?(Numeric)
224
- index_put!([Torch.tensor(index)], value)
225
- elsif index.is_a?(Range)
226
- finish = index.end
227
- finish += 1 unless index.exclude_end?
228
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
- elsif index.is_a?(Tensor)
230
- index_put!([index], value)
231
- else
232
- raise Error, "Unsupported index type: #{index.class.name}"
233
- end
199
+ _index_put_custom(tensor_indexes(indexes), value)
234
200
  end
235
201
 
236
202
  # native functions that need manually defined
@@ -244,13 +210,13 @@ module Torch
244
210
  end
245
211
  end
246
212
 
247
- # native functions overlap, so need to handle manually
213
+ # parser can't handle overlap, so need to handle manually
248
214
  def random!(*args)
249
215
  case args.size
250
216
  when 1
251
217
  _random__to(*args)
252
218
  when 2
253
- _random__from_to(*args)
219
+ _random__from(*args)
254
220
  else
255
221
  _random_(*args)
256
222
  end
@@ -260,5 +226,32 @@ module Torch
260
226
  _clamp_min_(min)
261
227
  _clamp_max_(max)
262
228
  end
229
+
230
+ private
231
+
232
+ def tensor_indexes(indexes)
233
+ indexes.map do |index|
234
+ case index
235
+ when Integer
236
+ TensorIndex.integer(index)
237
+ when Range
238
+ finish = index.end
239
+ if finish == -1 && !index.exclude_end?
240
+ finish = nil
241
+ else
242
+ finish += 1 unless index.exclude_end?
243
+ end
244
+ TensorIndex.slice(index.begin, finish)
245
+ when Tensor
246
+ TensorIndex.tensor(index)
247
+ when nil
248
+ TensorIndex.none
249
+ when true, false
250
+ TensorIndex.boolean(index)
251
+ else
252
+ raise Error, "Unsupported index type: #{index.class.name}"
253
+ end
254
+ end
255
+ end
263
256
  end
264
257
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.2"
2
+ VERSION = "0.3.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.2
4
+ version: 0.3.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-24 00:00:00.000000000 Z
11
+ date: 2020-08-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice