torch-rb 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +46 -0
- data/README.md +426 -0
- data/ext/torch/ext.cpp +839 -0
- data/ext/torch/extconf.rb +25 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +422 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +85 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +37 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +100 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +85 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +14 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/tensor.rb +196 -0
- data/lib/torch/utils/data/data_loader.rb +27 -0
- data/lib/torch/utils/data/tensor_dataset.rb +22 -0
- data/lib/torch/version.rb +3 -0
- metadata +169 -0
@@ -0,0 +1,65 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class ASGD < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
8
|
+
|
9
|
+
defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
loss = nil
|
15
|
+
if closure
|
16
|
+
loss = closure.call
|
17
|
+
end
|
18
|
+
|
19
|
+
@param_groups.each do |group|
|
20
|
+
group[:params].each do |p|
|
21
|
+
next unless p.grad
|
22
|
+
grad = p.grad.data
|
23
|
+
if grad.sparse?
|
24
|
+
raise Error, "ASGD does not support sparse gradients"
|
25
|
+
end
|
26
|
+
state = @state[p]
|
27
|
+
|
28
|
+
# State initialization
|
29
|
+
if state.size == 0
|
30
|
+
state[:step] = 0
|
31
|
+
state[:eta] = group[:lr]
|
32
|
+
state[:mu] = 1
|
33
|
+
state[:ax] = Torch.zeros_like(p.data)
|
34
|
+
end
|
35
|
+
|
36
|
+
state[:step] += 1
|
37
|
+
|
38
|
+
if group[:weight_decay] != 0
|
39
|
+
grad = grad.add(group[:weight_decay], p.data)
|
40
|
+
end
|
41
|
+
|
42
|
+
# decay term
|
43
|
+
p.data.mul!(1 - group[:lambd] * state[:eta])
|
44
|
+
|
45
|
+
# update parameter
|
46
|
+
p.data.add!(-state[:eta], grad)
|
47
|
+
|
48
|
+
# averaging
|
49
|
+
if state[:mu] != 1
|
50
|
+
state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
|
51
|
+
else
|
52
|
+
state[:ax].copy!(p.data)
|
53
|
+
end
|
54
|
+
|
55
|
+
# update eta and mu
|
56
|
+
state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
|
57
|
+
state[:mu] = 1 / [1, state[:step] - group[:t0]].max
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
loss
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class LRScheduler
|
5
|
+
def initialize(optimizer, last_epoch)
|
6
|
+
@optimizer = optimizer
|
7
|
+
if last_epoch == -1
|
8
|
+
optimizer.param_groups.each do |group|
|
9
|
+
group[:initial_lr] ||= group[:lr]
|
10
|
+
end
|
11
|
+
last_epoch = 0
|
12
|
+
else
|
13
|
+
raise NotImplementedYet
|
14
|
+
end
|
15
|
+
@base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
|
16
|
+
@last_epoch = last_epoch
|
17
|
+
|
18
|
+
@step_count = 0
|
19
|
+
step(last_epoch)
|
20
|
+
end
|
21
|
+
|
22
|
+
def step(epoch = nil)
|
23
|
+
@step_count += 1
|
24
|
+
epoch ||= @last_epoch + 1
|
25
|
+
@last_epoch = epoch
|
26
|
+
@optimizer.param_groups.zip(get_lr).each do |param_group, lr|
|
27
|
+
param_group[:lr] = lr
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class StepLR < LRScheduler
|
5
|
+
def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
|
6
|
+
@step_size = step_size
|
7
|
+
@gamma = gamma
|
8
|
+
super(optimizer, last_epoch)
|
9
|
+
end
|
10
|
+
|
11
|
+
def get_lr
|
12
|
+
@base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Optimizer
|
5
|
+
attr_reader :param_groups
|
6
|
+
|
7
|
+
def initialize(params, defaults)
|
8
|
+
@defaults = defaults
|
9
|
+
@state = Hash.new { |hash, key| hash[key] = {} }
|
10
|
+
@param_groups = []
|
11
|
+
|
12
|
+
param_groups = params
|
13
|
+
if param_groups.empty?
|
14
|
+
raise ArgumentError, "optimizer got an empty parameter list"
|
15
|
+
end
|
16
|
+
if !param_groups[0].is_a?(Hash)
|
17
|
+
param_groups = [{params: param_groups}]
|
18
|
+
end
|
19
|
+
|
20
|
+
param_groups.each do |param_group|
|
21
|
+
add_param_group(param_group)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def add_param_group(param_group)
|
26
|
+
# TODO more advanced logic
|
27
|
+
@param_groups << @defaults.merge(param_group)
|
28
|
+
end
|
29
|
+
|
30
|
+
def load_state_dict(state_dict)
|
31
|
+
raise NotImplementedYet
|
32
|
+
end
|
33
|
+
|
34
|
+
def state_dict
|
35
|
+
pack_group = lambda do |group|
|
36
|
+
packed = group.select { |k, _| k != :params }.to_h
|
37
|
+
packed[:params] = group[:params].map { |p| p.object_id }
|
38
|
+
packed
|
39
|
+
end
|
40
|
+
|
41
|
+
param_groups = @param_groups.map { |g| pack_group.call(g) }
|
42
|
+
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
|
+
|
44
|
+
{
|
45
|
+
state: packed_state,
|
46
|
+
param_groups: param_groups
|
47
|
+
}
|
48
|
+
end
|
49
|
+
|
50
|
+
def zero_grad
|
51
|
+
@param_groups.each do |group|
|
52
|
+
group[:params].each do |p|
|
53
|
+
if p.grad
|
54
|
+
p.grad.detach!
|
55
|
+
p.grad.zero!
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class RMSprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "RMSprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:square_avg] = Torch.zeros_like(p.data)
|
35
|
+
if group[:momentum] > 0
|
36
|
+
state[:momentum_buffer] = Torch.zeros_like(p.data)
|
37
|
+
end
|
38
|
+
if group[:centered]
|
39
|
+
state[:grad_avg] = Torch.zeros_like(p.data)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
square_avg = state[:square_avg]
|
44
|
+
alpha = group[:alpha]
|
45
|
+
|
46
|
+
state[:step] += 1
|
47
|
+
|
48
|
+
if group[:weight_decay] != 0
|
49
|
+
grad = grad.add(group[:weight_decay], p.data)
|
50
|
+
end
|
51
|
+
|
52
|
+
square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
|
53
|
+
|
54
|
+
if group[:centered]
|
55
|
+
grad_avg = state[:grad_avg]
|
56
|
+
grad_avg.mul!(alpha).add!(1 - alpha, grad)
|
57
|
+
avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
|
58
|
+
else
|
59
|
+
avg = square_avg.sqrt.add!(group[:eps])
|
60
|
+
end
|
61
|
+
|
62
|
+
if group[:momentum] > 0
|
63
|
+
buf = state[:momentum_buffer]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
+
p.data.add!(-group[:lr], buf)
|
66
|
+
else
|
67
|
+
p.data.addcdiv!(-group[:lr], grad, avg)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
loss
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Rprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
|
8
|
+
|
9
|
+
defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
# TODO implement []=
|
15
|
+
raise NotImplementedYet
|
16
|
+
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Rprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:prev] = Torch.zeros_like(p.data)
|
35
|
+
state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
|
36
|
+
end
|
37
|
+
|
38
|
+
etaminus, etaplus = group[:etas]
|
39
|
+
step_size_min, step_size_max = group[:step_sizes]
|
40
|
+
step_size = state[:step_size]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
sign = grad.mul(state[:prev]).sign
|
45
|
+
sign[sign.gt(0)] = etaplus
|
46
|
+
sign[sign.lt(0)] = etaminus
|
47
|
+
sign[sign.eq(0)] = 1
|
48
|
+
|
49
|
+
# update stepsizes with step size updates
|
50
|
+
step_size.mul!(sign).clamp!(step_size_min, step_size_max)
|
51
|
+
|
52
|
+
# for dir<0, dfdx=0
|
53
|
+
# for dir>=0 dfdx=dfdx
|
54
|
+
grad = grad.clone
|
55
|
+
grad[sign.eq(etaminus)] = 0
|
56
|
+
|
57
|
+
# update parameters
|
58
|
+
p.data.addcmul!(-1, grad.sign, step_size)
|
59
|
+
|
60
|
+
state[:prev].copy!(grad)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class SGD < Optimizer
|
5
|
+
def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
|
7
|
+
raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
|
8
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
|
9
|
+
|
10
|
+
defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
|
11
|
+
|
12
|
+
if nesterov && (momentum <= 0 || dampening != 0)
|
13
|
+
raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
|
14
|
+
end
|
15
|
+
|
16
|
+
super(params, defaults)
|
17
|
+
end
|
18
|
+
|
19
|
+
def step(closure = nil)
|
20
|
+
loss = nil
|
21
|
+
if closure
|
22
|
+
loss = closure.call
|
23
|
+
end
|
24
|
+
|
25
|
+
@param_groups.each do |group|
|
26
|
+
weight_decay = group[:weight_decay]
|
27
|
+
momentum = group[:momentum]
|
28
|
+
dampening = group[:dampening]
|
29
|
+
nesterov = group[:nesterov]
|
30
|
+
|
31
|
+
group[:params].each do |p|
|
32
|
+
next unless p.grad
|
33
|
+
d_p = p.grad.data
|
34
|
+
if weight_decay != 0
|
35
|
+
d_p.add!(weight_decay, p.data)
|
36
|
+
end
|
37
|
+
if momentum != 0
|
38
|
+
param_state = @state[p]
|
39
|
+
if !param_state.key(:momentum_buffer)
|
40
|
+
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
|
+
else
|
42
|
+
buf = param_state[:momentum_buffer]
|
43
|
+
buf.mul!(momentum).add!(1 - dampening, d_p)
|
44
|
+
end
|
45
|
+
if nesterov
|
46
|
+
d_p = d_p.add(momentum, buf)
|
47
|
+
else
|
48
|
+
d_p = buf
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
p.data.add!(-group[:lr], d_p)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
loss
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
data/lib/torch/tensor.rb
ADDED
@@ -0,0 +1,196 @@
|
|
1
|
+
module Torch
|
2
|
+
class Tensor
|
3
|
+
include Comparable
|
4
|
+
include Inspector
|
5
|
+
|
6
|
+
alias_method :requires_grad?, :requires_grad
|
7
|
+
|
8
|
+
def self.new(*size)
|
9
|
+
if size.length == 1 && size.first.is_a?(Tensor)
|
10
|
+
size.first
|
11
|
+
else
|
12
|
+
Torch.empty(*size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
def dtype
|
17
|
+
dtype = ENUM_TO_DTYPE[_dtype]
|
18
|
+
raise Error, "Unknown type: #{_dtype}" unless dtype
|
19
|
+
dtype
|
20
|
+
end
|
21
|
+
|
22
|
+
def layout
|
23
|
+
_layout.downcase.to_sym
|
24
|
+
end
|
25
|
+
|
26
|
+
def to_s
|
27
|
+
inspect
|
28
|
+
end
|
29
|
+
|
30
|
+
def to_a
|
31
|
+
reshape_arr(_data, shape)
|
32
|
+
end
|
33
|
+
|
34
|
+
# TODO support dtype
|
35
|
+
def to(device, non_blocking: false, copy: false)
|
36
|
+
device = Device.new(device) if device.is_a?(String)
|
37
|
+
_to(device, _dtype, non_blocking, copy)
|
38
|
+
end
|
39
|
+
|
40
|
+
def size(dim = nil)
|
41
|
+
if dim
|
42
|
+
_size(dim)
|
43
|
+
else
|
44
|
+
shape
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def shape
|
49
|
+
dim.times.map { |i| size(i) }
|
50
|
+
end
|
51
|
+
|
52
|
+
def view(*size)
|
53
|
+
_view(size)
|
54
|
+
end
|
55
|
+
|
56
|
+
def item
|
57
|
+
if numel != 1
|
58
|
+
raise Error, "only one element tensors can be converted to Ruby scalars"
|
59
|
+
end
|
60
|
+
_data.first
|
61
|
+
end
|
62
|
+
|
63
|
+
# unsure if this is correct
|
64
|
+
def new
|
65
|
+
Torch.empty(0, dtype: dtype)
|
66
|
+
end
|
67
|
+
|
68
|
+
def backward(gradient = nil)
|
69
|
+
if gradient
|
70
|
+
_backward_gradient(gradient)
|
71
|
+
else
|
72
|
+
_backward
|
73
|
+
end
|
74
|
+
end
|
75
|
+
|
76
|
+
# TODO read directly from memory
|
77
|
+
def numo
|
78
|
+
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
79
|
+
cls = Torch._dtype_to_numo[dtype]
|
80
|
+
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
81
|
+
cls.cast(_data).reshape(*shape)
|
82
|
+
end
|
83
|
+
|
84
|
+
def new_ones(*size, **options)
|
85
|
+
Torch.ones_like(Torch.empty(*size), **options)
|
86
|
+
end
|
87
|
+
|
88
|
+
def requires_grad!(requires_grad = true)
|
89
|
+
_requires_grad!(requires_grad)
|
90
|
+
end
|
91
|
+
|
92
|
+
def type(dtype)
|
93
|
+
enum = DTYPE_TO_ENUM[dtype]
|
94
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
95
|
+
_type(enum)
|
96
|
+
end
|
97
|
+
|
98
|
+
def add!(value = 1, other)
|
99
|
+
if other.is_a?(Numeric)
|
100
|
+
_add_scalar!(other * value)
|
101
|
+
else
|
102
|
+
# need to use alpha for sparse tensors instead of multiplying
|
103
|
+
_add_alpha!(other, value)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
def mul!(other)
|
108
|
+
if other.is_a?(Numeric)
|
109
|
+
_mul_scalar!(other)
|
110
|
+
else
|
111
|
+
_mul!(other)
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
115
|
+
# operations
|
116
|
+
%w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
|
117
|
+
define_method(op) do |*args, **options, &block|
|
118
|
+
if options.any?
|
119
|
+
Torch.send(op, self, *args, **options, &block)
|
120
|
+
else
|
121
|
+
Torch.send(op, self, *args, &block)
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
|
126
|
+
def +(other)
|
127
|
+
add(other)
|
128
|
+
end
|
129
|
+
|
130
|
+
def -(other)
|
131
|
+
sub(other)
|
132
|
+
end
|
133
|
+
|
134
|
+
def *(other)
|
135
|
+
mul(other)
|
136
|
+
end
|
137
|
+
|
138
|
+
def /(other)
|
139
|
+
div(other)
|
140
|
+
end
|
141
|
+
|
142
|
+
def %(other)
|
143
|
+
remainder(other)
|
144
|
+
end
|
145
|
+
|
146
|
+
def **(other)
|
147
|
+
pow(other)
|
148
|
+
end
|
149
|
+
|
150
|
+
def -@
|
151
|
+
neg
|
152
|
+
end
|
153
|
+
|
154
|
+
def <=>(other)
|
155
|
+
item <=> other
|
156
|
+
end
|
157
|
+
|
158
|
+
# based on python_variable_indexing.cpp
|
159
|
+
def [](*indexes)
|
160
|
+
result = self
|
161
|
+
dim = 0
|
162
|
+
indexes.each do |index|
|
163
|
+
if index.is_a?(Numeric)
|
164
|
+
result = result._select(dim, index)
|
165
|
+
elsif index.is_a?(Range)
|
166
|
+
finish = index.end
|
167
|
+
finish += 1 unless index.exclude_end?
|
168
|
+
result = result._slice(dim, index.begin, finish, 1)
|
169
|
+
dim += 1
|
170
|
+
else
|
171
|
+
raise Error, "Unsupported index type"
|
172
|
+
end
|
173
|
+
end
|
174
|
+
result
|
175
|
+
end
|
176
|
+
|
177
|
+
# TODO
|
178
|
+
# based on python_variable_indexing.cpp
|
179
|
+
# def []=(index, value)
|
180
|
+
# end
|
181
|
+
|
182
|
+
private
|
183
|
+
|
184
|
+
def reshape_arr(arr, dims)
|
185
|
+
if dims.empty?
|
186
|
+
arr
|
187
|
+
else
|
188
|
+
arr = arr.flatten
|
189
|
+
dims[1..-1].reverse.each do |dim|
|
190
|
+
arr = arr.each_slice(dim)
|
191
|
+
end
|
192
|
+
arr.to_a
|
193
|
+
end
|
194
|
+
end
|
195
|
+
end
|
196
|
+
end
|