torchtext 0.1.0 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 86469f8148e519b940a643f81b5317d3e180d6ebc031da14cb0599b48e3f6556
4
- data.tar.gz: 499079c8a32de3ea6704b58a04ad8511f7a6784cc138b08e87696c69d7835863
3
+ metadata.gz: f9f88b060fcb69a0df1258c1a5399b2518ca9b1c95cd3c3f999bd81f9df0cc41
4
+ data.tar.gz: 3b04fab3f8e9a3ef86e3c46aa64f2ba98a8d85c3dd42889cb728d0c7f39e7be2
5
5
  SHA512:
6
- metadata.gz: e3ea0d3719d35a58b757ac3d11adeda30912f35f69f7de37047ef702c556e5384862f950e055565db8396d8495a760b4919fd416affbdc0fd815dc14ed02e3a3
7
- data.tar.gz: 16d2817864dc4bba2d54ca4a7288bc609b95ab5d59da51c67eccf70107fd78e67af141b765d01b8eea82c0a112b3dc779c174d7864afee6fb5651a5d787df7c5
6
+ metadata.gz: 4669a0ce275ec3ef0a1331c39835b81a8cff582c62012adc62ae8c59ea941befb53cf232198922bcbdd9795930e970a4e8970a851d66969531c75dd83e7bd40c
7
+ data.tar.gz: a3bb8659e334540849a011a0f526136650a4bd88d5a9da6b48f6e3d14dd889b08fe730c3e517837e6989e06c9a8b515d8a34958e3df267a80eedd465a239123f
data/CHANGELOG.md CHANGED
@@ -1,3 +1,13 @@
1
+ ## 0.2.0 (2023-01-30)
2
+
3
+ - Added `Functional` module
4
+ - Dropped support for Ruby < 2.7
5
+
6
+ ## 0.1.1 (2021-07-15)
7
+
8
+ - Added `NN` module
9
+ - Added `bleu_score` method
10
+
1
11
  ## 0.1.0 (2020-08-24)
2
12
 
3
13
  - First release
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 3-Clause License
2
2
 
3
3
  Copyright (c) James Bradbury and Soumith Chintala 2016,
4
- Copyright (c) Andrew Kane 2020,
4
+ Copyright (c) Andrew Kane 2020-2023,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -1,13 +1,15 @@
1
- # TorchText
1
+ # TorchText Ruby
2
2
 
3
3
  :fire: Data loaders and abstractions for text and NLP - for Ruby
4
4
 
5
+ [![Build Status](https://github.com/ankane/torchtext-ruby/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchtext-ruby/actions)
6
+
5
7
  ## Installation
6
8
 
7
9
  Add this line to your application’s Gemfile:
8
10
 
9
11
  ```ruby
10
- gem 'torchtext'
12
+ gem "torchtext"
11
13
  ```
12
14
 
13
15
  ## Getting Started
@@ -19,7 +21,7 @@ This library follows the [Python API](https://pytorch.org/text/). Many methods a
19
21
  Text classification
20
22
 
21
23
  - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)
22
- - [Ruby code](examples/text_classification)
24
+ - [Ruby code](examples/text_classification.rb)
23
25
 
24
26
  ## Datasets
25
27
 
@@ -33,6 +35,37 @@ Supported datasets are:
33
35
 
34
36
  - [AG_NEWS](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
35
37
 
38
+ ## Data Utils
39
+
40
+ Supports:
41
+
42
+ - tokenizer
43
+ - ngrams_iterator
44
+
45
+ ## Data Metrics
46
+
47
+ Compute the BLEU score
48
+
49
+ ```ruby
50
+ candidate_corpus = [["My", "full", "pytorch", "test"], ["Another", "Sentence"]]
51
+ references_corpus = [[["My", "full", "pytorch", "test"], ["Completely", "Different"]], [["No", "Match"]]]
52
+ TorchText::Data::Metrics.bleu_score(candidate_corpus, references_corpus)
53
+ ```
54
+
55
+ ## NN
56
+
57
+ Supports:
58
+
59
+ - InProjContainer
60
+ - MultiheadAttentionContainer
61
+ - ScaledDotProduct
62
+
63
+ ## Vocab
64
+
65
+ Supports:
66
+
67
+ - Vocab
68
+
36
69
  ## Disclaimer
37
70
 
38
71
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
@@ -41,22 +74,22 @@ If you’re a dataset owner and wish to update any details or remove it from thi
41
74
 
42
75
  ## History
43
76
 
44
- View the [changelog](https://github.com/ankane/torchtext/blob/master/CHANGELOG.md)
77
+ View the [changelog](https://github.com/ankane/torchtext-ruby/blob/master/CHANGELOG.md)
45
78
 
46
79
  ## Contributing
47
80
 
48
81
  Everyone is encouraged to help improve this project. Here are a few ways you can help:
49
82
 
50
- - [Report bugs](https://github.com/ankane/torchtext/issues)
51
- - Fix bugs and [submit pull requests](https://github.com/ankane/torchtext/pulls)
83
+ - [Report bugs](https://github.com/ankane/torchtext-ruby/issues)
84
+ - Fix bugs and [submit pull requests](https://github.com/ankane/torchtext-ruby/pulls)
52
85
  - Write, clarify, or fix documentation
53
86
  - Suggest or add new features
54
87
 
55
88
  To get started with development:
56
89
 
57
90
  ```sh
58
- git clone https://github.com/ankane/torchtext.git
59
- cd torchtext
91
+ git clone https://github.com/ankane/torchtext-ruby.git
92
+ cd torchtext-ruby
60
93
  bundle install
61
94
  bundle exec rake test
62
95
  ```
@@ -0,0 +1,11 @@
1
+ module TorchText
2
+ module Data
3
+ module Functional
4
+ class << self
5
+ def simple_space_split(iterator)
6
+ iterator.map(&:split)
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,68 @@
1
+ module TorchText
2
+ module Data
3
+ module Metrics
4
+ class << self
5
+ def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
6
+ unless max_n == weights.length
7
+ raise "Length of the \"weights\" list has be equal to max_n"
8
+ end
9
+ unless candidate_corpus.length == references_corpus.length
10
+ raise "The length of candidate and reference corpus should be the same"
11
+ end
12
+
13
+ clipped_counts = Torch.zeros(max_n)
14
+ total_counts = Torch.zeros(max_n)
15
+ weights = Torch.tensor(weights)
16
+
17
+ candidate_len = 0.0
18
+ refs_len = 0.0
19
+
20
+ candidate_corpus.zip(references_corpus) do |candidate, refs|
21
+ candidate_len += candidate.length
22
+
23
+ # Get the length of the reference that's closest in length to the candidate
24
+ refs_len_list = refs.map { |ref| ref.length.to_f }
25
+ refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs }
26
+
27
+ reference_counters = compute_ngram_counter(refs[0], max_n)
28
+ refs[1..-1].each do |ref|
29
+ reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 }
30
+ end
31
+
32
+ candidate_counter = compute_ngram_counter(candidate, max_n)
33
+
34
+ shared_keys = candidate_counter.keys & reference_counters.keys
35
+ clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 }
36
+
37
+ clipped_counter.each_key do |ngram|
38
+ clipped_counts[ngram.length - 1] += clipped_counter[ngram]
39
+ end
40
+
41
+ candidate_counter.each_key do |ngram|
42
+ total_counts[ngram.length - 1] += candidate_counter[ngram]
43
+ end
44
+ end
45
+
46
+ if clipped_counts.to_a.min == 0
47
+ 0.0
48
+ else
49
+ pn = clipped_counts / total_counts
50
+ log_pn = weights * Torch.log(pn)
51
+ score = Torch.exp(log_pn.sum)
52
+
53
+ bp = Math.exp([1 - refs_len / candidate_len, 0].min)
54
+
55
+ bp * score.item
56
+ end
57
+ end
58
+
59
+ private
60
+
61
+ def compute_ngram_counter(tokens, max_n)
62
+ raise "Failed assert" unless max_n > 0
63
+ Hash[TorchText::Data::Utils.ngrams_iterator(tokens, max_n).map { |x| x.split(" ") }.group_by { |v| v }.map { |k, v| [k, v.size] }]
64
+ end
65
+ end
66
+ end
67
+ end
68
+ end
@@ -139,7 +139,7 @@ module TorchText
139
139
  return to_path
140
140
  end
141
141
 
142
- raise "Not implemented yet"
142
+ raise "We currently only support tar.gz and tgz archives"
143
143
  end
144
144
  end
145
145
 
@@ -0,0 +1,16 @@
1
+ module TorchText
2
+ module NN
3
+ class InProjContainer < Torch::NN::Module
4
+ def initialize(query_proj, key_proj, value_proj)
5
+ super()
6
+ @query_proj = query_proj
7
+ @key_proj = key_proj
8
+ @value_proj = value_proj
9
+ end
10
+
11
+ def forward(query, key, value)
12
+ [@query_proj.call(query), @key_proj.call(key), @value_proj.call(value)]
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,50 @@
1
+ module TorchText
2
+ module NN
3
+ class MultiheadAttentionContainer < Torch::NN::Module
4
+ def initialize(nhead, in_proj_container, attention_layer, out_proj, batch_first: false)
5
+ super()
6
+ @nhead = nhead
7
+ @in_proj_container = in_proj_container
8
+ @attention_layer = attention_layer
9
+ @out_proj = out_proj
10
+ @batch_first = batch_first
11
+ end
12
+
13
+ def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
14
+ if @batch_first
15
+ query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
16
+ end
17
+
18
+ tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
19
+ q, k, v = @in_proj_container.call(query, key, value)
20
+ unless q.size(-1) % @nhead == 0
21
+ raise "query's embed_dim must be divisible by the number of heads"
22
+ end
23
+ head_dim = q.size(-1).div(@nhead)
24
+ q = q.reshape(tgt_len, bsz * @nhead, head_dim)
25
+
26
+ unless k.size(-1) % @nhead == 0
27
+ raise "key's embed_dim must be divisible by the number of heads"
28
+ end
29
+ head_dim = k.size(-1).div(@nhead)
30
+ k = k.reshape(src_len, bsz * @nhead, head_dim)
31
+
32
+ unless v.size(-1) % @nhead == 0
33
+ raise "value's embed_dim must be divisible by the number of heads"
34
+ end
35
+ head_dim = v.size(-1).div(@nhead)
36
+ v = v.reshape(src_len, bsz * @nhead, head_dim)
37
+
38
+ attn_output, attn_output_weights = @attention_layer.call(q, k, v, attn_mask: attn_mask, bias_k: bias_k, bias_v: bias_v)
39
+ attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
40
+ attn_output = @out_proj.call(attn_output)
41
+
42
+ if @batch_first
43
+ attn_output = attn_output.transpose(-3, -2)
44
+ end
45
+
46
+ [attn_output, attn_output_weights]
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,72 @@
1
+ module TorchText
2
+ module NN
3
+ class ScaledDotProduct < Torch::NN::Module
4
+ def initialize(dropout: 0.0, batch_first: false)
5
+ super()
6
+ @dropout = dropout
7
+ @batch_first = batch_first
8
+ end
9
+
10
+ def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
11
+ if @batch_first
12
+ query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
13
+ end
14
+
15
+ if !bias_k.nil? && !bias_v.nil?
16
+ unless key.size(-1) == bias_k.size(-1) && key.size(-2) == bias_k.size(-2) && bias_k.size(-3) == 1
17
+ raise "Shape of bias_k is not supported"
18
+ end
19
+ unless value.size(-1) == bias_v.size(-1) && value.size(-2) == bias_v.size(-2) && bias_v.size(-3) == 1
20
+ raise "Shape of bias_v is not supported"
21
+ end
22
+ key = Torch.cat([key, bias_k])
23
+ value = Torch.cat([value, bias_v])
24
+ if !attn_mask.nil?
25
+ attn_mask = Torch::NN::Functional.pad(attn_mask, [0, 1])
26
+ end
27
+ end
28
+
29
+ tgt_len, head_dim = query.size(-3), query.size(-1)
30
+ unless query.size(-1) == key.size(-1) && key.size(-1) == value.size(-1)
31
+ raise "The feature dim of query, key, value must be equal."
32
+ end
33
+ unless key.size() == value.size()
34
+ raise "Shape of key, value must match"
35
+ end
36
+ src_len = key.size(-3)
37
+ batch_heads = [query.size(-2), key.size(-2)].max
38
+
39
+ # Scale query
40
+ query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
41
+ query = query * (head_dim.to_f ** -0.5)
42
+ if !attn_mask.nil?
43
+ if attn_mask.dim() != 3
44
+ raise RuntimeError, "attn_mask must be a 3D tensor."
45
+ end
46
+ if (attn_mask.size(-1) != src_len) || (attn_mask.size(-2) != tgt_len) || (attn_mask.size(-3) != 1 && attn_mask.size(-3) != batch_heads)
47
+ raise RuntimeError, "The size of the attn_mask is not correct."
48
+ end
49
+ if attn_mask.dtype != :bool
50
+ raise RuntimeError, "Only bool tensor is supported for attn_mask"
51
+ end
52
+ end
53
+
54
+ # Dot product of q, k
55
+ attn_output_weights = Torch.matmul(query, key.transpose(-2, -1))
56
+ if !attn_mask.nil?
57
+ # TODO confirm last argument
58
+ attn_output_weights.masked_fill!(attn_mask, -1e8, nil)
59
+ end
60
+ attn_output_weights = Torch::NN::Functional.softmax(attn_output_weights, dim: -1)
61
+ attn_output_weights = Torch::NN::Functional.dropout(attn_output_weights, p: @dropout, training: @training)
62
+ attn_output = Torch.matmul(attn_output_weights, value)
63
+
64
+ if @batch_first
65
+ [attn_output, attn_output_weights]
66
+ else
67
+ [attn_output.transpose(-3, -2), attn_output_weights]
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchText
2
- VERSION = "0.1.0"
2
+ VERSION = "0.2.0"
3
3
  end
data/lib/torchtext.rb CHANGED
@@ -8,11 +8,16 @@ require "rubygems/package"
8
8
  require "set"
9
9
 
10
10
  # modules
11
- require "torchtext/data/utils"
12
- require "torchtext/datasets/text_classification"
13
- require "torchtext/datasets/text_classification_dataset"
14
- require "torchtext/vocab"
15
- require "torchtext/version"
11
+ require_relative "torchtext/data/functional"
12
+ require_relative "torchtext/data/metrics"
13
+ require_relative "torchtext/data/utils"
14
+ require_relative "torchtext/datasets/text_classification"
15
+ require_relative "torchtext/datasets/text_classification_dataset"
16
+ require_relative "torchtext/nn/in_proj_container"
17
+ require_relative "torchtext/nn/multihead_attention_container"
18
+ require_relative "torchtext/nn/scaled_dot_product"
19
+ require_relative "torchtext/vocab"
20
+ require_relative "torchtext/version"
16
21
 
17
22
  module TorchText
18
23
  class Error < StandardError; end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchtext
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-24 00:00:00.000000000 Z
11
+ date: 2023-01-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,58 +16,16 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.3.2
19
+ version: 0.11.1
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.3.2
27
- - !ruby/object:Gem::Dependency
28
- name: bundler
29
- requirement: !ruby/object:Gem::Requirement
30
- requirements:
31
- - - ">="
32
- - !ruby/object:Gem::Version
33
- version: '0'
34
- type: :development
35
- prerelease: false
36
- version_requirements: !ruby/object:Gem::Requirement
37
- requirements:
38
- - - ">="
39
- - !ruby/object:Gem::Version
40
- version: '0'
41
- - !ruby/object:Gem::Dependency
42
- name: rake
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: minitest
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '5'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '5'
69
- description:
70
- email: andrew@chartkick.com
26
+ version: 0.11.1
27
+ description:
28
+ email: andrew@ankane.org
71
29
  executables: []
72
30
  extensions: []
73
31
  extra_rdoc_files: []
@@ -76,16 +34,21 @@ files:
76
34
  - LICENSE.txt
77
35
  - README.md
78
36
  - lib/torchtext.rb
37
+ - lib/torchtext/data/functional.rb
38
+ - lib/torchtext/data/metrics.rb
79
39
  - lib/torchtext/data/utils.rb
80
40
  - lib/torchtext/datasets/text_classification.rb
81
41
  - lib/torchtext/datasets/text_classification_dataset.rb
42
+ - lib/torchtext/nn/in_proj_container.rb
43
+ - lib/torchtext/nn/multihead_attention_container.rb
44
+ - lib/torchtext/nn/scaled_dot_product.rb
82
45
  - lib/torchtext/version.rb
83
46
  - lib/torchtext/vocab.rb
84
- homepage: https://github.com/ankane/torchtext
47
+ homepage: https://github.com/ankane/torchtext-ruby
85
48
  licenses:
86
49
  - BSD-3-Clause
87
50
  metadata: {}
88
- post_install_message:
51
+ post_install_message:
89
52
  rdoc_options: []
90
53
  require_paths:
91
54
  - lib
@@ -93,15 +56,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
93
56
  requirements:
94
57
  - - ">="
95
58
  - !ruby/object:Gem::Version
96
- version: '2.5'
59
+ version: '2.7'
97
60
  required_rubygems_version: !ruby/object:Gem::Requirement
98
61
  requirements:
99
62
  - - ">="
100
63
  - !ruby/object:Gem::Version
101
64
  version: '0'
102
65
  requirements: []
103
- rubygems_version: 3.1.2
104
- signing_key:
66
+ rubygems_version: 3.4.1
67
+ signing_key:
105
68
  specification_version: 4
106
69
  summary: Data loaders and abstractions for text and NLP
107
70
  test_files: []