snowflake-ml-python 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. snowflake/ml/_internal/env_utils.py +2 -1
  2. snowflake/ml/_internal/file_utils.py +29 -7
  3. snowflake/ml/_internal/telemetry.py +5 -8
  4. snowflake/ml/_internal/utils/uri.py +7 -2
  5. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
  6. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
  7. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
  8. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
  9. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
  10. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
  11. snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
  12. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
  13. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
  14. snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
  15. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
  16. snowflake/ml/model/_deploy_client/warehouse/deploy.py +24 -6
  17. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +5 -2
  18. snowflake/ml/model/_deployer.py +14 -27
  19. snowflake/ml/model/_env.py +4 -4
  20. snowflake/ml/model/_handlers/custom.py +14 -2
  21. snowflake/ml/model/_handlers/pytorch.py +186 -0
  22. snowflake/ml/model/_handlers/sklearn.py +14 -9
  23. snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
  24. snowflake/ml/model/_handlers/torchscript.py +180 -0
  25. snowflake/ml/model/_handlers/xgboost.py +19 -9
  26. snowflake/ml/model/_model.py +3 -2
  27. snowflake/ml/model/_model_meta.py +12 -7
  28. snowflake/ml/model/model_signature.py +446 -66
  29. snowflake/ml/model/type_hints.py +23 -4
  30. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -26
  31. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -26
  32. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -26
  33. snowflake/ml/modeling/cluster/birch.py +51 -26
  34. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -26
  35. snowflake/ml/modeling/cluster/dbscan.py +51 -26
  36. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -26
  37. snowflake/ml/modeling/cluster/k_means.py +51 -26
  38. snowflake/ml/modeling/cluster/mean_shift.py +51 -26
  39. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -26
  40. snowflake/ml/modeling/cluster/optics.py +51 -26
  41. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -26
  42. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -26
  43. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -26
  44. snowflake/ml/modeling/compose/column_transformer.py +51 -26
  45. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -26
  46. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -26
  47. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -26
  48. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -26
  49. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -26
  50. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -26
  51. snowflake/ml/modeling/covariance/min_cov_det.py +51 -26
  52. snowflake/ml/modeling/covariance/oas.py +51 -26
  53. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -26
  54. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -26
  55. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -26
  56. snowflake/ml/modeling/decomposition/fast_ica.py +51 -26
  57. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -26
  58. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -26
  59. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -26
  60. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -26
  61. snowflake/ml/modeling/decomposition/pca.py +51 -26
  62. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -26
  63. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -26
  64. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -26
  65. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -26
  66. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -26
  67. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -26
  68. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -26
  69. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -26
  70. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -26
  71. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -26
  72. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -26
  73. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -26
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -26
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -26
  76. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -26
  77. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -26
  78. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -26
  79. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -26
  80. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -26
  81. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -26
  82. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -26
  83. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -26
  84. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -26
  85. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -26
  86. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -26
  87. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -26
  88. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -26
  89. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -26
  90. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -26
  91. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -26
  92. snowflake/ml/modeling/impute/iterative_imputer.py +51 -26
  93. snowflake/ml/modeling/impute/knn_imputer.py +51 -26
  94. snowflake/ml/modeling/impute/missing_indicator.py +51 -26
  95. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -26
  96. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -26
  97. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -26
  98. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -26
  99. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -26
  100. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -26
  101. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -26
  102. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -26
  103. snowflake/ml/modeling/linear_model/ard_regression.py +51 -26
  104. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -26
  105. snowflake/ml/modeling/linear_model/elastic_net.py +51 -26
  106. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -26
  107. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -26
  108. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -26
  109. snowflake/ml/modeling/linear_model/lars.py +51 -26
  110. snowflake/ml/modeling/linear_model/lars_cv.py +51 -26
  111. snowflake/ml/modeling/linear_model/lasso.py +51 -26
  112. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -26
  113. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -26
  114. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -26
  115. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -26
  116. snowflake/ml/modeling/linear_model/linear_regression.py +51 -26
  117. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -26
  118. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -26
  119. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -26
  120. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -26
  121. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -26
  122. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -26
  123. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -26
  124. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -26
  125. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -26
  126. snowflake/ml/modeling/linear_model/perceptron.py +51 -26
  127. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -26
  128. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -26
  129. snowflake/ml/modeling/linear_model/ridge.py +51 -26
  130. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -26
  131. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -26
  132. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -26
  133. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -26
  134. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -26
  135. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -26
  136. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -26
  137. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -26
  138. snowflake/ml/modeling/manifold/isomap.py +51 -26
  139. snowflake/ml/modeling/manifold/mds.py +51 -26
  140. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -26
  141. snowflake/ml/modeling/manifold/tsne.py +51 -26
  142. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -26
  143. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -26
  144. snowflake/ml/modeling/model_selection/grid_search_cv.py +51 -26
  145. snowflake/ml/modeling/model_selection/randomized_search_cv.py +51 -26
  146. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -26
  147. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -26
  148. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -26
  149. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -26
  150. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -26
  151. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -26
  152. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -26
  153. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -26
  154. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -26
  155. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -26
  156. snowflake/ml/modeling/neighbors/kernel_density.py +51 -26
  157. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -26
  158. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -26
  159. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -26
  160. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -26
  161. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -26
  162. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -26
  163. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -26
  164. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -26
  165. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -26
  166. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
  167. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -26
  168. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -26
  169. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -26
  170. snowflake/ml/modeling/svm/linear_svc.py +51 -26
  171. snowflake/ml/modeling/svm/linear_svr.py +51 -26
  172. snowflake/ml/modeling/svm/nu_svc.py +51 -26
  173. snowflake/ml/modeling/svm/nu_svr.py +51 -26
  174. snowflake/ml/modeling/svm/svc.py +51 -26
  175. snowflake/ml/modeling/svm/svr.py +51 -26
  176. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -26
  177. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -26
  178. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -26
  179. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -26
  180. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -26
  181. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -26
  182. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -26
  183. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -26
  184. snowflake/ml/registry/model_registry.py +74 -56
  185. snowflake/ml/version.py +1 -1
  186. {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +27 -8
  187. snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
  188. snowflake_ml_python-1.0.2.dist-info/RECORD +0 -246
  189. {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -54,12 +54,15 @@ sys.path.insert(0, os.path.join(extracted_model_dir_path, "{code_dir_name}"))
54
54
  from snowflake.ml.model import _model
55
55
  model, meta = _model._load_model_for_deploy(extracted_model_dir_path)
56
56
 
57
+ features = meta.signatures["{target_method}"].inputs
58
+ input_cols = [feature.name for feature in features]
59
+ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
60
+
57
61
  # TODO(halu): Wire `max_batch_size`.
58
62
  # TODO(halu): Avoid per batch async detection branching.
59
63
  @vectorized(input=pd.DataFrame, max_batch_size=10)
60
64
  def infer(df):
61
- input_cols = [spec.name for spec in meta.signatures["{target_method}"].inputs]
62
- input_df = pd.io.json.json_normalize(df[0])
65
+ input_df = pd.io.json.json_normalize(df[0]).astype(dtype=dtype_map)
63
66
  if inspect.iscoroutinefunction(model.{target_method}):
64
67
  predictions_df = anyio.run(model.{target_method}, input_df[input_cols])
65
68
  else:
@@ -1,9 +1,7 @@
1
- import json
2
1
  import traceback
3
2
  from enum import Enum
4
3
  from typing import Optional, TypedDict, Union, overload
5
4
 
6
- import numpy as np
7
5
  import pandas as pd
8
6
  from typing_extensions import Required
9
7
 
@@ -184,7 +182,6 @@ def predict(
184
182
 
185
183
  Raises:
186
184
  ValueError: Raised when the input is too large to use keep_order option.
187
- NotImplementedError: FeatureGroupSpec is not supported.
188
185
 
189
186
  Returns:
190
187
  The output dataframe.
@@ -199,19 +196,19 @@ def predict(
199
196
  # Validate and prepare input
200
197
  if not isinstance(X, SnowparkDataFrame):
201
198
  df = model_signature._convert_and_validate_local_data(X, sig.inputs)
202
- s_df = session.create_dataframe(df)
199
+ s_df = model_signature._SnowparkDataFrameHandler.convert_from_df(session, df, keep_order=keep_order)
203
200
  else:
204
201
  model_signature._validate_snowpark_data(X, sig.inputs)
205
202
  s_df = X
206
203
 
207
- if keep_order:
208
- # ID is UINT64 type, this we should limit.
209
- if s_df.count() > 2**64:
210
- raise ValueError("Unable to keep order of a DataFrame with more than 2 ** 64 rows.")
211
- s_df = s_df.with_column(
212
- infer_template._KEEP_ORDER_COL_NAME,
213
- F.monotonically_increasing_id(),
214
- )
204
+ if keep_order:
205
+ # ID is UINT64 type, this we should limit.
206
+ if s_df.count() > 2**64:
207
+ raise ValueError("Unable to keep order of a DataFrame with more than 2 ** 64 rows.")
208
+ s_df = s_df.with_column(
209
+ infer_template._KEEP_ORDER_COL_NAME,
210
+ F.monotonically_increasing_id(),
211
+ )
215
212
 
216
213
  # Infer and get intermediate result
217
214
  input_cols = []
@@ -223,7 +220,9 @@ def predict(
223
220
  F.col(col_name),
224
221
  ]
225
222
  )
226
- output_obj = F.call_udf(deployment["name"], F.object_construct(*input_cols)) # type:ignore[arg-type]
223
+ output_obj = F.call_udf(
224
+ identifier.get_inferred_name(deployment["name"]), F.object_construct(*input_cols) # type:ignore[arg-type]
225
+ )
227
226
  if output_with_input_features:
228
227
  df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj)
229
228
  else:
@@ -243,24 +242,12 @@ def predict(
243
242
  output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_feature.name].astype(output_feature.as_snowpark_type()))
244
243
 
245
244
  df_res = df_res.with_columns(
246
- [identifier.quote_name_without_upper_casing(output_feature.name) for output_feature in sig.outputs],
245
+ [identifier.get_inferred_name(output_feature.name) for output_feature in sig.outputs],
247
246
  output_cols,
248
247
  ).drop(INTERMEDIATE_OBJ_NAME)
249
248
 
250
249
  # Get final result
251
250
  if not isinstance(X, SnowparkDataFrame):
252
- dtype_map = {}
253
- for feature in sig.outputs:
254
- if isinstance(feature, model_signature.FeatureGroupSpec):
255
- raise NotImplementedError("FeatureGroupSpec is not supported.")
256
- assert isinstance(feature, model_signature.FeatureSpec), "Invalid feature kind."
257
- dtype_map[feature.name] = feature.as_dtype()
258
- df_local = df_res.to_pandas()
259
- # This is because Array and object will generate variant type and requires an additional loads to
260
- # get correct data otherwise it would be string.
261
- for col_name in [col_name for col_name, col_dtype in dtype_map.items() if col_dtype == np.object0]:
262
- df_local[col_name] = df_local[col_name].map(json.loads)
263
- df_local = df_local.astype(dtype=dtype_map)
264
- return pd.DataFrame(df_local)
251
+ return model_signature._SnowparkDataFrameHandler.convert_to_df(df_res, features=sig.outputs)
265
252
  else:
266
253
  return df_res
@@ -36,7 +36,7 @@ def save_conda_env_file(
36
36
  for chan, reqs in deps.items():
37
37
  env["dependencies"].extend([f"{chan}::{str(req)}" if chan else str(req) for req in reqs])
38
38
 
39
- with open(path, "w") as f:
39
+ with open(path, "w", encoding="utf-8") as f:
40
40
  yaml.safe_dump(env, stream=f, default_flow_style=False)
41
41
 
42
42
  return path
@@ -54,7 +54,7 @@ def save_requirements_file(dir_path: str, pip_deps: List[requirements.Requiremen
54
54
  """
55
55
  requirements = "\n".join(map(str, pip_deps))
56
56
  path = os.path.join(dir_path, _REQUIREMENTS_FILE_NAME)
57
- with open(path, "w") as out:
57
+ with open(path, "w", encoding="utf-8") as out:
58
58
  out.write(requirements)
59
59
 
60
60
  return path
@@ -69,7 +69,7 @@ def load_conda_env_file(path: str) -> Tuple[DefaultDict[str, List[requirements.R
69
69
  Returns:
70
70
  A tuple of Dict of conda dependencies after validated and a string 'major.minor.patchlevel' of python version.
71
71
  """
72
- with open(path) as f:
72
+ with open(path, encoding="utf-8") as f:
73
73
  env = yaml.safe_load(stream=f)
74
74
 
75
75
  assert isinstance(env, dict)
@@ -99,7 +99,7 @@ def load_requirements_file(path: str) -> List[requirements.Requirement]:
99
99
  Returns:
100
100
  List of dependencies string after validated.
101
101
  """
102
- with open(path) as f:
102
+ with open(path, encoding="utf-8") as f:
103
103
  reqs = f.readlines()
104
104
 
105
105
  return env_utils.validate_pip_requirement_string_list(reqs)
@@ -1,16 +1,19 @@
1
1
  import inspect
2
2
  import os
3
+ import pathlib
3
4
  import sys
4
5
  from typing import TYPE_CHECKING, Dict, Optional
5
6
 
6
7
  import anyio
7
8
  import cloudpickle
9
+ import pandas as pd
8
10
  from typing_extensions import TypeGuard, Unpack
9
11
 
10
12
  from snowflake.ml._internal import file_utils, type_utils
11
13
  from snowflake.ml.model import (
12
14
  _model_handler,
13
15
  _model_meta as model_meta_api,
16
+ model_signature,
14
17
  type_hints as model_types,
15
18
  )
16
19
  from snowflake.ml.model._handlers import _base
@@ -55,6 +58,10 @@ class _CustomModelHandler(_base._ModelHandler["custom_model.CustomModel"]):
55
58
  target_method = getattr(model, target_method_name, None)
56
59
  assert callable(target_method) and inspect.ismethod(target_method)
57
60
  target_method = target_method.__func__
61
+
62
+ if not isinstance(sample_input, pd.DataFrame):
63
+ sample_input = model_signature._convert_local_data_to_df(sample_input)
64
+
58
65
  if inspect.iscoroutinefunction(target_method):
59
66
  with anyio.start_blocking_portal() as portal:
60
67
  predictions_df = portal.call(target_method, model, sample_input)
@@ -102,7 +109,9 @@ class _CustomModelHandler(_base._ModelHandler["custom_model.CustomModel"]):
102
109
  model_type=_CustomModelHandler.handler_type,
103
110
  path=_CustomModelHandler.MODEL_BLOB_FILE,
104
111
  artifacts={
105
- name: os.path.join(_CustomModelHandler.MODEL_ARTIFACTS_DIR, os.path.basename(os.path.normpath(uri)))
112
+ name: pathlib.Path(
113
+ os.path.join(_CustomModelHandler.MODEL_ARTIFACTS_DIR, os.path.basename(os.path.normpath(path=uri)))
114
+ ).as_posix()
106
115
  for name, uri in model.context.artifacts.items()
107
116
  },
108
117
  )
@@ -129,7 +138,10 @@ class _CustomModelHandler(_base._ModelHandler["custom_model.CustomModel"]):
129
138
  assert issubclass(ModelClass, custom_model.CustomModel)
130
139
 
131
140
  artifacts_meta = model_blob_metadata.artifacts
132
- artifacts = {name: os.path.join(model_blob_path, rel_path) for name, rel_path in artifacts_meta.items()}
141
+ artifacts = {
142
+ name: str(pathlib.PurePath(model_blob_path) / pathlib.PurePosixPath(rel_path))
143
+ for name, rel_path in artifacts_meta.items()
144
+ }
133
145
  models: Dict[str, model_types.SupportedModelType] = dict()
134
146
  for sub_model_name, _ref in m.context.model_refs.items():
135
147
  model_type = model_meta.models[sub_model_name].model_type
@@ -0,0 +1,186 @@
1
+ import os
2
+ import sys
3
+ from typing import TYPE_CHECKING, Callable, Optional, Type, cast
4
+
5
+ import cloudpickle
6
+ import pandas as pd
7
+ from typing_extensions import TypeGuard, Unpack
8
+
9
+ from snowflake.ml._internal import type_utils
10
+ from snowflake.ml.model import (
11
+ _model_meta as model_meta_api,
12
+ custom_model,
13
+ model_signature,
14
+ type_hints as model_types,
15
+ )
16
+ from snowflake.ml.model._handlers import _base
17
+
18
+ if TYPE_CHECKING:
19
+ import torch
20
+
21
+
22
+ class _PyTorchHandler(_base._ModelHandler["torch.nn.Module"]):
23
+ """Handler for PyTorch based model.
24
+
25
+ Currently torch.nn.Module based classes are supported.
26
+ """
27
+
28
+ handler_type = "pytorch"
29
+ MODEL_BLOB_FILE = "model.pt"
30
+ DEFAULT_TARGET_METHODS = ["forward"]
31
+
32
+ @staticmethod
33
+ def can_handle(
34
+ model: model_types.SupportedModelType,
35
+ ) -> TypeGuard["torch.nn.Module"]:
36
+ return type_utils.LazyType("torch.nn.Module").isinstance(model) and not type_utils.LazyType(
37
+ "torch.jit.ScriptModule"
38
+ ).isinstance(model)
39
+
40
+ @staticmethod
41
+ def cast_model(
42
+ model: model_types.SupportedModelType,
43
+ ) -> "torch.nn.Module":
44
+ import torch
45
+
46
+ assert isinstance(model, torch.nn.Module)
47
+
48
+ return cast(torch.nn.Module, model)
49
+
50
+ @staticmethod
51
+ def _save_model(
52
+ name: str,
53
+ model: "torch.nn.Module",
54
+ model_meta: model_meta_api.ModelMetadata,
55
+ model_blobs_dir_path: str,
56
+ sample_input: Optional[model_types.SupportedDataType] = None,
57
+ is_sub_model: Optional[bool] = False,
58
+ **kwargs: Unpack[model_types.PyTorchSaveOptions],
59
+ ) -> None:
60
+ import torch
61
+
62
+ assert isinstance(model, torch.nn.Module)
63
+
64
+ if not is_sub_model:
65
+ target_methods = model_meta_api._get_target_methods(
66
+ model=model,
67
+ target_methods=kwargs.pop("target_methods", None),
68
+ default_target_methods=_PyTorchHandler.DEFAULT_TARGET_METHODS,
69
+ )
70
+
71
+ def get_prediction(
72
+ target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
73
+ ) -> model_types.SupportedLocalDataType:
74
+ if not model_signature._SeqOfPyTorchTensorHandler.can_handle(sample_input):
75
+ sample_input = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(
76
+ model_signature._convert_local_data_to_df(sample_input)
77
+ )
78
+
79
+ model.eval()
80
+ target_method = getattr(model, target_method_name, None)
81
+ assert callable(target_method)
82
+ with torch.no_grad():
83
+ predictions_df = target_method(sample_input)
84
+ return predictions_df
85
+
86
+ model_meta = model_meta_api._validate_signature(
87
+ model=model,
88
+ model_meta=model_meta,
89
+ target_methods=target_methods,
90
+ sample_input=sample_input,
91
+ get_prediction_fn=get_prediction,
92
+ )
93
+
94
+ # Torch.save using pickle will not pickle the model definition if defined in the top level of a module.
95
+ # Make sure that the module where the model is defined get pickled by value as well.
96
+ cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
97
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
98
+ os.makedirs(model_blob_path, exist_ok=True)
99
+ with open(os.path.join(model_blob_path, _PyTorchHandler.MODEL_BLOB_FILE), "wb") as f:
100
+ torch.save(model, f, pickle_module=cloudpickle)
101
+ base_meta = model_meta_api._ModelBlobMetadata(
102
+ name=name, model_type=_PyTorchHandler.handler_type, path=_PyTorchHandler.MODEL_BLOB_FILE
103
+ )
104
+ model_meta.models[name] = base_meta
105
+ model_meta._include_if_absent([model_meta_api.Dependency(conda_name="pytorch", pip_name="torch")])
106
+
107
+ @staticmethod
108
+ def _load_model(
109
+ name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
110
+ ) -> "torch.nn.Module":
111
+ import torch
112
+
113
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
114
+ if not hasattr(model_meta, "models"):
115
+ raise ValueError("Ill model metadata found.")
116
+ model_blobs_metadata = model_meta.models
117
+ if name not in model_blobs_metadata:
118
+ raise ValueError(f"Blob of model {name} does not exist.")
119
+ model_blob_metadata = model_blobs_metadata[name]
120
+ model_blob_filename = model_blob_metadata.path
121
+ with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
122
+ m = torch.load(f)
123
+ assert isinstance(m, torch.nn.Module)
124
+ return m
125
+
126
+ @staticmethod
127
+ def _load_as_custom_model(
128
+ name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
129
+ ) -> custom_model.CustomModel:
130
+ """Create a custom model class wrap for unified interface when being deployed. The predict method will be
131
+ re-targeted based on target_method metadata.
132
+
133
+ Args:
134
+ name: Name of the model.
135
+ model_meta: The model metadata.
136
+ model_blobs_dir_path: Directory path to the whole model.
137
+
138
+ Returns:
139
+ The model object as a custom model.
140
+ """
141
+ import torch
142
+
143
+ from snowflake.ml.model import custom_model
144
+
145
+ def _create_custom_model(
146
+ raw_model: "torch.nn.Module",
147
+ model_meta: model_meta_api.ModelMetadata,
148
+ ) -> Type[custom_model.CustomModel]:
149
+ def fn_factory(
150
+ raw_model: "torch.nn.Module",
151
+ signature: model_signature.ModelSignature,
152
+ target_method: str,
153
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
154
+ @custom_model.inference_api
155
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
156
+ if X.isnull().any(axis=None):
157
+ raise ValueError("Tensor cannot handle null values.")
158
+
159
+ raw_model.eval()
160
+ t = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
161
+
162
+ with torch.no_grad():
163
+ res = getattr(raw_model, target_method)(t)
164
+ return model_signature._rename_pandas_df(
165
+ data=model_signature._SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
166
+ )
167
+
168
+ return fn
169
+
170
+ type_method_dict = {}
171
+ for target_method_name, sig in model_meta.signatures.items():
172
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
173
+
174
+ _PyTorchModel = type(
175
+ "_PyTorchModel",
176
+ (custom_model.CustomModel,),
177
+ type_method_dict,
178
+ )
179
+
180
+ return _PyTorchModel
181
+
182
+ raw_model = _PyTorchHandler._load_model(name, model_meta, model_blobs_dir_path)
183
+ _PyTorchModel = _create_custom_model(raw_model, model_meta)
184
+ pytorch_model = _PyTorchModel(custom_model.ModelContext())
185
+
186
+ return pytorch_model
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Optional, Sequence, Type, Union, cast
2
+ from typing import TYPE_CHECKING, Callable, Optional, Type, Union, cast
3
3
 
4
4
  import cloudpickle
5
5
  import numpy as np
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import (
11
11
  _model_meta as model_meta_api,
12
12
  custom_model,
13
+ model_signature,
13
14
  type_hints as model_types,
14
15
  )
15
16
  from snowflake.ml.model._handlers import _base
@@ -80,6 +81,9 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
80
81
  def get_prediction(
81
82
  target_method_name: str, sample_input: model_types.SupportedLocalDataType
82
83
  ) -> model_types.SupportedLocalDataType:
84
+ if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
85
+ sample_input = model_signature._convert_local_data_to_df(sample_input)
86
+
83
87
  target_method = getattr(model, target_method_name, None)
84
88
  assert callable(target_method)
85
89
  predictions_df = target_method(sample_input)
@@ -101,7 +105,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
101
105
  name=name, model_type=_SKLModelHandler.handler_type, path=_SKLModelHandler.MODEL_BLOB_FILE
102
106
  )
103
107
  model_meta.models[name] = base_meta
104
- model_meta._include_if_absent([("scikit-learn", "scikit-learn")])
108
+ model_meta._include_if_absent([model_meta_api.Dependency(conda_name="scikit-learn", pip_name="scikit-learn")])
105
109
 
106
110
  @staticmethod
107
111
  def _load_model(
@@ -147,7 +151,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
147
151
  ) -> Type[custom_model.CustomModel]:
148
152
  def fn_factory(
149
153
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
150
- output_col_names: Sequence[str],
154
+ signature: model_signature.ModelSignature,
151
155
  target_method: str,
152
156
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
153
157
  @custom_model.inference_api
@@ -156,17 +160,18 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
156
160
 
157
161
  if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
158
162
  # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
159
- # return a list of ndarrays. We need to concatenate them.
160
- res = np.concatenate(res, axis=1)
161
- return pd.DataFrame(res, columns=output_col_names)
163
+ # return a list of ndarrays. We need to deal them seperately
164
+ df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
165
+ else:
166
+ df = pd.DataFrame(res)
167
+
168
+ return model_signature._rename_pandas_df(df, signature.outputs)
162
169
 
163
170
  return fn
164
171
 
165
172
  type_method_dict = {}
166
173
  for target_method_name, sig in model_meta.signatures.items():
167
- type_method_dict[target_method_name] = fn_factory(
168
- raw_model, [spec.name for spec in sig.outputs], target_method_name
169
- )
174
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
170
175
 
171
176
  _SKLModel = type(
172
177
  "_SKLModel",
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Optional, Sequence, Type, cast
2
+ from typing import TYPE_CHECKING, Callable, Optional, Type, cast
3
3
 
4
4
  import cloudpickle
5
5
  import numpy as np
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import (
11
11
  _model_meta as model_meta_api,
12
12
  custom_model,
13
+ model_signature,
13
14
  type_hints as model_types,
14
15
  )
15
16
  from snowflake.ml.model._handlers import _base
@@ -81,6 +82,9 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
81
82
  def get_prediction(
82
83
  target_method_name: str, sample_input: model_types.SupportedLocalDataType
83
84
  ) -> model_types.SupportedLocalDataType:
85
+ if not isinstance(sample_input, (pd.DataFrame,)):
86
+ sample_input = model_signature._convert_local_data_to_df(sample_input)
87
+
84
88
  target_method = getattr(model, target_method_name, None)
85
89
  assert callable(target_method)
86
90
  predictions_df = target_method(sample_input)
@@ -106,7 +110,7 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
106
110
  model_dependencies = model._get_dependencies()
107
111
  for dep in model_dependencies:
108
112
  pkg_name = dep.split("==")[0]
109
- _include_if_absent_pkgs.append((pkg_name, pkg_name))
113
+ _include_if_absent_pkgs.append(model_meta_api.Dependency(conda_name=pkg_name, pip_name=pkg_name))
110
114
  model_meta._include_if_absent(_include_if_absent_pkgs)
111
115
 
112
116
  @staticmethod
@@ -150,7 +154,7 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
150
154
  ) -> Type[custom_model.CustomModel]:
151
155
  def fn_factory(
152
156
  raw_model: "BaseEstimator",
153
- output_col_names: Sequence[str],
157
+ signature: model_signature.ModelSignature,
154
158
  target_method: str,
155
159
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
156
160
  @custom_model.inference_api
@@ -159,17 +163,18 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
159
163
 
160
164
  if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
161
165
  # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
162
- # return a list of ndarrays. We need to concatenate them.
163
- res = np.concatenate(res, axis=1)
164
- return pd.DataFrame(res, columns=output_col_names)
166
+ # return a list of ndarrays. We need to deal them seperately
167
+ df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
168
+ else:
169
+ df = pd.DataFrame(res)
170
+
171
+ return model_signature._rename_pandas_df(df, signature.outputs)
165
172
 
166
173
  return fn
167
174
 
168
175
  type_method_dict = {}
169
176
  for target_method_name, sig in model_meta.signatures.items():
170
- type_method_dict[target_method_name] = fn_factory(
171
- raw_model, [spec.name for spec in sig.outputs], target_method_name
172
- )
177
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
173
178
 
174
179
  _SnowMLModel = type(
175
180
  "_SnowMLModel",