informers 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +22 -13
- data/lib/informers/fill_mask.rb +108 -0
- data/lib/informers/question_answering.rb +0 -3
- data/lib/informers/text_generation.rb +44 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +3 -0
- data/vendor/LICENSE-gpt2.txt +24 -0
- data/vendor/LICENSE-roberta.txt +21 -0
- data/vendor/gpt2.bin +0 -0
- data/vendor/gpt2.i2w +0 -0
- data/vendor/roberta.bin +0 -0
- data/vendor/roberta.i2w +0 -0
- metadata +18 -52
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 2baf3ed7ae9b6bf6a1347f0dc880ae3a48f26daa518112e37d6bf03927faed67
|
4
|
+
data.tar.gz: 03cd4f92aa6a062fc23ca712369a8cf1db5300bb53b1eb99ad8d71574a1a8ce6
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cfef17a6c7b9a574c43f3f45cc4f20bb36c1d764f6c68f47036f41a7af9a54aecf1a678eda1e3f3f7b0da26ff8131e22dc13d56e16e461366b67d8b6b0d77e97
|
7
|
+
data.tar.gz: 8c99136eb43350c118402e0ac076055d4ab563e6f185d06c9b826c5d592a3955f64b2ea5284d32583e4725229201514a30527401f454cde45944fb54f9dd0b97
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -7,7 +7,7 @@ Supports:
|
|
7
7
|
- Sentiment analysis
|
8
8
|
- Question answering
|
9
9
|
- Named-entity recognition
|
10
|
-
- Text generation
|
10
|
+
- Text generation
|
11
11
|
- Summarization - *in development*
|
12
12
|
- Translation - *in development*
|
13
13
|
|
@@ -21,17 +21,12 @@ Add this line to your application’s Gemfile:
|
|
21
21
|
gem 'informers'
|
22
22
|
```
|
23
23
|
|
24
|
-
On Mac, also install OpenMP:
|
25
|
-
|
26
|
-
```sh
|
27
|
-
brew install libomp
|
28
|
-
```
|
29
|
-
|
30
24
|
## Getting Started
|
31
25
|
|
32
26
|
- [Sentiment analysis](#sentiment-analysis)
|
33
27
|
- [Question answering](#question-answering)
|
34
28
|
- [Named-entity recognition](#named-entity-recognition)
|
29
|
+
- [Text Generation](#text-generation)
|
35
30
|
|
36
31
|
### Sentiment Analysis
|
37
32
|
|
@@ -58,11 +53,7 @@ model.predict(["This is super cool", "I didn't like it"])
|
|
58
53
|
|
59
54
|
### Question Answering
|
60
55
|
|
61
|
-
First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx)
|
62
|
-
|
63
|
-
```ruby
|
64
|
-
gem 'numo-narray'
|
65
|
-
```
|
56
|
+
First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx).
|
66
57
|
|
67
58
|
Ask a question with some context
|
68
59
|
|
@@ -101,6 +92,23 @@ This returns
|
|
101
92
|
]
|
102
93
|
```
|
103
94
|
|
95
|
+
### Text Generation
|
96
|
+
|
97
|
+
First, export the [pretrained model](tools/export.md).
|
98
|
+
|
99
|
+
Pass a prompt
|
100
|
+
|
101
|
+
```ruby
|
102
|
+
model = Informers::TextGeneration.new("text-generation.onnx")
|
103
|
+
model.predict("As far as I am concerned, I will", max_length: 50)
|
104
|
+
```
|
105
|
+
|
106
|
+
This returns
|
107
|
+
|
108
|
+
```text
|
109
|
+
As far as I am concerned, I will be the first to admit that I am not a fan of the idea of a "free market." I think that the idea of a free market is a bit of a stretch. I think that the idea
|
110
|
+
```
|
111
|
+
|
104
112
|
## Models
|
105
113
|
|
106
114
|
Task | Description | Contributor | License | Link
|
@@ -108,8 +116,9 @@ Task | Description | Contributor | License | Link
|
|
108
116
|
Sentiment analysis | DistilBERT fine-tuned on SST-2 | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
|
109
117
|
Question answering | DistilBERT fine-tuned on SQuAD | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-cased-distilled-squad)
|
110
118
|
Named-entity recognition | BERT fine-tuned on CoNLL03 | Bayerische Staatsbibliothek | In-progress | [Link](https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
|
119
|
+
Text generation | GPT-2 | Hugging Face | [Custom](https://github.com/openai/gpt-2/blob/master/LICENSE) | [Link](https://huggingface.co/gpt2)
|
111
120
|
|
112
|
-
|
121
|
+
Some models are [quantized](https://medium.com/microsoftazure/faster-and-smaller-quantized-nlp-with-hugging-face-and-onnx-runtime-ec5525473bb7) to make them faster and smaller.
|
113
122
|
|
114
123
|
## Deployment
|
115
124
|
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
# Copyright 2021 Andrew Kane.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Informers
|
17
|
+
class FillMask
|
18
|
+
def initialize(model_path)
|
19
|
+
encoder_path = File.expand_path("../../vendor/roberta.bin", __dir__)
|
20
|
+
@encoder = BlingFire.load_model(encoder_path, prefix: false)
|
21
|
+
|
22
|
+
decoder_path = File.expand_path("../../vendor/roberta.i2w", __dir__)
|
23
|
+
@decoder = BlingFire.load_model(decoder_path)
|
24
|
+
|
25
|
+
@model = OnnxRuntime::Model.new(model_path)
|
26
|
+
end
|
27
|
+
|
28
|
+
def predict(texts)
|
29
|
+
singular = !texts.is_a?(Array)
|
30
|
+
texts = [texts] if singular
|
31
|
+
|
32
|
+
mask_token = 50264
|
33
|
+
|
34
|
+
# tokenize
|
35
|
+
input_ids =
|
36
|
+
texts.map do |text|
|
37
|
+
tokens = @encoder.text_to_ids(text, nil, 3) # unk token
|
38
|
+
|
39
|
+
# add mask token
|
40
|
+
mask_sequence = [28696, 43776, 15698]
|
41
|
+
masks = []
|
42
|
+
(tokens.size - 2).times do |i|
|
43
|
+
masks << i if tokens[i..(i + 2)] == mask_sequence
|
44
|
+
end
|
45
|
+
masks.reverse.each do |mask|
|
46
|
+
tokens = tokens[0...mask] + [mask_token] + tokens[(mask + 3)..-1]
|
47
|
+
end
|
48
|
+
|
49
|
+
tokens.unshift(0) # cls token
|
50
|
+
tokens << 2 # sep token
|
51
|
+
|
52
|
+
tokens
|
53
|
+
end
|
54
|
+
|
55
|
+
max_tokens = input_ids.map(&:size).max
|
56
|
+
attention_mask = []
|
57
|
+
input_ids.each do |ids|
|
58
|
+
zeros = [0] * (max_tokens - ids.size)
|
59
|
+
|
60
|
+
mask = ([1] * ids.size) + zeros
|
61
|
+
attention_mask << mask
|
62
|
+
|
63
|
+
ids.concat(zeros)
|
64
|
+
end
|
65
|
+
|
66
|
+
input = {
|
67
|
+
input_ids: input_ids,
|
68
|
+
attention_mask: attention_mask
|
69
|
+
}
|
70
|
+
|
71
|
+
masked_index = input_ids.map { |v| v.each_index.select { |i| v[i] == mask_token } }
|
72
|
+
masked_index.each do |v|
|
73
|
+
raise "No mask_token (<mask>) found on the input" if v.size < 1
|
74
|
+
raise "More than one mask_token (<mask>) is not supported" if v.size > 1
|
75
|
+
end
|
76
|
+
|
77
|
+
outputs = @model.predict(input)["output_0"]
|
78
|
+
batch_size = outputs.size
|
79
|
+
|
80
|
+
results = []
|
81
|
+
batch_size.times do |i|
|
82
|
+
result = []
|
83
|
+
|
84
|
+
logits = outputs[i][masked_index[i][0]]
|
85
|
+
values = logits.map { |v| Math.exp(v) }
|
86
|
+
sum = values.sum
|
87
|
+
probs = values.map { |v| v / sum }
|
88
|
+
res = probs.each_with_index.sort_by { |v| -v[0] }.first(5)
|
89
|
+
|
90
|
+
res.each do |(v, p)|
|
91
|
+
tokens = input[:input_ids][i].dup
|
92
|
+
tokens[masked_index[i][0]] = p
|
93
|
+
result << {
|
94
|
+
sequence: @decoder.ids_to_text(tokens),
|
95
|
+
score: v,
|
96
|
+
token: p,
|
97
|
+
# TODO figure out prefix space
|
98
|
+
token_str: @decoder.ids_to_text([p], skip_special_tokens: false)
|
99
|
+
}
|
100
|
+
end
|
101
|
+
|
102
|
+
results += [result]
|
103
|
+
end
|
104
|
+
|
105
|
+
singular ? results.first : results
|
106
|
+
end
|
107
|
+
end
|
108
|
+
end
|
@@ -16,9 +16,6 @@
|
|
16
16
|
module Informers
|
17
17
|
class QuestionAnswering
|
18
18
|
def initialize(model_path)
|
19
|
-
# make sure Numo is available
|
20
|
-
require "numo/narray"
|
21
|
-
|
22
19
|
tokenizer_path = File.expand_path("../../vendor/bert_base_cased_tok.bin", __dir__)
|
23
20
|
@tokenizer = BlingFire.load_model(tokenizer_path)
|
24
21
|
@model = OnnxRuntime::Model.new(model_path)
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
# Copyright 2021 Andrew Kane.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Informers
|
17
|
+
class TextGeneration
|
18
|
+
def initialize(model_path)
|
19
|
+
encoder_path = File.expand_path("../../vendor/gpt2.bin", __dir__)
|
20
|
+
@encoder = BlingFire.load_model(encoder_path, prefix: false)
|
21
|
+
|
22
|
+
decoder_path = File.expand_path("../../vendor/gpt2.i2w", __dir__)
|
23
|
+
@decoder = BlingFire.load_model(decoder_path)
|
24
|
+
|
25
|
+
@model = OnnxRuntime::Model.new(model_path)
|
26
|
+
end
|
27
|
+
|
28
|
+
def predict(text, max_length: 50)
|
29
|
+
tokens = @encoder.text_to_ids(text)
|
30
|
+
|
31
|
+
input = {
|
32
|
+
input_ids: [tokens]
|
33
|
+
}
|
34
|
+
|
35
|
+
(max_length - tokens.size).times do |i|
|
36
|
+
output = @model.predict(input, output_type: :numo, output_names: ["output_0"])
|
37
|
+
# passed to input_ids
|
38
|
+
tokens << output["output_0"][0, true, true][-1, true].max_index
|
39
|
+
end
|
40
|
+
|
41
|
+
@decoder.ids_to_text(tokens)
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
data/lib/informers/version.rb
CHANGED
data/lib/informers.rb
CHANGED
@@ -1,10 +1,13 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "blingfire"
|
3
|
+
require "numo/narray"
|
3
4
|
require "onnxruntime"
|
4
5
|
|
5
6
|
# modules
|
6
7
|
require "informers/feature_extraction"
|
8
|
+
require "informers/fill_mask"
|
7
9
|
require "informers/ner"
|
8
10
|
require "informers/question_answering"
|
9
11
|
require "informers/sentiment_analysis"
|
12
|
+
require "informers/text_generation"
|
10
13
|
require "informers/version"
|
@@ -0,0 +1,24 @@
|
|
1
|
+
Modified MIT License
|
2
|
+
|
3
|
+
Software Copyright (c) 2019 OpenAI
|
4
|
+
|
5
|
+
We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
|
6
|
+
We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
|
7
|
+
|
8
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
9
|
+
associated documentation files (the "Software"), to deal in the Software without restriction,
|
10
|
+
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
11
|
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
|
12
|
+
subject to the following conditions:
|
13
|
+
|
14
|
+
The above copyright notice and this permission notice shall be included
|
15
|
+
in all copies or substantial portions of the Software.
|
16
|
+
The above copyright notice and this permission notice need not be included
|
17
|
+
with content created by the Software.
|
18
|
+
|
19
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
20
|
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
21
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
22
|
+
BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
23
|
+
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
24
|
+
OR OTHER DEALINGS IN THE SOFTWARE.
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) Facebook, Inc. and its affiliates.
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
data/vendor/gpt2.bin
ADDED
Binary file
|
data/vendor/gpt2.i2w
ADDED
Binary file
|
data/vendor/roberta.bin
ADDED
Binary file
|
data/vendor/roberta.i2w
ADDED
Binary file
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: informers
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2021-09-25 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: blingfire
|
@@ -16,16 +16,16 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 0.1.
|
19
|
+
version: 0.1.7
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 0.1.
|
26
|
+
version: 0.1.7
|
27
27
|
- !ruby/object:Gem::Dependency
|
28
|
-
name:
|
28
|
+
name: numo-narray
|
29
29
|
requirement: !ruby/object:Gem::Requirement
|
30
30
|
requirements:
|
31
31
|
- - ">="
|
@@ -39,61 +39,19 @@ dependencies:
|
|
39
39
|
- !ruby/object:Gem::Version
|
40
40
|
version: '0'
|
41
41
|
- !ruby/object:Gem::Dependency
|
42
|
-
name:
|
43
|
-
requirement: !ruby/object:Gem::Requirement
|
44
|
-
requirements:
|
45
|
-
- - ">="
|
46
|
-
- !ruby/object:Gem::Version
|
47
|
-
version: '0'
|
48
|
-
type: :development
|
49
|
-
prerelease: false
|
50
|
-
version_requirements: !ruby/object:Gem::Requirement
|
51
|
-
requirements:
|
52
|
-
- - ">="
|
53
|
-
- !ruby/object:Gem::Version
|
54
|
-
version: '0'
|
55
|
-
- !ruby/object:Gem::Dependency
|
56
|
-
name: rake
|
57
|
-
requirement: !ruby/object:Gem::Requirement
|
58
|
-
requirements:
|
59
|
-
- - ">="
|
60
|
-
- !ruby/object:Gem::Version
|
61
|
-
version: '0'
|
62
|
-
type: :development
|
63
|
-
prerelease: false
|
64
|
-
version_requirements: !ruby/object:Gem::Requirement
|
65
|
-
requirements:
|
66
|
-
- - ">="
|
67
|
-
- !ruby/object:Gem::Version
|
68
|
-
version: '0'
|
69
|
-
- !ruby/object:Gem::Dependency
|
70
|
-
name: minitest
|
71
|
-
requirement: !ruby/object:Gem::Requirement
|
72
|
-
requirements:
|
73
|
-
- - ">="
|
74
|
-
- !ruby/object:Gem::Version
|
75
|
-
version: '5'
|
76
|
-
type: :development
|
77
|
-
prerelease: false
|
78
|
-
version_requirements: !ruby/object:Gem::Requirement
|
79
|
-
requirements:
|
80
|
-
- - ">="
|
81
|
-
- !ruby/object:Gem::Version
|
82
|
-
version: '5'
|
83
|
-
- !ruby/object:Gem::Dependency
|
84
|
-
name: numo-narray
|
42
|
+
name: onnxruntime
|
85
43
|
requirement: !ruby/object:Gem::Requirement
|
86
44
|
requirements:
|
87
45
|
- - ">="
|
88
46
|
- !ruby/object:Gem::Version
|
89
|
-
version:
|
90
|
-
type: :
|
47
|
+
version: 0.5.1
|
48
|
+
type: :runtime
|
91
49
|
prerelease: false
|
92
50
|
version_requirements: !ruby/object:Gem::Requirement
|
93
51
|
requirements:
|
94
52
|
- - ">="
|
95
53
|
- !ruby/object:Gem::Version
|
96
|
-
version:
|
54
|
+
version: 0.5.1
|
97
55
|
description:
|
98
56
|
email: andrew@chartkick.com
|
99
57
|
executables: []
|
@@ -105,14 +63,22 @@ files:
|
|
105
63
|
- README.md
|
106
64
|
- lib/informers.rb
|
107
65
|
- lib/informers/feature_extraction.rb
|
66
|
+
- lib/informers/fill_mask.rb
|
108
67
|
- lib/informers/ner.rb
|
109
68
|
- lib/informers/question_answering.rb
|
110
69
|
- lib/informers/sentiment_analysis.rb
|
70
|
+
- lib/informers/text_generation.rb
|
111
71
|
- lib/informers/version.rb
|
112
72
|
- vendor/LICENSE-bert.txt
|
113
73
|
- vendor/LICENSE-blingfire.txt
|
74
|
+
- vendor/LICENSE-gpt2.txt
|
75
|
+
- vendor/LICENSE-roberta.txt
|
114
76
|
- vendor/bert_base_cased_tok.bin
|
115
77
|
- vendor/bert_base_tok.bin
|
78
|
+
- vendor/gpt2.bin
|
79
|
+
- vendor/gpt2.i2w
|
80
|
+
- vendor/roberta.bin
|
81
|
+
- vendor/roberta.i2w
|
116
82
|
homepage: https://github.com/ankane/informers
|
117
83
|
licenses:
|
118
84
|
- Apache-2.0
|
@@ -132,7 +98,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
132
98
|
- !ruby/object:Gem::Version
|
133
99
|
version: '0'
|
134
100
|
requirements: []
|
135
|
-
rubygems_version: 3.
|
101
|
+
rubygems_version: 3.2.22
|
136
102
|
signing_key:
|
137
103
|
specification_version: 4
|
138
104
|
summary: State-of-the-art natural language processing for Ruby
|