torch-rb 0.12.1 → 0.12.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb +37 -0
- data/lib/torch/utils/data/data_pipes/iter/file_lister.rb +74 -0
- data/lib/torch/utils/data/data_pipes/iter/file_opener.rb +51 -0
- data/lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb +30 -0
- data/lib/torch/utils/data/data_pipes/iter_data_pipe.rb +41 -0
- data/lib/torch/utils/data/iterable_dataset.rb +13 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +7 -0
- metadata +9 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8160298f6299869f699201dd9217785da5ac7d2562f05b12a2df2d4537143886
|
4
|
+
data.tar.gz: cb902f45cb72adaaa2008e181af2af0844e926ef62f1d2ccaa67251cb70c5799
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c5077c40cb32414ac7d230430fdcf1cd93f199a81100179a1aa26185d6216c049316bbf35d79968442fd4db5e18414991946a366e6e7e6260a904f001774faa2
|
7
|
+
data.tar.gz: b4f1ff33b47596953c0654b07f1916a88ba247af72c1d4693d9b737802a790256d3232c0419c3a603fdebd51f47360833265326a1f27f8242604ff05766ea690
|
data/CHANGELOG.md
CHANGED
@@ -0,0 +1,37 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class FilterIterDataPipe < IterDataPipe
|
6
|
+
functional_datapipe :filter
|
7
|
+
|
8
|
+
def initialize(datapipe, &block)
|
9
|
+
@datapipe = datapipe
|
10
|
+
@filter_fn = block
|
11
|
+
end
|
12
|
+
|
13
|
+
def each
|
14
|
+
@datapipe.each do |data|
|
15
|
+
filtered = return_if_true(data)
|
16
|
+
if non_empty?(filtered)
|
17
|
+
yield filtered
|
18
|
+
else
|
19
|
+
Iter::StreamWrapper.close_streams(data)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def return_if_true(data)
|
25
|
+
condition = @filter_fn.call(data)
|
26
|
+
|
27
|
+
data if condition
|
28
|
+
end
|
29
|
+
|
30
|
+
def non_empty?(data)
|
31
|
+
!data.nil?
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileLister < IterDataPipe
|
7
|
+
def initialize(
|
8
|
+
root = ".",
|
9
|
+
masks = "",
|
10
|
+
recursive: false,
|
11
|
+
abspath: false,
|
12
|
+
non_deterministic: false,
|
13
|
+
length: -1
|
14
|
+
)
|
15
|
+
super()
|
16
|
+
if root.is_a?(String)
|
17
|
+
root = [root]
|
18
|
+
end
|
19
|
+
if !root.is_a?(IterDataPipe)
|
20
|
+
root = IterableWrapper.new(root)
|
21
|
+
end
|
22
|
+
@datapipe = root
|
23
|
+
@masks = masks
|
24
|
+
@recursive = recursive
|
25
|
+
@abspath = abspath
|
26
|
+
@non_deterministic = non_deterministic
|
27
|
+
@length = length
|
28
|
+
end
|
29
|
+
|
30
|
+
def each(&block)
|
31
|
+
@datapipe.each do |path|
|
32
|
+
get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def get_file_pathnames_from_root(
|
39
|
+
root,
|
40
|
+
masks,
|
41
|
+
recursive: false,
|
42
|
+
abspath: false,
|
43
|
+
non_deterministic: false
|
44
|
+
)
|
45
|
+
if File.file?(root)
|
46
|
+
raise NotImplementedYet
|
47
|
+
else
|
48
|
+
pattern = recursive ? "**/*" : "*"
|
49
|
+
paths = Dir.glob(pattern, base: root)
|
50
|
+
paths = paths.sort if non_deterministic
|
51
|
+
paths.each do |f|
|
52
|
+
if abspath
|
53
|
+
raise NotImplementedYet
|
54
|
+
end
|
55
|
+
if match_masks(f, masks)
|
56
|
+
yield File.join(root, f)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
def match_masks(name, masks)
|
63
|
+
if masks.empty?
|
64
|
+
return true
|
65
|
+
end
|
66
|
+
|
67
|
+
raise NotImplementedYet
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -0,0 +1,51 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class FileOpener < IterDataPipe
|
7
|
+
def initialize(datapipe, mode: "r", encoding: nil, length: -1)
|
8
|
+
super()
|
9
|
+
@datapipe = datapipe
|
10
|
+
@mode = mode
|
11
|
+
@encoding = encoding
|
12
|
+
|
13
|
+
if !["b", "t", "rb", "rt", "r"].include?(@mode)
|
14
|
+
raise ArgumentError, "Invalid mode #{mode}"
|
15
|
+
end
|
16
|
+
|
17
|
+
if mode.include?("b") && !encoding.nil?
|
18
|
+
raise ArgumentError, "binary mode doesn't take an encoding argument"
|
19
|
+
end
|
20
|
+
|
21
|
+
@length = length
|
22
|
+
end
|
23
|
+
|
24
|
+
def each(&block)
|
25
|
+
get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
|
26
|
+
end
|
27
|
+
|
28
|
+
private
|
29
|
+
|
30
|
+
def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
|
31
|
+
if !pathnames.is_a?(Enumerable)
|
32
|
+
pathnames = [pathnames]
|
33
|
+
end
|
34
|
+
|
35
|
+
if ["b", "t"].include?(mode)
|
36
|
+
mode = "r#{mode}"
|
37
|
+
end
|
38
|
+
|
39
|
+
pathnames.each do |pathname|
|
40
|
+
if !pathname.is_a?(String)
|
41
|
+
raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
|
42
|
+
end
|
43
|
+
yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class IterableWrapper < IterDataPipe
|
7
|
+
def initialize(iterable, deepcopy: true)
|
8
|
+
@iterable = iterable
|
9
|
+
@deepcopy = deepcopy
|
10
|
+
end
|
11
|
+
|
12
|
+
def each
|
13
|
+
source_data = @iterable
|
14
|
+
if @deepcopy
|
15
|
+
source_data = Marshal.load(Marshal.dump(@iterable))
|
16
|
+
end
|
17
|
+
source_data.each do |data|
|
18
|
+
yield data
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
def length
|
23
|
+
@iterable.length
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
module Iter
|
6
|
+
class StreamWrapper
|
7
|
+
def initialize(file_obj)
|
8
|
+
@file_obj = file_obj
|
9
|
+
end
|
10
|
+
|
11
|
+
def gets(...)
|
12
|
+
@file_obj.gets(...)
|
13
|
+
end
|
14
|
+
|
15
|
+
def close
|
16
|
+
@file_obj.close
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO improve
|
20
|
+
def self.close_streams(cls)
|
21
|
+
if cls.is_a?(StreamWrapper)
|
22
|
+
cls.close
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
module DataPipes
|
5
|
+
class IterDataPipe < IterableDataset
|
6
|
+
def self.functional_datapipe(name)
|
7
|
+
IterDataPipe.register_datapipe_as_function(name, self)
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.functions
|
11
|
+
@functions ||= {}
|
12
|
+
end
|
13
|
+
|
14
|
+
def self.register_datapipe_as_function(function_name, cls_to_register)
|
15
|
+
if functions.include?(function_name)
|
16
|
+
raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
|
17
|
+
end
|
18
|
+
|
19
|
+
function = lambda do |source_dp, *args, **options, &block|
|
20
|
+
cls_to_register.new(source_dp, *args, **options, &block)
|
21
|
+
end
|
22
|
+
functions[function_name] = function
|
23
|
+
|
24
|
+
define_method function_name do |*args, **options, &block|
|
25
|
+
IterDataPipe.functions[function_name].call(self, *args, **options, &block)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def reset
|
30
|
+
# no-op, but subclasses can override
|
31
|
+
end
|
32
|
+
|
33
|
+
def each(&block)
|
34
|
+
reset
|
35
|
+
@source_datapipe.each(&block)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
|
|
187
187
|
require "torch/utils/data"
|
188
188
|
require "torch/utils/data/data_loader"
|
189
189
|
require "torch/utils/data/dataset"
|
190
|
+
require "torch/utils/data/iterable_dataset"
|
191
|
+
require "torch/utils/data/data_pipes/iter_data_pipe"
|
192
|
+
require "torch/utils/data/data_pipes/filter_iter_data_pipe"
|
193
|
+
require "torch/utils/data/data_pipes/iter/file_lister"
|
194
|
+
require "torch/utils/data/data_pipes/iter/file_opener"
|
195
|
+
require "torch/utils/data/data_pipes/iter/iterable_wrapper"
|
196
|
+
require "torch/utils/data/data_pipes/iter/stream_wrapper"
|
190
197
|
require "torch/utils/data/subset"
|
191
198
|
require "torch/utils/data/tensor_dataset"
|
192
199
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.12.
|
4
|
+
version: 0.12.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-01-
|
11
|
+
date: 2023-01-31 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -207,7 +207,14 @@ files:
|
|
207
207
|
- lib/torch/tensor.rb
|
208
208
|
- lib/torch/utils/data.rb
|
209
209
|
- lib/torch/utils/data/data_loader.rb
|
210
|
+
- lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
|
211
|
+
- lib/torch/utils/data/data_pipes/iter/file_lister.rb
|
212
|
+
- lib/torch/utils/data/data_pipes/iter/file_opener.rb
|
213
|
+
- lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
|
214
|
+
- lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
|
215
|
+
- lib/torch/utils/data/data_pipes/iter_data_pipe.rb
|
210
216
|
- lib/torch/utils/data/dataset.rb
|
217
|
+
- lib/torch/utils/data/iterable_dataset.rb
|
211
218
|
- lib/torch/utils/data/subset.rb
|
212
219
|
- lib/torch/utils/data/tensor_dataset.rb
|
213
220
|
- lib/torch/version.rb
|