torch-rb 0.5.2 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_tensor_functions(Module m);
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,95 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "torch_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_torch(Rice::Module& m) {
10
+ m.add_handler<torch::Error>(handle_error);
11
+ add_torch_functions(m);
12
+ m.define_singleton_function(
13
+ "grad_enabled?",
14
+ []() {
15
+ return torch::GradMode::is_enabled();
16
+ })
17
+ .define_singleton_function(
18
+ "_set_grad_enabled",
19
+ [](bool enabled) {
20
+ torch::GradMode::set_enabled(enabled);
21
+ })
22
+ .define_singleton_function(
23
+ "manual_seed",
24
+ [](uint64_t seed) {
25
+ return torch::manual_seed(seed);
26
+ })
27
+ // config
28
+ .define_singleton_function(
29
+ "show_config",
30
+ [] {
31
+ return torch::show_config();
32
+ })
33
+ .define_singleton_function(
34
+ "parallel_info",
35
+ [] {
36
+ return torch::get_parallel_info();
37
+ })
38
+ // begin operations
39
+ .define_singleton_function(
40
+ "_save",
41
+ [](const torch::IValue &value) {
42
+ auto v = torch::pickle_save(value);
43
+ std::string str(v.begin(), v.end());
44
+ return str;
45
+ })
46
+ .define_singleton_function(
47
+ "_load",
48
+ [](const std::string &s) {
49
+ std::vector<char> v;
50
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
51
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
+ return torch::pickle_load(v);
53
+ })
54
+ .define_singleton_function(
55
+ "_from_blob",
56
+ [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
+ void *data = const_cast<char *>(s.c_str());
58
+ return torch::from_blob(data, size, options);
59
+ })
60
+ .define_singleton_function(
61
+ "_tensor",
62
+ [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
+ auto dtype = options.dtype();
64
+ torch::Tensor t;
65
+ if (dtype == torch::kBool) {
66
+ std::vector<uint8_t> vec;
67
+ for (long i = 0; i < a.size(); i++) {
68
+ vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
72
+ // TODO use template
73
+ std::vector<c10::complex<double>> vec;
74
+ Object obj;
75
+ for (long i = 0; i < a.size(); i++) {
76
+ obj = a[i];
77
+ vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
78
+ }
79
+ t = torch::tensor(vec, options);
80
+ } else {
81
+ std::vector<float> vec;
82
+ for (long i = 0; i < a.size(); i++) {
83
+ vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
84
+ }
85
+ // hack for requires_grad error
86
+ if (options.requires_grad()) {
87
+ t = torch::tensor(vec, options.requires_grad(c10::nullopt));
88
+ t.set_requires_grad(true);
89
+ } else {
90
+ t = torch::tensor(vec, options);
91
+ }
92
+ }
93
+ return t.reshape(size);
94
+ });
95
+ }
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_torch_functions(Module m);
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -1,13 +1,19 @@
1
1
  #pragma once
2
2
 
3
- #include <rice/Symbol.hpp>
3
+ #include <rice/rice.hpp>
4
+ #include <rice/stl.hpp>
5
+
6
+ // TODO find better place
7
+ inline void handle_error(torch::Error const & ex) {
8
+ throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
9
+ }
4
10
 
5
11
  // keep THP prefix for now to make it easier to compare code
6
12
 
7
13
  extern VALUE THPVariableClass;
8
14
 
9
15
  inline VALUE THPUtils_internSymbol(const std::string& str) {
10
- return Symbol(str);
16
+ return Rice::Symbol(str);
11
17
  }
12
18
 
