informers 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 30960cffae248b704482b2faaa2d573cb9d9cf491543d8c9593b8c937d997a9f
4
- data.tar.gz: c6ce6d049ae38eb6a154fb7cd2fbef2cd709a71663f2ffc6951869687ae779e7
3
+ metadata.gz: 2baf3ed7ae9b6bf6a1347f0dc880ae3a48f26daa518112e37d6bf03927faed67
4
+ data.tar.gz: 03cd4f92aa6a062fc23ca712369a8cf1db5300bb53b1eb99ad8d71574a1a8ce6
5
5
  SHA512:
6
- metadata.gz: eb3382ec97e9ffbf7dbada8440290c2c7a2155574d2b5c3a14357dffcf19ac8b36256a0ae2265923cdfb6efed30cc1d1f3106576edfeaa652853c42d18f80063
7
- data.tar.gz: 13bc7da32218b600d49d0289dfcb258b25a2e3cce4793398eb92e8e48c2886cfcf7f944e7f666fa5d4170c0db8eade00599d4d4333974eea5e285b0ecf946b5d
6
+ metadata.gz: cfef17a6c7b9a574c43f3f45cc4f20bb36c1d764f6c68f47036f41a7af9a54aecf1a678eda1e3f3f7b0da26ff8131e22dc13d56e16e461366b67d8b6b0d77e97
7
+ data.tar.gz: 8c99136eb43350c118402e0ac076055d4ab563e6f185d06c9b826c5d592a3955f64b2ea5284d32583e4725229201514a30527401f454cde45944fb54f9dd0b97
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.1.3 (2021-09-25)
2
+
3
+ - Added text generation
4
+ - Added fill mask
5
+
1
6
  ## 0.1.2 (2020-11-24)
2
7
 
3
8
  - Added feature extraction
data/README.md CHANGED
@@ -7,7 +7,7 @@ Supports:
7
7
  - Sentiment analysis
8
8
  - Question answering
9
9
  - Named-entity recognition
10
- - Text generation - *in development*
10
+ - Text generation
11
11
  - Summarization - *in development*
12
12
  - Translation - *in development*
13
13
 
@@ -21,17 +21,12 @@ Add this line to your application’s Gemfile:
21
21
  gem 'informers'
22
22
  ```
23
23
 
24
- On Mac, also install OpenMP:
25
-
26
- ```sh
27
- brew install libomp
28
- ```
29
-
30
24
  ## Getting Started
31
25
 
32
26
  - [Sentiment analysis](#sentiment-analysis)
33
27
  - [Question answering](#question-answering)
34
28
  - [Named-entity recognition](#named-entity-recognition)
29
+ - [Text Generation](#text-generation)
35
30
 
36
31
  ### Sentiment Analysis
37
32
 
@@ -58,11 +53,7 @@ model.predict(["This is super cool", "I didn't like it"])
58
53
 
59
54
  ### Question Answering
60
55
 
61
- First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx) and add Numo to your application’s Gemfile:
62
-
63
- ```ruby
64
- gem 'numo-narray'
65
- ```
56
+ First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx).
66
57
 
67
58
  Ask a question with some context
68
59
 
@@ -101,6 +92,23 @@ This returns
101
92
  ]
102
93
  ```
103
94
 
95
+ ### Text Generation
96
+
97
+ First, export the [pretrained model](tools/export.md).
98
+
99
+ Pass a prompt
100
+
101
+ ```ruby
102
+ model = Informers::TextGeneration.new("text-generation.onnx")
103
+ model.predict("As far as I am concerned, I will", max_length: 50)
104
+ ```
105
+
106
+ This returns
107
+
108
+ ```text
109
+ As far as I am concerned, I will be the first to admit that I am not a fan of the idea of a "free market." I think that the idea of a free market is a bit of a stretch. I think that the idea
110
+ ```
111
+
104
112
  ## Models
105
113
 
106
114
  Task | Description | Contributor | License | Link
@@ -108,8 +116,9 @@ Task | Description | Contributor | License | Link
108
116
  Sentiment analysis | DistilBERT fine-tuned on SST-2 | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
109
117
  Question answering | DistilBERT fine-tuned on SQuAD | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-cased-distilled-squad)
