hugging-face 0.2.0 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 053432b377155f8768a2e25521b62c0fc5fdbb3a619f45875c346c5f3c126e81
4
- data.tar.gz: d8551f8263129317a2692e5b6d3994617c1ae66285ce21a78a1796a2ad5f9f85
3
+ metadata.gz: 93e213102e56d8e86de856912a1963f36e4b0a0d60f21c8a1f8d99b23a0f47d7
4
+ data.tar.gz: 877bd5a93e54ccbcb82943c19e4ac889657686aa656b0bd3bb911732b1d5d6ff
5
5
  SHA512:
6
- metadata.gz: 933fad128d5ad66370e0e24271498cc9730e0f71e9ed920cf5a3107f8971bb5382586d1bc3d887b4d490b0fbbb8f0498e9b503dab4a669469bc1349824e023a3
7
- data.tar.gz: 0654e009acdbea3097f00f13f7c2b4389679e52b4b15db629ca2703c4267f2e27dcbfe866ee644c9cb5c23d56a5f5960d37a18334c7d0a21ca55815941aec281
6
+ metadata.gz: ee534a915fa831884e42519f178ae709a08c201d79976d97b52319dd5322fc979c3dfe711a3e2210b0d53c0132490a089ccd2ca7c73730919f586582be0dc884
7
+ data.tar.gz: c4df146833ea986a86d9b9c060f594535b7bdf558dd9d28396dd100869c4a899dc4ddad1bdd2eba1ff8a09610ee4cada09a271adda9f4d6e58d21c5181c925a6
data/Gemfile.lock CHANGED
@@ -1,7 +1,7 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- hugging-face (0.1.0)
4
+ hugging-face (0.2.0)
5
5
  faraday (~> 2.7)
6
6
 
7
7
  GEM
data/README.md CHANGED
@@ -28,7 +28,7 @@ Instantiate a HuggigFace Inference API client:
28
28
  client = HuggingFace::InferenceApi.new(api_key: ENV['HUGGING_FACE_API_KEY'])
29
29
  ```
30
30
 
31
- Questiion answering:
31
+ Question answering:
32
32
 
33
33
  ```ruby
34
34
  client.question_answering(
@@ -15,7 +15,7 @@ module HuggingFace
15
15
 
16
16
  private
17
17
 
18
- def connection(url)
18
+ def build_connection(url)
19
19
  Faraday.new(url, headers: @headers)
20
20
  end
21
21
 
@@ -1,11 +1,15 @@
1
1
  module HuggingFace
2
2
  class InferenceApi < BaseApi
3
3
  HOST = "https://api-inference.huggingface.co"
4
- MAX_RETRY = 20
5
4
 
5
+ # Retry connecting to the model for 1 minute
6
+ MAX_RETRY = 60
7
+
8
+ # Deafult models that can be overriden by 'model' param
6
9
  QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad'
7
10
  SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6"
8
11
  GENERATION_MODEL = "distilgpt2"
12
+ EMBEDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
9
13
 
10
14
  def call(input:, model:)
11
15
  request(connection: connection(model), input: input)
@@ -14,21 +18,29 @@ module HuggingFace
14
18
  def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL)
15
19
  input = { question: question, context: context }
16
20
 
17
- request(connection: connection(model), input: input)
21
+ request connection: connection(model), input: input
18
22
  end
19
23
 
20
24
  def summarization(input:, model: SUMMARIZATION_MODEL)
21
- request(connection: connection(model), input: { inputs: input })
25
+ request connection: connection(model), input: { inputs: input }
22
26
  end
23
27
 
24
28
  def text_generation(input:, model: GENERATION_MODEL)
25
- request(connection: connection(model), input: { inputs: input })
29
+ request connection: connection(model), input: { inputs: input }
30
+ end
31
+
32
+ def embedding(input:)
33
+ request connection: connection(EMBEDING_MODEL), input: { inputs: input }
26
34
  end
27
35
 
28
36
  private
29
37
 
30
38
  def connection(model)
31
- super "#{HOST}/models/#{model}"
39
+ if model == EMBEDING_MODEL
40
+ build_connection "#{HOST}/pipeline/feature-extraction/#{model}"
41
+ else
42
+ build_connection "#{HOST}/models/#{model}"
43
+ end
32
44
  end
33
45
 
34
46
  def request(connection:, input:)
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module HuggingFace
4
- VERSION = "0.2.0"
4
+ VERSION = "0.3.0"
5
5
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: hugging-face
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Alex Chaplinsky