hugging-face 0.1.0 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/Gemfile +1 -0
- data/Gemfile.lock +14 -1
- data/README.md +31 -1
- data/lib/hugging_face/base_api.rb +46 -0
- data/lib/hugging_face/inference_api.rb +64 -0
- data/lib/hugging_face/version.rb +1 -1
- data/lib/hugging_face.rb +3 -1
- metadata +4 -3
- data/lib/hugging_face/interface_api.rb +0 -63
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 93e213102e56d8e86de856912a1963f36e4b0a0d60f21c8a1f8d99b23a0f47d7
|
4
|
+
data.tar.gz: 877bd5a93e54ccbcb82943c19e4ac889657686aa656b0bd3bb911732b1d5d6ff
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ee534a915fa831884e42519f178ae709a08c201d79976d97b52319dd5322fc979c3dfe711a3e2210b0d53c0132490a089ccd2ca7c73730919f586582be0dc884
|
7
|
+
data.tar.gz: c4df146833ea986a86d9b9c060f594535b7bdf558dd9d28396dd100869c4a899dc4ddad1bdd2eba1ff8a09610ee4cada09a271adda9f4d6e58d21c5181c925a6
|
data/Gemfile
CHANGED
data/Gemfile.lock
CHANGED
@@ -1,18 +1,25 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
hugging-face (0.
|
4
|
+
hugging-face (0.2.0)
|
5
5
|
faraday (~> 2.7)
|
6
6
|
|
7
7
|
GEM
|
8
8
|
remote: https://rubygems.org/
|
9
9
|
specs:
|
10
|
+
addressable (2.8.4)
|
11
|
+
public_suffix (>= 2.0.2, < 6.0)
|
12
|
+
crack (0.4.5)
|
13
|
+
rexml
|
10
14
|
diff-lcs (1.5.0)
|
11
15
|
faraday (2.7.4)
|
12
16
|
faraday-net_http (>= 2.0, < 3.1)
|
13
17
|
ruby2_keywords (>= 0.0.4)
|
14
18
|
faraday-net_http (3.0.2)
|
19
|
+
hashdiff (1.0.1)
|
20
|
+
public_suffix (5.0.1)
|
15
21
|
rake (13.0.6)
|
22
|
+
rexml (3.2.5)
|
16
23
|
rspec (3.12.0)
|
17
24
|
rspec-core (~> 3.12.0)
|
18
25
|
rspec-expectations (~> 3.12.0)
|
@@ -27,14 +34,20 @@ GEM
|
|
27
34
|
rspec-support (~> 3.12.0)
|
28
35
|
rspec-support (3.12.0)
|
29
36
|
ruby2_keywords (0.0.5)
|
37
|
+
webmock (3.18.1)
|
38
|
+
addressable (>= 2.8.0)
|
39
|
+
crack (>= 0.3.2)
|
40
|
+
hashdiff (>= 0.4.0, < 2.0.0)
|
30
41
|
|
31
42
|
PLATFORMS
|
32
43
|
arm64-darwin-21
|
44
|
+
x86_64-linux
|
33
45
|
|
34
46
|
DEPENDENCIES
|
35
47
|
hugging-face!
|
36
48
|
rake (~> 13.0)
|
37
49
|
rspec (~> 3.0)
|
50
|
+
webmock (~> 3.0)
|
38
51
|
|
39
52
|
BUNDLED WITH
|
40
53
|
2.4.0
|
data/README.md
CHANGED
@@ -18,7 +18,36 @@ $ gem install hugging-face
|
|
18
18
|
|
19
19
|
## Usage
|
20
20
|
|
21
|
-
|
21
|
+
```ruby
|
22
|
+
require "hugging_face"
|
23
|
+
```
|
24
|
+
|
25
|
+
Instantiate a HuggigFace Inference API client:
|
26
|
+
|
27
|
+
```ruby
|
28
|
+
client = HuggingFace::InferenceApi.new(api_key: ENV['HUGGING_FACE_API_KEY'])
|
29
|
+
```
|
30
|
+
|
31
|
+
Question answering:
|
32
|
+
|
33
|
+
```ruby
|
34
|
+
client.question_answering(
|
35
|
+
question: 'What is my name?',
|
36
|
+
context: 'I am the only child. My father named his son John.'
|
37
|
+
)
|
38
|
+
```
|
39
|
+
|
40
|
+
Text generation:
|
41
|
+
|
42
|
+
```ruby
|
43
|
+
client.text_generation(input: 'Can you please let us know more details about your ')
|
44
|
+
```
|
45
|
+
|
46
|
+
Summarization:
|
47
|
+
|
48
|
+
```ruby
|
49
|
+
client.summarization(input: 'The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.')
|
50
|
+
```
|
22
51
|
|
23
52
|
## Development
|
24
53
|
|
@@ -33,3 +62,4 @@ Bug reports and pull requests are welcome on GitHub at https://github.com/alchap
|
|
33
62
|
## Code of Conduct
|
34
63
|
|
35
64
|
Everyone interacting in the HuggingFace project's codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/alchaplinsky/hugging-face/blob/main/CODE_OF_CONDUCT.md).
|
65
|
+
|
@@ -0,0 +1,46 @@
|
|
1
|
+
require 'logger'
|
2
|
+
require 'faraday'
|
3
|
+
|
4
|
+
module HuggingFace
|
5
|
+
class BaseApi
|
6
|
+
HTTP_SERVICE_UNAVAILABLE = 503
|
7
|
+
JSON_CONTENT_TYPE = 'application/json'
|
8
|
+
|
9
|
+
def initialize(api_token:)
|
10
|
+
@headers = {
|
11
|
+
'Authorization' => 'Bearer ' + api_token,
|
12
|
+
'Content-Type' => JSON_CONTENT_TYPE
|
13
|
+
}
|
14
|
+
end
|
15
|
+
|
16
|
+
private
|
17
|
+
|
18
|
+
def build_connection(url)
|
19
|
+
Faraday.new(url, headers: @headers)
|
20
|
+
end
|
21
|
+
|
22
|
+
def request(connection:, input:)
|
23
|
+
response = connection.post { |req| req.body = input.to_json }
|
24
|
+
|
25
|
+
if response.success?
|
26
|
+
return parse_response response
|
27
|
+
else
|
28
|
+
raise ServiceUnavailable.new response.body if response.status == HTTP_SERVICE_UNAVAILABLE
|
29
|
+
raise Error.new response.body
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
def parse_response(response)
|
34
|
+
if response.headers['Content-Type'] == JSON_CONTENT_TYPE
|
35
|
+
JSON.parse(response.body)
|
36
|
+
else
|
37
|
+
response.body
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def logger
|
42
|
+
@logger ||= Logger.new(STDOUT)
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
@@ -0,0 +1,64 @@
|
|
1
|
+
module HuggingFace
|
2
|
+
class InferenceApi < BaseApi
|
3
|
+
HOST = "https://api-inference.huggingface.co"
|
4
|
+
|
5
|
+
# Retry connecting to the model for 1 minute
|
6
|
+
MAX_RETRY = 60
|
7
|
+
|
8
|
+
# Deafult models that can be overriden by 'model' param
|
9
|
+
QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad'
|
10
|
+
SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6"
|
11
|
+
GENERATION_MODEL = "distilgpt2"
|
12
|
+
EMBEDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
13
|
+
|
14
|
+
def call(input:, model:)
|
15
|
+
request(connection: connection(model), input: input)
|
16
|
+
end
|
17
|
+
|
18
|
+
def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL)
|
19
|
+
input = { question: question, context: context }
|
20
|
+
|
21
|
+
request connection: connection(model), input: input
|
22
|
+
end
|
23
|
+
|
24
|
+
def summarization(input:, model: SUMMARIZATION_MODEL)
|
25
|
+
request connection: connection(model), input: { inputs: input }
|
26
|
+
end
|
27
|
+
|
28
|
+
def text_generation(input:, model: GENERATION_MODEL)
|
29
|
+
request connection: connection(model), input: { inputs: input }
|
30
|
+
end
|
31
|
+
|
32
|
+
def embedding(input:)
|
33
|
+
request connection: connection(EMBEDING_MODEL), input: { inputs: input }
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def connection(model)
|
39
|
+
if model == EMBEDING_MODEL
|
40
|
+
build_connection "#{HOST}/pipeline/feature-extraction/#{model}"
|
41
|
+
else
|
42
|
+
build_connection "#{HOST}/models/#{model}"
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def request(connection:, input:)
|
47
|
+
retries = 0
|
48
|
+
|
49
|
+
begin
|
50
|
+
return super(connection: connection, input: input)
|
51
|
+
rescue ServiceUnavailable => exception
|
52
|
+
|
53
|
+
if retries < MAX_RETRY
|
54
|
+
logger.debug('Service unavailable, retrying...')
|
55
|
+
retries += 1
|
56
|
+
sleep 1
|
57
|
+
retry
|
58
|
+
else
|
59
|
+
raise exception
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
data/lib/hugging_face/version.rb
CHANGED
data/lib/hugging_face.rb
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
3
|
require_relative "hugging_face/version"
|
4
|
-
require_relative "hugging_face/
|
4
|
+
require_relative "hugging_face/base_api"
|
5
|
+
require_relative "hugging_face/inference_api"
|
5
6
|
|
6
7
|
module HuggingFace
|
7
8
|
class Error < StandardError; end
|
9
|
+
class ServiceUnavailable < Error; end
|
8
10
|
# Your code goes here...
|
9
11
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: hugging-face
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Alex Chaplinsky
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-05-
|
11
|
+
date: 2023-05-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: faraday
|
@@ -39,7 +39,8 @@ files:
|
|
39
39
|
- README.md
|
40
40
|
- Rakefile
|
41
41
|
- lib/hugging_face.rb
|
42
|
-
- lib/hugging_face/
|
42
|
+
- lib/hugging_face/base_api.rb
|
43
|
+
- lib/hugging_face/inference_api.rb
|
43
44
|
- lib/hugging_face/version.rb
|
44
45
|
- sig/hugging_face.rbs
|
45
46
|
homepage: https://rubygems.org/gems/hugging-face
|
@@ -1,63 +0,0 @@
|
|
1
|
-
require 'faraday'
|
2
|
-
|
3
|
-
module HuggingFace
|
4
|
-
class InterfaceApi
|
5
|
-
HOST = "https://api-inference.huggingface.co"
|
6
|
-
MAX_RETRY = 2
|
7
|
-
HTTP_SEVICE_UNAVAILABLE = 503
|
8
|
-
|
9
|
-
QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad'
|
10
|
-
SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6"
|
11
|
-
GENERATION_MODEL = "distilgpt2"
|
12
|
-
|
13
|
-
def initialize(api_token:)
|
14
|
-
@headers = {
|
15
|
-
'Authorization' => 'Bearer ' + api_token,
|
16
|
-
'Content-Type' => 'application/json'
|
17
|
-
}
|
18
|
-
end
|
19
|
-
|
20
|
-
def call(input:, model:)
|
21
|
-
request(connection: connection(model), input: input)
|
22
|
-
end
|
23
|
-
|
24
|
-
def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL)
|
25
|
-
input = { question: question, context: context }
|
26
|
-
|
27
|
-
request(connection: connection(model), input: input)
|
28
|
-
end
|
29
|
-
|
30
|
-
def summarization(input:, model: SUMMARIZATION_MODEL)
|
31
|
-
request(connection: connection(model), input: { inputs: input })
|
32
|
-
end
|
33
|
-
|
34
|
-
def text_generation(input:, model: GENERATION_MODEL)
|
35
|
-
request(connection: connection(model), input: { inputs: input })
|
36
|
-
end
|
37
|
-
|
38
|
-
private
|
39
|
-
|
40
|
-
def request(connection:, input:)
|
41
|
-
retries = 0
|
42
|
-
while retries < MAX_RETRY
|
43
|
-
response = connection.post { |req| req.body = input.to_json }
|
44
|
-
|
45
|
-
break if response.success?
|
46
|
-
|
47
|
-
if response.status == HTTP_SEVICE_UNAVAILABLE
|
48
|
-
retries += 1
|
49
|
-
sleep 1
|
50
|
-
redo
|
51
|
-
end
|
52
|
-
|
53
|
-
raise "Error: #{response.body}"
|
54
|
-
end
|
55
|
-
|
56
|
-
return JSON.parse(response.body)
|
57
|
-
end
|
58
|
-
|
59
|
-
def connection(model)
|
60
|
-
Faraday.new(url: "#{HOST}/models/#{model}" , headers: @headers)
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|