torch-rb 0.3.5 → 0.3.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 93271ffd62be6e35c6ea3a2219a7bc3dccbe8489d6f4aca1a1f00f99bab1a4bb
4
- data.tar.gz: df2755ac3e6221502430d780d116a1145c461f97857e7b1b2b809095afaad9e5
3
+ metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
4
+ data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
5
5
  SHA512:
6
- metadata.gz: 9e02de90a7a83e5d4421941a0ceea69c6367f42be2e97e4229812f35c27f83475fbace86f42285128da724343dcfd85050b7846d81d43fce100749be0072ad4c
7
- data.tar.gz: 884873c3c965f16b0a833087909019ebc6f228511a9b0ecbf4c436cc546e28012e63899651aeab6befddbe7b057676b170cd3eb858b997f42289c103252a2834
6
+ metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
7
+ data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
@@ -1,3 +1,9 @@
1
+ ## 0.3.6 (2020-09-17)
2
+
3
+ - Added `inplace` option for leaky ReLU
4
+ - Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
5
+ - Fixed error with buffers on GPU
6
+
1
7
  ## 0.3.5 (2020-09-04)
2
8
 
3
9
  - Fixed error with data loader (due to `dtype` of `randperm`)
@@ -301,11 +301,6 @@ void Init_ext()
301
301
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
302
302
  return torch::pickle_load(v);
303
303
  })
304
- .define_singleton_method(
305
- "_binary_cross_entropy_with_logits",
306
- *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
307
- return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
308
- })
309
304
  .define_singleton_method(
310
305
  "_from_blob",
311
306
  *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
@@ -379,11 +374,6 @@ void Init_ext()
379
374
  *[](Tensor& self, bool requires_grad) {
380
375
  return self.set_requires_grad(requires_grad);
381
376
  })
382
- .define_method(
383
- "_backward",
384
- *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
385
- return self.backward(gradient, create_graph, retain_graph);
386
- })
387
377
  .define_method(
388
378
  "grad",
389
379
  *[](Tensor& self) {
@@ -53,3 +53,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
53
53
  a.push(to_ruby<int64_t>(std::get<3>(x)));
54
54
  return Object(a);
55
55
  }
56
+
57
+ Object wrap(std::vector<torch::Tensor> x) {
58
+ Array a;
59
+ for (auto& t : x) {
60
+ a.push(to_ruby<torch::Tensor>(t));
61
+ }
62
+ return Object(a);
63
+ }
@@ -10,6 +10,7 @@
10
10
  using namespace Rice;
11
11
 
12
12
  using torch::Device;
13
+ using torch::Scalar;
13
14
  using torch::ScalarType;
14
15
  using torch::Tensor;
15
16
 
@@ -36,30 +37,6 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
36
37
  return IntArrayRef(x);
37
38
  }
38
39
 
39
- // for now
40
- class Scalar {
41
- torch::Scalar value;
42
- public:
43
- Scalar(Object o) {
44
- // TODO cast based on Ruby type
45
- if (o.rb_type() == T_FIXNUM) {
46
- value = torch::Scalar(from_ruby<int64_t>(o));
47
- } else {
48
- value = torch::Scalar(from_ruby<float>(o));
49
- }
50
- }
51
- operator torch::Scalar() {
52
- return value;
53
- }
54
- };
55
-
56
- template<>
57
- inline
58
- Scalar from_ruby<Scalar>(Object x)
59
- {
60
- return Scalar(x);
61
- }
62
-
63
40
  class TensorList {
64
41
  std::vector<torch::Tensor> vec;
65
42
  public:
@@ -192,6 +169,17 @@ class OptionalTensor {
192
169
  }
193
170
  };
194
171
 
172
+ template<>
173
+ inline
174
+ Scalar from_ruby<Scalar>(Object x)
175
+ {
176
+ if (x.rb_type() == T_FIXNUM) {
177
+ return torch::Scalar(from_ruby<int64_t>(x));
178
+ } else {
179
+ return torch::Scalar(from_ruby<double>(x));
180
+ }
181
+ }
182
+
195
183
  template<>
196
184
  inline
197
185
  OptionalTensor from_ruby<OptionalTensor>(Object x)
@@ -221,9 +209,43 @@ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
221
209
  }
222
210
  }
223
211
 
212
+ template<>
213
+ inline
214
+ torch::optional<double> from_ruby<torch::optional<double>>(Object x)
215
+ {
216
+ if (x.is_nil()) {
217
+ return torch::nullopt;
218
+ } else {
219
+ return torch::optional<double>{from_ruby<double>(x)};
220
+ }
221
+ }
222
+
223
+ template<>
224
+ inline
225
+ torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
226
+ {
227
+ if (x.is_nil()) {
228
+ return torch::nullopt;
229
+ } else {
230
+ return torch::optional<bool>{from_ruby<bool>(x)};
231
+ }
232
+ }
233
+
234
+ template<>
235
+ inline
236
+ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
237
+ {
238
+ if (x.is_nil()) {
239
+ return torch::nullopt;
240
+ } else {
241
+ return torch::optional<Scalar>{from_ruby<Scalar>(x)};
242
+ }
243
+ }
244
+
224
245
  Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
225
246
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
226
247
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
227
248
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
228
249
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
229
250
  Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
251
+ Object wrap(std::vector<torch::Tensor> x);
@@ -239,25 +239,22 @@ module Torch
239
239
  cls
240
240
  end
241
241
 
242
- FloatTensor = _make_tensor_class(:float32)
243
- DoubleTensor = _make_tensor_class(:float64)
244
- HalfTensor = _make_tensor_class(:float16)
245
- ByteTensor = _make_tensor_class(:uint8)
246
- CharTensor = _make_tensor_class(:int8)
247
- ShortTensor = _make_tensor_class(:int16)
248
- IntTensor = _make_tensor_class(:int32)
249
- LongTensor = _make_tensor_class(:int64)
250
- BoolTensor = _make_tensor_class(:bool)
251
-
252
- CUDA::FloatTensor = _make_tensor_class(:float32, true)
253
- CUDA::DoubleTensor = _make_tensor_class(:float64, true)
254
- CUDA::HalfTensor = _make_tensor_class(:float16, true)
255
- CUDA::ByteTensor = _make_tensor_class(:uint8, true)
256
- CUDA::CharTensor = _make_tensor_class(:int8, true)
257
- CUDA::ShortTensor = _make_tensor_class(:int16, true)
258
- CUDA::IntTensor = _make_tensor_class(:int32, true)
259
- CUDA::LongTensor = _make_tensor_class(:int64, true)
260
- CUDA::BoolTensor = _make_tensor_class(:bool, true)
242
+ DTYPE_TO_CLASS = {
243
+ float32: "FloatTensor",
244
+ float64: "DoubleTensor",
245
+ float16: "HalfTensor",
246
+ uint8: "ByteTensor",
247
+ int8: "CharTensor",
248
+ int16: "ShortTensor",
249
+ int32: "IntTensor",
250
+ int64: "LongTensor",
251
+ bool: "BoolTensor"
252
+ }
253
+
254
+ DTYPE_TO_CLASS.each do |dtype, class_name|
255
+ const_set(class_name, _make_tensor_class(dtype))
256
+ CUDA.const_set(class_name, _make_tensor_class(dtype, true))
257
+ end
261
258
 
262
259
  class << self
263
260
  # Torch.float, Torch.long, etc
@@ -86,6 +86,10 @@ module Torch
86
86
  @ret_size ||= func.split("->").last.split(", ").size
87
87
  end
88
88
 
89
+ def ret_array?
90
+ @ret_array ||= func.split("->").last.include?('[]')
91
+ end
92
+
89
93
  def out?
90
94
  out_size > 0 && base_name[-1] != "_"
91
95
  end
@@ -18,12 +18,12 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
25
25
  f.ruby_name.start_with?("_") ||
26
- f.ruby_name.end_with?("_backward") ||
26
+ f.ruby_name.include?("_backward") ||
27
27
  f.args.any? { |a| a[:type].include?("Dimname") }
28
28
  end
29
29
 
@@ -31,7 +31,6 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
34
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
35
  # call to 'range' is ambiguous
37
36
  f.cpp_name == "_range" ||
