torch-rb 0.12.0 → 0.12.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 24617a6191c4d1e42ff0d5885c0762db75d034ee52817d70004ddc52025890dd
4
- data.tar.gz: 7036f3c7fac8a1aac22914ec2811b8771464baeb847052e417b069f3897b3977
3
+ metadata.gz: 8160298f6299869f699201dd9217785da5ac7d2562f05b12a2df2d4537143886
4
+ data.tar.gz: cb902f45cb72adaaa2008e181af2af0844e926ef62f1d2ccaa67251cb70c5799
5
5
  SHA512:
6
- metadata.gz: 044166f526bea5fea0c4314abc55dfaf2410cdd5d7bce764ccb78329754fe284675dc0ecb55cf5a4546587221cf2361c8e215f1e9e067668ea5833c445bea961
7
- data.tar.gz: e086d0dcc731fb3e810f8fb91d04d95bf3c5efd9cd35f053238be60479d8a47f2126ddad7ab2705d18a26175c6c4126f12be3f19962c2a69f24ad3ee551252b1
6
+ metadata.gz: c5077c40cb32414ac7d230430fdcf1cd93f199a81100179a1aa26185d6216c049316bbf35d79968442fd4db5e18414991946a366e6e7e6260a904f001774faa2
7
+ data.tar.gz: b4f1ff33b47596953c0654b07f1916a88ba247af72c1d4693d9b737802a790256d3232c0419c3a603fdebd51f47360833265326a1f27f8242604ff05766ea690
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ ## 0.12.2 (2023-01-30)
2
+
3
+ - Added experimental support for DataPipes
4
+
5
+ ## 0.12.1 (2023-01-29)
6
+
7
+ - Added `Generator` class
8
+
1
9
  ## 0.12.0 (2022-11-05)
2
10
 
3
11
  - Updated LibTorch to 1.13.0
data/ext/torch/ext.cpp CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
12
12
  void init_backends(Rice::Module& m);
13
13
  void init_cuda(Rice::Module& m);
14
14
  void init_device(Rice::Module& m);
15
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
15
16
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
16
17
  void init_random(Rice::Module& m);
17
18
 
@@ -23,6 +24,7 @@ void Init_ext()
23
24
  // need to define certain classes up front to keep Rice happy
24
25
  auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
25
26
  .define_constructor(Rice::Constructor<torch::IValue>());
27
+ auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
26
28
  auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
27
29
  auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
28
30
  .define_constructor(Rice::Constructor<torch::TensorOptions>());
@@ -38,6 +40,7 @@ void Init_ext()
38
40
  init_backends(m);
39
41
  init_cuda(m);
40
42
  init_device(m);
43
+ init_generator(m, rb_cGenerator);
41
44
  init_ivalue(m, rb_cIValue);
42
45
  init_random(m);
43
46
  }
@@ -0,0 +1,50 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
8
+ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
9
+ rb_cGenerator
10
+ .add_handler<torch::Error>(handle_error)
11
+ .define_singleton_function(
12
+ "new",
13
+ []() {
14
+ // TODO support more devices
15
+ return torch::make_generator<torch::CPUGeneratorImpl>();
16
+ })
17
+ .define_method(
18
+ "device",
19
+ [](torch::Generator& self) {
20
+ return self.device();
21
+ })
22
+ .define_method(
23
+ "initial_seed",
24
+ [](torch::Generator& self) {
25
+ return self.current_seed();
26
+ })
27
+ .define_method(
28
+ "manual_seed",
29
+ [](torch::Generator& self, uint64_t seed) {
30
+ self.set_current_seed(seed);
31
+ return self;
32
+ })
33
+ .define_method(
34
+ "seed",
35
+ [](torch::Generator& self) {
36
+ return self.seed();
37
+ })
38
+ .define_method(
39
+ "state",
40
+ [](torch::Generator& self) {
41
+ return self.get_state();
42
+ })
43
+ .define_method(
44
+ "state=",
45
+ [](torch::Generator& self, const torch::Tensor& state) {
46
+ self.set_state(state);
47
+ });
48
+
49
+ THPGeneratorClass = rb_cGenerator.value();
50
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  #include "ruby_arg_parser.h"
4
4
 
5
+ VALUE THPGeneratorClass = Qnil;
5
6
  VALUE THPVariableClass = Qnil;
6
7
 
7
8
  static std::unordered_map<std::string, ParameterType> type_map = {
@@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
244
245
  return size > 0 && FIXNUM_P(obj);
245
246
  }
246
247
  case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
247
- case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
248
+ case ParameterType::GENERATOR: return THPGenerator_Check(obj);
248
249
  case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
249
250
  case ParameterType::STORAGE: return false; // return isStorage(obj);
250
251
  // case ParameterType::PYOBJECT: return true;
@@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
223
223
 
224
224
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
225
225
  if (NIL_P(args[i])) return c10::nullopt;
226
- throw std::runtime_error("generator not supported yet");
226
+ return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
227
227
  }
228
228
 
229
229
  inline at::Storage RubyArgs::storage(int i) {
data/ext/torch/utils.h CHANGED
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
17
17
 
18
18
  // keep THP prefix for now to make it easier to compare code
19
19
 
20
+ extern VALUE THPGeneratorClass;
20
21
  extern VALUE THPVariableClass;
21
22
 
22
23
  inline VALUE THPUtils_internSymbol(const std::string& str) {
@@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
44
45
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
45
46
  }
46
47
 
48
+ inline bool THPGenerator_Check(VALUE obj) {
49
+ return rb_obj_is_kind_of(obj, THPGeneratorClass);
50
+ }
51
+
47
52
  inline bool THPVariable_Check(VALUE obj) {
48
53
  return rb_obj_is_kind_of(obj, THPVariableClass);
49
54
  }
@@ -0,0 +1,37 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class FilterIterDataPipe < IterDataPipe
6
+ functional_datapipe :filter
7
+
8
+ def initialize(datapipe, &block)
9
+ @datapipe = datapipe
10
+ @filter_fn = block
11
+ end
12
+
13
+ def each
14
+ @datapipe.each do |data|
15
+ filtered = return_if_true(data)
16
+ if non_empty?(filtered)
17
+ yield filtered
18
+ else
19
+ Iter::StreamWrapper.close_streams(data)
20
+ end
21
+ end
22
+ end
23
+
24
+ def return_if_true(data)
25
+ condition = @filter_fn.call(data)
26
+
27
+ data if condition
28
+ end
29
+
30
+ def non_empty?(data)
31
+ !data.nil?
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,74 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileLister < IterDataPipe
7
+ def initialize(
8
+ root = ".",
9
+ masks = "",
10
+ recursive: false,
11
+ abspath: false,
12
+ non_deterministic: false,
13
+ length: -1
14
+ )
15
+ super()
16
+ if root.is_a?(String)
17
+ root = [root]
18
+ end
19
+ if !root.is_a?(IterDataPipe)
20
+ root = IterableWrapper.new(root)
21
+ end
22
+ @datapipe = root
23
+ @masks = masks
24
+ @recursive = recursive
25
+ @abspath = abspath
26
+ @non_deterministic = non_deterministic
27
+ @length = length
28
+ end
29
+
30
+ def each(&block)
31
+ @datapipe.each do |path|
32
+ get_file_pathnames_from_root(path, @masks, recursive: @recursive, abspath: @abspath, non_deterministic: @non_deterministic, &block)
33
+ end
34
+ end
35
+
36
+ private
37
+
38
+ def get_file_pathnames_from_root(
39
+ root,
40
+ masks,
41
+ recursive: false,
42
+ abspath: false,
43
+ non_deterministic: false
44
+ )
45
+ if File.file?(root)
46
+ raise NotImplementedYet
47
+ else
48
+ pattern = recursive ? "**/*" : "*"
49
+ paths = Dir.glob(pattern, base: root)
50
+ paths = paths.sort if non_deterministic
51
+ paths.each do |f|
52
+ if abspath
53
+ raise NotImplementedYet
54
+ end
55
+ if match_masks(f, masks)
56
+ yield File.join(root, f)
57
+ end
58
+ end
59
+ end
60
+ end
61
+
62
+ def match_masks(name, masks)
63
+ if masks.empty?
64
+ return true
65
+ end
66
+
67
+ raise NotImplementedYet
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,51 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class FileOpener < IterDataPipe
7
+ def initialize(datapipe, mode: "r", encoding: nil, length: -1)
8
+ super()
9
+ @datapipe = datapipe
10
+ @mode = mode
11
+ @encoding = encoding
12
+
13
+ if !["b", "t", "rb", "rt", "r"].include?(@mode)
14
+ raise ArgumentError, "Invalid mode #{mode}"
15
+ end
16
+
17
+ if mode.include?("b") && !encoding.nil?
18
+ raise ArgumentError, "binary mode doesn't take an encoding argument"
19
+ end
20
+
21
+ @length = length
22
+ end
23
+
24
+ def each(&block)
25
+ get_file_binaries_from_pathnames(@datapipe, @mode, encoding: @encoding, &block)
26
+ end
27
+
28
+ private
29
+
30
+ def get_file_binaries_from_pathnames(pathnames, mode, encoding: nil)
31
+ if !pathnames.is_a?(Enumerable)
32
+ pathnames = [pathnames]
33
+ end
34
+
35
+ if ["b", "t"].include?(mode)
36
+ mode = "r#{mode}"
37
+ end
38
+
39
+ pathnames.each do |pathname|
40
+ if !pathname.is_a?(String)
41
+ raise TypeError, "Expected string type for pathname, but got #{pathname.class.name}"
42
+ end
43
+ yield pathname, StreamWrapper.new(File.open(pathname, mode, encoding: encoding))
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class IterableWrapper < IterDataPipe
7
+ def initialize(iterable, deepcopy: true)
8
+ @iterable = iterable
9
+ @deepcopy = deepcopy
10
+ end
11
+
12
+ def each
13
+ source_data = @iterable
14
+ if @deepcopy
15
+ source_data = Marshal.load(Marshal.dump(@iterable))
16
+ end
17
+ source_data.each do |data|
18
+ yield data
19
+ end
20
+ end
21
+
22
+ def length
23
+ @iterable.length
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,30 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ module Iter
6
+ class StreamWrapper
7
+ def initialize(file_obj)
8
+ @file_obj = file_obj
9
+ end
10
+
11
+ def gets(...)
12
+ @file_obj.gets(...)
13
+ end
14
+
15
+ def close
16
+ @file_obj.close
17
+ end
18
+
19
+ # TODO improve
20
+ def self.close_streams(cls)
21
+ if cls.is_a?(StreamWrapper)
22
+ cls.close
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,41 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ module DataPipes
5
+ class IterDataPipe < IterableDataset
6
+ def self.functional_datapipe(name)
7
+ IterDataPipe.register_datapipe_as_function(name, self)
8
+ end
9
+
10
+ def self.functions
11
+ @functions ||= {}
12
+ end
13
+
14
+ def self.register_datapipe_as_function(function_name, cls_to_register)
15
+ if functions.include?(function_name)
16
+ raise Error, "Unable to add DataPipe function name #{function_name} as it is already taken"
17
+ end
18
+
19
+ function = lambda do |source_dp, *args, **options, &block|
20
+ cls_to_register.new(source_dp, *args, **options, &block)
21
+ end
22
+ functions[function_name] = function
23
+
24
+ define_method function_name do |*args, **options, &block|
25
+ IterDataPipe.functions[function_name].call(self, *args, **options, &block)
26
+ end
27
+ end
28
+
29
+ def reset
30
+ # no-op, but subclasses can override
31
+ end
32
+
33
+ def each(&block)
34
+ reset
35
+ @source_datapipe.each(&block)
36
+ end
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class IterableDataset < Dataset
5
+ include Enumerable
6
+
7
+ def each
8
+ raise NotImplementedError
9
+ end
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.12.0"
2
+ VERSION = "0.12.2"
3
3
  end
data/lib/torch.rb CHANGED
@@ -187,6 +187,13 @@ require "torch/nn/init"
187
187
  require "torch/utils/data"
188
188
  require "torch/utils/data/data_loader"
189
189
  require "torch/utils/data/dataset"
190
+ require "torch/utils/data/iterable_dataset"
191
+ require "torch/utils/data/data_pipes/iter_data_pipe"
192
+ require "torch/utils/data/data_pipes/filter_iter_data_pipe"
193
+ require "torch/utils/data/data_pipes/iter/file_lister"
194
+ require "torch/utils/data/data_pipes/iter/file_opener"
195
+ require "torch/utils/data/data_pipes/iter/iterable_wrapper"
196
+ require "torch/utils/data/data_pipes/iter/stream_wrapper"
190
197
  require "torch/utils/data/subset"
191
198
  require "torch/utils/data/tensor_dataset"
192
199
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.12.0
4
+ version: 0.12.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-11-05 00:00:00.000000000 Z
11
+ date: 2023-01-31 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -44,6 +44,7 @@ files:
44
44
  - ext/torch/extconf.rb
45
45
  - ext/torch/fft.cpp
46
46
  - ext/torch/fft_functions.h
47
+ - ext/torch/generator.cpp
47
48
  - ext/torch/ivalue.cpp
48
49
  - ext/torch/linalg.cpp
49
50
  - ext/torch/linalg_functions.h
@@ -206,7 +207,14 @@ files:
206
207
  - lib/torch/tensor.rb
207
208
  - lib/torch/utils/data.rb
208
209
  - lib/torch/utils/data/data_loader.rb
210
+ - lib/torch/utils/data/data_pipes/filter_iter_data_pipe.rb
211
+ - lib/torch/utils/data/data_pipes/iter/file_lister.rb
212
+ - lib/torch/utils/data/data_pipes/iter/file_opener.rb
213
+ - lib/torch/utils/data/data_pipes/iter/iterable_wrapper.rb
214
+ - lib/torch/utils/data/data_pipes/iter/stream_wrapper.rb
215
+ - lib/torch/utils/data/data_pipes/iter_data_pipe.rb
209
216
  - lib/torch/utils/data/dataset.rb
217
+ - lib/torch/utils/data/iterable_dataset.rb
210
218
  - lib/torch/utils/data/subset.rb
211
219
  - lib/torch/utils/data/tensor_dataset.rb
212
220
  - lib/torch/version.rb
@@ -229,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
229
237
  - !ruby/object:Gem::Version
230
238
  version: '0'
231
239
  requirements: []
232
- rubygems_version: 3.3.7
240
+ rubygems_version: 3.4.1
233
241
  signing_key:
234
242
  specification_version: 4
235
243
  summary: Deep learning for Ruby, powered by LibTorch