torchtext 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 86469f8148e519b940a643f81b5317d3e180d6ebc031da14cb0599b48e3f6556
4
- data.tar.gz: 499079c8a32de3ea6704b58a04ad8511f7a6784cc138b08e87696c69d7835863
3
+ metadata.gz: d5c6ff21a492fa03ce88ed7b823d29310e67b3d0c86825ccd2b6092f330bdbc1
4
+ data.tar.gz: 7df74bb05bb110ae37e9d962adc9dc020088203ed1b9e920fe3c4e9421f8e714
5
5
  SHA512:
6
- metadata.gz: e3ea0d3719d35a58b757ac3d11adeda30912f35f69f7de37047ef702c556e5384862f950e055565db8396d8495a760b4919fd416affbdc0fd815dc14ed02e3a3
7
- data.tar.gz: 16d2817864dc4bba2d54ca4a7288bc609b95ab5d59da51c67eccf70107fd78e67af141b765d01b8eea82c0a112b3dc779c174d7864afee6fb5651a5d787df7c5
6
+ metadata.gz: e2d0da285afa14f72380ad688ce63c316a6e46e0cceb7f332b15d298253ffe05d9f6bf9244aa630bc20719a794fe6b316ac7861788e7775ac104e7741dcf167b
7
+ data.tar.gz: '0202023381da0eefdd6c27587ad1798e9052a8798d3d32f8e71b19d7a5303aaf5a01f16c77d486d1d8249ef89cf4009f146855f6e3d1f4cf2c7e3fc9e66ce830'
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.1.1 (2021-07-15)
2
+
3
+ - Added `NN` module
4
+ - Added `bleu_score` method
5
+
1
6
  ## 0.1.0 (2020-08-24)
2
7
 
3
8
  - First release
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: Data loaders and abstractions for text and NLP - for Ruby
4
4
 
5
+ [![Build Status](https://github.com/ankane/torchtext/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchtext/actions)
6
+
5
7
  ## Installation
6
8
 
7
9
  Add this line to your application’s Gemfile:
@@ -19,7 +21,7 @@ This library follows the [Python API](https://pytorch.org/text/). Many methods a
19
21
  Text classification
20
22
 
21
23
  - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)
22
- - [Ruby code](examples/text_classification)
24
+ - [Ruby code](examples/text_classification.rb)
23
25
 
24
26
  ## Datasets
25
27
 
@@ -33,6 +35,37 @@ Supported datasets are:
33
35
 
34
36
  - [AG_NEWS](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
35
37
 
38
+ ## Data Utils
39
+
40
+ Supports:
41
+
42
+ - tokenizer
43
+ - ngrams_iterator
44
+
45
+ ## Data Metrics
46
+
47
+ Compute the BLEU score
48
+
49
+ ```ruby
50
+ candidate_corpus = [["My", "full", "pytorch", "test"], ["Another", "Sentence"]]
51
+ references_corpus = [[["My", "full", "pytorch", "test"], ["Completely", "Different"]], [["No", "Match"]]]
52
+ TorchText::Data::Metrics.bleu_score(candidate_corpus, references_corpus)
53
+ ```
54
+
55
+ ## NN
56
+
57
+ Supports:
58
+
59
+ - InProjContainer
60
+ - MultiheadAttentionContainer
61
+ - ScaledDotProduct
62
+
63
+ ## Vocab
64
+
65
+ Supports:
66
+
67
+ - Vocab
68
+
36
69
  ## Disclaimer
37
70
 
38
71
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
data/lib/torchtext.rb CHANGED
@@ -9,8 +9,12 @@ require "set"
9
9
 
10
10
  # modules
11
11
  require "torchtext/data/utils"
12
+ require "torchtext/data/metrics"
12
13
  require "torchtext/datasets/text_classification"
13
14
  require "torchtext/datasets/text_classification_dataset"
15
+ require "torchtext/nn/in_proj_container"
16
+ require "torchtext/nn/multihead_attention_container"
17
+ require "torchtext/nn/scaled_dot_product"
14
18
  require "torchtext/vocab"
15
19
  require "torchtext/version"
16
20
 
@@ -0,0 +1,68 @@
1
+ module TorchText
2
+ module Data
3
+ module Metrics
4
+ class << self
5
+ def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
6
+ unless max_n == weights.length
7
+ raise "Length of the \"weights\" list has be equal to max_n"
8
+ end
9
+ unless candidate_corpus.length == references_corpus.length
10
+ raise "The length of candidate and reference corpus should be the same"
11
+ end
12
+
13
+ clipped_counts = Torch.zeros(max_n)
14
+ total_counts = Torch.zeros(max_n)
15
+ weights = Torch.tensor(weights)
16
+
17
+ candidate_len = 0.0
18
+ refs_len = 0.0
19
+
20
+ candidate_corpus.zip(references_corpus) do |candidate, refs|
21
+ candidate_len += candidate.length
22
+
23
+ # Get the length of the reference that's closest in length to the candidate
24
+ refs_len_list = refs.map { |ref| ref.length.to_f }
25
+ refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs }
26
+
27
+ reference_counters = compute_ngram_counter(refs[0], max_n)
28
+ refs[1..-1].each do |ref|
29
+ reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 }
30
+ end
31
+
32
+ candidate_counter = compute_ngram_counter(candidate, max_n)
33
+
34
+ shared_keys = candidate_counter.keys & reference_counters.keys
35
+ clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 }
36
+
37
+ clipped_counter.each_key do |ngram|
38
+ clipped_counts[ngram.length - 1] += clipped_counter[ngram]
39
+ end
40
+
41
+ candidate_counter.each_key do |ngram|
42
+ total_counts[ngram.length - 1] += candidate_counter[ngram]
43
+ end
44
+ end
45
+
46
+ if clipped_counts.to_a.min == 0
47
+ 0.0
48
+ else
49
+ pn = clipped_counts / total_counts
50
+ log_pn = weights * Torch.log(pn)
51
+ score = Torch.exp(log_pn.sum)
52
+
53
+ bp = Math.exp([1 - refs_len / candidate_len, 0].min)
54
+
55
+ bp * score.item
56
+ end
57
+ end
58
+
59
+ private
60
+
61
+ def compute_ngram_counter(tokens, max_n)
62
+ raise "Failed assert" unless max_n > 0
63
+ Hash[TorchText::Data::Utils.ngrams_iterator(tokens, max_n).map { |x| x.split(" ") }.group_by { |v| v }.map { |k, v| [k, v.size] }]
64
+ end
65
+ end
66
+ end
67
+ end
68
+ end
@@ -139,7 +139,7 @@ module TorchText
139
139
  return to_path
140
140
  end
141
141
 
142
- raise "Not implemented yet"
142
+ raise "We currently only support tar.gz and tgz archives"
143
143
  end
144
144
  end
145
145
 
@@ -0,0 +1,16 @@
1
+ module TorchText
2
+ module NN
3
+ class InProjContainer < Torch::NN::Module
4
+ def initialize(query_proj, key_proj, value_proj)
5
+ super()
6
+ @query_proj = query_proj
7
+ @key_proj = key_proj
8
+ @value_proj = value_proj
9
+ end
10
+
11
+ def forward(query, key, value)
12
+ [@query_proj.call(query), @key_proj.call(key), @value_proj.call(value)]
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,50 @@
1
+ module TorchText
2
+ module NN
3
+ class MultiheadAttentionContainer < Torch::NN::Module
4
+ def initialize(nhead, in_proj_container, attention_layer, out_proj, batch_first: false)
5
+ super()
6
+ @nhead = nhead
7
+ @in_proj_container = in_proj_container
8
+ @attention_layer = attention_layer
9
+ @out_proj = out_proj
10
+ @batch_first = batch_first
11
+ end
12
+
13
+ def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
14
+ if @batch_first
15
+ query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
16
+ end
17
+
18
+ tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
19
+ q, k, v = @in_proj_container.call(query, key, value)
20
+ unless q.size(-1) % @nhead == 0
21
+ raise "query's embed_dim must be divisible by the number of heads"
22
+ end
23
+ head_dim = q.size(-1).div(@nhead)
24
+ q = q.reshape(tgt_len, bsz * @nhead, head_dim)
25
+
26
+ unless k.size(-1) % @nhead == 0
27
+ raise "key's embed_dim must be divisible by the number of heads"
28
+ end
29
+ head_dim = k.size(-1).div(@nhead)
30
+ k = k.reshape(src_len, bsz * @nhead, head_dim)
31
+
32
+ unless v.size(-1) % @nhead == 0
33
+ raise "value's embed_dim must be divisible by the number of heads"
34
+ end
35
+ head_dim = v.size(-1).div(@nhead)
36
+ v = v.reshape(src_len, bsz * @nhead, head_dim)
37
+
38
+ attn_output, attn_output_weights = @attention_layer.call(q, k, v, attn_mask: attn_mask, bias_k: bias_k, bias_v: bias_v)
39
+ attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
40
+ attn_output = @out_proj.call(attn_output)
41
+
42
+ if @batch_first
43
+ attn_output = attn_output.transpose(-3, -2)
44
+ end
45
+
46
+ [attn_output, attn_output_weights]
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,72 @@
1
+ module TorchText
2
+ module NN
3
+ class ScaledDotProduct < Torch::NN::Module
4
+ def initialize(dropout: 0.0, batch_first: false)
5
+ super()
6
+ @dropout = dropout
7
+ @batch_first = batch_first
8
+ end
9
+
10
+ def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
11
+ if @batch_first
12
+ query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
13
+ end
14
+
15
+ if !bias_k.nil? && !bias_v.nil?
16
+ unless key.size(-1) == bias_k.size(-1) && key.size(-2) == bias_k.size(-2) && bias_k.size(-3) == 1
17
+ raise "Shape of bias_k is not supported"
18
+ end
19
+ unless value.size(-1) == bias_v.size(-1) && value.size(-2) == bias_v.size(-2) && bias_v.size(-3) == 1
20
+ raise "Shape of bias_v is not supported"
21
+ end
22
+ key = Torch.cat([key, bias_k])
23
+ value = Torch.cat([value, bias_v])
24
+ if !attn_mask.nil?
25
+ attn_mask = Torch::NN::Functional.pad(attn_mask, [0, 1])
26
+ end
27
+ end
28
+
29
+ tgt_len, head_dim = query.size(-3), query.size(-1)
30
+ unless query.size(-1) == key.size(-1) && key.size(-1) == value.size(-1)
31
+ raise "The feature dim of query, key, value must be equal."
32
+ end
33
+ unless key.size() == value.size()
34
+ raise "Shape of key, value must match"
35
+ end
36
+ src_len = key.size(-3)
37
+ batch_heads = [query.size(-2), key.size(-2)].max
38
+
39
+ # Scale query
40
+ query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
41
+ query = query * (head_dim.to_f ** -0.5)
42
+ if !attn_mask.nil?
43
+ if attn_mask.dim() != 3
44
+ raise RuntimeError, "attn_mask must be a 3D tensor."
45
+ end
46
+ if (attn_mask.size(-1) != src_len) || (attn_mask.size(-2) != tgt_len) || (attn_mask.size(-3) != 1 && attn_mask.size(-3) != batch_heads)
47
+ raise RuntimeError, "The size of the attn_mask is not correct."
48
+ end
49
+ if attn_mask.dtype != :bool
50
+ raise RuntimeError, "Only bool tensor is supported for attn_mask"
51
+ end
52
+ end
53
+
54
+ # Dot product of q, k
55
+ attn_output_weights = Torch.matmul(query, key.transpose(-2, -1))
56
+ if !attn_mask.nil?
57
+ # TODO confirm last argument
58
+ attn_output_weights.masked_fill!(attn_mask, -1e8, nil)
59
+ end
60
+ attn_output_weights = Torch::NN::Functional.softmax(attn_output_weights, dim: -1)
61
+ attn_output_weights = Torch::NN::Functional.dropout(attn_output_weights, p: @dropout, training: @training)
62
+ attn_output = Torch.matmul(attn_output_weights, value)
63
+
64
+ if @batch_first
65
+ [attn_output, attn_output_weights]
66
+ else
67
+ [attn_output.transpose(-3, -2), attn_output_weights]
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchText
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchtext
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-24 00:00:00.000000000 Z
11
+ date: 2021-07-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -66,7 +66,7 @@ dependencies:
66
66
  - - ">="
67
67
  - !ruby/object:Gem::Version
68
68
  version: '5'
69
- description:
69
+ description:
70
70
  email: andrew@chartkick.com
71
71
  executables: []
72
72
  extensions: []
@@ -76,16 +76,20 @@ files:
76
76
  - LICENSE.txt
77
77
  - README.md
78
78
  - lib/torchtext.rb
79
+ - lib/torchtext/data/metrics.rb
79
80
  - lib/torchtext/data/utils.rb
80
81
  - lib/torchtext/datasets/text_classification.rb
81
82
  - lib/torchtext/datasets/text_classification_dataset.rb
83
+ - lib/torchtext/nn/in_proj_container.rb
84
+ - lib/torchtext/nn/multihead_attention_container.rb
85
+ - lib/torchtext/nn/scaled_dot_product.rb
82
86
  - lib/torchtext/version.rb
83
87
  - lib/torchtext/vocab.rb
84
88
  homepage: https://github.com/ankane/torchtext
85
89
  licenses:
86
90
  - BSD-3-Clause
87
91
  metadata: {}
88
- post_install_message:
92
+ post_install_message:
89
93
  rdoc_options: []
90
94
  require_paths:
91
95
  - lib
@@ -100,8 +104,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
100
104
  - !ruby/object:Gem::Version
101
105
  version: '0'
102
106
  requirements: []
103
- rubygems_version: 3.1.2
104
- signing_key:
107
+ rubygems_version: 3.2.22
108
+ signing_key:
105
109
  specification_version: 4
106
110
  summary: Data loaders and abstractions for text and NLP
107
111
  test_files: []