torch-rb 0.8.3 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -75,7 +75,7 @@ struct RubyArgs {
75
75
  int idx;
76
76
 
77
77
  inline at::Tensor tensor(int i);
78
- inline OptionalTensor optionalTensor(int i);
78
+ inline c10::optional<at::Tensor> optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
81
  inline std::vector<at::Scalar> scalarlist(int i);
@@ -109,6 +109,9 @@ struct RubyArgs {
109
109
  // inline at::QScheme toQScheme(int i);
110
110
  inline std::string string(int i);
111
111
  inline c10::optional<std::string> stringOptional(int i);
112
+ inline c10::string_view stringView(int i);
113
+ // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
114
+ inline c10::optional<c10::string_view> stringViewOptional(int i);
112
115
  // inline PyObject* pyobject(int i);
113
116
  inline int64_t toInt64(int i);
114
117
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
125
128
  return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
126
129
  }
127
130
 
128
- inline OptionalTensor RubyArgs::optionalTensor(int i) {
129
- if (NIL_P(args[i])) return OptionalTensor(Nil);
131
+ inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
132
+ if (NIL_P(args[i])) return c10::nullopt;
130
133
  return tensor(i);
131
134
  }
132
135
 
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
322
325
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
323
326
  }
324
327
 
328
+ inline c10::string_view RubyArgs::stringView(int i) {
329
+ auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
330
+ return c10::string_view(str.data(), str.size());
331
+ }
332
+
333
+ inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
334
+ if (NIL_P(args[i])) return c10::nullopt;
335
+ auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
336
+ return c10::string_view(str.data(), str.size());
337
+ }
338
+
325
339
  inline int64_t RubyArgs::toInt64(int i) {
326
340
  if (NIL_P(args[i])) return signature.params[i].default_int;
327
341
  return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
@@ -41,24 +41,6 @@ using torch::nn::init::NonlinearityType;
41
41
  #define RETURN_NIL \
42
42
  return Qnil;
43
43
 
44
- class OptionalTensor {
45
- torch::Tensor value;
46
- public:
47
- OptionalTensor(Object o) {
48
- if (o.is_nil()) {
49
- value = {};
50
- } else {
51
- value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
- }
53
- }
54
- OptionalTensor(torch::Tensor o) {
55
- value = o;
56
- }
57
- operator torch::Tensor() const {
58
- return value;
59
- }
60
- };
61
-
62
44
  namespace Rice::detail
63
45
  {
64
46
  template<>
@@ -131,25 +113,6 @@ namespace Rice::detail
131
113
  }
132
114
  };
133
115
 
134
- template<>
135
- struct Type<OptionalTensor>
136
- {
137
- static bool verify()
138
- {
139
- return true;
140
- }
141
- };
142
-
143
- template<>
144
- class From_Ruby<OptionalTensor>
145
- {
146
- public:
147
- OptionalTensor convert(VALUE x)
148
- {
149
- return OptionalTensor(x);
150
- }
151
- };
152
-
153
116
  template<>
154
117
  struct Type<Scalar>
155
118
  {
data/ext/torch/tensor.cpp CHANGED
@@ -107,7 +107,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
107
107
  ParsedArgs<4> parsed_args;
108
108
  auto _r = parser.parse(self_, argc, argv, parsed_args);
109
109
  // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
110
- auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
110
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
111
111
  // in future, release GVL
112
112
  self._backward(inputs, gradient, retain_graph, create_graph);
113
113
  };
@@ -125,13 +125,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
125
125
  rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
126
126
 
127
127
  rb_cTensor
128
- .define_method("cuda?", &torch::Tensor::is_cuda)
129
- .define_method("sparse?", &torch::Tensor::is_sparse)
130
- .define_method("quantized?", &torch::Tensor::is_quantized)
131
- .define_method("dim", &torch::Tensor::dim)
132
- .define_method("numel", &torch::Tensor::numel)
133
- .define_method("element_size", &torch::Tensor::element_size)
134
- .define_method("requires_grad", &torch::Tensor::requires_grad)
128
+ .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
129
+ .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
130
+ .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
131
+ .define_method("dim", [](Tensor& self) { return self.dim(); })
132
+ .define_method("numel", [](Tensor& self) { return self.numel(); })
133
+ .define_method("element_size", [](Tensor& self) { return self.element_size(); })
134
+ .define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
135
135
  .define_method(
136
136
  "_size",
137
137
  [](Tensor& self, int64_t dim) {
data/lib/torch/tensor.rb CHANGED
@@ -191,5 +191,17 @@ module Torch
191
191
  clone
192
192
  end
193
193
  end
194
+
195
+ # not a method in native_functions.yaml
196
+ # attribute in Python rather than method
197
+ def imag
198
+ Torch.imag(self)
199
+ end
200
+
201
+ # not a method in native_functions.yaml
202
+ # attribute in Python rather than method
203
+ def real
204
+ Torch.real(self)
205
+ end
194
206
  end
195
207
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.3"
2
+ VERSION = "0.9.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.3
4
+ version: 0.9.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-10-17 00:00:00.000000000 Z
11
+ date: 2021-10-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice