hugging-face 0.1.0 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 81267c0c030b98028e201d0f52471d12a2aea6f6fe519354534522528e331d02
4
- data.tar.gz: 235e128fdd8c7f534aace15ff606f89e432cb7fd120a44f94d3fc55cd4fa491a
3
+ metadata.gz: 93e213102e56d8e86de856912a1963f36e4b0a0d60f21c8a1f8d99b23a0f47d7
4
+ data.tar.gz: 877bd5a93e54ccbcb82943c19e4ac889657686aa656b0bd3bb911732b1d5d6ff
5
5
  SHA512:
6
- metadata.gz: 7d620ad3f025f8f0ccae02a266b8ee74c41c019d50e76fbf3e24688a71f213cacd6c342cea23f9854522480c2f5945b4973a5ea58fba5f860bb16c8f2244971e
7
- data.tar.gz: 1aa4a1fce880f38ae5367f168cb53ce3a93cf58450ab5d8d0a7ef6f5abe316efb713a1dcc97e9e592a91ba03484c915dc73dfe7401189fd8b8ba1b8cebfafc75
6
+ metadata.gz: ee534a915fa831884e42519f178ae709a08c201d79976d97b52319dd5322fc979c3dfe711a3e2210b0d53c0132490a089ccd2ca7c73730919f586582be0dc884
7
+ data.tar.gz: c4df146833ea986a86d9b9c060f594535b7bdf558dd9d28396dd100869c4a899dc4ddad1bdd2eba1ff8a09610ee4cada09a271adda9f4d6e58d21c5181c925a6
data/Gemfile CHANGED
@@ -8,3 +8,4 @@ gemspec
8
8
  gem "rake", "~> 13.0"
9
9
 
10
10
  gem "rspec", "~> 3.0"
11
+ gem "webmock", "~> 3.0"
data/Gemfile.lock CHANGED
@@ -1,18 +1,25 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- hugging-face (0.1.0)
4
+ hugging-face (0.2.0)
5
5
  faraday (~> 2.7)
6
6
 
7
7
  GEM
8
8
  remote: https://rubygems.org/
9
9
  specs:
10
+ addressable (2.8.4)
11
+ public_suffix (>= 2.0.2, < 6.0)
12
+ crack (0.4.5)
13
+ rexml
10
14
  diff-lcs (1.5.0)
11
15
  faraday (2.7.4)
12
16
  faraday-net_http (>= 2.0, < 3.1)
13
17
  ruby2_keywords (>= 0.0.4)
14
18
  faraday-net_http (3.0.2)
19
+ hashdiff (1.0.1)
20
+ public_suffix (5.0.1)
15
21
  rake (13.0.6)
22
+ rexml (3.2.5)
16
23
  rspec (3.12.0)
17
24
  rspec-core (~> 3.12.0)
18
25
  rspec-expectations (~> 3.12.0)
@@ -27,14 +34,20 @@ GEM
27
34
  rspec-support (~> 3.12.0)
28
35
  rspec-support (3.12.0)
29
36
  ruby2_keywords (0.0.5)
37
+ webmock (3.18.1)
38
+ addressable (>= 2.8.0)
39
+ crack (>= 0.3.2)
40
+ hashdiff (>= 0.4.0, < 2.0.0)
30
41
 
31
42
  PLATFORMS
32
43
  arm64-darwin-21
44
+ x86_64-linux
33
45
 
34
46
  DEPENDENCIES
35
47
  hugging-face!
36
48
  rake (~> 13.0)
37
49
  rspec (~> 3.0)
50
+ webmock (~> 3.0)
38
51
 
39
52
  BUNDLED WITH
40
53
  2.4.0
data/README.md CHANGED
@@ -18,7 +18,36 @@ $ gem install hugging-face
18
18
 
19
19
  ## Usage
20
20
 
21
- TODO: Write usage instructions here
21
+ ```ruby
22
+ require "hugging_face"
23
+ ```
24
+
25
+ Instantiate a HuggigFace Inference API client:
26
+
27
+ ```ruby
28
+ client = HuggingFace::InferenceApi.new(api_key: ENV['HUGGING_FACE_API_KEY'])
29
+ ```
30
+
31
+ Question answering:
32
+
33
+ ```ruby
34
+ client.question_answering(
35
+ question: 'What is my name?',
36
+ context: 'I am the only child. My father named his son John.'
37
+ )
38
+ ```
39
+
40
+ Text generation:
41
+
42
+ ```ruby
43
+ client.text_generation(input: 'Can you please let us know more details about your ')
44
+ ```
45
+
46
+ Summarization:
47
+
48
+ ```ruby
49
+ client.summarization(input: 'The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.')
50
+ ```
22
51
 
23
52
  ## Development
24
53
 
@@ -33,3 +62,4 @@ Bug reports and pull requests are welcome on GitHub at https://github.com/alchap
33
62
  ## Code of Conduct
34
63
 
35
64
  Everyone interacting in the HuggingFace project's codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/alchaplinsky/hugging-face/blob/main/CODE_OF_CONDUCT.md).
