huggingface-hub 0.24.7__py3-none-any.whl → 0.25.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (52) hide show
  1. huggingface_hub/__init__.py +21 -1
  2. huggingface_hub/_commit_api.py +4 -4
  3. huggingface_hub/_inference_endpoints.py +13 -1
  4. huggingface_hub/_local_folder.py +191 -4
  5. huggingface_hub/_login.py +6 -6
  6. huggingface_hub/_snapshot_download.py +8 -17
  7. huggingface_hub/_space_api.py +5 -0
  8. huggingface_hub/_tensorboard_logger.py +29 -13
  9. huggingface_hub/_upload_large_folder.py +573 -0
  10. huggingface_hub/_webhooks_server.py +1 -1
  11. huggingface_hub/commands/_cli_utils.py +5 -0
  12. huggingface_hub/commands/download.py +8 -0
  13. huggingface_hub/commands/huggingface_cli.py +6 -1
  14. huggingface_hub/commands/lfs.py +2 -1
  15. huggingface_hub/commands/repo_files.py +2 -2
  16. huggingface_hub/commands/scan_cache.py +99 -57
  17. huggingface_hub/commands/tag.py +1 -1
  18. huggingface_hub/commands/upload.py +2 -1
  19. huggingface_hub/commands/upload_large_folder.py +129 -0
  20. huggingface_hub/commands/version.py +37 -0
  21. huggingface_hub/community.py +2 -2
  22. huggingface_hub/errors.py +218 -1
  23. huggingface_hub/fastai_utils.py +2 -3
  24. huggingface_hub/file_download.py +61 -62
  25. huggingface_hub/hf_api.py +758 -314
  26. huggingface_hub/hf_file_system.py +15 -23
  27. huggingface_hub/hub_mixin.py +27 -25
  28. huggingface_hub/inference/_client.py +78 -127
  29. huggingface_hub/inference/_generated/_async_client.py +169 -144
  30. huggingface_hub/inference/_generated/types/base.py +0 -9
  31. huggingface_hub/inference/_templating.py +2 -3
  32. huggingface_hub/inference_api.py +2 -2
  33. huggingface_hub/keras_mixin.py +2 -2
  34. huggingface_hub/lfs.py +7 -98
  35. huggingface_hub/repocard.py +6 -5
  36. huggingface_hub/repository.py +5 -5
  37. huggingface_hub/serialization/_torch.py +64 -11
  38. huggingface_hub/utils/__init__.py +13 -14
  39. huggingface_hub/utils/_cache_manager.py +97 -14
  40. huggingface_hub/utils/_fixes.py +18 -2
  41. huggingface_hub/utils/_http.py +228 -2
  42. huggingface_hub/utils/_lfs.py +110 -0
  43. huggingface_hub/utils/_runtime.py +7 -1
  44. huggingface_hub/utils/_token.py +3 -2
  45. {huggingface_hub-0.24.7.dist-info → huggingface_hub-0.25.0rc0.dist-info}/METADATA +2 -2
  46. {huggingface_hub-0.24.7.dist-info → huggingface_hub-0.25.0rc0.dist-info}/RECORD +50 -48
  47. huggingface_hub/inference/_types.py +0 -52
  48. huggingface_hub/utils/_errors.py +0 -397
  49. {huggingface_hub-0.24.7.dist-info → huggingface_hub-0.25.0rc0.dist-info}/LICENSE +0 -0
  50. {huggingface_hub-0.24.7.dist-info → huggingface_hub-0.25.0rc0.dist-info}/WHEEL +0 -0
  51. {huggingface_hub-0.24.7.dist-info → huggingface_hub-0.25.0rc0.dist-info}/entry_points.txt +0 -0
  52. {huggingface_hub-0.24.7.dist-info → huggingface_hub-0.25.0rc0.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from typing import (
32
32
  List,
33
33
  Literal,
34
34
  Optional,
35
+ Set,
35
36
  Union,
36
37
  overload,
37
38
  )
@@ -86,9 +87,6 @@ from huggingface_hub.inference._generated.types import (
86
87
  ZeroShotClassificationOutputElement,
87
88
  ZeroShotImageClassificationOutputElement,
88
89
  )
89
- from huggingface_hub.inference._types import (
90
- ConversationalOutput, # soon to be removed
91
- )
92
90
  from huggingface_hub.utils import (
93
91
  build_hf_headers,
94
92
  )
@@ -99,6 +97,7 @@ from .._common import _async_yield_from, _import_aiohttp
99
97
 
100
98
  if TYPE_CHECKING:
101
99
  import numpy as np
100
+ from aiohttp import ClientResponse, ClientSession
102
101
  from PIL.Image import Image
103
102
 
104
103
  logger = logging.getLogger(__name__)
@@ -120,7 +119,9 @@ class AsyncInferenceClient:
120
119
  or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
121
120
  automatically selected for the task.
122
121
  Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
123
- arguments are mutually exclusive and have the exact same behavior.
122
+ arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
123
+ path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
124
+ documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
124
125
  token (`str` or `bool`, *optional*):
125
126
  Hugging Face token. Will default to the locally saved token if not provided.
126
127
  Pass `token=False` if you don't want to send your token to the server.
@@ -134,6 +135,10 @@ class AsyncInferenceClient:
134
135
  Values in this dictionary will override the default values.
135
136
  cookies (`Dict[str, str]`, `optional`):
136
137
  Additional cookies to send to the server.
138
+ trust_env ('bool', 'optional'):
139
+ Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).
140
+ proxies (`Any`, `optional`):
141
+ Proxies to use for the request.
137
142
  base_url (`str`, `optional`):
