torch-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 38e16e7f07d004fd9625f168694356d551c79cbc62b0131fe1403e4c0995f296
4
+ data.tar.gz: 66bf6ae0e4dd373a7542fbfb1cfb9dbd89fc455e16166e6a76d0945b32fecf38
5
+ SHA512:
6
+ metadata.gz: d100e3a21ac877fe93ac61e9b5e0d8a5e61126684fc037dda3e9f703b040188b1e1523aa4111dff4aaf92ada1001597c5f60674b9583b14d31afd18dbf1ff18d
7
+ data.tar.gz: c234dee79e26d3ee25ade2aaddd75f155dea6d59d8b9c5af2c571423a7aaa8a6489f5cfce89f09f390468a951b1644a4212c19525a79816be09214f0938860a8
data/CHANGELOG.md ADDED
@@ -0,0 +1,3 @@
1
+ ## 0.1.0 (2019-11-26)
2
+
3
+ - First release
data/LICENSE.txt ADDED
@@ -0,0 +1,22 @@
1
+ Copyright (c) 2019 Andrew Kane
2
+
3
+ MIT License
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining
6
+ a copy of this software and associated documentation files (the
7
+ "Software"), to deal in the Software without restriction, including
8
+ without limitation the rights to use, copy, modify, merge, publish,
9
+ distribute, sublicense, and/or sell copies of the Software, and to
10
+ permit persons to whom the Software is furnished to do so, subject to
11
+ the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be
14
+ included in all copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,363 @@
1
+ # Torch-rb
2
+
3
+ :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
+
5
+ **Note:** This gem is currently experimental. There may be breaking changes between each release.
6
+
7
+ ## Installation
8
+
9
+ First, [install LibTorch](#libtorch-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install ankane/brew/libtorch
13
+ ```
14
+
15
+ Add this line to your application’s Gemfile:
16
+
17
+ ```ruby
18
+ gem 'torch-rb'
19
+ ```
20
+
21
+ ## Getting Started
22
+
23
+ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
24
+
25
+ - Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
26
+ - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
27
+ - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
28
+
29
+ Many methods and options are missing at the moment. PRs welcome!
30
+
31
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
32
+
33
+ ### Tensors
34
+
35
+ Create a tensor from a Ruby array
36
+
37
+ ```ruby
38
+ x = Torch.tensor([[1, 2, 3], [4, 5, 6]])
39
+ ```
40
+
41
+ Get the shape of a tensor
42
+
43
+ ```ruby
44
+ x.shape
45
+ ```
46
+
47
+ There are [many functions](#tensor-creation) to create tensors, like
48
+
49
+ ```ruby
50
+ a = Torch.rand(3)
51
+ b = Torch.zeros(2, 3)
52
+ ```
53
+
54
+ Each tensor has four properties
55
+
56
+ - `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `float64`, or `:bool`
57
+ - `layout` - `:strided` (dense) or `:sparse`
58
+ - `device` - the compute device, like CPU or GPU
59
+ - `requires_grad` - whether or not to record gradients
60
+
61
+ You can specify properties when creating a tensor
62
+
63
+ ```ruby
64
+ Torch.rand(2, 3, dtype: :double, layout: :strided, device: "cpu", requires_grad: true)
65
+ ```
66
+
67
+ ### Operations
68
+
69
+ Create a tensor
70
+
71
+ ```ruby
72
+ x = Torch.tensor([10, 20, 30])
73
+ ```
74
+
75
+ Add
76
+
77
+ ```ruby
78
+ x + 5 # tensor([15, 25, 35])
79
+ ```
80
+
81
+ Subtract
82
+
83
+ ```ruby
84
+ x - 5 # tensor([5, 15, 25])
85
+ ```
86
+
87
+ Multiply
88
+
89
+ ```ruby
90
+ x * 5 # tensor([50, 100, 150])
91
+ ```
92
+
93
+ Divide
94
+
95
+ ```ruby
96
+ x / 5 # tensor([2, 4, 6])
97
+ ```
98
+
99
+ Get the remainder
100
+
101
+ ```ruby
102
+ x % 3 # tensor([1, 2, 0])
103
+ ```
104
+
105
+ Raise to a power
106
+
107
+ ```ruby
108
+ x**2 # tensor([100, 400, 900])
109
+ ```
110
+
111
+ Perform operations with other tensors
112
+
113
+ ```ruby
114
+ y = Torch.tensor([1, 2, 3])
115
+ x + y # tensor([11, 22, 33])
116
+ ```
117
+
118
+ Perform operations in-place
119
+
120
+ ```ruby
121
+ x.add!(5)
122
+ x # tensor([15, 25, 35])
123
+ ```
124
+
125
+ You can also specify an output tensor
126
+
127
+ ```ruby
128
+ result = Torch.empty(3)
129
+ Torch.add(x, y, out: result)
130
+ result # tensor([15, 25, 35])
131
+ ```
132
+
133
+ ### Numo
134
+
135
+ Convert a tensor to a Numo array
136
+
137
+ ```ruby
138
+ a = Torch.ones(5)
139
+ a.numo
140
+ ```
141
+
142
+ Convert a Numo array to a tensor
143
+
144
+ ```ruby
145
+ b = Numo::NArray.cast([1, 2, 3])
146
+ Torch.from_numpy(b)
147
+ ```
148
+
149
+ ### Autograd
150
+
151
+ Create a tensor with `requires_grad: true`
152
+
153
+ ```ruby
154
+ x = Torch.ones(2, 2, requires_grad: true)
155
+ ```
156
+
157
+ Perform operations
158
+
159
+ ```ruby
160
+ y = x + 2
161
+ z = y * y * 3
162
+ out = z.mean
163
+ ```
164
+
165
+ Backprop
166
+
167
+ ```ruby
168
+ out.backward
169
+ ```
170
+
171
+ Get gradients
172
+
173
+ ```ruby
174
+ x.grad
175
+ ```
176
+
177
+ Stop autograd from tracking history
178
+
179
+ ```ruby
180
+ x.requires_grad # true
181
+ (x ** 2).requires_grad # true
182
+
183
+ Torch.no_grad do
184
+ (x ** 2).requires_grad # false
185
+ end
186
+ ```
187
+
188
+ ### Neural Networks
189
+
190
+ Define a neural network
191
+
192
+ ```ruby
193
+ class Net < Torch::NN::Module
194
+ def initialize
195
+ super
196
+ @conv1 = Torch::NN::Conv2d.new(1, 6, 3)
197
+ @conv2 = Torch::NN::Conv2d.new(6, 16, 3)
198
+ @fc1 = Torch::NN::Linear.new(16 * 6 * 6, 120)
199
+ @fc2 = Torch::NN::Linear.new(120, 84)
200
+ @fc3 = Torch::NN::Linear.new(84, 10)
201
+ end
202
+
203
+ def forward(x)
204
+ x = Torch::NN::F.max_pool2d(Torch::NN::F.relu(@conv1.call(x)), [2, 2])
205
+ x = Torch::NN::F.max_pool2d(Torch::NN::F.relu(@conv2.call(x)), 2)
206
+ x = x.view(-1, num_flat_features(x))
207
+ x = Torch::NN::F.relu(@fc1.call(x))
208
+ x = Torch::NN::F.relu(@fc2.call(x))
209
+ x = @fc3.call(x)
210
+ x
211
+ end
212
+
213
+ def num_flat_features(x)
214
+ size = x.size[1..-1]
215
+ num_features = 1
216
+ size.each do |s|
217
+ num_features *= s
218
+ end
219
+ num_features
220
+ end
221
+ end
222
+ ```
223
+
224
+ And run
225
+
226
+ ```ruby
227
+ net = Net.new
228
+ input = Torch.randn(1, 1, 32, 32)
229
+ net.call(input)
230
+ ```
231
+
232
+ ### Tensor Creation
233
+
234
+ Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
235
+
236
+ - `arange` returns a tensor with a sequence of integers
237
+
238
+ ```ruby
239
+ Torch.arange(3) # tensor([0, 1, 2])
240
+ ```
241
+
242
+ - `empty` returns a tensor with uninitialized values
243
+
244
+ ```ruby
245
+ Torch.empty(3)
246
+ ```
247
+
248
+ - `eye` returns an identity matrix
249
+
250
+ ```ruby
251
+ Torch.eye(2) # tensor([[1, 0], [0, 1]])
252
+ ```
253
+
254
+ - `full` returns a tensor filled with a single value
255
+
256
+ ```ruby
257
+ Torch.full([3], 5) # tensor([5, 5, 5])
258
+ ```
259
+
260
+ - `linspace` returns a tensor with values linearly spaced in some interval
261
+
262
+ ```ruby
263
+ Torch.linspace(0, 10, 5) # tensor([0, 5, 10])
264
+ ```
265
+
266
+ - `logspace` returns a tensor with values logarithmically spaced in some interval
267
+
268
+ ```ruby
269
+ Torch.logspace(0, 10, 5) # tensor([1, 1e5, 1e10])
270
+ ```
271
+
272
+ - `ones` returns a tensor filled with all ones
273
+
274
+ ```ruby
275
+ Torch.ones(3) # tensor([1, 1, 1])
276
+ ```
277
+
278
+ - `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
279
+
280
+ ```ruby
281
+ Torch.rand(3)
282
+ ```
283
+
284
+ - `randint` returns a tensor with integers randomly drawn from an interval
285
+
286
+ ```ruby
287
+ Torch.randint(1, 10, [3])
288
+ ```
289
+
290
+ - `randn` returns a tensor filled with values drawn from a unit normal distribution
291
+
292
+ ```ruby
293
+ Torch.randn(3)
294
+ ```
295
+
296
+ - `randperm` returns a tensor filled with a random permutation of integers in some interval
297
+
298
+ ```ruby
299
+ Torch.randperm(3) # tensor([2, 0, 1])
300
+ ```
301
+
302
+ - `zeros` returns a tensor filled with all zeros
303
+
304
+ ```ruby
305
+ Torch.zeros(3) # tensor([0, 0, 0])
306
+ ```
307
+
308
+ ## LibTorch Installation
309
+
310
+ [Download LibTorch](https://pytorch.org/) and run:
311
+
312
+ ```sh
313
+ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
314
+ ```
315
+
316
+ ### Homebrew
317
+
318
+ For Mac, you can use Homebrew.
319
+
320
+ ```sh
321
+ brew install ankane/brew/libtorch
322
+ ```
323
+
324
+ Then install the gem (no need for `--with-torch-dir`).
325
+
326
+ ## rbenv
327
+
328
+ This library uses [Rice](https://github.com/jasonroelofs/rice) to interface with LibTorch. Rice and earlier versions of rbenv don’t play nicely together. If you encounter an error during installation, upgrade ruby-build and reinstall your Ruby version.
329
+
330
+ ```sh
331
+ brew upgrade ruby-build
332
+ rbenv install [version]
333
+ ```
334
+
335
+ ## History
336
+
337
+ View the [changelog](https://github.com/ankane/torch-rb/blob/master/CHANGELOG.md)
338
+
339
+ ## Contributing
340
+
341
+ Everyone is encouraged to help improve this project. Here are a few ways you can help:
342
+
343
+ - [Report bugs](https://github.com/ankane/torch-rb/issues)
344
+ - Fix bugs and [submit pull requests](https://github.com/ankane/torch-rb/pulls)
345
+ - Write, clarify, or fix documentation
346
+ - Suggest or add new features
347
+
348
+ To get started with development:
349
+
350
+ ```sh
351
+ git clone https://github.com/ankane/torch-rb.git
352
+ cd torch
353
+ bundle install
354
+ bundle exec rake compile
355
+ bundle exec rake test
356
+ ```
357
+
358
+ Here are some good resources for contributors:
359
+
360
+ - [PyTorch API](https://pytorch.org/docs/stable/torch.html)
361
+ - [PyTorch C++ API](https://pytorch.org/cppdocs/)
362
+ - [Tensor Creation API](https://pytorch.org/cppdocs/notes/tensor_creation.html)
363
+ - [Using the PyTorch C++ Frontend](https://pytorch.org/tutorials/advanced/cpp_frontend.html)
data/ext/torch/ext.cpp ADDED
@@ -0,0 +1,546 @@
1
+ #include <sstream>
2
+
3
+ #include <torch/torch.h>
4
+
5
+ #include <rice/Array.hpp>
6
+ #include <rice/Class.hpp>
7
+ #include <rice/Constructor.hpp>
8
+
9
+ using namespace Rice;
10
+
11
+ template<>
12
+ inline
13
+ long long from_ruby<long long>(Object x)
14
+ {
15
+ return NUM2LL(x);
16
+ }
17
+
18
+ template<>
19
+ inline
20
+ Object to_ruby<long long>(long long const & x)
21
+ {
22
+ return LL2NUM(x);
23
+ }
24
+
25
+ template<>
26
+ inline
27
+ unsigned long long from_ruby<unsigned long long>(Object x)
28
+ {
29
+ return NUM2ULL(x);
30
+ }
31
+
32
+ template<>
33
+ inline
34
+ Object to_ruby<unsigned long long>(unsigned long long const & x)
35
+ {
36
+ return ULL2NUM(x);
37
+ }
38
+
39
+ template<>
40
+ inline
41
+ short from_ruby<short>(Object x)
42
+ {
43
+ return NUM2SHORT(x);
44
+ }
45
+
46
+ template<>
47
+ inline
48
+ Object to_ruby<short>(short const & x)
49
+ {
50
+ return INT2NUM(x);
51
+ }
52
+
53
+ template<>
54
+ inline
55
+ unsigned short from_ruby<unsigned short>(Object x)
56
+ {
57
+ return NUM2USHORT(x);
58
+ }
59
+
60
+ template<>
61
+ inline
62
+ Object to_ruby<unsigned short>(unsigned short const & x)
63
+ {
64
+ return UINT2NUM(x);
65
+ }
66
+
67
+ // need to wrap torch::IntArrayRef() since
68
+ // it doesn't own underlying data
69
+ class IntArrayRef {
70
+ std::vector<int64_t> vec;
71
+ public:
72
+ IntArrayRef(Object o) {
73
+ Array a = Array(o);
74
+ for (size_t i = 0; i < a.size(); i++) {
75
+ vec.push_back(from_ruby<int64_t>(a[i]));
76
+ }
77
+ }
78
+ operator torch::IntArrayRef() {
79
+ return torch::IntArrayRef(vec);
80
+ }
81
+ };
82
+
83
+ template<>
84
+ inline
85
+ IntArrayRef from_ruby<IntArrayRef>(Object x)
86
+ {
87
+ return IntArrayRef(x);
88
+ }
89
+
90
+ // for now
91
+ typedef float Scalar;
92
+
93
+ extern "C"
94
+ void Init_ext()
95
+ {
96
+ Module rb_mTorch = define_module("Torch")
97
+ .define_singleton_method(
98
+ "grad_enabled?",
99
+ *[]() {
100
+ return torch::GradMode::is_enabled();
101
+ })
102
+ .define_singleton_method(
103
+ "_set_grad_enabled",
104
+ *[](bool enabled) {
105
+ torch::GradMode::set_enabled(enabled);
106
+ })
107
+ .define_singleton_method(
108
+ "floating_point?",
109
+ *[](torch::Tensor& input) {
110
+ return torch::is_floating_point(input);
111
+ })
112
+ .define_singleton_method(
113
+ "manual_seed",
114
+ *[](uint64_t seed) {
115
+ return torch::manual_seed(seed);
116
+ })
117
+ // begin tensor creation
118
+ .define_singleton_method(
119
+ "_arange",
120
+ *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
121
+ return torch::arange(start, end, step, options);
122
+ })
123
+ .define_singleton_method(
124
+ "_empty",
125
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
126
+ return torch::empty(size, options);
127
+ })
128
+ .define_singleton_method(
129
+ "_eye",
130
+ *[](int64_t m, int64_t n, const torch::TensorOptions &options) {
131
+ return torch::eye(m, n, options);
132
+ })
133
+ .define_singleton_method(
134
+ "_full",
135
+ *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
136
+ return torch::full(size, fill_value, options);
137
+ })
138
+ .define_singleton_method(
139
+ "_linspace",
140
+ *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
141
+ return torch::linspace(start, end, steps, options);
142
+ })
143
+ .define_singleton_method(
144
+ "_logspace",
145
+ *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
146
+ return torch::logspace(start, end, steps, base, options);
147
+ })
148
+ .define_singleton_method(
149
+ "_ones",
150
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
151
+ return torch::ones(size, options);
152
+ })
153
+ .define_singleton_method(
154
+ "_rand",
155
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
156
+ return torch::rand(size, options);
157
+ })
158
+ .define_singleton_method(
159
+ "_randint",
160
+ *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
161
+ return torch::randint(low, high, size, options);
162
+ })
163
+ .define_singleton_method(
164
+ "_randn",
165
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
166
+ return torch::randn(size, options);
167
+ })
168
+ .define_singleton_method(
169
+ "_randperm",
170
+ *[](int64_t n, const torch::TensorOptions &options) {
171
+ return torch::randperm(n, options);
172
+ })
173
+ .define_singleton_method(
174
+ "_zeros",
175
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
176
+ return torch::zeros(size, options);
177
+ })
178
+ // begin operations
179
+ .define_singleton_method(
180
+ "_mean",
181
+ *[](torch::Tensor& input) {
182
+ return torch::mean(input);
183
+ })
184
+ .define_singleton_method(
185
+ "_mean_dim",
186
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
187
+ return torch::mean(input, dim, keepdim);
188
+ })
189
+ .define_singleton_method(
190
+ "_sum",
191
+ *[](torch::Tensor& input) {
192
+ return torch::sum(input);
193
+ })
194
+ .define_singleton_method(
195
+ "_sum_dim",
196
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
197
+ return torch::sum(input, dim, keepdim);
198
+ })
199
+ .define_singleton_method(
200
+ "_norm",
201
+ *[](torch::Tensor& input) {
202
+ return torch::norm(input);
203
+ })
204
+ .define_singleton_method(
205
+ "_min",
206
+ *[](torch::Tensor& input) {
207
+ return torch::min(input);
208
+ })
209
+ .define_singleton_method(
210
+ "_max",
211
+ *[](torch::Tensor& input) {
212
+ return torch::max(input);
213
+ })
214
+ .define_singleton_method(
215
+ "_exp",
216
+ *[](torch::Tensor& input) {
217
+ return torch::exp(input);
218
+ })
219
+ .define_singleton_method(
220
+ "_log",
221
+ *[](torch::Tensor& input) {
222
+ return torch::log(input);
223
+ })
224
+ .define_singleton_method(
225
+ "_unsqueeze",
226
+ *[](torch::Tensor& input, int64_t dim) {
227
+ return torch::unsqueeze(input, dim);
228
+ })
229
+ .define_singleton_method(
230
+ "_dot",
231
+ *[](torch::Tensor& input, torch::Tensor& tensor) {
232
+ return torch::dot(input, tensor);
233
+ })
234
+ .define_singleton_method(
235
+ "_matmul",
236
+ *[](torch::Tensor& input, torch::Tensor& other) {
237
+ return torch::matmul(input, other);
238
+ })
239
+ .define_singleton_method(
240
+ "_add",
241
+ *[](torch::Tensor& input, torch::Tensor& other) {
242
+ return torch::add(input, other);
243
+ })
244
+ .define_singleton_method(
245
+ "_add_scalar",
246
+ *[](torch::Tensor& input, float other) {
247
+ return torch::add(input, other);
248
+ })
249
+ .define_singleton_method(
250
+ "_add_out",
251
+ *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
252
+ return torch::add_out(out, input, other);
253
+ })
254
+ .define_singleton_method(
255
+ "_sub",
256
+ *[](torch::Tensor& input, torch::Tensor& other) {
257
+ return torch::sub(input, other);
258
+ })
259
+ .define_singleton_method(
260
+ "_sub_scalar",
261
+ *[](torch::Tensor& input, float other) {
262
+ return torch::sub(input, other);
263
+ })
264
+ .define_singleton_method(
265
+ "_mul",
266
+ *[](torch::Tensor& input, torch::Tensor& other) {
267
+ return torch::mul(input, other);
268
+ })
269
+ .define_singleton_method(
270
+ "_mul_scalar",
271
+ *[](torch::Tensor& input, float other) {
272
+ return torch::mul(input, other);
273
+ })
274
+ .define_singleton_method(
275
+ "_div",
276
+ *[](torch::Tensor& input, torch::Tensor& other) {
277
+ return torch::div(input, other);
278
+ })
279
+ .define_singleton_method(
280
+ "_div_scalar",
281
+ *[](torch::Tensor& input, float other) {
282
+ return torch::div(input, other);
283
+ })
284
+ .define_singleton_method(
285
+ "_remainder",
286
+ *[](torch::Tensor& input, torch::Tensor& other) {
287
+ return torch::remainder(input, other);
288
+ })
289
+ .define_singleton_method(
290
+ "_remainder_scalar",
291
+ *[](torch::Tensor& input, float other) {
292
+ return torch::remainder(input, other);
293
+ })
294
+ .define_singleton_method(
295
+ "_pow",
296
+ *[](torch::Tensor& input, Scalar exponent) {
297
+ return torch::pow(input, exponent);
298
+ })
299
+ .define_singleton_method(
300
+ "_neg",
301
+ *[](torch::Tensor& input) {
302
+ return torch::neg(input);
303
+ })
304
+ .define_singleton_method(
305
+ "relu",
306
+ *[](torch::Tensor& input) {
307
+ return torch::relu(input);
308
+ })
309
+ .define_singleton_method(
310
+ "conv2d",
311
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
312
+ return torch::conv2d(input, weight, bias);
313
+ })
314
+ .define_singleton_method(
315
+ "linear",
316
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
317
+ return torch::linear(input, weight, bias);
318
+ })
319
+ .define_singleton_method(
320
+ "max_pool2d",
321
+ *[](torch::Tensor& input, IntArrayRef kernel_size) {
322
+ return torch::max_pool2d(input, kernel_size);
323
+ })
324
+ .define_singleton_method(
325
+ "mse_loss",
326
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
327
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
328
+ return torch::mse_loss(input, target, red);
329
+ })
330
+ .define_singleton_method(
331
+ "nll_loss",
332
+ *[](torch::Tensor& input, torch::Tensor& target) {
333
+ return torch::nll_loss(input, target);
334
+ })
335
+ .define_singleton_method(
336
+ "_tensor",
337
+ *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
338
+ Array a = Array(o);
339
+ std::vector<float> vec;
340
+ for (size_t i = 0; i < a.size(); i++) {
341
+ vec.push_back(from_ruby<float>(a[i]));
342
+ }
343
+ return torch::tensor(vec, options).reshape(size);
344
+ });
345
+
346
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
347
+ .define_method("cuda?", &torch::Tensor::is_cuda)
348
+ .define_method("distributed?", &torch::Tensor::is_distributed)
349
+ .define_method("complex?", &torch::Tensor::is_complex)
350
+ .define_method("floating_point?", &torch::Tensor::is_floating_point)
351
+ .define_method("signed?", &torch::Tensor::is_signed)
352
+ .define_method("sparse?", &torch::Tensor::is_sparse)
353
+ .define_method("quantized?", &torch::Tensor::is_quantized)
354
+ .define_method("dim", &torch::Tensor::dim)
355
+ .define_method("numel", &torch::Tensor::numel)
356
+ .define_method("element_size", &torch::Tensor::element_size)
357
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
358
+ .define_method(
359
+ "zero!",
360
+ *[](torch::Tensor& self) {
361
+ return self.zero_();
362
+ })
363
+ .define_method(
364
+ "detach!",
365
+ *[](torch::Tensor& self) {
366
+ return self.detach_();
367
+ })
368
+ .define_method(
369
+ "_access",
370
+ *[](torch::Tensor& self, int64_t index) {
371
+ return self[index];
372
+ })
373
+ .define_method(
374
+ "_requires_grad!",
375
+ *[](torch::Tensor& self, bool requires_grad) {
376
+ return self.set_requires_grad(requires_grad);
377
+ })
378
+ .define_method(
379
+ "backward",
380
+ *[](torch::Tensor& self) {
381
+ return self.backward();
382
+ })
383
+ .define_method(
384
+ "grad",
385
+ *[](torch::Tensor& self) {
386
+ return self.grad();
387
+ })
388
+ .define_method(
389
+ "_dtype",
390
+ *[](torch::Tensor& self) {
391
+ return (int) at::typeMetaToScalarType(self.dtype());
392
+ })
393
+ .define_method(
394
+ "_layout",
395
+ *[](torch::Tensor& self) {
396
+ std::stringstream s;
397
+ s << self.layout();
398
+ return s.str();
399
+ })
400
+ .define_method(
401
+ "device",
402
+ *[](torch::Tensor& self) {
403
+ std::stringstream s;
404
+ s << self.device();
405
+ return s.str();
406
+ })
407
+ .define_method(
408
+ "_view",
409
+ *[](torch::Tensor& self, IntArrayRef size) {
410
+ return self.view(size);
411
+ })
412
+ .define_method(
413
+ "add!",
414
+ *[](torch::Tensor& self, torch::Tensor& other) {
415
+ self.add_(other);
416
+ })
417
+ .define_method(
418
+ "sub!",
419
+ *[](torch::Tensor& self, torch::Tensor& other) {
420
+ self.sub_(other);
421
+ })
422
+ .define_method(
423
+ "mul!",
424
+ *[](torch::Tensor& self, torch::Tensor& other) {
425
+ self.mul_(other);
426
+ })
427
+ .define_method(
428
+ "div!",
429
+ *[](torch::Tensor& self, torch::Tensor& other) {
430
+ self.div_(other);
431
+ })
432
+ .define_method(
433
+ "log_softmax",
434
+ *[](torch::Tensor& self, int64_t dim) {
435
+ return self.log_softmax(dim);
436
+ })
437
+ .define_method(
438
+ "_data",
439
+ *[](torch::Tensor& self) {
440
+ Array a;
441
+ auto dtype = self.dtype();
442
+
443
+ // TODO DRY if someone knows C++
444
+ // TODO kByte (uint8), kChar (int8), kBool (bool)
445
+ if (dtype == torch::kShort) {
446
+ short* data = self.data_ptr<short>();
447
+ for (int i = 0; i < self.numel(); i++) {
448
+ a.push(data[i]);
449
+ }
450
+ } else if (dtype == torch::kInt) {
451
+ int* data = self.data_ptr<int>();
452
+ for (int i = 0; i < self.numel(); i++) {
453
+ a.push(data[i]);
454
+ }
455
+ } else if (dtype == torch::kLong) {
456
+ long long* data = self.data_ptr<long long>();
457
+ for (int i = 0; i < self.numel(); i++) {
458
+ a.push(data[i]);
459
+ }
460
+ } else if (dtype == torch::kFloat) {
461
+ float* data = self.data_ptr<float>();
462
+ for (int i = 0; i < self.numel(); i++) {
463
+ a.push(data[i]);
464
+ }
465
+ } else if (dtype == torch::kDouble) {
466
+ double* data = self.data_ptr<double>();
467
+ for (int i = 0; i < self.numel(); i++) {
468
+ a.push(data[i]);
469
+ }
470
+ } else {
471
+ throw "Unsupported type";
472
+ }
473
+ return a;
474
+ })
475
+ .define_method(
476
+ "_size",
477
+ *[](torch::Tensor& self, int i) {
478
+ return self.size(i);
479
+ })
480
+ .define_singleton_method(
481
+ "_make_subclass",
482
+ *[](torch::Tensor& rd, bool requires_grad) {
483
+ auto data = torch::autograd::as_variable_ref(rd).detach();
484
+ data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
485
+ auto var = data.set_requires_grad(requires_grad);
486
+ return torch::autograd::Variable(std::move(var));
487
+ });
488
+
489
+ Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
490
+ .define_constructor(Constructor<torch::TensorOptions>())
491
+ .define_method(
492
+ "dtype",
493
+ *[](torch::TensorOptions& self, int dtype) {
494
+ return self.dtype((torch::ScalarType) dtype);
495
+ })
496
+ .define_method(
497
+ "layout",
498
+ *[](torch::TensorOptions& self, std::string layout) {
499
+ torch::Layout l;
500
+ if (layout == "strided") {
501
+ l = torch::kStrided;
502
+ } else {
503
+ throw "Unsupported layout";
504
+ }
505
+ return self.layout(l);
506
+ })
507
+ .define_method(
508
+ "device",
509
+ *[](torch::TensorOptions& self, std::string device) {
510
+ torch::DeviceType d;
511
+ if (device == "cpu") {
512
+ d = torch::kCPU;
513
+ } else if (device == "cuda") {
514
+ d = torch::kCUDA;
515
+ } else {
516
+ throw "Unsupported device";
517
+ }
518
+ return self.device(d);
519
+ })
520
+ .define_method(
521
+ "requires_grad",
522
+ *[](torch::TensorOptions& self, bool requires_grad) {
523
+ return self.requires_grad(requires_grad);
524
+ });
525
+
526
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
527
+
528
+ Module rb_mInit = define_module_under(rb_mNN, "Init")
529
+ .define_singleton_method(
530
+ "kaiming_uniform_",
531
+ *[](torch::Tensor& input, double a) {
532
+ return torch::nn::init::kaiming_uniform_(input, a);
533
+ })
534
+ .define_singleton_method(
535
+ "uniform_",
536
+ *[](torch::Tensor& input, double to, double from) {
537
+ return torch::nn::init::uniform_(input, to, from);
538
+ });
539
+
540
+ Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
541
+ .define_method(
542
+ "grad",
543
+ *[](torch::autograd::Variable& self) {
544
+ return self.grad();
545
+ });
546
+ }