torch-rb 0.6.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/torch.cpp CHANGED
@@ -1,6 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "torch_functions.h"
6
6
  #include "templates.h"
@@ -9,69 +9,78 @@
9
9
  void init_torch(Rice::Module& m) {
10
10
  m.add_handler<torch::Error>(handle_error);
11
11
  add_torch_functions(m);
12
- m.define_singleton_method(
12
+ m.define_singleton_function(
13
13
  "grad_enabled?",
14
- *[]() {
14
+ []() {
15
15
  return torch::GradMode::is_enabled();
16
16
  })
17
- .define_singleton_method(
17
+ .define_singleton_function(
18
18
  "_set_grad_enabled",
19
- *[](bool enabled) {
19
+ [](bool enabled) {
20
20
  torch::GradMode::set_enabled(enabled);
21
21
  })
22
- .define_singleton_method(
22
+ .define_singleton_function(
23
23
  "manual_seed",
24
- *[](uint64_t seed) {
24
+ [](uint64_t seed) {
25
25
  return torch::manual_seed(seed);
26
26
  })
27
27
  // config
28
- .define_singleton_method(
28
+ .define_singleton_function(
29
29
  "show_config",
30
- *[] {
30
+ [] {
31
31
  return torch::show_config();
32
32
  })
33
- .define_singleton_method(
33
+ .define_singleton_function(
34
34
  "parallel_info",
35
- *[] {
35
+ [] {
36
36
  return torch::get_parallel_info();
37
37
  })
38
38
  // begin operations
39
- .define_singleton_method(
39
+ .define_singleton_function(
40
40
  "_save",
41
- *[](const torch::IValue &value) {
41
+ [](const torch::IValue &value) {
42
42
  auto v = torch::pickle_save(value);
43
43
  std::string str(v.begin(), v.end());
44
44
  return str;
45
45
  })
46
- .define_singleton_method(
46
+ .define_singleton_function(
47
47
  "_load",
48
- *[](const std::string &s) {
48
+ [](const std::string &s) {
49
49
  std::vector<char> v;
50
50
  std::copy(s.begin(), s.end(), std::back_inserter(v));
51
51
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
52
  return torch::pickle_load(v);
53
53
  })
54
- .define_singleton_method(
54
+ .define_singleton_function(
55
55
  "_from_blob",
56
- *[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
56
+ [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
57
  void *data = const_cast<char *>(s.c_str());
58
58
  return torch::from_blob(data, size, options);
59
59
  })
60
- .define_singleton_method(
60
+ .define_singleton_function(
61
61
  "_tensor",
62
- *[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
62
+ [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
63
  auto dtype = options.dtype();
64
64
  torch::Tensor t;
65
65
  if (dtype == torch::kBool) {
66
66
  std::vector<uint8_t> vec;
67
67
  for (long i = 0; i < a.size(); i++) {
68
- vec.push_back(from_ruby<bool>(a[i]));
68
+ vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
72
+ // TODO use template
73
+ std::vector<c10::complex<double>> vec;
74
+ Object obj;
75
+ for (long i = 0; i < a.size(); i++) {
76
+ obj = a[i];
77
+ vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
69
78
  }
70
79
  t = torch::tensor(vec, options);
71
80
  } else {
72
81
  std::vector<float> vec;
73
82
  for (long i = 0; i < a.size(); i++) {
74
- vec.push_back(from_ruby<float>(a[i]));
83
+ vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
75
84
  }
76
85
  // hack for requires_grad error
77
86
  if (options.requires_grad()) {
data/ext/torch/utils.h CHANGED
@@ -1,11 +1,10 @@
1
1
  #pragma once
2
2
 
3
- #include <rice/Exception.hpp>
4
- #include <rice/Symbol.hpp>
3
+ #include <rice/rice.hpp>
4
+ #include <rice/stl.hpp>
5
5
 
6
6
  // TODO find better place
7
- inline void handle_error(torch::Error const & ex)
8
- {
7
+ inline void handle_error(torch::Error const & ex) {
9
8
  throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
10
9
  }
11
10
 
@@ -1,99 +1,106 @@
1
1
  #pragma once
2
2
 
3
3
  #include <torch/torch.h>
4
- #include <rice/Object.hpp>
4
+ #include <rice/rice.hpp>
5
5
 
6
- inline Object wrap(bool x) {
7
- return to_ruby<bool>(x);
6
+ inline VALUE wrap(bool x) {
7
+ return Rice::detail::To_Ruby<bool>().convert(x);
8
8
  }
9
9
 
10
- inline Object wrap(int64_t x) {
11
- return to_ruby<int64_t>(x);
10
+ inline VALUE wrap(int64_t x) {
11
+ return Rice::detail::To_Ruby<int64_t>().convert(x);
12
12
  }
13
13
 
14
- inline Object wrap(double x) {
15
- return to_ruby<double>(x);
14
+ inline VALUE wrap(double x) {
15
+ return Rice::detail::To_Ruby<double>().convert(x);
16
16
  }
17
17
 
18
- inline Object wrap(torch::Tensor x) {
19
- return to_ruby<torch::Tensor>(x);
18
+ inline VALUE wrap(torch::Tensor x) {
19
+ return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
20
20
  }
21
21
 
22
- inline Object wrap(torch::Scalar x) {
23
- return to_ruby<torch::Scalar>(x);
22
+ inline VALUE wrap(torch::Scalar x) {
23
+ return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
24
24
  }
25
25
 
26
- inline Object wrap(torch::ScalarType x) {
27
- return to_ruby<torch::ScalarType>(x);
26
+ inline VALUE wrap(torch::ScalarType x) {
27
+ return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
28
28
  }
29
29
 
30
- inline Object wrap(torch::QScheme x) {
31
- return to_ruby<torch::QScheme>(x);
30
+ inline VALUE wrap(torch::QScheme x) {
31
+ return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
32
32
  }
33
33
 
34
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
- Array a;
36
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
37
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
38
- return Object(a);
34
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
+ return rb_ary_new3(
36
+ 2,
37
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
38
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
39
+ );
39
40
  }
40
41
 
41
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
42
- Array a;
43
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
44
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
45
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
46
- return Object(a);
42
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
43
+ return rb_ary_new3(
44
+ 3,
45
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
46
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
47
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
48
+ );
47
49
  }
48
50
 
49
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
50
- Array a;
51
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
52
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
53
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
54
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
55
- return Object(a);
51
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
52
+ return rb_ary_new3(
53
+ 4,
54
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
55
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
56
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
57
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
58
+ );
56
59
  }
57
60
 
58
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
59
- Array a;
60
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
61
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
62
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
63
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
64
- a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
65
- return Object(a);
61
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
62
+ return rb_ary_new3(
63
+ 5,
64
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
65
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
66
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
67
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
68
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
69
+ );
66
70
  }
67
71
 
68
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
69
- Array a;
70
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
71
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
72
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
73
- a.push(to_ruby<int64_t>(std::get<3>(x)));
74
- return Object(a);
72
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
73
+ return rb_ary_new3(
74
+ 4,
75
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
76
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
77
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
78
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
79
+ );
75
80
  }
76
81
 
77
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
78
- Array a;
79
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
80
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
81
- a.push(to_ruby<double>(std::get<2>(x)));
82
- a.push(to_ruby<int64_t>(std::get<3>(x)));
83
- return Object(a);
82
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
83
+ return rb_ary_new3(
84
+ 4,
85
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
86
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
87
+ Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
88
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
89
+ );
84
90
  }
85
91
 
86
- inline Object wrap(torch::TensorList x) {
87
- Array a;
88
- for (auto& t : x) {
89
- a.push(to_ruby<torch::Tensor>(t));
92
+ inline VALUE wrap(torch::TensorList x) {
93
+ auto a = rb_ary_new2(x.size());
94
+ for (auto t : x) {
95
+ rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
90
96
  }
91
- return Object(a);
97
+ return a;
92
98
  }
93
99
 
94
- inline Object wrap(std::tuple<double, double> x) {
95
- Array a;
96
- a.push(to_ruby<double>(std::get<0>(x)));
97
- a.push(to_ruby<double>(std::get<1>(x)));
98
- return Object(a);
100
+ inline VALUE wrap(std::tuple<double, double> x) {
101
+ return rb_ary_new3(
102
+ 2,
103
+ Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
104
+ Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
105
+ );
99
106
  }
data/lib/torch.rb CHANGED
@@ -238,8 +238,11 @@ module Torch
238
238
  double: 7,
239
239
  float64: 7,
240
240
  complex_half: 8,
241
+ complex32: 8,
241
242
  complex_float: 9,
243
+ complex64: 9,
242
244
  complex_double: 10,
245
+ complex128: 10,
243
246
  bool: 11,
244
247
  qint8: 12,
245
248
  quint8: 13,
@@ -394,6 +397,8 @@ module Torch
394
397
  options[:dtype] = :int64
395
398
  elsif data.all? { |v| v == true || v == false }
396
399
  options[:dtype] = :bool
400
+ elsif data.any? { |v| v.is_a?(Complex) }
401
+ options[:dtype] = :complex64
397
402
  end
398
403
  end
399
404
 
@@ -96,8 +96,11 @@ module Torch
96
96
  ret = "%.#{PRINT_OPTS[:precision]}f" % value
97
97
  end
98
98
  elsif @complex_dtype
99
- p = PRINT_OPTS[:precision]
100
- raise NotImplementedYet
99
+ # TODO use float formatter for each part
100
+ precision = PRINT_OPTS[:precision]
101
+ imag = value.imag
102
+ sign = imag >= 0 ? "+" : "-"
103
+ ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
101
104
  else
102
105
  ret = value.to_s
103
106
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.6.0"
2
+ VERSION = "0.7.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.6.0
4
+ version: 0.7.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-03-26 00:00:00.000000000 Z
11
+ date: 2021-05-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '2.2'
19
+ version: 4.0.2
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '2.2'
26
+ version: 4.0.2
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []