snowflake-ml-python 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. snowflake/ml/_internal/env_utils.py +2 -1
  2. snowflake/ml/_internal/file_utils.py +29 -7
  3. snowflake/ml/_internal/telemetry.py +5 -8
  4. snowflake/ml/_internal/utils/uri.py +7 -2
  5. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
  6. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
  7. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
  8. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
  9. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
  10. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
  11. snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
  12. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
  13. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
  14. snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
  15. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
  16. snowflake/ml/model/_deploy_client/warehouse/deploy.py +24 -6
  17. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +5 -2
  18. snowflake/ml/model/_deployer.py +14 -27
  19. snowflake/ml/model/_env.py +4 -4
  20. snowflake/ml/model/_handlers/custom.py +14 -2
  21. snowflake/ml/model/_handlers/pytorch.py +186 -0
  22. snowflake/ml/model/_handlers/sklearn.py +14 -9
  23. snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
  24. snowflake/ml/model/_handlers/torchscript.py +180 -0
  25. snowflake/ml/model/_handlers/xgboost.py +19 -9
  26. snowflake/ml/model/_model.py +3 -2
  27. snowflake/ml/model/_model_meta.py +12 -7
  28. snowflake/ml/model/model_signature.py +446 -66
  29. snowflake/ml/model/type_hints.py +23 -4
  30. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -26
  31. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -26
  32. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -26
  33. snowflake/ml/modeling/cluster/birch.py +51 -26
  34. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -26
  35. snowflake/ml/modeling/cluster/dbscan.py +51 -26
  36. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -26
  37. snowflake/ml/modeling/cluster/k_means.py +51 -26
  38. snowflake/ml/modeling/cluster/mean_shift.py +51 -26
  39. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -26
  40. snowflake/ml/modeling/cluster/optics.py +51 -26
  41. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -26
  42. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -26
  43. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -26
  44. snowflake/ml/modeling/compose/column_transformer.py +51 -26
  45. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -26
  46. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -26
  47. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -26
  48. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -26
  49. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -26
  50. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -26
  51. snowflake/ml/modeling/covariance/min_cov_det.py +51 -26
  52. snowflake/ml/modeling/covariance/oas.py +51 -26
  53. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -26
  54. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -26
  55. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -26
  56. snowflake/ml/modeling/decomposition/fast_ica.py +51 -26
  57. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -26
  58. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -26
  59. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -26
  60. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -26
  61. snowflake/ml/modeling/decomposition/pca.py +51 -26
  62. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -26
  63. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -26
  64. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -26
  65. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -26
  66. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -26
  67. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -26
  68. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -26
  69. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -26
  70. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -26
  71. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -26
  72. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -26
  73. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -26
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -26
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -26
  76. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -26
  77. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -26
  78. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -26
  79. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -26
  80. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -26
  81. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -26
  82. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -26
  83. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -26
  84. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -26
  85. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -26
  86. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -26
  87. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -26
  88. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -26
  89. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -26
  90. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -26
  91. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -26
  92. snowflake/ml/modeling/impute/iterative_imputer.py +51 -26
  93. snowflake/ml/modeling/impute/knn_imputer.py +51 -26
  94. snowflake/ml/modeling/impute/missing_indicator.py +51 -26
  95. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -26
  96. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -26
  97. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -26
  98. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -26
  99. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -26
  100. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -26
  101. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -26
  102. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -26
  103. snowflake/ml/modeling/linear_model/ard_regression.py +51 -26
  104. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -26
  105. snowflake/ml/modeling/linear_model/elastic_net.py +51 -26
  106. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -26
  107. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -26
  108. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -26
  109. snowflake/ml/modeling/linear_model/lars.py +51 -26
  110. snowflake/ml/modeling/linear_model/lars_cv.py +51 -26
  111. snowflake/ml/modeling/linear_model/lasso.py +51 -26
  112. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -26
  113. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -26
  114. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -26
  115. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -26
  116. snowflake/ml/modeling/linear_model/linear_regression.py +51 -26
  117. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -26
  118. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -26
  119. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -26
  120. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -26
  121. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -26
  122. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -26
  123. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -26
  124. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -26
  125. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -26
  126. snowflake/ml/modeling/linear_model/perceptron.py +51 -26
  127. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -26
  128. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -26
  129. snowflake/ml/modeling/linear_model/ridge.py +51 -26
  130. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -26
  131. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -26
  132. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -26
  133. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -26
  134. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -26
  135. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -26
  136. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -26
  137. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -26
  138. snowflake/ml/modeling/manifold/isomap.py +51 -26
  139. snowflake/ml/modeling/manifold/mds.py +51 -26
  140. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -26
  141. snowflake/ml/modeling/manifold/tsne.py +51 -26
  142. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -26
  143. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -26
  144. snowflake/ml/modeling/model_selection/grid_search_cv.py +51 -26
  145. snowflake/ml/modeling/model_selection/randomized_search_cv.py +51 -26
  146. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -26
  147. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -26
  148. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -26
  149. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -26
  150. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -26
  151. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -26
  152. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -26
  153. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -26
  154. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -26
  155. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -26
  156. snowflake/ml/modeling/neighbors/kernel_density.py +51 -26
  157. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -26
  158. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -26
  159. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -26
  160. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -26
  161. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -26
  162. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -26
  163. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -26
  164. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -26
  165. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -26
  166. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
  167. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -26
  168. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -26
  169. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -26
  170. snowflake/ml/modeling/svm/linear_svc.py +51 -26
  171. snowflake/ml/modeling/svm/linear_svr.py +51 -26
  172. snowflake/ml/modeling/svm/nu_svc.py +51 -26
  173. snowflake/ml/modeling/svm/nu_svr.py +51 -26
  174. snowflake/ml/modeling/svm/svc.py +51 -26
  175. snowflake/ml/modeling/svm/svr.py +51 -26
  176. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -26
  177. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -26
  178. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -26
  179. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -26
  180. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -26
  181. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -26
  182. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -26
  183. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -26
  184. snowflake/ml/registry/model_registry.py +74 -56
  185. snowflake/ml/version.py +1 -1
  186. {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +27 -8
  187. snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
  188. snowflake_ml_python-1.0.2.dist-info/RECORD +0 -246
  189. {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,180 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Callable, Optional, Type, cast
3
+
4
+ import pandas as pd
5
+ from typing_extensions import TypeGuard, Unpack
6
+
7
+ from snowflake.ml._internal import type_utils
8
+ from snowflake.ml.model import (
9
+ _model_meta as model_meta_api,
10
+ custom_model,
11
+ model_signature,
12
+ type_hints as model_types,
13
+ )
14
+ from snowflake.ml.model._handlers import _base
15
+
16
+ if TYPE_CHECKING:
17
+ import torch
18
+
19
+
20
+ class _TorchScriptHandler(_base._ModelHandler["torch.jit.ScriptModule"]): # type:ignore[name-defined]
21
+ """Handler for PyTorch JIT based model.
22
+
23
+ Currently torch.jit.ScriptModule based classes are supported.
24
+ """
25
+
26
+ handler_type = "torchscript"
27
+ MODEL_BLOB_FILE = "model.pt"
28
+ DEFAULT_TARGET_METHODS = ["forward"]
29
+
30
+ @staticmethod
31
+ def can_handle(
32
+ model: model_types.SupportedModelType,
33
+ ) -> TypeGuard["torch.jit.ScriptModule"]: # type:ignore[name-defined]
34
+ return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
35
+
36
+ @staticmethod
37
+ def cast_model(
38
+ model: model_types.SupportedModelType,
39
+ ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
40
+ import torch
41
+
42
+ assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
43
+
44
+ return cast(torch.jit.ScriptModule, model) # type:ignore[name-defined]
45
+
46
+ @staticmethod
47
+ def _save_model(
48
+ name: str,
49
+ model: "torch.jit.ScriptModule", # type:ignore[name-defined]
50
+ model_meta: model_meta_api.ModelMetadata,
51
+ model_blobs_dir_path: str,
52
+ sample_input: Optional[model_types.SupportedDataType] = None,
53
+ is_sub_model: Optional[bool] = False,
54
+ **kwargs: Unpack[model_types.TorchScriptSaveOptions],
55
+ ) -> None:
56
+ import torch
57
+
58
+ assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
59
+
60
+ if not is_sub_model:
61
+ target_methods = model_meta_api._get_target_methods(
62
+ model=model,
63
+ target_methods=kwargs.pop("target_methods", None),
64
+ default_target_methods=_TorchScriptHandler.DEFAULT_TARGET_METHODS,
65
+ )
66
+
67
+ def get_prediction(
68
+ target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
69
+ ) -> model_types.SupportedLocalDataType:
70
+ if not model_signature._SeqOfPyTorchTensorHandler.can_handle(sample_input):
71
+ sample_input = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(
72
+ model_signature._convert_local_data_to_df(sample_input)
73
+ )
74
+
75
+ model.eval()
76
+ target_method = getattr(model, target_method_name, None)
77
+ assert callable(target_method)
78
+ with torch.no_grad():
79
+ predictions_df = target_method(sample_input)
80
+ return predictions_df
81
+
82
+ model_meta = model_meta_api._validate_signature(
83
+ model=model,
84
+ model_meta=model_meta,
85
+ target_methods=target_methods,
86
+ sample_input=sample_input,
87
+ get_prediction_fn=get_prediction,
88
+ )
89
+
90
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
91
+ os.makedirs(model_blob_path, exist_ok=True)
92
+ with open(os.path.join(model_blob_path, _TorchScriptHandler.MODEL_BLOB_FILE), "wb") as f:
93
+ torch.jit.save(model, f) # type:ignore[attr-defined]
94
+ base_meta = model_meta_api._ModelBlobMetadata(
95
+ name=name, model_type=_TorchScriptHandler.handler_type, path=_TorchScriptHandler.MODEL_BLOB_FILE
96
+ )
97
+ model_meta.models[name] = base_meta
98
+ model_meta._include_if_absent([model_meta_api.Dependency(conda_name="pytorch", pip_name="torch")])
99
+
100
+ @staticmethod
101
+ def _load_model(
102
+ name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
103
+ ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
104
+ import torch
105
+
106
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
107
+ if not hasattr(model_meta, "models"):
108
+ raise ValueError("Ill model metadata found.")
109
+ model_blobs_metadata = model_meta.models
110
+ if name not in model_blobs_metadata:
111
+ raise ValueError(f"Blob of model {name} does not exist.")
112
+ model_blob_metadata = model_blobs_metadata[name]
113
+ model_blob_filename = model_blob_metadata.path
114
+ with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
115
+ m = torch.jit.load(f) # type:ignore[attr-defined]
116
+ assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
117
+ return m
118
+
119
+ @staticmethod
120
+ def _load_as_custom_model(
121
+ name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
122
+ ) -> custom_model.CustomModel:
123
+ """Create a custom model class wrap for unified interface when being deployed. The predict method will be
124
+ re-targeted based on target_method metadata.
125
+
126
+ Args:
127
+ name: Name of the model.
128
+ model_meta: The model metadata.
129
+ model_blobs_dir_path: Directory path to the whole model.
130
+
131
+ Returns:
132
+ The model object as a custom model.
133
+ """
134
+ from snowflake.ml.model import custom_model
135
+
136
+ def _create_custom_model(
137
+ raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
138
+ model_meta: model_meta_api.ModelMetadata,
139
+ ) -> Type[custom_model.CustomModel]:
140
+ def fn_factory(
141
+ raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
142
+ signature: model_signature.ModelSignature,
143
+ target_method: str,
144
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
145
+ @custom_model.inference_api
146
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
147
+ if X.isnull().any(axis=None):
148
+ raise ValueError("Tensor cannot handle null values.")
149
+
150
+ import torch
151
+
152
+ raw_model.eval()
153
+
154
+ t = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
155
+
156
+ with torch.no_grad():
157
+ res = getattr(raw_model, target_method)(t)
158
+ return model_signature._rename_pandas_df(
159
+ data=model_signature._SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
160
+ )
161
+
162
+ return fn
163
+
164
+ type_method_dict = {}
165
+ for target_method_name, sig in model_meta.signatures.items():
166
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
167
+
168
+ _TorchScriptModel = type(
169
+ "_TorchScriptModel",
170
+ (custom_model.CustomModel,),
171
+ type_method_dict,
172
+ )
173
+
174
+ return _TorchScriptModel
175
+
176
+ raw_model = _TorchScriptHandler._load_model(name, model_meta, model_blobs_dir_path)
177
+ _TorchScriptModel = _create_custom_model(raw_model, model_meta)
178
+ torchscript_model = _TorchScriptModel(custom_model.ModelContext())
179
+
180
+ return torchscript_model
@@ -1,6 +1,6 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
- from typing import TYPE_CHECKING, Callable, Optional, Sequence, Type, Union
3
+ from typing import TYPE_CHECKING, Callable, Optional, Type, Union
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import (
11
11
  _model_meta as model_meta_api,
12
12
  custom_model,
13
+ model_signature,
13
14
  type_hints as model_types,
14
15
  )
15
16
  from snowflake.ml.model._handlers import _base
@@ -72,6 +73,9 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
72
73
  def get_prediction(
73
74
  target_method_name: str, sample_input: model_types.SupportedLocalDataType
74
75
  ) -> model_types.SupportedLocalDataType:
76
+ if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
77
+ sample_input = model_signature._convert_local_data_to_df(sample_input)
78
+
75
79
  target_method = getattr(model, target_method_name, None)
76
80
  assert callable(target_method)
77
81
  predictions_df = target_method(sample_input)
@@ -95,7 +99,12 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
95
99
  options={"xgb_estimator_type": model.__class__.__name__},
96
100
  )
97
101
  model_meta.models[name] = base_meta
98
- model_meta._include_if_absent([("scikit-learn", "scikit-learn"), ("xgboost", "xgboost")])
102
+ model_meta._include_if_absent(
103
+ [
104
+ model_meta_api.Dependency(conda_name="scikit-learn", pip_name="scikit-learn"),
105
+ model_meta_api.Dependency(conda_name="xgboost", pip_name="xgboost"),
106
+ ]
107
+ )
99
108
 
100
109
  @staticmethod
101
110
  def _load_model(
@@ -143,7 +152,7 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
143
152
  ) -> Type[custom_model.CustomModel]:
144
153
  def fn_factory(
145
154
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
146
- output_col_names: Sequence[str],
155
+ signature: model_signature.ModelSignature,
147
156
  target_method: str,
148
157
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
149
158
  @custom_model.inference_api
@@ -152,17 +161,18 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
152
161
 
153
162
  if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
154
163
  # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
155
- # return a list of ndarrays. We need to concatenate them.
156
- res = np.concatenate(res, axis=1)
157
- return pd.DataFrame(res, columns=output_col_names)
164
+ # return a list of ndarrays. We need to deal them seperately
165
+ df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
166
+ else:
167
+ df = pd.DataFrame(res)
168
+
169
+ return model_signature._rename_pandas_df(df, signature.outputs)
158
170
 
159
171
  return fn
160
172
 
161
173
  type_method_dict = {}
162
174
  for target_method_name, sig in model_meta.signatures.items():
163
- type_method_dict[target_method_name] = fn_factory(
164
- raw_model, [spec.name for spec in sig.outputs], target_method_name
165
- )
175
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
166
176
 
167
177
  _XGBModel = type(
168
178
  "_XGBModel",
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import posixpath
2
3
  import tempfile
3
4
  import warnings
4
5
  from types import ModuleType
@@ -364,7 +365,7 @@ def save_model(
364
365
  )
365
366
 
366
367
  assert session and model_stage_file_path
367
- if os.path.splitext(model_stage_file_path)[1] != ".zip":
368
+ if posixpath.splitext(model_stage_file_path)[1] != ".zip":
368
369
  raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
369
370
 
370
371
  with tempfile.TemporaryDirectory() as temp_local_model_dir_path:
@@ -543,7 +544,7 @@ def load_model(
543
544
  return _load(local_dir_path=model_dir_path, meta_only=meta_only)
544
545
 
545
546
  assert session and model_stage_file_path
546
- if os.path.splitext(model_stage_file_path)[1] != ".zip":
547
+ if posixpath.splitext(model_stage_file_path)[1] != ".zip":
547
548
  raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
548
549
 
549
550
  fo = FileOperation(session=session)
@@ -3,10 +3,11 @@ import importlib
3
3
  import os
4
4
  import sys
5
5
  import warnings
6
+ from collections import namedtuple
6
7
  from contextlib import contextmanager
7
8
  from datetime import datetime
8
9
  from types import ModuleType
9
- from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, cast
10
+ from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, cast
10
11
 
11
12
  import cloudpickle
12
13
  import yaml
@@ -24,6 +25,8 @@ from snowflake.snowpark import DataFrame as SnowparkDataFrame
24
25
  MODEL_METADATA_VERSION = 1
25
26
  _BASIC_DEPENDENCIES = _core_requirements.REQUIREMENTS
26
27
 
28
+ Dependency = namedtuple("Dependency", ["conda_name", "pip_name"])
29
+
27
30
 
28
31
  @dataclasses.dataclass
29
32
  class _ModelBlobMetadata:
@@ -214,9 +217,11 @@ class ModelMetadata:
214
217
  pip_requirements if pip_requirements else []
215
218
  )
216
219
  if "local_ml_library_version" in kwargs:
217
- self._include_if_absent([(dep, dep) for dep in _BASIC_DEPENDENCIES])
220
+ self._include_if_absent([Dependency(conda_name=dep, pip_name=dep) for dep in _BASIC_DEPENDENCIES])
218
221
  else:
219
- self._include_if_absent([(dep, dep) for dep in _BASIC_DEPENDENCIES + [env_utils._SNOWML_PKG_NAME]])
222
+ self._include_if_absent(
223
+ [Dependency(conda_name=dep, pip_name=dep) for dep in _BASIC_DEPENDENCIES + [env_utils._SNOWML_PKG_NAME]]
224
+ )
220
225
 
221
226
  self.__dict__.update(kwargs)
222
227
 
@@ -234,7 +239,7 @@ class ModelMetadata:
234
239
  for req in reqs
235
240
  )
236
241
 
237
- def _include_if_absent(self, pkgs: List[Tuple[str, str]]) -> None:
242
+ def _include_if_absent(self, pkgs: List[Dependency]) -> None:
238
243
  conda_reqs_str, pip_reqs_str = tuple(zip(*pkgs))
239
244
  pip_reqs = env_utils.validate_pip_requirement_string_list(list(pip_reqs_str))
240
245
  conda_reqs = env_utils.validate_conda_dependency_string_list(list(conda_reqs_str))
@@ -327,7 +332,7 @@ class ModelMetadata:
327
332
  path: The path of the directory to write a yaml file in it.
328
333
  """
329
334
  model_yaml_path = os.path.join(path, ModelMetadata.MODEL_METADATA_FILE)
330
- with open(model_yaml_path, "w") as out:
335
+ with open(model_yaml_path, "w", encoding="utf-8") as out:
331
336
  yaml.safe_dump({**self.to_dict(), "version": MODEL_METADATA_VERSION}, stream=out, default_flow_style=False)
332
337
 
333
338
  env_dir_path = os.path.join(path, ModelMetadata.ENV_DIR)
@@ -350,7 +355,7 @@ class ModelMetadata:
350
355
  Loaded model metadata object.
351
356
  """
352
357
  model_yaml_path = os.path.join(path, ModelMetadata.MODEL_METADATA_FILE)
353
- with open(model_yaml_path) as f:
358
+ with open(model_yaml_path, encoding="utf-8") as f:
354
359
  loaded_mata = yaml.safe_load(f.read())
355
360
 
356
361
  loaded_mata_version = loaded_mata.pop("version", None)
@@ -392,7 +397,7 @@ def _validate_signature(
392
397
  if isinstance(sample_input, SnowparkDataFrame):
393
398
  # Added because of Any from missing stubs.
394
399
  trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
395
- local_sample_input = trunc_sample_input.to_pandas()
400
+ local_sample_input = model_signature._SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
396
401
  else:
397
402
  local_sample_input = trunc_sample_input
398
403
  for target_method in target_methods: