huggingface-hub 0.24.0rc0__py3-none-any.whl → 0.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

@@ -46,7 +46,7 @@ import sys
46
46
  from typing import TYPE_CHECKING
47
47
 
48
48
 
49
- __version__ = "0.24.0.rc0"
49
+ __version__ = "0.24.1"
50
50
 
51
51
  # Alphabetical order of definitions is ensured in tests
52
52
  # WARNING: any comment added in this dictionary definition will be lost when
huggingface_hub/hf_api.py CHANGED
@@ -149,7 +149,6 @@ ExpandModelProperty_T = Literal[
149
149
  "downloads",
150
150
  "downloadsAllTime",
151
151
  "gated",
152
- "gitalyUid",
153
152
  "inference",
154
153
  "lastModified",
155
154
  "library_name",
@@ -177,7 +176,6 @@ ExpandDatasetProperty_T = Literal[
177
176
  "downloads",
178
177
  "downloadsAllTime",
179
178
  "gated",
180
- "gitalyUid",
181
179
  "lastModified",
182
180
  "likes",
183
181
  "paperswithcode_id",
@@ -192,7 +190,6 @@ ExpandSpaceProperty_T = Literal[
192
190
  "cardData",
193
191
  "datasets",
194
192
  "disabled",
195
- "gitalyUid",
196
193
  "lastModified",
197
194
  "createdAt",
198
195
  "likes",
@@ -1633,7 +1630,7 @@ class HfApi:
1633
1630
  expand (`List[ExpandModelProperty_T]`, *optional*):
1634
1631
  List properties to return in the response. When used, only the properties in the list will be returned.
1635
1632
  This parameter cannot be used if `full`, `cardData` or `fetch_config` are passed.
1636
- Possible values are `"author"`, `"cardData"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gitalyUid"`, `"inference"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"` and `"widgetData"`.
1633
+ Possible values are `"author"`, `"cardData"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"inference"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"` and `"widgetData"`.
1637
1634
  full (`bool`, *optional*):
1638
1635
  Whether to fetch all model data, including the `last_modified`,
1639
1636
  the `sha`, the files and the `tags`. This is set to `True` by
@@ -1836,7 +1833,7 @@ class HfApi:
1836
1833
  expand (`List[ExpandDatasetProperty_T]`, *optional*):
1837
1834
  List properties to return in the response. When used, only the properties in the list will be returned.
1838
1835
  This parameter cannot be used if `full` is passed.
1839
- Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gitalyUid"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"` and `"tags"`.
1836
+ Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"` and `"tags"`.
1840
1837
  full (`bool`, *optional*):
1841
1838
  Whether to fetch all dataset data, including the `last_modified`,
1842
1839
  the `card_data` and the files. Can contain useful information such as the
@@ -2017,7 +2014,7 @@ class HfApi:
2017
2014
  expand (`List[ExpandSpaceProperty_T]`, *optional*):
2018
2015
  List properties to return in the response. When used, only the properties in the list will be returned.
2019
2016
  This parameter cannot be used if `full` is passed.
2020
- Possible values are `"author"`, `"cardData"`, `"datasets"`, `"disabled"`, `"gitalyUid"`, `"lastModified"`, `"createdAt"`, `"likes"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"` and `"models"`.
2017
+ Possible values are `"author"`, `"cardData"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"createdAt"`, `"likes"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"` and `"models"`.
2021
2018
  full (`bool`, *optional*):
2022
2019
  Whether to fetch all Spaces data, including the `last_modified`, `siblings`
2023
2020
  and `card_data` fields.
@@ -2334,7 +2331,7 @@ class HfApi:
2334
2331
  expand (`List[ExpandModelProperty_T]`, *optional*):
2335
2332
  List properties to return in the response. When used, only the properties in the list will be returned.
2336
2333
  This parameter cannot be used if `securityStatus` or `files_metadata` are passed.
2337
- Possible values are `"author"`, `"cardData"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gitalyUid"`, `"inference"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"` and `"widgetData"`.
2334
+ Possible values are `"author"`, `"cardData"`, `"config"`, `"createdAt"`, `"disabled"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"inference"`, `"lastModified"`, `"library_name"`, `"likes"`, `"mask_token"`, `"model-index"`, `"pipeline_tag"`, `"private"`, `"safetensors"`, `"sha"`, `"siblings"`, `"spaces"`, `"tags"`, `"transformersInfo"` and `"widgetData"`.
2338
2335
  token (Union[bool, str, None], optional):
2339
2336
  A valid user access token (string). Defaults to the locally saved
2340
2337
  token, which is the recommended method for authentication (see
@@ -2408,7 +2405,7 @@ class HfApi:
2408
2405
  expand (`List[ExpandDatasetProperty_T]`, *optional*):
2409
2406
  List properties to return in the response. When used, only the properties in the list will be returned.
2410
2407
  This parameter cannot be used if `files_metadata` is passed.
2411
- Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"gitalyUid"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"` and `"tags"`.
2408
+ Possible values are `"author"`, `"cardData"`, `"citation"`, `"createdAt"`, `"disabled"`, `"description"`, `"downloads"`, `"downloadsAllTime"`, `"gated"`, `"lastModified"`, `"likes"`, `"paperswithcode_id"`, `"private"`, `"siblings"`, `"sha"` and `"tags"`.
2412
2409
  token (Union[bool, str, None], optional):
2413
2410
  A valid user access token (string). Defaults to the locally saved
2414
2411
  token, which is the recommended method for authentication (see
@@ -2481,7 +2478,7 @@ class HfApi:
2481
2478
  expand (`List[ExpandSpaceProperty_T]`, *optional*):
2482
2479
  List properties to return in the response. When used, only the properties in the list will be returned.
2483
2480
  This parameter cannot be used if `full` is passed.
2484
- Possible values are `"author"`, `"cardData"`, `"datasets"`, `"disabled"`, `"gitalyUid"`, `"lastModified"`, `"createdAt"`, `"likes"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"` and `"models"`.
2481
+ Possible values are `"author"`, `"cardData"`, `"datasets"`, `"disabled"`, `"lastModified"`, `"createdAt"`, `"likes"`, `"private"`, `"runtime"`, `"sdk"`, `"siblings"`, `"sha"`, `"subdomain"`, `"tags"` and `"models"`.
2485
2482
  token (Union[bool, str, None], optional):
2486
2483
  A valid user access token (string). Defaults to the locally saved
2487
2484
  token, which is the recommended method for authentication (see
@@ -66,11 +66,9 @@ from huggingface_hub.inference._common import (
66
66
  _fetch_recommended_models,
67
67
  _get_unsupported_text_generation_kwargs,
68
68
  _import_numpy,
69
- _is_chat_completion_server,
70
69
  _open_as_binary,
71
- _set_as_non_chat_completion_server,
72
70
  _set_unsupported_text_generation_kwargs,
73
- _stream_chat_completion_response_from_bytes,
71
+ _stream_chat_completion_response,
74
72
  _stream_text_generation_response,
75
73
  raise_text_generation_error,
76
74
  )
@@ -82,8 +80,6 @@ from huggingface_hub.inference._generated.types import (
82
80
  ChatCompletionInputTool,
83
81
  ChatCompletionInputToolTypeClass,
84
82
  ChatCompletionOutput,
85
- ChatCompletionOutputComplete,
86
- ChatCompletionOutputMessage,
87
83
  ChatCompletionStreamOutput,
88
84
  DocumentQuestionAnsweringOutputElement,
89
85
  FillMaskOutputElement,
@@ -189,7 +185,7 @@ class InferenceClient:
189
185
  )
190
186
 
191
187
  self.model: Optional[str] = model
192
- self.token: Union[str, bool, None] = token or api_key
188
+ self.token: Union[str, bool, None] = token if token is not None else api_key
193
189
  self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
194
190
  if headers is not None:
195
191
  self.headers.update(headers)
@@ -818,123 +814,52 @@ class InferenceClient:
818
814
  # since `chat_completion(..., model=xxx)` is also a payload parameter for the
819
815
  # server, we need to handle it differently
820
816
  model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
817
+ is_url = model.startswith(("http://", "https://"))
818
+
819
+ # First, resolve the model chat completions URL
820
+ if model == self.base_url:
821
+ # base_url passed => add server route
822
+ model_url = model + "/v1/chat/completions"
823
+ elif is_url:
824
+ # model is a URL => use it directly
825
+ model_url = model
826
+ else:
827
+ # model is a model ID => resolve it + add server route
828
+ model_url = self._resolve_url(model) + "/v1/chat/completions"
829
+
830
+ # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
831
+ # If it's a ID on the Hub => use it. Otherwise, we use a random string.
832
+ model_id = model if not is_url and model.count("/") == 1 else "tgi"
833
+
834
+ data = self.post(
835
+ model=model_url,
836
+ json=dict(
837
+ model=model_id,
838
+ messages=messages,
839
+ frequency_penalty=frequency_penalty,
840
+ logit_bias=logit_bias,
841
+ logprobs=logprobs,
842
+ max_tokens=max_tokens,
843
+ n=n,
844
+ presence_penalty=presence_penalty,
845
+ response_format=response_format,
846
+ seed=seed,
847
+ stop=stop,
848
+ temperature=temperature,
849
+ tool_choice=tool_choice,
850
+ tool_prompt=tool_prompt,
851
+ tools=tools,
852
+ top_logprobs=top_logprobs,
853
+ top_p=top_p,
854
+ stream=stream,
855
+ ),
856
+ stream=stream,
857
+ )
821
858
 
822
- if _is_chat_completion_server(model):
823
- # First, let's consider the server has a `/v1/chat/completions` endpoint.
824
- # If that's the case, we don't have to render the chat template client-side.
825
- model_url = self._resolve_url(model)
826
- if not model_url.endswith("/chat/completions"):
827
- model_url += "/v1/chat/completions"
828
-
829
- # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
830
- if not model.startswith("http") and model.count("/") == 1:
831
- # If it's a ID on the Hub => use it
832
- model_id = model
833
- else:
834
- # Otherwise, we use a random string
835
- model_id = "tgi"
836
-
837
- try:
838
- data = self.post(
839
- model=model_url,
840
- json=dict(
841
- model=model_id,
842
- messages=messages,
843
- frequency_penalty=frequency_penalty,
844
- logit_bias=logit_bias,
845
- logprobs=logprobs,
846
- max_tokens=max_tokens,
847
- n=n,
848
- presence_penalty=presence_penalty,
849
- response_format=response_format,
850
- seed=seed,
851
- stop=stop,
852
- temperature=temperature,
853
- tool_choice=tool_choice,
854
- tool_prompt=tool_prompt,
855
- tools=tools,
856
- top_logprobs=top_logprobs,
857
- top_p=top_p,
858
- stream=stream,
859
- ),
860
- stream=stream,
861
- )
862
- except HTTPError as e:
863
- if e.response.status_code in (400, 404, 500):
864
- # Let's consider the server is not a chat completion server.
865
- # Then we call again `chat_completion` which will render the chat template client side.
866
- # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
867
- _set_as_non_chat_completion_server(model)
868
- logger.warning(
869
- f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
870
- )
871
- return self.chat_completion(
872
- messages=messages,
873
- model=model,
874
- stream=stream,
875
- max_tokens=max_tokens,
876
- seed=seed,
877
- stop=stop,
878
- temperature=temperature,
879
- top_p=top_p,
880
- )
881
- raise
882
-
883
- if stream:
884
- return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
885
-
886
- return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
887
-
888
- # At this point, we know the server is not a chat completion server.
889
- # It means it's a transformers-backed server for which we can send a list of messages directly to the
890
- # `text-generation` pipeline. We won't receive a detailed response but only the generated text.
891
859
  if stream:
892
- raise ValueError(
893
- "Streaming token is not supported by the model. This is due to the model not been served by a "
894
- "Text-Generation-Inference server. Please pass `stream=False` as input."
895
- )
896
- if tool_choice is not None or tool_prompt is not None or tools is not None:
897
- warnings.warn(
898
- "Tools are not supported by the model. This is due to the model not been served by a "
899
- "Text-Generation-Inference server. The provided tool parameters will be ignored."
900
- )
901
- if response_format is not None:
902
- warnings.warn(
903
- "Response format is not supported by the model. This is due to the model not been served by a "
904
- "Text-Generation-Inference server. The provided response format will be ignored."
905
- )
860
+ return _stream_chat_completion_response(data) # type: ignore[arg-type]
906
861
 
907
- # generate response
908
- text_generation_output = self.text_generation(
909
- prompt=messages, # type: ignore # Not correct type but works implicitly
910
- model=model,
911
- stream=False,
912
- details=False,
913
- max_new_tokens=max_tokens,
914
- seed=seed,
915
- stop_sequences=stop,
916
- temperature=temperature,
917
- top_p=top_p,
918
- )
919
-
920
- # Format as a ChatCompletionOutput with dummy values for fields we can't provide
921
- return ChatCompletionOutput(
922
- id="dummy",
923
- model="dummy",
924
- system_fingerprint="dummy",
925
- usage=None, # type: ignore # set to `None` as we don't want to provide false information
926
- created=int(time.time()),
927
- choices=[
928
- ChatCompletionOutputComplete(
929
- finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
930
- index=0,
931
- message=ChatCompletionOutputMessage(
932
- content=text_generation_output,
933
- role="assistant",
934
- ),
935
- )
936
- ],
937
- )
862
+ return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
938
863
 
939
864
  def conversational(
940
865
  self,
@@ -2251,7 +2176,12 @@ class InferenceClient:
2251
2176
  if stream:
2252
2177
  return _stream_text_generation_response(bytes_output, details) # type: ignore
2253
2178
 
2254
- data = _bytes_to_dict(bytes_output)[0] # type: ignore[arg-type]
2179
+ data = _bytes_to_dict(bytes_output) # type: ignore[arg-type]
2180
+
2181
+ # Data can be a single element (dict) or an iterable of dicts where we select the first element of.
2182
+ if isinstance(data, list):
2183
+ data = data[0]
2184
+
2255
2185
  return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
2256
2186
 
2257
2187
  def text_to_image(
@@ -34,7 +34,6 @@ from typing import (
34
34
  Literal,
35
35
  NoReturn,
36
36
  Optional,
37
- Set,
38
37
  Union,
39
38
  overload,
40
39
  )
@@ -61,8 +60,6 @@ from ..utils import (
61
60
  )
62
61
  from ._generated.types import (
63
62
  ChatCompletionStreamOutput,
64
- ChatCompletionStreamOutputChoice,
65
- ChatCompletionStreamOutputDelta,
66
63
  TextGenerationStreamOutput,
67
64
  )
68
65
 
@@ -271,7 +268,10 @@ def _stream_text_generation_response(
271
268
  """Used in `InferenceClient.text_generation`."""
272
269
  # Parse ServerSentEvents
273
270
  for byte_payload in bytes_output_as_lines:
274
- output = _format_text_generation_stream_output(byte_payload, details)
271
+ try:
272
+ output = _format_text_generation_stream_output(byte_payload, details)
273
+ except StopIteration:
274
+ break
275
275
  if output is not None:
276
276
  yield output
277
277
 
@@ -282,7 +282,10 @@ async def _async_stream_text_generation_response(
282
282
  """Used in `AsyncInferenceClient.text_generation`."""
283
283
  # Parse ServerSentEvents
284
284
  async for byte_payload in bytes_output_as_lines:
285
- output = _format_text_generation_stream_output(byte_payload, details)
285
+ try:
286
+ output = _format_text_generation_stream_output(byte_payload, details)
287
+ except StopIteration:
288
+ break
286
289
  if output is not None:
287
290
  yield output
288
291
 
@@ -293,6 +296,9 @@ def _format_text_generation_stream_output(
293
296
  if not byte_payload.startswith(b"data:"):
294
297
  return None # empty line
295
298
 
299
+ if byte_payload == b"data: [DONE]":
300
+ raise StopIteration("[DONE] signal received.")
301
+
296
302
  # Decode payload
297
303
  payload = byte_payload.decode("utf-8")
298
304
  json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
@@ -306,72 +312,41 @@ def _format_text_generation_stream_output(
306
312
  return output.token.text if not details else output
307
313
 
308
314
 
309
- def _format_chat_completion_stream_output_from_text_generation(
310
- item: TextGenerationStreamOutput, created: int
311
- ) -> ChatCompletionStreamOutput:
312
- if item.details is None:
313
- # new token generated => return delta
314
- return ChatCompletionStreamOutput(
315
- # explicitly set 'dummy' values to reduce expectations from users
316
- id="dummy",
317
- model="dummy",
318
- system_fingerprint="dummy",
319
- choices=[
320
- ChatCompletionStreamOutputChoice(
321
- delta=ChatCompletionStreamOutputDelta(
322
- role="assistant",
323
- content=item.token.text,
324
- ),
325
- finish_reason=None,
326
- index=0,
327
- )
328
- ],
329
- created=created,
330
- )
331
- else:
332
- # generation is completed => return finish reason
333
- return ChatCompletionStreamOutput(
334
- # explicitly set 'dummy' values to reduce expectations from users
335
- id="dummy",
336
- model="dummy",
337
- system_fingerprint="dummy",
338
- choices=[
339
- ChatCompletionStreamOutputChoice(
340
- delta=ChatCompletionStreamOutputDelta(role="assistant"),
341
- finish_reason=item.details.finish_reason,
342
- index=0,
343
- )
344
- ],
345
- created=created,
346
- )
347
-
348
-
349
- def _stream_chat_completion_response_from_bytes(
315
+ def _stream_chat_completion_response(
350
316
  bytes_lines: Iterable[bytes],
351
317
  ) -> Iterable[ChatCompletionStreamOutput]:
352
318
  """Used in `InferenceClient.chat_completion` if model is served with TGI."""
353
319
  for item in bytes_lines:
354
- output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
320
+ try:
321
+ output = _format_chat_completion_stream_output(item)
322
+ except StopIteration:
323
+ break
355
324
  if output is not None:
356
325
  yield output
357
326
 
358
327
 
359
- async def _async_stream_chat_completion_response_from_bytes(
328
+ async def _async_stream_chat_completion_response(
360
329
  bytes_lines: AsyncIterable[bytes],
361
330
  ) -> AsyncIterable[ChatCompletionStreamOutput]:
362
331
  """Used in `AsyncInferenceClient.chat_completion`."""
363
332
  async for item in bytes_lines:
364
- output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
333
+ try:
334
+ output = _format_chat_completion_stream_output(item)
335
+ except StopIteration:
336
+ break
365
337
  if output is not None:
366
338
  yield output
367
339
 
368
340
 
369
- def _format_chat_completion_stream_output_from_text_generation_from_bytes(
341
+ def _format_chat_completion_stream_output(
370
342
  byte_payload: bytes,
371
343
  ) -> Optional[ChatCompletionStreamOutput]:
372
344
  if not byte_payload.startswith(b"data:"):
373
345
  return None # empty line
374
346
 
347
+ if byte_payload == b"data: [DONE]":
348
+ raise StopIteration("[DONE] signal received.")
349
+
375
350
  # Decode payload
376
351
  payload = byte_payload.decode("utf-8")
377
352
  json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
@@ -413,17 +388,6 @@ def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]:
413
388
  return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])
414
389
 
415
390
 
416
- _NON_CHAT_COMPLETION_SERVER: Set[str] = set()
417
-
418
-
419
- def _set_as_non_chat_completion_server(model: str) -> None:
420
- _NON_CHAT_COMPLETION_SERVER.add(model)
421
-
422
-
423
- def _is_chat_completion_server(model: str) -> bool:
424
- return model not in _NON_CHAT_COMPLETION_SERVER
425
-
426
-
427
391
  # TEXT GENERATION ERRORS
428
392
  # ----------------------
429
393
  # Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
@@ -44,7 +44,7 @@ from huggingface_hub.inference._common import (
44
44
  TASKS_EXPECTING_IMAGES,
45
45
  ContentT,
46
46
  ModelStatus,
47
- _async_stream_chat_completion_response_from_bytes,
47
+ _async_stream_chat_completion_response,
48
48
  _async_stream_text_generation_response,
49
49
  _b64_encode,
50
50
  _b64_to_image,
@@ -54,9 +54,7 @@ from huggingface_hub.inference._common import (
54
54
  _fetch_recommended_models,
55
55
  _get_unsupported_text_generation_kwargs,
56
56
  _import_numpy,
57
- _is_chat_completion_server,
58
57
  _open_as_binary,
59
- _set_as_non_chat_completion_server,
60
58
  _set_unsupported_text_generation_kwargs,
61
59
  raise_text_generation_error,
62
60
  )
@@ -68,8 +66,6 @@ from huggingface_hub.inference._generated.types import (
68
66
  ChatCompletionInputTool,
69
67
  ChatCompletionInputToolTypeClass,
70
68
  ChatCompletionOutput,
71
- ChatCompletionOutputComplete,
72
- ChatCompletionOutputMessage,
73
69
  ChatCompletionStreamOutput,
74
70
  DocumentQuestionAnsweringOutputElement,
75
71
  FillMaskOutputElement,
@@ -174,7 +170,7 @@ class AsyncInferenceClient:
174
170
  )
175
171
 
176
172
  self.model: Optional[str] = model
177
- self.token: Union[str, bool, None] = token or api_key
173
+ self.token: Union[str, bool, None] = token if token is not None else api_key
178
174
  self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
179
175
  if headers is not None:
180
176
  self.headers.update(headers)
@@ -824,123 +820,52 @@ class AsyncInferenceClient:
824
820
  # since `chat_completion(..., model=xxx)` is also a payload parameter for the
825
821
  # server, we need to handle it differently
826
822
  model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
823
+ is_url = model.startswith(("http://", "https://"))
824
+
825
+ # First, resolve the model chat completions URL
826
+ if model == self.base_url:
827
+ # base_url passed => add server route
828
+ model_url = model + "/v1/chat/completions"
829
+ elif is_url:
830
+ # model is a URL => use it directly
831
+ model_url = model
832
+ else:
833
+ # model is a model ID => resolve it + add server route
834
+ model_url = self._resolve_url(model) + "/v1/chat/completions"
835
+
836
+ # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
837
+ # If it's a ID on the Hub => use it. Otherwise, we use a random string.
838
+ model_id = model if not is_url and model.count("/") == 1 else "tgi"
839
+
840
+ data = await self.post(
841
+ model=model_url,
842
+ json=dict(
843
+ model=model_id,
844
+ messages=messages,
845
+ frequency_penalty=frequency_penalty,
846
+ logit_bias=logit_bias,
847
+ logprobs=logprobs,
848
+ max_tokens=max_tokens,
849
+ n=n,
850
+ presence_penalty=presence_penalty,
851
+ response_format=response_format,
852
+ seed=seed,
853
+ stop=stop,
854
+ temperature=temperature,
855
+ tool_choice=tool_choice,
856
+ tool_prompt=tool_prompt,
857
+ tools=tools,
858
+ top_logprobs=top_logprobs,
859
+ top_p=top_p,
860
+ stream=stream,
861
+ ),
862
+ stream=stream,
863
+ )
827
864
 
828
- if _is_chat_completion_server(model):
829
- # First, let's consider the server has a `/v1/chat/completions` endpoint.
830
- # If that's the case, we don't have to render the chat template client-side.
831
- model_url = self._resolve_url(model)
832
- if not model_url.endswith("/chat/completions"):
833
- model_url += "/v1/chat/completions"
834
-
835
- # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
836
- if not model.startswith("http") and model.count("/") == 1:
837
- # If it's a ID on the Hub => use it
838
- model_id = model
839
- else:
840
- # Otherwise, we use a random string
841
- model_id = "tgi"
842
-
843
- try:
844
- data = await self.post(
845
- model=model_url,
846
- json=dict(
847
- model=model_id,
848
- messages=messages,
849
- frequency_penalty=frequency_penalty,
850
- logit_bias=logit_bias,
851
- logprobs=logprobs,
852
- max_tokens=max_tokens,
853
- n=n,
854
- presence_penalty=presence_penalty,
855
- response_format=response_format,
856
- seed=seed,
857
- stop=stop,
858
- temperature=temperature,
859
- tool_choice=tool_choice,
860
- tool_prompt=tool_prompt,
861
- tools=tools,
862
- top_logprobs=top_logprobs,
863
- top_p=top_p,
864
- stream=stream,
865
- ),
866
- stream=stream,
867
- )
868
- except _import_aiohttp().ClientResponseError as e:
869
- if e.status in (400, 404, 500):
870
- # Let's consider the server is not a chat completion server.
871
- # Then we call again `chat_completion` which will render the chat template client side.
872
- # (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
873
- _set_as_non_chat_completion_server(model)
874
- logger.warning(
875
- f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
876
- )
877
- return await self.chat_completion(
878
- messages=messages,
879
- model=model,
880
- stream=stream,
881
- max_tokens=max_tokens,
882
- seed=seed,
883
- stop=stop,
884
- temperature=temperature,
885
- top_p=top_p,
886
- )
887
- raise
888
-
889
- if stream:
890
- return _async_stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
891
-
892
- return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
893
-
894
- # At this point, we know the server is not a chat completion server.
895
- # It means it's a transformers-backed server for which we can send a list of messages directly to the
896
- # `text-generation` pipeline. We won't receive a detailed response but only the generated text.
897
865
  if stream:
898
- raise ValueError(
899
- "Streaming token is not supported by the model. This is due to the model not been served by a "
900
- "Text-Generation-Inference server. Please pass `stream=False` as input."
901
- )
902
- if tool_choice is not None or tool_prompt is not None or tools is not None:
903
- warnings.warn(
904
- "Tools are not supported by the model. This is due to the model not been served by a "
905
- "Text-Generation-Inference server. The provided tool parameters will be ignored."
906
- )
907
- if response_format is not None:
908
- warnings.warn(
909
- "Response format is not supported by the model. This is due to the model not been served by a "
910
- "Text-Generation-Inference server. The provided response format will be ignored."
911
- )
912
-
913
- # generate response
914
- text_generation_output = await self.text_generation(
915
- prompt=messages, # type: ignore # Not correct type but works implicitly
916
- model=model,
917
- stream=False,
918
- details=False,
919
- max_new_tokens=max_tokens,
920
- seed=seed,
921
- stop_sequences=stop,
922
- temperature=temperature,
923
- top_p=top_p,
924
- )
866
+ return _async_stream_chat_completion_response(data) # type: ignore[arg-type]
925
867
 
926
- # Format as a ChatCompletionOutput with dummy values for fields we can't provide
927
- return ChatCompletionOutput(
928
- id="dummy",
929
- model="dummy",
930
- system_fingerprint="dummy",
931
- usage=None, # type: ignore # set to `None` as we don't want to provide false information
932
- created=int(time.time()),
933
- choices=[
934
- ChatCompletionOutputComplete(
935
- finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
936
- index=0,
937
- message=ChatCompletionOutputMessage(
938
- content=text_generation_output,
939
- role="assistant",
940
- ),
941
- )
942
- ],
943
- )
868
+ return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
944
869
 
945
870
  async def conversational(
946
871
  self,
@@ -2282,7 +2207,12 @@ class AsyncInferenceClient:
2282
2207
  if stream:
2283
2208
  return _async_stream_text_generation_response(bytes_output, details) # type: ignore
2284
2209
 
2285
- data = _bytes_to_dict(bytes_output)[0] # type: ignore[arg-type]
2210
+ data = _bytes_to_dict(bytes_output) # type: ignore[arg-type]
2211
+
2212
+ # Data can be a single element (dict) or an iterable of dicts where we select the first element of.
2213
+ if isinstance(data, list):
2214
+ data = data[0]
2215
+
2286
2216
  return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
2287
2217
 
2288
2218
  async def text_to_image(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: huggingface-hub
3
- Version: 0.24.0rc0
3
+ Version: 0.24.1
4
4
  Summary: Client library to download and publish models, datasets and other repos on the huggingface.co hub
5
5
  Home-page: https://github.com/huggingface/huggingface_hub
6
6
  Author: Hugging Face, Inc.
@@ -1,4 +1,4 @@
1
- huggingface_hub/__init__.py,sha256=Kd7XPNFlbXWrx5Pzhcvl4MqKFYd2ZGGf3_MF2tSvUsc,33901
1
+ huggingface_hub/__init__.py,sha256=Uf3KJ-RqdzyayY4T0Yxr1X26y2w-Mrm9vLUKilotLI8,33897
2
2
  huggingface_hub/_commit_api.py,sha256=Yj1ft_WbsnqjSbiYHgdqGmLTF6BTA4E8kAGYW89t2sQ,31057
3
3
  huggingface_hub/_commit_scheduler.py,sha256=nlJS_vnLb8i92NLrRwJX8Mg9QZ7f3kfLbLlQuEd5YjU,13647
4
4
  huggingface_hub/_inference_endpoints.py,sha256=th6vlJ2vUg314x7uMLzQHfy4AuX5mFlJqNobVIz5yOY,15944
@@ -15,7 +15,7 @@ huggingface_hub/constants.py,sha256=BG3n2gl4JbxMw_JRvNTFyMcNnZIPzvT3KXSH-jm2J08,
15
15
  huggingface_hub/errors.py,sha256=IM0lNbExLzaYEs0HrrPvY4-kyj6DiP2Szu7Jy9slHOE,2083
16
16
  huggingface_hub/fastai_utils.py,sha256=5I7zAfgHJU_mZnxnf9wgWTHrCRu_EAV8VTangDVfE_o,16676
17
17
  huggingface_hub/file_download.py,sha256=Lf1RhCMb9HkXPUy90O_zUc-fonmFTwE2xadbZpVoKrM,84243
18
- huggingface_hub/hf_api.py,sha256=kFN02B2AFJEhK04PvMDZdWRKhiAz9zD3JZdDdPZJgjY,406833
18
+ huggingface_hub/hf_api.py,sha256=YK4EcYD7vvGOjzAO_7pSrr2len7u4xa7yvwn6CojdIA,406692
19
19
  huggingface_hub/hf_file_system.py,sha256=HlYbWFhMrPWNqGUQfQrZR6H70QK0PgsxRvO4FantCNc,39160
20
20
  huggingface_hub/hub_mixin.py,sha256=bm5hZGeOHBSUBfiAXJv8cU05nAZr65TxnkUJLWLwAEg,37308
21
21
  huggingface_hub/inference_api.py,sha256=UXOKu_Ez2I3hDsjguqCcCrj03WFDndehpngYiIAucdg,8331
@@ -37,12 +37,12 @@ huggingface_hub/commands/tag.py,sha256=gCoR8G95lhHBzyVytTxT7MnqTmjKYtStDnHXcysOJ
37
37
  huggingface_hub/commands/upload.py,sha256=Mr69qO60otqCVw0sVSBPykUTkL9HO-pkCyulSD2mROM,13622
38
38
  huggingface_hub/commands/user.py,sha256=QApZJOCQEHADhjunM3hlQ72uqHsearCiCE4SdpzGdcc,6893
39
39
  huggingface_hub/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
- huggingface_hub/inference/_client.py,sha256=m1GX7Yd2VngZR9-RuFqudqEM3dPtKUIYCDphsHMR5Lw,132602
41
- huggingface_hub/inference/_common.py,sha256=3xbeCOjLgSPRJcbtxKnv1DNXr_TOMivOeQyvg-Ma1HU,16306
40
+ huggingface_hub/inference/_client.py,sha256=6oJjWgDIGqKK52DU7VR2fQkqYqf2UGQbHWMEqVszaZU,129014
41
+ huggingface_hub/inference/_common.py,sha256=EEF8T9jtfLvqhIwwDM0vt8S54yObExoBncJIiHvEew8,14882
42
42
  huggingface_hub/inference/_templating.py,sha256=LCy-U_25R-l5dhcEHsyRwiOrgvKQHXkdSmynWCfsPjI,3991
43
43
  huggingface_hub/inference/_types.py,sha256=C73l5-RO8P1UMBHF8OAO9CRUq7Xdv33pcADoJsGMPSU,1782
44
44
  huggingface_hub/inference/_generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
- huggingface_hub/inference/_generated/_async_client.py,sha256=adlz58-FvC0-4X9VitsWkkHeD7vnZ_HAqqx33PkciYQ,136329
45
+ huggingface_hub/inference/_generated/_async_client.py,sha256=X7dHAHJdDbAvho3_tsyOjH0spC_aJUCNeG_UclHSo_Q,132715
46
46
  huggingface_hub/inference/_generated/types/__init__.py,sha256=uEsA0z8Gcu34q0gNAZVcqHFqJT5BPrhnM9qS_LQgN0Q,5215
47
47
  huggingface_hub/inference/_generated/types/audio_classification.py,sha256=wk4kUTLQZoXWLpiUOpKRHRRE-JYqqJlzGVe62VACR-0,1347
48
48
  huggingface_hub/inference/_generated/types/audio_to_audio.py,sha256=n7GeCepzt254yoSLsdjrI1j4fzYgjWzxoaKE5gZJc48,881
@@ -107,9 +107,9 @@ huggingface_hub/utils/insecure_hashlib.py,sha256=OjxlvtSQHpbLp9PWSrXBDJ0wHjxCBU-
107
107
  huggingface_hub/utils/logging.py,sha256=Cp03s0uEl3kDM9XHQW9a8GAoExODQ-e7kEtgMt-_To8,4728
108
108
  huggingface_hub/utils/sha.py,sha256=OFnNGCba0sNcT2gUwaVCJnldxlltrHHe0DS_PCpV3C4,2134
109
109
  huggingface_hub/utils/tqdm.py,sha256=jQiVYwRG78HK4_54u0vTtz6Kt9IMGiHy3ixbIn3h2TU,9368
110
- huggingface_hub-0.24.0rc0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
111
- huggingface_hub-0.24.0rc0.dist-info/METADATA,sha256=ELk2xmUxcdKGyHyj6vzBFxzinEdBPihANDN6klqCEng,13186
112
- huggingface_hub-0.24.0rc0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
- huggingface_hub-0.24.0rc0.dist-info/entry_points.txt,sha256=Y3Z2L02rBG7va_iE6RPXolIgwOdwUFONyRN3kXMxZ0g,131
114
- huggingface_hub-0.24.0rc0.dist-info/top_level.txt,sha256=8KzlQJAY4miUvjAssOAJodqKOw3harNzuiwGQ9qLSSk,16
115
- huggingface_hub-0.24.0rc0.dist-info/RECORD,,
110
+ huggingface_hub-0.24.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
111
+ huggingface_hub-0.24.1.dist-info/METADATA,sha256=PbJAesxB3sZZtDX1HgX0keQBQBdkp66KoK2XD_U0Ga8,13183
112
+ huggingface_hub-0.24.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
+ huggingface_hub-0.24.1.dist-info/entry_points.txt,sha256=Y3Z2L02rBG7va_iE6RPXolIgwOdwUFONyRN3kXMxZ0g,131
114
+ huggingface_hub-0.24.1.dist-info/top_level.txt,sha256=8KzlQJAY4miUvjAssOAJodqKOw3harNzuiwGQ9qLSSk,16
115
+ huggingface_hub-0.24.1.dist-info/RECORD,,