torch-rb 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
data/ext/torch/utils.h
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <rice/Symbol.hpp>
|
4
|
+
|
5
|
+
// keep THP prefix for now to make it easier to compare code
|
6
|
+
|
7
|
+
extern VALUE THPVariableClass;
|
8
|
+
|
9
|
+
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
+
return Symbol(str);
|
11
|
+
}
|
12
|
+
|
13
|
+
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
14
|
+
Check_Type(obj, T_SYMBOL);
|
15
|
+
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
16
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
17
|
+
}
|
18
|
+
|
19
|
+
inline std::string THPUtils_unpackString(VALUE obj) {
|
20
|
+
Check_Type(obj, T_STRING);
|
21
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
22
|
+
}
|
23
|
+
|
24
|
+
inline bool THPUtils_checkSymbol(VALUE obj) {
|
25
|
+
return SYMBOL_P(obj);
|
26
|
+
}
|
27
|
+
|
28
|
+
inline bool THPUtils_checkIndex(VALUE obj) {
|
29
|
+
return FIXNUM_P(obj);
|
30
|
+
}
|
31
|
+
|
32
|
+
inline bool THPUtils_checkScalar(VALUE obj) {
|
33
|
+
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
34
|
+
}
|
35
|
+
|
36
|
+
inline bool THPVariable_Check(VALUE obj) {
|
37
|
+
return rb_obj_is_kind_of(obj, THPVariableClass);
|
38
|
+
}
|
39
|
+
|
40
|
+
inline bool THPVariable_CheckExact(VALUE obj) {
|
41
|
+
return rb_obj_is_instance_of(obj, THPVariableClass);
|
42
|
+
}
|
@@ -1,43 +1,44 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
1
3
|
#include <torch/torch.h>
|
2
4
|
#include <rice/Object.hpp>
|
3
|
-
#include "templates.hpp"
|
4
5
|
|
5
|
-
Object wrap(bool x) {
|
6
|
+
inline Object wrap(bool x) {
|
6
7
|
return to_ruby<bool>(x);
|
7
8
|
}
|
8
9
|
|
9
|
-
Object wrap(int64_t x) {
|
10
|
+
inline Object wrap(int64_t x) {
|
10
11
|
return to_ruby<int64_t>(x);
|
11
12
|
}
|
12
13
|
|
13
|
-
Object wrap(double x) {
|
14
|
+
inline Object wrap(double x) {
|
14
15
|
return to_ruby<double>(x);
|
15
16
|
}
|
16
17
|
|
17
|
-
Object wrap(torch::Tensor x) {
|
18
|
+
inline Object wrap(torch::Tensor x) {
|
18
19
|
return to_ruby<torch::Tensor>(x);
|
19
20
|
}
|
20
21
|
|
21
|
-
Object wrap(torch::Scalar x) {
|
22
|
+
inline Object wrap(torch::Scalar x) {
|
22
23
|
return to_ruby<torch::Scalar>(x);
|
23
24
|
}
|
24
25
|
|
25
|
-
Object wrap(torch::ScalarType x) {
|
26
|
+
inline Object wrap(torch::ScalarType x) {
|
26
27
|
return to_ruby<torch::ScalarType>(x);
|
27
28
|
}
|
28
29
|
|
29
|
-
Object wrap(torch::QScheme x) {
|
30
|
+
inline Object wrap(torch::QScheme x) {
|
30
31
|
return to_ruby<torch::QScheme>(x);
|
31
32
|
}
|
32
33
|
|
33
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
34
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
34
35
|
Array a;
|
35
36
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
36
37
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
37
38
|
return Object(a);
|
38
39
|
}
|
39
40
|
|
40
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
41
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
41
42
|
Array a;
|
42
43
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
43
44
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -45,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
|
45
46
|
return Object(a);
|
46
47
|
}
|
47
48
|
|
48
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
49
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
49
50
|
Array a;
|
50
51
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
51
52
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -54,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
54
55
|
return Object(a);
|
55
56
|
}
|
56
57
|
|
57
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
58
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
58
59
|
Array a;
|
59
60
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
60
61
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -64,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
64
65
|
return Object(a);
|
65
66
|
}
|
66
67
|
|
67
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
68
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
68
69
|
Array a;
|
69
70
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
70
71
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -73,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
|
|
73
74
|
return Object(a);
|
74
75
|
}
|
75
76
|
|
76
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
77
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
77
78
|
Array a;
|
78
79
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
79
80
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -82,7 +83,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
82
83
|
return Object(a);
|
83
84
|
}
|
84
85
|
|
85
|
-
Object wrap(
|
86
|
+
inline Object wrap(torch::TensorList x) {
|
86
87
|
Array a;
|
87
88
|
for (auto& t : x) {
|
88
89
|
a.push(to_ruby<torch::Tensor>(t));
|
data/lib/torch.rb
CHANGED
@@ -7,11 +7,6 @@ require "net/http"
|
|
7
7
|
require "set"
|
8
8
|
require "tmpdir"
|
9
9
|
|
10
|
-
# native functions
|
11
|
-
require "torch/native/generator"
|
12
|
-
require "torch/native/parser"
|
13
|
-
require "torch/native/dispatcher"
|
14
|
-
|
15
10
|
# modules
|
16
11
|
require "torch/inspector"
|
17
12
|
require "torch/tensor"
|
@@ -374,63 +369,6 @@ module Torch
|
|
374
369
|
|
375
370
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
376
371
|
|
377
|
-
def arange(start, finish = nil, step = 1, **options)
|
378
|
-
# ruby doesn't support start = 0, finish, step = 1, ...
|
379
|
-
if finish.nil?
|
380
|
-
finish = start
|
381
|
-
start = 0
|
382
|
-
end
|
383
|
-
_arange(start, finish, step, tensor_options(**options))
|
384
|
-
end
|
385
|
-
|
386
|
-
def empty(*size, **options)
|
387
|
-
_empty(tensor_size(size), tensor_options(**options))
|
388
|
-
end
|
389
|
-
|
390
|
-
def eye(n, m = nil, **options)
|
391
|
-
_eye(n, m || n, tensor_options(**options))
|
392
|
-
end
|
393
|
-
|
394
|
-
def full(size, fill_value, **options)
|
395
|
-
_full(size, fill_value, tensor_options(**options))
|
396
|
-
end
|
397
|
-
|
398
|
-
def linspace(start, finish, steps = 100, **options)
|
399
|
-
_linspace(start, finish, steps, tensor_options(**options))
|
400
|
-
end
|
401
|
-
|
402
|
-
def logspace(start, finish, steps = 100, base = 10.0, **options)
|
403
|
-
_logspace(start, finish, steps, base, tensor_options(**options))
|
404
|
-
end
|
405
|
-
|
406
|
-
def ones(*size, **options)
|
407
|
-
_ones(tensor_size(size), tensor_options(**options))
|
408
|
-
end
|
409
|
-
|
410
|
-
def rand(*size, **options)
|
411
|
-
_rand(tensor_size(size), tensor_options(**options))
|
412
|
-
end
|
413
|
-
|
414
|
-
def randint(low = 0, high, size, **options)
|
415
|
-
_randint(low, high, size, tensor_options(**options))
|
416
|
-
end
|
417
|
-
|
418
|
-
def randn(*size, **options)
|
419
|
-
_randn(tensor_size(size), tensor_options(**options))
|
420
|
-
end
|
421
|
-
|
422
|
-
def randperm(n, **options)
|
423
|
-
# dtype hack in Python
|
424
|
-
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
425
|
-
options[:dtype] ||= :int64
|
426
|
-
|
427
|
-
_randperm(n, tensor_options(**options))
|
428
|
-
end
|
429
|
-
|
430
|
-
def zeros(*size, **options)
|
431
|
-
_zeros(tensor_size(size), tensor_options(**options))
|
432
|
-
end
|
433
|
-
|
434
372
|
def tensor(data, **options)
|
435
373
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
436
374
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -394,15 +394,15 @@ module Torch
|
|
394
394
|
# loss functions
|
395
395
|
|
396
396
|
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
397
|
-
NN.binary_cross_entropy(input, target, weight, reduction)
|
397
|
+
NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
|
398
398
|
end
|
399
399
|
|
400
400
|
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
401
|
-
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
401
|
+
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
|
402
402
|
end
|
403
403
|
|
404
404
|
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
405
|
-
Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
|
405
|
+
Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
|
406
406
|
end
|
407
407
|
|
408
408
|
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
@@ -411,34 +411,34 @@ module Torch
|
|
411
411
|
|
412
412
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
413
413
|
# call to_a on input_lengths and target_lengths for C++
|
414
|
-
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
414
|
+
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
|
415
415
|
end
|
416
416
|
|
417
417
|
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
418
|
-
Torch.hinge_embedding_loss(input, target, margin, reduction)
|
418
|
+
Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
|
419
419
|
end
|
420
420
|
|
421
421
|
def kl_div(input, target, reduction: "mean")
|
422
|
-
Torch.kl_div(input, target, reduction)
|
422
|
+
Torch.kl_div(input, target, to_reduction(reduction))
|
423
423
|
end
|
424
424
|
|
425
425
|
def l1_loss(input, target, reduction: "mean")
|
426
|
-
NN.l1_loss(input, target, reduction)
|
426
|
+
NN.l1_loss(input, target, to_reduction(reduction))
|
427
427
|
end
|
428
428
|
|
429
429
|
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
430
|
-
Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
|
430
|
+
Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
|
431
431
|
end
|
432
432
|
|
433
433
|
def mse_loss(input, target, reduction: "mean")
|
434
434
|
if target.size != input.size
|
435
435
|
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
436
436
|
end
|
437
|
-
NN.mse_loss(input, target, reduction)
|
437
|
+
NN.mse_loss(input, target, to_reduction(reduction))
|
438
438
|
end
|
439
439
|
|
440
440
|
def multilabel_margin_loss(input, target, reduction: "mean")
|
441
|
-
NN.multilabel_margin_loss(input, target, reduction)
|
441
|
+
NN.multilabel_margin_loss(input, target, to_reduction(reduction))
|
442
442
|
end
|
443
443
|
|
444
444
|
def multilabel_soft_margin_loss(input, target, weight: nil)
|
@@ -446,27 +446,27 @@ module Torch
|
|
446
446
|
end
|
447
447
|
|
448
448
|
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
449
|
-
NN.multi_margin_loss(input, target, p, margin, weight, reduction)
|
449
|
+
NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
|
450
450
|
end
|
451
451
|
|
452
452
|
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
453
|
-
NN.nll_loss(input, target, weight, reduction, ignore_index)
|
453
|
+
NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
|
454
454
|
end
|
455
455
|
|
456
456
|
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
457
|
-
Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
457
|
+
Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
|
458
458
|
end
|
459
459
|
|
460
460
|
def soft_margin_loss(input, target, reduction: "mean")
|
461
|
-
NN.soft_margin_loss(input, target, reduction)
|
461
|
+
NN.soft_margin_loss(input, target, to_reduction(reduction))
|
462
462
|
end
|
463
463
|
|
464
464
|
def smooth_l1_loss(input, target, reduction: "mean")
|
465
|
-
NN.smooth_l1_loss(input, target, reduction)
|
465
|
+
NN.smooth_l1_loss(input, target, to_reduction(reduction))
|
466
466
|
end
|
467
467
|
|
468
468
|
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
469
|
-
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
469
|
+
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
|
470
470
|
end
|
471
471
|
|
472
472
|
# vision
|
@@ -542,6 +542,20 @@ module Torch
|
|
542
542
|
|
543
543
|
private
|
544
544
|
|
545
|
+
# see _reduction.py
|
546
|
+
def to_reduction(v)
|
547
|
+
case v.to_s
|
548
|
+
when "none"
|
549
|
+
0
|
550
|
+
when "mean"
|
551
|
+
1
|
552
|
+
when "sum"
|
553
|
+
2
|
554
|
+
else
|
555
|
+
raise ArgumentError, "#{v} is not a valid value for reduction"
|
556
|
+
end
|
557
|
+
end
|
558
|
+
|
545
559
|
def softmax_dim(ndim)
|
546
560
|
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
547
561
|
end
|
data/lib/torch/nn/init.rb
CHANGED
@@ -14,25 +14,11 @@ module Torch
|
|
14
14
|
_normal!(tensor, mean, std)
|
15
15
|
end
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
_ones!(tensor)
|
23
|
-
end
|
24
|
-
|
25
|
-
def zeros!(tensor)
|
26
|
-
_zeros!(tensor)
|
27
|
-
end
|
28
|
-
|
29
|
-
def eye!(tensor)
|
30
|
-
_eye!(tensor)
|
31
|
-
end
|
32
|
-
|
33
|
-
def dirac!(tensor)
|
34
|
-
_dirac!(tensor)
|
35
|
-
end
|
17
|
+
alias_method :constant!, :_constant!
|
18
|
+
alias_method :ones!, :_ones!
|
19
|
+
alias_method :zeros!, :_zeros!
|
20
|
+
alias_method :eye!, :_eye!
|
21
|
+
alias_method :dirac!, :_dirac!
|
36
22
|
|
37
23
|
def xavier_uniform!(tensor, gain: 1.0)
|
38
24
|
_xavier_uniform!(tensor, gain)
|
data/lib/torch/optim/adadelta.rb
CHANGED
@@ -45,7 +45,7 @@ module Torch
|
|
45
45
|
square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
|
46
46
|
std = square_avg.add(eps).sqrt!
|
47
47
|
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
|
-
p.data.add!(-group[:lr]
|
48
|
+
p.data.add!(delta, alpha: -group[:lr])
|
49
49
|
acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
|
50
50
|
end
|
51
51
|
end
|
data/lib/torch/optim/adam.rb
CHANGED
@@ -53,11 +53,11 @@ module Torch
|
|
53
53
|
bias_correction2 = 1 - beta2 ** state[:step]
|
54
54
|
|
55
55
|
if group[:weight_decay] != 0
|
56
|
-
grad.add!(group[:weight_decay]
|
56
|
+
grad.add!(p.data, alpha: group[:weight_decay])
|
57
57
|
end
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
60
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
61
|
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
data/lib/torch/optim/adamax.rb
CHANGED
@@ -46,7 +46,7 @@ module Torch
|
|
46
46
|
end
|
47
47
|
|
48
48
|
# Update biased first moment estimate.
|
49
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
49
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
50
50
|
# Update the exponentially weighted infinity norm.
|
51
51
|
norm_buf = Torch.cat([
|
52
52
|
exp_inf.mul!(beta2).unsqueeze(0),
|
data/lib/torch/optim/adamw.rb
CHANGED
@@ -58,7 +58,7 @@ module Torch
|
|
58
58
|
bias_correction2 = 1 - beta2 ** state[:step]
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
61
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
62
|
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
data/lib/torch/optim/asgd.rb
CHANGED
data/lib/torch/optim/sgd.rb
CHANGED
@@ -32,7 +32,7 @@ module Torch
|
|
32
32
|
next unless p.grad
|
33
33
|
d_p = p.grad.data
|
34
34
|
if weight_decay != 0
|
35
|
-
d_p.add!(
|
35
|
+
d_p.add!(p.data, alpha: weight_decay)
|
36
36
|
end
|
37
37
|
if momentum != 0
|
38
38
|
param_state = @state[p]
|
@@ -40,7 +40,7 @@ module Torch
|
|
40
40
|
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
41
|
else
|
42
42
|
buf = param_state[:momentum_buffer]
|
43
|
-
buf.mul!(momentum).add!(1 - dampening
|
43
|
+
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
|
44
44
|
end
|
45
45
|
if nesterov
|
46
46
|
d_p = d_p.add(momentum, buf)
|
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
end
|
50
50
|
end
|
51
51
|
|
52
|
-
p.data.add!(-group[:lr]
|
52
|
+
p.data.add!(d_p, alpha: -group[:lr])
|
53
53
|
end
|
54
54
|
end
|
55
55
|
|