torch-rb 0.5.0 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/README.md +13 -4
- data/codegen/generate_functions.rb +13 -14
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +28 -0
- data/ext/torch/ext.cpp +26 -613
- data/ext/torch/extconf.rb +1 -4
- data/ext/torch/ivalue.cpp +132 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +3 -3
- data/ext/torch/ruby_arg_parser.h +37 -16
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +320 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +95 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -2
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +19 -17
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +9 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +14 -91
data/ext/torch/torch.cpp
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "torch_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_torch(Rice::Module& m) {
|
10
|
+
m.add_handler<torch::Error>(handle_error);
|
11
|
+
add_torch_functions(m);
|
12
|
+
m.define_singleton_function(
|
13
|
+
"grad_enabled?",
|
14
|
+
[]() {
|
15
|
+
return torch::GradMode::is_enabled();
|
16
|
+
})
|
17
|
+
.define_singleton_function(
|
18
|
+
"_set_grad_enabled",
|
19
|
+
[](bool enabled) {
|
20
|
+
torch::GradMode::set_enabled(enabled);
|
21
|
+
})
|
22
|
+
.define_singleton_function(
|
23
|
+
"manual_seed",
|
24
|
+
[](uint64_t seed) {
|
25
|
+
return torch::manual_seed(seed);
|
26
|
+
})
|
27
|
+
// config
|
28
|
+
.define_singleton_function(
|
29
|
+
"show_config",
|
30
|
+
[] {
|
31
|
+
return torch::show_config();
|
32
|
+
})
|
33
|
+
.define_singleton_function(
|
34
|
+
"parallel_info",
|
35
|
+
[] {
|
36
|
+
return torch::get_parallel_info();
|
37
|
+
})
|
38
|
+
// begin operations
|
39
|
+
.define_singleton_function(
|
40
|
+
"_save",
|
41
|
+
[](const torch::IValue &value) {
|
42
|
+
auto v = torch::pickle_save(value);
|
43
|
+
std::string str(v.begin(), v.end());
|
44
|
+
return str;
|
45
|
+
})
|
46
|
+
.define_singleton_function(
|
47
|
+
"_load",
|
48
|
+
[](const std::string &s) {
|
49
|
+
std::vector<char> v;
|
50
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
|
+
return torch::pickle_load(v);
|
53
|
+
})
|
54
|
+
.define_singleton_function(
|
55
|
+
"_from_blob",
|
56
|
+
[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
|
+
void *data = const_cast<char *>(s.c_str());
|
58
|
+
return torch::from_blob(data, size, options);
|
59
|
+
})
|
60
|
+
.define_singleton_function(
|
61
|
+
"_tensor",
|
62
|
+
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
|
+
auto dtype = options.dtype();
|
64
|
+
torch::Tensor t;
|
65
|
+
if (dtype == torch::kBool) {
|
66
|
+
std::vector<uint8_t> vec;
|
67
|
+
for (long i = 0; i < a.size(); i++) {
|
68
|
+
vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
|
72
|
+
// TODO use template
|
73
|
+
std::vector<c10::complex<double>> vec;
|
74
|
+
Object obj;
|
75
|
+
for (long i = 0; i < a.size(); i++) {
|
76
|
+
obj = a[i];
|
77
|
+
vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
|
78
|
+
}
|
79
|
+
t = torch::tensor(vec, options);
|
80
|
+
} else {
|
81
|
+
std::vector<float> vec;
|
82
|
+
for (long i = 0; i < a.size(); i++) {
|
83
|
+
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
84
|
+
}
|
85
|
+
// hack for requires_grad error
|
86
|
+
if (options.requires_grad()) {
|
87
|
+
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
88
|
+
t.set_requires_grad(true);
|
89
|
+
} else {
|
90
|
+
t = torch::tensor(vec, options);
|
91
|
+
}
|
92
|
+
}
|
93
|
+
return t.reshape(size);
|
94
|
+
});
|
95
|
+
}
|
data/ext/torch/torch_functions.h
CHANGED
data/ext/torch/utils.h
CHANGED
@@ -1,13 +1,19 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
#include <rice/stl.hpp>
|
5
|
+
|
6
|
+
// TODO find better place
|
7
|
+
inline void handle_error(torch::Error const & ex) {
|
8
|
+
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
9
|
+
}
|
4
10
|
|
5
11
|
// keep THP prefix for now to make it easier to compare code
|
6
12
|
|
7
13
|
extern VALUE THPVariableClass;
|
8
14
|
|
9
15
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
-
return Symbol(str);
|
16
|
+
return Rice::Symbol(str);
|
11
17
|
}
|
12
18
|
|
13
19
|
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
data/ext/torch/wrap_outputs.h
CHANGED
@@ -1,99 +1,106 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include <torch/torch.h>
|
4
|
-
#include <rice/
|
4
|
+
#include <rice/rice.hpp>
|
5
5
|
|
6
|
-
inline
|
7
|
-
return
|
6
|
+
inline VALUE wrap(bool x) {
|
7
|
+
return Rice::detail::To_Ruby<bool>().convert(x);
|
8
8
|
}
|
9
9
|
|
10
|
-
inline
|
11
|
-
return
|
10
|
+
inline VALUE wrap(int64_t x) {
|
11
|
+
return Rice::detail::To_Ruby<int64_t>().convert(x);
|
12
12
|
}
|
13
13
|
|
14
|
-
inline
|
15
|
-
return
|
14
|
+
inline VALUE wrap(double x) {
|
15
|
+
return Rice::detail::To_Ruby<double>().convert(x);
|
16
16
|
}
|
17
17
|
|
18
|
-
inline
|
19
|
-
return
|
18
|
+
inline VALUE wrap(torch::Tensor x) {
|
19
|
+
return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
|
20
20
|
}
|
21
21
|
|
22
|
-
inline
|
23
|
-
return
|
22
|
+
inline VALUE wrap(torch::Scalar x) {
|
23
|
+
return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
|
24
24
|
}
|
25
25
|
|
26
|
-
inline
|
27
|
-
return
|
26
|
+
inline VALUE wrap(torch::ScalarType x) {
|
27
|
+
return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
|
28
28
|
}
|
29
29
|
|
30
|
-
inline
|
31
|
-
return
|
30
|
+
inline VALUE wrap(torch::QScheme x) {
|
31
|
+
return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
|
32
32
|
}
|
33
33
|
|
34
|
-
inline
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
34
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
35
|
+
return rb_ary_new3(
|
36
|
+
2,
|
37
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
38
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
|
39
|
+
);
|
39
40
|
}
|
40
41
|
|
41
|
-
inline
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
42
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
43
|
+
return rb_ary_new3(
|
44
|
+
3,
|
45
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
46
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
47
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
|
48
|
+
);
|
47
49
|
}
|
48
50
|
|
49
|
-
inline
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
51
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
52
|
+
return rb_ary_new3(
|
53
|
+
4,
|
54
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
55
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
56
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
57
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
|
58
|
+
);
|
56
59
|
}
|
57
60
|
|
58
|
-
inline
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
61
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
62
|
+
return rb_ary_new3(
|
63
|
+
5,
|
64
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
65
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
66
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
67
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
|
68
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
|
69
|
+
);
|
66
70
|
}
|
67
71
|
|
68
|
-
inline
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
72
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
73
|
+
return rb_ary_new3(
|
74
|
+
4,
|
75
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
76
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
77
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
78
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
79
|
+
);
|
75
80
|
}
|
76
81
|
|
77
|
-
inline
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
82
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
83
|
+
return rb_ary_new3(
|
84
|
+
4,
|
85
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
86
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
87
|
+
Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
|
88
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
89
|
+
);
|
84
90
|
}
|
85
91
|
|
86
|
-
inline
|
87
|
-
|
88
|
-
for (auto
|
89
|
-
a
|
92
|
+
inline VALUE wrap(torch::TensorList x) {
|
93
|
+
auto a = rb_ary_new2(x.size());
|
94
|
+
for (auto t : x) {
|
95
|
+
rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
|
90
96
|
}
|
91
|
-
return
|
97
|
+
return a;
|
92
98
|
}
|
93
99
|
|
94
|
-
inline
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
100
|
+
inline VALUE wrap(std::tuple<double, double> x) {
|
101
|
+
return rb_ary_new3(
|
102
|
+
2,
|
103
|
+
Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
|
104
|
+
Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
|
105
|
+
);
|
99
106
|
}
|
data/lib/torch.rb
CHANGED
@@ -238,8 +238,11 @@ module Torch
|
|
238
238
|
double: 7,
|
239
239
|
float64: 7,
|
240
240
|
complex_half: 8,
|
241
|
+
complex32: 8,
|
241
242
|
complex_float: 9,
|
243
|
+
complex64: 9,
|
242
244
|
complex_double: 10,
|
245
|
+
complex128: 10,
|
243
246
|
bool: 11,
|
244
247
|
qint8: 12,
|
245
248
|
quint8: 13,
|
@@ -261,6 +264,8 @@ module Torch
|
|
261
264
|
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
262
265
|
elsif args.size == 1 && args.first.is_a?(Array)
|
263
266
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
267
|
+
elsif args.size == 0
|
268
|
+
Torch.empty(0, dtype: dtype, device: device)
|
264
269
|
else
|
265
270
|
Torch.empty(*args, dtype: dtype, device: device)
|
266
271
|
end
|
@@ -335,25 +340,24 @@ module Torch
|
|
335
340
|
}
|
336
341
|
end
|
337
342
|
|
338
|
-
def no_grad
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
_set_grad_enabled(previous_value)
|
345
|
-
end
|
343
|
+
def no_grad(&block)
|
344
|
+
grad_enabled(false, &block)
|
345
|
+
end
|
346
|
+
|
347
|
+
def enable_grad(&block)
|
348
|
+
grad_enabled(true, &block)
|
346
349
|
end
|
347
350
|
|
348
|
-
def
|
351
|
+
def grad_enabled(value)
|
349
352
|
previous_value = grad_enabled?
|
350
353
|
begin
|
351
|
-
_set_grad_enabled(
|
354
|
+
_set_grad_enabled(value)
|
352
355
|
yield
|
353
356
|
ensure
|
354
357
|
_set_grad_enabled(previous_value)
|
355
358
|
end
|
356
359
|
end
|
360
|
+
alias_method :set_grad_enabled, :grad_enabled
|
357
361
|
|
358
362
|
def device(str)
|
359
363
|
Device.new(str)
|
@@ -393,6 +397,8 @@ module Torch
|
|
393
397
|
options[:dtype] = :int64
|
394
398
|
elsif data.all? { |v| v == true || v == false }
|
395
399
|
options[:dtype] = :bool
|
400
|
+
elsif data.any? { |v| v.is_a?(Complex) }
|
401
|
+
options[:dtype] = :complex64
|
396
402
|
end
|
397
403
|
end
|
398
404
|
|
@@ -434,7 +440,8 @@ module Torch
|
|
434
440
|
zeros(input.size, **like_options(input, options))
|
435
441
|
end
|
436
442
|
|
437
|
-
|
443
|
+
# center option
|
444
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
438
445
|
if center
|
439
446
|
signal_dim = input.dim
|
440
447
|
extended_shape = [1] * (3 - signal_dim) + input.size
|
@@ -442,12 +449,7 @@ module Torch
|
|
442
449
|
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
443
450
|
input = input.view(input.shape[-signal_dim..-1])
|
444
451
|
end
|
445
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
446
|
-
end
|
447
|
-
|
448
|
-
def clamp(tensor, min, max)
|
449
|
-
tensor = _clamp_min(tensor, min)
|
450
|
-
_clamp_max(tensor, max)
|
452
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
451
453
|
end
|
452
454
|
|
453
455
|
private
|
data/lib/torch/inspector.rb
CHANGED
@@ -96,8 +96,11 @@ module Torch
|
|
96
96
|
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
97
97
|
end
|
98
98
|
elsif @complex_dtype
|
99
|
-
|
100
|
-
|
99
|
+
# TODO use float formatter for each part
|
100
|
+
precision = PRINT_OPTS[:precision]
|
101
|
+
imag = value.imag
|
102
|
+
sign = imag >= 0 ? "+" : "-"
|
103
|
+
ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
|
101
104
|
else
|
102
105
|
ret = value.to_s
|
103
106
|
end
|
data/lib/torch/nn/linear.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -113,35 +113,53 @@ module Torch
|
|
113
113
|
forward(*input, **kwargs)
|
114
114
|
end
|
115
115
|
|
116
|
-
def state_dict(destination: nil)
|
116
|
+
def state_dict(destination: nil, prefix: "")
|
117
117
|
destination ||= {}
|
118
|
-
|
119
|
-
|
118
|
+
save_to_state_dict(destination, prefix: prefix)
|
119
|
+
|
120
|
+
named_children.each do |name, mod|
|
121
|
+
next unless mod
|
122
|
+
mod.state_dict(destination: destination, prefix: prefix + name + ".")
|
120
123
|
end
|
121
124
|
destination
|
122
125
|
end
|
123
126
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
127
|
+
def load_state_dict(state_dict, strict: true)
|
128
|
+
# TODO support strict: false
|
129
|
+
raise "strict: false not implemented yet" unless strict
|
130
|
+
|
131
|
+
missing_keys = []
|
132
|
+
unexpected_keys = []
|
133
|
+
error_msgs = []
|
134
|
+
|
135
|
+
# TODO handle metadata
|
136
|
+
|
137
|
+
_load = lambda do |mod, prefix = ""|
|
138
|
+
# TODO handle metadata
|
139
|
+
local_metadata = {}
|
140
|
+
mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
|
141
|
+
mod.named_children.each do |name, child|
|
142
|
+
_load.call(child, prefix + name + ".") unless child.nil?
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
_load.call(self)
|
147
|
+
|
148
|
+
if strict
|
149
|
+
if unexpected_keys.any?
|
150
|
+
error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
|
151
|
+
end
|
152
|
+
|
153
|
+
if missing_keys.any?
|
154
|
+
error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
|
141
155
|
end
|
142
156
|
end
|
143
157
|
|
144
|
-
|
158
|
+
if error_msgs.any?
|
159
|
+
# just show first error
|
160
|
+
raise Error, error_msgs[0]
|
161
|
+
end
|
162
|
+
|
145
163
|
nil
|
146
164
|
end
|
147
165
|
|
@@ -268,6 +286,12 @@ module Torch
|
|
268
286
|
named_buffers[name]
|
269
287
|
elsif named_modules.key?(name)
|
270
288
|
named_modules[name]
|
289
|
+
elsif method.end_with?("=") && named_modules.key?(method[0..-2])
|
290
|
+
if instance_variable_defined?("@#{method[0..-2]}")
|
291
|
+
instance_variable_set("@#{method[0..-2]}", *args)
|
292
|
+
else
|
293
|
+
raise NotImplementedYet
|
294
|
+
end
|
271
295
|
else
|
272
296
|
super
|
273
297
|
end
|
@@ -300,6 +324,68 @@ module Torch
|
|
300
324
|
def dict
|
301
325
|
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
302
326
|
end
|
327
|
+
|
328
|
+
def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
329
|
+
# TODO add hooks
|
330
|
+
|
331
|
+
# TODO handle non-persistent buffers
|
332
|
+
persistent_buffers = named_buffers
|
333
|
+
local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
|
334
|
+
local_state = local_name_params.select { |_, v| !v.nil? }
|
335
|
+
|
336
|
+
local_state.each do |name, param|
|
337
|
+
key = prefix + name
|
338
|
+
if state_dict.key?(key)
|
339
|
+
input_param = state_dict[key]
|
340
|
+
|
341
|
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
342
|
+
if param.shape.length == 0 && input_param.shape.length == 1
|
343
|
+
input_param = input_param[0]
|
344
|
+
end
|
345
|
+
|
346
|
+
if input_param.shape != param.shape
|
347
|
+
# local shape should match the one in checkpoint
|
348
|
+
error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
|
349
|
+
"the shape in current model is #{param.shape}."
|
350
|
+
next
|
351
|
+
end
|
352
|
+
|
353
|
+
begin
|
354
|
+
Torch.no_grad do
|
355
|
+
param.copy!(input_param)
|
356
|
+
end
|
357
|
+
rescue => e
|
358
|
+
error_msgs << "While copying the parameter named #{key.inspect}, " +
|
359
|
+
"whose dimensions in the model are #{param.size} and " +
|
360
|
+
"whose dimensions in the checkpoint are #{input_param.size}, " +
|
361
|
+
"an exception occurred: #{e.inspect}"
|
362
|
+
end
|
363
|
+
elsif strict
|
364
|
+
missing_keys << key
|
365
|
+
end
|
366
|
+
end
|
367
|
+
|
368
|
+
if strict
|
369
|
+
state_dict.each_key do |key|
|
370
|
+
if key.start_with?(prefix)
|
371
|
+
input_name = key[prefix.length..-1]
|
372
|
+
input_name = input_name.split(".", 2)[0]
|
373
|
+
if !named_children.key?(input_name) && !local_state.key?(input_name)
|
374
|
+
unexpected_keys << key
|
375
|
+
end
|
376
|
+
end
|
377
|
+
end
|
378
|
+
end
|
379
|
+
end
|
380
|
+
|
381
|
+
def save_to_state_dict(destination, prefix: "")
|
382
|
+
named_parameters(recurse: false).each do |k, v|
|
383
|
+
destination[prefix + k] = v
|
384
|
+
end
|
385
|
+
named_buffers.each do |k, v|
|
386
|
+
destination[prefix + k] = v
|
387
|
+
end
|
388
|
+
end
|
303
389
|
end
|
304
390
|
end
|
305
391
|
end
|