snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/ml/_internal/env_utils.py +2 -1
  2. snowflake/ml/_internal/file_utils.py +35 -40
  3. snowflake/ml/_internal/telemetry.py +5 -8
  4. snowflake/ml/_internal/utils/identifier.py +74 -7
  5. snowflake/ml/_internal/utils/uri.py +7 -2
  6. snowflake/ml/model/_core_requirements.py +1 -1
  7. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
  8. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
  9. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
  10. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
  11. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
  12. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
  13. snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
  14. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
  15. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
  16. snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
  17. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
  18. snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
  19. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
  20. snowflake/ml/model/_deployer.py +14 -27
  21. snowflake/ml/model/_env.py +4 -4
  22. snowflake/ml/model/_handlers/_base.py +3 -1
  23. snowflake/ml/model/_handlers/custom.py +14 -2
  24. snowflake/ml/model/_handlers/pytorch.py +186 -0
  25. snowflake/ml/model/_handlers/sklearn.py +14 -8
  26. snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
  27. snowflake/ml/model/_handlers/torchscript.py +180 -0
  28. snowflake/ml/model/_handlers/xgboost.py +19 -9
  29. snowflake/ml/model/_model.py +27 -21
  30. snowflake/ml/model/_model_meta.py +33 -19
  31. snowflake/ml/model/model_signature.py +446 -66
  32. snowflake/ml/model/type_hints.py +28 -15
  33. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
  34. snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
  35. snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
  36. snowflake/ml/modeling/cluster/birch.py +79 -43
  37. snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
  38. snowflake/ml/modeling/cluster/dbscan.py +79 -43
  39. snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
  40. snowflake/ml/modeling/cluster/k_means.py +79 -43
  41. snowflake/ml/modeling/cluster/mean_shift.py +79 -43
  42. snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
  43. snowflake/ml/modeling/cluster/optics.py +79 -43
  44. snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
  45. snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
  46. snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
  47. snowflake/ml/modeling/compose/column_transformer.py +79 -43
  48. snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
  49. snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
  50. snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
  51. snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
  52. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
  53. snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
  54. snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
  55. snowflake/ml/modeling/covariance/oas.py +79 -43
  56. snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
  57. snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
  58. snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
  59. snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
  60. snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
  61. snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
  62. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
  63. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
  64. snowflake/ml/modeling/decomposition/pca.py +79 -43
  65. snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
  66. snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
  68. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
  69. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
  70. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
  71. snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
  72. snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
  73. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
  74. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
  75. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
  76. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
  79. snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
  80. snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
  81. snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
  82. snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
  83. snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
  84. snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
  85. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
  86. snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
  87. snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
  88. snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
  89. snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
  90. snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
  91. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
  92. snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
  94. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
  95. snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
  96. snowflake/ml/modeling/impute/knn_imputer.py +79 -43
  97. snowflake/ml/modeling/impute/missing_indicator.py +79 -43
  98. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
  99. snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
  100. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
  101. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
  102. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
  103. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
  104. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
  105. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
  106. snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
  107. snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
  108. snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
  109. snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
  110. snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
  111. snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
  112. snowflake/ml/modeling/linear_model/lars.py +79 -43
  113. snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
  114. snowflake/ml/modeling/linear_model/lasso.py +79 -43
  115. snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
  116. snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
  117. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
  118. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
  119. snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
  120. snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
  121. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
  123. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
  124. snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
  125. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
  126. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
  127. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
  128. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
  129. snowflake/ml/modeling/linear_model/perceptron.py +79 -43
  130. snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
  131. snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
  132. snowflake/ml/modeling/linear_model/ridge.py +79 -43
  133. snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
  134. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
  135. snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
  136. snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
  137. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
  138. snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
  139. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
  140. snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
  141. snowflake/ml/modeling/manifold/isomap.py +79 -43
  142. snowflake/ml/modeling/manifold/mds.py +79 -43
  143. snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
  144. snowflake/ml/modeling/manifold/tsne.py +79 -43
  145. snowflake/ml/modeling/metrics/classification.py +6 -1
  146. snowflake/ml/modeling/metrics/regression.py +517 -9
  147. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
  148. snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
  149. snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
  150. snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
  151. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
  152. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
  153. snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
  154. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
  155. snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
  156. snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
  157. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
  158. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
  159. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
  160. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
  161. snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
  162. snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
  163. snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
  164. snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
  165. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
  166. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
  167. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
  168. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
  169. snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
  170. snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
  171. snowflake/ml/modeling/pipeline/pipeline.py +24 -0
  172. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
  173. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
  174. snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
  175. snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
  176. snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
  177. snowflake/ml/modeling/svm/linear_svc.py +79 -43
  178. snowflake/ml/modeling/svm/linear_svr.py +79 -43
  179. snowflake/ml/modeling/svm/nu_svc.py +79 -43
  180. snowflake/ml/modeling/svm/nu_svr.py +79 -43
  181. snowflake/ml/modeling/svm/svc.py +79 -43
  182. snowflake/ml/modeling/svm/svr.py +79 -43
  183. snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
  184. snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
  185. snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
  186. snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
  187. snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
  188. snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
  189. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
  190. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
  191. snowflake/ml/registry/model_registry.py +123 -121
  192. snowflake/ml/version.py +1 -1
  193. {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
  194. snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
  195. snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
  196. {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,186 @@
1
+ import os
2
+ import sys
3
+ from typing import TYPE_CHECKING, Callable, Optional, Type, cast
4
+
5
+ import cloudpickle
6
+ import pandas as pd
7
+ from typing_extensions import TypeGuard, Unpack
8
+
9
+ from snowflake.ml._internal import type_utils
10
+ from snowflake.ml.model import (
11
+ _model_meta as model_meta_api,
12
+ custom_model,
13
+ model_signature,
14
+ type_hints as model_types,
15
+ )
16
+ from snowflake.ml.model._handlers import _base
17
+
18
+ if TYPE_CHECKING:
19
+ import torch
20
+
21
+
22
+ class _PyTorchHandler(_base._ModelHandler["torch.nn.Module"]):
23
+ """Handler for PyTorch based model.
24
+
25
+ Currently torch.nn.Module based classes are supported.
26
+ """
27
+
28
+ handler_type = "pytorch"
29
+ MODEL_BLOB_FILE = "model.pt"
30
+ DEFAULT_TARGET_METHODS = ["forward"]
31
+
32
+ @staticmethod
33
+ def can_handle(
34
+ model: model_types.SupportedModelType,
35
+ ) -> TypeGuard["torch.nn.Module"]:
36
+ return type_utils.LazyType("torch.nn.Module").isinstance(model) and not type_utils.LazyType(
37
+ "torch.jit.ScriptModule"
38
+ ).isinstance(model)
39
+
40
+ @staticmethod
41
+ def cast_model(
42
+ model: model_types.SupportedModelType,
43
+ ) -> "torch.nn.Module":
44
+ import torch
45
+
46
+ assert isinstance(model, torch.nn.Module)
47
+
48
+ return cast(torch.nn.Module, model)
49
+
50
+ @staticmethod
51
+ def _save_model(
52
+ name: str,
53
+ model: "torch.nn.Module",
54
+ model_meta: model_meta_api.ModelMetadata,
55
+ model_blobs_dir_path: str,
56
+ sample_input: Optional[model_types.SupportedDataType] = None,
57
+ is_sub_model: Optional[bool] = False,
58
+ **kwargs: Unpack[model_types.PyTorchSaveOptions],
59
+ ) -> None:
60
+ import torch
61
+
62
+ assert isinstance(model, torch.nn.Module)
63
+
64
+ if not is_sub_model:
65
+ target_methods = model_meta_api._get_target_methods(
66
+ model=model,
67
+ target_methods=kwargs.pop("target_methods", None),
68
+ default_target_methods=_PyTorchHandler.DEFAULT_TARGET_METHODS,
69
+ )
70
+
71
+ def get_prediction(
72
+ target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
73
+ ) -> model_types.SupportedLocalDataType:
74
+ if not model_signature._SeqOfPyTorchTensorHandler.can_handle(sample_input):
75
+ sample_input = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(
76
+ model_signature._convert_local_data_to_df(sample_input)
77
+ )
78
+
79
+ model.eval()
80
+ target_method = getattr(model, target_method_name, None)
81
+ assert callable(target_method)
82
+ with torch.no_grad():
83
+ predictions_df = target_method(sample_input)
84
+ return predictions_df
85
+
86
+ model_meta = model_meta_api._validate_signature(
87
+ model=model,
88
+ model_meta=model_meta,
89
+ target_methods=target_methods,
90
+ sample_input=sample_input,
91
+ get_prediction_fn=get_prediction,
92
+ )
93
+
94
+ # Torch.save using pickle will not pickle the model definition if defined in the top level of a module.
95
+ # Make sure that the module where the model is defined get pickled by value as well.
96
+ cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
97
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
98
+ os.makedirs(model_blob_path, exist_ok=True)
99
+ with open(os.path.join(model_blob_path, _PyTorchHandler.MODEL_BLOB_FILE), "wb") as f:
100
+ torch.save(model, f, pickle_module=cloudpickle)
101
+ base_meta = model_meta_api._ModelBlobMetadata(
102
+ name=name, model_type=_PyTorchHandler.handler_type, path=_PyTorchHandler.MODEL_BLOB_FILE
103
+ )
104
+ model_meta.models[name] = base_meta
105
+ model_meta._include_if_absent([model_meta_api.Dependency(conda_name="pytorch", pip_name="torch")])
106
+
107
+ @staticmethod
108
+ def _load_model(
109
+ name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
110
+ ) -> "torch.nn.Module":
111
+ import torch
112
+
113
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
114
+ if not hasattr(model_meta, "models"):
115
+ raise ValueError("Ill model metadata found.")
116
+ model_blobs_metadata = model_meta.models
117
+ if name not in model_blobs_metadata:
118
+ raise ValueError(f"Blob of model {name} does not exist.")
119
+ model_blob_metadata = model_blobs_metadata[name]
120
+ model_blob_filename = model_blob_metadata.path
121
+ with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
122
+ m = torch.load(f)
123
+ assert isinstance(m, torch.nn.Module)
124
+ return m
125
+
126
+ @staticmethod
127
+ def _load_as_custom_model(
128
+ name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
129
+ ) -> custom_model.CustomModel:
130
+ """Create a custom model class wrap for unified interface when being deployed. The predict method will be
131
+ re-targeted based on target_method metadata.
132
+
133
+ Args:
134
+ name: Name of the model.
135
+ model_meta: The model metadata.
136
+ model_blobs_dir_path: Directory path to the whole model.
137
+
138
+ Returns:
139
+ The model object as a custom model.
140
+ """
141
+ import torch
142
+
143
+ from snowflake.ml.model import custom_model
144
+
145
+ def _create_custom_model(
146
+ raw_model: "torch.nn.Module",
147
+ model_meta: model_meta_api.ModelMetadata,
148
+ ) -> Type[custom_model.CustomModel]:
149
+ def fn_factory(
150
+ raw_model: "torch.nn.Module",
151
+ signature: model_signature.ModelSignature,
152
+ target_method: str,
153
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
154
+ @custom_model.inference_api
155
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
156
+ if X.isnull().any(axis=None):
157
+ raise ValueError("Tensor cannot handle null values.")
158
+
159
+ raw_model.eval()
160
+ t = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
161
+
162
+ with torch.no_grad():
163
+ res = getattr(raw_model, target_method)(t)
164
+ return model_signature._rename_pandas_df(
165
+ data=model_signature._SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
166
+ )
167
+
168
+ return fn
169
+
170
+ type_method_dict = {}
171
+ for target_method_name, sig in model_meta.signatures.items():
172
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
173
+
174
+ _PyTorchModel = type(
175
+ "_PyTorchModel",
176
+ (custom_model.CustomModel,),
177
+ type_method_dict,
178
+ )
179
+
180
+ return _PyTorchModel
181
+
182
+ raw_model = _PyTorchHandler._load_model(name, model_meta, model_blobs_dir_path)
183
+ _PyTorchModel = _create_custom_model(raw_model, model_meta)
184
+ pytorch_model = _PyTorchModel(custom_model.ModelContext())
185
+
186
+ return pytorch_model
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Optional, Sequence, Type, Union, cast
2
+ from typing import TYPE_CHECKING, Callable, Optional, Type, Union, cast
3
3
 
4
4
  import cloudpickle
5
5
  import numpy as np
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import (
11
11
  _model_meta as model_meta_api,
12
12
  custom_model,
13
+ model_signature,
13
14
  type_hints as model_types,
14
15
  )
15
16
  from snowflake.ml.model._handlers import _base
@@ -80,6 +81,9 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
80
81
  def get_prediction(
81
82
  target_method_name: str, sample_input: model_types.SupportedLocalDataType
82
83
  ) -> model_types.SupportedLocalDataType:
84
+ if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
85
+ sample_input = model_signature._convert_local_data_to_df(sample_input)
86
+
83
87
  target_method = getattr(model, target_method_name, None)
84
88
  assert callable(target_method)
85
89
  predictions_df = target_method(sample_input)
@@ -101,6 +105,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
101
105
  name=name, model_type=_SKLModelHandler.handler_type, path=_SKLModelHandler.MODEL_BLOB_FILE
102
106
  )
103
107
  model_meta.models[name] = base_meta
108
+ model_meta._include_if_absent([model_meta_api.Dependency(conda_name="scikit-learn", pip_name="scikit-learn")])
104
109
 
105
110
  @staticmethod
106
111
  def _load_model(
@@ -146,7 +151,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
146
151
  ) -> Type[custom_model.CustomModel]:
147
152
  def fn_factory(
148
153
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
149
- output_col_names: Sequence[str],
154
+ signature: model_signature.ModelSignature,
150
155
  target_method: str,
151
156
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
152
157
  @custom_model.inference_api
@@ -155,17 +160,18 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
155
160
 
156
161
  if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
157
162
  # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
158
- # return a list of ndarrays. We need to concatenate them.
159
- res = np.concatenate(res, axis=1)
160
- return pd.DataFrame(res, columns=output_col_names)
163
+ # return a list of ndarrays. We need to deal them seperately
164
+ df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
165
+ else:
166
+ df = pd.DataFrame(res)
167
+
168
+ return model_signature._rename_pandas_df(df, signature.outputs)
161
169
 
162
170
  return fn
163
171
 
164
172
  type_method_dict = {}
165
173
  for target_method_name, sig in model_meta.signatures.items():
166
- type_method_dict[target_method_name] = fn_factory(
167
- raw_model, [spec.name for spec in sig.outputs], target_method_name
168
- )
174
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
169
175
 
170
176
  _SKLModel = type(
171
177
  "_SKLModel",
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Optional, Sequence, Type, cast
2
+ from typing import TYPE_CHECKING, Callable, Optional, Type, cast
3
3
 
4
4
  import cloudpickle
5
5
  import numpy as np
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import (
11
11
  _model_meta as model_meta_api,
12
12
  custom_model,
13
+ model_signature,
13
14
  type_hints as model_types,
14
15
  )
15
16
  from snowflake.ml.model._handlers import _base
@@ -81,6 +82,9 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
81
82
  def get_prediction(
82
83
  target_method_name: str, sample_input: model_types.SupportedLocalDataType
83
84
  ) -> model_types.SupportedLocalDataType:
85
+ if not isinstance(sample_input, (pd.DataFrame,)):
86
+ sample_input = model_signature._convert_local_data_to_df(sample_input)
87
+
84
88
  target_method = getattr(model, target_method_name, None)
85
89
  assert callable(target_method)
86
90
  predictions_df = target_method(sample_input)
@@ -106,7 +110,7 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
106
110
  model_dependencies = model._get_dependencies()
107
111
  for dep in model_dependencies:
108
112
  pkg_name = dep.split("==")[0]
109
- _include_if_absent_pkgs.append((pkg_name, pkg_name))
113
+ _include_if_absent_pkgs.append(model_meta_api.Dependency(conda_name=pkg_name, pip_name=pkg_name))
110
114
  model_meta._include_if_absent(_include_if_absent_pkgs)
111
115
 
112
116
  @staticmethod
@@ -150,7 +154,7 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
150
154
  ) -> Type[custom_model.CustomModel]:
151
155
  def fn_factory(
152
156
  raw_model: "BaseEstimator",
153
- output_col_names: Sequence[str],
157
+ signature: model_signature.ModelSignature,
154
158
  target_method: str,
155
159
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
156
160
  @custom_model.inference_api
@@ -159,17 +163,18 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
159
163
 
160
164
  if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
161
165
  # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
162
- # return a list of ndarrays. We need to concatenate them.
163
- res = np.concatenate(res, axis=1)
164
- return pd.DataFrame(res, columns=output_col_names)
166
+ # return a list of ndarrays. We need to deal them seperately
167
+ df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
168
+ else:
169
+ df = pd.DataFrame(res)
170
+
171
+ return model_signature._rename_pandas_df(df, signature.outputs)
165
172
 
166
173
  return fn
167
174
 
168
175
  type_method_dict = {}
169
176
  for target_method_name, sig in model_meta.signatures.items():
170
- type_method_dict[target_method_name] = fn_factory(
171
- raw_model, [spec.name for spec in sig.outputs], target_method_name
172
- )
177
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
173
178
 
174
179
  _SnowMLModel = type(
175
180
  "_SnowMLModel",
@@ -0,0 +1,180 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Callable, Optional, Type, cast
3
+
4
+ import pandas as pd
5
+ from typing_extensions import TypeGuard, Unpack
6
+
7
+ from snowflake.ml._internal import type_utils
8
+ from snowflake.ml.model import (
9
+ _model_meta as model_meta_api,
10
+ custom_model,
11
+ model_signature,
12
+ type_hints as model_types,
13
+ )
14
+ from snowflake.ml.model._handlers import _base
15
+
16
+ if TYPE_CHECKING:
17
+ import torch
18
+
19
+
20
+ class _TorchScriptHandler(_base._ModelHandler["torch.jit.ScriptModule"]): # type:ignore[name-defined]
21
+ """Handler for PyTorch JIT based model.
22
+
23
+ Currently torch.jit.ScriptModule based classes are supported.
24
+ """
25
+
26
+ handler_type = "torchscript"
27
+ MODEL_BLOB_FILE = "model.pt"
28
+ DEFAULT_TARGET_METHODS = ["forward"]
29
+
30
+ @staticmethod
31
+ def can_handle(
32
+ model: model_types.SupportedModelType,
33
+ ) -> TypeGuard["torch.jit.ScriptModule"]: # type:ignore[name-defined]
34
+ return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
35
+
36
+ @staticmethod
37
+ def cast_model(
38
+ model: model_types.SupportedModelType,
39
+ ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
40
+ import torch
41
+
42
+ assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
43
+
44
+ return cast(torch.jit.ScriptModule, model) # type:ignore[name-defined]
45
+
46
+ @staticmethod
47
+ def _save_model(
48
+ name: str,
49
+ model: "torch.jit.ScriptModule", # type:ignore[name-defined]
50
+ model_meta: model_meta_api.ModelMetadata,
51
+ model_blobs_dir_path: str,
52
+ sample_input: Optional[model_types.SupportedDataType] = None,
53
+ is_sub_model: Optional[bool] = False,
54
+ **kwargs: Unpack[model_types.TorchScriptSaveOptions],
55
+ ) -> None:
56
+ import torch
57
+
58
+ assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
59
+
60
+ if not is_sub_model:
61
+ target_methods = model_meta_api._get_target_methods(
62
+ model=model,
63
+ target_methods=kwargs.pop("target_methods", None),
64
+ default_target_methods=_TorchScriptHandler.DEFAULT_TARGET_METHODS,
65
+ )
66
+
67
+ def get_prediction(
68
+ target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
69
+ ) -> model_types.SupportedLocalDataType:
70
+ if not model_signature._SeqOfPyTorchTensorHandler.can_handle(sample_input):
71
+ sample_input = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(
72
+ model_signature._convert_local_data_to_df(sample_input)
73
+ )
74
+
75
+ model.eval()
76
+ target_method = getattr(model, target_method_name, None)
77
+ assert callable(target_method)
78
+ with torch.no_grad():
79
+ predictions_df = target_method(sample_input)
80
+ return predictions_df
81
+
82
+ model_meta = model_meta_api._validate_signature(
83
+ model=model,
84
+ model_meta=model_meta,
85
+ target_methods=target_methods,
86
+ sample_input=sample_input,
87
+ get_prediction_fn=get_prediction,
88
+ )
89
+
90
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
91
+ os.makedirs(model_blob_path, exist_ok=True)
92
+ with open(os.path.join(model_blob_path, _TorchScriptHandler.MODEL_BLOB_FILE), "wb") as f:
93
+ torch.jit.save(model, f) # type:ignore[attr-defined]
94
+ base_meta = model_meta_api._ModelBlobMetadata(
95
+ name=name, model_type=_TorchScriptHandler.handler_type, path=_TorchScriptHandler.MODEL_BLOB_FILE
96
+ )
97
+ model_meta.models[name] = base_meta
98
+ model_meta._include_if_absent([model_meta_api.Dependency(conda_name="pytorch", pip_name="torch")])
99
+
100
+ @staticmethod
101
+ def _load_model(
102
+ name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
103
+ ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
104
+ import torch
105
+
106
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
107
+ if not hasattr(model_meta, "models"):
108
+ raise ValueError("Ill model metadata found.")
109
+ model_blobs_metadata = model_meta.models
110
+ if name not in model_blobs_metadata:
111
+ raise ValueError(f"Blob of model {name} does not exist.")
112
+ model_blob_metadata = model_blobs_metadata[name]
113
+ model_blob_filename = model_blob_metadata.path
114
+ with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
115
+ m = torch.jit.load(f) # type:ignore[attr-defined]
116
+ assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
117
+ return m
118
+
119
+ @staticmethod
120
+ def _load_as_custom_model(
121
+ name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
122
+ ) -> custom_model.CustomModel:
123
+ """Create a custom model class wrap for unified interface when being deployed. The predict method will be
124
+ re-targeted based on target_method metadata.
125
+
126
+ Args:
127
+ name: Name of the model.
128
+ model_meta: The model metadata.
129
+ model_blobs_dir_path: Directory path to the whole model.
130
+
131
+ Returns:
132
+ The model object as a custom model.
133
+ """
134
+ from snowflake.ml.model import custom_model
135
+
136
+ def _create_custom_model(
137
+ raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
138
+ model_meta: model_meta_api.ModelMetadata,
139
+ ) -> Type[custom_model.CustomModel]:
140
+ def fn_factory(
141
+ raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
142
+ signature: model_signature.ModelSignature,
143
+ target_method: str,
144
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
145
+ @custom_model.inference_api
146
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
147
+ if X.isnull().any(axis=None):
148
+ raise ValueError("Tensor cannot handle null values.")
149
+
150
+ import torch
151
+
152
+ raw_model.eval()
153
+
154
+ t = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
155
+
156
+ with torch.no_grad():
157
+ res = getattr(raw_model, target_method)(t)
158
+ return model_signature._rename_pandas_df(
159
+ data=model_signature._SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
160
+ )
161
+
162
+ return fn
163
+
164
+ type_method_dict = {}
165
+ for target_method_name, sig in model_meta.signatures.items():
166
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
167
+
168
+ _TorchScriptModel = type(
169
+ "_TorchScriptModel",
170
+ (custom_model.CustomModel,),
171
+ type_method_dict,
172
+ )
173
+
174
+ return _TorchScriptModel
175
+
176
+ raw_model = _TorchScriptHandler._load_model(name, model_meta, model_blobs_dir_path)
177
+ _TorchScriptModel = _create_custom_model(raw_model, model_meta)
178
+ torchscript_model = _TorchScriptModel(custom_model.ModelContext())
179
+
180
+ return torchscript_model
@@ -1,6 +1,6 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
- from typing import TYPE_CHECKING, Callable, Optional, Sequence, Type, Union
3
+ from typing import TYPE_CHECKING, Callable, Optional, Type, Union
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import (
11
11
  _model_meta as model_meta_api,
12
12
  custom_model,
13
+ model_signature,
13
14
  type_hints as model_types,
14
15
  )
15
16
  from snowflake.ml.model._handlers import _base
@@ -72,6 +73,9 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
72
73
  def get_prediction(
73
74
  target_method_name: str, sample_input: model_types.SupportedLocalDataType
74
75
  ) -> model_types.SupportedLocalDataType:
76
+ if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
77
+ sample_input = model_signature._convert_local_data_to_df(sample_input)
78
+
75
79
  target_method = getattr(model, target_method_name, None)
76
80
  assert callable(target_method)
77
81
  predictions_df = target_method(sample_input)
@@ -95,7 +99,12 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
95
99
  options={"xgb_estimator_type": model.__class__.__name__},
96
100
  )
97
101
  model_meta.models[name] = base_meta
98
- model_meta._include_if_absent([("xgboost", "xgboost")])
102
+ model_meta._include_if_absent(
103
+ [
104
+ model_meta_api.Dependency(conda_name="scikit-learn", pip_name="scikit-learn"),
105
+ model_meta_api.Dependency(conda_name="xgboost", pip_name="xgboost"),
106
+ ]
107
+ )
99
108
 
100
109
  @staticmethod
101
110
  def _load_model(
@@ -143,7 +152,7 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
143
152
  ) -> Type[custom_model.CustomModel]:
144
153
  def fn_factory(
145
154
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
146
- output_col_names: Sequence[str],
155
+ signature: model_signature.ModelSignature,
147
156
  target_method: str,
148
157
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
149
158
  @custom_model.inference_api
@@ -152,17 +161,18 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
152
161
 
153
162
  if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
154
163
  # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
155
- # return a list of ndarrays. We need to concatenate them.
156
- res = np.concatenate(res, axis=1)
157
- return pd.DataFrame(res, columns=output_col_names)
164
+ # return a list of ndarrays. We need to deal them seperately
165
+ df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
166
+ else:
167
+ df = pd.DataFrame(res)
168
+
169
+ return model_signature._rename_pandas_df(df, signature.outputs)
158
170
 
159
171
  return fn
160
172
 
161
173
  type_method_dict = {}
162
174
  for target_method_name, sig in model_meta.signatures.items():
163
- type_method_dict[target_method_name] = fn_factory(
164
- raw_model, [spec.name for spec in sig.outputs], target_method_name
165
- )
175
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
166
176
 
167
177
  _XGBModel = type(
168
178
  "_XGBModel",
@@ -1,10 +1,11 @@
1
1
  import os
2
+ import posixpath
2
3
  import tempfile
3
4
  import warnings
4
5
  from types import ModuleType
5
- from typing import Dict, List, Literal, Optional, Tuple, Union, overload
6
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union, overload
6
7
 
7
- from snowflake.ml._internal import file_utils
8
+ from snowflake.ml._internal import file_utils, type_utils
8
9
  from snowflake.ml.model import (
9
10
  _env,
10
11
  _model_handler,
@@ -13,9 +14,11 @@ from snowflake.ml.model import (
13
14
  model_signature,
14
15
  type_hints as model_types,
15
16
  )
16
- from snowflake.ml.modeling.framework import base
17
17
  from snowflake.snowpark import FileOperation, Session
18
18
 
19
+ if TYPE_CHECKING:
20
+ from snowflake.ml.modeling.framework import base
21
+
19
22
  MODEL_BLOBS_DIR = "models"
20
23
 
21
24
 
@@ -23,7 +26,7 @@ MODEL_BLOBS_DIR = "models"
23
26
  def save_model(
24
27
  *,
25
28
  name: str,
26
- model: base.BaseEstimator,
29
+ model: "base.BaseEstimator",
27
30
  model_dir_path: str,
28
31
  metadata: Optional[Dict[str, str]] = None,
29
32
  conda_dependencies: Optional[List[str]] = None,
@@ -135,7 +138,7 @@ def save_model(
135
138
  def save_model(
136
139
  *,
137
140
  name: str,
138
- model: base.BaseEstimator,
141
+ model: "base.BaseEstimator",
139
142
  session: Session,
140
143
  model_stage_file_path: str,
141
144
  metadata: Optional[Dict[str, str]] = None,
@@ -322,9 +325,11 @@ def save_model(
322
325
  + f"{'None' if model_stage_file_path is None else 'specified'} at the same time."
323
326
  )
324
327
 
325
- if ((signatures is None) and (sample_input is None) and not isinstance(model, base.BaseEstimator)) or (
326
- (signatures is not None) and (sample_input is not None)
327
- ):
328
+ if (
329
+ (signatures is None)
330
+ and (sample_input is None)
331
+ and not type_utils.LazyType("snowflake.ml.modeling.framework.base.BaseEstimator").isinstance(model)
332
+ ) or ((signatures is not None) and (sample_input is not None)):
328
333
  raise ValueError(
329
334
  "Signatures and sample_input both cannot be "
330
335
  + f"{'None for local model' if signatures is None else 'specified'} at the same time."
@@ -360,8 +365,8 @@ def save_model(
360
365
  )
361
366
 
362
367
  assert session and model_stage_file_path
363
- if os.path.splitext(model_stage_file_path)[1] != ".zip":
364
- raise ValueError("Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
368
+ if posixpath.splitext(model_stage_file_path)[1] != ".zip":
369
+ raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
365
370
 
366
371
  with tempfile.TemporaryDirectory() as temp_local_model_dir_path:
367
372
  meta = _save(
@@ -397,15 +402,15 @@ def _save(
397
402
  name: str,
398
403
  model: model_types.SupportedModelType,
399
404
  local_dir_path: str,
400
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
401
- sample_input: Optional[model_types.SupportedDataType] = None,
402
- metadata: Optional[Dict[str, str]] = None,
403
- conda_dependencies: Optional[List[str]] = None,
404
- pip_requirements: Optional[List[str]] = None,
405
- python_version: Optional[str] = None,
406
- ext_modules: Optional[List[ModuleType]] = None,
407
- code_paths: Optional[List[str]] = None,
408
- options: Optional[model_types.ModelSaveOption] = None,
405
+ signatures: Optional[Dict[str, model_signature.ModelSignature]],
406
+ sample_input: Optional[model_types.SupportedDataType],
407
+ metadata: Optional[Dict[str, str]],
408
+ conda_dependencies: Optional[List[str]],
409
+ pip_requirements: Optional[List[str]],
410
+ python_version: Optional[str],
411
+ ext_modules: Optional[List[ModuleType]],
412
+ code_paths: Optional[List[str]],
413
+ options: model_types.ModelSaveOption,
409
414
  ) -> _model_meta.ModelMetadata:
410
415
  local_dir_path = os.path.normpath(local_dir_path)
411
416
 
@@ -423,6 +428,7 @@ def _save(
423
428
  conda_dependencies=conda_dependencies,
424
429
  pip_requirements=pip_requirements,
425
430
  python_version=python_version,
431
+ **options,
426
432
  ) as meta:
427
433
  model_blobs_path = os.path.join(local_dir_path, MODEL_BLOBS_DIR)
428
434
  os.makedirs(model_blobs_path, exist_ok=True)
@@ -538,8 +544,8 @@ def load_model(
538
544
  return _load(local_dir_path=model_dir_path, meta_only=meta_only)
539
545
 
540
546
  assert session and model_stage_file_path
541
- if os.path.splitext(model_stage_file_path)[1] != ".zip":
542
- raise ValueError("Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
547
+ if posixpath.splitext(model_stage_file_path)[1] != ".zip":
548
+ raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
543
549
 
544
550
  fo = FileOperation(session=session)
545
551
  zf = fo.get_stream(model_stage_file_path)