torch-rb 0.9.2 → 0.10.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -88,18 +88,18 @@ struct RubyArgs {
88
88
  inline c10::optional<at::Generator> generator(int i);
89
89
  inline at::Storage storage(int i);
90
90
  inline at::ScalarType scalartype(int i);
91
- // inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
91
+ inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
92
92
  inline c10::optional<at::ScalarType> scalartypeOptional(int i);
93
93
  inline c10::optional<at::Scalar> scalarOptional(int i);
94
94
  inline c10::optional<int64_t> toInt64Optional(int i);
95
95
  inline c10::optional<bool> toBoolOptional(int i);
96
96
  inline c10::optional<double> toDoubleOptional(int i);
97
97
  inline c10::OptionalArray<double> doublelistOptional(int i);
98
- // inline at::Layout layout(int i);
99
- // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
98
+ inline at::Layout layout(int i);
99
+ inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
100
100
  inline c10::optional<at::Layout> layoutOptional(int i);
101
101
  inline at::Device device(int i);
102
- // inline at::Device deviceWithDefault(int i, const at::Device& default_device);
102
+ inline at::Device deviceWithDefault(int i, const at::Device& default_device);
103
103
  // inline c10::optional<at::Device> deviceOptional(int i);
104
104
  // inline at::Dimname dimname(int i);
105
105
  // inline std::vector<at::Dimname> dimnamelist(int i);
@@ -240,6 +240,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
240
240
  return it->second;
241
241
  }
242
242
 
243
+ inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
244
+ if (NIL_P(args[i])) return default_scalartype;
245
+ return scalartype(i);
246
+ }
247
+
243
248
  inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
244
249
  if (NIL_P(args[i])) return c10::nullopt;
245
250
  return scalartype(i);
@@ -284,8 +289,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
284
289
  return res;
285
290
  }
286
291
 
287
- inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
288
- if (NIL_P(args[i])) return c10::nullopt;
292
+ inline at::Layout RubyArgs::layout(int i) {
293
+ if (NIL_P(args[i])) return signature.params[i].default_layout;
289
294
 
290
295
  static std::unordered_map<VALUE, Layout> layout_map = {
291
296
  {ID2SYM(rb_intern("strided")), Layout::Strided},
@@ -298,6 +303,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
298
303
  return it->second;
299
304
  }
300
305
 
306
+ inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
307
+ if (NIL_P(args[i])) return default_layout;
308
+ return layout(i);
309
+ }
310
+
311
+ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
312
+ if (NIL_P(args[i])) return c10::nullopt;
313
+ return layout(i);
314
+ }
315
+
301
316
  inline at::Device RubyArgs::device(int i) {
302
317
  if (NIL_P(args[i])) {
303
318
  return at::Device("cpu");
@@ -306,6 +321,11 @@ inline at::Device RubyArgs::device(int i) {
306
321
  return at::Device(device_str);
307
322
  }
308
323
 
324
+ inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
325
+ if (NIL_P(args[i])) return default_device;
326
+ return device(i);
327
+ }
328
+
309
329
  inline at::MemoryFormat RubyArgs::memoryformat(int i) {
310
330
  if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
311
331
  throw std::runtime_error("memoryformat not supported yet");
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_sparse_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -1,8 +1,15 @@
1
1
  #pragma once
2
2
 
3
+ #include <torch/torch.h>
4
+
3
5
  #include <rice/rice.hpp>
4
6
  #include <rice/stl.hpp>
5
7
 
8
+ static_assert(
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
10
+ "Incompatible LibTorch version"
11
+ );
12
+
6
13
  // TODO find better place
7
14
  inline void handle_error(torch::Error const & ex) {
8
15
  throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
@@ -0,0 +1,48 @@
1
+ module Torch
2
+ module NN
3
+ class ParameterList < Module
4
+ include Enumerable
5
+
6
+ def initialize(parameters)
7
+ super()
8
+ @initialized = true
9
+ unless parameters.nil?
10
+ concat(parameters)
11
+ end
12
+ end
13
+
14
+ def length
15
+ @parameters.length
16
+ end
17
+ alias_method :count, :length
18
+ alias_method :size, :length
19
+
20
+ def concat(parameters)
21
+ unless parameters.is_a?(Enumerable)
22
+ raise TypeError, "ParameterList#concat should be called with an enumerable, but got #{parameters.class.name}"
23
+ end
24
+ offset = length
25
+ parameters.each_with_index do |param, i|
26
+ register_parameter((offset + i).to_s, param)
27
+ end
28
+ self
29
+ end
30
+
31
+ def each(&block)
32
+ if block_given?
33
+ @parameters.values.each(&block)
34
+ else
35
+ to_enum(:each)
36
+ end
37
+ end
38
+
39
+ def [](idx)
40
+ if idx.is_a?(Range)
41
+ self.class.new(@parameters.values[idx])
42
+ else
43
+ @parameters[idx.to_s]
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
data/lib/torch/tensor.rb CHANGED
@@ -199,5 +199,13 @@ module Torch
199
199
  def real
200
200
  Torch.real(self)
201
201
  end
202
+
203
+ def coerce(other)
204
+ if other.is_a?(Numeric)
205
+ [Torch.tensor(other), self]
206
+ else
207
+ raise TypeError, "#{self.class} can't be coerced into #{other.class}"
208
+ end
209
+ end
202
210
  end
203
211
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.9.2"
2
+ VERSION = "0.10.2"
3
3
  end
data/lib/torch.rb CHANGED
@@ -40,6 +40,7 @@ require "torch/nn/utils"
40
40
  # nn containers
41
41
  require "torch/nn/module"
42
42
  require "torch/nn/module_list"
43
+ require "torch/nn/parameter_list"
43
44
  require "torch/nn/sequential"
44
45
 
45
46
  # nn convolution layers
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.9.2
4
+ version: 0.10.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-02-03 00:00:00.000000000 Z
11
+ date: 2022-06-14 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -52,6 +52,7 @@ files:
52
52
  - ext/torch/random.cpp
53
53
  - ext/torch/ruby_arg_parser.cpp
54
54
  - ext/torch/ruby_arg_parser.h
55
+ - ext/torch/sparse_functions.h
55
56
  - ext/torch/special.cpp
56
57
  - ext/torch/special_functions.h
57
58
  - ext/torch/templates.h
@@ -149,6 +150,7 @@ files:
149
150
  - lib/torch/nn/nll_loss.rb
150
151
  - lib/torch/nn/pairwise_distance.rb
151
152
  - lib/torch/nn/parameter.rb
153
+ - lib/torch/nn/parameter_list.rb
152
154
  - lib/torch/nn/poisson_nll_loss.rb
153
155
  - lib/torch/nn/prelu.rb
154
156
  - lib/torch/nn/reflection_pad1d.rb
@@ -227,7 +229,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
227
229
  - !ruby/object:Gem::Version
228
230
  version: '0'
229
231
  requirements: []
230
- rubygems_version: 3.3.3
232
+ rubygems_version: 3.3.7
231
233
  signing_key:
232
234
  specification_version: 4
233
235
  summary: Deep learning for Ruby, powered by LibTorch