torch-rb 0.3.6 → 0.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
4
- data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
3
+ metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
4
+ data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
5
5
  SHA512:
6
- metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
7
- data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
6
+ metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
7
+ data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
@@ -1,3 +1,12 @@
1
+ ## 0.3.7 (2020-09-22)
2
+
3
+ - Improved performance
4
+ - Added `Upsample`
5
+ - Added support for passing tensor class to `type` method
6
+ - Fixed error with buffers on GPU
7
+ - Fixed error with `new_full`
8
+ - Fixed issue with `numo` method and non-contiguous tensors
9
+
1
10
  ## 0.3.6 (2020-09-17)
2
11
 
3
12
  - Added `inplace` option for leaky ReLU
data/README.md CHANGED
@@ -402,6 +402,7 @@ Here are a few full examples:
402
402
  - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
403
403
  - [Collaborative filtering with MovieLens](examples/movielens)
404
404
  - [Sequence models and word embeddings](examples/nlp)
405
+ - [Generative adversarial networks](examples/gan)
405
406
 
406
407
  ## LibTorch Installation
407
408
 
@@ -232,7 +232,7 @@ void Init_ext()
232
232
  })
233
233
  .define_singleton_method(
234
234
  "_empty",
235
- *[](IntArrayRef size, const torch::TensorOptions &options) {
235
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
236
236
  return torch::empty(size, options);
237
237
  })
238
238
  .define_singleton_method(
@@ -242,7 +242,7 @@ void Init_ext()
242
242
  })
243
243
  .define_singleton_method(
244
244
  "_full",
245
- *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
245
+ *[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
246
246
  return torch::full(size, fill_value, options);
247
247
  })
248
248
  .define_singleton_method(
@@ -257,22 +257,22 @@ void Init_ext()
257
257
  })
258
258
  .define_singleton_method(
259
259
  "_ones",
260
- *[](IntArrayRef size, const torch::TensorOptions &options) {
260
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
261
261
  return torch::ones(size, options);
262
262
  })
263
263
  .define_singleton_method(
264
264
  "_rand",
265
- *[](IntArrayRef size, const torch::TensorOptions &options) {
265
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
266
266
  return torch::rand(size, options);
267
267
  })
268
268
  .define_singleton_method(
269
269
  "_randint",
270
- *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
270
+ *[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
271
271
  return torch::randint(low, high, size, options);
272
272
  })
273
273
  .define_singleton_method(
274
274
  "_randn",
275
- *[](IntArrayRef size, const torch::TensorOptions &options) {
275
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
276
276
  return torch::randn(size, options);
277
277
  })
278
278
  .define_singleton_method(
@@ -282,7 +282,7 @@ void Init_ext()
282
282
  })
283
283
  .define_singleton_method(
284
284
  "_zeros",
285
- *[](IntArrayRef size, const torch::TensorOptions &options) {
285
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
286
286
  return torch::zeros(size, options);
287
287
  })
288
288
  // begin operations
@@ -303,13 +303,13 @@ void Init_ext()
303
303
  })
304
304
  .define_singleton_method(
305
305
  "_from_blob",
306
- *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
306
+ *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
307
307
  void *data = const_cast<char *>(s.c_str());
308
308
  return torch::from_blob(data, size, options);
309
309
  })
310
310
  .define_singleton_method(
311
311
  "_tensor",
312
- *[](Array a, IntArrayRef size, const torch::TensorOptions &options) {
312
+ *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
313
313
  auto dtype = options.dtype();
314
314
  torch::Tensor t;
315
315
  if (dtype == torch::kBool) {
@@ -342,6 +342,16 @@ void Init_ext()
342
342
  .define_method("numel", &torch::Tensor::numel)
343
343
  .define_method("element_size", &torch::Tensor::element_size)
344
344
  .define_method("requires_grad", &torch::Tensor::requires_grad)
345
+ // in C++ for performance
346
+ .define_method(
347
+ "shape",
348
+ *[](Tensor& self) {
349
+ Array a;
350
+ for (auto &size : self.sizes()) {
351
+ a.push(size);
352
+ }
353
+ return a;
354
+ })
345
355
  .define_method(
346
356
  "_index",
347
357
  *[](Tensor& self, Array indices) {
@@ -420,9 +430,19 @@ void Init_ext()
420
430
  tensor = tensor.to(device);
421
431
  }
422
432
 
433
+ if (!tensor.is_contiguous()) {
434
+ tensor = tensor.contiguous();
435
+ }
436
+
423
437
  auto data_ptr = (const char *) tensor.data_ptr();
424
438
  return std::string(data_ptr, tensor.numel() * tensor.element_size());
425
439
  })
440
+ // for TorchVision
441
+ .define_method(
442
+ "_data_ptr",
443
+ *[](Tensor& self) {
444
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
445
+ })
426
446
  // TODO figure out a better way to do this
427
447
  .define_method(
428
448
  "_flat_data",
@@ -17,6 +17,9 @@ if have_library("omp") || have_library("gomp")
17
17
  end
18
18
 
19
19
  if apple_clang
20
+ # silence rice warnings
21
+ $CXXFLAGS += " -Wno-deprecated-declarations"
22
+
20
23
  # silence ruby/intern.h warning
21
24
  $CXXFLAGS += " -Wno-deprecated-register"
22
25
 
@@ -2,6 +2,34 @@
2
2
  #include <rice/Object.hpp>
3
3
  #include "templates.hpp"
4
4
 
5
+ Object wrap(bool x) {
6
+ return to_ruby<bool>(x);
7
+ }
8
+
9
+ Object wrap(int64_t x) {
10
+ return to_ruby<int64_t>(x);
11
+ }
12
+
13
+ Object wrap(double x) {
14
+ return to_ruby<double>(x);
15
+ }
16
+
17
+ Object wrap(torch::Tensor x) {
18
+ return to_ruby<torch::Tensor>(x);
19
+ }
20
+
21
+ Object wrap(torch::Scalar x) {
22
+ return to_ruby<torch::Scalar>(x);
23
+ }
24
+
25
+ Object wrap(torch::ScalarType x) {
26
+ return to_ruby<torch::ScalarType>(x);
27
+ }
28
+
29
+ Object wrap(torch::QScheme x) {
30
+ return to_ruby<torch::QScheme>(x);
31
+ }
32
+
5
33
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
6
34
  Array a;
7
35
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
@@ -13,49 +13,31 @@ using torch::Device;
13
13
  using torch::Scalar;
14
14
  using torch::ScalarType;
15
15
  using torch::Tensor;
16
-
17
- // need to wrap torch::IntArrayRef() since
18
- // it doesn't own underlying data
19
- class IntArrayRef {
20
- std::vector<int64_t> vec;
21
- public:
22
- IntArrayRef(Object o) {
23
- Array a = Array(o);
24
- for (size_t i = 0; i < a.size(); i++) {
25
- vec.push_back(from_ruby<int64_t>(a[i]));
26
- }
27
- }
28
- operator torch::IntArrayRef() {
29
- return torch::IntArrayRef(vec);
30
- }
31
- };
16
+ using torch::IntArrayRef;
17
+ using torch::TensorList;
32
18
 
33
19
  template<>
34
20
  inline
35
- IntArrayRef from_ruby<IntArrayRef>(Object x)
21
+ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
36
22
  {
37
- return IntArrayRef(x);
23
+ Array a = Array(x);
24
+ std::vector<int64_t> vec(a.size());
25
+ for (size_t i = 0; i < a.size(); i++) {
26
+ vec[i] = from_ruby<int64_t>(a[i]);
27
+ }
28
+ return vec;
38
29
  }
39
30
 
40
- class TensorList {
41
- std::vector<torch::Tensor> vec;
42
- public:
43
- TensorList(Object o) {
44
- Array a = Array(o);
45
- for (size_t i = 0; i < a.size(); i++) {
46
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
47
- }
48
- }
49
- operator torch::TensorList() {
50
- return torch::TensorList(vec);
51
- }
52
- };
53
-
54
31
  template<>
55
32
  inline
56
- TensorList from_ruby<TensorList>(Object x)
33
+ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
57
34
  {
58
- return TensorList(x);
35
+ Array a = Array(x);
36
+ std::vector<Tensor> vec(a.size());
37
+ for (size_t i = 0; i < a.size(); i++) {
38
+ vec[i] = from_ruby<Tensor>(a[i]);
39
+ }
40
+ return vec;
59
41
  }
60
42
 
61
43
  class FanModeType {
@@ -242,6 +224,13 @@ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
242
224
  }
243
225
  }
244
226
 
227
+ Object wrap(bool x);
228
+ Object wrap(int64_t x);
229
+ Object wrap(double x);
230
+ Object wrap(torch::Tensor x);
231
+ Object wrap(torch::Scalar x);
232
+ Object wrap(torch::ScalarType x);
233
+ Object wrap(torch::QScheme x);
245
234
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
246
235
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
247
236
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss"
174
174
  require "torch/nn/soft_margin_loss"
175
175
  require "torch/nn/triplet_margin_loss"
176
176
 
177
+ # nn vision
178
+ require "torch/nn/upsample"
179
+
177
180
  # nn other
178
181
  require "torch/nn/functional"
179
182
  require "torch/nn/init"
@@ -196,6 +199,32 @@ module Torch
196
199
  end
197
200
  end
198
201
 
202
+ # legacy
203
+ # but may make it easier to port tutorials
204
+ module Autograd
205
+ class Variable
206
+ def self.new(x)
207
+ raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
208
+ warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
209
+ x
210
+ end
211
+ end
212
+ end
213
+
214
+ # TODO move to C++
215
+ class ByteStorage
216
+ # private
217
+ attr_reader :bytes
218
+
219
+ def initialize(bytes)
220
+ @bytes = bytes
221
+ end
222
+
223
+ def self.from_buffer(bytes)
224
+ new(bytes)
225
+ end
226
+ end
227
+
199
228
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
200
229
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
201
230
  DTYPE_TO_ENUM = {
@@ -224,18 +253,24 @@ module Torch
224
253
  }
225
254
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
226
255
 
256
+ TENSOR_TYPE_CLASSES = []
257
+
227
258
  def self._make_tensor_class(dtype, cuda = false)
228
259
  cls = Class.new
229
260
  device = cuda ? "cuda" : "cpu"
230
261
  cls.define_singleton_method("new") do |*args|
231
262
  if args.size == 1 && args.first.is_a?(Tensor)
232
263
  args.first.send(dtype).to(device)
264
+ elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
265
+ bytes = args.first.bytes
266
+ Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
233
267
  elsif args.size == 1 && args.first.is_a?(Array)
234
268
  Torch.tensor(args.first, dtype: dtype, device: device)
235
269
  else
236
270
  Torch.empty(*args, dtype: dtype, device: device)
237
271
  end
238
272
  end
273
+ TENSOR_TYPE_CLASSES << cls
239
274
  cls
240
275
  end
241
276
 
@@ -22,21 +22,43 @@ module Torch
22
22
  end
23
23
 
24
24
  def bind_functions(context, def_method, functions)
25
+ instance_method = def_method == :define_method
25
26
  functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
26
- if def_method == :define_method
27
+ if instance_method
27
28
  funcs.map! { |f| Function.new(f.function) }
28
- funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
29
+ funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
29
30
  end
30
31
 
31
- defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
32
+ defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
32
33
  next if defined && name != "clone"
33
34
 
34
- parser = Parser.new(funcs)
35
+ # skip parser when possible for performance
36
+ if funcs.size == 1 && funcs.first.args.size == 0
37
+ # functions with no arguments
38
+ if instance_method
39
+ context.send(:alias_method, name, funcs.first.cpp_name)
40
+ else
41
+ context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
42
+ end
43
+ elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
44
+ # functions that take a tensor or scalar
45
+ scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
46
+ context.send(def_method, name) do |other|
47
+ case other
48
+ when Tensor
49
+ send(tensor_name, other)
50
+ else
51
+ send(scalar_name, other)
52
+ end
53
+ end
54
+ else
55
+ parser = Parser.new(funcs)
35
56
 
36
- context.send(def_method, name) do |*args, **options|
37
- result = parser.parse(args, options)
38
- raise ArgumentError, result[:error] if result[:error]
39
- send(result[:name], *result[:args])
57
+ context.send(def_method, name) do |*args, **options|
58
+ result = parser.parse(args, options)
59
+ raise ArgumentError, result[:error] if result[:error]
60
+ send(result[:name], *result[:args])
61
+ end
40
62
  end
41
63
  end
42
64
  end
@@ -6,9 +6,10 @@ module Torch
6
6
  def initialize(function)
7
7
  @function = function
8
8
 
9
- tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
- @tensor_options = @function["func"].include?(tensor_options_str)
11
- @function["func"].sub!(tensor_options_str, ")")
9
+ # note: don't modify function in-place
10
+ @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
11
+ @tensor_options = @function["func"].include?(@tensor_options_str)
12
+ @out = out_size > 0 && base_name[-1] != "_"
12
13
  end
13
14
 
14
15
  def func
@@ -31,7 +32,7 @@ module Torch
31
32
  @args ||= begin
32
33
  args = []
33
34
  pos = true
34
- args_str = func.split("(", 2).last.split(") ->").first
35
+ args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
35
36
  args_str.split(", ").each do |a|
36
37
  if a == "*"
37
38
  pos = false
@@ -72,12 +73,88 @@ module Torch
72
73
  next if t == "Generator?"
73
74
  next if t == "MemoryFormat"
74
75
  next if t == "MemoryFormat?"
75
- args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
76
+ args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
76
77
  end
77
78
  args
78
79
  end
79
80
  end
80
81
 
82
+ def arg_checkers
83
+ @arg_checkers ||= begin
84
+ checkers = {}
85
+ arg_types.each do |k, t|
86
+ checker =
87
+ case t
88
+ when "Tensor"
89
+ ->(v) { v.is_a?(Tensor) }
90
+ when "Tensor?"
91
+ ->(v) { v.nil? || v.is_a?(Tensor) }
92
+ when "Tensor[]", "Tensor?[]"
93
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
94
+ when "int"
95
+ if k == :reduction
96
+ ->(v) { v.is_a?(String) }
97
+ else
98
+ ->(v) { v.is_a?(Integer) }
99
+ end
100
+ when "int?"
101
+ ->(v) { v.is_a?(Integer) || v.nil? }
102
+ when "float?"
103
+ ->(v) { v.is_a?(Numeric) || v.nil? }
104
+ when "bool?"
105
+ ->(v) { v == true || v == false || v.nil? }
106
+ when "float"
107
+ ->(v) { v.is_a?(Numeric) }
108
+ when /int\[.*\]/
109
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
110
+ when "Scalar"
111
+ ->(v) { v.is_a?(Numeric) }
112
+ when "Scalar?"
113
+ ->(v) { v.is_a?(Numeric) || v.nil? }
114
+ when "ScalarType"
115
+ ->(v) { false } # not supported yet
116
+ when "ScalarType?"
117
+ ->(v) { v.nil? }
118
+ when "bool"
119
+ ->(v) { v == true || v == false }
120
+ when "str"
121
+ ->(v) { v.is_a?(String) }
122
+ else
123
+ raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
124
+ end
125
+ checkers[k] = checker
126
+ end
127
+ checkers
128
+ end
129
+ end
130
+
131
+ def int_array_lengths
132
+ @int_array_lengths ||= begin
133
+ ret = {}
134
+ arg_types.each do |k, t|
135
+ if t.match?(/\Aint\[.+\]\z/)
136
+ size = t[4..-2]
137
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
138
+ ret[k] = size.to_i
139
+ end
140
+ end
141
+ ret
142
+ end
143
+ end
144
+
145
+ def arg_names
146
+ @arg_names ||= args.map { |a| a[:name] }
147
+ end
148
+
149
+ def arg_types
150
+ @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
151
+ end
152
+
153
+ def arg_defaults
154
+ # TODO find out why can't use select here
155
+ @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
156
+ end
157
+
81
158
  def out_size
82
159
  @out_size ||= func.split("->").last.count("!")
83
160
  end
@@ -90,8 +167,12 @@ module Torch
90
167
  @ret_array ||= func.split("->").last.include?('[]')
91
168
  end
92
169
 
170
+ def ret_void?
171
+ func.split("->").last.strip == "()"
172
+ end
173
+
93
174
  def out?
94
- out_size > 0 && base_name[-1] != "_"
175
+ @out
95
176
  end
96
177
 
97
178
  def ruby_name
@@ -72,16 +72,18 @@ void add_%{type}_functions(Module m);
72
72
  #include <rice/Module.hpp>
73
73
  #include "templates.hpp"
74
74
 
75
+ %{functions}
76
+
75
77
  void add_%{type}_functions(Module m) {
76
- m
77
- %{functions};
78
+ %{add_functions}
78
79
  }
79
80
  TEMPLATE
80
81
 
81
82
  cpp_defs = []
83
+ add_defs = []
82
84
  functions.sort_by(&:cpp_name).each do |func|
83
85
  fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
84
- fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
86
+ fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
85
87
 
86
88
  cpp_args = []
87
89
  fargs.each do |a|
@@ -94,11 +96,9 @@ void add_%{type}_functions(Module m) {
94
96
  "OptionalTensor"
95
97
  when "ScalarType?"
96
98
  "torch::optional<ScalarType>"
97
- when "Tensor[]"
98
- "TensorList"
99
- when "Tensor?[]"
99
+ when "Tensor[]", "Tensor?[]"
100
100
  # TODO make optional
101
- "TensorList"
101
+ "std::vector<Tensor>"
102
102
  when "int"
103
103
  "int64_t"
104
104
  when "int?"
@@ -112,43 +112,53 @@ void add_%{type}_functions(Module m) {
112
112
  when "float"
113
113
  "double"
114
114
  when /\Aint\[/
115
- "IntArrayRef"
115
+ "std::vector<int64_t>"
116
116
  when /Tensor\(\S!?\)/
117
117
  "Tensor &"
118
118
  when "str"
119
119
  "std::string"
120
120
  when "TensorOptions"
121
121
  "const torch::TensorOptions &"
122
- else
122
+ when "Layout?"
123
+ "torch::optional<Layout>"
124
+ when "Device?"
125
+ "torch::optional<Device>"
126
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
123
127
  a[:type]
128
+ else
129
+ raise "Unknown type: #{a[:type]}"
124
130
  end
125
131
 
126
- t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
132
+ t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
127
133
  cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
128
134
  end
129
135
 
130
136
  dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
131
137
  args = fargs.map { |a| a[:name] }
132
138
  args.unshift(*args.pop(func.out_size)) if func.out?
133
- args.delete("self") if def_method == :define_method
139
+ args.delete(:self) if def_method == :define_method
134
140
 
135
141
  prefix = def_method == :define_method ? "self." : "torch::"
136
142
 
137
143
  body = "#{prefix}#{dispatch}(#{args.join(", ")})"
138
144
 
139
- if func.ret_size > 1 || func.ret_array?
145
+ if func.cpp_name == "_fill_diagonal_"
146
+ body = "to_ruby<torch::Tensor>(#{body})"
147
+ elsif !func.ret_void?
140
148
  body = "wrap(#{body})"
141
149
  end
142
150
 
143
- cpp_defs << ".#{def_method}(
144
- \"#{func.cpp_name}\",
145
- *[](#{cpp_args.join(", ")}) {
146
- return #{body};
147
- })"
151
+ cpp_defs << "// #{func.func}
152
+ static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
153
+ {
154
+ return #{body};
155
+ }"
156
+
157
+ add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
148
158
  end
149
159
 
150
160
  hpp_contents = hpp_template % {type: type}
151
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
161
+ cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
152
162
 
153
163
  path = File.expand_path("../../../ext/torch", __dir__)
154
164
  File.write("#{path}/#{type}_functions.hpp", hpp_contents)
@@ -6,14 +6,24 @@ module Torch
6
6
  @name = @functions.first.ruby_name
7
7
  @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
8
  @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
9
10
  end
10
11
 
12
+ # TODO improve performance
13
+ # possibly move to C++ (see python_arg_parser.cpp)
11
14
  def parse(args, options)
12
15
  candidates = @functions.dup
13
16
 
14
- # remove nil
15
- while args.any? && args.last.nil?
16
- args.pop
17
+ # TODO check candidates individually to see if they match
18
+ if @int_array_first
19
+ int_args = []
20
+ while args.first.is_a?(Integer)
21
+ int_args << args.shift
22
+ end
23
+ if int_args.any?
24
+ raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
25
+ args.unshift(int_args)
26
+ end
17
27
  end
18
28
 
19
29
  # TODO account for args passed as options here
@@ -25,99 +35,60 @@ module Torch
25
35
 
26
36
  candidates.reject! { |f| args.size > f.args.size }
27
37
 
28
- # exclude functions missing required options
29
- candidates.reject! do |func|
30
- # TODO make more generic
31
- func.out? && !options[:out]
32
- end
33
-
34
38
  # handle out with multiple
35
39
  # there should only be one match, so safe to modify all
36
- out_func = candidates.find { |f| f.out? }
37
- if out_func && out_func.out_size > 1 && options[:out]
38
- out_args = out_func.args.last(2).map { |a| a[:name] }
39
- out_args.zip(options.delete(:out)).each do |k, v|
40
- options[k.to_sym] = v
41
- end
42
- candidates = [out_func]
43
- end
44
-
45
- # exclude functions where options don't match
46
- options.each do |k, v|
47
- candidates.select! do |func|
48
- func.args.any? { |a| a[:name] == k.to_s }
40
+ if options[:out]
41
+ if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
42
+ out_args = out_func.args.last(2).map { |a| a[:name] }
43
+ out_args.zip(options.delete(:out)).each do |k, v|
44
+ options[k] = v
45
+ end
46
+ candidates = [out_func]
49
47
  end
50
- # TODO show all bad keywords at once like Ruby?
51
- return {error: "unknown keyword: #{k}"} if candidates.empty?
48
+ else
49
+ # exclude functions missing required options
50
+ candidates.reject!(&:out?)
52
51
  end
53
52
 
54
- final_values = {}
53
+ final_values = nil
55
54
 
56
55
  # check args
57
- candidates.select! do |func|
56
+ while (func = candidates.shift)
58
57
  good = true
59
58
 
60
- values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
61
- values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
62
- func.args.each do |fa|
63
- values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
59
+ # set values
60
+ # TODO use array instead of hash?
61
+ values = {}
62
+ args.each_with_index do |a, i|
63
+ values[func.arg_names[i]] = a
64
+ end
65
+ options.each do |k, v|
66
+ values[k] = v
67
+ end
68
+ func.arg_defaults.each do |k, v|
69
+ values[k] = v unless values.key?(k)
70
+ end
71
+ func.int_array_lengths.each do |k, len|
72
+ values[k] = [values[k]] * len if values[k].is_a?(Integer)
64
73
  end
65
74
 
66
- arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
75
+ arg_checkers = func.arg_checkers
67
76
 
68
77
  values.each_key do |k|
69
- v = values[k]
70
- t = arg_types[k].split("(").first
71
-
72
- good =
73
- case t
74
- when "Tensor"
75
- v.is_a?(Tensor)
76
- when "Tensor?"
77
- v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]", "Tensor?[]"
79
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
- when "int"
81
- if k == "reduction"
82
- v.is_a?(String)
83
- else
84
- v.is_a?(Integer)
85
- end
86
- when "int?"
87
- v.is_a?(Integer) || v.nil?
88
- when "float?"
89
- v.is_a?(Numeric) || v.nil?
90
- when "bool?"
91
- v == true || v == false || v.nil?
92
- when "float"
93
- v.is_a?(Numeric)
94
- when /int\[.*\]/
95
- if v.is_a?(Integer)
96
- size = t[4..-2]
97
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
98
- v = [v] * size.to_i
99
- values[k] = v
100
- end
101
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
102
- when "Scalar"
103
- v.is_a?(Numeric)
104
- when "Scalar?"
105
- v.is_a?(Numeric) || v.nil?
106
- when "ScalarType"
107
- false # not supported yet
108
- when "ScalarType?"
109
- v.nil?
110
- when "bool"
111
- v == true || v == false
112
- when "str"
113
- v.is_a?(String)
114
- else
115
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
78
+ unless arg_checkers.key?(k)
79
+ good = false
80
+ if candidates.empty?
81
+ # TODO show all bad keywords at once like Ruby?
82
+ return {error: "unknown keyword: #{k}"}
116
83
  end
84
+ break
85
+ end
117
86
 
118
- if !good
119
- if candidates.size == 1
120
- k = "input" if k == "self"
87
+ unless arg_checkers[k].call(values[k])
88
+ good = false
89
+ if candidates.empty?
90
+ t = func.arg_types[k]
91
+ k = :input if k == :self
121
92
  return {error: "#{@name}(): argument '#{k}' must be #{t}"}
122
93
  end
123
94
  break
@@ -126,17 +97,15 @@ module Torch
126
97
 
127
98
  if good
128
99
  final_values = values
100
+ break
129
101
  end
130
-
131
- good
132
102
  end
133
103
 
134
- if candidates.size != 1
104
+ unless final_values
135
105
  raise Error, "This should never happen. Please report a bug with #{@name}."
136
106
  end
137
107
 
138
- func = candidates.first
139
- args = func.args.map { |a| final_values[a[:name]] }
108
+ args = func.arg_names.map { |k| final_values[k] }
140
109
  args << TensorOptions.new.dtype(6) if func.tensor_options
141
110
  {
142
111
  name: func.cpp_name,
@@ -469,6 +469,77 @@ module Torch
469
469
  Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
470
470
  end
471
471
 
472
+ # vision
473
+
474
+ def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
475
+ if ["nearest", "area"].include?(mode)
476
+ unless align_corners.nil?
477
+ raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
478
+ end
479
+ else
480
+ if align_corners.nil?
481
+ align_corners = false
482
+ end
483
+ end
484
+
485
+ scale_factor_len = input.dim - 2
486
+ scale_factor_list = [nil] * scale_factor_len
487
+ # default value of recompute_scale_factor is False
488
+ if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
489
+ if scale_factor.is_a?(Array)
490
+ _scale_factor_repeated = scale_factor
491
+ else
492
+ _scale_factor_repeated = [scale_factor] * scale_factor_len
493
+ end
494
+ scale_factor_list = _scale_factor_repeated
495
+ end
496
+
497
+ # Give this variable a short name because it has to be repeated multiple times below.
498
+ sfl = scale_factor_list
499
+
500
+ closed_over_args = [input, size, scale_factor, recompute_scale_factor]
501
+ output_size = _interp_output_size(closed_over_args)
502
+ if input.dim == 3 && mode == "nearest"
503
+ NN.upsample_nearest1d(input, output_size, sfl[0])
504
+ elsif input.dim == 4 && mode == "nearest"
505
+ NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
506
+ elsif input.dim == 5 && mode == "nearest"
507
+ NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
508
+ elsif input.dim == 3 && mode == "area"
509
+ adaptive_avg_pool1d(input, output_size)
510
+ elsif input.dim == 4 && mode == "area"
511
+ adaptive_avg_pool2d(input, output_size)
512
+ elsif input.dim == 5 && mode == "area"
513
+ adaptive_avg_pool3d(input, output_size)
514
+ elsif input.dim == 3 && mode == "linear"
515
+ # assert align_corners is not None
516
+ NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
517
+ elsif input.dim == 3 && mode == "bilinear"
518
+ raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
519
+ elsif input.dim == 3 && mode == "trilinear"
520
+ raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
521
+ elsif input.dim == 4 && mode == "linear"
522
+ raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
523
+ elsif input.dim == 4 && mode == "bilinear"
524
+ # assert align_corners is not None
525
+ NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
526
+ elsif input.dim == 4 && mode == "trilinear"
527
+ raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
528
+ elsif input.dim == 5 && mode == "linear"
529
+ raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
530
+ elsif input.dim == 5 && mode == "bilinear"
531
+ raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
532
+ elsif input.dim == 5 && mode == "trilinear"
533
+ # assert align_corners is not None
534
+ NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
535
+ elsif input.dim == 4 && mode == "bicubic"
536
+ # assert align_corners is not None
537
+ NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
538
+ else
539
+ raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
540
+ end
541
+ end
542
+
472
543
  private
473
544
 
474
545
  def softmax_dim(ndim)
@@ -484,6 +555,41 @@ module Torch
484
555
  out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
485
556
  end
486
557
  end
558
+
559
+ def _interp_output_size(closed_over_args)
560
+ input, size, scale_factor, recompute_scale_factor = closed_over_args
561
+ dim = input.dim - 2
562
+ if size.nil? && scale_factor.nil?
563
+ raise ArgumentError, "either size or scale_factor should be defined"
564
+ end
565
+ if !size.nil? && !scale_factor.nil?
566
+ raise ArgumentError, "only one of size or scale_factor should be defined"
567
+ end
568
+ if !scale_factor.nil?
569
+ if scale_factor.is_a?(Array)
570
+ if scale_factor.length != dim
571
+ raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
572
+ end
573
+ end
574
+ end
575
+
576
+ if !size.nil?
577
+ if size.is_a?(Array)
578
+ return size
579
+ else
580
+ return [size] * dim
581
+ end
582
+ end
583
+
584
+ raise "Failed assertion" if scale_factor.nil?
585
+ if scale_factor.is_a?(Array)
586
+ scale_factors = scale_factor
587
+ else
588
+ scale_factors = [scale_factor] * dim
589
+ end
590
+
591
+ dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
592
+ end
487
593
  end
488
594
  end
489
595
 
@@ -58,7 +58,10 @@ module Torch
58
58
 
59
59
  @buffers.each_key do |k|
60
60
  buf = @buffers[k]
61
- @buffers[k] = fn.call(buf) unless buf.nil?
61
+ unless buf.nil?
62
+ @buffers[k] = fn.call(buf)
63
+ instance_variable_set("@#{k}", @buffers[k])
64
+ end
62
65
  end
63
66
 
64
67
  self
@@ -0,0 +1,31 @@
1
+ module Torch
2
+ module NN
3
+ class Upsample < Module
4
+ def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
5
+ super()
6
+ @size = size
7
+ if scale_factor.is_a?(Array)
8
+ @scale_factor = scale_factor.map(&:to_f)
9
+ else
10
+ @scale_factor = scale_factor ? scale_factor.to_f : nil
11
+ end
12
+ @mode = mode
13
+ @align_corners = align_corners
14
+ end
15
+
16
+ def forward(input)
17
+ F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
18
+ end
19
+
20
+ def extra_inspect
21
+ if !@scale_factor.nil?
22
+ info = "scale_factor: #{@scale_factor.inspect}"
23
+ else
24
+ info = "size: #{@size.inspect}"
25
+ end
26
+ info += ", mode: #{@mode.inspect}"
27
+ info
28
+ end
29
+ end
30
+ end
31
+ end
@@ -48,6 +48,11 @@ module Torch
48
48
  end
49
49
 
50
50
  def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ if device.is_a?(Symbol) && !dtype
52
+ dtype = device
53
+ device = nil
54
+ end
55
+
51
56
  device ||= self.device
52
57
  device = Device.new(device) if device.is_a?(String)
53
58
 
@@ -74,10 +79,6 @@ module Torch
74
79
  end
75
80
  end
76
81
 
77
- def shape
78
- dim.times.map { |i| size(i) }
79
- end
80
-
81
82
  # mirror Python len()
82
83
  def length
83
84
  size(0)
@@ -119,9 +120,14 @@ module Torch
119
120
  end
120
121
 
121
122
  def type(dtype)
122
- enum = DTYPE_TO_ENUM[dtype]
123
- raise Error, "Unknown type: #{dtype}" unless enum
124
- _type(enum)
123
+ if dtype.is_a?(Class)
124
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
125
+ dtype.new(self)
126
+ else
127
+ enum = DTYPE_TO_ENUM[dtype]
128
+ raise Error, "Invalid type: #{dtype}" unless enum
129
+ _type(enum)
130
+ end
125
131
  end
126
132
 
127
133
  def reshape(*size)
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.6"
2
+ VERSION = "0.3.7"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.6
4
+ version: 0.3.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-18 00:00:00.000000000 Z
11
+ date: 2020-09-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -238,6 +238,7 @@ files:
238
238
  - lib/torch/nn/tanhshrink.rb
239
239
  - lib/torch/nn/triplet_margin_loss.rb
240
240
  - lib/torch/nn/unfold.rb
241
+ - lib/torch/nn/upsample.rb
241
242
  - lib/torch/nn/utils.rb
242
243
  - lib/torch/nn/weighted_loss.rb
243
244
  - lib/torch/nn/zero_pad2d.rb