torch-rb 0.3.2 → 0.3.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 97908e85a67729120f763bb4140323505b77a831d5648e9d2d0961259e3d300c
4
- data.tar.gz: f366548f9880dac7dffce6305e192f75a7467526ae55ae13af05d355918375ba
3
+ metadata.gz: f17982ebcf982b779b2c88dc0ef13a9c94b2e9a6a87007a0c3136ad5f2ef261b
4
+ data.tar.gz: d4ddddfc7cbb9baee0e117e36decfe8a757903f876ddcf510db4db53363f7adf
5
5
  SHA512:
6
- metadata.gz: cc32ddbc43131175452a8b62df5d1eac6bc8450eea174018affd5cd073f81e2c9825d613014c62c5f8137cf5dddd1ab6ab6de60a4b3a67a757387446dbc1efad
7
- data.tar.gz: c322e0b7ec7f03f12311d737034dad45037d2ad7710974e24250e11f4a0db14e221e870ddc52c0f2b723476be6f41fca8a8719068b0d0b7d8974d2080e9be6dc
6
+ metadata.gz: 00ff39bb405350f0107974b1786bf14ac6ca558744f05653ee2f16445d46f66658f82dcfa069bd3a92b44c33ba642dd196fc29a21379f188111fd7e8648f5eab
7
+ data.tar.gz: b66d48682789f71032c1d928174829791df43a04e829ebc241cafab884bc076983323bdffbcb12f5785c6d4614c13d2fed32bb5f20106957a4d6326289b7a1ea
@@ -1,3 +1,8 @@
1
+ ## 0.3.3 (2020-08-25)
2
+
3
+ - Added spectral ops
4
+ - Fixed tensor indexing
5
+
1
6
  ## 0.3.2 (2020-08-24)
2
7
 
3
8
  - Added `enable_grad` method
data/README.md CHANGED
@@ -2,7 +2,11 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
5
+ Check out:
6
+
7
+ - [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
8
+ - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
+ - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
6
10
 
7
11
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
8
12
 
@@ -16,6 +16,7 @@
16
16
  #include "nn_functions.hpp"
17
17
 
18
18
  using namespace Rice;
19
+ using torch::indexing::TensorIndex;
19
20
 
20
21
  // need to make a distinction between parameters and tensors
21
22
  class Parameter: public torch::autograd::Variable {
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
28
29
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
30
  }
30
31
 
32
+ std::vector<TensorIndex> index_vector(Array a) {
33
+ auto indices = std::vector<TensorIndex>();
34
+ indices.reserve(a.size());
35
+ for (size_t i = 0; i < a.size(); i++) {
36
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
37
+ }
38
+ return indices;
39
+ }
40
+
31
41
  extern "C"
32
42
  void Init_ext()
33
43
  {
@@ -58,6 +68,13 @@ void Init_ext()
58
68
  return generator.seed();
59
69
  });
60
70
 
71
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
+
61
78
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
62
79
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
63
80
  .add_handler<torch::Error>(handle_error)
@@ -330,6 +347,18 @@ void Init_ext()
330
347
  .define_method("numel", &torch::Tensor::numel)
331
348
  .define_method("element_size", &torch::Tensor::element_size)
332
349
  .define_method("requires_grad", &torch::Tensor::requires_grad)
350
+ .define_method(
351
+ "_index",
352
+ *[](Tensor& self, Array indices) {
353
+ auto vec = index_vector(indices);
354
+ return self.index(vec);
355
+ })
356
+ .define_method(
357
+ "_index_put_custom",
358
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
359
+ auto vec = index_vector(indices);
360
+ return self.index_put_(vec, value);
361
+ })
333
362
  .define_method(
334
363
  "contiguous?",
335
364
  *[](Tensor& self) {
@@ -508,6 +537,7 @@ void Init_ext()
508
537
  });
509
538
 
510
539
  Module rb_mInit = define_module_under(rb_mNN, "Init")
540
+ .add_handler<torch::Error>(handle_error)
511
541
  .define_singleton_method(
512
542
  "_calculate_gain",
513
543
  *[](NonlinearityType nonlinearity, double param) {
@@ -594,8 +624,8 @@ void Init_ext()
594
624
  });
595
625
 
596
626
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
597
- .define_constructor(Constructor<torch::Device, std::string>())
598
627
  .add_handler<torch::Error>(handle_error)
628
+ .define_constructor(Constructor<torch::Device, std::string>())
599
629
  .define_method("index", &torch::Device::index)
600
630
  .define_method("index?", &torch::Device::has_index)
601
631
  .define_method(
@@ -9,6 +9,10 @@
9
9
 
10
10
  using namespace Rice;
11
11
 
12
+ using torch::Device;
13
+ using torch::ScalarType;
14
+ using torch::Tensor;
15
+
12
16
  // need to wrap torch::IntArrayRef() since
13
17
  // it doesn't own underlying data
14
18
  class IntArrayRef {
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
174
178
  return MyReduction(x);
175
179
  }
176
180
 
177
- typedef torch::Tensor Tensor;
178
-
179
181
  class OptionalTensor {
180
182
  Object value;
181
183
  public:
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
197
199
  return OptionalTensor(x);
198
200
  }
199
201
 
200
- class ScalarType {
201
- Object value;
202
- public:
203
- ScalarType(Object o) {
204
- value = o;
205
- }
206
- operator at::ScalarType() {
207
- throw std::runtime_error("ScalarType arguments not implemented yet");
208
- }
209
- };
210
-
211
202
  template<>
212
203
  inline
213
- ScalarType from_ruby<ScalarType>(Object x)
204
+ torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
214
205
  {
215
- return ScalarType(x);
206
+ if (x.is_nil()) {
207
+ return torch::nullopt;
208
+ } else {
209
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
210
+ }
216
211
  }
217
212
 
218
- class OptionalScalarType {
219
- Object value;
220
- public:
221
- OptionalScalarType(Object o) {
222
- value = o;
223
- }
224
- operator c10::optional<at::ScalarType>() {
225
- if (value.is_nil()) {
226
- return c10::nullopt;
227
- }
228
- return ScalarType(value);
229
- }
230
- };
231
-
232
213
  template<>
233
214
  inline
234
- OptionalScalarType from_ruby<OptionalScalarType>(Object x)
215
+ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
235
216
  {
236
- return OptionalScalarType(x);
217
+ if (x.is_nil()) {
218
+ return torch::nullopt;
219
+ } else {
220
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
221
+ }
237
222
  }
238
223
 
239
- typedef torch::Device Device;
240
-
241
224
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
242
225
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
243
226
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -460,6 +460,17 @@ module Torch
460
460
  zeros(input.size, **like_options(input, options))
461
461
  end
462
462
 
463
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
464
+ if center
465
+ signal_dim = input.dim
466
+ extended_shape = [1] * (3 - signal_dim) + input.size
467
+ pad = n_fft.div(2).to_i
468
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
469
+ input = input.view(input.shape[-signal_dim..-1])
470
+ end
471
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
472
+ end
473
+
463
474
  private
464
475
 
465
476
  def to_ivalue(obj)
@@ -1,10 +1,14 @@
1
1
  module Torch
2
2
  module Native
3
3
  class Function
4
- attr_reader :function
4
+ attr_reader :function, :tensor_options
5
5
 
6
6
  def initialize(function)
7
7
  @function = function
8
+
9
+ tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
+ @tensor_options = @function["func"].include?(tensor_options_str)
11
+ @function["func"].sub!(tensor_options_str, ")")
8
12
  end
9
13
 
10
14
  def func
@@ -33,30 +33,14 @@ module Torch
33
33
  f.args.any? do |a|
34
34
  a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
35
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
+ # call to 'range' is ambiguous
37
+ f.cpp_name == "_range" ||
36
38
  # native_functions.yaml is missing size argument for normal
37
39
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
38
40
  (f.base_name == "normal" && !f.out?)
39
41
  end
40
42
  end
41
43
 
42
- # generate additional functions for optional arguments
43
- # there may be a better way to do this
44
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
45
- optional_functions.each do |f|
46
- next if f.ruby_name == "cross"
47
- next if f.ruby_name.start_with?("avg_pool") && f.out?
48
-
49
- opt_args = f.args.select { |a| a[:type] == "int?" }
50
- if opt_args.size == 1
51
- sep = f.name.include?(".") ? "_" : "."
52
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
53
- # TODO only remove some arguments
54
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
55
- functions << f1
56
- functions << f2
57
- end
58
- end
59
-
60
44
  # todo_functions.each do |f|
61
45
  # puts f.func
62
46
  # puts
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
97
81
 
98
82
  cpp_defs = []
99
83
  functions.sort_by(&:cpp_name).each do |func|
100
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
84
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
85
+ fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
101
86
 
102
87
  cpp_args = []
103
88
  fargs.each do |a|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
109
94
  # TODO better signature
110
95
  "OptionalTensor"
111
96
  when "ScalarType?"
112
- "OptionalScalarType"
97
+ "torch::optional<ScalarType>"
113
98
  when "Tensor[]"
114
99
  "TensorList"
115
100
  when "Tensor?[]"
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
117
102
  "TensorList"
118
103
  when "int"
119
104
  "int64_t"
105
+ when "int?"
106
+ "torch::optional<int64_t>"
120
107
  when "float"
121
108
  "double"
122
109
  when /\Aint\[/
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
125
112
  "Tensor &"
126
113
  when "str"
127
114
  "std::string"
115
+ when "TensorOptions"
116
+ "const torch::TensorOptions &"
128
117
  else
129
118
  a[:type]
130
119
  end
@@ -83,6 +83,8 @@ module Torch
83
83
  else
84
84
  v.is_a?(Integer)
85
85
  end
86
+ when "int?"
87
+ v.is_a?(Integer) || v.nil?
86
88
  when "float"
87
89
  v.is_a?(Numeric)
88
90
  when /int\[.*\]/
@@ -126,9 +128,11 @@ module Torch
126
128
  end
127
129
 
128
130
  func = candidates.first
131
+ args = func.args.map { |a| final_values[a[:name]] }
132
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
133
  {
130
134
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
135
+ args: args
132
136
  }
133
137
  end
134
138
  end
@@ -188,49 +188,15 @@ module Torch
188
188
  # based on python_variable_indexing.cpp and
189
189
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
190
190
  def [](*indexes)
191
- result = self
192
- dim = 0
193
- indexes.each do |index|
194
- if index.is_a?(Numeric)
195
- result = result._select_int(dim, index)
196
- elsif index.is_a?(Range)
197
- finish = index.end
198
- finish += 1 unless index.exclude_end?
199
- result = result._slice_tensor(dim, index.begin, finish, 1)
200
- dim += 1
201
- elsif index.is_a?(Tensor)
202
- result = result.index([index])
203
- elsif index.nil?
204
- result = result.unsqueeze(dim)
205
- dim += 1
206
- elsif index == true
207
- result = result.unsqueeze(dim)
208
- # TODO handle false
209
- else
210
- raise Error, "Unsupported index type: #{index.class.name}"
211
- end
212
- end
213
- result
191
+ _index(tensor_indexes(indexes))
214
192
  end
215
193
 
216
194
  # based on python_variable_indexing.cpp and
217
195
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
218
- def []=(index, value)
196
+ def []=(*indexes, value)
219
197
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
220
-
221
198
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
222
-
223
- if index.is_a?(Numeric)
224
- index_put!([Torch.tensor(index)], value)
225
- elsif index.is_a?(Range)
226
- finish = index.end
227
- finish += 1 unless index.exclude_end?
228
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
229
- elsif index.is_a?(Tensor)
230
- index_put!([index], value)
231
- else
232
- raise Error, "Unsupported index type: #{index.class.name}"
233
- end
199
+ _index_put_custom(tensor_indexes(indexes), value)
234
200
  end
235
201
 
236
202
  # native functions that need manually defined
@@ -244,13 +210,13 @@ module Torch
244
210
  end
245
211
  end
246
212
 
247
- # native functions overlap, so need to handle manually
213
+ # parser can't handle overlap, so need to handle manually
248
214
  def random!(*args)
249
215
  case args.size
250
216
  when 1
251
217
  _random__to(*args)
252
218
  when 2
253
- _random__from_to(*args)
219
+ _random__from(*args)
254
220
  else
255
221
  _random_(*args)
256
222
  end
@@ -260,5 +226,32 @@ module Torch
260
226
  _clamp_min_(min)
261
227
  _clamp_max_(max)
262
228
  end
229
+
230
+ private
231
+
232
+ def tensor_indexes(indexes)
233
+ indexes.map do |index|
234
+ case index
235
+ when Integer
236
+ TensorIndex.integer(index)
237
+ when Range
238
+ finish = index.end
239
+ if finish == -1 && !index.exclude_end?
240
+ finish = nil
241
+ else
242
+ finish += 1 unless index.exclude_end?
243
+ end
244
+ TensorIndex.slice(index.begin, finish)
245
+ when Tensor
246
+ TensorIndex.tensor(index)
247
+ when nil
248
+ TensorIndex.none
249
+ when true, false
250
+ TensorIndex.boolean(index)
251
+ else
252
+ raise Error, "Unsupported index type: #{index.class.name}"
253
+ end
254
+ end
255
+ end
263
256
  end
264
257
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.2"
2
+ VERSION = "0.3.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.2
4
+ version: 0.3.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-24 00:00:00.000000000 Z
11
+ date: 2020-08-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice