torch-rb 0.5.2 → 0.5.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +9 -2
- data/ext/torch/ext.cpp +13 -4
- data/lib/torch/nn/module.rb +101 -21
- data/lib/torch/version.rb +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: fa1967c6330c6ad818e5452d38c5f83b2dde901c33d03f2a442bc003e4e662be
|
4
|
+
data.tar.gz: 00d57cb1de4e6cec72986cab301b04b66ce53aa1d7a0e84eadbc25ac924d28eb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ae015b2ea53033df631d914c7e6fc70bdff216742900d6a9a76d5e4479cc8faf05becf2ab8ee9edc874092e36355afecbe6c08f5205678be6ec01f0f0783a469
|
7
|
+
data.tar.gz: c751eaa8253682bc95653aa45cc61a05c9f120fa6df88c3d69873a43f47ca7da07e45c3a1e1c5f36bc63da69fae0dc9626f9db2f20af9d6ef190dbcce60bfb75
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -8,7 +8,7 @@ Check out:
|
|
8
8
|
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
9
|
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
10
10
|
|
11
|
-
[![Build Status](https://
|
11
|
+
[![Build Status](https://github.com/ankane/torch.rb/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torch.rb/actions)
|
12
12
|
|
13
13
|
## Installation
|
14
14
|
|
@@ -319,6 +319,13 @@ net.load_state_dict(Torch.load("net.pth"))
|
|
319
319
|
net.eval
|
320
320
|
```
|
321
321
|
|
322
|
+
When saving a model in Python to load in Ruby, convert parameters to tensors (due to outstanding bugs in LibTorch)
|
323
|
+
|
324
|
+
```python
|
325
|
+
state_dict = {k: v.data if isinstance(v, torch.nn.Parameter) else v for k, v in state_dict.items()}
|
326
|
+
torch.save(state_dict, "net.pth")
|
327
|
+
```
|
328
|
+
|
322
329
|
### Tensor Creation
|
323
330
|
|
324
331
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -416,7 +423,7 @@ Here’s the list of compatible versions.
|
|
416
423
|
|
417
424
|
Torch.rb | LibTorch
|
418
425
|
--- | ---
|
419
|
-
0.5.0+ | 1.7.0
|
426
|
+
0.5.0+ | 1.7.0-1.7.1
|
420
427
|
0.3.0+ | 1.6.0
|
421
428
|
0.2.0-0.2.7 | 1.5.0-1.5.1
|
422
429
|
0.1.8 | 1.4.0
|
data/ext/torch/ext.cpp
CHANGED
@@ -44,8 +44,13 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
44
44
|
if (obj.is_instance_of(rb_cInteger)) {
|
45
45
|
indices.push_back(from_ruby<int64_t>(obj));
|
46
46
|
} else if (obj.is_instance_of(rb_cRange)) {
|
47
|
-
torch::optional<int64_t> start_index =
|
48
|
-
torch::optional<int64_t> stop_index =
|
47
|
+
torch::optional<int64_t> start_index = torch::nullopt;
|
48
|
+
torch::optional<int64_t> stop_index = torch::nullopt;
|
49
|
+
|
50
|
+
Object begin = obj.call("begin");
|
51
|
+
if (!begin.is_nil()) {
|
52
|
+
start_index = from_ruby<int64_t>(begin);
|
53
|
+
}
|
49
54
|
|
50
55
|
Object end = obj.call("end");
|
51
56
|
if (!end.is_nil()) {
|
@@ -53,12 +58,14 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
53
58
|
}
|
54
59
|
|
55
60
|
Object exclude_end = obj.call("exclude_end?");
|
56
|
-
if (!exclude_end) {
|
61
|
+
if (stop_index.has_value() && !exclude_end) {
|
57
62
|
if (stop_index.value() == -1) {
|
58
63
|
stop_index = torch::nullopt;
|
59
64
|
} else {
|
60
65
|
stop_index = stop_index.value() + 1;
|
61
66
|
}
|
67
|
+
} else if (!stop_index.has_value() && exclude_end) {
|
68
|
+
stop_index = -1;
|
62
69
|
}
|
63
70
|
|
64
71
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
@@ -618,5 +625,7 @@ void Init_ext()
|
|
618
625
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
619
626
|
.add_handler<torch::Error>(handle_error)
|
620
627
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
621
|
-
.define_singleton_method("device_count", &torch::cuda::device_count)
|
628
|
+
.define_singleton_method("device_count", &torch::cuda::device_count)
|
629
|
+
.define_singleton_method("manual_seed", &torch::cuda::manual_seed)
|
630
|
+
.define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
|
622
631
|
}
|
data/lib/torch/nn/module.rb
CHANGED
@@ -113,35 +113,53 @@ module Torch
|
|
113
113
|
forward(*input, **kwargs)
|
114
114
|
end
|
115
115
|
|
116
|
-
def state_dict(destination: nil)
|
116
|
+
def state_dict(destination: nil, prefix: "")
|
117
117
|
destination ||= {}
|
118
|
-
|
119
|
-
|
118
|
+
save_to_state_dict(destination, prefix: prefix)
|
119
|
+
|
120
|
+
named_children.each do |name, mod|
|
121
|
+
next unless mod
|
122
|
+
mod.state_dict(destination: destination, prefix: prefix + name + ".")
|
120
123
|
end
|
121
124
|
destination
|
122
125
|
end
|
123
126
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
raise Error, "Unknown module: #{k1}"
|
127
|
+
def load_state_dict(state_dict, strict: true)
|
128
|
+
# TODO support strict: false
|
129
|
+
raise "strict: false not implemented yet" unless strict
|
130
|
+
|
131
|
+
missing_keys = []
|
132
|
+
unexpected_keys = []
|
133
|
+
error_msgs = []
|
134
|
+
|
135
|
+
# TODO handle metadata
|
136
|
+
|
137
|
+
_load = lambda do |mod, prefix = ""|
|
138
|
+
# TODO handle metadata
|
139
|
+
local_metadata = {}
|
140
|
+
mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
|
141
|
+
mod.named_children.each do |name, child|
|
142
|
+
_load.call(child, prefix + name + ".") unless child.nil?
|
141
143
|
end
|
142
144
|
end
|
143
145
|
|
144
|
-
|
146
|
+
_load.call(self)
|
147
|
+
|
148
|
+
if strict
|
149
|
+
if unexpected_keys.any?
|
150
|
+
error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
|
151
|
+
end
|
152
|
+
|
153
|
+
if missing_keys.any?
|
154
|
+
error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
if error_msgs.any?
|
159
|
+
# just show first error
|
160
|
+
raise Error, error_msgs[0]
|
161
|
+
end
|
162
|
+
|
145
163
|
nil
|
146
164
|
end
|
147
165
|
|
@@ -300,6 +318,68 @@ module Torch
|
|
300
318
|
def dict
|
301
319
|
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
302
320
|
end
|
321
|
+
|
322
|
+
def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
323
|
+
# TODO add hooks
|
324
|
+
|
325
|
+
# TODO handle non-persistent buffers
|
326
|
+
persistent_buffers = named_buffers
|
327
|
+
local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
|
328
|
+
local_state = local_name_params.select { |_, v| !v.nil? }
|
329
|
+
|
330
|
+
local_state.each do |name, param|
|
331
|
+
key = prefix + name
|
332
|
+
if state_dict.key?(key)
|
333
|
+
input_param = state_dict[key]
|
334
|
+
|
335
|
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
336
|
+
if param.shape.length == 0 && input_param.shape.length == 1
|
337
|
+
input_param = input_param[0]
|
338
|
+
end
|
339
|
+
|
340
|
+
if input_param.shape != param.shape
|
341
|
+
# local shape should match the one in checkpoint
|
342
|
+
error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
|
343
|
+
"the shape in current model is #{param.shape}."
|
344
|
+
next
|
345
|
+
end
|
346
|
+
|
347
|
+
begin
|
348
|
+
Torch.no_grad do
|
349
|
+
param.copy!(input_param)
|
350
|
+
end
|
351
|
+
rescue => e
|
352
|
+
error_msgs << "While copying the parameter named #{key.inspect}, " +
|
353
|
+
"whose dimensions in the model are #{param.size} and " +
|
354
|
+
"whose dimensions in the checkpoint are #{input_param.size}, " +
|
355
|
+
"an exception occurred: #{e.inspect}"
|
356
|
+
end
|
357
|
+
elsif strict
|
358
|
+
missing_keys << key
|
359
|
+
end
|
360
|
+
end
|
361
|
+
|
362
|
+
if strict
|
363
|
+
state_dict.each_key do |key|
|
364
|
+
if key.start_with?(prefix)
|
365
|
+
input_name = key[prefix.length..-1]
|
366
|
+
input_name = input_name.split(".", 2)[0]
|
367
|
+
if !named_children.key?(input_name) && !local_state.key?(input_name)
|
368
|
+
unexpected_keys << key
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
end
|
373
|
+
end
|
374
|
+
|
375
|
+
def save_to_state_dict(destination, prefix: "")
|
376
|
+
named_parameters(recurse: false).each do |k, v|
|
377
|
+
destination[prefix + k] = v
|
378
|
+
end
|
379
|
+
named_buffers.each do |k, v|
|
380
|
+
destination[prefix + k] = v
|
381
|
+
end
|
382
|
+
end
|
303
383
|
end
|
304
384
|
end
|
305
385
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.5.
|
4
|
+
version: 0.5.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2021-01-14 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -289,7 +289,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
289
289
|
- !ruby/object:Gem::Version
|
290
290
|
version: '0'
|
291
291
|
requirements: []
|
292
|
-
rubygems_version: 3.
|
292
|
+
rubygems_version: 3.2.3
|
293
293
|
signing_key:
|
294
294
|
specification_version: 4
|
295
295
|
summary: Deep learning for Ruby, powered by LibTorch
|