torch-rb 0.8.2 → 0.8.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/lib/torch/nn/module.rb +28 -0
- data/lib/torch/nn/parameter.rb +6 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +1 -1
- data/lib/torch/nn/utils.rb +1 -5
- data/lib/torch/tensor.rb +6 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 19fb3b7bb14f4b45424591ac7d5f65c5290b372c6e54d65f36188869d267410d
|
4
|
+
data.tar.gz: c2c8878412b46febea4af13881169a74a577b4ac1a7de0e2d567058c06695e1a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6a11fd1965d00e43cfbccd658ee14c95f2a032253842a6d6ec8cb33e91fbeedae5594b2d641ccc9593d2bc967376cf4538d0d12d9ded7bbd1e816945a30d1b28
|
7
|
+
data.tar.gz: 5a6febd08b15c544fd4dfef13b56efa2aab8a80c7e171fdf434f9fa692e13381654fbe3c8ea021ac661a71d310bebe942a01161f910c2931bc8ee336fe237afe
|
data/CHANGELOG.md
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -280,6 +280,11 @@ module Torch
|
|
280
280
|
end
|
281
281
|
end
|
282
282
|
|
283
|
+
def deep_dup
|
284
|
+
memo = {}
|
285
|
+
dup_value(self, memo)
|
286
|
+
end
|
287
|
+
|
283
288
|
def method_missing(method, *args, &block)
|
284
289
|
name = method.to_s
|
285
290
|
if named_parameters.key?(name)
|
@@ -388,6 +393,29 @@ module Torch
|
|
388
393
|
destination[prefix + k] = v
|
389
394
|
end
|
390
395
|
end
|
396
|
+
|
397
|
+
# keep memo hash like Python deepcopy
|
398
|
+
# https://docs.python.org/3/library/copy.html
|
399
|
+
def dup_value(v, memo)
|
400
|
+
memo[v.object_id] ||= begin
|
401
|
+
case v
|
402
|
+
when Method, UnboundMethod
|
403
|
+
v
|
404
|
+
when Hash
|
405
|
+
v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
|
406
|
+
when Array
|
407
|
+
v.map { |v2| dup_value(v2, memo) }
|
408
|
+
when Torch::NN::Module
|
409
|
+
copy = v.dup
|
410
|
+
v.instance_variables.each do |var|
|
411
|
+
copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
|
412
|
+
end
|
413
|
+
copy
|
414
|
+
else
|
415
|
+
v.dup
|
416
|
+
end
|
417
|
+
end
|
418
|
+
end
|
391
419
|
end
|
392
420
|
end
|
393
421
|
end
|
data/lib/torch/nn/parameter.rb
CHANGED
data/lib/torch/nn/utils.rb
CHANGED
@@ -22,11 +22,7 @@ module Torch
|
|
22
22
|
end
|
23
23
|
|
24
24
|
def _clones(mod, n)
|
25
|
-
|
26
|
-
layers = n.times.map do |i|
|
27
|
-
mod.clone.tap { |l| l.load_state_dict(state) }
|
28
|
-
end
|
29
|
-
ModuleList.new(layers)
|
25
|
+
ModuleList.new(n.times.map { mod.deep_dup })
|
30
26
|
end
|
31
27
|
|
32
28
|
def _activation_fn(activation)
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.8.
|
4
|
+
version: 0.8.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-10-
|
11
|
+
date: 2021-10-17 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|