torch-rb 0.12.1 → 0.13.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +3 -1
- data/codegen/generate_functions.rb +4 -2
- data/codegen/native_functions.yaml +1392 -593
- data/ext/torch/backends.cpp +4 -0
- data/ext/torch/tensor.cpp +4 -4
- data/ext/torch/utils.h +1 -1
- data/lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb +37 -0
- data/lib/torch/utils/data/data_pipes/iter/file_lister.rb +74 -0
- data/lib/torch/utils/data/data_pipes/iter/file_opener.rb +51 -0
- data/lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter_data_pipe.rb +41 -0
- data/lib/torch/utils/data/iterable_dataset.rb +13 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +7 -0
- metadata +11 -4
data/ext/torch/backends.cpp
CHANGED
@@ -14,4 +14,8 @@ void init_backends(Rice::Module& m) {
|
|
14
14
|
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
|
+
|
18
|
+
Rice::define_module_under(rb_mBackends, "MPS")
|
19
|
+
.add_handler<torch::Error>(handle_error)
|
20
|
+
.define_singleton_function("available?", &torch::hasMPS);
|
17
21
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -35,17 +35,17 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
35
35
|
if (obj.is_instance_of(rb_cInteger)) {
|
36
36
|
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
37
37
|
} else if (obj.is_instance_of(rb_cRange)) {
|
38
|
-
torch::optional<
|
39
|
-
torch::optional<
|
38
|
+
torch::optional<c10::SymInt> start_index = torch::nullopt;
|
39
|
+
torch::optional<c10::SymInt> stop_index = torch::nullopt;
|
40
40
|
|
41
41
|
Object begin = obj.call("begin");
|
42
42
|
if (!begin.is_nil()) {
|
43
|
-
start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
|
43
|
+
start_index = c10::SymInt(Rice::detail::From_Ruby<int64_t>().convert(begin.value()));
|
44
44
|
}
|
45
45
|
|
46
46
|
Object end = obj.call("end");
|
47
47
|
if (!end.is_nil()) {
|
48
|
-
stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
|
48
|
+
stop_index = c10::SymInt(Rice::detail::From_Ruby<int64_t>().convert(end.value()));
|
49
49
|
}
|
50
50
|
|
51
51
|
Object exclude_end = obj.call("exclude_end?");
|
data/ext/torch/utils.h
CHANGED
@@ -0,0 +1,37 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class FilterIterDataPipe < IterDataPipe
|
6
|
+
functional_datapipe :filter
|
7
|
+
|
8
|
+
def initialize(datapipe, &block)
|
9
|
+
@datapipe = datapipe
|
10
|
+
@filter_fn = block
|
11
|
+
end
|
12
|
+
|
13
|
+
def each
|
14
|
+
@datapipe.each do |data|
|
15
|
+
filtered = return_if_true(data)
|
16
|
+
if non_empty?(filtered)
|
17
|
+
yield filtered
|
18
|
+
else
|
19
|
+
Iter::StreamWrapper.close_streams(data)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def return_if_true(data)
|
25
|
+
condition = @filter_fn.call(data)
|
26
|
+
|
27
|
+
data if condition
|
28
|
+
end
|
29
|
+
|
30
|
+
def non_empty?(data)
|
31
|
+
!data.nil?
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileLister < IterDataPipe
|
7
|
+
def initialize(
|
8
|
+
root = ".",
|
9
|
+
masks = "",
|
10
|
+
recursive: false,
|
11
|
+
abspath: false,
|
12
|
+
non_deterministic: false,
|
13
|
+
length: -1
|
14
|
+
)
|
15
|
+
super()
|
16
|
+
if root.is_a?(String)
|
17
|
+
root = [root]
|
18
|
+
end
|
19
|
+
if !root.is_a?(IterDataPipe)
|
20
|
+
root = IterableWrapper.new(root)
|
21
|
+
end
|
22
|
+
@datapipe = root
|
23
|
+
@masks = masks
|
24
|
+
@recursive = recursive
|
25
|
+
@abspath = abspath
|
26
|
+
@non_deterministic = non_deterministic
|
27
|
+
@length = length
|
28
|
+
end
|
29
|
+
|
30
|
+
def each(&block)
|
31
|
+
@datapipe.each do |path|
|
32
|
+
get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def get_file_pathnames_from_root(
|
39
|
+
root,
|
40
|
+
masks,
|
41
|
+
recursive: false,
|
42
|
+
abspath: false,
|
43
|
+
non_deterministic: false
|
44
|
+
)
|
45
|
+
if File.file?(root)
|
46
|
+
raise NotImplementedYet
|
47
|
+
else
|
48
|
+
pattern = recursive ? "**/*" : "*"
|
49
|
+
paths = Dir.glob(pattern, base: root)
|
50
|
+
paths = paths.sort if non_deterministic
|
51
|
+
paths.each do |f|
|
52
|
+
if abspath
|
53
|
+
raise NotImplementedYet
|
54
|
+
end
|
55
|
+
if match_masks(f, masks)
|
56
|
+
yield File.join(root, f)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
def match_masks(name, masks)
|
63
|
+
if masks.empty?
|
64
|
+
return true
|
65
|
+
end
|
66
|
+
|
67
|
+
raise NotImplementedYet
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileOpener < IterDataPipe
|
7
|
+
def initialize(datapipe, mode: "r", encoding: nil, length: -1)
|
8
|
+
super()
|
9
|
+
@datapipe = datapipe
|
10
|
+
@mode = mode
|
11
|
+
@encoding = encoding
|
12
|
+
|
13
|
+
if !["b", "t", "rb", "rt", "r"].include?(@mode)
|
14
|
+
raise ArgumentError, "Invalid mode #{mode}"
|
15
|
+
end
|
16
|
+
|
17
|
+
if mode.include?("b") && !encoding.nil?
|
18
|
+
raise ArgumentError, "binary mode doesn't take an encoding argument"
|
19
|
+
end
|
20
|
+
|
21
|
+
@length = length
|
22
|
+
end
|
23
|
+
|
24
|
+
def each(&block)
|
25
|
+
get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
|
26
|
+
end
|
27
|
+
|
28
|
+
private
|
29
|
+
|
30
|
+
def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
|
31
|
+
if !pathnames.is_a?(Enumerable)
|
32
|
+
pathnames = [pathnames]
|
33
|
+
end
|
34
|
+
|
35
|
+
if ["b", "t"].include?(mode)
|
36
|
+
mode = "r#{mode}"
|
37
|
+
end
|
38
|
+
|
39
|
+
pathnames.each do |pathname|
|
40
|
+
if !pathname.is_a?(String)
|
41
|
+
raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
|
42
|
+
end
|
43
|
+
yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class IterableWrapper < IterDataPipe
|
7
|
+
def initialize(iterable, deepcopy: true)
|
8
|
+
@iterable = iterable
|
9
|
+
@deepcopy = deepcopy
|
10
|
+
end
|
11
|
+
|
12
|
+
def each
|
13
|
+
source_data = @iterable
|
14
|
+
if @deepcopy
|
15
|
+
source_data = Marshal.load(Marshal.dump(@iterable))
|
16
|
+
end
|
17
|
+
source_data.each do |data|
|
18
|
+
yield data
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
def length
|
23
|
+
@iterable.length
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class StreamWrapper
|
7
|
+
def initialize(file_obj)
|
8
|
+
@file_obj = file_obj
|
9
|
+
end
|
10
|
+
|
11
|
+
def gets(...)
|
12
|
+
@file_obj.gets(...)
|
13
|
+
end
|
14
|
+
|
15
|
+
def close
|
16
|
+
@file_obj.close
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO improve
|
20
|
+
def self.close_streams(cls)
|
21
|
+
if cls.is_a?(StreamWrapper)
|
22
|
+
cls.close
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class IterDataPipe < IterableDataset
|
6
|
+
def self.functional_datapipe(name)
|
7
|
+
IterDataPipe.register_datapipe_as_function(name, self)
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.functions
|
11
|
+
@functions ||= {}
|
12
|
+
end
|
13
|
+
|
14
|
+
def self.register_datapipe_as_function(function_name, cls_to_register)
|
15
|
+
if functions.include?(function_name)
|
16
|
+
raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
|
17
|
+
end
|
18
|
+
|
19
|
+
function = lambda do |source_dp, *args, **options, &block|
|
20
|
+
cls_to_register.new(source_dp, *args, **options, &block)
|
21
|
+
end
|
22
|
+
functions[function_name] = function
|
23
|
+
|
24
|
+
define_method function_name do |*args, **options, &block|
|
25
|
+
IterDataPipe.functions[function_name].call(self, *args, **options, &block)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def reset
|
30
|
+
# no-op, but subclasses can override
|
31
|
+
end
|
32
|
+
|
33
|
+
def each(&block)
|
34
|
+
reset
|
35
|
+
@source_datapipe.each(&block)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
|
|
187
187
|
require "torch/utils/data"
|
188
188
|
require "torch/utils/data/data_loader"
|
189
189
|
require "torch/utils/data/dataset"
|
190
|
+
require "torch/utils/data/iterable_dataset"
|
191
|
+
require "torch/utils/data/data_pipes/iter_data_pipe"
|
192
|
+
require "torch/utils/data/data_pipes/filter_iter_data_pipe"
|
193
|
+
require "torch/utils/data/data_pipes/iter/file_lister"
|
194
|
+
require "torch/utils/data/data_pipes/iter/file_opener"
|
195
|
+
require "torch/utils/data/data_pipes/iter/iterable_wrapper"
|
196
|
+
require "torch/utils/data/data_pipes/iter/stream_wrapper"
|
190
197
|
require "torch/utils/data/subset"
|
191
198
|
require "torch/utils/data/tensor_dataset"
|
192
199
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.13.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-04-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -207,7 +207,14 @@ files:
|
|
207
207
|
- lib/torch/tensor.rb
|
208
208
|
- lib/torch/utils/data.rb
|
209
209
|
- lib/torch/utils/data/data_loader.rb
|
210
|
+
- lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
|
211
|
+
- lib/torch/utils/data/data_pipes/iter/file_lister.rb
|
212
|
+
- lib/torch/utils/data/data_pipes/iter/file_opener.rb
|
213
|
+
- lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
|
214
|
+
- lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
|
215
|
+
- lib/torch/utils/data/data_pipes/iter_data_pipe.rb
|
210
216
|
- lib/torch/utils/data/dataset.rb
|
217
|
+
- lib/torch/utils/data/iterable_dataset.rb
|
211
218
|
- lib/torch/utils/data/subset.rb
|
212
219
|
- lib/torch/utils/data/tensor_dataset.rb
|
213
220
|
- lib/torch/version.rb
|
@@ -223,14 +230,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
223
230
|
requirements:
|
224
231
|
- - ">="
|
225
232
|
- !ruby/object:Gem::Version
|
226
|
-
version: '
|
233
|
+
version: '3'
|
227
234
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
228
235
|
requirements:
|
229
236
|
- - ">="
|
230
237
|
- !ruby/object:Gem::Version
|
231
238
|
version: '0'
|
232
239
|
requirements: []
|
233
|
-
rubygems_version: 3.4.
|
240
|
+
rubygems_version: 3.4.10
|
234
241
|
signing_key:
|
235
242
|
specification_version: 4
|
236
243
|
summary: Deep learning for Ruby, powered by LibTorch
|