torch-rb 0.12.0 → 0.12.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/torch/ext.cpp +3 -0
- data/ext/torch/generator.cpp +50 -0
- data/ext/torch/ruby_arg_parser.cpp +2 -1
- data/ext/torch/ruby_arg_parser.h +1 -1
- data/ext/torch/utils.h +5 -0
- data/lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb +37 -0
- data/lib/torch/utils/data/data_pipes/iter/file_lister.rb +74 -0
- data/lib/torch/utils/data/data_pipes/iter/file_opener.rb +51 -0
- data/lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter_data_pipe.rb +41 -0
- data/lib/torch/utils/data/iterable_dataset.rb +13 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +7 -0
- metadata +11 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8160298f6299869f699201dd9217785da5ac7d2562f05b12a2df2d4537143886
|
4
|
+
data.tar.gz: cb902f45cb72adaaa2008e181af2af0844e926ef62f1d2ccaa67251cb70c5799
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c5077c40cb32414ac7d230430fdcf1cd93f199a81100179a1aa26185d6216c049316bbf35d79968442fd4db5e18414991946a366e6e7e6260a904f001774faa2
|
7
|
+
data.tar.gz: b4f1ff33b47596953c0654b07f1916a88ba247af72c1d4693d9b737802a790256d3232c0419c3a603fdebd51f47360833265326a1f27f8242604ff05766ea690
|
data/CHANGELOG.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
|
|
12
12
|
void init_backends(Rice::Module& m);
|
13
13
|
void init_cuda(Rice::Module& m);
|
14
14
|
void init_device(Rice::Module& m);
|
15
|
+
void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
|
15
16
|
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
16
17
|
void init_random(Rice::Module& m);
|
17
18
|
|
@@ -23,6 +24,7 @@ void Init_ext()
|
|
23
24
|
// need to define certain classes up front to keep Rice happy
|
24
25
|
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
|
25
26
|
.define_constructor(Rice::Constructor<torch::IValue>());
|
27
|
+
auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
|
26
28
|
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
27
29
|
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
28
30
|
.define_constructor(Rice::Constructor<torch::TensorOptions>());
|
@@ -38,6 +40,7 @@ void Init_ext()
|
|
38
40
|
init_backends(m);
|
39
41
|
init_cuda(m);
|
40
42
|
init_device(m);
|
43
|
+
init_generator(m, rb_cGenerator);
|
41
44
|
init_ivalue(m, rb_cIValue);
|
42
45
|
init_random(m);
|
43
46
|
}
|
@@ -0,0 +1,50 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
|
8
|
+
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
|
9
|
+
rb_cGenerator
|
10
|
+
.add_handler<torch::Error>(handle_error)
|
11
|
+
.define_singleton_function(
|
12
|
+
"new",
|
13
|
+
[]() {
|
14
|
+
// TODO support more devices
|
15
|
+
return torch::make_generator<torch::CPUGeneratorImpl>();
|
16
|
+
})
|
17
|
+
.define_method(
|
18
|
+
"device",
|
19
|
+
[](torch::Generator& self) {
|
20
|
+
return self.device();
|
21
|
+
})
|
22
|
+
.define_method(
|
23
|
+
"initial_seed",
|
24
|
+
[](torch::Generator& self) {
|
25
|
+
return self.current_seed();
|
26
|
+
})
|
27
|
+
.define_method(
|
28
|
+
"manual_seed",
|
29
|
+
[](torch::Generator& self, uint64_t seed) {
|
30
|
+
self.set_current_seed(seed);
|
31
|
+
return self;
|
32
|
+
})
|
33
|
+
.define_method(
|
34
|
+
"seed",
|
35
|
+
[](torch::Generator& self) {
|
36
|
+
return self.seed();
|
37
|
+
})
|
38
|
+
.define_method(
|
39
|
+
"state",
|
40
|
+
[](torch::Generator& self) {
|
41
|
+
return self.get_state();
|
42
|
+
})
|
43
|
+
.define_method(
|
44
|
+
"state=",
|
45
|
+
[](torch::Generator& self, const torch::Tensor& state) {
|
46
|
+
self.set_state(state);
|
47
|
+
});
|
48
|
+
|
49
|
+
THPGeneratorClass = rb_cGenerator.value();
|
50
|
+
}
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
#include "ruby_arg_parser.h"
|
4
4
|
|
5
|
+
VALUE THPGeneratorClass = Qnil;
|
5
6
|
VALUE THPVariableClass = Qnil;
|
6
7
|
|
7
8
|
static std::unordered_map<std::string, ParameterType> type_map = {
|
@@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
244
245
|
return size > 0 && FIXNUM_P(obj);
|
245
246
|
}
|
246
247
|
case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
|
247
|
-
case ParameterType::GENERATOR: return
|
248
|
+
case ParameterType::GENERATOR: return THPGenerator_Check(obj);
|
248
249
|
case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
|
249
250
|
case ParameterType::STORAGE: return false; // return isStorage(obj);
|
250
251
|
// case ParameterType::PYOBJECT: return true;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
|
|
223
223
|
|
224
224
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
225
225
|
if (NIL_P(args[i])) return c10::nullopt;
|
226
|
-
|
226
|
+
return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
|
227
227
|
}
|
228
228
|
|
229
229
|
inline at::Storage RubyArgs::storage(int i) {
|
data/ext/torch/utils.h
CHANGED
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
|
|
17
17
|
|
18
18
|
// keep THP prefix for now to make it easier to compare code
|
19
19
|
|
20
|
+
extern VALUE THPGeneratorClass;
|
20
21
|
extern VALUE THPVariableClass;
|
21
22
|
|
22
23
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
@@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
|
|
44
45
|
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
45
46
|
}
|
46
47
|
|
48
|
+
inline bool THPGenerator_Check(VALUE obj) {
|
49
|
+
return rb_obj_is_kind_of(obj, THPGeneratorClass);
|
50
|
+
}
|
51
|
+
|
47
52
|
inline bool THPVariable_Check(VALUE obj) {
|
48
53
|
return rb_obj_is_kind_of(obj, THPVariableClass);
|
49
54
|
}
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class FilterIterDataPipe < IterDataPipe
|
6
|
+
functional_datapipe :filter
|
7
|
+
|
8
|
+
def initialize(datapipe, &block)
|
9
|
+
@datapipe = datapipe
|
10
|
+
@filter_fn = block
|
11
|
+
end
|
12
|
+
|
13
|
+
def each
|
14
|
+
@datapipe.each do |data|
|
15
|
+
filtered = return_if_true(data)
|
16
|
+
if non_empty?(filtered)
|
17
|
+
yield filtered
|
18
|
+
else
|
19
|
+
Iter::StreamWrapper.close_streams(data)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def return_if_true(data)
|
25
|
+
condition = @filter_fn.call(data)
|
26
|
+
|
27
|
+
data if condition
|
28
|
+
end
|
29
|
+
|
30
|
+
def non_empty?(data)
|
31
|
+
!data.nil?
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileLister < IterDataPipe
|
7
|
+
def initialize(
|
8
|
+
root = ".",
|
9
|
+
masks = "",
|
10
|
+
recursive: false,
|
11
|
+
abspath: false,
|
12
|
+
non_deterministic: false,
|
13
|
+
length: -1
|
14
|
+
)
|
15
|
+
super()
|
16
|
+
if root.is_a?(String)
|
17
|
+
root = [root]
|
18
|
+
end
|
19
|
+
if !root.is_a?(IterDataPipe)
|
20
|
+
root = IterableWrapper.new(root)
|
21
|
+
end
|
22
|
+
@datapipe = root
|
23
|
+
@masks = masks
|
24
|
+
@recursive = recursive
|
25
|
+
@abspath = abspath
|
26
|
+
@non_deterministic = non_deterministic
|
27
|
+
@length = length
|
28
|
+
end
|
29
|
+
|
30
|
+
def each(&block)
|
31
|
+
@datapipe.each do |path|
|
32
|
+
get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def get_file_pathnames_from_root(
|
39
|
+
root,
|
40
|
+
masks,
|
41
|
+
recursive: false,
|
42
|
+
abspath: false,
|
43
|
+
non_deterministic: false
|
44
|
+
)
|
45
|
+
if File.file?(root)
|
46
|
+
raise NotImplementedYet
|
47
|
+
else
|
48
|
+
pattern = recursive ? "**/*" : "*"
|
49
|
+
paths = Dir.glob(pattern, base: root)
|
50
|
+
paths = paths.sort if non_deterministic
|
51
|
+
paths.each do |f|
|
52
|
+
if abspath
|
53
|
+
raise NotImplementedYet
|
54
|
+
end
|
55
|
+
if match_masks(f, masks)
|
56
|
+
yield File.join(root, f)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
def match_masks(name, masks)
|
63
|
+
if masks.empty?
|
64
|
+
return true
|
65
|
+
end
|
66
|
+
|
67
|
+
raise NotImplementedYet
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileOpener < IterDataPipe
|
7
|
+
def initialize(datapipe, mode: "r", encoding: nil, length: -1)
|
8
|
+
super()
|
9
|
+
@datapipe = datapipe
|
10
|
+
@mode = mode
|
11
|
+
@encoding = encoding
|
12
|
+
|
13
|
+
if !["b", "t", "rb", "rt", "r"].include?(@mode)
|
14
|
+
raise ArgumentError, "Invalid mode #{mode}"
|
15
|
+
end
|
16
|
+
|
17
|
+
if mode.include?("b") && !encoding.nil?
|
18
|
+
raise ArgumentError, "binary mode doesn't take an encoding argument"
|
19
|
+
end
|
20
|
+
|
21
|
+
@length = length
|
22
|
+
end
|
23
|
+
|
24
|
+
def each(&block)
|
25
|
+
get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
|
26
|
+
end
|
27
|
+
|
28
|
+
private
|
29
|
+
|
30
|
+
def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
|
31
|
+
if !pathnames.is_a?(Enumerable)
|
32
|
+
pathnames = [pathnames]
|
33
|
+
end
|
34
|
+
|
35
|
+
if ["b", "t"].include?(mode)
|
36
|
+
mode = "r#{mode}"
|
37
|
+
end
|
38
|
+
|
39
|
+
pathnames.each do |pathname|
|
40
|
+
if !pathname.is_a?(String)
|
41
|
+
raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
|
42
|
+
end
|
43
|
+
yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class IterableWrapper < IterDataPipe
|
7
|
+
def initialize(iterable, deepcopy: true)
|
8
|
+
@iterable = iterable
|
9
|
+
@deepcopy = deepcopy
|
10
|
+
end
|
11
|
+
|
12
|
+
def each
|
13
|
+
source_data = @iterable
|
14
|
+
if @deepcopy
|
15
|
+
source_data = Marshal.load(Marshal.dump(@iterable))
|
16
|
+
end
|
17
|
+
source_data.each do |data|
|
18
|
+
yield data
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
def length
|
23
|
+
@iterable.length
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class StreamWrapper
|
7
|
+
def initialize(file_obj)
|
8
|
+
@file_obj = file_obj
|
9
|
+
end
|
10
|
+
|
11
|
+
def gets(...)
|
12
|
+
@file_obj.gets(...)
|
13
|
+
end
|
14
|
+
|
15
|
+
def close
|
16
|
+
@file_obj.close
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO improve
|
20
|
+
def self.close_streams(cls)
|
21
|
+
if cls.is_a?(StreamWrapper)
|
22
|
+
cls.close
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class IterDataPipe < IterableDataset
|
6
|
+
def self.functional_datapipe(name)
|
7
|
+
IterDataPipe.register_datapipe_as_function(name, self)
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.functions
|
11
|
+
@functions ||= {}
|
12
|
+
end
|
13
|
+
|
14
|
+
def self.register_datapipe_as_function(function_name, cls_to_register)
|
15
|
+
if functions.include?(function_name)
|
16
|
+
raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
|
17
|
+
end
|
18
|
+
|
19
|
+
function = lambda do |source_dp, *args, **options, &block|
|
20
|
+
cls_to_register.new(source_dp, *args, **options, &block)
|
21
|
+
end
|
22
|
+
functions[function_name] = function
|
23
|
+
|
24
|
+
define_method function_name do |*args, **options, &block|
|
25
|
+
IterDataPipe.functions[function_name].call(self, *args, **options, &block)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def reset
|
30
|
+
# no-op, but subclasses can override
|
31
|
+
end
|
32
|
+
|
33
|
+
def each(&block)
|
34
|
+
reset
|
35
|
+
@source_datapipe.each(&block)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
|
|
187
187
|
require "torch/utils/data"
|
188
188
|
require "torch/utils/data/data_loader"
|
189
189
|
require "torch/utils/data/dataset"
|
190
|
+
require "torch/utils/data/iterable_dataset"
|
191
|
+
require "torch/utils/data/data_pipes/iter_data_pipe"
|
192
|
+
require "torch/utils/data/data_pipes/filter_iter_data_pipe"
|
193
|
+
require "torch/utils/data/data_pipes/iter/file_lister"
|
194
|
+
require "torch/utils/data/data_pipes/iter/file_opener"
|
195
|
+
require "torch/utils/data/data_pipes/iter/iterable_wrapper"
|
196
|
+
require "torch/utils/data/data_pipes/iter/stream_wrapper"
|
190
197
|
require "torch/utils/data/subset"
|
191
198
|
require "torch/utils/data/tensor_dataset"
|
192
199
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.12.
|
4
|
+
version: 0.12.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-01-31 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -44,6 +44,7 @@ files:
|
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
46
|
- ext/torch/fft_functions.h
|
47
|
+
- ext/torch/generator.cpp
|
47
48
|
- ext/torch/ivalue.cpp
|
48
49
|
- ext/torch/linalg.cpp
|
49
50
|
- ext/torch/linalg_functions.h
|
@@ -206,7 +207,14 @@ files:
|
|
206
207
|
- lib/torch/tensor.rb
|
207
208
|
- lib/torch/utils/data.rb
|
208
209
|
- lib/torch/utils/data/data_loader.rb
|
210
|
+
- lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
|
211
|
+
- lib/torch/utils/data/data_pipes/iter/file_lister.rb
|
212
|
+
- lib/torch/utils/data/data_pipes/iter/file_opener.rb
|
213
|
+
- lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
|
214
|
+
- lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
|
215
|
+
- lib/torch/utils/data/data_pipes/iter_data_pipe.rb
|
209
216
|
- lib/torch/utils/data/dataset.rb
|
217
|
+
- lib/torch/utils/data/iterable_dataset.rb
|
210
218
|
- lib/torch/utils/data/subset.rb
|
211
219
|
- lib/torch/utils/data/tensor_dataset.rb
|
212
220
|
- lib/torch/version.rb
|
@@ -229,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
229
237
|
- !ruby/object:Gem::Version
|
230
238
|
version: '0'
|
231
239
|
requirements: []
|
232
|
-
rubygems_version: 3.
|
240
|
+
rubygems_version: 3.4.1
|
233
241
|
signing_key:
|
234
242
|
specification_version: 4
|
235
243
|
summary: Deep learning for Ruby, powered by LibTorch
|