snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/_internal/telemetry.py +3 -2
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
  5. snowflake/ml/experiment/callback/keras.py +3 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  7. snowflake/ml/experiment/callback/xgboost.py +3 -0
  8. snowflake/ml/experiment/experiment_tracking.py +19 -7
  9. snowflake/ml/feature_store/feature_store.py +236 -61
  10. snowflake/ml/jobs/__init__.py +4 -0
  11. snowflake/ml/jobs/_interop/__init__.py +0 -0
  12. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  13. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  14. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  15. snowflake/ml/jobs/_interop/legacy.py +225 -0
  16. snowflake/ml/jobs/_interop/protocols.py +471 -0
  17. snowflake/ml/jobs/_interop/results.py +51 -0
  18. snowflake/ml/jobs/_interop/utils.py +144 -0
  19. snowflake/ml/jobs/_utils/constants.py +16 -2
  20. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  21. snowflake/ml/jobs/_utils/payload_utils.py +8 -2
  22. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  23. snowflake/ml/jobs/_utils/spec_utils.py +2 -1
  24. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  25. snowflake/ml/jobs/_utils/types.py +15 -0
  26. snowflake/ml/jobs/job.py +186 -40
  27. snowflake/ml/jobs/manager.py +48 -39
  28. snowflake/ml/model/__init__.py +19 -0
  29. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  30. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  31. snowflake/ml/model/_client/model/model_version_impl.py +168 -18
  32. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  33. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  36. snowflake/ml/model/_client/sql/model_version.py +3 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
  39. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  40. snowflake/ml/model/_packager/model_env/model_env.py +22 -5
  41. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  42. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  43. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  44. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  46. snowflake/ml/model/type_hints.py +16 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  48. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  49. snowflake/ml/version.py +1 -1
  50. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
  51. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
  52. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
  53. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
  54. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ aerial
1
2
  afraid
2
3
  ancient
3
4
  angry
@@ -26,7 +27,6 @@ dull
26
27
  empty
27
28
  evil
28
29
  fast
29
- fat
30
30
  fluffy
31
31
  foolish
32
32
  fresh
@@ -57,10 +57,10 @@ lovely
57
57
  lucky
58
58
  massive
59
59
  mean
60
+ metallic
60
61
  mighty
61
62
  modern
62
63
  moody
63
- nasty
64
64
  neat
65
65
  nervous
66
66
  new
@@ -85,7 +85,6 @@ rotten
85
85
  rude
86
86
  selfish
87
87
  serious
88
- shaggy
89
88
  sharp
90
89
  short
91
90
  shy
@@ -96,14 +95,15 @@ slippery
96
95
  smart
97
96
  smooth
98
97
  soft
98
+ solid
99
99
  sour
100
100
  spicy
101
101
  splendid
102
102
  spotty
103
+ squishy
103
104
  stale
104
105
  strange
105
106
  strong
106
- stupid
107
107
  sweet
108
108
  swift
109
109
  tall
@@ -116,7 +116,6 @@ tidy
116
116
  tiny
117
117
  tough
118
118
  tricky
119
- ugly
120
119
  warm
121
120
  weak
122
121
  wet
@@ -124,5 +123,6 @@ wicked
124
123
  wise
125
124
  witty
126
125
  wonderful
126
+ wooden
127
127
  yellow
128
128
  young
@@ -1,10 +1,9 @@
1
1
  anaconda
2
2
  ant
3
- ape
4
- baboon
5
3
  badger
6
4
  bat
7
5
  bear
6
+ beetle
8
7
  bird
9
8
  bobcat
10
9
  bulldog
@@ -73,7 +72,6 @@ lobster
73
72
  mayfly
74
73
  mamba
75
74
  mole
76
- monkey
77
75
  moose
78
76
  moth
79
77
  mouse
@@ -114,6 +112,7 @@ swan
114
112
  termite
115
113
  tiger
116
114
  treefrog
115
+ tuna
117
116
  turkey
118
117
  turtle
119
118
  vampirebat
@@ -126,3 +125,4 @@ worm
126
125
  yak
127
126
  yeti
128
127
  zebra
128
+ zebrafish
@@ -73,6 +73,7 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
73
73
  class TelemetryProject(enum.Enum):
74
74
  MLOPS = "MLOps"
75
75
  MODELING = "ModelDevelopment"
76
+ EXPERIMENT_TRACKING = "ExperimentTracking"
76
77
  # TODO: Update with remaining projects.
77
78
 
78
79
 
