torch-rb 0.12.0 → 0.12.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/torch/ext.cpp +3 -0
- data/ext/torch/generator.cpp +50 -0
- data/ext/torch/ruby_arg_parser.cpp +2 -1
- data/ext/torch/ruby_arg_parser.h +1 -1
- data/ext/torch/utils.h +5 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ce853372191c85509a65417abeaa05c484976f24681246babf9bd00f8db16df1
|
4
|
+
data.tar.gz: 3327004180566a194c7de8288c260b8fe9487d80c25493d82e041cf4fc0062e2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d8b6b0f7bd8b79963931b6b28b6a6cee59be18b8f185e2acadc22487a27b5793edabf0fb8b80857c5dfc0eb036a27b67a55750b8dd3c8eb1d199f06323e2b919
|
7
|
+
data.tar.gz: ffcafd2e9e99d6654f9689dd021c874cf53847833e26207853b0051056ae7cfe0b5ff202036e9a38264198facfa2c66f83b6ee51815131b63aa350415139b614
|
data/CHANGELOG.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
|
|
12
12
|
void init_backends(Rice::Module& m);
|
13
13
|
void init_cuda(Rice::Module& m);
|
14
14
|
void init_device(Rice::Module& m);
|
15
|
+
void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
|
15
16
|
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
16
17
|
void init_random(Rice::Module& m);
|
17
18
|
|
@@ -23,6 +24,7 @@ void Init_ext()
|
|
23
24
|
// need to define certain classes up front to keep Rice happy
|
24
25
|
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
|
25
26
|
.define_constructor(Rice::Constructor<torch::IValue>());
|
27
|
+
auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
|
26
28
|
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
27
29
|
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
28
30
|
.define_constructor(Rice::Constructor<torch::TensorOptions>());
|
@@ -38,6 +40,7 @@ void Init_ext()
|
|
38
40
|
init_backends(m);
|
39
41
|
init_cuda(m);
|
40
42
|
init_device(m);
|
43
|
+
init_generator(m, rb_cGenerator);
|
41
44
|
init_ivalue(m, rb_cIValue);
|
42
45
|
init_random(m);
|
43
46
|
}
|
@@ -0,0 +1,50 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
|
8
|
+
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
|
9
|
+
rb_cGenerator
|
10
|
+
.add_handler<torch::Error>(handle_error)
|
11
|
+
.define_singleton_function(
|
12
|
+
"new",
|
13
|
+
[]() {
|
14
|
+
// TODO support more devices
|
15
|
+
return torch::make_generator<torch::CPUGeneratorImpl>();
|
16
|
+
})
|
17
|
+
.define_method(
|
18
|
+
"device",
|
19
|
+
[](torch::Generator& self) {
|
20
|
+
return self.device();
|
21
|
+
})
|
22
|
+
.define_method(
|
23
|
+
"initial_seed",
|
24
|
+
[](torch::Generator& self) {
|
25
|
+
return self.current_seed();
|
26
|
+
})
|
27
|
+
.define_method(
|
28
|
+
"manual_seed",
|
29
|
+
[](torch::Generator& self, uint64_t seed) {
|
30
|
+
self.set_current_seed(seed);
|
31
|
+
return self;
|
32
|
+
})
|
33
|
+
.define_method(
|
34
|
+
"seed",
|
35
|
+
[](torch::Generator& self) {
|
36
|
+
return self.seed();
|
37
|
+
})
|
38
|
+
.define_method(
|
39
|
+
"state",
|
40
|
+
[](torch::Generator& self) {
|
41
|
+
return self.get_state();
|
42
|
+
})
|
43
|
+
.define_method(
|
44
|
+
"state=",
|
45
|
+
[](torch::Generator& self, const torch::Tensor& state) {
|
46
|
+
self.set_state(state);
|
47
|
+
});
|
48
|
+
|
49
|
+
THPGeneratorClass = rb_cGenerator.value();
|
50
|
+
}
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
#include "ruby_arg_parser.h"
|
4
4
|
|
5
|
+
VALUE THPGeneratorClass = Qnil;
|
5
6
|
VALUE THPVariableClass = Qnil;
|
6
7
|
|
7
8
|
static std::unordered_map<std::string, ParameterType> type_map = {
|
@@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
244
245
|
return size > 0 && FIXNUM_P(obj);
|
245
246
|
}
|
246
247
|
case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
|
247
|
-
case ParameterType::GENERATOR: return
|
248
|
+
case ParameterType::GENERATOR: return THPGenerator_Check(obj);
|
248
249
|
case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
|
249
250
|
case ParameterType::STORAGE: return false; // return isStorage(obj);
|
250
251
|
// case ParameterType::PYOBJECT: return true;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
|
|
223
223
|
|
224
224
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
225
225
|
if (NIL_P(args[i])) return c10::nullopt;
|
226
|
-
|
226
|
+
return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
|
227
227
|
}
|
228
228
|
|
229
229
|
inline at::Storage RubyArgs::storage(int i) {
|
data/ext/torch/utils.h
CHANGED
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
|
|
17
17
|
|
18
18
|
// keep THP prefix for now to make it easier to compare code
|
19
19
|
|
20
|
+
extern VALUE THPGeneratorClass;
|
20
21
|
extern VALUE THPVariableClass;
|
21
22
|
|
22
23
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
@@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
|
|
44
45
|
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
45
46
|
}
|
46
47
|
|
48
|
+
inline bool THPGenerator_Check(VALUE obj) {
|
49
|
+
return rb_obj_is_kind_of(obj, THPGeneratorClass);
|
50
|
+
}
|
51
|
+
|
47
52
|
inline bool THPVariable_Check(VALUE obj) {
|
48
53
|
return rb_obj_is_kind_of(obj, THPVariableClass);
|
49
54
|
}
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.12.
|
4
|
+
version: 0.12.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-01-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -44,6 +44,7 @@ files:
|
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
46
|
- ext/torch/fft_functions.h
|
47
|
+
- ext/torch/generator.cpp
|
47
48
|
- ext/torch/ivalue.cpp
|
48
49
|
- ext/torch/linalg.cpp
|
49
50
|
- ext/torch/linalg_functions.h
|
@@ -229,7 +230,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
229
230
|
- !ruby/object:Gem::Version
|
230
231
|
version: '0'
|
231
232
|
requirements: []
|
232
|
-
rubygems_version: 3.
|
233
|
+
rubygems_version: 3.4.1
|
233
234
|
signing_key:
|
234
235
|
specification_version: 4
|
235
236
|
summary: Deep learning for Ruby, powered by LibTorch
|