torch-rb 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +46 -0
- data/README.md +426 -0
- data/ext/torch/ext.cpp +839 -0
- data/ext/torch/extconf.rb +25 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +422 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +85 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +37 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +100 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +85 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +14 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/tensor.rb +196 -0
- data/lib/torch/utils/data/data_loader.rb +27 -0
- data/lib/torch/utils/data/tensor_dataset.rb +22 -0
- data/lib/torch/version.rb +3 -0
- metadata +169 -0
@@ -0,0 +1,65 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class ASGD < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
8
|
+
|
9
|
+
defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
loss = nil
|
15
|
+
if closure
|
16
|
+
loss = closure.call
|
17
|
+
end
|
18
|
+
|
19
|
+
@param_groups.each do |group|
|
20
|
+
group[:params].each do |p|
|
21
|
+
next unless p.grad
|
22
|
+
grad = p.grad.data
|
23
|
+
if grad.sparse?
|
24
|
+
raise Error, "ASGD does not support sparse gradients"
|
25
|
+
end
|
26
|
+
state = @state[p]
|
27
|
+
|
28
|
+
# State initialization
|
29
|
+
if state.size == 0
|
30
|
+
state[:step] = 0
|
31
|
+
state[:eta] = group[:lr]
|
32
|
+
state[:mu] = 1
|
33
|
+
state[:ax] = Torch.zeros_like(p.data)
|
34
|
+
end
|
35
|
+
|
36
|
+
state[:step] += 1
|
37
|
+
|
38
|
+
if group[:weight_decay] != 0
|
39
|
+
grad = grad.add(group[:weight_decay], p.data)
|
40
|
+
end
|
41
|
+
|
42
|
+
# decay term
|
43
|
+
p.data.mul!(1 - group[:lambd] * state[:eta])
|
44
|
+
|
45
|
+
# update parameter
|
46
|
+
p.data.add!(-state[:eta], grad)
|
47
|
+
|
48
|
+
# averaging
|
49
|
+
if state[:mu] != 1
|
50
|
+
state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
|
51
|
+
else
|
52
|
+
state[:ax].copy!(p.data)
|
53
|
+
end
|
54
|
+
|
55
|
+
# update eta and mu
|
56
|
+
state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
|
57
|
+
state[:mu] = 1 / [1, state[:step] - group[:t0]].max
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
loss
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class LRScheduler
|
5
|
+
def initialize(optimizer, last_epoch)
|
6
|
+
@optimizer = optimizer
|
7
|
+
if last_epoch == -1
|
8
|
+
optimizer.param_groups.each do |group|
|
9
|
+
group[:initial_lr] ||= group[:lr]
|
10
|
+
end
|
11
|
+
last_epoch = 0
|
12
|
+
else
|
13
|
+
raise NotImplementedYet
|
14
|
+
end
|
15
|
+
@base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
|
16
|
+
@last_epoch = last_epoch
|
17
|
+
|
18
|
+
@step_count = 0
|
19
|
+
step(last_epoch)
|
20
|
+
end
|
21
|
+
|
22
|
+
def step(epoch = nil)
|
23
|
+
@step_count += 1
|
24
|
+
epoch ||= @last_epoch + 1
|
25
|
+
@last_epoch = epoch
|
26
|
+
@optimizer.param_groups.zip(get_lr).each do |param_group, lr|
|
27
|
+
param_group[:lr] = lr
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class StepLR < LRScheduler
|
5
|
+
def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
|
6
|
+
@step_size = step_size
|
7
|
+
@gamma = gamma
|
8
|
+
super(optimizer, last_epoch)
|
9
|
+
end
|
10
|
+
|
11
|
+
def get_lr
|
12
|
+
@base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Optimizer
|
5
|
+
attr_reader :param_groups
|
6
|
+
|
7
|
+
def initialize(params, defaults)
|
8
|
+
@defaults = defaults
|
9
|
+
@state = Hash.new { |hash, key| hash[key] = {} }
|
10
|
+
@param_groups = []
|
11
|
+
|
12
|
+
param_groups = params
|
13
|
+
if param_groups.empty?
|
14
|
+
raise ArgumentError, "optimizer got an empty parameter list"
|
15
|
+
end
|
16
|
+
if !param_groups[0].is_a?(Hash)
|
17
|
+
param_groups = [{params: param_groups}]
|
18
|
+
end
|
19
|
+
|
20
|
+
param_groups.each do |param_group|
|
21
|
+
add_param_group(param_group)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def add_param_group(param_group)
|
26
|
+
# TODO more advanced logic
|
27
|
+
@param_groups << @defaults.merge(param_group)
|
28
|
+
end
|
29
|
+
|
30
|
+
def load_state_dict(state_dict)
|
31
|
+
raise NotImplementedYet
|
32
|
+
end
|
33
|
+
|
34
|
+
def state_dict
|
35
|
+
pack_group = lambda do |group|
|
36
|
+
packed = group.select { |k, _| k != :params }.to_h
|
37
|
+
packed[:params] = group[:params].map { |p| p.object_id }
|
38
|
+
packed
|
39
|
+
end
|
40
|
+
|
41
|
+
param_groups = @param_groups.map { |g| pack_group.call(g) }
|
42
|
+
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
|
+
|
44
|
+
{
|
45
|
+
state: packed_state,
|
46
|
+
param_groups: param_groups
|
47
|
+
}
|
48
|
+
end
|
49
|
+
|
50
|
+
def zero_grad
|
51
|
+
@param_groups.each do |group|
|
52
|
+
group[:params].each do |p|
|
53
|
+
if p.grad
|
54
|
+
p.grad.detach!
|
55
|
+
p.grad.zero!
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class RMSprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "RMSprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:square_avg] = Torch.zeros_like(p.data)
|
35
|
+
if group[:momentum] > 0
|
36
|
+
state[:momentum_buffer] = Torch.zeros_like(p.data)
|
37
|
+
end
|
38
|
+
if group[:centered]
|
39
|
+
state[:grad_avg] = Torch.zeros_like(p.data)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
square_avg = state[:square_avg]
|
44
|
+
alpha = group[:alpha]
|
45
|
+
|
46
|
+
state[:step] += 1
|
47
|
+
|
48
|
+
if group[:weight_decay] != 0
|
49
|
+
grad = grad.add(group[:weight_decay], p.data)
|
50
|
+
end
|
51
|
+
|
52
|
+
square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
|
53
|
+
|
54
|
+
if group[:centered]
|
55
|
+
grad_avg = state[:grad_avg]
|
56
|
+
grad_avg.mul!(alpha).add!(1 - alpha, grad)
|
57
|
+
avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
|
58
|
+
else
|
59
|
+
avg = square_avg.sqrt.add!(group[:eps])
|
60
|
+
end
|
61
|
+
|
62
|
+
if group[:momentum] > 0
|
63
|
+
buf = state[:momentum_buffer]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
+
p.data.add!(-group[:lr], buf)
|
66
|
+
else
|
67
|
+
p.data.addcdiv!(-group[:lr], grad, avg)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
loss
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Rprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
|
8
|
+
|
9
|
+
defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
# TODO implement []=
|
15
|
+
raise NotImplementedYet
|
16
|
+
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Rprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:prev] = Torch.zeros_like(p.data)
|
35
|
+
state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
|
36
|
+
end
|
37
|
+
|
38
|
+
etaminus, etaplus = group[:etas]
|
39
|
+
step_size_min, step_size_max = group[:step_sizes]
|
40
|
+
step_size = state[:step_size]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
sign = grad.mul(state[:prev]).sign
|
45
|
+
sign[sign.gt(0)] = etaplus
|
46
|
+
sign[sign.lt(0)] = etaminus
|
47
|
+
sign[sign.eq(0)] = 1
|
48
|
+
|
49
|
+
# update stepsizes with step size updates
|
50
|
+
step_size.mul!(sign).clamp!(step_size_min, step_size_max)
|
51
|
+
|
52
|
+
# for dir<0, dfdx=0
|
53
|
+
# for dir>=0 dfdx=dfdx
|
54
|
+
grad = grad.clone
|
55
|
+
grad[sign.eq(etaminus)] = 0
|
56
|
+
|
57
|
+
# update parameters
|
58
|
+
p.data.addcmul!(-1, grad.sign, step_size)
|
59
|
+
|
60
|
+
state[:prev].copy!(grad)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class SGD < Optimizer
|
5
|
+
def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
|
7
|
+
raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
|
8
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
|
9
|
+
|
10
|
+
defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
|
11
|
+
|
12
|
+
if nesterov && (momentum <= 0 || dampening != 0)
|
13
|
+
raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
|
14
|
+
end
|
15
|
+
|
16
|
+
super(params, defaults)
|
17
|
+
end
|
18
|
+
|
19
|
+
def step(closure = nil)
|
20
|
+
loss = nil
|
21
|
+
if closure
|
22
|
+
loss = closure.call
|
23
|
+
end
|
24
|
+
|
25
|
+
@param_groups.each do |group|
|
26
|
+
weight_decay = group[:weight_decay]
|
27
|
+
momentum = group[:momentum]
|
28
|
+
dampening = group[:dampening]
|
29
|
+
nesterov = group[:nesterov]
|
30
|
+
|
31
|
+
group[:params].each do |p|
|
32
|
+
next unless p.grad
|
33
|
+
d_p = p.grad.data
|
34
|
+
if weight_decay != 0
|
35
|
+
d_p.add!(weight_decay, p.data)
|
36
|
+
end
|
37
|
+
if momentum != 0
|
38
|
+
param_state = @state[p]
|
39
|
+
if !param_state.key(:momentum_buffer)
|
40
|
+
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
|
+
else
|
42
|
+
buf = param_state[:momentum_buffer]
|
43
|
+
buf.mul!(momentum).add!(1 - dampening, d_p)
|
44
|
+
end
|
45
|
+
if nesterov
|
46
|
+
d_p = d_p.add(momentum, buf)
|
47
|
+
else
|
48
|
+
d_p = buf
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
p.data.add!(-group[:lr], d_p)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
loss
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
data/lib/torch/tensor.rb
ADDED
@@ -0,0 +1,196 @@
|
|
1
|
+
module Torch
|
2
|
+
class Tensor
|
3
|
+
include Comparable
|
4
|
+
include Inspector
|
5
|
+
|
6
|
+
alias_method :requires_grad?, :requires_grad
|
7
|
+
|
8
|
+
def self.new(*size)
|
9
|
+
if size.length == 1 && size.first.is_a?(Tensor)
|
10
|
+
size.first
|
11
|
+
else
|
12
|
+
Torch.empty(*size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
def dtype
|
17
|
+
dtype = ENUM_TO_DTYPE[_dtype]
|
18
|
+
raise Error, "Unknown type: #{_dtype}" unless dtype
|
19
|
+
dtype
|
20
|
+
end
|
21
|
+
|
22
|
+
def layout
|
23
|
+
_layout.downcase.to_sym
|
24
|
+
end
|
25
|
+
|
26
|
+
def to_s
|
27
|
+
inspect
|
28
|
+
end
|
29
|
+
|
30
|
+
def to_a
|
31
|
+
reshape_arr(_data, shape)
|
32
|
+
end
|
33
|
+
|
34
|
+
# TODO support dtype
|
35
|
+
def to(device, non_blocking: false, copy: false)
|
36
|
+
device = Device.new(device) if device.is_a?(String)
|
37
|
+
_to(device, _dtype, non_blocking, copy)
|
38
|
+
end
|
39
|
+
|
40
|
+
def size(dim = nil)
|
41
|
+
if dim
|
42
|
+
_size(dim)
|
43
|
+
else
|
44
|
+
shape
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def shape
|
49
|
+
dim.times.map { |i| size(i) }
|
50
|
+
end
|
51
|
+
|
52
|
+
def view(*size)
|
53
|
+
_view(size)
|
54
|
+
end
|
55
|
+
|
56
|
+
def item
|
57
|
+
if numel != 1
|
58
|
+
raise Error, "only one element tensors can be converted to Ruby scalars"
|
59
|
+
end
|
60
|
+
_data.first
|
61
|
+
end
|
62
|
+
|
63
|
+
# unsure if this is correct
|
64
|
+
def new
|
65
|
+
Torch.empty(0, dtype: dtype)
|
66
|
+
end
|
67
|
+
|
68
|
+
def backward(gradient = nil)
|
69
|
+
if gradient
|
70
|
+
_backward_gradient(gradient)
|
71
|
+
else
|
72
|
+
_backward
|
73
|
+
end
|
74
|
+
end
|
75
|
+
|
76
|
+
# TODO read directly from memory
|
77
|
+
def numo
|
78
|
+
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
79
|
+
cls = Torch._dtype_to_numo[dtype]
|
80
|
+
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
81
|
+
cls.cast(_data).reshape(*shape)
|
82
|
+
end
|
83
|
+
|
84
|
+
def new_ones(*size, **options)
|
85
|
+
Torch.ones_like(Torch.empty(*size), **options)
|
86
|
+
end
|
87
|
+
|
88
|
+
def requires_grad!(requires_grad = true)
|
89
|
+
_requires_grad!(requires_grad)
|
90
|
+
end
|
91
|
+
|
92
|
+
def type(dtype)
|
93
|
+
enum = DTYPE_TO_ENUM[dtype]
|
94
|
+
raise Error, "Unknown type: #{dtype}" unless enum
|
95
|
+
_type(enum)
|
96
|
+
end
|
97
|
+
|
98
|
+
def add!(value = 1, other)
|
99
|
+
if other.is_a?(Numeric)
|
100
|
+
_add_scalar!(other * value)
|
101
|
+
else
|
102
|
+
# need to use alpha for sparse tensors instead of multiplying
|
103
|
+
_add_alpha!(other, value)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
def mul!(other)
|
108
|
+
if other.is_a?(Numeric)
|
109
|
+
_mul_scalar!(other)
|
110
|
+
else
|
111
|
+
_mul!(other)
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
115
|
+
# operations
|
116
|
+
%w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
|
117
|
+
define_method(op) do |*args, **options, &block|
|
118
|
+
if options.any?
|
119
|
+
Torch.send(op, self, *args, **options, &block)
|
120
|
+
else
|
121
|
+
Torch.send(op, self, *args, &block)
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
|
126
|
+
def +(other)
|
127
|
+
add(other)
|
128
|
+
end
|
129
|
+
|
130
|
+
def -(other)
|
131
|
+
sub(other)
|
132
|
+
end
|
133
|
+
|
134
|
+
def *(other)
|
135
|
+
mul(other)
|
136
|
+
end
|
137
|
+
|
138
|
+
def /(other)
|
139
|
+
div(other)
|
140
|
+
end
|
141
|
+
|
142
|
+
def %(other)
|
143
|
+
remainder(other)
|
144
|
+
end
|
145
|
+
|
146
|
+
def **(other)
|
147
|
+
pow(other)
|
148
|
+
end
|
149
|
+
|
150
|
+
def -@
|
151
|
+
neg
|
152
|
+
end
|
153
|
+
|
154
|
+
def <=>(other)
|
155
|
+
item <=> other
|
156
|
+
end
|
157
|
+
|
158
|
+
# based on python_variable_indexing.cpp
|
159
|
+
def [](*indexes)
|
160
|
+
result = self
|
161
|
+
dim = 0
|
162
|
+
indexes.each do |index|
|
163
|
+
if index.is_a?(Numeric)
|
164
|
+
result = result._select(dim, index)
|
165
|
+
elsif index.is_a?(Range)
|
166
|
+
finish = index.end
|
167
|
+
finish += 1 unless index.exclude_end?
|
168
|
+
result = result._slice(dim, index.begin, finish, 1)
|
169
|
+
dim += 1
|
170
|
+
else
|
171
|
+
raise Error, "Unsupported index type"
|
172
|
+
end
|
173
|
+
end
|
174
|
+
result
|
175
|
+
end
|
176
|
+
|
177
|
+
# TODO
|
178
|
+
# based on python_variable_indexing.cpp
|
179
|
+
# def []=(index, value)
|
180
|
+
# end
|
181
|
+
|
182
|
+
private
|
183
|
+
|
184
|
+
def reshape_arr(arr, dims)
|
185
|
+
if dims.empty?
|
186
|
+
arr
|
187
|
+
else
|
188
|
+
arr = arr.flatten
|
189
|
+
dims[1..-1].reverse.each do |dim|
|
190
|
+
arr = arr.each_slice(dim)
|
191
|
+
end
|
192
|
+
arr.to_a
|
193
|
+
end
|
194
|
+
end
|
195
|
+
end
|
196
|
+
end
|