110
118
  Named-entity recognition | BERT fine-tuned on CoNLL03 | Bayerische Staatsbibliothek | In-progress | [Link](https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
119
+ Text generation | GPT-2 | Hugging Face | [Custom](https://github.com/openai/gpt-2/blob/master/LICENSE) | [Link](https://huggingface.co/gpt2)
111
120
 
112
- Models are [quantized](https://medium.com/microsoftazure/faster-and-smaller-quantized-nlp-with-hugging-face-and-onnx-runtime-ec5525473bb7) to make them faster and smaller.
121
+ Some models are [quantized](https://medium.com/microsoftazure/faster-and-smaller-quantized-nlp-with-hugging-face-and-onnx-runtime-ec5525473bb7) to make them faster and smaller.
113
122
 
114
123
  ## Deployment
115
124
 
@@ -0,0 +1,108 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ # Copyright 2021 Andrew Kane.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Informers
17
+ class FillMask
18
+ def initialize(model_path)
19
+ encoder_path = File.expand_path("../../vendor/roberta.bin", __dir__)
20
+ @encoder = BlingFire.load_model(encoder_path, prefix: false)
21
+
22
+ decoder_path = File.expand_path("../../vendor/roberta.i2w", __dir__)
23
+ @decoder = BlingFire.load_model(decoder_path)
24
+
25
+ @model = OnnxRuntime::Model.new(model_path)
26
+ end
27
+
28
+ def predict(texts)
29
+ singular = !texts.is_a?(Array)
30
+ texts = [texts] if singular
31
+
32
+ mask_token = 50264
33
+
34
+ # tokenize
35
+ input_ids =
36
+ texts.map do |text|
37
+ tokens = @encoder.text_to_ids(text, nil, 3) # unk token
38
+
39
+ # add mask token
40
+ mask_sequence = [28696, 43776, 15698]
41
+ masks = []
42
+ (tokens.size - 2).times do |i|
43
+ masks << i if tokens[i..(i + 2)] == mask_sequence
44
+ end
45
+ masks.reverse.each do |mask|
46
+ tokens = tokens[0...mask] + [mask_token] + tokens[(mask + 3)..-1]
47
+ end
48
+
49
+ tokens.unshift(0) # cls token
50
+ tokens << 2 # sep token
51
+
52
+ tokens
53
+ end
54
+
55
+ max_tokens = input_ids.map(&:size).max
56
+ attention_mask = []
57
+ input_ids.each do |ids|
58
+ zeros = [0] * (max_tokens - ids.size)
59
+
60
+ mask = ([1] * ids.size) + zeros
61
+ attention_mask << mask
62
+
63
+ ids.concat(zeros)
64
+ end
65
+
66
+ input = {
67
+ input_ids: input_ids,
68
+ attention_mask: attention_mask
69
+ }
70
+
71
+ masked_index = input_ids.map { |v| v.each_index.select { |i| v[i] == mask_token } }
72
+ masked_index.each do |v|
73
+ raise "No mask_token (<mask>) found on the input" if v.size < 1
74
+ raise "More than one mask_token (<mask>) is not supported" if v.size > 1
75
+ end
76
+
77
+ outputs = @model.predict(input)["output_0"]
78
+ batch_size = outputs.size
79
+
80
+ results = []
81
+ batch_size.times do |i|
82
+ result = []
83
+
84
+ logits = outputs[i][masked_index[i][0]]
85
+ values = logits.map { |v| Math.exp(v) }
86
+ sum = values.sum
87
+ probs = values.map { |v| v / sum }
88
+ res = probs.each_with_index.sort_by { |v| -v[0] }.first(5)
89
+
90
+ res.each do |(v, p)|
91
+ tokens = input[:input_ids][i].dup
92
+ tokens[masked_index[i][0]] = p
93
+ result << {
94
+ sequence: @decoder.ids_to_text(tokens),
95
+ score: v,
96
+ token: p,
97
+ # TODO figure out prefix space
98
+ token_str: @decoder.ids_to_text([p], skip_special_tokens: false)
99
+ }
100
+ end
101
+
102
+ results += [result]
103
+ end
104
+
105
+ singular ? results.first : results
106
+ end
107
+ end
108
+ end
@@ -16,9 +16,6 @@
16
16
  module Informers
17
17
  class QuestionAnswering
18
18
  def initialize(model_path)
19
- # make sure Numo is available
20
- require "numo/narray"
21
-
22
19
  tokenizer_path = File.expand_path("../../vendor/bert_base_cased_tok.bin", __dir__)
23
20
  @tokenizer = BlingFire.load_model(tokenizer_path)
24
21
  @model = OnnxRuntime::Model.new(model_path)
@@ -0,0 +1,44 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ # Copyright 2021 Andrew Kane.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Informers
17
+ class TextGeneration
18
+ def initialize(model_path)
19
+ encoder_path = File.expand_path("../../vendor/gpt2.bin", __dir__)
20
+ @encoder = BlingFire.load_model(encoder_path, prefix: false)
21
+
22
+ decoder_path = File.expand_path("../../vendor/gpt2.i2w", __dir__)
23
+ @decoder = BlingFire.load_model(decoder_path)
24
+
25
+ @model = OnnxRuntime::Model.new(model_path)
26
+ end
27
+
28
+ def predict(text, max_length: 50)
29
+ tokens = @encoder.text_to_ids(text)
30
+
31
+ input = {
32
+ input_ids: [tokens]
33
+ }
34
+
35
+ (max_length - tokens.size).times do |i|
36
+ output = @model.predict(input, output_type: :numo, output_names: ["output_0"])
37
+ # passed to input_ids
38
+ tokens << output["output_0"][0, true, true][-1, true].max_index
39
+ end
40
+
41
+ @decoder.ids_to_text(tokens)
42
+ end
43
+ end
44
+ end
@@ -1,3 +1,3 @@
1
1
  module Informers
2
- VERSION = "0.1.2"
2
+ VERSION = "0.1.3"
3
3
  end
data/lib/informers.rb CHANGED
@@ -1,10 +1,13 @@
1
1
  # dependencies
2
2
  require "blingfire"
3
+ require "numo/narray"
3
4
  require "onnxruntime"
4
5
 
5
6
  # modules
6
7
  require "informers/feature_extraction"
8
+ require "informers/fill_mask"
7
9
  require "informers/ner"
8
10
  require "informers/question_answering"
9
11
  require "informers/sentiment_analysis"
12
+ require "informers/text_generation"
10
13
  require "informers/version"
@@ -0,0 +1,24 @@
1
+ Modified MIT License
2
+
3
+ Software Copyright (c) 2019 OpenAI
4
+
5
+ We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
6
+ We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
9
+ associated documentation files (the "Software"), to deal in the Software without restriction,
10
+ including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
11
+ and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
12
+ subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included
15
+ in all copies or substantial portions of the Software.
16
+ The above copyright notice and this permission notice need not be included
17
+ with content created by the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
20
+ INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
22
+ BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
24
+ OR OTHER DEALINGS IN THE SOFTWARE.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/vendor/gpt2.bin ADDED
Binary file
data/vendor/gpt2.i2w ADDED
Binary file
Binary file
Binary file
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: informers
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.1.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-11-24 00:00:00.000000000 Z
11
+ date: 2021-09-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: blingfire
@@ -16,16 +16,16 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.1.3
19
+ version: 0.1.7
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.1.3
26
+ version: 0.1.7
27
27
  - !ruby/object:Gem::Dependency
28
- name: onnxruntime
28
+ name: numo-narray
29
29
  requirement: !ruby/object:Gem::Requirement
30
30
  requirements:
31
31
  - - ">="
@@ -39,61 +39,19 @@ dependencies:
39
39
  - !ruby/object:Gem::Version
40
40
  version: '0'
41
41
  - !ruby/object:Gem::Dependency
42
- name: bundler
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: rake
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: minitest
71
- requirement: !ruby/object:Gem::Requirement
72
- requirements:
73
- - - ">="
74
- - !ruby/object:Gem::Version
75
- version: '5'
76
- type: :development
77
- prerelease: false
78
- version_requirements: !ruby/object:Gem::Requirement
79
- requirements:
80
- - - ">="
81
- - !ruby/object:Gem::Version
82
- version: '5'
83
- - !ruby/object:Gem::Dependency
84
- name: numo-narray
42
+ name: onnxruntime
85
43
  requirement: !ruby/object:Gem::Requirement
86
44
  requirements:
87
45
  - - ">="
88
46
  - !ruby/object:Gem::Version
89
- version: '0'
90
- type: :development
47
+ version: 0.5.1
48
+ type: :runtime
91
49
  prerelease: false
92
50
  version_requirements: !ruby/object:Gem::Requirement
93
51
  requirements:
94
52
  - - ">="
95
53
  - !ruby/object:Gem::Version
96
- version: '0'
54
+ version: 0.5.1
97
55
  description:
98
56
  email: andrew@chartkick.com
99
57
  executables: []
@@ -105,14 +63,22 @@ files:
105
63
  - README.md
106
64
  - lib/informers.rb
107
65
  - lib/informers/feature_extraction.rb
66
+ - lib/informers/fill_mask.rb
108
67
  - lib/informers/ner.rb
109
68
  - lib/informers/question_answering.rb
110
69
  - lib/informers/sentiment_analysis.rb
70
+ - lib/informers/text_generation.rb
111
71
  - lib/informers/version.rb
112
72
  - vendor/LICENSE-bert.txt
113
73
  - vendor/LICENSE-blingfire.txt
74
+ - vendor/LICENSE-gpt2.txt
75
+ - vendor/LICENSE-roberta.txt
114
76
  - vendor/bert_base_cased_tok.bin
115
77
  - vendor/bert_base_tok.bin
78
+ - vendor/gpt2.bin
79
+ - vendor/gpt2.i2w
80
+ - vendor/roberta.bin
81
+ - vendor/roberta.i2w
116
82
  homepage: https://github.com/ankane/informers
117
83
  licenses:
118
84
  - Apache-2.0
@@ -132,7 +98,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
132
98
  - !ruby/object:Gem::Version
133
99
  version: '0'
134
100
  requirements: []
135
- rubygems_version: 3.1.4
101
+ rubygems_version: 3.2.22
136
102
  signing_key:
137
103
  specification_version: 4
138
104
  summary: State-of-the-art natural language processing for Ruby