torch-rb 0.5.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_tensor_functions(Module m);
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,86 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Module.hpp>
4
+
5
+ #include "torch_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_torch(Rice::Module& m) {
10
+ m.add_handler<torch::Error>(handle_error);
11
+ add_torch_functions(m);
12
+ m.define_singleton_method(
13
+ "grad_enabled?",
14
+ *[]() {
15
+ return torch::GradMode::is_enabled();
16
+ })
17
+ .define_singleton_method(
18
+ "_set_grad_enabled",
19
+ *[](bool enabled) {
20
+ torch::GradMode::set_enabled(enabled);
21
+ })
22
+ .define_singleton_method(
23
+ "manual_seed",
24
+ *[](uint64_t seed) {
25
+ return torch::manual_seed(seed);
26
+ })
27
+ // config
28
+ .define_singleton_method(
29
+ "show_config",
30
+ *[] {
31
+ return torch::show_config();
32
+ })
33
+ .define_singleton_method(
34
+ "parallel_info",
35
+ *[] {
36
+ return torch::get_parallel_info();
37
+ })
38
+ // begin operations
39
+ .define_singleton_method(
40
+ "_save",
41
+ *[](const torch::IValue &value) {
42
+ auto v = torch::pickle_save(value);
43
+ std::string str(v.begin(), v.end());
44
+ return str;
45
+ })
46
+ .define_singleton_method(
47
+ "_load",
48
+ *[](const std::string &s) {
49
+ std::vector<char> v;
50
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
51
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
+ return torch::pickle_load(v);
53
+ })
54
+ .define_singleton_method(
55
+ "_from_blob",
56
+ *[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
+ void *data = const_cast<char *>(s.c_str());
58
+ return torch::from_blob(data, size, options);
59
+ })
60
+ .define_singleton_method(
61
+ "_tensor",
62
+ *[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
+ auto dtype = options.dtype();
64
+ torch::Tensor t;
65
+ if (dtype == torch::kBool) {
66
+ std::vector<uint8_t> vec;
67
+ for (long i = 0; i < a.size(); i++) {
68
+ vec.push_back(from_ruby<bool>(a[i]));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else {
72
+ std::vector<float> vec;
73
+ for (long i = 0; i < a.size(); i++) {
74
+ vec.push_back(from_ruby<float>(a[i]));
75
+ }
76
+ // hack for requires_grad error
77
+ if (options.requires_grad()) {
78
+ t = torch::tensor(vec, options.requires_grad(c10::nullopt));
79
+ t.set_requires_grad(true);
80
+ } else {
81
+ t = torch::tensor(vec, options);
82
+ }
83
+ }
84
+ return t.reshape(size);
85
+ });
86
+ }
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_torch_functions(Module m);
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -1,13 +1,20 @@
1
1
  #pragma once
2
2
 
3
+ #include <rice/Exception.hpp>
3
4
  #include <rice/Symbol.hpp>
4
5
 
6
+ // TODO find better place
7
+ inline void handle_error(torch::Error const & ex)
8
+ {
9
+ throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
10
+ }
11
+
5
12
  // keep THP prefix for now to make it easier to compare code
6
13
 
7
14
  extern VALUE THPVariableClass;
8
15
 
9
16
  inline VALUE THPUtils_internSymbol(const std::string& str) {
10
- return Symbol(str);
17
+ return Rice::Symbol(str);
11
18
  }
12
19
 
13
20
  inline std::string THPUtils_unpackSymbol(VALUE obj) {
data/lib/torch.rb CHANGED
@@ -337,25 +337,24 @@ module Torch
337
337
  }
338
338
  end
339
339
 
340
- def no_grad
341
- previous_value = grad_enabled?
342
- begin
343
- _set_grad_enabled(false)
344
- yield
345
- ensure
346
- _set_grad_enabled(previous_value)
347
- end
340
+ def no_grad(&block)
341
+ grad_enabled(false, &block)
342
+ end
343
+
344
+ def enable_grad(&block)
345
+ grad_enabled(true, &block)
348
346
  end
349
347
 
350
- def enable_grad
348
+ def grad_enabled(value)
351
349
  previous_value = grad_enabled?
352
350
  begin
353
- _set_grad_enabled(true)
351
+ _set_grad_enabled(value)
354
352
  yield
355
353
  ensure
356
354
  _set_grad_enabled(previous_value)
357
355
  end
358
356
  end
357
+ alias_method :set_grad_enabled, :grad_enabled
359
358
 
360
359
  def device(str)
361
360
  Device.new(str)
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
+ attr_reader :in_features, :out_features
5
+
4
6
  def initialize(in_features, out_features, bias: true)
5
7
  super()
6
8
  @in_features = in_features
@@ -286,6 +286,12 @@ module Torch
286
286
  named_buffers[name]
287
287
  elsif named_modules.key?(name)
288
288
  named_modules[name]
289
+ elsif method.end_with?("=") && named_modules.key?(method[0..-2])
290
+ if instance_variable_defined?("@#{method[0..-2]}")
291
+ instance_variable_set("@#{method[0..-2]}", *args)
292
+ else
293
+ raise NotImplementedYet
294
+ end
289
295
  else
290
296
  super
291
297
  end
@@ -3,7 +3,7 @@ module Torch
3
3
  class Parameter < Tensor
4
4
  def self.new(data = nil, requires_grad: true)
5
5
  data = Tensor.new unless data
6
- Tensor._make_subclass(data, requires_grad)
6
+ _make_subclass(data, requires_grad)
7
7
  end
8
8
 
9
9
  def inspect
data/lib/torch/tensor.rb CHANGED
@@ -135,6 +135,10 @@ module Torch
135
135
  Torch.ones_like(Torch.empty(*size), **options)
136
136
  end
137
137
 
138
+ def requires_grad=(requires_grad)
139
+ _requires_grad!(requires_grad)
140
+ end
141
+
138
142
  def requires_grad!(requires_grad = true)
139
143
  _requires_grad!(requires_grad)
140
144
  end
@@ -60,7 +60,7 @@ module Torch
60
60
  when Array
61
61
  batch.transpose.map { |v| default_convert(v) }
62
62
  else
63
- raise NotImplementedYet
63
+ batch
64
64
  end
65
65
  end
66
66
 
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.5.3"
2
+ VERSION = "0.6.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.3
4
+ version: 0.6.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-01-14 00:00:00.000000000 Z
11
+ date: 2021-03-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -24,92 +24,8 @@ dependencies:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
26
  version: '2.2'
27
- - !ruby/object:Gem::Dependency
28
- name: bundler
29
- requirement: !ruby/object:Gem::Requirement
30
- requirements:
31
- - - ">="
32
- - !ruby/object:Gem::Version
33
- version: '0'
34
- type: :development
35
- prerelease: false
36
- version_requirements: !ruby/object:Gem::Requirement
37
- requirements:
38
- - - ">="
39
- - !ruby/object:Gem::Version
40
- version: '0'
41
- - !ruby/object:Gem::Dependency
42
- name: rake
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: rake-compiler
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: minitest
71
- requirement: !ruby/object:Gem::Requirement
72
- requirements:
73
- - - ">="
74
- - !ruby/object:Gem::Version
75
- version: '5'
76
- type: :development
77
- prerelease: false
78
- version_requirements: !ruby/object:Gem::Requirement
79
- requirements:
80
- - - ">="
81
- - !ruby/object:Gem::Version
82
- version: '5'
83
- - !ruby/object:Gem::Dependency
84
- name: numo-narray
85
- requirement: !ruby/object:Gem::Requirement
86
- requirements:
87
- - - ">="
88
- - !ruby/object:Gem::Version
89
- version: '0'
90
- type: :development
91
- prerelease: false
92
- version_requirements: !ruby/object:Gem::Requirement
93
- requirements:
94
- - - ">="
95
- - !ruby/object:Gem::Version
96
- version: '0'
97
- - !ruby/object:Gem::Dependency
98
- name: torchvision
99
- requirement: !ruby/object:Gem::Requirement
100
- requirements:
101
- - - ">="
102
- - !ruby/object:Gem::Version
103
- version: 0.1.1
104
- type: :development
105
- prerelease: false
106
- version_requirements: !ruby/object:Gem::Requirement
107
- requirements:
108
- - - ">="
109
- - !ruby/object:Gem::Version
110
- version: 0.1.1
111
27
  description:
112
- email: andrew@chartkick.com
28
+ email: andrew@ankane.org
113
29
  executables: []
114
30
  extensions:
115
31
  - ext/torch/extconf.rb
@@ -121,13 +37,20 @@ files:
121
37
  - codegen/function.rb
122
38
  - codegen/generate_functions.rb
123
39
  - codegen/native_functions.yaml
40
+ - ext/torch/cuda.cpp
41
+ - ext/torch/device.cpp
124
42
  - ext/torch/ext.cpp
125
43
  - ext/torch/extconf.rb
44
+ - ext/torch/ivalue.cpp
45
+ - ext/torch/nn.cpp
126
46
  - ext/torch/nn_functions.h
47
+ - ext/torch/random.cpp
127
48
  - ext/torch/ruby_arg_parser.cpp
128
49
  - ext/torch/ruby_arg_parser.h
129
50
  - ext/torch/templates.h
51
+ - ext/torch/tensor.cpp
130
52
  - ext/torch/tensor_functions.h
53
+ - ext/torch/torch.cpp
131
54
  - ext/torch/torch_functions.h
132
55
  - ext/torch/utils.h
133
56
  - ext/torch/wrap_outputs.h
@@ -282,7 +205,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
282
205
  requirements:
283
206
  - - ">="
284
207
  - !ruby/object:Gem::Version
285
- version: '2.4'
208
+ version: '2.6'
286
209
  required_rubygems_version: !ruby/object:Gem::Requirement
287
210
  requirements:
288
211
  - - ">="