snowflake-ml-python 1.17.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. snowflake/ml/_internal/telemetry.py +3 -2
  2. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
  3. snowflake/ml/experiment/callback/keras.py +3 -0
  4. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  5. snowflake/ml/experiment/callback/xgboost.py +3 -0
  6. snowflake/ml/experiment/experiment_tracking.py +19 -7
  7. snowflake/ml/feature_store/feature_store.py +236 -61
  8. snowflake/ml/jobs/_utils/constants.py +12 -1
  9. snowflake/ml/jobs/_utils/payload_utils.py +7 -1
  10. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  11. snowflake/ml/jobs/_utils/types.py +5 -0
  12. snowflake/ml/jobs/job.py +16 -2
  13. snowflake/ml/jobs/manager.py +12 -1
  14. snowflake/ml/model/__init__.py +19 -0
  15. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  16. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  17. snowflake/ml/model/_client/model/model_version_impl.py +129 -11
  18. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  19. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  20. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  22. snowflake/ml/model/_model_composer/model_method/model_method.py +4 -1
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  24. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  25. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -0
  26. snowflake/ml/model/type_hints.py +16 -0
  27. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  28. snowflake/ml/version.py +1 -1
  29. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +25 -1
  30. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +33 -32
  31. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
  32. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
  33. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
@@ -73,6 +73,7 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
73
73
  class TelemetryProject(enum.Enum):
74
74
  MLOPS = "MLOps"
75
75
  MODELING = "ModelDevelopment"
76
+ EXPERIMENT_TRACKING = "ExperimentTracking"
76
77
  # TODO: Update with remaining projects.
77
78
 
78
79
 
