torchtext 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 86469f8148e519b940a643f81b5317d3e180d6ebc031da14cb0599b48e3f6556
4
- data.tar.gz: 499079c8a32de3ea6704b58a04ad8511f7a6784cc138b08e87696c69d7835863
3
+ metadata.gz: d5c6ff21a492fa03ce88ed7b823d29310e67b3d0c86825ccd2b6092f330bdbc1
4
+ data.tar.gz: 7df74bb05bb110ae37e9d962adc9dc020088203ed1b9e920fe3c4e9421f8e714
5
5
  SHA512:
6
- metadata.gz: e3ea0d3719d35a58b757ac3d11adeda30912f35f69f7de37047ef702c556e5384862f950e055565db8396d8495a760b4919fd416affbdc0fd815dc14ed02e3a3
7
- data.tar.gz: 16d2817864dc4bba2d54ca4a7288bc609b95ab5d59da51c67eccf70107fd78e67af141b765d01b8eea82c0a112b3dc779c174d7864afee6fb5651a5d787df7c5
6
+ metadata.gz: e2d0da285afa14f72380ad688ce63c316a6e46e0cceb7f332b15d298253ffe05d9f6bf9244aa630bc20719a794fe6b316ac7861788e7775ac104e7741dcf167b
7
+ data.tar.gz: '0202023381da0eefdd6c27587ad1798e9052a8798d3d32f8e71b19d7a5303aaf5a01f16c77d486d1d8249ef89cf4009f146855f6e3d1f4cf2c7e3fc9e66ce830'
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.1.1 (2021-07-15)
2
+
3
+ - Added `NN` module
4
+ - Added `bleu_score` method
5
+
1
6
  ## 0.1.0 (2020-08-24)
2
7
 
3
8
  - First release
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: Data loaders and abstractions for text and NLP - for Ruby
4
4
 
5
+ [![Build Status](https://github.com/ankane/torchtext/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchtext/actions)
6
+
5
7
  ## Installation
6
8
 
7
9
  Add this line to your application’s Gemfile:
@@ -19,7 +21,7 @@ This library follows the [Python API](https://pytorch.org/text/). Many methods a
19
21
  Text classification
20
22
 
21
23
  - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)
22
- - [Ruby code](examples/text_classification)
24
+ - [Ruby code](examples/text_classification.rb)
23
25
 
24
26
  ## Datasets
25
27
 
@@ -33,6 +35,37 @@ Supported datasets are:
33
35
 
34
36
  - [AG_NEWS](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
35
37
 
38
+ ## Data Utils
39
+
40
+ Supports:
41
+
42
+ - tokenizer
43
+ - ngrams_iterator
44
+
45
+ ## Data Metrics
46
+
47
+ Compute the BLEU score
48
+
49
+ ```ruby
50
+ candidate_corpus = [["My", "full", "pytorch", "test"], ["Another", "Sentence"]]
51
+ references_corpus = [[["My", "full", "pytorch", "test"], ["Completely", "Different"]], [["No", "Match"]]]
52
+ TorchText::Data::Metrics.bleu_score(candidate_corpus, references_corpus)
53
+ ```
54
+
55
+ ## NN
56
+
57
+ Supports:
58
+
59
+ - InProjContainer
60
+ - MultiheadAttentionContainer
61
+ - ScaledDotProduct
62
+
63
+ ## Vocab
64
+
65
+ Supports:
66
+
67
+ - Vocab
68
+
36
69
  ## Disclaimer
37
70
 
38
71
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
data/lib/torchtext.rb CHANGED
@@ -9,8 +9,12 @@ require "set"
9
9
 
10
10
  # modules
11
11
  require "torchtext/data/utils"
12
+ require "torchtext/data/metrics"
12
13
  require "torchtext/datasets/text_classification"
13
14
  require "torchtext/datasets/text_classification_dataset"
15
+ require "torchtext/nn/in_proj_container"
16
+ require "torchtext/nn/multihead_attention_container"
17
+ require "torchtext/nn/scaled_dot_product"
14
18
  require "torchtext/vocab"
15
19
  require "torchtext/version"
16
20
 
@@ -0,0 +1,68 @@
1
+ module TorchText
2
+ module Data
3
+ module Metrics
4
+ class << self
5
+ def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
6
+ unless max_n == weights.length
7
+ raise "Length of the \"weights\" list has be equal to max_n"
8
+ end
9
+ unless candidate_corpus.length == references_corpus.length
10
+ raise "The length of candidate and reference corpus should be the same"
11
+ end
12
+
13
+ clipped_counts = Torch.zeros(max_n)
14
+ total_counts = Torch.zeros(max_n)
15
+ weights = Torch.tensor(weights)
16
+
17
+ candidate_len = 0.0
18
+ refs_len = 0.0
19
+
20
+ candidate_corpus.zip(references_corpus) do |candidate, refs|
21
+ candidate_len += candidate.length
22
+
23
+ # Get the length of the reference that's closest in length to the candidate
24
+ refs_len_list = refs.map { |ref| ref.length.to_f }
25
+ refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs }
26
+
27
+ reference_counters = compute_ngram_counter(refs[0], max_n)
28
+ refs[1..-1].each do |ref|
29
+ reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 }
30
+ end
31
+
32
+ candidate_counter = compute_ngram_counter(candidate, max_n)
33
+
34
+ shared_keys = candidate_counter.keys & reference_counters.keys
35
+ clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 }
36
+
37
+ clipped_counter.each_key do |ngram|
38
+ clipped_counts[ngram.length - 1] += clipped_counter[ngram]
39
+ end
40
+
41
+ candidate_counter.each_key do |ngram|
42
+ total_counts[ngram.length - 1] += candidate_counter[ngram]
43
+ end
44
+ end
45
+
46
+ if clipped_counts.to_a.min == 0
47
+ 0.0
48
+ else
49
+ pn = clipped_counts / total_counts
50
+ log_pn = weights * Torch.log(pn)
51
+ score = Torch.exp(log_pn.sum)
52
+
53
+ bp = Math.exp([1 - refs_len / candidate_len, 0].min)
54
+
55
+ bp * score.item
56
+ end
57
+ end
58
+
59
+ private
60
+
61
+ def compute_ngram_counter(tokens, max_n)
62
+ raise "Failed assert" unless max_n > 0
63
+ Hash[TorchText::Data::Utils.ngrams_iterator(tokens, max_n).map { |x| x.split(" ") }.group_by { |v| v }.map { |k, v| [k, v.size] }]
64
+ end
65
+ end
66
+ end
67
+ end
68
+ end
@@ -139,7 +139,7 @@ module TorchText
139
139
  return to_path
140
140
  end
141
141
 
142
- raise "Not implemented yet"
142
+ raise "We currently only support tar.gz and tgz archives"
143
143
  end
144
144
  end
145
145
 
@@ -0,0 +1,16 @@
1
+ module TorchText
2
+ module NN
3
+ class InProjContainer < Torch::NN::Module
4
+ def initialize(query_proj, key_proj, value_proj)
5
+ super()
6
+ @query_proj = query_proj
7
+ @key_proj = key_proj
8
+ @value_proj = value_proj
9
+ end
10
+
11
+ def forward(query, key, value)
12
+ [@query_proj.call(query), @key_proj.call(key), @value_proj.call(value)]
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,50 @@
1
+ module TorchText
2
+ module NN
3
+ class MultiheadAttentionContainer < Torch::NN::Module
4
+ def initialize(nhead, in_proj_container, attention_layer, out_proj, batch_first: false)
5
+ super()
6
+ @nhead = nhead
7
+ @in_proj_container = in_proj_container
8
+ @attention_layer = attention_layer
9
+ @out_proj = out_proj
10
+ @batch_first = batch_first
11
+ end
12
+
13
+ def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
14
+ if @batch_first
15
+ query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
16
+ end
17
+
18
+ tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
19
+ q, k, v = @in_proj_container.call(query, key, value)
20
+ unless q.size(-1) % @nhead == 0
21
+ raise "query's embed_dim must be divisible by the number of heads"
22
+ end
23
+ head_dim = q.size(-1).div(@nhead)
24
+ q = q.reshape(tgt_len, bsz * @nhead, head_dim)
25
+
26
+ unless k.size(-1) % @nhead == 0
27
+ raise "key's embed_dim must be divisible by the number of heads"
28
+ end
29
+ head_dim = k.size(-1).div(@nhead)
30
+ k = k.reshape(src_len, bsz * @nhead, head_dim)
31
+
32
+ unless v.size(-1) % @nhead == 0
33
+ raise "value's embed_dim must be divisible by the number of heads"
34
+ end
35
+ head_dim = v.size(-1).div(@nhead)
36
+ v = v.reshape(src_len, bsz * @nhead, head_dim)
37
+
38
+ attn_output, attn_output_weights = @attention_layer.call(q, k, v, attn_mask: attn_mask, bias_k: bias_k, bias_v: bias_v)
39
+ attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
40
+ attn_output = @out_proj.call(attn_output)
41
+
42
+ if @batch_first
43
+ attn_output = attn_output.transpose(-3, -2)
44
+ end
45
+
46
+ [attn_output, attn_output_weights]
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,72 @@
1
+ module TorchText
2
+ module NN
3
+ class ScaledDotProduct < Torch::NN::Module
4
+ def initialize(dropout: 0.0, batch_first: false)
5
+ super()
6
+ @dropout = dropout
7
+ @batch_first = batch_first
8
+ end
9
+
10
+ def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
11
+ if @batch_first
12
+ query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
13
+ end
14
+
15
+ if !bias_k.nil? && !bias_v.nil?
16
+ unless key.size(-1) == bias_k.size(-1) && key.size(-2) == bias_k.size(-2) && bias_k.size(-3) == 1
17
+ raise "Shape of bias_k is not supported"
18
+ end
19
+ unless value.size(-1) == bias_v.size(-1) && value.size(-2) == bias_v.size(-2) && bias_v.size(-3) == 1
20
+ raise "Shape of bias_v is not supported"
21
+ end
22
+ key = Torch.cat([key, bias_k])
23
+ value = Torch.cat([value, bias_v])
24
+ if !attn_mask.nil?
25
+ attn_mask = Torch::NN::Functional.pad(attn_mask, [0, 1])
26
+ end
27
+ end
28
+
29
+ tgt_len, head_dim = query.size(-3), query.size(-1)
30
+ unless query.size(-1) == key.size(-1) && key.size(-1) == value.size(-1)
31
+ raise "The feature dim of query, key, value must be equal."
32
+ end
33
+ unless key.size() == value.size()
34
+ raise "Shape of key, value must match"
35
+ end
36
+ src_len = key.size(-3)
37
+ batch_heads = [query.size(-2), key.size(-2)].max
38
+
39
+ # Scale query
40
+ query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
41
+ query = query * (head_dim.to_f ** -0.5)
42
+ if !attn_mask.nil?
43
+ if attn_mask.dim() != 3
44
+ raise RuntimeError, "attn_mask must be a 3D tensor."
45
+ end
46
+ if (attn_mask.size(-1) != src_len) || (attn_mask.size(-2) != tgt_len) || (attn_mask.size(-3) != 1 && attn_mask.size(-3) != batch_heads)
47
+ raise RuntimeError, "The size of the attn_mask is not correct."
48
+ end
49
+ if attn_mask.dtype != :bool
50
+ raise RuntimeError, "Only bool tensor is supported for attn_mask"
51
+ end
52
+ end
53
+
54
+ # Dot product of q, k
55
+ attn_output_weights = Torch.matmul(query, key.transpose(-2, -1))
56
+ if !attn_mask.nil?
57
+ # TODO confirm last argument
58
+ attn_output_weights.masked_fill!(attn_mask, -1e8, nil)
59
+ end
60
+ attn_output_weights = Torch::NN::Functional.softmax(attn_output_weights, dim: -1)
61
+ attn_output_weights = Torch::NN::Functional.dropout(attn_output_weights, p: @dropout, training: @training)
62
+ attn_output = Torch.matmul(attn_output_weights, value)
63
+
64
+ if @batch_first
65
+ [attn_output, attn_output_weights]
66
+ else
67
+ [attn_output.transpose(-3, -2), attn_output_weights]
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchText
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchtext
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-24 00:00:00.000000000 Z
11
+ date: 2021-07-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -66,7 +66,7 @@ dependencies:
66
66
  - - ">="
67
67
  - !ruby/object:Gem::Version
68
68
  version: '5'
69
- description:
69
+ description:
70
70
  email: andrew@chartkick.com
71
71
  executables: []
72
72
  extensions: []
@@ -76,16 +76,20 @@ files:
76
76
  - LICENSE.txt
77
77
  - README.md
78
78
  - lib/torchtext.rb
79
+ - lib/torchtext/data/metrics.rb
79
80
  - lib/torchtext/data/utils.rb
80
81
  - lib/torchtext/datasets/text_classification.rb
81
82
  - lib/torchtext/datasets/text_classification_dataset.rb
83
+ - lib/torchtext/nn/in_proj_container.rb
84
+ - lib/torchtext/nn/multihead_attention_container.rb
85
+ - lib/torchtext/nn/scaled_dot_product.rb
82
86
  - lib/torchtext/version.rb
83
87
  - lib/torchtext/vocab.rb
84
88
  homepage: https://github.com/ankane/torchtext
85
89
  licenses:
86
90
  - BSD-3-Clause
87
91
  metadata: {}
88
- post_install_message:
92
+ post_install_message:
89
93
  rdoc_options: []
90
94
  require_paths:
91
95
  - lib
@@ -100,8 +104,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
100
104
  - !ruby/object:Gem::Version
101
105
  version: '0'
102
106
  requirements: []
103
- rubygems_version: 3.1.2
104
- signing_key:
107
+ rubygems_version: 3.2.22
108
+ signing_key:
105
109
  specification_version: 4
106
110
  summary: Data loaders and abstractions for text and NLP
107
111
  test_files: []