snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/__init__.py +4 -1
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +281 -21
  4. snowflake/cortex/_extract_answer.py +0 -1
  5. snowflake/cortex/_sentiment.py +0 -1
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +12 -85
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  16. snowflake/ml/_internal/telemetry.py +38 -2
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  20. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  21. snowflake/ml/data/data_connector.py +133 -0
  22. snowflake/ml/data/data_ingestor.py +28 -0
  23. snowflake/ml/data/data_source.py +23 -0
  24. snowflake/ml/dataset/dataset.py +39 -32
  25. snowflake/ml/dataset/dataset_reader.py +18 -118
  26. snowflake/ml/feature_store/access_manager.py +7 -1
  27. snowflake/ml/feature_store/entity.py +19 -2
  28. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  29. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  30. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  31. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  32. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  34. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  35. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  36. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  37. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  38. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  39. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  40. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  43. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  44. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  45. snowflake/ml/feature_store/feature_store.py +987 -264
  46. snowflake/ml/feature_store/feature_view.py +228 -13
  47. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  48. snowflake/ml/fileset/fileset.py +2 -2
  49. snowflake/ml/fileset/snowfs.py +4 -15
  50. snowflake/ml/fileset/stage_fs.py +24 -18
  51. snowflake/ml/lineage/__init__.py +3 -0
  52. snowflake/ml/lineage/lineage_node.py +139 -0
  53. snowflake/ml/model/_client/model/model_impl.py +47 -14
  54. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  55. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  56. snowflake/ml/model/_client/sql/model.py +1 -0
  57. snowflake/ml/model/_client/sql/model_version.py +45 -2
  58. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  59. snowflake/ml/model/_model_composer/model_composer.py +15 -17
  60. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
  61. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  62. snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
  63. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  64. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
  65. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
  66. snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
  67. snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
  68. snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
  69. snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
  70. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  71. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  72. snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
  73. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  74. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  75. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  76. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  77. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  78. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  79. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  80. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  81. snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
  82. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  83. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  84. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  85. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  86. snowflake/ml/model/_packager/model_packager.py +9 -4
  87. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  88. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  89. snowflake/ml/model/custom_model.py +22 -2
  90. snowflake/ml/model/model_signature.py +4 -4
  91. snowflake/ml/model/type_hints.py +77 -4
  92. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
  93. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  94. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  95. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  96. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  97. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  98. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  99. snowflake/ml/modeling/cluster/birch.py +4 -2
  100. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  101. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  102. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  103. snowflake/ml/modeling/cluster/k_means.py +4 -2
  104. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  105. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  106. snowflake/ml/modeling/cluster/optics.py +4 -2
  107. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  108. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  109. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  110. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  111. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  112. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  113. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  114. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  115. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  116. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  117. snowflake/ml/modeling/covariance/oas.py +4 -2
  118. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  119. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  120. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  121. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  122. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  123. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  124. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  125. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  126. snowflake/ml/modeling/decomposition/pca.py +4 -2
  127. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  128. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  129. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  130. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  131. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  132. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  133. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  134. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  135. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  136. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  137. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  138. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  139. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  140. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  141. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  142. snowflake/ml/modeling/manifold/isomap.py +4 -2
  143. snowflake/ml/modeling/manifold/mds.py +4 -2
  144. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  145. snowflake/ml/modeling/manifold/tsne.py +4 -2
  146. snowflake/ml/modeling/metrics/ranking.py +3 -0
  147. snowflake/ml/modeling/metrics/regression.py +3 -0
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  150. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  151. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  152. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  153. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  154. snowflake/ml/modeling/pipeline/pipeline.py +5 -4
  155. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  156. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  157. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  158. snowflake/ml/registry/_manager/model_manager.py +16 -3
  159. snowflake/ml/registry/registry.py +100 -13
  160. snowflake/ml/version.py +1 -1
  161. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
  162. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
  163. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  164. snowflake/ml/_internal/lineage/data_source.py +0 -10
  165. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
