torch-rb 0.3.6 → 0.3.7

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
4
- data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
3
+ metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
4
+ data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
5
5
  SHA512:
6
- metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
7
- data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
6
+ metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
7
+ data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
@@ -1,3 +1,12 @@
1
+ ## 0.3.7 (2020-09-22)
2
+
3
+ - Improved performance
4
+ - Added `Upsample`
5
+ - Added support for passing tensor class to `type` method
6
+ - Fixed error with buffers on GPU
7
+ - Fixed error with `new_full`
8
+ - Fixed issue with `numo` method and non-contiguous tensors
9
+
1
10
  ## 0.3.6 (2020-09-17)
2
11
 
3
12
  - Added `inplace` option for leaky ReLU
data/README.md CHANGED
@@ -402,6 +402,7 @@ Here are a few full examples:
402
402
  - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
403
403
  - [Collaborative filtering with MovieLens](examples/movielens)
404
404
  - [Sequence models and word embeddings](examples/nlp)
405
+ - [Generative adversarial networks](examples/gan)
405
406
 
406
407
  ## LibTorch Installation
407
408
 
@@ -232,7 +232,7 @@ void Init_ext()
232
232
  })
233
233
  .define_singleton_method(
234
234
  "_empty",
235
- *[](IntArrayRef size, const torch::TensorOptions &options) {
235
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
236
236
  return torch::empty(size, options);
237
237
  })
238
238
  .define_singleton_method(
@@ -242,7 +242,7 @@ void Init_ext()
242
242
  })
243
243
  .define_singleton_method(
244
244
  "_full",
245
- *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
245
+ *[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
246
246
  return torch::full(size, fill_value, options);
247
247
  })
248
248
  .define_singleton_method(
@@ -257,22 +257,22 @@ void Init_ext()
257
257
  })
258
258
  .define_singleton_method(
259
259
  "_ones",
260
- *[](IntArrayRef size, const torch::TensorOptions &options) {
260
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
261
261
  return torch::ones(size, options);
262
262
  })
263
263
  .define_singleton_method(
264
264
  "_rand",
265
- *[](IntArrayRef size, const torch::TensorOptions &options) {
265
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
266
266
  return torch::rand(size, options);
267
267
  })
268
268
  .define_singleton_method(
269
269
  "_randint",
270
- *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
270
+ *[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
271
271
  return torch::randint(low, high, size, options);
272
272
  })
273
273
  .define_singleton_method(
274
274
  "_randn",
275
- *[](IntArrayRef size, const torch::TensorOptions &options) {
275
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
276
276
  return torch::randn(size, options);
277
277
  })
278
278
  .define_singleton_method(
@@ -282,7 +282,7 @@ void Init_ext()
282
282
  })
283
283
  .define_singleton_method(
284
284
  "_zeros",
285
- *[](IntArrayRef size, const torch::TensorOptions &options) {
285
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
286
286
  return torch::zeros(size, options);
287
287
  })
288
288
  // begin operations
@@ -303,13 +303,13 @@ void Init_ext()
303
303
  })
304
304
  .define_singleton_method(
305
305
  "_from_blob",
306
- *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
306
+ *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
307
307
  void *data = const_cast<char *>(s.c_str());
308
308
  return torch::from_blob(data, size, options);
309
309
  })
310
310
  .define_singleton_method(
311
311
  "_tensor",
312
- *[](Array a, IntArrayRef size, const torch::TensorOptions &options) {
312
+ *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
313
313
  auto dtype = options.dtype();
314
314
  torch::Tensor t;
315
315
  if (dtype == torch::kBool) {
@@ -342,6 +342,16 @@ void Init_ext()
342
342
  .define_method("numel", &torch::Tensor::numel)
343
343
  .define_method("element_size", &torch::Tensor::element_size)
344
344
  .define_method("requires_grad", &torch::Tensor::requires_grad)
345
+ // in C++ for performance
346
+ .define_method(
347
+ "shape",
348
+ *[](Tensor& self) {
349
+ Array a;
350
+ for (auto &size : self.sizes()) {
351
+ a.push(size);
352
+ }
353
+ return a;
354
+ })
345
355
  .define_method(
346
356
  "_index",
347
357
  *[](Tensor& self, Array indices) {
@@ -420,9 +430,19 @@ void Init_ext()
420
430
  tensor = tensor.to(device);
421
431
  }
422
432
 
433
+ if (!tensor.is_contiguous()) {
434
+ tensor = tensor.contiguous();
435
+ }
436
+
423
437
  auto data_ptr = (const char *) tensor.data_ptr();
424
438
  return std::string(data_ptr, tensor.numel() * tensor.element_size());
425
439
  })
440
+ // for TorchVision
441
+ .define_method(
442
+ "_data_ptr",
443
+ *[](Tensor& self) {
444
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
445
+ })
426
446
  // TODO figure out a better way to do this
427
447
  .define_method(
428
448
  "_flat_data",
@@ -17,6 +17,9 @@ if have_library("omp") || have_library("gomp")
17
17
  end
18
18
 
19
19
  if apple_clang
20
+ # silence rice warnings
21
+ $CXXFLAGS += " -Wno-deprecated-declarations"
22
+
20
23
  # silence ruby/intern.h warning
21
24
  $CXXFLAGS += " -Wno-deprecated-register"
22
25
 
@@ -2,6 +2,34 @@
2
2
  #include <rice/Object.hpp>
3
3
  #include "templates.hpp"
4
4
 
5
+ Object wrap(bool x) {
6
+ return to_ruby<bool>(x);
7
+ }
8
+
9
+ Object wrap(int64_t x) {
10
+ return to_ruby<int64_t>(x);
11
+ }
12
+
13
+ Object wrap(double x) {
14
+ return to_ruby<double>(x);
15
+ }
16
+
17
+ Object wrap(torch::Tensor x) {
18
+ return to_ruby<torch::Tensor>(x);
19
+ }
20
+
21
+ Object wrap(torch::Scalar x) {
22
+ return to_ruby<torch::Scalar>(x);
23
+ }
24
+
25
+ Object wrap(torch::ScalarType x) {
26
+ return to_ruby<torch::ScalarType>(x);
27
+ }
28
+
29
+ Object wrap(torch::QScheme x) {
30
+ return to_ruby<torch::QScheme>(x);
31
+ }
32
+
5
33
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
6
34
  Array a;
7
35
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
@@ -13,49 +13,31 @@ using torch::Device;
13
13
  using torch::Scalar;
14
14
  using torch::ScalarType;
15
15
  using torch::Tensor;
16
-
17
- // need to wrap torch::IntArrayRef() since
18
- // it doesn't own underlying data
19
- class IntArrayRef {
20
- std::vector<int64_t> vec;
21
- public:
22
- IntArrayRef(Object o) {
23
- Array a = Array(o);
24
- for (size_t i = 0; i < a.size(); i++) {
25
- vec.push_back(from_ruby<int64_t>(a[i]));
26
- }
27
- }
28
- operator torch::IntArrayRef() {
29
- return torch::IntArrayRef(vec);
30
- }
31
- };
16
+ using torch::IntArrayRef;
17
+ using torch::TensorList;
32
18
 
33
19
  template<>
34
20
  inline
35
- IntArrayRef from_ruby<IntArrayRef>(Object x)
21
+ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
36
22
  {
37
- return IntArrayRef(x);
23
+ Array a = Array(x);
24
+ std::vector<int64_t> vec(a.size());
25
+ for (size_t i = 0; i < a.size(); i++) {
26
+ vec[i] = from_ruby<int64_t>(a[i]);
27
+ }
28
+ return vec;
38
29
  }
39
30
 
40
- class TensorList {
41
- std::vector<torch::Tensor> vec;
42
- public:
43
- TensorList(Object o) {
44
- Array a = Array(o);
45
- for (size_t i = 0; i < a.size(); i++) {
46
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
47
- }
48
- }
49
- operator torch::TensorList() {
50
- return torch::TensorList(vec);
51
- }
52
- };
53
-
54
31
  template<>
55
32
  inline
56
- TensorList from_ruby<TensorList>(Object x)
33
+ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
57
34
  {
58
- return TensorList(x);
35
+ Array a = Array(x);
36
+ std::vector<Tensor> vec(a.size());
37
+ for (size_t i = 0; i < a.size(); i++) {
38
+ vec[i] = from_ruby<Tensor>(a[i]);
39
+ }
40
+ return vec;
59
41
  }
60
42
 
61
43
  class FanModeType {
@@ -242,6 +224,13 @@ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
242
224
  }
243
225
  }
244
226
 
227
+ Object wrap(bool x);
228
+ Object wrap(int64_t x);
229
+ Object wrap(double x);
230
+ Object wrap(torch::Tensor x);
231
+ Object wrap(torch::Scalar x);
232
+ Object wrap(torch::ScalarType x);
233
+ Object wrap(torch::QScheme x);
245
234
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
246
235
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
247
236
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
@@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss"
174
174
  require "torch/nn/soft_margin_loss"
175
175
  require "torch/nn/triplet_margin_loss"
176
176
 
177
+ # nn vision
178
+ require "torch/nn/upsample"
179
+
177
180
  # nn other
178
181
  require "torch/nn/functional"
179
182
  require "torch/nn/init"
@@ -196,6 +199,32 @@ module Torch
196
199
  end
197
200
  end
198
201
 
202
+ # legacy
203
+ # but may make it easier to port tutorials
204
+ module Autograd
205
+ class Variable
206
+ def self.new(x)
207
+ raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
208
+ warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
209
+ x
210
+ end
211
+ end
212
+ end
213
+
214
+ # TODO move to C++
215
+ class ByteStorage
216
+ # private
217
+ attr_reader :bytes
218
+
219
+ def initialize(bytes)
220
+ @bytes = bytes
221
+ end
222
+
223
+ def self.from_buffer(bytes)
224
+ new(bytes)
225
+ end
226
+ end
227
+
199
228
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
200
229
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
201
230
  DTYPE_TO_ENUM = {
@@ -224,18 +253,24 @@ module Torch
224
253
  }
225
254
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
226
255
 
256
+ TENSOR_TYPE_CLASSES = []
257
+
227
258
  def self._make_tensor_class(dtype, cuda = false)
228
259
  cls = Class.new
229
260
  device = cuda ? "cuda" : "cpu"
230
261
  cls.define_singleton_method("new") do |*args|
231
262
  if args.size == 1 && args.first.is_a?(Tensor)
232
263
  args.first.send(dtype).to(device)
264
+ elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
265
+ bytes = args.first.bytes
266
+ Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
233
267
  elsif args.size == 1 && args.first.is_a?(Array)
234
268
  Torch.tensor(args.first, dtype: dtype, device: device)
235
269
  else
236
270
  Torch.empty(*args, dtype: dtype, device: device)
237
271
  end
238
272
  end
273
+ TENSOR_TYPE_CLASSES << cls
239
274
  cls
240
275
  end
241
276
 
@@ -22,21 +22,43 @@ module Torch
22
22
  end
23
23
 
24
24
  def bind_functions(context, def_method, functions)
25
+ instance_method = def_method == :define_method
25
26
  functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
26
- if def_method == :define_method
27
+ if instance_method
27
28
  funcs.map! { |f| Function.new(f.function) }
28
- funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
29
+ funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
29
30
  end
30
31
 
31
- defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
32
+ defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
32
33
  next if defined && name != "clone"
33
34
 
34
- parser = Parser.new(funcs)
35
+ # skip parser when possible for performance
36
+ if funcs.size == 1 && funcs.first.args.size == 0
37
+ # functions with no arguments
38
+ if instance_method
39
+ context.send(:alias_method, name, funcs.first.cpp_name)
40
+ else
41
+ context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
42
+ end
43
+ elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
44
+ # functions that take a tensor or scalar
45
+ scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
46
+ context.send(def_method, name) do |other|
47
+ case other
48
+ when Tensor
49
+ send(tensor_name, other)
50
+ else
51
+ send(scalar_name, other)
52
+ end
53
+ end
54
+ else
55
+ parser = Parser.new(funcs)
35
56
 
36
- context.send(def_method, name) do |*args, **options|
37
- result = parser.parse(args, options)
38
- raise ArgumentError, result[:error] if result[:error]
39
- send(result[:name], *result[:args])
57
+ context.send(def_method, name) do |*args, **options|
58
+ result = parser.parse(args, options)
59
+ raise ArgumentError, result[:error] if result[:error]
60
+ send(result[:name], *result[:args])
61
+ end
40
62
  end
41
63
  end
42
64
  end
@@ -6,9 +6,10 @@ module Torch
6
6
  def initialize(function)
7
7
  @function = function
8
8
 
9
- tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
- @tensor_options = @function["func"].include?(tensor_options_str)
11
- @function["func"].sub!(tensor_options_str, ")")
9
+ # note: don't modify function in-place
10
+ @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
11
+ @tensor_options = @function["func"].include?(@tensor_options_str)
12
+ @out = out_size > 0 && base_name[-1] != "_"
12
13
  end
13
14
 
14
15
  def func
@@ -31,7 +32,7 @@ module Torch
31
32
  @args ||= begin
32
33
  args = []
33
34
  pos = true
34
- args_str = func.split("(", 2).last.split(") ->").first
35
+ args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
35
36
  args_str.split(", ").each do |a|
36
37
  if a == "*"
37
38
  pos = false
@@ -72,12 +73,88 @@ module Torch
72
73
  next if t == "Generator?"
73
74
  next if t == "MemoryFormat"
74
75
  next if t == "MemoryFormat?"
75
- args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
76
+ args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
76
77
  end
77
78
  args
78
79
  end
79
80
  end
80
81
 
82
+ def arg_checkers
83
+ @arg_checkers ||= begin
84
+ checkers = {}
85
+ arg_types.each do |k, t|
86
+ checker =
87
+ case t
88
+ when "Tensor"
89
+ ->(v) { v.is_a?(Tensor) }
90
+ when "Tensor?"
91
+ ->(v) { v.nil? || v.is_a?(Tensor) }
92
+ when "Tensor[]", "Tensor?[]"
93
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
94
+ when "int"
95
+ if k == :reduction
96
+ ->(v) { v.is_a?(String) }
97
+ else
98
+ ->(v) { v.is_a?(Integer) }
99
+ end
100
+ when "int?"
101
+ ->(v) { v.is_a?(Integer) || v.nil? }
102
+ when "float?"
103
+ ->(v) { v.is_a?(Numeric) || v.nil? }
104
+ when "bool?"
105
+ ->(v) { v == true || v == false || v.nil? }
106
+ when "float"
107
+ ->(v) { v.is_a?(Numeric) }
108
+ when /int\[.*\]/
109
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
110
+ when "Scalar"
111
+ ->(v) { v.is_a?(Numeric) }
112
+ when "Scalar?"
113
+ ->(v) { v.is_a?(Numeric) || v.nil? }
114
+ when "ScalarType"
115
+ ->(v) { false } # not supported yet
116
+ when "ScalarType?"
117
+ ->(v) { v.nil? }
118
+ when "bool"
119
+ ->(v) { v == true || v == false }
120
+ when "str"
121
+ ->(v) { v.is_a?(String) }
122
+ else
123
+ raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
124
+ end
125
+ checkers[k] = checker
126
+ end
127
+ checkers
128
+ end
129
+ end
130
+
131
+ def int_array_lengths
132
+ @int_array_lengths ||= begin
133
+ ret = {}
134
+ arg_types.each do |k, t|
135
+ if t.match?(/\Aint\[.+\]\z/)
136
+ size = t[4..-2]
137
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
138
+ ret[k] = size.to_i
139
+ end
140
+ end
141
+ ret
142
+ end
143
+ end
144
+
145
+ def arg_names
146
+ @arg_names ||= args.map { |a| a[:name] }
147
+ end
148
+
149
+ def arg_types
150
+ @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
151
+ end
152
+
153
+ def arg_defaults
154
+ # TODO find out why can't use select here
155
+ @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
156
+ end
157
+
81
158
  def out_size
82
159
  @out_size ||= func.split("->").last.count("!")
83
160
  end
@@ -90,8 +167,12 @@ module Torch
90
167
  @ret_array ||= func.split("->").last.include?('[]')
91
168
  end
92
169
 
170
+ def ret_void?
171
+ func.split("->").last.strip == "()"
172
+ end
173
+
93
174
  def out?
94
- out_size > 0 && base_name[-1] != "_"
175
+ @out
95
176
  end
96
177
 
97
178
  def ruby_name
@@ -72,16 +72,18 @@ void add_%{type}_functions(Module m);
72
72
  #include <rice/Module.hpp>
73
73
  #include "templates.hpp"
74
74
 
75
+ %{functions}
76
+
75
77
  void add_%{type}_functions(Module m) {
76
- m
77
- %{functions};
78
+ %{add_functions}
78
79
  }
79
80
  TEMPLATE
80
81
 
81
82
  cpp_defs = []
83
+ add_defs = []
82
84
  functions.sort_by(&:cpp_name).each do |func|
83
85
  fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
84
- fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
86
+ fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
85
87
 
86
88
  cpp_args = []
87
89
  fargs.each do |a|
@@ -94,11 +96,9 @@ void add_%{type}_functions(Module m) {
94
96
  "OptionalTensor"
95
97
  when "ScalarType?"
96
98
  "torch::optional<ScalarType>"
97
- when "Tensor[]"
98
- "TensorList"
99
- when "Tensor?[]"
99
+ when "Tensor[]", "Tensor?[]"
100
100
  # TODO make optional
101
- "TensorList"
101
+ "std::vector<Tensor>"
102
102
  when "int"
103
103
  "int64_t"
104
104
  when "int?"
@@ -112,43 +112,53 @@ void add_%{type}_functions(Module m) {
112
112
  when "float"
113
113
  "double"
114
114
  when /\Aint\[/
115
- "IntArrayRef"
115
+ "std::vector<int64_t>"
116
116
  when /Tensor\(\S!?\)/
117
117
  "Tensor &"
118
118
  when "str"
119
119
  "std::string"
120
120
  when "TensorOptions"
121
121
  "const torch::TensorOptions &"
122
- else
122
+ when "Layout?"
123
+ "torch::optional<Layout>"
124
+ when "Device?"
125
+ "torch::optional<Device>"
126
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
123
127
  a[:type]
128
+ else
129
+ raise "Unknown type: #{a[:type]}"
124
130
  end
125
131
 
126
- t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
132
+ t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
127
133
  cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
128
134
  end
129
135
 
130
136
  dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
131
137
  args = fargs.map { |a| a[:name] }
132
138
  args.unshift(*args.pop(func.out_size)) if func.out?
133
- args.delete("self") if def_method == :define_method
139
+ args.delete(:self) if def_method == :define_method
134
140
 
135
141
  prefix = def_method == :define_method ? "self." : "torch::"
136
142
 
137
143
  body = "#{prefix}#{dispatch}(#{args.join(", ")})"
138
144
 
139
- if func.ret_size > 1 || func.ret_array?
145
+ if func.cpp_name == "_fill_diagonal_"
146
+ body = "to_ruby<torch::Tensor>(#{body})"
147
+ elsif !func.ret_void?
140
148
  body = "wrap(#{body})"
141
149
  end
142
150
 
143
- cpp_defs << ".#{def_method}(
144
- \"#{func.cpp_name}\",
145
- *[](#{cpp_args.join(", ")}) {
146
- return #{body};
147
- })"
151
+ cpp_defs << "// #{func.func}
152
+ static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
153
+ {
154
+ return #{body};
155
+ }"
156
+
157
+ add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
148
158
  end
149
159
 
150
160
  hpp_contents = hpp_template % {type: type}
151
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
161
+ cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
152
162
 
153
163
  path = File.expand_path("../../../ext/torch", __dir__)
154
164
  File.write("#{path}/#{type}_functions.hpp", hpp_contents)
@@ -6,14 +6,24 @@ module Torch
6
6
  @name = @functions.first.ruby_name
7
7
  @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
8
  @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
9
10
  end
10
11
 
12
+ # TODO improve performance
13
+ # possibly move to C++ (see python_arg_parser.cpp)
11
14
  def parse(args, options)
12
15
  candidates = @functions.dup
13
16
 
14
- # remove nil
15
- while args.any? && args.last.nil?
16
- args.pop
17
+ # TODO check candidates individually to see if they match
18
+ if @int_array_first
19
+ int_args = []
20
+ while args.first.is_a?(Integer)
21
+ int_args << args.shift
22
+ end
23
+ if int_args.any?
24
+ raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
25
+ args.unshift(int_args)
26
+ end
17
27
  end
18
28
 
19
29
  # TODO account for args passed as options here
@@ -25,99 +35,60 @@ module Torch
25
35
 
26
36
  candidates.reject! { |f| args.size > f.args.size }
27
37
 
28
- # exclude functions missing required options
29
- candidates.reject! do |func|
30
- # TODO make more generic
31
- func.out? && !options[:out]
32
- end
33
-
34
38
  # handle out with multiple
35
39
  # there should only be one match, so safe to modify all
36
- out_func = candidates.find { |f| f.out? }
37
- if out_func && out_func.out_size > 1 && options[:out]
38
- out_args = out_func.args.last(2).map { |a| a[:name] }
39
- out_args.zip(options.delete(:out)).each do |k, v|
40
- options[k.to_sym] = v
41
- end
42
- candidates = [out_func]
43
- end
44
-
45
- # exclude functions where options don't match
46
- options.each do |k, v|
47
- candidates.select! do |func|
48
- func.args.any? { |a| a[:name] == k.to_s }
40
+ if options[:out]
41
+ if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
42
+ out_args = out_func.args.last(2).map { |a| a[:name] }
43
+ out_args.zip(options.delete(:out)).each do |k, v|
44
+ options[k] = v
45
+ end
46
+ candidates = [out_func]
49
47
  end
50
- # TODO show all bad keywords at once like Ruby?
51
- return {error: "unknown keyword: #{k}"} if candidates.empty?
48
+ else
49
+ # exclude functions missing required options
50
+ candidates.reject!(&:out?)
52
51
  end
53
52
 
54
- final_values = {}
53
+ final_values = nil
55
54
 
56
55
  # check args
57
- candidates.select! do |func|
56
+ while (func = candidates.shift)
58
57
  good = true
59
58
 
60
- values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
61
- values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
62
- func.args.each do |fa|
63
- values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
59
+ # set values
60
+ # TODO use array instead of hash?
61
+ values = {}
62
+ args.each_with_index do |a, i|
63
+ values[func.arg_names[i]] = a
64
+ end
65
+ options.each do |k, v|
66
+ values[k] = v
67
+ end
68
+ func.arg_defaults.each do |k, v|
69
+ values[k] = v unless values.key?(k)
70
+ end
71
+ func.int_array_lengths.each do |k, len|
72
+ values[k] = [values[k]] * len if values[k].is_a?(Integer)
64
73
  end
65
74
 
66
- arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
75
+ arg_checkers = func.arg_checkers
67
76
 
68
77
  values.each_key do |k|
69
- v = values[k]
70
- t = arg_types[k].split("(").first
71
-
72
- good =
73
- case t
74
- when "Tensor"
75
- v.is_a?(Tensor)
76
- when "Tensor?"
77
- v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]", "Tensor?[]"
79
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
- when "int"
81
- if k == "reduction"
82
- v.is_a?(String)
83
- else
84
- v.is_a?(Integer)
85
- end
86
- when "int?"
87
- v.is_a?(Integer) || v.nil?
88
- when "float?"
89
- v.is_a?(Numeric) || v.nil?
90
- when "bool?"
91
- v == true || v == false || v.nil?
92
- when "float"
93
- v.is_a?(Numeric)
94
- when /int\[.*\]/
95
- if v.is_a?(Integer)
96
- size = t[4..-2]
97
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
98
- v = [v] * size.to_i
99
- values[k] = v
100
- end
101
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
102
- when "Scalar"
103
- v.is_a?(Numeric)
104
- when "Scalar?"
105
- v.is_a?(Numeric) || v.nil?
106
- when "ScalarType"
107
- false # not supported yet
108
- when "ScalarType?"
109
- v.nil?
110
- when "bool"
111
- v == true || v == false
112
- when "str"
113
- v.is_a?(String)
114
- else
115
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
78
+ unless arg_checkers.key?(k)
79
+ good = false
80
+ if candidates.empty?
81
+ # TODO show all bad keywords at once like Ruby?
82
+ return {error: "unknown keyword: #{k}"}
116
83
  end
84
+ break
85
+ end
117
86
 
118
- if !good
119
- if candidates.size == 1
120
- k = "input" if k == "self"
87
+ unless arg_checkers[k].call(values[k])
88
+ good = false
89
+ if candidates.empty?
90
+ t = func.arg_types[k]
91
+ k = :input if k == :self
121
92
  return {error: "#{@name}(): argument '#{k}' must be #{t}"}
122
93
  end
123
94
  break
@@ -126,17 +97,15 @@ module Torch
126
97
 
127
98
  if good
128
99
  final_values = values
100
+ break
129
101
  end
130
-
131
- good
132
102
  end
133
103
 
134
- if candidates.size != 1
104
+ unless final_values
135
105
  raise Error, "This should never happen. Please report a bug with #{@name}."
136
106
  end
137
107
 
138
- func = candidates.first
139
- args = func.args.map { |a| final_values[a[:name]] }
108
+ args = func.arg_names.map { |k| final_values[k] }
140
109
  args << TensorOptions.new.dtype(6) if func.tensor_options
141
110
  {
142
111
  name: func.cpp_name,
@@ -469,6 +469,77 @@ module Torch
469
469
  Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
470
470
  end
471
471
 
472
+ # vision
473
+
474
+ def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
475
+ if ["nearest", "area"].include?(mode)
476
+ unless align_corners.nil?
477
+ raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
478
+ end
479
+ else
480
+ if align_corners.nil?
481
+ align_corners = false
482
+ end
483
+ end
484
+
485
+ scale_factor_len = input.dim - 2
486
+ scale_factor_list = [nil] * scale_factor_len
487
+ # default value of recompute_scale_factor is False
488
+ if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
489
+ if scale_factor.is_a?(Array)
490
+ _scale_factor_repeated = scale_factor
491
+ else
492
+ _scale_factor_repeated = [scale_factor] * scale_factor_len
493
+ end
494
+ scale_factor_list = _scale_factor_repeated
495
+ end
496
+
497
+ # Give this variable a short name because it has to be repeated multiple times below.
498
+ sfl = scale_factor_list
499
+
500
+ closed_over_args = [input, size, scale_factor, recompute_scale_factor]
501
+ output_size = _interp_output_size(closed_over_args)
502
+ if input.dim == 3 && mode == "nearest"
503
+ NN.upsample_nearest1d(input, output_size, sfl[0])
504
+ elsif input.dim == 4 && mode == "nearest"
505
+ NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
506
+ elsif input.dim == 5 && mode == "nearest"
507
+ NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
508
+ elsif input.dim == 3 && mode == "area"
509
+ adaptive_avg_pool1d(input, output_size)
510
+ elsif input.dim == 4 && mode == "area"
511
+ adaptive_avg_pool2d(input, output_size)
512
+ elsif input.dim == 5 && mode == "area"
513
+ adaptive_avg_pool3d(input, output_size)
514
+ elsif input.dim == 3 && mode == "linear"
515
+ # assert align_corners is not None
516
+ NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
517
+ elsif input.dim == 3 && mode == "bilinear"
518
+ raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
519
+ elsif input.dim == 3 && mode == "trilinear"
520
+ raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
521
+ elsif input.dim == 4 && mode == "linear"
522
+ raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
523
+ elsif input.dim == 4 && mode == "bilinear"
524
+ # assert align_corners is not None
525
+ NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
526
+ elsif input.dim == 4 && mode == "trilinear"
527
+ raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
528
+ elsif input.dim == 5 && mode == "linear"
529
+ raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
530
+ elsif input.dim == 5 && mode == "bilinear"
531
+ raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
532
+ elsif input.dim == 5 && mode == "trilinear"
533
+ # assert align_corners is not None
534
+ NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
535
+ elsif input.dim == 4 && mode == "bicubic"
536
+ # assert align_corners is not None
537
+ NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
538
+ else
539
+ raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
540
+ end
541
+ end
542
+
472
543
  private
473
544
 
474
545
  def softmax_dim(ndim)
@@ -484,6 +555,41 @@ module Torch
484
555
  out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
485
556
  end
486
557
  end
558
+
559
+ def _interp_output_size(closed_over_args)
560
+ input, size, scale_factor, recompute_scale_factor = closed_over_args
561
+ dim = input.dim - 2
562
+ if size.nil? && scale_factor.nil?
563
+ raise ArgumentError, "either size or scale_factor should be defined"
564
+ end
565
+ if !size.nil? && !scale_factor.nil?
566
+ raise ArgumentError, "only one of size or scale_factor should be defined"
567
+ end
568
+ if !scale_factor.nil?
569
+ if scale_factor.is_a?(Array)
570
+ if scale_factor.length != dim
571
+ raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
572
+ end
573
+ end
574
+ end
575
+
576
+ if !size.nil?
577
+ if size.is_a?(Array)
578
+ return size
579
+ else
580
+ return [size] * dim
581
+ end
582
+ end
583
+
584
+ raise "Failed assertion" if scale_factor.nil?
585
+ if scale_factor.is_a?(Array)
586
+ scale_factors = scale_factor
587
+ else
588
+ scale_factors = [scale_factor] * dim
589
+ end
590
+
591
+ dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
592
+ end
487
593
  end
488
594
  end
489
595
 
@@ -58,7 +58,10 @@ module Torch
58
58
 
59
59
  @buffers.each_key do |k|
60
60
  buf = @buffers[k]
61
- @buffers[k] = fn.call(buf) unless buf.nil?
61
+ unless buf.nil?
62
+ @buffers[k] = fn.call(buf)
63
+ instance_variable_set("@#{k}", @buffers[k])
64
+ end
62
65
  end
63
66
 
64
67
  self
@@ -0,0 +1,31 @@
1
+ module Torch
2
+ module NN
3
+ class Upsample < Module
4
+ def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
5
+ super()
6
+ @size = size
7
+ if scale_factor.is_a?(Array)
8
+ @scale_factor = scale_factor.map(&:to_f)
9
+ else
10
+ @scale_factor = scale_factor ? scale_factor.to_f : nil
11
+ end
12
+ @mode = mode
13
+ @align_corners = align_corners
14
+ end
15
+
16
+ def forward(input)
17
+ F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
18
+ end
19
+
20
+ def extra_inspect
21
+ if !@scale_factor.nil?
22
+ info = "scale_factor: #{@scale_factor.inspect}"
23
+ else
24
+ info = "size: #{@size.inspect}"
25
+ end
26
+ info += ", mode: #{@mode.inspect}"
27
+ info
28
+ end
29
+ end
30
+ end
31
+ end
@@ -48,6 +48,11 @@ module Torch
48
48
  end
49
49
 
50
50
  def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ if device.is_a?(Symbol) && !dtype
52
+ dtype = device
53
+ device = nil
54
+ end
55
+
51
56
  device ||= self.device
52
57
  device = Device.new(device) if device.is_a?(String)
53
58
 
@@ -74,10 +79,6 @@ module Torch
74
79
  end
75
80
  end
76
81
 
77
- def shape
78
- dim.times.map { |i| size(i) }
79
- end
80
-
81
82
  # mirror Python len()
82
83
  def length
83
84
  size(0)
@@ -119,9 +120,14 @@ module Torch
119
120
  end
120
121
 
121
122
  def type(dtype)
122
- enum = DTYPE_TO_ENUM[dtype]
123
- raise Error, "Unknown type: #{dtype}" unless enum
124
- _type(enum)
123
+ if dtype.is_a?(Class)
124
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
125
+ dtype.new(self)
126
+ else
127
+ enum = DTYPE_TO_ENUM[dtype]
128
+ raise Error, "Invalid type: #{dtype}" unless enum
129
+ _type(enum)
130
+ end
125
131
  end
126
132
 
127
133
  def reshape(*size)
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.6"
2
+ VERSION = "0.3.7"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.6
4
+ version: 0.3.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-18 00:00:00.000000000 Z
11
+ date: 2020-09-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -238,6 +238,7 @@ files:
238
238
  - lib/torch/nn/tanhshrink.rb
239
239
  - lib/torch/nn/triplet_margin_loss.rb
240
240
  - lib/torch/nn/unfold.rb
241
+ - lib/torch/nn/upsample.rb
241
242
  - lib/torch/nn/utils.rb
242
243
  - lib/torch/nn/weighted_loss.rb
243
244
  - lib/torch/nn/zero_pad2d.rb