snowflake-ml-python 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -0
  2. snowflake/ml/_internal/platform_capabilities.py +36 -0
  3. snowflake/ml/_internal/telemetry.py +56 -7
  4. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  5. snowflake/ml/data/data_connector.py +103 -1
  6. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  7. snowflake/ml/experiment/_entities/run.py +15 -0
  8. snowflake/ml/experiment/callback/keras.py +25 -2
  9. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  10. snowflake/ml/experiment/callback/xgboost.py +25 -2
  11. snowflake/ml/experiment/experiment_tracking.py +123 -13
  12. snowflake/ml/experiment/utils.py +6 -0
  13. snowflake/ml/feature_store/access_manager.py +1 -0
  14. snowflake/ml/feature_store/feature_store.py +1 -1
  15. snowflake/ml/feature_store/feature_view.py +34 -24
  16. snowflake/ml/jobs/_interop/protocols.py +3 -0
  17. snowflake/ml/jobs/_utils/feature_flags.py +1 -0
  18. snowflake/ml/jobs/_utils/payload_utils.py +360 -357
  19. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  20. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  21. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  22. snowflake/ml/jobs/_utils/spec_utils.py +2 -406
  23. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  24. snowflake/ml/jobs/_utils/types.py +14 -7
  25. snowflake/ml/jobs/job.py +8 -9
  26. snowflake/ml/jobs/manager.py +64 -129
  27. snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
  28. snowflake/ml/model/_client/model/model_version_impl.py +109 -28
  29. snowflake/ml/model/_client/ops/model_ops.py +32 -6
  30. snowflake/ml/model/_client/ops/service_ops.py +9 -4
  31. snowflake/ml/model/_client/sql/service.py +69 -2
  32. snowflake/ml/model/_packager/model_handler.py +8 -2
  33. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  34. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  35. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  36. snowflake/ml/model/_signatures/core.py +305 -8
  37. snowflake/ml/model/_signatures/utils.py +13 -4
  38. snowflake/ml/model/compute_pool.py +2 -0
  39. snowflake/ml/model/models/huggingface.py +285 -0
  40. snowflake/ml/model/models/huggingface_pipeline.py +25 -215
  41. snowflake/ml/model/type_hints.py +5 -1
  42. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  43. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  44. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  45. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  46. snowflake/ml/utils/html_utils.py +67 -1
  47. snowflake/ml/version.py +1 -1
  48. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/METADATA +94 -7
  49. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/RECORD +52 -48
  50. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/WHEEL +0 -0
  51. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  52. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import json
3
3
  import sys
4
+ import warnings
4
5
  from typing import Any, Optional, Union
5
6
  from urllib.parse import quote
6
7
 
@@ -27,6 +28,13 @@ class ExperimentTracking:
27
28
  Class to manage experiments in Snowflake.
28
29
  """
29
30
 
31
+ _instance = None
32
+
33
+ def __new__(cls, *args: Any, **kwargs: Any) -> "ExperimentTracking":
34
+ if cls._instance is None:
35
+ cls._instance = super().__new__(cls)
36
+ return cls._instance
37
+
30
38
  def __init__(
31
39
  self,
32
40
  session: snowpark.Session,
@@ -36,6 +44,7 @@ class ExperimentTracking:
36
44
  ) -> None:
37
45
  """
38
46
  Initializes experiment tracking within a pre-created schema.
47
+ This is a singleton class, so if an instance already exists, it will not reinitialize.
39
48
 
40
49
  Args:
41
50
  session: The Snowpark Session to connect with Snowflake.
@@ -47,6 +56,21 @@ class ExperimentTracking:
47
56
  Raises:
48
57
  ValueError: If no database is provided and no active database exists in the session.
49
58
  """
59
+ if hasattr(self, "_initialized"):
60
+ warnings.warn(
61
+ "ExperimentTracking is a singleton class. Reusing the existing instance, which has the setting:\n"
62
+ f" Database: {self._database_name}, Schema: {self._schema_name}\n"
63
+ "To change the database or schema, use the database_name and schema_name arguments to set_experiment.",
64
+ UserWarning,
65
+ stacklevel=2,
66
+ )
67
+ return
68
+
69
+ # Declare types for mypy
70
+ self._database_name: sql_identifier.SqlIdentifier
71
+ self._schema_name: sql_identifier.SqlIdentifier
72
+ self._sql_client: sql_client.ExperimentTrackingSQLClient
73
+
50
74
  if database_name:
51
75
  self._database_name = sql_identifier.SqlIdentifier(database_name)
52
76
  elif session_db := session.get_current_database():
@@ -78,6 +102,8 @@ class ExperimentTracking:
78
102
  # The run in context
79
103
  self._run: Optional[entities.Run] = None
80
104
 
105
+ self._initialized = True
106
+
81
107
  def __getstate__(self) -> dict[str, Any]:
82
108
  parent_state = (
83
109
  super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
@@ -116,19 +142,40 @@ class ExperimentTracking:
116
142
  def set_experiment(
117
143
  self,
118
144
  experiment_name: str,
145
+ database_name: Optional[str] = None,
146
+ schema_name: Optional[str] = None,
119
147
  ) -> entities.Experiment:
120
148
  """
121
149
  Set the experiment in context. Creates a new experiment if it doesn't exist.
122
150
 
123
151
  Args:
124
152
  experiment_name: The name of the experiment.
153
+ database_name: The name of the database. If None, reuse the current database. Defaults to None.
154
+ schema_name: The name of the schema. If None, the behavior depends on whether `database_name` is specified.
155
+ If `database_name` is specified, the schema is set to "PUBLIC".
156
+ If `database_name` is not specified, reuse the current schema. Defaults to None.
125
157
 
126
158
  Returns:
127
159
  Experiment: The experiment that was set.
128
160
  """
161
+ if database_name is not None:
162
+ if schema_name is None:
163
+ schema_name = "PUBLIC"
164
+ database_name = (
165
+ sql_identifier.SqlIdentifier(database_name) if database_name is not None else self._database_name
166
+ )
167
+ schema_name = sql_identifier.SqlIdentifier(schema_name) if schema_name is not None else self._schema_name
168
+
129
169
  experiment_name = sql_identifier.SqlIdentifier(experiment_name)
130
- if self._experiment and self._experiment.name == experiment_name:
170
+ if (
171
+ self._experiment
172
+ and self._experiment.name == experiment_name
173
+ and self._database_name == database_name
174
+ and self._schema_name == schema_name
175
+ ):
131
176
  return self._experiment
177
+
178
+ self._update_database_and_schema(database_name, schema_name)
132
179
  self._sql_client.create_experiment(
133
180
  experiment_name=experiment_name,
134
181
  creation_mode=sql_client_utils.CreationMode(if_not_exists=True),
@@ -140,15 +187,42 @@ class ExperimentTracking:
140
187
  def delete_experiment(
141
188
  self,
142
189
  experiment_name: str,
190
+ database_name: Optional[str] = None,
191
+ schema_name: Optional[str] = None,
143
192
  ) -> None:
144
193
  """
145
194
  Delete an experiment.
146
195
 
147
196
  Args:
148
197
  experiment_name: The name of the experiment.
198
+ database_name: The name of the database. If None, reuse the current database.
199
+ Must be specified if `schema_name` is specified. Defaults to None.
200
+ schema_name: The name of the schema. If None, reuse the current schema.
201
+ Must be specified if `database_name` is specified. Defaults to None.
202
+
203
+ Raises:
204
+ ValueError: If database_name is specified but schema_name is not.
149
205
  """
150
- self._sql_client.drop_experiment(experiment_name=sql_identifier.SqlIdentifier(experiment_name))
151
- if self._experiment and self._experiment.name == experiment_name:
206
+ if (database_name is None) ^ (schema_name is None): # if only one of database_name and schema_name is set
207
+ raise ValueError(
208
+ "If one of database_name and schema_name is specified, the other one must also be specified."
209
+ )
210
+ database_name = (
211
+ sql_identifier.SqlIdentifier(database_name) if database_name is not None else self._database_name
212
+ )
213
+ schema_name = sql_identifier.SqlIdentifier(schema_name) if schema_name is not None else self._schema_name
214
+
215
+ self._sql_client.drop_experiment(
216
+ database_name=database_name,
217
+ schema_name=schema_name,
218
+ experiment_name=sql_identifier.SqlIdentifier(experiment_name),
219
+ )
220
+ if (
221
+ self._experiment
222
+ and self._experiment.name == experiment_name
223
+ and self._database_name == database_name
224
+ and self._schema_name == schema_name
225
+ ):
152
226
  self._experiment = None
153
227
  self._run = None
154
228
 
@@ -283,16 +357,26 @@ class ExperimentTracking:
283
357
  Args:
284
358
  metrics: Dictionary containing metric keys and float values.
285
359
  step: The step of the metrics. Defaults to 0.
360
+
361
+ Raises:
362
+ snowpark.exceptions.SnowparkSQLException: If logging metrics fails due to Snowflake SQL errors,
363
+ except for run metadata size limit errors which will issue a warning instead of raising.
286
364
  """
287
365
  run = self._get_or_start_run()
288
366
  metrics_list = []
289
367
  for key, value in metrics.items():
290
368
  metrics_list.append(entities.Metric(key, value, step))
291
- self._sql_client.modify_run_add_metrics(
292
- experiment_name=run.experiment_name,
293
- run_name=run.name,
294
- metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
295
- )
369
+ try:
370
+ self._sql_client.modify_run_add_metrics(
371
+ experiment_name=run.experiment_name,
372
+ run_name=run.name,
373
+ metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
374
+ )
375
+ except snowpark.exceptions.SnowparkSQLException as e:
376
+ if e.sql_error_code == 400003: # EXPERIMENT_RUN_PROPERTY_SIZE_LIMIT_EXCEEDED
377
+ run._warn_about_run_metadata_size(e.message)
378
+ else:
379
+ raise
296
380
 
