snowflake-ml-python 1.5.3__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +224 -21
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_summarize.py +0 -1
  6. snowflake/cortex/_translate.py +0 -1
  7. snowflake/cortex/_util.py +12 -85
  8. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  9. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  10. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  11. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  12. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  13. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  14. snowflake/ml/_internal/telemetry.py +26 -0
  15. snowflake/ml/_internal/utils/identifier.py +14 -0
  16. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  17. snowflake/ml/dataset/dataset.py +39 -20
  18. snowflake/ml/feature_store/feature_store.py +440 -243
  19. snowflake/ml/feature_store/feature_view.py +61 -9
  20. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  21. snowflake/ml/fileset/fileset.py +2 -2
  22. snowflake/ml/fileset/snowfs.py +4 -15
  23. snowflake/ml/fileset/stage_fs.py +6 -8
  24. snowflake/ml/lineage/__init__.py +3 -0
  25. snowflake/ml/lineage/lineage_node.py +139 -0
  26. snowflake/ml/model/_client/model/model_impl.py +47 -14
  27. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  28. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  29. snowflake/ml/model/_client/sql/model.py +1 -0
  30. snowflake/ml/model/_client/sql/model_version.py +45 -2
  31. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  32. snowflake/ml/model/_model_composer/model_composer.py +5 -4
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  34. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  35. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  36. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -2
  37. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  38. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  39. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  41. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  42. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  43. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  45. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  46. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  53. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  54. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  55. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  56. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  57. snowflake/ml/model/_packager/model_packager.py +9 -4
  58. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  59. snowflake/ml/model/custom_model.py +22 -2
  60. snowflake/ml/model/type_hints.py +73 -4
  61. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -0
  62. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  63. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  64. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  65. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  66. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  67. snowflake/ml/modeling/cluster/birch.py +4 -2
  68. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  69. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  70. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  71. snowflake/ml/modeling/cluster/k_means.py +4 -2
  72. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  74. snowflake/ml/modeling/cluster/optics.py +4 -2
  75. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  76. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  77. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  78. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  79. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  80. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  81. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  82. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  83. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  84. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  85. snowflake/ml/modeling/covariance/oas.py +4 -2
  86. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  87. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  88. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  89. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  90. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  91. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  92. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  93. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  94. snowflake/ml/modeling/decomposition/pca.py +4 -2
  95. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  96. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  97. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  98. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  99. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  100. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  101. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  102. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  103. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  104. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  105. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  106. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  107. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  108. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  109. snowflake/ml/modeling/manifold/isomap.py +4 -2
  110. snowflake/ml/modeling/manifold/mds.py +4 -2
  111. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  112. snowflake/ml/modeling/manifold/tsne.py +4 -2
  113. snowflake/ml/modeling/metrics/ranking.py +3 -0
  114. snowflake/ml/modeling/metrics/regression.py +3 -0
  115. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  116. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  117. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  118. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  119. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  120. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  121. snowflake/ml/modeling/pipeline/pipeline.py +1 -0
  122. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  123. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  124. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  125. snowflake/ml/registry/_manager/model_manager.py +16 -3
  126. snowflake/ml/version.py +1 -1
  127. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +35 -7
  128. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/RECORD +131 -127
  129. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  130. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  131. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ from typing import (
10
10
  Dict,
11
11
  Iterable,
12
12
  List,
13
+ Mapping,
13
14
  Optional,
14
15
  Tuple,
15
16
  TypeVar,
@@ -92,6 +93,31 @@ def get_statement_params(
92
93
  )
93
94
 
94
95
 
96
+ def add_statement_params_custom_tags(
97
+ statement_params: Optional[Dict[str, Any]], custom_tags: Mapping[str, Any]
98
+ ) -> Dict[str, Any]:
99
+ """
100
+ Add custom_tags to existing statement_params. Overwrite keys in custom_tags dict that already exist.
101
+ If existing statement_params are not provided, do nothing as the information cannot be effectively tracked.
102
+
103
+ Args:
104
+ statement_params: Existing statement_params dictionary.
105
+ custom_tags: Dictionary of existing k/v pairs to add as custom_tags
106
+
107
+ Returns:
108
+ new statement_params dictionary with all keys and an updated custom_tags field.
109
+ """
110
+ if not statement_params:
111
+ return {}
112
+ existing_custom_tags: Dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
113
+ existing_custom_tags.update(custom_tags)
114
+ # NOTE: This can be done with | operator after upgrade from py3.8
115
+ return {
116
+ **statement_params,
117
+ TelemetryField.KEY_CUSTOM_TAGS.value: existing_custom_tags,
118
+ }
119
+
120
+
95
121
  # TODO: we can merge this with get_statement_params after code clean up
96
122
  def get_statement_params_full_func_name(frame: Optional[types.FrameType], class_name: Optional[str] = None) -> str:
97
123
  """
@@ -165,6 +165,20 @@ def parse_schema_level_object_identifier(
165
165
  )
166
166
 
167
167
 
168
+ def is_fully_qualified_name(name: str) -> bool:
169
+ """
170
+ Checks if a given name is a fully qualified name, which is in the format '<db>.<schema>.<object_name>'.
171
+
172
+ Args:
173
+ name: The name to be checked.
174
+
175
+ Returns:
176
+ bool: True if the name is fully qualified, False otherwise.
177
+ """
178
+ res = parse_schema_level_object_identifier(name)
179
+ return res[0] is not None and res[1] is not None and res[2] is not None and not res[3]
180
+
181
+
168
182
  def get_schema_level_object_identifier(
169
183
  db: Optional[str],
170
184
  schema: Optional[str],
@@ -1,22 +1,27 @@
1
1
  import logging
2
2
  import warnings
3
+ from typing import List, Optional
3
4
 
4
5
  from snowflake import snowpark
6
+ from snowflake.ml._internal.utils import sql_identifier
5
7
  from snowflake.snowpark import functions, types
6
8
 
7
9
 
8
- def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
10
+ def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[List[str]] = None) -> snowpark.DataFrame:
9
11
  """Cast columns in the dataframe to types that are compatible with tensor.
10
12
 
11
13
  It assists FileSet.make() in performing implicit data casting.
12
14
 
13
15
  Args:
14
16
  df: A snowpark dataframe.
17
+ ignore_columns: Columns to exclude from casting. These columns will be propagated unchanged.
15
18
 
16
19
  Returns:
17
20
  A snowpark dataframe whose data type has been casted.
18
21
  """
19
22
 
23
+ ignore_cols_set = {sql_identifier.SqlIdentifier(c).identifier() for c in ignore_columns} if ignore_columns else {}
24
+
20
25
  fields = df.schema.fields
21
26
  selected_cols = []
22
27
  for field in fields:
@@ -40,7 +45,9 @@ def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
40
45
  dest = field.datatype
41
46
  selected_cols.append(functions.cast(functions.col(src), dest).alias(src))
42
47
  else:
43
- if field.datatype in (types.DateType(), types.TimestampType(), types.TimeType()):
48
+ if field.column_identifier.name in ignore_cols_set:
49
+ pass
50
+ elif field.datatype in (types.DateType(), types.TimestampType(), types.TimeType()):
44
51
  logging.warning(
45
52
  "A Column with DATE or TIMESTAMP data type detected. "
46
53
  "It might not be able to get converted to tensors. "
@@ -90,7 +97,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
90
97
  " is being automatically converted to DoubleType in the Snowpark DataFrame. "
91
98
  "This automatic conversion may lead to potential precision loss and rounding errors. "
92
99
  "If you wish to prevent this conversion, you should manually perform "
93
- "the necessary data type conversion."
100
+ "the necessary data type conversion.",
101
+ UserWarning,
102
+ stacklevel=2,
94
103
  )
95
104
  else:
96
105
  # IntegerType default as NUMBER(38, 0), but
@@ -102,7 +111,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
102
111
  " is being automatically converted to LongType in the Snowpark DataFrame. "
103
112
  "This automatic conversion may lead to potential precision loss and rounding errors. "
104
113
  "If you wish to prevent this conversion, you should manually perform "
105
- "the necessary data type conversion."
114
+ "the necessary data type conversion.",
115
+ UserWarning,
116
+ stacklevel=2,
106
117
  )
107
118
  selected_cols.append(functions.cast(functions.col(src), dest_dtype).alias(src))
108
119
  # TODO: add more type handling or error message
@@ -19,6 +19,7 @@ from snowflake.ml._internal.utils import (
19
19
  snowpark_dataframe_utils,
20
20
  )
21
21
  from snowflake.ml.dataset import dataset_metadata, dataset_reader
22
+ from snowflake.ml.lineage import lineage_node
22
23
  from snowflake.snowpark import exceptions as snowpark_exceptions, functions
23
24
 
24
25
  _PROJECT = "Dataset"
@@ -125,7 +126,7 @@ class DatasetVersion:
125
126
  return f"{self.__class__.__name__}(dataset='{self._parent.fully_qualified_name}', version='{self.name}')"
126
127
 
127
128
 
128
- class Dataset:
129
+ class Dataset(lineage_node.LineageNode):
129
130
  """Represents a Snowflake Dataset which is organized into versions."""
130
131
 
131
132
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -138,18 +139,31 @@ class Dataset:
138
139
  selected_version: Optional[str] = None,
139
140
  ) -> None:
140
141
  """Initialize a lazily evaluated Dataset object"""
141
- self._session = session
142
142
  self._db = database
143
143
  self._schema = schema
144
144
  self._name = name
145
- self._fully_qualified_name = identifier.get_schema_level_object_identifier(database, schema, name)
145
+
146
+ super().__init__(
147
+ session,
148
+ identifier.get_schema_level_object_identifier(database, schema, name),
149
+ domain="dataset",
150
+ version=selected_version,
151
+ )
146
152
 
147
153
  self._version = DatasetVersion(self, selected_version) if selected_version else None
148
154
  self._reader: Optional[dataset_reader.DatasetReader] = None
149
155
 
156
+ def __repr__(self) -> str:
157
+ return (
158
+ f"{self.__class__.__name__}(\n"
159
+ f" name='{self._lineage_node_name}',\n"
160
+ f" version='{self._version._version if self._version else None}',\n"
161
+ f")"
162
+ )
163
+
150
164
  @property
151
165
  def fully_qualified_name(self) -> str:
152
- return self._fully_qualified_name
166
+ return self._lineage_node_name
153
167
 
154
168
  @property
155
169
  def selected_version(self) -> Optional[DatasetVersion]:
@@ -168,7 +182,7 @@ class Dataset:
168
182
  self._session,
169
183
  [
170
184
  data_source.DataSource(
171
- fully_qualified_name=self._fully_qualified_name,
185
+ fully_qualified_name=self._lineage_node_name,
172
186
  version=v.name,
173
187
  url=v.url(),
174
188
  exclude_cols=(v.label_cols + v.exclude_cols),
@@ -230,9 +244,8 @@ class Dataset:
230
244
  try:
231
245
  session.sql(query).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)
232
246
  return Dataset(session, db, schema, ds_name)
233
- except snowpark_exceptions.SnowparkClientException as e:
234
- # Snowpark wraps the Python Connector error code in the head of the error message.
235
- if e.message.startswith(dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS):
247
+ except snowpark_exceptions.SnowparkSQLException as e:
248
+ if e.sql_error_code == dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS:
236
249
  raise snowml_exceptions.SnowflakeMLException(
237
250
  error_code=error_codes.OBJECT_ALREADY_EXISTS,
238
251
  original_exception=dataset_errors.DatasetExistError(
@@ -296,7 +309,7 @@ class Dataset:
296
309
  Raises:
297
310
  SnowflakeMLException: The Dataset no longer exists.
298
311
  SnowflakeMLException: The specified Dataset version already exists.
299
- snowpark_exceptions.SnowparkClientException: An error occurred during Dataset creation.
312
+ snowpark_exceptions.SnowparkSQLException: An error occurred during Dataset creation.
300
313
 
301
314
  Note: During the generation of stage files, data casting will occur. The casting rules are as follows::
302
315
  - Data casting:
@@ -321,7 +334,8 @@ class Dataset:
321
334
  - DateType(DATE): Not supported. A warning will be logged.
322
335
  - VariantType(VARIANT): Not supported. A warning will be logged.
323
336
  """
324
- casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe)
337
+ cast_ignore_cols = (exclude_cols or []) + (label_cols or [])
338
+ casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe, ignore_columns=cast_ignore_cols)
325
339
 