@@ -464,14 +465,14 @@ def send_api_usage_telemetry(
464
465
 
465
466
  # noqa: DAR402
466
467
  """
467
- start_time = time.perf_counter()
468
-
469
468
  if subproject is not None and subproject_extractor is not None:
470
469
  raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
471
470
 
472
471
  def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
473
472
  @functools.wraps(func)
474
473
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
474
+ start_time = time.perf_counter()
475
+
475
476
  params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
476
477
 
477
478
  api_calls: list[Union[dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
@@ -1,17 +1,17 @@
1
1
  from typing import Optional
2
2
 
3
+ from snowflake.ml._internal import telemetry
3
4
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
5
  from snowflake.ml.experiment._client import artifact
5
6
  from snowflake.ml.model._client.sql import _base
6
7
  from snowflake.ml.utils import sql_client
7
8
  from snowflake.snowpark import file_operation, row, session
8
9
 
10
+ RUN_NAME_COL_NAME = "name"
11
+ RUN_METADATA_COL_NAME = "metadata"
9
12
 
10
- class ExperimentTrackingSQLClient(_base._BaseSQLClient):
11
-
12
- RUN_NAME_COL_NAME = "name"
13
- RUN_METADATA_COL_NAME = "metadata"
14
13
 
14
+ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
15
15
  def __init__(
16
16
  self,
17
17
  session: session.Session,
@@ -28,6 +28,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
28
28
  """
29
29
  super().__init__(session, database_name=database_name, schema_name=schema_name)
30
30
 
31
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
31
32
  def create_experiment(
32
33
  self,
33
34
  experiment_name: sql_identifier.SqlIdentifier,
@@ -39,24 +40,21 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
39
40
  self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
40
41
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
41
42
 
43
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
42
44
  def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
43
45
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
44
46
  query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
45
47
  expected_rows=1, expected_cols=1
46
48
  ).validate()
47
49
 
48
- def add_run(
49
- self,
50
- *,
51
- experiment_name: sql_identifier.SqlIdentifier,
52
- run_name: sql_identifier.SqlIdentifier,
53
- live: bool = True,
54
- ) -> None:
50
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
51
+ def add_run(self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier) -> None:
55
52
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
56
53
  query_result_checker.SqlResultValidator(
57
- self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD {'LIVE' if live else ''} RUN {run_name}"
54
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD RUN {run_name}"
58
55
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
59
56
 
57
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
60
58
  def commit_run(
61
59
  self,
62
60
  *,
@@ -68,6 +66,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
68
66
  self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
69
67
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
70
68
 
69
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
71
70
  def drop_run(
72
71
  self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
73
72
  ) -> None:
@@ -76,6 +75,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
76
75
  self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
77
76
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
78
77
 
78
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
79
79
  def modify_run_add_metrics(
80
80
  self,
81
81
  *,
@@ -89,6 +89,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
89
89
  f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
90
90
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
91
 
92
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
92
93
  def modify_run_add_params(
93
94
  self,
94
95
  *,
@@ -102,6 +103,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
102
103
  f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
103
104
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
104
105
 
106
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
105
107
  def put_artifact(
106
108
  self,
107
109
  *,
@@ -118,6 +120,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
118
120
  auto_compress=auto_compress,
119
121
  )[0]
120
122
 
123
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
121
124
  def list_artifacts(
122
125
  self,
123
126
  *,
@@ -142,6 +145,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
142
145
  for result in results
143
146
  ]
144
147
 
148
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
145
149
  def get_artifact(
146
150
  self,
147
151
  *,
@@ -155,6 +159,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
155
159
  target_directory=target_path,
156
160
  )[0]
157
161
 
162
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
158
163
  def show_runs_in_experiment(
159
164
  self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
160
165
  ) -> list[row.Row]:
@@ -20,6 +20,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
20
20
  log_params: bool = True,
21
21
  log_every_n_epochs: int = 1,
22
22
  model_name: Optional[str] = None,
23
+ version_name: Optional[str] = None,
23
24
  model_signature: Optional["ModelSignature"] = None,
24
25
  ) -> None:
25
26
  self._experiment_tracking = experiment_tracking
@@ -30,6 +31,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
30
31
  raise ValueError("`log_every_n_epochs` must be positive.")
31
32
  self.log_every_n_epochs = log_every_n_epochs
32
33
  self.model_name = model_name
34
+ self.version_name = version_name
33
35
  self.model_signature = model_signature
34
36
 
35
37
  def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
@@ -59,5 +61,6 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
59
61
  self._experiment_tracking.log_model( # type: ignore[call-arg]
60
62
  model=self.model,
61
63
  model_name=model_name,
64
+ version_name=self.version_name,
62
65
  signatures={"predict": self.model_signature},
63
66
  )
@@ -17,6 +17,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
17
17
  log_params: bool = True,
18
18
  log_every_n_epochs: int = 1,
19
19
  model_name: Optional[str] = None,
20
+ version_name: Optional[str] = None,
20
21
  model_signature: Optional["ModelSignature"] = None,
21
22
  ) -> None:
22
23
  self._experiment_tracking = experiment_tracking
@@ -27,6 +28,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
27
28
  raise ValueError("`log_every_n_epochs` must be positive.")
28
29
  self.log_every_n_epochs = log_every_n_epochs
29
30
  self.model_name = model_name
31
+ self.version_name = version_name
30
32
  self.model_signature = model_signature
31
33
 
32
34
  super().__init__(eval_result={})
@@ -50,6 +52,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
50
52
  self._experiment_tracking.log_model( # type: ignore[call-arg]
51
53
  model=env.model,
52
54
  model_name=model_name,
55
+ version_name=self.version_name,
53
56
  signatures={"predict": self.model_signature},
54
57
  )
55
58
  else:
@@ -20,6 +20,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
20
20
  log_params: bool = True,
21
21
  log_every_n_epochs: int = 1,
22
22
  model_name: Optional[str] = None,
23
+ version_name: Optional[str] = None,
23
24
  model_signature: Optional["ModelSignature"] = None,
24
25
  ) -> None:
25
26
  self._experiment_tracking = experiment_tracking
@@ -30,6 +31,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
30
31
  raise ValueError("`log_every_n_epochs` must be positive.")
31
32
  self.log_every_n_epochs = log_every_n_epochs
32
33
  self.model_name = model_name
34
+ self.version_name = version_name
33
35
  self.model_signature = model_signature
34
36
 
35
37
  def before_training(self, model: xgb.Booster) -> xgb.Booster:
@@ -61,6 +63,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
61
63
  self._experiment_tracking.log_model( # type: ignore[call-arg]
62
64
  model=model,
63
65
  model_name=model_name,
66
+ version_name=self.version_name,
64
67
  signatures={"predict": self.model_signature},
65
68
  )
66
69
 
@@ -193,23 +193,35 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
193
193
  run_name: Optional[str] = None,
194
194
  ) -> entities.Run:
195
195
  """
196
- Start a new run.
196
+ Start a new run. If a run name of an existing run is provided, resumes the run if it is running.
197
197
 
198
198
  Args:
199
199
  run_name: The name of the run. If None, a default name will be generated.
200
200
 
201
201
  Returns:
202
- Run: The run that was started.
202
+ Run: The run that was started or resumed.
203
203
 
204
204
  Raises:
205
- RuntimeError: If a run is already active.
205
+ RuntimeError: If a run is already active. If a run with the same name exists but is not running.
206
206
  """
207
207
  if self._run:
208
208
  raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
209
209
  experiment = self._get_or_set_experiment()
210
- run_name = (
211
- sql_identifier.SqlIdentifier(run_name) if run_name is not None else self._generate_run_name(experiment)
212
- )
210
+
211
+ if run_name is None:
212
+ run_name = self._generate_run_name(experiment)
213
+ elif runs := self._sql_client.show_runs_in_experiment(experiment_name=experiment.name, like=run_name):
214
+ if "RUNNING" != json.loads(runs[0][sql_client.RUN_METADATA_COL_NAME])["status"]:
215
+ raise RuntimeError(f"Run {run_name} exists but cannot be resumed as it is no longer running.")
216
+ else:
217
+ self._run = entities.Run(
218
+ experiment_tracking=self,
219
+ experiment_name=experiment.name,
220
+ run_name=sql_identifier.SqlIdentifier(run_name),
221
+ )
222
+ return self._run
223
+
224
+ run_name = sql_identifier.SqlIdentifier(run_name)
213
225
  self._sql_client.add_run(
214
226
  experiment_name=experiment.name,
215
227
  run_name=run_name,
@@ -444,7 +456,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
444
456
  def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
445
457
  generator = hrid_generator.HRID16()
446
458
  existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
447
- existing_run_names = [row[sql_client.ExperimentTrackingSQLClient.RUN_NAME_COL_NAME] for row in existing_runs]
459
+ existing_run_names = [row[sql_client.RUN_NAME_COL_NAME] for row in existing_runs]
448
460
  for _ in range(1000):
449
461
  run_name = generator.generate()[1]
450
462
  if run_name not in existing_run_names: