torchtext 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +34 -1
- data/lib/torchtext.rb +4 -0
- data/lib/torchtext/data/metrics.rb +68 -0
- data/lib/torchtext/datasets/text_classification.rb +1 -1
- data/lib/torchtext/nn/in_proj_container.rb +16 -0
- data/lib/torchtext/nn/multihead_attention_container.rb +50 -0
- data/lib/torchtext/nn/scaled_dot_product.rb +72 -0
- data/lib/torchtext/version.rb +1 -1
- metadata +11 -7
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d5c6ff21a492fa03ce88ed7b823d29310e67b3d0c86825ccd2b6092f330bdbc1
|
4
|
+
data.tar.gz: 7df74bb05bb110ae37e9d962adc9dc020088203ed1b9e920fe3c4e9421f8e714
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e2d0da285afa14f72380ad688ce63c316a6e46e0cceb7f332b15d298253ffe05d9f6bf9244aa630bc20719a794fe6b316ac7861788e7775ac104e7741dcf167b
|
7
|
+
data.tar.gz: '0202023381da0eefdd6c27587ad1798e9052a8798d3d32f8e71b19d7a5303aaf5a01f16c77d486d1d8249ef89cf4009f146855f6e3d1f4cf2c7e3fc9e66ce830'
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Data loaders and abstractions for text and NLP - for Ruby
|
4
4
|
|
5
|
+
[](https://github.com/ankane/torchtext/actions)
|
6
|
+
|
5
7
|
## Installation
|
6
8
|
|
7
9
|
Add this line to your application’s Gemfile:
|
@@ -19,7 +21,7 @@ This library follows the [Python API](https://pytorch.org/text/). Many methods a
|
|
19
21
|
Text classification
|
20
22
|
|
21
23
|
- [PyTorch tutorial](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)
|
22
|
-
- [Ruby code](examples/text_classification)
|
24
|
+
- [Ruby code](examples/text_classification.rb)
|
23
25
|
|
24
26
|
## Datasets
|
25
27
|
|
@@ -33,6 +35,37 @@ Supported datasets are:
|
|
33
35
|
|
34
36
|
- [AG_NEWS](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
|
35
37
|
|
38
|
+
## Data Utils
|
39
|
+
|
40
|
+
Supports:
|
41
|
+
|
42
|
+
- tokenizer
|
43
|
+
- ngrams_iterator
|
44
|
+
|
45
|
+
## Data Metrics
|
46
|
+
|
47
|
+
Compute the BLEU score
|
48
|
+
|
49
|
+
```ruby
|
50
|
+
candidate_corpus = [["My", "full", "pytorch", "test"], ["Another", "Sentence"]]
|
51
|
+
references_corpus = [[["My", "full", "pytorch", "test"], ["Completely", "Different"]], [["No", "Match"]]]
|
52
|
+
TorchText::Data::Metrics.bleu_score(candidate_corpus, references_corpus)
|
53
|
+
```
|
54
|
+
|
55
|
+
## NN
|
56
|
+
|
57
|
+
Supports:
|
58
|
+
|
59
|
+
- InProjContainer
|
60
|
+
- MultiheadAttentionContainer
|
61
|
+
- ScaledDotProduct
|
62
|
+
|
63
|
+
## Vocab
|
64
|
+
|
65
|
+
Supports:
|
66
|
+
|
67
|
+
- Vocab
|
68
|
+
|
36
69
|
## Disclaimer
|
37
70
|
|
38
71
|
This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
|
data/lib/torchtext.rb
CHANGED
@@ -9,8 +9,12 @@ require "set"
|
|
9
9
|
|
10
10
|
# modules
|
11
11
|
require "torchtext/data/utils"
|
12
|
+
require "torchtext/data/metrics"
|
12
13
|
require "torchtext/datasets/text_classification"
|
13
14
|
require "torchtext/datasets/text_classification_dataset"
|
15
|
+
require "torchtext/nn/in_proj_container"
|
16
|
+
require "torchtext/nn/multihead_attention_container"
|
17
|
+
require "torchtext/nn/scaled_dot_product"
|
14
18
|
require "torchtext/vocab"
|
15
19
|
require "torchtext/version"
|
16
20
|
|
@@ -0,0 +1,68 @@
|
|
1
|
+
module TorchText
|
2
|
+
module Data
|
3
|
+
module Metrics
|
4
|
+
class << self
|
5
|
+
def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
|
6
|
+
unless max_n == weights.length
|
7
|
+
raise "Length of the \"weights\" list has be equal to max_n"
|
8
|
+
end
|
9
|
+
unless candidate_corpus.length == references_corpus.length
|
10
|
+
raise "The length of candidate and reference corpus should be the same"
|
11
|
+
end
|
12
|
+
|
13
|
+
clipped_counts = Torch.zeros(max_n)
|
14
|
+
total_counts = Torch.zeros(max_n)
|
15
|
+
weights = Torch.tensor(weights)
|
16
|
+
|
17
|
+
candidate_len = 0.0
|
18
|
+
refs_len = 0.0
|
19
|
+
|
20
|
+
candidate_corpus.zip(references_corpus) do |candidate, refs|
|
21
|
+
candidate_len += candidate.length
|
22
|
+
|
23
|
+
# Get the length of the reference that's closest in length to the candidate
|
24
|
+
refs_len_list = refs.map { |ref| ref.length.to_f }
|
25
|
+
refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs }
|
26
|
+
|
27
|
+
reference_counters = compute_ngram_counter(refs[0], max_n)
|
28
|
+
refs[1..-1].each do |ref|
|
29
|
+
reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 }
|
30
|
+
end
|
31
|
+
|
32
|
+
candidate_counter = compute_ngram_counter(candidate, max_n)
|
33
|
+
|
34
|
+
shared_keys = candidate_counter.keys & reference_counters.keys
|
35
|
+
clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 }
|
36
|
+
|
37
|
+
clipped_counter.each_key do |ngram|
|
38
|
+
clipped_counts[ngram.length - 1] += clipped_counter[ngram]
|
39
|
+
end
|
40
|
+
|
41
|
+
candidate_counter.each_key do |ngram|
|
42
|
+
total_counts[ngram.length - 1] += candidate_counter[ngram]
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
if clipped_counts.to_a.min == 0
|
47
|
+
0.0
|
48
|
+
else
|
49
|
+
pn = clipped_counts / total_counts
|
50
|
+
log_pn = weights * Torch.log(pn)
|
51
|
+
score = Torch.exp(log_pn.sum)
|
52
|
+
|
53
|
+
bp = Math.exp([1 - refs_len / candidate_len, 0].min)
|
54
|
+
|
55
|
+
bp * score.item
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
private
|
60
|
+
|
61
|
+
def compute_ngram_counter(tokens, max_n)
|
62
|
+
raise "Failed assert" unless max_n > 0
|
63
|
+
Hash[TorchText::Data::Utils.ngrams_iterator(tokens, max_n).map { |x| x.split(" ") }.group_by { |v| v }.map { |k, v| [k, v.size] }]
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module TorchText
|
2
|
+
module NN
|
3
|
+
class InProjContainer < Torch::NN::Module
|
4
|
+
def initialize(query_proj, key_proj, value_proj)
|
5
|
+
super()
|
6
|
+
@query_proj = query_proj
|
7
|
+
@key_proj = key_proj
|
8
|
+
@value_proj = value_proj
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(query, key, value)
|
12
|
+
[@query_proj.call(query), @key_proj.call(key), @value_proj.call(value)]
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
module TorchText
|
2
|
+
module NN
|
3
|
+
class MultiheadAttentionContainer < Torch::NN::Module
|
4
|
+
def initialize(nhead, in_proj_container, attention_layer, out_proj, batch_first: false)
|
5
|
+
super()
|
6
|
+
@nhead = nhead
|
7
|
+
@in_proj_container = in_proj_container
|
8
|
+
@attention_layer = attention_layer
|
9
|
+
@out_proj = out_proj
|
10
|
+
@batch_first = batch_first
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
|
14
|
+
if @batch_first
|
15
|
+
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
|
16
|
+
end
|
17
|
+
|
18
|
+
tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
|
19
|
+
q, k, v = @in_proj_container.call(query, key, value)
|
20
|
+
unless q.size(-1) % @nhead == 0
|
21
|
+
raise "query's embed_dim must be divisible by the number of heads"
|
22
|
+
end
|
23
|
+
head_dim = q.size(-1).div(@nhead)
|
24
|
+
q = q.reshape(tgt_len, bsz * @nhead, head_dim)
|
25
|
+
|
26
|
+
unless k.size(-1) % @nhead == 0
|
27
|
+
raise "key's embed_dim must be divisible by the number of heads"
|
28
|
+
end
|
29
|
+
head_dim = k.size(-1).div(@nhead)
|
30
|
+
k = k.reshape(src_len, bsz * @nhead, head_dim)
|
31
|
+
|
32
|
+
unless v.size(-1) % @nhead == 0
|
33
|
+
raise "value's embed_dim must be divisible by the number of heads"
|
34
|
+
end
|
35
|
+
head_dim = v.size(-1).div(@nhead)
|
36
|
+
v = v.reshape(src_len, bsz * @nhead, head_dim)
|
37
|
+
|
38
|
+
attn_output, attn_output_weights = @attention_layer.call(q, k, v, attn_mask: attn_mask, bias_k: bias_k, bias_v: bias_v)
|
39
|
+
attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
|
40
|
+
attn_output = @out_proj.call(attn_output)
|
41
|
+
|
42
|
+
if @batch_first
|
43
|
+
attn_output = attn_output.transpose(-3, -2)
|
44
|
+
end
|
45
|
+
|
46
|
+
[attn_output, attn_output_weights]
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -0,0 +1,72 @@
|
|
1
|
+
module TorchText
|
2
|
+
module NN
|
3
|
+
class ScaledDotProduct < Torch::NN::Module
|
4
|
+
def initialize(dropout: 0.0, batch_first: false)
|
5
|
+
super()
|
6
|
+
@dropout = dropout
|
7
|
+
@batch_first = batch_first
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
|
11
|
+
if @batch_first
|
12
|
+
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
|
13
|
+
end
|
14
|
+
|
15
|
+
if !bias_k.nil? && !bias_v.nil?
|
16
|
+
unless key.size(-1) == bias_k.size(-1) && key.size(-2) == bias_k.size(-2) && bias_k.size(-3) == 1
|
17
|
+
raise "Shape of bias_k is not supported"
|
18
|
+
end
|
19
|
+
unless value.size(-1) == bias_v.size(-1) && value.size(-2) == bias_v.size(-2) && bias_v.size(-3) == 1
|
20
|
+
raise "Shape of bias_v is not supported"
|
21
|
+
end
|
22
|
+
key = Torch.cat([key, bias_k])
|
23
|
+
value = Torch.cat([value, bias_v])
|
24
|
+
if !attn_mask.nil?
|
25
|
+
attn_mask = Torch::NN::Functional.pad(attn_mask, [0, 1])
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
tgt_len, head_dim = query.size(-3), query.size(-1)
|
30
|
+
unless query.size(-1) == key.size(-1) && key.size(-1) == value.size(-1)
|
31
|
+
raise "The feature dim of query, key, value must be equal."
|
32
|
+
end
|
33
|
+
unless key.size() == value.size()
|
34
|
+
raise "Shape of key, value must match"
|
35
|
+
end
|
36
|
+
src_len = key.size(-3)
|
37
|
+
batch_heads = [query.size(-2), key.size(-2)].max
|
38
|
+
|
39
|
+
# Scale query
|
40
|
+
query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
|
41
|
+
query = query * (head_dim.to_f ** -0.5)
|
42
|
+
if !attn_mask.nil?
|
43
|
+
if attn_mask.dim() != 3
|
44
|
+
raise RuntimeError, "attn_mask must be a 3D tensor."
|
45
|
+
end
|
46
|
+
if (attn_mask.size(-1) != src_len) || (attn_mask.size(-2) != tgt_len) || (attn_mask.size(-3) != 1 && attn_mask.size(-3) != batch_heads)
|
47
|
+
raise RuntimeError, "The size of the attn_mask is not correct."
|
48
|
+
end
|
49
|
+
if attn_mask.dtype != :bool
|
50
|
+
raise RuntimeError, "Only bool tensor is supported for attn_mask"
|
51
|
+
end
|
52
|
+
end
|
53
|
+
|
54
|
+
# Dot product of q, k
|
55
|
+
attn_output_weights = Torch.matmul(query, key.transpose(-2, -1))
|
56
|
+
if !attn_mask.nil?
|
57
|
+
# TODO confirm last argument
|
58
|
+
attn_output_weights.masked_fill!(attn_mask, -1e8, nil)
|
59
|
+
end
|
60
|
+
attn_output_weights = Torch::NN::Functional.softmax(attn_output_weights, dim: -1)
|
61
|
+
attn_output_weights = Torch::NN::Functional.dropout(attn_output_weights, p: @dropout, training: @training)
|
62
|
+
attn_output = Torch.matmul(attn_output_weights, value)
|
63
|
+
|
64
|
+
if @batch_first
|
65
|
+
[attn_output, attn_output_weights]
|
66
|
+
else
|
67
|
+
[attn_output.transpose(-3, -2), attn_output_weights]
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
data/lib/torchtext/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchtext
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2021-07-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: torch-rb
|
@@ -66,7 +66,7 @@ dependencies:
|
|
66
66
|
- - ">="
|
67
67
|
- !ruby/object:Gem::Version
|
68
68
|
version: '5'
|
69
|
-
description:
|
69
|
+
description:
|
70
70
|
email: andrew@chartkick.com
|
71
71
|
executables: []
|
72
72
|
extensions: []
|
@@ -76,16 +76,20 @@ files:
|
|
76
76
|
- LICENSE.txt
|
77
77
|
- README.md
|
78
78
|
- lib/torchtext.rb
|
79
|
+
- lib/torchtext/data/metrics.rb
|
79
80
|
- lib/torchtext/data/utils.rb
|
80
81
|
- lib/torchtext/datasets/text_classification.rb
|
81
82
|
- lib/torchtext/datasets/text_classification_dataset.rb
|
83
|
+
- lib/torchtext/nn/in_proj_container.rb
|
84
|
+
- lib/torchtext/nn/multihead_attention_container.rb
|
85
|
+
- lib/torchtext/nn/scaled_dot_product.rb
|
82
86
|
- lib/torchtext/version.rb
|
83
87
|
- lib/torchtext/vocab.rb
|
84
88
|
homepage: https://github.com/ankane/torchtext
|
85
89
|
licenses:
|
86
90
|
- BSD-3-Clause
|
87
91
|
metadata: {}
|
88
|
-
post_install_message:
|
92
|
+
post_install_message:
|
89
93
|
rdoc_options: []
|
90
94
|
require_paths:
|
91
95
|
- lib
|
@@ -100,8 +104,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
100
104
|
- !ruby/object:Gem::Version
|
101
105
|
version: '0'
|
102
106
|
requirements: []
|
103
|
-
rubygems_version: 3.
|
104
|
-
signing_key:
|
107
|
+
rubygems_version: 3.2.22
|
108
|
+
signing_key:
|
105
109
|
specification_version: 4
|
106
110
|
summary: Data loaders and abstractions for text and NLP
|
107
111
|
test_files: []
|