snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +54 -42
  5. snowflake/ml/_internal/utils/service_logger.py +105 -3
  6. snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
  7. snowflake/ml/data/data_connector.py +13 -2
  8. snowflake/ml/data/data_ingestor.py +8 -0
  9. snowflake/ml/data/torch_utils.py +1 -1
  10. snowflake/ml/dataset/dataset.py +2 -1
  11. snowflake/ml/dataset/dataset_reader.py +14 -4
  12. snowflake/ml/experiment/__init__.py +3 -0
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/callback.py +121 -0
  20. snowflake/ml/experiment/experiment_tracking.py +319 -0
  21. snowflake/ml/jobs/_utils/constants.py +15 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +156 -54
  23. snowflake/ml/jobs/_utils/query_helper.py +16 -5
  24. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  25. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
  26. snowflake/ml/jobs/_utils/spec_utils.py +23 -8
  27. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  28. snowflake/ml/jobs/_utils/types.py +64 -4
  29. snowflake/ml/jobs/job.py +70 -75
  30. snowflake/ml/jobs/manager.py +59 -31
  31. snowflake/ml/lineage/lineage_node.py +2 -2
  32. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  33. snowflake/ml/model/_client/ops/service_ops.py +336 -137
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
  36. snowflake/ml/model/_client/sql/service.py +1 -38
  37. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  38. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  41. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  45. snowflake/ml/model/_signatures/utils.py +4 -0
  46. snowflake/ml/model/event_handler.py +117 -0
  47. snowflake/ml/model/model_signature.py +11 -9
  48. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  49. snowflake/ml/modeling/framework/base.py +1 -1
  50. snowflake/ml/modeling/metrics/classification.py +14 -14
  51. snowflake/ml/modeling/metrics/correlation.py +19 -8
  52. snowflake/ml/modeling/metrics/ranking.py +6 -6
  53. snowflake/ml/modeling/metrics/regression.py +9 -9
  54. snowflake/ml/monitoring/explain_visualize.py +12 -5
  55. snowflake/ml/registry/_manager/model_manager.py +32 -15
  56. snowflake/ml/registry/registry.py +48 -80
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
  59. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
  60. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  61. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,319 @@
