torch-rb 0.12.1 → 0.13.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +3 -1
- data/codegen/generate_functions.rb +4 -2
- data/codegen/native_functions.yaml +1392 -593
- data/ext/torch/backends.cpp +4 -0
- data/ext/torch/tensor.cpp +4 -4
- data/ext/torch/utils.h +1 -1
- data/lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb +37 -0
- data/lib/torch/utils/data/data_pipes/iter/file_lister.rb +74 -0
- data/lib/torch/utils/data/data_pipes/iter/file_opener.rb +51 -0
- data/lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter_data_pipe.rb +41 -0
- data/lib/torch/utils/data/iterable_dataset.rb +13 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +7 -0
- metadata +11 -4
data/ext/torch/backends.cpp
CHANGED
@@ -14,4 +14,8 @@ void init_backends(Rice::Module& m) {
|
|
14
14
|
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
|
+
|
18
|
+
Rice::define_module_under(rb_mBackends, "MPS")
|
19
|
+
.add_handler<torch::Error>(handle_error)
|
20
|
+
.define_singleton_function("available?", &torch::hasMPS);
|
17
21
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -35,17 +35,17 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
35
35
|
if (obj.is_instance_of(rb_cInteger)) {
|
36
36
|
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
37
37
|
} else if (obj.is_instance_of(rb_cRange)) {
|
38
|
-
torch::optional<
|
39
|
-
torch::optional<
|
38
|
+
torch::optional<c10::SymInt> start_index = torch::nullopt;
|
39
|
+
torch::optional<c10::SymInt> stop_index = torch::nullopt;
|
40
40
|
|
41
41
|
Object begin = obj.call("begin");
|
42
42
|
if (!begin.is_nil()) {
|
43
|
-
start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
|
43
|
+
start_index = c10::SymInt(Rice::detail::From_Ruby<int64_t>().convert(begin.value()));
|
44
44
|
}
|
45
45
|
|
46
46
|
Object end = obj.call("end");
|
47
47
|
if (!end.is_nil()) {
|
48
|
-
stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
|
48
|
+
stop_index = c10::SymInt(Rice::detail::From_Ruby<int64_t>().convert(end.value()));
|
49
49
|
}
|
50
50
|
|
51
51
|
Object exclude_end = obj.call("exclude_end?");
|
data/ext/torch/utils.h
CHANGED
@@ -0,0 +1,37 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class FilterIterDataPipe < IterDataPipe
|
6
|
+
functional_datapipe :filter
|
7
|
+
|
8
|
+
def initialize(datapipe, &block)
|
9
|
+
@datapipe = datapipe
|
10
|
+
@filter_fn = block
|
11
|
+
end
|
12
|
+
|
13
|
+
def each
|
14
|
+
@datapipe.each do |data|
|
15
|
+
filtered = return_if_true(data)
|
16
|
+
if non_empty?(filtered)
|
17
|
+
yield filtered
|
18
|
+
else
|
19
|
+
Iter::StreamWrapper.close_streams(data)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def return_if_true(data)
|
25
|
+
condition = @filter_fn.call(data)
|
26
|
+
|
27
|
+
data if condition
|
28
|
+
end
|
29
|
+
|
30
|
+
def non_empty?(data)
|
31
|
+
!data.nil?
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileLister < IterDataPipe
|
7
|
+
def initialize(
|
8
|
+
root = ".",
|
9
|
+
masks = "",
|
10
|
+
recursive: false,
|
11
|
+
abspath: false,
|
12
|
+
non_deterministic: false,
|
13
|
+
length: -1
|
14
|
+
)
|
15
|
+
super()
|
16
|
+
if root.is_a?(String)
|
17
|
+
root = [root]
|
18
|
+
end
|
19
|
+
if !root.is_a?(IterDataPipe)
|
20
|
+
root = IterableWrapper.new(root)
|
21
|
+
end
|
22
|
+
@datapipe = root
|
23
|
+
@masks = masks
|
24
|
+
@recursive = recursive
|
25
|
+
@abspath = abspath
|
26
|
+
@non_deterministic = non_deterministic
|
27
|
+
@length = length
|
28
|
+
end
|
29
|
+
|
30
|
+
def each(&block)
|
31
|
+
@datapipe.each do |path|
|
32
|
+
get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def get_file_pathnames_from_root(
|
39
|
+
root,
|
40
|
+
masks,
|
41
|
+
recursive: false,
|
42
|
+
abspath: false,
|
43
|
+
non_deterministic: false
|
44
|
+
)
|
45
|
+
if File.file?(root)
|
46
|
+
raise NotImplementedYet
|
47
|
+
else
|
48
|
+
pattern = recursive ? "**/*" : "*"
|
49
|
+
paths = Dir.glob(pattern, base: root)
|
50
|
+
paths = paths.sort if non_deterministic
|
51
|
+
paths.each do |f|
|
52
|
+
if abspath
|
53
|
+
raise NotImplementedYet
|
54
|
+
end
|
55
|
+
if match_masks(f, masks)
|
56
|
+
yield File.join(root, f)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
def match_masks(name, masks)
|
63
|
+
if masks.empty?
|
64
|
+
return true
|
65
|
+
end
|
66
|
+
|
67
|
+
raise NotImplementedYet
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileOpener < IterDataPipe
|
7
|
+
def initialize(datapipe, mode: "r", encoding: nil, length: -1)
|
8
|
+
super()
|
9
|
+
@datapipe = datapipe
|
10
|
+
@mode = mode
|
11
|
+
@encoding = encoding
|
12
|
+
|
13
|
+
if !["b", "t", "rb", "rt", "r"].include?(@mode)
|
14
|
+
raise ArgumentError, "Invalid mode #{mode}"
|
15
|
+
end
|
16
|
+
|
17
|
+
if mode.include?("b") && !encoding.nil?
|
18
|
+
raise ArgumentError, "binary mode doesn't take an encoding argument"
|
19
|
+
end
|
20
|
+
|
21
|
+
@length = length
|
22
|
+
end
|
23
|
+
|
24
|
+
def each(&block)
|
25
|
+
get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
|
26
|
+
end
|
27
|
+
|
28
|
+
private
|
29
|
+
|
30
|
+
def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
|
31
|
+
if !pathnames.is_a?(Enumerable)
|
32
|
+
pathnames = [pathnames]
|
33
|
+
end
|
34
|
+
|
35
|
+
if ["b", "t"].include?(mode)
|
36
|
+
mode = "r#{mode}"
|
37
|
+
end
|
38
|
+
|
39
|
+
pathnames.each do |pathname|
|
40
|
+
if !pathname.is_a?(String)
|
41
|
+
raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
|
42
|
+
end
|
43
|
+
yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class IterableWrapper < IterDataPipe
|
7
|
+
def initialize(iterable, deepcopy: true)
|
8
|
+
@iterable = iterable
|
9
|
+
@deepcopy = deepcopy
|
10
|
+
end
|
11
|
+
|
12
|
+
def each
|
13
|
+
source_data = @iterable
|
14
|
+
if @deepcopy
|
15
|
+
source_data = Marshal.load(Marshal.dump(@iterable))
|
16
|
+
end
|
17
|
+
source_data.each do |data|
|
18
|
+
yield data
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
def length
|
23
|
+
@iterable.length
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class StreamWrapper
|
7
|
+
def initialize(file_obj)
|
8
|
+
@file_obj = file_obj
|
9
|
+
end
|
10
|
+
|
11
|
+
def gets(...)
|
12
|
+
@file_obj.gets(...)
|
13
|
+
end
|
14
|
+
|
15
|
+
def close
|
16
|
+
@file_obj.close
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO improve
|
20
|
+
def self.close_streams(cls)
|
21
|
+
if cls.is_a?(StreamWrapper)
|
22
|
+
cls.close
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class IterDataPipe < IterableDataset
|
6
|
+
def self.functional_datapipe(name)
|
7
|
+
IterDataPipe.register_datapipe_as_function(name, self)
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.functions
|
11
|
+
@functions ||= {}
|
12
|
+
end
|
13
|
+
|
14
|
+
def self.register_datapipe_as_function(function_name, cls_to_register)
|
15
|
+
if functions.include?(function_name)
|
16
|
+
raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
|
17
|
+
end
|
18
|
+
|
19
|
+
function = lambda do |source_dp, *args, **options, &block|
|
20
|
+
cls_to_register.new(source_dp, *args, **options, &block)
|
21
|
+
end
|
22
|
+
functions[function_name] = function
|
23
|
+
|
24
|
+
define_method function_name do |*args, **options, &block|
|
25
|
+
IterDataPipe.functions[function_name].call(self, *args, **options, &block)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def reset
|
30
|
+
# no-op, but subclasses can override
|
31
|
+
end
|
32
|
+
|
33
|
+
def each(&block)
|
34
|
+
reset
|
35
|
+
@source_datapipe.each(&block)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
|
|
187
187
|
require "torch/utils/data"
|
188
188
|
require "torch/utils/data/data_loader"
|
189
189
|
require "torch/utils/data/dataset"
|
190
|
+
require "torch/utils/data/iterable_dataset"
|
191
|
+
require "torch/utils/data/data_pipes/iter_data_pipe"
|
192
|
+
require "torch/utils/data/data_pipes/filter_iter_data_pipe"
|
193
|
+
require "torch/utils/data/data_pipes/iter/file_lister"
|
194
|
+
require "torch/utils/data/data_pipes/iter/file_opener"
|
195
|
+
require "torch/utils/data/data_pipes/iter/iterable_wrapper"
|
196
|
+
require "torch/utils/data/data_pipes/iter/stream_wrapper"
|
190
197
|
require "torch/utils/data/subset"
|
191
198
|
require "torch/utils/data/tensor_dataset"
|
192
199
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.13.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-04-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -207,7 +207,14 @@ files:
|
|
207
207
|
- lib/torch/tensor.rb
|
208
208
|
- lib/torch/utils/data.rb
|
209
209
|
- lib/torch/utils/data/data_loader.rb
|
210
|
+
- lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
|
211
|
+
- lib/torch/utils/data/data_pipes/iter/file_lister.rb
|
212
|
+
- lib/torch/utils/data/data_pipes/iter/file_opener.rb
|
213
|
+
- lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
|
214
|
+
- lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
|
215
|
+
- lib/torch/utils/data/data_pipes/iter_data_pipe.rb
|
210
216
|
- lib/torch/utils/data/dataset.rb
|
217
|
+
- lib/torch/utils/data/iterable_dataset.rb
|
211
218
|
- lib/torch/utils/data/subset.rb
|
212
219
|
- lib/torch/utils/data/tensor_dataset.rb
|
213
220
|
- lib/torch/version.rb
|
@@ -223,14 +230,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
223
230
|
requirements:
|
224
231
|
- - ">="
|
225
232
|
- !ruby/object:Gem::Version
|
226
|
-
version: '
|
233
|
+
version: '3'
|
227
234
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
228
235
|
requirements:
|
229
236
|
- - ">="
|
230
237
|
- !ruby/object:Gem::Version
|
231
238
|
version: '0'
|
232
239
|
requirements: []
|
233
|
-
rubygems_version: 3.4.
|
240
|
+
rubygems_version: 3.4.10
|
234
241
|
signing_key:
|
235
242
|
specification_version: 4
|
236
243
|
summary: Deep learning for Ruby, powered by LibTorch
|