snowflake-ml-python 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/ml/_internal/utils/mixins.py +6 -4
  2. snowflake/ml/_internal/utils/service_logger.py +118 -4
  3. snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
  4. snowflake/ml/data/data_connector.py +4 -34
  5. snowflake/ml/dataset/dataset.py +1 -1
  6. snowflake/ml/dataset/dataset_reader.py +2 -8
  7. snowflake/ml/experiment/__init__.py +3 -0
  8. snowflake/ml/experiment/callback/lightgbm.py +55 -0
  9. snowflake/ml/experiment/callback/xgboost.py +63 -0
  10. snowflake/ml/experiment/utils.py +14 -0
  11. snowflake/ml/jobs/_utils/constants.py +15 -4
  12. snowflake/ml/jobs/_utils/payload_utils.py +159 -52
  13. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  14. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +126 -23
  15. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  16. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  17. snowflake/ml/jobs/_utils/types.py +64 -4
  18. snowflake/ml/jobs/job.py +22 -6
  19. snowflake/ml/jobs/manager.py +5 -3
  20. snowflake/ml/model/_client/model/model_version_impl.py +56 -48
  21. snowflake/ml/model/_client/ops/service_ops.py +194 -14
  22. snowflake/ml/model/_client/sql/service.py +1 -38
  23. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  24. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
  25. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  26. snowflake/ml/model/_signatures/utils.py +4 -0
  27. snowflake/ml/model/event_handler.py +87 -18
  28. snowflake/ml/model/model_signature.py +2 -0
  29. snowflake/ml/model/models/huggingface_pipeline.py +71 -49
  30. snowflake/ml/model/type_hints.py +26 -1
  31. snowflake/ml/registry/_manager/model_manager.py +30 -35
  32. snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
  33. snowflake/ml/registry/registry.py +0 -19
  34. snowflake/ml/version.py +1 -1
  35. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +542 -491
  36. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +39 -34
  37. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
  38. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
  39. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
@@ -299,6 +299,7 @@ class HuggingFacePipelineModel:
299
299
  Raises:
300
300
  ValueError: if database and schema name is not provided and session doesn't have a
301
301
  database and schema name.
302
+ exceptions.SnowparkSQLException: if service already exists.
302
303
 
303
304
  Returns:
304
305
  The service ID or an async job object.
@@ -327,7 +328,6 @@ class HuggingFacePipelineModel:
327
328
  version_name = name_generator.generate()[1]
328
329
 
329
330
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
330
- image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
331
331
 
332
332
  service_operator = service_ops.ServiceOperator(
333
333
  session=session,
@@ -336,51 +336,73 @@ class HuggingFacePipelineModel:
336
336
  )
337
337
  logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
338
338
 
339
- return service_operator.create_service(
340
- database_name=database_name_id,
341
- schema_name=schema_name_id,
342
- model_name=model_name_id,
343
- version_name=sql_identifier.SqlIdentifier(version_name),
344
- service_database_name=service_db_id,
345
- service_schema_name=service_schema_id,
346
- service_name=service_id,
347
- image_build_compute_pool_name=(
348
- sql_identifier.SqlIdentifier(image_build_compute_pool)
349
- if image_build_compute_pool
350
- else sql_identifier.SqlIdentifier(service_compute_pool)
351
- ),
352
- service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
353
- image_repo_database_name=image_repo_db_id,
354
- image_repo_schema_name=image_repo_schema_id,
355
- image_repo_name=image_repo_id,
356
- ingress_enabled=ingress_enabled,
357
- max_instances=max_instances,
358
- cpu_requests=cpu_requests,
359
- memory_requests=memory_requests,
360
- gpu_requests=gpu_requests,
361
- num_workers=num_workers,
362
- max_batch_rows=max_batch_rows,
363
- force_rebuild=force_rebuild,
364
- build_external_access_integrations=(
365
- None
366
- if build_external_access_integrations is None
367
- else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
368
- ),
369
- block=block,
370
- statement_params=statement_params,
371
- # hf model
372
- hf_model_args=service_ops.HFModelArgs(
373
- hf_model_name=self.model,
374
- hf_task=self.task,
375
- hf_tokenizer=self.tokenizer,
376
- hf_revision=self.revision,
377
- hf_token=self.token,
378
- hf_trust_remote_code=bool(self.trust_remote_code),
379
- hf_model_kwargs=self.model_kwargs,
380
- pip_requirements=pip_requirements,
381
- conda_dependencies=conda_dependencies,
382
- comment=comment,
383
- # TODO: remove warehouse in the next release
384
- warehouse=session.get_current_warehouse(),
385
- ),
386
- )
339
+ from snowflake.ml.model import event_handler
340
+ from snowflake.snowpark import exceptions
341
+
342
+ hf_event_handler = event_handler.ModelEventHandler()
343
+ with hf_event_handler.status("Creating HuggingFace model service", total=6, block=block) as status:
344
+ try:
345
+ result = service_operator.create_service(
346
+ database_name=database_name_id,
347
+ schema_name=schema_name_id,
348
+ model_name=model_name_id,
349
+ version_name=sql_identifier.SqlIdentifier(version_name),
350
+ service_database_name=service_db_id,
351
+ service_schema_name=service_schema_id,
352
+ service_name=service_id,
353
+ image_build_compute_pool_name=(
354
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
355
+ if image_build_compute_pool
356
+ else sql_identifier.SqlIdentifier(service_compute_pool)
357
+ ),
358
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
359
+ image_repo=image_repo,
360
+ ingress_enabled=ingress_enabled,
361
+ max_instances=max_instances,
362
+ cpu_requests=cpu_requests,
363
+ memory_requests=memory_requests,
364
+ gpu_requests=gpu_requests,
365
+ num_workers=num_workers,
366
+ max_batch_rows=max_batch_rows,
367
+ force_rebuild=force_rebuild,
368
+ build_external_access_integrations=(
369
+ None
370
+ if build_external_access_integrations is None
371
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
372
+ ),
373
+ block=block,
374
+ progress_status=status,
375
+ statement_params=statement_params,
376
+ # hf model
377
+ hf_model_args=service_ops.HFModelArgs(
378
+ hf_model_name=self.model,
379
+ hf_task=self.task,
380
+ hf_tokenizer=self.tokenizer,
381
+ hf_revision=self.revision,
382
+ hf_token=self.token,
383
+ hf_trust_remote_code=bool(self.trust_remote_code),
384
+ hf_model_kwargs=self.model_kwargs,
385
+ pip_requirements=pip_requirements,
386
+ conda_dependencies=conda_dependencies,
387
+ comment=comment,
388
+ # TODO: remove warehouse in the next release
389
+ warehouse=session.get_current_warehouse(),
390
+ ),
391
+ )
392
+ status.update(label="HuggingFace model service created successfully", state="complete", expanded=False)
393
+ return result
394
+ except exceptions.SnowparkSQLException as e:
395
+ # Check if the error is because the service already exists
396
+ if "already exists" in str(e).lower() or "100132" in str(
397
+ e
398
+ ): # 100132 is Snowflake error code for object already exists
399
+ # Update progress to show service already exists (preserve exception behavior)
400
+ status.update("service already exists")
401
+ status.complete() # Complete progress to full state
402
+ status.update(label="Service already exists", state="error", expanded=False)
403
+ # Re-raise the exception to preserve existing API behavior
404
+ raise
405
+ else:
406
+ # Re-raise other SQL exceptions
407
+ status.update(label="Service creation failed", state="error", expanded=False)
408
+ raise
@@ -1,5 +1,14 @@
1
1
  # mypy: disable-error-code="import"
2
- from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Literal,
6
+ Protocol,
7
+ Sequence,
8
+ TypedDict,
9
+ TypeVar,
10
+ Union,
11
+ )
3
12
 
4
13
  import numpy.typing as npt
5
14
  from typing_extensions import NotRequired