@@ -464,14 +465,14 @@ def send_api_usage_telemetry(
464
465
 
465
466
  # noqa: DAR402
466
467
  """
467
- start_time = time.perf_counter()
468
-
469
468
  if subproject is not None and subproject_extractor is not None:
470
469
  raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
471
470
 
472
471
  def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
473
472
  @functools.wraps(func)
474
473
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
474
+ start_time = time.perf_counter()
475
+
475
476
  params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
476
477
 
477
478
  api_calls: list[Union[dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
@@ -1,17 +1,17 @@
1
1
  from typing import Optional
2
2
 
3
+ from snowflake.ml._internal import telemetry
3
4
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
5
  from snowflake.ml.experiment._client import artifact
5
6
  from snowflake.ml.model._client.sql import _base
6
7
  from snowflake.ml.utils import sql_client
7
8
  from snowflake.snowpark import file_operation, row, session
8
9
 
10
+ RUN_NAME_COL_NAME = "name"
11
+ RUN_METADATA_COL_NAME = "metadata"
9
12
 
10
- class ExperimentTrackingSQLClient(_base._BaseSQLClient):
11
-
12
- RUN_NAME_COL_NAME = "name"
13
- RUN_METADATA_COL_NAME = "metadata"
14
13
 
14
+ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
15
15
  def __init__(
16
16
  self,
17
17
  session: session.Session,
@@ -28,6 +28,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
28
28
  """
29
29
  super().__init__(session, database_name=database_name, schema_name=schema_name)
30
30
 
31
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
31
32
  def create_experiment(
32
33
  self,
33
34
  experiment_name: sql_identifier.SqlIdentifier,
@@ -39,24 +40,21 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
39
40
  self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
40
41
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
41
42
 
43
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
42
44
  def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
43
45
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
44
46
  query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
45
47
  expected_rows=1, expected_cols=1
46
48
  ).validate()
47
49
 
48
- def add_run(
49
- self,
50
- *,
51
- experiment_name: sql_identifier.SqlIdentifier,
52
- run_name: sql_identifier.SqlIdentifier,
53
- live: bool = True,
54
- ) -> None:
50
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
51
+ def add_run(self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier) -> None:
55
52
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
56
53
  query_result_checker.SqlResultValidator(
57
- self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD {'LIVE' if live else ''} RUN {run_name}"
54
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD RUN {run_name}"
58
55
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
59
56
 
57
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
60
58
  def commit_run(
61
59
  self,
62
60
  *,
@@ -68,6 +66,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
68
66
  self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
69
67
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
70
68
 
69
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
71
70
  def drop_run(
72
71
  self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
73
72
  ) -> None:
@@ -76,6 +75,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
76
75
  self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
77
76
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
78
77
 
78
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
79
79
  def modify_run_add_metrics(
80
80
  self,
81
81
  *,
@@ -89,6 +89,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
89
89
  f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
90
90
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
91
 
92
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
92
93
  def modify_run_add_params(
93
94
  self,
94
95
  *,
@@ -102,6 +103,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
102
103
  f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
103
104
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
104
105
 
106
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
105
107
  def put_artifact(
106
108
  self,
107
109
  *,
@@ -118,6 +120,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
118
120
  auto_compress=auto_compress,
119
121
  )[0]
120
122
 
123
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
121
124
  def list_artifacts(
122
125
  self,
123
126
  *,
@@ -142,6 +145,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
142
145
  for result in results
143
146
  ]
144
147
 
148
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
145
149
  def get_artifact(
146
150
  self,
147
151
  *,
@@ -155,6 +159,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
155
159
  target_directory=target_path,
156
160
  )[0]
157
161
 
162
+ @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
158
163
  def show_runs_in_experiment(
159
164
  self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
160
165
  ) -> list[row.Row]:
@@ -20,6 +20,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
20
20
  log_params: bool = True,
21
21
  log_every_n_epochs: int = 1,
22
22
  model_name: Optional[str] = None,
23
+ version_name: Optional[str] = None,
23
24
  model_signature: Optional["ModelSignature"] = None,
24
25
  ) -> None:
25
26
  self._experiment_tracking = experiment_tracking
@@ -30,6 +31,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
30
31
  raise ValueError("`log_every_n_epochs` must be positive.")
31
32
  self.log_every_n_epochs = log_every_n_epochs
32
33
  self.model_name = model_name
34
+ self.version_name = version_name
33
35
  self.model_signature = model_signature
34
36
 
35
37
  def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
@@ -59,5 +61,6 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
59
61
  self._experiment_tracking.log_model( # type: ignore[call-arg]
60
62
  model=self.model,
61
63
  model_name=model_name,
64
+ version_name=self.version_name,
62
65
  signatures={"predict": self.model_signature},
63
66
  )
@@ -17,6 +17,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
17
17
  log_params: bool = True,
18
18
  log_every_n_epochs: int = 1,
19
19
  model_name: Optional[str] = None,
20
+ version_name: Optional[str] = None,
20
21
  model_signature: Optional["ModelSignature"] = None,
21
22
  ) -> None:
22
23
  self._experiment_tracking = experiment_tracking
@@ -27,6 +28,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
27
28
  raise ValueError("`log_every_n_epochs` must be positive.")
28
29
  self.log_every_n_epochs = log_every_n_epochs
29
30
  self.model_name = model_name
31
+ self.version_name = version_name
30
32
  self.model_signature = model_signature
31
33
 
32
34
  super().__init__(eval_result={})
@@ -50,6 +52,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
50
52
  self._experiment_tracking.log_model( # type: ignore[call-arg]
51
53
  model=env.model,
52
54
  model_name=model_name,
55
+ version_name=self.version_name,
53
56
  signatures={"predict": self.model_signature},
54
57
  )
55
58
  else:
@@ -20,6 +20,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
20
20
  log_params: bool = True,
21
21
  log_every_n_epochs: int = 1,
22
22
  model_name: Optional[str] = None,
23
+ version_name: Optional[str] = None,
23
24
  model_signature: Optional["ModelSignature"] = None,
24
25
  ) -> None:
25
26
  self._experiment_tracking = experiment_tracking
@@ -30,6 +31,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
30
31
  raise ValueError("`log_every_n_epochs` must be positive.")
31
32
  self.log_every_n_epochs = log_every_n_epochs
32
33
  self.model_name = model_name
34
+ self.version_name = version_name
33
35
  self.model_signature = model_signature
34
36
 
35
37
  def before_training(self, model: xgb.Booster) -> xgb.Booster:
@@ -61,6 +63,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
61
63
  self._experiment_tracking.log_model( # type: ignore[call-arg]
62
64
  model=model,
63
65
  model_name=model_name,
66
+ version_name=self.version_name,
64
67
  signatures={"predict": self.model_signature},
65
68
  )
66
69
 
@@ -193,23 +193,35 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
193
193
  run_name: Optional[str] = None,
194
194
  ) -> entities.Run:
195
195
  """
196
- Start a new run.
196
+ Start a new run. If a run name of an existing run is provided, resumes the run if it is running.
197
197
 
198
198
  Args:
199
199
  run_name: The name of the run. If None, a default name will be generated.
200
200
 
201
201
  Returns:
202
- Run: The run that was started.
202
+ Run: The run that was started or resumed.
203
203
 
204
204
  Raises:
205
- RuntimeError: If a run is already active.
205
+ RuntimeError: If a run is already active. If a run with the same name exists but is not running.
206
206
  """
207
207
  if self._run:
208
208
  raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
209
209
  experiment = self._get_or_set_experiment()
210
- run_name = (
211
- sql_identifier.SqlIdentifier(run_name) if run_name is not None else self._generate_run_name(experiment)
212
- )
210
+
211
+ if run_name is None:
212
+ run_name = self._generate_run_name(experiment)
213
+ elif runs := self._sql_client.show_runs_in_experiment(experiment_name=experiment.name, like=run_name):
214
+ if "RUNNING" != json.loads(runs[0][sql_client.RUN_METADATA_COL_NAME])["status"]:
215
+ raise RuntimeError(f"Run {run_name} exists but cannot be resumed as it is no longer running.")
216
+ else:
217
+ self._run = entities.Run(
218
+ experiment_tracking=self,
219
+ experiment_name=experiment.name,
220
+ run_name=sql_identifier.SqlIdentifier(run_name),
221
+ )
222
+ return self._run
223
+
224
+ run_name = sql_identifier.SqlIdentifier(run_name)
213
225
  self._sql_client.add_run(
214
226
  experiment_name=experiment.name,
215
227
  run_name=run_name,
@@ -444,7 +456,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
444
456
  def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
445
457
  generator = hrid_generator.HRID16()
446
458
  existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
447
- existing_run_names = [row[sql_client.ExperimentTrackingSQLClient.RUN_NAME_COL_NAME] for row in existing_runs]
459
+ existing_run_names = [row[sql_client.RUN_NAME_COL_NAME] for row in existing_runs]
448
460
  for _ in range(1000):
449
461
  run_name = generator.generate()[1]
450
462
  if run_name not in existing_run_names: