snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +240 -16
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_sse_client.py +81 -0
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +34 -10
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  16. snowflake/ml/_internal/telemetry.py +26 -0
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/dataset/dataset.py +54 -32
  20. snowflake/ml/dataset/dataset_factory.py +3 -4
  21. snowflake/ml/feature_store/feature_store.py +440 -243
  22. snowflake/ml/feature_store/feature_view.py +61 -9
  23. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  24. snowflake/ml/fileset/fileset.py +2 -2
  25. snowflake/ml/fileset/snowfs.py +4 -15
  26. snowflake/ml/fileset/stage_fs.py +6 -8
  27. snowflake/ml/lineage/__init__.py +3 -0
  28. snowflake/ml/lineage/lineage_node.py +139 -0
  29. snowflake/ml/model/_client/model/model_impl.py +47 -14
  30. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  31. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  32. snowflake/ml/model/_client/sql/model.py +1 -0
  33. snowflake/ml/model/_client/sql/model_version.py +47 -4
  34. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  35. snowflake/ml/model/_model_composer/model_composer.py +7 -6
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  37. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  38. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
  40. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  41. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  42. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  43. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  45. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  46. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  53. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  56. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  57. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  58. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  59. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  60. snowflake/ml/model/_packager/model_packager.py +9 -4
  61. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  62. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  63. snowflake/ml/model/_signatures/core.py +13 -1
  64. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  65. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  66. snowflake/ml/model/custom_model.py +22 -2
  67. snowflake/ml/model/model_signature.py +2 -0
  68. snowflake/ml/model/type_hints.py +74 -4
  69. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  70. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
  71. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  72. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
  73. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
  74. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
  75. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  76. snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
  77. snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
  78. snowflake/ml/modeling/cluster/birch.py +5 -3
  79. snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
  80. snowflake/ml/modeling/cluster/dbscan.py +5 -3
  81. snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
  82. snowflake/ml/modeling/cluster/k_means.py +5 -3
  83. snowflake/ml/modeling/cluster/mean_shift.py +5 -3
  84. snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
  85. snowflake/ml/modeling/cluster/optics.py +5 -3
  86. snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
  87. snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
  88. snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
  89. snowflake/ml/modeling/compose/column_transformer.py +5 -3
  90. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  91. snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
  92. snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
  93. snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
  94. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
  95. snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
  96. snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
  97. snowflake/ml/modeling/covariance/oas.py +5 -3
  98. snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
  99. snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
  100. snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
  101. snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
  102. snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
  103. snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
  104. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
  105. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
  106. snowflake/ml/modeling/decomposition/pca.py +5 -3
  107. snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
  108. snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
  109. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  110. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  111. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  112. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  113. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  114. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  115. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  116. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  117. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  118. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  119. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  120. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  121. snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
  122. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  123. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  124. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  125. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  126. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  127. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  128. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  129. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  130. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  131. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  132. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  133. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  134. snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
  135. snowflake/ml/modeling/framework/base.py +3 -8
  136. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  137. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  138. snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
  139. snowflake/ml/modeling/impute/knn_imputer.py +5 -3
  140. snowflake/ml/modeling/impute/missing_indicator.py +5 -3
  141. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  142. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
  143. snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
  144. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
  145. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
  146. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
  147. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  148. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  149. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  151. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  152. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  153. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  154. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  155. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  156. snowflake/ml/modeling/linear_model/lars.py +1 -1
  157. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  158. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  159. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  160. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  161. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  162. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  163. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  164. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  165. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  166. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  167. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  168. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  169. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  170. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  171. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  172. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  173. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  174. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  175. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  176. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  177. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  178. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  179. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  180. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  181. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
  182. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  183. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  184. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  185. snowflake/ml/modeling/manifold/isomap.py +5 -3
  186. snowflake/ml/modeling/manifold/mds.py +5 -3
  187. snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
  188. snowflake/ml/modeling/manifold/tsne.py +5 -3
  189. snowflake/ml/modeling/metrics/ranking.py +3 -0
  190. snowflake/ml/modeling/metrics/regression.py +3 -0
  191. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
  192. snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
  193. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  194. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  195. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  196. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  197. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  198. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  199. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  200. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  201. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  202. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  203. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  204. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  205. snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
  206. snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
  207. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  208. snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
  209. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  210. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  211. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  212. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
  213. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  214. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  215. snowflake/ml/modeling/pipeline/pipeline.py +6 -0
  216. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  217. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  218. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  219. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  220. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  221. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  222. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
  223. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
  224. snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
  225. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  226. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  227. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  228. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  229. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  230. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  231. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  232. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  233. snowflake/ml/modeling/svm/svc.py +1 -1
  234. snowflake/ml/modeling/svm/svr.py +1 -1
  235. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  236. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  237. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  238. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  239. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  240. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  241. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  242. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  243. snowflake/ml/registry/_manager/model_manager.py +16 -3
  244. snowflake/ml/version.py +1 -1
  245. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
  246. snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
  247. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  248. snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
  249. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  250. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
17
17
  from snowflake.ml.model._packager.model_meta import (
18
18
  model_blob_meta,
19
19
  model_meta as model_meta_api,
20
+ model_meta_schema,
20
21
  )
21
22
 
22
23
 
@@ -68,6 +69,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
68
69
  predictions_df = target_method(model, sample_input_data)
69
70
  return predictions_df
70
71
 
72
+ for func_name in model._get_partitioned_infer_methods():
73
+ function_properties = model_meta.function_properties.get(func_name, {})
74
+ function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
75
+ model_meta.function_properties[func_name] = function_properties
76
+
71
77
  if not is_sub_model:
72
78
  model_meta = handlers_utils.validate_signature(
73
79
  model=model,
@@ -101,14 +107,16 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
101
107
 
102
108
  # Make sure that the module where the model is defined get pickled by value as well.
103
109
  cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
104
- picked_obj = (model.__class__, model.context)
110
+ pickled_obj = (model.__class__, model.context)
105
111
  with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
106
- cloudpickle.dump(picked_obj, f)
112
+ cloudpickle.dump(pickled_obj, f)
113
+ # model meta will be saved by the context manager
107
114
  model_meta.models[name] = model_blob_meta.ModelBlobMeta(
108
115
  name=name,
109
116
  model_type=cls.HANDLER_TYPE,
110
117
  path=cls.MODELE_BLOB_FILE_OR_DIR,
111
118
  handler_version=cls.HANDLER_VERSION,
119
+ function_properties=model_meta.function_properties,
112
120
  artifacts={
113
121
  name: pathlib.Path(
114
122
  os.path.join(cls.MODEL_ARTIFACTS_DIR, os.path.basename(os.path.normpath(path=uri)))
@@ -128,7 +136,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
128
136
  name: str,
129
137
  model_meta: model_meta_api.ModelMetadata,
130
138
  model_blobs_dir_path: str,
131
- **kwargs: Unpack[model_types.ModelLoadOption],
139
+ **kwargs: Unpack[model_types.CustomModelLoadOption],
132
140
  ) -> "custom_model.CustomModel":
133
141
  model_blob_path = os.path.join(model_blobs_dir_path, name)
134
142
 
@@ -175,6 +183,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
175
183
  cls,
176
184
  raw_model: custom_model.CustomModel,
177
185
  model_meta: model_meta_api.ModelMetadata,
178
- **kwargs: Unpack[model_types.ModelLoadOption],
186
+ **kwargs: Unpack[model_types.CustomModelLoadOption],
179
187
  ) -> custom_model.CustomModel:
180
188
  return raw_model
@@ -3,6 +3,7 @@ import os
3
3
  import warnings
4
4
  from typing import (
5
5
  TYPE_CHECKING,
6
+ Any,
6
7
  Callable,
7
8
  Dict,
8
9
  List,
@@ -250,9 +251,18 @@ class HuggingFacePipelineHandler(
250
251
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
251
252
 
252
253
  @staticmethod
253
- def _get_device_config() -> Dict[str, str]:
254
- device_config = {}
255
- device_config["device_map"] = "auto"
254
+ def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
255
+ device_config: Dict[str, Any] = {}
256
+ if (
257
+ kwargs.get("use_gpu", False)
258
+ and kwargs.get("device_map", None) is None
259
+ and kwargs.get("device", None) is None
260
+ ):
261
+ device_config["device_map"] = "auto"
262
+ elif kwargs.get("device_map", None) is not None:
263
+ device_config["device_map"] = kwargs["device_map"]
264
+ elif kwargs.get("device", None) is not None:
265
+ device_config["device"] = kwargs["device"]
256
266
 
257
267
  return device_config
258
268
 
@@ -262,7 +272,7 @@ class HuggingFacePipelineHandler(
262
272
  name: str,
263
273
  model_meta: model_meta_api.ModelMetadata,
264
274
  model_blobs_dir_path: str,
265
- **kwargs: Unpack[model_types.ModelLoadOption],
275
+ **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
266
276
  ) -> Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"]:
267
277
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
268
278
  # We need to redirect the some folders to a writable location in the sandbox.
@@ -292,10 +302,7 @@ class HuggingFacePipelineHandler(
292
302
  ) as f:
293
303
  pipeline_params = cloudpickle.load(f)
294
304
 
295
- if kwargs.get("use_gpu", False):
296
- device_config = cls._get_device_config()
297
- else:
298
- device_config = {}
305
+ device_config = cls._get_device_config(**kwargs)
299
306
 
300
307
  m = transformers.pipeline(
301
308
  model_blob_options["task"],
@@ -310,12 +317,8 @@ class HuggingFacePipelineHandler(
310
317
  with open(model_blob_file_or_dir_path, "rb") as f:
311
318
  m = cloudpickle.load(f)
312
319
  assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
313
- if (
314
- getattr(m, "device", None) is None
315
- and getattr(m, "device_map", None) is None
316
- and kwargs.get("use_gpu", False)
317
- ):
318
- m.__dict__.update(cls._get_device_config())
320
+ if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
321
+ m.__dict__.update(cls._get_device_config(**kwargs))
319
322
 
320
323
  if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
321
324
  m.__dict__.update(torch_dtype="auto")
@@ -326,7 +329,7 @@ class HuggingFacePipelineHandler(
326
329
  cls,
327
330
  raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
328
331
  model_meta: model_meta_api.ModelMetadata,
329
- **kwargs: Unpack[model_types.ModelLoadOption],
332
+ **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
330
333
  ) -> custom_model.CustomModel:
331
334
  import transformers
332
335
 
@@ -139,7 +139,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
139
139
  name: str,
140
140
  model_meta: model_meta_api.ModelMetadata,
141
141
  model_blobs_dir_path: str,
142
- **kwargs: Unpack[model_types.ModelLoadOption],
142
+ **kwargs: Unpack[model_types.LGBMModelLoadOptions],
143
143
  ) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
144
144
  import lightgbm
145
145
 
@@ -169,7 +169,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
169
169
  cls,
170
170
  raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
171
171
  model_meta: model_meta_api.ModelMetadata,
172
- **kwargs: Unpack[model_types.ModelLoadOption],
172
+ **kwargs: Unpack[model_types.LGBMModelLoadOptions],
173
173
  ) -> custom_model.CustomModel:
174
174
  import lightgbm
175
175
 
@@ -118,7 +118,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
118
118
  name: str,
119
119
  model_meta: model_meta_api.ModelMetadata,
120
120
  model_blobs_dir_path: str,
121
- **kwargs: Unpack[model_types.ModelLoadOption],
121
+ **kwargs: Unpack[model_types.LLMLoadOptions],
122
122
  ) -> llm.LLM:
123
123
  model_blob_path = os.path.join(model_blobs_dir_path, name)
124
124
  if not hasattr(model_meta, "models"):
@@ -143,7 +143,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
143
143
  cls,
144
144
  raw_model: llm.LLM,
145
145
  model_meta: model_meta_api.ModelMetadata,
146
- **kwargs: Unpack[model_types.ModelLoadOption],
146
+ **kwargs: Unpack[model_types.LLMLoadOptions],
147
147
  ) -> custom_model.CustomModel:
148
148
  import gc
149
149
  import tempfile
@@ -160,7 +160,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
160
160
  name: str,
161
161
  model_meta: model_meta_api.ModelMetadata,
162
162
  model_blobs_dir_path: str,
163
- **kwargs: Unpack[model_types.ModelLoadOption],
163
+ **kwargs: Unpack[model_types.MLFlowLoadOptions],
164
164
  ) -> "mlflow.pyfunc.PyFuncModel":
165
165
  import mlflow
166
166
 
@@ -194,7 +194,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
194
194
  cls,
195
195
  raw_model: "mlflow.pyfunc.PyFuncModel",
196
196
  model_meta: model_meta_api.ModelMetadata,
197
- **kwargs: Unpack[model_types.ModelLoadOption],
197
+ **kwargs: Unpack[model_types.MLFlowLoadOptions],
198
198
  ) -> custom_model.CustomModel:
