torch-rb 0.12.0 → 0.12.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/torch/ext.cpp +3 -0
- data/ext/torch/generator.cpp +50 -0
- data/ext/torch/ruby_arg_parser.cpp +2 -1
- data/ext/torch/ruby_arg_parser.h +1 -1
- data/ext/torch/utils.h +5 -0
- data/lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb +37 -0
- data/lib/torch/utils/data/data_pipes/iter/file_lister.rb +74 -0
- data/lib/torch/utils/data/data_pipes/iter/file_opener.rb +51 -0
- data/lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter_data_pipe.rb +41 -0
- data/lib/torch/utils/data/iterable_dataset.rb +13 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +7 -0
- metadata +11 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8160298f6299869f699201dd9217785da5ac7d2562f05b12a2df2d4537143886
|
4
|
+
data.tar.gz: cb902f45cb72adaaa2008e181af2af0844e926ef62f1d2ccaa67251cb70c5799
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c5077c40cb32414ac7d230430fdcf1cd93f199a81100179a1aa26185d6216c049316bbf35d79968442fd4db5e18414991946a366e6e7e6260a904f001774faa2
|
7
|
+
data.tar.gz: b4f1ff33b47596953c0654b07f1916a88ba247af72c1d4693d9b737802a790256d3232c0419c3a603fdebd51f47360833265326a1f27f8242604ff05766ea690
|
data/CHANGELOG.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
|
|
12
12
|
void init_backends(Rice::Module& m);
|
13
13
|
void init_cuda(Rice::Module& m);
|
14
14
|
void init_device(Rice::Module& m);
|
15
|
+
void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
|
15
16
|
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
16
17
|
void init_random(Rice::Module& m);
|
17
18
|
|
@@ -23,6 +24,7 @@ void Init_ext()
|
|
23
24
|
// need to define certain classes up front to keep Rice happy
|
24
25
|
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
|
25
26
|
.define_constructor(Rice::Constructor<torch::IValue>());
|
27
|
+
auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
|
26
28
|
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
27
29
|
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
28
30
|
.define_constructor(Rice::Constructor<torch::TensorOptions>());
|
@@ -38,6 +40,7 @@ void Init_ext()
|
|
38
40
|
init_backends(m);
|
39
41
|
init_cuda(m);
|
40
42
|
init_device(m);
|
43
|
+
init_generator(m, rb_cGenerator);
|
41
44
|
init_ivalue(m, rb_cIValue);
|
42
45
|
init_random(m);
|
43
46
|
}
|
@@ -0,0 +1,50 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
|
8
|
+
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
|
9
|
+
rb_cGenerator
|
10
|
+
.add_handler<torch::Error>(handle_error)
|
11
|
+
.define_singleton_function(
|
12
|
+
"new",
|
13
|
+
[]() {
|
14
|
+
// TODO support more devices
|
15
|
+
return torch::make_generator<torch::CPUGeneratorImpl>();
|
16
|
+
})
|
17
|
+
.define_method(
|
18
|
+
"device",
|
19
|
+
[](torch::Generator& self) {
|
20
|
+
return self.device();
|
21
|
+
})
|
22
|
+
.define_method(
|
23
|
+
"initial_seed",
|
24
|
+
[](torch::Generator& self) {
|
25
|
+
return self.current_seed();
|
26
|
+
})
|
27
|
+
.define_method(
|
28
|
+
"manual_seed",
|
29
|
+
[](torch::Generator& self, uint64_t seed) {
|
30
|
+
self.set_current_seed(seed);
|
31
|
+
return self;
|
32
|
+
})
|
33
|
+
.define_method(
|
34
|
+
"seed",
|
35
|
+
[](torch::Generator& self) {
|
36
|
+
return self.seed();
|
37
|
+
})
|
38
|
+
.define_method(
|
39
|
+
"state",
|
40
|
+
[](torch::Generator& self) {
|
41
|
+
return self.get_state();
|
42
|
+
})
|
43
|
+
.define_method(
|
44
|
+
"state=",
|
45
|
+
[](torch::Generator& self, const torch::Tensor& state) {
|
46
|
+
self.set_state(state);
|
47
|
+
});
|
48
|
+
|
49
|
+
THPGeneratorClass = rb_cGenerator.value();
|
50
|
+
}
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
#include "ruby_arg_parser.h"
|
4
4
|
|
5
|
+
VALUE THPGeneratorClass = Qnil;
|
5
6
|
VALUE THPVariableClass = Qnil;
|
6
7
|
|
7
8
|
static std::unordered_map<std::string, ParameterType> type_map = {
|
@@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
244
245
|
return size > 0 && FIXNUM_P(obj);
|
245
246
|
}
|
246
247
|
case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
|
247
|
-
case ParameterType::GENERATOR: return
|
248
|
+
case ParameterType::GENERATOR: return THPGenerator_Check(obj);
|
248
249
|
case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
|
249
250
|
case ParameterType::STORAGE: return false; // return isStorage(obj);
|
250
251
|
// case ParameterType::PYOBJECT: return true;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
|
|
223
223
|
|
224
224
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
225
225
|
if (NIL_P(args[i])) return c10::nullopt;
|
226
|
-
|
226
|
+
return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
|
227
227
|
}
|
228
228
|
|
229
229
|
inline at::Storage RubyArgs::storage(int i) {
|
data/ext/torch/utils.h
CHANGED
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
|
|
17
17
|
|
18
18
|
// keep THP prefix for now to make it easier to compare code
|
19
19
|
|
20
|
+
extern VALUE THPGeneratorClass;
|
20
21
|
extern VALUE THPVariableClass;
|
21
22
|
|
22
23
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
@@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
|
|
44
45
|
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
45
46
|
}
|
46
47
|
|
48
|
+
inline bool THPGenerator_Check(VALUE obj) {
|
49
|
+
return rb_obj_is_kind_of(obj, THPGeneratorClass);
|
50
|
+
}
|
51
|
+
|
47
52
|
inline bool THPVariable_Check(VALUE obj) {
|
48
53
|
return rb_obj_is_kind_of(obj, THPVariableClass);
|
49
54
|
}
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class FilterIterDataPipe < IterDataPipe
|
6
|
+
functional_datapipe :filter
|
7
|
+
|
8
|
+
def initialize(datapipe, &block)
|
9
|
+
@datapipe = datapipe
|
10
|
+
@filter_fn = block
|
11
|
+
end
|
12
|
+
|
13
|
+
def each
|
14
|
+
@datapipe.each do |data|
|
15
|
+
filtered = return_if_true(data)
|
16
|
+
if non_empty?(filtered)
|
17
|
+
yield filtered
|
18
|
+
else
|
19
|
+
Iter::StreamWrapper.close_streams(data)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def return_if_true(data)
|
25
|
+
condition = @filter_fn.call(data)
|
26
|
+
|
27
|
+
data if condition
|
28
|
+
end
|
29
|
+
|
30
|
+
def non_empty?(data)
|
31
|
+
!data.nil?
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileLister < IterDataPipe
|
7
|
+
def initialize(
|
8
|
+
root = ".",
|
9
|
+
masks = "",
|
10
|
+
recursive: false,
|
11
|
+
abspath: false,
|
12
|
+
non_deterministic: false,
|
13
|
+
length: -1
|
14
|
+
)
|
15
|
+
super()
|
16
|
+
if root.is_a?(String)
|
17
|
+
root = [root]
|
18
|
+
end
|
19
|
+
if !root.is_a?(IterDataPipe)
|
20
|
+
root = IterableWrapper.new(root)
|
21
|
+
end
|
22
|
+
@datapipe = root
|
23
|
+
@masks = masks
|
24
|
+
@recursive = recursive
|
25
|
+
@abspath = abspath
|
26
|
+
@non_deterministic = non_deterministic
|
27
|
+
@length = length
|
28
|
+
end
|
29
|
+
|
30
|
+
def each(&block)
|
31
|
+
@datapipe.each do |path|
|
32
|
+
get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def get_file_pathnames_from_root(
|
39
|
+
root,
|
40
|
+
masks,
|
41
|
+
recursive: false,
|
42
|
+
abspath: false,
|
43
|
+
non_deterministic: false
|
44
|
+
)
|
45
|
+
if File.file?(root)
|
46
|
+
raise NotImplementedYet
|
47
|
+
else
|
48
|
+
pattern = recursive ? "**/*" : "*"
|
49
|
+
paths = Dir.glob(pattern, base: root)
|
50
|
+
paths = paths.sort if non_deterministic
|
51
|
+
paths.each do |f|
|
52
|
+
if abspath
|
53
|
+
raise NotImplementedYet
|
54
|
+
end
|
55
|
+
if match_masks(f, masks)
|
56
|
+
yield File.join(root, f)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
def match_masks(name, masks)
|
63
|
+
if masks.empty?
|
64
|
+
return true
|
65
|
+
end
|
66
|
+
|
67
|
+
raise NotImplementedYet
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileOpener < IterDataPipe
|
7
|
+
def initialize(datapipe, mode: "r", encoding: nil, length: -1)
|
8
|
+
super()
|
9
|
+
@datapipe = datapipe
|
10
|
+
@mode = mode
|
11
|
+
@encoding = encoding
|
12
|
+
|
13
|
+
if !["b", "t", "rb", "rt", "r"].include?(@mode)
|
14
|
+
raise ArgumentError, "Invalid mode #{mode}"
|
15
|
+
end
|
16
|
+
|
17
|
+
if mode.include?("b") && !encoding.nil?
|
18
|
+
raise ArgumentError, "binary mode doesn't take an encoding argument"
|
19
|
+
end
|
20
|
+
|
21
|
+
@length = length
|
22
|
+
end
|
23
|
+
|
24
|
+
def each(&block)
|
25
|
+
get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
|
26
|
+
end
|
27
|
+
|
28
|
+
private
|
29
|
+
|
30
|
+
def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
|
31
|
+
if !pathnames.is_a?(Enumerable)
|
32
|
+
pathnames = [pathnames]
|
33
|
+
end
|
34
|
+
|
35
|
+
if ["b", "t"].include?(mode)
|
36
|
+
mode = "r#{mode}"
|
37
|
+
end
|
38
|
+
|
39
|
+
pathnames.each do |pathname|
|
40
|
+
if !pathname.is_a?(String)
|
41
|
+
raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
|
42
|
+
end
|
43
|
+
yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class IterableWrapper < IterDataPipe
|
7
|
+
def initialize(iterable, deepcopy: true)
|
8
|
+
@iterable = iterable
|
9
|
+
@deepcopy = deepcopy
|
10
|
+
end
|
11
|
+
|
12
|
+
def each
|
13
|
+
source_data = @iterable
|
14
|
+
if @deepcopy
|
15
|
+
source_data = Marshal.load(Marshal.dump(@iterable))
|
16
|
+
end
|
17
|
+
source_data.each do |data|
|
18
|
+
yield data
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
def length
|
23
|
+
@iterable.length
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class StreamWrapper
|
7
|
+
def initialize(file_obj)
|
8
|
+
@file_obj = file_obj
|
9
|
+
end
|
10
|
+
|
11
|
+
def gets(...)
|
12
|
+
@file_obj.gets(...)
|
13
|
+
end
|
14
|
+
|
15
|
+
def close
|
16
|
+
@file_obj.close
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO improve
|
20
|
+
def self.close_streams(cls)
|
21
|
+
if cls.is_a?(StreamWrapper)
|
22
|
+
cls.close
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class IterDataPipe < IterableDataset
|
6
|
+
def self.functional_datapipe(name)
|
7
|
+
IterDataPipe.register_datapipe_as_function(name, self)
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.functions
|
11
|
+
@functions ||= {}
|
12
|
+
end
|
13
|
+
|
14
|
+
def self.register_datapipe_as_function(function_name, cls_to_register)
|
15
|
+
if functions.include?(function_name)
|
16
|
+
raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
|
17
|
+
end
|
18
|
+
|
19
|
+
function = lambda do |source_dp, *args, **options, &block|
|
20
|
+
cls_to_register.new(source_dp, *args, **options, &block)
|
21
|
+
end
|
22
|
+
functions[function_name] = function
|
23
|
+
|
24
|
+
define_method function_name do |*args, **options, &block|
|
25
|
+
IterDataPipe.functions[function_name].call(self, *args, **options, &block)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def reset
|
30
|
+
# no-op, but subclasses can override
|
31
|
+
end
|
32
|
+
|
33
|
+
def each(&block)
|
34
|
+
reset
|
35
|
+
@source_datapipe.each(&block)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
|
|
187
187
|
require "torch/utils/data"
|
188
188
|
require "torch/utils/data/data_loader"
|
189
189
|
require "torch/utils/data/dataset"
|
190
|
+
require "torch/utils/data/iterable_dataset"
|
191
|
+
require "torch/utils/data/data_pipes/iter_data_pipe"
|
192
|
+
require "torch/utils/data/data_pipes/filter_iter_data_pipe"
|
193
|
+
require "torch/utils/data/data_pipes/iter/file_lister"
|
194
|
+
require "torch/utils/data/data_pipes/iter/file_opener"
|
195
|
+
require "torch/utils/data/data_pipes/iter/iterable_wrapper"
|
196
|
+
require "torch/utils/data/data_pipes/iter/stream_wrapper"
|
190
197
|
require "torch/utils/data/subset"
|
191
198
|
require "torch/utils/data/tensor_dataset"
|
192
199
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.12.
|
4
|
+
version: 0.12.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-01-31 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -44,6 +44,7 @@ files:
|
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
46
|
- ext/torch/fft_functions.h
|
47
|
+
- ext/torch/generator.cpp
|
47
48
|
- ext/torch/ivalue.cpp
|
48
49
|
- ext/torch/linalg.cpp
|
49
50
|
- ext/torch/linalg_functions.h
|
@@ -206,7 +207,14 @@ files:
|
|
206
207
|
- lib/torch/tensor.rb
|
207
208
|
- lib/torch/utils/data.rb
|
208
209
|
- lib/torch/utils/data/data_loader.rb
|
210
|
+
- lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
|
211
|
+
- lib/torch/utils/data/data_pipes/iter/file_lister.rb
|
212
|
+
- lib/torch/utils/data/data_pipes/iter/file_opener.rb
|
213
|
+
- lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
|
214
|
+
- lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
|
215
|
+
- lib/torch/utils/data/data_pipes/iter_data_pipe.rb
|
209
216
|
- lib/torch/utils/data/dataset.rb
|
217
|
+
- lib/torch/utils/data/iterable_dataset.rb
|
210
218
|
- lib/torch/utils/data/subset.rb
|
211
219
|
- lib/torch/utils/data/tensor_dataset.rb
|
212
220
|
- lib/torch/version.rb
|
@@ -229,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
229
237
|
- !ruby/object:Gem::Version
|
230
238
|
version: '0'
|
231
239
|
requirements: []
|
232
|
-
rubygems_version: 3.
|
240
|
+
rubygems_version: 3.4.1
|
233
241
|
signing_key:
|
234
242
|
specification_version: 4
|
235
243
|
summary: Deep learning for Ruby, powered by LibTorch
|