297
381
  def log_param(
298
382
  self,
@@ -318,16 +402,26 @@ class ExperimentTracking:
318
402
  Args:
319
403
  params: Dictionary containing parameter keys and values. Values can be of any type, but will be converted
320
404
  to string.
405
+
406
+ Raises:
407
+ snowpark.exceptions.SnowparkSQLException: If logging parameters fails due to Snowflake SQL errors,
408
+ except for run metadata size limit errors which will issue a warning instead of raising.
321
409
  """
322
410
  run = self._get_or_start_run()
323
411
  params_list = []
324
412
  for key, value in params.items():
325
413
  params_list.append(entities.Param(key, str(value)))
326
- self._sql_client.modify_run_add_params(
327
- experiment_name=run.experiment_name,
328
- run_name=run.name,
329
- params=json.dumps([param.to_dict() for param in params_list]),
330
- )
414
+ try:
415
+ self._sql_client.modify_run_add_params(
416
+ experiment_name=run.experiment_name,
417
+ run_name=run.name,
418
+ params=json.dumps([param.to_dict() for param in params_list]),
419
+ )
420
+ except snowpark.exceptions.SnowparkSQLException as e:
421
+ if e.sql_error_code == 400003: # EXPERIMENT_RUN_PROPERTY_SIZE_LIMIT_EXCEEDED
422
+ run._warn_about_run_metadata_size(e.message)
423
+ else:
424
+ raise
331
425
 
332
426
  def log_artifact(
333
427
  self,
@@ -431,6 +525,22 @@ class ExperimentTracking:
431
525
  return sql_identifier.SqlIdentifier(run_name)
432
526
  raise RuntimeError("Random run name generation failed.")
433
527
 
528
+ def _update_database_and_schema(
529
+ self, database_name: sql_identifier.SqlIdentifier, schema_name: sql_identifier.SqlIdentifier
530
+ ) -> None:
531
+ self._database_name = database_name
532
+ self._schema_name = schema_name
533
+ self._sql_client = sql_client.ExperimentTrackingSQLClient(
534
+ session=self._session,
535
+ database_name=database_name,
536
+ schema_name=schema_name,
537
+ )
538
+ self._registry = registry.Registry(
539
+ session=self._session,
540
+ database_name=database_name,
541
+ schema_name=schema_name,
542
+ )
543
+
434
544
  def _print_urls(
435
545
  self,
436
546
  experiment_name: sql_identifier.SqlIdentifier,
@@ -1,3 +1,4 @@
1
+ import numbers
1
2
  from typing import Any, Union
2
3
 
3
4
 
@@ -12,3 +13,8 @@ def flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str
12
13
  else:
13
14
  flat_params[new_prefix] = value
14
15
  return flat_params
16
+
17
+
18
+ def is_integer(value: Any) -> bool:
19
+ """Check if the given value is an integer, excluding booleans."""
20
+ return isinstance(value, numbers.Integral) and not isinstance(value, bool)
@@ -202,6 +202,7 @@ def _configure_role_hierarchy(
202
202
  session.sql(f"GRANT ROLE {producer_role} TO ROLE {session.get_current_role()}").collect()
203
203
 
204
204
  if consumer_role is not None:
205
+ # Create CONSUMER and grant it to PRODUCER to build hierarchy
205
206
  consumer_role = SqlIdentifier(consumer_role)
206
207
  session.sql(f"CREATE ROLE IF NOT EXISTS {consumer_role}").collect()
207
208
  session.sql(f"GRANT ROLE {consumer_role} TO ROLE {producer_role}").collect()
@@ -1200,7 +1200,7 @@ class FeatureStore:
1200
1200
  {self._config.database}.INFORMATION_SCHEMA.DYNAMIC_TABLE_REFRESH_HISTORY (RESULT_LIMIT => 10000)
1201
1201
  )
1202
1202
  WHERE NAME = '{fv_resolved_name}'
1203
- AND SCHEMA_NAME = '{self._config.schema}'
1203
+ AND SCHEMA_NAME = '{self._config.schema.resolved()}'
1204
1204
  """
1205
1205
  )
1206
1206
 
@@ -218,38 +218,48 @@ class FeatureView(lineage_node.LineageNode):
218
218
  """
219
219
  Create a FeatureView instance.
220
220
 
221
+ # noqa: DAR101
222
+
221
223
  Args:
222
- name: name of the FeatureView. NOTE: following Snowflake identifier rule
223
- entities: entities that the FeatureView is associated with.
224
- feature_df: Snowpark DataFrame containing data source and all feature feature_df logics.
225
- Final projection of the DataFrame should contain feature names, join keys and timestamp(if applicable).
224
+ name: The name of the FeatureView. This must follow Snowflake identifier rules.
225
+ entities: The entities that the FeatureView is associated with.
226
+ feature_df: The Snowpark DataFrame containing data source and all feature feature_df logic.
227
+ The final projection of the DataFrame should contain feature names, join keys and timestamp if
228
+ applicable.
226
229
  timestamp_col: name of the timestamp column for point-in-time lookup when consuming the
227
230
  feature values.
228
- refresh_freq: Time unit defining how often the new feature data should be generated.
229
- Valid args are { <num> { seconds | minutes | hours | days } | DOWNSTREAM | <cron expr> <time zone>}.
230
- NOTE: Currently minimum refresh frequency is 1 minute.
231
- NOTE: If refresh_freq is in cron expression format, there must be a valid time zone as well.
232
- E.g. * * * * * UTC
233
- NOTE: If refresh_freq is not provided, then FeatureView will be registered as View on Snowflake backend
234
- and there won't be extra storage cost.
235
- desc: description of the FeatureView.
236
- warehouse: warehouse to refresh feature view. Not needed for static feature view (refresh_freq is None).
237
- For managed feature view, this warehouse will overwrite the default warehouse of Feature Store if it is
238
- specified, otherwise the default warehouse will be used.
231
+ refresh_freq: Time unit defining how often the new feature data should be generated, in the format
232
+ ``{ <num> { seconds | minutes | hours | days } | DOWNSTREAM | <cron expr> <time zone>}``.
233
+
234
+ The minimum refresh frequency is 1 minute.
235
+
236
+ When using a ``cron`` format, you must provide a time zone.
237
+
238
+ When you don't provide a refresh value, the ``FeatureView`` is registered as a ``View`` on the Snowflake
239
+ backend. There are no extra storage costs incurred for this view.
240
+ desc: Description of the FeatureView.
241
+ warehouse: The warehouse used to refresh this feature view. Not needed when ``refresh_freq`` is ``None``.
242
+ This warehouse will overwrite the default warehouse of Feature Store if specified, otherwise the default
243
+ warehouse will be used.
239
244
  initialize: Specifies the behavior of the initial refresh of feature view. This property cannot be altered
240
245
  after you register the feature view. It supports ON_CREATE (default) or ON_SCHEDULE. ON_CREATE refreshes
241
246
  the feature view synchronously at creation. ON_SCHEDULE refreshes the feature view at the next scheduled
242
247
  refresh. It is only effective when refresh_freq is not None.
243
248
  refresh_mode: The refresh mode of managed feature view. The value can be 'AUTO', 'FULL' or 'INCREMENTAL'.
244
- For managed feature view, the default value is 'AUTO'. For static feature view it has no effect.
245
- Check https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table for for details.
246
- cluster_by: Columns to cluster the feature view by.
247
- - Defaults to the join keys from entities.
248
- - If `timestamp_col` is provided, it is added to the default clustering keys.
249
- online_config: Optional configuration for online storage. If provided with enable=True,
250
- online storage will be enabled. Defaults to None (no online storage).
251
- NOTE: this feature is currently in Public Preview.
252
- _kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
249
+ For managed feature view, the default value is 'AUTO'. For static feature view it has no effect. For
250
+ more information, see
251
+ `CREATE DYNAMIC TABLE <https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table>`__.
252
+ cluster_by: Columns to cluster the feature view by. If ``timestamp_col`` is provided, it is added to the
253
+ default clustering keys. Default is to use the join keys from entities in the view.
254
+ online_config: Configuration for online storage. If provided with ``enable=True``,
255
+ online storage will be enabled. Defaults to ``None`` (no online storage).
256
+
257
+ .. note::
258
+ This feature is currently in preview.
259
+ _kwargs: Reserved kwargs for system generated args.
260
+
261
+ .. caution::
262
+ Use of additional keywords is prohibited.
253
263
 
254
264
  Example::
255
265
 
@@ -266,6 +266,9 @@ class PandasDataFrameProtocol(SerializationProtocol):
266
266
 
267
267
  # TODO: Support partitioned writes for large datasets
268
268
  result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
269
+ # stage mount v2 has a bug where it creates an empty file when creating a new file
270
+ with data_utils.open_stream(result_path, "wb", session=session) as stream:
271
+ stream.write(b"") # Dummy write to create the file
269
272
  with data_utils.open_stream(result_path, "wb", session=session) as stream:
270
273
  obj.to_parquet(stream)
271
274
 
@@ -31,6 +31,7 @@ def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
31
31
  class FeatureFlags(Enum):
32
32
  USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
33
33
  ENABLE_RUNTIME_VERSIONS = "MLRS_ENABLE_RUNTIME_VERSIONS"
34
+ ENABLE_STAGE_MOUNT_V2 = "MLRS_ENABLE_STAGE_MOUNT_V2"
34
35
 
35
36
  def is_enabled(self, default: bool = False) -> bool:
36
37
  """Check if the feature flag is enabled.