snowflake-ml-python 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/utils/identifier.py +3 -1
  3. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  4. snowflake/ml/feature_store/feature_store.py +151 -78
  5. snowflake/ml/feature_store/feature_view.py +12 -24
  6. snowflake/ml/fileset/sfcfs.py +56 -50
  7. snowflake/ml/fileset/stage_fs.py +48 -13
  8. snowflake/ml/model/_client/model/model_version_impl.py +2 -50
  9. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  10. snowflake/ml/model/_client/sql/model.py +23 -2
  11. snowflake/ml/model/_client/sql/model_version.py +22 -1
  12. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  13. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  14. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  15. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  16. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  17. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  18. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  19. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  20. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  21. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  22. snowflake/ml/model/_packager/model_packager.py +2 -2
  23. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  24. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  25. snowflake/ml/model/type_hints.py +21 -2
  26. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  27. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  29. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  30. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  31. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  32. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  33. snowflake/ml/modeling/cluster/birch.py +195 -123
  34. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  35. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  36. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  37. snowflake/ml/modeling/cluster/k_means.py +195 -123
  38. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  39. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  40. snowflake/ml/modeling/cluster/optics.py +195 -123
  41. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  42. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  43. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  44. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  45. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  46. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  47. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  48. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  49. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  50. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  51. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  52. snowflake/ml/modeling/covariance/oas.py +195 -123
  53. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  54. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  55. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  56. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  57. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  58. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  59. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  60. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  61. snowflake/ml/modeling/decomposition/pca.py +195 -123
  62. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  63. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  64. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  65. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  66. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  67. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  68. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  69. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  70. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  71. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  72. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  73. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  76. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  77. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  78. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  79. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  80. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  81. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  82. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  83. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  84. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  85. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  86. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  87. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  88. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  89. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  90. snowflake/ml/modeling/framework/_utils.py +8 -1
  91. snowflake/ml/modeling/framework/base.py +9 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  94. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  95. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  96. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +195 -123
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +195 -123
  105. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  107. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  111. snowflake/ml/modeling/linear_model/lars.py +195 -123
  112. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  113. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  118. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  128. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  131. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  140. snowflake/ml/modeling/manifold/isomap.py +195 -123
  141. snowflake/ml/modeling/manifold/mds.py +195 -123
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  143. snowflake/ml/modeling/manifold/tsne.py +195 -123
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  146. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  147. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  148. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  149. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  150. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  151. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  152. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  153. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  154. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  155. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  156. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  157. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  158. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  159. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  160. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  161. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  162. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  163. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  164. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  165. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  166. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  167. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  168. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  169. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  170. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  171. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  172. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  173. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  174. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  175. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  176. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  178. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  179. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  180. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  181. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  182. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  183. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  184. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  185. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  186. snowflake/ml/modeling/svm/svc.py +195 -123
  187. snowflake/ml/modeling/svm/svr.py +195 -123
  188. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  189. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  190. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  191. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  192. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  193. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  194. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  195. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  196. snowflake/ml/registry/registry.py +1 -1
  197. snowflake/ml/version.py +1 -1
  198. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +68 -57
  199. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +202 -200
  200. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  201. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  202. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  203. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import collections
2
+ import copy
2
3
  import pathlib
3
- from typing import Any, Dict, List, Optional, cast
4
+ from typing import List, Optional, cast
4
5
 
5
6
  import yaml
6
7
 
@@ -10,7 +11,6 @@ from snowflake.ml.model._model_composer.model_method import (
10
11
  function_generator,
11
12
  model_method,
12
13
  )
13
- from snowflake.ml.model._model_composer.model_runtime import model_runtime
14
14
  from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
15
15
  from snowflake.snowpark import Session
16
16
 
@@ -39,21 +39,19 @@ class ModelManifest:
39
39
  ) -> None:
40
40
  if options is None:
41
41
  options = {}
