torch-rb 0.7.0 → 0.8.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/README.md +23 -41
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +41 -4
- data/codegen/native_functions.yaml +2007 -1327
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/ext.cpp +8 -0
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.h +7 -1
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +1 -0
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +30 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/parameter.rb +6 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +12 -0
- data/lib/torch/tensor.rb +8 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +6 -0
- metadata +18 -3
@@ -0,0 +1,92 @@
|
|
1
|
+
require_relative 'transformer_encoder_layer'
|
2
|
+
require_relative 'transformer_encoder'
|
3
|
+
require_relative 'transformer_decoder_layer'
|
4
|
+
require_relative 'transformer_decoder'
|
5
|
+
|
6
|
+
module Torch
|
7
|
+
module NN
|
8
|
+
class Transformer < Module
|
9
|
+
def initialize(
|
10
|
+
d_model: 512, nhead: 8,
|
11
|
+
num_encoder_layers: 6, num_decoder_layers: 6,
|
12
|
+
dim_feedforward: 2048, dropout: 0.1, activation: :relu,
|
13
|
+
custom_encoder: nil, custom_decoder: nil,
|
14
|
+
layer_norm_eps: 1e-5, batch_first: false
|
15
|
+
)
|
16
|
+
|
17
|
+
super()
|
18
|
+
|
19
|
+
@encoder =
|
20
|
+
if custom_encoder
|
21
|
+
custom_encoder
|
22
|
+
else
|
23
|
+
encoder_layer = TransformerEncoderLayer.new(
|
24
|
+
d_model, nhead,
|
25
|
+
dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
|
26
|
+
layer_norm_eps: layer_norm_eps, batch_first: batch_first
|
27
|
+
)
|
28
|
+
encoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
|
29
|
+
TransformerEncoder.new(encoder_layer, num_encoder_layers, norm: encoder_norm)
|
30
|
+
end
|
31
|
+
|
32
|
+
@decoder =
|
33
|
+
if custom_decoder
|
34
|
+
custom_decoder
|
35
|
+
else
|
36
|
+
decoder_layer = TransformerDecoderLayer.new(
|
37
|
+
d_model, nhead,
|
38
|
+
dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
|
39
|
+
layer_norm_eps: layer_norm_eps, batch_first: batch_first
|
40
|
+
)
|
41
|
+
decoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
|
42
|
+
TransformerDecoder.new(decoder_layer, num_decoder_layers, norm: decoder_norm)
|
43
|
+
end
|
44
|
+
|
45
|
+
reset_parameters
|
46
|
+
|
47
|
+
@d_model = d_model
|
48
|
+
@nhead = nhead
|
49
|
+
@batch_first = batch_first
|
50
|
+
end
|
51
|
+
|
52
|
+
attr_reader :d_model, :nhead, :encoder, :decoder
|
53
|
+
|
54
|
+
def batch_first?
|
55
|
+
!!@batch_first
|
56
|
+
end
|
57
|
+
|
58
|
+
def reset_parameters
|
59
|
+
parameters.each { |p| Init.xavier_uniform!(p) if p.dim > 1 }
|
60
|
+
end
|
61
|
+
|
62
|
+
def forward(
|
63
|
+
src, tgt,
|
64
|
+
src_mask: nil, tgt_mask: nil, memory_mask: nil,
|
65
|
+
src_key_padding_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil
|
66
|
+
)
|
67
|
+
|
68
|
+
if (!batch_first? && src.size(1) != tgt.size(1)) ||
|
69
|
+
(batch_first? && src.size(0) != tgt.size(0))
|
70
|
+
|
71
|
+
raise ArgumentError, "The batch number of src and tgt must be equal"
|
72
|
+
end
|
73
|
+
|
74
|
+
if src.size(2) != d_model || tgt.size(2) != d_model
|
75
|
+
raise ArgumentError, "The feature number of src and tgt must be equal to d_model"
|
76
|
+
end
|
77
|
+
|
78
|
+
memory = @encoder.(src, mask: src_mask, src_key_padding_mask: src_key_padding_mask)
|
79
|
+
@decoder.(
|
80
|
+
tgt, memory,
|
81
|
+
tgt_mask: tgt_mask, memory_mask: memory_mask,
|
82
|
+
tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask
|
83
|
+
)
|
84
|
+
end
|
85
|
+
|
86
|
+
def generate_square_subsequent_mask(sz)
|
87
|
+
mask = Torch.triu(Torch.ones([sz, sz])).eq(1).transpose(0, 1)
|
88
|
+
mask.float.masked_fill!(mask.eq(0), -Float::INFINITY).masked_fill!(mask.eq(1), 0.0)
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class TransformerDecoder < Module
|
4
|
+
def initialize(decoder_layer, num_layers, norm: nil)
|
5
|
+
super()
|
6
|
+
|
7
|
+
@layers = _clones(decoder_layer, num_layers)
|
8
|
+
@num_layers = num_layers
|
9
|
+
@norm = norm
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
|
13
|
+
output = tgt
|
14
|
+
|
15
|
+
@layers.each do |mod|
|
16
|
+
output = mod.call(output, memory, tgt_mask: tgt_mask, memory_mask: memory_mask, tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask)
|
17
|
+
end
|
18
|
+
|
19
|
+
output = @norm.call(output) if @norm
|
20
|
+
|
21
|
+
output
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,43 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class TransformerDecoderLayer < Module
|
4
|
+
def initialize(
|
5
|
+
d_model, n_head,
|
6
|
+
dim_feedforward: 2048, dropout: 0.1, activation: :relu,
|
7
|
+
layer_norm_eps: 1e-5, batch_first: false
|
8
|
+
)
|
9
|
+
|
10
|
+
super()
|
11
|
+
|
12
|
+
@self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
|
13
|
+
@multihead_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
|
14
|
+
|
15
|
+
@linear1 = Linear.new(d_model, dim_feedforward)
|
16
|
+
@dropout = Dropout.new(p: dropout)
|
17
|
+
@linear2 = Linear.new(dim_feedforward, d_model)
|
18
|
+
|
19
|
+
@norm1 = LayerNorm.new(d_model, eps: layer_norm_eps)
|
20
|
+
@norm2 = LayerNorm.new(d_model, eps: layer_norm_eps)
|
21
|
+
@norm3 = LayerNorm.new(d_model, eps: layer_norm_eps)
|
22
|
+
|
23
|
+
@dropout1 = Dropout.new(p: dropout)
|
24
|
+
@dropout2 = Dropout.new(p: dropout)
|
25
|
+
@dropout3 = Dropout.new(p: dropout)
|
26
|
+
|
27
|
+
@activation = _activation_fn(activation)
|
28
|
+
end
|
29
|
+
|
30
|
+
def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
|
31
|
+
tgt2 = @self_attn.(tgt, tgt, tgt, attn_mask: tgt_mask, key_padding_mask: tgt_key_padding_mask).first
|
32
|
+
tgt += @dropout1.(tgt2)
|
33
|
+
tgt = @norm1.(tgt)
|
34
|
+
tgt2 = @multihead_attn.(tgt, memory, memory, attn_mask: memory_mask, key_padding_mask: memory_key_padding_mask).first
|
35
|
+
tgt += @dropout2.(tgt2)
|
36
|
+
tgt = @norm2.(tgt)
|
37
|
+
tgt2 = @linear2.(@dropout.(@activation.(@linear1.(tgt))))
|
38
|
+
tgt += @dropout3.(tgt2)
|
39
|
+
@norm3.(tgt)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class TransformerEncoder < Module
|
4
|
+
def initialize(encoder_layer, num_layers, norm: nil)
|
5
|
+
super()
|
6
|
+
|
7
|
+
@layers = _clones(encoder_layer, num_layers)
|
8
|
+
@num_layers = num_layers
|
9
|
+
@norm = norm
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(src, mask: nil, src_key_padding_mask: nil)
|
13
|
+
output = src
|
14
|
+
|
15
|
+
@layers.each do |mod|
|
16
|
+
output = mod.call(output, src_mask: mask, src_key_padding_mask: src_key_padding_mask)
|
17
|
+
end
|
18
|
+
|
19
|
+
output = @norm.call(output) if @norm
|
20
|
+
|
21
|
+
output
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class TransformerEncoderLayer < Module
|
4
|
+
def initialize(
|
5
|
+
d_model, n_head,
|
6
|
+
dim_feedforward: 2048, dropout: 0.1, activation: :relu,
|
7
|
+
layer_norm_eps: 1e-5, batch_first: false
|
8
|
+
)
|
9
|
+
|
10
|
+
super()
|
11
|
+
|
12
|
+
@self_attn = MultiheadAttention.new(d_model, n_head, dropout: dropout, batch_first: batch_first)
|
13
|
+
@linear1 = Linear.new(d_model, dim_feedforward)
|
14
|
+
@dropout = Dropout.new(p: dropout)
|
15
|
+
@linear2 = Linear.new(dim_feedforward, d_model)
|
16
|
+
|
17
|
+
@norm1 = LayerNorm.new(d_model, eps: layer_norm_eps)
|
18
|
+
@norm2 = LayerNorm.new(d_model, eps: layer_norm_eps)
|
19
|
+
|
20
|
+
@dropout1 = Dropout.new(p: dropout)
|
21
|
+
@dropout2 = Dropout.new(p: dropout)
|
22
|
+
|
23
|
+
@activation = _activation_fn(activation)
|
24
|
+
end
|
25
|
+
|
26
|
+
def forward(src, src_mask: nil, src_key_padding_mask: nil)
|
27
|
+
src2 = @self_attn.(src, src, src, attn_mask: src_mask, key_padding_mask: src_key_padding_mask).first
|
28
|
+
src += @dropout1.(src2)
|
29
|
+
src = @norm1.(src)
|
30
|
+
src2 = @linear2.(@dropout.(@activation.(@linear1.(src))))
|
31
|
+
src += @dropout2.(src2)
|
32
|
+
@norm2.(src)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
data/lib/torch/nn/utils.rb
CHANGED
@@ -20,6 +20,18 @@ module Torch
|
|
20
20
|
def _ntuple(n, value)
|
21
21
|
value.is_a?(Array) ? value : [value] * n
|
22
22
|
end
|
23
|
+
|
24
|
+
def _clones(mod, n)
|
25
|
+
ModuleList.new(n.times.map { mod.deep_dup })
|
26
|
+
end
|
27
|
+
|
28
|
+
def _activation_fn(activation)
|
29
|
+
case activation.to_sym
|
30
|
+
when :relu then F.method(:relu)
|
31
|
+
when :gelu then F.method(:gelu)
|
32
|
+
else raise ArgumentError, "Activation should be relu/gelu, not `#{activation}`"
|
33
|
+
end
|
34
|
+
end
|
23
35
|
end
|
24
36
|
end
|
25
37
|
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -19,6 +19,8 @@ module Torch
|
|
19
19
|
alias_method :&, :logical_and
|
20
20
|
alias_method :|, :logical_or
|
21
21
|
alias_method :^, :logical_xor
|
22
|
+
alias_method :<<, :__lshift__
|
23
|
+
alias_method :>>, :__rshift__
|
22
24
|
|
23
25
|
def self.new(*args)
|
24
26
|
FloatTensor.new(*args)
|
@@ -183,5 +185,11 @@ module Torch
|
|
183
185
|
def stft(*args)
|
184
186
|
Torch.stft(*args)
|
185
187
|
end
|
188
|
+
|
189
|
+
def dup
|
190
|
+
Torch.no_grad do
|
191
|
+
clone
|
192
|
+
end
|
193
|
+
end
|
186
194
|
end
|
187
195
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -39,6 +39,7 @@ require "torch/nn/utils"
|
|
39
39
|
|
40
40
|
# nn containers
|
41
41
|
require "torch/nn/module"
|
42
|
+
require "torch/nn/module_list"
|
42
43
|
require "torch/nn/sequential"
|
43
44
|
|
44
45
|
# nn convolution layers
|
@@ -143,6 +144,10 @@ require "torch/nn/softmin"
|
|
143
144
|
require "torch/nn/embedding"
|
144
145
|
require "torch/nn/embedding_bag"
|
145
146
|
|
147
|
+
# attention is all you need
|
148
|
+
require "torch/nn/multihead_attention"
|
149
|
+
require "torch/nn/transformer"
|
150
|
+
|
146
151
|
# nn distance functions
|
147
152
|
require "torch/nn/cosine_similarity"
|
148
153
|
require "torch/nn/pairwise_distance"
|
@@ -174,6 +179,7 @@ require "torch/nn/upsample"
|
|
174
179
|
|
175
180
|
# nn other
|
176
181
|
require "torch/nn/functional"
|
182
|
+
require "torch/nn/functional_attention"
|
177
183
|
require "torch/nn/init"
|
178
184
|
|
179
185
|
# utils
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.8.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-
|
11
|
+
date: 2021-10-17 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -37,16 +37,23 @@ files:
|
|
37
37
|
- codegen/function.rb
|
38
38
|
- codegen/generate_functions.rb
|
39
39
|
- codegen/native_functions.yaml
|
40
|
+
- ext/torch/backends.cpp
|
40
41
|
- ext/torch/cuda.cpp
|
41
42
|
- ext/torch/device.cpp
|
42
43
|
- ext/torch/ext.cpp
|
43
44
|
- ext/torch/extconf.rb
|
45
|
+
- ext/torch/fft.cpp
|
46
|
+
- ext/torch/fft_functions.h
|
44
47
|
- ext/torch/ivalue.cpp
|
48
|
+
- ext/torch/linalg.cpp
|
49
|
+
- ext/torch/linalg_functions.h
|
45
50
|
- ext/torch/nn.cpp
|
46
51
|
- ext/torch/nn_functions.h
|
47
52
|
- ext/torch/random.cpp
|
48
53
|
- ext/torch/ruby_arg_parser.cpp
|
49
54
|
- ext/torch/ruby_arg_parser.h
|
55
|
+
- ext/torch/special.cpp
|
56
|
+
- ext/torch/special_functions.h
|
50
57
|
- ext/torch/templates.h
|
51
58
|
- ext/torch/tensor.cpp
|
52
59
|
- ext/torch/tensor_functions.h
|
@@ -99,6 +106,7 @@ files:
|
|
99
106
|
- lib/torch/nn/feature_alpha_dropout.rb
|
100
107
|
- lib/torch/nn/fold.rb
|
101
108
|
- lib/torch/nn/functional.rb
|
109
|
+
- lib/torch/nn/functional_attention.rb
|
102
110
|
- lib/torch/nn/group_norm.rb
|
103
111
|
- lib/torch/nn/gru.rb
|
104
112
|
- lib/torch/nn/hardshrink.rb
|
@@ -132,10 +140,12 @@ files:
|
|
132
140
|
- lib/torch/nn/max_unpool3d.rb
|
133
141
|
- lib/torch/nn/max_unpoolnd.rb
|
134
142
|
- lib/torch/nn/module.rb
|
143
|
+
- lib/torch/nn/module_list.rb
|
135
144
|
- lib/torch/nn/mse_loss.rb
|
136
145
|
- lib/torch/nn/multi_label_margin_loss.rb
|
137
146
|
- lib/torch/nn/multi_label_soft_margin_loss.rb
|
138
147
|
- lib/torch/nn/multi_margin_loss.rb
|
148
|
+
- lib/torch/nn/multihead_attention.rb
|
139
149
|
- lib/torch/nn/nll_loss.rb
|
140
150
|
- lib/torch/nn/pairwise_distance.rb
|
141
151
|
- lib/torch/nn/parameter.rb
|
@@ -163,6 +173,11 @@ files:
|
|
163
173
|
- lib/torch/nn/softsign.rb
|
164
174
|
- lib/torch/nn/tanh.rb
|
165
175
|
- lib/torch/nn/tanhshrink.rb
|
176
|
+
- lib/torch/nn/transformer.rb
|
177
|
+
- lib/torch/nn/transformer_decoder.rb
|
178
|
+
- lib/torch/nn/transformer_decoder_layer.rb
|
179
|
+
- lib/torch/nn/transformer_encoder.rb
|
180
|
+
- lib/torch/nn/transformer_encoder_layer.rb
|
166
181
|
- lib/torch/nn/triplet_margin_loss.rb
|
167
182
|
- lib/torch/nn/unfold.rb
|
168
183
|
- lib/torch/nn/upsample.rb
|
@@ -212,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
212
227
|
- !ruby/object:Gem::Version
|
213
228
|
version: '0'
|
214
229
|
requirements: []
|
215
|
-
rubygems_version: 3.2.
|
230
|
+
rubygems_version: 3.2.22
|
216
231
|
signing_key:
|
217
232
|
specification_version: 4
|
218
233
|
summary: Deep learning for Ruby, powered by LibTorch
|