torch-rb 0.12.0 → 0.12.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 24617a6191c4d1e42ff0d5885c0762db75d034ee52817d70004ddc52025890dd
4
- data.tar.gz: 7036f3c7fac8a1aac22914ec2811b8771464baeb847052e417b069f3897b3977
3
+ metadata.gz: 8160298f6299869f699201dd9217785da5ac7d2562f05b12a2df2d4537143886
4
+ data.tar.gz: cb902f45cb72adaaa2008e181af2af0844e926ef62f1d2ccaa67251cb70c5799
5
5
  SHA512:
6
- metadata.gz: 044166f526bea5fea0c4314abc55dfaf2410cdd5d7bce764ccb78329754fe284675dc0ecb55cf5a4546587221cf2361c8e215f1e9e067668ea5833c445bea961
7
- data.tar.gz: e086d0dcc731fb3e810f8fb91d04d95bf3c5efd9cd35f053238be60479d8a47f2126ddad7ab2705d18a26175c6c4126f12be3f19962c2a69f24ad3ee551252b1
6
+ metadata.gz: c5077c40cb32414ac7d230430fdcf1cd93f199a81100179a1aa26185d6216c049316bbf35d79968442fd4db5e18414991946a366e6e7e6260a904f001774faa2
7
+ data.tar.gz: b4f1ff33b47596953c0654b07f1916a88ba247af72c1d4693d9b737802a790256d3232c0419c3a603fdebd51f47360833265326a1f27f8242604ff05766ea690
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ ## 0.12.2 (2023-01-30)
2
+
3
+ - Added experimental support for DataPipes
4
+
5
+ ## 0.12.1 (2023-01-29)
6
+
7
+ - Added `Generator` class
8
+
1
9
  ## 0.12.0 (2022-11-05)
2
10
 
3
11
  - Updated LibTorch to 1.13.0
data/ext/torch/ext.cpp CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
12
12
  void init_backends(Rice::Module& m);
13
13
  void init_cuda(Rice::Module& m);
14
14
  void init_device(Rice::Module& m);
15
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
15
16
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
16
17
  void init_random(Rice::Module& m);
17
18
 
@@ -23,6 +24,7 @@ void Init_ext()
23
24
  // need to define certain classes up front to keep Rice happy
24
25
  auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
25
26
  .define_constructor(Rice::Constructor<torch::IValue>());
27
+ auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
26
28
  auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
27
29
  auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
28
30
  .define_constructor(Rice::Constructor<torch::TensorOptions>());
@@ -38,6 +40,7 @@ void Init_ext()
38
40
  init_backends(m);
39
41
  init_cuda(m);
40
42
  init_device(m);
43
+ init_generator(m, rb_cGenerator);
41
44
  init_ivalue(m, rb_cIValue);
42
45
  init_random(m);
43
46
  }
@@ -0,0 +1,50 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
8
+ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
9
+ rb_cGenerator
10
+ .add_handler<torch::Error>(handle_error)
11
+ .define_singleton_function(
12
+ "new",
13
+ []() {
14
+ // TODO support more devices
15
+ return torch::make_generator<torch::CPUGeneratorImpl>();
16
+ })
17
+ .define_method(
18
+ "device",
19
+ [](torch::Generator& self) {
20
+ return self.device();
21
+ })
22
+ .define_method(
23
+ "initial_seed",
24
+ [](torch::Generator& self) {
25
+ return self.current_seed();
26
+ })
27
+ .define_method(
28
+ "manual_seed",
29
+ [](torch::Generator& self, uint64_t seed) {
30
+ self.set_current_seed(seed);
31
+ return self;
32
+ })
33
+ .define_method(
34
+ "seed",
35
+ [](torch::Generator& self) {
36
+ return self.seed();
37
+ })
38
+ .define_method(
39
+ "state",
40
+ [](torch::Generator& self) {
41
+ return self.get_state();
42
+ })
43
+ .define_method(
44
+ "state=",
45
+ [](torch::Generator& self, const torch::Tensor& state) {
46
+ self.set_state(state);
47
+ });
48
+
49
+ THPGeneratorClass = rb_cGenerator.value();
50
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  #include "ruby_arg_parser.h"
4
4
 
5
+ VALUE THPGeneratorClass = Qnil;
5
6
  VALUE THPVariableClass = Qnil;
6
7
 
7
8
  static std::unordered_map<std::string, ParameterType> type_map = {
@@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
244
245
  return size > 0 && FIXNUM_P(obj);
245
246
  }
246
247
  case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
247
- case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
248
+ case ParameterType::GENERATOR: return THPGenerator_Check(obj);
248
249
  case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
249
250
  case ParameterType::STORAGE: return false; // return isStorage(obj);
250
251
  // case ParameterType::PYOBJECT: return true;
@@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
223
223
 
224
224
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
225
225
  if (NIL_P(args[i])) return c10::nullopt;
226
- throw std::runtime_error("generator not supported yet");
226
+ return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
227
227
  }
228
228
 
229
229
  inline at::Storage RubyArgs::storage(int i) {
data/ext/torch/utils.h CHANGED
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
17
17
 
18
18
  // keep THP prefix for now to make it easier to compare code
19
19
 
20
+ extern VALUE THPGeneratorClass;
20
21
  extern VALUE THPVariableClass;
21
22
 
22
23
  inline VALUE THPUtils_internSymbol(const std::string& str) {
@@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
44
45
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
45
46
  }
46
47
 
48
+ inline bool THPGenerator_Check(VALUE obj) {
49
+ return rb_obj_is_kind_of(obj, THPGeneratorClass);
50
+ }
51
+
47
52
  inline bool THPVariable_Check(VALUE obj) {
48
53
  return rb_obj_is_kind_of(obj, THPVariableClass);
49
54
  }
@@ -0,0 +1,37 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class FilterIterDataPipe < IterDataPipe
6
+ functional_datapipe :filter
7
+
8
+ def initialize(datapipe, &block)
9
+ @datapipe = datapipe
10
+ @filter_fn = block
11
+ end
12
+
13
+ def each
14
+ @datapipe.each do |data|
15
+ filtered = return_if_true(data)
16
+ if non_empty?(filtered)
17
+ yield filtered
18
+ else
19
+ Iter::StreamWrapper.close_streams(data)
20
+ end
21
+ end
22
+ end
23
+
24
+ def return_if_true(data)
25
+ condition = @filter_fn.call(data)
26
+
27
+ data if condition
28
+ end
29
+
30
+ def non_empty?(data)
31
+ !data.nil?
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,74 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileLister < IterDataPipe
7
+ def initialize(
8
+ root = ".",
9
+ masks = "",
10
+ recursive: false,
11
+ abspath: false,
12
+ non_deterministic: false,
13
+ length: -1
14
+ )
15
+ super()
16
+ if root.is_a?(String)
17
+ root = [root]
18
+ end
19
+ if !root.is_a?(IterDataPipe)
20
+ root = IterableWrapper.new(root)
21
+ end
22
+ @datapipe = root
23
+ @masks = masks
24
+ @recursive = recursive
25
+ @abspath = abspath
26
+ @non_deterministic = non_deterministic
27
+ @length = length
28
+ end
29
+
30
+ def each(&block)
31
+ @datapipe.each do |path|
32
+ get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
33
+ end
34
+ end
35
+
36
+ private
37
+
38
+ def get_file_pathnames_from_root(
39
+ root,
40
+ masks,
41
+ recursive: false,
42
+ abspath: false,
43
+ non_deterministic: false
44
+ )
45
+ if File.file?(root)
46
+ raise NotImplementedYet
47
+ else
48
+ pattern = recursive ? "**/*" : "*"
49
+ paths = Dir.glob(pattern, base: root)
50
+ paths = paths.sort if non_deterministic
51
+ paths.each do |f|
52
+ if abspath
53
+ raise NotImplementedYet
54
+ end
55
+ if match_masks(f, masks)
56
+ yield File.join(root, f)
57
+ end
58
+ end
59
+ end
60
+ end
61
+
62
+ def match_masks(name, masks)
63
+ if masks.empty?
64
+ return true
65
+ end
66
+
67
+ raise NotImplementedYet
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,51 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileOpener < IterDataPipe
7
+ def initialize(datapipe, mode: "r", encoding: nil, length: -1)
8
+ super()
9
+ @datapipe = datapipe
10
+ @mode = mode
11
+ @encoding = encoding
12
+
13
+ if !["b", "t", "rb", "rt", "r"].include?(@mode)
14
+ raise ArgumentError, "Invalid mode #{mode}"
15
+ end
16
+
17
+ if mode.include?("b") && !encoding.nil?
18
+ raise ArgumentError, "binary mode doesn't take an encoding argument"
19
+ end
20
+
21
+ @length = length
22
+ end
23
+
24
+ def each(&block)
25
+ get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
26
+ end
27
+
28
+ private
29
+
30
+ def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
31
+ if !pathnames.is_a?(Enumerable)
32
+ pathnames = [pathnames]
33
+ end
34
+
35
+ if ["b", "t"].include?(mode)
36
+ mode = "r#{mode}"
37
+ end
38
+
39
+ pathnames.each do |pathname|
40
+ if !pathname.is_a?(String)
41
+ raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
42
+ end
43
+ yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class IterableWrapper < IterDataPipe
7
+ def initialize(iterable, deepcopy: true)
8
+ @iterable = iterable
9
+ @deepcopy = deepcopy
10
+ end
11
+
12
+ def each
13
+ source_data = @iterable
14
+ if @deepcopy
15
+ source_data = Marshal.load(Marshal.dump(@iterable))
16
+ end
17
+ source_data.each do |data|
18
+ yield data
19
+ end
20
+ end
21
+
22
+ def length
23
+ @iterable.length
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class StreamWrapper
7
+ def initialize(file_obj)
8
+ @file_obj = file_obj
9
+ end
10
+
11
+ def gets(...)
12
+ @file_obj.gets(...)
13
+ end
14
+
15
+ def close
16
+ @file_obj.close
17
+ end
18
+
19
+ # TODO improve
20
+ def self.close_streams(cls)
21
+ if cls.is_a?(StreamWrapper)
22
+ cls.close
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,41 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class IterDataPipe < IterableDataset
6
+ def self.functional_datapipe(name)
7
+ IterDataPipe.register_datapipe_as_function(name, self)
8
+ end
9
+
10
+ def self.functions
11
+ @functions ||= {}
12
+ end
13
+
14
+ def self.register_datapipe_as_function(function_name, cls_to_register)
15
+ if functions.include?(function_name)
16
+ raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
17
+ end
18
+
19
+ function = lambda do |source_dp, *args, **options, &block|
20
+ cls_to_register.new(source_dp, *args, **options, &block)
21
+ end
22
+ functions[function_name] = function
23
+
24
+ define_method function_name do |*args, **options, &block|
25
+ IterDataPipe.functions[function_name].call(self, *args, **options, &block)
26
+ end
27
+ end
28
+
29
+ def reset
30
+ # no-op, but subclasses can override
31
+ end
32
+
33
+ def each(&block)
34
+ reset
35
+ @source_datapipe.each(&block)
36
+ end
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class IterableDataset < Dataset
5
+ include Enumerable
6
+
7
+ def each
8
+ raise NotImplementedError
9
+ end
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.12.0"
2
+ VERSION = "0.12.2"
3
3
  end
data/lib/torch.rb CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
187
187
  require "torch/utils/data"
188
188
  require "torch/utils/data/data_loader"
189
189
  require "torch/utils/data/dataset"
190
+ require "torch/utils/data/iterable_dataset"
191
+ require "torch/utils/data/data_pipes/iter_data_pipe"
192
+ require "torch/utils/data/data_pipes/filter_iter_data_pipe"
193
+ require "torch/utils/data/data_pipes/iter/file_lister"
194
+ require "torch/utils/data/data_pipes/iter/file_opener"
195
+ require "torch/utils/data/data_pipes/iter/iterable_wrapper"
196
+ require "torch/utils/data/data_pipes/iter/stream_wrapper"
190
197
  require "torch/utils/data/subset"
191
198
  require "torch/utils/data/tensor_dataset"
192
199
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.12.0
4
+ version: 0.12.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-11-05 00:00:00.000000000 Z
11
+ date: 2023-01-31 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -44,6 +44,7 @@ files:
44
44
  - ext/torch/extconf.rb
45
45
  - ext/torch/fft.cpp
46
46
  - ext/torch/fft_functions.h
47
+ - ext/torch/generator.cpp
47
48
  - ext/torch/ivalue.cpp
48
49
  - ext/torch/linalg.cpp
49
50
  - ext/torch/linalg_functions.h
@@ -206,7 +207,14 @@ files:
206
207
  - lib/torch/tensor.rb
207
208
  - lib/torch/utils/data.rb
208
209
  - lib/torch/utils/data/data_loader.rb
210
+ - lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
211
+ - lib/torch/utils/data/data_pipes/iter/file_lister.rb
212
+ - lib/torch/utils/data/data_pipes/iter/file_opener.rb
213
+ - lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
214
+ - lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
215
+ - lib/torch/utils/data/data_pipes/iter_data_pipe.rb
209
216
  - lib/torch/utils/data/dataset.rb
217
+ - lib/torch/utils/data/iterable_dataset.rb
210
218
  - lib/torch/utils/data/subset.rb
211
219
  - lib/torch/utils/data/tensor_dataset.rb
212
220
  - lib/torch/version.rb
@@ -229,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
229
237
  - !ruby/object:Gem::Version
230
238
  version: '0'
231
239
  requirements: []
232
- rubygems_version: 3.3.7
240
+ rubygems_version: 3.4.1
233
241
  signing_key:
234
242
  specification_version: 4
235
243
  summary: Deep learning for Ruby, powered by LibTorch