138
143
  Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
139
144
  follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
@@ -151,6 +156,7 @@ class AsyncInferenceClient:
151
156
  timeout: Optional[float] = None,
152
157
  headers: Optional[Dict[str, str]] = None,
153
158
  cookies: Optional[Dict[str, str]] = None,
159
+ trust_env: bool = False,
154
160
  proxies: Optional[Any] = None,
155
161
  # OpenAI compatibility
156
162
  base_url: Optional[str] = None,
@@ -160,7 +166,8 @@ class AsyncInferenceClient:
160
166
  raise ValueError(
161
167
  "Received both `model` and `base_url` arguments. Please provide only one of them."
162
168
  " `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
163
- " It has the exact same behavior as `model`."
169
+ " If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url."
170
+ " When passing a URL as `model`, the client will not append any suffix path to it."
164
171
  )
165
172
  if token is not None and api_key is not None:
166
173
  raise ValueError(
@@ -176,11 +183,15 @@ class AsyncInferenceClient:
176
183
  self.headers.update(headers)
177
184
  self.cookies = cookies
178
185
  self.timeout = timeout
186
+ self.trust_env = trust_env
179
187
  self.proxies = proxies
180
188
 
181
189
  # OpenAI compatibility
182
190
  self.base_url = base_url
183
191
 
192
+ # Keep track of the sessions to close them properly
193
+ self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict()
194
+
184
195
  def __repr__(self):
185
196
  return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
186
197
 
@@ -265,7 +276,7 @@ class AsyncInferenceClient:
265
276
  warnings.warn("Ignoring `json` as `data` is passed as binary.")
266
277
 
267
278
  # Set Accept header if relevant
268
- headers = self.headers.copy()
279
+ headers = dict()
269
280
  if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
270
281
  headers["Accept"] = "image/png"
271
282
 
@@ -275,12 +286,10 @@ class AsyncInferenceClient:
275
286
  with _open_as_binary(data) as data_as_binary:
276
287
  # Do not use context manager as we don't want to close the connection immediately when returning
277
288
  # a stream
278
- client = aiohttp.ClientSession(
279
- headers=headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout)
280
- )
289
+ session = self._get_client_session(headers=headers)
281
290
 
282
291
  try:
283
- response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
292
+ response = await session.post(url, json=json, data=data_as_binary, proxy=self.proxies)
284
293
  response_error_payload = None
285
294
  if response.status != 200:
286
295
  try:
@@ -289,18 +298,18 @@ class AsyncInferenceClient:
289
298
  pass
290
299
  response.raise_for_status()
291
300
  if stream:
292
- return _async_yield_from(client, response)
301
+ return _async_yield_from(session, response)
293
302
  else:
294
303
  content = await response.read()
295
- await client.close()
304
+ await session.close()
296
305
  return content
297
306
  except asyncio.TimeoutError as error:
298
- await client.close()
307
+ await session.close()
299
308
  # Convert any `TimeoutError` to a `InferenceTimeoutError`
300
309
  raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
301
310
  except aiohttp.ClientResponseError as error:
302
311
  error.response_error_payload = response_error_payload
303
- await client.close()
312
+ await session.close()
304
313
  if response.status == 422 and task is not None:
305
314
  error.message += f". Make sure '{task}' task is supported by the model."
306
315
  if response.status == 503:
@@ -322,9 +331,35 @@ class AsyncInferenceClient:
322
331
  continue
323
332
  raise error
324
333
  except Exception:
325
- await client.close()
334
+ await session.close()
326
335
  raise
327
336
 
337
+ async def __aenter__(self):
338
+ return self
339
+
340
+ async def __aexit__(self, exc_type, exc_value, traceback):
341
+ await self.close()
342
+
343
+ def __del__(self):
344
+ if len(self._sessions) > 0:
345
+ warnings.warn(
346
+ "Deleting 'AsyncInferenceClient' client but some sessions are still open. "
347
+ "This can happen if you've stopped streaming data from the server before the stream was complete. "
348
+ "To close the client properly, you must call `await client.close()` "
349
+ "or use an async context (e.g. `async with AsyncInferenceClient(): ...`."
350
+ )
351
+
352
+ async def close(self):
353
+ """Close all open sessions.
354
+
355
+ By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you
356
+ are streaming data from the server and you stop before the stream is complete, you must call this method to
357
+ close the session properly.
358
+
359
+ Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`).
360
+ """
361
+ await asyncio.gather(*[session.close() for session in self._sessions.keys()])
362
+
328
363
  async def audio_classification(
329
364
  self,
330
365
  audio: ContentT,
@@ -815,134 +850,66 @@ class AsyncInferenceClient:
815
850
  '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
816
851
  ```
817
852
  """
818
- # Determine model
819
- # `self.xxx` takes precedence over the method argument only in `chat_completion`
820
- # since `chat_completion(..., model=xxx)` is also a payload parameter for the
821
- # server, we need to handle it differently
822
- model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
823
- is_url = model.startswith(("http://", "https://"))
824
-
825
- # First, resolve the model chat completions URL
826
- if model == self.base_url:
827
- # base_url passed => add server route
828
- model_url = model.rstrip("/")
829
- if not model_url.endswith("/v1"):
830
- model_url += "/v1"
831
- model_url += "/chat/completions"
832
- elif is_url:
833
- # model is a URL => use it directly
834
- model_url = model
835
- else:
836
- # model is a model ID => resolve it + add server route
837
- model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"
853
+ model_url = self._resolve_chat_completion_url(model)
838
854
 
839
855
  # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
840
856
  # If it's a ID on the Hub => use it. Otherwise, we use a random string.
841
- model_id = model if not is_url and model.count("/") == 1 else "tgi"
842
-
843
- data = await self.post(
844
- model=model_url,
845
- json=dict(
846
- model=model_id,
847
- messages=messages,
848
- frequency_penalty=frequency_penalty,
849
- logit_bias=logit_bias,
850
- logprobs=logprobs,
851
- max_tokens=max_tokens,
852
- n=n,
853
- presence_penalty=presence_penalty,
854
- response_format=response_format,
855
- seed=seed,
856
- stop=stop,
857
- temperature=temperature,
858
- tool_choice=tool_choice,
859
- tool_prompt=tool_prompt,
860
- tools=tools,
861
- top_logprobs=top_logprobs,
862
- top_p=top_p,
863
- stream=stream,
864
- ),
857
+ model_id = model or self.model or "tgi"
858
+ if model_id.startswith(("http://", "https://")):
859
+ model_id = "tgi" # dummy value
860
+
861
+ payload = dict(
862
+ model=model_id,
863
+ messages=messages,
864
+ frequency_penalty=frequency_penalty,
865
+ logit_bias=logit_bias,
866
+ logprobs=logprobs,
867
+ max_tokens=max_tokens,
868
+ n=n,
869
+ presence_penalty=presence_penalty,
870
+ response_format=response_format,
871
+ seed=seed,
872
+ stop=stop,
873
+ temperature=temperature,
874
+ tool_choice=tool_choice,
875
+ tool_prompt=tool_prompt,
876
+ tools=tools,
877
+ top_logprobs=top_logprobs,
878
+ top_p=top_p,
865
879
  stream=stream,
866
880
  )
881
+ payload = {key: value for key, value in payload.items() if value is not None}
882
+ data = await self.post(model=model_url, json=payload, stream=stream)
867
883
 
868
884
  if stream:
869
885
  return _async_stream_chat_completion_response(data) # type: ignore[arg-type]
870
886
 
871
887
  return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
872
888
 
873
- async def conversational(
874
- self,
875
- text: str,
876
- generated_responses: Optional[List[str]] = None,
877
- past_user_inputs: Optional[List[str]] = None,
878
- *,
879
- parameters: Optional[Dict[str, Any]] = None,
880
- model: Optional[str] = None,
881
- ) -> ConversationalOutput:
882
- """
883
- Generate conversational responses based on the given input text (i.e. chat with the API).
884
-
885
- <Tip warning={true}>
886
-
887
- [`InferenceClient.conversational`] API is deprecated and will be removed in a future release. Please use
888
- [`InferenceClient.chat_completion`] instead.
889
+ def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
890
+ # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
891
+ # `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
892
+ model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
889
893
 
890
- </Tip>
894
+ # Resolve URL if it's a model ID
895
+ model_url = (
896
+ model_id_or_url
897
+ if model_id_or_url.startswith(("http://", "https://"))
898
+ else self._resolve_url(model_id_or_url, task="text-generation")
899
+ )
891
900
 
892
- Args:
893
- text (`str`):
894
- The last input from the user in the conversation.
895
- generated_responses (`List[str]`, *optional*):
896
- A list of strings corresponding to the earlier replies from the model. Defaults to None.
897
- past_user_inputs (`List[str]`, *optional*):
898
- A list of strings corresponding to the earlier replies from the user. Should be the same length as
899
- `generated_responses`. Defaults to None.
900
- parameters (`Dict[str, Any]`, *optional*):
901
- Additional parameters for the conversational task. Defaults to None. For more details about the available
902
- parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#conversational-task)
903
- model (`str`, *optional*):
904
- The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
905
- a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
906
- Defaults to None.
901
+ # Strip trailing /
902
+ model_url = model_url.rstrip("/")
907
903
 
908
- Returns:
909
- `Dict`: The generated conversational output.
904
+ # Append /chat/completions if not already present
905
+ if model_url.endswith("/v1"):
906
+ model_url += "/chat/completions"
910
907
 
911
- Raises:
912
- [`InferenceTimeoutError`]:
913
- If the model is unavailable or the request times out.
914
- `aiohttp.ClientResponseError`:
915
- If the request fails with an HTTP error status code other than HTTP 503.
908
+ # Append /v1/chat/completions if not already present
909
+ if not model_url.endswith("/chat/completions"):
910
+ model_url += "/v1/chat/completions"
916
911
 
917
- Example:
918
- ```py
919
- # Must be run in an async context
920
- >>> from huggingface_hub import AsyncInferenceClient
921
- >>> client = AsyncInferenceClient()
922
- >>> output = await client.conversational("Hi, who are you?")
923
- >>> output
924
- {'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
925
- >>> await client.conversational(
926
- ... "Wow, that's scary!",
927
- ... generated_responses=output["conversation"]["generated_responses"],
928
- ... past_user_inputs=output["conversation"]["past_user_inputs"],
929
- ... )
930
- ```
931
- """
932
- warnings.warn(
933
- "'InferenceClient.conversational' is deprecated and will be removed starting from huggingface_hub>=0.25. "
934
- "Please use the more appropriate 'InferenceClient.chat_completion' API instead.",
935
- FutureWarning,
936
- )
937
- payload: Dict[str, Any] = {"inputs": {"text": text}}
938
- if generated_responses is not None:
939
- payload["inputs"]["generated_responses"] = generated_responses
940
- if past_user_inputs is not None:
941
- payload["inputs"]["past_user_inputs"] = past_user_inputs
942
- if parameters is not None:
943
- payload["parameters"] = parameters
944
- response = await self.post(json=payload, model=model, task="conversational")
945
- return _bytes_to_dict(response) # type: ignore
912
+ return model_url
946
913
 
947
914
  async def document_question_answering(
948
915
  self,
@@ -1373,8 +1340,8 @@ class AsyncInferenceClient:
1373
1340
  models_by_task.setdefault(model["task"], []).append(model["model_id"])
1374
1341
 
1375
1342
  async def _fetch_framework(framework: str) -> None:
1376
- async with _import_aiohttp().ClientSession(headers=self.headers) as client:
1377
- response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
1343
+ async with self._get_client_session() as client:
1344
+ response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies)
1378
1345
  response.raise_for_status()
1379
1346
  _unpack_response(framework, await response.json())
1380
1347
 
@@ -1757,7 +1724,8 @@ class AsyncInferenceClient:
1757
1724
  repetition_penalty: Optional[float] = None,
1758
1725
  return_full_text: Optional[bool] = False, # Manual default value
1759
1726
  seed: Optional[int] = None,
1760
- stop_sequences: Optional[List[str]] = None, # Same as `stop`
1727
+ stop: Optional[List[str]] = None,
1728
+ stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
1761
1729
  temperature: Optional[float] = None,
1762
1730
  top_k: Optional[int] = None,
1763
1731
  top_n_tokens: Optional[int] = None,
@@ -1786,7 +1754,8 @@ class AsyncInferenceClient:
1786
1754
  repetition_penalty: Optional[float] = None,
1787
1755
  return_full_text: Optional[bool] = False, # Manual default value
1788
1756
  seed: Optional[int] = None,
1789
- stop_sequences: Optional[List[str]] = None, # Same as `stop`
1757
+ stop: Optional[List[str]] = None,
1758
+ stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
1790
1759
  temperature: Optional[float] = None,
1791
1760
  top_k: Optional[int] = None,
1792
1761
  top_n_tokens: Optional[int] = None,
@@ -1815,7 +1784,8 @@ class AsyncInferenceClient:
1815
1784
  repetition_penalty: Optional[float] = None,
1816
1785
  return_full_text: Optional[bool] = False, # Manual default value
1817
1786
  seed: Optional[int] = None,
1818
- stop_sequences: Optional[List[str]] = None, # Same as `stop`
1787
+ stop: Optional[List[str]] = None,
1788
+ stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
1819
1789
  temperature: Optional[float] = None,
1820
1790
  top_k: Optional[int] = None,
1821
1791
  top_n_tokens: Optional[int] = None,
@@ -1844,7 +1814,8 @@ class AsyncInferenceClient:
1844
1814
  repetition_penalty: Optional[float] = None,
1845
1815
  return_full_text: Optional[bool] = False, # Manual default value
1846
1816
  seed: Optional[int] = None,
1847
- stop_sequences: Optional[List[str]] = None, # Same as `stop`
1817
+ stop: Optional[List[str]] = None,
1818
+ stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
1848
1819
  temperature: Optional[float] = None,
1849
1820
  top_k: Optional[int] = None,
1850
1821
  top_n_tokens: Optional[int] = None,
@@ -1873,7 +1844,8 @@ class AsyncInferenceClient:
1873
1844
  repetition_penalty: Optional[float] = None,
1874
1845
  return_full_text: Optional[bool] = False, # Manual default value
1875
1846
  seed: Optional[int] = None,
1876
- stop_sequences: Optional[List[str]] = None, # Same as `stop`
1847
+ stop: Optional[List[str]] = None,
1848
+ stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
1877
1849
  temperature: Optional[float] = None,
1878
1850
  top_k: Optional[int] = None,
1879
1851
  top_n_tokens: Optional[int] = None,
@@ -1901,7 +1873,8 @@ class AsyncInferenceClient:
1901
1873
  repetition_penalty: Optional[float] = None,
1902
1874
  return_full_text: Optional[bool] = False, # Manual default value
1903
1875
  seed: Optional[int] = None,
1904
- stop_sequences: Optional[List[str]] = None, # Same as `stop`
1876
+ stop: Optional[List[str]] = None,
1877
+ stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
1905
1878
  temperature: Optional[float] = None,
1906
1879
  top_k: Optional[int] = None,
1907
1880
  top_n_tokens: Optional[int] = None,
@@ -1966,8 +1939,10 @@ class AsyncInferenceClient:
1966
1939
  Whether to prepend the prompt to the generated text
1967
1940
  seed (`int`, *optional*):
1968
1941
  Random sampling seed
1942
+ stop (`List[str]`, *optional*):
1943
+ Stop generating tokens if a member of `stop` is generated.
1969
1944
  stop_sequences (`List[str]`, *optional*):
1970
- Stop generating tokens if a member of `stop_sequences` is generated
1945
+ Deprecated argument. Use `stop` instead.
1971
1946
  temperature (`float`, *optional*):
1972
1947
  The value used to module the logits distribution.
1973
1948
  top_n_tokens (`int`, *optional*):
@@ -2112,6 +2087,15 @@ class AsyncInferenceClient:
2112
2087
  )
2113
2088
  decoder_input_details = False
2114
2089
 
2090
+ if stop_sequences is not None:
2091
+ warnings.warn(
2092
+ "`stop_sequences` is a deprecated argument for `text_generation` task"
2093
+ " and will be removed in version '0.28.0'. Use `stop` instead.",
2094
+ FutureWarning,
2095
+ )
2096
+ if stop is None:
2097
+ stop = stop_sequences # use deprecated arg if provided
2098
+
2115
2099
  # Build payload
2116
2100
  parameters = {
2117
2101
  "adapter_id": adapter_id,
@@ -2125,7 +2109,7 @@ class AsyncInferenceClient:
2125
2109
  "repetition_penalty": repetition_penalty,
2126
2110
  "return_full_text": return_full_text,
2127
2111
  "seed": seed,
2128
- "stop": stop_sequences if stop_sequences is not None else [],
2112
+ "stop": stop if stop is not None else [],
2129
2113
  "temperature": temperature,
2130
2114
  "top_k": top_k,
2131
2115
  "top_n_tokens": top_n_tokens,
@@ -2195,7 +2179,7 @@ class AsyncInferenceClient:
2195
2179
  repetition_penalty=repetition_penalty,
2196
2180
  return_full_text=return_full_text,
2197
2181
  seed=seed,
2198
- stop_sequences=stop_sequences,
2182
+ stop=stop,
2199
2183
  temperature=temperature,
2200
2184
  top_k=top_k,
2201
2185
  top_n_tokens=top_n_tokens,
@@ -2655,6 +2639,47 @@ class AsyncInferenceClient:
2655
2639
  )
2656
2640
  return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
2657
2641
 
2642
+ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
2643
+ aiohttp = _import_aiohttp()
2644
+ client_headers = self.headers.copy()
2645
+ if headers is not None:
2646
+ client_headers.update(headers)
2647
+
2648
+ # Return a new aiohttp ClientSession with correct settings.
2649
+ session = aiohttp.ClientSession(
2650
+ headers=client_headers,
2651
+ cookies=self.cookies,
2652
+ timeout=aiohttp.ClientTimeout(self.timeout),
2653
+ trust_env=self.trust_env,
2654
+ )
2655
+
2656
+ # Keep track of sessions to close them later
2657
+ self._sessions[session] = set()
2658
+
2659
+ # Override the `._request` method to register responses to be closed
2660
+ session._wrapped_request = session._request
2661
+
2662
+ async def _request(method, url, **kwargs):
2663
+ response = await session._wrapped_request(method, url, **kwargs)
2664
+ self._sessions[session].add(response)
2665
+ return response
2666
+
2667
+ session._request = _request
2668
+
2669
+ # Override the 'close' method to
2670
+ # 1. close ongoing responses
2671
+ # 2. deregister the session when closed
2672
+ session._close = session.close
2673
+
2674
+ async def close_session():
2675
+ for response in self._sessions[session]:
2676
+ response.close()
2677
+ await session._close()
2678
+ self._sessions.pop(session, None)
2679
+
2680
+ session.close = close_session
2681
+ return session
2682
+
2658
2683
  def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
2659
2684
  model = model or self.model or self.base_url
2660
2685
 
@@ -2761,8 +2786,8 @@ class AsyncInferenceClient:
2761
2786
  else:
2762
2787
  url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
2763
2788
 
2764
- async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2765
- response = await client.get(url)
2789
+ async with self._get_client_session() as client:
2790
+ response = await client.get(url, proxy=self.proxies)
2766
2791
  response.raise_for_status()
2767
2792
  return await response.json()
2768
2793
 
@@ -2798,8 +2823,8 @@ class AsyncInferenceClient:
2798
2823
  )
2799
2824
  url = model.rstrip("/") + "/health"
2800
2825
 
2801
- async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2802
- response = await client.get(url)
2826
+ async with self._get_client_session() as client:
2827
+ response = await client.get(url, proxy=self.proxies)
2803
2828
  return response.status == 200
2804
2829
 
2805
2830
  async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
@@ -2840,8 +2865,8 @@ class AsyncInferenceClient:
2840
2865
  raise NotImplementedError("Model status is only available for Inference API endpoints.")
2841
2866
  url = f"{INFERENCE_ENDPOINT}/status/{model}"
2842
2867
 
2843
- async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2844
- response = await client.get(url)
2868
+ async with self._get_client_session() as client:
2869
+ response = await client.get(url, proxy=self.proxies)
2845
2870
  response.raise_for_status()
2846
2871
  response_data = await response.json()
2847
2872
 
@@ -15,7 +15,6 @@
15
15
 
16
16
  import inspect
17
17
  import json
18
- import warnings
19
18
  from dataclasses import asdict, dataclass
20
19
  from typing import Any, Dict, List, Type, TypeVar, Union, get_args
21
20
 
@@ -135,14 +134,6 @@ class BaseInferenceType(dict):
135
134
  self[__name] = __value
136
135
  return
137
136
 
138
- def __getitem__(self, __key: Any) -> Any:
139
- warnings.warn(
140
- f"Accessing '{self.__class__.__name__}' values through dict is deprecated and "
141
- "will be removed from version '0.25'. Use dataclass attributes instead.",
142
- FutureWarning,
143
- )
144
- return super().__getitem__(__key)
145
-
146
137
 
147
138
  def normalize_key(key: str) -> str:
148
139
  # e.g "content-type" -> "content_type", "Accept" -> "accept"
@@ -1,9 +1,8 @@
1
1
  from functools import lru_cache
2
2
  from typing import Callable, Dict, List, Optional, Union
3
3
 
4
- from huggingface_hub.errors import TemplateError
5
-
6
- from ..utils import HfHubHTTPError, RepositoryNotFoundError, is_minijinja_available
4
+ from ..errors import HfHubHTTPError, RepositoryNotFoundError, TemplateError
5
+ from ..utils import is_minijinja_available
7
6
 
8
7
 
9
8
  def _import_minijinja():
@@ -1,7 +1,7 @@
1
1
  import io
2
2
  from typing import Any, Dict, List, Optional, Union
3
3
 
4
- from .constants import INFERENCE_ENDPOINT
4
+ from . import constants
5
5
  from .hf_api import HfApi
6
6
  from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args
7
7
  from .utils._deprecation import _deprecate_method
@@ -149,7 +149,7 @@ class InferenceApi:
149
149
  assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None"
150
150
  self.task = model_info.pipeline_tag
151
151
 
152
- self.api_url = f"{INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
152
+ self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
153
153
 
154
154
  def __repr__(self):
155
155
  # Do not add headers to repr to avoid leaking token.
@@ -16,7 +16,7 @@ from huggingface_hub.utils import (
16
16
  yaml_dump,
17
17
  )
18
18
 
19
- from .constants import CONFIG_NAME
19
+ from . import constants
20
20
  from .hf_api import HfApi
21
21
  from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
22
22
  from .utils._typing import CallableT
@@ -202,7 +202,7 @@ def save_pretrained_keras(
202
202
  if not isinstance(config, dict):
203
203
  raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'")
204
204
 
205
- with (save_directory / CONFIG_NAME).open("w") as f:
205
+ with (save_directory / constants.CONFIG_NAME).open("w") as f:
206
206
  json.dump(config, f)
207
207
 
208
208
  metadata = {}