snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -4,8 +4,13 @@ from typing import Any, Dict, List, Literal, TypedDict
4
4
 
5
5
  from typing_extensions import NotRequired, Required
6
6
 
7
+ from snowflake.ml.model import model_signature
8
+
7
9
  MODEL_MANIFEST_VERSION = "1.0"
8
10
 
11
+ MANIFEST_CLIENT_DATA_KEY_NAME = "snowpark_ml_data"
12
+ MANIFEST_CLIENT_DATA_SCHEMA_VERSION = "2024-02-01"
13
+
9
14
 
10
15
  class ModelRuntimeDependenciesDict(TypedDict):
11
16
  conda: Required[str]
@@ -38,6 +43,31 @@ class ModelFunctionMethodDict(TypedDict):
38
43
  ModelMethodDict = ModelFunctionMethodDict
39
44
 
40
45
 
46
+ class ModelFunctionInfo(TypedDict):
47
+ """Function information.
48
+
49
+ Attributes:
50
+ name: Name of the function to be called via SQL.
51
+ target_method: actual target method name to be called.
52
+ signature: The signature of the model method.
53
+ """
54
+
55
+ name: Required[str]
56
+ target_method: Required[str]
57
+ signature: Required[model_signature.ModelSignature]
58
+
59
+
60
+ class ModelFunctionInfoDict(TypedDict):
61
+ name: Required[str]
62
+ target_method: Required[str]
63
+ signature: Required[Dict[str, Any]]
64
+
65
+
66
+ class SnowparkMLDataDict(TypedDict):
67
+ schema_version: Required[str]
68
+ functions: Required[List[ModelFunctionInfoDict]]
69
+
70
+
41
71
  class ModelManifestDict(TypedDict):
42
72
  manifest_version: Required[str]
43
73
  runtimes: Required[Dict[str, ModelRuntimeDict]]
@@ -1,7 +1,6 @@
1
1
  import pathlib
2
2
  from typing import Optional, TypedDict
3
3
 
4
- import importlib_resources
5
4
  from typing_extensions import NotRequired
6
5
 
7
6
  from snowflake.ml.model import type_hints
@@ -33,6 +32,8 @@ class FunctionGenerator:
33
32
  target_method: str,
34
33
  options: Optional[FunctionGenerateOptions] = None,
35
34
  ) -> None:
35
+ import importlib_resources
36
+
36
37
  if options is None:
37
38
  options = {}
38
39
  function_template = (
@@ -1 +1,10 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'anyio>=3.5.0,<4', 'numpy>=1.23,<2', 'packaging>=20.9,<24', 'pandas>=1.0.0,<2', 'pyyaml>=6.0,<7', 'snowflake-snowpark-python>=1.8.0,<2', 'typing-extensions>=4.1.0,<5']
1
+ REQUIREMENTS = [
2
+ "absl-py>=0.15,<2",
3
+ "anyio>=3.5.0,<4",
4
+ "numpy>=1.23,<2",
5
+ "packaging>=20.9,<24",
6
+ "pandas>=1.0.0,<2",
7
+ "pyyaml>=6.0,<7",
8
+ "snowflake-snowpark-python>=1.8.0,<2",
9
+ "typing-extensions>=4.1.0,<5"
10
+ ]
@@ -44,12 +44,17 @@ class ModelRuntime:
44
44
  if self.runtime_env._snowpark_ml_version.local:
45
45
  self.embed_local_ml_library = True
46
46
  else:
47
- snowml_server_availability = env_utils.validate_requirements_in_information_schema(
48
- session=session,
49
- reqs=[requirements.Requirement(snowml_pkg_spec)],
50
- python_version=snowml_env.PYTHON_VERSION,
47
+ snowml_server_availability = (
48
+ len(
49
+ env_utils.get_matched_package_versions_in_information_schema(
50
+ session=session,
51
+ reqs=[requirements.Requirement(snowml_pkg_spec)],
52
+ python_version=snowml_env.PYTHON_VERSION,
53
+ ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
54
+ )
55
+ >= 1
51
56
  )
52
- self.embed_local_ml_library = snowml_server_availability is None
57
+ self.embed_local_ml_library = not snowml_server_availability
53
58
 
54
59
  if self.embed_local_ml_library:
55
60
  self.runtime_env.include_if_absent(
@@ -57,7 +62,6 @@ class ModelRuntime:
57
62
  model_env.ModelDependency(requirement=dep, pip_name=requirements.Requirement(dep).name)
58
63
  for dep in _UDF_INFERENCE_DEPENDENCIES
59
64
  ],
60
- check_local_version=True,
61
65
  )
62
66
  else:
63
67
  self.runtime_env.include_if_absent(
@@ -65,7 +69,6 @@ class ModelRuntime:
65
69
  model_env.ModelDependency(requirement=dep, pip_name=requirements.Requirement(dep).name)
66
70
  for dep in _UDF_INFERENCE_DEPENDENCIES + [snowml_pkg_spec]
67
71
  ],
68
- check_local_version=True,
69
72
  )
70
73
 
71
74
  def save(self, workspace_path: pathlib.Path) -> model_manifest_schema.ModelRuntimeDict:
@@ -59,7 +59,7 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model
59
59
  return (
60
60
  [model_env.ModelDependency(requirement="tokenizers>=0.13.3", pip_name="tokenizers")]
61
61
  if spcs_only
62
- else [model_env.ModelDependency(requirement="tokenizers<=0.13.2", pip_name="tokenizers")]
62
+ else [model_env.ModelDependency(requirement="tokenizers", pip_name="tokenizers")]
63
63
  )
64
64
 
65
65
  return []
@@ -1,6 +1,16 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ Optional,
9
+ Type,
10
+ Union,
11
+ cast,
12
+ final,
13
+ )
4
14
 
5
15
  import numpy as np
6
16
  import pandas as pd
@@ -150,6 +160,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
150
160
  m.load_model(os.path.join(model_blob_path, model_blob_filename))
151
161
 
152
162
  if kwargs.get("use_gpu", False):
163
+ assert type(kwargs.get("use_gpu", False)) == bool
153
164
  gpu_params = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
154
165
  if isinstance(m, xgboost.Booster):
155
166
  m.set_param(gpu_params)
@@ -197,7 +208,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
197
208
 
198
209
  return fn
199
210
 
200
- type_method_dict = {}
211
+ type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
201
212
  for target_method_name, sig in model_meta.signatures.items():
202
213
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
203
214
 
@@ -1 +1,11 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'anyio>=3.5.0,<4', 'cloudpickle>=2.0.0', 'numpy>=1.23,<2', 'packaging>=20.9,<24', 'pandas>=1.0.0,<2', 'pyyaml>=6.0,<7', 'snowflake-snowpark-python>=1.8.0,<2', 'typing-extensions>=4.1.0,<5']
1
+ REQUIREMENTS = [
2
+ "absl-py>=0.15,<2",
3
+ "anyio>=3.5.0,<4",
4
+ "cloudpickle>=2.0.0",
5
+ "numpy>=1.23,<2",
6
+ "packaging>=20.9,<24",
7
+ "pandas>=1.0.0,<2",
8
+ "pyyaml>=6.0,<7",
9
+ "snowflake-snowpark-python>=1.8.0,<2",
10
+ "typing-extensions>=4.1.0,<5"
11
+ ]
@@ -0,0 +1,3 @@
1
+ REQUIREMENTS = [
2
+ "cloudpickle>=2.0.0"
3
+ ]
@@ -18,6 +18,7 @@ from snowflake.ml.model import model_signature, type_hints as model_types
18
18
  from snowflake.ml.model._packager.model_env import model_env
19
19
  from snowflake.ml.model._packager.model_meta import (
20
20
  _core_requirements,
21
+ _packaging_requirements,
21
22
  model_blob_meta,
22
23
  model_meta_schema,
23
24
  )
@@ -26,7 +27,8 @@ from snowflake.ml.model._packager.model_meta_migrator import migrator_plans
26
27
  MODEL_METADATA_FILE = "model.yaml"
27
28
  MODEL_CODE_DIR = "code"
28
29
 
29
- _PACKAGING_CORE_DEPENDENCIES = _core_requirements.REQUIREMENTS
30
+ _PACKAGING_CORE_DEPENDENCIES = _core_requirements.REQUIREMENTS # Legacy Model only
31
+ _PACKAGING_REQUIREMENTS = _packaging_requirements.REQUIREMENTS # New Model only
30
32
  _SNOWFLAKE_PKG_NAME = "snowflake"
31
33
  _SNOWFLAKE_ML_PKG_NAME = f"{_SNOWFLAKE_PKG_NAME}.ml"
32
34
 
@@ -73,6 +75,8 @@ def create_model_metadata(
73
75
  model_dir_path = os.path.normpath(model_dir_path)
74
76
  embed_local_ml_library = kwargs.pop("embed_local_ml_library", False)
75
77
  legacy_save = kwargs.pop("_legacy_save", False)
78
+ relax_version = kwargs.pop("relax_version", False)
79
+
76
80
  if embed_local_ml_library:
77
81
  # Use the last one which is loaded first, that is mean, it is loaded from site-packages.
78
82
  # We could make sure that user does not overwrite our library with their code follow the same naming.
@@ -94,6 +98,8 @@ def create_model_metadata(
94
98
  pip_requirements=pip_requirements,
95
99
  python_version=python_version,
96
100
  embed_local_ml_library=embed_local_ml_library,
101
+ legacy_save=legacy_save,
102
+ relax_version=relax_version,
97
103
  )
98
104
 
99
105
  if embed_local_ml_library:
@@ -146,6 +152,8 @@ def _create_env_for_model_metadata(
146
152
  pip_requirements: Optional[List[str]] = None,
147
153
  python_version: Optional[str] = None,
148
154
  embed_local_ml_library: bool = False,
155
+ legacy_save: bool = False,
156
+ relax_version: bool = False,
149
157
  ) -> model_env.ModelEnv:
150
158
  env = model_env.ModelEnv()
151
159
 
@@ -154,11 +162,14 @@ def _create_env_for_model_metadata(
154
162
  env.pip_requirements = pip_requirements # type: ignore[assignment]
155
163
  env.python_version = python_version # type: ignore[assignment]
156
164
  env.snowpark_ml_version = snowml_env.VERSION
165
+
166
+ requirements_to_add = _PACKAGING_CORE_DEPENDENCIES if legacy_save else _PACKAGING_REQUIREMENTS
167
+
157
168
  if embed_local_ml_library:
158
169
  env.include_if_absent(
159
170
  [
160
171
  model_env.ModelDependency(requirement=dep, pip_name=requirements.Requirement(dep).name)
161
- for dep in _PACKAGING_CORE_DEPENDENCIES
172
+ for dep in requirements_to_add
162
173
  ],
163
174
  check_local_version=True,
164
175
  )
@@ -166,11 +177,14 @@ def _create_env_for_model_metadata(
166
177
  env.include_if_absent(
167
178
  [
168
179
  model_env.ModelDependency(requirement=dep, pip_name=requirements.Requirement(dep).name)
169
- for dep in _PACKAGING_CORE_DEPENDENCIES + [env_utils.SNOWPARK_ML_PKG_NAME]
180
+ for dep in requirements_to_add + [env_utils.SNOWPARK_ML_PKG_NAME]
170
181
  ],
171
182
  check_local_version=True,
172
183
  )
173
184
 
185
+ if relax_version:
186
+ env.relax_version()
187
+
174
188
  return env
175
189
 
176
190
 
@@ -146,7 +146,8 @@ class DataType(Enum):
146
146
  " is being automatically converted to INT64 in the Snowpark DataFrame. "
147
147
  "This automatic conversion may lead to potential precision loss and rounding errors. "
148
148
  "If you wish to prevent this conversion, you should manually perform "
149
- "the necessary data type conversion."
149
+ "the necessary data type conversion.",
150
+ stacklevel=2,
150
151
  )
151
152
  return DataType.INT64
152
153
  else:
@@ -155,7 +156,8 @@ class DataType(Enum):
155
156
  " is being automatically converted to DOUBLE in the Snowpark DataFrame. "
156
157
  "This automatic conversion may lead to potential precision loss and rounding errors. "
157
158
  "If you wish to prevent this conversion, you should manually perform "
158
- "the necessary data type conversion."
159
+ "the necessary data type conversion.",
160
+ stacklevel=2,
159
161
  )
160
162
  return DataType.DOUBLE
161
163
  raise snowml_exceptions.SnowflakeMLException(
@@ -202,23 +204,24 @@ class FeatureSpec(BaseFeatureSpec):
202
204
  dtype: DataType,
203
205
  shape: Optional[Tuple[int, ...]] = None,
204
206
  ) -> None:
205
- """Initialize a feature.
207
+ """
208
+ Initialize a feature.
206
209
 
207
210
  Args:
208
211
  name: Name of the feature.
209
212
  dtype: Type of the elements in the feature.
210
- shape: Used to represent scalar feature, 1-d feature list or n-d tensor.
211
- -1 is used to represent variable length.Defaults to None.
213
+ shape: Used to represent scalar feature, 1-d feature list,
214
+ or n-d tensor. Use -1 to represent variable length. Defaults to None.
212
215
 
213
- E.g.
214
- None: scalar
215
- (2,): 1d list with fixed len of 2.
216
- (-1,): 1d list with variable length. Used for ragged tensor representation.
217
- (d1, d2, d3): 3d tensor.
216
+ Examples:
217
+ - None: scalar
218
+ - (2,): 1d list with a fixed length of 2.
219
+ - (-1,): 1d list with variable length, used for ragged tensor representation.
220
+ - (d1, d2, d3): 3d tensor.
218
221
 
219
222
  Raises:
220
- SnowflakeMLException: TypeError: Raised when the dtype input type is incorrect.
221
- SnowflakeMLException: TypeError: Raised when the shape input type is incorrect.
223
+ SnowflakeMLException: TypeError: When the dtype input type is incorrect.
224
+ SnowflakeMLException: TypeError: When the shape input type is incorrect.
222
225
  """
223
226
  super().__init__(name=name)
224
227
 
@@ -408,13 +411,13 @@ class ModelSignature:
408
411
  """Signature of a model that specifies the input and output of a model."""
409
412
 
410
413
  def __init__(self, inputs: Sequence[BaseFeatureSpec], outputs: Sequence[BaseFeatureSpec]) -> None:
411
- """Initialize a model signature
414
+ """Initialize a model signature.
412
415
 
413
416
  Args:
414
- inputs: A sequence of feature specifications and feature group specifications that will compose the
415
- input of the model.
416
- outputs: A sequence of feature specifications and feature group specifications that will compose the
417
- output of the model.
417
+ inputs: A sequence of feature specifications and feature group specifications that will compose
418
+ the input of the model.
419
+ outputs: A sequence of feature specifications and feature group specifications that will compose
420
+ the output of the model.
418
421
  """
419
422
  self._inputs = inputs
420
423
  self._outputs = outputs
@@ -9,15 +9,16 @@ from snowflake.ml.model import type_hints as model_types
9
9
 
10
10
 
11
11
  class MethodRef:
12
- """Represents an method invocation of an instance of `ModelRef`.
12
+ """Represents a method invocation of an instance of `ModelRef`.
13
+
14
+ This allows us to:
15
+ 1) Customize the place of actual execution of the method (inline, thread/process pool, or remote).
16
+ 2) Enrich the way of execution (sync versus async).
13
17
 
14
- This allows us to
15
- 1) Customize the place of actual execution of the method(inline, thread/process pool or remote).
16
- 2) Enrich the way of execution(sync versus async).
17
18
  Example:
18
- If you have a SKL model, you would normally invoke by `skl_ref.predict(df)` which has sync API.
19
- Within inference graph, you could invoke `await skl_ref.predict.async_run(df)` which automatically
20
- will be run on thread with async interface.
19
+ If you have an SKL model, you would normally invoke it by `skl_ref.predict(df)`, which has a synchronous API.
20
+ Within the inference graph, you could invoke `await skl_ref.predict.async_run(df)`, which will automatically
21
+ run on a thread with an asynchronous interface.
21
22
  """
22
23
 
23
24
  def __init__(self, model_ref: "ModelRef", method_name: str) -> None:
@@ -27,11 +28,11 @@ class MethodRef:
27
28
  return self._func(*args, **kwargs)
28
29
 
29
30
  async def async_run(self, *args: Any, **kwargs: Any) -> Any:
30
- """Run the method in a async way. If the method is defined as async, this will simply run it. If not, this will
31
- be run in a separate thread.
31
+ """Run the method in an asynchronous way. If the method is defined as async, this will simply run it.
32
+ If not, this will be run in a separate thread.
32
33
 
33
34
  Args:
34
- *args: Arguments of the original method,
35
+ *args: Arguments of the original method.
35
36
  **kwargs: Keyword arguments of the original method.
36
37
 
37
38
  Returns:
@@ -43,19 +44,20 @@ class MethodRef:
43
44
 
44
45
 
45
46
  class ModelRef:
46
- """Represents an model in the inference graph. Method could be directly called using this reference object as if
47
- with the original model object.
47
+ """
48
+ Represents a model in the inference graph. Methods can be directly called using this reference object
49
+ as if with the original model object.
48
50
 
49
- This enables us to separate physical and logical representation of a model which
50
- will allows us to deeply understand the graph and perform optimization at entire
51
- graph level.
51
+ This enables us to separate the physical and logical representation of a model, allowing for a deep understanding
52
+ of the graph and enabling optimization at the entire graph level.
52
53
  """
53
54
 
54
55
  def __init__(self, name: str, model: model_types.SupportedModelType) -> None:
55
- """Initialize the ModelRef.
56
+ """
57
+ Initialize the ModelRef.
56
58
 
57
59
  Args:
58
- name: The name of a model to refer it.
60
+ name: The name of the model to refer to.
59
61
  model: The model object.
60
62
  """
61
63
  self._model = model
@@ -91,11 +93,12 @@ class ModelRef:
91
93
 
92
94
 
93
95
  class ModelContext:
94
- """Context for a custom model showing path to artifacts and mapping between model name and object reference.
96
+ """
97
+ Context for a custom model showing paths to artifacts and mapping between model name and object reference.
95
98
 
96
99
  Attributes:
97
- artifacts: A dict mapping name of the artifact to its path.
98
- model_refs: A dict mapping name of the sub-model to its ModelRef object.
100
+ artifacts: A dictionary mapping the name of the artifact to its path.
101
+ model_refs: A dictionary mapping the name of the sub-model to its ModelRef object.
99
102
  """
100
103
 
101
104
  def __init__(
@@ -104,11 +107,11 @@ class ModelContext:
104
107
  artifacts: Optional[Dict[str, str]] = None,
105
108
  models: Optional[Dict[str, model_types.SupportedModelType]] = None,
106
109
  ) -> None:
107
- """Initialize the model context
110
+ """Initialize the model context.
108
111
 
109
112
  Args:
110
- artifacts: A dict mapping name of the artifact to its currently available path. Defaults to None.
111
- models: A dict mapping name of the sub-model to the corresponding model object. Defaults to None.
113
+ artifacts: A dictionary mapping the name of the artifact to its currently available path. Defaults to None.
114
+ models: A dictionary mapping the name of the sub-model to the corresponding model object. Defaults to None.
112
115
  """
113
116
  self.artifacts: Dict[str, str] = artifacts if artifacts else dict()
114
117
  self.model_refs: Dict[str, ModelRef] = (
@@ -116,7 +119,8 @@ class ModelContext:
116
119
  )
117
120
 
118
121
  def path(self, key: str) -> str:
119
- """Get the actual path to a specific artifact.
122
+ """Get the actual path to a specific artifact. This could be used when defining a Custom Model to retrieve
123
+ artifacts.
120
124
 
121
125
  Args:
122
126
  key: The name of the artifact.
@@ -127,14 +131,13 @@ class ModelContext:
127
131
  return self.artifacts[key]
128
132
 
129
133
  def model_ref(self, name: str) -> ModelRef:
130
- """Get a ModelRef object of a sub-model containing the name and model object, while able to call its method
131
- directly as well.
134
+ """Get a ModelRef object of a sub-model containing the name and model object, allowing direct method calls.
132
135
 
133
136
  Args:
134
137
  name: The name of the sub-model.
135
138
 
136
139
  Returns:
137
- The ModelRef object to the sub-model.
140
+ The ModelRef object representing the sub-model.
138
141
  """
139
142
  return self.model_refs[name]
140
143
 
@@ -570,32 +570,31 @@ def infer_signature(
570
570
  input_feature_names: Optional[List[str]] = None,
571
571
  output_feature_names: Optional[List[str]] = None,
572
572
  ) -> core.ModelSignature:
573
- """Infer model signature from given input and output sample data.
573
+ """
574
+ Infer model signature from given input and output sample data.
575
+
576
+ Currently supports inferring model signatures from the following data types:
574
577
 
575
- Currently, we support infer the model signature from example input/output data in the following cases:
576
- - Pandas data frame whose column could have types of supported data types,
577
- list (including list of supported data types, list of numpy array of supported data types, and nested list),
578
- and numpy array of supported data types.
578
+ - Pandas DataFrame with columns of supported data types, lists (including nested lists) of supported data types,
579
+ and NumPy arrays of supported data types.
579
580
  - Does not support DataFrame with CategoricalIndex column index.
580
- - Does not support DataFrame with column of variant length list or numpy array.
581
- - Numpy array of supported data types.
582
- - List of Numpy array of supported data types.
583
- - List of supported data types, or nested list of supported data types.
584
- - Does not support list of list of variant length list.
581
+ - NumPy arrays of supported data types.
582
+ - Lists of NumPy arrays of supported data types.
583
+ - Lists of supported data types or nested lists of supported data types.
584
+
585
+ When inferring the signature, a ValueError indicates that the data is insufficient or invalid.
585
586
 
586
- When a ValueError is raised when inferring the signature, it indicates that the data is ill and it is impossible to
587
- create a signature reflecting that.
588
- When a NotImplementedError is raised, it indicates that it might be possible to create a signature reflecting the
589
- provided data, however, we could not infer it.
587
+ When it might be possible to create a signature reflecting the provided data, but it could not be inferred,
588
+ a NotImplementedError is raised
590
589
 
591
590
  Args:
592
591
  input_data: Sample input data for the model.
593
592
  output_data: Sample output data for the model.
594
- input_feature_names: Name for input features. Defaults to None.
595
- output_feature_names: Name for output features. Defaults to None.
593
+ input_feature_names: Names for input features. Defaults to None.
594
+ output_feature_names: Names for output features. Defaults to None.
596
595
 
597
596
  Returns:
598
- A model signature.
597
+ A model signature inferred from the given input and output sample data.
599
598
  """
600
599
  inputs = _infer_signature(input_data, role="input")
601
600
  inputs = utils.rename_features(inputs, input_feature_names)
@@ -198,9 +198,12 @@ class BaseModelSaveOption(TypedDict):
198
198
  """Options for saving the model.
199
199
 
200
200
  embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
201
+ relax_version: Whether or not relax the version constraints of the dependencies if unresolvable. It detects any
202
+ ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False.
201
203
  """
202
204
 
203
205
  embed_local_ml_library: NotRequired[bool]
206
+ relax_version: NotRequired[bool]
204
207
  _legacy_save: NotRequired[bool]
205
208
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
206
209