snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +54 -42
  5. snowflake/ml/_internal/utils/service_logger.py +105 -3
  6. snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
  7. snowflake/ml/data/data_connector.py +13 -2
  8. snowflake/ml/data/data_ingestor.py +8 -0
  9. snowflake/ml/data/torch_utils.py +1 -1
  10. snowflake/ml/dataset/dataset.py +2 -1
  11. snowflake/ml/dataset/dataset_reader.py +14 -4
  12. snowflake/ml/experiment/__init__.py +3 -0
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/callback.py +121 -0
  20. snowflake/ml/experiment/experiment_tracking.py +319 -0
  21. snowflake/ml/jobs/_utils/constants.py +15 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +156 -54
  23. snowflake/ml/jobs/_utils/query_helper.py +16 -5
  24. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  25. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
  26. snowflake/ml/jobs/_utils/spec_utils.py +23 -8
  27. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  28. snowflake/ml/jobs/_utils/types.py +64 -4
  29. snowflake/ml/jobs/job.py +70 -75
  30. snowflake/ml/jobs/manager.py +59 -31
  31. snowflake/ml/lineage/lineage_node.py +2 -2
  32. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  33. snowflake/ml/model/_client/ops/service_ops.py +336 -137
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
  36. snowflake/ml/model/_client/sql/service.py +1 -38
  37. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  38. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  41. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  45. snowflake/ml/model/_signatures/utils.py +4 -0
  46. snowflake/ml/model/event_handler.py +117 -0
  47. snowflake/ml/model/model_signature.py +11 -9
  48. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  49. snowflake/ml/modeling/framework/base.py +1 -1
  50. snowflake/ml/modeling/metrics/classification.py +14 -14
  51. snowflake/ml/modeling/metrics/correlation.py +19 -8
  52. snowflake/ml/modeling/metrics/ranking.py +6 -6
  53. snowflake/ml/modeling/metrics/regression.py +9 -9
  54. snowflake/ml/monitoring/explain_visualize.py +12 -5
  55. snowflake/ml/registry/_manager/model_manager.py +32 -15
  56. snowflake/ml/registry/registry.py +48 -80
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
  59. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
  60. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  61. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
1
+ from typing import Optional
2
+
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
5
+ from snowflake.ml.utils import sql_client
6
+ from snowflake.snowpark import row, session
7
+
8
+
9
+ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
10
+
11
+ RUN_NAME_COL_NAME = "name"
12
+ RUN_METADATA_COL_NAME = "metadata"
13
+
14
+ def __init__(
15
+ self,
16
+ session: session.Session,
17
+ *,
18
+ database_name: sql_identifier.SqlIdentifier,
19
+ schema_name: sql_identifier.SqlIdentifier,
20
+ ) -> None:
21
+ """Snowflake SQL Client to manage experiment tracking.
22
+
23
+ Args:
24
+ session: Active snowpark session.
25
+ database_name: Name of the Database where experiment tracking resources are provisioned.
26
+ schema_name: Name of the Schema where experiment tracking resources are provisioned.
27
+ """
28
+ super().__init__(session, database_name=database_name, schema_name=schema_name)
29
+
30
+ def create_experiment(
31
+ self,
32
+ experiment_name: sql_identifier.SqlIdentifier,
33
+ creation_mode: sql_client.CreationMode,
34
+ ) -> None:
35
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
36
+ if_not_exists_sql = "IF NOT EXISTS" if creation_mode.if_not_exists else ""
37
+ query_result_checker.SqlResultValidator(
38
+ self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
39
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
40
+
41
+ def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
42
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
43
+ query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
44
+ expected_rows=1, expected_cols=1
45
+ ).validate()
46
+
47
+ def add_run(
48
+ self,
49
+ *,
50
+ experiment_name: sql_identifier.SqlIdentifier,
51
+ run_name: sql_identifier.SqlIdentifier,
52
+ live: bool = True,
53
+ ) -> None:
54
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
55
+ query_result_checker.SqlResultValidator(
56
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD {'LIVE' if live else ''} RUN {run_name}"
57
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
58
+
59
+ def commit_run(
60
+ self,
61
+ *,
62
+ experiment_name: sql_identifier.SqlIdentifier,
63
+ run_name: sql_identifier.SqlIdentifier,
64
+ ) -> None:
65
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
66
+ query_result_checker.SqlResultValidator(
67
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
68
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
69
+
70
+ def drop_run(
71
+ self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
72
+ ) -> None:
73
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
74
+ query_result_checker.SqlResultValidator(
75
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
76
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
77
+
78
+ def modify_run(
79
+ self,
80
+ *,
81
+ experiment_name: sql_identifier.SqlIdentifier,
82
+ run_name: sql_identifier.SqlIdentifier,
83
+ run_metadata: str,
84
+ ) -> None:
85
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
86
+ query_result_checker.SqlResultValidator(
87
+ self._session,
88
+ f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} SET METADATA=$${run_metadata}$$",
89
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
90
+
91
+ def show_runs_in_experiment(
92
+ self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
93
+ ) -> list[row.Row]:
94
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
95
+ like_clause = f"LIKE '{like}'" if like else ""
96
+ return query_result_checker.SqlResultValidator(
97
+ self._session, f"SHOW RUNS {like_clause} IN EXPERIMENT {experiment_fqn}"
98
+ ).validate()
@@ -0,0 +1,4 @@
1
+ from snowflake.ml.experiment._entities.experiment import Experiment
2
+ from snowflake.ml.experiment._entities.run import Run
3
+
4
+ __all__ = ["Experiment", "Run"]
@@ -0,0 +1,10 @@
1
+ from snowflake.ml._internal.utils import sql_identifier
2
+
3
+
4
+ class Experiment:
5
+ def __init__(
6
+ self,
7
+ *,
8
+ experiment_name: sql_identifier.SqlIdentifier,
9
+ ) -> None:
10
+ self.name = experiment_name
@@ -0,0 +1,62 @@
1
+ import json
2
+ import types
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from snowflake.ml._internal.utils import sql_identifier
6
+ from snowflake.ml.experiment import _experiment_info as experiment_info
7
+ from snowflake.ml.experiment._client import experiment_tracking_sql_client
8
+ from snowflake.ml.experiment._entities import run_metadata
9
+
10
+ if TYPE_CHECKING:
11
+ from snowflake.ml.experiment import experiment_tracking
12
+
13
+
14
+ class Run:
15
+ def __init__(
16
+ self,
17
+ experiment_tracking: "experiment_tracking.ExperimentTracking",
18
+ *,
19
+ experiment_name: sql_identifier.SqlIdentifier,
20
+ run_name: sql_identifier.SqlIdentifier,
21
+ ) -> None:
22
+ self._experiment_tracking = experiment_tracking
23
+ self.experiment_name = experiment_name
24
+ self.name = run_name
25
+
26
+ self._patcher = experiment_info.ExperimentInfoPatcher(
27
+ experiment_info=self._get_experiment_info(),
28
+ )
29
+
30
+ def __enter__(self) -> "Run":
31
+ self._patcher.__enter__()
32
+ return self
33
+
34
+ def __exit__(
35
+ self,
36
+ exc_type: Optional[type[BaseException]],
37
+ exc_value: Optional[BaseException],
38
+ traceback: Optional[types.TracebackType],
39
+ ) -> None:
40
+ self._patcher.__exit__(exc_type, exc_value, traceback)
41
+ if self._experiment_tracking._run is self:
42
+ self._experiment_tracking.end_run()
43
+
44
+ def _get_metadata(
45
+ self,
46
+ ) -> run_metadata.RunMetadata:
47
+ runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
48
+ experiment_name=self.experiment_name, like=str(self.name)
49
+ )
50
+ if not runs:
51
+ raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
52
+ return run_metadata.RunMetadata.from_dict(
53
+ json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
54
+ )
55
+
56
+ def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
57
+ return experiment_info.ExperimentInfo(
58
+ fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
59
+ self._experiment_tracking._database_name, self._experiment_tracking._schema_name, self.experiment_name
60
+ ),
61
+ run_name=self.name.identifier(),
62
+ )
@@ -0,0 +1,68 @@
1
+ import dataclasses
2
+ import enum
3
+ import typing
4
+
5
+
6
+ class RunStatus(str, enum.Enum):
7
+ UNKNOWN = "UNKNOWN"
8
+ RUNNING = "RUNNING"
9
+ FINISHED = "FINISHED"
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Metric:
14
+ name: str
15
+ value: float
16
+ step: int
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Param:
21
+ name: str
22
+ value: str
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class RunMetadata:
27
+ status: RunStatus
28
+ metrics: list[Metric]
29
+ parameters: list[Param]
30
+
31
+ @classmethod
32
+ def from_dict(
33
+ cls,
34
+ metadata: dict, # type: ignore[type-arg]
35
+ ) -> "RunMetadata":
36
+ return RunMetadata(
37
+ status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
38
+ metrics=[Metric(**m) for m in metadata.get("metrics", [])],
39
+ parameters=[Param(**p) for p in metadata.get("parameters", [])],
40
+ )
41
+
42
+ def to_dict(self) -> dict: # type: ignore[type-arg]
43
+ return dataclasses.asdict(self)
44
+
45
+ def set_metric(
46
+ self,
47
+ key: str,
48
+ value: float,
49
+ step: int,
50
+ ) -> None:
51
+ for metric in self.metrics:
52
+ if metric.name == key and metric.step == step:
53
+ metric.value = value
54
+ break
55
+ else:
56
+ self.metrics.append(Metric(name=key, value=value, step=step))
57
+
58
+ def set_param(
59
+ self,
60
+ key: str,
61
+ value: typing.Any,
62
+ ) -> None:
63
+ for parameter in self.parameters:
64
+ if parameter.name == key:
65
+ parameter.value = str(value)
66
+ break
67
+ else:
68
+ self.parameters.append(Param(name=key, value=str(value)))
@@ -0,0 +1,63 @@
1
+ import dataclasses
2
+ import functools
3
+ import types
4
+ from typing import Callable, Optional
5
+
6
+ from snowflake.ml import model
7
+ from snowflake.ml.registry._manager import model_manager
8
+
9
+
10
+ @dataclasses.dataclass(frozen=True)
11
+ class ExperimentInfo:
12
+ """Serializable information identifying a Experiment"""
13
+
14
+ fully_qualified_name: str
15
+ run_name: str
16
+
17
+
18
+ class ExperimentInfoPatcher:
19
+ """Context manager that patches ModelManager.log_model to include experiment information.
20
+
21
+ This class maintains a stack of active experiment contexts and ensures that
22
+ log_model calls are automatically tagged with the appropriate experiment info.
23
+ """
24
+
25
+ # Store original method at class definition time to avoid recursive patching
26
+ _original_log_model: Callable[..., model.ModelVersion] = model_manager.ModelManager.log_model
27
+
28
+ # Stack of active experiment_info contexts for nested experiment support
29
+ _experiment_info_stack: list[ExperimentInfo] = []
30
+
31
+ def __init__(self, experiment_info: ExperimentInfo) -> None:
32
+ self._experiment_info = experiment_info
33
+
34
+ def __enter__(self) -> "ExperimentInfoPatcher":
35
+ # Only patch ModelManager.log_model if we're the first patcher to avoid nested patching
36
+ if not ExperimentInfoPatcher._experiment_info_stack:
37
+
38
+ @functools.wraps(ExperimentInfoPatcher._original_log_model)
39
+ def patched(*args, **kwargs) -> model.ModelVersion: # type: ignore[no-untyped-def]
40
+ # Use the most recent (top of stack) experiment_info for nested contexts
41
+ current_experiment_info = ExperimentInfoPatcher._experiment_info_stack[-1]
42
+ return ExperimentInfoPatcher._original_log_model(
43
+ *args, **kwargs, experiment_info=current_experiment_info
44
+ )
45
+
46
+ model_manager.ModelManager.log_model = patched # type: ignore[method-assign]
47
+
48
+ ExperimentInfoPatcher._experiment_info_stack.append(self._experiment_info)
49
+ return self
50
+
51
+ def __exit__(
52
+ self,
53
+ exc_type: Optional[type[BaseException]],
54
+ exc_value: Optional[BaseException],
55
+ traceback: Optional[types.TracebackType],
56
+ ) -> None:
57
+ ExperimentInfoPatcher._experiment_info_stack.pop()
58
+
59
+ # Restore original method when no patches are active to clean up properly
60
+ if not ExperimentInfoPatcher._experiment_info_stack:
61
+ model_manager.ModelManager.log_model = ( # type: ignore[method-assign]
62
+ ExperimentInfoPatcher._original_log_model
63
+ )
@@ -0,0 +1,121 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Optional, Union
3
+ from warnings import warn
4
+
5
+ import lightgbm as lgb
6
+ import xgboost as xgb
7
+
8
+ from snowflake.ml.model.model_signature import ModelSignature
9
+
10
+ if TYPE_CHECKING:
11
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
12
+
13
+
14
+ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
15
+ def __init__(
16
+ self,
17
+ experiment_tracking: "ExperimentTracking",
18
+ log_model: bool = True,
19
+ log_metrics: bool = True,
20
+ log_params: bool = True,
21
+ model_name: Optional[str] = None,
22
+ model_signature: Optional[ModelSignature] = None,
23
+ ) -> None:
24
+ self._experiment_tracking = experiment_tracking
25
+ self.log_model = log_model
26
+ self.log_metrics = log_metrics
27
+ self.log_params = log_params
28
+ self.model_name = model_name
29
+ self.model_signature = model_signature
30
+
31
+ def before_training(self, model: xgb.Booster) -> xgb.Booster:
32
+ def _flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
33
+ flat_params = {}
34
+ items = params.items() if isinstance(params, dict) else enumerate(params)
35
+ for key, value in items:
36
+ new_prefix = f"{prefix}.{key}" if prefix else str(key)
37
+ if isinstance(value, (dict, list)):
38
+ flat_params.update(_flatten_nested_params(value, new_prefix))
39
+ else:
40
+ flat_params[new_prefix] = value
41
+ return flat_params
42
+
43
+ if self.log_params:
44
+ params = json.loads(model.save_config())
45
+ self._experiment_tracking.log_params(_flatten_nested_params(params))
46
+
47
+ return model
48
+
49
+ def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
50
+ if self.log_metrics:
51
+ for dataset_name, metrics in evals_log.items():
52
+ for metric_name, log in metrics.items():
53
+ metric_key = dataset_name + ":" + metric_name
54
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
55
+
56
+ return False
57
+
58
+ def after_training(self, model: xgb.Booster) -> xgb.Booster:
59
+ if self.log_model:
60
+ if not self.model_signature:
61
+ warn(
62
+ "Model will not be logged because model signature is missing. "
63
+ "To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
64
+ )
65
+ return model
66
+
67
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
68
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
69
+ model=model,
70
+ model_name=model_name,
71
+ signatures={"predict": self.model_signature},
72
+ )
73
+
74
+ return model
75
+
76
+
77
+ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
78
+ def __init__(
79
+ self,
80
+ experiment_tracking: "ExperimentTracking",
81
+ log_model: bool = True,
82
+ log_metrics: bool = True,
83
+ log_params: bool = True,
84
+ model_name: Optional[str] = None,
85
+ model_signature: Optional[ModelSignature] = None,
86
+ ) -> None:
87
+ self._experiment_tracking = experiment_tracking
88
+ self.log_model = log_model
89
+ self.log_metrics = log_metrics
90
+ self.log_params = log_params
91
+ self.model_name = model_name
92
+ self.model_signature = model_signature
93
+
94
+ super().__init__(eval_result={})
95
+
96
+ def __call__(self, env: lgb.callback.CallbackEnv) -> None:
97
+ if self.log_params:
98
+ if env.iteration == env.begin_iteration: # Log params only at the first iteration
99
+ self._experiment_tracking.log_params(env.params)
100
+
101
+ if self.log_metrics:
102
+ super().__call__(env)
103
+ for dataset_name, metrics in self.eval_result.items():
104
+ for metric_name, log in metrics.items():
105
+ metric_key = dataset_name + ":" + metric_name
106
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
107
+
108
+ if self.log_model:
109
+ if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
110
+ if self.model_signature:
111
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
112
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
113
+ model=env.model,
114
+ model_name=model_name,
115
+ signatures={"predict": self.model_signature},
116
+ )
117
+ else:
118
+ warn(
119
+ "Model will not be logged because model signature is missing. To autolog the model, "
120
+ "please specify `model_signature` when constructing SnowflakeLightgbmCallback."
121
+ )