199
199
  from snowflake.ml.model import custom_model
200
200
 
@@ -137,7 +137,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
137
137
  name: str,
138
138
  model_meta: model_meta_api.ModelMetadata,
139
139
  model_blobs_dir_path: str,
140
- **kwargs: Unpack[model_types.ModelLoadOption],
140
+ **kwargs: Unpack[model_types.PyTorchLoadOptions],
141
141
  ) -> "torch.nn.Module":
142
142
  import torch
143
143
 
@@ -156,7 +156,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
156
156
  cls,
157
157
  raw_model: "torch.nn.Module",
158
158
  model_meta: model_meta_api.ModelMetadata,
159
- **kwargs: Unpack[model_types.ModelLoadOption],
159
+ **kwargs: Unpack[model_types.PyTorchLoadOptions],
160
160
  ) -> custom_model.CustomModel:
161
161
  import torch
162
162
 
@@ -126,7 +126,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
126
126
  name: str,
127
127
  model_meta: model_meta_api.ModelMetadata,
128
128
  model_blobs_dir_path: str,
129
- **kwargs: Unpack[model_types.ModelLoadOption], # use_gpu
129
+ **kwargs: Unpack[model_types.SentenceTransformersLoadOptions], # use_gpu
130
130
  ) -> "sentence_transformers.SentenceTransformer":
131
131
  import sentence_transformers
132
132
 
@@ -154,7 +154,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
154
154
  cls,
155
155
  raw_model: "sentence_transformers.SentenceTransformer",
156
156
  model_meta: model_meta_api.ModelMetadata,
157
- **kwargs: Unpack[model_types.ModelLoadOption],
157
+ **kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
158
158
  ) -> custom_model.CustomModel:
159
159
  import sentence_transformers
160
160
 
@@ -133,7 +133,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
133
133
  name: str,
134
134
  model_meta: model_meta_api.ModelMetadata,
135
135
  model_blobs_dir_path: str,
136
- **kwargs: Unpack[model_types.ModelLoadOption],
136
+ **kwargs: Unpack[model_types.SKLModelLoadOptions],
137
137
  ) -> Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]:
138
138
  model_blob_path = os.path.join(model_blobs_dir_path, name)
139
139
  model_blobs_metadata = model_meta.models
@@ -153,7 +153,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
153
153
  cls,
154
154
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
155
155
  model_meta: model_meta_api.ModelMetadata,
156
- **kwargs: Unpack[model_types.ModelLoadOption],
156
+ **kwargs: Unpack[model_types.SKLModelLoadOptions],
157
157
  ) -> custom_model.CustomModel:
158
158
  from snowflake.ml.model import custom_model
159
159
 
@@ -127,7 +127,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
127
127
  name: str,
128
128
  model_meta: model_meta_api.ModelMetadata,
129
129
  model_blobs_dir_path: str,
130
- **kwargs: Unpack[model_types.ModelLoadOption],
130
+ **kwargs: Unpack[model_types.SNOWModelLoadOptions],
131
131
  ) -> "BaseEstimator":
132
132
  model_blob_path = os.path.join(model_blobs_dir_path, name)
133
133
  model_blobs_metadata = model_meta.models
@@ -146,7 +146,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
146
146
  cls,
147
147
  raw_model: "BaseEstimator",
148
148
  model_meta: model_meta_api.ModelMetadata,
149
- **kwargs: Unpack[model_types.ModelLoadOption],
149
+ **kwargs: Unpack[model_types.SNOWModelLoadOptions],
150
150
  ) -> custom_model.CustomModel:
151
151
  from snowflake.ml.model import custom_model
152
152
 
@@ -138,7 +138,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
138
138
  name: str,
139
139
  model_meta: model_meta_api.ModelMetadata,
140
140
  model_blobs_dir_path: str,
