torch-rb 0.2.1 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +23 -3
- data/lib/torch/tensor.rb +3 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +8 -2
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8e1f9758c937519ca31d92f3acd35ce0372f8cf57362cd3b50bd45920e7a6763
|
4
|
+
data.tar.gz: 4d370857ee758694b974da0f5d0973a687181ef2fedcad5c583f446ceb67dda2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: b1cbb37019852bfdbfc45b28ac32924b4de313ce21112ffe8bb5ec91fe17898d3a1ceb42d77540e6b7a0d656e9443002fb72a1201453507dca4915db13879167
|
7
|
+
data.tar.gz: d1a35689c3ad6a0628485633af4f6d7f613288fbf739f9e94ccfdb72c613b0d4581a21e00aef3a309771967f65086034befc519bc89ecc7884ff1dd142a8289f
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -82,6 +82,16 @@ void Init_ext()
|
|
82
82
|
*[](torch::IValue& self) {
|
83
83
|
return self.toInt();
|
84
84
|
})
|
85
|
+
.define_method(
|
86
|
+
"to_list",
|
87
|
+
*[](torch::IValue& self) {
|
88
|
+
auto list = self.toListRef();
|
89
|
+
Array obj;
|
90
|
+
for (auto& elem : list) {
|
91
|
+
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
92
|
+
}
|
93
|
+
return obj;
|
94
|
+
})
|
85
95
|
.define_method(
|
86
96
|
"to_string_ref",
|
87
97
|
*[](torch::IValue& self) {
|
@@ -96,17 +106,27 @@ void Init_ext()
|
|
96
106
|
"to_generic_dict",
|
97
107
|
*[](torch::IValue& self) {
|
98
108
|
auto dict = self.toGenericDict();
|
99
|
-
Hash
|
109
|
+
Hash obj;
|
100
110
|
for (auto& pair : dict) {
|
101
|
-
|
111
|
+
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
102
112
|
}
|
103
|
-
return
|
113
|
+
return obj;
|
104
114
|
})
|
105
115
|
.define_singleton_method(
|
106
116
|
"from_tensor",
|
107
117
|
*[](torch::Tensor& v) {
|
108
118
|
return torch::IValue(v);
|
109
119
|
})
|
120
|
+
// TODO create specialized list types?
|
121
|
+
.define_singleton_method(
|
122
|
+
"from_list",
|
123
|
+
*[](Array obj) {
|
124
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
125
|
+
for (auto entry : obj) {
|
126
|
+
list.push_back(from_ruby<torch::IValue>(entry));
|
127
|
+
}
|
128
|
+
return torch::IValue(list);
|
129
|
+
})
|
110
130
|
.define_singleton_method(
|
111
131
|
"from_string",
|
112
132
|
*[](String v) {
|
data/lib/torch/tensor.rb
CHANGED
@@ -4,6 +4,8 @@ module Torch
|
|
4
4
|
include Inspector
|
5
5
|
|
6
6
|
alias_method :requires_grad?, :requires_grad
|
7
|
+
alias_method :ndim, :dim
|
8
|
+
alias_method :ndimension, :dim
|
7
9
|
|
8
10
|
def self.new(*args)
|
9
11
|
FloatTensor.new(*args)
|
@@ -143,6 +145,7 @@ module Torch
|
|
143
145
|
neg
|
144
146
|
end
|
145
147
|
|
148
|
+
# TODO better compare?
|
146
149
|
def <=>(other)
|
147
150
|
item <=> other
|
148
151
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -466,6 +466,12 @@ module Torch
|
|
466
466
|
IValue.from_bool(obj)
|
467
467
|
when nil
|
468
468
|
IValue.new
|
469
|
+
when Array
|
470
|
+
if obj.all? { |v| v.is_a?(Tensor) }
|
471
|
+
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
472
|
+
else
|
473
|
+
raise Error, "Unknown list type"
|
474
|
+
end
|
469
475
|
else
|
470
476
|
raise Error, "Unknown type: #{obj.class.name}"
|
471
477
|
end
|
@@ -490,6 +496,8 @@ module Torch
|
|
490
496
|
dict[to_ruby(k)] = to_ruby(v)
|
491
497
|
end
|
492
498
|
dict
|
499
|
+
elsif ivalue.list?
|
500
|
+
ivalue.to_list.map { |v| to_ruby(v) }
|
493
501
|
else
|
494
502
|
type =
|
495
503
|
if ivalue.capsule?
|
@@ -510,8 +518,6 @@ module Torch
|
|
510
518
|
"BoolList"
|
511
519
|
elsif ivalue.tensor_list?
|
512
520
|
"TensorList"
|
513
|
-
elsif ivalue.list?
|
514
|
-
"List"
|
515
521
|
elsif ivalue.object?
|
516
522
|
"Object"
|
517
523
|
elsif ivalue.module?
|