snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,104 @@
1
+ """CodePath class for selective code packaging in model registry."""
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ _ERR_ROOT_NOT_FOUND = "CodePath: root '{root}' does not exist (resolved to {resolved})."
8
+ _ERR_WILDCARDS_NOT_SUPPORTED = "CodePath: Wildcards are not supported in filter. Got '{filter}'. Use exact paths only."
9
+ _ERR_FILTER_MUST_BE_RELATIVE = "CodePath: filter must be a relative path, got absolute path '{filter}'."
10
+ _ERR_FILTER_HOME_PATH = "CodePath: filter must be a relative path, got home directory path '{filter}'."
11
+ _ERR_FILTER_ON_FILE_ROOT = (
12
+ "CodePath: cannot apply filter to a file root. " "Root '{root}' is a file. Use filter only with directory roots."
13
+ )
14
+ _ERR_FILTER_ESCAPES_ROOT = (
15
+ "CodePath: filter '{filter}' escapes root directory '{root}'. " "Relative paths must stay within root."
16
+ )
17
+ _ERR_FILTER_NOT_FOUND = "CodePath: filter '{filter}' under root '{root}' does not exist (resolved to {resolved})."
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class CodePath:
22
+ """Specifies a code path with optional filtering for selective inclusion.
23
+
24
+ Args:
25
+ root: The root directory or file path (absolute or relative to cwd).
26
+ filter: Optional relative path under root to select a subdirectory or file.
27
+ The filter also determines the destination path under code/.
28
+
29
+ Examples:
30
+ CodePath("project/src/") # Copy entire src/ to code/src/
31
+ CodePath("project/src/", filter="utils") # Copy utils/ to code/utils/
32
+ CodePath("project/src/", filter="lib/helpers") # Copy to code/lib/helpers/
33
+ """
34
+
35
+ root: str
36
+ filter: Optional[str] = None
37
+
38
+ def __post_init__(self) -> None:
39
+ if self.filter == "":
40
+ object.__setattr__(self, "filter", None)
41
+
42
+ def __repr__(self) -> str:
43
+ if self.filter:
44
+ return f"CodePath({self.root!r}, filter={self.filter!r})"
45
+ return f"CodePath({self.root!r})"
46
+
47
+ def _validate_filter(self) -> Optional[str]:
48
+ """Validate and normalize filter, returning normalized filter or None.
49
+
50
+ Returns:
51
+ Normalized filter path, or None if no filter is set.
52
+
53
+ Raises:
54
+ ValueError: If filter contains wildcards or is an absolute path.
55
+ """
56
+ if self.filter is None:
57
+ return None
58
+
59
+ if any(c in self.filter for c in ["*", "?", "[", "]"]):
60
+ raise ValueError(_ERR_WILDCARDS_NOT_SUPPORTED.format(filter=self.filter))
61
+
62
+ if self.filter.startswith("~"):
63
+ raise ValueError(_ERR_FILTER_HOME_PATH.format(filter=self.filter))
64
+
65
+ filter_normalized = os.path.normpath(self.filter)
66
+
67
+ if os.path.isabs(filter_normalized):
68
+ raise ValueError(_ERR_FILTER_MUST_BE_RELATIVE.format(filter=self.filter))
69
+
70
+ return filter_normalized
71
+
72
+ def _resolve(self) -> tuple[str, str]:
73
+ """Resolve the source path and destination path.
74
+
75
+ Returns:
76
+ Tuple of (source_path, destination_relative_path)
77
+
78
+ Raises:
79
+ FileNotFoundError: If root or filter path does not exist.
80
+ ValueError: If filter is invalid (wildcards, absolute, escapes root, or applied to file).
81
+ """
82
+ filter_normalized = self._validate_filter()
83
+ root_normalized = os.path.normpath(os.path.abspath(self.root))
84
+
85
+ if filter_normalized is None:
86
+ if not os.path.exists(root_normalized):
87
+ raise FileNotFoundError(_ERR_ROOT_NOT_FOUND.format(root=self.root, resolved=root_normalized))
88
+ return root_normalized, os.path.basename(root_normalized)
89
+
90
+ if not os.path.exists(root_normalized):
91
+ raise FileNotFoundError(_ERR_ROOT_NOT_FOUND.format(root=self.root, resolved=root_normalized))
92
+
93
+ if os.path.isfile(root_normalized):
94
+ raise ValueError(_ERR_FILTER_ON_FILE_ROOT.format(root=self.root))
95
+
96
+ source = os.path.normpath(os.path.join(root_normalized, filter_normalized))
97
+
98
+ if not (source.startswith(root_normalized + os.sep) or source == root_normalized):
99
+ raise ValueError(_ERR_FILTER_ESCAPES_ROOT.format(filter=self.filter, root=self.root))
100
+
101
+ if not os.path.exists(source):
102
+ raise FileNotFoundError(_ERR_FILTER_NOT_FOUND.format(filter=self.filter, root=self.root, resolved=source))
103
+
104
+ return source, filter_normalized
@@ -0,0 +1,2 @@
1
+ DEFAULT_CPU_COMPUTE_POOL = "SYSTEM_COMPUTE_POOL_CPU"
2
+ DEFAULT_GPU_COMPUTE_POOL = "SYSTEM_COMPUTE_POOL_GPU"
@@ -4,10 +4,13 @@ from typing import Any, Callable, Coroutine, Generator, Optional, Union
4
4
 
5
5
  import anyio
6
6
  import pandas as pd
7
- from typing_extensions import deprecated
7
+ from typing_extensions import Concatenate, ParamSpec, deprecated
8
8
 
9
9
  from snowflake.ml.model import type_hints as model_types
10
10
 
11
+ # Captures additional keyword-only parameters for inference methods
12
+ InferenceParams = ParamSpec("InferenceParams")
13
+
11
14
 
12
15
  class MethodRef:
13
16
  """Represents a method invocation of an instance of `ModelRef`.
@@ -217,7 +220,7 @@ class CustomModel:
217
220
 
218
221
  def _get_infer_methods(
219
222
  self,
220
- ) -> Generator[Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame], None, None]:
223
+ ) -> Generator[Callable[..., pd.DataFrame], None, None]:
221
224
  """Returns all methods in CLS with `inference_api` decorator as the outermost decorator."""
222
225
  for cls_method_str in dir(self):
223
226
  cls_method = getattr(self, cls_method_str)
@@ -240,7 +243,7 @@ class CustomModel:
240
243
  return rv
241
244
 
242
245
 
243
- def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]) -> None:
246
+ def _validate_predict_function(func: Callable[..., pd.DataFrame]) -> None:
244
247
  """Validate the user provided predict method.
245
248
 
246
249
  Args:
@@ -248,19 +251,22 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
248
251
 
249
252
  Raises:
250
253
  TypeError: Raised when the method is not a callable object.
251
- TypeError: Raised when the method does not have 2 arguments (self and X).
254
+ TypeError: Raised when the method does not have at least 2 arguments (self and X).
252
255
  TypeError: Raised when the method does not have typing annotation.
253
256
  TypeError: Raised when the method's input (X) does not have type pd.DataFrame.
254
257
  TypeError: Raised when the method's output does not have type pd.DataFrame.
