torch-rb 0.12.1 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,4 +14,8 @@ void init_backends(Rice::Module& m) {
14
14
  Rice::define_module_under(rb_mBackends, "MKL")
15
15
  .add_handler<torch::Error>(handle_error)
16
16
  .define_singleton_function("available?", &torch::hasMKL);
17
+
18
+ Rice::define_module_under(rb_mBackends, "MPS")
19
+ .add_handler<torch::Error>(handle_error)
20
+ .define_singleton_function("available?", &torch::hasMPS);
17
21
  }
data/ext/torch/tensor.cpp CHANGED
@@ -35,17 +35,17 @@ std::vector<TensorIndex> index_vector(Array a) {
35
35
  if (obj.is_instance_of(rb_cInteger)) {
36
36
  indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
37
37
  } else if (obj.is_instance_of(rb_cRange)) {
38
- torch::optional<int64_t> start_index = torch::nullopt;
39
- torch::optional<int64_t> stop_index = torch::nullopt;
38
+ torch::optional<c10::SymInt> start_index = torch::nullopt;
39
+ torch::optional<c10::SymInt> stop_index = torch::nullopt;
40
40
 
41
41
  Object begin = obj.call("begin");
42
42
  if (!begin.is_nil()) {
43
- start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
43
+ start_index = c10::SymInt(Rice::detail::From_Ruby<int64_t>().convert(begin.value()));
44
44
  }
45
45
 
46
46
  Object end = obj.call("end");
47
47
  if (!end.is_nil()) {
48
- stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
48
+ stop_index = c10::SymInt(Rice::detail::From_Ruby<int64_t>().convert(end.value()));
49
49
  }
50
50
 
51
51
  Object exclude_end = obj.call("exclude_end?");
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 13,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -0,0 +1,37 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class FilterIterDataPipe < IterDataPipe
6
+ functional_datapipe :filter
7
+
8
+ def initialize(datapipe, &block)
9
+ @datapipe = datapipe
10
+ @filter_fn = block
11
+ end
12
+
13
+ def each
14
+ @datapipe.each do |data|
15
+ filtered = return_if_true(data)
16
+ if non_empty?(filtered)
17
+ yield filtered
18
+ else
19
+ Iter::StreamWrapper.close_streams(data)
20
+ end
21
+ end
22
+ end
23
+
24
+ def return_if_true(data)
25
+ condition = @filter_fn.call(data)
26
+
27
+ data if condition
28
+ end
29
+
30
+ def non_empty?(data)
31
+ !data.nil?
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,74 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileLister < IterDataPipe
7
+ def initialize(
8
+ root = ".",
9
+ masks = "",
10
+ recursive: false,
11
+ abspath: false,
12
+ non_deterministic: false,
13
+ length: -1
14
+ )
15
+ super()
16
+ if root.is_a?(String)
17
+ root = [root]
18
+ end
19
+ if !root.is_a?(IterDataPipe)
20
+ root = IterableWrapper.new(root)
21
+ end
22
+ @datapipe = root
23
+ @masks = masks
24
+ @recursive = recursive
25
+ @abspath = abspath
26
+ @non_deterministic = non_deterministic
27
+ @length = length
28
+ end
29
+
30
+ def each(&block)
31
+ @datapipe.each do |path|
32
+ get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
33
+ end
34
+ end
35
+
36
+ private
37
+
38
+ def get_file_pathnames_from_root(
39
+ root,
40
+ masks,
41
+ recursive: false,
42
+ abspath: false,
43
+ non_deterministic: false
44
+ )
45
+ if File.file?(root)
46
+ raise NotImplementedYet
47
+ else
48
+ pattern = recursive ? "**/*" : "*"
49
+ paths = Dir.glob(pattern, base: root)
50
+ paths = paths.sort if non_deterministic
51
+ paths.each do |f|
52
+ if abspath
53
+ raise NotImplementedYet
54
+ end
55
+ if match_masks(f, masks)
56
+ yield File.join(root, f)
57
+ end
58
+ end
59
+ end
60
+ end
61
+
62
+ def match_masks(name, masks)
63
+ if masks.empty?
64
+ return true
65
+ end
66
+
67
+ raise NotImplementedYet
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,51 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileOpener < IterDataPipe
7
+ def initialize(datapipe, mode: "r", encoding: nil, length: -1)
8
+ super()
9
+ @datapipe = datapipe
10
+ @mode = mode
11
+ @encoding = encoding
12
+
13
+ if !["b", "t", "rb", "rt", "r"].include?(@mode)
14
+ raise ArgumentError, "Invalid mode #{mode}"
15
+ end
16
+
17
+ if mode.include?("b") && !encoding.nil?
18
+ raise ArgumentError, "binary mode doesn't take an encoding argument"
19
+ end
20
+
21
+ @length = length
22
+ end
23
+
24
+ def each(&block)
25
+ get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
26
+ end
27
+
28
+ private
29
+
30
+ def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
31
+ if !pathnames.is_a?(Enumerable)
32
+ pathnames = [pathnames]
33
+ end
34
+
35
+ if ["b", "t"].include?(mode)
36
+ mode = "r#{mode}"
37
+ end
38
+
39
+ pathnames.each do |pathname|
40
+ if !pathname.is_a?(String)
41
+ raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
42
+ end
43
+ yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class IterableWrapper < IterDataPipe
7
+ def initialize(iterable, deepcopy: true)
8
+ @iterable = iterable
9
+ @deepcopy = deepcopy
10
+ end
11
+
12
+ def each
13
+ source_data = @iterable
14
+ if @deepcopy
15
+ source_data = Marshal.load(Marshal.dump(@iterable))
16
+ end
17
+ source_data.each do |data|
18
+ yield data
19
+ end
20
+ end
21
+
22
+ def length
23
+ @iterable.length
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class StreamWrapper
7
+ def initialize(file_obj)
8
+ @file_obj = file_obj
9
+ end
10
+
11
+ def gets(...)
12
+ @file_obj.gets(...)
13
+ end
14
+
15
+ def close
16
+ @file_obj.close
17
+ end
18
+
19
+ # TODO improve
20
+ def self.close_streams(cls)
21
+ if cls.is_a?(StreamWrapper)
22
+ cls.close
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,41 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class IterDataPipe < IterableDataset
6
+ def self.functional_datapipe(name)
7
+ IterDataPipe.register_datapipe_as_function(name, self)
8
+ end
9
+
10
+ def self.functions
11
+ @functions ||= {}
12
+ end
13
+
14
+ def self.register_datapipe_as_function(function_name, cls_to_register)
15
+ if functions.include?(function_name)
16
+ raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
17
+ end
18
+
19
+ function = lambda do |source_dp, *args, **options, &block|
20
+ cls_to_register.new(source_dp, *args, **options, &block)
21
+ end
22
+ functions[function_name] = function
23
+
24
+ define_method function_name do |*args, **options, &block|
25
+ IterDataPipe.functions[function_name].call(self, *args, **options, &block)
26
+ end
27
+ end
28
+
29
+ def reset
30
+ # no-op, but subclasses can override
31
+ end
32
+
33
+ def each(&block)
34
+ reset
35
+ @source_datapipe.each(&block)
36
+ end
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class IterableDataset < Dataset
5
+ include Enumerable
6
+
7
+ def each
8
+ raise NotImplementedError
9
+ end
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.12.1"
2
+ VERSION = "0.13.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
187
187
  require "torch/utils/data"
188
188
  require "torch/utils/data/data_loader"
189
189
  require "torch/utils/data/dataset"
190
+ require "torch/utils/data/iterable_dataset"
191
+ require "torch/utils/data/data_pipes/iter_data_pipe"
192
+ require "torch/utils/data/data_pipes/filter_iter_data_pipe"
193
+ require "torch/utils/data/data_pipes/iter/file_lister"
194
+ require "torch/utils/data/data_pipes/iter/file_opener"
195
+ require "torch/utils/data/data_pipes/iter/iterable_wrapper"
196
+ require "torch/utils/data/data_pipes/iter/stream_wrapper"
190
197
  require "torch/utils/data/subset"
191
198
  require "torch/utils/data/tensor_dataset"
192
199
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.12.1
4
+ version: 0.13.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-01-30 00:00:00.000000000 Z
11
+ date: 2023-04-13 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -207,7 +207,14 @@ files:
207
207
  - lib/torch/tensor.rb
208
208
  - lib/torch/utils/data.rb
209
209
  - lib/torch/utils/data/data_loader.rb
210
+ - lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
211
+ - lib/torch/utils/data/data_pipes/iter/file_lister.rb
212
+ - lib/torch/utils/data/data_pipes/iter/file_opener.rb
213
+ - lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
214
+ - lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
215
+ - lib/torch/utils/data/data_pipes/iter_data_pipe.rb
210
216
  - lib/torch/utils/data/dataset.rb
217
+ - lib/torch/utils/data/iterable_dataset.rb
211
218
  - lib/torch/utils/data/subset.rb
212
219
  - lib/torch/utils/data/tensor_dataset.rb
213
220
  - lib/torch/version.rb
@@ -223,14 +230,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
223
230
  requirements:
224
231
  - - ">="
225
232
  - !ruby/object:Gem::Version
226
- version: '2.7'
233
+ version: '3'
227
234
  required_rubygems_version: !ruby/object:Gem::Requirement
228
235
  requirements:
229
236
  - - ">="
230
237
  - !ruby/object:Gem::Version
231
238
  version: '0'
232
239
  requirements: []
233
- rubygems_version: 3.4.1
240
+ rubygems_version: 3.4.10
234
241
  signing_key:
235
242
  specification_version: 4
236
243
  summary: Deep learning for Ruby, powered by LibTorch