1
+ import json
2
+ from datetime import datetime
3
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Set, Type, Union
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml._internal import telemetry
7
+ from snowflake.ml._internal.utils import identifier
8
+
9
+ if TYPE_CHECKING:
10
+ from snowflake.ml import dataset
11
+ from snowflake.ml.feature_store import feature_view
12
+ from snowflake.ml.model._client.model import model_version_impl
13
+
14
+ _PROJECT = "LINEAGE"
15
+ DOMAIN_LINEAGE_REGISTRY: Dict[str, Type["LineageNode"]] = {}
16
+
17
+
18
+ class LineageNode:
19
+ """
20
+ Represents a node in a lineage graph and serves as the base class for all machine learning objects.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ session: snowpark.Session,
26
+ name: str,
27
+ domain: Union[Literal["feature_view", "dataset", "model", "table", "view"]],
28
+ version: Optional[str] = None,
29
+ status: Optional[Literal["ACTIVE", "DELETED", "MASKED"]] = None,
30
+ created_on: Optional[datetime] = None,
31
+ ) -> None:
32
+ """
33
+ Initializes a LineageNode instance.
34
+
35
+ Args:
36
+ session : The Snowflake session object.
37
+ name : Fully qualified name of the lineage node, which is in the format '<db>.<schema>.<object_name>'.
38
+ domain : The domain of the lineage node.
39
+ version : The version of the lineage node, if applies.
40
+ status : The status of the lineage node. Possible values are:
41
+ - 'MASKED': The user does not have the privilege to view the node.
42
+ - 'DELETED': The node has been deleted.
43
+ - 'ACTIVE': The node is currently active.
44
+ created_on : The creation time of the lineage node.
45
+
46
+ Raises:
47
+ ValueError: If the name is not fully qualified.
48
+ """
49
+ if name and not identifier.is_fully_qualified_name(name):
50
+ raise ValueError("name should be fully qualifed.")
51
+
52
+ self._lineage_node_name = name
53
+ self._lineage_node_domain = domain
54
+ self._lineage_node_version = version
55
+ self._lineage_node_status = status
56
+ self._lineage_node_created_on = created_on
57
+ self._session = session
58
+
59
+ def __repr__(self) -> str:
60
+ return (
61
+ f"{self.__class__.__name__}(\n"
62
+ f" name='{self._lineage_node_name}',\n"
63
+ f" version='{self._lineage_node_version}',\n"
64
+ f" domain='{self._lineage_node_domain}',\n"
65
+ f" status='{self._lineage_node_status}',\n"
66
+ f" created_on='{self._lineage_node_created_on}'\n"
67
+ f")"
68
+ )
69
+
70
+ @staticmethod
71
+ def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) -> "LineageNode":
72
+ """
73
+ Loads the concrete object.
74
+
75
+ Args:
76
+ session : The Snowflake session object.
77
+ name : Fully qualified name of the object.
78
+ version : The version of object.
79
+
80
+ Raises:
81
+ NotImplementedError: If the derived class does not implement this method.
82
+ """
83
+ raise NotImplementedError()
84
+
85
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
86
+ @snowpark._internal.utils.private_preview(version="1.5.3")
87
+ def lineage(
88
+ self,
89
+ direction: Literal["upstream", "downstream"] = "downstream",
90
+ domain_filter: Optional[Set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
91
+ ) -> List[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
92
+ """
93
+ Retrieves the lineage nodes connected to this node.
94
+
95
+ Args:
96
+ direction : The direction to trace lineage. Defaults to "downstream".
97
+ domain_filter : Set of domains to filter nodes. Defaults to None.
98
+
99
+ Returns:
100
+ List[LineageNode]: A list of connected lineage nodes.
101
+ """
102
+ df = self._session.lineage.trace(
103
+ self._lineage_node_name,
104
+ self._lineage_node_domain.upper(),
105
+ object_version=self._lineage_node_version,
106
+ direction=direction,
107
+ distance=1,
108
+ )
109
+ if domain_filter is not None:
110
+ domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
111
+
112
+ lineage_nodes: List["LineageNode"] = []
113
+ for row in df.collect():
114
+ lineage_object = (
115
+ json.loads(row["TARGET_OBJECT"])
116
+ if direction.lower() == "downstream"
117
+ else json.loads(row["SOURCE_OBJECT"])
118
+ )
119
+ domain = lineage_object["domain"].lower()
120
+ if domain_filter is None or domain in domain_filter:
121
+ if domain in DOMAIN_LINEAGE_REGISTRY and lineage_object["status"] == "ACTIVE":
122
+ lineage_nodes.append(
123
+ DOMAIN_LINEAGE_REGISTRY[domain]._load_from_lineage_node(
124
+ self._session, lineage_object["name"], lineage_object.get("version")
125
+ )
126
+ )
127
+ else:
128
+ lineage_nodes.append(
129
+ LineageNode(
130
+ name=lineage_object["name"],
131
+ version=lineage_object.get("version"),
132
+ domain=domain,
133
+ status=lineage_object["status"],
134
+ created_on=datetime.strptime(lineage_object["createdOn"], "%Y-%m-%dT%H:%M:%SZ"),
135
+ session=self._session,
136
+ )
137
+ )
138
+
139
+ return lineage_nodes
@@ -9,6 +9,10 @@ from snowflake.ml.model._client.ops import model_ops
9
9
 
10
10
  _TELEMETRY_PROJECT = "MLOps"
11
11
  _TELEMETRY_SUBPROJECT = "ModelManagement"
12
+ SYSTEM_VERSION_ALIAS_DEFAULT = "DEFAULT"
13
+ SYSTEM_VERSION_ALIAS_FIRST = "FIRST"
14
+ SYSTEM_VERSION_ALIAS_LAST = "LAST"
15
+ SYSTEM_VERSION_ALIASES = (SYSTEM_VERSION_ALIAS_DEFAULT, SYSTEM_VERSION_ALIAS_FIRST, SYSTEM_VERSION_ALIAS_LAST)
12
16
 
13
17
 
14
18
  class Model:
@@ -144,12 +148,28 @@ class Model:
144
148
  project=_TELEMETRY_PROJECT,
145
149
  subproject=_TELEMETRY_SUBPROJECT,
146
150
  )
147
- def version(self, version_name: str) -> model_version_impl.ModelVersion:
151
+ def first(self) -> model_version_impl.ModelVersion:
152
+ """The first version of the model."""
153
+ return self.version(SYSTEM_VERSION_ALIAS_FIRST)
154
+
155
+ @telemetry.send_api_usage_telemetry(
156
+ project=_TELEMETRY_PROJECT,
157
+ subproject=_TELEMETRY_SUBPROJECT,
158
+ )
159
+ def last(self) -> model_version_impl.ModelVersion:
160
+ """The latest version of the model."""
161
+ return self.version(SYSTEM_VERSION_ALIAS_LAST)
162
+
163
+ @telemetry.send_api_usage_telemetry(
164
+ project=_TELEMETRY_PROJECT,
165
+ subproject=_TELEMETRY_SUBPROJECT,
166
+ )
167
+ def version(self, version_or_alias: str) -> model_version_impl.ModelVersion:
148
168
  """
149
- Get a model version object given a version name in the model.
169
+ Get a model version object given a version name or version alias in the model.
150
170
 
151
171
  Args:
152
- version_name: The name of the version.
172
+ version_or_alias: The name of the version or alias to a version.
153
173
 
154
174
  Raises:
155
175
  ValueError: When the requested version does not exist.
@@ -161,23 +181,36 @@ class Model:
161
181
  project=_TELEMETRY_PROJECT,
162
182
  subproject=_TELEMETRY_SUBPROJECT,
163
183
  )
164
- version_id = sql_identifier.SqlIdentifier(version_name)
165
- if self._model_ops.validate_existence(
184
+
185
+ # check with system alias or with user defined alias
186
+ version_id = self._model_ops.get_version_by_alias(
166
187
  database_name=None,
167
188
  schema_name=None,
168
189
  model_name=self._model_name,
169
- version_name=version_id,
190
+ alias_name=sql_identifier.SqlIdentifier(version_or_alias),
170
191
  statement_params=statement_params,
171
- ):
172
- return model_version_impl.ModelVersion._ref(
173
- self._model_ops,
192
+ )
193
+
194
+ # version_id is still None implies version_or_alias is not an alias. So it must be a version name.
195
+ if version_id is None:
196
+ version_id = sql_identifier.SqlIdentifier(version_or_alias)
197
+ if not self._model_ops.validate_existence(
198
+ database_name=None,
199
+ schema_name=None,
174
200
  model_name=self._model_name,
175
201
  version_name=version_id,
176
- )
177
- else:
178
- raise ValueError(
179
- f"Unable to find version with name {version_id.identifier()} in model {self.fully_qualified_name}"
180
- )
202
+ statement_params=statement_params,
203
+ ):
204
+ raise ValueError(
205
+ f"Unable to find version or alias with name {version_id.identifier()} "
206
+ f"in model {self.fully_qualified_name}"
207
+ )
208
+
209
+ return model_version_impl.ModelVersion._ref(
210
+ self._model_ops,
211
+ model_name=self._model_name,
212
+ version_name=version_id,
213
+ )
181
214
 
182
215
  @telemetry.send_api_usage_telemetry(
183
216
  project=_TELEMETRY_PROJECT,
@@ -8,12 +8,13 @@ import pandas as pd
8
8
 
9
9
  from snowflake.ml._internal import telemetry
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.lineage import lineage_node
11
12
  from snowflake.ml.model import type_hints as model_types
12
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops
13
14
  from snowflake.ml.model._model_composer import model_composer
14
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
15
16
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
16
- from snowflake.snowpark import dataframe
17
+ from snowflake.snowpark import Session, dataframe
17
18
 
18
19
  _TELEMETRY_PROJECT = "MLOps"
19
20
  _TELEMETRY_SUBPROJECT = "ModelManagement"
@@ -24,7 +25,7 @@ class ExportMode(enum.Enum):
24
25
  FULL = "full"
25
26
 
26
27
 
27
- class ModelVersion:
28
+ class ModelVersion(lineage_node.LineageNode):
28
29
  """Model Version Object representing a specific version of the model that could be run."""
29
30
 
30
31
  _model_ops: model_ops.ModelOperator
@@ -48,6 +49,15 @@ class ModelVersion:
48
49
  self._model_name = model_name
49
50
  self._version_name = version_name
50
51
  self._functions = self._get_functions()
52
+ super(cls, cls).__init__(
53
+ self,
54
+ session=model_ops._session,
55
+ name=model_ops._model_client.fully_qualified_object_name(
56
+ database_name=None, schema_name=None, object_name=model_name
57
+ ),
58
+ domain="model",
59
+ version=version_name,
60
+ )
51
61
  return self
52
62
 
53
63
  def __eq__(self, __value: object) -> bool:
@@ -59,6 +69,11 @@ class ModelVersion:
59
69
  and self._version_name == __value._version_name
60
70
  )
61
71
 
72
+ def __repr__(self) -> str:
73
+ return (
74
+ f"{self.__class__.__name__}(\n" f" name='{self.model_name}',\n" f" version='{self._version_name}',\n" f")"
75
+ )
76
+
62
77
  @property
63
78
  def model_name(self) -> str:
64
79
  """Return the name of the model to which the model version belongs, usable as a reference in SQL."""
@@ -198,6 +213,52 @@ class ModelVersion:
198
213
  statement_params=statement_params,
199
214
  )
200
215
 
216
+ @telemetry.send_api_usage_telemetry(
217
+ project=_TELEMETRY_PROJECT,
218
+ subproject=_TELEMETRY_SUBPROJECT,
219
+ )
220
+ def set_alias(self, alias_name: str) -> None:
221
+ """Set alias to a model version.
222
+
223
+ Args:
224
+ alias_name: Alias to the model version.
225
+ """
226
+ statement_params = telemetry.get_statement_params(
227
+ project=_TELEMETRY_PROJECT,
228
+ subproject=_TELEMETRY_SUBPROJECT,
229
+ )
230
+ alias_name = sql_identifier.SqlIdentifier(alias_name)
231
+ self._model_ops.set_alias(
232
+ alias_name=alias_name,
233
+ database_name=None,
234
+ schema_name=None,
235
+ model_name=self._model_name,
236
+ version_name=self._version_name,
237
+ statement_params=statement_params,
238
+ )
239
+
240
+ @telemetry.send_api_usage_telemetry(
241
+ project=_TELEMETRY_PROJECT,
242
+ subproject=_TELEMETRY_SUBPROJECT,
243
+ )
244
+ def unset_alias(self, version_or_alias: str) -> None:
245
+ """unset alias to a model version.
246
+
247
+ Args:
248
+ version_or_alias: The name of the version or alias to a version.
249
+ """
250
+ statement_params = telemetry.get_statement_params(
251
+ project=_TELEMETRY_PROJECT,
252
+ subproject=_TELEMETRY_SUBPROJECT,
253
+ )
254
+ self._model_ops.unset_alias(
255
+ version_or_alias_name=sql_identifier.SqlIdentifier(version_or_alias),
256
+ database_name=None,
257
+ schema_name=None,
258
+ model_name=self._model_name,
259
+ statement_params=statement_params,
260
+ )
261
+
201
262
  @telemetry.send_api_usage_telemetry(
202
263
  project=_TELEMETRY_PROJECT,
203
264
  subproject=_TELEMETRY_SUBPROJECT,
@@ -451,3 +512,22 @@ class ModelVersion:
451
512
  f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}"
452
513
  )
453
514
  return pk.model
515
+
516
+ @staticmethod
517
+ def _load_from_lineage_node(session: Session, name: str, version: str) -> "ModelVersion":
518
+ database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(name)
519
+ if not database_name_id or not schema_name_id:
520
+ raise ValueError("name should be fully qualifed.")
521
+
522
+ return ModelVersion._ref(
523
+ model_ops.ModelOperator(
524
+ session,
525
+ database_name=database_name_id,
526
+ schema_name=schema_name_id,
527
+ ),
528
+ model_name=model_name_id,
529
+ version_name=sql_identifier.SqlIdentifier(version),
530
+ )
531
+
532
+
533
+ lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -1,11 +1,12 @@
1
1
  import os
2
2
  import pathlib
3
3
  import tempfile
4
+ import warnings
4
5
  from typing import Any, Dict, List, Literal, Optional, Union, cast
5
6
 
6
7
  import yaml
7
8
 
8
- from snowflake.ml._internal.utils import identifier, sql_identifier
9
+ from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
9
10
  from snowflake.ml.model import model_signature, type_hints
10
11
  from snowflake.ml.model._client.ops import metadata_ops
11
12
  from snowflake.ml.model._client.sql import (
@@ -311,6 +312,42 @@ class ModelOperator:
311
312
  statement_params=statement_params,
312
313
  )
313
314
 
315
+ def set_alias(
316
+ self,
317
+ *,
318
+ alias_name: sql_identifier.SqlIdentifier,
319
+ database_name: Optional[sql_identifier.SqlIdentifier],
320
+ schema_name: Optional[sql_identifier.SqlIdentifier],
321
+ model_name: sql_identifier.SqlIdentifier,
322
+ version_name: sql_identifier.SqlIdentifier,
323
+ statement_params: Optional[Dict[str, Any]] = None,
324
+ ) -> None:
325
+ self._model_version_client.set_alias(
326
+ alias_name=alias_name,
327
+ database_name=database_name,
328
+ schema_name=schema_name,
329
+ model_name=model_name,
330
+ version_name=version_name,
331
+ statement_params=statement_params,
332
+ )
333
+
334
+ def unset_alias(
335
+ self,
336
+ *,
337
+ version_or_alias_name: sql_identifier.SqlIdentifier,
338
+ database_name: Optional[sql_identifier.SqlIdentifier],
339
+ schema_name: Optional[sql_identifier.SqlIdentifier],
340
+ model_name: sql_identifier.SqlIdentifier,
341
+ statement_params: Optional[Dict[str, Any]] = None,
342
+ ) -> None:
343
+ self._model_version_client.unset_alias(
344
+ database_name=database_name,
345
+ schema_name=schema_name,
346
+ model_name=model_name,
347
+ version_or_alias_name=version_or_alias_name,
348
+ statement_params=statement_params,
349
+ )
350
+
314
351
  def set_default_version(
315
352
  self,
316
353
  *,
@@ -354,6 +391,28 @@ class ModelOperator:
354
391
  res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
355
392
  )
356
393
 
394
+ def get_version_by_alias(
395
+ self,
396
+ *,
397
+ database_name: Optional[sql_identifier.SqlIdentifier],
398
+ schema_name: Optional[sql_identifier.SqlIdentifier],
399
+ model_name: sql_identifier.SqlIdentifier,
400
+ alias_name: sql_identifier.SqlIdentifier,
401
+ statement_params: Optional[Dict[str, Any]] = None,
402
+ ) -> Optional[sql_identifier.SqlIdentifier]:
403
+ res = self._model_client.show_versions(
404
+ database_name=database_name,
405
+ schema_name=schema_name,
406
+ model_name=model_name,
407
+ statement_params=statement_params,
408
+ )
409
+ for r in res:
410
+ if alias_name in r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]:
411
+ return sql_identifier.SqlIdentifier(
412
+ r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
413
+ )
414
+ return None
415
+
357
416
  def get_tag_value(
358
417
  self,
359
418
  *,
@@ -625,10 +684,23 @@ class ModelOperator:
625
684
  )
626
685
 
627
686
  if keep_order:
628
- df_res = df_res.sort(
629
- "_ID",
630
- ascending=True,
631
- )
687
+ # if it's a partitioned table function, _ID will be null and we won't be able to sort.
688
+ if df_res.select("_ID").limit(1).collect()[0][0] is None:
689
+ warnings.warn(
690
+ formatting.unwrap(
691
+ """
692
+ When invoking partitioned inference methods, ordering of rows in output dataframe will differ
693
+ from that of input dataframe.
694
+ """
695
+ ),
696
+ category=UserWarning,
697
+ stacklevel=1,
698
+ )
699
+ else:
700
+ df_res = df_res.sort(
701
+ "_ID",
702
+ ascending=True,
703
+ )
632
704
 
633
705
  if not output_with_input_features:
634
706
  cols_to_drop = original_cols
@@ -14,6 +14,7 @@ class ModelSQLClient(_base._BaseSQLClient):
14
14
  MODEL_VERSION_COMMENT_COL_NAME = "comment"
15
15
  MODEL_VERSION_METADATA_COL_NAME = "metadata"
16
16
  MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
17
+ MODEL_VERSION_ALIASES_COL_NAME = "aliases"
17
18
 
18
19
  def show_models(
19
20
  self,
@@ -134,6 +134,43 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
134
134
  statement_params=statement_params,
135
135
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
136
136
 
137
+ def set_alias(
138
+ self,
139
+ *,
140
+ database_name: Optional[sql_identifier.SqlIdentifier],
141
+ schema_name: Optional[sql_identifier.SqlIdentifier],
142
+ model_name: sql_identifier.SqlIdentifier,
143
+ version_name: sql_identifier.SqlIdentifier,
144
+ alias_name: sql_identifier.SqlIdentifier,
145
+ statement_params: Optional[Dict[str, Any]] = None,
146
+ ) -> None:
147
+ query_result_checker.SqlResultValidator(
148
+ self._session,
149
+ (
150
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
151
+ f"VERSION {version_name.identifier()} SET ALIAS = {alias_name.identifier()}"
152
+ ),
153
+ statement_params=statement_params,
154
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
155
+
156
+ def unset_alias(
157
+ self,
158
+ *,
159
+ database_name: Optional[sql_identifier.SqlIdentifier],
160
+ schema_name: Optional[sql_identifier.SqlIdentifier],
161
+ model_name: sql_identifier.SqlIdentifier,
162
+ version_or_alias_name: sql_identifier.SqlIdentifier,
163
+ statement_params: Optional[Dict[str, Any]] = None,
164
+ ) -> None:
165
+ query_result_checker.SqlResultValidator(
166
+ self._session,
167
+ (
168
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
169
+ f"VERSION {version_or_alias_name.identifier()} UNSET ALIAS"
170
+ ),
171
+ statement_params=statement_params,
172
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
173
+
137
174
  def list_file(
138
175
  self,
139
176
  *,
@@ -383,9 +420,13 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
383
420
  # Prepare the output
384
421
  output_cols = []
385
422
  output_names = []
423
+ cols_to_drop = []
386
424
 
387
425
  for output_name, output_type, output_col_name in returns:
388
- output_cols.append(F.col(output_name).astype(output_type))
426
+ output_identifier = sql_identifier.SqlIdentifier(output_name).identifier()
427
+ if output_identifier != output_col_name:
428
+ cols_to_drop.append(output_identifier)
429
+ output_cols.append(F.col(output_identifier).astype(output_type))
389
430
  output_names.append(output_col_name)
390
431
 
391
432
  if partition_column is not None:
@@ -396,10 +437,12 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
396
437
  col_names=output_names,
397
438
  values=output_cols,
398
439
  )
399
-
400
440
  if statement_params:
401
441
  output_df._statement_params = statement_params # type: ignore[assignment]
402
442
 
443
+ if cols_to_drop:
444
+ output_df = output_df.drop(cols_to_drop)
445
+
403
446
  return output_df
404
447
 
405
448
  def set_metadata(
@@ -85,9 +85,8 @@ def _run_setup() -> None:
85
85
 
86
86
  TARGET_METHOD = os.getenv("TARGET_METHOD")
87
87
 
88
- _concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", None)
89
-
90
- _CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env) if _concurrent_requests_max_env else None
88
+ _concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", "1")
89
+ _CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env)
91
90
 
92
91
  with tempfile.TemporaryDirectory() as tmp_dir:
93
92
  if zipfile.is_zipfile(model_zip_stage_path):
@@ -101,7 +100,6 @@ def _run_setup() -> None:
101
100
  logger.info(f"Loading model from {extracted_dir} into memory")
102
101
 
103
102
  sys.path.insert(0, os.path.join(extracted_dir, _MODEL_CODE_DIR))
104
- from snowflake.ml.model import type_hints as model_types
105
103
 
106
104
  # TODO (Server-side Model Rollout):
107
105
  # Keep try block only
@@ -114,7 +112,7 @@ def _run_setup() -> None:
114
112
  pk.load(
115
113
  as_custom_model=True,
116
114
  meta_only=False,
117
- options=model_types.ModelLoadOption({"use_gpu": use_gpu}),
115
+ options={"use_gpu": use_gpu},
118
116
  )
119
117
  _LOADED_MODEL = pk.model
120
118
  _LOADED_META = pk.meta
@@ -132,7 +130,7 @@ def _run_setup() -> None:
132
130
  _LOADED_MODEL, meta_LOADED_META = model_api._load(
133
131
  local_dir_path=extracted_dir,
134
132
  as_custom_model=True,
135
- options=model_types.ModelLoadOption({"use_gpu": use_gpu}),
133
+ options={"use_gpu": use_gpu},
136
134
  )
137
135
  _MODEL_LOADING_STATE = _ModelLoadingState.SUCCEEDED
138
136
  logger.info("Successfully loaded model into memory")
@@ -11,10 +11,12 @@ from packaging import requirements
11
11
  from typing_extensions import deprecated
12
12
 
13
13
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
- from snowflake.ml._internal.lineage import data_source, lineage_utils
14
+ from snowflake.ml._internal.lineage import lineage_utils
15
+ from snowflake.ml.data import data_source
15
16
  from snowflake.ml.model import model_signature, type_hints as model_types
16
17
  from snowflake.ml.model._model_composer.model_manifest import model_manifest
17
18
  from snowflake.ml.model._packager import model_packager
19
+ from snowflake.ml.model._packager.model_meta import model_meta
18
20
  from snowflake.snowpark import Session
19
21
  from snowflake.snowpark._internal import utils as snowpark_utils
20
22
 
@@ -90,7 +92,7 @@ class ModelComposer:
90
92
  ext_modules: Optional[List[ModuleType]] = None,
91
93
  code_paths: Optional[List[str]] = None,
92
94
  options: Optional[model_types.ModelSaveOption] = None,
93
- ) -> None:
95
+ ) -> model_meta.ModelMetadata:
94
96
  if not options:
95
97
  options = model_types.BaseModelSaveOption()
96
98
 
@@ -106,7 +108,7 @@ class ModelComposer:
106
108
  )
107
109
  options["embed_local_ml_library"] = True
108
110
 
109
- self.packager.save(
111
+ model_metadata: model_meta.ModelMetadata = self.packager.save(
110
112
  name=name,
111
113
  model=model,
112
114
  signatures=signatures,
@@ -119,7 +121,6 @@ class ModelComposer:
119
121
  code_paths=code_paths,
120
122
  options=options,
121
123
  )
122
-
123
124
  assert self.packager.meta is not None
124
125
 
125
126
  if not options.get("_legacy_save", False):
@@ -128,16 +129,14 @@ class ModelComposer:
128
129
  file_utils.copytree(
129
130
  str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
130
131
  )
131
-
132
- file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
133
-
134
- self.manifest.save(
135
- session=self.session,
136
- model_meta=self.packager.meta,
137
- model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
138
- options=options,
139
- data_sources=self._get_data_sources(model, sample_input_data),
140
- )
132
+ self.manifest.save(
133
+ model_meta=self.packager.meta,
134
+ model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
135
+ options=options,
136
+ data_sources=self._get_data_sources(model, sample_input_data),
137
+ )
138
+ else:
139
+ file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
141
140
 
142
141
  file_utils.upload_directory_to_stage(
143
142
  self.session,
@@ -145,6 +144,7 @@ class ModelComposer:
145
144
  stage_path=self.stage_path,
146
145
  statement_params=self._statement_params,
147
146
  )
147
+ return model_metadata
148
148
 
149
149
  @deprecated("Only used by PrPr model registry. Use static method version of load instead.")
150
150
  def legacy_load(
@@ -185,6 +185,4 @@ class ModelComposer:
185
185
  data_sources = lineage_utils.get_data_sources(model)
186
186
  if not data_sources and sample_input_data is not None:
187
187
  data_sources = lineage_utils.get_data_sources(sample_input_data)
188
- if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
189
- return data_sources
190
- return None
188
+ return data_sources