1
+ import functools
2
+ import json
3
+ import sys
4
+ from typing import Any, Optional, Union
5
+ from urllib.parse import quote
6
+
7
+ import snowflake.snowpark._internal.utils as snowpark_utils
8
+ from snowflake.ml import model, registry
9
+ from snowflake.ml._internal.human_readable_id import hrid_generator
10
+ from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.experiment import (
12
+ _entities as entities,
13
+ _experiment_info as experiment_info,
14
+ )
15
+ from snowflake.ml.experiment._client import experiment_tracking_sql_client as sql_client
16
+ from snowflake.ml.model import type_hints
17
+ from snowflake.ml.utils import sql_client as sql_client_utils
18
+ from snowflake.snowpark import session
19
+
20
+ DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
21
+
22
+
23
+ class ExperimentTracking:
24
+ """
25
+ Class to manage experiments in Snowflake.
26
+ """
27
+
28
+ @snowpark_utils.private_preview(version="1.9.1")
29
+ def __init__(
30
+ self,
31
+ session: session.Session,
32
+ *,
33
+ database_name: Optional[str] = None,
34
+ schema_name: Optional[str] = None,
35
+ ) -> None:
36
+ """
37
+ Initializes experiment tracking within a pre-created schema.
38
+
39
+ Args:
40
+ session: The Snowpark Session to connect with Snowflake.
41
+ database_name: The name of the database. If None, the current database of the session
42
+ will be used. Defaults to None.
43
+ schema_name: The name of the schema. If None, the current schema of the session
44
+ will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
45
+
46
+ Raises:
47
+ ValueError: If no database is provided and no active database exists in the session.
48
+ """
49
+ if database_name:
50
+ self._database_name = sql_identifier.SqlIdentifier(database_name)
51
+ elif session_db := session.get_current_database():
52
+ self._database_name = sql_identifier.SqlIdentifier(session_db)
53
+ else:
54
+ raise ValueError("You need to provide a database to use experiment tracking.")
55
+
56
+ if schema_name:
57
+ self._schema_name = sql_identifier.SqlIdentifier(schema_name)
58
+ elif session_schema := session.get_current_schema():
59
+ self._schema_name = sql_identifier.SqlIdentifier(session_schema)
60
+ else:
61
+ self._schema_name = sql_identifier.SqlIdentifier("PUBLIC")
62
+
63
+ self._sql_client = sql_client.ExperimentTrackingSQLClient(
64
+ session,
65
+ database_name=self._database_name,
66
+ schema_name=self._schema_name,
67
+ )
68
+ self._registry = registry.Registry(
69
+ session=session,
70
+ database_name=self._database_name,
71
+ schema_name=self._schema_name,
72
+ )
73
+
74
+ # The experiment in context
75
+ self._experiment: Optional[entities.Experiment] = None
76
+ # The run in context
77
+ self._run: Optional[entities.Run] = None
78
+
79
+ def set_experiment(
80
+ self,
81
+ experiment_name: str,
82
+ ) -> entities.Experiment:
83
+ """
84
+ Set the experiment in context. Creates a new experiment if it doesn't exist.
85
+
86
+ Args:
87
+ experiment_name: The name of the experiment.
88
+
89
+ Returns:
90
+ Experiment: The experiment that was set.
91
+ """
92
+ experiment_name = sql_identifier.SqlIdentifier(experiment_name)
93
+ if self._experiment and self._experiment.name == experiment_name:
94
+ return self._experiment
95
+ self._sql_client.create_experiment(
96
+ experiment_name=experiment_name,
97
+ creation_mode=sql_client_utils.CreationMode(if_not_exists=True),
98
+ )
99
+ self._experiment = entities.Experiment(experiment_name=experiment_name)
100
+ self._run = None
101
+ return self._experiment
102
+
103
+ def delete_experiment(
104
+ self,
105
+ experiment_name: str,
106
+ ) -> None:
107
+ """
108
+ Delete an experiment.
109
+
110
+ Args:
111
+ experiment_name: The name of the experiment.
112
+ """
113
+ self._sql_client.drop_experiment(experiment_name=sql_identifier.SqlIdentifier(experiment_name))
114
+ if self._experiment and self._experiment.name == experiment_name:
115
+ self._experiment = None
116
+ self._run = None
117
+
118
+ @functools.wraps(registry.Registry.log_model)
119
+ def log_model(
120
+ self,
121
+ model: Union[type_hints.SupportedModelType, model.ModelVersion],
122
+ *,
123
+ model_name: str,
124
+ **kwargs: Any,
125
+ ) -> model.ModelVersion:
126
+ run = self._get_or_start_run()
127
+ with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
128
+ return self._registry.log_model(model, model_name=model_name, **kwargs)
129
+
130
+ def start_run(
131
+ self,
132
+ run_name: Optional[str] = None,
133
+ ) -> entities.Run:
134
+ """
135
+ Start a new run.
136
+
137
+ Args:
138
+ run_name: The name of the run. If None, a default name will be generated.
139
+
140
+ Returns:
141
+ Run: The run that was started.
142
+
143
+ Raises:
144
+ RuntimeError: If a run is already active.
145
+ """
146
+ if self._run:
147
+ raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
148
+ experiment = self._get_or_set_experiment()
149
+ run_name = (
150
+ sql_identifier.SqlIdentifier(run_name) if run_name is not None else self._generate_run_name(experiment)
151
+ )
152
+ self._sql_client.add_run(
153
+ experiment_name=experiment.name,
154
+ run_name=run_name,
155
+ )
156
+ self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
157
+ return self._run
158
+
159
+ def end_run(self, run_name: Optional[str] = None) -> None:
160
+ """
161
+ End the current run if no run name is provided. Otherwise, the specified run is ended.
162
+
163
+ Args:
164
+ run_name: The name of the run to be ended. If None, the current run is ended.
165
+
166
+ Raises:
167
+ RuntimeError: If no run is active.
168
+ """
169
+ if not self._experiment:
170
+ raise RuntimeError("No experiment set. Please set an experiment before ending a run.")
171
+ experiment_name = self._experiment.name
172
+
173
+ if run_name:
174
+ run_name = sql_identifier.SqlIdentifier(run_name)
175
+ elif self._run:
176
+ run_name = self._run.name
177
+ else:
178
+ raise RuntimeError("No run is active. Please start a run before ending it.")
179
+
180
+ self._sql_client.commit_run(
181
+ experiment_name=experiment_name,
182
+ run_name=run_name,
183
+ )
184
+ if self._run and run_name == self._run.name:
185
+ self._run = None
186
+ self._print_urls(experiment_name=experiment_name, run_name=run_name)
187
+
188
+ def delete_run(
189
+ self,
190
+ run_name: str,
191
+ ) -> None:
192
+ """
193
+ Delete a run.
194
+
195
+ Args:
196
+ run_name: The name of the run to be deleted.
197
+
198
+ Raises:
199
+ RuntimeError: If no experiment is set.
200
+ """
201
+ if not self._experiment:
202
+ raise RuntimeError("No experiment set. Please set an experiment before deleting a run.")
203
+ self._sql_client.drop_run(
204
+ experiment_name=self._experiment.name,
205
+ run_name=sql_identifier.SqlIdentifier(run_name),
206
+ )
207
+ if self._run and self._run.name == run_name:
208
+ self._run = None
209
+
210
+ def log_metric(
211
+ self,
212
+ key: str,
213
+ value: float,
214
+ step: int = 0,
215
+ ) -> None:
216
+ """
217
+ Log a metric under the current run. If no run is active, this method will create a new run.
218
+
219
+ Args:
220
+ key: The name of the metric.
221
+ value: The value of the metric.
222
+ step: The step of the metric. Defaults to 0.
223
+ """
224
+ self.log_metrics(metrics={key: value}, step=step)
225
+
226
+ def log_metrics(
227
+ self,
228
+ metrics: dict[str, float],
229
+ step: int = 0,
230
+ ) -> None:
231
+ """
232
+ Log metrics under the current run. If no run is active, this method will create a new run.
233
+
234
+ Args:
235
+ metrics: Dictionary containing metric keys and float values.
236
+ step: The step of the metrics. Defaults to 0.
237
+ """
238
+ run = self._get_or_start_run()
239
+ metadata = run._get_metadata()
240
+ for key, value in metrics.items():
241
+ metadata.set_metric(key, value, step)
242
+ self._sql_client.modify_run(
243
+ experiment_name=run.experiment_name,
244
+ run_name=run.name,
245
+ run_metadata=json.dumps(metadata.to_dict()),
246
+ )
247
+
248
+ def log_param(
249
+ self,
250
+ key: str,
251
+ value: Any,
252
+ ) -> None:
253
+ """
254
+ Log a parameter under the current run. If no run is active, this method will create a new run.
255
+
256
+ Args:
257
+ key: The name of the parameter.
258
+ value: The value of the parameter. Values can be of any type, but will be converted to string.
259
+ """
260
+ self.log_params({key: value})
261
+
262
+ def log_params(
263
+ self,
264
+ params: dict[str, Any],
265
+ ) -> None:
266
+ """
267
+ Log parameters under the current run. If no run is active, this method will create a new run.
268
+
269
+ Args:
270
+ params: Dictionary containing parameter keys and values. Values can be of any type, but will be converted
271
+ to string.
272
+ """
273
+ run = self._get_or_start_run()
274
+ metadata = run._get_metadata()
275
+ for key, value in params.items():
276
+ metadata.set_param(key, value)
277
+ self._sql_client.modify_run(
278
+ experiment_name=run.experiment_name,
279
+ run_name=run.name,
280
+ run_metadata=json.dumps(metadata.to_dict()),
281
+ )
282
+
283
+ def _get_or_set_experiment(self) -> entities.Experiment:
284
+ if self._experiment:
285
+ return self._experiment
286
+ return self.set_experiment(experiment_name=DEFAULT_EXPERIMENT_NAME)
287
+
288
+ def _get_or_start_run(self) -> entities.Run:
289
+ if self._run:
290
+ return self._run
291
+ return self.start_run()
292
+
293
+ def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
294
+ generator = hrid_generator.HRID16()
295
+ existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
296
+ existing_run_names = [row[sql_client.ExperimentTrackingSQLClient.RUN_NAME_COL_NAME] for row in existing_runs]
297
+ for _ in range(1000):
298
+ run_name = generator.generate()[1]
299
+ if run_name not in existing_run_names:
300
+ return sql_identifier.SqlIdentifier(run_name)
301
+ raise RuntimeError("Random run name generation failed.")
302
+
303
+ def _print_urls(
304
+ self,
305
+ experiment_name: sql_identifier.SqlIdentifier,
306
+ run_name: sql_identifier.SqlIdentifier,
307
+ scheme: str = "https",
308
+ host: str = "app.snowflake.com",
309
+ ) -> None:
310
+
311
+ experiment_url = (
312
+ f"{scheme}://{host}/_deeplink/#/experiments"
313
+ f"/databases/{quote(str(self._database_name))}"
314
+ f"/schemas/{quote(str(self._schema_name))}"
315
+ f"/experiments/{quote(str(experiment_name))}"
316
+ )
317
+ run_url = experiment_url + f"/runs/{quote(str(run_name))}"
318
+ sys.stdout.write(f"🏃 View run {run_name} at: {run_url}\n")
319
+ sys.stdout.write(f"🧪 View experiment at: {experiment_url}\n")
@@ -6,10 +6,23 @@ DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
7
  RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
8
8
  MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
9
+ TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
9
10
  RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
10
11
  MEMORY_VOLUME_NAME = "dshm"
11
12
  STAGE_VOLUME_NAME = "stage-volume"
12
- STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
13
+ # Base mount path
14
+ STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
15
+
16
+ # Stage subdirectory paths
17
+ APP_STAGE_SUBPATH = "app"
18
+ SYSTEM_STAGE_SUBPATH = "system"
19
+ OUTPUT_STAGE_SUBPATH = "output"
20
+
21
+ # Complete mount paths (automatically generated from base + subpath)
22
+ APP_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{APP_STAGE_SUBPATH}"
23
+ SYSTEM_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{SYSTEM_STAGE_SUBPATH}"
24
+ OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
25
+
13
26
 
14
27
  # Default container image information
15
28
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
@@ -46,9 +59,7 @@ ENABLE_HEALTH_CHECKS = "false"
46
59
  JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
47
60
  JOB_POLL_MAX_DELAY_SECONDS = 30
48
61
 
49
- # Magic attributes
50
- IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
51
- RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
62
+ RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_MOUNT_PATH}/mljob_result.pkl"
52
63
 
53
64
  # Log start and end messages
54
65
  LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"