torch-rb 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +55 -3
- data/ext/torch/ext.cpp +68 -7
- data/lib/torch.rb +27 -3
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +27 -5
- data/lib/torch/nn/conv2d.rb +4 -5
- data/lib/torch/nn/functional.rb +16 -2
- data/lib/torch/nn/module.rb +0 -1
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/optim/optimizer.rb +6 -0
- data/lib/torch/optim/sgd.rb +28 -0
- data/lib/torch/tensor.rb +30 -11
- data/lib/torch/utils/data/data_loader.rb +9 -0
- data/lib/torch/utils/data/tensor_dataset.rb +5 -1
- data/lib/torch/version.rb +1 -1
- metadata +3 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 30089078de4039df111087e5c27e0cb10d6f36398c0e8d5cc774e9b642a8e133
|
4
|
+
data.tar.gz: 89eb9e183b395dd67cd9cf228749cf26402993bb561de973f1ba7438bc372b04
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 027a069b00ac1329c007ddaf471a21b57a82a823ad974a937f832d17720b8e26474c64c79e9a29ec71bac433abb3d74d6a7cf407f0a983bb3c0cafb5b5c7532f
|
7
|
+
data.tar.gz: 6d7ef10b53db0df39eda13d07aa9b52b4afac0965674919b5cc517e7b53f59a9010cb647e50d62bc06154f7d8f3ef632d5897e4f7774372d7ab1b44b2cb6ca82
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
## 0.1.2 (2019-11-27)
|
2
|
+
|
3
|
+
- Added SGD optimizer
|
4
|
+
- Added support for gradient to `backward` method
|
5
|
+
- Added `argmax`, `eq`, `leaky_relu`, `prelu`, and `reshape` methods
|
6
|
+
- Improved indexing
|
7
|
+
- Fixed `zero_grad`
|
8
|
+
- Fixed error with infinite values
|
9
|
+
|
1
10
|
## 0.1.1 (2019-11-26)
|
2
11
|
|
3
12
|
- Added support for `uint8` and `int8` types
|
data/README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
This gem is currently experimental. There may be breaking changes between each release.
|
5
|
+
This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
|
6
6
|
|
7
7
|
[](https://travis-ci.org/ankane/torch-rb)
|
8
8
|
|
@@ -223,7 +223,7 @@ class Net < Torch::NN::Module
|
|
223
223
|
end
|
224
224
|
```
|
225
225
|
|
226
|
-
|
226
|
+
Create an instance of it
|
227
227
|
|
228
228
|
```ruby
|
229
229
|
net = Net.new
|
@@ -231,6 +231,58 @@ input = Torch.randn(1, 1, 32, 32)
|
|
231
231
|
net.call(input)
|
232
232
|
```
|
233
233
|
|
234
|
+
Get trainable parameters
|
235
|
+
|
236
|
+
```ruby
|
237
|
+
net.parameters
|
238
|
+
```
|
239
|
+
|
240
|
+
Zero the gradient buffers and backprop with random gradients
|
241
|
+
|
242
|
+
```ruby
|
243
|
+
net.zero_grad
|
244
|
+
out.backward(Torch.randn(1, 10))
|
245
|
+
```
|
246
|
+
|
247
|
+
Define a loss function
|
248
|
+
|
249
|
+
```ruby
|
250
|
+
output = net.call(input)
|
251
|
+
target = Torch.randn(10)
|
252
|
+
target = target.view(1, -1)
|
253
|
+
criterion = Torch::NN::MSELoss.new
|
254
|
+
loss = criterion.call(output, target)
|
255
|
+
```
|
256
|
+
|
257
|
+
Backprop
|
258
|
+
|
259
|
+
```ruby
|
260
|
+
net.zero_grad
|
261
|
+
p net.conv1.bias.grad
|
262
|
+
loss.backward
|
263
|
+
p net.conv1.bias.grad
|
264
|
+
```
|
265
|
+
|
266
|
+
Update the weights
|
267
|
+
|
268
|
+
```ruby
|
269
|
+
learning_rate = 0.01
|
270
|
+
net.parameters.each do |f|
|
271
|
+
f.data.sub!(f.grad.data * learning_rate)
|
272
|
+
end
|
273
|
+
```
|
274
|
+
|
275
|
+
Use an optimizer
|
276
|
+
|
277
|
+
```ruby
|
278
|
+
optimizer = Torch::Optim::SGD.new(net.parameters, lr: 0.01)
|
279
|
+
optimizer.zero_grad
|
280
|
+
output = net.call(input)
|
281
|
+
loss = criterion.call(output, target)
|
282
|
+
loss.backward
|
283
|
+
optimizer.step
|
284
|
+
```
|
285
|
+
|
234
286
|
### Tensor Creation
|
235
287
|
|
236
288
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -351,7 +403,7 @@ To get started with development:
|
|
351
403
|
|
352
404
|
```sh
|
353
405
|
git clone https://github.com/ankane/torch-rb.git
|
354
|
-
cd torch
|
406
|
+
cd torch-rb
|
355
407
|
bundle install
|
356
408
|
bundle exec rake compile
|
357
409
|
bundle exec rake test
|
data/ext/torch/ext.cpp
CHANGED
@@ -196,6 +196,16 @@ void Init_ext()
|
|
196
196
|
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
|
197
197
|
return torch::sum(input, dim, keepdim);
|
198
198
|
})
|
199
|
+
.define_singleton_method(
|
200
|
+
"_argmax",
|
201
|
+
*[](torch::Tensor& input) {
|
202
|
+
return torch::argmax(input);
|
203
|
+
})
|
204
|
+
.define_singleton_method(
|
205
|
+
"_argmax_dim",
|
206
|
+
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
|
207
|
+
return torch::argmax(input, dim, keepdim);
|
208
|
+
})
|
199
209
|
.define_singleton_method(
|
200
210
|
"_norm",
|
201
211
|
*[](torch::Tensor& input) {
|
@@ -236,6 +246,11 @@ void Init_ext()
|
|
236
246
|
*[](torch::Tensor& input, torch::Tensor& other) {
|
237
247
|
return torch::matmul(input, other);
|
238
248
|
})
|
249
|
+
.define_singleton_method(
|
250
|
+
"_eq",
|
251
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
252
|
+
return torch::eq(input, other);
|
253
|
+
})
|
239
254
|
.define_singleton_method(
|
240
255
|
"_add",
|
241
256
|
*[](torch::Tensor& input, torch::Tensor& other) {
|
@@ -301,15 +316,30 @@ void Init_ext()
|
|
301
316
|
*[](torch::Tensor& input) {
|
302
317
|
return torch::neg(input);
|
303
318
|
})
|
319
|
+
.define_singleton_method(
|
320
|
+
"_reshape",
|
321
|
+
*[](torch::Tensor& input, IntArrayRef shape) {
|
322
|
+
return torch::reshape(input, shape);
|
323
|
+
})
|
304
324
|
.define_singleton_method(
|
305
325
|
"relu",
|
306
326
|
*[](torch::Tensor& input) {
|
307
327
|
return torch::relu(input);
|
308
328
|
})
|
329
|
+
.define_singleton_method(
|
330
|
+
"prelu",
|
331
|
+
*[](torch::Tensor& input, torch::Tensor& weight) {
|
332
|
+
return torch::prelu(input, weight);
|
333
|
+
})
|
334
|
+
.define_singleton_method(
|
335
|
+
"leaky_relu",
|
336
|
+
*[](torch::Tensor& input, Scalar negative_slope = 0.01) {
|
337
|
+
return torch::leaky_relu(input, negative_slope);
|
338
|
+
})
|
309
339
|
.define_singleton_method(
|
310
340
|
"conv2d",
|
311
|
-
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
|
312
|
-
return torch::conv2d(input, weight, bias);
|
341
|
+
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
|
342
|
+
return torch::conv2d(input, weight, bias, stride, padding);
|
313
343
|
})
|
314
344
|
.define_singleton_method(
|
315
345
|
"linear",
|
@@ -321,6 +351,11 @@ void Init_ext()
|
|
321
351
|
*[](torch::Tensor& input, IntArrayRef kernel_size) {
|
322
352
|
return torch::max_pool2d(input, kernel_size);
|
323
353
|
})
|
354
|
+
.define_singleton_method(
|
355
|
+
"avg_pool2d",
|
356
|
+
*[](torch::Tensor& input, IntArrayRef kernel_size) {
|
357
|
+
return torch::avg_pool2d(input, kernel_size);
|
358
|
+
})
|
324
359
|
.define_singleton_method(
|
325
360
|
"mse_loss",
|
326
361
|
*[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
|
@@ -366,9 +401,14 @@ void Init_ext()
|
|
366
401
|
return self.detach_();
|
367
402
|
})
|
368
403
|
.define_method(
|
369
|
-
"
|
370
|
-
*[](torch::Tensor& self, int64_t index) {
|
371
|
-
return self
|
404
|
+
"_select",
|
405
|
+
*[](torch::Tensor& self, int64_t dim, int64_t index) {
|
406
|
+
return self.select(dim, index);
|
407
|
+
})
|
408
|
+
.define_method(
|
409
|
+
"_slice",
|
410
|
+
*[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
411
|
+
return self.slice(dim, start, end, step);
|
372
412
|
})
|
373
413
|
.define_method(
|
374
414
|
"_requires_grad!",
|
@@ -376,10 +416,15 @@ void Init_ext()
|
|
376
416
|
return self.set_requires_grad(requires_grad);
|
377
417
|
})
|
378
418
|
.define_method(
|
379
|
-
"
|
419
|
+
"_backward",
|
380
420
|
*[](torch::Tensor& self) {
|
381
421
|
return self.backward();
|
382
422
|
})
|
423
|
+
.define_method(
|
424
|
+
"_backward_gradient",
|
425
|
+
*[](torch::Tensor& self, const torch::Tensor& gradient) {
|
426
|
+
return self.backward(gradient);
|
427
|
+
})
|
383
428
|
.define_method(
|
384
429
|
"grad",
|
385
430
|
*[](torch::Tensor& self) {
|
@@ -390,6 +435,11 @@ void Init_ext()
|
|
390
435
|
*[](torch::Tensor& self) {
|
391
436
|
return (int) at::typeMetaToScalarType(self.dtype());
|
392
437
|
})
|
438
|
+
.define_method(
|
439
|
+
"_type",
|
440
|
+
*[](torch::Tensor& self, int dtype) {
|
441
|
+
return self.toType((torch::ScalarType) dtype);
|
442
|
+
})
|
393
443
|
.define_method(
|
394
444
|
"_layout",
|
395
445
|
*[](torch::Tensor& self) {
|
@@ -434,6 +484,11 @@ void Init_ext()
|
|
434
484
|
*[](torch::Tensor& self, int64_t dim) {
|
435
485
|
return self.log_softmax(dim);
|
436
486
|
})
|
487
|
+
.define_method(
|
488
|
+
"data",
|
489
|
+
*[](torch::Tensor& self) {
|
490
|
+
return self.data();
|
491
|
+
})
|
437
492
|
.define_method(
|
438
493
|
"_data",
|
439
494
|
*[](torch::Tensor& self) {
|
@@ -553,9 +608,15 @@ void Init_ext()
|
|
553
608
|
});
|
554
609
|
|
555
610
|
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
|
611
|
+
// TODO return grad or nil to remove need for 2nd function
|
556
612
|
.define_method(
|
557
|
-
"
|
613
|
+
"_grad",
|
558
614
|
*[](torch::autograd::Variable& self) {
|
559
615
|
return self.grad();
|
616
|
+
})
|
617
|
+
.define_method(
|
618
|
+
"_grad_defined",
|
619
|
+
*[](torch::autograd::Variable& self) {
|
620
|
+
return self.grad().defined();
|
560
621
|
});
|
561
622
|
}
|
data/lib/torch.rb
CHANGED
@@ -6,6 +6,10 @@ require "torch/inspector"
|
|
6
6
|
require "torch/tensor"
|
7
7
|
require "torch/version"
|
8
8
|
|
9
|
+
# optim
|
10
|
+
require "torch/optim/optimizer"
|
11
|
+
require "torch/optim/sgd"
|
12
|
+
|
9
13
|
# nn
|
10
14
|
require "torch/nn/module"
|
11
15
|
require "torch/nn/init"
|
@@ -55,9 +59,13 @@ module Torch
|
|
55
59
|
|
56
60
|
class << self
|
57
61
|
# Torch.float, Torch.long, etc
|
58
|
-
DTYPE_TO_ENUM.each_key do |
|
59
|
-
define_method(
|
60
|
-
|
62
|
+
DTYPE_TO_ENUM.each_key do |dtype|
|
63
|
+
define_method(dtype) do
|
64
|
+
dtype
|
65
|
+
end
|
66
|
+
|
67
|
+
Tensor.define_method(dtype) do
|
68
|
+
type(dtype)
|
61
69
|
end
|
62
70
|
end
|
63
71
|
|
@@ -240,6 +248,18 @@ module Torch
|
|
240
248
|
end
|
241
249
|
end
|
242
250
|
|
251
|
+
def argmax(input, dim = nil, keepdim: false)
|
252
|
+
if dim
|
253
|
+
_argmax_dim(input, dim, keepdim)
|
254
|
+
else
|
255
|
+
_argmax(input)
|
256
|
+
end
|
257
|
+
end
|
258
|
+
|
259
|
+
def eq(input, other)
|
260
|
+
_eq(input, other)
|
261
|
+
end
|
262
|
+
|
243
263
|
def norm(input)
|
244
264
|
_norm(input)
|
245
265
|
end
|
@@ -276,6 +296,10 @@ module Torch
|
|
276
296
|
_matmul(input, other)
|
277
297
|
end
|
278
298
|
|
299
|
+
def reshape(input, shape)
|
300
|
+
_reshape(input, shape)
|
301
|
+
end
|
302
|
+
|
279
303
|
private
|
280
304
|
|
281
305
|
def execute_op(op, input, other, out: nil)
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module Inspector
|
3
|
+
# TODO make more performance, especially when summarizing
|
3
4
|
def inspect
|
4
5
|
data =
|
5
6
|
if numel == 0
|
@@ -20,7 +21,7 @@ module Torch
|
|
20
21
|
if floating_point?
|
21
22
|
sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
|
22
23
|
|
23
|
-
all_int = values.all? { |v| v == v.to_i }
|
24
|
+
all_int = values.all? { |v| v.finite? && v == v.to_i }
|
24
25
|
decimal = all_int ? 1 : 4
|
25
26
|
|
26
27
|
total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
|
@@ -35,7 +36,9 @@ module Torch
|
|
35
36
|
fmt = "%#{total}d"
|
36
37
|
end
|
37
38
|
|
38
|
-
|
39
|
+
summarize = numel > 1000
|
40
|
+
|
41
|
+
inspect_level(to_a, fmt, dim - 1, 0, summarize)
|
39
42
|
end
|
40
43
|
|
41
44
|
attributes = []
|
@@ -51,11 +54,30 @@ module Torch
|
|
51
54
|
|
52
55
|
private
|
53
56
|
|
54
|
-
|
57
|
+
# TODO DRY code
|
58
|
+
def inspect_level(arr, fmt, total, level, summarize)
|
55
59
|
if level == total
|
56
|
-
|
60
|
+
cols =
|
61
|
+
if summarize && arr.size > 7
|
62
|
+
arr[0..2].map { |v| fmt % v } +
|
63
|
+
["..."] +
|
64
|
+
arr[-3..-1].map { |v| fmt % v }
|
65
|
+
else
|
66
|
+
arr.map { |v| fmt % v }
|
67
|
+
end
|
68
|
+
|
69
|
+
"[#{cols.join(", ")}]"
|
57
70
|
else
|
58
|
-
|
71
|
+
rows =
|
72
|
+
if summarize && arr.size > 7
|
73
|
+
arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
|
74
|
+
["..."] +
|
75
|
+
arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
|
76
|
+
else
|
77
|
+
arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
|
78
|
+
end
|
79
|
+
|
80
|
+
"[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
|
59
81
|
end
|
60
82
|
end
|
61
83
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -3,13 +3,12 @@ module Torch
|
|
3
3
|
class Conv2d < Module
|
4
4
|
attr_reader :bias, :weight
|
5
5
|
|
6
|
-
def initialize(in_channels, out_channels, kernel_size
|
6
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0) #, dilation: 1, groups: 1)
|
7
7
|
@in_channels = in_channels
|
8
8
|
@out_channels = out_channels
|
9
9
|
@kernel_size = pair(kernel_size)
|
10
|
-
@stride = pair(
|
11
|
-
|
12
|
-
# @padding = pair(padding)
|
10
|
+
@stride = pair(stride)
|
11
|
+
@padding = pair(padding)
|
13
12
|
# @dilation = pair(dilation)
|
14
13
|
|
15
14
|
# TODO divide by groups
|
@@ -29,7 +28,7 @@ module Torch
|
|
29
28
|
end
|
30
29
|
|
31
30
|
def call(input)
|
32
|
-
F.conv2d(input, @weight, @bias
|
31
|
+
F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding) #, @dilation, @groups)
|
33
32
|
end
|
34
33
|
|
35
34
|
def inspect
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -6,8 +6,17 @@ module Torch
|
|
6
6
|
Torch.relu(input)
|
7
7
|
end
|
8
8
|
|
9
|
-
def conv2d(input, weight, bias)
|
10
|
-
|
9
|
+
def conv2d(input, weight, bias, stride: 1, padding: 0)
|
10
|
+
# TODO pair stride and padding when needed
|
11
|
+
Torch.conv2d(input, weight, bias, stride, padding)
|
12
|
+
end
|
13
|
+
|
14
|
+
def prelu(input, weight)
|
15
|
+
Torch.prelu(input, weight)
|
16
|
+
end
|
17
|
+
|
18
|
+
def leaky_relu(input, negative_slope = 0.01)
|
19
|
+
Torch.leaky_relu(input, negative_slope)
|
11
20
|
end
|
12
21
|
|
13
22
|
def max_pool2d(input, kernel_size)
|
@@ -15,6 +24,11 @@ module Torch
|
|
15
24
|
Torch.max_pool2d(input, kernel_size)
|
16
25
|
end
|
17
26
|
|
27
|
+
def avg_pool2d(input, kernel_size)
|
28
|
+
kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
|
29
|
+
Torch.avg_pool2d(input, kernel_size)
|
30
|
+
end
|
31
|
+
|
18
32
|
def linear(input, weight, bias)
|
19
33
|
Torch.linear(input, weight, bias)
|
20
34
|
end
|
data/lib/torch/nn/module.rb
CHANGED
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,28 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
class SGD < Optimizer
|
4
|
+
def initialize(params, lr:)
|
5
|
+
@params = params
|
6
|
+
@lr = lr
|
7
|
+
end
|
8
|
+
|
9
|
+
def zero_grad
|
10
|
+
@params.each do |param|
|
11
|
+
if param.grad
|
12
|
+
param.grad.detach!
|
13
|
+
param.grad.zero!
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
|
18
|
+
def step
|
19
|
+
@params.each do |param|
|
20
|
+
next unless param.grad
|
21
|
+
d_p = param.grad.data
|
22
|
+
# same as param.data.add!(-@lr, d_p)
|
23
|
+
param.data.sub!(d_p * @lr)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -28,7 +28,7 @@ module Torch
|
|
28
28
|
end
|
29
29
|
|
30
30
|
def to_a
|
31
|
-
|
31
|
+
reshape_arr(_data, shape)
|
32
32
|
end
|
33
33
|
|
34
34
|
def size(dim = nil)
|
@@ -54,8 +54,12 @@ module Torch
|
|
54
54
|
_data.first
|
55
55
|
end
|
56
56
|
|
57
|
-
def
|
58
|
-
|
57
|
+
def backward(gradient = nil)
|
58
|
+
if gradient
|
59
|
+
_backward_gradient(gradient)
|
60
|
+
else
|
61
|
+
_backward
|
62
|
+
end
|
59
63
|
end
|
60
64
|
|
61
65
|
# TODO read directly from memory
|
@@ -74,8 +78,14 @@ module Torch
|
|
74
78
|
_requires_grad!(requires_grad)
|
75
79
|
end
|
76
80
|
|
81
|
+
def type(dtype)
|
82
|
+
enum = DTYPE_TO_ENUM[dtype]
|
83
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
84
|
+
_type(enum)
|
85
|
+
end
|
86
|
+
|
77
87
|
# operations
|
78
|
-
%w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
|
88
|
+
%w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze reshape argmax eq).each do |op|
|
79
89
|
define_method(op) do |*args, **options, &block|
|
80
90
|
if options.any?
|
81
91
|
Torch.send(op, self, *args, **options, &block)
|
@@ -117,18 +127,27 @@ module Torch
|
|
117
127
|
item <=> other
|
118
128
|
end
|
119
129
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
130
|
+
def [](*indexes)
|
131
|
+
result = self
|
132
|
+
dim = 0
|
133
|
+
indexes.each_with_index do |index|
|
134
|
+
if index.is_a?(Numeric)
|
135
|
+
result = result._select(dim, index)
|
136
|
+
elsif index.is_a?(Range)
|
137
|
+
finish = index.end
|
138
|
+
finish += 1 unless index.exclude_end?
|
139
|
+
result = result._slice(dim, index.begin, finish, 1)
|
140
|
+
dim += 1
|
141
|
+
else
|
142
|
+
raise Error, "Unsupported index type"
|
143
|
+
end
|
125
144
|
end
|
126
|
-
|
145
|
+
result
|
127
146
|
end
|
128
147
|
|
129
148
|
private
|
130
149
|
|
131
|
-
def
|
150
|
+
def reshape_arr(arr, dims)
|
132
151
|
if dims.empty?
|
133
152
|
arr
|
134
153
|
else
|
@@ -6,6 +6,15 @@ module Torch
|
|
6
6
|
@dataset = dataset
|
7
7
|
@batch_size = batch_size
|
8
8
|
end
|
9
|
+
|
10
|
+
def each
|
11
|
+
size = @dataset.size
|
12
|
+
start_index = 0
|
13
|
+
while start_index < size
|
14
|
+
yield @dataset[start_index...(start_index + @batch_size)]
|
15
|
+
start_index += @batch_size
|
16
|
+
end
|
17
|
+
end
|
9
18
|
end
|
10
19
|
end
|
11
20
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
@@ -119,6 +119,8 @@ files:
|
|
119
119
|
- lib/torch/nn/parameter.rb
|
120
120
|
- lib/torch/nn/relu.rb
|
121
121
|
- lib/torch/nn/sequential.rb
|
122
|
+
- lib/torch/optim/optimizer.rb
|
123
|
+
- lib/torch/optim/sgd.rb
|
122
124
|
- lib/torch/tensor.rb
|
123
125
|
- lib/torch/utils/data/data_loader.rb
|
124
126
|
- lib/torch/utils/data/tensor_dataset.rb
|