torch-rb 0.2.5 → 0.2.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +5 -0
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/module.rb +1 -1
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +14 -4
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
|
4
|
+
data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
|
7
|
+
data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -329,6 +329,11 @@ void Init_ext()
|
|
329
329
|
.define_method("numel", &torch::Tensor::numel)
|
330
330
|
.define_method("element_size", &torch::Tensor::element_size)
|
331
331
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
332
|
+
.define_method(
|
333
|
+
"contiguous?",
|
334
|
+
*[](Tensor& self) {
|
335
|
+
return self.is_contiguous();
|
336
|
+
})
|
332
337
|
.define_method(
|
333
338
|
"addcmul!",
|
334
339
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["bool[3]", "Dimname", "
|
21
|
+
skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
@@ -31,7 +31,7 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
|
34
|
+
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
36
|
# native_functions.yaml is missing size argument for normal
|
37
37
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
@@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) {
|
|
112
112
|
"OptionalScalarType"
|
113
113
|
when "Tensor[]"
|
114
114
|
"TensorList"
|
115
|
+
when "Tensor?[]"
|
116
|
+
# TODO make optional
|
117
|
+
"TensorList"
|
115
118
|
when "int"
|
116
119
|
"int64_t"
|
117
120
|
when "float"
|
data/lib/torch/native/parser.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -145,7 +145,7 @@ module Torch
|
|
145
145
|
params = {}
|
146
146
|
if recurse
|
147
147
|
named_children.each do |name, mod|
|
148
|
-
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
148
|
+
params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
|
149
149
|
end
|
150
150
|
end
|
151
151
|
instance_variables.each do |name|
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -158,7 +158,8 @@ module Torch
|
|
158
158
|
item <=> other
|
159
159
|
end
|
160
160
|
|
161
|
-
# based on python_variable_indexing.cpp
|
161
|
+
# based on python_variable_indexing.cpp and
|
162
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
162
163
|
def [](*indexes)
|
163
164
|
result = self
|
164
165
|
dim = 0
|
@@ -170,6 +171,8 @@ module Torch
|
|
170
171
|
finish += 1 unless index.exclude_end?
|
171
172
|
result = result._slice_tensor(dim, index.begin, finish, 1)
|
172
173
|
dim += 1
|
174
|
+
elsif index.is_a?(Tensor)
|
175
|
+
result = result.index([index])
|
173
176
|
elsif index.nil?
|
174
177
|
result = result.unsqueeze(dim)
|
175
178
|
dim += 1
|
@@ -183,12 +186,12 @@ module Torch
|
|
183
186
|
result
|
184
187
|
end
|
185
188
|
|
186
|
-
#
|
187
|
-
#
|
189
|
+
# based on python_variable_indexing.cpp and
|
190
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
188
191
|
def []=(index, value)
|
189
192
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
190
193
|
|
191
|
-
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
194
|
+
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
192
195
|
|
193
196
|
if index.is_a?(Numeric)
|
194
197
|
copy_to(_select_int(0, index), value)
|
@@ -196,6 +199,8 @@ module Torch
|
|
196
199
|
finish = index.end
|
197
200
|
finish += 1 unless index.exclude_end?
|
198
201
|
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
202
|
+
elsif index.is_a?(Tensor)
|
203
|
+
index_put!([index], value)
|
199
204
|
else
|
200
205
|
raise Error, "Unsupported index type: #{index.class.name}"
|
201
206
|
end
|
@@ -224,6 +229,11 @@ module Torch
|
|
224
229
|
end
|
225
230
|
end
|
226
231
|
|
232
|
+
def clamp!(min, max)
|
233
|
+
_clamp_min_(min)
|
234
|
+
_clamp_max_(max)
|
235
|
+
end
|
236
|
+
|
227
237
|
private
|
228
238
|
|
229
239
|
def copy_to(dst, src)
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-06-
|
11
|
+
date: 2020-06-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|