torchtext 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +1 -1
- data/README.md +41 -8
- data/lib/torchtext/data/functional.rb +11 -0
- data/lib/torchtext/data/metrics.rb +68 -0
- data/lib/torchtext/datasets/text_classification.rb +1 -1
- data/lib/torchtext/nn/in_proj_container.rb +16 -0
- data/lib/torchtext/nn/multihead_attention_container.rb +50 -0
- data/lib/torchtext/nn/scaled_dot_product.rb +72 -0
- data/lib/torchtext/version.rb +1 -1
- data/lib/torchtext.rb +10 -5
- metadata +17 -54
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f9f88b060fcb69a0df1258c1a5399b2518ca9b1c95cd3c3f999bd81f9df0cc41
|
4
|
+
data.tar.gz: 3b04fab3f8e9a3ef86e3c46aa64f2ba98a8d85c3dd42889cb728d0c7f39e7be2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4669a0ce275ec3ef0a1331c39835b81a8cff582c62012adc62ae8c59ea941befb53cf232198922bcbdd9795930e970a4e8970a851d66969531c75dd83e7bd40c
|
7
|
+
data.tar.gz: a3bb8659e334540849a011a0f526136650a4bd88d5a9da6b48f6e3d14dd889b08fe730c3e517837e6989e06c9a8b515d8a34958e3df267a80eedd465a239123f
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -1,13 +1,15 @@
|
|
1
|
-
# TorchText
|
1
|
+
# TorchText Ruby
|
2
2
|
|
3
3
|
:fire: Data loaders and abstractions for text and NLP - for Ruby
|
4
4
|
|
5
|
+
[](https://github.com/ankane/torchtext-ruby/actions)
|
6
|
+
|
5
7
|
## Installation
|
6
8
|
|
7
9
|
Add this line to your application’s Gemfile:
|
8
10
|
|
9
11
|
```ruby
|
10
|
-
gem
|
12
|
+
gem "torchtext"
|
11
13
|
```
|
12
14
|
|
13
15
|
## Getting Started
|
@@ -19,7 +21,7 @@ This library follows the [Python API](https://pytorch.org/text/). Many methods a
|
|
19
21
|
Text classification
|
20
22
|
|
21
23
|
- [PyTorch tutorial](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)
|
22
|
-
- [Ruby code](examples/text_classification)
|
24
|
+
- [Ruby code](examples/text_classification.rb)
|
23
25
|
|
24
26
|
## Datasets
|
25
27
|
|
@@ -33,6 +35,37 @@ Supported datasets are:
|
|
33
35
|
|
34
36
|
- [AG_NEWS](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
|
35
37
|
|
38
|
+
## Data Utils
|
39
|
+
|
40
|
+
Supports:
|
41
|
+
|
42
|
+
- tokenizer
|
43
|
+
- ngrams_iterator
|
44
|
+
|
45
|
+
## Data Metrics
|
46
|
+
|
47
|
+
Compute the BLEU score
|
48
|
+
|
49
|
+
```ruby
|
50
|
+
candidate_corpus = [["My", "full", "pytorch", "test"], ["Another", "Sentence"]]
|
51
|
+
references_corpus = [[["My", "full", "pytorch", "test"], ["Completely", "Different"]], [["No", "Match"]]]
|
52
|
+
TorchText::Data::Metrics.bleu_score(candidate_corpus, references_corpus)
|
53
|
+
```
|
54
|
+
|
55
|
+
## NN
|
56
|
+
|
57
|
+
Supports:
|
58
|
+
|
59
|
+
- InProjContainer
|
60
|
+
- MultiheadAttentionContainer
|
61
|
+
- ScaledDotProduct
|
62
|
+
|
63
|
+
## Vocab
|
64
|
+
|
65
|
+
Supports:
|
66
|
+
|
67
|
+
- Vocab
|
68
|
+
|
36
69
|
## Disclaimer
|
37
70
|
|
38
71
|
This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
|
@@ -41,22 +74,22 @@ If you’re a dataset owner and wish to update any details or remove it from thi
|
|
41
74
|
|
42
75
|
## History
|
43
76
|
|
44
|
-
View the [changelog](https://github.com/ankane/torchtext/blob/master/CHANGELOG.md)
|
77
|
+
View the [changelog](https://github.com/ankane/torchtext-ruby/blob/master/CHANGELOG.md)
|
45
78
|
|
46
79
|
## Contributing
|
47
80
|
|
48
81
|
Everyone is encouraged to help improve this project. Here are a few ways you can help:
|
49
82
|
|
50
|
-
- [Report bugs](https://github.com/ankane/torchtext/issues)
|
51
|
-
- Fix bugs and [submit pull requests](https://github.com/ankane/torchtext/pulls)
|
83
|
+
- [Report bugs](https://github.com/ankane/torchtext-ruby/issues)
|
84
|
+
- Fix bugs and [submit pull requests](https://github.com/ankane/torchtext-ruby/pulls)
|
52
85
|
- Write, clarify, or fix documentation
|
53
86
|
- Suggest or add new features
|
54
87
|
|
55
88
|
To get started with development:
|
56
89
|
|
57
90
|
```sh
|
58
|
-
git clone https://github.com/ankane/torchtext.git
|
59
|
-
cd torchtext
|
91
|
+
git clone https://github.com/ankane/torchtext-ruby.git
|
92
|
+
cd torchtext-ruby
|
60
93
|
bundle install
|
61
94
|
bundle exec rake test
|
62
95
|
```
|
@@ -0,0 +1,68 @@
|
|
1
|
+
module TorchText
|
2
|
+
module Data
|
3
|
+
module Metrics
|
4
|
+
class << self
|
5
|
+
def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
|
6
|
+
unless max_n == weights.length
|
7
|
+
raise "Length of the \"weights\" list has be equal to max_n"
|
8
|
+
end
|
9
|
+
unless candidate_corpus.length == references_corpus.length
|
10
|
+
raise "The length of candidate and reference corpus should be the same"
|
11
|
+
end
|
12
|
+
|
13
|
+
clipped_counts = Torch.zeros(max_n)
|
14
|
+
total_counts = Torch.zeros(max_n)
|
15
|
+
weights = Torch.tensor(weights)
|
16
|
+
|
17
|
+
candidate_len = 0.0
|
18
|
+
refs_len = 0.0
|
19
|
+
|
20
|
+
candidate_corpus.zip(references_corpus) do |candidate, refs|
|
21
|
+
candidate_len += candidate.length
|
22
|
+
|
23
|
+
# Get the length of the reference that's closest in length to the candidate
|
24
|
+
refs_len_list = refs.map { |ref| ref.length.to_f }
|
25
|
+
refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs }
|
26
|
+
|
27
|
+
reference_counters = compute_ngram_counter(refs[0], max_n)
|
28
|
+
refs[1..-1].each do |ref|
|
29
|
+
reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 }
|
30
|
+
end
|
31
|
+
|
32
|
+
candidate_counter = compute_ngram_counter(candidate, max_n)
|
33
|
+
|
34
|
+
shared_keys = candidate_counter.keys & reference_counters.keys
|
35
|
+
clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 }
|
36
|
+
|
37
|
+
clipped_counter.each_key do |ngram|
|
38
|
+
clipped_counts[ngram.length - 1] += clipped_counter[ngram]
|
39
|
+
end
|
40
|
+
|
41
|
+
candidate_counter.each_key do |ngram|
|
42
|
+
total_counts[ngram.length - 1] += candidate_counter[ngram]
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
if clipped_counts.to_a.min == 0
|
47
|
+
0.0
|
48
|
+
else
|
49
|
+
pn = clipped_counts / total_counts
|
50
|
+
log_pn = weights * Torch.log(pn)
|
51
|
+
score = Torch.exp(log_pn.sum)
|
52
|
+
|
53
|
+
bp = Math.exp([1 - refs_len / candidate_len, 0].min)
|
54
|
+
|
55
|
+
bp * score.item
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
private
|
60
|
+
|
61
|
+
def compute_ngram_counter(tokens, max_n)
|
62
|
+
raise "Failed assert" unless max_n > 0
|
63
|
+
Hash[TorchText::Data::Utils.ngrams_iterator(tokens, max_n).map { |x| x.split(" ") }.group_by { |v| v }.map { |k, v| [k, v.size] }]
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module TorchText
|
2
|
+
module NN
|
3
|
+
class InProjContainer < Torch::NN::Module
|
4
|
+
def initialize(query_proj, key_proj, value_proj)
|
5
|
+
super()
|
6
|
+
@query_proj = query_proj
|
7
|
+
@key_proj = key_proj
|
8
|
+
@value_proj = value_proj
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(query, key, value)
|
12
|
+
[@query_proj.call(query), @key_proj.call(key), @value_proj.call(value)]
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
module TorchText
|
2
|
+
module NN
|
3
|
+
class MultiheadAttentionContainer < Torch::NN::Module
|
4
|
+
def initialize(nhead, in_proj_container, attention_layer, out_proj, batch_first: false)
|
5
|
+
super()
|
6
|
+
@nhead = nhead
|
7
|
+
@in_proj_container = in_proj_container
|
8
|
+
@attention_layer = attention_layer
|
9
|
+
@out_proj = out_proj
|
10
|
+
@batch_first = batch_first
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
|
14
|
+
if @batch_first
|
15
|
+
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
|
16
|
+
end
|
17
|
+
|
18
|
+
tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
|
19
|
+
q, k, v = @in_proj_container.call(query, key, value)
|
20
|
+
unless q.size(-1) % @nhead == 0
|
21
|
+
raise "query's embed_dim must be divisible by the number of heads"
|
22
|
+
end
|
23
|
+
head_dim = q.size(-1).div(@nhead)
|
24
|
+
q = q.reshape(tgt_len, bsz * @nhead, head_dim)
|
25
|
+
|
26
|
+
unless k.size(-1) % @nhead == 0
|
27
|
+
raise "key's embed_dim must be divisible by the number of heads"
|
28
|
+
end
|
29
|
+
head_dim = k.size(-1).div(@nhead)
|
30
|
+
k = k.reshape(src_len, bsz * @nhead, head_dim)
|
31
|
+
|
32
|
+
unless v.size(-1) % @nhead == 0
|
33
|
+
raise "value's embed_dim must be divisible by the number of heads"
|
34
|
+
end
|
35
|
+
head_dim = v.size(-1).div(@nhead)
|
36
|
+
v = v.reshape(src_len, bsz * @nhead, head_dim)
|
37
|
+
|
38
|
+
attn_output, attn_output_weights = @attention_layer.call(q, k, v, attn_mask: attn_mask, bias_k: bias_k, bias_v: bias_v)
|
39
|
+
attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
|
40
|
+
attn_output = @out_proj.call(attn_output)
|
41
|
+
|
42
|
+
if @batch_first
|
43
|
+
attn_output = attn_output.transpose(-3, -2)
|
44
|
+
end
|
45
|
+
|
46
|
+
[attn_output, attn_output_weights]
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -0,0 +1,72 @@
|
|
1
|
+
module TorchText
|
2
|
+
module NN
|
3
|
+
class ScaledDotProduct < Torch::NN::Module
|
4
|
+
def initialize(dropout: 0.0, batch_first: false)
|
5
|
+
super()
|
6
|
+
@dropout = dropout
|
7
|
+
@batch_first = batch_first
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
|
11
|
+
if @batch_first
|
12
|
+
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
|
13
|
+
end
|
14
|
+
|
15
|
+
if !bias_k.nil? && !bias_v.nil?
|
16
|
+
unless key.size(-1) == bias_k.size(-1) && key.size(-2) == bias_k.size(-2) && bias_k.size(-3) == 1
|
17
|
+
raise "Shape of bias_k is not supported"
|
18
|
+
end
|
19
|
+
unless value.size(-1) == bias_v.size(-1) && value.size(-2) == bias_v.size(-2) && bias_v.size(-3) == 1
|
20
|
+
raise "Shape of bias_v is not supported"
|
21
|
+
end
|
22
|
+
key = Torch.cat([key, bias_k])
|
23
|
+
value = Torch.cat([value, bias_v])
|
24
|
+
if !attn_mask.nil?
|
25
|
+
attn_mask = Torch::NN::Functional.pad(attn_mask, [0, 1])
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
tgt_len, head_dim = query.size(-3), query.size(-1)
|
30
|
+
unless query.size(-1) == key.size(-1) && key.size(-1) == value.size(-1)
|
31
|
+
raise "The feature dim of query, key, value must be equal."
|
32
|
+
end
|
33
|
+
unless key.size() == value.size()
|
34
|
+
raise "Shape of key, value must match"
|
35
|
+
end
|
36
|
+
src_len = key.size(-3)
|
37
|
+
batch_heads = [query.size(-2), key.size(-2)].max
|
38
|
+
|
39
|
+
# Scale query
|
40
|
+
query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
|
41
|
+
query = query * (head_dim.to_f ** -0.5)
|
42
|
+
if !attn_mask.nil?
|
43
|
+
if attn_mask.dim() != 3
|
44
|
+
raise RuntimeError, "attn_mask must be a 3D tensor."
|
45
|
+
end
|
46
|
+
if (attn_mask.size(-1) != src_len) || (attn_mask.size(-2) != tgt_len) || (attn_mask.size(-3) != 1 && attn_mask.size(-3) != batch_heads)
|
47
|
+
raise RuntimeError, "The size of the attn_mask is not correct."
|
48
|
+
end
|
49
|
+
if attn_mask.dtype != :bool
|
50
|
+
raise RuntimeError, "Only bool tensor is supported for attn_mask"
|
51
|
+
end
|
52
|
+
end
|
53
|
+
|
54
|
+
# Dot product of q, k
|
55
|
+
attn_output_weights = Torch.matmul(query, key.transpose(-2, -1))
|
56
|
+
if !attn_mask.nil?
|
57
|
+
# TODO confirm last argument
|
58
|
+
attn_output_weights.masked_fill!(attn_mask, -1e8, nil)
|
59
|
+
end
|
60
|
+
attn_output_weights = Torch::NN::Functional.softmax(attn_output_weights, dim: -1)
|
61
|
+
attn_output_weights = Torch::NN::Functional.dropout(attn_output_weights, p: @dropout, training: @training)
|
62
|
+
attn_output = Torch.matmul(attn_output_weights, value)
|
63
|
+
|
64
|
+
if @batch_first
|
65
|
+
[attn_output, attn_output_weights]
|
66
|
+
else
|
67
|
+
[attn_output.transpose(-3, -2), attn_output_weights]
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
data/lib/torchtext/version.rb
CHANGED
data/lib/torchtext.rb
CHANGED
@@ -8,11 +8,16 @@ require "rubygems/package"
|
|
8
8
|
require "set"
|
9
9
|
|
10
10
|
# modules
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
11
|
+
require_relative "torchtext/data/functional"
|
12
|
+
require_relative "torchtext/data/metrics"
|
13
|
+
require_relative "torchtext/data/utils"
|
14
|
+
require_relative "torchtext/datasets/text_classification"
|
15
|
+
require_relative "torchtext/datasets/text_classification_dataset"
|
16
|
+
require_relative "torchtext/nn/in_proj_container"
|
17
|
+
require_relative "torchtext/nn/multihead_attention_container"
|
18
|
+
require_relative "torchtext/nn/scaled_dot_product"
|
19
|
+
require_relative "torchtext/vocab"
|
20
|
+
require_relative "torchtext/version"
|
16
21
|
|
17
22
|
module TorchText
|
18
23
|
class Error < StandardError; end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchtext
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-01-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: torch-rb
|
@@ -16,58 +16,16 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 0.
|
19
|
+
version: 0.11.1
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 0.
|
27
|
-
|
28
|
-
|
29
|
-
requirement: !ruby/object:Gem::Requirement
|
30
|
-
requirements:
|
31
|
-
- - ">="
|
32
|
-
- !ruby/object:Gem::Version
|
33
|
-
version: '0'
|
34
|
-
type: :development
|
35
|
-
prerelease: false
|
36
|
-
version_requirements: !ruby/object:Gem::Requirement
|
37
|
-
requirements:
|
38
|
-
- - ">="
|
39
|
-
- !ruby/object:Gem::Version
|
40
|
-
version: '0'
|
41
|
-
- !ruby/object:Gem::Dependency
|
42
|
-
name: rake
|
43
|
-
requirement: !ruby/object:Gem::Requirement
|
44
|
-
requirements:
|
45
|
-
- - ">="
|
46
|
-
- !ruby/object:Gem::Version
|
47
|
-
version: '0'
|
48
|
-
type: :development
|
49
|
-
prerelease: false
|
50
|
-
version_requirements: !ruby/object:Gem::Requirement
|
51
|
-
requirements:
|
52
|
-
- - ">="
|
53
|
-
- !ruby/object:Gem::Version
|
54
|
-
version: '0'
|
55
|
-
- !ruby/object:Gem::Dependency
|
56
|
-
name: minitest
|
57
|
-
requirement: !ruby/object:Gem::Requirement
|
58
|
-
requirements:
|
59
|
-
- - ">="
|
60
|
-
- !ruby/object:Gem::Version
|
61
|
-
version: '5'
|
62
|
-
type: :development
|
63
|
-
prerelease: false
|
64
|
-
version_requirements: !ruby/object:Gem::Requirement
|
65
|
-
requirements:
|
66
|
-
- - ">="
|
67
|
-
- !ruby/object:Gem::Version
|
68
|
-
version: '5'
|
69
|
-
description:
|
70
|
-
email: andrew@chartkick.com
|
26
|
+
version: 0.11.1
|
27
|
+
description:
|
28
|
+
email: andrew@ankane.org
|
71
29
|
executables: []
|
72
30
|
extensions: []
|
73
31
|
extra_rdoc_files: []
|
@@ -76,16 +34,21 @@ files:
|
|
76
34
|
- LICENSE.txt
|
77
35
|
- README.md
|
78
36
|
- lib/torchtext.rb
|
37
|
+
- lib/torchtext/data/functional.rb
|
38
|
+
- lib/torchtext/data/metrics.rb
|
79
39
|
- lib/torchtext/data/utils.rb
|
80
40
|
- lib/torchtext/datasets/text_classification.rb
|
81
41
|
- lib/torchtext/datasets/text_classification_dataset.rb
|
42
|
+
- lib/torchtext/nn/in_proj_container.rb
|
43
|
+
- lib/torchtext/nn/multihead_attention_container.rb
|
44
|
+
- lib/torchtext/nn/scaled_dot_product.rb
|
82
45
|
- lib/torchtext/version.rb
|
83
46
|
- lib/torchtext/vocab.rb
|
84
|
-
homepage: https://github.com/ankane/torchtext
|
47
|
+
homepage: https://github.com/ankane/torchtext-ruby
|
85
48
|
licenses:
|
86
49
|
- BSD-3-Clause
|
87
50
|
metadata: {}
|
88
|
-
post_install_message:
|
51
|
+
post_install_message:
|
89
52
|
rdoc_options: []
|
90
53
|
require_paths:
|
91
54
|
- lib
|
@@ -93,15 +56,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
93
56
|
requirements:
|
94
57
|
- - ">="
|
95
58
|
- !ruby/object:Gem::Version
|
96
|
-
version: '2.
|
59
|
+
version: '2.7'
|
97
60
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
98
61
|
requirements:
|
99
62
|
- - ">="
|
100
63
|
- !ruby/object:Gem::Version
|
101
64
|
version: '0'
|
102
65
|
requirements: []
|
103
|
-
rubygems_version: 3.1
|
104
|
-
signing_key:
|
66
|
+
rubygems_version: 3.4.1
|
67
|
+
signing_key:
|
105
68
|
specification_version: 4
|
106
69
|
summary: Data loaders and abstractions for text and NLP
|
107
70
|
test_files: []
|