torch-rb 0.2.1 → 0.2.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +23 -3
- data/lib/torch/tensor.rb +3 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +8 -2
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8e1f9758c937519ca31d92f3acd35ce0372f8cf57362cd3b50bd45920e7a6763
|
4
|
+
data.tar.gz: 4d370857ee758694b974da0f5d0973a687181ef2fedcad5c583f446ceb67dda2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: b1cbb37019852bfdbfc45b28ac32924b4de313ce21112ffe8bb5ec91fe17898d3a1ceb42d77540e6b7a0d656e9443002fb72a1201453507dca4915db13879167
|
7
|
+
data.tar.gz: d1a35689c3ad6a0628485633af4f6d7f613288fbf739f9e94ccfdb72c613b0d4581a21e00aef3a309771967f65086034befc519bc89ecc7884ff1dd142a8289f
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -82,6 +82,16 @@ void Init_ext()
|
|
82
82
|
*[](torch::IValue& self) {
|
83
83
|
return self.toInt();
|
84
84
|
})
|
85
|
+
.define_method(
|
86
|
+
"to_list",
|
87
|
+
*[](torch::IValue& self) {
|
88
|
+
auto list = self.toListRef();
|
89
|
+
Array obj;
|
90
|
+
for (auto& elem : list) {
|
91
|
+
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
92
|
+
}
|
93
|
+
return obj;
|
94
|
+
})
|
85
95
|
.define_method(
|
86
96
|
"to_string_ref",
|
87
97
|
*[](torch::IValue& self) {
|
@@ -96,17 +106,27 @@ void Init_ext()
|
|
96
106
|
"to_generic_dict",
|
97
107
|
*[](torch::IValue& self) {
|
98
108
|
auto dict = self.toGenericDict();
|
99
|
-
Hash
|
109
|
+
Hash obj;
|
100
110
|
for (auto& pair : dict) {
|
101
|
-
|
111
|
+
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
102
112
|
}
|
103
|
-
return
|
113
|
+
return obj;
|
104
114
|
})
|
105
115
|
.define_singleton_method(
|
106
116
|
"from_tensor",
|
107
117
|
*[](torch::Tensor& v) {
|
108
118
|
return torch::IValue(v);
|
109
119
|
})
|
120
|
+
// TODO create specialized list types?
|
121
|
+
.define_singleton_method(
|
122
|
+
"from_list",
|
123
|
+
*[](Array obj) {
|
124
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
125
|
+
for (auto entry : obj) {
|
126
|
+
list.push_back(from_ruby<torch::IValue>(entry));
|
127
|
+
}
|
128
|
+
return torch::IValue(list);
|
129
|
+
})
|
110
130
|
.define_singleton_method(
|
111
131
|
"from_string",
|
112
132
|
*[](String v) {
|
data/lib/torch/tensor.rb
CHANGED
@@ -4,6 +4,8 @@ module Torch
|
|
4
4
|
include Inspector
|
5
5
|
|
6
6
|
alias_method :requires_grad?, :requires_grad
|
7
|
+
alias_method :ndim, :dim
|
8
|
+
alias_method :ndimension, :dim
|
7
9
|
|
8
10
|
def self.new(*args)
|
9
11
|
FloatTensor.new(*args)
|
@@ -143,6 +145,7 @@ module Torch
|
|
143
145
|
neg
|
144
146
|
end
|
145
147
|
|
148
|
+
# TODO better compare?
|
146
149
|
def <=>(other)
|
147
150
|
item <=> other
|
148
151
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -466,6 +466,12 @@ module Torch
|
|
466
466
|
IValue.from_bool(obj)
|
467
467
|
when nil
|
468
468
|
IValue.new
|
469
|
+
when Array
|
470
|
+
if obj.all? { |v| v.is_a?(Tensor) }
|
471
|
+
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
472
|
+
else
|
473
|
+
raise Error, "Unknown list type"
|
474
|
+
end
|
469
475
|
else
|
470
476
|
raise Error, "Unknown type: #{obj.class.name}"
|
471
477
|
end
|
@@ -490,6 +496,8 @@ module Torch
|
|
490
496
|
dict[to_ruby(k)] = to_ruby(v)
|
491
497
|
end
|
492
498
|
dict
|
499
|
+
elsif ivalue.list?
|
500
|
+
ivalue.to_list.map { |v| to_ruby(v) }
|
493
501
|
else
|
494
502
|
type =
|
495
503
|
if ivalue.capsule?
|
@@ -510,8 +518,6 @@ module Torch
|
|
510
518
|
"BoolList"
|
511
519
|
elsif ivalue.tensor_list?
|
512
520
|
"TensorList"
|
513
|
-
elsif ivalue.list?
|
514
|
-
"List"
|
515
521
|
elsif ivalue.object?
|
516
522
|
"Object"
|
517
523
|
elsif ivalue.module?
|