snowflake-ml-python 1.5.3__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +224 -21
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_summarize.py +0 -1
  6. snowflake/cortex/_translate.py +0 -1
  7. snowflake/cortex/_util.py +12 -85
  8. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  9. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  10. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  11. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  12. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  13. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  14. snowflake/ml/_internal/telemetry.py +26 -0
  15. snowflake/ml/_internal/utils/identifier.py +14 -0
  16. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  17. snowflake/ml/dataset/dataset.py +39 -20
  18. snowflake/ml/feature_store/feature_store.py +440 -243
  19. snowflake/ml/feature_store/feature_view.py +61 -9
  20. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  21. snowflake/ml/fileset/fileset.py +2 -2
  22. snowflake/ml/fileset/snowfs.py +4 -15
  23. snowflake/ml/fileset/stage_fs.py +6 -8
  24. snowflake/ml/lineage/__init__.py +3 -0
  25. snowflake/ml/lineage/lineage_node.py +139 -0
  26. snowflake/ml/model/_client/model/model_impl.py +47 -14
  27. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  28. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  29. snowflake/ml/model/_client/sql/model.py +1 -0
  30. snowflake/ml/model/_client/sql/model_version.py +45 -2
  31. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  32. snowflake/ml/model/_model_composer/model_composer.py +5 -4
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  34. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  35. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  36. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -2
  37. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  38. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  39. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  41. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  42. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  43. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  45. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  46. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  53. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  54. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  55. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  56. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  57. snowflake/ml/model/_packager/model_packager.py +9 -4
  58. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  59. snowflake/ml/model/custom_model.py +22 -2
  60. snowflake/ml/model/type_hints.py +73 -4
  61. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -0
  62. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  63. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  64. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  65. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  66. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  67. snowflake/ml/modeling/cluster/birch.py +4 -2
  68. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  69. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  70. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  71. snowflake/ml/modeling/cluster/k_means.py +4 -2
  72. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  74. snowflake/ml/modeling/cluster/optics.py +4 -2
  75. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  76. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  77. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  78. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  79. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  80. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  81. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  82. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  83. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  84. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  85. snowflake/ml/modeling/covariance/oas.py +4 -2
  86. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  87. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  88. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  89. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  90. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  91. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  92. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  93. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  94. snowflake/ml/modeling/decomposition/pca.py +4 -2
  95. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  96. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  97. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  98. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  99. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  100. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  101. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  102. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  103. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  104. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  105. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  106. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  107. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  108. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  109. snowflake/ml/modeling/manifold/isomap.py +4 -2
  110. snowflake/ml/modeling/manifold/mds.py +4 -2
  111. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  112. snowflake/ml/modeling/manifold/tsne.py +4 -2
  113. snowflake/ml/modeling/metrics/ranking.py +3 -0
  114. snowflake/ml/modeling/metrics/regression.py +3 -0
  115. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  116. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  117. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  118. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  119. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  120. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  121. snowflake/ml/modeling/pipeline/pipeline.py +1 -0
  122. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  123. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  124. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  125. snowflake/ml/registry/_manager/model_manager.py +16 -3
  126. snowflake/ml/version.py +1 -1
  127. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +35 -7
  128. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/RECORD +131 -127
  129. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  130. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  131. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -2,17 +2,25 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  import re
5
+ import warnings
5
6
  from collections import OrderedDict
6
7
  from dataclasses import asdict, dataclass
7
8
  from enum import Enum
8
9
  from typing import Any, Dict, List, Optional
9
10
 
11
+ from snowflake.ml._internal.exceptions import (
12
+ error_codes,
13
+ exceptions as snowml_exceptions,
14
+ )
15
+ from snowflake.ml._internal.utils import identifier
10
16
  from snowflake.ml._internal.utils.identifier import concat_names
11
17
  from snowflake.ml._internal.utils.sql_identifier import (
12
18
  SqlIdentifier,
13
19
  to_sql_identifiers,
14
20
  )
21
+ from snowflake.ml.feature_store import feature_store
15
22
  from snowflake.ml.feature_store.entity import Entity
23
+ from snowflake.ml.lineage import lineage_node
16
24
  from snowflake.snowpark import DataFrame, Session
17
25
  from snowflake.snowpark.types import (
18
26
  DateType,
@@ -67,6 +75,7 @@ class FeatureViewVersion(str):
67
75
 
68
76
 
69
77
  class FeatureViewStatus(Enum):
78
+ MASKED = "MASKED" # for shared feature views where scheduling state is not available
70
79
  DRAFT = "DRAFT"
71
80
  STATIC = "STATIC"
72
81
  RUNNING = "RUNNING" # This can be deprecated after BCR 2024_02 gets fully deployed
@@ -107,7 +116,7 @@ class FeatureViewSlice:
107
116
  return cls(**json_dict)
108
117
 
109
118
 
110
- class FeatureView:
119
+ class FeatureView(lineage_node.LineageNode):
111
120
  """
112
121
  A FeatureView instance encapsulates a logical group of features.
113
122
  """
@@ -243,6 +252,16 @@ class FeatureView:
243
252
  def desc(self) -> str:
244
253
  return self._desc
245
254
 
255
+ @desc.setter
256
+ def desc(self, new_value: str) -> None:
257
+ warnings.warn(
258
+ "You must call register_feature_view() to make it effective. "
259
+ "Or use update_feature_view(desc=<new_value>).",
260
+ stacklevel=2,
261
+ category=UserWarning,
262
+ )
263
+ self._desc = new_value
264
+
246
265
  @property
247
266
  def query(self) -> str:
248
267
  return self._query
@@ -269,10 +288,12 @@ class FeatureView:
269
288
 
270
289
  @refresh_freq.setter
271
290
  def refresh_freq(self, new_value: str) -> None:
272
- if self.status == FeatureViewStatus.DRAFT or self.status == FeatureViewStatus.STATIC:
273
- raise RuntimeError(
274
- f"Feature view {self.name}/{self.version} must be registered and non-static to update refresh_freq."
275
- )
291
+ warnings.warn(
292
+ "You must call register_feature_view() to make it effective. "
293
+ "Or use update_feature_view(refresh_freq=<new_value>).",
294
+ stacklevel=2,
295
+ category=UserWarning,
296
+ )
276
297
  self._refresh_freq = new_value
277
298
 
278
299
  @property
@@ -289,10 +310,12 @@ class FeatureView:
289
310
 
290
311
  @warehouse.setter
291
312
  def warehouse(self, new_value: str) -> None:
292
- if self.status == FeatureViewStatus.DRAFT or self.status == FeatureViewStatus.STATIC:
293
- raise RuntimeError(
294
- f"Feature view {self.name}/{self.version} must be registered and non-static to update warehouse."
295
- )
313
+ warnings.warn(
314
+ "You must call register_feature_view() to make it effective. "
315
+ "Or use update_feature_view(warehouse=<new_value>).",
316
+ stacklevel=2,
317
+ category=UserWarning,
318
+ )
296
319
  self._warehouse = SqlIdentifier(new_value)
297
320
 
298
321
  @property
@@ -406,6 +429,11 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
406
429
  feature_desc_dict[k.identifier()] = v
407
430
  fv_dict["_feature_desc"] = feature_desc_dict
408
431
 
432
+ lineage_node_keys = [key for key in fv_dict if key.startswith("_node") or key == "_session"]
433
+
434
+ for key in lineage_node_keys:
435
+ fv_dict.pop(key)
436
+
409
437
  return fv_dict
410
438
 
411
439
  def to_df(self, session: Session) -> DataFrame:
@@ -449,6 +477,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
449
477
  refresh_mode_reason=json_dict["_refresh_mode_reason"],
450
478
  owner=json_dict["_owner"],
451
479
  infer_schema_df=session.sql(json_dict.get("_infer_schema_query", None)),
480
+ session=session,
452
481
  )
453
482
 
454
483
  @staticmethod
@@ -463,6 +492,21 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
463
492
  )
464
493
  )
465
494
 
495
+ @staticmethod
496
+ def _load_from_lineage_node(session: Session, name: str, version: str) -> FeatureView:
497
+ db_name, feature_store_name, feature_view_name, _ = identifier.parse_schema_level_object_identifier(name)
498
+
499
+ session_warehouse = session.get_current_warehouse()
500
+
501
+ if not session_warehouse:
502
+ raise snowml_exceptions.SnowflakeMLException(
503
+ error_code=error_codes.NOT_FOUND,
504
+ original_exception=ValueError("No active warehouse selected in the current session"),
505
+ )
506
+
507
+ fs = feature_store.FeatureStore(session, db_name, feature_store_name, session_warehouse)
508
+ return fs.get_feature_view(feature_view_name, version) # type: ignore[no-any-return]
509
+
466
510
  @staticmethod
467
511
  def _construct_feature_view(
468
512
  name: str,
@@ -481,6 +525,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
481
525
  refresh_mode_reason: Optional[str],
482
526
  owner: Optional[str],
483
527
  infer_schema_df: Optional[DataFrame],
528
+ session: Session,
484
529
  ) -> FeatureView:
485
530
  fv = FeatureView(
486
531
  name=name,
@@ -500,4 +545,11 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
500
545
  fv._refresh_mode_reason = refresh_mode_reason
501
546
  fv._owner = owner
502
547
  fv.attach_feature_desc(feature_descs)
548
+
549
+ lineage_node.LineageNode.__init__(
550
+ fv, session=session, name=f"{fv.database}.{fv._schema}.{name}", domain="feature_view", version=version
551
+ )
503
552
  return fv
553
+
554
+
555
+ lineage_node.DOMAIN_LINEAGE_REGISTRY["feature_view"] = FeatureView
@@ -11,11 +11,17 @@ from snowflake.ml._internal.exceptions import (
11
11
  fileset_errors,
12
12
  )
13
13
  from snowflake.ml._internal.utils import identifier
14
+ from snowflake.ml.fileset import stage_fs
14
15
  from snowflake.snowpark import exceptions as snowpark_exceptions
15
16
 
16
- from . import stage_fs
17
-
18
- _SNOWURL_PATH_RE = re.compile(r"versions/(?P<version>[^/]+)(?:/+(?P<filepath>.*))?")
17
+ PROTOCOL_NAME = "snow"
18
+ _SNOWURL_ENTITY_PATTERN = (
19
+ f"(?:{PROTOCOL_NAME}://)?"
20
+ r"(?<!@)(?P<domain>\w+)/"
21
+ rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/"
22
+ )
23
+ _SNOWURL_VERSION_PATTERN = r"(?P<path>versions/(?:(?P<version>[^/]+)(?:/+(?P<relpath>.*))?)?)"
24
+ _SNOWURL_PATH_RE = re.compile(f"(?:{_SNOWURL_ENTITY_PATTERN})?" + _SNOWURL_VERSION_PATTERN)
19
25
 
20
26
 
21
27
  class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem):
@@ -76,8 +82,8 @@ class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem):
76
82
  versions_dict = defaultdict(list)
77
83
  for file in files:
78
84
  match = _SNOWURL_PATH_RE.fullmatch(file)
79
- assert match is not None and match.group("filepath") is not None
80
- versions_dict[match.group("version")].append(match.group("filepath"))
85
+ assert match is not None and match.group("relpath") is not None
86
+ versions_dict[match.group("version")].append(match.group("relpath"))
81
87
  try:
82
88
  async_jobs: List[snowpark.AsyncJob] = []
83
89
  for version, version_files in versions_dict.items():
@@ -98,10 +104,8 @@ class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem):
98
104
  (r["NAME"], r["URL"]) for job in async_jobs for r in stage_fs._resolve_async_job(job)
99
105
  ]
100
106
  return presigned_urls
101
- except snowpark_exceptions.SnowparkClientException as e:
102
- if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST) or e.message.startswith(
103
- fileset_errors.ERRNO_STAGE_NOT_EXIST
104
- ):
107
+ except snowpark_exceptions.SnowparkSQLException as e:
108
+ if e.sql_error_code in {fileset_errors.ERRNO_DOMAIN_NOT_EXIST, fileset_errors.ERRNO_STAGE_NOT_EXIST}:
105
109
  raise snowml_exceptions.SnowflakeMLException(
106
110
  error_code=error_codes.SNOWML_NOT_FOUND,
107
111
  original_exception=fileset_errors.StageNotFoundError(
@@ -118,7 +122,7 @@ class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem):
118
122
  def _parent(cls, path: str) -> str:
119
123
  """Get parent of specified path up to minimally valid root path.
120
124
 
121
- For SnowURL, the minimum valid path is snow://<domain>/<entity>/versions/<version>
125
+ For SnowURL, the minimum valid relative path is versions/<version>
122
126
 
123
127
  Args:
124
128
  path: File or directory path
@@ -128,22 +132,22 @@ class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem):
128
132
 
129
133
  Examples:
130
134
  ----
131
- >>> fs._parent("snow://dataset/my_ds/versions/my_version/file.ext")
132
- "snow://dataset/my_ds/versions/my_version/"
133
- >>> fs._parent("snow://dataset/my_ds/versions/my_version/subdir/file.ext")
134
- "snow://dataset/my_ds/versions/my_version/subdir/"
135
- >>> fs._parent("snow://dataset/my_ds/versions/my_version/")
136
- "snow://dataset/my_ds/versions/my_version/"
137
- >>> fs._parent("snow://dataset/my_ds/versions/my_version")
138
- "snow://dataset/my_ds/versions/my_version"
135
+ >>> fs._parent("versions/my_version/file.ext")
136
+ "versions/my_version"
137
+ >>> fs._parent("versions/my_version/subdir/file.ext")
138
+ "versions/my_version/subdir"
139
+ >>> fs._parent("versions/my_version/")
140
+ "versions/my_version"
141
+ >>> fs._parent("versions/my_version")
142
+ "versions/my_version"
139
143
  """
140
144
  path_match = _SNOWURL_PATH_RE.fullmatch(path)
141
145
  if not path_match:
142
146
  return super()._parent(path) # type: ignore[no-any-return]
143
- filepath: str = path_match.group("filepath") or ""
144
- root: str = path[: path_match.start("filepath")] if filepath else path
147
+ filepath: str = path_match.group("relpath") or ""
148
+ root: str = path[: path_match.start("relpath")] if filepath else path
145
149
  if "/" in filepath:
146
150
  parent = filepath.rsplit("/", 1)[0]
147
151
  return root + parent
148
152
  else:
149
- return root
153
+ return root.rstrip("/")
@@ -256,9 +256,9 @@ class FileSet:
256
256
  api_calls=[snowpark.DataFrameWriter.copy_into_location],
257
257
  ),
258
258
  )
259
- except snowpark_exceptions.SnowparkClientException as e:
259
+ except snowpark_exceptions.SnowparkSQLException as e:
260
260
  # Snowpark wraps the Python Connector error code in the head of the error message.
261
- if e.message.startswith(fileset_errors.ERRNO_FILE_EXIST_IN_STAGE):
261
+ if e.sql_error_code == fileset_errors.ERRNO_FILE_EXIST_IN_STAGE:
262
262
  raise fileset_errors.FileSetExistError(fileset_error_messages.FILESET_ALREADY_EXISTS.format(name))
263
263
  else:
264
264
  raise fileset_errors.FileSetError(str(e))
@@ -14,18 +14,10 @@ from snowflake.ml._internal.exceptions import (
14
14
  from snowflake.ml._internal.utils import identifier
15
15
  from snowflake.ml.fileset import embedded_stage_fs, sfcfs
16
16
 
17
- PROTOCOL_NAME = "snow"
18
-
19
17
  _SFFileEntityPath = collections.namedtuple(
20
18
  "_SFFileEntityPath", ["domain", "name", "filepath", "version", "relative_path"]
21
19
  )
22
- _PROJECT = "FileSet"
23
- _SNOWURL_PATTERN = re.compile(
24
- f"({PROTOCOL_NAME}://)?"
25
- r"(?<!@)(?P<domain>\w+)/"
26
- rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/"
27
- r"(?P<path>versions/(?:(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)"
28
- )
20
+ _SNOWURL_PATTERN = re.compile(embedded_stage_fs._SNOWURL_ENTITY_PATTERN + embedded_stage_fs._SNOWURL_VERSION_PATTERN)
29
21
 
30
22
 
31
23
  class SnowFileSystem(sfcfs.SFFileSystem):
@@ -38,7 +30,7 @@ class SnowFileSystem(sfcfs.SFFileSystem):
38
30
  See `sfcfs.SFFileSystem` documentation for example usage patterns.
39
31
  """
40
32
 
41
- protocol = PROTOCOL_NAME
33
+ protocol = embedded_stage_fs.PROTOCOL_NAME
42
34
  _IS_BUGGED_VERSION = None
43
35
 
44
36
  def __init__(
@@ -75,10 +67,7 @@ class SnowFileSystem(sfcfs.SFFileSystem):
75
67
  """Convert the relative path in a stage to an absolute path starts with the location of the stage."""
76
68
  # Strip protocol from absolute path, since backend needs snow:// prefix to resolve correctly
77
69
  # but fsspec logic strips protocol when doing any searching and globbing
78
- stage_name = stage_fs.stage_name
79
- protocol = f"{PROTOCOL_NAME}://"
80
- if stage_name.startswith(protocol):
81
- stage_name = stage_name[len(protocol) :]
70
+ stage_name: str = self._strip_protocol(stage_fs.stage_name)
82
71
  abs_path = stage_name + "/" + path
83
72
  return abs_path
84
73
 
@@ -128,4 +117,4 @@ class SnowFileSystem(sfcfs.SFFileSystem):
128
117
  )
129
118
 
130
119
 
131
- fsspec.register_implementation(PROTOCOL_NAME, SnowFileSystem)
120
+ fsspec.register_implementation(SnowFileSystem.protocol, SnowFileSystem)
@@ -170,8 +170,8 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
170
170
  path = path.lstrip("/")
171
171
  async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
172
172
  objects: List[snowpark.Row] = _resolve_async_job(async_job)
173
- except snowpark_exceptions.SnowparkClientException as e:
174
- if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST):
173
+ except snowpark_exceptions.SnowparkSQLException as e:
174
+ if e.sql_error_code == fileset_errors.ERRNO_DOMAIN_NOT_EXIST:
175
175
  raise snowml_exceptions.SnowflakeMLException(
176
176
  error_code=error_codes.SNOWML_NOT_FOUND,
177
177
  original_exception=fileset_errors.StageNotFoundError(
@@ -387,10 +387,8 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
387
387
  api_calls=[snowpark.DataFrame.collect],
388
388
  ),
389
389
  )
390
- except snowpark_exceptions.SnowparkClientException as e:
391
- if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST) or e.message.startswith(
392
- fileset_errors.ERRNO_STAGE_NOT_EXIST
393
- ):
390
+ except snowpark_exceptions.SnowparkSQLException as e:
391
+ if e.sql_error_code in {fileset_errors.ERRNO_DOMAIN_NOT_EXIST, fileset_errors.ERRNO_STAGE_NOT_EXIST}:
394
392
  raise snowml_exceptions.SnowflakeMLException(
395
393
  error_code=error_codes.SNOWML_NOT_FOUND,
396
394
  original_exception=fileset_errors.StageNotFoundError(
@@ -406,9 +404,9 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
406
404
 
407
405
 
408
406
  def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code: int) -> bool:
409
- # Snowpark writes error code to message instead of populating e.error_code
407
+ # Snowpark writes error code to message instead of populating e.sql_error_code
410
408
  error_code_str = str(error_code)
411
- return ex.error_code == error_code_str or error_code_str in ex.message
409
+ return ex.sql_error_code == error_code_str or error_code_str in ex.message
412
410
 
413
411
 
414
412
  @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
@@ -0,0 +1,3 @@
1
+ from .lineage_node import LineageNode
2
+
3
+ __all__ = ["LineageNode"]
@@ -0,0 +1,139 @@
1
+ import json
2
+ from datetime import datetime
3
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Set, Type, Union
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml._internal import telemetry
7
+ from snowflake.ml._internal.utils import identifier
8
+
9
+ if TYPE_CHECKING:
10
+ from snowflake.ml import dataset
11
+ from snowflake.ml.feature_store import feature_view
12
+ from snowflake.ml.model._client.model import model_version_impl
13
+
14
+ _PROJECT = "LINEAGE"
15
+ DOMAIN_LINEAGE_REGISTRY: Dict[str, Type["LineageNode"]] = {}
16
+
17
+
18
+ class LineageNode:
19
+ """
20
+ Represents a node in a lineage graph and serves as the base class for all machine learning objects.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ session: snowpark.Session,
26
+ name: str,
27
+ domain: Union[Literal["feature_view", "dataset", "model", "table", "view"]],
28
+ version: Optional[str] = None,
29
+ status: Optional[Literal["ACTIVE", "DELETED", "MASKED"]] = None,
30
+ created_on: Optional[datetime] = None,
31
+ ) -> None:
32
+ """
33
+ Initializes a LineageNode instance.
34
+
35
+ Args:
36
+ session : The Snowflake session object.
37
+ name : Fully qualified name of the lineage node, which is in the format '<db>.<schema>.<object_name>'.
38
+ domain : The domain of the lineage node.
39
+ version : The version of the lineage node, if applies.
40
+ status : The status of the lineage node. Possible values are:
41
+ - 'MASKED': The user does not have the privilege to view the node.
42
+ - 'DELETED': The node has been deleted.
43
+ - 'ACTIVE': The node is currently active.
44
+ created_on : The creation time of the lineage node.
45
+
46
+ Raises:
47
+ ValueError: If the name is not fully qualified.
48
+ """
49
+ if name and not identifier.is_fully_qualified_name(name):
50
+ raise ValueError("name should be fully qualifed.")
51
+
52
+ self._lineage_node_name = name
53
+ self._lineage_node_domain = domain
54
+ self._lineage_node_version = version
55
+ self._lineage_node_status = status
56
+ self._lineage_node_created_on = created_on
57
+ self._session = session
58
+
59
+ def __repr__(self) -> str:
60
+ return (
61
+ f"{self.__class__.__name__}(\n"
62
+ f" name='{self._lineage_node_name}',\n"
63
+ f" version='{self._lineage_node_version}',\n"
64
+ f" domain='{self._lineage_node_domain}',\n"
65
+ f" status='{self._lineage_node_status}',\n"
66
+ f" created_on='{self._lineage_node_created_on}'\n"
67
+ f")"
68
+ )
69
+
70
+ @staticmethod
71
+ def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) -> "LineageNode":
72
+ """
73
+ Loads the concrete object.
74
+
75
+ Args:
76
+ session : The Snowflake session object.
77
+ name : Fully qualified name of the object.
78
+ version : The version of object.
79
+
80
+ Raises:
81
+ NotImplementedError: If the derived class does not implement this method.
82
+ """
83
+ raise NotImplementedError()
84
+
85
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
86
+ @snowpark._internal.utils.private_preview(version="1.5.3")
87
+ def lineage(
88
+ self,
89
+ direction: Literal["upstream", "downstream"] = "downstream",
90
+ domain_filter: Optional[Set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
91
+ ) -> List[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
92
+ """
93
+ Retrieves the lineage nodes connected to this node.
94
+
95
+ Args:
96
+ direction : The direction to trace lineage. Defaults to "downstream".
97
+ domain_filter : Set of domains to filter nodes. Defaults to None.
98
+
99
+ Returns:
100
+ List[LineageNode]: A list of connected lineage nodes.
101
+ """
102
+ df = self._session.lineage.trace(
103
+ self._lineage_node_name,
104
+ self._lineage_node_domain.upper(),
105
+ object_version=self._lineage_node_version,
106
+ direction=direction,
107
+ distance=1,
108
+ )
109
+ if domain_filter is not None:
110
+ domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
111
+
112
+ lineage_nodes: List["LineageNode"] = []
113
+ for row in df.collect():
114
+ lineage_object = (
115
+ json.loads(row["TARGET_OBJECT"])
116
+ if direction.lower() == "downstream"
117
+ else json.loads(row["SOURCE_OBJECT"])
118
+ )
119
+ domain = lineage_object["domain"].lower()
120
+ if domain_filter is None or domain in domain_filter:
121
+ if domain in DOMAIN_LINEAGE_REGISTRY:
122
+ lineage_nodes.append(
123
+ DOMAIN_LINEAGE_REGISTRY[domain]._load_from_lineage_node(
124
+ self._session, lineage_object["name"], lineage_object.get("version")
125
+ )
126
+ )
127
+ else:
128
+ lineage_nodes.append(
129
+ LineageNode(
130
+ name=lineage_object["name"],
131
+ version=lineage_object.get("version"),
132
+ domain=domain,
133
+ status=lineage_object["status"],
134
+ created_on=datetime.strptime(lineage_object["createdOn"], "%Y-%m-%dT%H:%M:%SZ"),
135
+ session=self._session,
136
+ )
137
+ )
138
+
139
+ return lineage_nodes
@@ -9,6 +9,10 @@ from snowflake.ml.model._client.ops import model_ops
9
9
 
10
10
  _TELEMETRY_PROJECT = "MLOps"
11
11
  _TELEMETRY_SUBPROJECT = "ModelManagement"
12
+ SYSTEM_VERSION_ALIAS_DEFAULT = "DEFAULT"
13
+ SYSTEM_VERSION_ALIAS_FIRST = "FIRST"
14
+ SYSTEM_VERSION_ALIAS_LAST = "LAST"
15
+ SYSTEM_VERSION_ALIASES = (SYSTEM_VERSION_ALIAS_DEFAULT, SYSTEM_VERSION_ALIAS_FIRST, SYSTEM_VERSION_ALIAS_LAST)
12
16
 
13
17
 
14
18
  class Model:
@@ -144,12 +148,28 @@ class Model:
144
148
  project=_TELEMETRY_PROJECT,
145
149
  subproject=_TELEMETRY_SUBPROJECT,
146
150
  )
147
- def version(self, version_name: str) -> model_version_impl.ModelVersion:
151
+ def first(self) -> model_version_impl.ModelVersion:
152
+ """The first version of the model."""
153
+ return self.version(SYSTEM_VERSION_ALIAS_FIRST)
154
+
155
+ @telemetry.send_api_usage_telemetry(
156
+ project=_TELEMETRY_PROJECT,
157
+ subproject=_TELEMETRY_SUBPROJECT,
158
+ )
159
+ def last(self) -> model_version_impl.ModelVersion:
160
+ """The latest version of the model."""
161
+ return self.version(SYSTEM_VERSION_ALIAS_LAST)
162
+
163
+ @telemetry.send_api_usage_telemetry(
164
+ project=_TELEMETRY_PROJECT,
165
+ subproject=_TELEMETRY_SUBPROJECT,
166
+ )
167
+ def version(self, version_or_alias: str) -> model_version_impl.ModelVersion:
148
168
  """
149
- Get a model version object given a version name in the model.
169
+ Get a model version object given a version name or version alias in the model.
150
170
 
151
171
  Args:
152
- version_name: The name of the version.
172
+ version_or_alias: The name of the version or alias to a version.
153
173
 
154
174
  Raises:
155
175
  ValueError: When the requested version does not exist.
@@ -161,23 +181,36 @@ class Model:
161
181
  project=_TELEMETRY_PROJECT,
162
182
  subproject=_TELEMETRY_SUBPROJECT,
163
183
  )
164
- version_id = sql_identifier.SqlIdentifier(version_name)
165
- if self._model_ops.validate_existence(
184
+
185
+ # check with system alias or with user defined alias
186
+ version_id = self._model_ops.get_version_by_alias(
166
187
  database_name=None,
167
188
  schema_name=None,
168
189
  model_name=self._model_name,
169
- version_name=version_id,
190
+ alias_name=sql_identifier.SqlIdentifier(version_or_alias),
170
191
  statement_params=statement_params,
171
- ):
172
- return model_version_impl.ModelVersion._ref(
173
- self._model_ops,
192
+ )
193
+
194
+ # version_id is still None implies version_or_alias is not an alias. So it must be a version name.
195
+ if version_id is None:
196
+ version_id = sql_identifier.SqlIdentifier(version_or_alias)
197
+ if not self._model_ops.validate_existence(
198
+ database_name=None,
199
+ schema_name=None,
174
200
  model_name=self._model_name,
175
201
  version_name=version_id,
176
- )
177
- else:
178
- raise ValueError(
179
- f"Unable to find version with name {version_id.identifier()} in model {self.fully_qualified_name}"
180
- )
202
+ statement_params=statement_params,
203
+ ):
204
+ raise ValueError(
205
+ f"Unable to find version or alias with name {version_id.identifier()} "
206
+ f"in model {self.fully_qualified_name}"
207
+ )
208
+
209
+ return model_version_impl.ModelVersion._ref(
210
+ self._model_ops,
211
+ model_name=self._model_name,
212
+ version_name=version_id,
213
+ )
181
214
 
182
215
  @telemetry.send_api_usage_telemetry(
183
216
  project=_TELEMETRY_PROJECT,