258
+ TypeError: Raised when additional parameters are not keyword-only with defaults.
255
259
  """
256
260
  if not callable(func):
257
261
  raise TypeError("Predict method is not callable.")
258
262
 
259
263
  func_signature = inspect.signature(func)
260
- if len(func_signature.parameters) != 2:
261
- raise TypeError("Predict method should have exact 2 arguments.")
264
+ func_signature_params = list(func_signature.parameters.values())
265
+
266
+ if len(func_signature_params) < 2:
267
+ raise TypeError("Predict method should have at least 2 arguments.")
262
268
 
263
- input_annotation = list(func_signature.parameters.values())[1].annotation
269
+ input_annotation = func_signature_params[1].annotation
264
270
  output_annotation = func_signature.return_annotation
265
271
 
266
272
  if input_annotation == inspect.Parameter.empty or output_annotation == inspect.Signature.empty:
@@ -275,17 +281,53 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
275
281
  ):
276
282
  raise TypeError("Output for predict method should have type pandas.DataFrame.")
277
283
 
284
+ # Validate additional parameters (beyond self and input) are keyword-only with defaults
285
+ for func_signature_param in func_signature_params[2:]:
286
+ _validate_parameter(func_signature_param)
287
+
288
+
289
+ def _validate_parameter(param: inspect.Parameter) -> None:
290
+ """Validate a parameter."""
291
+ if param.kind != inspect.Parameter.KEYWORD_ONLY:
292
+ raise TypeError(f"Parameter '{param.name}' must be keyword-only (defined after '*' in signature).")
293
+ if param.default == inspect.Parameter.empty:
294
+ raise TypeError(f"Parameter '{param.name}' must have a default value.")
295
+ if param.annotation == inspect.Parameter.empty:
296
+ raise TypeError(f"Parameter '{param.name}' must have a type annotation.")
297
+
298
+ # Validate annotation is a supported type
299
+ supported_types = {int, float, str, bool}
300
+ if param.annotation not in supported_types:
301
+ raise TypeError(
302
+ f"Parameter '{param.name}' has unsupported type annotation '{param.annotation}'. "
303
+ f"Supported types are: int, float, str, bool"
304
+ )
305
+
306
+
307
+ def get_method_parameters(func: Callable[..., Any]) -> list[tuple[str, Any, Any]]:
308
+ """Extract keyword-only parameters with defaults from an inference method.
309
+
310
+ Args:
311
+ func: The inference method.
312
+
313
+ Returns:
314
+ A list of tuples (name, type annotation, default value) for each keyword-only parameter.
315
+ """
316
+ func_signature = inspect.signature(func)
317
+ params = list(func_signature.parameters.values())
318
+ return [(param.name, param.annotation, param.default) for param in params[2:]]
319
+
278
320
 
279
321
  def inference_api(
280
- func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
281
- ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
322
+ func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
323
+ ) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
282
324
  func.__dict__["_is_inference_api"] = True
283
325
  return func
284
326
 
285
327
 
286
328
  def partitioned_api(
287
- func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
288
- ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
329
+ func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
330
+ ) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
289
331
  func.__dict__["_is_inference_api"] = True
290
332
  func.__dict__["_is_partitioned_api"] = True
291
333
  return func
@@ -296,8 +338,8 @@ def partitioned_api(
296
338
  " Use snowflake.ml.custom_model.partitioned_api instead."
297
339
  )
298
340
  def partitioned_inference_api(
299
- func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
300
- ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
341
+ func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
342
+ ) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
301
343
  func.__dict__["_is_inference_api"] = True
302
344
  func.__dict__["_is_partitioned_api"] = True
303
345
  return func
@@ -34,6 +34,9 @@ DataType = core.DataType
34
34
  BaseFeatureSpec = core.BaseFeatureSpec
35
35
  FeatureSpec = core.FeatureSpec
36
36
  FeatureGroupSpec = core.FeatureGroupSpec
37
+ BaseParamSpec = core.BaseParamSpec
38
+ ParamSpec = core.ParamSpec
39
+ ParamGroupSpec = core.ParamGroupSpec
37
40
  ModelSignature = core.ModelSignature
38
41
 
39
42
 
@@ -711,6 +714,7 @@ def infer_signature(
711
714
  output_feature_names: Optional[list[str]] = None,
712
715
  input_data_limit: Optional[int] = 100,
713
716
  output_data_limit: Optional[int] = 100,
717
+ params: Optional[Sequence[core.BaseParamSpec]] = None,
714
718
  ) -> core.ModelSignature:
715
719
  """
716
720
  Infer model signature from given input and output sample data.
@@ -740,12 +744,20 @@ def infer_signature(
740
744
  output_data_limit: Limit the number of rows to be used in signature inference in the output data. Defaults to
741
745
  100. If None, all rows are used. If the number of rows in the output data is less than the limit, all rows
742
746
  are used.
747
+ params: Optional sequence of parameter specifications to include in the signature. Parameters define
748
+ optional configuration values that can be passed to model inference. Defaults to None.
749
+
750
+ Raises:
751
+ SnowflakeMLException: ValueError: Raised when input data contains columns matching parameter names.
743
752
 
744
753
  Returns:
745
754
  A model signature inferred from the given input and output sample data.
755
+
756
+ # noqa: DAR402
746
757
  """
747
758
  inputs = _infer_signature(_truncate_data(input_data, input_data_limit), role="input")
748
759
  inputs = utils.rename_features(inputs, input_feature_names)
760
+
749
761
  outputs = _infer_signature(_truncate_data(output_data, output_data_limit), role="output")
750
762
  outputs = utils.rename_features(outputs, output_feature_names)
751
- return core.ModelSignature(inputs, outputs)
763
+ return core.ModelSignature(inputs, outputs, params=params)
@@ -0,0 +1,285 @@
1
+ import logging
2
+ import warnings
3
+ from typing import Any, Optional, Union
4
+
5
+ from packaging import version
6
+
7
+ from snowflake.ml._internal.utils import sql_identifier
8
+ from snowflake.ml.model.compute_pool import DEFAULT_CPU_COMPUTE_POOL
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ _TELEMETRY_PROJECT = "MLOps"
14
+ _TELEMETRY_SUBPROJECT = "ModelManagement"
15
+
16
+
17
+ class TransformersPipeline:
18
+ def __init__(
19
+ self,
20
+ task: Optional[str] = None,
21
+ model: Optional[str] = None,
22
+ *,
23
+ revision: Optional[str] = None,
24
+ token_or_secret: Optional[str] = None,
25
+ trust_remote_code: Optional[bool] = None,
26
+ model_kwargs: Optional[dict[str, Any]] = None,
27
+ compute_pool_for_log: Optional[str] = DEFAULT_CPU_COMPUTE_POOL,
28
+ # repo snapshot download args
29
+ allow_patterns: Optional[Union[list[str], str]] = None,
30
+ ignore_patterns: Optional[Union[list[str], str]] = None,
31
+ **kwargs: Any,
32
+ ) -> None:
33
+ """
34
+ Utility factory method to build a wrapper over transformers [`Pipeline`].
35
+ When deploying, this wrapper will create a real pipeline object and loading tokenizers and models.
36
+
37
+ For pipelines docs, please refer:
38
+ https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline
39
+
40
+ Args:
41
+ task: The task that pipeline will be used. If None it would be inferred from model.
42
+ For available tasks, please refer Transformers's documentation. Defaults to None.
43
+ model: The model that will be used by the pipeline to make predictions. This can only be a model identifier
44
+ currently. If not provided, the default for the `task` will be loaded. Defaults to None.
45
+ revision: When passing a task name or a string model identifier: The specific model version to use. It can
46
+ be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and
47
+ other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. Defaults to None.
48
+ token_or_secret: The token to use as HTTP bearer authorization for remote files. Defaults to None.
49
+ The token can be a token or a secret. If a secret is provided, it must a fully qualified secret name.
50
+ trust_remote_code: Whether or not to allow for custom code defined on the Hub in their own modeling,
51
+ configuration, tokenization or even pipeline files. This option should only be set to `True` for
52
+ repositories you trust and in which you have read the code, as it will execute code present on the Hub.
53
+ Defaults to None.
54
+ model_kwargs: Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,`.
55
+ Defaults to None.
56
+ compute_pool_for_log: The compute pool to use for logging the model. Defaults to DEFAULT_CPU_COMPUTE_POOL.
57
+ If a string is provided, it will be used as the compute pool name. This override allows for logging
58
+ the model when there is no system compute pool available.
59
+ If None is passed,
60
+ if `huggingface_hub` is installed, the model artifacts will be downloaded
61
+ from the HuggingFace repository.
62
+ otherwise, the only the metadata will be logged to snowflake.
63
+ allow_patterns: If provided, only files matching at least one pattern are downloaded.
64
+ ignore_patterns: If provided, files matching any of the patterns are not downloaded.
65
+ kwargs: Additional keyword arguments passed along to the specific pipeline init (see the documentation for
66
+ the corresponding pipeline class for possible values).
67
+
68
+ Raises:
69
+ RuntimeError: Raised when the input argument cannot determine the pipeline.
70
+ ValueError: Raised when the pipeline contains remote code but trust_remote_code is not set or False.
71
+ ValueError: Raised when having conflicting arguments.
72
+
73
+ .. # noqa: DAR003
74
+ """
75
+ import transformers
76
+
77
+ config = kwargs.get("config", None)
78
+ tokenizer = kwargs.get("tokenizer", None)
79
+ framework = kwargs.get("framework", None)
80
+ feature_extractor = kwargs.get("feature_extractor", None)
81
+
82
+ self.secret_identifier: Optional[str] = None
83
+ uses_secret = False
84
+ if token_or_secret is not None and isinstance(token_or_secret, str):
85
+ db, schema, secret_name = sql_identifier.parse_fully_qualified_name(token_or_secret)
86
+ if db is not None and schema is not None and secret_name is not None:
87
+ self.secret_identifier = sql_identifier.get_fully_qualified_name(
88
+ db=db,
89
+ schema=schema,
90
+ object=secret_name,
91
+ )
92
+ uses_secret = True
93
+ else:
94
+ logger.info("The token_or_secret is not a fully qualified secret name. It will be used as is.")
95
+
96
+ can_download_snapshot = False
97
+ if compute_pool_for_log is None:
98
+ try:
99
+ import huggingface_hub as hf_hub
100
+
101
+ can_download_snapshot = True
102
+ except ImportError:
103
+ pass
104
+
105
+ if compute_pool_for_log is None and not can_download_snapshot:
106
+ logger.info(
107
+ "The model will be logged with metadata only. No model artifacts will be downloaded. "
108
+ "During deployment, the model artifacts will be downloaded from the HuggingFace repository."
109
+ )
110
+
111
+ # ==== Start pipeline logic from transformers ====
112
+ if model_kwargs is None:
113
+ model_kwargs = {}
114
+
115
+ use_auth_token = model_kwargs.pop("use_auth_token", None)
116
+ if use_auth_token is not None:
117
+ warnings.warn(
118
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
119
+ FutureWarning,
120
+ stacklevel=2,
121
+ )
122
+ if token_or_secret is not None:
123
+ raise ValueError(
124
+ "`token_or_secret` and `use_auth_token` are both specified. "
125
+ "Please set only the argument `token_or_secret`."
126
+ )
127
+ token_or_secret = use_auth_token
128
+
129
+ hub_kwargs = {
130
+ "revision": revision,
131
+ "token": token_or_secret,
132
+ "trust_remote_code": trust_remote_code,
133
+ "_commit_hash": None,
134
+ }
135
+
136
+ # Backward compatibility since HF interface change.
137
+ if version.parse(transformers.__version__) < version.parse("4.32.0"):
138
+ # Backward compatibility since HF interface change.
139
+ hub_kwargs["use_auth_token"] = hub_kwargs["token"]
140
+ del hub_kwargs["token"]
141
+
142
+ if task is None and model is None:
143
+ raise RuntimeError(
144
+ "Impossible to instantiate a pipeline without either a task or a model being specified. "
145
+ )
146
+
147
+ if model is None and tokenizer is not None:
148
+ raise RuntimeError(
149
+ "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided"
150
+ " tokenizer may not be compatible with the default model. Please provide an identifier to a pretrained"
151
+ " model when providing tokenizer."
152
+ )
153
+ if model is None and feature_extractor is not None:
154
+ raise RuntimeError(
155
+ "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the "
156
+ "provided feature_extractor may not be compatible with the default model. Please provide an identifier"
157
+ " to a pretrained model when providing feature_extractor."
158
+ )
159
+
160
+ # ==== End pipeline logic from transformers ====
161
+
162
+ # We only support string as model argument.
163
+
164
+ if model is not None and not isinstance(model, str):
165
+ raise RuntimeError(f"Impossible to use non-string model as input for class {self.__class__.__name__}.")
166
+
167
+ # ==== Start pipeline logic (Config) from transformers ====
168
+
169
+ # Config is the primordial information item.
170
+ # Instantiate config if needed
171
+ config_obj = None
172
+
173
+ if not can_download_snapshot:
174
+ if isinstance(config, str):
175
+ config_obj = transformers.AutoConfig.from_pretrained(
176
+ config, _from_pipeline=task, **hub_kwargs, **model_kwargs
177
+ )
178
+ hub_kwargs["_commit_hash"] = config_obj._commit_hash
179
+ elif config is None and isinstance(model, str):
180
+ config_obj = transformers.AutoConfig.from_pretrained(
181
+ model, _from_pipeline=task, **hub_kwargs, **model_kwargs
182
+ )
183
+ hub_kwargs["_commit_hash"] = config_obj._commit_hash
184
+ # We only support string as config argument.
185
+ elif config is not None and not isinstance(config, str):
186
+ raise RuntimeError(
187
+ f"Impossible to use non-string config as input for class {self.__class__.__name__}. "
188
+ "Use transformers.Pipeline object if required."
189
+ )
190
+
191
+ # ==== Start pipeline logic (Task) from transformers ====
192
+
193
+ custom_tasks = {}
194
+ if config_obj is not None and len(getattr(config_obj, "custom_pipelines", {})) > 0:
195
+ custom_tasks = config_obj.custom_pipelines
196
+ if task is None and trust_remote_code is not False:
197
+ if len(custom_tasks) == 1:
198
+ task = list(custom_tasks.keys())[0]
199
+ else:
200
+ raise RuntimeError(
201
+ "We can't infer the task automatically for this model as there are multiple tasks available. "
202
+ f"Pick one in {', '.join(custom_tasks.keys())}"
203
+ )
204
+
205
+ if task is None and model is not None:
206
+ task = transformers.pipelines.get_task(model, token_or_secret)
207
+
208
+ # Retrieve the task
209
+ if task in custom_tasks:
210
+ normalized_task = task
211
+ targeted_task, task_options = transformers.pipelines.clean_custom_task(custom_tasks[task])
212
+ if not trust_remote_code:
213
+ raise ValueError(
214
+ "Loading this pipeline requires you to execute the code in the pipeline file in that"
215
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
216
+ " set the option `trust_remote_code=True` to remove this error."
217
+ )
218
+ else:
219
+ (
220
+ normalized_task,
221
+ targeted_task,
222
+ task_options,
223
+ ) = transformers.pipelines.check_task(task)
224
+
225
+ # ==== Start pipeline logic (Model) from transformers ====
226
+
227
+ # Use default model/config/tokenizer for the task if no model is provided
228
+ if model is None:
229
+ # At that point framework might still be undetermined
230
+ (
231
+ model,
232
+ default_revision,
233
+ ) = transformers.pipelines.get_default_model_and_revision(targeted_task, framework, task_options)
234
+ revision = revision if revision is not None else default_revision
235
+ warnings.warn(
236
+ f"No model was supplied, defaulted to {model} and revision"
237
+ f" {revision} ({transformers.pipelines.HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
238
+ "Using a pipeline without specifying a model name and revision in production is not recommended.",
239
+ stacklevel=2,
240
+ )
241
+ if not can_download_snapshot and config is None and isinstance(model, str):
242
+ config_obj = transformers.AutoConfig.from_pretrained(
243
+ model, _from_pipeline=task, **hub_kwargs, **model_kwargs
244
+ )
245
+ hub_kwargs["_commit_hash"] = config_obj._commit_hash
246
+
247
+ if kwargs.get("device_map", None) is not None:
248
+ if "device_map" in model_kwargs:
249
+ raise ValueError(
250
+ 'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
251
+ " arguments might conflict, use only one.)"
252
+ )
253
+ if kwargs.get("device", None) is not None:
254
+ warnings.warn(
255
+ "Both `device` and `device_map` are specified. `device` will override `device_map`. You"
256
+ " will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`.",
257
+ stacklevel=2,
258
+ )
259
+
260
+ repo_snapshot_dir: Optional[str] = None
261
+ if can_download_snapshot and not uses_secret:
262
+ try:
263
+ repo_snapshot_dir = hf_hub.snapshot_download(
264
+ repo_id=model,
265
+ revision=revision,
266
+ token=token_or_secret,
267
+ allow_patterns=allow_patterns,
268
+ ignore_patterns=ignore_patterns,
269
+ )
270
+ except ImportError:
271
+ logger.info("huggingface_hub package is not installed, skipping snapshot download")
272
+
273
+ # ==== End pipeline logic from transformers ====
274
+
275
+ self.model = model
276
+ self.task = normalized_task
277
+ self.revision = revision
278
+ self.token_or_secret = token_or_secret
279
+ self.trust_remote_code = trust_remote_code
280
+ self.model_kwargs = model_kwargs
281
+ self.tokenizer = tokenizer
282
+
283
+ self.repo_snapshot_dir = repo_snapshot_dir
284
+ self.compute_pool_for_log = compute_pool_for_log
285
+ self.__dict__.update(kwargs)