torch-rb 0.5.0 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_tensor_functions(Module m);
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,95 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "torch_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_torch(Rice::Module& m) {
10
+ m.add_handler<torch::Error>(handle_error);
11
+ add_torch_functions(m);
12
+ m.define_singleton_function(
13
+ "grad_enabled?",
14
+ []() {
15
+ return torch::GradMode::is_enabled();
16
+ })
17
+ .define_singleton_function(
18
+ "_set_grad_enabled",
19
+ [](bool enabled) {
20
+ torch::GradMode::set_enabled(enabled);
21
+ })
22
+ .define_singleton_function(
23
+ "manual_seed",
24
+ [](uint64_t seed) {
25
+ return torch::manual_seed(seed);
26
+ })
27
+ // config
28
+ .define_singleton_function(
29
+ "show_config",
30
+ [] {
31
+ return torch::show_config();
32
+ })
33
+ .define_singleton_function(
34
+ "parallel_info",
35
+ [] {
36
+ return torch::get_parallel_info();
37
+ })
38
+ // begin operations
39
+ .define_singleton_function(
40
+ "_save",
41
+ [](const torch::IValue &value) {
42
+ auto v = torch::pickle_save(value);
43
+ std::string str(v.begin(), v.end());
44
+ return str;
45
+ })
46
+ .define_singleton_function(
47
+ "_load",
48
+ [](const std::string &s) {
49
+ std::vector<char> v;
50
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
51
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
+ return torch::pickle_load(v);
53
+ })
54
+ .define_singleton_function(
55
+ "_from_blob",
56
+ [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
+ void *data = const_cast<char *>(s.c_str());
58
+ return torch::from_blob(data, size, options);
59
+ })
60
+ .define_singleton_function(
61
+ "_tensor",
62
+ [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
+ auto dtype = options.dtype();
64
+ torch::Tensor t;
65
+ if (dtype == torch::kBool) {
66
+ std::vector<uint8_t> vec;
67
+ for (long i = 0; i < a.size(); i++) {
68
+ vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
72
+ // TODO use template
73
+ std::vector<c10::complex<double>> vec;
74
+ Object obj;
75
+ for (long i = 0; i < a.size(); i++) {
76
+ obj = a[i];
77
+ vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
78
+ }
79
+ t = torch::tensor(vec, options);
80
+ } else {
81
+ std::vector<float> vec;
82
+ for (long i = 0; i < a.size(); i++) {
83
+ vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
84
+ }
85
+ // hack for requires_grad error
86
+ if (options.requires_grad()) {
87
+ t = torch::tensor(vec, options.requires_grad(c10::nullopt));
88
+ t.set_requires_grad(true);
89
+ } else {
90
+ t = torch::tensor(vec, options);
91
+ }
92
+ }
93
+ return t.reshape(size);
94
+ });
95
+ }
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_torch_functions(Module m);
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -1,13 +1,19 @@
1
1
  #pragma once
2
2
 
3
- #include <rice/Symbol.hpp>
3
+ #include <rice/rice.hpp>
4
+ #include <rice/stl.hpp>
5
+
6
+ // TODO find better place
7
+ inline void handle_error(torch::Error const & ex) {
8
+ throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
9
+ }
4
10
 
5
11
  // keep THP prefix for now to make it easier to compare code
6
12
 
7
13
  extern VALUE THPVariableClass;
8
14
 
9
15
  inline VALUE THPUtils_internSymbol(const std::string& str) {
10
- return Symbol(str);
16
+ return Rice::Symbol(str);
11
17
  }
12
18
 