326
340
  if shuffle:
327
341
  casted_df = casted_df.order_by(functions.random())
@@ -367,19 +381,19 @@ class Dataset:
367
381
 
368
382
  return Dataset(self._session, self._db, self._schema, self._name, version)
369
383
 
370
- except snowpark_exceptions.SnowparkClientException as e:
371
- if e.message.startswith(dataset_errors.ERRNO_DATASET_NOT_EXIST):
384
+ except snowpark_exceptions.SnowparkSQLException as e:
385
+ if e.sql_error_code == dataset_errors.ERRNO_DATASET_NOT_EXIST:
372
386
  raise snowml_exceptions.SnowflakeMLException(
373
387
  error_code=error_codes.NOT_FOUND,
374
388
  original_exception=dataset_errors.DatasetNotExistError(
375
389
  dataset_error_messages.DATASET_NOT_EXIST.format(self.fully_qualified_name)
376
390
  ),
377
391
  ) from e
378
- elif (
379
- e.message.startswith(dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS)
380
- or e.message.startswith(dataset_errors.ERRNO_VERSION_ALREADY_EXISTS)
381
- or e.message.startswith(dataset_errors.ERRNO_FILES_ALREADY_EXISTING)
382
- ):
392
+ elif e.sql_error_code in {
393
+ dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS,
394
+ dataset_errors.ERRNO_VERSION_ALREADY_EXISTS,
395
+ dataset_errors.ERRNO_FILES_ALREADY_EXISTING,
396
+ }:
383
397
  raise snowml_exceptions.SnowflakeMLException(
384
398
  error_code=error_codes.OBJECT_ALREADY_EXISTS,
385
399
  original_exception=dataset_errors.DatasetExistError(
@@ -435,9 +449,8 @@ class Dataset:
435
449
  .has_column(_DATASET_VERSION_NAME_COL, allow_empty=True)
436
450
  .validate()
437
451
  )
438
- except snowpark_exceptions.SnowparkClientException as e:
439
- # Snowpark wraps the Python Connector error code in the head of the error message.
440
- if e.message.startswith(dataset_errors.ERRNO_OBJECT_NOT_EXIST):
452
+ except snowpark_exceptions.SnowparkSQLException as e:
453
+ if e.sql_error_code == dataset_errors.ERRNO_OBJECT_NOT_EXIST:
441
454
  raise snowml_exceptions.SnowflakeMLException(
442
455
  error_code=error_codes.NOT_FOUND,
443
456
  original_exception=dataset_errors.DatasetNotExistError(
@@ -459,6 +472,12 @@ class Dataset:
459
472
  ),
460
473
  )
461
474
 
475
+ @staticmethod
476
+ def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) -> "Dataset":
477
+ return Dataset.load(session, name).select_version(version)
478
+
479
+
480
+ lineage_node.DOMAIN_LINEAGE_REGISTRY["dataset"] = Dataset
462
481
 
463
482
  # Utility methods
464
483