torch-rb 0.8.2 → 0.8.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '05811faa93ab089485bfa213362bfed0462227e6964e726bf1c3f9fc0cdba0c3'
4
- data.tar.gz: 8d25304063db51850e535e71c2b4fe643e038a18d046b8f27ee50269e5e9695d
3
+ metadata.gz: 19fb3b7bb14f4b45424591ac7d5f65c5290b372c6e54d65f36188869d267410d
4
+ data.tar.gz: c2c8878412b46febea4af13881169a74a577b4ac1a7de0e2d567058c06695e1a
5
5
  SHA512:
6
- metadata.gz: 8cea906b03b37ec848be7b1c7cfa6bfb0fde4ef7ed384818bcc85826d04611621835bbeecad0fd31c94b497bb527a21c107901a9d991692b6fffa6bb24c23c38
7
- data.tar.gz: 2ded65d614d274afe61e061898268172e6dec85dc28e3094481c2650713a129c0574b7ba17dbafccd8b8dc7ee064602fd261ce5034cec4a184fa32f2965eb476
6
+ metadata.gz: 6a11fd1965d00e43cfbccd658ee14c95f2a032253842a6d6ec8cb33e91fbeedae5594b2d641ccc9593d2bc967376cf4538d0d12d9ded7bbd1e816945a30d1b28
7
+ data.tar.gz: 5a6febd08b15c544fd4dfef13b56efa2aab8a80c7e171fdf434f9fa692e13381654fbe3c8ea021ac661a71d310bebe942a01161f910c2931bc8ee336fe237afe
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.8.3 (2021-10-17)
2
+
3
+ - Fixed `dup` method for tensors and parameters
4
+ - Fixed issues with transformers
5
+
1
6
  ## 0.8.2 (2021-10-03)
2
7
 
3
8
  - Added transformers
@@ -280,6 +280,11 @@ module Torch
280
280
  end
281
281
  end
282
282
 
283
+ def deep_dup
284
+ memo = {}
285
+ dup_value(self, memo)
286
+ end
287
+
283
288
  def method_missing(method, *args, &block)
284
289
  name = method.to_s
285
290
  if named_parameters.key?(name)
@@ -388,6 +393,29 @@ module Torch
388
393
  destination[prefix + k] = v
389
394
  end
390
395
  end
396
+
397
+ # keep memo hash like Python deepcopy
398
+ # https://docs.python.org/3/library/copy.html
399
+ def dup_value(v, memo)
400
+ memo[v.object_id] ||= begin
401
+ case v
402
+ when Method, UnboundMethod
403
+ v
404
+ when Hash
405
+ v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
406
+ when Array
407
+ v.map { |v2| dup_value(v2, memo) }
408
+ when Torch::NN::Module
409
+ copy = v.dup
410
+ v.instance_variables.each do |var|
411
+ copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
412
+ end
413
+ copy
414
+ else
415
+ v.dup
416
+ end
417
+ end
418
+ end
391
419
  end
392
420
  end
393
421
  end
@@ -9,6 +9,12 @@ module Torch
9
9
  def inspect
10
10
  "Parameter containing:\n#{super}"
11
11
  end
12
+
13
+ def dup
14
+ Torch.no_grad do
15
+ Parameter.new(clone, requires_grad: requires_grad)
16
+ end
17
+ end
12
18
  end
13
19
  end
14
20
  end
@@ -35,7 +35,7 @@ module Torch
35
35
  tgt += @dropout2.(tgt2)
36
36
  tgt = @norm2.(tgt)
37
37
  tgt2 = @linear2.(@dropout.(@activation.(@linear1.(tgt))))
38
- tgt += @dropout3.(tgt)
38
+ tgt += @dropout3.(tgt2)
39
39
  @norm3.(tgt)
40
40
  end
41
41
  end
@@ -22,11 +22,7 @@ module Torch
22
22
  end
23
23
 
24
24
  def _clones(mod, n)
25
- state = mod.state_dict
26
- layers = n.times.map do |i|
27
- mod.clone.tap { |l| l.load_state_dict(state) }
28
- end
29
- ModuleList.new(layers)
25
+ ModuleList.new(n.times.map { mod.deep_dup })
30
26
  end
31
27
 
32
28
  def _activation_fn(activation)
data/lib/torch/tensor.rb CHANGED
@@ -185,5 +185,11 @@ module Torch
185
185
  def stft(*args)
186
186
  Torch.stft(*args)
187
187
  end
188
+
189
+ def dup
190
+ Torch.no_grad do
191
+ clone
192
+ end
193
+ end
188
194
  end
189
195
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.2"
2
+ VERSION = "0.8.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.2
4
+ version: 0.8.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-10-04 00:00:00.000000000 Z
11
+ date: 2021-10-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice