torch-rb 0.12.0 → 0.12.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/torch/ext.cpp +3 -0
- data/ext/torch/generator.cpp +50 -0
- data/ext/torch/ruby_arg_parser.cpp +2 -1
- data/ext/torch/ruby_arg_parser.h +1 -1
- data/ext/torch/utils.h +5 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ce853372191c85509a65417abeaa05c484976f24681246babf9bd00f8db16df1
|
4
|
+
data.tar.gz: 3327004180566a194c7de8288c260b8fe9487d80c25493d82e041cf4fc0062e2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d8b6b0f7bd8b79963931b6b28b6a6cee59be18b8f185e2acadc22487a27b5793edabf0fb8b80857c5dfc0eb036a27b67a55750b8dd3c8eb1d199f06323e2b919
|
7
|
+
data.tar.gz: ffcafd2e9e99d6654f9689dd021c874cf53847833e26207853b0051056ae7cfe0b5ff202036e9a38264198facfa2c66f83b6ee51815131b63aa350415139b614
|
data/CHANGELOG.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
|
|
12
12
|
void init_backends(Rice::Module& m);
|
13
13
|
void init_cuda(Rice::Module& m);
|
14
14
|
void init_device(Rice::Module& m);
|
15
|
+
void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
|
15
16
|
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
16
17
|
void init_random(Rice::Module& m);
|
17
18
|
|
@@ -23,6 +24,7 @@ void Init_ext()
|
|
23
24
|
// need to define certain classes up front to keep Rice happy
|
24
25
|
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
|
25
26
|
.define_constructor(Rice::Constructor<torch::IValue>());
|
27
|
+
auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
|
26
28
|
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
27
29
|
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
28
30
|
.define_constructor(Rice::Constructor<torch::TensorOptions>());
|
@@ -38,6 +40,7 @@ void Init_ext()
|
|
38
40
|
init_backends(m);
|
39
41
|
init_cuda(m);
|
40
42
|
init_device(m);
|
43
|
+
init_generator(m, rb_cGenerator);
|
41
44
|
init_ivalue(m, rb_cIValue);
|
42
45
|
init_random(m);
|
43
46
|
}
|
@@ -0,0 +1,50 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
|
8
|
+
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
|
9
|
+
rb_cGenerator
|
10
|
+
.add_handler<torch::Error>(handle_error)
|
11
|
+
.define_singleton_function(
|
12
|
+
"new",
|
13
|
+
[]() {
|
14
|
+
// TODO support more devices
|
15
|
+
return torch::make_generator<torch::CPUGeneratorImpl>();
|
16
|
+
})
|
17
|
+
.define_method(
|
18
|
+
"device",
|
19
|
+
[](torch::Generator& self) {
|
20
|
+
return self.device();
|
21
|
+
})
|
22
|
+
.define_method(
|
23
|
+
"initial_seed",
|
24
|
+
[](torch::Generator& self) {
|
25
|
+
return self.current_seed();
|
26
|
+
})
|
27
|
+
.define_method(
|
28
|
+
"manual_seed",
|
29
|
+
[](torch::Generator& self, uint64_t seed) {
|
30
|
+
self.set_current_seed(seed);
|
31
|
+
return self;
|
32
|
+
})
|
33
|
+
.define_method(
|
34
|
+
"seed",
|
35
|
+
[](torch::Generator& self) {
|
36
|
+
return self.seed();
|
37
|
+
})
|
38
|
+
.define_method(
|
39
|
+
"state",
|
40
|
+
[](torch::Generator& self) {
|
41
|
+
return self.get_state();
|
42
|
+
})
|
43
|
+
.define_method(
|
44
|
+
"state=",
|
45
|
+
[](torch::Generator& self, const torch::Tensor& state) {
|
46
|
+
self.set_state(state);
|
47
|
+
});
|
48
|
+
|
49
|
+
THPGeneratorClass = rb_cGenerator.value();
|
50
|
+
}
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
#include "ruby_arg_parser.h"
|
4
4
|
|
5
|
+
VALUE THPGeneratorClass = Qnil;
|
5
6
|
VALUE THPVariableClass = Qnil;
|
6
7
|
|
7
8
|
static std::unordered_map<std::string, ParameterType> type_map = {
|
@@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
244
245
|
return size > 0 && FIXNUM_P(obj);
|
245
246
|
}
|
246
247
|
case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
|
247
|
-
case ParameterType::GENERATOR: return
|
248
|
+
case ParameterType::GENERATOR: return THPGenerator_Check(obj);
|
248
249
|
case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
|
249
250
|
case ParameterType::STORAGE: return false; // return isStorage(obj);
|
250
251
|
// case ParameterType::PYOBJECT: return true;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
|
|
223
223
|
|
224
224
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
225
225
|
if (NIL_P(args[i])) return c10::nullopt;
|
226
|
-
|
226
|
+
return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
|
227
227
|
}
|
228
228
|
|
229
229
|
inline at::Storage RubyArgs::storage(int i) {
|
data/ext/torch/utils.h
CHANGED
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
|
|
17
17
|
|
18
18
|
// keep THP prefix for now to make it easier to compare code
|
19
19
|
|
20
|
+
extern VALUE THPGeneratorClass;
|
20
21
|
extern VALUE THPVariableClass;
|
21
22
|
|
22
23
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
@@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
|
|
44
45
|
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
45
46
|
}
|
46
47
|
|
48
|
+
inline bool THPGenerator_Check(VALUE obj) {
|
49
|
+
return rb_obj_is_kind_of(obj, THPGeneratorClass);
|
50
|
+
}
|
51
|
+
|
47
52
|
inline bool THPVariable_Check(VALUE obj) {
|
48
53
|
return rb_obj_is_kind_of(obj, THPVariableClass);
|
49
54
|
}
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.12.
|
4
|
+
version: 0.12.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-01-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -44,6 +44,7 @@ files:
|
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
46
|
- ext/torch/fft_functions.h
|
47
|
+
- ext/torch/generator.cpp
|
47
48
|
- ext/torch/ivalue.cpp
|
48
49
|
- ext/torch/linalg.cpp
|
49
50
|
- ext/torch/linalg_functions.h
|
@@ -229,7 +230,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
229
230
|
- !ruby/object:Gem::Version
|
230
231
|
version: '0'
|
231
232
|
requirements: []
|
232
|
-
rubygems_version: 3.
|
233
|
+
rubygems_version: 3.4.1
|
233
234
|
signing_key:
|
234
235
|
specification_version: 4
|
235
236
|
summary: Deep learning for Ruby, powered by LibTorch
|