141
- **kwargs: Unpack[model_types.ModelLoadOption],
141
+ **kwargs: Unpack[model_types.TensorflowLoadOptions],
142
142
  ) -> "tensorflow.Module":
143
143
  import tensorflow
144
144
 
@@ -156,7 +156,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
156
156
  cls,
157
157
  raw_model: "tensorflow.Module",
158
158
  model_meta: model_meta_api.ModelMetadata,
159
- **kwargs: Unpack[model_types.ModelLoadOption],
159
+ **kwargs: Unpack[model_types.TensorflowLoadOptions],
160
160
  ) -> custom_model.CustomModel:
161
161
  import tensorflow
162
162
 
@@ -128,7 +128,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
128
128
  name: str,
129
129
  model_meta: model_meta_api.ModelMetadata,
130
130
  model_blobs_dir_path: str,
131
- **kwargs: Unpack[model_types.ModelLoadOption],
131
+ **kwargs: Unpack[model_types.TorchScriptLoadOptions],
132
132
  ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
133
133
  import torch
134
134
 
@@ -152,7 +152,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
152
152
  cls,
153
153
  raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
154
154
  model_meta: model_meta_api.ModelMetadata,
155
- **kwargs: Unpack[model_types.ModelLoadOption],
155
+ **kwargs: Unpack[model_types.TorchScriptLoadOptions],
156
156
  ) -> custom_model.CustomModel:
157
157
  from snowflake.ml.model import custom_model
158
158
 
@@ -141,7 +141,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
141
141
  name: str,
142
142
  model_meta: model_meta_api.ModelMetadata,
143
143
  model_blobs_dir_path: str,
144
- **kwargs: Unpack[model_types.ModelLoadOption],
144
+ **kwargs: Unpack[model_types.XGBModelLoadOptions],
145
145
  ) -> Union["xgboost.Booster", "xgboost.XGBModel"]:
146
146
  import xgboost
147
147
 
@@ -175,7 +175,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
175
175
  cls,
176
176
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
177
177
  model_meta: model_meta_api.ModelMetadata,
178
- **kwargs: Unpack[model_types.ModelLoadOption],
178
+ **kwargs: Unpack[model_types.XGBModelLoadOptions],
179
179
  ) -> custom_model.CustomModel:
180
180
  import xgboost
181
181
 
@@ -6,6 +6,6 @@ REQUIREMENTS = [
6
6
  "packaging>=20.9,<24",
7
7
  "pandas>=1.0.0,<3",
8
8
  "pyyaml>=6.0,<7",
9
- "snowflake-snowpark-python>=1.11.1,<2,!=1.12.0",
9
+ "snowflake-snowpark-python>=1.17.0,<2",
10
10
  "typing-extensions>=4.1.0,<5"
11
11
  ]
@@ -23,6 +23,7 @@ class ModelBlobMeta:
23
23
  self.model_type = kwargs["model_type"]
24
24
  self.path = kwargs["path"]
25
25
  self.handler_version = kwargs["handler_version"]
26
+ self.function_properties = kwargs.get("function_properties", {})
26
27
 
27
28
  self.artifacts: Dict[str, str] = {}
28
29
  artifacts = kwargs.get("artifacts", None)
@@ -39,6 +40,7 @@ class ModelBlobMeta:
39
40
  model_type=self.model_type,
40
41
  path=self.path,
41
42
  handler_version=self.handler_version,
43
+ function_properties=self.function_properties,
42
44
  artifacts=self.artifacts,
43
45
  options=self.options,
44
46
  )
@@ -7,11 +7,12 @@ import zipfile
7
7
  from contextlib import contextmanager
8
8
  from datetime import datetime
9
9
  from types import ModuleType
10
- from typing import Any, Dict, Generator, List, Optional
10
+ from typing import Any, Dict, Generator, List, Optional, TypedDict
11
11
 
12
12
  import cloudpickle
13
13
  import yaml
14
14
  from packaging import requirements, version
15
+ from typing_extensions import Required
15
16
 
16
17
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
17
18
  from snowflake.ml.model import model_signature, type_hints as model_types
@@ -47,6 +48,7 @@ def create_model_metadata(
47
48
  name: str,
48
49
  model_type: model_types.SupportedModelHandlerType,
49
50
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
51
+ function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
50
52
  metadata: Optional[Dict[str, str]] = None,
51
53
  code_paths: Optional[List[str]] = None,
52
54
  ext_modules: Optional[List[ModuleType]] = None,
@@ -64,6 +66,7 @@ def create_model_metadata(
64
66
  model_type: Type of the model.
65
67
  signatures: Signatures of the model. If None, it will be inferred after the model meta is created.
66
68
  Defaults to None.
69
+ function_properties: Dict mapping function names to a dict of properties, mapping property key to value.
67
70
  metadata: User provided key-value metadata of the model. Defaults to None.
68
71
  code_paths: List of paths to additional codes that needs to be packed with. Defaults to None.
69
72
  ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
@@ -127,6 +130,7 @@ def create_model_metadata(
127
130
  metadata=metadata,
128
131
  model_type=model_type,
129
132
  signatures=signatures,
133
+ function_properties=function_properties,
130
134
  )
131
135
 
132
136
  code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
@@ -215,6 +219,12 @@ def load_code_path(model_dir_path: str) -> None:
215
219
  sys.path.insert(0, code_path)
216
220
 
217
221
 
222
+ class ModelMetadataTelemetryDict(TypedDict):
223
+ model_name: Required[str]
224
+ framework_type: Required[model_types.SupportedModelHandlerType]
225
+ number_of_functions: Required[int]
226
+
227
+
218
228
  class ModelMetadata:
219
229
  """Model metadata for Snowflake native model packaged model.
220
230
 
@@ -224,10 +234,18 @@ class ModelMetadata:
224
234
  env: ModelEnv object containing all environment related object
225
235
  models: Dict of model blob metadata
226
236
  signatures: A dict mapping from target function name to input and output signatures.
237
+ function_properties: A dict mapping function names to dict mapping function property key to value.
227
238
  metadata: User provided key-value metadata of the model. Defaults to None.
228
239
  creation_timestamp: Unix timestamp when the model metadata is created.
229
240
  """
230
241
 
242
+ def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
243
+ return ModelMetadataTelemetryDict(
244
+ model_name=self.name,
245
+ framework_type=self.model_type,
246
+ number_of_functions=len(self.signatures.keys()),
247
+ )
248
+
231
249
  def __init__(
232
250
  self,
233
251
  *,
@@ -236,6 +254,7 @@ class ModelMetadata:
236
254
  model_type: model_types.SupportedModelHandlerType,
237
255
  runtimes: Optional[Dict[str, model_runtime.ModelRuntime]] = None,
238
256
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
257
+ function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
239
258
  metadata: Optional[Dict[str, str]] = None,
240
259
  creation_timestamp: Optional[str] = None,
241
260
  min_snowpark_ml_version: Optional[str] = None,
@@ -246,6 +265,7 @@ class ModelMetadata:
246
265
  self.signatures: Dict[str, model_signature.ModelSignature] = dict()
247
266
  if signatures:
248
267
  self.signatures = signatures
268
+ self.function_properties = function_properties or {}
249
269
  self.metadata = metadata
250
270
  self.model_type = model_type
251
271
  self.env = env
@@ -1,6 +1,6 @@
1
1
  # This files contains schema definition of what will be written into model.yml
2
2
  # Changing this file should lead to a change of the schema version.
3
-
3
+ from enum import Enum
4
4
  from typing import Any, Dict, List, Optional, TypedDict, Union
5
5
 
6
6
  from typing_extensions import NotRequired, Required
@@ -11,6 +11,10 @@ MODEL_METADATA_VERSION = "2023-12-01"
11
11
  MODEL_METADATA_MIN_SNOWPARK_ML_VERSION = "1.0.12"
12
12
 
13
13
 
14
+ class FunctionProperties(Enum):
15
+ PARTITIONED = "PARTITIONED"
16
+
17
+
14
18
  class ModelRuntimeDependenciesDict(TypedDict):
15
19
  conda: Required[str]
16
20
  pip: Required[str]
@@ -72,6 +76,7 @@ class ModelBlobMetadataDict(TypedDict):
72
76
  model_type: Required[type_hints.SupportedModelHandlerType]
73
77
  path: Required[str]
74
78
  handler_version: Required[str]
79
+ function_properties: NotRequired[Dict[str, Dict[str, Any]]]
75
80
  artifacts: NotRequired[Dict[str, str]]
76
81
  options: NotRequired[ModelBlobOptions]
77
82
 
@@ -47,12 +47,12 @@ class ModelPackager:
47
47
  ext_modules: Optional[List[ModuleType]] = None,
48
48
  code_paths: Optional[List[str]] = None,
49
49
  options: Optional[model_types.ModelSaveOption] = None,
50
- ) -> None:
50
+ ) -> model_meta.ModelMetadata:
51
51
  if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
52
52
  raise snowml_exceptions.SnowflakeMLException(
53
53
  error_code=error_codes.INVALID_ARGUMENT,
54
54
  original_exception=ValueError(
55
- "Signatures and sample_input_data both cannot be None at the same time for this kind of model."
55
+ "Either of `signatures` or `sample_input_data` must be provided for this kind of model."
56
56
  ),
57
57
  )
58
58
 
@@ -103,6 +103,7 @@ class ModelPackager:
103
103
 
104
104
  self.model = model
105
105
  self.meta = meta
106
+ return meta
106
107
 
107
108
  def load(
108
109
  self,
@@ -110,7 +111,7 @@ class ModelPackager:
110
111
  meta_only: bool = False,
111
112
  as_custom_model: bool = False,
112
113
  options: Optional[model_types.ModelLoadOption] = None,
113
- ) -> None:
114
+ ) -> model_meta.ModelMetadata:
114
115
  """Load the model into memory from directory. Used internal only.
115
116
 
116
117
  Args:
@@ -120,11 +121,14 @@ class ModelPackager:
120
121
 
121
122
  Raises:
122
123
  SnowflakeMLException: Raised if model is not native format.
124
+
125
+ Returns:
126
+ Metadata of loaded model.
123
127
  """
124
128
 
125
129
  self.meta = model_meta.ModelMetadata.load(self.local_dir_path)
126
130
  if meta_only:
127
- return
131
+ return self.meta
128
132
 
129
133
  model_meta.load_code_path(self.local_dir_path)
130
134
 
@@ -146,3 +150,4 @@ class ModelPackager:
146
150
  assert isinstance(m, custom_model.CustomModel)
147
151
 
148
152
  self.model = m
153
+ return self.meta
@@ -5,6 +5,6 @@ REQUIREMENTS = [
5
5
  "packaging>=20.9,<24",
6
6
  "pandas>=1.0.0,<3",
7
7
  "pyyaml>=6.0,<7",
8
- "snowflake-snowpark-python>=1.11.1,<2,!=1.12.0",
8
+ "snowflake-snowpark-python>=1.17.0,<2",
9
9
  "typing-extensions>=4.1.0,<5"
10
10
  ]
@@ -1,3 +1,4 @@
1
+ import datetime
1
2
  from collections import abc
2
3
  from typing import Literal, Sequence
3
4
 
@@ -24,7 +25,7 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
24
25
  # String is a Sequence but we take them as an whole
25
26
  if isinstance(element, abc.Sequence) and not isinstance(element, str):
26
27
  can_handle = ListOfBuiltinHandler.can_handle(element)
27
- elif not isinstance(element, (int, float, bool, str)):
28
+ elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
28
29
  can_handle = False
29
30
  break
30
31
  return can_handle
@@ -53,6 +53,8 @@ class DataType(Enum):
53
53
  STRING = ("string", spt.StringType, np.str_)
54
54
  BYTES = ("bytes", spt.BinaryType, np.bytes_)
55
55
 
56
+ TIMESTAMP_NTZ = ("datetime64[ns]", spt.TimestampType, "datetime64[ns]")
57
+
56
58
  def as_snowpark_type(self) -> spt.DataType:
57
59
  """Convert to corresponding Snowpark Type.
58
60
 
@@ -78,6 +80,13 @@ class DataType(Enum):
78
80
  Corresponding DataType.
79
81
  """
80
82
  np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
83
+
84
+ # Add datetime types:
85
+ datetime_res = ["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"]
86
+
87
+ for res in datetime_res:
88
+ np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
89
+
81
90
  for potential_type in np_to_snowml_type_mapping.keys():
82
91
  if np.can_cast(np_type, potential_type, casting="no"):
83
92
  # This is used since the same dtype might represented in different ways.
