huggingface-hub 0.24.6__py3-none-any.whl → 0.25.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +21 -1
- huggingface_hub/_commit_api.py +4 -4
- huggingface_hub/_inference_endpoints.py +13 -1
- huggingface_hub/_local_folder.py +191 -4
- huggingface_hub/_login.py +6 -6
- huggingface_hub/_snapshot_download.py +8 -17
- huggingface_hub/_space_api.py +5 -0
- huggingface_hub/_tensorboard_logger.py +29 -13
- huggingface_hub/_upload_large_folder.py +573 -0
- huggingface_hub/_webhooks_server.py +1 -1
- huggingface_hub/commands/_cli_utils.py +5 -0
- huggingface_hub/commands/download.py +8 -0
- huggingface_hub/commands/huggingface_cli.py +6 -1
- huggingface_hub/commands/lfs.py +2 -1
- huggingface_hub/commands/repo_files.py +2 -2
- huggingface_hub/commands/scan_cache.py +99 -57
- huggingface_hub/commands/tag.py +1 -1
- huggingface_hub/commands/upload.py +2 -1
- huggingface_hub/commands/upload_large_folder.py +129 -0
- huggingface_hub/commands/version.py +37 -0
- huggingface_hub/community.py +2 -2
- huggingface_hub/errors.py +218 -1
- huggingface_hub/fastai_utils.py +2 -3
- huggingface_hub/file_download.py +63 -63
- huggingface_hub/hf_api.py +758 -314
- huggingface_hub/hf_file_system.py +15 -23
- huggingface_hub/hub_mixin.py +27 -25
- huggingface_hub/inference/_client.py +78 -127
- huggingface_hub/inference/_generated/_async_client.py +169 -144
- huggingface_hub/inference/_generated/types/base.py +0 -9
- huggingface_hub/inference/_templating.py +2 -3
- huggingface_hub/inference_api.py +2 -2
- huggingface_hub/keras_mixin.py +2 -2
- huggingface_hub/lfs.py +7 -98
- huggingface_hub/repocard.py +6 -5
- huggingface_hub/repository.py +5 -5
- huggingface_hub/serialization/_torch.py +64 -11
- huggingface_hub/utils/__init__.py +13 -14
- huggingface_hub/utils/_cache_manager.py +97 -14
- huggingface_hub/utils/_fixes.py +18 -2
- huggingface_hub/utils/_http.py +228 -2
- huggingface_hub/utils/_lfs.py +110 -0
- huggingface_hub/utils/_runtime.py +7 -1
- huggingface_hub/utils/_token.py +3 -2
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/METADATA +2 -2
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/RECORD +50 -48
- huggingface_hub/inference/_types.py +0 -52
- huggingface_hub/utils/_errors.py +0 -397
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -32,6 +32,7 @@ from typing import (
|
|
|
32
32
|
List,
|
|
33
33
|
Literal,
|
|
34
34
|
Optional,
|
|
35
|
+
Set,
|
|
35
36
|
Union,
|
|
36
37
|
overload,
|
|
37
38
|
)
|
|
@@ -86,9 +87,6 @@ from huggingface_hub.inference._generated.types import (
|
|
|
86
87
|
ZeroShotClassificationOutputElement,
|
|
87
88
|
ZeroShotImageClassificationOutputElement,
|
|
88
89
|
)
|
|
89
|
-
from huggingface_hub.inference._types import (
|
|
90
|
-
ConversationalOutput, # soon to be removed
|
|
91
|
-
)
|
|
92
90
|
from huggingface_hub.utils import (
|
|
93
91
|
build_hf_headers,
|
|
94
92
|
)
|
|
@@ -99,6 +97,7 @@ from .._common import _async_yield_from, _import_aiohttp
|
|
|
99
97
|
|
|
100
98
|
if TYPE_CHECKING:
|
|
101
99
|
import numpy as np
|
|
100
|
+
from aiohttp import ClientResponse, ClientSession
|
|
102
101
|
from PIL.Image import Image
|
|
103
102
|
|
|
104
103
|
logger = logging.getLogger(__name__)
|
|
@@ -120,7 +119,9 @@ class AsyncInferenceClient:
|
|
|
120
119
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
121
120
|
automatically selected for the task.
|
|
122
121
|
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
|
123
|
-
arguments are mutually exclusive
|
|
122
|
+
arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
|
|
123
|
+
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
124
|
+
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
124
125
|
token (`str` or `bool`, *optional*):
|
|
125
126
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
126
127
|
Pass `token=False` if you don't want to send your token to the server.
|
|
@@ -134,6 +135,10 @@ class AsyncInferenceClient:
|
|
|
134
135
|
Values in this dictionary will override the default values.
|
|
135
136
|
cookies (`Dict[str, str]`, `optional`):
|
|
136
137
|
Additional cookies to send to the server.
|
|
138
|
+
trust_env ('bool', 'optional'):
|
|
139
|
+
Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).
|
|
140
|
+
proxies (`Any`, `optional`):
|
|
141
|
+
Proxies to use for the request.
|
|
137
142
|
base_url (`str`, `optional`):
|
|
138
143
|
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
|
|
139
144
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
|
|
@@ -151,6 +156,7 @@ class AsyncInferenceClient:
|
|
|
151
156
|
timeout: Optional[float] = None,
|
|
152
157
|
headers: Optional[Dict[str, str]] = None,
|
|
153
158
|
cookies: Optional[Dict[str, str]] = None,
|
|
159
|
+
trust_env: bool = False,
|
|
154
160
|
proxies: Optional[Any] = None,
|
|
155
161
|
# OpenAI compatibility
|
|
156
162
|
base_url: Optional[str] = None,
|
|
@@ -160,7 +166,8 @@ class AsyncInferenceClient:
|
|
|
160
166
|
raise ValueError(
|
|
161
167
|
"Received both `model` and `base_url` arguments. Please provide only one of them."
|
|
162
168
|
" `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
|
|
163
|
-
"
|
|
169
|
+
" If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url."
|
|
170
|
+
" When passing a URL as `model`, the client will not append any suffix path to it."
|
|
164
171
|
)
|
|
165
172
|
if token is not None and api_key is not None:
|
|
166
173
|
raise ValueError(
|
|
@@ -176,11 +183,15 @@ class AsyncInferenceClient:
|
|
|
176
183
|
self.headers.update(headers)
|
|
177
184
|
self.cookies = cookies
|
|
178
185
|
self.timeout = timeout
|
|
186
|
+
self.trust_env = trust_env
|
|
179
187
|
self.proxies = proxies
|
|
180
188
|
|
|
181
189
|
# OpenAI compatibility
|
|
182
190
|
self.base_url = base_url
|
|
183
191
|
|
|
192
|
+
# Keep track of the sessions to close them properly
|
|
193
|
+
self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict()
|
|
194
|
+
|
|
184
195
|
def __repr__(self):
|
|
185
196
|
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
|
|
186
197
|
|
|
@@ -265,7 +276,7 @@ class AsyncInferenceClient:
|
|
|
265
276
|
warnings.warn("Ignoring `json` as `data` is passed as binary.")
|
|
266
277
|
|
|
267
278
|
# Set Accept header if relevant
|
|
268
|
-
headers =
|
|
279
|
+
headers = dict()
|
|
269
280
|
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
|
|
270
281
|
headers["Accept"] = "image/png"
|
|
271
282
|
|
|
@@ -275,12 +286,10 @@ class AsyncInferenceClient:
|
|
|
275
286
|
with _open_as_binary(data) as data_as_binary:
|
|
276
287
|
# Do not use context manager as we don't want to close the connection immediately when returning
|
|
277
288
|
# a stream
|
|
278
|
-
|
|
279
|
-
headers=headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout)
|
|
280
|
-
)
|
|
289
|
+
session = self._get_client_session(headers=headers)
|
|
281
290
|
|
|
282
291
|
try:
|
|
283
|
-
response = await
|
|
292
|
+
response = await session.post(url, json=json, data=data_as_binary, proxy=self.proxies)
|
|
284
293
|
response_error_payload = None
|
|
285
294
|
if response.status != 200:
|
|
286
295
|
try:
|
|
@@ -289,18 +298,18 @@ class AsyncInferenceClient:
|
|
|
289
298
|
pass
|
|
290
299
|
response.raise_for_status()
|
|
291
300
|
if stream:
|
|
292
|
-
return _async_yield_from(
|
|
301
|
+
return _async_yield_from(session, response)
|
|
293
302
|
else:
|
|
294
303
|
content = await response.read()
|
|
295
|
-
await
|
|
304
|
+
await session.close()
|
|
296
305
|
return content
|
|
297
306
|
except asyncio.TimeoutError as error:
|
|
298
|
-
await
|
|
307
|
+
await session.close()
|
|
299
308
|
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
300
309
|
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
|
|
301
310
|
except aiohttp.ClientResponseError as error:
|
|
302
311
|
error.response_error_payload = response_error_payload
|
|
303
|
-
await
|
|
312
|
+
await session.close()
|
|
304
313
|
if response.status == 422 and task is not None:
|
|
305
314
|
error.message += f". Make sure '{task}' task is supported by the model."
|
|
306
315
|
if response.status == 503:
|
|
@@ -322,9 +331,35 @@ class AsyncInferenceClient:
|
|
|
322
331
|
continue
|
|
323
332
|
raise error
|
|
324
333
|
except Exception:
|
|
325
|
-
await
|
|
334
|
+
await session.close()
|
|
326
335
|
raise
|
|
327
336
|
|
|
337
|
+
async def __aenter__(self):
|
|
338
|
+
return self
|
|
339
|
+
|
|
340
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
341
|
+
await self.close()
|
|
342
|
+
|
|
343
|
+
def __del__(self):
|
|
344
|
+
if len(self._sessions) > 0:
|
|
345
|
+
warnings.warn(
|
|
346
|
+
"Deleting 'AsyncInferenceClient' client but some sessions are still open. "
|
|
347
|
+
"This can happen if you've stopped streaming data from the server before the stream was complete. "
|
|
348
|
+
"To close the client properly, you must call `await client.close()` "
|
|
349
|
+
"or use an async context (e.g. `async with AsyncInferenceClient(): ...`."
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
async def close(self):
|
|
353
|
+
"""Close all open sessions.
|
|
354
|
+
|
|
355
|
+
By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you
|
|
356
|
+
are streaming data from the server and you stop before the stream is complete, you must call this method to
|
|
357
|
+
close the session properly.
|
|
358
|
+
|
|
359
|
+
Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`).
|
|
360
|
+
"""
|
|
361
|
+
await asyncio.gather(*[session.close() for session in self._sessions.keys()])
|
|
362
|
+
|
|
328
363
|
async def audio_classification(
|
|
329
364
|
self,
|
|
330
365
|
audio: ContentT,
|
|
@@ -815,134 +850,66 @@ class AsyncInferenceClient:
|
|
|
815
850
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
816
851
|
```
|
|
817
852
|
"""
|
|
818
|
-
|
|
819
|
-
# `self.xxx` takes precedence over the method argument only in `chat_completion`
|
|
820
|
-
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
|
|
821
|
-
# server, we need to handle it differently
|
|
822
|
-
model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
|
|
823
|
-
is_url = model.startswith(("http://", "https://"))
|
|
824
|
-
|
|
825
|
-
# First, resolve the model chat completions URL
|
|
826
|
-
if model == self.base_url:
|
|
827
|
-
# base_url passed => add server route
|
|
828
|
-
model_url = model.rstrip("/")
|
|
829
|
-
if not model_url.endswith("/v1"):
|
|
830
|
-
model_url += "/v1"
|
|
831
|
-
model_url += "/chat/completions"
|
|
832
|
-
elif is_url:
|
|
833
|
-
# model is a URL => use it directly
|
|
834
|
-
model_url = model
|
|
835
|
-
else:
|
|
836
|
-
# model is a model ID => resolve it + add server route
|
|
837
|
-
model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"
|
|
853
|
+
model_url = self._resolve_chat_completion_url(model)
|
|
838
854
|
|
|
839
855
|
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
|
|
840
856
|
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
|
|
841
|
-
model_id = model
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
stream=stream,
|
|
864
|
-
),
|
|
857
|
+
model_id = model or self.model or "tgi"
|
|
858
|
+
if model_id.startswith(("http://", "https://")):
|
|
859
|
+
model_id = "tgi" # dummy value
|
|
860
|
+
|
|
861
|
+
payload = dict(
|
|
862
|
+
model=model_id,
|
|
863
|
+
messages=messages,
|
|
864
|
+
frequency_penalty=frequency_penalty,
|
|
865
|
+
logit_bias=logit_bias,
|
|
866
|
+
logprobs=logprobs,
|
|
867
|
+
max_tokens=max_tokens,
|
|
868
|
+
n=n,
|
|
869
|
+
presence_penalty=presence_penalty,
|
|
870
|
+
response_format=response_format,
|
|
871
|
+
seed=seed,
|
|
872
|
+
stop=stop,
|
|
873
|
+
temperature=temperature,
|
|
874
|
+
tool_choice=tool_choice,
|
|
875
|
+
tool_prompt=tool_prompt,
|
|
876
|
+
tools=tools,
|
|
877
|
+
top_logprobs=top_logprobs,
|
|
878
|
+
top_p=top_p,
|
|
865
879
|
stream=stream,
|
|
866
880
|
)
|
|
881
|
+
payload = {key: value for key, value in payload.items() if value is not None}
|
|
882
|
+
data = await self.post(model=model_url, json=payload, stream=stream)
|
|
867
883
|
|
|
868
884
|
if stream:
|
|
869
885
|
return _async_stream_chat_completion_response(data) # type: ignore[arg-type]
|
|
870
886
|
|
|
871
887
|
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
872
888
|
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
past_user_inputs: Optional[List[str]] = None,
|
|
878
|
-
*,
|
|
879
|
-
parameters: Optional[Dict[str, Any]] = None,
|
|
880
|
-
model: Optional[str] = None,
|
|
881
|
-
) -> ConversationalOutput:
|
|
882
|
-
"""
|
|
883
|
-
Generate conversational responses based on the given input text (i.e. chat with the API).
|
|
884
|
-
|
|
885
|
-
<Tip warning={true}>
|
|
886
|
-
|
|
887
|
-
[`InferenceClient.conversational`] API is deprecated and will be removed in a future release. Please use
|
|
888
|
-
[`InferenceClient.chat_completion`] instead.
|
|
889
|
+
def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
|
|
890
|
+
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
|
|
891
|
+
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
|
|
892
|
+
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
|
|
889
893
|
|
|
890
|
-
|
|
894
|
+
# Resolve URL if it's a model ID
|
|
895
|
+
model_url = (
|
|
896
|
+
model_id_or_url
|
|
897
|
+
if model_id_or_url.startswith(("http://", "https://"))
|
|
898
|
+
else self._resolve_url(model_id_or_url, task="text-generation")
|
|
899
|
+
)
|
|
891
900
|
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
The last input from the user in the conversation.
|
|
895
|
-
generated_responses (`List[str]`, *optional*):
|
|
896
|
-
A list of strings corresponding to the earlier replies from the model. Defaults to None.
|
|
897
|
-
past_user_inputs (`List[str]`, *optional*):
|
|
898
|
-
A list of strings corresponding to the earlier replies from the user. Should be the same length as
|
|
899
|
-
`generated_responses`. Defaults to None.
|
|
900
|
-
parameters (`Dict[str, Any]`, *optional*):
|
|
901
|
-
Additional parameters for the conversational task. Defaults to None. For more details about the available
|
|
902
|
-
parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#conversational-task)
|
|
903
|
-
model (`str`, *optional*):
|
|
904
|
-
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
905
|
-
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
|
906
|
-
Defaults to None.
|
|
901
|
+
# Strip trailing /
|
|
902
|
+
model_url = model_url.rstrip("/")
|
|
907
903
|
|
|
908
|
-
|
|
909
|
-
|
|
904
|
+
# Append /chat/completions if not already present
|
|
905
|
+
if model_url.endswith("/v1"):
|
|
906
|
+
model_url += "/chat/completions"
|
|
910
907
|
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
`aiohttp.ClientResponseError`:
|
|
915
|
-
If the request fails with an HTTP error status code other than HTTP 503.
|
|
908
|
+
# Append /v1/chat/completions if not already present
|
|
909
|
+
if not model_url.endswith("/chat/completions"):
|
|
910
|
+
model_url += "/v1/chat/completions"
|
|
916
911
|
|
|
917
|
-
|
|
918
|
-
```py
|
|
919
|
-
# Must be run in an async context
|
|
920
|
-
>>> from huggingface_hub import AsyncInferenceClient
|
|
921
|
-
>>> client = AsyncInferenceClient()
|
|
922
|
-
>>> output = await client.conversational("Hi, who are you?")
|
|
923
|
-
>>> output
|
|
924
|
-
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
|
|
925
|
-
>>> await client.conversational(
|
|
926
|
-
... "Wow, that's scary!",
|
|
927
|
-
... generated_responses=output["conversation"]["generated_responses"],
|
|
928
|
-
... past_user_inputs=output["conversation"]["past_user_inputs"],
|
|
929
|
-
... )
|
|
930
|
-
```
|
|
931
|
-
"""
|
|
932
|
-
warnings.warn(
|
|
933
|
-
"'InferenceClient.conversational' is deprecated and will be removed starting from huggingface_hub>=0.25. "
|
|
934
|
-
"Please use the more appropriate 'InferenceClient.chat_completion' API instead.",
|
|
935
|
-
FutureWarning,
|
|
936
|
-
)
|
|
937
|
-
payload: Dict[str, Any] = {"inputs": {"text": text}}
|
|
938
|
-
if generated_responses is not None:
|
|
939
|
-
payload["inputs"]["generated_responses"] = generated_responses
|
|
940
|
-
if past_user_inputs is not None:
|
|
941
|
-
payload["inputs"]["past_user_inputs"] = past_user_inputs
|
|
942
|
-
if parameters is not None:
|
|
943
|
-
payload["parameters"] = parameters
|
|
944
|
-
response = await self.post(json=payload, model=model, task="conversational")
|
|
945
|
-
return _bytes_to_dict(response) # type: ignore
|
|
912
|
+
return model_url
|
|
946
913
|
|
|
947
914
|
async def document_question_answering(
|
|
948
915
|
self,
|
|
@@ -1373,8 +1340,8 @@ class AsyncInferenceClient:
|
|
|
1373
1340
|
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
1374
1341
|
|
|
1375
1342
|
async def _fetch_framework(framework: str) -> None:
|
|
1376
|
-
async with
|
|
1377
|
-
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
|
|
1343
|
+
async with self._get_client_session() as client:
|
|
1344
|
+
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies)
|
|
1378
1345
|
response.raise_for_status()
|
|
1379
1346
|
_unpack_response(framework, await response.json())
|
|
1380
1347
|
|
|
@@ -1757,7 +1724,8 @@ class AsyncInferenceClient:
|
|
|
1757
1724
|
repetition_penalty: Optional[float] = None,
|
|
1758
1725
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1759
1726
|
seed: Optional[int] = None,
|
|
1760
|
-
|
|
1727
|
+
stop: Optional[List[str]] = None,
|
|
1728
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1761
1729
|
temperature: Optional[float] = None,
|
|
1762
1730
|
top_k: Optional[int] = None,
|
|
1763
1731
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1786,7 +1754,8 @@ class AsyncInferenceClient:
|
|
|
1786
1754
|
repetition_penalty: Optional[float] = None,
|
|
1787
1755
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1788
1756
|
seed: Optional[int] = None,
|
|
1789
|
-
|
|
1757
|
+
stop: Optional[List[str]] = None,
|
|
1758
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1790
1759
|
temperature: Optional[float] = None,
|
|
1791
1760
|
top_k: Optional[int] = None,
|
|
1792
1761
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1815,7 +1784,8 @@ class AsyncInferenceClient:
|
|
|
1815
1784
|
repetition_penalty: Optional[float] = None,
|
|
1816
1785
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1817
1786
|
seed: Optional[int] = None,
|
|
1818
|
-
|
|
1787
|
+
stop: Optional[List[str]] = None,
|
|
1788
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1819
1789
|
temperature: Optional[float] = None,
|
|
1820
1790
|
top_k: Optional[int] = None,
|
|
1821
1791
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1844,7 +1814,8 @@ class AsyncInferenceClient:
|
|
|
1844
1814
|
repetition_penalty: Optional[float] = None,
|
|
1845
1815
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1846
1816
|
seed: Optional[int] = None,
|
|
1847
|
-
|
|
1817
|
+
stop: Optional[List[str]] = None,
|
|
1818
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1848
1819
|
temperature: Optional[float] = None,
|
|
1849
1820
|
top_k: Optional[int] = None,
|
|
1850
1821
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1873,7 +1844,8 @@ class AsyncInferenceClient:
|
|
|
1873
1844
|
repetition_penalty: Optional[float] = None,
|
|
1874
1845
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1875
1846
|
seed: Optional[int] = None,
|
|
1876
|
-
|
|
1847
|
+
stop: Optional[List[str]] = None,
|
|
1848
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1877
1849
|
temperature: Optional[float] = None,
|
|
1878
1850
|
top_k: Optional[int] = None,
|
|
1879
1851
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1901,7 +1873,8 @@ class AsyncInferenceClient:
|
|
|
1901
1873
|
repetition_penalty: Optional[float] = None,
|
|
1902
1874
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1903
1875
|
seed: Optional[int] = None,
|
|
1904
|
-
|
|
1876
|
+
stop: Optional[List[str]] = None,
|
|
1877
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1905
1878
|
temperature: Optional[float] = None,
|
|
1906
1879
|
top_k: Optional[int] = None,
|
|
1907
1880
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1966,8 +1939,10 @@ class AsyncInferenceClient:
|
|
|
1966
1939
|
Whether to prepend the prompt to the generated text
|
|
1967
1940
|
seed (`int`, *optional*):
|
|
1968
1941
|
Random sampling seed
|
|
1942
|
+
stop (`List[str]`, *optional*):
|
|
1943
|
+
Stop generating tokens if a member of `stop` is generated.
|
|
1969
1944
|
stop_sequences (`List[str]`, *optional*):
|
|
1970
|
-
|
|
1945
|
+
Deprecated argument. Use `stop` instead.
|
|
1971
1946
|
temperature (`float`, *optional*):
|
|
1972
1947
|
The value used to module the logits distribution.
|
|
1973
1948
|
top_n_tokens (`int`, *optional*):
|
|
@@ -2112,6 +2087,15 @@ class AsyncInferenceClient:
|
|
|
2112
2087
|
)
|
|
2113
2088
|
decoder_input_details = False
|
|
2114
2089
|
|
|
2090
|
+
if stop_sequences is not None:
|
|
2091
|
+
warnings.warn(
|
|
2092
|
+
"`stop_sequences` is a deprecated argument for `text_generation` task"
|
|
2093
|
+
" and will be removed in version '0.28.0'. Use `stop` instead.",
|
|
2094
|
+
FutureWarning,
|
|
2095
|
+
)
|
|
2096
|
+
if stop is None:
|
|
2097
|
+
stop = stop_sequences # use deprecated arg if provided
|
|
2098
|
+
|
|
2115
2099
|
# Build payload
|
|
2116
2100
|
parameters = {
|
|
2117
2101
|
"adapter_id": adapter_id,
|
|
@@ -2125,7 +2109,7 @@ class AsyncInferenceClient:
|
|
|
2125
2109
|
"repetition_penalty": repetition_penalty,
|
|
2126
2110
|
"return_full_text": return_full_text,
|
|
2127
2111
|
"seed": seed,
|
|
2128
|
-
"stop":
|
|
2112
|
+
"stop": stop if stop is not None else [],
|
|
2129
2113
|
"temperature": temperature,
|
|
2130
2114
|
"top_k": top_k,
|
|
2131
2115
|
"top_n_tokens": top_n_tokens,
|
|
@@ -2195,7 +2179,7 @@ class AsyncInferenceClient:
|
|
|
2195
2179
|
repetition_penalty=repetition_penalty,
|
|
2196
2180
|
return_full_text=return_full_text,
|
|
2197
2181
|
seed=seed,
|
|
2198
|
-
|
|
2182
|
+
stop=stop,
|
|
2199
2183
|
temperature=temperature,
|
|
2200
2184
|
top_k=top_k,
|
|
2201
2185
|
top_n_tokens=top_n_tokens,
|
|
@@ -2655,6 +2639,47 @@ class AsyncInferenceClient:
|
|
|
2655
2639
|
)
|
|
2656
2640
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
2657
2641
|
|
|
2642
|
+
def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
|
|
2643
|
+
aiohttp = _import_aiohttp()
|
|
2644
|
+
client_headers = self.headers.copy()
|
|
2645
|
+
if headers is not None:
|
|
2646
|
+
client_headers.update(headers)
|
|
2647
|
+
|
|
2648
|
+
# Return a new aiohttp ClientSession with correct settings.
|
|
2649
|
+
session = aiohttp.ClientSession(
|
|
2650
|
+
headers=client_headers,
|
|
2651
|
+
cookies=self.cookies,
|
|
2652
|
+
timeout=aiohttp.ClientTimeout(self.timeout),
|
|
2653
|
+
trust_env=self.trust_env,
|
|
2654
|
+
)
|
|
2655
|
+
|
|
2656
|
+
# Keep track of sessions to close them later
|
|
2657
|
+
self._sessions[session] = set()
|
|
2658
|
+
|
|
2659
|
+
# Override the `._request` method to register responses to be closed
|
|
2660
|
+
session._wrapped_request = session._request
|
|
2661
|
+
|
|
2662
|
+
async def _request(method, url, **kwargs):
|
|
2663
|
+
response = await session._wrapped_request(method, url, **kwargs)
|
|
2664
|
+
self._sessions[session].add(response)
|
|
2665
|
+
return response
|
|
2666
|
+
|
|
2667
|
+
session._request = _request
|
|
2668
|
+
|
|
2669
|
+
# Override the 'close' method to
|
|
2670
|
+
# 1. close ongoing responses
|
|
2671
|
+
# 2. deregister the session when closed
|
|
2672
|
+
session._close = session.close
|
|
2673
|
+
|
|
2674
|
+
async def close_session():
|
|
2675
|
+
for response in self._sessions[session]:
|
|
2676
|
+
response.close()
|
|
2677
|
+
await session._close()
|
|
2678
|
+
self._sessions.pop(session, None)
|
|
2679
|
+
|
|
2680
|
+
session.close = close_session
|
|
2681
|
+
return session
|
|
2682
|
+
|
|
2658
2683
|
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
|
|
2659
2684
|
model = model or self.model or self.base_url
|
|
2660
2685
|
|
|
@@ -2761,8 +2786,8 @@ class AsyncInferenceClient:
|
|
|
2761
2786
|
else:
|
|
2762
2787
|
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
|
|
2763
2788
|
|
|
2764
|
-
async with
|
|
2765
|
-
response = await client.get(url)
|
|
2789
|
+
async with self._get_client_session() as client:
|
|
2790
|
+
response = await client.get(url, proxy=self.proxies)
|
|
2766
2791
|
response.raise_for_status()
|
|
2767
2792
|
return await response.json()
|
|
2768
2793
|
|
|
@@ -2798,8 +2823,8 @@ class AsyncInferenceClient:
|
|
|
2798
2823
|
)
|
|
2799
2824
|
url = model.rstrip("/") + "/health"
|
|
2800
2825
|
|
|
2801
|
-
async with
|
|
2802
|
-
response = await client.get(url)
|
|
2826
|
+
async with self._get_client_session() as client:
|
|
2827
|
+
response = await client.get(url, proxy=self.proxies)
|
|
2803
2828
|
return response.status == 200
|
|
2804
2829
|
|
|
2805
2830
|
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
@@ -2840,8 +2865,8 @@ class AsyncInferenceClient:
|
|
|
2840
2865
|
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
|
2841
2866
|
url = f"{INFERENCE_ENDPOINT}/status/{model}"
|
|
2842
2867
|
|
|
2843
|
-
async with
|
|
2844
|
-
response = await client.get(url)
|
|
2868
|
+
async with self._get_client_session() as client:
|
|
2869
|
+
response = await client.get(url, proxy=self.proxies)
|
|
2845
2870
|
response.raise_for_status()
|
|
2846
2871
|
response_data = await response.json()
|
|
2847
2872
|
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
import inspect
|
|
17
17
|
import json
|
|
18
|
-
import warnings
|
|
19
18
|
from dataclasses import asdict, dataclass
|
|
20
19
|
from typing import Any, Dict, List, Type, TypeVar, Union, get_args
|
|
21
20
|
|
|
@@ -135,14 +134,6 @@ class BaseInferenceType(dict):
|
|
|
135
134
|
self[__name] = __value
|
|
136
135
|
return
|
|
137
136
|
|
|
138
|
-
def __getitem__(self, __key: Any) -> Any:
|
|
139
|
-
warnings.warn(
|
|
140
|
-
f"Accessing '{self.__class__.__name__}' values through dict is deprecated and "
|
|
141
|
-
"will be removed from version '0.25'. Use dataclass attributes instead.",
|
|
142
|
-
FutureWarning,
|
|
143
|
-
)
|
|
144
|
-
return super().__getitem__(__key)
|
|
145
|
-
|
|
146
137
|
|
|
147
138
|
def normalize_key(key: str) -> str:
|
|
148
139
|
# e.g "content-type" -> "content_type", "Accept" -> "accept"
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from functools import lru_cache
|
|
2
2
|
from typing import Callable, Dict, List, Optional, Union
|
|
3
3
|
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
from ..utils import HfHubHTTPError, RepositoryNotFoundError, is_minijinja_available
|
|
4
|
+
from ..errors import HfHubHTTPError, RepositoryNotFoundError, TemplateError
|
|
5
|
+
from ..utils import is_minijinja_available
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
def _import_minijinja():
|
huggingface_hub/inference_api.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import io
|
|
2
2
|
from typing import Any, Dict, List, Optional, Union
|
|
3
3
|
|
|
4
|
-
from .
|
|
4
|
+
from . import constants
|
|
5
5
|
from .hf_api import HfApi
|
|
6
6
|
from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args
|
|
7
7
|
from .utils._deprecation import _deprecate_method
|
|
@@ -149,7 +149,7 @@ class InferenceApi:
|
|
|
149
149
|
assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None"
|
|
150
150
|
self.task = model_info.pipeline_tag
|
|
151
151
|
|
|
152
|
-
self.api_url = f"{INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
|
|
152
|
+
self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
|
|
153
153
|
|
|
154
154
|
def __repr__(self):
|
|
155
155
|
# Do not add headers to repr to avoid leaking token.
|
huggingface_hub/keras_mixin.py
CHANGED
|
@@ -16,7 +16,7 @@ from huggingface_hub.utils import (
|
|
|
16
16
|
yaml_dump,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
-
from .
|
|
19
|
+
from . import constants
|
|
20
20
|
from .hf_api import HfApi
|
|
21
21
|
from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
|
|
22
22
|
from .utils._typing import CallableT
|
|
@@ -202,7 +202,7 @@ def save_pretrained_keras(
|
|
|
202
202
|
if not isinstance(config, dict):
|
|
203
203
|
raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'")
|
|
204
204
|
|
|
205
|
-
with (save_directory / CONFIG_NAME).open("w") as f:
|
|
205
|
+
with (save_directory / constants.CONFIG_NAME).open("w") as f:
|
|
206
206
|
json.dump(config, f)
|
|
207
207
|
|
|
208
208
|
metadata = {}
|