torch-rb 0.12.1 → 0.13.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -14,4 +14,8 @@ void init_backends(Rice::Module& m) {
14
14
  Rice::define_module_under(rb_mBackends, "MKL")
15
15
  .add_handler<torch::Error>(handle_error)
16
16
  .define_singleton_function("available?", &torch::hasMKL);
17
+
18
+ Rice::define_module_under(rb_mBackends, "MPS")
19
+ .add_handler<torch::Error>(handle_error)
20
+ .define_singleton_function("available?", &torch::hasMPS);
17
21
  }
data/ext/torch/tensor.cpp CHANGED
@@ -35,17 +35,17 @@ std::vector<TensorIndex> index_vector(Array a) {
35
35
  if (obj.is_instance_of(rb_cInteger)) {
36
36
  indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
37
37
  } else if (obj.is_instance_of(rb_cRange)) {
38
- torch::optional<int64_t> start_index = torch::nullopt;
39
- torch::optional<int64_t> stop_index = torch::nullopt;
38
+ torch::optional<c10::SymInt> start_index = torch::nullopt;
39
+ torch::optional<c10::SymInt> stop_index = torch::nullopt;
40
40
 
41
41
  Object begin = obj.call("begin");
42
42
  if (!begin.is_nil()) {
43
- start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
43
+ start_index = c10::SymInt(Rice::detail::From_Ruby<int64_t>().convert(begin.value()));
44
44
  }
45
45
 
46
46
  Object end = obj.call("end");
47
47
  if (!end.is_nil()) {
48
- stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
48
+ stop_index = c10::SymInt(Rice::detail::From_Ruby<int64_t>().convert(end.value()));
49
49
  }
50
50
 
51
51
  Object exclude_end = obj.call("exclude_end?");
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 13,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -0,0 +1,37 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class FilterIterDataPipe < IterDataPipe
6
+ functional_datapipe :filter
7
+
8
+ def initialize(datapipe, &block)
9
+ @datapipe = datapipe
10
+ @filter_fn = block
11
+ end
12
+
13
+ def each
14
+ @datapipe.each do |data|
15
+ filtered = return_if_true(data)
16
+ if non_empty?(filtered)
17
+ yield filtered
18
+ else
19
+ Iter::StreamWrapper.close_streams(data)
20
+ end
21
+ end
22
+ end
23
+
24
+ def return_if_true(data)
25
+ condition = @filter_fn.call(data)
26
+
27
+ data if condition
28
+ end
29
+
30
+ def non_empty?(data)
31
+ !data.nil?
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,74 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileLister < IterDataPipe
7
+ def initialize(
8
+ root = ".",
9
+ masks = "",
10
+ recursive: false,
11
+ abspath: false,
12
+ non_deterministic: false,
13
+ length: -1
14
+ )
15
+ super()
16
+ if root.is_a?(String)
17
+ root = [root]
18
+ end
19
+ if !root.is_a?(IterDataPipe)
20
+ root = IterableWrapper.new(root)
21
+ end
22
+ @datapipe = root
23
+ @masks = masks
24
+ @recursive = recursive
25
+ @abspath = abspath
26
+ @non_deterministic = non_deterministic
27
+ @length = length
28
+ end
29
+
30
+ def each(&block)
31
+ @datapipe.each do |path|
32
+ get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
33
+ end
34
+ end
35
+
36
+ private
37
+
38
+ def get_file_pathnames_from_root(
39
+ root,
40
+ masks,
41
+ recursive: false,
42
+ abspath: false,
43
+ non_deterministic: false
44
+ )
45
+ if File.file?(root)
46
+ raise NotImplementedYet
47
+ else
48
+ pattern = recursive ? "**/*" : "*"
49
+ paths = Dir.glob(pattern, base: root)
50
+ paths = paths.sort if non_deterministic
51
+ paths.each do |f|
52
+ if abspath
53
+ raise NotImplementedYet
54
+ end
55
+ if match_masks(f, masks)
56
+ yield File.join(root, f)
57
+ end
58
+ end
59
+ end
60
+ end
61
+
62
+ def match_masks(name, masks)
63
+ if masks.empty?
64
+ return true
65
+ end
66
+
67
+ raise NotImplementedYet
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,51 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileOpener < IterDataPipe
7
+ def initialize(datapipe, mode: "r", encoding: nil, length: -1)
8
+ super()
9
+ @datapipe = datapipe
10
+ @mode = mode
11
+ @encoding = encoding
12
+
13
+ if !["b", "t", "rb", "rt", "r"].include?(@mode)
14
+ raise ArgumentError, "Invalid mode #{mode}"
15
+ end
16
+
17
+ if mode.include?("b") && !encoding.nil?
18
+ raise ArgumentError, "binary mode doesn't take an encoding argument"
19
+ end
20
+
21
+ @length = length
22
+ end
23
+
24
+ def each(&block)
25
+ get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
26
+ end
27
+
28
+ private
29
+
30
+ def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
31
+ if !pathnames.is_a?(Enumerable)
32
+ pathnames = [pathnames]
33
+ end
34
+
35
+ if ["b", "t"].include?(mode)
36
+ mode = "r#{mode}"
37
+ end
38
+
39
+ pathnames.each do |pathname|
40
+ if !pathname.is_a?(String)
41
+ raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
42
+ end
43
+ yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class IterableWrapper < IterDataPipe
7
+ def initialize(iterable, deepcopy: true)
8
+ @iterable = iterable
9
+ @deepcopy = deepcopy
10
+ end
11
+
12
+ def each
13
+ source_data = @iterable
14
+ if @deepcopy
15
+ source_data = Marshal.load(Marshal.dump(@iterable))
16
+ end
17
+ source_data.each do |data|
18
+ yield data
19
+ end
20
+ end
21
+
22
+ def length
23
+ @iterable.length
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class StreamWrapper
7
+ def initialize(file_obj)
8
+ @file_obj = file_obj
9
+ end
10
+
11
+ def gets(...)
12
+ @file_obj.gets(...)
13
+ end
14
+
15
+ def close
16
+ @file_obj.close
17
+ end
18
+
19
+ # TODO improve
20
+ def self.close_streams(cls)
21
+ if cls.is_a?(StreamWrapper)
22
+ cls.close
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,41 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class IterDataPipe < IterableDataset
6
+ def self.functional_datapipe(name)
7
+ IterDataPipe.register_datapipe_as_function(name, self)
8
+ end
9
+
10
+ def self.functions
11
+ @functions ||= {}
12
+ end
13
+
14
+ def self.register_datapipe_as_function(function_name, cls_to_register)
15
+ if functions.include?(function_name)
16
+ raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
17
+ end
18
+
19
+ function = lambda do |source_dp, *args, **options, &block|
20
+ cls_to_register.new(source_dp, *args, **options, &block)
21
+ end
22
+ functions[function_name] = function
23
+
24
+ define_method function_name do |*args, **options, &block|
25
+ IterDataPipe.functions[function_name].call(self, *args, **options, &block)
26
+ end
27
+ end
28
+
29
+ def reset
30
+ # no-op, but subclasses can override
31
+ end
32
+
33
+ def each(&block)
34
+ reset
35
+ @source_datapipe.each(&block)
36
+ end
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class IterableDataset < Dataset
5
+ include Enumerable
6
+
7
+ def each
8
+ raise NotImplementedError
9
+ end
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.12.1"
2
+ VERSION = "0.13.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
187
187
  require "torch/utils/data"
188
188
  require "torch/utils/data/data_loader"
189
189
  require "torch/utils/data/dataset"
190
+ require "torch/utils/data/iterable_dataset"
191
+ require "torch/utils/data/data_pipes/iter_data_pipe"
192
+ require "torch/utils/data/data_pipes/filter_iter_data_pipe"
193
+ require "torch/utils/data/data_pipes/iter/file_lister"
194
+ require "torch/utils/data/data_pipes/iter/file_opener"
195
+ require "torch/utils/data/data_pipes/iter/iterable_wrapper"
196
+ require "torch/utils/data/data_pipes/iter/stream_wrapper"
190
197
  require "torch/utils/data/subset"
191
198
  require "torch/utils/data/tensor_dataset"
192
199
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.12.1
4
+ version: 0.13.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-01-30 00:00:00.000000000 Z
11
+ date: 2023-04-13 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -207,7 +207,14 @@ files:
207
207
  - lib/torch/tensor.rb
208
208
  - lib/torch/utils/data.rb
209
209
  - lib/torch/utils/data/data_loader.rb
210
+ - lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
211
+ - lib/torch/utils/data/data_pipes/iter/file_lister.rb
212
+ - lib/torch/utils/data/data_pipes/iter/file_opener.rb
213
+ - lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
214
+ - lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
215
+ - lib/torch/utils/data/data_pipes/iter_data_pipe.rb
210
216
  - lib/torch/utils/data/dataset.rb
217
+ - lib/torch/utils/data/iterable_dataset.rb
211
218
  - lib/torch/utils/data/subset.rb
212
219
  - lib/torch/utils/data/tensor_dataset.rb
213
220
  - lib/torch/version.rb
@@ -223,14 +230,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
223
230
  requirements:
224
231
  - - ">="
225
232
  - !ruby/object:Gem::Version
226
- version: '2.7'
233
+ version: '3'
227
234
  required_rubygems_version: !ruby/object:Gem::Requirement
228
235
  requirements:
229
236
  - - ">="
230
237
  - !ruby/object:Gem::Version
231
238
  version: '0'
232
239
  requirements: []
233
- rubygems_version: 3.4.1
240
+ rubygems_version: 3.4.10
234
241
  signing_key:
235
242
  specification_version: 4
236
243
  summary: Deep learning for Ruby, powered by LibTorch