torch-rb 0.6.0 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
data/ext/torch/torch.cpp CHANGED
@@ -1,6 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "torch_functions.h"
6
6
  #include "templates.h"
@@ -9,69 +9,78 @@
9
9
  void init_torch(Rice::Module& m) {
10
10
  m.add_handler<torch::Error>(handle_error);
11
11
  add_torch_functions(m);
12
- m.define_singleton_method(
12
+ m.define_singleton_function(
13
13
  "grad_enabled?",
14
- *[]() {
14
+ []() {
15
15
  return torch::GradMode::is_enabled();
16
16
  })
17
- .define_singleton_method(
17
+ .define_singleton_function(
18
18
  "_set_grad_enabled",
19
- *[](bool enabled) {
19
+ [](bool enabled) {
20
20
  torch::GradMode::set_enabled(enabled);
21
21
  })
22
- .define_singleton_method(
22
+ .define_singleton_function(
23
23
  "manual_seed",
24
- *[](uint64_t seed) {
24
+ [](uint64_t seed) {
25
25
  return torch::manual_seed(seed);
26
26
  })
27
27
  // config
28
- .define_singleton_method(
28
+ .define_singleton_function(
29
29
  "show_config",
30
- *[] {
30
+ [] {
31
31
  return torch::show_config();
32
32
  })
33
- .define_singleton_method(
33
+ .define_singleton_function(
34
34
  "parallel_info",
35
- *[] {
35
+ [] {
36
36
  return torch::get_parallel_info();
37
37
  })
38
38
  // begin operations
39
- .define_singleton_method(
39
+ .define_singleton_function(
40
40
  "_save",
41
- *[](const torch::IValue &value) {
41
+ [](const torch::IValue &value) {
42
42
  auto v = torch::pickle_save(value);
43
43
  std::string str(v.begin(), v.end());
44
44
  return str;
45
45
  })
46
- .define_singleton_method(
46
+ .define_singleton_function(
47
47
  "_load",
48
- *[](const std::string &s) {
48
+ [](const std::string &s) {
49
49
  std::vector<char> v;
50
50
  std::copy(s.begin(), s.end(), std::back_inserter(v));
51
51
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
52
  return torch::pickle_load(v);
53
53
  })
54
- .define_singleton_method(
54
+ .define_singleton_function(
55
55
  "_from_blob",
56
- *[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
56
+ [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
57
  void *data = const_cast<char *>(s.c_str());
58
58
  return torch::from_blob(data, size, options);
59
59
  })
60
- .define_singleton_method(
60
+ .define_singleton_function(
61
61
  "_tensor",
62
- *[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
62
+ [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
63
  auto dtype = options.dtype();
64
64
  torch::Tensor t;
65
65
  if (dtype == torch::kBool) {
66
66
  std::vector<uint8_t> vec;
67
67
  for (long i = 0; i < a.size(); i++) {
68
- vec.push_back(from_ruby<bool>(a[i]));
68
+ vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
72
+ // TODO use template
73
+ std::vector<c10::complex<double>> vec;
74
+ Object obj;
75
+ for (long i = 0; i < a.size(); i++) {
76
+ obj = a[i];
77
+ vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
69
78
  }
70
79
  t = torch::tensor(vec, options);
71
80
  } else {
72
81
  std::vector<float> vec;
73
82
  for (long i = 0; i < a.size(); i++) {
74
- vec.push_back(from_ruby<float>(a[i]));
83
+ vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
75
84
  }
76
85
  // hack for requires_grad error
77
86
  if (options.requires_grad()) {
data/ext/torch/utils.h CHANGED
@@ -1,11 +1,10 @@
1
1
  #pragma once
2
2
 
3
- #include <rice/Exception.hpp>
4
- #include <rice/Symbol.hpp>
3
+ #include <rice/rice.hpp>
4
+ #include <rice/stl.hpp>
5
5
 
6
6
  // TODO find better place
7
- inline void handle_error(torch::Error const & ex)
8
- {
7
+ inline void handle_error(torch::Error const & ex) {
9
8
  throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
10
9
  }
11
10
 
@@ -1,99 +1,106 @@
1
1
  #pragma once
2
2
 
3
3
  #include <torch/torch.h>
4
- #include <rice/Object.hpp>
4
+ #include <rice/rice.hpp>
5
5
 
6
- inline Object wrap(bool x) {
7
- return to_ruby<bool>(x);
6
+ inline VALUE wrap(bool x) {
7
+ return Rice::detail::To_Ruby<bool>().convert(x);
8
8
  }
9
9
 
10
- inline Object wrap(int64_t x) {
11
- return to_ruby<int64_t>(x);
10
+ inline VALUE wrap(int64_t x) {
11
+ return Rice::detail::To_Ruby<int64_t>().convert(x);
12
12
  }
13
13
 
14
- inline Object wrap(double x) {
15
- return to_ruby<double>(x);
14
+ inline VALUE wrap(double x) {
15
+ return Rice::detail::To_Ruby<double>().convert(x);
16
16
  }
17
17
 
18
- inline Object wrap(torch::Tensor x) {
19
- return to_ruby<torch::Tensor>(x);
18
+ inline VALUE wrap(torch::Tensor x) {
19
+ return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
20
20
  }
21
21
 
22
- inline Object wrap(torch::Scalar x) {
23
- return to_ruby<torch::Scalar>(x);
22
+ inline VALUE wrap(torch::Scalar x) {
23
+ return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
24
24
  }
25
25
 
26
- inline Object wrap(torch::ScalarType x) {
27
- return to_ruby<torch::ScalarType>(x);
26
+ inline VALUE wrap(torch::ScalarType x) {
27
+ return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
28
28
  }
29
29
 
30
- inline Object wrap(torch::QScheme x) {
31
- return to_ruby<torch::QScheme>(x);
30
+ inline VALUE wrap(torch::QScheme x) {
31
+ return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
32
32
  }
33
33
 
34
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
- Array a;
36
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
37
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
38
- return Object(a);
34
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
+ return rb_ary_new3(
36
+ 2,
37
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
38
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
39
+ );
39
40
  }
40
41
 
41
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
42
- Array a;
43
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
44
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
45
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
46
- return Object(a);
42
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
43
+ return rb_ary_new3(
44
+ 3,
45
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
46
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
47
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
48
+ );
47
49
  }
48
50
 
49
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
50
- Array a;
51
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
52
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
53
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
54
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
55
- return Object(a);
51
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
52
+ return rb_ary_new3(
53
+ 4,
54
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
55
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
56
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
57
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
58
+ );
56
59
  }
57
60
 
58
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
59
- Array a;
60
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
61
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
62
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
63
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
64
- a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
65
- return Object(a);
61
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
62
+ return rb_ary_new3(
63
+ 5,
64
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
65
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
66
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
67
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
68
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
69
+ );
66
70
  }
67
71
 
68
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
69
- Array a;
70
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
71
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
72
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
73
- a.push(to_ruby<int64_t>(std::get<3>(x)));
74
- return Object(a);
72
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
73
+ return rb_ary_new3(
74
+ 4,
75
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
76
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
77
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
78
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
79
+ );
75
80
  }
76
81
 
77
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
78
- Array a;
79
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
80
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
81
- a.push(to_ruby<double>(std::get<2>(x)));
82
- a.push(to_ruby<int64_t>(std::get<3>(x)));
83
- return Object(a);
82
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
83
+ return rb_ary_new3(
84
+ 4,
85
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
86
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
87
+ Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
88
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
89
+ );
84
90
  }
85
91
 
86
- inline Object wrap(torch::TensorList x) {
87
- Array a;
88
- for (auto& t : x) {
89
- a.push(to_ruby<torch::Tensor>(t));
92
+ inline VALUE wrap(torch::TensorList x) {
93
+ auto a = rb_ary_new2(x.size());
94
+ for (auto t : x) {
95
+ rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
90
96
  }
91
- return Object(a);
97
+ return a;
92
98
  }
93
99
 
94
- inline Object wrap(std::tuple<double, double> x) {
95
- Array a;
96
- a.push(to_ruby<double>(std::get<0>(x)));
97
- a.push(to_ruby<double>(std::get<1>(x)));
98
- return Object(a);
100
+ inline VALUE wrap(std::tuple<double, double> x) {
101
+ return rb_ary_new3(
102
+ 2,
103
+ Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
104
+ Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
105
+ );
99
106
  }
data/lib/torch.rb CHANGED
@@ -238,8 +238,11 @@ module Torch
238
238
  double: 7,
239
239
  float64: 7,
240
240
  complex_half: 8,
241
+ complex32: 8,
241
242
  complex_float: 9,
243
+ complex64: 9,
242
244
  complex_double: 10,
245
+ complex128: 10,
243
246
  bool: 11,
244
247
  qint8: 12,
245
248
  quint8: 13,
@@ -394,6 +397,8 @@ module Torch
394
397
  options[:dtype] = :int64
395
398
  elsif data.all? { |v| v == true || v == false }
396
399
  options[:dtype] = :bool
400
+ elsif data.any? { |v| v.is_a?(Complex) }
401
+ options[:dtype] = :complex64
397
402
  end
398
403
  end
399
404
 
@@ -96,8 +96,11 @@ module Torch
96
96
  ret = "%.#{PRINT_OPTS[:precision]}f" % value
97
97
  end
98
98
  elsif @complex_dtype
99
- p = PRINT_OPTS[:precision]
100
- raise NotImplementedYet
99
+ # TODO use float formatter for each part
100
+ precision = PRINT_OPTS[:precision]
101
+ imag = value.imag
102
+ sign = imag >= 0 ? "+" : "-"
103
+ ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
101
104
  else
102
105
  ret = value.to_s
103
106
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.6.0"
2
+ VERSION = "0.7.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.6.0
4
+ version: 0.7.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-03-26 00:00:00.000000000 Z
11
+ date: 2021-05-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '2.2'
19
+ version: 4.0.2
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '2.2'
26
+ version: 4.0.2
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []