torch-rb 0.5.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_tensor_functions(Module m);
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,95 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "torch_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_torch(Rice::Module& m) {
10
+ m.add_handler<torch::Error>(handle_error);
11
+ add_torch_functions(m);
12
+ m.define_singleton_function(
13
+ "grad_enabled?",
14
+ []() {
15
+ return torch::GradMode::is_enabled();
16
+ })
17
+ .define_singleton_function(
18
+ "_set_grad_enabled",
19
+ [](bool enabled) {
20
+ torch::GradMode::set_enabled(enabled);
21
+ })
22
+ .define_singleton_function(
23
+ "manual_seed",
24
+ [](uint64_t seed) {
25
+ return torch::manual_seed(seed);
26
+ })
27
+ // config
28
+ .define_singleton_function(
29
+ "show_config",
30
+ [] {
31
+ return torch::show_config();
32
+ })
33
+ .define_singleton_function(
34
+ "parallel_info",
35
+ [] {
36
+ return torch::get_parallel_info();
37
+ })
38
+ // begin operations
39
+ .define_singleton_function(
40
+ "_save",
41
+ [](const torch::IValue &value) {
42
+ auto v = torch::pickle_save(value);
43
+ std::string str(v.begin(), v.end());
44
+ return str;
45
+ })
46
+ .define_singleton_function(
47
+ "_load",
48
+ [](const std::string &s) {
49
+ std::vector<char> v;
50
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
51
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
+ return torch::pickle_load(v);
53
+ })
54
+ .define_singleton_function(
55
+ "_from_blob",
56
+ [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
+ void *data = const_cast<char *>(s.c_str());
58
+ return torch::from_blob(data, size, options);
59
+ })
60
+ .define_singleton_function(
61
+ "_tensor",
62
+ [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
+ auto dtype = options.dtype();
64
+ torch::Tensor t;
65
+ if (dtype == torch::kBool) {
66
+ std::vector<uint8_t> vec;
67
+ for (long i = 0; i < a.size(); i++) {
68
+ vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
72
+ // TODO use template
73
+ std::vector<c10::complex<double>> vec;
74
+ Object obj;
75
+ for (long i = 0; i < a.size(); i++) {
76
+ obj = a[i];
77
+ vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
78
+ }
79
+ t = torch::tensor(vec, options);
80
+ } else {
81
+ std::vector<float> vec;
82
+ for (long i = 0; i < a.size(); i++) {
83
+ vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
84
+ }
85
+ // hack for requires_grad error
86
+ if (options.requires_grad()) {
87
+ t = torch::tensor(vec, options.requires_grad(c10::nullopt));
88
+ t.set_requires_grad(true);
89
+ } else {
90
+ t = torch::tensor(vec, options);
91
+ }
92
+ }
93
+ return t.reshape(size);
94
+ });
95
+ }
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_torch_functions(Module m);
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -1,13 +1,19 @@
1
1
  #pragma once
2
2
 
3
- #include <rice/Symbol.hpp>
3
+ #include <rice/rice.hpp>
4
+ #include <rice/stl.hpp>
5
+
6
+ // TODO find better place
7
+ inline void handle_error(torch::Error const & ex) {
8
+ throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
9
+ }
4
10
 
5
11
  // keep THP prefix for now to make it easier to compare code
6
12
 
7
13
  extern VALUE THPVariableClass;
8
14
 
9
15
  inline VALUE THPUtils_internSymbol(const std::string& str) {
10
- return Symbol(str);
16
+ return Rice::Symbol(str);
11
17
  }
12
18
 
13
19
  inline std::string THPUtils_unpackSymbol(VALUE obj) {
@@ -1,99 +1,106 @@
1
1
  #pragma once
2
2
 
3
3
  #include <torch/torch.h>
4
- #include <rice/Object.hpp>
4
+ #include <rice/rice.hpp>
5
5
 
6
- inline Object wrap(bool x) {
7
- return to_ruby<bool>(x);
6
+ inline VALUE wrap(bool x) {
7
+ return Rice::detail::To_Ruby<bool>().convert(x);
8
8
  }
9
9
 
10
- inline Object wrap(int64_t x) {
11
- return to_ruby<int64_t>(x);
10
+ inline VALUE wrap(int64_t x) {
11
+ return Rice::detail::To_Ruby<int64_t>().convert(x);
12
12
  }
13
13
 
14
- inline Object wrap(double x) {
15
- return to_ruby<double>(x);
14
+ inline VALUE wrap(double x) {
15
+ return Rice::detail::To_Ruby<double>().convert(x);
16
16
  }
17
17
 
18
- inline Object wrap(torch::Tensor x) {
19
- return to_ruby<torch::Tensor>(x);
18
+ inline VALUE wrap(torch::Tensor x) {
19
+ return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
20
20
  }
21
21
 
22
- inline Object wrap(torch::Scalar x) {
23
- return to_ruby<torch::Scalar>(x);
22
+ inline VALUE wrap(torch::Scalar x) {
23
+ return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
24
24
  }
25
25
 
26
- inline Object wrap(torch::ScalarType x) {
27
- return to_ruby<torch::ScalarType>(x);
26
+ inline VALUE wrap(torch::ScalarType x) {
27
+ return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
28
28
  }
29
29
 
30
- inline Object wrap(torch::QScheme x) {
31
- return to_ruby<torch::QScheme>(x);
30
+ inline VALUE wrap(torch::QScheme x) {
31
+ return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
32
32
  }
33
33
 
34
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
- Array a;
36
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
37
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
38
- return Object(a);
34
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
+ return rb_ary_new3(
36
+ 2,
37
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
38
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
39
+ );
39
40
  }
40
41
 
41
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
42
- Array a;
43
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
44
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
45
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
46
- return Object(a);
42
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
43
+ return rb_ary_new3(
44
+ 3,
45
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
46
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
47
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
48
+ );
47
49
  }
48
50
 
49
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
50
- Array a;
51
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
52
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
53
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
54
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
55
- return Object(a);
51
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
52
+ return rb_ary_new3(
53
+ 4,
54
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
55
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
56
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
57
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
58
+ );
56
59
  }
57
60
 
58
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
59
- Array a;
60
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
61
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
62
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
63
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
64
- a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
65
- return Object(a);
61
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
62
+ return rb_ary_new3(
63
+ 5,
64
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
65
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
66
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
67
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
68
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
69
+ );
66
70
  }
67
71
 
68
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
69
- Array a;
70
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
71
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
72
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
73
- a.push(to_ruby<int64_t>(std::get<3>(x)));
74
- return Object(a);
72
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
73
+ return rb_ary_new3(
74
+ 4,
75
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
76
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
77
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
78
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
79
+ );
75
80
  }
76
81
 
77
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
78
- Array a;
79
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
80
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
81
- a.push(to_ruby<double>(std::get<2>(x)));
82
- a.push(to_ruby<int64_t>(std::get<3>(x)));
83
- return Object(a);
82
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
83
+ return rb_ary_new3(
84
+ 4,
85
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
86
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
87
+ Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
88
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
89
+ );
84
90
  }
85
91
 
86
- inline Object wrap(torch::TensorList x) {
87
- Array a;
88
- for (auto& t : x) {
89
- a.push(to_ruby<torch::Tensor>(t));
92
+ inline VALUE wrap(torch::TensorList x) {
93
+ auto a = rb_ary_new2(x.size());
94
+ for (auto t : x) {
95
+ rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
90
96
  }
91
- return Object(a);
97
+ return a;
92
98
  }
93
99
 
94
- inline Object wrap(std::tuple<double, double> x) {
95
- Array a;
96
- a.push(to_ruby<double>(std::get<0>(x)));
97
- a.push(to_ruby<double>(std::get<1>(x)));
98
- return Object(a);
100
+ inline VALUE wrap(std::tuple<double, double> x) {
101
+ return rb_ary_new3(
102
+ 2,
103
+ Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
104
+ Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
105
+ );
99
106
  }
data/lib/torch.rb CHANGED
@@ -238,8 +238,11 @@ module Torch
238
238
  double: 7,
239
239
  float64: 7,
240
240
  complex_half: 8,
241
+ complex32: 8,
241
242
  complex_float: 9,
243
+ complex64: 9,
242
244
  complex_double: 10,
245
+ complex128: 10,
243
246
  bool: 11,
244
247
  qint8: 12,
245
248
  quint8: 13,
@@ -261,6 +264,8 @@ module Torch
261
264
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
262
265
  elsif args.size == 1 && args.first.is_a?(Array)
263
266
  Torch.tensor(args.first, dtype: dtype, device: device)
267
+ elsif args.size == 0
268
+ Torch.empty(0, dtype: dtype, device: device)
264
269
  else
265
270
  Torch.empty(*args, dtype: dtype, device: device)
266
271
  end
@@ -335,25 +340,24 @@ module Torch
335
340
  }
336
341
  end
337
342
 
338
- def no_grad
339
- previous_value = grad_enabled?
340
- begin
341
- _set_grad_enabled(false)
342
- yield
343
- ensure
344
- _set_grad_enabled(previous_value)
345
- end
343
+ def no_grad(&block)
344
+ grad_enabled(false, &block)
345
+ end
346
+
347
+ def enable_grad(&block)
348
+ grad_enabled(true, &block)
346
349
  end
347
350
 
348
- def enable_grad
351
+ def grad_enabled(value)
349
352
  previous_value = grad_enabled?
350
353
  begin
351
- _set_grad_enabled(true)
354
+ _set_grad_enabled(value)
352
355
  yield
353
356
  ensure
354
357
  _set_grad_enabled(previous_value)
355
358
  end
356
359
  end
360
+ alias_method :set_grad_enabled, :grad_enabled
357
361
 
358
362
  def device(str)
359
363
  Device.new(str)
@@ -393,6 +397,8 @@ module Torch
393
397
  options[:dtype] = :int64
394
398
  elsif data.all? { |v| v == true || v == false }
395
399
  options[:dtype] = :bool
400
+ elsif data.any? { |v| v.is_a?(Complex) }
401
+ options[:dtype] = :complex64
396
402
  end
397
403
  end
398
404
 
@@ -434,7 +440,8 @@ module Torch
434
440
  zeros(input.size, **like_options(input, options))
435
441
  end
436
442
 
437
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
443
+ # center option
444
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
438
445
  if center
439
446
  signal_dim = input.dim
440
447
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -442,12 +449,7 @@ module Torch
442
449
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
443
450
  input = input.view(input.shape[-signal_dim..-1])
444
451
  end
445
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
446
- end
447
-
448
- def clamp(tensor, min, max)
449
- tensor = _clamp_min(tensor, min)
450
- _clamp_max(tensor, max)
452
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
451
453
  end
452
454
 
453
455
  private
@@ -96,8 +96,11 @@ module Torch
96
96
  ret = "%.#{PRINT_OPTS[:precision]}f" % value
97
97
  end
98
98
  elsif @complex_dtype
99
- p = PRINT_OPTS[:precision]
100
- raise NotImplementedYet
99
+ # TODO use float formatter for each part
100
+ precision = PRINT_OPTS[:precision]
101
+ imag = value.imag
102
+ sign = imag >= 0 ? "+" : "-"
103
+ ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
101
104
  else
102
105
  ret = value.to_s
103
106
  end
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
+ attr_reader :in_features, :out_features
5
+
4
6
  def initialize(in_features, out_features, bias: true)
5
7
  super()
6
8
  @in_features = in_features
@@ -113,35 +113,53 @@ module Torch
113
113
  forward(*input, **kwargs)
114
114
  end
115
115
 
116
- def state_dict(destination: nil)
116
+ def state_dict(destination: nil, prefix: "")
117
117
  destination ||= {}
118
- named_parameters.each do |k, v|
119
- destination[k] = v
118
+ save_to_state_dict(destination, prefix: prefix)
119
+
120
+ named_children.each do |name, mod|
121
+ next unless mod
122
+ mod.state_dict(destination: destination, prefix: prefix + name + ".")
120
123
  end
121
124
  destination
122
125
  end
123
126
 
124
- # TODO add strict option
125
- # TODO match PyTorch behavior
126
- def load_state_dict(state_dict)
127
- state_dict.each do |k, input_param|
128
- k1, k2 = k.split(".", 2)
129
- mod = named_modules[k1]
130
- if mod.is_a?(Module)
131
- param = mod.named_parameters[k2]
132
- if param.is_a?(Parameter)
133
- Torch.no_grad do
134
- param.copy!(input_param)
135
- end
136
- else
137
- raise Error, "Unknown parameter: #{k1}"
138
- end
139
- else
140
- raise Error, "Unknown module: #{k1}"
127
+ def load_state_dict(state_dict, strict: true)
128
+ # TODO support strict: false
129
+ raise "strict: false not implemented yet" unless strict
130
+
131
+ missing_keys = []
132
+ unexpected_keys = []
133
+ error_msgs = []
134
+
135
+ # TODO handle metadata
136
+
137
+ _load = lambda do |mod, prefix = ""|
138
+ # TODO handle metadata
139
+ local_metadata = {}
140
+ mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
141
+ mod.named_children.each do |name, child|
142
+ _load.call(child, prefix + name + ".") unless child.nil?
143
+ end
144
+ end
145
+
146
+ _load.call(self)
147
+
148
+ if strict
149
+ if unexpected_keys.any?
150
+ error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
151
+ end
152
+
153
+ if missing_keys.any?
154
+ error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
141
155
  end
142
156
  end
143
157
 
144
- # TODO return missing keys and unexpected keys
158
+ if error_msgs.any?
159
+ # just show first error
160
+ raise Error, error_msgs[0]
161
+ end
162
+
145
163
  nil
146
164
  end
147
165
 
@@ -268,6 +286,12 @@ module Torch
268
286
  named_buffers[name]
269
287
  elsif named_modules.key?(name)
270
288
  named_modules[name]
289
+ elsif method.end_with?("=") && named_modules.key?(method[0..-2])
290
+ if instance_variable_defined?("@#{method[0..-2]}")
291
+ instance_variable_set("@#{method[0..-2]}", *args)
292
+ else
293
+ raise NotImplementedYet
294
+ end
271
295
  else
272
296
  super
273
297
  end
@@ -300,6 +324,68 @@ module Torch
300
324
  def dict
301
325
  instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
302
326
  end
327
+
328
+ def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
329
+ # TODO add hooks
330
+
331
+ # TODO handle non-persistent buffers
332
+ persistent_buffers = named_buffers
333
+ local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
334
+ local_state = local_name_params.select { |_, v| !v.nil? }
335
+
336
+ local_state.each do |name, param|
337
+ key = prefix + name
338
+ if state_dict.key?(key)
339
+ input_param = state_dict[key]
340
+
341
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
342
+ if param.shape.length == 0 && input_param.shape.length == 1
343
+ input_param = input_param[0]
344
+ end
345
+
346
+ if input_param.shape != param.shape
347
+ # local shape should match the one in checkpoint
348
+ error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
349
+ "the shape in current model is #{param.shape}."
350
+ next
351
+ end
352
+
353
+ begin
354
+ Torch.no_grad do
355
+ param.copy!(input_param)
356
+ end
357
+ rescue => e
358
+ error_msgs << "While copying the parameter named #{key.inspect}, " +
359
+ "whose dimensions in the model are #{param.size} and " +
360
+ "whose dimensions in the checkpoint are #{input_param.size}, " +
361
+ "an exception occurred: #{e.inspect}"
362
+ end
363
+ elsif strict
364
+ missing_keys << key
365
+ end
366
+ end
367
+
368
+ if strict
369
+ state_dict.each_key do |key|
370
+ if key.start_with?(prefix)
371
+ input_name = key[prefix.length..-1]
372
+ input_name = input_name.split(".", 2)[0]
373
+ if !named_children.key?(input_name) && !local_state.key?(input_name)
374
+ unexpected_keys << key
375
+ end
376
+ end
377
+ end
378
+ end
379
+ end
380
+
381
+ def save_to_state_dict(destination, prefix: "")
382
+ named_parameters(recurse: false).each do |k, v|
383
+ destination[prefix + k] = v
384
+ end
385
+ named_buffers.each do |k, v|
386
+ destination[prefix + k] = v
387
+ end
388
+ end
303
389
  end
304
390
  end
305
391
  end