huggingface-hub 0.12.0rc0__py3-none-any.whl → 0.13.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. huggingface_hub/__init__.py +166 -126
  2. huggingface_hub/_commit_api.py +25 -51
  3. huggingface_hub/_login.py +4 -13
  4. huggingface_hub/_snapshot_download.py +45 -23
  5. huggingface_hub/_space_api.py +7 -0
  6. huggingface_hub/commands/delete_cache.py +13 -39
  7. huggingface_hub/commands/env.py +1 -3
  8. huggingface_hub/commands/huggingface_cli.py +1 -3
  9. huggingface_hub/commands/lfs.py +4 -8
  10. huggingface_hub/commands/scan_cache.py +5 -16
  11. huggingface_hub/commands/user.py +27 -45
  12. huggingface_hub/community.py +4 -4
  13. huggingface_hub/constants.py +22 -19
  14. huggingface_hub/fastai_utils.py +14 -23
  15. huggingface_hub/file_download.py +210 -121
  16. huggingface_hub/hf_api.py +500 -255
  17. huggingface_hub/hub_mixin.py +181 -176
  18. huggingface_hub/inference_api.py +4 -10
  19. huggingface_hub/keras_mixin.py +39 -71
  20. huggingface_hub/lfs.py +8 -24
  21. huggingface_hub/repocard.py +33 -48
  22. huggingface_hub/repocard_data.py +141 -30
  23. huggingface_hub/repository.py +41 -112
  24. huggingface_hub/templates/modelcard_template.md +39 -34
  25. huggingface_hub/utils/__init__.py +1 -0
  26. huggingface_hub/utils/_cache_assets.py +1 -4
  27. huggingface_hub/utils/_cache_manager.py +17 -39
  28. huggingface_hub/utils/_deprecation.py +8 -12
  29. huggingface_hub/utils/_errors.py +10 -57
  30. huggingface_hub/utils/_fixes.py +2 -6
  31. huggingface_hub/utils/_git_credential.py +5 -16
  32. huggingface_hub/utils/_headers.py +22 -11
  33. huggingface_hub/utils/_http.py +1 -4
  34. huggingface_hub/utils/_paths.py +5 -12
  35. huggingface_hub/utils/_runtime.py +2 -1
  36. huggingface_hub/utils/_telemetry.py +120 -0
  37. huggingface_hub/utils/_validators.py +5 -13
  38. huggingface_hub/utils/endpoint_helpers.py +1 -3
  39. huggingface_hub/utils/logging.py +10 -8
  40. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/METADATA +7 -14
  41. huggingface_hub-0.13.0rc0.dist-info/RECORD +56 -0
  42. huggingface_hub/py.typed +0 -0
  43. huggingface_hub-0.12.0rc0.dist-info/RECORD +0 -56
  44. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/LICENSE +0 -0
  45. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/WHEEL +0 -0
  46. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/entry_points.txt +0 -0
  47. {huggingface_hub-0.12.0rc0.dist-info → huggingface_hub-0.13.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  import os
3
+ import warnings
3
4
  from pathlib import Path
4
- from typing import Dict, List, Optional, Union
5
+ from typing import Dict, List, Optional, Type, TypeVar, Union
5
6
 
6
7
  import requests
7
8
 
@@ -9,6 +10,7 @@ from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
9
10
  from .file_download import hf_hub_download, is_torch_available
10
11
  from .hf_api import HfApi
11
12
  from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
13
+ from .utils._deprecation import _deprecate_positional_args
12
14
 
13
15
 
14
16
  if is_torch_available():
@@ -16,149 +18,131 @@ if is_torch_available():
16
18
 
17
19
  logger = logging.get_logger(__name__)
18
20
 
21
+ # Generic variable that is either ModelHubMixin or a subclass thereof
22
+ T = TypeVar("T", bound="ModelHubMixin")
23
+
19
24
 
20
25
  class ModelHubMixin:
21
26
  """
22
- A generic Hub mixin for machine learning models. Define your own mixin for
23
- any framework by inheriting from this class and overwriting the
24
- [`_from_pretrained`] and [`_save_pretrained`] methods to define custom logic
25
- for saving and loading your classes. See [`PyTorchModelHubMixin`] for an
26
- example.
27
+ A generic mixin to integrate ANY machine learning framework with the Hub.
28
+
29
+ To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
30
+ have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
31
+ of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
27
32
  """
28
33
 
34
+ @_deprecate_positional_args(version="0.16")
29
35
  def save_pretrained(
30
36
  self,
31
37
  save_directory: Union[str, Path],
38
+ *,
32
39
  config: Optional[dict] = None,
40
+ repo_id: Optional[str] = None,
33
41
  push_to_hub: bool = False,
34
42
  **kwargs,
35
- ):
43
+ ) -> Optional[str]:
36
44
  """
37
45
  Save weights in local directory.
38
46
 
39
- Parameters:
47
+ Args:
40
48
  save_directory (`str` or `Path`):
41
- Specify directory in which you want to save weights.
49
+ Path to directory in which the model weights and configuration will be saved.
42
50
  config (`dict`, *optional*):
43
- Specify config (must be dict) in case you want to save
44
- it.
51
+ Model configuration specified as a key/value dictionary.
45
52
  push_to_hub (`bool`, *optional*, defaults to `False`):
46
- Whether or not to push your model to the Hugging Face Hub after
47
- saving it. You can specify the repository you want to push to with
48
- `repo_id` (will default to the name of `save_directory` in your
49
- namespace).
53
+ Whether or not to push your model to the Huggingface Hub after saving it.
54
+ repo_id (`str`, *optional*):
55
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
56
+ not provided.
50
57
  kwargs:
51
- Additional key word arguments passed along to the
52
- [`~utils.PushToHubMixin.push_to_hub`] method.
58
+ Additional key word arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
53
59
  """
