torch-rb 0.8.3 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -75,7 +75,7 @@ struct RubyArgs {
75
75
  int idx;
76
76
 
77
77
  inline at::Tensor tensor(int i);
78
- inline OptionalTensor optionalTensor(int i);
78
+ inline c10::optional<at::Tensor> optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
81
  inline std::vector<at::Scalar> scalarlist(int i);
@@ -109,6 +109,9 @@ struct RubyArgs {
109
109
  // inline at::QScheme toQScheme(int i);
110
110
  inline std::string string(int i);
111
111
  inline c10::optional<std::string> stringOptional(int i);
112
+ inline c10::string_view stringView(int i);
113
+ // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
114
+ inline c10::optional<c10::string_view> stringViewOptional(int i);
112
115
  // inline PyObject* pyobject(int i);
113
116
  inline int64_t toInt64(int i);
114
117
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
125
128
  return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
126
129
  }
127
130
 
128
- inline OptionalTensor RubyArgs::optionalTensor(int i) {
129
- if (NIL_P(args[i])) return OptionalTensor(Nil);
131
+ inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
132
+ if (NIL_P(args[i])) return c10::nullopt;
130
133
  return tensor(i);
131
134
  }
132
135
 
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
322
325
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
323
326
  }
324
327
 
328
+ inline c10::string_view RubyArgs::stringView(int i) {
329
+ auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
330
+ return c10::string_view(str.data(), str.size());
331
+ }
332
+
333
+ inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
334
+ if (NIL_P(args[i])) return c10::nullopt;
335
+ auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
336
+ return c10::string_view(str.data(), str.size());
337
+ }
338
+
325
339
  inline int64_t RubyArgs::toInt64(int i) {
326
340
  if (NIL_P(args[i])) return signature.params[i].default_int;
327
341
  return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
@@ -41,24 +41,6 @@ using torch::nn::init::NonlinearityType;
41
41
  #define RETURN_NIL \
42
42
  return Qnil;
43
43
 
44
- class OptionalTensor {
45
- torch::Tensor value;
46
- public:
47
- OptionalTensor(Object o) {
48
- if (o.is_nil()) {
49
- value = {};
50
- } else {
51
- value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
- }
53
- }
54
- OptionalTensor(torch::Tensor o) {
55
- value = o;
56
- }
57
- operator torch::Tensor() const {
58
- return value;
59
- }
60
- };
61
-
62
44
  namespace Rice::detail
63
45
  {
64
46
  template<>
@@ -131,25 +113,6 @@ namespace Rice::detail
131
113
  }
132
114
  };
133
115
 
134
- template<>
135
- struct Type<OptionalTensor>
136
- {
137
- static bool verify()
138
- {
139
- return true;
140
- }
141
- };
142
-
143
- template<>
144
- class From_Ruby<OptionalTensor>
145
- {
146
- public:
147
- OptionalTensor convert(VALUE x)
148
- {
149
- return OptionalTensor(x);
150
- }
151
- };
152
-
153
116
  template<>
154
117
  struct Type<Scalar>
155
118
  {
data/ext/torch/tensor.cpp CHANGED
@@ -107,7 +107,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
107
107
  ParsedArgs<4> parsed_args;
108
108
  auto _r = parser.parse(self_, argc, argv, parsed_args);
109
109
  // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
110
- auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
110
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
111
111
  // in future, release GVL
112
112
  self._backward(inputs, gradient, retain_graph, create_graph);
113
113
  };
@@ -125,13 +125,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
125
125
  rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
126
126
 
127
127
  rb_cTensor
128
- .define_method("cuda?", &torch::Tensor::is_cuda)
129
- .define_method("sparse?", &torch::Tensor::is_sparse)
130
- .define_method("quantized?", &torch::Tensor::is_quantized)
131
- .define_method("dim", &torch::Tensor::dim)
132
- .define_method("numel", &torch::Tensor::numel)
133
- .define_method("element_size", &torch::Tensor::element_size)
134
- .define_method("requires_grad", &torch::Tensor::requires_grad)
128
+ .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
129
+ .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
130
+ .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
131
+ .define_method("dim", [](Tensor& self) { return self.dim(); })
132
+ .define_method("numel", [](Tensor& self) { return self.numel(); })
133
+ .define_method("element_size", [](Tensor& self) { return self.element_size(); })
134
+ .define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
135
135
  .define_method(
136
136
  "_size",
137
137
  [](Tensor& self, int64_t dim) {
data/lib/torch/tensor.rb CHANGED
@@ -191,5 +191,17 @@ module Torch
191
191
  clone
192
192
  end
193
193
  end
194
+
195
+ # not a method in native_functions.yaml
196
+ # attribute in Python rather than method
197
+ def imag
198
+ Torch.imag(self)
199
+ end
200
+
201
+ # not a method in native_functions.yaml
202
+ # attribute in Python rather than method
203
+ def real
204
+ Torch.real(self)
205
+ end
194
206
  end
195
207
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.3"
2
+ VERSION = "0.9.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.3
4
+ version: 0.9.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-10-17 00:00:00.000000000 Z
11
+ date: 2021-10-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice