torchtext 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 86469f8148e519b940a643f81b5317d3e180d6ebc031da14cb0599b48e3f6556
4
- data.tar.gz: 499079c8a32de3ea6704b58a04ad8511f7a6784cc138b08e87696c69d7835863
3
+ metadata.gz: f9f88b060fcb69a0df1258c1a5399b2518ca9b1c95cd3c3f999bd81f9df0cc41
4
+ data.tar.gz: 3b04fab3f8e9a3ef86e3c46aa64f2ba98a8d85c3dd42889cb728d0c7f39e7be2
5
5
  SHA512:
6
- metadata.gz: e3ea0d3719d35a58b757ac3d11adeda30912f35f69f7de37047ef702c556e5384862f950e055565db8396d8495a760b4919fd416affbdc0fd815dc14ed02e3a3
7
- data.tar.gz: 16d2817864dc4bba2d54ca4a7288bc609b95ab5d59da51c67eccf70107fd78e67af141b765d01b8eea82c0a112b3dc779c174d7864afee6fb5651a5d787df7c5
6
+ metadata.gz: 4669a0ce275ec3ef0a1331c39835b81a8cff582c62012adc62ae8c59ea941befb53cf232198922bcbdd9795930e970a4e8970a851d66969531c75dd83e7bd40c
7
+ data.tar.gz: a3bb8659e334540849a011a0f526136650a4bd88d5a9da6b48f6e3d14dd889b08fe730c3e517837e6989e06c9a8b515d8a34958e3df267a80eedd465a239123f
data/CHANGELOG.md CHANGED
@@ -1,3 +1,13 @@
1
+ ## 0.2.0 (2023-01-30)
2
+
3
+ - Added `Functional` module
4
+ - Dropped support for Ruby < 2.7
5
+
6
+ ## 0.1.1 (2021-07-15)
7
+
8
+ - Added `NN` module
9
+ - Added `bleu_score` method
10
+
1
11
  ## 0.1.0 (2020-08-24)
2
12
 
3
13
  - First release
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 3-Clause License
2
2
 
3
3
  Copyright (c) James Bradbury and Soumith Chintala 2016,
4
- Copyright (c) Andrew Kane 2020,
4
+ Copyright (c) Andrew Kane 2020-2023,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -1,13 +1,15 @@
1
- # TorchText
1
+ # TorchText Ruby
2
2
 
3
3
  :fire: Data loaders and abstractions for text and NLP - for Ruby
4
4
 
5
+ [![Build Status](https://github.com/ankane/torchtext-ruby/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchtext-ruby/actions)
6
+
5
7
  ## Installation
6
8
 
7
9
  Add this line to your application’s Gemfile:
8
10
 
9
11
  ```ruby
10
- gem 'torchtext'
12
+ gem "torchtext"
11
13
  ```
12
14
 
13
15
  ## Getting Started
@@ -19,7 +21,7 @@ This library follows the [Python API](https://pytorch.org/text/). Many methods a
19
21
  Text classification
20
22
 
21
23
  - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)
22
- - [Ruby code](examples/text_classification)
24
+ - [Ruby code](examples/text_classification.rb)
23
25
 
24
26
  ## Datasets
25
27
 
@@ -33,6 +35,37 @@ Supported datasets are:
33
35
 
34
36
  - [AG_NEWS](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
35
37
 
38
+ ## Data Utils
39
+
40
+ Supports:
41
+
42
+ - tokenizer
43
+ - ngrams_iterator
44
+
45
+ ## Data Metrics
46
+
47
+ Compute the BLEU score
48
+
49
+ ```ruby
50
+ candidate_corpus = [["My", "full", "pytorch", "test"], ["Another", "Sentence"]]
51
+ references_corpus = [[["My", "full", "pytorch", "test"], ["Completely", "Different"]], [["No", "Match"]]]
52
+ TorchText::Data::Metrics.bleu_score(candidate_corpus, references_corpus)
53
+ ```
54
+
55
+ ## NN
56
+
57
+ Supports:
58
+
59
+ - InProjContainer
60
+ - MultiheadAttentionContainer
61
+ - ScaledDotProduct
62
+
63
+ ## Vocab
64
+
65
+ Supports:
66
+
67
+ - Vocab
68
+
36
69
  ## Disclaimer
37
70
 
38
71
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
@@ -41,22 +74,22 @@ If you’re a dataset owner and wish to update any details or remove it from thi
41
74
 
42
75
  ## History
43
76
 
44
- View the [changelog](https://github.com/ankane/torchtext/blob/master/CHANGELOG.md)
77
+ View the [changelog](https://github.com/ankane/torchtext-ruby/blob/master/CHANGELOG.md)
45
78
 
46
79
  ## Contributing
47
80
 
48
81
  Everyone is encouraged to help improve this project. Here are a few ways you can help:
49
82
 
50
- - [Report bugs](https://github.com/ankane/torchtext/issues)
51
- - Fix bugs and [submit pull requests](https://github.com/ankane/torchtext/pulls)
83
+ - [Report bugs](https://github.com/ankane/torchtext-ruby/issues)
84
+ - Fix bugs and [submit pull requests](https://github.com/ankane/torchtext-ruby/pulls)
52
85
  - Write, clarify, or fix documentation
53
86
  - Suggest or add new features
54
87
 
55
88
  To get started with development:
56
89
 
57
90
  ```sh
58
- git clone https://github.com/ankane/torchtext.git
59
- cd torchtext
91
+ git clone https://github.com/ankane/torchtext-ruby.git
92
+ cd torchtext-ruby
60
93
  bundle install
61
94
  bundle exec rake test
62
95
  ```
@@ -0,0 +1,11 @@
1
+ module TorchText
2
+ module Data
3
+ module Functional
4
+ class << self
5
+ def simple_space_split(iterator)
6
+ iterator.map(&:split)
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,68 @@
1
+ module TorchText
2
+ module Data
3
+ module Metrics
4
+ class << self
5
+ def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
6
+ unless max_n == weights.length
7
+ raise "Length of the \"weights\" list has be equal to max_n"
8
+ end
9
+ unless candidate_corpus.length == references_corpus.length
10
+ raise "The length of candidate and reference corpus should be the same"
11
+ end
12
+
13
+ clipped_counts = Torch.zeros(max_n)
14
+ total_counts = Torch.zeros(max_n)
15
+ weights = Torch.tensor(weights)
16
+
17
+ candidate_len = 0.0
18
+ refs_len = 0.0
19
+
20
+ candidate_corpus.zip(references_corpus) do |candidate, refs|
21
+ candidate_len += candidate.length
22
+
23
+ # Get the length of the reference that's closest in length to the candidate
24
+ refs_len_list = refs.map { |ref| ref.length.to_f }
25
+ refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs }
26
+
27
+ reference_counters = compute_ngram_counter(refs[0], max_n)
28
+ refs[1..-1].each do |ref|
29
+ reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 }
30
+ end
31
+
32
+ candidate_counter = compute_ngram_counter(candidate, max_n)
33
+
34
+ shared_keys = candidate_counter.keys & reference_counters.keys
35
+ clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 }
36
+
37
+ clipped_counter.each_key do |ngram|
38
+ clipped_counts[ngram.length - 1] += clipped_counter[ngram]
39
+ end
40
+
41
+ candidate_counter.each_key do |ngram|
42
+ total_counts[ngram.length - 1] += candidate_counter[ngram]
43
+ end
44
+ end
45
+
46
+ if clipped_counts.to_a.min == 0
47
+ 0.0
48
+ else
49
+ pn = clipped_counts / total_counts
50
+ log_pn = weights * Torch.log(pn)
51
+ score = Torch.exp(log_pn.sum)
52
+
53
+ bp = Math.exp([1 - refs_len / candidate_len, 0].min)
54
+
55
+ bp * score.item
56
+ end
57
+ end
58
+
59
+ private
60
+
61
+ def compute_ngram_counter(tokens, max_n)
62
+ raise "Failed assert" unless max_n > 0
63
+ Hash[TorchText::Data::Utils.ngrams_iterator(tokens, max_n).map { |x| x.split(" ") }.group_by { |v| v }.map { |k, v| [k, v.size] }]
64
+ end
65
+ end
66
+ end
67
+ end
68
+ end
@@ -139,7 +139,7 @@ module TorchText
139
139
  return to_path
140
140
  end
141
141
 
142
- raise "Not implemented yet"
142
+ raise "We currently only support tar.gz and tgz archives"
143
143
  end
144
144
  end
145
145
 
@@ -0,0 +1,16 @@
1
+ module TorchText
2
+ module NN
3
+ class InProjContainer < Torch::NN::Module
4
+ def initialize(query_proj, key_proj, value_proj)
5
+ super()
6
+ @query_proj = query_proj
7
+ @key_proj = key_proj
8
+ @value_proj = value_proj
9
+ end
10
+
11
+ def forward(query, key, value)
12
+ [@query_proj.call(query), @key_proj.call(key), @value_proj.call(value)]
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,50 @@
1
+ module TorchText
2
+ module NN
3
+ class MultiheadAttentionContainer < Torch::NN::Module
4
+ def initialize(nhead, in_proj_container, attention_layer, out_proj, batch_first: false)
5
+ super()
6
+ @nhead = nhead
7
+ @in_proj_container = in_proj_container
8
+ @attention_layer = attention_layer
9
+ @out_proj = out_proj
10
+ @batch_first = batch_first
11
+ end
12
+
13
+ def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
14
+ if @batch_first
15
+ query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
16
+ end
17
+
18
+ tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
19
+ q, k, v = @in_proj_container.call(query, key, value)
20
+ unless q.size(-1) % @nhead == 0
21
+ raise "query's embed_dim must be divisible by the number of heads"
22
+ end
23
+ head_dim = q.size(-1).div(@nhead)
24
+ q = q.reshape(tgt_len, bsz * @nhead, head_dim)
25
+
26
+ unless k.size(-1) % @nhead == 0
27
+ raise "key's embed_dim must be divisible by the number of heads"
28
+ end
29
+ head_dim = k.size(-1).div(@nhead)
30
+ k = k.reshape(src_len, bsz * @nhead, head_dim)
31
+
32
+ unless v.size(-1) % @nhead == 0
33
+ raise "value's embed_dim must be divisible by the number of heads"
34
+ end
35
+ head_dim = v.size(-1).div(@nhead)
36
+ v = v.reshape(src_len, bsz * @nhead, head_dim)
37
+
38
+ attn_output, attn_output_weights = @attention_layer.call(q, k, v, attn_mask: attn_mask, bias_k: bias_k, bias_v: bias_v)
39
+ attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
40
+ attn_output = @out_proj.call(attn_output)
41
+
42
+ if @batch_first
43
+ attn_output = attn_output.transpose(-3, -2)
44
+ end
45
+
46
+ [attn_output, attn_output_weights]
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,72 @@
1
+ module TorchText
2
+ module NN
3
+ class ScaledDotProduct < Torch::NN::Module
4
+ def initialize(dropout: 0.0, batch_first: false)
5
+ super()
6
+ @dropout = dropout
7
+ @batch_first = batch_first
8
+ end
9
+
10
+ def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
11
+ if @batch_first
12
+ query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
13
+ end
14
+
15
+ if !bias_k.nil? && !bias_v.nil?
16
+ unless key.size(-1) == bias_k.size(-1) && key.size(-2) == bias_k.size(-2) && bias_k.size(-3) == 1
17
+ raise "Shape of bias_k is not supported"
18
+ end
19
+ unless value.size(-1) == bias_v.size(-1) && value.size(-2) == bias_v.size(-2) && bias_v.size(-3) == 1
20
+ raise "Shape of bias_v is not supported"
21
+ end
22
+ key = Torch.cat([key, bias_k])
23
+ value = Torch.cat([value, bias_v])
24
+ if !attn_mask.nil?
25
+ attn_mask = Torch::NN::Functional.pad(attn_mask, [0, 1])
26
+ end
27
+ end
28
+
29
+ tgt_len, head_dim = query.size(-3), query.size(-1)
30
+ unless query.size(-1) == key.size(-1) && key.size(-1) == value.size(-1)
31
+ raise "The feature dim of query, key, value must be equal."
32
+ end
33
+ unless key.size() == value.size()
34
+ raise "Shape of key, value must match"
35
+ end
36
+ src_len = key.size(-3)
37
+ batch_heads = [query.size(-2), key.size(-2)].max
38
+
39
+ # Scale query
40
+ query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
41
+ query = query * (head_dim.to_f ** -0.5)
42
+ if !attn_mask.nil?
43
+ if attn_mask.dim() != 3
44
+ raise RuntimeError, "attn_mask must be a 3D tensor."
45
+ end
46
+ if (attn_mask.size(-1) != src_len) || (attn_mask.size(-2) != tgt_len) || (attn_mask.size(-3) != 1 && attn_mask.size(-3) != batch_heads)
47
+ raise RuntimeError, "The size of the attn_mask is not correct."
48
+ end
49
+ if attn_mask.dtype != :bool
50
+ raise RuntimeError, "Only bool tensor is supported for attn_mask"
51
+ end
52
+ end
53
+
54
+ # Dot product of q, k
55
+ attn_output_weights = Torch.matmul(query, key.transpose(-2, -1))
56
+ if !attn_mask.nil?
57
+ # TODO confirm last argument
58
+ attn_output_weights.masked_fill!(attn_mask, -1e8, nil)
59
+ end
60
+ attn_output_weights = Torch::NN::Functional.softmax(attn_output_weights, dim: -1)
61
+ attn_output_weights = Torch::NN::Functional.dropout(attn_output_weights, p: @dropout, training: @training)
62
+ attn_output = Torch.matmul(attn_output_weights, value)
63
+
64
+ if @batch_first
65
+ [attn_output, attn_output_weights]
66
+ else
67
+ [attn_output.transpose(-3, -2), attn_output_weights]
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchText
2
- VERSION = "0.1.0"
2
+ VERSION = "0.2.0"
3
3
  end
data/lib/torchtext.rb CHANGED
@@ -8,11 +8,16 @@ require "rubygems/package"
8
8
  require "set"
9
9
 
10
10
  # modules
11
- require "torchtext/data/utils"
12
- require "torchtext/datasets/text_classification"
13
- require "torchtext/datasets/text_classification_dataset"
14
- require "torchtext/vocab"
15
- require "torchtext/version"
11
+ require_relative "torchtext/data/functional"
12
+ require_relative "torchtext/data/metrics"
13
+ require_relative "torchtext/data/utils"
14
+ require_relative "torchtext/datasets/text_classification"
15
+ require_relative "torchtext/datasets/text_classification_dataset"
16
+ require_relative "torchtext/nn/in_proj_container"
17
+ require_relative "torchtext/nn/multihead_attention_container"
18
+ require_relative "torchtext/nn/scaled_dot_product"
19
+ require_relative "torchtext/vocab"
20
+ require_relative "torchtext/version"
16
21
 
17
22
  module TorchText
18
23
  class Error < StandardError; end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchtext
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-24 00:00:00.000000000 Z
11
+ date: 2023-01-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,58 +16,16 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.3.2
19
+ version: 0.11.1
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.3.2
27
- - !ruby/object:Gem::Dependency
28
- name: bundler
29
- requirement: !ruby/object:Gem::Requirement
30
- requirements:
31
- - - ">="
32
- - !ruby/object:Gem::Version
33
- version: '0'
34
- type: :development
35
- prerelease: false
36
- version_requirements: !ruby/object:Gem::Requirement
37
- requirements:
38
- - - ">="
39
- - !ruby/object:Gem::Version
40
- version: '0'
41
- - !ruby/object:Gem::Dependency
42
- name: rake
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: minitest
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '5'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '5'
69
- description:
70
- email: andrew@chartkick.com
26
+ version: 0.11.1
27
+ description:
28
+ email: andrew@ankane.org
71
29
  executables: []
72
30
  extensions: []
73
31
  extra_rdoc_files: []
@@ -76,16 +34,21 @@ files:
76
34
  - LICENSE.txt
77
35
  - README.md
78
36
  - lib/torchtext.rb
37
+ - lib/torchtext/data/functional.rb
38
+ - lib/torchtext/data/metrics.rb
79
39
  - lib/torchtext/data/utils.rb
80
40
  - lib/torchtext/datasets/text_classification.rb
81
41
  - lib/torchtext/datasets/text_classification_dataset.rb
42
+ - lib/torchtext/nn/in_proj_container.rb
43
+ - lib/torchtext/nn/multihead_attention_container.rb
44
+ - lib/torchtext/nn/scaled_dot_product.rb
82
45
  - lib/torchtext/version.rb
83
46
  - lib/torchtext/vocab.rb
84
- homepage: https://github.com/ankane/torchtext
47
+ homepage: https://github.com/ankane/torchtext-ruby
85
48
  licenses:
86
49
  - BSD-3-Clause
87
50
  metadata: {}
88
- post_install_message:
51
+ post_install_message:
89
52
  rdoc_options: []
90
53
  require_paths:
91
54
  - lib
@@ -93,15 +56,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
93
56
  requirements:
94
57
  - - ">="
95
58
  - !ruby/object:Gem::Version
96
- version: '2.5'
59
+ version: '2.7'
97
60
  required_rubygems_version: !ruby/object:Gem::Requirement
98
61
  requirements:
99
62
  - - ">="
100
63
  - !ruby/object:Gem::Version
101
64
  version: '0'
102
65
  requirements: []
103
- rubygems_version: 3.1.2
104
- signing_key:
66
+ rubygems_version: 3.4.1
67
+ signing_key:
105
68
  specification_version: 4
106
69
  summary: Data loaders and abstractions for text and NLP
107
70
  test_files: []