torch-rb 0.9.2 → 0.10.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -1
- data/codegen/function.rb +2 -2
- data/codegen/generate_functions.rb +5 -1
- data/codegen/native_functions.yaml +951 -362
- data/ext/torch/sparse_functions.h +6 -0
- data/lib/torch/nn/parameter_list.rb +48 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -0
- metadata +5 -3
@@ -0,0 +1,48 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ParameterList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(parameters)
|
7
|
+
super()
|
8
|
+
@initialized = true
|
9
|
+
unless parameters.nil?
|
10
|
+
concat(parameters)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@parameters.length
|
16
|
+
end
|
17
|
+
alias_method :count, :length
|
18
|
+
alias_method :size, :length
|
19
|
+
|
20
|
+
def concat(parameters)
|
21
|
+
unless parameters.is_a?(Enumerable)
|
22
|
+
raise TypeError, "ParameterList#concat should be called with an enumerable, but got #{parameters.class.name}"
|
23
|
+
end
|
24
|
+
offset = length
|
25
|
+
parameters.each_with_index do |param, i|
|
26
|
+
register_parameter((offset + i).to_s, param)
|
27
|
+
end
|
28
|
+
self
|
29
|
+
end
|
30
|
+
|
31
|
+
def each(&block)
|
32
|
+
if block_given?
|
33
|
+
@parameters.values.each(&block)
|
34
|
+
else
|
35
|
+
to_enum(:each)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def [](idx)
|
40
|
+
if idx.is_a?(Range)
|
41
|
+
self.class.new(@parameters.values[idx])
|
42
|
+
else
|
43
|
+
@parameters[idx.to_s]
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.10.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-03-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -52,6 +52,7 @@ files:
|
|
52
52
|
- ext/torch/random.cpp
|
53
53
|
- ext/torch/ruby_arg_parser.cpp
|
54
54
|
- ext/torch/ruby_arg_parser.h
|
55
|
+
- ext/torch/sparse_functions.h
|
55
56
|
- ext/torch/special.cpp
|
56
57
|
- ext/torch/special_functions.h
|
57
58
|
- ext/torch/templates.h
|
@@ -149,6 +150,7 @@ files:
|
|
149
150
|
- lib/torch/nn/nll_loss.rb
|
150
151
|
- lib/torch/nn/pairwise_distance.rb
|
151
152
|
- lib/torch/nn/parameter.rb
|
153
|
+
- lib/torch/nn/parameter_list.rb
|
152
154
|
- lib/torch/nn/poisson_nll_loss.rb
|
153
155
|
- lib/torch/nn/prelu.rb
|
154
156
|
- lib/torch/nn/reflection_pad1d.rb
|
@@ -227,7 +229,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
227
229
|
- !ruby/object:Gem::Version
|
228
230
|
version: '0'
|
229
231
|
requirements: []
|
230
|
-
rubygems_version: 3.3.
|
232
|
+
rubygems_version: 3.3.7
|
231
233
|
signing_key:
|
232
234
|
specification_version: 4
|
233
235
|
summary: Deep learning for Ruby, powered by LibTorch
|