42
- self.runtimes = [
43
- model_runtime.ModelRuntime(
44
- session=session,
45
- name=ModelManifest._DEFAULT_RUNTIME_NAME,
46
- model_meta=model_meta,
47
- imports=[model_file_rel_path],
48
- )
49
- ]
42
+
43
+ runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
44
+ runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
45
+ runtime_to_use.imports.append(model_file_rel_path)
46
+ runtime_dict = runtime_to_use.save(self.workspace_path)
47
+
50
48
  self.function_generator = function_generator.FunctionGenerator(model_file_rel_path=model_file_rel_path)
51
49
  self.methods: List[model_method.ModelMethod] = []
52
50
  for target_method in model_meta.signatures.keys():
53
51
  method = model_method.ModelMethod(
54
52
  model_meta=model_meta,
55
53
  target_method=target_method,
56
- runtime_name=self.runtimes[0].name,
54
+ runtime_name=self._DEFAULT_RUNTIME_NAME,
57
55
  function_generator=self.function_generator,
58
56
  options=model_method.get_model_method_options_from_options(options, target_method),
59
57
  )
@@ -71,7 +69,16 @@ class ModelManifest:
71
69
 
72
70
  manifest_dict = model_manifest_schema.ModelManifestDict(
73
71
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
74
- runtimes={runtime.name: runtime.save(self.workspace_path) for runtime in self.runtimes},
72
+ runtimes={
73
+ self._DEFAULT_RUNTIME_NAME: model_manifest_schema.ModelRuntimeDict(
74
+ language="PYTHON",
75
+ version=runtime_to_use.runtime_env.python_version,
76
+ imports=runtime_dict["imports"],
77
+ dependencies=model_manifest_schema.ModelRuntimeDependenciesDict(
78
+ conda=runtime_dict["dependencies"]["conda"]
79
+ ),
80
+ )
81
+ },
75
82
  methods=[
76
83
  method.save(
77
84
  self.workspace_path,
@@ -83,8 +90,6 @@ class ModelManifest:
83
90
  ],
84
91
  )
85
92
 
86
- manifest_dict["user_data"] = self.generate_user_data_with_client_data(model_meta)
87
-
88
93
  with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
89
94
  # Anchors are not supported in the server, avoid that.
90
95
  yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
@@ -103,43 +108,3 @@ class ModelManifest:
103
108
  res = cast(model_manifest_schema.ModelManifestDict, raw_input)
104
109
 
105
110
  return res
106
-
107
- def generate_user_data_with_client_data(self, model_meta: model_meta_api.ModelMetadata) -> Dict[str, Any]:
108
- client_data = model_manifest_schema.SnowparkMLDataDict(
109
- schema_version=model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION,
110
- functions=[
111
- model_manifest_schema.ModelFunctionInfoDict(
112
- name=method.method_name.identifier(),
113
- target_method=method.target_method,
114
- signature=model_meta.signatures[method.target_method].to_dict(),
115
- )
116
- for method in self.methods
117
- ],
118
- )
119
- return {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: client_data}
120
-
121
- @staticmethod
122
- def parse_client_data_from_user_data(raw_user_data: Dict[str, Any]) -> model_manifest_schema.SnowparkMLDataDict:
123
- raw_client_data = raw_user_data.get(model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME, {})
124
- if not isinstance(raw_client_data, dict) or "schema_version" not in raw_client_data:
125
- raise ValueError(f"Ill-formatted client data {raw_client_data} in user data found.")
126
- loaded_client_data_schema_version = raw_client_data["schema_version"]
127
- if (
128
- not isinstance(loaded_client_data_schema_version, str)
129
- or loaded_client_data_schema_version != model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION
130
- ):
131
- raise ValueError(f"Unsupported client data schema version {loaded_client_data_schema_version} confronted.")
132
-
133
- return_functions_info: List[model_manifest_schema.ModelFunctionInfoDict] = []
134
- loaded_functions_info = raw_client_data.get("functions", [])
135
- for func in loaded_functions_info:
136
- fi = model_manifest_schema.ModelFunctionInfoDict(
137
- name=func["name"],
138
- target_method=func["target_method"],
139
- signature=func["signature"],
140
- )
141
- return_functions_info.append(fi)
142
-
143
- return model_manifest_schema.SnowparkMLDataDict(
144
- schema_version=loaded_client_data_schema_version, functions=return_functions_info
145
- )
@@ -1,5 +1,5 @@
1
1
  # This files contains schema definition of what will be written into MANIFEST.yml
2
-
2
+ import enum
3
3
  from typing import Any, Dict, List, Literal, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired, Required
@@ -12,6 +12,11 @@ MANIFEST_CLIENT_DATA_KEY_NAME = "snowpark_ml_data"
12
12
  MANIFEST_CLIENT_DATA_SCHEMA_VERSION = "2024-02-01"
13
13
 
14
14
 
15
+ class ModelMethodFunctionTypes(enum.Enum):
16
+ FUNCTION = "FUNCTION"
17
+ TABLE_FUNCTION = "TABLE_FUNCTION"
18
+
19
+
15
20
  class ModelRuntimeDependenciesDict(TypedDict):
16
21
  conda: Required[str]
17
22
 
@@ -49,11 +54,13 @@ class ModelFunctionInfo(TypedDict):
49
54
  Attributes:
50
55
  name: Name of the function to be called via SQL.
51
56
  target_method: actual target method name to be called.
57
+ target_method_function_type: target method function type (FUNCTION or TABLE_FUNCTION).
52
58
  signature: The signature of the model method.
53
59
  """
54
60
 
55
61
  name: Required[str]
56
62
  target_method: Required[str]
63
+ target_method_function_type: Required[str]
57
64
  signature: Required[model_signature.ModelSignature]
58
65
 
59
66
 
@@ -1,5 +1,4 @@
1
1
  import collections
2
- import enum
3
2
  import pathlib
4
3
  from typing import List, Optional, TypedDict, Union
5
4
 
@@ -13,11 +12,6 @@ from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
13
12
  from snowflake.snowpark._internal import type_utils
14
13
 
15
14
 
16
- class ModelMethodFunctionTypes(enum.Enum):
17
- FUNCTION = "FUNCTION"
18
- TABLE_FUNCTION = "TABLE_FUNCTION"
19
-
20
-
21
15
  class ModelMethodOptions(TypedDict):
22
16
  """Options when creating model method.
23
17
 
@@ -33,9 +27,9 @@ def get_model_method_options_from_options(
33
27
  options: type_hints.ModelSaveOption, target_method: str
34
28
  ) -> ModelMethodOptions:
35
29
  method_option = options.get("method_options", {}).get(target_method, {})
36
- global_function_type = options.get("function_type", ModelMethodFunctionTypes.FUNCTION.value)
30
+ global_function_type = options.get("function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value)
37
31
  function_type = method_option.get("function_type", global_function_type)
38
- if function_type not in [function_type.value for function_type in ModelMethodFunctionTypes]:
32
+ if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
39
33
  raise NotImplementedError
40
34
 
41
35
  # TODO(TH): enforce minimum snowflake version
@@ -89,7 +83,9 @@ class ModelMethod:
89
83
  if self.target_method not in self.model_meta.signatures.keys():
90
84
  raise ValueError(f"Target method {self.target_method} is not available in the signatures of the model.")
91
85
 
92
- self.function_type = self.options.get("function_type", ModelMethodFunctionTypes.FUNCTION.value)
86
+ self.function_type = self.options.get(
87
+ "function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
88
+ )
93
89
 
94
90
  @staticmethod
95
91
  def _get_method_arg_from_feature(
@@ -134,7 +130,7 @@ class ModelMethod:
134
130
  List[model_manifest_schema.ModelMethodSignatureField],
135
131
  List[model_manifest_schema.ModelMethodSignatureFieldWithName],
136
132
  ]
137
- if self.function_type == ModelMethodFunctionTypes.TABLE_FUNCTION.value:
133
+ if self.function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
138
134
  outputs = [
139
135
  ModelMethod._get_method_arg_from_feature(ft, case_sensitive=self.options.get("case_sensitive", False))
140
136
  for ft in self.model_meta.signatures[self.target_method].outputs
@@ -0,0 +1,206 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from typing_extensions import TypeGuard, Unpack
7
+
8
+ from snowflake.ml._internal import type_utils
9
+ from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
10
+ from snowflake.ml.model._packager.model_env import model_env
11
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
12
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
+ from snowflake.ml.model._packager.model_meta import (
14
+ model_blob_meta,
15
+ model_meta as model_meta_api,
16
+ model_meta_schema,
17
+ )
18
+ from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
19
+
20
+ if TYPE_CHECKING:
21
+ import catboost
22
+
23
+
24
+ @final
25
+ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
26
+ """Handler for CatBoost based model."""
27
+
28
+ HANDLER_TYPE = "catboost"
29
+ HANDLER_VERSION = "2024-03-21"
30
+ _MIN_SNOWPARK_ML_VERSION = "1.3.1"
31
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
32
+
33
+ MODELE_BLOB_FILE_OR_DIR = "model.bin"
34
+ DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
35
+
36
+ @classmethod
37
+ def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
38
+ return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
39
+ (hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
40
+ )
41
+
42
+ @classmethod
43
+ def cast_model(
44
+ cls,
45
+ model: model_types.SupportedModelType,
46
+ ) -> "catboost.CatBoost":
47
+ import catboost
48
+
49
+ assert isinstance(model, catboost.CatBoost)
50
+
51
+ return model
52
+
53
+ @classmethod
54
+ def save_model(
55
+ cls,
56
+ name: str,
57
+ model: "catboost.CatBoost",
58
+ model_meta: model_meta_api.ModelMetadata,
59
+ model_blobs_dir_path: str,
60
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
61
+ is_sub_model: Optional[bool] = False,
62
+ **kwargs: Unpack[model_types.CatBoostModelSaveOptions],
63
+ ) -> None:
64
+ import catboost
65
+
66
+ assert isinstance(model, catboost.CatBoost)
67
+
68
+ if not is_sub_model:
69
+ target_methods = handlers_utils.get_target_methods(
70
+ model=model,
71
+ target_methods=kwargs.pop("target_methods", None),
72
+ default_target_methods=cls.DEFAULT_TARGET_METHODS,
73
+ )
74
+
75
+ def get_prediction(
76
+ target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
77
+ ) -> model_types.SupportedLocalDataType:
78
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
79
+ sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
80
+ target_method = getattr(model, target_method_name, None)
81
+ assert callable(target_method)
82
+ predictions_df = target_method(sample_input_data)
83
+ return predictions_df
84
+
85
+ model_meta = handlers_utils.validate_signature(
86
+ model=model,
87
+ model_meta=model_meta,
88
+ target_methods=target_methods,
89
+ sample_input_data=sample_input_data,
90
+ get_prediction_fn=get_prediction,
91
+ )
92
+
93
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
94
+ os.makedirs(model_blob_path, exist_ok=True)
95
+ model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
96
+
97
+ model.save_model(model_save_path)
98
+
99
+ base_meta = model_blob_meta.ModelBlobMeta(
100
+ name=name,
101
+ model_type=cls.HANDLER_TYPE,
102
+ handler_version=cls.HANDLER_VERSION,
103
+ path=cls.MODELE_BLOB_FILE_OR_DIR,
104
+ options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
105
+ )
106
+ model_meta.models[name] = base_meta
107
+ model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
108
+
109
+ model_meta.env.include_if_absent(
110
+ [
111
+ model_env.ModelDependency(requirement="catboost", pip_name="catboost"),
112
+ ],
113
+ check_local_version=True,
114
+ )
115
+ model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
116
+
117
+ return None
118
+
119
+ @classmethod
120
+ def load_model(
121
+ cls,
122
+ name: str,
123
+ model_meta: model_meta_api.ModelMetadata,
124
+ model_blobs_dir_path: str,
125
+ **kwargs: Unpack[model_types.ModelLoadOption],
126
+ ) -> "catboost.CatBoost":
127
+ import catboost
128
+
129
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
130
+ model_blobs_metadata = model_meta.models
131
+ model_blob_metadata = model_blobs_metadata[name]
132
+ model_blob_filename = model_blob_metadata.path
133
+ model_blob_file_path = os.path.join(model_blob_path, model_blob_filename)
134
+
135
+ model_blob_options = cast(model_meta_schema.CatBoostModelBlobOptions, model_blob_metadata.options)
136
+ if "catboost_estimator_type" not in model_blob_options:
137
+ raise ValueError("Missing field `catboost_estimator_type` in model blob metadata for type `catboost`")
138
+
139
+ catboost_estimator_type = model_blob_options["catboost_estimator_type"]
140
+ if not hasattr(catboost, catboost_estimator_type):
141
+ raise ValueError("Type of CatBoost estimator is not supported.")
142
+
143
+ assert os.path.isfile(model_blob_file_path) # saved model is a file
144
+ model = getattr(catboost, catboost_estimator_type)()
145
+ model.load_model(model_blob_file_path)
146
+ assert isinstance(model, getattr(catboost, catboost_estimator_type))
147
+
148
+ if kwargs.get("use_gpu", False):
149
+ assert type(kwargs.get("use_gpu", False)) == bool
150
+ gpu_params = {"task_type": "GPU"}
151
+ model.__dict__.update(gpu_params)
152
+
153
+ return model
154
+
155
+ @classmethod
156
+ def convert_as_custom_model(
157
+ cls,
158
+ raw_model: "catboost.CatBoost",
159
+ model_meta: model_meta_api.ModelMetadata,
160
+ **kwargs: Unpack[model_types.ModelLoadOption],
161
+ ) -> custom_model.CustomModel:
162
+ import catboost
163
+
164
+ from snowflake.ml.model import custom_model
165
+
166
+ def _create_custom_model(
167
+ raw_model: "catboost.CatBoost",
168
+ model_meta: model_meta_api.ModelMetadata,
169
+ ) -> Type[custom_model.CustomModel]:
170
+ def fn_factory(
171
+ raw_model: "catboost.CatBoost",
172
+ signature: model_signature.ModelSignature,
173
+ target_method: str,
174
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
175
+ @custom_model.inference_api
176
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
177
+
178
+ res = getattr(raw_model, target_method)(X)
179
+
180
+ if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
181
+ # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
182
+ # return a list of ndarrays. We need to deal them separately
183
+ df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
184
+ else:
185
+ df = pd.DataFrame(res)
186
+
187
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
188
+
189
+ return fn
190
+
191
+ type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
192
+ for target_method_name, sig in model_meta.signatures.items():
193
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
194
+
195
+ _CatBoostModel = type(
196
+ "_CatBoostModel",
197
+ (custom_model.CustomModel,),
198
+ type_method_dict,
199
+ )
200
+
201
+ return _CatBoostModel
202
+
203
+ _CatBoostModel = _create_custom_model(raw_model, model_meta)
204
+ catboost_model = _CatBoostModel(custom_model.ModelContext())
205
+
206
+ return catboost_model
@@ -0,0 +1,218 @@
1
+ import os
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Optional,
8
+ Type,
9
+ Union,
10
+ cast,
11
+ final,
12
+ )
13
+
14
+ import cloudpickle
15
+ import numpy as np
16
+ import pandas as pd
17
+ from typing_extensions import TypeGuard, Unpack
18
+
19
+ from snowflake.ml._internal import type_utils
20
+ from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
21
+ from snowflake.ml.model._packager.model_env import model_env
22
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
23
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
24
+ from snowflake.ml.model._packager.model_meta import (
25
+ model_blob_meta,
26
+ model_meta as model_meta_api,
27
+ model_meta_schema,
28
+ )
29
+ from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
30
+
31
+ if TYPE_CHECKING:
32
+ import lightgbm
33
+
34
+
35
+ @final
36
+ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgbm.LGBMModel"]]):
37
+ """Handler for LightGBM based model."""
38
+
39
+ HANDLER_TYPE = "lightgbm"
40
+ HANDLER_VERSION = "2024-03-19"
41
+ _MIN_SNOWPARK_ML_VERSION = "1.3.1"
42
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
43
+
44
+ MODELE_BLOB_FILE_OR_DIR = "model.pkl"
45
+ DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
46
+
47
+ @classmethod
48
+ def can_handle(
49
+ cls, model: model_types.SupportedModelType
50
+ ) -> TypeGuard[Union["lightgbm.Booster", "lightgbm.LGBMModel"]]:
51
+ return (
52
+ type_utils.LazyType("lightgbm.Booster").isinstance(model)
53
+ or type_utils.LazyType("lightgbm.LGBMModel").isinstance(model)
54
+ ) and any(
55
+ (hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
56
+ )
57
+
58
+ @classmethod
59
+ def cast_model(
60
+ cls,
61
+ model: model_types.SupportedModelType,
62
+ ) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
63
+ import lightgbm
64
+
65
+ assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
66
+
67
+ return model
68
+
69
+ @classmethod
70
+ def save_model(
71
+ cls,
72
+ name: str,
73
+ model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
74
+ model_meta: model_meta_api.ModelMetadata,
75
+ model_blobs_dir_path: str,
76
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
77
+ is_sub_model: Optional[bool] = False,
78
+ **kwargs: Unpack[model_types.LGBMModelSaveOptions],
79
+ ) -> None:
80
+ import lightgbm
81
+
82
+ assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
83
+
84
+ if not is_sub_model:
85
+ target_methods = handlers_utils.get_target_methods(
86
+ model=model,
87
+ target_methods=kwargs.pop("target_methods", None),
88
+ default_target_methods=cls.DEFAULT_TARGET_METHODS,
89
+ )
90
+
91
+ def get_prediction(
92
+ target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
93
+ ) -> model_types.SupportedLocalDataType:
94
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
95
+ sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
96
+ target_method = getattr(model, target_method_name, None)
97
+ assert callable(target_method)
98
+ predictions_df = target_method(sample_input_data)
99
+ return predictions_df
100
+
101
+ model_meta = handlers_utils.validate_signature(
102
+ model=model,
103
+ model_meta=model_meta,
104
+ target_methods=target_methods,
105
+ sample_input_data=sample_input_data,
106
+ get_prediction_fn=get_prediction,
107
+ )
108
+
109
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
110
+ os.makedirs(model_blob_path, exist_ok=True)
111
+
112
+ model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
113
+ with open(model_save_path, "wb") as f:
114
+ cloudpickle.dump(model, f)
115
+
116
+ base_meta = model_blob_meta.ModelBlobMeta(
117
+ name=name,
118
+ model_type=cls.HANDLER_TYPE,
119
+ handler_version=cls.HANDLER_VERSION,
120
+ path=cls.MODELE_BLOB_FILE_OR_DIR,
121
+ options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
122
+ )
123
+ model_meta.models[name] = base_meta
124
+ model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
125
+
126
+ model_meta.env.include_if_absent(
127
+ [
128
+ model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"),
129
+ model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
130
+ ],
131
+ check_local_version=True,
132
+ )
133
+
134
+ return None
135
+
136
+ @classmethod
137
+ def load_model(
138
+ cls,
139
+ name: str,
140
+ model_meta: model_meta_api.ModelMetadata,
141
+ model_blobs_dir_path: str,
142
+ **kwargs: Unpack[model_types.ModelLoadOption],
143
+ ) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
144
+ import lightgbm
145
+
146
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
147
+ model_blobs_metadata = model_meta.models
148
+ model_blob_metadata = model_blobs_metadata[name]
149
+ model_blob_filename = model_blob_metadata.path
150
+ model_blob_file_path = os.path.join(model_blob_path, model_blob_filename)
151
+
152
+ model_blob_options = cast(model_meta_schema.LightGBMModelBlobOptions, model_blob_metadata.options)
153
+ if "lightgbm_estimator_type" not in model_blob_options:
154
+ raise ValueError("Missing field `lightgbm_estimator_type` in model blob metadata for type `lightgbm`")
155
+
156
+ lightgbm_estimator_type = model_blob_options["lightgbm_estimator_type"]
157
+ if not hasattr(lightgbm, lightgbm_estimator_type):
158
+ raise ValueError("Type of LightGBM estimator is not supported.")
159
+
160
+ assert os.path.isfile(model_blob_file_path) # saved model is a file
161
+ with open(model_blob_file_path, "rb") as f:
162
+ model = cloudpickle.load(f)
163
+ assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
164
+
165
+ return model
166
+
167
+ @classmethod
168
+ def convert_as_custom_model(
169
+ cls,
170
+ raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
171
+ model_meta: model_meta_api.ModelMetadata,
172
+ **kwargs: Unpack[model_types.ModelLoadOption],
173
+ ) -> custom_model.CustomModel:
174
+ import lightgbm
175
+
176
+ from snowflake.ml.model import custom_model
177
+
178
+ def _create_custom_model(
179
+ raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
180
+ model_meta: model_meta_api.ModelMetadata,
181
+ ) -> Type[custom_model.CustomModel]:
182
+ def fn_factory(
183
+ raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
184
+ signature: model_signature.ModelSignature,
185
+ target_method: str,
186
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
187
+ @custom_model.inference_api
188
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
189
+
190
+ res = getattr(raw_model, target_method)(X)
191
+
192
+ if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
193
+ # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
194
+ # return a list of ndarrays. We need to deal them separately
195
+ df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
196
+ else:
197
+ df = pd.DataFrame(res)
198
+
199
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
200
+
201
+ return fn
202
+
203
+ type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
204
+ for target_method_name, sig in model_meta.signatures.items():
205
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
206
+
207
+ _LightGBMModel = type(
208
+ "_LightGBMModel",
209
+ (custom_model.CustomModel,),
210
+ type_method_dict,
211
+ )
212
+
213
+ return _LightGBMModel
214
+
215
+ _LightGBMModel = _create_custom_model(raw_model, model_meta)
216
+ lightgbm_model = _LightGBMModel(custom_model.ModelContext())
217
+
218
+ return lightgbm_model
@@ -47,6 +47,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
47
47
  or type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model)
48
48
  )
49
49
  and (not type_utils.LazyType("xgboost.XGBModel").isinstance(model)) # XGBModel is actually a BaseEstimator
50
+ and (
51
+ not type_utils.LazyType("lightgbm.LGBMModel").isinstance(model)
52
+ ) # LGBMModel is actually a BaseEstimator
50
53
  and any(
51
54
  (hasattr(model, method) and callable(getattr(model, method, None)))
52
55
  for method in cls.DEFAULT_TARGET_METHODS
@@ -4,7 +4,7 @@ REQUIREMENTS = [
4
4
  "cloudpickle>=2.0.0",
5
5
  "numpy>=1.23,<2",
6
6
  "packaging>=20.9,<24",
7
- "pandas>=1.0.0,<2",
7
+ "pandas>=1.0.0,<3",
8
8
  "pyyaml>=6.0,<7",
9
9
  "snowflake-snowpark-python>=1.11.1,<2,!=1.12.0",
10
10
  "typing-extensions>=4.1.0,<5"