@@ -247,9 +256,12 @@ class FeatureSpec(BaseFeatureSpec):
247
256
  result_type = spt.ArrayType(result_type)
248
257
  return result_type
249
258
 
250
- def as_dtype(self) -> npt.DTypeLike:
259
+ def as_dtype(self) -> Union[npt.DTypeLike, str]:
251
260
  """Convert to corresponding local Type."""
252
261
  if not self._shape:
262
+ # scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
263
+ if "datetime64" in self._dtype._value:
264
+ return self._dtype._value
253
265
  return self._dtype._numpy_type
254
266
  return np.object_
255
267
 
@@ -147,6 +147,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
147
147
  specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
148
148
  elif isinstance(data[df_col].iloc[0], bytes):
149
149
  specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
150
+ elif isinstance(data[df_col].iloc[0], np.datetime64):
151
+ specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
150
152
  else:
151
153
  specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(df_col_dtype), name=ft_name))
152
154
  return specs
@@ -107,6 +107,9 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
107
107
  if not features:
108
108
  features = pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input")
109
109
  # Role will be no effect on the column index. That is to say, the feature name is the actual column name.
110
+ if keep_order:
111
+ df = df.reset_index(drop=True)
112
+ df[infer_template._KEEP_ORDER_COL_NAME] = df.index
110
113
  sp_df = session.create_dataframe(df)
111
114
  column_names = []
112
115
  columns = []
@@ -122,7 +125,4 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
122
125
 
123
126
  sp_df = sp_df.with_columns(column_names, columns)
124
127
 
125
- if keep_order:
126
- sp_df = sp_df.with_column(infer_template._KEEP_ORDER_COL_NAME, F.monotonically_increasing_id())
127
-
128
128
  return sp_df
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import inspect
3
- from typing import Any, Callable, Coroutine, Dict, Generator, Optional
3
+ from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional
4
4
 
5
5
  import anyio
6
6
  import pandas as pd
@@ -168,7 +168,7 @@ class CustomModel:
168
168
  def _get_infer_methods(
169
169
  self,
170
170
  ) -> Generator[Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame], None, None]:
171
- """Returns all methods in CLS with DECORATOR as the outermost decorator."""
171
+ """Returns all methods in CLS with `inference_api` decorator as the outermost decorator."""
172
172
  for cls_method_str in dir(self):
173
173
  cls_method = getattr(self, cls_method_str)
174
174
  if getattr(cls_method, "_is_inference_api", False):
@@ -177,6 +177,18 @@ class CustomModel:
177
177
  else:
178
178
  raise TypeError("A non-method inference API function is not supported.")
179
179
 
180
+ def _get_partitioned_infer_methods(self) -> List[str]:
181
+ """Returns all methods in CLS with `partitioned_inference_api` as the outermost decorator."""
182
+ rv = []
183
+ for cls_method_str in dir(self):
184
+ cls_method = getattr(self, cls_method_str)
185
+ if getattr(cls_method, "_is_partitioned_inference_api", False):
186
+ if inspect.ismethod(cls_method):
187
+ rv.append(cls_method_str)
188
+ else:
189
+ raise TypeError("A non-method inference API function is not supported.")
190
+ return rv
191
+
180
192
 
181
193
  def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]) -> None:
182
194
  """Validate the user provided predict method.
@@ -219,3 +231,11 @@ def inference_api(
219
231
  ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
220
232
  func.__dict__["_is_inference_api"] = True
221
233
  return func
234
+
235
+
236
+ def partitioned_inference_api(
237
+ func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]
238
+ ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
239
+ func.__dict__["_is_inference_api"] = True
240
+ func.__dict__["_is_partitioned_inference_api"] = True
241
+ return func
@@ -168,6 +168,8 @@ def _validate_numpy_array(
168
168
  max_v <= np.finfo(feature_type._numpy_type).max # type: ignore[arg-type]
169
169
  and min_v >= np.finfo(feature_type._numpy_type).min # type: ignore[arg-type]
170
170
  )
171
+ elif feature_type in [core.DataType.TIMESTAMP_NTZ]:
172
+ return np.issubdtype(arr.dtype, np.datetime64)
171
173
  else:
172
174
  return np.can_cast(arr.dtype, feature_type._numpy_type, casting="no")
173
175