@@ -326,4 +335,20 @@ ModelLoadOption = Union[
326
335
  SupportedTargetPlatformType = Union[TargetPlatform, str]
327
336
 
328
337
 
338
+ class ProgressStatus(Protocol):
339
+ """Protocol for tracking progress during long-running operations."""
340
+
341
+ def update(self, message: str, *, state: str = "running", expanded: bool = True, **kwargs: Any) -> None:
342
+ """Update the progress status with a new message."""
343
+ ...
344
+
345
+ def increment(self) -> None:
346
+ """Increment the progress by one step."""
347
+ ...
348
+
349
+ def complete(self) -> None:
350
+ """Complete the progress bar to full state."""
351
+ ...
352
+
353
+
329
354
  __all__ = ["TargetPlatform", "Task"]
@@ -14,6 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
16
  from snowflake.ml.model._packager.model_meta import model_meta
17
+ from snowflake.ml.registry._manager import model_parameter_reconciler
17
18
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
18
19
  from snowflake.snowpark._internal import utils as snowpark_utils
19
20
 
@@ -46,6 +47,7 @@ class ModelManager:
46
47
  *,
47
48
  model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
48
49
  model_name: str,
50
+ progress_status: type_hints.ProgressStatus,
49
51
  version_name: Optional[str] = None,
50
52
  comment: Optional[str] = None,
51
53
  metrics: Optional[dict[str, Any]] = None,
@@ -64,7 +66,6 @@ class ModelManager:
64
66
  experiment_info: Optional["ExperimentInfo"] = None,
65
67
  options: Optional[type_hints.ModelSaveOption] = None,
66
68
  statement_params: Optional[dict[str, Any]] = None,
67
- progress_status: Optional[Any] = None,
68
69
  ) -> model_version_impl.ModelVersion:
69
70
 
70
71
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -158,6 +159,7 @@ class ModelManager:
158
159
  *,
159
160
  model_name: str,
160
161
  version_name: str,
162
+ progress_status: type_hints.ProgressStatus,
161
163
  comment: Optional[str] = None,
162
164
  metrics: Optional[dict[str, Any]] = None,
163
165
  conda_dependencies: Optional[list[str]] = None,
@@ -175,7 +177,6 @@ class ModelManager:
175
177
  experiment_info: Optional["ExperimentInfo"] = None,
176
178
  options: Optional[type_hints.ModelSaveOption] = None,
177
179
  statement_params: Optional[dict[str, Any]] = None,
178
- progress_status: Optional[Any] = None,
179
180
  ) -> model_version_impl.ModelVersion:
180
181
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
181
182
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -250,27 +251,27 @@ class ModelManager:
250
251
  )
251
252
  platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
252
253
 
253
- if artifact_repository_map:
254
- for channel, artifact_repository_name in artifact_repository_map.items():
255
- db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
254
+ reconciler = model_parameter_reconciler.ModelParameterReconciler(
255
+ database_name=self._database_name,
256
+ schema_name=self._schema_name,
257
+ conda_dependencies=conda_dependencies,
258
+ pip_requirements=pip_requirements,
259
+ target_platforms=target_platforms,
260
+ artifact_repository_map=artifact_repository_map,
261
+ options=options,
262
+ )
256
263
 
257
- artifact_repository_map[channel] = sql_identifier.get_fully_qualified_name(
258
- db_id,
259
- schema_id,
260
- repo_id,
261
- self._database_name,
262
- self._schema_name,
263
- )
264
+ model_params = reconciler.reconcile()
265
+
266
+ # Use reconciled parameters
267
+ artifact_repository_map = model_params.artifact_repository_map
268
+ save_location = model_params.save_location
264
269
 
265
270
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
266
- if progress_status:
267
- progress_status.update("packaging model...")
268
- progress_status.increment()
269
-
270
- # Extract save_location from options if present
271
- save_location = None
272
- if options and "save_location" in options:
273
- save_location = options.get("save_location")
271
+ progress_status.update("packaging model...")
272
+ progress_status.increment()
273
+
274
+ if save_location:
274
275
  logger.info(f"Model will be saved to local directory: {save_location}")
275
276
 
276
277
  mc = model_composer.ModelComposer(
@@ -280,9 +281,8 @@ class ModelManager:
280
281
  save_location=save_location,
281
282
  )
282
283
 
283
- if progress_status:
284
- progress_status.update("creating model manifest...")
285
- progress_status.increment()
284
+ progress_status.update("creating model manifest...")
285
+ progress_status.increment()
286
286
 
287
287
  model_metadata: model_meta.ModelMetadata = mc.save(
288
288
  name=model_name_id.resolved(),
@@ -303,9 +303,8 @@ class ModelManager:
303
303
  experiment_info=experiment_info,
304
304
  )
305
305
 
306
- if progress_status:
307
- progress_status.update("uploading model files...")
308
- progress_status.increment()
306
+ progress_status.update("uploading model files...")
307
+ progress_status.increment()
309
308
  statement_params = telemetry.add_statement_params_custom_tags(
310
309
  statement_params, model_metadata.telemetry_metadata()
311
310
  )
@@ -313,10 +312,8 @@ class ModelManager:
313
312
  statement_params, {"model_version_name": version_name_id}
314
313
  )
315
314
 
316
- logger.info("Start creating MODEL object for you in the Snowflake.")
317
- if progress_status:
318
- progress_status.update("creating model object in Snowflake...")
319
- progress_status.increment()
315
+ progress_status.update("creating model object in Snowflake...")
316
+ progress_status.increment()
320
317
 
321
318
  self._model_ops.create_from_stage(
322
319
  composed_model=mc,
@@ -343,9 +340,8 @@ class ModelManager:
343
340
  version_name=version_name_id,
344
341
  )
345
342
 
346
- if progress_status:
347
- progress_status.update("setting model metadata...")
348
- progress_status.increment()
343
+ progress_status.update("setting model metadata...")
344
+ progress_status.increment()
349
345
 
350
346
  if comment:
351
347
  mv.comment = comment
@@ -360,8 +356,7 @@ class ModelManager:
360
356
  statement_params=statement_params,
361
357
  )
362
358
 
363
- if progress_status:
364
- progress_status.update("model logged successfully!")
359
+ progress_status.update("model logged successfully!")
365
360
 
366
361
  return mv
367
362
 
@@ -0,0 +1,105 @@
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ from snowflake.ml._internal.utils import sql_identifier
6
+ from snowflake.ml.model import type_hints as model_types
7
+
8
+
9
+ @dataclass
10
+ class ReconciledParameters:
11
+ """Holds the reconciled and validated parameters after processing."""
12
+
13
+ conda_dependencies: Optional[list[str]] = None
14
+ pip_requirements: Optional[list[str]] = None
15
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None
16
+ artifact_repository_map: Optional[dict[str, str]] = None
17
+ options: Optional[model_types.ModelSaveOption] = None
18
+ save_location: Optional[str] = None
19
+
20
+
21
+ class ModelParameterReconciler:
22
+ """Centralizes all complex log_model parameter validation, transformation, and reconciliation logic."""
23
+
24
+ def __init__(
25
+ self,
26
+ database_name: sql_identifier.SqlIdentifier,
27
+ schema_name: sql_identifier.SqlIdentifier,
28
+ conda_dependencies: Optional[list[str]] = None,
29
+ pip_requirements: Optional[list[str]] = None,
30
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
31
+ artifact_repository_map: Optional[dict[str, str]] = None,
32
+ options: Optional[model_types.ModelSaveOption] = None,
33
+ ) -> None:
34
+ self._database_name = database_name
35
+ self._schema_name = schema_name
36
+ self._conda_dependencies = conda_dependencies
37
+ self._pip_requirements = pip_requirements
38
+ self._target_platforms = target_platforms
39
+ self._artifact_repository_map = artifact_repository_map
40
+ self._options = options
41
+
42
+ def reconcile(self) -> ReconciledParameters:
43
+ """Perform all parameter reconciliation and return clean parameters."""
44
+ reconciled_artifact_repository_map = self._reconcile_artifact_repository_map()
45
+ reconciled_save_location = self._extract_save_location()
46
+
47
+ self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map)
48
+
49
+ return ReconciledParameters(
50
+ conda_dependencies=self._conda_dependencies,
51
+ pip_requirements=self._pip_requirements,
52
+ target_platforms=self._target_platforms,
53
+ artifact_repository_map=reconciled_artifact_repository_map,
54
+ options=self._options,
55
+ save_location=reconciled_save_location,
56
+ )
57
+
58
+ def _reconcile_artifact_repository_map(self) -> Optional[dict[str, str]]:
59
+ """Transform artifact_repository_map to use fully qualified names."""
60
+ if not self._artifact_repository_map:
61
+ return None
62
+
63
+ transformed_map = {}
64
+
65
+ for channel, artifact_repository_name in self._artifact_repository_map.items():
66
+ db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
67
+
68
+ transformed_map[channel] = sql_identifier.get_fully_qualified_name(
69
+ db_id,
70
+ schema_id,
71
+ repo_id,
72
+ self._database_name,
73
+ self._schema_name,
74
+ )
75
+
76
+ return transformed_map
77
+
78
+ def _extract_save_location(self) -> Optional[str]:
79
+ """Extract save_location from options."""
80
+ if self._options and "save_location" in self._options:
81
+ return self._options.get("save_location")
82
+
83
+ return None
84
+
85
+ def _validate_pip_requirements_warehouse_compatibility(
86
+ self, artifact_repository_map: Optional[dict[str, str]]
87
+ ) -> None:
88
+ """Validate pip_requirements compatibility with warehouse deployment."""
89
+ if self._pip_requirements and not artifact_repository_map and self._targets_warehouse(self._target_platforms):
90
+ warnings.warn(
91
+ "Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
92
+ "without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
93
+ "Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
94
+ category=UserWarning,
95
+ stacklevel=1,
96
+ )
97
+
98
+ @staticmethod
99
+ def _targets_warehouse(target_platforms: Optional[list[model_types.SupportedTargetPlatformType]]) -> bool:
100
+ """Returns True if warehouse is a target platform (None defaults to True)."""
101
+ return (
102
+ target_platforms is None
103
+ or model_types.TargetPlatform.WAREHOUSE in target_platforms
104
+ or "WAREHOUSE" in target_platforms
105
+ )
@@ -1,4 +1,3 @@
1
- import warnings
2
1
  from types import ModuleType
3
2
  from typing import Any, Optional, Union, overload
4
3
 
@@ -442,15 +441,6 @@ class Registry:
442
441
  if task is not type_hints.Task.UNKNOWN:
443
442
  raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
444
443
 
445
- if pip_requirements and not artifact_repository_map and self._targets_warehouse(target_platforms):
446
- warnings.warn(
447
- "Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
448
- "without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
449
- "Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
450
- category=UserWarning,
451
- stacklevel=1,
452
- )
453
-
454
444
  registry_event_handler = event_handler.ModelEventHandler()
455
445
  with registry_event_handler.status("Logging model", total=6) as status:
456
446
  # Step 1: Validation and setup
@@ -662,12 +652,3 @@ class Registry:
662
652
  if not self.enable_monitoring:
663
653
  raise ValueError(_MODEL_MONITORING_DISABLED_ERROR)
664
654
  self._model_monitor_manager.delete_monitor(name)
665
-
666
- @staticmethod
667
- def _targets_warehouse(target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]) -> bool:
668
- """Returns True if warehouse is a target platform (None defaults to True)."""
669
- return (
670
- target_platforms is None
671
- or type_hints.TargetPlatform.WAREHOUSE in target_platforms
672
- or "WAREHOUSE" in target_platforms
673
- )
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.9.1"
2
+ VERSION = "1.10.0"