torch-rb 0.12.1 → 0.12.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ce853372191c85509a65417abeaa05c484976f24681246babf9bd00f8db16df1
4
- data.tar.gz: 3327004180566a194c7de8288c260b8fe9487d80c25493d82e041cf4fc0062e2
3
+ metadata.gz: 8160298f6299869f699201dd9217785da5ac7d2562f05b12a2df2d4537143886
4
+ data.tar.gz: cb902f45cb72adaaa2008e181af2af0844e926ef62f1d2ccaa67251cb70c5799
5
5
  SHA512:
6
- metadata.gz: d8b6b0f7bd8b79963931b6b28b6a6cee59be18b8f185e2acadc22487a27b5793edabf0fb8b80857c5dfc0eb036a27b67a55750b8dd3c8eb1d199f06323e2b919
7
- data.tar.gz: ffcafd2e9e99d6654f9689dd021c874cf53847833e26207853b0051056ae7cfe0b5ff202036e9a38264198facfa2c66f83b6ee51815131b63aa350415139b614
6
+ metadata.gz: c5077c40cb32414ac7d230430fdcf1cd93f199a81100179a1aa26185d6216c049316bbf35d79968442fd4db5e18414991946a366e6e7e6260a904f001774faa2
7
+ data.tar.gz: b4f1ff33b47596953c0654b07f1916a88ba247af72c1d4693d9b737802a790256d3232c0419c3a603fdebd51f47360833265326a1f27f8242604ff05766ea690
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.12.2 (2023-01-30)
2
+
3
+ - Added experimental support for DataPipes
4
+
1
5
  ## 0.12.1 (2023-01-29)
2
6
 
3
7
  - Added `Generator` class
@@ -0,0 +1,37 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class FilterIterDataPipe < IterDataPipe
6
+ functional_datapipe :filter
7
+
8
+ def initialize(datapipe, &block)
9
+ @datapipe = datapipe
10
+ @filter_fn = block
11
+ end
12
+
13
+ def each
14
+ @datapipe.each do |data|
15
+ filtered = return_if_true(data)
16
+ if non_empty?(filtered)
17
+ yield filtered
18
+ else
19
+ Iter::StreamWrapper.close_streams(data)
20
+ end
21
+ end
22
+ end
23
+
24
+ def return_if_true(data)
25
+ condition = @filter_fn.call(data)
26
+
27
+ data if condition
28
+ end
29
+
30
+ def non_empty?(data)
31
+ !data.nil?
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,74 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileLister < IterDataPipe
7
+ def initialize(
8
+ root = ".",
9
+ masks = "",
10
+ recursive: false,
11
+ abspath: false,
12
+ non_deterministic: false,
13
+ length: -1
14
+ )
15
+ super()
16
+ if root.is_a?(String)
17
+ root = [root]
18
+ end
19
+ if !root.is_a?(IterDataPipe)
20
+ root = IterableWrapper.new(root)
21
+ end
22
+ @datapipe = root
23
+ @masks = masks
24
+ @recursive = recursive
25
+ @abspath = abspath
26
+ @non_deterministic = non_deterministic
27
+ @length = length
28
+ end
29
+
30
+ def each(&block)
31
+ @datapipe.each do |path|
32
+ get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
33
+ end
34
+ end
35
+
36
+ private
37
+
38
+ def get_file_pathnames_from_root(
39
+ root,
40
+ masks,
41
+ recursive: false,
42
+ abspath: false,
43
+ non_deterministic: false
44
+ )
45
+ if File.file?(root)
46
+ raise NotImplementedYet
47
+ else
48
+ pattern = recursive ? "**/*" : "*"
49
+ paths = Dir.glob(pattern, base: root)
50
+ paths = paths.sort if non_deterministic
51
+ paths.each do |f|
52
+ if abspath
53
+ raise NotImplementedYet
54
+ end
55
+ if match_masks(f, masks)
56
+ yield File.join(root, f)
57
+ end
58
+ end
59
+ end
60
+ end
61
+
62
+ def match_masks(name, masks)
63
+ if masks.empty?
64
+ return true
65
+ end
66
+
67
+ raise NotImplementedYet
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,51 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileOpener < IterDataPipe
7
+ def initialize(datapipe, mode: "r", encoding: nil, length: -1)
8
+ super()
9
+ @datapipe = datapipe
10
+ @mode = mode
11
+ @encoding = encoding
12
+
13
+ if !["b", "t", "rb", "rt", "r"].include?(@mode)
14
+ raise ArgumentError, "Invalid mode #{mode}"
15
+ end
16
+
17
+ if mode.include?("b") && !encoding.nil?
18
+ raise ArgumentError, "binary mode doesn't take an encoding argument"
19
+ end
20
+
21
+ @length = length
22
+ end
23
+
24
+ def each(&block)
25
+ get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
26
+ end
27
+
28
+ private
29
+
30
+ def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
31
+ if !pathnames.is_a?(Enumerable)
32
+ pathnames = [pathnames]
33
+ end
34
+
35
+ if ["b", "t"].include?(mode)
36
+ mode = "r#{mode}"
37
+ end
38
+
39
+ pathnames.each do |pathname|
40
+ if !pathname.is_a?(String)
41
+ raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
42
+ end
43
+ yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class IterableWrapper < IterDataPipe
7
+ def initialize(iterable, deepcopy: true)
8
+ @iterable = iterable
9
+ @deepcopy = deepcopy
10
+ end
11
+
12
+ def each
13
+ source_data = @iterable
14
+ if @deepcopy
15
+ source_data = Marshal.load(Marshal.dump(@iterable))
16
+ end
17
+ source_data.each do |data|
18
+ yield data
19
+ end
20
+ end
21
+
22
+ def length
23
+ @iterable.length
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class StreamWrapper
7
+ def initialize(file_obj)
8
+ @file_obj = file_obj
9
+ end
10
+
11
+ def gets(...)
12
+ @file_obj.gets(...)
13
+ end
14
+
15
+ def close
16
+ @file_obj.close
17
+ end
18
+
19
+ # TODO improve
20
+ def self.close_streams(cls)
21
+ if cls.is_a?(StreamWrapper)
22
+ cls.close
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,41 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class IterDataPipe < IterableDataset
6
+ def self.functional_datapipe(name)
7
+ IterDataPipe.register_datapipe_as_function(name, self)
8
+ end
9
+
10
+ def self.functions
11
+ @functions ||= {}
12
+ end
13
+
14
+ def self.register_datapipe_as_function(function_name, cls_to_register)
15
+ if functions.include?(function_name)
16
+ raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
17
+ end
18
+
19
+ function = lambda do |source_dp, *args, **options, &block|
20
+ cls_to_register.new(source_dp, *args, **options, &block)
21
+ end
22
+ functions[function_name] = function
23
+
24
+ define_method function_name do |*args, **options, &block|
25
+ IterDataPipe.functions[function_name].call(self, *args, **options, &block)
26
+ end
27
+ end
28
+
29
+ def reset
30
+ # no-op, but subclasses can override
31
+ end
32
+
33
+ def each(&block)
34
+ reset
35
+ @source_datapipe.each(&block)
36
+ end
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class IterableDataset < Dataset
5
+ include Enumerable
6
+
7
+ def each
8
+ raise NotImplementedError
9
+ end
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.12.1"
2
+ VERSION = "0.12.2"
3
3
  end
data/lib/torch.rb CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
187
187
  require "torch/utils/data"
188
188
  require "torch/utils/data/data_loader"
189
189
  require "torch/utils/data/dataset"
190
+ require "torch/utils/data/iterable_dataset"
191
+ require "torch/utils/data/data_pipes/iter_data_pipe"
192
+ require "torch/utils/data/data_pipes/filter_iter_data_pipe"
193
+ require "torch/utils/data/data_pipes/iter/file_lister"
194
+ require "torch/utils/data/data_pipes/iter/file_opener"
195
+ require "torch/utils/data/data_pipes/iter/iterable_wrapper"
196
+ require "torch/utils/data/data_pipes/iter/stream_wrapper"
190
197
  require "torch/utils/data/subset"
191
198
  require "torch/utils/data/tensor_dataset"
192
199
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.12.1
4
+ version: 0.12.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-01-30 00:00:00.000000000 Z
11
+ date: 2023-01-31 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -207,7 +207,14 @@ files:
207
207
  - lib/torch/tensor.rb
208
208
  - lib/torch/utils/data.rb
209
209
  - lib/torch/utils/data/data_loader.rb
210
+ - lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
211
+ - lib/torch/utils/data/data_pipes/iter/file_lister.rb
212
+ - lib/torch/utils/data/data_pipes/iter/file_opener.rb
213
+ - lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
214
+ - lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
215
+ - lib/torch/utils/data/data_pipes/iter_data_pipe.rb
210
216
  - lib/torch/utils/data/dataset.rb
217
+ - lib/torch/utils/data/iterable_dataset.rb
211
218
  - lib/torch/utils/data/subset.rb
212
219
  - lib/torch/utils/data/tensor_dataset.rb
213
220
  - lib/torch/version.rb