huggingface-hub 0.24.6__py3-none-any.whl → 0.25.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +21 -1
- huggingface_hub/_commit_api.py +4 -4
- huggingface_hub/_inference_endpoints.py +13 -1
- huggingface_hub/_local_folder.py +191 -4
- huggingface_hub/_login.py +6 -6
- huggingface_hub/_snapshot_download.py +8 -17
- huggingface_hub/_space_api.py +5 -0
- huggingface_hub/_tensorboard_logger.py +29 -13
- huggingface_hub/_upload_large_folder.py +573 -0
- huggingface_hub/_webhooks_server.py +1 -1
- huggingface_hub/commands/_cli_utils.py +5 -0
- huggingface_hub/commands/download.py +8 -0
- huggingface_hub/commands/huggingface_cli.py +6 -1
- huggingface_hub/commands/lfs.py +2 -1
- huggingface_hub/commands/repo_files.py +2 -2
- huggingface_hub/commands/scan_cache.py +99 -57
- huggingface_hub/commands/tag.py +1 -1
- huggingface_hub/commands/upload.py +2 -1
- huggingface_hub/commands/upload_large_folder.py +129 -0
- huggingface_hub/commands/version.py +37 -0
- huggingface_hub/community.py +2 -2
- huggingface_hub/errors.py +218 -1
- huggingface_hub/fastai_utils.py +2 -3
- huggingface_hub/file_download.py +63 -63
- huggingface_hub/hf_api.py +758 -314
- huggingface_hub/hf_file_system.py +15 -23
- huggingface_hub/hub_mixin.py +27 -25
- huggingface_hub/inference/_client.py +78 -127
- huggingface_hub/inference/_generated/_async_client.py +169 -144
- huggingface_hub/inference/_generated/types/base.py +0 -9
- huggingface_hub/inference/_templating.py +2 -3
- huggingface_hub/inference_api.py +2 -2
- huggingface_hub/keras_mixin.py +2 -2
- huggingface_hub/lfs.py +7 -98
- huggingface_hub/repocard.py +6 -5
- huggingface_hub/repository.py +5 -5
- huggingface_hub/serialization/_torch.py +64 -11
- huggingface_hub/utils/__init__.py +13 -14
- huggingface_hub/utils/_cache_manager.py +97 -14
- huggingface_hub/utils/_fixes.py +18 -2
- huggingface_hub/utils/_http.py +228 -2
- huggingface_hub/utils/_lfs.py +110 -0
- huggingface_hub/utils/_runtime.py +7 -1
- huggingface_hub/utils/_token.py +3 -2
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/METADATA +2 -2
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/RECORD +50 -48
- huggingface_hub/inference/_types.py +0 -52
- huggingface_hub/utils/_errors.py +0 -397
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.24.6.dist-info → huggingface_hub-0.25.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -15,23 +15,13 @@ from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
|
|
|
15
15
|
from fsspec.utils import isfilelike
|
|
16
16
|
from requests import Response
|
|
17
17
|
|
|
18
|
+
from . import constants
|
|
18
19
|
from ._commit_api import CommitOperationCopy, CommitOperationDelete
|
|
19
|
-
from .
|
|
20
|
-
DEFAULT_REVISION,
|
|
21
|
-
ENDPOINT,
|
|
22
|
-
HF_HUB_DOWNLOAD_TIMEOUT,
|
|
23
|
-
HF_HUB_ETAG_TIMEOUT,
|
|
24
|
-
REPO_TYPE_MODEL,
|
|
25
|
-
REPO_TYPES_MAPPING,
|
|
26
|
-
REPO_TYPES_URL_PREFIXES,
|
|
27
|
-
)
|
|
20
|
+
from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
|
28
21
|
from .file_download import hf_hub_url, http_get
|
|
29
22
|
from .hf_api import HfApi, LastCommitInfo, RepoFile
|
|
30
23
|
from .utils import (
|
|
31
|
-
EntryNotFoundError,
|
|
32
24
|
HFValidationError,
|
|
33
|
-
RepositoryNotFoundError,
|
|
34
|
-
RevisionNotFoundError,
|
|
35
25
|
hf_raise_for_status,
|
|
36
26
|
http_backoff,
|
|
37
27
|
)
|
|
@@ -61,10 +51,10 @@ class HfFileSystemResolvedPath:
|
|
|
61
51
|
_raw_revision: Optional[str] = field(default=None, repr=False)
|
|
62
52
|
|
|
63
53
|
def unresolve(self) -> str:
|
|
64
|
-
repo_path = REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id
|
|
54
|
+
repo_path = constants.REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id
|
|
65
55
|
if self._raw_revision:
|
|
66
56
|
return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/")
|
|
67
|
-
elif self.revision != DEFAULT_REVISION:
|
|
57
|
+
elif self.revision != constants.DEFAULT_REVISION:
|
|
68
58
|
return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/")
|
|
69
59
|
else:
|
|
70
60
|
return f"{repo_path}/{self.path_in_repo}".rstrip("/")
|
|
@@ -113,7 +103,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
113
103
|
**storage_options,
|
|
114
104
|
):
|
|
115
105
|
super().__init__(*args, **storage_options)
|
|
116
|
-
self.endpoint = endpoint or ENDPOINT
|
|
106
|
+
self.endpoint = endpoint or constants.ENDPOINT
|
|
117
107
|
self.token = token
|
|
118
108
|
self._api = HfApi(endpoint=endpoint, token=token)
|
|
119
109
|
# Maps (repo_type, repo_id, revision) to a 2-tuple with:
|
|
@@ -128,7 +118,9 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
128
118
|
) -> Tuple[bool, Optional[Exception]]:
|
|
129
119
|
if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
|
|
130
120
|
try:
|
|
131
|
-
self._api.repo_info(
|
|
121
|
+
self._api.repo_info(
|
|
122
|
+
repo_id, revision=revision, repo_type=repo_type, timeout=constants.HF_HUB_ETAG_TIMEOUT
|
|
123
|
+
)
|
|
132
124
|
except (RepositoryNotFoundError, HFValidationError) as e:
|
|
133
125
|
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
|
|
134
126
|
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e
|
|
@@ -158,14 +150,14 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
158
150
|
if not path:
|
|
159
151
|
# can't list repositories at root
|
|
160
152
|
raise NotImplementedError("Access to repositories lists is not implemented.")
|
|
161
|
-
elif path.split("/")[0] + "/" in REPO_TYPES_URL_PREFIXES.values():
|
|
153
|
+
elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values():
|
|
162
154
|
if "/" not in path:
|
|
163
155
|
# can't list repositories at the repository type level
|
|
164
156
|
raise NotImplementedError("Access to repositories lists is not implemented.")
|
|
165
157
|
repo_type, path = path.split("/", 1)
|
|
166
|
-
repo_type = REPO_TYPES_MAPPING[repo_type]
|
|
158
|
+
repo_type = constants.REPO_TYPES_MAPPING[repo_type]
|
|
167
159
|
else:
|
|
168
|
-
repo_type = REPO_TYPE_MODEL
|
|
160
|
+
repo_type = constants.REPO_TYPE_MODEL
|
|
169
161
|
if path.count("/") > 0:
|
|
170
162
|
if "@" in path:
|
|
171
163
|
repo_id, revision_in_path = path.split("@", 1)
|
|
@@ -213,7 +205,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
|
|
|
213
205
|
if not repo_and_revision_exist:
|
|
214
206
|
raise NotImplementedError("Access to repositories lists is not implemented.")
|
|
215
207
|
|
|
216
|
-
revision = revision if revision is not None else DEFAULT_REVISION
|
|
208
|
+
revision = revision if revision is not None else constants.DEFAULT_REVISION
|
|
217
209
|
return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path)
|
|
218
210
|
|
|
219
211
|
def invalidate_cache(self, path: Optional[str] = None) -> None:
|
|
@@ -723,7 +715,7 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
|
|
723
715
|
url,
|
|
724
716
|
headers=headers,
|
|
725
717
|
retry_on_status_codes=(502, 503, 504),
|
|
726
|
-
timeout=HF_HUB_DOWNLOAD_TIMEOUT,
|
|
718
|
+
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
727
719
|
)
|
|
728
720
|
hf_raise_for_status(r)
|
|
729
721
|
return r.content
|
|
@@ -823,7 +815,7 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
823
815
|
headers=self.fs._api._build_hf_headers(),
|
|
824
816
|
retry_on_status_codes=(502, 503, 504),
|
|
825
817
|
stream=True,
|
|
826
|
-
timeout=HF_HUB_DOWNLOAD_TIMEOUT,
|
|
818
|
+
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
827
819
|
)
|
|
828
820
|
hf_raise_for_status(self.response)
|
|
829
821
|
try:
|
|
@@ -845,7 +837,7 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
|
|
|
845
837
|
headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
|
|
846
838
|
retry_on_status_codes=(502, 503, 504),
|
|
847
839
|
stream=True,
|
|
848
|
-
timeout=HF_HUB_DOWNLOAD_TIMEOUT,
|
|
840
|
+
timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
|
|
849
841
|
)
|
|
850
842
|
hf_raise_for_status(self.response)
|
|
851
843
|
try:
|
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -17,13 +17,12 @@ from typing import (
|
|
|
17
17
|
Union,
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
-
from .
|
|
20
|
+
from . import constants
|
|
21
|
+
from .errors import EntryNotFoundError, HfHubHTTPError
|
|
21
22
|
from .file_download import hf_hub_download
|
|
22
23
|
from .hf_api import HfApi
|
|
23
24
|
from .repocard import ModelCard, ModelCardData
|
|
24
25
|
from .utils import (
|
|
25
|
-
EntryNotFoundError,
|
|
26
|
-
HfHubHTTPError,
|
|
27
26
|
SoftTemporaryDirectory,
|
|
28
27
|
is_jsonable,
|
|
29
28
|
is_safetensors_available,
|
|
@@ -42,6 +41,8 @@ if is_torch_available():
|
|
|
42
41
|
import torch # type: ignore
|
|
43
42
|
|
|
44
43
|
if is_safetensors_available():
|
|
44
|
+
import packaging.version
|
|
45
|
+
import safetensors
|
|
45
46
|
from safetensors.torch import load_model as load_model_as_safetensor
|
|
46
47
|
from safetensors.torch import save_model as save_model_as_safetensor
|
|
47
48
|
|
|
@@ -417,7 +418,7 @@ class ModelHubMixin:
|
|
|
417
418
|
# Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
|
|
418
419
|
# as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
|
|
419
420
|
# an existing config.json if it was not saved by `_save_pretrained`.
|
|
420
|
-
config_path = save_directory / CONFIG_NAME
|
|
421
|
+
config_path = save_directory / constants.CONFIG_NAME
|
|
421
422
|
config_path.unlink(missing_ok=True)
|
|
422
423
|
|
|
423
424
|
# save model weights/files (framework-specific)
|
|
@@ -505,15 +506,15 @@ class ModelHubMixin:
|
|
|
505
506
|
model_id = str(pretrained_model_name_or_path)
|
|
506
507
|
config_file: Optional[str] = None
|
|
507
508
|
if os.path.isdir(model_id):
|
|
508
|
-
if CONFIG_NAME in os.listdir(model_id):
|
|
509
|
-
config_file = os.path.join(model_id, CONFIG_NAME)
|
|
509
|
+
if constants.CONFIG_NAME in os.listdir(model_id):
|
|
510
|
+
config_file = os.path.join(model_id, constants.CONFIG_NAME)
|
|
510
511
|
else:
|
|
511
|
-
logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
512
|
+
logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
512
513
|
else:
|
|
513
514
|
try:
|
|
514
515
|
config_file = hf_hub_download(
|
|
515
516
|
repo_id=model_id,
|
|
516
|
-
filename=CONFIG_NAME,
|
|
517
|
+
filename=constants.CONFIG_NAME,
|
|
517
518
|
revision=revision,
|
|
518
519
|
cache_dir=cache_dir,
|
|
519
520
|
force_download=force_download,
|
|
@@ -523,7 +524,7 @@ class ModelHubMixin:
|
|
|
523
524
|
local_files_only=local_files_only,
|
|
524
525
|
)
|
|
525
526
|
except HfHubHTTPError as e:
|
|
526
|
-
logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
|
|
527
|
+
logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
|
|
527
528
|
|
|
528
529
|
# Read config
|
|
529
530
|
config = None
|
|
@@ -767,7 +768,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
767
768
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
768
769
|
"""Save weights from a Pytorch model to a local directory."""
|
|
769
770
|
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
|
770
|
-
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
|
771
|
+
save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE))
|
|
771
772
|
|
|
772
773
|
@classmethod
|
|
773
774
|
def _from_pretrained(
|
|
@@ -789,13 +790,13 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
789
790
|
model = cls(**model_kwargs)
|
|
790
791
|
if os.path.isdir(model_id):
|
|
791
792
|
print("Loading weights from local directory")
|
|
792
|
-
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
|
793
|
+
model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE)
|
|
793
794
|
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
|
794
795
|
else:
|
|
795
796
|
try:
|
|
796
797
|
model_file = hf_hub_download(
|
|
797
798
|
repo_id=model_id,
|
|
798
|
-
filename=SAFETENSORS_SINGLE_FILE,
|
|
799
|
+
filename=constants.SAFETENSORS_SINGLE_FILE,
|
|
799
800
|
revision=revision,
|
|
800
801
|
cache_dir=cache_dir,
|
|
801
802
|
force_download=force_download,
|
|
@@ -808,7 +809,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
808
809
|
except EntryNotFoundError:
|
|
809
810
|
model_file = hf_hub_download(
|
|
810
811
|
repo_id=model_id,
|
|
811
|
-
filename=PYTORCH_WEIGHTS_NAME,
|
|
812
|
+
filename=constants.PYTORCH_WEIGHTS_NAME,
|
|
812
813
|
revision=revision,
|
|
813
814
|
cache_dir=cache_dir,
|
|
814
815
|
force_download=force_download,
|
|
@@ -821,24 +822,25 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
821
822
|
|
|
822
823
|
@classmethod
|
|
823
824
|
def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
|
824
|
-
state_dict = torch.load(model_file, map_location=torch.device(map_location))
|
|
825
|
+
state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True)
|
|
825
826
|
model.load_state_dict(state_dict, strict=strict) # type: ignore
|
|
826
827
|
model.eval() # type: ignore
|
|
827
828
|
return model
|
|
828
829
|
|
|
829
830
|
@classmethod
|
|
830
831
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
832
|
+
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined]
|
|
833
|
+
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
|
|
834
|
+
if map_location != "cpu":
|
|
835
|
+
logger.warning(
|
|
836
|
+
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
|
837
|
+
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
|
838
|
+
" This leads to a slower loading time."
|
|
839
|
+
" Please update safetensors to version 0.4.3 or above for improved performance."
|
|
840
|
+
)
|
|
841
|
+
model.to(map_location) # type: ignore [attr-defined]
|
|
842
|
+
else:
|
|
843
|
+
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type]
|
|
842
844
|
return model
|
|
843
845
|
|
|
844
846
|
|
|
@@ -53,7 +53,7 @@ from requests import HTTPError
|
|
|
53
53
|
from requests.structures import CaseInsensitiveDict
|
|
54
54
|
|
|
55
55
|
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
|
|
56
|
-
from huggingface_hub.errors import InferenceTimeoutError
|
|
56
|
+
from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
|
|
57
57
|
from huggingface_hub.inference._common import (
|
|
58
58
|
TASKS_EXPECTING_IMAGES,
|
|
59
59
|
ContentT,
|
|
@@ -100,11 +100,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
100
100
|
ZeroShotClassificationOutputElement,
|
|
101
101
|
ZeroShotImageClassificationOutputElement,
|
|
102
102
|
)
|
|
103
|
-
from huggingface_hub.inference._types import (
|
|
104
|
-
ConversationalOutput, # soon to be removed
|
|
105
|
-
)
|
|
106
103
|
from huggingface_hub.utils import (
|
|
107
|
-
BadRequestError,
|
|
108
104
|
build_hf_headers,
|
|
109
105
|
get_session,
|
|
110
106
|
hf_raise_for_status,
|
|
@@ -135,7 +131,9 @@ class InferenceClient:
|
|
|
135
131
|
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
|
136
132
|
automatically selected for the task.
|
|
137
133
|
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
|
|
138
|
-
arguments are mutually exclusive
|
|
134
|
+
arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
|
|
135
|
+
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
136
|
+
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
139
137
|
token (`str` or `bool`, *optional*):
|
|
140
138
|
Hugging Face token. Will default to the locally saved token if not provided.
|
|
141
139
|
Pass `token=False` if you don't want to send your token to the server.
|
|
@@ -149,6 +147,8 @@ class InferenceClient:
|
|
|
149
147
|
Values in this dictionary will override the default values.
|
|
150
148
|
cookies (`Dict[str, str]`, `optional`):
|
|
151
149
|
Additional cookies to send to the server.
|
|
150
|
+
proxies (`Any`, `optional`):
|
|
151
|
+
Proxies to use for the request.
|
|
152
152
|
base_url (`str`, `optional`):
|
|
153
153
|
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
|
|
154
154
|
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
|
|
@@ -175,7 +175,8 @@ class InferenceClient:
|
|
|
175
175
|
raise ValueError(
|
|
176
176
|
"Received both `model` and `base_url` arguments. Please provide only one of them."
|
|
177
177
|
" `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
|
|
178
|
-
"
|
|
178
|
+
" If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url."
|
|
179
|
+
" When passing a URL as `model`, the client will not append any suffix path to it."
|
|
179
180
|
)
|
|
180
181
|
if token is not None and api_key is not None:
|
|
181
182
|
raise ValueError(
|
|
@@ -809,133 +810,66 @@ class InferenceClient:
|
|
|
809
810
|
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
|
|
810
811
|
```
|
|
811
812
|
"""
|
|
812
|
-
|
|
813
|
-
# `self.xxx` takes precedence over the method argument only in `chat_completion`
|
|
814
|
-
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
|
|
815
|
-
# server, we need to handle it differently
|
|
816
|
-
model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
|
|
817
|
-
is_url = model.startswith(("http://", "https://"))
|
|
818
|
-
|
|
819
|
-
# First, resolve the model chat completions URL
|
|
820
|
-
if model == self.base_url:
|
|
821
|
-
# base_url passed => add server route
|
|
822
|
-
model_url = model.rstrip("/")
|
|
823
|
-
if not model_url.endswith("/v1"):
|
|
824
|
-
model_url += "/v1"
|
|
825
|
-
model_url += "/chat/completions"
|
|
826
|
-
elif is_url:
|
|
827
|
-
# model is a URL => use it directly
|
|
828
|
-
model_url = model
|
|
829
|
-
else:
|
|
830
|
-
# model is a model ID => resolve it + add server route
|
|
831
|
-
model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"
|
|
813
|
+
model_url = self._resolve_chat_completion_url(model)
|
|
832
814
|
|
|
833
815
|
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
|
|
834
816
|
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
|
|
835
|
-
model_id = model
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
stream=stream,
|
|
858
|
-
),
|
|
817
|
+
model_id = model or self.model or "tgi"
|
|
818
|
+
if model_id.startswith(("http://", "https://")):
|
|
819
|
+
model_id = "tgi" # dummy value
|
|
820
|
+
|
|
821
|
+
payload = dict(
|
|
822
|
+
model=model_id,
|
|
823
|
+
messages=messages,
|
|
824
|
+
frequency_penalty=frequency_penalty,
|
|
825
|
+
logit_bias=logit_bias,
|
|
826
|
+
logprobs=logprobs,
|
|
827
|
+
max_tokens=max_tokens,
|
|
828
|
+
n=n,
|
|
829
|
+
presence_penalty=presence_penalty,
|
|
830
|
+
response_format=response_format,
|
|
831
|
+
seed=seed,
|
|
832
|
+
stop=stop,
|
|
833
|
+
temperature=temperature,
|
|
834
|
+
tool_choice=tool_choice,
|
|
835
|
+
tool_prompt=tool_prompt,
|
|
836
|
+
tools=tools,
|
|
837
|
+
top_logprobs=top_logprobs,
|
|
838
|
+
top_p=top_p,
|
|
859
839
|
stream=stream,
|
|
860
840
|
)
|
|
841
|
+
payload = {key: value for key, value in payload.items() if value is not None}
|
|
842
|
+
data = self.post(model=model_url, json=payload, stream=stream)
|
|
861
843
|
|
|
862
844
|
if stream:
|
|
863
845
|
return _stream_chat_completion_response(data) # type: ignore[arg-type]
|
|
864
846
|
|
|
865
847
|
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
|
|
866
848
|
|
|
867
|
-
def
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
past_user_inputs: Optional[List[str]] = None,
|
|
872
|
-
*,
|
|
873
|
-
parameters: Optional[Dict[str, Any]] = None,
|
|
874
|
-
model: Optional[str] = None,
|
|
875
|
-
) -> ConversationalOutput:
|
|
876
|
-
"""
|
|
877
|
-
Generate conversational responses based on the given input text (i.e. chat with the API).
|
|
849
|
+
def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
|
|
850
|
+
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
|
|
851
|
+
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
|
|
852
|
+
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
|
|
878
853
|
|
|
879
|
-
|
|
854
|
+
# Resolve URL if it's a model ID
|
|
855
|
+
model_url = (
|
|
856
|
+
model_id_or_url
|
|
857
|
+
if model_id_or_url.startswith(("http://", "https://"))
|
|
858
|
+
else self._resolve_url(model_id_or_url, task="text-generation")
|
|
859
|
+
)
|
|
880
860
|
|
|
881
|
-
|
|
882
|
-
|
|
861
|
+
# Strip trailing /
|
|
862
|
+
model_url = model_url.rstrip("/")
|
|
883
863
|
|
|
884
|
-
|
|
864
|
+
# Append /chat/completions if not already present
|
|
865
|
+
if model_url.endswith("/v1"):
|
|
866
|
+
model_url += "/chat/completions"
|
|
885
867
|
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
generated_responses (`List[str]`, *optional*):
|
|
890
|
-
A list of strings corresponding to the earlier replies from the model. Defaults to None.
|
|
891
|
-
past_user_inputs (`List[str]`, *optional*):
|
|
892
|
-
A list of strings corresponding to the earlier replies from the user. Should be the same length as
|
|
893
|
-
`generated_responses`. Defaults to None.
|
|
894
|
-
parameters (`Dict[str, Any]`, *optional*):
|
|
895
|
-
Additional parameters for the conversational task. Defaults to None. For more details about the available
|
|
896
|
-
parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#conversational-task)
|
|
897
|
-
model (`str`, *optional*):
|
|
898
|
-
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
|
899
|
-
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
|
900
|
-
Defaults to None.
|
|
868
|
+
# Append /v1/chat/completions if not already present
|
|
869
|
+
if not model_url.endswith("/chat/completions"):
|
|
870
|
+
model_url += "/v1/chat/completions"
|
|
901
871
|
|
|
902
|
-
|
|
903
|
-
`Dict`: The generated conversational output.
|
|
904
|
-
|
|
905
|
-
Raises:
|
|
906
|
-
[`InferenceTimeoutError`]:
|
|
907
|
-
If the model is unavailable or the request times out.
|
|
908
|
-
`HTTPError`:
|
|
909
|
-
If the request fails with an HTTP error status code other than HTTP 503.
|
|
910
|
-
|
|
911
|
-
Example:
|
|
912
|
-
```py
|
|
913
|
-
>>> from huggingface_hub import InferenceClient
|
|
914
|
-
>>> client = InferenceClient()
|
|
915
|
-
>>> output = client.conversational("Hi, who are you?")
|
|
916
|
-
>>> output
|
|
917
|
-
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
|
|
918
|
-
>>> client.conversational(
|
|
919
|
-
... "Wow, that's scary!",
|
|
920
|
-
... generated_responses=output["conversation"]["generated_responses"],
|
|
921
|
-
... past_user_inputs=output["conversation"]["past_user_inputs"],
|
|
922
|
-
... )
|
|
923
|
-
```
|
|
924
|
-
"""
|
|
925
|
-
warnings.warn(
|
|
926
|
-
"'InferenceClient.conversational' is deprecated and will be removed starting from huggingface_hub>=0.25. "
|
|
927
|
-
"Please use the more appropriate 'InferenceClient.chat_completion' API instead.",
|
|
928
|
-
FutureWarning,
|
|
929
|
-
)
|
|
930
|
-
payload: Dict[str, Any] = {"inputs": {"text": text}}
|
|
931
|
-
if generated_responses is not None:
|
|
932
|
-
payload["inputs"]["generated_responses"] = generated_responses
|
|
933
|
-
if past_user_inputs is not None:
|
|
934
|
-
payload["inputs"]["past_user_inputs"] = past_user_inputs
|
|
935
|
-
if parameters is not None:
|
|
936
|
-
payload["parameters"] = parameters
|
|
937
|
-
response = self.post(json=payload, model=model, task="conversational")
|
|
938
|
-
return _bytes_to_dict(response) # type: ignore
|
|
872
|
+
return model_url
|
|
939
873
|
|
|
940
874
|
def document_question_answering(
|
|
941
875
|
self,
|
|
@@ -1727,7 +1661,8 @@ class InferenceClient:
|
|
|
1727
1661
|
repetition_penalty: Optional[float] = None,
|
|
1728
1662
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1729
1663
|
seed: Optional[int] = None,
|
|
1730
|
-
|
|
1664
|
+
stop: Optional[List[str]] = None,
|
|
1665
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1731
1666
|
temperature: Optional[float] = None,
|
|
1732
1667
|
top_k: Optional[int] = None,
|
|
1733
1668
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1756,7 +1691,8 @@ class InferenceClient:
|
|
|
1756
1691
|
repetition_penalty: Optional[float] = None,
|
|
1757
1692
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1758
1693
|
seed: Optional[int] = None,
|
|
1759
|
-
|
|
1694
|
+
stop: Optional[List[str]] = None,
|
|
1695
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1760
1696
|
temperature: Optional[float] = None,
|
|
1761
1697
|
top_k: Optional[int] = None,
|
|
1762
1698
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1785,7 +1721,8 @@ class InferenceClient:
|
|
|
1785
1721
|
repetition_penalty: Optional[float] = None,
|
|
1786
1722
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1787
1723
|
seed: Optional[int] = None,
|
|
1788
|
-
|
|
1724
|
+
stop: Optional[List[str]] = None,
|
|
1725
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1789
1726
|
temperature: Optional[float] = None,
|
|
1790
1727
|
top_k: Optional[int] = None,
|
|
1791
1728
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1814,7 +1751,8 @@ class InferenceClient:
|
|
|
1814
1751
|
repetition_penalty: Optional[float] = None,
|
|
1815
1752
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1816
1753
|
seed: Optional[int] = None,
|
|
1817
|
-
|
|
1754
|
+
stop: Optional[List[str]] = None,
|
|
1755
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1818
1756
|
temperature: Optional[float] = None,
|
|
1819
1757
|
top_k: Optional[int] = None,
|
|
1820
1758
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1843,7 +1781,8 @@ class InferenceClient:
|
|
|
1843
1781
|
repetition_penalty: Optional[float] = None,
|
|
1844
1782
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1845
1783
|
seed: Optional[int] = None,
|
|
1846
|
-
|
|
1784
|
+
stop: Optional[List[str]] = None,
|
|
1785
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1847
1786
|
temperature: Optional[float] = None,
|
|
1848
1787
|
top_k: Optional[int] = None,
|
|
1849
1788
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1871,7 +1810,8 @@ class InferenceClient:
|
|
|
1871
1810
|
repetition_penalty: Optional[float] = None,
|
|
1872
1811
|
return_full_text: Optional[bool] = False, # Manual default value
|
|
1873
1812
|
seed: Optional[int] = None,
|
|
1874
|
-
|
|
1813
|
+
stop: Optional[List[str]] = None,
|
|
1814
|
+
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
|
|
1875
1815
|
temperature: Optional[float] = None,
|
|
1876
1816
|
top_k: Optional[int] = None,
|
|
1877
1817
|
top_n_tokens: Optional[int] = None,
|
|
@@ -1936,8 +1876,10 @@ class InferenceClient:
|
|
|
1936
1876
|
Whether to prepend the prompt to the generated text
|
|
1937
1877
|
seed (`int`, *optional*):
|
|
1938
1878
|
Random sampling seed
|
|
1879
|
+
stop (`List[str]`, *optional*):
|
|
1880
|
+
Stop generating tokens if a member of `stop` is generated.
|
|
1939
1881
|
stop_sequences (`List[str]`, *optional*):
|
|
1940
|
-
|
|
1882
|
+
Deprecated argument. Use `stop` instead.
|
|
1941
1883
|
temperature (`float`, *optional*):
|
|
1942
1884
|
The value used to module the logits distribution.
|
|
1943
1885
|
top_n_tokens (`int`, *optional*):
|
|
@@ -2081,6 +2023,15 @@ class InferenceClient:
|
|
|
2081
2023
|
)
|
|
2082
2024
|
decoder_input_details = False
|
|
2083
2025
|
|
|
2026
|
+
if stop_sequences is not None:
|
|
2027
|
+
warnings.warn(
|
|
2028
|
+
"`stop_sequences` is a deprecated argument for `text_generation` task"
|
|
2029
|
+
" and will be removed in version '0.28.0'. Use `stop` instead.",
|
|
2030
|
+
FutureWarning,
|
|
2031
|
+
)
|
|
2032
|
+
if stop is None:
|
|
2033
|
+
stop = stop_sequences # use deprecated arg if provided
|
|
2034
|
+
|
|
2084
2035
|
# Build payload
|
|
2085
2036
|
parameters = {
|
|
2086
2037
|
"adapter_id": adapter_id,
|
|
@@ -2094,7 +2045,7 @@ class InferenceClient:
|
|
|
2094
2045
|
"repetition_penalty": repetition_penalty,
|
|
2095
2046
|
"return_full_text": return_full_text,
|
|
2096
2047
|
"seed": seed,
|
|
2097
|
-
"stop":
|
|
2048
|
+
"stop": stop if stop is not None else [],
|
|
2098
2049
|
"temperature": temperature,
|
|
2099
2050
|
"top_k": top_k,
|
|
2100
2051
|
"top_n_tokens": top_n_tokens,
|
|
@@ -2164,7 +2115,7 @@ class InferenceClient:
|
|
|
2164
2115
|
repetition_penalty=repetition_penalty,
|
|
2165
2116
|
return_full_text=return_full_text,
|
|
2166
2117
|
seed=seed,
|
|
2167
|
-
|
|
2118
|
+
stop=stop,
|
|
2168
2119
|
temperature=temperature,
|
|
2169
2120
|
top_k=top_k,
|
|
2170
2121
|
top_n_tokens=top_n_tokens,
|