hugging-face 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile.lock +1 -1
- data/README.md +1 -1
- data/lib/hugging_face/base_api.rb +1 -1
- data/lib/hugging_face/inference_api.rb +17 -5
- data/lib/hugging_face/version.rb +1 -1
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 93e213102e56d8e86de856912a1963f36e4b0a0d60f21c8a1f8d99b23a0f47d7
|
4
|
+
data.tar.gz: 877bd5a93e54ccbcb82943c19e4ac889657686aa656b0bd3bb911732b1d5d6ff
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ee534a915fa831884e42519f178ae709a08c201d79976d97b52319dd5322fc979c3dfe711a3e2210b0d53c0132490a089ccd2ca7c73730919f586582be0dc884
|
7
|
+
data.tar.gz: c4df146833ea986a86d9b9c060f594535b7bdf558dd9d28396dd100869c4a899dc4ddad1bdd2eba1ff8a09610ee4cada09a271adda9f4d6e58d21c5181c925a6
|
data/Gemfile.lock
CHANGED
data/README.md
CHANGED
@@ -1,11 +1,15 @@
|
|
1
1
|
module HuggingFace
|
2
2
|
class InferenceApi < BaseApi
|
3
3
|
HOST = "https://api-inference.huggingface.co"
|
4
|
-
MAX_RETRY = 20
|
5
4
|
|
5
|
+
# Retry connecting to the model for 1 minute
|
6
|
+
MAX_RETRY = 60
|
7
|
+
|
8
|
+
# Deafult models that can be overriden by 'model' param
|
6
9
|
QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad'
|
7
10
|
SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6"
|
8
11
|
GENERATION_MODEL = "distilgpt2"
|
12
|
+
EMBEDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
9
13
|
|
10
14
|
def call(input:, model:)
|
11
15
|
request(connection: connection(model), input: input)
|
@@ -14,21 +18,29 @@ module HuggingFace
|
|
14
18
|
def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL)
|
15
19
|
input = { question: question, context: context }
|
16
20
|
|
17
|
-
request
|
21
|
+
request connection: connection(model), input: input
|
18
22
|
end
|
19
23
|
|
20
24
|
def summarization(input:, model: SUMMARIZATION_MODEL)
|
21
|
-
request
|
25
|
+
request connection: connection(model), input: { inputs: input }
|
22
26
|
end
|
23
27
|
|
24
28
|
def text_generation(input:, model: GENERATION_MODEL)
|
25
|
-
request
|
29
|
+
request connection: connection(model), input: { inputs: input }
|
30
|
+
end
|
31
|
+
|
32
|
+
def embedding(input:)
|
33
|
+
request connection: connection(EMBEDING_MODEL), input: { inputs: input }
|
26
34
|
end
|
27
35
|
|
28
36
|
private
|
29
37
|
|
30
38
|
def connection(model)
|
31
|
-
|
39
|
+
if model == EMBEDING_MODEL
|
40
|
+
build_connection "#{HOST}/pipeline/feature-extraction/#{model}"
|
41
|
+
else
|
42
|
+
build_connection "#{HOST}/models/#{model}"
|
43
|
+
end
|
32
44
|
end
|
33
45
|
|
34
46
|
def request(connection:, input:)
|
data/lib/hugging_face/version.rb
CHANGED