torch-rb 0.5.0 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/README.md +13 -4
- data/codegen/generate_functions.rb +13 -14
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +28 -0
- data/ext/torch/ext.cpp +26 -613
- data/ext/torch/extconf.rb +1 -4
- data/ext/torch/ivalue.cpp +132 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +3 -3
- data/ext/torch/ruby_arg_parser.h +37 -16
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +320 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +95 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -2
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +19 -17
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +9 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +14 -91
data/ext/torch/torch.cpp
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "torch_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_torch(Rice::Module& m) {
|
10
|
+
m.add_handler<torch::Error>(handle_error);
|
11
|
+
add_torch_functions(m);
|
12
|
+
m.define_singleton_function(
|
13
|
+
"grad_enabled?",
|
14
|
+
[]() {
|
15
|
+
return torch::GradMode::is_enabled();
|
16
|
+
})
|
17
|
+
.define_singleton_function(
|
18
|
+
"_set_grad_enabled",
|
19
|
+
[](bool enabled) {
|
20
|
+
torch::GradMode::set_enabled(enabled);
|
21
|
+
})
|
22
|
+
.define_singleton_function(
|
23
|
+
"manual_seed",
|
24
|
+
[](uint64_t seed) {
|
25
|
+
return torch::manual_seed(seed);
|
26
|
+
})
|
27
|
+
// config
|
28
|
+
.define_singleton_function(
|
29
|
+
"show_config",
|
30
|
+
[] {
|
31
|
+
return torch::show_config();
|
32
|
+
})
|
33
|
+
.define_singleton_function(
|
34
|
+
"parallel_info",
|
35
|
+
[] {
|
36
|
+
return torch::get_parallel_info();
|
37
|
+
})
|
38
|
+
// begin operations
|
39
|
+
.define_singleton_function(
|
40
|
+
"_save",
|
41
|
+
[](const torch::IValue &value) {
|
42
|
+
auto v = torch::pickle_save(value);
|
43
|
+
std::string str(v.begin(), v.end());
|
44
|
+
return str;
|
45
|
+
})
|
46
|
+
.define_singleton_function(
|
47
|
+
"_load",
|
48
|
+
[](const std::string &s) {
|
49
|
+
std::vector<char> v;
|
50
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
|
+
return torch::pickle_load(v);
|
53
|
+
})
|
54
|
+
.define_singleton_function(
|
55
|
+
"_from_blob",
|
56
|
+
[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
|
+
void *data = const_cast<char *>(s.c_str());
|
58
|
+
return torch::from_blob(data, size, options);
|
59
|
+
})
|
60
|
+
.define_singleton_function(
|
61
|
+
"_tensor",
|
62
|
+
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
|
+
auto dtype = options.dtype();
|
64
|
+
torch::Tensor t;
|
65
|
+
if (dtype == torch::kBool) {
|
66
|
+
std::vector<uint8_t> vec;
|
67
|
+
for (long i = 0; i < a.size(); i++) {
|
68
|
+
vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
|
72
|
+
// TODO use template
|
73
|
+
std::vector<c10::complex<double>> vec;
|
74
|
+
Object obj;
|
75
|
+
for (long i = 0; i < a.size(); i++) {
|
76
|
+
obj = a[i];
|
77
|
+
vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
|
78
|
+
}
|
79
|
+
t = torch::tensor(vec, options);
|
80
|
+
} else {
|
81
|
+
std::vector<float> vec;
|
82
|
+
for (long i = 0; i < a.size(); i++) {
|
83
|
+
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
84
|
+
}
|
85
|
+
// hack for requires_grad error
|
86
|
+
if (options.requires_grad()) {
|
87
|
+
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
88
|
+
t.set_requires_grad(true);
|
89
|
+
} else {
|
90
|
+
t = torch::tensor(vec, options);
|
91
|
+
}
|
92
|
+
}
|
93
|
+
return t.reshape(size);
|
94
|
+
});
|
95
|
+
}
|
data/ext/torch/torch_functions.h
CHANGED
data/ext/torch/utils.h
CHANGED
@@ -1,13 +1,19 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
#include <rice/stl.hpp>
|
5
|
+
|
6
|
+
// TODO find better place
|
7
|
+
inline void handle_error(torch::Error const & ex) {
|
8
|
+
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
9
|
+
}
|
4
10
|
|
5
11
|
// keep THP prefix for now to make it easier to compare code
|
6
12
|
|
7
13
|
extern VALUE THPVariableClass;
|
8
14
|
|
9
15
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
-
return Symbol(str);
|
16
|
+
return Rice::Symbol(str);
|
11
17
|
}
|
12
18
|
|
13
19
|
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
data/ext/torch/wrap_outputs.h
CHANGED
@@ -1,99 +1,106 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include <torch/torch.h>
|
4
|
-
#include <rice/
|
4
|
+
#include <rice/rice.hpp>
|
5
5
|
|
6
|
-
inline
|
7
|
-
return
|
6
|
+
inline VALUE wrap(bool x) {
|
7
|
+
return Rice::detail::To_Ruby<bool>().convert(x);
|
8
8
|
}
|
9
9
|
|
10
|
-
inline
|
11
|
-
return
|
10
|
+
inline VALUE wrap(int64_t x) {
|
11
|
+
return Rice::detail::To_Ruby<int64_t>().convert(x);
|
12
12
|
}
|
13
13
|
|
14
|
-
inline
|
15
|
-
return
|
14
|
+
inline VALUE wrap(double x) {
|
15
|
+
return Rice::detail::To_Ruby<double>().convert(x);
|
16
16
|
}
|
17
17
|
|
18
|
-
inline
|
19
|
-
return
|
18
|
+
inline VALUE wrap(torch::Tensor x) {
|
19
|
+
return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
|
20
20
|
}
|
21
21
|
|
22
|
-
inline
|
23
|
-
return
|
22
|
+
inline VALUE wrap(torch::Scalar x) {
|
23
|
+
return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
|
24
24
|
}
|
25
25
|
|
26
|
-
inline
|
27
|
-
return
|
26
|
+
inline VALUE wrap(torch::ScalarType x) {
|
27
|
+
return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
|
28
28
|
}
|
29
29
|
|
30
|
-
inline
|
31
|
-
return
|
30
|
+
inline VALUE wrap(torch::QScheme x) {
|
31
|
+
return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
|
32
32
|
}
|
33
33
|
|
34
|
-
inline
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
34
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
35
|
+
return rb_ary_new3(
|
36
|
+
2,
|
37
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
38
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
|
39
|
+
);
|
39
40
|
}
|
40
41
|
|
41
|
-
inline
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
42
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
43
|
+
return rb_ary_new3(
|
44
|
+
3,
|
45
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
46
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
47
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
|
48
|
+
);
|
47
49
|
}
|
48
50
|
|
49
|
-
inline
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
51
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
52
|
+
return rb_ary_new3(
|
53
|
+
4,
|
54
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
55
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
56
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
57
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
|
58
|
+
);
|
56
59
|
}
|
57
60
|
|
58
|
-
inline
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
61
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
62
|
+
return rb_ary_new3(
|
63
|
+
5,
|
64
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
65
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
66
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
67
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
|
68
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
|
69
|
+
);
|
66
70
|
}
|
67
71
|
|
68
|
-
inline
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
72
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
73
|
+
return rb_ary_new3(
|
74
|
+
4,
|
75
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
76
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
77
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
78
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
79
|
+
);
|
75
80
|
}
|
76
81
|
|
77
|
-
inline
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
82
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
83
|
+
return rb_ary_new3(
|
84
|
+
4,
|
85
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
86
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
87
|
+
Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
|
88
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
89
|
+
);
|
84
90
|
}
|
85
91
|
|
86
|
-
inline
|
87
|
-
|
88
|
-
for (auto
|
89
|
-
a
|
92
|
+
inline VALUE wrap(torch::TensorList x) {
|
93
|
+
auto a = rb_ary_new2(x.size());
|
94
|
+
for (auto t : x) {
|
95
|
+
rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
|
90
96
|
}
|
91
|
-
return
|
97
|
+
return a;
|
92
98
|
}
|
93
99
|
|
94
|
-
inline
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
100
|
+
inline VALUE wrap(std::tuple<double, double> x) {
|
101
|
+
return rb_ary_new3(
|
102
|
+
2,
|
103
|
+
Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
|
104
|
+
Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
|
105
|
+
);
|
99
106
|
}
|
data/lib/torch.rb
CHANGED
@@ -238,8 +238,11 @@ module Torch
|
|
238
238
|
double: 7,
|
239
239
|
float64: 7,
|
240
240
|
complex_half: 8,
|
241
|
+
complex32: 8,
|
241
242
|
complex_float: 9,
|
243
|
+
complex64: 9,
|
242
244
|
complex_double: 10,
|
245
|
+
complex128: 10,
|
243
246
|
bool: 11,
|
244
247
|
qint8: 12,
|
245
248
|
quint8: 13,
|
@@ -261,6 +264,8 @@ module Torch
|
|
261
264
|
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
262
265
|
elsif args.size == 1 && args.first.is_a?(Array)
|
263
266
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
267
|
+
elsif args.size == 0
|
268
|
+
Torch.empty(0, dtype: dtype, device: device)
|
264
269
|
else
|
265
270
|
Torch.empty(*args, dtype: dtype, device: device)
|
266
271
|
end
|
@@ -335,25 +340,24 @@ module Torch
|
|
335
340
|
}
|
336
341
|
end
|
337
342
|
|
338
|
-
def no_grad
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
_set_grad_enabled(previous_value)
|
345
|
-
end
|
343
|
+
def no_grad(&block)
|
344
|
+
grad_enabled(false, &block)
|
345
|
+
end
|
346
|
+
|
347
|
+
def enable_grad(&block)
|
348
|
+
grad_enabled(true, &block)
|
346
349
|
end
|
347
350
|
|
348
|
-
def
|
351
|
+
def grad_enabled(value)
|
349
352
|
previous_value = grad_enabled?
|
350
353
|
begin
|
351
|
-
_set_grad_enabled(
|
354
|
+
_set_grad_enabled(value)
|
352
355
|
yield
|
353
356
|
ensure
|
354
357
|
_set_grad_enabled(previous_value)
|
355
358
|
end
|
356
359
|
end
|
360
|
+
alias_method :set_grad_enabled, :grad_enabled
|
357
361
|
|
358
362
|
def device(str)
|
359
363
|
Device.new(str)
|
@@ -393,6 +397,8 @@ module Torch
|
|
393
397
|
options[:dtype] = :int64
|
394
398
|
elsif data.all? { |v| v == true || v == false }
|
395
399
|
options[:dtype] = :bool
|
400
|
+
elsif data.any? { |v| v.is_a?(Complex) }
|
401
|
+
options[:dtype] = :complex64
|
396
402
|
end
|
397
403
|
end
|
398
404
|
|
@@ -434,7 +440,8 @@ module Torch
|
|
434
440
|
zeros(input.size, **like_options(input, options))
|
435
441
|
end
|
436
442
|
|
437
|
-
|
443
|
+
# center option
|
444
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
438
445
|
if center
|
439
446
|
signal_dim = input.dim
|
440
447
|
extended_shape = [1] * (3 - signal_dim) + input.size
|
@@ -442,12 +449,7 @@ module Torch
|
|
442
449
|
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
443
450
|
input = input.view(input.shape[-signal_dim..-1])
|
444
451
|
end
|
445
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
446
|
-
end
|
447
|
-
|
448
|
-
def clamp(tensor, min, max)
|
449
|
-
tensor = _clamp_min(tensor, min)
|
450
|
-
_clamp_max(tensor, max)
|
452
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
451
453
|
end
|
452
454
|
|
453
455
|
private
|
data/lib/torch/inspector.rb
CHANGED
@@ -96,8 +96,11 @@ module Torch
|
|
96
96
|
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
97
97
|
end
|
98
98
|
elsif @complex_dtype
|
99
|
-
|
100
|
-
|
99
|
+
# TODO use float formatter for each part
|
100
|
+
precision = PRINT_OPTS[:precision]
|
101
|
+
imag = value.imag
|
102
|
+
sign = imag >= 0 ? "+" : "-"
|
103
|
+
ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
|
101
104
|
else
|
102
105
|
ret = value.to_s
|
103
106
|
end
|
data/lib/torch/nn/linear.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -113,35 +113,53 @@ module Torch
|
|
113
113
|
forward(*input, **kwargs)
|
114
114
|
end
|
115
115
|
|
116
|
-
def state_dict(destination: nil)
|
116
|
+
def state_dict(destination: nil, prefix: "")
|
117
117
|
destination ||= {}
|
118
|
-
|
119
|
-
|
118
|
+
save_to_state_dict(destination, prefix: prefix)
|
119
|
+
|
120
|
+
named_children.each do |name, mod|
|
121
|
+
next unless mod
|
122
|
+
mod.state_dict(destination: destination, prefix: prefix + name + ".")
|
120
123
|
end
|
121
124
|
destination
|
122
125
|
end
|
123
126
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
127
|
+
def load_state_dict(state_dict, strict: true)
|
128
|
+
# TODO support strict: false
|
129
|
+
raise "strict: false not implemented yet" unless strict
|
130
|
+
|
131
|
+
missing_keys = []
|
132
|
+
unexpected_keys = []
|
133
|
+
error_msgs = []
|
134
|
+
|
135
|
+
# TODO handle metadata
|
136
|
+
|
137
|
+
_load = lambda do |mod, prefix = ""|
|
138
|
+
# TODO handle metadata
|
139
|
+
local_metadata = {}
|
140
|
+
mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
|
141
|
+
mod.named_children.each do |name, child|
|
142
|
+
_load.call(child, prefix + name + ".") unless child.nil?
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
_load.call(self)
|
147
|
+
|
148
|
+
if strict
|
149
|
+
if unexpected_keys.any?
|
150
|
+
error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
|
151
|
+
end
|
152
|
+
|
153
|
+
if missing_keys.any?
|
154
|
+
error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
|
141
155
|
end
|
142
156
|
end
|
143
157
|
|
144
|
-
|
158
|
+
if error_msgs.any?
|
159
|
+
# just show first error
|
160
|
+
raise Error, error_msgs[0]
|
161
|
+
end
|
162
|
+
|
145
163
|
nil
|
146
164
|
end
|
147
165
|
|
@@ -268,6 +286,12 @@ module Torch
|
|
268
286
|
named_buffers[name]
|
269
287
|
elsif named_modules.key?(name)
|
270
288
|
named_modules[name]
|
289
|
+
elsif method.end_with?("=") && named_modules.key?(method[0..-2])
|
290
|
+
if instance_variable_defined?("@#{method[0..-2]}")
|
291
|
+
instance_variable_set("@#{method[0..-2]}", *args)
|
292
|
+
else
|
293
|
+
raise NotImplementedYet
|
294
|
+
end
|
271
295
|
else
|
272
296
|
super
|
273
297
|
end
|
@@ -300,6 +324,68 @@ module Torch
|
|
300
324
|
def dict
|
301
325
|
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
302
326
|
end
|
327
|
+
|
328
|
+
def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
329
|
+
# TODO add hooks
|
330
|
+
|
331
|
+
# TODO handle non-persistent buffers
|
332
|
+
persistent_buffers = named_buffers
|
333
|
+
local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
|
334
|
+
local_state = local_name_params.select { |_, v| !v.nil? }
|
335
|
+
|
336
|
+
local_state.each do |name, param|
|
337
|
+
key = prefix + name
|
338
|
+
if state_dict.key?(key)
|
339
|
+
input_param = state_dict[key]
|
340
|
+
|
341
|
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
342
|
+
if param.shape.length == 0 && input_param.shape.length == 1
|
343
|
+
input_param = input_param[0]
|
344
|
+
end
|
345
|
+
|
346
|
+
if input_param.shape != param.shape
|
347
|
+
# local shape should match the one in checkpoint
|
348
|
+
error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
|
349
|
+
"the shape in current model is #{param.shape}."
|
350
|
+
next
|
351
|
+
end
|
352
|
+
|
353
|
+
begin
|
354
|
+
Torch.no_grad do
|
355
|
+
param.copy!(input_param)
|
356
|
+
end
|
357
|
+
rescue => e
|
358
|
+
error_msgs << "While copying the parameter named #{key.inspect}, " +
|
359
|
+
"whose dimensions in the model are #{param.size} and " +
|
360
|
+
"whose dimensions in the checkpoint are #{input_param.size}, " +
|
361
|
+
"an exception occurred: #{e.inspect}"
|
362
|
+
end
|
363
|
+
elsif strict
|
364
|
+
missing_keys << key
|
365
|
+
end
|
366
|
+
end
|
367
|
+
|
368
|
+
if strict
|
369
|
+
state_dict.each_key do |key|
|
370
|
+
if key.start_with?(prefix)
|
371
|
+
input_name = key[prefix.length..-1]
|
372
|
+
input_name = input_name.split(".", 2)[0]
|
373
|
+
if !named_children.key?(input_name) && !local_state.key?(input_name)
|
374
|
+
unexpected_keys << key
|
375
|
+
end
|
376
|
+
end
|
377
|
+
end
|
378
|
+
end
|
379
|
+
end
|
380
|
+
|
381
|
+
def save_to_state_dict(destination, prefix: "")
|
382
|
+
named_parameters(recurse: false).each do |k, v|
|
383
|
+
destination[prefix + k] = v
|
384
|
+
end
|
385
|
+
named_buffers.each do |k, v|
|
386
|
+
destination[prefix + k] = v
|
387
|
+
end
|
388
|
+
end
|
303
389
|
end
|
304
390
|
end
|
305
391
|
end
|