snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,19 @@
1
1
  import logging
2
- import warnings
3
2
  from typing import Any, Optional, Union
4
3
 
5
- from packaging import version
6
-
7
4
  from snowflake import snowpark
8
5
  from snowflake.ml._internal import telemetry
9
6
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
7
  from snowflake.ml._internal.utils import sql_identifier
11
8
  from snowflake.ml.model._client.model import inference_engine_utils
12
9
  from snowflake.ml.model._client.ops import service_ops
10
+ from snowflake.ml.model.models import huggingface
13
11
  from snowflake.snowpark import async_job, session
14
12
 
15
13
  logger = logging.getLogger(__name__)
16
14
 
17
15
 
18
- _TELEMETRY_PROJECT = "MLOps"
19
- _TELEMETRY_SUBPROJECT = "ModelManagement"
20
-
21
-
22
- class HuggingFacePipelineModel:
16
+ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
23
17
  def __init__(
24
18
  self,
25
19
  task: Optional[str] = None,
@@ -65,208 +59,25 @@ class HuggingFacePipelineModel:
65
59
 
66
60
  Return:
67
61
  A wrapper over transformers [`Pipeline`].
68
-
69
- Raises:
70
- RuntimeError: Raised when the input argument cannot determine the pipeline.
71
- ValueError: Raised when the pipeline contains remote code but trust_remote_code is not set or False.
72
- ValueError: Raised when having conflicting arguments.
73
62
  """
74
- import transformers
75
-
76
- config = kwargs.get("config", None)
77
- tokenizer = kwargs.get("tokenizer", None)
78
- framework = kwargs.get("framework", None)
79
- feature_extractor = kwargs.get("feature_extractor", None)
80
-
81
- _can_download_snapshot = False
82
- if download_snapshot:
83
- try:
84
- import huggingface_hub as hf_hub
85
-
86
- _can_download_snapshot = True
87
- except ImportError:
88
- pass
89
-
90
- # ==== Start pipeline logic from transformers ====
91
- if model_kwargs is None:
92
- model_kwargs = {}
93
-
94
- use_auth_token = model_kwargs.pop("use_auth_token", None)
95
- if use_auth_token is not None:
96
- warnings.warn(
97
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
98
- FutureWarning,
99
- stacklevel=2,
100
- )
101
- if token is not None:
102
- raise ValueError(
103
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
104
- )
105
- token = use_auth_token
106
-
107
- hub_kwargs = {
108
- "revision": revision,
109
- "token": token,
110
- "trust_remote_code": trust_remote_code,
111
- "_commit_hash": None,
112
- }
113
-
114
- # Backward compatibility since HF interface change.
115
- if version.parse(transformers.__version__) < version.parse("4.32.0"):
116
- # Backward compatibility since HF interface change.
117
- hub_kwargs["use_auth_token"] = hub_kwargs["token"]
118
- del hub_kwargs["token"]
119
-
120
- if task is None and model is None:
121
- raise RuntimeError(
122
- "Impossible to instantiate a pipeline without either a task or a model being specified. "
123
- )
124
-
125
- if model is None and tokenizer is not None:
126
- raise RuntimeError(
127
- "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided"
128
- " tokenizer may not be compatible with the default model. Please provide an identifier to a pretrained"
129
- " model when providing tokenizer."
130
- )
131
- if model is None and feature_extractor is not None:
132
- raise RuntimeError(
133
- "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the "
134
- "provided feature_extractor may not be compatible with the default model. Please provide an identifier"
135
- " to a pretrained model when providing feature_extractor."
136
- )
137
-
138
- # ==== End pipeline logic from transformers ====
139
-
140
- # We only support string as model argument.
141
-
142
- if model is not None and not isinstance(model, str):
143
- raise RuntimeError(
144
- "Impossible to use non-string model as input for HuggingFacePipelineModel. Use transformers.Pipeline"
145
- " object if required."
146
- )
147
-
148
- # ==== Start pipeline logic (Config) from transformers ====
149
-
150
- # Config is the primordial information item.
151
- # Instantiate config if needed
152
- config_obj = None
153
-
154
- if not _can_download_snapshot:
155
- if isinstance(config, str):
156
- config_obj = transformers.AutoConfig.from_pretrained(
157
- config, _from_pipeline=task, **hub_kwargs, **model_kwargs
158
- )
159
- hub_kwargs["_commit_hash"] = config_obj._commit_hash
160
- elif config is None and isinstance(model, str):
161
- config_obj = transformers.AutoConfig.from_pretrained(
162
- model, _from_pipeline=task, **hub_kwargs, **model_kwargs
163
- )
164
- hub_kwargs["_commit_hash"] = config_obj._commit_hash
165
- # We only support string as config argument.
166
- elif config is not None and not isinstance(config, str):
167
- raise RuntimeError(
168
- "Impossible to use non-string config as input for HuggingFacePipelineModel. "
169
- "Use transformers.Pipeline object if required."
170
- )
171
-
172
- # ==== Start pipeline logic (Task) from transformers ====
173
-
174
- custom_tasks = {}
175
- if config_obj is not None and len(getattr(config_obj, "custom_pipelines", {})) > 0:
176
- custom_tasks = config_obj.custom_pipelines
177
- if task is None and trust_remote_code is not False:
178
- if len(custom_tasks) == 1:
179
- task = list(custom_tasks.keys())[0]
180
- else:
181
- raise RuntimeError(
182
- "We can't infer the task automatically for this model as there are multiple tasks available. "
183
- f"Pick one in {', '.join(custom_tasks.keys())}"
184
- )
185
-
186
- if task is None and model is not None:
187
- task = transformers.pipelines.get_task(model, token)
188
-
189
- # Retrieve the task
190
- if task in custom_tasks:
191
- normalized_task = task
192
- targeted_task, task_options = transformers.pipelines.clean_custom_task(custom_tasks[task])
193
- if not trust_remote_code:
194
- raise ValueError(
195
- "Loading this pipeline requires you to execute the code in the pipeline file in that"
196
- " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
197
- " set the option `trust_remote_code=True` to remove this error."
198
- )
199
- else:
200
- (
201
- normalized_task,
202
- targeted_task,
203
- task_options,
204
- ) = transformers.pipelines.check_task(task)
205
-
206
- # ==== Start pipeline logic (Model) from transformers ====
207
-
208
- # Use default model/config/tokenizer for the task if no model is provided
209
- if model is None:
210
- # At that point framework might still be undetermined
211
- (
212
- model,
213
- default_revision,
214
- ) = transformers.pipelines.get_default_model_and_revision(targeted_task, framework, task_options)
215
- revision = revision if revision is not None else default_revision
216
- warnings.warn(
217
- f"No model was supplied, defaulted to {model} and revision"
218
- f" {revision} ({transformers.pipelines.HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
219
- "Using a pipeline without specifying a model name and revision in production is not recommended.",
220
- stacklevel=2,
221
- )
222
- if not _can_download_snapshot and config is None and isinstance(model, str):
223
- config_obj = transformers.AutoConfig.from_pretrained(
224
- model, _from_pipeline=task, **hub_kwargs, **model_kwargs
225
- )
226
- hub_kwargs["_commit_hash"] = config_obj._commit_hash
227
-
228
- if kwargs.get("device_map", None) is not None:
229
- if "device_map" in model_kwargs:
230
- raise ValueError(
231
- 'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
232
- " arguments might conflict, use only one.)"
233
- )
234
- if kwargs.get("device", None) is not None:
235
- warnings.warn(
236
- "Both `device` and `device_map` are specified. `device` will override `device_map`. You"
237
- " will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`.",
238
- stacklevel=2,
239
- )
240
-
241
- repo_snapshot_dir: Optional[str] = None
242
- if _can_download_snapshot:
243
- try:
244
-
245
- repo_snapshot_dir = hf_hub.snapshot_download(
246
- repo_id=model,
247
- revision=revision,
248
- token=token,
249
- allow_patterns=allow_patterns,
250
- ignore_patterns=ignore_patterns,
251
- )
252
- except ImportError:
253
- logger.info("huggingface_hub package is not installed, skipping snapshot download")
254
-
255
- # ==== End pipeline logic from transformers ====
256
-
257
- self.task = normalized_task
258
- self.model = model
259
- self.revision = revision
63
+ logger.warning("HuggingFacePipelineModel is deprecated. Please use TransformersPipeline instead.")
64
+ super().__init__(
65
+ task=task,
66
+ model=model,
67
+ revision=revision,
68
+ token_or_secret=token,
69
+ trust_remote_code=trust_remote_code,
70
+ model_kwargs=model_kwargs,
71
+ compute_pool_for_log=None,
72
+ allow_patterns=allow_patterns,
73
+ ignore_patterns=ignore_patterns,
74
+ **kwargs,
75
+ )
260
76
  self.token = token
261
- self.trust_remote_code = trust_remote_code
262
- self.model_kwargs = model_kwargs
263
- self.tokenizer = tokenizer
264
- self.repo_snapshot_dir = repo_snapshot_dir
265
- self.__dict__.update(kwargs)
266
77
 
267
78
  @telemetry.send_api_usage_telemetry(
268
- project=_TELEMETRY_PROJECT,
269
- subproject=_TELEMETRY_SUBPROJECT,
79
+ project=huggingface._TELEMETRY_PROJECT,
80
+ subproject=huggingface._TELEMETRY_SUBPROJECT,
270
81
  func_params_to_log=[
271
82
  "service_name",
272
83
  "image_build_compute_pool",
@@ -345,8 +156,8 @@ class HuggingFacePipelineModel:
345
156
  .. # noqa: DAR003
346
157
  """
347
158
  statement_params = telemetry.get_statement_params(
348
- project=_TELEMETRY_PROJECT,
349
- subproject=_TELEMETRY_SUBPROJECT,
159
+ project=huggingface._TELEMETRY_PROJECT,
160
+ subproject=huggingface._TELEMETRY_SUBPROJECT,
350
161
  )
351
162
 
352
163
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
@@ -13,6 +13,11 @@ from typing import (
13
13
  import numpy.typing as npt
14
14
  from typing_extensions import NotRequired
15
15
 
16
+ from snowflake.ml.model.code_path import CodePath
17
+ from snowflake.ml.model.compute_pool import (
18
+ DEFAULT_CPU_COMPUTE_POOL,
19
+ DEFAULT_GPU_COMPUTE_POOL,
20
+ )
16
21
  from snowflake.ml.model.target_platform import TargetPlatform
17
22
  from snowflake.ml.model.task import Task
18
23
  from snowflake.ml.model.volatility import Volatility
@@ -362,6 +367,7 @@ ModelLoadOption = Union[
362
367
 
363
368
 
364
369
  SupportedTargetPlatformType = Union[TargetPlatform, str]
370
+ CodePathLike = Union[str, CodePath]
365
371
 
366
372
 
367
373
  class ProgressStatus(Protocol):
@@ -380,4 +386,4 @@ class ProgressStatus(Protocol):
380
386
  ...
381
387
 
382
388
 
383
- __all__ = ["TargetPlatform", "Task"]
389
+ __all__ = ["TargetPlatform", "Task", "DEFAULT_CPU_COMPUTE_POOL", "DEFAULT_GPU_COMPUTE_POOL"]
@@ -365,8 +365,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
365
365
 
366
366
  required_deps = dependencies + [
367
367
  "snowflake-snowpark-python<2",
368
- "fastparquet<2023.11",
369
- "pyarrow<14",
368
+ "fastparquet<2024.3",
369
+ "pyarrow<18",
370
370
  "cachetools<6",
371
371
  ]
372
372
 
@@ -92,6 +92,9 @@ class ModelMonitorSQLClient:
92
92
  baseline: Optional[sql_identifier.SqlIdentifier] = None,
93
93
  segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
94
94
  custom_metric_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
95
+ timestamp_custom_metric_database: Optional[sql_identifier.SqlIdentifier] = None,
96
+ timestamp_custom_metric_schema: Optional[sql_identifier.SqlIdentifier] = None,
97
+ timestamp_custom_metric_table: Optional[sql_identifier.SqlIdentifier] = None,
95
98
  statement_params: Optional[dict[str, Any]] = None,
96
99
  ) -> None:
97
100
  baseline_sql = ""
@@ -106,6 +109,14 @@ class ModelMonitorSQLClient:
106
109
  if custom_metric_columns:
107
110
  custom_metric_columns_sql = f"CUSTOM_METRIC_COLUMNS={_build_sql_list_from_columns(custom_metric_columns)}"
108
111
 
112
+ timestamp_custom_metric_table_sql = ""
113
+ if timestamp_custom_metric_table:
114
+ timestamp_custom_metric_table_sql = (
115
+ f"TIMESTAMP_CUSTOM_METRIC_TABLE="
116
+ f"{self._infer_qualified_schema(timestamp_custom_metric_database, timestamp_custom_metric_schema)}."
117
+ f"{timestamp_custom_metric_table}"
118
+ )
119
+
109
120
  query_result_checker.SqlResultValidator(
110
121
  self._sql_client._session,
111
122
  f"""
@@ -126,6 +137,7 @@ class ModelMonitorSQLClient:
126
137
  AGGREGATION_WINDOW='{aggregation_window}'
127
138
  {segment_columns_sql}
128
139
  {custom_metric_columns_sql}
140
+ {timestamp_custom_metric_table_sql}
129
141
  {baseline_sql}""",
130
142
  statement_params=statement_params,
131
143
  ).has_column("status").has_dimensions(1, 1).validate()
@@ -100,6 +100,15 @@ class ModelMonitorManager:
100
100
  if source_config.baseline
101
101
  else (None, None, None)
102
102
  )
103
+ (
104
+ timestamp_custom_metric_database_name_id,
105
+ timestamp_custom_metric_schema_name_id,
106
+ timestamp_custom_metric_table_name_id,
107
+ ) = (
108
+ sql_identifier.parse_fully_qualified_name(source_config.timestamp_custom_metric_table)
109
+ if source_config.timestamp_custom_metric_table
110
+ else (None, None, None)
111
+ )
103
112
  model_database_name_id, model_schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(
104
113
  model_monitor_config.model_version.fully_qualified_model_name
105
114
  )
@@ -155,6 +164,9 @@ class ModelMonitorManager:
155
164
  baseline_database=baseline_database_name_id,
156
165
  baseline_schema=baseline_schema_name_id,
157
166
  baseline=baseline_name_id,
167
+ timestamp_custom_metric_database=timestamp_custom_metric_database_name_id,
168
+ timestamp_custom_metric_schema=timestamp_custom_metric_schema_name_id,
169
+ timestamp_custom_metric_table=timestamp_custom_metric_table_name_id,
158
170
  statement_params=self.statement_params,
159
171
  )
