torch-rb 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +55 -3
- data/ext/torch/ext.cpp +68 -7
- data/lib/torch.rb +27 -3
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +27 -5
- data/lib/torch/nn/conv2d.rb +4 -5
- data/lib/torch/nn/functional.rb +16 -2
- data/lib/torch/nn/module.rb +0 -1
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/optim/optimizer.rb +6 -0
- data/lib/torch/optim/sgd.rb +28 -0
- data/lib/torch/tensor.rb +30 -11
- data/lib/torch/utils/data/data_loader.rb +9 -0
- data/lib/torch/utils/data/tensor_dataset.rb +5 -1
- data/lib/torch/version.rb +1 -1
- metadata +3 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 30089078de4039df111087e5c27e0cb10d6f36398c0e8d5cc774e9b642a8e133
|
4
|
+
data.tar.gz: 89eb9e183b395dd67cd9cf228749cf26402993bb561de973f1ba7438bc372b04
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 027a069b00ac1329c007ddaf471a21b57a82a823ad974a937f832d17720b8e26474c64c79e9a29ec71bac433abb3d74d6a7cf407f0a983bb3c0cafb5b5c7532f
|
7
|
+
data.tar.gz: 6d7ef10b53db0df39eda13d07aa9b52b4afac0965674919b5cc517e7b53f59a9010cb647e50d62bc06154f7d8f3ef632d5897e4f7774372d7ab1b44b2cb6ca82
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
## 0.1.2 (2019-11-27)
|
2
|
+
|
3
|
+
- Added SGD optimizer
|
4
|
+
- Added support for gradient to `backward` method
|
5
|
+
- Added `argmax`, `eq`, `leaky_relu`, `prelu`, and `reshape` methods
|
6
|
+
- Improved indexing
|
7
|
+
- Fixed `zero_grad`
|
8
|
+
- Fixed error with infinite values
|
9
|
+
|
1
10
|
## 0.1.1 (2019-11-26)
|
2
11
|
|
3
12
|
- Added support for `uint8` and `int8` types
|
data/README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
This gem is currently experimental. There may be breaking changes between each release.
|
5
|
+
This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
|
6
6
|
|
7
7
|
[![Build Status](https://travis-ci.org/ankane/torch-rb.svg?branch=master)](https://travis-ci.org/ankane/torch-rb)
|
8
8
|
|
@@ -223,7 +223,7 @@ class Net < Torch::NN::Module
|
|
223
223
|
end
|
224
224
|
```
|
225
225
|
|
226
|
-
|
226
|
+
Create an instance of it
|
227
227
|
|
228
228
|
```ruby
|
229
229
|
net = Net.new
|
@@ -231,6 +231,58 @@ input = Torch.randn(1, 1, 32, 32)
|
|
231
231
|
net.call(input)
|
232
232
|
```
|
233
233
|
|
234
|
+
Get trainable parameters
|
235
|
+
|
236
|
+
```ruby
|
237
|
+
net.parameters
|
238
|
+
```
|
239
|
+
|
240
|
+
Zero the gradient buffers and backprop with random gradients
|
241
|
+
|
242
|
+
```ruby
|
243
|
+
net.zero_grad
|
244
|
+
out.backward(Torch.randn(1, 10))
|
245
|
+
```
|
246
|
+
|
247
|
+
Define a loss function
|
248
|
+
|
249
|
+
```ruby
|
250
|
+
output = net.call(input)
|
251
|
+
target = Torch.randn(10)
|
252
|
+
target = target.view(1, -1)
|
253
|
+
criterion = Torch::NN::MSELoss.new
|
254
|
+
loss = criterion.call(output, target)
|
255
|
+
```
|
256
|
+
|
257
|
+
Backprop
|
258
|
+
|
259
|
+
```ruby
|
260
|
+
net.zero_grad
|
261
|
+
p net.conv1.bias.grad
|
262
|
+
loss.backward
|
263
|
+
p net.conv1.bias.grad
|
264
|
+
```
|
265
|
+
|
266
|
+
Update the weights
|
267
|
+
|
268
|
+
```ruby
|
269
|
+
learning_rate = 0.01
|
270
|
+
net.parameters.each do |f|
|
271
|
+
f.data.sub!(f.grad.data * learning_rate)
|
272
|
+
end
|
273
|
+
```
|
274
|
+
|
275
|
+
Use an optimizer
|
276
|
+
|
277
|
+
```ruby
|
278
|
+
optimizer = Torch::Optim::SGD.new(net.parameters, lr: 0.01)
|
279
|
+
optimizer.zero_grad
|
280
|
+
output = net.call(input)
|
281
|
+
loss = criterion.call(output, target)
|
282
|
+
loss.backward
|
283
|
+
optimizer.step
|
284
|
+
```
|
285
|
+
|
234
286
|
### Tensor Creation
|
235
287
|
|
236
288
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -351,7 +403,7 @@ To get started with development:
|
|
351
403
|
|
352
404
|
```sh
|
353
405
|
git clone https://github.com/ankane/torch-rb.git
|
354
|
-
cd torch
|
406
|
+
cd torch-rb
|
355
407
|
bundle install
|
356
408
|
bundle exec rake compile
|
357
409
|
bundle exec rake test
|
data/ext/torch/ext.cpp
CHANGED
@@ -196,6 +196,16 @@ void Init_ext()
|
|
196
196
|
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
|
197
197
|
return torch::sum(input, dim, keepdim);
|
198
198
|
})
|
199
|
+
.define_singleton_method(
|
200
|
+
"_argmax",
|
201
|
+
*[](torch::Tensor& input) {
|
202
|
+
return torch::argmax(input);
|
203
|
+
})
|
204
|
+
.define_singleton_method(
|
205
|
+
"_argmax_dim",
|
206
|
+
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
|
207
|
+
return torch::argmax(input, dim, keepdim);
|
208
|
+
})
|
199
209
|
.define_singleton_method(
|
200
210
|
"_norm",
|
201
211
|
*[](torch::Tensor& input) {
|
@@ -236,6 +246,11 @@ void Init_ext()
|
|
236
246
|
*[](torch::Tensor& input, torch::Tensor& other) {
|
237
247
|
return torch::matmul(input, other);
|
238
248
|
})
|
249
|
+
.define_singleton_method(
|
250
|
+
"_eq",
|
251
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
252
|
+
return torch::eq(input, other);
|
253
|
+
})
|
239
254
|
.define_singleton_method(
|
240
255
|
"_add",
|
241
256
|
*[](torch::Tensor& input, torch::Tensor& other) {
|
@@ -301,15 +316,30 @@ void Init_ext()
|
|
301
316
|
*[](torch::Tensor& input) {
|
302
317
|
return torch::neg(input);
|
303
318
|
})
|
319
|
+
.define_singleton_method(
|
320
|
+
"_reshape",
|
321
|
+
*[](torch::Tensor& input, IntArrayRef shape) {
|
322
|
+
return torch::reshape(input, shape);
|
323
|
+
})
|
304
324
|
.define_singleton_method(
|
305
325
|
"relu",
|
306
326
|
*[](torch::Tensor& input) {
|
307
327
|
return torch::relu(input);
|
308
328
|
})
|
329
|
+
.define_singleton_method(
|
330
|
+
"prelu",
|
331
|
+
*[](torch::Tensor& input, torch::Tensor& weight) {
|
332
|
+
return torch::prelu(input, weight);
|
333
|
+
})
|
334
|
+
.define_singleton_method(
|
335
|
+
"leaky_relu",
|
336
|
+
*[](torch::Tensor& input, Scalar negative_slope = 0.01) {
|
337
|
+
return torch::leaky_relu(input, negative_slope);
|
338
|
+
})
|
309
339
|
.define_singleton_method(
|
310
340
|
"conv2d",
|
311
|
-
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
|
312
|
-
return torch::conv2d(input, weight, bias);
|
341
|
+
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
|
342
|
+
return torch::conv2d(input, weight, bias, stride, padding);
|
313
343
|
})
|
314
344
|
.define_singleton_method(
|
315
345
|
"linear",
|
@@ -321,6 +351,11 @@ void Init_ext()
|
|
321
351
|
*[](torch::Tensor& input, IntArrayRef kernel_size) {
|
322
352
|
return torch::max_pool2d(input, kernel_size);
|
323
353
|
})
|
354
|
+
.define_singleton_method(
|
355
|
+
"avg_pool2d",
|
356
|
+
*[](torch::Tensor& input, IntArrayRef kernel_size) {
|
357
|
+
return torch::avg_pool2d(input, kernel_size);
|
358
|
+
})
|
324
359
|
.define_singleton_method(
|
325
360
|
"mse_loss",
|
326
361
|
*[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
|
@@ -366,9 +401,14 @@ void Init_ext()
|
|
366
401
|
return self.detach_();
|
367
402
|
})
|
368
403
|
.define_method(
|
369
|
-
"
|
370
|
-
*[](torch::Tensor& self, int64_t index) {
|
371
|
-
return self
|
404
|
+
"_select",
|
405
|
+
*[](torch::Tensor& self, int64_t dim, int64_t index) {
|
406
|
+
return self.select(dim, index);
|
407
|
+
})
|
408
|
+
.define_method(
|
409
|
+
"_slice",
|
410
|
+
*[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
411
|
+
return self.slice(dim, start, end, step);
|
372
412
|
})
|
373
413
|
.define_method(
|
374
414
|
"_requires_grad!",
|
@@ -376,10 +416,15 @@ void Init_ext()
|
|
376
416
|
return self.set_requires_grad(requires_grad);
|
377
417
|
})
|
378
418
|
.define_method(
|
379
|
-
"
|
419
|
+
"_backward",
|
380
420
|
*[](torch::Tensor& self) {
|
381
421
|
return self.backward();
|
382
422
|
})
|
423
|
+
.define_method(
|
424
|
+
"_backward_gradient",
|
425
|
+
*[](torch::Tensor& self, const torch::Tensor& gradient) {
|
426
|
+
return self.backward(gradient);
|
427
|
+
})
|
383
428
|
.define_method(
|
384
429
|
"grad",
|
385
430
|
*[](torch::Tensor& self) {
|
@@ -390,6 +435,11 @@ void Init_ext()
|
|
390
435
|
*[](torch::Tensor& self) {
|
391
436
|
return (int) at::typeMetaToScalarType(self.dtype());
|
392
437
|
})
|
438
|
+
.define_method(
|
439
|
+
"_type",
|
440
|
+
*[](torch::Tensor& self, int dtype) {
|
441
|
+
return self.toType((torch::ScalarType) dtype);
|
442
|
+
})
|
393
443
|
.define_method(
|
394
444
|
"_layout",
|
395
445
|
*[](torch::Tensor& self) {
|
@@ -434,6 +484,11 @@ void Init_ext()
|
|
434
484
|
*[](torch::Tensor& self, int64_t dim) {
|
435
485
|
return self.log_softmax(dim);
|
436
486
|
})
|
487
|
+
.define_method(
|
488
|
+
"data",
|
489
|
+
*[](torch::Tensor& self) {
|
490
|
+
return self.data();
|
491
|
+
})
|
437
492
|
.define_method(
|
438
493
|
"_data",
|
439
494
|
*[](torch::Tensor& self) {
|
@@ -553,9 +608,15 @@ void Init_ext()
|
|
553
608
|
});
|
554
609
|
|
555
610
|
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
|
611
|
+
// TODO return grad or nil to remove need for 2nd function
|
556
612
|
.define_method(
|
557
|
-
"
|
613
|
+
"_grad",
|
558
614
|
*[](torch::autograd::Variable& self) {
|
559
615
|
return self.grad();
|
616
|
+
})
|
617
|
+
.define_method(
|
618
|
+
"_grad_defined",
|
619
|
+
*[](torch::autograd::Variable& self) {
|
620
|
+
return self.grad().defined();
|
560
621
|
});
|
561
622
|
}
|
data/lib/torch.rb
CHANGED
@@ -6,6 +6,10 @@ require "torch/inspector"
|
|
6
6
|
require "torch/tensor"
|
7
7
|
require "torch/version"
|
8
8
|
|
9
|
+
# optim
|
10
|
+
require "torch/optim/optimizer"
|
11
|
+
require "torch/optim/sgd"
|
12
|
+
|
9
13
|
# nn
|
10
14
|
require "torch/nn/module"
|
11
15
|
require "torch/nn/init"
|
@@ -55,9 +59,13 @@ module Torch
|
|
55
59
|
|
56
60
|
class << self
|
57
61
|
# Torch.float, Torch.long, etc
|
58
|
-
DTYPE_TO_ENUM.each_key do |
|
59
|
-
define_method(
|
60
|
-
|
62
|
+
DTYPE_TO_ENUM.each_key do |dtype|
|
63
|
+
define_method(dtype) do
|
64
|
+
dtype
|
65
|
+
end
|
66
|
+
|
67
|
+
Tensor.define_method(dtype) do
|
68
|
+
type(dtype)
|
61
69
|
end
|
62
70
|
end
|
63
71
|
|
@@ -240,6 +248,18 @@ module Torch
|
|
240
248
|
end
|
241
249
|
end
|
242
250
|
|
251
|
+
def argmax(input, dim = nil, keepdim: false)
|
252
|
+
if dim
|
253
|
+
_argmax_dim(input, dim, keepdim)
|
254
|
+
else
|
255
|
+
_argmax(input)
|
256
|
+
end
|
257
|
+
end
|
258
|
+
|
259
|
+
def eq(input, other)
|
260
|
+
_eq(input, other)
|
261
|
+
end
|
262
|
+
|
243
263
|
def norm(input)
|
244
264
|
_norm(input)
|
245
265
|
end
|
@@ -276,6 +296,10 @@ module Torch
|
|
276
296
|
_matmul(input, other)
|
277
297
|
end
|
278
298
|
|
299
|
+
def reshape(input, shape)
|
300
|
+
_reshape(input, shape)
|
301
|
+
end
|
302
|
+
|
279
303
|
private
|
280
304
|
|
281
305
|
def execute_op(op, input, other, out: nil)
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module Inspector
|
3
|
+
# TODO make more performance, especially when summarizing
|
3
4
|
def inspect
|
4
5
|
data =
|
5
6
|
if numel == 0
|
@@ -20,7 +21,7 @@ module Torch
|
|
20
21
|
if floating_point?
|
21
22
|
sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
|
22
23
|
|
23
|
-
all_int = values.all? { |v| v == v.to_i }
|
24
|
+
all_int = values.all? { |v| v.finite? && v == v.to_i }
|
24
25
|
decimal = all_int ? 1 : 4
|
25
26
|
|
26
27
|
total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
|
@@ -35,7 +36,9 @@ module Torch
|
|
35
36
|
fmt = "%#{total}d"
|
36
37
|
end
|
37
38
|
|
38
|
-
|
39
|
+
summarize = numel > 1000
|
40
|
+
|
41
|
+
inspect_level(to_a, fmt, dim - 1, 0, summarize)
|
39
42
|
end
|
40
43
|
|
41
44
|
attributes = []
|
@@ -51,11 +54,30 @@ module Torch
|
|
51
54
|
|
52
55
|
private
|
53
56
|
|
54
|
-
|
57
|
+
# TODO DRY code
|
58
|
+
def inspect_level(arr, fmt, total, level, summarize)
|
55
59
|
if level == total
|
56
|
-
|
60
|
+
cols =
|
61
|
+
if summarize && arr.size > 7
|
62
|
+
arr[0..2].map { |v| fmt % v } +
|
63
|
+
["..."] +
|
64
|
+
arr[-3..-1].map { |v| fmt % v }
|
65
|
+
else
|
66
|
+
arr.map { |v| fmt % v }
|
67
|
+
end
|
68
|
+
|
69
|
+
"[#{cols.join(", ")}]"
|
57
70
|
else
|
58
|
-
|
71
|
+
rows =
|
72
|
+
if summarize && arr.size > 7
|
73
|
+
arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
|
74
|
+
["..."] +
|
75
|
+
arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
|
76
|
+
else
|
77
|
+
arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
|
78
|
+
end
|
79
|
+
|
80
|
+
"[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
|
59
81
|
end
|
60
82
|
end
|
61
83
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -3,13 +3,12 @@ module Torch
|
|
3
3
|
class Conv2d < Module
|
4
4
|
attr_reader :bias, :weight
|
5
5
|
|
6
|
-
def initialize(in_channels, out_channels, kernel_size
|
6
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0) #, dilation: 1, groups: 1)
|
7
7
|
@in_channels = in_channels
|
8
8
|
@out_channels = out_channels
|
9
9
|
@kernel_size = pair(kernel_size)
|
10
|
-
@stride = pair(
|
11
|
-
|
12
|
-
# @padding = pair(padding)
|
10
|
+
@stride = pair(stride)
|
11
|
+
@padding = pair(padding)
|
13
12
|
# @dilation = pair(dilation)
|
14
13
|
|
15
14
|
# TODO divide by groups
|
@@ -29,7 +28,7 @@ module Torch
|
|
29
28
|
end
|
30
29
|
|
31
30
|
def call(input)
|
32
|
-
F.conv2d(input, @weight, @bias
|
31
|
+
F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding) #, @dilation, @groups)
|
33
32
|
end
|
34
33
|
|
35
34
|
def inspect
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -6,8 +6,17 @@ module Torch
|
|
6
6
|
Torch.relu(input)
|
7
7
|
end
|
8
8
|
|
9
|
-
def conv2d(input, weight, bias)
|
10
|
-
|
9
|
+
def conv2d(input, weight, bias, stride: 1, padding: 0)
|
10
|
+
# TODO pair stride and padding when needed
|
11
|
+
Torch.conv2d(input, weight, bias, stride, padding)
|
12
|
+
end
|
13
|
+
|
14
|
+
def prelu(input, weight)
|
15
|
+
Torch.prelu(input, weight)
|
16
|
+
end
|
17
|
+
|
18
|
+
def leaky_relu(input, negative_slope = 0.01)
|
19
|
+
Torch.leaky_relu(input, negative_slope)
|
11
20
|
end
|
12
21
|
|
13
22
|
def max_pool2d(input, kernel_size)
|
@@ -15,6 +24,11 @@ module Torch
|
|
15
24
|
Torch.max_pool2d(input, kernel_size)
|
16
25
|
end
|
17
26
|
|
27
|
+
def avg_pool2d(input, kernel_size)
|
28
|
+
kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
|
29
|
+
Torch.avg_pool2d(input, kernel_size)
|
30
|
+
end
|
31
|
+
|
18
32
|
def linear(input, weight, bias)
|
19
33
|
Torch.linear(input, weight, bias)
|
20
34
|
end
|
data/lib/torch/nn/module.rb
CHANGED
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,28 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
class SGD < Optimizer
|
4
|
+
def initialize(params, lr:)
|
5
|
+
@params = params
|
6
|
+
@lr = lr
|
7
|
+
end
|
8
|
+
|
9
|
+
def zero_grad
|
10
|
+
@params.each do |param|
|
11
|
+
if param.grad
|
12
|
+
param.grad.detach!
|
13
|
+
param.grad.zero!
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
|
18
|
+
def step
|
19
|
+
@params.each do |param|
|
20
|
+
next unless param.grad
|
21
|
+
d_p = param.grad.data
|
22
|
+
# same as param.data.add!(-@lr, d_p)
|
23
|
+
param.data.sub!(d_p * @lr)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -28,7 +28,7 @@ module Torch
|
|
28
28
|
end
|
29
29
|
|
30
30
|
def to_a
|
31
|
-
|
31
|
+
reshape_arr(_data, shape)
|
32
32
|
end
|
33
33
|
|
34
34
|
def size(dim = nil)
|
@@ -54,8 +54,12 @@ module Torch
|
|
54
54
|
_data.first
|
55
55
|
end
|
56
56
|
|
57
|
-
def
|
58
|
-
|
57
|
+
def backward(gradient = nil)
|
58
|
+
if gradient
|
59
|
+
_backward_gradient(gradient)
|
60
|
+
else
|
61
|
+
_backward
|
62
|
+
end
|
59
63
|
end
|
60
64
|
|
61
65
|
# TODO read directly from memory
|
@@ -74,8 +78,14 @@ module Torch
|
|
74
78
|
_requires_grad!(requires_grad)
|
75
79
|
end
|
76
80
|
|
81
|
+
def type(dtype)
|
82
|
+
enum = DTYPE_TO_ENUM[dtype]
|
83
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
84
|
+
_type(enum)
|
85
|
+
end
|
86
|
+
|
77
87
|
# operations
|
78
|
-
%w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
|
88
|
+
%w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze reshape argmax eq).each do |op|
|
79
89
|
define_method(op) do |*args, **options, &block|
|
80
90
|
if options.any?
|
81
91
|
Torch.send(op, self, *args, **options, &block)
|
@@ -117,18 +127,27 @@ module Torch
|
|
117
127
|
item <=> other
|
118
128
|
end
|
119
129
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
130
|
+
def [](*indexes)
|
131
|
+
result = self
|
132
|
+
dim = 0
|
133
|
+
indexes.each_with_index do |index|
|
134
|
+
if index.is_a?(Numeric)
|
135
|
+
result = result._select(dim, index)
|
136
|
+
elsif index.is_a?(Range)
|
137
|
+
finish = index.end
|
138
|
+
finish += 1 unless index.exclude_end?
|
139
|
+
result = result._slice(dim, index.begin, finish, 1)
|
140
|
+
dim += 1
|
141
|
+
else
|
142
|
+
raise Error, "Unsupported index type"
|
143
|
+
end
|
125
144
|
end
|
126
|
-
|
145
|
+
result
|
127
146
|
end
|
128
147
|
|
129
148
|
private
|
130
149
|
|
131
|
-
def
|
150
|
+
def reshape_arr(arr, dims)
|
132
151
|
if dims.empty?
|
133
152
|
arr
|
134
153
|
else
|
@@ -6,6 +6,15 @@ module Torch
|
|
6
6
|
@dataset = dataset
|
7
7
|
@batch_size = batch_size
|
8
8
|
end
|
9
|
+
|
10
|
+
def each
|
11
|
+
size = @dataset.size
|
12
|
+
start_index = 0
|
13
|
+
while start_index < size
|
14
|
+
yield @dataset[start_index...(start_index + @batch_size)]
|
15
|
+
start_index += @batch_size
|
16
|
+
end
|
17
|
+
end
|
9
18
|
end
|
10
19
|
end
|
11
20
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
@@ -119,6 +119,8 @@ files:
|
|
119
119
|
- lib/torch/nn/parameter.rb
|
120
120
|
- lib/torch/nn/relu.rb
|
121
121
|
- lib/torch/nn/sequential.rb
|
122
|
+
- lib/torch/optim/optimizer.rb
|
123
|
+
- lib/torch/optim/sgd.rb
|
122
124
|
- lib/torch/tensor.rb
|
123
125
|
- lib/torch/utils/data/data_loader.rb
|
124
126
|
- lib/torch/utils/data/tensor_dataset.rb
|