65
+
@@ -0,0 +1,46 @@
1
+ require 'logger'
2
+ require 'faraday'
3
+
4
+ module HuggingFace
5
+ class BaseApi
6
+ HTTP_SERVICE_UNAVAILABLE = 503
7
+ JSON_CONTENT_TYPE = 'application/json'
8
+
9
+ def initialize(api_token:)
10
+ @headers = {
11
+ 'Authorization' => 'Bearer ' + api_token,
12
+ 'Content-Type' => JSON_CONTENT_TYPE
13
+ }
14
+ end
15
+
16
+ private
17
+
18
+ def build_connection(url)
19
+ Faraday.new(url, headers: @headers)
20
+ end
21
+
22
+ def request(connection:, input:)
23
+ response = connection.post { |req| req.body = input.to_json }
24
+
25
+ if response.success?
26
+ return parse_response response
27
+ else
28
+ raise ServiceUnavailable.new response.body if response.status == HTTP_SERVICE_UNAVAILABLE
29
+ raise Error.new response.body
30
+ end
31
+ end
32
+
33
+ def parse_response(response)
34
+ if response.headers['Content-Type'] == JSON_CONTENT_TYPE
35
+ JSON.parse(response.body)
36
+ else
37
+ response.body
38
+ end
39
+ end
40
+
41
+ def logger
42
+ @logger ||= Logger.new(STDOUT)
43
+ end
44
+ end
45
+ end
46
+
@@ -0,0 +1,64 @@
1
+ module HuggingFace
2
+ class InferenceApi < BaseApi
3
+ HOST = "https://api-inference.huggingface.co"
4
+
5
+ # Retry connecting to the model for 1 minute
6
+ MAX_RETRY = 60
7
+
8
+ # Deafult models that can be overriden by 'model' param
9
+ QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad'
10
+ SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6"
11
+ GENERATION_MODEL = "distilgpt2"
12
+ EMBEDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
13
+
14
+ def call(input:, model:)
15
+ request(connection: connection(model), input: input)
16
+ end
17
+
18
+ def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL)
19
+ input = { question: question, context: context }
20
+
21
+ request connection: connection(model), input: input
22
+ end
23
+
24
+ def summarization(input:, model: SUMMARIZATION_MODEL)
25
+ request connection: connection(model), input: { inputs: input }
26
+ end
27
+
28
+ def text_generation(input:, model: GENERATION_MODEL)
29
+ request connection: connection(model), input: { inputs: input }
30
+ end
31
+
32
+ def embedding(input:)
33
+ request connection: connection(EMBEDING_MODEL), input: { inputs: input }
34
+ end
35
+
36
+ private
37
+
38
+ def connection(model)
39
+ if model == EMBEDING_MODEL
40
+ build_connection "#{HOST}/pipeline/feature-extraction/#{model}"
41
+ else
42
+ build_connection "#{HOST}/models/#{model}"
43
+ end
44
+ end
45
+
46
+ def request(connection:, input:)
47
+ retries = 0
48
+
49
+ begin
50
+ return super(connection: connection, input: input)
51
+ rescue ServiceUnavailable => exception
52
+
53
+ if retries < MAX_RETRY
54
+ logger.debug('Service unavailable, retrying...')
55
+ retries += 1
56
+ sleep 1
57
+ retry
58
+ else
59
+ raise exception
60
+ end
61
+ end
62
+ end
63
+ end
64
+ end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module HuggingFace
4
- VERSION = "0.1.0"
4
+ VERSION = "0.3.0"
5
5
  end
data/lib/hugging_face.rb CHANGED
@@ -1,9 +1,11 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  require_relative "hugging_face/version"
4
- require_relative "hugging_face/interface_api"
4
+ require_relative "hugging_face/base_api"
5
+ require_relative "hugging_face/inference_api"
5
6
 
6
7
  module HuggingFace
7
8
  class Error < StandardError; end
9
+ class ServiceUnavailable < Error; end
8
10
  # Your code goes here...
9
11
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: hugging-face
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Alex Chaplinsky
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-05-15 00:00:00.000000000 Z
11
+ date: 2023-05-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: faraday
@@ -39,7 +39,8 @@ files:
39
39
  - README.md
40
40
  - Rakefile
41
41
  - lib/hugging_face.rb
42
- - lib/hugging_face/interface_api.rb
42
+ - lib/hugging_face/base_api.rb
43
+ - lib/hugging_face/inference_api.rb
43
44
  - lib/hugging_face/version.rb
44
45
  - sig/hugging_face.rbs
45
46
  homepage: https://rubygems.org/gems/hugging-face
@@ -1,63 +0,0 @@
1
- require 'faraday'
2
-
3
- module HuggingFace
4
- class InterfaceApi
5
- HOST = "https://api-inference.huggingface.co"
6
- MAX_RETRY = 2
7
- HTTP_SEVICE_UNAVAILABLE = 503
8
-
9
- QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad'
10
- SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6"
11
- GENERATION_MODEL = "distilgpt2"
12
-
13
- def initialize(api_token:)
14
- @headers = {
15
- 'Authorization' => 'Bearer ' + api_token,
16
- 'Content-Type' => 'application/json'
17
- }
18
- end
19
-
20
- def call(input:, model:)
21
- request(connection: connection(model), input: input)
22
- end
23
-
24
- def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL)
25
- input = { question: question, context: context }
26
-
27
- request(connection: connection(model), input: input)
28
- end
29
-
30
- def summarization(input:, model: SUMMARIZATION_MODEL)
31
- request(connection: connection(model), input: { inputs: input })
32
- end
33
-
34
- def text_generation(input:, model: GENERATION_MODEL)
35
- request(connection: connection(model), input: { inputs: input })
36
- end
37
-
38
- private
39
-
40
- def request(connection:, input:)
41
- retries = 0
42
- while retries < MAX_RETRY
43
- response = connection.post { |req| req.body = input.to_json }
44
-
45
- break if response.success?
46
-
47
- if response.status == HTTP_SEVICE_UNAVAILABLE
48
- retries += 1
49
- sleep 1
50
- redo
51
- end
52
-
53
- raise "Error: #{response.body}"
54
- end
55
-
56
- return JSON.parse(response.body)
57
- end
58
-
59
- def connection(model)
60
- Faraday.new(url: "#{HOST}/models/#{model}" , headers: @headers)
61
- end
62
- end
63
- end