13
19
  inline std::string THPUtils_unpackSymbol(VALUE obj) {
@@ -1,99 +1,106 @@
1
1
  #pragma once
2
2
 
3
3
  #include <torch/torch.h>
4
- #include <rice/Object.hpp>
4
+ #include <rice/rice.hpp>
5
5
 
6
- inline Object wrap(bool x) {
7
- return to_ruby<bool>(x);
6
+ inline VALUE wrap(bool x) {
7
+ return Rice::detail::To_Ruby<bool>().convert(x);
8
8
  }
9
9
 
10
- inline Object wrap(int64_t x) {
11
- return to_ruby<int64_t>(x);
10
+ inline VALUE wrap(int64_t x) {
11
+ return Rice::detail::To_Ruby<int64_t>().convert(x);
12
12
  }
13
13
 
14
- inline Object wrap(double x) {
15
- return to_ruby<double>(x);
14
+ inline VALUE wrap(double x) {
15
+ return Rice::detail::To_Ruby<double>().convert(x);
16
16
  }
17
17
 
18
- inline Object wrap(torch::Tensor x) {
19
- return to_ruby<torch::Tensor>(x);
18
+ inline VALUE wrap(torch::Tensor x) {
19
+ return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
20
20
  }
21
21
 
22
- inline Object wrap(torch::Scalar x) {
23
- return to_ruby<torch::Scalar>(x);
22
+ inline VALUE wrap(torch::Scalar x) {
23
+ return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
24
24
  }
25
25
 
26
- inline Object wrap(torch::ScalarType x) {
27
- return to_ruby<torch::ScalarType>(x);
26
+ inline VALUE wrap(torch::ScalarType x) {
27
+ return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
28
28
  }
29
29
 
30
- inline Object wrap(torch::QScheme x) {
31
- return to_ruby<torch::QScheme>(x);
30
+ inline VALUE wrap(torch::QScheme x) {
31
+ return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
32
32
  }
33
33
 
34
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
- Array a;
36
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
37
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
38
- return Object(a);
34
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
+ return rb_ary_new3(
36
+ 2,
37
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
38
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
39
+ );
39
40
  }
40
41
 
41
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
42
- Array a;
43
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
44
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
45
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
46
- return Object(a);
42
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
43
+ return rb_ary_new3(
44
+ 3,
45
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
46
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
47
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
48
+ );
47
49
  }
48
50
 
49
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
50
- Array a;
51
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
52
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
53
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
54
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
55
- return Object(a);
51
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
52
+ return rb_ary_new3(
53
+ 4,
54
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
55
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
56
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
57
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
58
+ );
56
59
  }
57
60
 
58
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
59
- Array a;
60
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
61
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
62
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
63
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
64
- a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
65
- return Object(a);
61
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
62
+ return rb_ary_new3(
63
+ 5,
64
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
65
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
66
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
67
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
68
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
69
+ );
66
70
  }
67
71
 
68
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
69
- Array a;
70
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
71
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
72
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
73
- a.push(to_ruby<int64_t>(std::get<3>(x)));
74
- return Object(a);
72
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
73
+ return rb_ary_new3(
74
+ 4,
75
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
76
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
77
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
78
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
79
+ );
75
80
  }
76
81
 
77
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
78
- Array a;
79
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
80
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
81
- a.push(to_ruby<double>(std::get<2>(x)));
82
- a.push(to_ruby<int64_t>(std::get<3>(x)));
83
- return Object(a);
82
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
83
+ return rb_ary_new3(
84
+ 4,
85
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
86
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
87
+ Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
88
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
89
+ );
84
90
  }
85
91
 
86
- inline Object wrap(torch::TensorList x) {
87
- Array a;
88
- for (auto& t : x) {
89
- a.push(to_ruby<torch::Tensor>(t));
92
+ inline VALUE wrap(torch::TensorList x) {
93
+ auto a = rb_ary_new2(x.size());
94
+ for (auto t : x) {
95
+ rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
90
96
  }
91
- return Object(a);
97
+ return a;
92
98
  }
93
99
 
94
- inline Object wrap(std::tuple<double, double> x) {
95
- Array a;
96
- a.push(to_ruby<double>(std::get<0>(x)));
97
- a.push(to_ruby<double>(std::get<1>(x)));
98
- return Object(a);
100
+ inline VALUE wrap(std::tuple<double, double> x) {
101
+ return rb_ary_new3(
102
+ 2,
103
+ Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
104
+ Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
105
+ );
99
106
  }
data/lib/torch.rb CHANGED
@@ -238,8 +238,11 @@ module Torch
238
238
  double: 7,
239
239
  float64: 7,
240
240
  complex_half: 8,
241
+ complex32: 8,
241
242
  complex_float: 9,
243
+ complex64: 9,
242
244
  complex_double: 10,
245
+ complex128: 10,
243
246
  bool: 11,
244
247
  qint8: 12,
245
248
  quint8: 13,
@@ -261,6 +264,8 @@ module Torch
261
264
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
262
265
  elsif args.size == 1 && args.first.is_a?(Array)
263
266
  Torch.tensor(args.first, dtype: dtype, device: device)
267
+ elsif args.size == 0
268
+ Torch.empty(0, dtype: dtype, device: device)
264
269
  else
265
270
  Torch.empty(*args, dtype: dtype, device: device)
266
271
  end
@@ -335,25 +340,24 @@ module Torch
335
340
  }
336
341
  end
337
342
 
338
- def no_grad
339
- previous_value = grad_enabled?
340
- begin
341
- _set_grad_enabled(false)
342
- yield
343
- ensure
344
- _set_grad_enabled(previous_value)
345
- end
343
+ def no_grad(&block)
344
+ grad_enabled(false, &block)
345
+ end
346
+
347
+ def enable_grad(&block)
348
+ grad_enabled(true, &block)
346
349
  end
347
350
 
348
- def enable_grad
351
+ def grad_enabled(value)
349
352
  previous_value = grad_enabled?
350
353
  begin
351
- _set_grad_enabled(true)
354
+ _set_grad_enabled(value)
352
355
  yield
353
356
  ensure
354
357
  _set_grad_enabled(previous_value)
355
358
  end
356
359
  end
360
+ alias_method :set_grad_enabled, :grad_enabled
357
361
 
358
362
  def device(str)
359
363
  Device.new(str)
@@ -393,6 +397,8 @@ module Torch
393
397
  options[:dtype] = :int64
394
398
  elsif data.all? { |v| v == true || v == false }
395
399
  options[:dtype] = :bool
400
+ elsif data.any? { |v| v.is_a?(Complex) }
401
+ options[:dtype] = :complex64
396
402
  end
397
403
  end
398
404
 
@@ -434,7 +440,8 @@ module Torch
434
440
  zeros(input.size, **like_options(input, options))
435
441
  end
436
442
 
437
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
443
+ # center option
444
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
438
445
  if center
439
446
  signal_dim = input.dim
440
447
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -442,12 +449,7 @@ module Torch
442
449
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
443
450
  input = input.view(input.shape[-signal_dim..-1])
444
451
  end
445
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
446
- end
447
-
448
- def clamp(tensor, min, max)
449
- tensor = _clamp_min(tensor, min)
450
- _clamp_max(tensor, max)
452
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
451
453
  end
452
454
 
453
455
  private
@@ -96,8 +96,11 @@ module Torch
96
96
  ret = "%.#{PRINT_OPTS[:precision]}f" % value
97
97
  end
98
98
  elsif @complex_dtype
99
- p = PRINT_OPTS[:precision]
100
- raise NotImplementedYet
99
+ # TODO use float formatter for each part
100
+ precision = PRINT_OPTS[:precision]
101
+ imag = value.imag
102
+ sign = imag >= 0 ? "+" : "-"
103
+ ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
101
104
  else
102
105
  ret = value.to_s
103
106
  end
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
+ attr_reader :in_features, :out_features
5
+
4
6
  def initialize(in_features, out_features, bias: true)
5
7
  super()
6
8
  @in_features = in_features
@@ -113,35 +113,53 @@ module Torch
113
113
  forward(*input, **kwargs)
114
114
  end
115
115
 
116
- def state_dict(destination: nil)
116
+ def state_dict(destination: nil, prefix: "")
117
117
  destination ||= {}
118
- named_parameters.each do |k, v|
119
- destination[k] = v
118
+ save_to_state_dict(destination, prefix: prefix)
119
+
120
+ named_children.each do |name, mod|
121
+ next unless mod
122
+ mod.state_dict(destination: destination, prefix: prefix + name + ".")
120
123
  end
121
124
  destination
122
125
  end
123
126
 
124
- # TODO add strict option
125
- # TODO match PyTorch behavior
126
- def load_state_dict(state_dict)
127
- state_dict.each do |k, input_param|
128
- k1, k2 = k.split(".", 2)
129
- mod = named_modules[k1]
130
- if mod.is_a?(Module)
131
- param = mod.named_parameters[k2]
132
- if param.is_a?(Parameter)
133
- Torch.no_grad do
134
- param.copy!(input_param)
135
- end
136
- else
137
- raise Error, "Unknown parameter: #{k1}"
138
- end
139
- else
140
- raise Error, "Unknown module: #{k1}"
127
+ def load_state_dict(state_dict, strict: true)
128
+ # TODO support strict: false
129
+ raise "strict: false not implemented yet" unless strict
130
+
131
+ missing_keys = []
132
+ unexpected_keys = []
133
+ error_msgs = []
134
+
135
+ # TODO handle metadata
136
+
137
+ _load = lambda do |mod, prefix = ""|
138
+ # TODO handle metadata
139
+ local_metadata = {}
140
+ mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
141
+ mod.named_children.each do |name, child|
142
+ _load.call(child, prefix + name + ".") unless child.nil?
143
+ end
144
+ end
145
+
146
+ _load.call(self)
147
+
148
+ if strict
149
+ if unexpected_keys.any?
150
+ error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
151
+ end
152
+
153
+ if missing_keys.any?
154
+ error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
141
155
  end
142
156
  end
143
157
 
144
- # TODO return missing keys and unexpected keys
158
+ if error_msgs.any?
159
+ # just show first error
160
+ raise Error, error_msgs[0]
161
+ end
162
+
145
163
  nil
146
164
  end
147
165
 
@@ -268,6 +286,12 @@ module Torch
268
286
  named_buffers[name]
269
287
  elsif named_modules.key?(name)
270
288
  named_modules[name]
289
+ elsif method.end_with?("=") && named_modules.key?(method[0..-2])
290
+ if instance_variable_defined?("@#{method[0..-2]}")
291
+ instance_variable_set("@#{method[0..-2]}", *args)
292
+ else
293
+ raise NotImplementedYet
294
+ end
271
295
  else
272
296
  super
273
297
  end
@@ -300,6 +324,68 @@ module Torch
300
324
  def dict
301
325
  instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
302
326
  end
327
+
328
+ def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
329
+ # TODO add hooks
330
+
331
+ # TODO handle non-persistent buffers
332
+ persistent_buffers = named_buffers
333
+ local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
334
+ local_state = local_name_params.select { |_, v| !v.nil? }
335
+
336
+ local_state.each do |name, param|
337
+ key = prefix + name
338
+ if state_dict.key?(key)
339
+ input_param = state_dict[key]
340
+
341
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
342
+ if param.shape.length == 0 && input_param.shape.length == 1
343
+ input_param = input_param[0]
344
+ end
345
+
346
+ if input_param.shape != param.shape
347
+ # local shape should match the one in checkpoint
348
+ error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
349
+ "the shape in current model is #{param.shape}."
350
+ next
351
+ end
352
+
353
+ begin
354
+ Torch.no_grad do
355
+ param.copy!(input_param)
356
+ end
357
+ rescue => e
358
+ error_msgs << "While copying the parameter named #{key.inspect}, " +
359
+ "whose dimensions in the model are #{param.size} and " +
360
+ "whose dimensions in the checkpoint are #{input_param.size}, " +
361
+ "an exception occurred: #{e.inspect}"
362
+ end
363
+ elsif strict
364
+ missing_keys << key
365
+ end
366
+ end
367
+
368
+ if strict
369
+ state_dict.each_key do |key|
370
+ if key.start_with?(prefix)
371
+ input_name = key[prefix.length..-1]
372
+ input_name = input_name.split(".", 2)[0]
373
+ if !named_children.key?(input_name) && !local_state.key?(input_name)
374
+ unexpected_keys << key
375
+ end
376
+ end
377
+ end
378
+ end
379
+ end
380
+
381
+ def save_to_state_dict(destination, prefix: "")
382
+ named_parameters(recurse: false).each do |k, v|
383
+ destination[prefix + k] = v
384
+ end
385
+ named_buffers.each do |k, v|
386
+ destination[prefix + k] = v
387
+ end
388
+ end
303
389
  end
304
390
  end
305
391
  end