huggingface-hub 0.34.5__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +19 -1
- huggingface_hub/_jobs_api.py +159 -2
- huggingface_hub/_tensorboard_logger.py +9 -10
- huggingface_hub/cli/auth.py +1 -1
- huggingface_hub/cli/cache.py +3 -9
- huggingface_hub/cli/jobs.py +551 -1
- huggingface_hub/cli/repo.py +6 -4
- huggingface_hub/commands/delete_cache.py +2 -2
- huggingface_hub/commands/scan_cache.py +1 -1
- huggingface_hub/commands/user.py +1 -1
- huggingface_hub/hf_api.py +522 -78
- huggingface_hub/hf_file_system.py +3 -1
- huggingface_hub/hub_mixin.py +5 -3
- huggingface_hub/inference/_client.py +18 -181
- huggingface_hub/inference/_common.py +72 -70
- huggingface_hub/inference/_generated/_async_client.py +35 -201
- huggingface_hub/inference/_generated/types/chat_completion.py +2 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +3 -3
- huggingface_hub/inference/_mcp/cli.py +1 -1
- huggingface_hub/inference/_mcp/constants.py +1 -1
- huggingface_hub/inference/_mcp/mcp_client.py +28 -11
- huggingface_hub/inference/_mcp/types.py +3 -0
- huggingface_hub/inference/_mcp/utils.py +7 -3
- huggingface_hub/inference/_providers/__init__.py +5 -0
- huggingface_hub/inference/_providers/_common.py +28 -4
- huggingface_hub/inference/_providers/black_forest_labs.py +1 -1
- huggingface_hub/inference/_providers/fal_ai.py +2 -2
- huggingface_hub/inference/_providers/hf_inference.py +15 -7
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +1 -1
- huggingface_hub/repocard.py +2 -1
- huggingface_hub/utils/_git_credential.py +1 -1
- huggingface_hub/utils/_typing.py +24 -4
- huggingface_hub/utils/_xet_progress_reporting.py +31 -10
- {huggingface_hub-0.34.5.dist-info → huggingface_hub-0.35.0.dist-info}/METADATA +7 -4
- {huggingface_hub-0.34.5.dist-info → huggingface_hub-0.35.0.dist-info}/RECORD +40 -39
- {huggingface_hub-0.34.5.dist-info → huggingface_hub-0.35.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.34.5.dist-info → huggingface_hub-0.35.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.34.5.dist-info → huggingface_hub-0.35.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.34.5.dist-info → huggingface_hub-0.35.0.dist-info}/top_level.txt +0 -0
|
@@ -896,7 +896,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
896
896
|
repo_type=resolve_remote_path.repo_type,
|
|
897
897
|
endpoint=self.endpoint,
|
|
898
898
|
),
|
|
899
|
-
temp_file=outfile,
|
|
899
|
+
temp_file=outfile, # type: ignore[arg-type]
|
|
900
900
|
displayed_filename=rpath,
|
|
901
901
|
expected_size=expected_size,
|
|
902
902
|
resume_size=0,
|
|
@@ -1069,6 +1069,7 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1069
1069
|
)
|
|
1070
1070
|
hf_raise_for_status(self.response)
|
|
1071
1071
|
try:
|
|
1072
|
+
self.response.raw.decode_content = True
|
|
1072
1073
|
out = self.response.raw.read(*read_args)
|
|
1073
1074
|
except Exception:
|
|
1074
1075
|
self.response.close()
|
|
@@ -1091,6 +1092,7 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
1091
1092
|
)
|
|
1092
1093
|
hf_raise_for_status(self.response)
|
|
1093
1094
|
try:
|
|
1095
|
+
self.response.raw.decode_content = True
|
|
1094
1096
|
out = self.response.raw.read(*read_args)
|
|
1095
1097
|
except Exception:
|
|
1096
1098
|
self.response.close()
|
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -266,12 +266,14 @@ class ModelHubMixin:
|
|
|
266
266
|
if pipeline_tag is not None:
|
|
267
267
|
info.model_card_data.pipeline_tag = pipeline_tag
|
|
268
268
|
if tags is not None:
|
|
269
|
+
normalized_tags = list(tags)
|
|
269
270
|
if info.model_card_data.tags is not None:
|
|
270
|
-
info.model_card_data.tags.extend(
|
|
271
|
+
info.model_card_data.tags.extend(normalized_tags)
|
|
271
272
|
else:
|
|
272
|
-
info.model_card_data.tags =
|
|
273
|
+
info.model_card_data.tags = normalized_tags
|
|
273
274
|
|
|
274
|
-
info.model_card_data.tags
|
|
275
|
+
if info.model_card_data.tags is not None:
|
|
276
|
+
info.model_card_data.tags = sorted(set(info.model_card_data.tags))
|
|
275
277
|
|
|
276
278
|
# Handle encoders/decoders for args
|
|
277
279
|
cls._hub_mixin_coders = coders or {}
|
|
@@ -45,7 +45,6 @@ from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
|
|
|
45
45
|
from huggingface_hub.inference._common import (
|
|
46
46
|
TASKS_EXPECTING_IMAGES,
|
|
47
47
|
ContentT,
|
|
48
|
-
ModelStatus,
|
|
49
48
|
RequestParameters,
|
|
50
49
|
_b64_encode,
|
|
51
50
|
_b64_to_image,
|
|
@@ -54,7 +53,6 @@ from huggingface_hub.inference._common import (
|
|
|
54
53
|
_bytes_to_list,
|
|
55
54
|
_get_unsupported_text_generation_kwargs,
|
|
56
55
|
_import_numpy,
|
|
57
|
-
_open_as_binary,
|
|
58
56
|
_set_unsupported_text_generation_kwargs,
|
|
59
57
|
_stream_chat_completion_response,
|
|
60
58
|
_stream_text_generation_response,
|
|
@@ -105,7 +103,6 @@ from huggingface_hub.inference._generated.types import (
|
|
|
105
103
|
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
|
|
106
104
|
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
107
105
|
from huggingface_hub.utils._auth import get_token
|
|
108
|
-
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
109
106
|
|
|
110
107
|
|
|
111
108
|
if TYPE_CHECKING:
|
|
@@ -133,7 +130,7 @@ class InferenceClient:
|
|
|
133
130
|
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
|
134
131
|
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
|
|
135
132
|
provider (`str`, *optional*):
|
|
136
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
|
|
133
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"` or `"together"`.
|
|
137
134
|
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
138
135
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
139
136
|
token (`str`, *optional*):
|
|
@@ -259,21 +256,20 @@ class InferenceClient:
|
|
|
259
256
|
if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers:
|
|
260
257
|
request_parameters.headers["Accept"] = "image/png"
|
|
261
258
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
259
|
+
try:
|
|
260
|
+
response = get_session().post(
|
|
261
|
+
request_parameters.url,
|
|
262
|
+
json=request_parameters.json,
|
|
263
|
+
data=request_parameters.data,
|
|
264
|
+
headers=request_parameters.headers,
|
|
265
|
+
cookies=self.cookies,
|
|
266
|
+
timeout=self.timeout,
|
|
267
|
+
stream=stream,
|
|
268
|
+
proxies=self.proxies,
|
|
269
|
+
)
|
|
270
|
+
except TimeoutError as error:
|
|
271
|
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
|
272
|
+
raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore
|
|
277
273
|
|
|
278
274
|
try:
|
|
279
275
|
hf_raise_for_status(response)
|
|
@@ -1462,8 +1458,8 @@ class InferenceClient:
|
|
|
1462
1458
|
api_key=self.token,
|
|
1463
1459
|
)
|
|
1464
1460
|
response = self._inner_post(request_parameters)
|
|
1465
|
-
|
|
1466
|
-
return
|
|
1461
|
+
output_list: List[ImageToTextOutput] = ImageToTextOutput.parse_obj_as_list(response)
|
|
1462
|
+
return output_list[0]
|
|
1467
1463
|
|
|
1468
1464
|
def object_detection(
|
|
1469
1465
|
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
|
|
@@ -3273,101 +3269,6 @@ class InferenceClient:
|
|
|
3273
3269
|
response = self._inner_post(request_parameters)
|
|
3274
3270
|
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
|
|
3275
3271
|
|
|
3276
|
-
@_deprecate_method(
|
|
3277
|
-
version="0.35.0",
|
|
3278
|
-
message=(
|
|
3279
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3280
|
-
" Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider."
|
|
3281
|
-
),
|
|
3282
|
-
)
|
|
3283
|
-
def list_deployed_models(
|
|
3284
|
-
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
|
3285
|
-
) -> Dict[str, List[str]]:
|
|
3286
|
-
"""
|
|
3287
|
-
List models deployed on the HF Serverless Inference API service.
|
|
3288
|
-
|
|
3289
|
-
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
|
3290
|
-
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
|
3291
|
-
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
|
3292
|
-
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
|
3293
|
-
frameworks are checked, the more time it will take.
|
|
3294
|
-
|
|
3295
|
-
<Tip warning={true}>
|
|
3296
|
-
|
|
3297
|
-
This endpoint method does not return a live list of all models available for the HF Inference API service.
|
|
3298
|
-
It searches over a cached list of models that were recently available and the list may not be up to date.
|
|
3299
|
-
If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`].
|
|
3300
|
-
|
|
3301
|
-
</Tip>
|
|
3302
|
-
|
|
3303
|
-
<Tip>
|
|
3304
|
-
|
|
3305
|
-
This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to
|
|
3306
|
-
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
|
3307
|
-
|
|
3308
|
-
</Tip>
|
|
3309
|
-
|
|
3310
|
-
Args:
|
|
3311
|
-
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
|
3312
|
-
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
|
3313
|
-
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
|
3314
|
-
custom set of frameworks to check.
|
|
3315
|
-
|
|
3316
|
-
Returns:
|
|
3317
|
-
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
|
3318
|
-
|
|
3319
|
-
Example:
|
|
3320
|
-
```python
|
|
3321
|
-
>>> from huggingface_hub import InferenceClient
|
|
3322
|
-
>>> client = InferenceClient()
|
|
3323
|
-
|
|
3324
|
-
# Discover zero-shot-classification models currently deployed
|
|
3325
|
-
>>> models = client.list_deployed_models()
|
|
3326
|
-
>>> models["zero-shot-classification"]
|
|
3327
|
-
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
|
3328
|
-
|
|
3329
|
-
# List from only 1 framework
|
|
3330
|
-
>>> client.list_deployed_models("text-generation-inference")
|
|
3331
|
-
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
|
3332
|
-
```
|
|
3333
|
-
"""
|
|
3334
|
-
if self.provider != "hf-inference":
|
|
3335
|
-
raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.")
|
|
3336
|
-
|
|
3337
|
-
# Resolve which frameworks to check
|
|
3338
|
-
if frameworks is None:
|
|
3339
|
-
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
|
|
3340
|
-
elif frameworks == "all":
|
|
3341
|
-
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
|
|
3342
|
-
elif isinstance(frameworks, str):
|
|
3343
|
-
frameworks = [frameworks]
|
|
3344
|
-
frameworks = list(set(frameworks))
|
|
3345
|
-
|
|
3346
|
-
# Fetch them iteratively
|
|
3347
|
-
models_by_task: Dict[str, List[str]] = {}
|
|
3348
|
-
|
|
3349
|
-
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
|
3350
|
-
for model in items:
|
|
3351
|
-
if framework == "sentence-transformers":
|
|
3352
|
-
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
|
3353
|
-
# branded as such in the API response
|
|
3354
|
-
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
|
3355
|
-
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
|
3356
|
-
else:
|
|
3357
|
-
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
|
3358
|
-
|
|
3359
|
-
for framework in frameworks:
|
|
3360
|
-
response = get_session().get(
|
|
3361
|
-
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
|
|
3362
|
-
)
|
|
3363
|
-
hf_raise_for_status(response)
|
|
3364
|
-
_unpack_response(framework, response.json())
|
|
3365
|
-
|
|
3366
|
-
# Sort alphabetically for discoverability and return
|
|
3367
|
-
for task, models in models_by_task.items():
|
|
3368
|
-
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
|
3369
|
-
return models_by_task
|
|
3370
|
-
|
|
3371
3272
|
def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
|
|
3372
3273
|
"""
|
|
3373
3274
|
Get information about the deployed endpoint.
|
|
@@ -3431,7 +3332,6 @@ class InferenceClient:
|
|
|
3431
3332
|
Check the health of the deployed endpoint.
|
|
3432
3333
|
|
|
3433
3334
|
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
|
|
3434
|
-
For Inference API, please use [`InferenceClient.get_model_status`] instead.
|
|
3435
3335
|
|
|
3436
3336
|
Args:
|
|
3437
3337
|
model (`str`, *optional*):
|
|
@@ -3455,75 +3355,12 @@ class InferenceClient:
|
|
|
3455
3355
|
if model is None:
|
|
3456
3356
|
raise ValueError("Model id not provided.")
|
|
3457
3357
|
if not model.startswith(("http://", "https://")):
|
|
3458
|
-
raise ValueError(
|
|
3459
|
-
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
|
|
3460
|
-
)
|
|
3358
|
+
raise ValueError("Model must be an Inference Endpoint URL.")
|
|
3461
3359
|
url = model.rstrip("/") + "/health"
|
|
3462
3360
|
|
|
3463
3361
|
response = get_session().get(url, headers=build_hf_headers(token=self.token))
|
|
3464
3362
|
return response.status_code == 200
|
|
3465
3363
|
|
|
3466
|
-
@_deprecate_method(
|
|
3467
|
-
version="0.35.0",
|
|
3468
|
-
message=(
|
|
3469
|
-
"HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)."
|
|
3470
|
-
" Use `HfApi.model_info` to get the model status both with HF Inference API and external providers."
|
|
3471
|
-
),
|
|
3472
|
-
)
|
|
3473
|
-
def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
|
3474
|
-
"""
|
|
3475
|
-
Get the status of a model hosted on the HF Inference API.
|
|
3476
|
-
|
|
3477
|
-
<Tip>
|
|
3478
|
-
|
|
3479
|
-
This endpoint is mostly useful when you already know which model you want to use and want to check its
|
|
3480
|
-
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
|
|
3481
|
-
|
|
3482
|
-
</Tip>
|
|
3483
|
-
|
|
3484
|
-
Args:
|
|
3485
|
-
model (`str`, *optional*):
|
|
3486
|
-
Identifier of the model for witch the status gonna be checked. If model is not provided,
|
|
3487
|
-
the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the
|
|
3488
|
-
identifier cannot be a URL.
|
|
3489
|
-
|
|
3490
|
-
|
|
3491
|
-
Returns:
|
|
3492
|
-
[`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
|
|
3493
|
-
about the state of the model: load, state, compute type and framework.
|
|
3494
|
-
|
|
3495
|
-
Example:
|
|
3496
|
-
```py
|
|
3497
|
-
>>> from huggingface_hub import InferenceClient
|
|
3498
|
-
>>> client = InferenceClient()
|
|
3499
|
-
>>> client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
3500
|
-
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
|
3501
|
-
```
|
|
3502
|
-
"""
|
|
3503
|
-
if self.provider != "hf-inference":
|
|
3504
|
-
raise ValueError(f"Getting model status is not supported on '{self.provider}'.")
|
|
3505
|
-
|
|
3506
|
-
model = model or self.model
|
|
3507
|
-
if model is None:
|
|
3508
|
-
raise ValueError("Model id not provided.")
|
|
3509
|
-
if model.startswith("https://"):
|
|
3510
|
-
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
|
3511
|
-
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
|
|
3512
|
-
|
|
3513
|
-
response = get_session().get(url, headers=build_hf_headers(token=self.token))
|
|
3514
|
-
hf_raise_for_status(response)
|
|
3515
|
-
response_data = response.json()
|
|
3516
|
-
|
|
3517
|
-
if "error" in response_data:
|
|
3518
|
-
raise ValueError(response_data["error"])
|
|
3519
|
-
|
|
3520
|
-
return ModelStatus(
|
|
3521
|
-
loaded=response_data["loaded"],
|
|
3522
|
-
state=response_data["state"],
|
|
3523
|
-
compute_type=response_data["compute_type"],
|
|
3524
|
-
framework=response_data["framework"],
|
|
3525
|
-
)
|
|
3526
|
-
|
|
3527
3364
|
@property
|
|
3528
3365
|
def chat(self) -> "ProxyClientChat":
|
|
3529
3366
|
return ProxyClientChat(self)
|
|
@@ -19,7 +19,6 @@ import io
|
|
|
19
19
|
import json
|
|
20
20
|
import logging
|
|
21
21
|
import mimetypes
|
|
22
|
-
from contextlib import contextmanager
|
|
23
22
|
from dataclasses import dataclass
|
|
24
23
|
from pathlib import Path
|
|
25
24
|
from typing import (
|
|
@@ -27,9 +26,7 @@ from typing import (
|
|
|
27
26
|
Any,
|
|
28
27
|
AsyncIterable,
|
|
29
28
|
BinaryIO,
|
|
30
|
-
ContextManager,
|
|
31
29
|
Dict,
|
|
32
|
-
Generator,
|
|
33
30
|
Iterable,
|
|
34
31
|
List,
|
|
35
32
|
Literal,
|
|
@@ -61,8 +58,7 @@ if TYPE_CHECKING:
|
|
|
61
58
|
# TYPES
|
|
62
59
|
UrlT = str
|
|
63
60
|
PathT = Union[str, Path]
|
|
64
|
-
|
|
65
|
-
ContentT = Union[BinaryT, PathT, UrlT, "Image"]
|
|
61
|
+
ContentT = Union[bytes, BinaryIO, PathT, UrlT, "Image", bytearray, memoryview]
|
|
66
62
|
|
|
67
63
|
# Use to set a Accept: image/png header
|
|
68
64
|
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
|
|
@@ -76,39 +72,33 @@ class RequestParameters:
|
|
|
76
72
|
task: str
|
|
77
73
|
model: Optional[str]
|
|
78
74
|
json: Optional[Union[str, Dict, List]]
|
|
79
|
-
data: Optional[
|
|
75
|
+
data: Optional[bytes]
|
|
80
76
|
headers: Dict[str, Any]
|
|
81
77
|
|
|
82
78
|
|
|
83
|
-
|
|
84
|
-
@dataclass
|
|
85
|
-
class ModelStatus:
|
|
79
|
+
class MimeBytes(bytes):
|
|
86
80
|
"""
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
requests inference on the endpoint. This means it is transparent for the
|
|
99
|
-
user to load a model, except that the first call takes longer to complete.
|
|
100
|
-
compute_type (`Dict`):
|
|
101
|
-
Information about the compute resource the model is using or will use, such as 'gpu' type and number of
|
|
102
|
-
replicas.
|
|
103
|
-
framework (`str`):
|
|
104
|
-
The name of the framework that the model was built with, such as 'transformers'
|
|
105
|
-
or 'text-generation-inference'.
|
|
81
|
+
A bytes object with a mime type.
|
|
82
|
+
To be returned by `_prepare_payload_open_as_mime_bytes` in subclasses.
|
|
83
|
+
|
|
84
|
+
Example:
|
|
85
|
+
```python
|
|
86
|
+
>>> b = MimeBytes(b"hello", "text/plain")
|
|
87
|
+
>>> isinstance(b, bytes)
|
|
88
|
+
True
|
|
89
|
+
>>> b.mime_type
|
|
90
|
+
'text/plain'
|
|
91
|
+
```
|
|
106
92
|
"""
|
|
107
93
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
94
|
+
mime_type: Optional[str]
|
|
95
|
+
|
|
96
|
+
def __new__(cls, data: bytes, mime_type: Optional[str] = None):
|
|
97
|
+
obj = super().__new__(cls, data)
|
|
98
|
+
obj.mime_type = mime_type
|
|
99
|
+
if isinstance(data, MimeBytes) and mime_type is None:
|
|
100
|
+
obj.mime_type = data.mime_type
|
|
101
|
+
return obj
|
|
112
102
|
|
|
113
103
|
|
|
114
104
|
## IMPORT UTILS
|
|
@@ -148,31 +138,49 @@ def _import_pil_image():
|
|
|
148
138
|
|
|
149
139
|
|
|
150
140
|
@overload
|
|
151
|
-
def
|
|
152
|
-
content: ContentT,
|
|
153
|
-
) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None"
|
|
141
|
+
def _open_as_mime_bytes(content: ContentT) -> MimeBytes: ... # means "if input is not None, output is not None"
|
|
154
142
|
|
|
155
143
|
|
|
156
144
|
@overload
|
|
157
|
-
def
|
|
158
|
-
content: Literal[None],
|
|
159
|
-
) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None"
|
|
145
|
+
def _open_as_mime_bytes(content: Literal[None]) -> Literal[None]: ... # means "if input is None, output is None"
|
|
160
146
|
|
|
161
147
|
|
|
162
|
-
|
|
163
|
-
def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
|
|
148
|
+
def _open_as_mime_bytes(content: Optional[ContentT]) -> Optional[MimeBytes]:
|
|
164
149
|
"""Open `content` as a binary file, either from a URL, a local path, raw bytes, or a PIL Image.
|
|
165
150
|
|
|
166
151
|
Do nothing if `content` is None.
|
|
167
|
-
|
|
168
|
-
TODO: handle base64 as input
|
|
169
152
|
"""
|
|
153
|
+
# If content is None, yield None
|
|
154
|
+
if content is None:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
# If content is bytes, return it
|
|
158
|
+
if isinstance(content, bytes):
|
|
159
|
+
return MimeBytes(content)
|
|
160
|
+
|
|
161
|
+
# If content is raw binary data (bytearray, memoryview)
|
|
162
|
+
if isinstance(content, (bytearray, memoryview)):
|
|
163
|
+
return MimeBytes(bytes(content))
|
|
164
|
+
|
|
165
|
+
# If content is a binary file-like object
|
|
166
|
+
if hasattr(content, "read"): # duck-typing instead of isinstance(content, BinaryIO)
|
|
167
|
+
logger.debug("Reading content from BinaryIO")
|
|
168
|
+
data = content.read()
|
|
169
|
+
mime_type = mimetypes.guess_type(content.name)[0] if hasattr(content, "name") else None
|
|
170
|
+
if isinstance(data, str):
|
|
171
|
+
raise TypeError("Expected binary stream (bytes), but got text stream")
|
|
172
|
+
return MimeBytes(data, mime_type=mime_type)
|
|
173
|
+
|
|
170
174
|
# If content is a string => must be either a URL or a path
|
|
171
175
|
if isinstance(content, str):
|
|
172
176
|
if content.startswith("https://") or content.startswith("http://"):
|
|
173
177
|
logger.debug(f"Downloading content from {content}")
|
|
174
|
-
|
|
175
|
-
|
|
178
|
+
response = get_session().get(content)
|
|
179
|
+
mime_type = response.headers.get("Content-Type")
|
|
180
|
+
if mime_type is None:
|
|
181
|
+
mime_type = mimetypes.guess_type(content)[0]
|
|
182
|
+
return MimeBytes(response.content, mime_type=mime_type)
|
|
183
|
+
|
|
176
184
|
content = Path(content)
|
|
177
185
|
if not content.exists():
|
|
178
186
|
raise FileNotFoundError(
|
|
@@ -183,9 +191,7 @@ def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT],
|
|
|
183
191
|
# If content is a Path => open it
|
|
184
192
|
if isinstance(content, Path):
|
|
185
193
|
logger.debug(f"Opening content from {content}")
|
|
186
|
-
|
|
187
|
-
yield f
|
|
188
|
-
return
|
|
194
|
+
return MimeBytes(content.read_bytes(), mime_type=mimetypes.guess_type(content)[0])
|
|
189
195
|
|
|
190
196
|
# If content is a PIL Image => convert to bytes
|
|
191
197
|
if is_pillow_available():
|
|
@@ -194,38 +200,37 @@ def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT],
|
|
|
194
200
|
if isinstance(content, Image.Image):
|
|
195
201
|
logger.debug("Converting PIL Image to bytes")
|
|
196
202
|
buffer = io.BytesIO()
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
return
|
|
203
|
+
format = content.format or "PNG"
|
|
204
|
+
content.save(buffer, format=format)
|
|
205
|
+
return MimeBytes(buffer.getvalue(), mime_type=f"image/{format.lower()}")
|
|
200
206
|
|
|
201
|
-
#
|
|
202
|
-
|
|
207
|
+
# If nothing matched, raise error
|
|
208
|
+
raise TypeError(
|
|
209
|
+
f"Unsupported content type: {type(content)}. "
|
|
210
|
+
"Expected one of: bytes, bytearray, BinaryIO, memoryview, Path, str (URL or file path), or PIL.Image.Image."
|
|
211
|
+
)
|
|
203
212
|
|
|
204
213
|
|
|
205
214
|
def _b64_encode(content: ContentT) -> str:
|
|
206
215
|
"""Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL."""
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
return base64.b64encode(data_as_bytes).decode()
|
|
216
|
+
raw_bytes = _open_as_mime_bytes(content)
|
|
217
|
+
return base64.b64encode(raw_bytes).decode()
|
|
210
218
|
|
|
211
219
|
|
|
212
220
|
def _as_url(content: ContentT, default_mime_type: str) -> str:
|
|
213
|
-
if isinstance(content, str) and
|
|
221
|
+
if isinstance(content, str) and content.startswith(("http://", "https://", "data:")):
|
|
214
222
|
return content
|
|
215
223
|
|
|
216
|
-
#
|
|
217
|
-
|
|
218
|
-
if isinstance(content, (str, Path)):
|
|
219
|
-
mime_type = mimetypes.guess_type(content, strict=False)[0]
|
|
220
|
-
elif is_pillow_available():
|
|
221
|
-
from PIL import Image
|
|
224
|
+
# Convert content to bytes
|
|
225
|
+
raw_bytes = _open_as_mime_bytes(content)
|
|
222
226
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
mime_type = f"image/{(content.format or 'PNG').lower()}"
|
|
227
|
+
# Get MIME type
|
|
228
|
+
mime_type = raw_bytes.mime_type or default_mime_type
|
|
226
229
|
|
|
227
|
-
|
|
228
|
-
encoded_data =
|
|
230
|
+
# Encode content to base64
|
|
231
|
+
encoded_data = base64.b64encode(raw_bytes).decode()
|
|
232
|
+
|
|
233
|
+
# Build data URL
|
|
229
234
|
return f"data:{mime_type};base64,{encoded_data}"
|
|
230
235
|
|
|
231
236
|
|
|
@@ -270,9 +275,6 @@ def _as_dict(response: Union[bytes, Dict]) -> Dict:
|
|
|
270
275
|
return json.loads(response) if isinstance(response, bytes) else response
|
|
271
276
|
|
|
272
277
|
|
|
273
|
-
## PAYLOAD UTILS
|
|
274
|
-
|
|
275
|
-
|
|
276
278
|
## STREAMING UTILS
|
|
277
279
|
|
|
278
280
|
|