torch-rb 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_tensor_functions(Module m);
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,86 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Module.hpp>
4
+
5
+ #include "torch_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_torch(Rice::Module& m) {
10
+ m.add_handler<torch::Error>(handle_error);
11
+ add_torch_functions(m);
12
+ m.define_singleton_method(
13
+ "grad_enabled?",
14
+ *[]() {
15
+ return torch::GradMode::is_enabled();
16
+ })
17
+ .define_singleton_method(
18
+ "_set_grad_enabled",
19
+ *[](bool enabled) {
20
+ torch::GradMode::set_enabled(enabled);
21
+ })
22
+ .define_singleton_method(
23
+ "manual_seed",
24
+ *[](uint64_t seed) {
25
+ return torch::manual_seed(seed);
26
+ })
27
+ // config
28
+ .define_singleton_method(
29
+ "show_config",
30
+ *[] {
31
+ return torch::show_config();
32
+ })
33
+ .define_singleton_method(
34
+ "parallel_info",
35
+ *[] {
36
+ return torch::get_parallel_info();
37
+ })
38
+ // begin operations
39
+ .define_singleton_method(
40
+ "_save",
41
+ *[](const torch::IValue &value) {
42
+ auto v = torch::pickle_save(value);
43
+ std::string str(v.begin(), v.end());
44
+ return str;
45
+ })
46
+ .define_singleton_method(
47
+ "_load",
48
+ *[](const std::string &s) {
49
+ std::vector<char> v;
50
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
51
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
+ return torch::pickle_load(v);
53
+ })
54
+ .define_singleton_method(
55
+ "_from_blob",
56
+ *[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
+ void *data = const_cast<char *>(s.c_str());
58
+ return torch::from_blob(data, size, options);
59
+ })
60
+ .define_singleton_method(
61
+ "_tensor",
62
+ *[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
+ auto dtype = options.dtype();
64
+ torch::Tensor t;
65
+ if (dtype == torch::kBool) {
66
+ std::vector<uint8_t> vec;
67
+ for (long i = 0; i < a.size(); i++) {
68
+ vec.push_back(from_ruby<bool>(a[i]));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else {
72
+ std::vector<float> vec;
73
+ for (long i = 0; i < a.size(); i++) {
74
+ vec.push_back(from_ruby<float>(a[i]));
75
+ }
76
+ // hack for requires_grad error
77
+ if (options.requires_grad()) {
78
+ t = torch::tensor(vec, options.requires_grad(c10::nullopt));
79
+ t.set_requires_grad(true);
80
+ } else {
81
+ t = torch::tensor(vec, options);
82
+ }
83
+ }
84
+ return t.reshape(size);
85
+ });
86
+ }
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_torch_functions(Module m);
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -1,13 +1,20 @@
1
1
  #pragma once
2
2
 
3
+ #include <rice/Exception.hpp>
3
4
  #include <rice/Symbol.hpp>
4
5
 
6
+ // TODO find better place
7
+ inline void handle_error(torch::Error const & ex)
8
+ {
9
+ throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
10
+ }
11
+
5
12
  // keep THP prefix for now to make it easier to compare code
6
13
 
7
14
  extern VALUE THPVariableClass;
8
15
 
9
16
  inline VALUE THPUtils_internSymbol(const std::string& str) {
10
- return Symbol(str);
17
+ return Rice::Symbol(str);
11
18
  }
12
19
 
13
20
  inline std::string THPUtils_unpackSymbol(VALUE obj) {
data/lib/torch.rb CHANGED
@@ -337,25 +337,24 @@ module Torch
337
337
  }
338
338
  end
339
339
 
340
- def no_grad
341
- previous_value = grad_enabled?
342
- begin
343
- _set_grad_enabled(false)
344
- yield
345
- ensure
346
- _set_grad_enabled(previous_value)
347
- end
340
+ def no_grad(&block)
341
+ grad_enabled(false, &block)
342
+ end
343
+
344
+ def enable_grad(&block)
345
+ grad_enabled(true, &block)
348
346
  end
349
347
 
350
- def enable_grad
348
+ def grad_enabled(value)
351
349
  previous_value = grad_enabled?
352
350
  begin
353
- _set_grad_enabled(true)
351
+ _set_grad_enabled(value)
354
352
  yield
355
353
  ensure
356
354
  _set_grad_enabled(previous_value)
357
355
  end
358
356
  end
357
+ alias_method :set_grad_enabled, :grad_enabled
359
358
 
360
359
  def device(str)
361
360
  Device.new(str)
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
+ attr_reader :in_features, :out_features
5
+
4
6
  def initialize(in_features, out_features, bias: true)
5
7
  super()
6
8
  @in_features = in_features
@@ -286,6 +286,12 @@ module Torch
286
286
  named_buffers[name]
287
287
  elsif named_modules.key?(name)
288
288
  named_modules[name]
289
+ elsif method.end_with?("=") && named_modules.key?(method[0..-2])
290
+ if instance_variable_defined?("@#{method[0..-2]}")
291
+ instance_variable_set("@#{method[0..-2]}", *args)
292
+ else
293
+ raise NotImplementedYet
294
+ end
289
295
  else
290
296
  super
291
297
  end
@@ -3,7 +3,7 @@ module Torch
3
3
  class Parameter < Tensor
4
4
  def self.new(data = nil, requires_grad: true)
5
5
  data = Tensor.new unless data
6
- Tensor._make_subclass(data, requires_grad)
6
+ _make_subclass(data, requires_grad)
7
7
  end
8
8
 
9
9
  def inspect
data/lib/torch/tensor.rb CHANGED
@@ -135,6 +135,10 @@ module Torch
135
135
  Torch.ones_like(Torch.empty(*size), **options)
136
136
  end
137
137
 
138
+ def requires_grad=(requires_grad)
139
+ _requires_grad!(requires_grad)
140
+ end
141
+
138
142
  def requires_grad!(requires_grad = true)
139
143
  _requires_grad!(requires_grad)
140
144
  end
@@ -60,7 +60,7 @@ module Torch
60
60
  when Array
61
61
  batch.transpose.map { |v| default_convert(v) }
62
62
  else
63
- raise NotImplementedYet
63
+ batch
64
64
  end
65
65
  end
66
66
 
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.5.3"
2
+ VERSION = "0.6.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.3
4
+ version: 0.6.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-01-14 00:00:00.000000000 Z
11
+ date: 2021-03-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -24,92 +24,8 @@ dependencies:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
26
  version: '2.2'
27
- - !ruby/object:Gem::Dependency
28
- name: bundler
29
- requirement: !ruby/object:Gem::Requirement
30
- requirements:
31
- - - ">="
32
- - !ruby/object:Gem::Version
33
- version: '0'
34
- type: :development
35
- prerelease: false
36
- version_requirements: !ruby/object:Gem::Requirement
37
- requirements:
38
- - - ">="
39
- - !ruby/object:Gem::Version
40
- version: '0'
41
- - !ruby/object:Gem::Dependency
42
- name: rake
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: rake-compiler
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: minitest
71
- requirement: !ruby/object:Gem::Requirement
72
- requirements:
73
- - - ">="
74
- - !ruby/object:Gem::Version
75
- version: '5'
76
- type: :development
77
- prerelease: false
78
- version_requirements: !ruby/object:Gem::Requirement
79
- requirements:
80
- - - ">="
81
- - !ruby/object:Gem::Version
82
- version: '5'
83
- - !ruby/object:Gem::Dependency
84
- name: numo-narray
85
- requirement: !ruby/object:Gem::Requirement
86
- requirements:
87
- - - ">="
88
- - !ruby/object:Gem::Version
89
- version: '0'
90
- type: :development
91
- prerelease: false
92
- version_requirements: !ruby/object:Gem::Requirement
93
- requirements:
94
- - - ">="
95
- - !ruby/object:Gem::Version
96
- version: '0'
97
- - !ruby/object:Gem::Dependency
98
- name: torchvision
99
- requirement: !ruby/object:Gem::Requirement
100
- requirements:
101
- - - ">="
102
- - !ruby/object:Gem::Version
103
- version: 0.1.1
104
- type: :development
105
- prerelease: false
106
- version_requirements: !ruby/object:Gem::Requirement
107
- requirements:
108
- - - ">="
109
- - !ruby/object:Gem::Version
110
- version: 0.1.1
111
27
  description:
112
- email: andrew@chartkick.com
28
+ email: andrew@ankane.org
113
29
  executables: []
114
30
  extensions:
115
31
  - ext/torch/extconf.rb
@@ -121,13 +37,20 @@ files:
121
37
  - codegen/function.rb
122
38
  - codegen/generate_functions.rb
123
39
  - codegen/native_functions.yaml
40
+ - ext/torch/cuda.cpp
41
+ - ext/torch/device.cpp
124
42
  - ext/torch/ext.cpp
125
43
  - ext/torch/extconf.rb
44
+ - ext/torch/ivalue.cpp
45
+ - ext/torch/nn.cpp
126
46
  - ext/torch/nn_functions.h
47
+ - ext/torch/random.cpp
127
48
  - ext/torch/ruby_arg_parser.cpp
128
49
  - ext/torch/ruby_arg_parser.h
129
50
  - ext/torch/templates.h
51
+ - ext/torch/tensor.cpp
130
52
  - ext/torch/tensor_functions.h
53
+ - ext/torch/torch.cpp
131
54
  - ext/torch/torch_functions.h
132
55
  - ext/torch/utils.h
133
56
  - ext/torch/wrap_outputs.h
@@ -282,7 +205,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
282
205
  requirements:
283
206
  - - ">="
284
207
  - !ruby/object:Gem::Version
285
- version: '2.4'
208
+ version: '2.6'
286
209
  required_rubygems_version: !ruby/object:Gem::Requirement
287
210
  requirements:
288
211
  - - ">="