54
- os.makedirs(save_directory, exist_ok=True)
60
+ save_directory = Path(save_directory)
61
+ save_directory.mkdir(parents=True, exist_ok=True)
55
62
 
56
63
  # saving model weights/files
57
64
  self._save_pretrained(save_directory)
58
65
 
59
66
  # saving config
60
67
  if isinstance(config, dict):
61
- path = os.path.join(save_directory, CONFIG_NAME)
62
- with open(path, "w") as f:
63
- json.dump(config, f)
68
+ (save_directory / CONFIG_NAME).write_text(json.dumps(config))
64
69
 
65
70
  if push_to_hub:
66
71
  kwargs = kwargs.copy() # soft-copy to avoid mutating input
67
72
  if config is not None: # kwarg for `push_to_hub`
68
73
  kwargs["config"] = config
74
+ if repo_id is None:
75
+ repo_id = save_directory.name # Defaults to `save_directory` name
76
+ return self.push_to_hub(repo_id=repo_id, **kwargs)
77
+ return None
69
78
 
70
- if kwargs.get("repo_id") is None:
71
- # Repo name defaults to `save_directory` name
72
- kwargs["repo_id"] = Path(save_directory).name
73
-
74
- return self.push_to_hub(**kwargs)
75
-
76
- def _save_pretrained(self, save_directory: Union[str, Path]):
79
+ def _save_pretrained(self, save_directory: Path) -> None:
77
80
  """
78
81
  Overwrite this method in subclass to define how to save your model.
82
+ Check out our [integration guide](../guides/integrations) for instructions.
83
+
84
+ Args:
85
+ save_directory (`str` or `Path`):
86
+ Path to directory in which the model weights and configuration will be saved.
79
87
  """
80
88
  raise NotImplementedError
81
89
 
82
90
  @classmethod
83
91
  @validate_hf_hub_args
92
+ @_deprecate_positional_args(version="0.16")
84
93
  def from_pretrained(
85
- cls,
86
- pretrained_model_name_or_path: str,
94
+ cls: Type[T],
95
+ pretrained_model_name_or_path: Union[str, Path],
96
+ *,
87
97
  force_download: bool = False,
88
98
  resume_download: bool = False,
89
99
  proxies: Optional[Dict] = None,
90
100
  token: Optional[Union[str, bool]] = None,
91
- cache_dir: Optional[str] = None,
101
+ cache_dir: Optional[Union[str, Path]] = None,
92
102
  local_files_only: bool = False,
103
+ revision: Optional[str] = None,
93
104
  **model_kwargs,
94
- ):
95
- r"""
96
- Download and instantiate a model from the Hugging Face Hub.
97
-
98
- Parameters:
99
- pretrained_model_name_or_path (`str` or `os.PathLike`):
100
- Can be either:
101
- - A string, the `model id` of a pretrained model
102
- hosted inside a model repo on huggingface.co.
103
- Valid model ids can be located at the root-level,
104
- like `bert-base-uncased`, or namespaced under a
105
- user or organization name, like
106
- `dbmdz/bert-base-german-cased`.
107
- - You can add `revision` by appending `@` at the end
108
- of model_id simply like this:
109
- `dbmdz/bert-base-german-cased@main` Revision is
110
- the specific model version to use. It can be a
111
- branch name, a tag name, or a commit id, since we
112
- use a git-based system for storing models and
113
- other artifacts on huggingface.co, so `revision`
114
- can be any identifier allowed by git.
115
- - A path to a `directory` containing model weights
116
- saved using
117
- [`~transformers.PreTrainedModel.save_pretrained`],
118
- e.g., `./my_model_directory/`.
119
- - `None` if you are both providing the configuration
120
- and state dictionary (resp. with keyword arguments
121
- `config` and `state_dict`).
122
- force_download (`bool`, *optional*, defaults to `False`):
123
- Whether to force the (re-)download of the model weights
124
- and configuration files, overriding the cached versions
125
- if they exist.
126
- resume_download (`bool`, *optional*, defaults to `False`):
127
- Whether to delete incompletely received files. Will
128
- attempt to resume the download if such a file exists.
129
- proxies (`Dict[str, str]`, *optional*):
130
- A dictionary of proxy servers to use by protocol or
131
- endpoint, e.g., `{'http': 'foo.bar:3128',
132
- 'http://hostname': 'foo.bar:4012'}`. The proxies are
133
- used on each request.
134
- token (`str` or `bool`, *optional*):
135
- The token to use as HTTP bearer authorization for remote
136
- files. If `True`, will use the token generated when
137
- running `transformers-cli login` (stored in
138
- `~/.huggingface`).
139
- cache_dir (`Union[str, os.PathLike]`, *optional*):
140
- Path to a directory in which a downloaded pretrained
141
- model configuration should be cached if the standard
142
- cache should not be used.
143
- local_files_only(`bool`, *optional*, defaults to `False`):
144
- Whether to only look at local files (i.e., do not try to
145
- download the model).
146
- model_kwargs (`Dict`, *optional*):
147
- model_kwargs will be passed to the model during
148
- initialization
149
-
150
- <Tip>
151
-
152
- Passing `token=True` is required when you want to use a
153
- private model.
154
-
155
- </Tip>
105
+ ) -> T:
106
+ """
107
+ Download a model from the Huggingface Hub and instantiate it.
108
+
109
+ Args:
110
+ pretrained_model_name_or_path (`str`, `Path`):
111
+ - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
112
+ - Or a path to a `directory` containing model weights saved using
113
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
114
+ revision (`str`, *optional*):
115
+ Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
116
+ Defaults to the latest commit on `main` branch.
117
+ force_download (`bool`, *optional*, defaults to `False`):
118
+ Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
119
+ the existing cache.
120
+ resume_download (`bool`, *optional*, defaults to `False`):
121
+ Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
122
+ proxies (`Dict[str, str]`, *optional*):
123
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
124
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
125
+ token (`str` or `bool`, *optional*):
126
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
127
+ cached when running `huggingface-cli login`.
128
+ cache_dir (`str`, `Path`, *optional*):
129
+ Path to the folder where cached files are stored.
130
+ local_files_only (`bool`, *optional*, defaults to `False`):
131
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
132
+ model_kwargs (`Dict`, *optional*):
133
+ Additional kwargs to pass to the model during initialization.
156
134
  """
157
-
158
135
  model_id = pretrained_model_name_or_path
159
136
 
160
- revision = None
161
- if len(model_id.split("@")) == 2:
137
+ if isinstance(model_id, str) and len(model_id.split("@")) == 2:
138
+ warnings.warn(
139
+ (
140
+ "Passing a revision using 'namespace/model_id@revision' pattern is"
141
+ " deprecated and will be removed in version v0.16. Please pass"
142
+ " 'revision=...' as argument."
143
+ ),
144
+ FutureWarning,
145
+ )
162
146
  model_id, revision = model_id.split("@")
163
147
 
164
148
  config_file: Optional[str] = None
@@ -167,10 +151,10 @@ class ModelHubMixin:
167
151
  config_file = os.path.join(model_id, CONFIG_NAME)
168
152
  else:
169
153
  logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
170
- else:
154
+ elif isinstance(model_id, str):
171
155
  try:
172
156
  config_file = hf_hub_download(
173
- repo_id=model_id,
157
+ repo_id=str(model_id),
174
158
  filename=CONFIG_NAME,
175
159
  revision=revision,
176
160
  cache_dir=cache_dir,
@@ -181,7 +165,7 @@ class ModelHubMixin:
181
165
  local_files_only=local_files_only,
182
166
  )
183
167
  except requests.exceptions.RequestException:
184
- logger.warning(f"{CONFIG_NAME} not found in HuggingFace Hub")
168
+ logger.warning(f"{CONFIG_NAME} not found in HuggingFace Hub.")
185
169
 
186
170
  if config_file is not None:
187
171
  with open(config_file, "r", encoding="utf-8") as f:
@@ -189,32 +173,65 @@ class ModelHubMixin:
189
173
  model_kwargs.update({"config": config})
190
174
 
191
175
  return cls._from_pretrained(
192
- model_id,
193
- revision,
194
- cache_dir,
195
- force_download,
196
- proxies,
197
- resume_download,
198
- local_files_only,
199
- token,
176
+ model_id=model_id,
177
+ revision=revision,
178
+ cache_dir=cache_dir,
179
+ force_download=force_download,
180
+ proxies=proxies,
181
+ resume_download=resume_download,
182
+ local_files_only=local_files_only,
183
+ token=token,
200
184
  **model_kwargs,
201
185
  )
202
186
 
203
187
  @classmethod
188
+ @_deprecate_positional_args(version="0.16")
204
189
  def _from_pretrained(
205
- cls,
206
- model_id,
207
- revision,
208
- cache_dir,
209
- force_download,
210
- proxies,
211
- resume_download,
212
- local_files_only,
213
- token,
190
+ cls: Type[T],
191
+ *,
192
+ model_id: str,
193
+ revision: Optional[str],
194
+ cache_dir: Optional[Union[str, Path]],
195
+ force_download: bool,
196
+ proxies: Optional[Dict],
197
+ resume_download: bool,
198
+ local_files_only: bool,
199
+ token: Optional[Union[str, bool]],
214
200
  **model_kwargs,
215
- ):
216
- """Overwrite this method in subclass to define how to load your model from
217
- pretrained"""
201
+ ) -> T:
202
+ """Overwrite this method in subclass to define how to load your model from pretrained.
203
+
204
+ Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
205
+ args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
206
+ method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
207
+ parameter to set on which device the model should be loaded.
208
+
209
+ Check out our [integration guide](../guides/integrations) for more instructions.
210
+
211
+ Args:
212
+ model_id (`str`):
213
+ ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
214
+ revision (`str`, *optional*):
215
+ Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
216
+ latest commit on `main` branch.
217
+ force_download (`bool`, *optional*, defaults to `False`):
218
+ Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
219
+ the existing cache.
220
+ resume_download (`bool`, *optional*, defaults to `False`):
221
+ Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
222
+ proxies (`Dict[str, str]`, *optional*):
223
+ A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
224
+ 'http://hostname': 'foo.bar:4012'}`).
225
+ token (`str` or `bool`, *optional*):
226
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
227
+ cached when running `huggingface-cli login`.
228
+ cache_dir (`str`, `Path`, *optional*):
229
+ Path to the folder where cached files are stored.
230
+ local_files_only (`bool`, *optional*, defaults to `False`):
231
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
232
+ model_kwargs:
233
+ Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
234
+ """
218
235
  raise NotImplementedError
219
236
 
220
237
  @validate_hf_hub_args
@@ -231,15 +248,18 @@ class ModelHubMixin:
231
248
  create_pr: Optional[bool] = None,
232
249
  allow_patterns: Optional[Union[List[str], str]] = None,
233
250
  ignore_patterns: Optional[Union[List[str], str]] = None,
251
+ delete_patterns: Optional[Union[List[str], str]] = None,
234
252
  ) -> str:
235
253
  """
236
254
  Upload model checkpoint to the Hub.
237
255
 
238
- Use `allow_patterns` and `ignore_patterns` to precisely filter which files
239
- should be pushed to the hub. See [`upload_folder`] reference for more details.
256
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
257
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
258
+ details.
240
259
 
241
- Parameters:
242
- repo_id (`str`, *optional*):
260
+
261
+ Args:
262
+ repo_id (`str`):
243
263
  Repository name to which push.
244
264
  config (`dict`, *optional*):
245
265
  Configuration object to be saved alongside the model weights.
@@ -250,32 +270,24 @@ class ModelHubMixin:
250
270
  api_endpoint (`str`, *optional*):
251
271
  The API endpoint to use when pushing the model to the hub.
252
272
  token (`str`, *optional*):
253
- The token to use as HTTP bearer authorization for remote files.
254
- If not set, will use the token set when logging in with
255
- `transformers-cli login` (stored in `~/.huggingface`).
273
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
274
+ cached when running `huggingface-cli login`.
256
275
  branch (`str`, *optional*):
257
- The git branch on which to push the model. This defaults to
258
- the default branch as specified in your repository, which
259
- defaults to `"main"`.
276
+ The git branch on which to push the model. This defaults to `"main"`.
260
277
  create_pr (`boolean`, *optional*):
261
- Whether or not to create a Pull Request from `branch` with that commit.
262
- Defaults to `False`.
278
+ Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
263
279
  allow_patterns (`List[str]` or `str`, *optional*):
264
280
  If provided, only files matching at least one pattern are pushed.
265
281
  ignore_patterns (`List[str]` or `str`, *optional*):
266
282
  If provided, files matching any of the patterns are not pushed.
283
+ delete_patterns (`List[str]` or `str`, *optional*):
284
+ If provided, remote files matching any of the patterns will be deleted from the repo.
267
285
 
268
286
  Returns:
269
287
  The url of the commit of your model in the given repository.
270
288
  """
271
- api = HfApi(endpoint=api_endpoint)
272
- api.create_repo(
273
- repo_id=repo_id,
274
- repo_type="model",
275
- token=token,
276
- private=private,
277
- exist_ok=True,
278
- )
289
+ api = HfApi(endpoint=api_endpoint, token=token)
290
+ repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
279
291
 
280
292
  # Push the files to the repo in a single commit
281
293
  with SoftTemporaryDirectory() as tmp:
@@ -284,23 +296,21 @@ class ModelHubMixin:
284
296
  return api.upload_folder(
285
297
  repo_id=repo_id,
286
298
  repo_type="model",
287
- token=token,
288
299
  folder_path=saved_path,
289
300
  commit_message=commit_message,
290
301
  revision=branch,
291
302
  create_pr=create_pr,
292
303
  allow_patterns=allow_patterns,
293
304
  ignore_patterns=ignore_patterns,
305
+ delete_patterns=delete_patterns,
294
306
  )
295
307
 
296
308
 
297
309
  class PyTorchModelHubMixin(ModelHubMixin):
298
310
  """
299
- Implementation of [`ModelHubMixin`] to provide model Hub upload/download
300
- capabilities to PyTorch models. The model is set in evaluation mode by
301
- default using `model.eval()` (dropout modules are deactivated). To train
302
- the model, you should first set it back in training mode with
303
- `model.train()`.
311
+ Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
312
+ is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
313
+ you should first set it back in training mode with `model.train()`.
304
314
 
305
315
  Example:
306
316
 
@@ -318,47 +328,42 @@ class PyTorchModelHubMixin(ModelHubMixin):
318
328
 
319
329
  ... def forward(self, x):
320
330
  ... return self.linear(x + self.param)
321
-
322
-
323
331
  >>> model = MyModel()
324
- >>> # Save model weights to local directory
332
+
333
+ # Save model weights to local directory
325
334
  >>> model.save_pretrained("my-awesome-model")
326
- >>> # Push model weights to the Hub
335
+
336
+ # Push model weights to the Hub
327
337
  >>> model.push_to_hub("my-awesome-model")
328
- >>> # Download and initialize weights from the Hub
338
+
339
+ # Download and initialize weights from the Hub
329
340
  >>> model = MyModel.from_pretrained("username/my-awesome-model")
330
341
  ```
331
342
  """
332
343
 
333
- def _save_pretrained(self, save_directory):
334
- """
335
- Overwrite this method if you wish to save specific layers instead of the
336
- complete model.
337
- """
338
- path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
339
- model_to_save = self.module if hasattr(self, "module") else self
340
- torch.save(model_to_save.state_dict(), path)
344
+ def _save_pretrained(self, save_directory: Path) -> None:
345
+ """Save weights from a Pytorch model to a local directory."""
346
+ model_to_save = self.module if hasattr(self, "module") else self # type: ignore
347
+ torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
341
348
 
342
349
  @classmethod
350
+ @_deprecate_positional_args(version="0.16")
343
351
  def _from_pretrained(
344
352
  cls,
345
- model_id,
346
- revision,
347
- cache_dir,
348
- force_download,
349
- proxies,
350
- resume_download,
351
- local_files_only,
352
- token,
353
- map_location="cpu",
354
- strict=False,
353
+ *,
354
+ model_id: str,
355
+ revision: str,
356
+ cache_dir: str,
357
+ force_download: bool,
358
+ proxies: Optional[Dict],
359
+ resume_download: bool,
360
+ local_files_only: bool,
361
+ token: Union[str, bool, None],
362
+ map_location: str = "cpu",
363
+ strict: bool = False,
355
364
  **model_kwargs,
356
365
  ):
357
- """
358
- Overwrite this method to initialize your model in a different way.
359
- """
360
- map_location = torch.device(map_location)
361
-
366
+ """Load Pytorch pretrained weights and return the loaded model."""
362
367
  if os.path.isdir(model_id):
363
368
  print("Loading weights from local directory")
364
369
  model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
@@ -376,8 +381,8 @@ class PyTorchModelHubMixin(ModelHubMixin):
376
381
  )
377
382
  model = cls(**model_kwargs)
378
383
 
379
- state_dict = torch.load(model_file, map_location=map_location)
380
- model.load_state_dict(state_dict, strict=strict)
381
- model.eval()
384
+ state_dict = torch.load(model_file, map_location=torch.device(map_location))
385
+ model.load_state_dict(state_dict, strict=strict) # type: ignore
386
+ model.eval() # type: ignore
382
387
 
383
388
  return model
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union
3
3
 
4
4
  import requests
5
5
 
6
+ from .constants import INFERENCE_ENDPOINT
6
7
  from .hf_api import HfApi
7
8
  from .utils import build_hf_headers, is_pillow_available, logging, validate_hf_hub_args
8
9
 
@@ -10,8 +11,6 @@ from .utils import build_hf_headers, is_pillow_available, logging, validate_hf_h
10
11
  logger = logging.get_logger(__name__)
11
12
 
12
13
 
13
- ENDPOINT = "https://api-inference.huggingface.co"
14
-
15
14
  ALL_TASKS = [
16
15
  # NLP
17
16
  "text-classification",
@@ -142,14 +141,11 @@ class InferenceApi:
142
141
  assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None"
143
142
  self.task = model_info.pipeline_tag
144
143
 
145
- self.api_url = f"{ENDPOINT}/pipeline/{self.task}/{repo_id}"
144
+ self.api_url = f"{INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
146
145
 
147
146
  def __repr__(self):
148
147
  # Do not add headers to repr to avoid leaking token.
149
- return (
150
- f"InferenceAPI(api_url='{self.api_url}', task='{self.task}',"
151
- f" options={self.options})"
152
- )
148
+ return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})"
153
149
 
154
150
  def __call__(
155
151
  self,
@@ -183,9 +179,7 @@ class InferenceApi:
183
179
  payload["parameters"] = params
184
180
 
185
181
  # Make API call
186
- response = requests.post(
187
- self.api_url, headers=self.headers, json=payload, data=data
188
- )
182
+ response = requests.post(self.api_url, headers=self.headers, json=payload, data=data)
189
183
 
190
184
  # Let the user handle the response
191
185
  if raw_response: