snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,29 @@
1
+ import enum
2
+ import pathlib
3
+ import tempfile
4
+ import warnings
1
5
  from typing import Any, Callable, Dict, List, Optional, Union
2
6
 
3
7
  import pandas as pd
4
8
 
5
9
  from snowflake.ml._internal import telemetry
6
10
  from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.model import type_hints as model_types
7
12
  from snowflake.ml.model._client.ops import metadata_ops, model_ops
13
+ from snowflake.ml.model._model_composer import model_composer
8
14
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
15
+ from snowflake.ml.model._packager.model_handlers import snowmlmodel
9
16
  from snowflake.snowpark import dataframe
10
17
 
11
18
  _TELEMETRY_PROJECT = "MLOps"
12
19
  _TELEMETRY_SUBPROJECT = "ModelManagement"
13
20
 
14
21
 
22
+ class ExportMode(enum.Enum):
23
+ MODEL = "model"
24
+ FULL = "full"
25
+
26
+
15
27
  class ModelVersion:
16
28
  """Model Version Object representing a specific version of the model that could be run."""
17
29
 
@@ -240,6 +252,7 @@ class ModelVersion:
240
252
  X: Union[pd.DataFrame, dataframe.DataFrame],
241
253
  *,
242
254
  function_name: Optional[str] = None,
255
+ partition_column: Optional[str] = None,
243
256
  strict_input_validation: bool = False,
244
257
  ) -> Union[pd.DataFrame, dataframe.DataFrame]:
245
258
  """Invoke a method in a model version object.
@@ -248,12 +261,14 @@ class ModelVersion:
248
261
  X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
249
262
  function_name: The function name to run. It is the name used to call a function in SQL.
250
263
  Defaults to None. It can only be None if there is only 1 method.
264
+ partition_column: The partition column name to partition by.
251
265
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
252
266
  type validation to make sure your input data won't overflow when providing to the model.
253
267
 
254
268
  Raises:
255
269
  ValueError: When no method with the corresponding name is available.
256
270
  ValueError: When there are more than 1 target methods available in the model but no function name specified.
271
+ ValueError: When the partition column is not a valid Snowflake identifier.
257
272
 
258
273
  Returns:
259
274
  The prediction data. It would be the same type dataframe as your input.
@@ -263,6 +278,10 @@ class ModelVersion:
263
278
  subproject=_TELEMETRY_SUBPROJECT,
264
279
  )
265
280
 
281
+ if partition_column is not None:
282
+ # Partition column must be a valid identifier
283
+ partition_column = sql_identifier.SqlIdentifier(partition_column)
284
+
266
285
  functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions
267
286
  if function_name:
268
287
  req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
@@ -287,10 +306,126 @@ class ModelVersion:
287
306
  target_function_info = functions[0]
288
307
  return self._model_ops.invoke_method(
289
308
  method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
309
+ method_function_type=target_function_info["target_method_function_type"],
290
310
  signature=target_function_info["signature"],
291
311
  X=X,
292
312
  model_name=self._model_name,
293
313
  version_name=self._version_name,
294
314
  strict_input_validation=strict_input_validation,
315
+ partition_column=partition_column,
316
+ statement_params=statement_params,
317
+ )
318
+
319
+ @telemetry.send_api_usage_telemetry(
320
+ project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
321
+ )
322
+ def export(self, target_path: str, *, export_mode: ExportMode = ExportMode.MODEL) -> None:
323
+ """Export model files to a local directory.
324
+
325
+ Args:
326
+ target_path: Path to a local directory to export files to. A directory will be created if does not exist.
327
+ export_mode: The mode to export the model. Defaults to ExportMode.MODEL.
328
+ ExportMode.MODEL: All model files including environment to load the model and model weights.
329
+ ExportMode.FULL: Additional files to run the model in Warehouse, besides all files in MODEL mode,
330
+
331
+ Raises:
332
+ ValueError: Raised when the target path is a file or an non-empty folder.
333
+ """
334
+ target_local_path = pathlib.Path(target_path)
335
+ if target_local_path.is_file() or any(target_local_path.iterdir()):
336
+ raise ValueError(f"Target path {target_local_path} is a file or an non-empty folder.")
337
+
338
+ target_local_path.mkdir(parents=False, exist_ok=True)
339
+ statement_params = telemetry.get_statement_params(
340
+ project=_TELEMETRY_PROJECT,
341
+ subproject=_TELEMETRY_SUBPROJECT,
342
+ )
343
+ self._model_ops.download_files(
344
+ model_name=self._model_name,
345
+ version_name=self._version_name,
346
+ target_path=target_local_path,
347
+ mode=export_mode.value,
348
+ statement_params=statement_params,
349
+ )
350
+
351
+ @telemetry.send_api_usage_telemetry(
352
+ project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["force", "options"]
353
+ )
354
+ def load(
355
+ self,
356
+ *,
357
+ force: bool = False,
358
+ options: Optional[model_types.ModelLoadOption] = None,
359
+ ) -> model_types.SupportedModelType:
360
+ """Load the underlying original Python object back from a model.
361
+ This operation requires to have the exact the same environment as the one when logging the model, otherwise,
362
+ the model might be not functional or some other problems might occur.
363
+
364
+ Args:
365
+ force: Bypass the best-effort environment validation. Defaults to False.
366
+ options: Options to specify when loading the model, check `snowflake.ml.model.type_hints` for available
367
+ options. Defaults to None.
368
+
369
+ Raises:
370
+ ValueError: Raised when the best-effort environment validation fails.
371
+
372
+ Returns:
373
+ The original Python object loaded from the model object.
374
+ """
375
+ statement_params = telemetry.get_statement_params(
376
+ project=_TELEMETRY_PROJECT,
377
+ subproject=_TELEMETRY_SUBPROJECT,
378
+ )
379
+ if not force:
380
+ with tempfile.TemporaryDirectory() as tmp_workspace_for_validation:
381
+ ws_path_for_validation = pathlib.Path(tmp_workspace_for_validation)
382
+ self._model_ops.download_files(
383
+ model_name=self._model_name,
384
+ version_name=self._version_name,
385
+ target_path=ws_path_for_validation,
386
+ mode="minimal",
387
+ statement_params=statement_params,
388
+ )
389
+ pk_for_validation = model_composer.ModelComposer.load(
390
+ ws_path_for_validation, meta_only=True, options=options
391
+ )
392
+ assert pk_for_validation.meta, (
393
+ "Unable to load model metadata for validation. "
394
+ f"model_name={self._model_name}, version_name={self._version_name}"
395
+ )
396
+
397
+ validation_errors = pk_for_validation.meta.env.validate_with_local_env(
398
+ check_snowpark_ml_version=(
399
+ pk_for_validation.meta.model_type == snowmlmodel.SnowMLModelHandler.HANDLER_TYPE
400
+ )
401
+ )
402
+ if validation_errors:
403
+ raise ValueError(
404
+ f"Unable to load this model due to following validation errors: {validation_errors}. "
405
+ "Make sure your local environment is the same as that when you logged the model, "
406
+ "or if you believe it should work, specify `force=True` to bypass this check."
407
+ )
408
+
409
+ warnings.warn(
410
+ "Loading model requires to have the exact the same environment as the one when "
411
+ "logging the model, otherwise, the model might be not functional or "
412
+ "some other problems might occur.",
413
+ category=RuntimeWarning,
414
+ stacklevel=2,
415
+ )
416
+
417
+ # We need the folder to be existed.
418
+ workspace = pathlib.Path(tempfile.mkdtemp())
419
+ self._model_ops.download_files(
420
+ model_name=self._model_name,
421
+ version_name=self._version_name,
422
+ target_path=workspace,
423
+ mode="model",
295
424
  statement_params=statement_params,
296
425
  )
426
+ pk = model_composer.ModelComposer.load(workspace, meta_only=False, options=options)
427
+ assert pk.model, (
428
+ "Unable to load model. "
429
+ f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}"
430
+ )
431
+ return pk.model
@@ -1,7 +1,7 @@
1
+ import os
1
2
  import pathlib
2
3
  import tempfile
3
- from contextlib import contextmanager
4
- from typing import Any, Dict, Generator, List, Optional, Union, cast
4
+ from typing import Any, Dict, List, Literal, Optional, Union, cast
5
5
 
6
6
  import yaml
7
7
 
@@ -19,7 +19,9 @@ from snowflake.ml.model._model_composer.model_manifest import (
19
19
  model_manifest,
20
20
  model_manifest_schema,
21
21
  )
22
+ from snowflake.ml.model._packager.model_env import model_env
22
23
  from snowflake.ml.model._packager.model_meta import model_meta
24
+ from snowflake.ml.model._packager.model_runtime import model_runtime
23
25
  from snowflake.ml.model._signatures import snowpark_handler
24
26
  from snowflake.snowpark import dataframe, row, session
25
27
  from snowflake.snowpark._internal import utils as snowpark_utils
@@ -337,16 +339,6 @@ class ModelOperator:
337
339
  mm = model_manifest.ModelManifest(pathlib.Path(tmpdir))
338
340
  return mm.load()
339
341
 
340
- @contextmanager
341
- def _enable_model_details(
342
- self,
343
- *,
344
- statement_params: Optional[Dict[str, Any]] = None,
345
- ) -> Generator[None, None, None]:
346
- self._model_client.config_model_details(enable=True, statement_params=statement_params)
347
- yield
348
- self._model_client.config_model_details(enable=False, statement_params=statement_params)
349
-
350
342
  @staticmethod
351
343
  def _match_model_spec_with_sql_functions(
352
344
  sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
@@ -374,64 +366,63 @@ class ModelOperator:
374
366
  version_name: sql_identifier.SqlIdentifier,
375
367
  statement_params: Optional[Dict[str, Any]] = None,
376
368
  ) -> List[model_manifest_schema.ModelFunctionInfo]:
377
- with self._enable_model_details(statement_params=statement_params):
378
- raw_model_spec_res = self._model_client.show_versions(
379
- model_name=model_name,
380
- version_name=version_name,
381
- check_model_details=True,
382
- statement_params=statement_params,
383
- )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
384
- model_spec_dict = yaml.safe_load(raw_model_spec_res)
385
- model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
386
- show_functions_res = self._model_version_client.show_functions(
387
- model_name=model_name,
388
- version_name=version_name,
389
- statement_params=statement_params,
369
+ raw_model_spec_res = self._model_client.show_versions(
370
+ model_name=model_name,
371
+ version_name=version_name,
372
+ check_model_details=True,
373
+ statement_params={**(statement_params or {}), "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True},
374
+ )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
375
+ model_spec_dict = yaml.safe_load(raw_model_spec_res)
376
+ model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
377
+ show_functions_res = self._model_version_client.show_functions(
378
+ model_name=model_name,
379
+ version_name=version_name,
380
+ statement_params=statement_params,
381
+ )
382
+ function_names_and_types = []
383
+ for r in show_functions_res:
384
+ function_name = sql_identifier.SqlIdentifier(
385
+ r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
390
386
  )
391
- function_names_and_types = []
392
- for r in show_functions_res:
393
- function_name = sql_identifier.SqlIdentifier(
394
- r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
395
- )
396
387
 
397
- function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
398
- try:
399
- return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME]
400
- except KeyError:
401
- pass
402
- else:
403
- if "TABLE" in return_type:
404
- function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
405
-
406
- function_names_and_types.append((function_name, function_type))
407
-
408
- signatures = model_spec["signatures"]
409
- function_names = [name for name, _ in function_names_and_types]
410
- function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
411
- function_names, list(signatures.keys())
412
- )
388
+ function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
389
+ try:
390
+ return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME]
391
+ except KeyError:
392
+ pass
393
+ else:
394
+ if "TABLE" in return_type:
395
+ function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
413
396
 
414
- return [
415
- model_manifest_schema.ModelFunctionInfo(
416
- name=function_name.identifier(),
417
- target_method=function_name_mapping[function_name],
418
- target_method_function_type=function_type,
419
- signature=model_signature.ModelSignature.from_dict(
420
- signatures[function_name_mapping[function_name]]
421
- ),
422
- )
423
- for function_name, function_type in function_names_and_types
424
- ]
397
+ function_names_and_types.append((function_name, function_type))
398
+
399
+ signatures = model_spec["signatures"]
400
+ function_names = [name for name, _ in function_names_and_types]
401
+ function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
402
+ function_names, list(signatures.keys())
403
+ )
404
+
405
+ return [
406
+ model_manifest_schema.ModelFunctionInfo(
407
+ name=function_name.identifier(),
408
+ target_method=function_name_mapping[function_name],
409
+ target_method_function_type=function_type,
410
+ signature=model_signature.ModelSignature.from_dict(signatures[function_name_mapping[function_name]]),
411
+ )
412
+ for function_name, function_type in function_names_and_types
413
+ ]
425
414
 
426
415
  def invoke_method(
427
416
  self,
428
417
  *,
429
418
  method_name: sql_identifier.SqlIdentifier,
419
+ method_function_type: str,
430
420
  signature: model_signature.ModelSignature,
431
421
  X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
432
422
  model_name: sql_identifier.SqlIdentifier,
433
423
  version_name: sql_identifier.SqlIdentifier,
434
424
  strict_input_validation: bool = False,
425
+ partition_column: Optional[sql_identifier.SqlIdentifier] = None,
435
426
  statement_params: Optional[Dict[str, str]] = None,
436
427
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
437
428
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
@@ -469,15 +460,27 @@ class ModelOperator:
469
460
  if output_name in original_cols:
470
461
  original_cols.remove(output_name)
471
462
 
472
- df_res = self._model_version_client.invoke_method(
473
- method_name=method_name,
474
- input_df=s_df,
475
- input_args=input_args,
476
- returns=returns,
477
- model_name=model_name,
478
- version_name=version_name,
479
- statement_params=statement_params,
480
- )
463
+ if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
464
+ df_res = self._model_version_client.invoke_function_method(
465
+ method_name=method_name,
466
+ input_df=s_df,
467
+ input_args=input_args,
468
+ returns=returns,
469
+ model_name=model_name,
470
+ version_name=version_name,
471
+ statement_params=statement_params,
472
+ )
473
+ elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
474
+ df_res = self._model_version_client.invoke_table_function_method(
475
+ method_name=method_name,
476
+ input_df=s_df,
477
+ input_args=input_args,
478
+ partition_column=partition_column,
479
+ returns=returns,
480
+ model_name=model_name,
481
+ version_name=version_name,
482
+ statement_params=statement_params,
483
+ )
481
484
 
482
485
  if keep_order:
483
486
  df_res = df_res.sort(
@@ -486,7 +489,11 @@ class ModelOperator:
486
489
  )
487
490
 
488
491
  if not output_with_input_features:
489
- df_res = df_res.drop(*original_cols)
492
+ cols_to_drop = original_cols
493
+ if partition_column is not None:
494
+ # don't drop partition column
495
+ cols_to_drop.remove(partition_column.identifier())
496
+ df_res = df_res.drop(*cols_to_drop)
490
497
 
491
498
  # Get final result
492
499
  if not isinstance(X, dataframe.DataFrame):
@@ -512,3 +519,66 @@ class ModelOperator:
512
519
  model_name=model_name,
513
520
  statement_params=statement_params,
514
521
  )
522
+
523
+ def rename(
524
+ self,
525
+ *,
526
+ model_name: sql_identifier.SqlIdentifier,
527
+ new_model_db: Optional[sql_identifier.SqlIdentifier],
528
+ new_model_schema: Optional[sql_identifier.SqlIdentifier],
529
+ new_model_name: sql_identifier.SqlIdentifier,
530
+ statement_params: Optional[Dict[str, Any]] = None,
531
+ ) -> None:
532
+ self._model_client.rename(
533
+ model_name=model_name,
534
+ new_model_db=new_model_db,
535
+ new_model_schema=new_model_schema,
536
+ new_model_name=new_model_name,
537
+ statement_params=statement_params,
538
+ )
539
+
540
+ # Map indicating in different modes, the path to list and download.
541
+ # The boolean value indicates if it is a directory,
542
+ MODEL_FILE_DOWNLOAD_PATTERN = {
543
+ "minimal": {
544
+ pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
545
+ / model_meta.MODEL_METADATA_FILE: False,
546
+ pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH) / model_env._DEFAULT_ENV_DIR: True,
547
+ pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
548
+ / model_runtime.ModelRuntime.RUNTIME_DIR_REL_PATH: True,
549
+ },
550
+ "model": {pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH): True},
551
+ "full": {pathlib.PurePosixPath(os.curdir): True},
552
+ }
553
+
554
+ def download_files(
555
+ self,
556
+ *,
557
+ model_name: sql_identifier.SqlIdentifier,
558
+ version_name: sql_identifier.SqlIdentifier,
559
+ target_path: pathlib.Path,
560
+ mode: Literal["full", "model", "minimal"] = "model",
561
+ statement_params: Optional[Dict[str, Any]] = None,
562
+ ) -> None:
563
+ for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
564
+ list_file_res = self._model_version_client.list_file(
565
+ model_name=model_name,
566
+ version_name=version_name,
567
+ file_path=remote_rel_path,
568
+ is_dir=is_dir,
569
+ statement_params=statement_params,
570
+ )
571
+ file_list = [
572
+ pathlib.PurePosixPath(*pathlib.PurePosixPath(row.name).parts[2:]) # versions/<version_name>/...
573
+ for row in list_file_res
574
+ ]
575
+ for stage_file_path in file_list:
576
+ local_file_dir = target_path / stage_file_path.parent
577
+ local_file_dir.mkdir(parents=True, exist_ok=True)
578
+ self._model_version_client.get_file(
579
+ model_name=model_name,
580
+ version_name=version_name,
581
+ file_path=stage_file_path,
582
+ target_path=local_file_dir,
583
+ statement_params=statement_params,
584
+ )
@@ -121,21 +121,23 @@ class ModelSQLClient:
121
121
  statement_params=statement_params,
122
122
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
123
123
 
124
- def config_model_details(
124
+ def rename(
125
125
  self,
126
126
  *,
127
- enable: bool,
127
+ model_name: sql_identifier.SqlIdentifier,
128
+ new_model_db: Optional[sql_identifier.SqlIdentifier],
129
+ new_model_schema: Optional[sql_identifier.SqlIdentifier],
130
+ new_model_name: sql_identifier.SqlIdentifier,
128
131
  statement_params: Optional[Dict[str, Any]] = None,
129
132
  ) -> None:
130
- if enable:
131
- query_result_checker.SqlResultValidator(
132
- self._session,
133
- "ALTER SESSION SET SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL=true",
134
- statement_params=statement_params,
135
- ).has_dimensions(expected_rows=1, expected_cols=1).validate()
136
- else:
137
- query_result_checker.SqlResultValidator(
138
- self._session,
139
- "ALTER SESSION UNSET SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL",
140
- statement_params=statement_params,
141
- ).has_dimensions(expected_rows=1, expected_cols=1).validate()
133
+ # Use registry's database and schema if a non fully qualified new model name is provided.
134
+ new_fully_qualified_name = identifier.get_schema_level_object_identifier(
135
+ new_model_db.identifier() if new_model_db else self._database_name.identifier(),
136
+ new_model_schema.identifier() if new_model_schema else self._schema_name.identifier(),
137
+ new_model_name.identifier(),
138
+ )
139
+ query_result_checker.SqlResultValidator(
140
+ self._session,
141
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} RENAME TO {new_fully_qualified_name}",
142
+ statement_params=statement_params,
143
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -96,6 +96,38 @@ class ModelVersionSQLClient:
96
96
  statement_params=statement_params,
97
97
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
98
98
 
99
+ def list_file(
100
+ self,
101
+ *,
102
+ model_name: sql_identifier.SqlIdentifier,
103
+ version_name: sql_identifier.SqlIdentifier,
104
+ file_path: pathlib.PurePosixPath,
105
+ is_dir: bool = False,
106
+ statement_params: Optional[Dict[str, Any]] = None,
107
+ ) -> List[row.Row]:
108
+ # Workaround for snowURL bug.
109
+ trailing_slash = "/" if is_dir else ""
110
+
111
+ stage_location = (
112
+ pathlib.PurePosixPath(
113
+ self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
114
+ ).as_posix()
115
+ + trailing_slash
116
+ )
117
+ stage_location_url = ParseResult(
118
+ scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
119
+ ).geturl()
120
+
121
+ return (
122
+ query_result_checker.SqlResultValidator(
123
+ self._session,
124
+ f"List {_normalize_url_for_sql(stage_location_url)}",
125
+ statement_params=statement_params,
126
+ )
127
+ .has_column("name")
128
+ .validate()
129
+ )
130
+
99
131
  def get_file(
100
132
  self,
101
133
  *,
@@ -162,7 +194,7 @@ class ModelVersionSQLClient:
162
194
  statement_params=statement_params,
163
195
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
164
196
 
165
- def invoke_method(
197
+ def invoke_function_method(
166
198
  self,
167
199
  *,
168
200
  model_name: sql_identifier.SqlIdentifier,
@@ -232,6 +264,82 @@ class ModelVersionSQLClient:
232
264
 
233
265
  return output_df
234
266
 
267
+ def invoke_table_function_method(
268
+ self,
269
+ *,
270
+ model_name: sql_identifier.SqlIdentifier,
271
+ version_name: sql_identifier.SqlIdentifier,
272
+ method_name: sql_identifier.SqlIdentifier,
273
+ input_df: dataframe.DataFrame,
274
+ input_args: List[sql_identifier.SqlIdentifier],
275
+ returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
276
+ partition_column: Optional[sql_identifier.SqlIdentifier],
277
+ statement_params: Optional[Dict[str, Any]] = None,
278
+ ) -> dataframe.DataFrame:
279
+ with_statements = []
280
+ if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
281
+ INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
282
+ with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
283
+ else:
284
+ tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
285
+ INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
286
+ self._database_name.identifier(),
287
+ self._schema_name.identifier(),
288
+ tmp_table_name,
289
+ )
290
+ input_df.write.save_as_table( # type: ignore[call-overload]
291
+ table_name=INTERMEDIATE_TABLE_NAME,
292
+ mode="errorifexists",
293
+ table_type="temporary",
294
+ statement_params=statement_params,
295
+ )
296
+
297
+ module_version_alias = "MODEL_VERSION_ALIAS"
298
+ with_statements.append(
299
+ f"{module_version_alias} AS "
300
+ f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
301
+ )
302
+
303
+ partition_by = partition_column.identifier() if partition_column is not None else "1"
304
+
305
+ args_sql_list = []
306
+ for input_arg_value in input_args:
307
+ args_sql_list.append(input_arg_value)
308
+
309
+ args_sql = ", ".join(args_sql_list)
310
+
311
+ sql = textwrap.dedent(
312
+ f"""WITH {','.join(with_statements)}
313
+ SELECT *,
314
+ FROM {INTERMEDIATE_TABLE_NAME},
315
+ TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
316
+ OVER (PARTITION BY {partition_by}))"""
317
+ )
318
+
319
+ output_df = self._session.sql(sql)
320
+
321
+ # Prepare the output
322
+ output_cols = []
323
+ output_names = []
324
+
325
+ for output_name, output_type, output_col_name in returns:
326
+ output_cols.append(F.col(output_name).astype(output_type))
327
+ output_names.append(output_col_name)
328
+
329
+ if partition_column is not None:
330
+ output_cols.append(F.col(partition_column.identifier()))
331
+ output_names.append(partition_column)
332
+
333
+ output_df = output_df.with_columns(
334
+ col_names=output_names,
335
+ values=output_cols,
336
+ )
337
+
338
+ if statement_params:
339
+ output_df._statement_params = statement_params # type: ignore[assignment]
340
+
341
+ return output_df
342
+
235
343
  def set_metadata(
236
344
  self,
237
345
  metadata_dict: Dict[str, Any],
@@ -37,6 +37,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
37
37
  session: snowpark.Session,
38
38
  artifact_stage_location: str,
39
39
  compute_pool: str,
40
+ job_name: str,
40
41
  external_access_integrations: List[str],
41
42
  ) -> None:
42
43
  """Initialization
@@ -49,6 +50,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
49
50
  artifact_stage_location: Spec file and future deployment related artifacts will be stored under
50
51
  {stage}/models/{model_id}
51
52
  compute_pool: The compute pool used to run docker image build workload.
53
+ job_name: job_name to use.
52
54
  external_access_integrations: EAIs for network connection.
53
55
  """
54
56
  self.context_dir = context_dir
@@ -58,6 +60,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
58
60
  self.artifact_stage_location = artifact_stage_location
59
61
  self.compute_pool = compute_pool
60
62
  self.external_access_integrations = external_access_integrations
63
+ self.job_name = job_name
61
64
  self.client = snowservice_client.SnowServiceClient(session)
62
65
 
63
66
  assert artifact_stage_location.startswith(
@@ -203,8 +206,9 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
203
206
  )
204
207
 
205
208
  def _launch_kaniko_job(self, spec_stage_location: str) -> None:
206
- logger.debug("Submitting job for building docker image with kaniko")
209
+ logger.debug(f"Submitting job {self.job_name} for building docker image with kaniko")
207
210
  self.client.create_job(
211
+ job_name=self.job_name,
208
212
  compute_pool=self.compute_pool,
209
213
  spec_stage_location=spec_stage_location,
210
214
  external_access_integrations=self.external_access_integrations,
@@ -30,6 +30,7 @@ USER mambauser
30
30
 
31
31
  # Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time.
32
32
  ARG MAMBA_DOCKERFILE_ACTIVATE=1
33
+ ARG MAMBA_NO_LOW_SPEED_LIMIT=1
33
34
 
34
35
  # Bitsandbytes uses this ENVVAR to determine CUDA library location
35
36
  ENV CONDA_PREFIX=/opt/conda