torch-rb 0.8.2 → 0.8.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '05811faa93ab089485bfa213362bfed0462227e6964e726bf1c3f9fc0cdba0c3'
4
- data.tar.gz: 8d25304063db51850e535e71c2b4fe643e038a18d046b8f27ee50269e5e9695d
3
+ metadata.gz: 19fb3b7bb14f4b45424591ac7d5f65c5290b372c6e54d65f36188869d267410d
4
+ data.tar.gz: c2c8878412b46febea4af13881169a74a577b4ac1a7de0e2d567058c06695e1a
5
5
  SHA512:
6
- metadata.gz: 8cea906b03b37ec848be7b1c7cfa6bfb0fde4ef7ed384818bcc85826d04611621835bbeecad0fd31c94b497bb527a21c107901a9d991692b6fffa6bb24c23c38
7
- data.tar.gz: 2ded65d614d274afe61e061898268172e6dec85dc28e3094481c2650713a129c0574b7ba17dbafccd8b8dc7ee064602fd261ce5034cec4a184fa32f2965eb476
6
+ metadata.gz: 6a11fd1965d00e43cfbccd658ee14c95f2a032253842a6d6ec8cb33e91fbeedae5594b2d641ccc9593d2bc967376cf4538d0d12d9ded7bbd1e816945a30d1b28
7
+ data.tar.gz: 5a6febd08b15c544fd4dfef13b56efa2aab8a80c7e171fdf434f9fa692e13381654fbe3c8ea021ac661a71d310bebe942a01161f910c2931bc8ee336fe237afe
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.8.3 (2021-10-17)
2
+
3
+ - Fixed `dup` method for tensors and parameters
4
+ - Fixed issues with transformers
5
+
1
6
  ## 0.8.2 (2021-10-03)
2
7
 
3
8
  - Added transformers
@@ -280,6 +280,11 @@ module Torch
280
280
  end
281
281
  end
282
282
 
283
+ def deep_dup
284
+ memo = {}
285
+ dup_value(self, memo)
286
+ end
287
+
283
288
  def method_missing(method, *args, &block)
284
289
  name = method.to_s
285
290
  if named_parameters.key?(name)
@@ -388,6 +393,29 @@ module Torch
388
393
  destination[prefix + k] = v
389
394
  end
390
395
  end
396
+
397
+ # keep memo hash like Python deepcopy
398
+ # https://docs.python.org/3/library/copy.html
399
+ def dup_value(v, memo)
400
+ memo[v.object_id] ||= begin
401
+ case v
402
+ when Method, UnboundMethod
403
+ v
404
+ when Hash
405
+ v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
406
+ when Array
407
+ v.map { |v2| dup_value(v2, memo) }
408
+ when Torch::NN::Module
409
+ copy = v.dup
410
+ v.instance_variables.each do |var|
411
+ copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
412
+ end
413
+ copy
414
+ else
415
+ v.dup
416
+ end
417
+ end
418
+ end
391
419
  end
392
420
  end
393
421
  end
@@ -9,6 +9,12 @@ module Torch
9
9
  def inspect
10
10
  "Parameter containing:\n#{super}"
11
11
  end
12
+
13
+ def dup
14
+ Torch.no_grad do
15
+ Parameter.new(clone, requires_grad: requires_grad)
16
+ end
17
+ end
12
18
  end
13
19
  end
14
20
  end
@@ -35,7 +35,7 @@ module Torch
35
35
  tgt += @dropout2.(tgt2)
36
36
  tgt = @norm2.(tgt)
37
37
  tgt2 = @linear2.(@dropout.(@activation.(@linear1.(tgt))))
38
- tgt += @dropout3.(tgt)
38
+ tgt += @dropout3.(tgt2)
39
39
  @norm3.(tgt)
40
40
  end
41
41
  end
@@ -22,11 +22,7 @@ module Torch
22
22
  end
23
23
 
24
24
  def _clones(mod, n)
25
- state = mod.state_dict
26
- layers = n.times.map do |i|
27
- mod.clone.tap { |l| l.load_state_dict(state) }
28
- end
29
- ModuleList.new(layers)
25
+ ModuleList.new(n.times.map { mod.deep_dup })
30
26
  end
31
27
 
32
28
  def _activation_fn(activation)
data/lib/torch/tensor.rb CHANGED
@@ -185,5 +185,11 @@ module Torch
185
185
  def stft(*args)
186
186
  Torch.stft(*args)
187
187
  end
188
+
189
+ def dup
190
+ Torch.no_grad do
191
+ clone
192
+ end
193
+ end
188
194
  end
189
195
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.2"
2
+ VERSION = "0.8.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.2
4
+ version: 0.8.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-10-04 00:00:00.000000000 Z
11
+ date: 2021-10-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice