torch-rb 0.3.7 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
data/ext/torch/utils.h
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <rice/Symbol.hpp>
|
4
|
+
|
5
|
+
// keep THP prefix for now to make it easier to compare code
|
6
|
+
|
7
|
+
extern VALUE THPVariableClass;
|
8
|
+
|
9
|
+
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
+
return Symbol(str);
|
11
|
+
}
|
12
|
+
|
13
|
+
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
14
|
+
Check_Type(obj, T_SYMBOL);
|
15
|
+
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
16
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
17
|
+
}
|
18
|
+
|
19
|
+
inline std::string THPUtils_unpackString(VALUE obj) {
|
20
|
+
Check_Type(obj, T_STRING);
|
21
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
22
|
+
}
|
23
|
+
|
24
|
+
inline bool THPUtils_checkSymbol(VALUE obj) {
|
25
|
+
return SYMBOL_P(obj);
|
26
|
+
}
|
27
|
+
|
28
|
+
inline bool THPUtils_checkIndex(VALUE obj) {
|
29
|
+
return FIXNUM_P(obj);
|
30
|
+
}
|
31
|
+
|
32
|
+
inline bool THPUtils_checkScalar(VALUE obj) {
|
33
|
+
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
34
|
+
}
|
35
|
+
|
36
|
+
inline bool THPVariable_Check(VALUE obj) {
|
37
|
+
return rb_obj_is_kind_of(obj, THPVariableClass);
|
38
|
+
}
|
39
|
+
|
40
|
+
inline bool THPVariable_CheckExact(VALUE obj) {
|
41
|
+
return rb_obj_is_instance_of(obj, THPVariableClass);
|
42
|
+
}
|
@@ -1,43 +1,44 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
1
3
|
#include <torch/torch.h>
|
2
4
|
#include <rice/Object.hpp>
|
3
|
-
#include "templates.hpp"
|
4
5
|
|
5
|
-
Object wrap(bool x) {
|
6
|
+
inline Object wrap(bool x) {
|
6
7
|
return to_ruby<bool>(x);
|
7
8
|
}
|
8
9
|
|
9
|
-
Object wrap(int64_t x) {
|
10
|
+
inline Object wrap(int64_t x) {
|
10
11
|
return to_ruby<int64_t>(x);
|
11
12
|
}
|
12
13
|
|
13
|
-
Object wrap(double x) {
|
14
|
+
inline Object wrap(double x) {
|
14
15
|
return to_ruby<double>(x);
|
15
16
|
}
|
16
17
|
|
17
|
-
Object wrap(torch::Tensor x) {
|
18
|
+
inline Object wrap(torch::Tensor x) {
|
18
19
|
return to_ruby<torch::Tensor>(x);
|
19
20
|
}
|
20
21
|
|
21
|
-
Object wrap(torch::Scalar x) {
|
22
|
+
inline Object wrap(torch::Scalar x) {
|
22
23
|
return to_ruby<torch::Scalar>(x);
|
23
24
|
}
|
24
25
|
|
25
|
-
Object wrap(torch::ScalarType x) {
|
26
|
+
inline Object wrap(torch::ScalarType x) {
|
26
27
|
return to_ruby<torch::ScalarType>(x);
|
27
28
|
}
|
28
29
|
|
29
|
-
Object wrap(torch::QScheme x) {
|
30
|
+
inline Object wrap(torch::QScheme x) {
|
30
31
|
return to_ruby<torch::QScheme>(x);
|
31
32
|
}
|
32
33
|
|
33
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
34
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
34
35
|
Array a;
|
35
36
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
36
37
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
37
38
|
return Object(a);
|
38
39
|
}
|
39
40
|
|
40
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
41
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
41
42
|
Array a;
|
42
43
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
43
44
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -45,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
|
45
46
|
return Object(a);
|
46
47
|
}
|
47
48
|
|
48
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
49
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
49
50
|
Array a;
|
50
51
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
51
52
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -54,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
54
55
|
return Object(a);
|
55
56
|
}
|
56
57
|
|
57
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
58
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
58
59
|
Array a;
|
59
60
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
60
61
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -64,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
64
65
|
return Object(a);
|
65
66
|
}
|
66
67
|
|
67
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
68
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
68
69
|
Array a;
|
69
70
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
70
71
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -73,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
|
|
73
74
|
return Object(a);
|
74
75
|
}
|
75
76
|
|
76
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
77
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
77
78
|
Array a;
|
78
79
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
79
80
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -82,7 +83,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
82
83
|
return Object(a);
|
83
84
|
}
|
84
85
|
|
85
|
-
Object wrap(
|
86
|
+
inline Object wrap(torch::TensorList x) {
|
86
87
|
Array a;
|
87
88
|
for (auto& t : x) {
|
88
89
|
a.push(to_ruby<torch::Tensor>(t));
|
data/lib/torch.rb
CHANGED
@@ -7,11 +7,6 @@ require "net/http"
|
|
7
7
|
require "set"
|
8
8
|
require "tmpdir"
|
9
9
|
|
10
|
-
# native functions
|
11
|
-
require "torch/native/generator"
|
12
|
-
require "torch/native/parser"
|
13
|
-
require "torch/native/dispatcher"
|
14
|
-
|
15
10
|
# modules
|
16
11
|
require "torch/inspector"
|
17
12
|
require "torch/tensor"
|
@@ -374,63 +369,6 @@ module Torch
|
|
374
369
|
|
375
370
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
376
371
|
|
377
|
-
def arange(start, finish = nil, step = 1, **options)
|
378
|
-
# ruby doesn't support start = 0, finish, step = 1, ...
|
379
|
-
if finish.nil?
|
380
|
-
finish = start
|
381
|
-
start = 0
|
382
|
-
end
|
383
|
-
_arange(start, finish, step, tensor_options(**options))
|
384
|
-
end
|
385
|
-
|
386
|
-
def empty(*size, **options)
|
387
|
-
_empty(tensor_size(size), tensor_options(**options))
|
388
|
-
end
|
389
|
-
|
390
|
-
def eye(n, m = nil, **options)
|
391
|
-
_eye(n, m || n, tensor_options(**options))
|
392
|
-
end
|
393
|
-
|
394
|
-
def full(size, fill_value, **options)
|
395
|
-
_full(size, fill_value, tensor_options(**options))
|
396
|
-
end
|
397
|
-
|
398
|
-
def linspace(start, finish, steps = 100, **options)
|
399
|
-
_linspace(start, finish, steps, tensor_options(**options))
|
400
|
-
end
|
401
|
-
|
402
|
-
def logspace(start, finish, steps = 100, base = 10.0, **options)
|
403
|
-
_logspace(start, finish, steps, base, tensor_options(**options))
|
404
|
-
end
|
405
|
-
|
406
|
-
def ones(*size, **options)
|
407
|
-
_ones(tensor_size(size), tensor_options(**options))
|
408
|
-
end
|
409
|
-
|
410
|
-
def rand(*size, **options)
|
411
|
-
_rand(tensor_size(size), tensor_options(**options))
|
412
|
-
end
|
413
|
-
|
414
|
-
def randint(low = 0, high, size, **options)
|
415
|
-
_randint(low, high, size, tensor_options(**options))
|
416
|
-
end
|
417
|
-
|
418
|
-
def randn(*size, **options)
|
419
|
-
_randn(tensor_size(size), tensor_options(**options))
|
420
|
-
end
|
421
|
-
|
422
|
-
def randperm(n, **options)
|
423
|
-
# dtype hack in Python
|
424
|
-
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
425
|
-
options[:dtype] ||= :int64
|
426
|
-
|
427
|
-
_randperm(n, tensor_options(**options))
|
428
|
-
end
|
429
|
-
|
430
|
-
def zeros(*size, **options)
|
431
|
-
_zeros(tensor_size(size), tensor_options(**options))
|
432
|
-
end
|
433
|
-
|
434
372
|
def tensor(data, **options)
|
435
373
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
436
374
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -394,15 +394,15 @@ module Torch
|
|
394
394
|
# loss functions
|
395
395
|
|
396
396
|
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
397
|
-
NN.binary_cross_entropy(input, target, weight, reduction)
|
397
|
+
NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
|
398
398
|
end
|
399
399
|
|
400
400
|
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
401
|
-
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
401
|
+
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
|
402
402
|
end
|
403
403
|
|
404
404
|
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
405
|
-
Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
|
405
|
+
Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
|
406
406
|
end
|
407
407
|
|
408
408
|
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
@@ -411,34 +411,34 @@ module Torch
|
|
411
411
|
|
412
412
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
413
413
|
# call to_a on input_lengths and target_lengths for C++
|
414
|
-
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
414
|
+
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
|
415
415
|
end
|
416
416
|
|
417
417
|
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
418
|
-
Torch.hinge_embedding_loss(input, target, margin, reduction)
|
418
|
+
Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
|
419
419
|
end
|
420
420
|
|
421
421
|
def kl_div(input, target, reduction: "mean")
|
422
|
-
Torch.kl_div(input, target, reduction)
|
422
|
+
Torch.kl_div(input, target, to_reduction(reduction))
|
423
423
|
end
|
424
424
|
|
425
425
|
def l1_loss(input, target, reduction: "mean")
|
426
|
-
NN.l1_loss(input, target, reduction)
|
426
|
+
NN.l1_loss(input, target, to_reduction(reduction))
|
427
427
|
end
|
428
428
|
|
429
429
|
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
430
|
-
Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
|
430
|
+
Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
|
431
431
|
end
|
432
432
|
|
433
433
|
def mse_loss(input, target, reduction: "mean")
|
434
434
|
if target.size != input.size
|
435
435
|
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
436
436
|
end
|
437
|
-
NN.mse_loss(input, target, reduction)
|
437
|
+
NN.mse_loss(input, target, to_reduction(reduction))
|
438
438
|
end
|
439
439
|
|
440
440
|
def multilabel_margin_loss(input, target, reduction: "mean")
|
441
|
-
NN.multilabel_margin_loss(input, target, reduction)
|
441
|
+
NN.multilabel_margin_loss(input, target, to_reduction(reduction))
|
442
442
|
end
|
443
443
|
|
444
444
|
def multilabel_soft_margin_loss(input, target, weight: nil)
|
@@ -446,27 +446,27 @@ module Torch
|
|
446
446
|
end
|
447
447
|
|
448
448
|
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
449
|
-
NN.multi_margin_loss(input, target, p, margin, weight, reduction)
|
449
|
+
NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
|
450
450
|
end
|
451
451
|
|
452
452
|
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
453
|
-
NN.nll_loss(input, target, weight, reduction, ignore_index)
|
453
|
+
NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
|
454
454
|
end
|
455
455
|
|
456
456
|
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
457
|
-
Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
457
|
+
Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
|
458
458
|
end
|
459
459
|
|
460
460
|
def soft_margin_loss(input, target, reduction: "mean")
|
461
|
-
NN.soft_margin_loss(input, target, reduction)
|
461
|
+
NN.soft_margin_loss(input, target, to_reduction(reduction))
|
462
462
|
end
|
463
463
|
|
464
464
|
def smooth_l1_loss(input, target, reduction: "mean")
|
465
|
-
NN.smooth_l1_loss(input, target, reduction)
|
465
|
+
NN.smooth_l1_loss(input, target, to_reduction(reduction))
|
466
466
|
end
|
467
467
|
|
468
468
|
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
469
|
-
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
469
|
+
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
|
470
470
|
end
|
471
471
|
|
472
472
|
# vision
|
@@ -542,6 +542,20 @@ module Torch
|
|
542
542
|
|
543
543
|
private
|
544
544
|
|
545
|
+
# see _reduction.py
|
546
|
+
def to_reduction(v)
|
547
|
+
case v.to_s
|
548
|
+
when "none"
|
549
|
+
0
|
550
|
+
when "mean"
|
551
|
+
1
|
552
|
+
when "sum"
|
553
|
+
2
|
554
|
+
else
|
555
|
+
raise ArgumentError, "#{v} is not a valid value for reduction"
|
556
|
+
end
|
557
|
+
end
|
558
|
+
|
545
559
|
def softmax_dim(ndim)
|
546
560
|
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
547
561
|
end
|
data/lib/torch/nn/init.rb
CHANGED
@@ -14,25 +14,11 @@ module Torch
|
|
14
14
|
_normal!(tensor, mean, std)
|
15
15
|
end
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
_ones!(tensor)
|
23
|
-
end
|
24
|
-
|
25
|
-
def zeros!(tensor)
|
26
|
-
_zeros!(tensor)
|
27
|
-
end
|
28
|
-
|
29
|
-
def eye!(tensor)
|
30
|
-
_eye!(tensor)
|
31
|
-
end
|
32
|
-
|
33
|
-
def dirac!(tensor)
|
34
|
-
_dirac!(tensor)
|
35
|
-
end
|
17
|
+
alias_method :constant!, :_constant!
|
18
|
+
alias_method :ones!, :_ones!
|
19
|
+
alias_method :zeros!, :_zeros!
|
20
|
+
alias_method :eye!, :_eye!
|
21
|
+
alias_method :dirac!, :_dirac!
|
36
22
|
|
37
23
|
def xavier_uniform!(tensor, gain: 1.0)
|
38
24
|
_xavier_uniform!(tensor, gain)
|
data/lib/torch/optim/adadelta.rb
CHANGED
@@ -45,7 +45,7 @@ module Torch
|
|
45
45
|
square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
|
46
46
|
std = square_avg.add(eps).sqrt!
|
47
47
|
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
|
-
p.data.add!(-group[:lr]
|
48
|
+
p.data.add!(delta, alpha: -group[:lr])
|
49
49
|
acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
|
50
50
|
end
|
51
51
|
end
|
data/lib/torch/optim/adam.rb
CHANGED
@@ -53,11 +53,11 @@ module Torch
|
|
53
53
|
bias_correction2 = 1 - beta2 ** state[:step]
|
54
54
|
|
55
55
|
if group[:weight_decay] != 0
|
56
|
-
grad.add!(group[:weight_decay]
|
56
|
+
grad.add!(p.data, alpha: group[:weight_decay])
|
57
57
|
end
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
60
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
61
|
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
data/lib/torch/optim/adamax.rb
CHANGED
@@ -46,7 +46,7 @@ module Torch
|
|
46
46
|
end
|
47
47
|
|
48
48
|
# Update biased first moment estimate.
|
49
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
49
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
50
50
|
# Update the exponentially weighted infinity norm.
|
51
51
|
norm_buf = Torch.cat([
|
52
52
|
exp_inf.mul!(beta2).unsqueeze(0),
|
data/lib/torch/optim/adamw.rb
CHANGED
@@ -58,7 +58,7 @@ module Torch
|
|
58
58
|
bias_correction2 = 1 - beta2 ** state[:step]
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
61
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
62
|
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
data/lib/torch/optim/asgd.rb
CHANGED
data/lib/torch/optim/sgd.rb
CHANGED
@@ -32,7 +32,7 @@ module Torch
|
|
32
32
|
next unless p.grad
|
33
33
|
d_p = p.grad.data
|
34
34
|
if weight_decay != 0
|
35
|
-
d_p.add!(
|
35
|
+
d_p.add!(p.data, alpha: weight_decay)
|
36
36
|
end
|
37
37
|
if momentum != 0
|
38
38
|
param_state = @state[p]
|
@@ -40,7 +40,7 @@ module Torch
|
|
40
40
|
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
41
|
else
|
42
42
|
buf = param_state[:momentum_buffer]
|
43
|
-
buf.mul!(momentum).add!(1 - dampening
|
43
|
+
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
|
44
44
|
end
|
45
45
|
if nesterov
|
46
46
|
d_p = d_p.add(momentum, buf)
|
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
end
|
50
50
|
end
|
51
51
|
|
52
|
-
p.data.add!(-group[:lr]
|
52
|
+
p.data.add!(d_p, alpha: -group[:lr])
|
53
53
|
end
|
54
54
|
end
|
55
55
|
|