huggingface-hub 0.29.0rc2__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huggingface_hub/__init__.py +160 -46
- huggingface_hub/_commit_api.py +277 -71
- huggingface_hub/_commit_scheduler.py +15 -15
- huggingface_hub/_inference_endpoints.py +33 -22
- huggingface_hub/_jobs_api.py +301 -0
- huggingface_hub/_local_folder.py +18 -3
- huggingface_hub/_login.py +31 -63
- huggingface_hub/_oauth.py +460 -0
- huggingface_hub/_snapshot_download.py +241 -81
- huggingface_hub/_space_api.py +18 -10
- huggingface_hub/_tensorboard_logger.py +15 -19
- huggingface_hub/_upload_large_folder.py +196 -76
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +15 -25
- huggingface_hub/{commands → cli}/__init__.py +1 -15
- huggingface_hub/cli/_cli_utils.py +173 -0
- huggingface_hub/cli/auth.py +147 -0
- huggingface_hub/cli/cache.py +841 -0
- huggingface_hub/cli/download.py +189 -0
- huggingface_hub/cli/hf.py +60 -0
- huggingface_hub/cli/inference_endpoints.py +377 -0
- huggingface_hub/cli/jobs.py +772 -0
- huggingface_hub/cli/lfs.py +175 -0
- huggingface_hub/cli/repo.py +315 -0
- huggingface_hub/cli/repo_files.py +94 -0
- huggingface_hub/{commands/env.py → cli/system.py} +10 -13
- huggingface_hub/cli/upload.py +294 -0
- huggingface_hub/cli/upload_large_folder.py +117 -0
- huggingface_hub/community.py +20 -12
- huggingface_hub/constants.py +83 -59
- huggingface_hub/dataclasses.py +609 -0
- huggingface_hub/errors.py +99 -30
- huggingface_hub/fastai_utils.py +30 -41
- huggingface_hub/file_download.py +606 -346
- huggingface_hub/hf_api.py +2445 -1132
- huggingface_hub/hf_file_system.py +269 -152
- huggingface_hub/hub_mixin.py +61 -66
- huggingface_hub/inference/_client.py +501 -630
- huggingface_hub/inference/_common.py +133 -121
- huggingface_hub/inference/_generated/_async_client.py +536 -722
- huggingface_hub/inference/_generated/types/__init__.py +6 -1
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +5 -6
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +77 -31
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/image_to_image.py +8 -2
- huggingface_hub/inference/_generated/types/image_to_text.py +2 -3
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +11 -11
- huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_speech.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
- huggingface_hub/inference/_mcp/agent.py +100 -0
- huggingface_hub/inference/_mcp/cli.py +247 -0
- huggingface_hub/inference/_mcp/constants.py +81 -0
- huggingface_hub/inference/_mcp/mcp_client.py +395 -0
- huggingface_hub/inference/_mcp/types.py +45 -0
- huggingface_hub/inference/_mcp/utils.py +128 -0
- huggingface_hub/inference/_providers/__init__.py +149 -20
- huggingface_hub/inference/_providers/_common.py +160 -37
- huggingface_hub/inference/_providers/black_forest_labs.py +12 -9
- huggingface_hub/inference/_providers/cerebras.py +6 -0
- huggingface_hub/inference/_providers/clarifai.py +13 -0
- huggingface_hub/inference/_providers/cohere.py +32 -0
- huggingface_hub/inference/_providers/fal_ai.py +231 -22
- huggingface_hub/inference/_providers/featherless_ai.py +38 -0
- huggingface_hub/inference/_providers/fireworks_ai.py +22 -1
- huggingface_hub/inference/_providers/groq.py +9 -0
- huggingface_hub/inference/_providers/hf_inference.py +143 -33
- huggingface_hub/inference/_providers/hyperbolic.py +9 -5
- huggingface_hub/inference/_providers/nebius.py +47 -5
- huggingface_hub/inference/_providers/novita.py +48 -5
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/openai.py +25 -0
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +46 -9
- huggingface_hub/inference/_providers/sambanova.py +37 -1
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/together.py +34 -5
- huggingface_hub/inference/_providers/wavespeed.py +138 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +33 -100
- huggingface_hub/repocard.py +34 -38
- huggingface_hub/repocard_data.py +79 -59
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +12 -15
- huggingface_hub/serialization/_dduf.py +8 -8
- huggingface_hub/serialization/_torch.py +69 -69
- huggingface_hub/utils/__init__.py +27 -8
- huggingface_hub/utils/_auth.py +7 -7
- huggingface_hub/utils/_cache_manager.py +92 -147
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +55 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +5 -5
- huggingface_hub/utils/_headers.py +8 -30
- huggingface_hub/utils/_http.py +399 -237
- huggingface_hub/utils/_pagination.py +6 -6
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +74 -22
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +13 -11
- huggingface_hub/utils/_telemetry.py +4 -4
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +55 -74
- huggingface_hub/utils/_verification.py +167 -0
- huggingface_hub/utils/_xet.py +235 -0
- huggingface_hub/utils/_xet_progress_reporting.py +162 -0
- huggingface_hub/utils/insecure_hashlib.py +3 -5
- huggingface_hub/utils/logging.py +8 -11
- huggingface_hub/utils/tqdm.py +33 -4
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -82
- huggingface_hub-1.1.3.dist-info/RECORD +155 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
- huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
- huggingface_hub/commands/delete_cache.py +0 -428
- huggingface_hub/commands/download.py +0 -200
- huggingface_hub/commands/huggingface_cli.py +0 -61
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo_files.py +0 -128
- huggingface_hub/commands/scan_cache.py +0 -181
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -299
- huggingface_hub/commands/upload_large_folder.py +0 -129
- huggingface_hub/commands/user.py +0 -304
- huggingface_hub/commands/version.py +0 -37
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.29.0rc2.dist-info/RECORD +0 -131
- huggingface_hub-0.29.0rc2.dist-info/entry_points.txt +0 -6
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
|
@@ -18,30 +18,16 @@ import base64
|
|
|
18
18
|
import io
|
|
19
19
|
import json
|
|
20
20
|
import logging
|
|
21
|
-
|
|
21
|
+
import mimetypes
|
|
22
22
|
from dataclasses import dataclass
|
|
23
23
|
from pathlib import Path
|
|
24
|
-
from typing import
|
|
25
|
-
TYPE_CHECKING,
|
|
26
|
-
Any,
|
|
27
|
-
AsyncIterable,
|
|
28
|
-
BinaryIO,
|
|
29
|
-
ContextManager,
|
|
30
|
-
Dict,
|
|
31
|
-
Generator,
|
|
32
|
-
Iterable,
|
|
33
|
-
List,
|
|
34
|
-
Literal,
|
|
35
|
-
NoReturn,
|
|
36
|
-
Optional,
|
|
37
|
-
Union,
|
|
38
|
-
overload,
|
|
39
|
-
)
|
|
24
|
+
from typing import TYPE_CHECKING, Any, AsyncIterable, BinaryIO, Iterable, Literal, NoReturn, Optional, Union, overload
|
|
40
25
|
|
|
41
|
-
|
|
26
|
+
import httpx
|
|
42
27
|
|
|
43
28
|
from huggingface_hub.errors import (
|
|
44
29
|
GenerationError,
|
|
30
|
+
HfHubHTTPError,
|
|
45
31
|
IncompleteGenerationError,
|
|
46
32
|
OverloadedError,
|
|
47
33
|
TextGenerationError,
|
|
@@ -49,21 +35,19 @@ from huggingface_hub.errors import (
|
|
|
49
35
|
ValidationError,
|
|
50
36
|
)
|
|
51
37
|
|
|
52
|
-
from ..utils import get_session,
|
|
38
|
+
from ..utils import get_session, is_numpy_available, is_pillow_available
|
|
53
39
|
from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput
|
|
54
40
|
|
|
55
41
|
|
|
56
42
|
if TYPE_CHECKING:
|
|
57
|
-
from aiohttp import ClientResponse, ClientSession
|
|
58
43
|
from PIL.Image import Image
|
|
59
44
|
|
|
60
45
|
# TYPES
|
|
61
46
|
UrlT = str
|
|
62
47
|
PathT = Union[str, Path]
|
|
63
|
-
|
|
64
|
-
ContentT = Union[BinaryT, PathT, UrlT]
|
|
48
|
+
ContentT = Union[bytes, BinaryIO, PathT, UrlT, "Image", bytearray, memoryview]
|
|
65
49
|
|
|
66
|
-
# Use to set
|
|
50
|
+
# Use to set an Accept: image/png header
|
|
67
51
|
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
|
|
68
52
|
|
|
69
53
|
logger = logging.getLogger(__name__)
|
|
@@ -74,52 +58,37 @@ class RequestParameters:
|
|
|
74
58
|
url: str
|
|
75
59
|
task: str
|
|
76
60
|
model: Optional[str]
|
|
77
|
-
json: Optional[Union[str,
|
|
78
|
-
data: Optional[
|
|
79
|
-
headers:
|
|
61
|
+
json: Optional[Union[str, dict, list]]
|
|
62
|
+
data: Optional[bytes]
|
|
63
|
+
headers: dict[str, Any]
|
|
80
64
|
|
|
81
65
|
|
|
82
|
-
|
|
83
|
-
@dataclass
|
|
84
|
-
class ModelStatus:
|
|
66
|
+
class MimeBytes(bytes):
|
|
85
67
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
requests inference on the endpoint. This means it is transparent for the
|
|
98
|
-
user to load a model, except that the first call takes longer to complete.
|
|
99
|
-
compute_type (`Dict`):
|
|
100
|
-
Information about the compute resource the model is using or will use, such as 'gpu' type and number of
|
|
101
|
-
replicas.
|
|
102
|
-
framework (`str`):
|
|
103
|
-
The name of the framework that the model was built with, such as 'transformers'
|
|
104
|
-
or 'text-generation-inference'.
|
|
68
|
+
A bytes object with a mime type.
|
|
69
|
+
To be returned by `_prepare_payload_open_as_mime_bytes` in subclasses.
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
```python
|
|
73
|
+
>>> b = MimeBytes(b"hello", "text/plain")
|
|
74
|
+
>>> isinstance(b, bytes)
|
|
75
|
+
True
|
|
76
|
+
>>> b.mime_type
|
|
77
|
+
'text/plain'
|
|
78
|
+
```
|
|
105
79
|
"""
|
|
106
80
|
|
|
107
|
-
|
|
108
|
-
state: str
|
|
109
|
-
compute_type: Dict
|
|
110
|
-
framework: str
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
## IMPORT UTILS
|
|
81
|
+
mime_type: Optional[str]
|
|
114
82
|
|
|
83
|
+
def __new__(cls, data: bytes, mime_type: Optional[str] = None):
|
|
84
|
+
obj = super().__new__(cls, data)
|
|
85
|
+
obj.mime_type = mime_type
|
|
86
|
+
if isinstance(data, MimeBytes) and mime_type is None:
|
|
87
|
+
obj.mime_type = data.mime_type
|
|
88
|
+
return obj
|
|
115
89
|
|
|
116
|
-
def _import_aiohttp():
|
|
117
|
-
# Make sure `aiohttp` is installed on the machine.
|
|
118
|
-
if not is_aiohttp_available():
|
|
119
|
-
raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).")
|
|
120
|
-
import aiohttp
|
|
121
90
|
|
|
122
|
-
|
|
91
|
+
## IMPORT UTILS
|
|
123
92
|
|
|
124
93
|
|
|
125
94
|
def _import_numpy():
|
|
@@ -147,32 +116,49 @@ def _import_pil_image():
|
|
|
147
116
|
|
|
148
117
|
|
|
149
118
|
@overload
|
|
150
|
-
def
|
|
151
|
-
content: ContentT,
|
|
152
|
-
) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None"
|
|
119
|
+
def _open_as_mime_bytes(content: ContentT) -> MimeBytes: ... # means "if input is not None, output is not None"
|
|
153
120
|
|
|
154
121
|
|
|
155
122
|
@overload
|
|
156
|
-
def
|
|
157
|
-
content: Literal[None],
|
|
158
|
-
) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None"
|
|
159
|
-
|
|
123
|
+
def _open_as_mime_bytes(content: Literal[None]) -> Literal[None]: ... # means "if input is None, output is None"
|
|
160
124
|
|
|
161
|
-
@contextmanager # type: ignore
|
|
162
|
-
def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
|
|
163
|
-
"""Open `content` as a binary file, either from a URL, a local path, or raw bytes.
|
|
164
125
|
|
|
165
|
-
|
|
126
|
+
def _open_as_mime_bytes(content: Optional[ContentT]) -> Optional[MimeBytes]:
|
|
127
|
+
"""Open `content` as a binary file, either from a URL, a local path, raw bytes, or a PIL Image.
|
|
166
128
|
|
|
167
|
-
|
|
168
|
-
TODO: handle base64 as input
|
|
129
|
+
Do nothing if `content` is None.
|
|
169
130
|
"""
|
|
131
|
+
# If content is None, yield None
|
|
132
|
+
if content is None:
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
# If content is bytes, return it
|
|
136
|
+
if isinstance(content, bytes):
|
|
137
|
+
return MimeBytes(content)
|
|
138
|
+
|
|
139
|
+
# If content is raw binary data (bytearray, memoryview)
|
|
140
|
+
if isinstance(content, (bytearray, memoryview)):
|
|
141
|
+
return MimeBytes(bytes(content))
|
|
142
|
+
|
|
143
|
+
# If content is a binary file-like object
|
|
144
|
+
if hasattr(content, "read"): # duck-typing instead of isinstance(content, BinaryIO)
|
|
145
|
+
logger.debug("Reading content from BinaryIO")
|
|
146
|
+
data = content.read()
|
|
147
|
+
mime_type = mimetypes.guess_type(content.name)[0] if hasattr(content, "name") else None
|
|
148
|
+
if isinstance(data, str):
|
|
149
|
+
raise TypeError("Expected binary stream (bytes), but got text stream")
|
|
150
|
+
return MimeBytes(data, mime_type=mime_type)
|
|
151
|
+
|
|
170
152
|
# If content is a string => must be either a URL or a path
|
|
171
153
|
if isinstance(content, str):
|
|
172
154
|
if content.startswith("https://") or content.startswith("http://"):
|
|
173
155
|
logger.debug(f"Downloading content from {content}")
|
|
174
|
-
|
|
175
|
-
|
|
156
|
+
response = get_session().get(content)
|
|
157
|
+
mime_type = response.headers.get("Content-Type")
|
|
158
|
+
if mime_type is None:
|
|
159
|
+
mime_type = mimetypes.guess_type(content)[0]
|
|
160
|
+
return MimeBytes(response.content, mime_type=mime_type)
|
|
161
|
+
|
|
176
162
|
content = Path(content)
|
|
177
163
|
if not content.exists():
|
|
178
164
|
raise FileNotFoundError(
|
|
@@ -183,18 +169,47 @@ def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT],
|
|
|
183
169
|
# If content is a Path => open it
|
|
184
170
|
if isinstance(content, Path):
|
|
185
171
|
logger.debug(f"Opening content from {content}")
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
172
|
+
return MimeBytes(content.read_bytes(), mime_type=mimetypes.guess_type(content)[0])
|
|
173
|
+
|
|
174
|
+
# If content is a PIL Image => convert to bytes
|
|
175
|
+
if is_pillow_available():
|
|
176
|
+
from PIL import Image
|
|
177
|
+
|
|
178
|
+
if isinstance(content, Image.Image):
|
|
179
|
+
logger.debug("Converting PIL Image to bytes")
|
|
180
|
+
buffer = io.BytesIO()
|
|
181
|
+
format = content.format or "PNG"
|
|
182
|
+
content.save(buffer, format=format)
|
|
183
|
+
return MimeBytes(buffer.getvalue(), mime_type=f"image/{format.lower()}")
|
|
184
|
+
|
|
185
|
+
# If nothing matched, raise error
|
|
186
|
+
raise TypeError(
|
|
187
|
+
f"Unsupported content type: {type(content)}. "
|
|
188
|
+
"Expected one of: bytes, bytearray, BinaryIO, memoryview, Path, str (URL or file path), or PIL.Image.Image."
|
|
189
|
+
)
|
|
191
190
|
|
|
192
191
|
|
|
193
192
|
def _b64_encode(content: ContentT) -> str:
|
|
194
193
|
"""Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL."""
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
194
|
+
raw_bytes = _open_as_mime_bytes(content)
|
|
195
|
+
return base64.b64encode(raw_bytes).decode()
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _as_url(content: ContentT, default_mime_type: str) -> str:
|
|
199
|
+
if isinstance(content, str) and content.startswith(("http://", "https://", "data:")):
|
|
200
|
+
return content
|
|
201
|
+
|
|
202
|
+
# Convert content to bytes
|
|
203
|
+
raw_bytes = _open_as_mime_bytes(content)
|
|
204
|
+
|
|
205
|
+
# Get MIME type
|
|
206
|
+
mime_type = raw_bytes.mime_type or default_mime_type
|
|
207
|
+
|
|
208
|
+
# Encode content to base64
|
|
209
|
+
encoded_data = base64.b64encode(raw_bytes).decode()
|
|
210
|
+
|
|
211
|
+
# Build data URL
|
|
212
|
+
return f"data:{mime_type};base64,{encoded_data}"
|
|
198
213
|
|
|
199
214
|
|
|
200
215
|
def _b64_to_image(encoded_image: str) -> "Image":
|
|
@@ -203,7 +218,7 @@ def _b64_to_image(encoded_image: str) -> "Image":
|
|
|
203
218
|
return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
|
|
204
219
|
|
|
205
220
|
|
|
206
|
-
def _bytes_to_list(content: bytes) ->
|
|
221
|
+
def _bytes_to_list(content: bytes) -> list:
|
|
207
222
|
"""Parse bytes from a Response object into a Python list.
|
|
208
223
|
|
|
209
224
|
Expects the response body to be JSON-encoded data.
|
|
@@ -214,7 +229,7 @@ def _bytes_to_list(content: bytes) -> List:
|
|
|
214
229
|
return json.loads(content.decode())
|
|
215
230
|
|
|
216
231
|
|
|
217
|
-
def _bytes_to_dict(content: bytes) ->
|
|
232
|
+
def _bytes_to_dict(content: bytes) -> dict:
|
|
218
233
|
"""Parse bytes from a Response object into a Python dictionary.
|
|
219
234
|
|
|
220
235
|
Expects the response body to be JSON-encoded data.
|
|
@@ -234,24 +249,21 @@ def _bytes_to_image(content: bytes) -> "Image":
|
|
|
234
249
|
return Image.open(io.BytesIO(content))
|
|
235
250
|
|
|
236
251
|
|
|
237
|
-
def _as_dict(response: Union[bytes,
|
|
252
|
+
def _as_dict(response: Union[bytes, dict]) -> dict:
|
|
238
253
|
return json.loads(response) if isinstance(response, bytes) else response
|
|
239
254
|
|
|
240
255
|
|
|
241
|
-
## PAYLOAD UTILS
|
|
242
|
-
|
|
243
|
-
|
|
244
256
|
## STREAMING UTILS
|
|
245
257
|
|
|
246
258
|
|
|
247
259
|
def _stream_text_generation_response(
|
|
248
|
-
|
|
260
|
+
output_lines: Iterable[str], details: bool
|
|
249
261
|
) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]:
|
|
250
262
|
"""Used in `InferenceClient.text_generation`."""
|
|
251
263
|
# Parse ServerSentEvents
|
|
252
|
-
for
|
|
264
|
+
for line in output_lines:
|
|
253
265
|
try:
|
|
254
|
-
output = _format_text_generation_stream_output(
|
|
266
|
+
output = _format_text_generation_stream_output(line, details)
|
|
255
267
|
except StopIteration:
|
|
256
268
|
break
|
|
257
269
|
if output is not None:
|
|
@@ -259,13 +271,13 @@ def _stream_text_generation_response(
|
|
|
259
271
|
|
|
260
272
|
|
|
261
273
|
async def _async_stream_text_generation_response(
|
|
262
|
-
|
|
274
|
+
output_lines: AsyncIterable[str], details: bool
|
|
263
275
|
) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
|
|
264
276
|
"""Used in `AsyncInferenceClient.text_generation`."""
|
|
265
277
|
# Parse ServerSentEvents
|
|
266
|
-
async for
|
|
278
|
+
async for line in output_lines:
|
|
267
279
|
try:
|
|
268
|
-
output = _format_text_generation_stream_output(
|
|
280
|
+
output = _format_text_generation_stream_output(line, details)
|
|
269
281
|
except StopIteration:
|
|
270
282
|
break
|
|
271
283
|
if output is not None:
|
|
@@ -273,17 +285,17 @@ async def _async_stream_text_generation_response(
|
|
|
273
285
|
|
|
274
286
|
|
|
275
287
|
def _format_text_generation_stream_output(
|
|
276
|
-
|
|
288
|
+
line: str, details: bool
|
|
277
289
|
) -> Optional[Union[str, TextGenerationStreamOutput]]:
|
|
278
|
-
if not
|
|
290
|
+
if not line.startswith("data:"):
|
|
279
291
|
return None # empty line
|
|
280
292
|
|
|
281
|
-
if
|
|
293
|
+
if line.strip() == "data: [DONE]":
|
|
282
294
|
raise StopIteration("[DONE] signal received.")
|
|
283
295
|
|
|
284
296
|
# Decode payload
|
|
285
|
-
payload =
|
|
286
|
-
json_payload = json.loads(payload
|
|
297
|
+
payload = line.lstrip("data:").rstrip("/n")
|
|
298
|
+
json_payload = json.loads(payload)
|
|
287
299
|
|
|
288
300
|
# Either an error as being returned
|
|
289
301
|
if json_payload.get("error") is not None:
|
|
@@ -295,12 +307,12 @@ def _format_text_generation_stream_output(
|
|
|
295
307
|
|
|
296
308
|
|
|
297
309
|
def _stream_chat_completion_response(
|
|
298
|
-
|
|
310
|
+
lines: Iterable[str],
|
|
299
311
|
) -> Iterable[ChatCompletionStreamOutput]:
|
|
300
312
|
"""Used in `InferenceClient.chat_completion` if model is served with TGI."""
|
|
301
|
-
for
|
|
313
|
+
for line in lines:
|
|
302
314
|
try:
|
|
303
|
-
output = _format_chat_completion_stream_output(
|
|
315
|
+
output = _format_chat_completion_stream_output(line)
|
|
304
316
|
except StopIteration:
|
|
305
317
|
break
|
|
306
318
|
if output is not None:
|
|
@@ -308,12 +320,12 @@ def _stream_chat_completion_response(
|
|
|
308
320
|
|
|
309
321
|
|
|
310
322
|
async def _async_stream_chat_completion_response(
|
|
311
|
-
|
|
323
|
+
lines: AsyncIterable[str],
|
|
312
324
|
) -> AsyncIterable[ChatCompletionStreamOutput]:
|
|
313
325
|
"""Used in `AsyncInferenceClient.chat_completion`."""
|
|
314
|
-
async for
|
|
326
|
+
async for line in lines:
|
|
315
327
|
try:
|
|
316
|
-
output = _format_chat_completion_stream_output(
|
|
328
|
+
output = _format_chat_completion_stream_output(line)
|
|
317
329
|
except StopIteration:
|
|
318
330
|
break
|
|
319
331
|
if output is not None:
|
|
@@ -321,17 +333,16 @@ async def _async_stream_chat_completion_response(
|
|
|
321
333
|
|
|
322
334
|
|
|
323
335
|
def _format_chat_completion_stream_output(
|
|
324
|
-
|
|
336
|
+
line: str,
|
|
325
337
|
) -> Optional[ChatCompletionStreamOutput]:
|
|
326
|
-
if not
|
|
338
|
+
if not line.startswith("data:"):
|
|
327
339
|
return None # empty line
|
|
328
340
|
|
|
329
|
-
if
|
|
341
|
+
if line.strip() == "data: [DONE]":
|
|
330
342
|
raise StopIteration("[DONE] signal received.")
|
|
331
343
|
|
|
332
344
|
# Decode payload
|
|
333
|
-
|
|
334
|
-
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
|
345
|
+
json_payload = json.loads(line.lstrip("data:").strip())
|
|
335
346
|
|
|
336
347
|
# Either an error as being returned
|
|
337
348
|
if json_payload.get("error") is not None:
|
|
@@ -341,10 +352,9 @@ def _format_chat_completion_stream_output(
|
|
|
341
352
|
return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
|
|
342
353
|
|
|
343
354
|
|
|
344
|
-
async def _async_yield_from(client:
|
|
345
|
-
async for
|
|
346
|
-
yield
|
|
347
|
-
await client.close()
|
|
355
|
+
async def _async_yield_from(client: httpx.AsyncClient, response: httpx.Response) -> AsyncIterable[str]:
|
|
356
|
+
async for line in response.aiter_lines():
|
|
357
|
+
yield line.strip()
|
|
348
358
|
|
|
349
359
|
|
|
350
360
|
# "TGI servers" are servers running with the `text-generation-inference` backend.
|
|
@@ -354,7 +364,7 @@ async def _async_yield_from(client: "ClientSession", response: "ClientResponse")
|
|
|
354
364
|
#
|
|
355
365
|
# Both approaches have very similar APIs, but not exactly the same. What we do first in
|
|
356
366
|
# the `text_generation` method is to assume the model is served via TGI. If we realize
|
|
357
|
-
# it's not the case (i.e. we receive an HTTP 400 Bad Request), we
|
|
367
|
+
# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fall back to the
|
|
358
368
|
# default API with a warning message. When that's the case, We remember the unsupported
|
|
359
369
|
# attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable.
|
|
360
370
|
#
|
|
@@ -365,14 +375,14 @@ async def _async_yield_from(client: "ClientSession", response: "ClientResponse")
|
|
|
365
375
|
# For more details, see https://github.com/huggingface/text-generation-inference and
|
|
366
376
|
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
|
|
367
377
|
|
|
368
|
-
_UNSUPPORTED_TEXT_GENERATION_KWARGS:
|
|
378
|
+
_UNSUPPORTED_TEXT_GENERATION_KWARGS: dict[Optional[str], list[str]] = {}
|
|
369
379
|
|
|
370
380
|
|
|
371
|
-
def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs:
|
|
381
|
+
def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: list[str]) -> None:
|
|
372
382
|
_UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs)
|
|
373
383
|
|
|
374
384
|
|
|
375
|
-
def _get_unsupported_text_generation_kwargs(model: Optional[str]) ->
|
|
385
|
+
def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> list[str]:
|
|
376
386
|
return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])
|
|
377
387
|
|
|
378
388
|
|
|
@@ -383,7 +393,7 @@ def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]:
|
|
|
383
393
|
# ----------------------
|
|
384
394
|
|
|
385
395
|
|
|
386
|
-
def raise_text_generation_error(http_error:
|
|
396
|
+
def raise_text_generation_error(http_error: HfHubHTTPError) -> NoReturn:
|
|
387
397
|
"""
|
|
388
398
|
Try to parse text-generation-inference error message and raise HTTPError in any case.
|
|
389
399
|
|
|
@@ -392,6 +402,8 @@ def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
|
|
|
392
402
|
The HTTPError that have been raised.
|
|
393
403
|
"""
|
|
394
404
|
# Try to parse a Text Generation Inference error
|
|
405
|
+
if http_error.response is None:
|
|
406
|
+
raise http_error
|
|
395
407
|
|
|
396
408
|
try:
|
|
397
409
|
# Hacky way to retrieve payload in case of aiohttp error
|