torch-rb 0.5.2 → 0.5.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: af82b518434a574d47ba1e9d4c785e4d295622b0092f50e6d638a88396601ef1
4
- data.tar.gz: dbc5b4ad9e5fd1a907688f12feb9e1a01d8bab852b68a1712cd70d655ed4833f
3
+ metadata.gz: fa1967c6330c6ad818e5452d38c5f83b2dde901c33d03f2a442bc003e4e662be
4
+ data.tar.gz: 00d57cb1de4e6cec72986cab301b04b66ce53aa1d7a0e84eadbc25ac924d28eb
5
5
  SHA512:
6
- metadata.gz: 8e372efecb07c2120e03c04e96e9b9d9fd8e2242b35527e101f545397cbd9649b1fd3ccaddaf42b0cdfb6c6f8aa70d529c828a9a9a8fa248bed1ff1bc83d7d5a
7
- data.tar.gz: a8ddbd40b7f0b70ee904098285d1c1c02ca51e2d5be46aec7d423985b04f085e9f92aa5d7f1e9dd5785b22898277ff0739b01ab6bee32235273f70c5fd76c74c
6
+ metadata.gz: ae015b2ea53033df631d914c7e6fc70bdff216742900d6a9a76d5e4479cc8faf05becf2ab8ee9edc874092e36355afecbe6c08f5205678be6ec01f0f0783a469
7
+ data.tar.gz: c751eaa8253682bc95653aa45cc61a05c9f120fa6df88c3d69873a43f47ca7da07e45c3a1e1c5f36bc63da69fae0dc9626f9db2f20af9d6ef190dbcce60bfb75
@@ -1,3 +1,9 @@
1
+ ## 0.5.3 (2021-01-14)
2
+
3
+ - Added `manual_seed` and `manual_seed_all` for CUDA
4
+ - Improved saving and loading models
5
+ - Fixed error with tensor indexing with beginless ranges in Ruby 3.0
6
+
1
7
  ## 0.5.2 (2020-10-29)
2
8
 
3
9
  - Fixed `undefined symbol` error with CUDA
data/README.md CHANGED
@@ -8,7 +8,7 @@ Check out:
8
8
  - [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
9
9
  - [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
10
10
 
11
- [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
11
+ [![Build Status](https://github.com/ankane/torch.rb/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torch.rb/actions)
12
12
 
13
13
  ## Installation
14
14
 
@@ -319,6 +319,13 @@ net.load_state_dict(Torch.load("net.pth"))
319
319
  net.eval
320
320
  ```
321
321
 
322
+ When saving a model in Python to load in Ruby, convert parameters to tensors (due to outstanding bugs in LibTorch)
323
+
324
+ ```python
325
+ state_dict = {k: v.data if isinstance(v, torch.nn.Parameter) else v for k, v in state_dict.items()}
326
+ torch.save(state_dict, "net.pth")
327
+ ```
328
+
322
329
  ### Tensor Creation
323
330
 
324
331
  Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
@@ -416,7 +423,7 @@ Here’s the list of compatible versions.
416
423
 
417
424
  Torch.rb | LibTorch
418
425
  --- | ---
419
- 0.5.0+ | 1.7.0
426
+ 0.5.0+ | 1.7.0-1.7.1
420
427
  0.3.0+ | 1.6.0
421
428
  0.2.0-0.2.7 | 1.5.0-1.5.1
422
429
  0.1.8 | 1.4.0
@@ -44,8 +44,13 @@ std::vector<TensorIndex> index_vector(Array a) {
44
44
  if (obj.is_instance_of(rb_cInteger)) {
45
45
  indices.push_back(from_ruby<int64_t>(obj));
46
46
  } else if (obj.is_instance_of(rb_cRange)) {
47
- torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
48
- torch::optional<int64_t> stop_index = -1;
47
+ torch::optional<int64_t> start_index = torch::nullopt;
48
+ torch::optional<int64_t> stop_index = torch::nullopt;
49
+
50
+ Object begin = obj.call("begin");
51
+ if (!begin.is_nil()) {
52
+ start_index = from_ruby<int64_t>(begin);
53
+ }
49
54
 
50
55
  Object end = obj.call("end");
51
56
  if (!end.is_nil()) {
@@ -53,12 +58,14 @@ std::vector<TensorIndex> index_vector(Array a) {
53
58
  }
54
59
 
55
60
  Object exclude_end = obj.call("exclude_end?");
56
- if (!exclude_end) {
61
+ if (stop_index.has_value() && !exclude_end) {
57
62
  if (stop_index.value() == -1) {
58
63
  stop_index = torch::nullopt;
59
64
  } else {
60
65
  stop_index = stop_index.value() + 1;
61
66
  }
67
+ } else if (!stop_index.has_value() && exclude_end) {
68
+ stop_index = -1;
62
69
  }
63
70
 
64
71
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
@@ -618,5 +625,7 @@ void Init_ext()
618
625
  Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
619
626
  .add_handler<torch::Error>(handle_error)
620
627
  .define_singleton_method("available?", &torch::cuda::is_available)
621
- .define_singleton_method("device_count", &torch::cuda::device_count);
628
+ .define_singleton_method("device_count", &torch::cuda::device_count)
629
+ .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
630
+ .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
622
631
  }
@@ -113,35 +113,53 @@ module Torch
113
113
  forward(*input, **kwargs)
114
114
  end
115
115
 
116
- def state_dict(destination: nil)
116
+ def state_dict(destination: nil, prefix: "")
117
117
  destination ||= {}
118
- named_parameters.each do |k, v|
119
- destination[k] = v
118
+ save_to_state_dict(destination, prefix: prefix)
119
+
120
+ named_children.each do |name, mod|
121
+ next unless mod
122
+ mod.state_dict(destination: destination, prefix: prefix + name + ".")
120
123
  end
121
124
  destination
122
125
  end
123
126
 
124
- # TODO add strict option
125
- # TODO match PyTorch behavior
126
- def load_state_dict(state_dict)
127
- state_dict.each do |k, input_param|
128
- k1, k2 = k.split(".", 2)
129
- mod = named_modules[k1]
130
- if mod.is_a?(Module)
131
- param = mod.named_parameters[k2]
132
- if param.is_a?(Parameter)
133
- Torch.no_grad do
134
- param.copy!(input_param)
135
- end
136
- else
137
- raise Error, "Unknown parameter: #{k1}"
138
- end
139
- else
140
- raise Error, "Unknown module: #{k1}"
127
+ def load_state_dict(state_dict, strict: true)
128
+ # TODO support strict: false
129
+ raise "strict: false not implemented yet" unless strict
130
+
131
+ missing_keys = []
132
+ unexpected_keys = []
133
+ error_msgs = []
134
+
135
+ # TODO handle metadata
136
+
137
+ _load = lambda do |mod, prefix = ""|
138
+ # TODO handle metadata
139
+ local_metadata = {}
140
+ mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
141
+ mod.named_children.each do |name, child|
142
+ _load.call(child, prefix + name + ".") unless child.nil?
141
143
  end
142
144
  end
143
145
 
144
- # TODO return missing keys and unexpected keys
146
+ _load.call(self)
147
+
148
+ if strict
149
+ if unexpected_keys.any?
150
+ error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
151
+ end
152
+
153
+ if missing_keys.any?
154
+ error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
155
+ end
156
+ end
157
+
158
+ if error_msgs.any?
159
+ # just show first error
160
+ raise Error, error_msgs[0]
161
+ end
162
+
145
163
  nil
146
164
  end
147
165
 
@@ -300,6 +318,68 @@ module Torch
300
318
  def dict
301
319
  instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
302
320
  end
321
+
322
+ def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
323
+ # TODO add hooks
324
+
325
+ # TODO handle non-persistent buffers
326
+ persistent_buffers = named_buffers
327
+ local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
328
+ local_state = local_name_params.select { |_, v| !v.nil? }
329
+
330
+ local_state.each do |name, param|
331
+ key = prefix + name
332
+ if state_dict.key?(key)
333
+ input_param = state_dict[key]
334
+
335
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
336
+ if param.shape.length == 0 && input_param.shape.length == 1
337
+ input_param = input_param[0]
338
+ end
339
+
340
+ if input_param.shape != param.shape
341
+ # local shape should match the one in checkpoint
342
+ error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
343
+ "the shape in current model is #{param.shape}."
344
+ next
345
+ end
346
+
347
+ begin
348
+ Torch.no_grad do
349
+ param.copy!(input_param)
350
+ end
351
+ rescue => e
352
+ error_msgs << "While copying the parameter named #{key.inspect}, " +
353
+ "whose dimensions in the model are #{param.size} and " +
354
+ "whose dimensions in the checkpoint are #{input_param.size}, " +
355
+ "an exception occurred: #{e.inspect}"
356
+ end
357
+ elsif strict
358
+ missing_keys << key
359
+ end
360
+ end
361
+
362
+ if strict
363
+ state_dict.each_key do |key|
364
+ if key.start_with?(prefix)
365
+ input_name = key[prefix.length..-1]
366
+ input_name = input_name.split(".", 2)[0]
367
+ if !named_children.key?(input_name) && !local_state.key?(input_name)
368
+ unexpected_keys << key
369
+ end
370
+ end
371
+ end
372
+ end
373
+ end
374
+
375
+ def save_to_state_dict(destination, prefix: "")
376
+ named_parameters(recurse: false).each do |k, v|
377
+ destination[prefix + k] = v
378
+ end
379
+ named_buffers.each do |k, v|
380
+ destination[prefix + k] = v
381
+ end
382
+ end
303
383
  end
304
384
  end
305
385
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.5.2"
2
+ VERSION = "0.5.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.2
4
+ version: 0.5.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-10-29 00:00:00.000000000 Z
11
+ date: 2021-01-14 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -289,7 +289,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
289
289
  - !ruby/object:Gem::Version
290
290
  version: '0'
291
291
  requirements: []
292
- rubygems_version: 3.1.4
292
+ rubygems_version: 3.2.3
293
293
  signing_key:
294
294
  specification_version: 4
295
295
  summary: Deep learning for Ruby, powered by LibTorch