160
172
  return model_monitor.ModelMonitor._ref(
@@ -39,6 +39,11 @@ class ModelMonitorSourceConfig:
39
39
  custom_metric_columns: Optional[list[str]] = None
40
40
  """List of columns in the source containing custom metrics."""
41
41
 
42
+ timestamp_custom_metric_table: Optional[str] = None
43
+ """Optional name of a table containing timestamp-based custom metrics.
44
+ Can be specified unqualified or fully qualified as database.schema.table.
45
+ """
46
+
42
47
 
43
48
  @dataclass
44
49
  class ModelMonitorConfig:
@@ -1,8 +1,10 @@
1
+ import json
1
2
  import logging
2
3
  from types import ModuleType
3
4
  from typing import TYPE_CHECKING, Any, Optional, Union
4
5
 
5
6
  import pandas as pd
7
+ import yaml
6
8
 
7
9
  from snowflake.ml._internal import platform_capabilities, telemetry
8
10
  from snowflake.ml._internal.exceptions import error_codes, exceptions
@@ -11,8 +13,13 @@ from snowflake.ml._internal.utils import sql_identifier
11
13
  from snowflake.ml.model import model_signature, task, type_hints
12
14
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
15
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
16
+ from snowflake.ml.model._client.service import (
17
+ import_model_spec_schema,
18
+ model_deployment_spec_schema,
19
+ )
14
20
  from snowflake.ml.model._model_composer import model_composer
15
21
  from snowflake.ml.model._packager.model_meta import model_meta
22
+ from snowflake.ml.model.models import huggingface
16
23
  from snowflake.ml.registry._manager import model_parameter_reconciler
17
24
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
18
25
  from snowflake.snowpark._internal import utils as snowpark_utils
@@ -59,7 +66,7 @@ class ModelManager:
59
66
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
60
67
  sample_input_data: Optional[type_hints.SupportedDataType] = None,
61
68
  user_files: Optional[dict[str, list[str]]] = None,
62
- code_paths: Optional[list[str]] = None,
69
+ code_paths: Optional[list[type_hints.CodePathLike]] = None,
63
70
  ext_modules: Optional[list[ModuleType]] = None,
64
71
  task: type_hints.Task = task.Task.UNKNOWN,
65
72
  experiment_info: Optional["ExperimentInfo"] = None,
@@ -170,7 +177,7 @@ class ModelManager:
170
177
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
171
178
  sample_input_data: Optional[type_hints.SupportedDataType] = None,
172
179
  user_files: Optional[dict[str, list[str]]] = None,
173
- code_paths: Optional[list[str]] = None,
180
+ code_paths: Optional[list[type_hints.CodePathLike]] = None,
174
181
  ext_modules: Optional[list[ModuleType]] = None,
175
182
  task: type_hints.Task = task.Task.UNKNOWN,
176
183
  experiment_info: Optional["ExperimentInfo"] = None,
@@ -180,6 +187,31 @@ class ModelManager:
180
187
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
181
188
  version_name_id = sql_identifier.SqlIdentifier(version_name)
182
189
 
190
+ # Check if model is HuggingFace TransformersPipeline with no repo_snapshot_dir
191
+ # If so, use remote logging via SYSTEM$IMPORT_MODEL
192
+ if (
193
+ isinstance(model, huggingface.TransformersPipeline)
194
+ and model.compute_pool_for_log is not None
195
+ and (not hasattr(model, "repo_snapshot_dir") or model.repo_snapshot_dir is None)
196
+ ):
197
+ logger.info("HuggingFace model has compute_pool_for_log, using remote logging")
198
+ return self._remote_log_huggingface_model(
199
+ model=model,
200
+ model_name=model_name,
201
+ version_name=version_name,
202
+ database_name_id=database_name_id,
203
+ schema_name_id=schema_name_id,
204
+ model_name_id=model_name_id,
205
+ version_name_id=version_name_id,
206
+ comment=comment,
207
+ conda_dependencies=conda_dependencies,
208
+ pip_requirements=pip_requirements,
209
+ target_platforms=target_platforms,
210
+ options=options,
211
+ statement_params=statement_params,
212
+ progress_status=progress_status,
213
+ )
214
+
183
215
  # TODO(SNOW-2091317): Remove this when the snowpark enables file PUT operation for snowurls
184
216
  use_live_commit = (
185
217
  not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
@@ -298,19 +330,11 @@ class ModelManager:
298
330
  use_live_commit=use_live_commit,
299
331
  )
300
332
 
301
- mv = model_version_impl.ModelVersion._ref(
302
- model_ops=model_ops.ModelOperator(
303
- self._model_ops._session,
304
- database_name=database_name_id or self._database_name,
305
- schema_name=schema_name_id or self._schema_name,
306
- ),
307
- service_ops=service_ops.ServiceOperator(
308
- self._service_ops._session,
309
- database_name=database_name_id or self._database_name,
310
- schema_name=schema_name_id or self._schema_name,
311
- ),
312
- model_name=model_name_id,
313
- version_name=version_name_id,
333
+ mv = self._create_model_version_ref(
334
+ database_name_id=database_name_id,
335
+ schema_name_id=schema_name_id,
336
+ model_name_id=model_name_id,
337
+ version_name_id=version_name_id,
314
338
  )
315
339
 
316
340
  progress_status.update("setting model metadata...")
@@ -333,6 +357,73 @@ class ModelManager:
333
357
 
334
358
  return mv
335
359
 
360
+ def _remote_log_huggingface_model(
361
+ self,
362
+ model: huggingface.TransformersPipeline,
363
+ model_name: str,
364
+ version_name: str,
365
+ database_name_id: Optional[sql_identifier.SqlIdentifier],
366
+ schema_name_id: Optional[sql_identifier.SqlIdentifier],
367
+ model_name_id: sql_identifier.SqlIdentifier,
368
+ version_name_id: sql_identifier.SqlIdentifier,
369
+ comment: Optional[str],
370
+ conda_dependencies: Optional[list[str]],
371
+ pip_requirements: Optional[list[str]],
372
+ target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]],
373
+ options: Optional[type_hints.ModelSaveOption],
374
+ statement_params: Optional[dict[str, Any]],
375
+ progress_status: type_hints.ProgressStatus,
376
+ ) -> model_version_impl.ModelVersion:
377
+ """Log HuggingFace model remotely using SYSTEM$IMPORT_MODEL."""
378
+ if not isinstance(model, huggingface.TransformersPipeline):
379
+ raise ValueError(
380
+ f"Model must be a TransformersPipeline object. The provided model is a {type(model)} object"
381
+ )
382
+ progress_status.update("preparing remote model logging...")
383
+ progress_status.increment()
384
+
385
+ # Get compute pool from options or use default
386
+ compute_pool = model.compute_pool_for_log
387
+ if compute_pool is None:
388
+ raise ValueError("compute_pool_for_log is required for remote logging")
389
+
390
+ # Construct fully qualified model name
391
+ db_name = database_name_id.identifier() if database_name_id else self._database_name.identifier()
392
+ schema_name = schema_name_id.identifier() if schema_name_id else self._schema_name.identifier()
393
+ fq_model_name = f"{db_name}.{schema_name}.{model_name_id.identifier()}"
394
+
395
+ # Build YAML spec for import model
396
+ yaml_content = self._build_import_model_yaml_spec(
397
+ model=model,
398
+ fq_model_name=fq_model_name,
399
+ version_name=version_name,
400
+ compute_pool=compute_pool,
401
+ comment=comment,
402
+ conda_dependencies=conda_dependencies,
403
+ pip_requirements=pip_requirements,
404
+ target_platforms=target_platforms,
405
+ )
406
+
407
+ progress_status.update("Remotely logging the model...")
408
+ progress_status.increment()
409
+
410
+ self._model_ops.run_import_model_query(
411
+ database_name=db_name,
412
+ schema_name=schema_name,
413
+ yaml_content=yaml_content,
414
+ statement_params=statement_params,
415
+ )
416
+ progress_status.update("Remotely logged the model")
417
+ progress_status.increment()
418
+
419
+ # Return ModelVersion object
420
+ return self._create_model_version_ref(
421
+ database_name_id=database_name_id,
422
+ schema_name_id=schema_name_id,
423
+ model_name_id=model_name_id,
424
+ version_name_id=version_name_id,
425
+ )
426
+
336
427
  def get_model(
337
428
  self,
338
429
  model_name: str,
@@ -408,6 +499,130 @@ class ModelManager:
408
499
  statement_params=statement_params,
409
500
  )
410
501
 
502
+ def _create_model_version_ref(
503
+ self,
504
+ database_name_id: Optional[sql_identifier.SqlIdentifier],
505
+ schema_name_id: Optional[sql_identifier.SqlIdentifier],
506
+ model_name_id: sql_identifier.SqlIdentifier,
507
+ version_name_id: sql_identifier.SqlIdentifier,
508
+ ) -> model_version_impl.ModelVersion:
509
+ """Create a ModelVersion reference object.
510
+
511
+ Args:
512
+ database_name_id: Database name identifier, falls back to instance database if None.
513
+ schema_name_id: Schema name identifier, falls back to instance schema if None.
514
+ model_name_id: Model name identifier.
515
+ version_name_id: Version name identifier.
516
+
517
+ Returns:
518
+ ModelVersion reference object.
519
+ """
520
+ return model_version_impl.ModelVersion._ref(
521
+ model_ops=model_ops.ModelOperator(
522
+ self._model_ops._session,
523
+ database_name=database_name_id or self._database_name,
524
+ schema_name=schema_name_id or self._schema_name,
525
+ ),
526
+ service_ops=service_ops.ServiceOperator(
527
+ self._service_ops._session,
528
+ database_name=database_name_id or self._database_name,
529
+ schema_name=schema_name_id or self._schema_name,
530
+ ),
531
+ model_name=model_name_id,
532
+ version_name=version_name_id,
533
+ )
534
+
535
+ def _build_import_model_yaml_spec(
536
+ self,
537
+ model: huggingface.TransformersPipeline,
538
+ fq_model_name: str,
539
+ version_name: str,
540
+ compute_pool: str,
541
+ comment: Optional[str],
542
+ conda_dependencies: Optional[list[str]],
543
+ pip_requirements: Optional[list[str]],
544
+ target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]],
545
+ ) -> str:
546
+ """Build YAML spec for SYSTEM$IMPORT_MODEL.
547
+
548
+ Args:
549
+ model: HuggingFace TransformersPipeline model.
550
+ fq_model_name: Fully qualified model name.
551
+ version_name: Model version name.
552
+ compute_pool: Compute pool name.
553
+ comment: Optional comment for the model.
554
+ conda_dependencies: Optional conda dependencies.
555
+ pip_requirements: Optional pip requirements.
556
+ target_platforms: Optional target platforms.
557
+
558
+ Returns:
559
+ YAML string representing the import model spec.
560
+ """
561
+ # Convert target_platforms to list of strings
562
+ target_platforms_list = self._convert_target_platforms_to_list(target_platforms)
563
+
564
+ # Build HuggingFaceModel spec
565
+ hf_model = model_deployment_spec_schema.HuggingFaceModel(
566
+ hf_model_name=model.model,
567
+ task=model.task,
568
+ tokenizer=getattr(model, "tokenizer", None),
569
+ token_secret_object=model.secret_identifier,
570
+ trust_remote_code=model.trust_remote_code if model.trust_remote_code is not None else False,
571
+ revision=model.revision,
572
+ hf_model_kwargs=json.dumps(model.model_kwargs) if model.model_kwargs else "{}",
573
+ )
574
+
575
+ # Build LogModelArgs
576
+ log_model_args = model_deployment_spec_schema.LogModelArgs(
577
+ pip_requirements=pip_requirements,
578
+ conda_dependencies=conda_dependencies,
579
+ target_platforms=target_platforms_list,
580
+ comment=comment,
581
+ )
582
+
583
+ # Build ModelSpec
584
+ model_spec = import_model_spec_schema.ModelSpec(
585
+ name=import_model_spec_schema.ModelName(
586
+ model_name=fq_model_name,
587
+ version_name=version_name,
588
+ ),
589
+ hf_model=hf_model,
590
+ log_model_args=log_model_args,
591
+ )
592
+
593
+ # Build ImportModelSpec
594
+ import_spec = import_model_spec_schema.ImportModelSpec(
595
+ compute_pool=compute_pool,
596
+ models=[model_spec],
597
+ )
598
+
599
+ # Convert to YAML
600
+ return yaml.safe_dump(import_spec.model_dump(exclude_none=True))
601
+
602
+ def _convert_target_platforms_to_list(
603
+ self, target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]
604
+ ) -> Optional[list[str]]:
605
+ """Convert target_platforms to list of strings.
606
+
607
+ Args:
608
+ target_platforms: List of target platforms (enums or strings).
609
+
610
+ Returns:
611
+ List of platform strings, or None if input is None.
612
+ """
613
+ if not target_platforms:
614
+ return None
615
+
616
+ target_platforms_list = []
617
+ for tp in target_platforms:
618
+ if hasattr(tp, "value"):
619
+ # It's an enum, get the value
620
+ target_platforms_list.append(tp.value)
621
+ else:
622
+ # It's already a string
623
+ target_platforms_list.append(str(tp))
624
+ return target_platforms_list
625
+
411
626
  def _parse_fully_qualified_name(
412
627
  self, model_name: str
413
628
  ) -> tuple[