torch-rb 0.8.2 → 0.8.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/lib/torch/nn/module.rb +28 -0
- data/lib/torch/nn/parameter.rb +6 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +1 -1
- data/lib/torch/nn/utils.rb +1 -5
- data/lib/torch/tensor.rb +6 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 19fb3b7bb14f4b45424591ac7d5f65c5290b372c6e54d65f36188869d267410d
|
4
|
+
data.tar.gz: c2c8878412b46febea4af13881169a74a577b4ac1a7de0e2d567058c06695e1a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6a11fd1965d00e43cfbccd658ee14c95f2a032253842a6d6ec8cb33e91fbeedae5594b2d641ccc9593d2bc967376cf4538d0d12d9ded7bbd1e816945a30d1b28
|
7
|
+
data.tar.gz: 5a6febd08b15c544fd4dfef13b56efa2aab8a80c7e171fdf434f9fa692e13381654fbe3c8ea021ac661a71d310bebe942a01161f910c2931bc8ee336fe237afe
|
data/CHANGELOG.md
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -280,6 +280,11 @@ module Torch
|
|
280
280
|
end
|
281
281
|
end
|
282
282
|
|
283
|
+
def deep_dup
|
284
|
+
memo = {}
|
285
|
+
dup_value(self, memo)
|
286
|
+
end
|
287
|
+
|
283
288
|
def method_missing(method, *args, &block)
|
284
289
|
name = method.to_s
|
285
290
|
if named_parameters.key?(name)
|
@@ -388,6 +393,29 @@ module Torch
|
|
388
393
|
destination[prefix + k] = v
|
389
394
|
end
|
390
395
|
end
|
396
|
+
|
397
|
+
# keep memo hash like Python deepcopy
|
398
|
+
# https://docs.python.org/3/library/copy.html
|
399
|
+
def dup_value(v, memo)
|
400
|
+
memo[v.object_id] ||= begin
|
401
|
+
case v
|
402
|
+
when Method, UnboundMethod
|
403
|
+
v
|
404
|
+
when Hash
|
405
|
+
v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
|
406
|
+
when Array
|
407
|
+
v.map { |v2| dup_value(v2, memo) }
|
408
|
+
when Torch::NN::Module
|
409
|
+
copy = v.dup
|
410
|
+
v.instance_variables.each do |var|
|
411
|
+
copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
|
412
|
+
end
|
413
|
+
copy
|
414
|
+
else
|
415
|
+
v.dup
|
416
|
+
end
|
417
|
+
end
|
418
|
+
end
|
391
419
|
end
|
392
420
|
end
|
393
421
|
end
|
data/lib/torch/nn/parameter.rb
CHANGED
data/lib/torch/nn/utils.rb
CHANGED
@@ -22,11 +22,7 @@ module Torch
|
|
22
22
|
end
|
23
23
|
|
24
24
|
def _clones(mod, n)
|
25
|
-
|
26
|
-
layers = n.times.map do |i|
|
27
|
-
mod.clone.tap { |l| l.load_state_dict(state) }
|
28
|
-
end
|
29
|
-
ModuleList.new(layers)
|
25
|
+
ModuleList.new(n.times.map { mod.deep_dup })
|
30
26
|
end
|
31
27
|
|
32
28
|
def _activation_fn(activation)
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.8.
|
4
|
+
version: 0.8.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-10-
|
11
|
+
date: 2021-10-17 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|