torch-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 38e16e7f07d004fd9625f168694356d551c79cbc62b0131fe1403e4c0995f296
4
+ data.tar.gz: 66bf6ae0e4dd373a7542fbfb1cfb9dbd89fc455e16166e6a76d0945b32fecf38
5
+ SHA512:
6
+ metadata.gz: d100e3a21ac877fe93ac61e9b5e0d8a5e61126684fc037dda3e9f703b040188b1e1523aa4111dff4aaf92ada1001597c5f60674b9583b14d31afd18dbf1ff18d
7
+ data.tar.gz: c234dee79e26d3ee25ade2aaddd75f155dea6d59d8b9c5af2c571423a7aaa8a6489f5cfce89f09f390468a951b1644a4212c19525a79816be09214f0938860a8
data/CHANGELOG.md ADDED
@@ -0,0 +1,3 @@
1
+ ## 0.1.0 (2019-11-26)
2
+
3
+ - First release
data/LICENSE.txt ADDED
@@ -0,0 +1,22 @@
1
+ Copyright (c) 2019 Andrew Kane
2
+
3
+ MIT License
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining
6
+ a copy of this software and associated documentation files (the
7
+ "Software"), to deal in the Software without restriction, including
8
+ without limitation the rights to use, copy, modify, merge, publish,
9
+ distribute, sublicense, and/or sell copies of the Software, and to
10
+ permit persons to whom the Software is furnished to do so, subject to
11
+ the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be
14
+ included in all copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,363 @@
1
+ # Torch-rb
2
+
3
+ :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
+
5
+ **Note:** This gem is currently experimental. There may be breaking changes between each release.
6
+
7
+ ## Installation
8
+
9
+ First, [install LibTorch](#libtorch-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install ankane/brew/libtorch
13
+ ```
14
+
15
+ Add this line to your application’s Gemfile:
16
+
17
+ ```ruby
18
+ gem 'torch-rb'
19
+ ```
20
+
21
+ ## Getting Started
22
+
23
+ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
24
+
25
+ - Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
26
+ - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
27
+ - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
28
+
29
+ Many methods and options are missing at the moment. PRs welcome!
30
+
31
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
32
+
33
+ ### Tensors
34
+
35
+ Create a tensor from a Ruby array
36
+
37
+ ```ruby
38
+ x = Torch.tensor([[1, 2, 3], [4, 5, 6]])
39
+ ```
40
+
41
+ Get the shape of a tensor
42
+
43
+ ```ruby
44
+ x.shape
45
+ ```
46
+
47
+ There are [many functions](#tensor-creation) to create tensors, like
48
+
49
+ ```ruby
50
+ a = Torch.rand(3)
51
+ b = Torch.zeros(2, 3)
52
+ ```
53
+
54
+ Each tensor has four properties
55
+
56
+ - `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `float64`, or `:bool`
57
+ - `layout` - `:strided` (dense) or `:sparse`
58
+ - `device` - the compute device, like CPU or GPU
59
+ - `requires_grad` - whether or not to record gradients
60
+
61
+ You can specify properties when creating a tensor
62
+
63
+ ```ruby
64
+ Torch.rand(2, 3, dtype: :double, layout: :strided, device: "cpu", requires_grad: true)
65
+ ```
66
+
67
+ ### Operations
68
+
69
+ Create a tensor
70
+
71
+ ```ruby
72
+ x = Torch.tensor([10, 20, 30])
73
+ ```
74
+
75
+ Add
76
+
77
+ ```ruby
78
+ x + 5 # tensor([15, 25, 35])
79
+ ```
80
+
81
+ Subtract
82
+
83
+ ```ruby
84
+ x - 5 # tensor([5, 15, 25])
85
+ ```
86
+
87
+ Multiply
88
+
89
+ ```ruby
90
+ x * 5 # tensor([50, 100, 150])
91
+ ```
92
+
93
+ Divide
94
+
95
+ ```ruby
96
+ x / 5 # tensor([2, 4, 6])
97
+ ```
98
+
99
+ Get the remainder
100
+
101
+ ```ruby
102
+ x % 3 # tensor([1, 2, 0])
103
+ ```
104
+
105
+ Raise to a power
106
+
107
+ ```ruby
108
+ x**2 # tensor([100, 400, 900])
109
+ ```
110
+
111
+ Perform operations with other tensors
112
+
113
+ ```ruby
114
+ y = Torch.tensor([1, 2, 3])
115
+ x + y # tensor([11, 22, 33])
116
+ ```
117
+
118
+ Perform operations in-place
119
+
120
+ ```ruby
121
+ x.add!(5)
122
+ x # tensor([15, 25, 35])
123
+ ```
124
+
125
+ You can also specify an output tensor
126
+
127
+ ```ruby
128
+ result = Torch.empty(3)
129
+ Torch.add(x, y, out: result)
130
+ result # tensor([15, 25, 35])
131
+ ```
132
+
133
+ ### Numo
134
+
135
+ Convert a tensor to a Numo array
136
+
137
+ ```ruby
138
+ a = Torch.ones(5)
139
+ a.numo
140
+ ```
141
+
142
+ Convert a Numo array to a tensor
143
+
144
+ ```ruby
145
+ b = Numo::NArray.cast([1, 2, 3])
146
+ Torch.from_numpy(b)
147
+ ```
148
+
149
+ ### Autograd
150
+
151
+ Create a tensor with `requires_grad: true`
152
+
153
+ ```ruby
154
+ x = Torch.ones(2, 2, requires_grad: true)
155
+ ```
156
+
157
+ Perform operations
158
+
159
+ ```ruby
160
+ y = x + 2
161
+ z = y * y * 3
162
+ out = z.mean
163
+ ```
164
+
165
+ Backprop
166
+
167
+ ```ruby
168
+ out.backward
169
+ ```
170
+
171
+ Get gradients
172
+
173
+ ```ruby
174
+ x.grad
175
+ ```
176
+
177
+ Stop autograd from tracking history
178
+
179
+ ```ruby
180
+ x.requires_grad # true
181
+ (x ** 2).requires_grad # true
182
+
183
+ Torch.no_grad do
184
+ (x ** 2).requires_grad # false
185
+ end
186
+ ```
187
+
188
+ ### Neural Networks
189
+
190
+ Define a neural network
191
+
192
+ ```ruby
193
+ class Net < Torch::NN::Module
194
+ def initialize
195
+ super
196
+ @conv1 = Torch::NN::Conv2d.new(1, 6, 3)
197
+ @conv2 = Torch::NN::Conv2d.new(6, 16, 3)
198
+ @fc1 = Torch::NN::Linear.new(16 * 6 * 6, 120)
199
+ @fc2 = Torch::NN::Linear.new(120, 84)
200
+ @fc3 = Torch::NN::Linear.new(84, 10)
201
+ end
202
+
203
+ def forward(x)
204
+ x = Torch::NN::F.max_pool2d(Torch::NN::F.relu(@conv1.call(x)), [2, 2])
205
+ x = Torch::NN::F.max_pool2d(Torch::NN::F.relu(@conv2.call(x)), 2)
206
+ x = x.view(-1, num_flat_features(x))
207
+ x = Torch::NN::F.relu(@fc1.call(x))
208
+ x = Torch::NN::F.relu(@fc2.call(x))
209
+ x = @fc3.call(x)
210
+ x
211
+ end
212
+
213
+ def num_flat_features(x)
214
+ size = x.size[1..-1]
215
+ num_features = 1
216
+ size.each do |s|
217
+ num_features *= s
218
+ end
219
+ num_features
220
+ end
221
+ end
222
+ ```
223
+
224
+ And run
225
+
226
+ ```ruby
227
+ net = Net.new
228
+ input = Torch.randn(1, 1, 32, 32)
229
+ net.call(input)
230
+ ```
231
+
232
+ ### Tensor Creation
233
+
234
+ Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
235
+
236
+ - `arange` returns a tensor with a sequence of integers
237
+
238
+ ```ruby
239
+ Torch.arange(3) # tensor([0, 1, 2])
240
+ ```
241
+
242
+ - `empty` returns a tensor with uninitialized values
243
+
244
+ ```ruby
245
+ Torch.empty(3)
246
+ ```
247
+
248
+ - `eye` returns an identity matrix
249
+
250
+ ```ruby
251
+ Torch.eye(2) # tensor([[1, 0], [0, 1]])
252
+ ```
253
+
254
+ - `full` returns a tensor filled with a single value
255
+
256
+ ```ruby
257
+ Torch.full([3], 5) # tensor([5, 5, 5])
258
+ ```
259
+
260
+ - `linspace` returns a tensor with values linearly spaced in some interval
261
+
262
+ ```ruby
263
+ Torch.linspace(0, 10, 5) # tensor([0, 5, 10])
264
+ ```
265
+
266
+ - `logspace` returns a tensor with values logarithmically spaced in some interval
267
+
268
+ ```ruby
269
+ Torch.logspace(0, 10, 5) # tensor([1, 1e5, 1e10])
270
+ ```
271
+
272
+ - `ones` returns a tensor filled with all ones
273
+
274
+ ```ruby
275
+ Torch.ones(3) # tensor([1, 1, 1])
276
+ ```
277
+
278
+ - `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
279
+
280
+ ```ruby
281
+ Torch.rand(3)
282
+ ```
283
+
284
+ - `randint` returns a tensor with integers randomly drawn from an interval
285
+
286
+ ```ruby
287
+ Torch.randint(1, 10, [3])
288
+ ```
289
+
290
+ - `randn` returns a tensor filled with values drawn from a unit normal distribution
291
+
292
+ ```ruby
293
+ Torch.randn(3)
294
+ ```
295
+
296
+ - `randperm` returns a tensor filled with a random permutation of integers in some interval
297
+
298
+ ```ruby
299
+ Torch.randperm(3) # tensor([2, 0, 1])
300
+ ```
301
+
302
+ - `zeros` returns a tensor filled with all zeros
303
+
304
+ ```ruby
305
+ Torch.zeros(3) # tensor([0, 0, 0])
306
+ ```
307
+
308
+ ## LibTorch Installation
309
+
310
+ [Download LibTorch](https://pytorch.org/) and run:
311
+
312
+ ```sh
313
+ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
314
+ ```
315
+
316
+ ### Homebrew
317
+
318
+ For Mac, you can use Homebrew.
319
+
320
+ ```sh
321
+ brew install ankane/brew/libtorch
322
+ ```
323
+
324
+ Then install the gem (no need for `--with-torch-dir`).
325
+
326
+ ## rbenv
327
+
328
+ This library uses [Rice](https://github.com/jasonroelofs/rice) to interface with LibTorch. Rice and earlier versions of rbenv don’t play nicely together. If you encounter an error during installation, upgrade ruby-build and reinstall your Ruby version.
329
+
330
+ ```sh
331
+ brew upgrade ruby-build
332
+ rbenv install [version]
333
+ ```
334
+
335
+ ## History
336
+
337
+ View the [changelog](https://github.com/ankane/torch-rb/blob/master/CHANGELOG.md)
338
+
339
+ ## Contributing
340
+
341
+ Everyone is encouraged to help improve this project. Here are a few ways you can help:
342
+
343
+ - [Report bugs](https://github.com/ankane/torch-rb/issues)
344
+ - Fix bugs and [submit pull requests](https://github.com/ankane/torch-rb/pulls)
345
+ - Write, clarify, or fix documentation
346
+ - Suggest or add new features
347
+
348
+ To get started with development:
349
+
350
+ ```sh
351
+ git clone https://github.com/ankane/torch-rb.git
352
+ cd torch
353
+ bundle install
354
+ bundle exec rake compile
355
+ bundle exec rake test
356
+ ```
357
+
358
+ Here are some good resources for contributors:
359
+
360
+ - [PyTorch API](https://pytorch.org/docs/stable/torch.html)
361
+ - [PyTorch C++ API](https://pytorch.org/cppdocs/)
362
+ - [Tensor Creation API](https://pytorch.org/cppdocs/notes/tensor_creation.html)
363
+ - [Using the PyTorch C++ Frontend](https://pytorch.org/tutorials/advanced/cpp_frontend.html)
data/ext/torch/ext.cpp ADDED
@@ -0,0 +1,546 @@
1
+ #include <sstream>
2
+
3
+ #include <torch/torch.h>
4
+
5
+ #include <rice/Array.hpp>
6
+ #include <rice/Class.hpp>
7
+ #include <rice/Constructor.hpp>
8
+
9
+ using namespace Rice;
10
+
11
+ template<>
12
+ inline
13
+ long long from_ruby<long long>(Object x)
14
+ {
15
+ return NUM2LL(x);
16
+ }
17
+
18
+ template<>
19
+ inline
20
+ Object to_ruby<long long>(long long const & x)
21
+ {
22
+ return LL2NUM(x);
23
+ }
24
+
25
+ template<>
26
+ inline
27
+ unsigned long long from_ruby<unsigned long long>(Object x)
28
+ {
29
+ return NUM2ULL(x);
30
+ }
31
+
32
+ template<>
33
+ inline
34
+ Object to_ruby<unsigned long long>(unsigned long long const & x)
35
+ {
36
+ return ULL2NUM(x);
37
+ }
38
+
39
+ template<>
40
+ inline
41
+ short from_ruby<short>(Object x)
42
+ {
43
+ return NUM2SHORT(x);
44
+ }
45
+
46
+ template<>
47
+ inline
48
+ Object to_ruby<short>(short const & x)
49
+ {
50
+ return INT2NUM(x);
51
+ }
52
+
53
+ template<>
54
+ inline
55
+ unsigned short from_ruby<unsigned short>(Object x)
56
+ {
57
+ return NUM2USHORT(x);
58
+ }
59
+
60
+ template<>
61
+ inline
62
+ Object to_ruby<unsigned short>(unsigned short const & x)
63
+ {
64
+ return UINT2NUM(x);
65
+ }
66
+
67
+ // need to wrap torch::IntArrayRef() since
68
+ // it doesn't own underlying data
69
+ class IntArrayRef {
70
+ std::vector<int64_t> vec;
71
+ public:
72
+ IntArrayRef(Object o) {
73
+ Array a = Array(o);
74
+ for (size_t i = 0; i < a.size(); i++) {
75
+ vec.push_back(from_ruby<int64_t>(a[i]));
76
+ }
77
+ }
78
+ operator torch::IntArrayRef() {
79
+ return torch::IntArrayRef(vec);
80
+ }
81
+ };
82
+
83
+ template<>
84
+ inline
85
+ IntArrayRef from_ruby<IntArrayRef>(Object x)
86
+ {
87
+ return IntArrayRef(x);
88
+ }
89
+
90
+ // for now
91
+ typedef float Scalar;
92
+
93
+ extern "C"
94
+ void Init_ext()
95
+ {
96
+ Module rb_mTorch = define_module("Torch")
97
+ .define_singleton_method(
98
+ "grad_enabled?",
99
+ *[]() {
100
+ return torch::GradMode::is_enabled();
101
+ })
102
+ .define_singleton_method(
103
+ "_set_grad_enabled",
104
+ *[](bool enabled) {
105
+ torch::GradMode::set_enabled(enabled);
106
+ })
107
+ .define_singleton_method(
108
+ "floating_point?",
109
+ *[](torch::Tensor& input) {
110
+ return torch::is_floating_point(input);
111
+ })
112
+ .define_singleton_method(
113
+ "manual_seed",
114
+ *[](uint64_t seed) {
115
+ return torch::manual_seed(seed);
116
+ })
117
+ // begin tensor creation
118
+ .define_singleton_method(
119
+ "_arange",
120
+ *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
121
+ return torch::arange(start, end, step, options);
122
+ })
123
+ .define_singleton_method(
124
+ "_empty",
125
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
126
+ return torch::empty(size, options);
127
+ })
128
+ .define_singleton_method(
129
+ "_eye",
130
+ *[](int64_t m, int64_t n, const torch::TensorOptions &options) {
131
+ return torch::eye(m, n, options);
132
+ })
133
+ .define_singleton_method(
134
+ "_full",
135
+ *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
136
+ return torch::full(size, fill_value, options);
137
+ })
138
+ .define_singleton_method(
139
+ "_linspace",
140
+ *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
141
+ return torch::linspace(start, end, steps, options);
142
+ })
143
+ .define_singleton_method(
144
+ "_logspace",
145
+ *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
146
+ return torch::logspace(start, end, steps, base, options);
147
+ })
148
+ .define_singleton_method(
149
+ "_ones",
150
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
151
+ return torch::ones(size, options);
152
+ })
153
+ .define_singleton_method(
154
+ "_rand",
155
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
156
+ return torch::rand(size, options);
157
+ })
158
+ .define_singleton_method(
159
+ "_randint",
160
+ *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
161
+ return torch::randint(low, high, size, options);
162
+ })
163
+ .define_singleton_method(
164
+ "_randn",
165
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
166
+ return torch::randn(size, options);
167
+ })
168
+ .define_singleton_method(
169
+ "_randperm",
170
+ *[](int64_t n, const torch::TensorOptions &options) {
171
+ return torch::randperm(n, options);
172
+ })
173
+ .define_singleton_method(
174
+ "_zeros",
175
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
176
+ return torch::zeros(size, options);
177
+ })
178
+ // begin operations
179
+ .define_singleton_method(
180
+ "_mean",
181
+ *[](torch::Tensor& input) {
182
+ return torch::mean(input);
183
+ })
184
+ .define_singleton_method(
185
+ "_mean_dim",
186
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
187
+ return torch::mean(input, dim, keepdim);
188
+ })
189
+ .define_singleton_method(
190
+ "_sum",
191
+ *[](torch::Tensor& input) {
192
+ return torch::sum(input);
193
+ })
194
+ .define_singleton_method(
195
+ "_sum_dim",
196
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
197
+ return torch::sum(input, dim, keepdim);
198
+ })
199
+ .define_singleton_method(
200
+ "_norm",
201
+ *[](torch::Tensor& input) {
202
+ return torch::norm(input);
203
+ })
204
+ .define_singleton_method(
205
+ "_min",
206
+ *[](torch::Tensor& input) {
207
+ return torch::min(input);
208
+ })
209
+ .define_singleton_method(
210
+ "_max",
211
+ *[](torch::Tensor& input) {
212
+ return torch::max(input);
213
+ })
214
+ .define_singleton_method(
215
+ "_exp",
216
+ *[](torch::Tensor& input) {
217
+ return torch::exp(input);
218
+ })
219
+ .define_singleton_method(
220
+ "_log",
221
+ *[](torch::Tensor& input) {
222
+ return torch::log(input);
223
+ })
224
+ .define_singleton_method(
225
+ "_unsqueeze",
226
+ *[](torch::Tensor& input, int64_t dim) {
227
+ return torch::unsqueeze(input, dim);
228
+ })
229
+ .define_singleton_method(
230
+ "_dot",
231
+ *[](torch::Tensor& input, torch::Tensor& tensor) {
232
+ return torch::dot(input, tensor);
233
+ })
234
+ .define_singleton_method(
235
+ "_matmul",
236
+ *[](torch::Tensor& input, torch::Tensor& other) {
237
+ return torch::matmul(input, other);
238
+ })
239
+ .define_singleton_method(
240
+ "_add",
241
+ *[](torch::Tensor& input, torch::Tensor& other) {
242
+ return torch::add(input, other);
243
+ })
244
+ .define_singleton_method(
245
+ "_add_scalar",
246
+ *[](torch::Tensor& input, float other) {
247
+ return torch::add(input, other);
248
+ })
249
+ .define_singleton_method(
250
+ "_add_out",
251
+ *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
252
+ return torch::add_out(out, input, other);
253
+ })
254
+ .define_singleton_method(
255
+ "_sub",
256
+ *[](torch::Tensor& input, torch::Tensor& other) {
257
+ return torch::sub(input, other);
258
+ })
259
+ .define_singleton_method(
260
+ "_sub_scalar",
261
+ *[](torch::Tensor& input, float other) {
262
+ return torch::sub(input, other);
263
+ })
264
+ .define_singleton_method(
265
+ "_mul",
266
+ *[](torch::Tensor& input, torch::Tensor& other) {
267
+ return torch::mul(input, other);
268
+ })
269
+ .define_singleton_method(
270
+ "_mul_scalar",
271
+ *[](torch::Tensor& input, float other) {
272
+ return torch::mul(input, other);
273
+ })
274
+ .define_singleton_method(
275
+ "_div",
276
+ *[](torch::Tensor& input, torch::Tensor& other) {
277
+ return torch::div(input, other);
278
+ })
279
+ .define_singleton_method(
280
+ "_div_scalar",
281
+ *[](torch::Tensor& input, float other) {
282
+ return torch::div(input, other);
283
+ })
284
+ .define_singleton_method(
285
+ "_remainder",
286
+ *[](torch::Tensor& input, torch::Tensor& other) {
287
+ return torch::remainder(input, other);
288
+ })
289
+ .define_singleton_method(
290
+ "_remainder_scalar",
291
+ *[](torch::Tensor& input, float other) {
292
+ return torch::remainder(input, other);
293
+ })
294
+ .define_singleton_method(
295
+ "_pow",
296
+ *[](torch::Tensor& input, Scalar exponent) {
297
+ return torch::pow(input, exponent);
298
+ })
299
+ .define_singleton_method(
300
+ "_neg",
301
+ *[](torch::Tensor& input) {
302
+ return torch::neg(input);
303
+ })
304
+ .define_singleton_method(
305
+ "relu",
306
+ *[](torch::Tensor& input) {
307
+ return torch::relu(input);
308
+ })
309
+ .define_singleton_method(
310
+ "conv2d",
311
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
312
+ return torch::conv2d(input, weight, bias);
313
+ })
314
+ .define_singleton_method(
315
+ "linear",
316
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
317
+ return torch::linear(input, weight, bias);
318
+ })
319
+ .define_singleton_method(
320
+ "max_pool2d",
321
+ *[](torch::Tensor& input, IntArrayRef kernel_size) {
322
+ return torch::max_pool2d(input, kernel_size);
323
+ })
324
+ .define_singleton_method(
325
+ "mse_loss",
326
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
327
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
328
+ return torch::mse_loss(input, target, red);
329
+ })
330
+ .define_singleton_method(
331
+ "nll_loss",
332
+ *[](torch::Tensor& input, torch::Tensor& target) {
333
+ return torch::nll_loss(input, target);
334
+ })
335
+ .define_singleton_method(
336
+ "_tensor",
337
+ *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
338
+ Array a = Array(o);
339
+ std::vector<float> vec;
340
+ for (size_t i = 0; i < a.size(); i++) {
341
+ vec.push_back(from_ruby<float>(a[i]));
342
+ }
343
+ return torch::tensor(vec, options).reshape(size);
344
+ });
345
+
346
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
347
+ .define_method("cuda?", &torch::Tensor::is_cuda)
348
+ .define_method("distributed?", &torch::Tensor::is_distributed)
349
+ .define_method("complex?", &torch::Tensor::is_complex)
350
+ .define_method("floating_point?", &torch::Tensor::is_floating_point)
351
+ .define_method("signed?", &torch::Tensor::is_signed)
352
+ .define_method("sparse?", &torch::Tensor::is_sparse)
353
+ .define_method("quantized?", &torch::Tensor::is_quantized)
354
+ .define_method("dim", &torch::Tensor::dim)
355
+ .define_method("numel", &torch::Tensor::numel)
356
+ .define_method("element_size", &torch::Tensor::element_size)
357
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
358
+ .define_method(
359
+ "zero!",
360
+ *[](torch::Tensor& self) {
361
+ return self.zero_();
362
+ })
363
+ .define_method(
364
+ "detach!",
365
+ *[](torch::Tensor& self) {
366
+ return self.detach_();
367
+ })
368
+ .define_method(
369
+ "_access",
370
+ *[](torch::Tensor& self, int64_t index) {
371
+ return self[index];
372
+ })
373
+ .define_method(
374
+ "_requires_grad!",
375
+ *[](torch::Tensor& self, bool requires_grad) {
376
+ return self.set_requires_grad(requires_grad);
377
+ })
378
+ .define_method(
379
+ "backward",
380
+ *[](torch::Tensor& self) {
381
+ return self.backward();
382
+ })
383
+ .define_method(
384
+ "grad",
385
+ *[](torch::Tensor& self) {
386
+ return self.grad();
387
+ })
388
+ .define_method(
389
+ "_dtype",
390
+ *[](torch::Tensor& self) {
391
+ return (int) at::typeMetaToScalarType(self.dtype());
392
+ })
393
+ .define_method(
394
+ "_layout",
395
+ *[](torch::Tensor& self) {
396
+ std::stringstream s;
397
+ s << self.layout();
398
+ return s.str();
399
+ })
400
+ .define_method(
401
+ "device",
402
+ *[](torch::Tensor& self) {
403
+ std::stringstream s;
404
+ s << self.device();
405
+ return s.str();
406
+ })
407
+ .define_method(
408
+ "_view",
409
+ *[](torch::Tensor& self, IntArrayRef size) {
410
+ return self.view(size);
411
+ })
412
+ .define_method(
413
+ "add!",
414
+ *[](torch::Tensor& self, torch::Tensor& other) {
415
+ self.add_(other);
416
+ })
417
+ .define_method(
418
+ "sub!",
419
+ *[](torch::Tensor& self, torch::Tensor& other) {
420
+ self.sub_(other);
421
+ })
422
+ .define_method(
423
+ "mul!",
424
+ *[](torch::Tensor& self, torch::Tensor& other) {
425
+ self.mul_(other);
426
+ })
427
+ .define_method(
428
+ "div!",
429
+ *[](torch::Tensor& self, torch::Tensor& other) {
430
+ self.div_(other);
431
+ })
432
+ .define_method(
433
+ "log_softmax",
434
+ *[](torch::Tensor& self, int64_t dim) {
435
+ return self.log_softmax(dim);
436
+ })
437
+ .define_method(
438
+ "_data",
439
+ *[](torch::Tensor& self) {
440
+ Array a;
441
+ auto dtype = self.dtype();
442
+
443
+ // TODO DRY if someone knows C++
444
+ // TODO kByte (uint8), kChar (int8), kBool (bool)
445
+ if (dtype == torch::kShort) {
446
+ short* data = self.data_ptr<short>();
447
+ for (int i = 0; i < self.numel(); i++) {
448
+ a.push(data[i]);
449
+ }
450
+ } else if (dtype == torch::kInt) {
451
+ int* data = self.data_ptr<int>();
452
+ for (int i = 0; i < self.numel(); i++) {
453
+ a.push(data[i]);
454
+ }
455
+ } else if (dtype == torch::kLong) {
456
+ long long* data = self.data_ptr<long long>();
457
+ for (int i = 0; i < self.numel(); i++) {
458
+ a.push(data[i]);
459
+ }
460
+ } else if (dtype == torch::kFloat) {
461
+ float* data = self.data_ptr<float>();
462
+ for (int i = 0; i < self.numel(); i++) {
463
+ a.push(data[i]);
464
+ }
465
+ } else if (dtype == torch::kDouble) {
466
+ double* data = self.data_ptr<double>();
467
+ for (int i = 0; i < self.numel(); i++) {
468
+ a.push(data[i]);
469
+ }
470
+ } else {
471
+ throw "Unsupported type";
472
+ }
473
+ return a;
474
+ })
475
+ .define_method(
476
+ "_size",
477
+ *[](torch::Tensor& self, int i) {
478
+ return self.size(i);
479
+ })
480
+ .define_singleton_method(
481
+ "_make_subclass",
482
+ *[](torch::Tensor& rd, bool requires_grad) {
483
+ auto data = torch::autograd::as_variable_ref(rd).detach();
484
+ data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
485
+ auto var = data.set_requires_grad(requires_grad);
486
+ return torch::autograd::Variable(std::move(var));
487
+ });
488
+
489
+ Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
490
+ .define_constructor(Constructor<torch::TensorOptions>())
491
+ .define_method(
492
+ "dtype",
493
+ *[](torch::TensorOptions& self, int dtype) {
494
+ return self.dtype((torch::ScalarType) dtype);
495
+ })
496
+ .define_method(
497
+ "layout",
498
+ *[](torch::TensorOptions& self, std::string layout) {
499
+ torch::Layout l;
500
+ if (layout == "strided") {
501
+ l = torch::kStrided;
502
+ } else {
503
+ throw "Unsupported layout";
504
+ }
505
+ return self.layout(l);
506
+ })
507
+ .define_method(
508
+ "device",
509
+ *[](torch::TensorOptions& self, std::string device) {
510
+ torch::DeviceType d;
511
+ if (device == "cpu") {
512
+ d = torch::kCPU;
513
+ } else if (device == "cuda") {
514
+ d = torch::kCUDA;
515
+ } else {
516
+ throw "Unsupported device";
517
+ }
518
+ return self.device(d);
519
+ })
520
+ .define_method(
521
+ "requires_grad",
522
+ *[](torch::TensorOptions& self, bool requires_grad) {
523
+ return self.requires_grad(requires_grad);
524
+ });
525
+
526
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
527
+
528
+ Module rb_mInit = define_module_under(rb_mNN, "Init")
529
+ .define_singleton_method(
530
+ "kaiming_uniform_",
531
+ *[](torch::Tensor& input, double a) {
532
+ return torch::nn::init::kaiming_uniform_(input, a);
533
+ })
534
+ .define_singleton_method(
535
+ "uniform_",
536
+ *[](torch::Tensor& input, double to, double from) {
537
+ return torch::nn::init::uniform_(input, to, from);
538
+ });
539
+
540
+ Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
541
+ .define_method(
542
+ "grad",
543
+ *[](torch::autograd::Variable& self) {
544
+ return self.grad();
545
+ });
546
+ }