@@ -104,6 +103,12 @@ void add_%{type}_functions(Module m) {
104
103
  "int64_t"
105
104
  when "int?"
106
105
  "torch::optional<int64_t>"
106
+ when "float?"
107
+ "torch::optional<double>"
108
+ when "bool?"
109
+ "torch::optional<bool>"
110
+ when "Scalar?"
111
+ "torch::optional<torch::Scalar>"
107
112
  when "float"
108
113
  "double"
109
114
  when /\Aint\[/
@@ -130,8 +135,8 @@ void add_%{type}_functions(Module m) {
130
135
  prefix = def_method == :define_method ? "self." : "torch::"
131
136
 
132
137
  body = "#{prefix}#{dispatch}(#{args.join(", ")})"
133
- # TODO check type as well
134
- if func.ret_size > 1
138
+
139
+ if func.ret_size > 1 || func.ret_array?
135
140
  body = "wrap(#{body})"
136
141
  end
137
142
 
@@ -85,6 +85,10 @@ module Torch
85
85
  end
86
86
  when "int?"
87
87
  v.is_a?(Integer) || v.nil?
88
+ when "float?"
89
+ v.is_a?(Numeric) || v.nil?
90
+ when "bool?"
91
+ v == true || v == false || v.nil?
88
92
  when "float"
89
93
  v.is_a?(Numeric)
90
94
  when /int\[.*\]/
@@ -97,6 +101,10 @@ module Torch
97
101
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
98
102
  when "Scalar"
99
103
  v.is_a?(Numeric)
104
+ when "Scalar?"
105
+ v.is_a?(Numeric) || v.nil?
106
+ when "ScalarType"
107
+ false # not supported yet
100
108
  when "ScalarType?"
101
109
  v.nil?
102
110
  when "bool"
@@ -178,8 +178,12 @@ module Torch
178
178
  Torch.hardshrink(input, lambd)
179
179
  end
180
180
 
181
- def leaky_relu(input, negative_slope = 0.01)
182
- NN.leaky_relu(input, negative_slope)
181
+ def leaky_relu(input, negative_slope = 0.01, inplace: false)
182
+ if inplace
183
+ NN.leaky_relu!(input, negative_slope)
184
+ else
185
+ NN.leaky_relu(input, negative_slope)
186
+ end
183
187
  end
184
188
 
185
189
  def log_sigmoid(input)
@@ -1,14 +1,14 @@
1
1
  module Torch
2
2
  module NN
3
3
  class LeakyReLU < Module
4
- def initialize(negative_slope: 1e-2) #, inplace: false)
4
+ def initialize(negative_slope: 1e-2, inplace: false)
5
5
  super()
6
6
  @negative_slope = negative_slope
7
- # @inplace = inplace
7
+ @inplace = inplace
8
8
  end
9
9
 
10
10
  def forward(input)
11
- F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
11
+ F.leaky_relu(input, @negative_slope, inplace: @inplace)
12
12
  end
13
13
 
14
14
  def extra_inspect
@@ -55,7 +55,12 @@ module Torch
55
55
  end
56
56
  end
57
57
  end
58
- # TODO apply to more objects
58
+
59
+ @buffers.each_key do |k|
60
+ buf = @buffers[k]
61
+ @buffers[k] = fn.call(buf) unless buf.nil?
62
+ end
63
+
59
64
  self
60
65
  end
61
66
 
@@ -103,11 +103,6 @@ module Torch
103
103
  Torch.empty(0, dtype: dtype)
104
104
  end
105
105
 
106
- def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
- retain_graph = create_graph if retain_graph.nil?
108
- _backward(gradient, retain_graph, create_graph)
109
- end
110
-
111
106
  # TODO read directly from memory
112
107
  def numo
113
108
  cls = Torch._dtype_to_numo[dtype]
@@ -235,7 +230,7 @@ module Torch
235
230
  when Integer
236
231
  TensorIndex.integer(index)
237
232
  when Range
238
- finish = index.end
233
+ finish = index.end || -1
239
234
  if finish == -1 && !index.exclude_end?
240
235
  finish = nil
241
236
  else
@@ -45,6 +45,8 @@ module Torch
45
45
  def size
46
46
  (@dataset.size / @batch_size.to_f).ceil
47
47
  end
48
+ alias_method :length, :size
49
+ alias_method :count, :size
48
50
 
49
51
  private
50
52
 
@@ -16,6 +16,8 @@ module Torch
16
16
  def size
17
17
  @tensors[0].size(0)
18
18
  end
19
+ alias_method :length, :size
20
+ alias_method :count, :size
19
21
  end
20
22
  end
21
23
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.5"
2
+ VERSION = "0.3.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.5
4
+ version: 0.3.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-04 00:00:00.000000000 Z
11
+ date: 2020-09-18 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice