torch-rb 0.2.5 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +5 -0
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/module.rb +1 -1
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +14 -4
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
|
4
|
+
data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
|
7
|
+
data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -329,6 +329,11 @@ void Init_ext()
|
|
329
329
|
.define_method("numel", &torch::Tensor::numel)
|
330
330
|
.define_method("element_size", &torch::Tensor::element_size)
|
331
331
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
332
|
+
.define_method(
|
333
|
+
"contiguous?",
|
334
|
+
*[](Tensor& self) {
|
335
|
+
return self.is_contiguous();
|
336
|
+
})
|
332
337
|
.define_method(
|
333
338
|
"addcmul!",
|
334
339
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["bool[3]", "Dimname", "
|
21
|
+
skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
@@ -31,7 +31,7 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
|
34
|
+
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
36
|
# native_functions.yaml is missing size argument for normal
|
37
37
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
@@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) {
|
|
112
112
|
"OptionalScalarType"
|
113
113
|
when "Tensor[]"
|
114
114
|
"TensorList"
|
115
|
+
when "Tensor?[]"
|
116
|
+
# TODO make optional
|
117
|
+
"TensorList"
|
115
118
|
when "int"
|
116
119
|
"int64_t"
|
117
120
|
when "float"
|
data/lib/torch/native/parser.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -145,7 +145,7 @@ module Torch
|
|
145
145
|
params = {}
|
146
146
|
if recurse
|
147
147
|
named_children.each do |name, mod|
|
148
|
-
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
148
|
+
params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
|
149
149
|
end
|
150
150
|
end
|
151
151
|
instance_variables.each do |name|
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -158,7 +158,8 @@ module Torch
|
|
158
158
|
item <=> other
|
159
159
|
end
|
160
160
|
|
161
|
-
# based on python_variable_indexing.cpp
|
161
|
+
# based on python_variable_indexing.cpp and
|
162
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
162
163
|
def [](*indexes)
|
163
164
|
result = self
|
164
165
|
dim = 0
|
@@ -170,6 +171,8 @@ module Torch
|
|
170
171
|
finish += 1 unless index.exclude_end?
|
171
172
|
result = result._slice_tensor(dim, index.begin, finish, 1)
|
172
173
|
dim += 1
|
174
|
+
elsif index.is_a?(Tensor)
|
175
|
+
result = result.index([index])
|
173
176
|
elsif index.nil?
|
174
177
|
result = result.unsqueeze(dim)
|
175
178
|
dim += 1
|
@@ -183,12 +186,12 @@ module Torch
|
|
183
186
|
result
|
184
187
|
end
|
185
188
|
|
186
|
-
#
|
187
|
-
#
|
189
|
+
# based on python_variable_indexing.cpp and
|
190
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
188
191
|
def []=(index, value)
|
189
192
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
190
193
|
|
191
|
-
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
194
|
+
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
192
195
|
|
193
196
|
if index.is_a?(Numeric)
|
194
197
|
copy_to(_select_int(0, index), value)
|
@@ -196,6 +199,8 @@ module Torch
|
|
196
199
|
finish = index.end
|
197
200
|
finish += 1 unless index.exclude_end?
|
198
201
|
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
202
|
+
elsif index.is_a?(Tensor)
|
203
|
+
index_put!([index], value)
|
199
204
|
else
|
200
205
|
raise Error, "Unsupported index type: #{index.class.name}"
|
201
206
|
end
|
@@ -224,6 +229,11 @@ module Torch
|
|
224
229
|
end
|
225
230
|
end
|
226
231
|
|
232
|
+
def clamp!(min, max)
|
233
|
+
_clamp_min_(min)
|
234
|
+
_clamp_max_(max)
|
235
|
+
end
|
236
|
+
|
227
237
|
private
|
228
238
|
|
229
239
|
def copy_to(dst, src)
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-06-
|
11
|
+
date: 2020-06-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|