13
19
  inline std::string THPUtils_unpackSymbol(VALUE obj) {
@@ -1,99 +1,106 @@
1
1
  #pragma once
2
2
 
3
3
  #include <torch/torch.h>
4
- #include <rice/Object.hpp>
4
+ #include <rice/rice.hpp>
5
5
 
6
- inline Object wrap(bool x) {
7
- return to_ruby<bool>(x);
6
+ inline VALUE wrap(bool x) {
7
+ return Rice::detail::To_Ruby<bool>().convert(x);
8
8
  }
9
9
 
10
- inline Object wrap(int64_t x) {
11
- return to_ruby<int64_t>(x);
10
+ inline VALUE wrap(int64_t x) {
11
+ return Rice::detail::To_Ruby<int64_t>().convert(x);
12
12
  }
13
13
 
14
- inline Object wrap(double x) {
15
- return to_ruby<double>(x);
14
+ inline VALUE wrap(double x) {
15
+ return Rice::detail::To_Ruby<double>().convert(x);
16
16
  }
17
17
 
18
- inline Object wrap(torch::Tensor x) {
19
- return to_ruby<torch::Tensor>(x);
18
+ inline VALUE wrap(torch::Tensor x) {
19
+ return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
20
20
  }
21
21
 
22
- inline Object wrap(torch::Scalar x) {
23
- return to_ruby<torch::Scalar>(x);
22
+ inline VALUE wrap(torch::Scalar x) {
23
+ return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
24
24
  }
25
25
 
26
- inline Object wrap(torch::ScalarType x) {
27
- return to_ruby<torch::ScalarType>(x);
26
+ inline VALUE wrap(torch::ScalarType x) {
27
+ return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
28
28
  }
29
29
 
30
- inline Object wrap(torch::QScheme x) {
31
- return to_ruby<torch::QScheme>(x);
30
+ inline VALUE wrap(torch::QScheme x) {
31
+ return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
32
32
  }
33
33
 
34
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
- Array a;
36
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
37
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
38
- return Object(a);
34
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
+ return rb_ary_new3(
36
+ 2,
37
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
38
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
39
+ );
39
40
  }
40
41
 
41
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
42
- Array a;
43
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
44
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
45
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
46
- return Object(a);
42
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
43
+ return rb_ary_new3(
44
+ 3,
45
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
46
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
47
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
48
+ );
47
49
  }
48
50
 
49
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
50
- Array a;
51
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
52
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
53
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
54
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
55
- return Object(a);
51
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
52
+ return rb_ary_new3(
53
+ 4,
54
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
55
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
56
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
57
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
58
+ );
56
59
  }
57
60
 
58
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
59
- Array a;
60
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
61
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
62
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
63
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
64
- a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
65
- return Object(a);
61
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
62
+ return rb_ary_new3(
63
+ 5,
64
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
65
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
66
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
67
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
68
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
69
+ );
66
70
  }
67
71
 
68
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
69
- Array a;
70
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
71
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
72
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
73
- a.push(to_ruby<int64_t>(std::get<3>(x)));
74
- return Object(a);
72
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
73
+ return rb_ary_new3(
74
+ 4,
75
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
76
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
77
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
78
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
79
+ );
75
80
  }
76
81
 
77
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
78
- Array a;
79
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
80
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
81
- a.push(to_ruby<double>(std::get<2>(x)));
82
- a.push(to_ruby<int64_t>(std::get<3>(x)));
83
- return Object(a);
82
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
83
+ return rb_ary_new3(
84
+ 4,
85
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
86
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
87
+ Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
88
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
89
+ );
84
90
  }
85
91
 
86
- inline Object wrap(torch::TensorList x) {
87
- Array a;
88
- for (auto& t : x) {
89
- a.push(to_ruby<torch::Tensor>(t));
92
+ inline VALUE wrap(torch::TensorList x) {
93
+ auto a = rb_ary_new2(x.size());
94
+ for (auto t : x) {
95
+ rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
90
96
  }
91
- return Object(a);
97
+ return a;
92
98
  }
93
99
 
94
- inline Object wrap(std::tuple<double, double> x) {
95
- Array a;
96
- a.push(to_ruby<double>(std::get<0>(x)));
97
- a.push(to_ruby<double>(std::get<1>(x)));
98
- return Object(a);
100
+ inline VALUE wrap(std::tuple<double, double> x) {
101
+ return rb_ary_new3(
102
+ 2,
103
+ Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
104
+ Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
105
+ );
99
106
  }
data/lib/torch.rb CHANGED
@@ -238,8 +238,11 @@ module Torch
238
238
  double: 7,
239
239
  float64: 7,
240
240
  complex_half: 8,
241
+ complex32: 8,
241
242
  complex_float: 9,
243
+ complex64: 9,
242
244
  complex_double: 10,
245
+ complex128: 10,
243
246
  bool: 11,
244
247
  qint8: 12,
245
248
  quint8: 13,
@@ -337,25 +340,24 @@ module Torch
337
340
  }
338
341
  end
339
342
 
340
- def no_grad
341
- previous_value = grad_enabled?
342
- begin
343
- _set_grad_enabled(false)
344
- yield
345
- ensure
346
- _set_grad_enabled(previous_value)
347
- end
343
+ def no_grad(&block)
344
+ grad_enabled(false, &block)
345
+ end
346
+
347
+ def enable_grad(&block)
348
+ grad_enabled(true, &block)
348
349
  end
349
350
 
350
- def enable_grad
351
+ def grad_enabled(value)
351
352
  previous_value = grad_enabled?
352
353
  begin
353
- _set_grad_enabled(true)
354
+ _set_grad_enabled(value)
354
355
  yield
355
356
  ensure
356
357
  _set_grad_enabled(previous_value)
357
358
  end
358
359
  end
360
+ alias_method :set_grad_enabled, :grad_enabled
359
361
 
360
362
  def device(str)
361
363
  Device.new(str)
@@ -395,6 +397,8 @@ module Torch
395
397
  options[:dtype] = :int64
396
398
  elsif data.all? { |v| v == true || v == false }
397
399
  options[:dtype] = :bool
400
+ elsif data.any? { |v| v.is_a?(Complex) }
401
+ options[:dtype] = :complex64
398
402
  end
399
403
  end
400
404
 
@@ -96,8 +96,11 @@ module Torch
96
96
  ret = "%.#{PRINT_OPTS[:precision]}f" % value
97
97
  end
98
98
  elsif @complex_dtype
99
- p = PRINT_OPTS[:precision]
100
- raise NotImplementedYet
99
+ # TODO use float formatter for each part
100
+ precision = PRINT_OPTS[:precision]
101
+ imag = value.imag
102
+ sign = imag >= 0 ? "+" : "-"
103
+ ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
101
104
  else
102
105
  ret = value.to_s
103
106
  end
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
+ attr_reader :in_features, :out_features
5
+
4
6
  def initialize(in_features, out_features, bias: true)
5
7
  super()
6
8
  @in_features = in_features
@@ -113,35 +113,53 @@ module Torch
113
113
  forward(*input, **kwargs)
114
114
  end
115
115
 
116
- def state_dict(destination: nil)
116
+ def state_dict(destination: nil, prefix: "")
117
117
  destination ||= {}
118
- named_parameters.each do |k, v|
119
- destination[k] = v
118
+ save_to_state_dict(destination, prefix: prefix)
119
+
120
+ named_children.each do |name, mod|
121
+ next unless mod
122
+ mod.state_dict(destination: destination, prefix: prefix + name + ".")
120
123
  end
121
124
  destination
122
125
  end
123
126
 
124
- # TODO add strict option
125
- # TODO match PyTorch behavior
126
- def load_state_dict(state_dict)
127
- state_dict.each do |k, input_param|
128
- k1, k2 = k.split(".", 2)
129
- mod = named_modules[k1]
130
- if mod.is_a?(Module)
131
- param = mod.named_parameters[k2]
132
- if param.is_a?(Parameter)
133
- Torch.no_grad do
134
- param.copy!(input_param)
135
- end
136
- else
137
- raise Error, "Unknown parameter: #{k1}"
138
- end
139
- else
140
- raise Error, "Unknown module: #{k1}"
127
+ def load_state_dict(state_dict, strict: true)
128
+ # TODO support strict: false
129
+ raise "strict: false not implemented yet" unless strict
130
+
131
+ missing_keys = []
132
+ unexpected_keys = []
133
+ error_msgs = []
134
+
135
+ # TODO handle metadata
136
+
137
+ _load = lambda do |mod, prefix = ""|
138
+ # TODO handle metadata
139
+ local_metadata = {}
140
+ mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
141
+ mod.named_children.each do |name, child|
142
+ _load.call(child, prefix + name + ".") unless child.nil?
143
+ end
144
+ end
145
+
146
+ _load.call(self)
147
+
148
+ if strict
149
+ if unexpected_keys.any?
150
+ error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
151
+ end
152
+
153
+ if missing_keys.any?
154
+ error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
141
155
  end
142
156
  end
143
157
 
144
- # TODO return missing keys and unexpected keys
158
+ if error_msgs.any?
159
+ # just show first error
160
+ raise Error, error_msgs[0]
161
+ end
162
+
145
163
  nil
146
164
  end
147
165
 
@@ -268,6 +286,12 @@ module Torch
268
286
  named_buffers[name]
269
287
  elsif named_modules.key?(name)
270
288
  named_modules[name]
289
+ elsif method.end_with?("=") && named_modules.key?(method[0..-2])
290
+ if instance_variable_defined?("@#{method[0..-2]}")
291
+ instance_variable_set("@#{method[0..-2]}", *args)
292
+ else
293
+ raise NotImplementedYet
294
+ end
271
295
  else
272
296
  super
273
297
  end
@@ -300,6 +324,68 @@ module Torch
300
324
  def dict
301
325
  instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
302
326
  end
327
+
328
+ def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
329
+ # TODO add hooks
330
+
331
+ # TODO handle non-persistent buffers
332
+ persistent_buffers = named_buffers
333
+ local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
334
+ local_state = local_name_params.select { |_, v| !v.nil? }
335
+
336
+ local_state.each do |name, param|
337
+ key = prefix + name
338
+ if state_dict.key?(key)
339
+ input_param = state_dict[key]
340
+
341
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
342
+ if param.shape.length == 0 && input_param.shape.length == 1
343
+ input_param = input_param[0]
344
+ end
345
+
346
+ if input_param.shape != param.shape
347
+ # local shape should match the one in checkpoint
348
+ error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
349
+ "the shape in current model is #{param.shape}."
350
+ next
351
+ end
352
+
353
+ begin
354
+ Torch.no_grad do
355
+ param.copy!(input_param)
356
+ end
357
+ rescue => e
358
+ error_msgs << "While copying the parameter named #{key.inspect}, " +
359
+ "whose dimensions in the model are #{param.size} and " +
360
+ "whose dimensions in the checkpoint are #{input_param.size}, " +
361
+ "an exception occurred: #{e.inspect}"
362
+ end
363
+ elsif strict
364
+ missing_keys << key
365
+ end
366
+ end
367
+
368
+ if strict
369
+ state_dict.each_key do |key|
370
+ if key.start_with?(prefix)
371
+ input_name = key[prefix.length..-1]
372
+ input_name = input_name.split(".", 2)[0]
373
+ if !named_children.key?(input_name) && !local_state.key?(input_name)
374
+ unexpected_keys << key
375
+ end
376
+ end
377
+ end
378
+ end
379
+ end
380
+
381
+ def save_to_state_dict(destination, prefix: "")
382
+ named_parameters(recurse: false).each do |k, v|
383
+ destination[prefix + k] = v
384
+ end
385
+ named_buffers.each do |k, v|
386
+ destination[prefix + k] = v
387
+ end
388
+ end
303
389
  end
304
390
  end
305
391
  end