snowflake-ml-python 1.5.3__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +224 -21
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_summarize.py +0 -1
  6. snowflake/cortex/_translate.py +0 -1
  7. snowflake/cortex/_util.py +12 -85
  8. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  9. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  10. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  11. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  12. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  13. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  14. snowflake/ml/_internal/telemetry.py +26 -0
  15. snowflake/ml/_internal/utils/identifier.py +14 -0
  16. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  17. snowflake/ml/dataset/dataset.py +39 -20
  18. snowflake/ml/feature_store/feature_store.py +440 -243
  19. snowflake/ml/feature_store/feature_view.py +61 -9
  20. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  21. snowflake/ml/fileset/fileset.py +2 -2
  22. snowflake/ml/fileset/snowfs.py +4 -15
  23. snowflake/ml/fileset/stage_fs.py +6 -8
  24. snowflake/ml/lineage/__init__.py +3 -0
  25. snowflake/ml/lineage/lineage_node.py +139 -0
  26. snowflake/ml/model/_client/model/model_impl.py +47 -14
  27. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  28. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  29. snowflake/ml/model/_client/sql/model.py +1 -0
  30. snowflake/ml/model/_client/sql/model_version.py +45 -2
  31. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  32. snowflake/ml/model/_model_composer/model_composer.py +5 -4
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  34. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  35. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  36. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -2
  37. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  38. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  39. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  41. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  42. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  43. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  45. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  46. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  53. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  54. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  55. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  56. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  57. snowflake/ml/model/_packager/model_packager.py +9 -4
  58. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  59. snowflake/ml/model/custom_model.py +22 -2
  60. snowflake/ml/model/type_hints.py +73 -4
  61. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -0
  62. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  63. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  64. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  65. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  66. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  67. snowflake/ml/modeling/cluster/birch.py +4 -2
  68. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  69. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  70. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  71. snowflake/ml/modeling/cluster/k_means.py +4 -2
  72. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  74. snowflake/ml/modeling/cluster/optics.py +4 -2
  75. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  76. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  77. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  78. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  79. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  80. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  81. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  82. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  83. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  84. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  85. snowflake/ml/modeling/covariance/oas.py +4 -2
  86. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  87. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  88. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  89. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  90. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  91. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  92. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  93. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  94. snowflake/ml/modeling/decomposition/pca.py +4 -2
  95. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  96. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  97. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  98. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  99. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  100. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  101. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  102. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  103. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  104. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  105. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  106. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  107. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  108. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  109. snowflake/ml/modeling/manifold/isomap.py +4 -2
  110. snowflake/ml/modeling/manifold/mds.py +4 -2
  111. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  112. snowflake/ml/modeling/manifold/tsne.py +4 -2
  113. snowflake/ml/modeling/metrics/ranking.py +3 -0
  114. snowflake/ml/modeling/metrics/regression.py +3 -0
  115. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  116. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  117. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  118. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  119. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  120. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  121. snowflake/ml/modeling/pipeline/pipeline.py +1 -0
  122. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  123. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  124. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  125. snowflake/ml/registry/_manager/model_manager.py +16 -3
  126. snowflake/ml/version.py +1 -1
  127. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +35 -7
  128. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/RECORD +131 -127
  129. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  130. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  131. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,13 @@ import pandas as pd
8
8
 
9
9
  from snowflake.ml._internal import telemetry
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.lineage import lineage_node
11
12
  from snowflake.ml.model import type_hints as model_types
12
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops
13
14
  from snowflake.ml.model._model_composer import model_composer
14
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
15
16
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
16
- from snowflake.snowpark import dataframe
17
+ from snowflake.snowpark import Session, dataframe
17
18
 
18
19
  _TELEMETRY_PROJECT = "MLOps"
19
20
  _TELEMETRY_SUBPROJECT = "ModelManagement"
@@ -24,7 +25,7 @@ class ExportMode(enum.Enum):
24
25
  FULL = "full"
25
26
 
26
27
 
27
- class ModelVersion:
28
+ class ModelVersion(lineage_node.LineageNode):
28
29
  """Model Version Object representing a specific version of the model that could be run."""
29
30
 
30
31
  _model_ops: model_ops.ModelOperator
@@ -48,6 +49,15 @@ class ModelVersion:
48
49
  self._model_name = model_name
49
50
  self._version_name = version_name
50
51
  self._functions = self._get_functions()
52
+ super(cls, cls).__init__(
53
+ self,
54
+ session=model_ops._session,
55
+ name=model_ops._model_client.fully_qualified_object_name(
56
+ database_name=None, schema_name=None, object_name=model_name
57
+ ),
58
+ domain="model",
59
+ version=version_name,
60
+ )
51
61
  return self
52
62
 
53
63
  def __eq__(self, __value: object) -> bool:
@@ -59,6 +69,11 @@ class ModelVersion:
59
69
  and self._version_name == __value._version_name
60
70
  )
61
71
 
72
+ def __repr__(self) -> str:
73
+ return (
74
+ f"{self.__class__.__name__}(\n" f" name='{self.model_name}',\n" f" version='{self._version_name}',\n" f")"
75
+ )
76
+
62
77
  @property
63
78
  def model_name(self) -> str:
64
79
  """Return the name of the model to which the model version belongs, usable as a reference in SQL."""
@@ -198,6 +213,52 @@ class ModelVersion:
198
213
  statement_params=statement_params,
199
214
  )
200
215
 
216
+ @telemetry.send_api_usage_telemetry(
217
+ project=_TELEMETRY_PROJECT,
218
+ subproject=_TELEMETRY_SUBPROJECT,
219
+ )
220
+ def set_alias(self, alias_name: str) -> None:
221
+ """Set alias to a model version.
222
+
223
+ Args:
224
+ alias_name: Alias to the model version.
225
+ """
226
+ statement_params = telemetry.get_statement_params(
227
+ project=_TELEMETRY_PROJECT,
228
+ subproject=_TELEMETRY_SUBPROJECT,
229
+ )
230
+ alias_name = sql_identifier.SqlIdentifier(alias_name)
231
+ self._model_ops.set_alias(
232
+ alias_name=alias_name,
233
+ database_name=None,
234
+ schema_name=None,
235
+ model_name=self._model_name,
236
+ version_name=self._version_name,
237
+ statement_params=statement_params,
238
+ )
239
+
240
+ @telemetry.send_api_usage_telemetry(
241
+ project=_TELEMETRY_PROJECT,
242
+ subproject=_TELEMETRY_SUBPROJECT,
243
+ )
244
+ def unset_alias(self, version_or_alias: str) -> None:
245
+ """unset alias to a model version.
246
+
247
+ Args:
248
+ version_or_alias: The name of the version or alias to a version.
249
+ """
250
+ statement_params = telemetry.get_statement_params(
251
+ project=_TELEMETRY_PROJECT,
252
+ subproject=_TELEMETRY_SUBPROJECT,
253
+ )
254
+ self._model_ops.unset_alias(
255
+ version_or_alias_name=sql_identifier.SqlIdentifier(version_or_alias),
256
+ database_name=None,
257
+ schema_name=None,
258
+ model_name=self._model_name,
259
+ statement_params=statement_params,
260
+ )
261
+
201
262
  @telemetry.send_api_usage_telemetry(
202
263
  project=_TELEMETRY_PROJECT,
203
264
  subproject=_TELEMETRY_SUBPROJECT,
@@ -451,3 +512,22 @@ class ModelVersion:
451
512
  f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}"
452
513
  )
453
514
  return pk.model
515
+
516
+ @staticmethod
517
+ def _load_from_lineage_node(session: Session, name: str, version: str) -> "ModelVersion":
518
+ database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(name)
519
+ if not database_name_id or not schema_name_id:
520
+ raise ValueError("name should be fully qualifed.")
521
+
522
+ return ModelVersion._ref(
523
+ model_ops.ModelOperator(
524
+ session,
525
+ database_name=database_name_id,
526
+ schema_name=schema_name_id,
527
+ ),
528
+ model_name=model_name_id,
529
+ version_name=sql_identifier.SqlIdentifier(version),
530
+ )
531
+
532
+
533
+ lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -1,11 +1,12 @@
1
1
  import os
2
2
  import pathlib
3
3
  import tempfile
4
+ import warnings
4
5
  from typing import Any, Dict, List, Literal, Optional, Union, cast
5
6
 
6
7
  import yaml
7
8
 
8
- from snowflake.ml._internal.utils import identifier, sql_identifier
9
+ from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
9
10
  from snowflake.ml.model import model_signature, type_hints
10
11
  from snowflake.ml.model._client.ops import metadata_ops
11
12
  from snowflake.ml.model._client.sql import (
@@ -311,6 +312,42 @@ class ModelOperator:
311
312
  statement_params=statement_params,
312
313
  )
313
314
 
315
+ def set_alias(
316
+ self,
317
+ *,
318
+ alias_name: sql_identifier.SqlIdentifier,
319
+ database_name: Optional[sql_identifier.SqlIdentifier],
320
+ schema_name: Optional[sql_identifier.SqlIdentifier],
321
+ model_name: sql_identifier.SqlIdentifier,
322
+ version_name: sql_identifier.SqlIdentifier,
323
+ statement_params: Optional[Dict[str, Any]] = None,
324
+ ) -> None:
325
+ self._model_version_client.set_alias(
326
+ alias_name=alias_name,
327
+ database_name=database_name,
328
+ schema_name=schema_name,
329
+ model_name=model_name,
330
+ version_name=version_name,
331
+ statement_params=statement_params,
332
+ )
333
+
334
+ def unset_alias(
335
+ self,
336
+ *,
337
+ version_or_alias_name: sql_identifier.SqlIdentifier,
338
+ database_name: Optional[sql_identifier.SqlIdentifier],
339
+ schema_name: Optional[sql_identifier.SqlIdentifier],
340
+ model_name: sql_identifier.SqlIdentifier,
341
+ statement_params: Optional[Dict[str, Any]] = None,
342
+ ) -> None:
343
+ self._model_version_client.unset_alias(
344
+ database_name=database_name,
345
+ schema_name=schema_name,
346
+ model_name=model_name,
347
+ version_or_alias_name=version_or_alias_name,
348
+ statement_params=statement_params,
349
+ )
350
+
314
351
  def set_default_version(
315
352
  self,
316
353
  *,
@@ -354,6 +391,28 @@ class ModelOperator:
354
391
  res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
355
392
  )
356
393
 
394
+ def get_version_by_alias(
395
+ self,
396
+ *,
397
+ database_name: Optional[sql_identifier.SqlIdentifier],
398
+ schema_name: Optional[sql_identifier.SqlIdentifier],
399
+ model_name: sql_identifier.SqlIdentifier,
400
+ alias_name: sql_identifier.SqlIdentifier,
401
+ statement_params: Optional[Dict[str, Any]] = None,
402
+ ) -> Optional[sql_identifier.SqlIdentifier]:
403
+ res = self._model_client.show_versions(
404
+ database_name=database_name,
405
+ schema_name=schema_name,
406
+ model_name=model_name,
407
+ statement_params=statement_params,
408
+ )
409
+ for r in res:
410
+ if alias_name in r[self._model_client.MODEL_VERSION_ALIASES_COL_NAME]:
411
+ return sql_identifier.SqlIdentifier(
412
+ r[self._model_client.MODEL_VERSION_NAME_COL_NAME], case_sensitive=True
413
+ )
414
+ return None
415
+
357
416
  def get_tag_value(
358
417
  self,
359
418
  *,
@@ -625,10 +684,23 @@ class ModelOperator:
625
684
  )
626
685
 
627
686
  if keep_order:
628
- df_res = df_res.sort(
629
- "_ID",
630
- ascending=True,
631
- )
687
+ # if it's a partitioned table function, _ID will be null and we won't be able to sort.
688
+ if df_res.select("_ID").limit(1).collect()[0][0] is None:
689
+ warnings.warn(
690
+ formatting.unwrap(
691
+ """
692
+ When invoking partitioned inference methods, ordering of rows in output dataframe will differ
693
+ from that of input dataframe.
694
+ """
695
+ ),
696
+ category=UserWarning,
697
+ stacklevel=1,
698
+ )
699
+ else:
700
+ df_res = df_res.sort(
701
+ "_ID",
702
+ ascending=True,
703
+ )
632
704
 
633
705
  if not output_with_input_features:
634
706
  cols_to_drop = original_cols
@@ -14,6 +14,7 @@ class ModelSQLClient(_base._BaseSQLClient):
14
14
  MODEL_VERSION_COMMENT_COL_NAME = "comment"
15
15
  MODEL_VERSION_METADATA_COL_NAME = "metadata"
16
16
  MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
17
+ MODEL_VERSION_ALIASES_COL_NAME = "aliases"
17
18
 
18
19
  def show_models(
19
20
  self,
@@ -134,6 +134,43 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
134
134
  statement_params=statement_params,
135
135
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
136
136
 
137
+ def set_alias(
138
+ self,
139
+ *,
140
+ database_name: Optional[sql_identifier.SqlIdentifier],
141
+ schema_name: Optional[sql_identifier.SqlIdentifier],
142
+ model_name: sql_identifier.SqlIdentifier,
143
+ version_name: sql_identifier.SqlIdentifier,
144
+ alias_name: sql_identifier.SqlIdentifier,
145
+ statement_params: Optional[Dict[str, Any]] = None,
146
+ ) -> None:
147
+ query_result_checker.SqlResultValidator(
148
+ self._session,
149
+ (
150
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
151
+ f"VERSION {version_name.identifier()} SET ALIAS = {alias_name.identifier()}"
152
+ ),
153
+ statement_params=statement_params,
154
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
155
+
156
+ def unset_alias(
157
+ self,
158
+ *,
159
+ database_name: Optional[sql_identifier.SqlIdentifier],
160
+ schema_name: Optional[sql_identifier.SqlIdentifier],
161
+ model_name: sql_identifier.SqlIdentifier,
162
+ version_or_alias_name: sql_identifier.SqlIdentifier,
163
+ statement_params: Optional[Dict[str, Any]] = None,
164
+ ) -> None:
165
+ query_result_checker.SqlResultValidator(
166
+ self._session,
167
+ (
168
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
169
+ f"VERSION {version_or_alias_name.identifier()} UNSET ALIAS"
170
+ ),
171
+ statement_params=statement_params,
172
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
173
+
137
174
  def list_file(
138
175
  self,
139
176
  *,
@@ -383,9 +420,13 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
383
420
  # Prepare the output
384
421
  output_cols = []
385
422
  output_names = []
423
+ cols_to_drop = []
386
424
 
387
425
  for output_name, output_type, output_col_name in returns:
388
- output_cols.append(F.col(output_name).astype(output_type))
426
+ output_identifier = sql_identifier.SqlIdentifier(output_name).identifier()
427
+ if output_identifier != output_col_name:
428
+ cols_to_drop.append(output_identifier)
429
+ output_cols.append(F.col(output_identifier).astype(output_type))
389
430
  output_names.append(output_col_name)
390
431
 
391
432
  if partition_column is not None:
@@ -396,10 +437,12 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
396
437
  col_names=output_names,
397
438
  values=output_cols,
398
439
  )
399
-
400
440
  if statement_params:
401
441
  output_df._statement_params = statement_params # type: ignore[assignment]
402
442
 
443
+ if cols_to_drop:
444
+ output_df = output_df.drop(cols_to_drop)
445
+
403
446
  return output_df
404
447
 
405
448
  def set_metadata(
@@ -101,7 +101,6 @@ def _run_setup() -> None:
101
101
  logger.info(f"Loading model from {extracted_dir} into memory")
102
102
 
103
103
  sys.path.insert(0, os.path.join(extracted_dir, _MODEL_CODE_DIR))
104
- from snowflake.ml.model import type_hints as model_types
105
104
 
106
105
  # TODO (Server-side Model Rollout):
107
106
  # Keep try block only
@@ -114,7 +113,7 @@ def _run_setup() -> None:
114
113
  pk.load(
115
114
  as_custom_model=True,
116
115
  meta_only=False,
117
- options=model_types.ModelLoadOption({"use_gpu": use_gpu}),
116
+ options={"use_gpu": use_gpu},
118
117
  )
119
118
  _LOADED_MODEL = pk.model
120
119
  _LOADED_META = pk.meta
@@ -132,7 +131,7 @@ def _run_setup() -> None:
132
131
  _LOADED_MODEL, meta_LOADED_META = model_api._load(
133
132
  local_dir_path=extracted_dir,
134
133
  as_custom_model=True,
135
- options=model_types.ModelLoadOption({"use_gpu": use_gpu}),
134
+ options={"use_gpu": use_gpu},
136
135
  )
137
136
  _MODEL_LOADING_STATE = _ModelLoadingState.SUCCEEDED
138
137
  logger.info("Successfully loaded model into memory")
@@ -15,6 +15,7 @@ from snowflake.ml._internal.lineage import data_source, lineage_utils
15
15
  from snowflake.ml.model import model_signature, type_hints as model_types
16
16
  from snowflake.ml.model._model_composer.model_manifest import model_manifest
17
17
  from snowflake.ml.model._packager import model_packager
18
+ from snowflake.ml.model._packager.model_meta import model_meta
18
19
  from snowflake.snowpark import Session
19
20
  from snowflake.snowpark._internal import utils as snowpark_utils
20
21
 
@@ -90,7 +91,7 @@ class ModelComposer:
90
91
  ext_modules: Optional[List[ModuleType]] = None,
91
92
  code_paths: Optional[List[str]] = None,
92
93
  options: Optional[model_types.ModelSaveOption] = None,
93
- ) -> None:
94
+ ) -> model_meta.ModelMetadata:
94
95
  if not options:
95
96
  options = model_types.BaseModelSaveOption()
96
97
 
@@ -106,7 +107,7 @@ class ModelComposer:
106
107
  )
107
108
  options["embed_local_ml_library"] = True
108
109
 
109
- self.packager.save(
110
+ model_metadata: model_meta.ModelMetadata = self.packager.save(
110
111
  name=name,
111
112
  model=model,
112
113
  signatures=signatures,
@@ -119,7 +120,6 @@ class ModelComposer:
119
120
  code_paths=code_paths,
120
121
  options=options,
121
122
  )
122
-
123
123
  assert self.packager.meta is not None
124
124
 
125
125
  if not options.get("_legacy_save", False):
@@ -133,7 +133,7 @@ class ModelComposer:
133
133
 
134
134
  self.manifest.save(
135
135
  session=self.session,
136
- model_meta=self.packager.meta,
136
+ model_meta=model_metadata,
137
137
  model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
138
138
  options=options,
139
139
  data_sources=self._get_data_sources(model, sample_input_data),
@@ -145,6 +145,7 @@ class ModelComposer:
145
145
  stage_path=self.stage_path,
146
146
  statement_params=self._statement_params,
147
147
  )
148
+ return model_metadata
148
149
 
149
150
  @deprecated("Only used by PrPr model registry. Use static method version of load instead.")
150
151
  def legacy_load(
@@ -12,7 +12,10 @@ from snowflake.ml.model._model_composer.model_method import (
12
12
  function_generator,
13
13
  model_method,
14
14
  )
15
- from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
15
+ from snowflake.ml.model._packager.model_meta import (
16
+ model_meta as model_meta_api,
17
+ model_meta_schema,
18
+ )
16
19
  from snowflake.snowpark import Session
17
20
 
18
21
 
@@ -55,6 +58,9 @@ class ModelManifest:
55
58
  target_method=target_method,
56
59
  runtime_name=self._DEFAULT_RUNTIME_NAME,
57
60
  function_generator=self.function_generator,
61
+ is_partitioned_function=model_meta.function_properties.get(target_method, {}).get(
62
+ model_meta_schema.FunctionProperties.PARTITIONED.value, False
63
+ ),
58
64
  options=model_method.get_model_method_options_from_options(options, target_method),
59
65
  )
60
66
 
@@ -3,7 +3,14 @@ from typing import Optional, TypedDict
3
3
 
4
4
  from typing_extensions import NotRequired
5
5
 
6
+ from snowflake.ml._internal.exceptions import (
7
+ error_codes,
8
+ exceptions as snowml_exceptions,
9
+ )
6
10
  from snowflake.ml.model import type_hints
11
+ from snowflake.ml.model._model_composer.model_manifest.model_manifest_schema import (
12
+ ModelMethodFunctionTypes,
13
+ )
7
14
 
8
15
 
9
16
  class FunctionGenerateOptions(TypedDict):
@@ -35,6 +42,7 @@ class FunctionGenerator:
35
42
  function_file_path: pathlib.Path,
36
43
  target_method: str,
37
44
  function_type: str,
45
+ is_partitioned_function: bool = False,
38
46
  options: Optional[FunctionGenerateOptions] = None,
39
47
  ) -> None:
40
48
  import importlib_resources
@@ -42,7 +50,15 @@ class FunctionGenerator:
42
50
  if options is None:
43
51
  options = {}
44
52
 
45
- template_filename = f"infer_{function_type.lower()}.py_template"
53
+ if is_partitioned_function:
54
+ if function_type != ModelMethodFunctionTypes.TABLE_FUNCTION.value:
55
+ raise snowml_exceptions.SnowflakeMLException(
56
+ error_code=error_codes.INVALID_DATA,
57
+ original_exception=ValueError("Partitioned inference api functions must have type TABLE_FUNCTION."),
58
+ )
59
+ template_filename = "infer_partitioned.py_template"
60
+ else:
61
+ template_filename = f"infer_{function_type.lower()}.py_template"
46
62
 
47
63
  function_template = (
48
64
  importlib_resources.files("snowflake.ml.model._model_composer.model_method")
@@ -0,0 +1,79 @@
1
+ import fcntl
2
+ import functools
3
+ import inspect
4
+ import os
5
+ import sys
6
+ import threading
7
+ import zipfile
8
+ from types import TracebackType
9
+ from typing import Optional, Type
10
+
11
+ import anyio
12
+ import pandas as pd
13
+ from _snowflake import vectorized
14
+
15
+ from snowflake.ml.model._packager import model_packager
16
+
17
+
18
+ class FileLock:
19
+ def __enter__(self) -> None:
20
+ self._lock = threading.Lock()
21
+ self._lock.acquire()
22
+ self._fd = open("/tmp/lockfile.LOCK", "w+")
23
+ fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
+
25
+ def __exit__(
26
+ self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
+ ) -> None:
28
+ self._fd.close()
29
+ self._lock.release()
30
+
31
+
32
+ # User-defined parameters
33
+ MODEL_FILE_NAME = "{model_file_name}"
34
+ TARGET_METHOD = "{target_method}"
35
+ MAX_BATCH_SIZE = {max_batch_size}
36
+
37
+
38
+ # Retrieve the model
39
+ IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
+ import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
+
42
+ model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
+ zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
+ extracted = "/tmp/models"
45
+ extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
+
47
+ with FileLock():
48
+ if not os.path.isdir(extracted_model_dir_path):
49
+ with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
+ myzip.extractall(extracted_model_dir_path)
51
+
52
+ # Load the model
53
+ pk = model_packager.ModelPackager(extracted_model_dir_path)
54
+ pk.load(as_custom_model=True)
55
+ assert pk.model, "model is not loaded"
56
+ assert pk.meta, "model metadata is not loaded"
57
+
58
+ # Determine the actual runner
59
+ model = pk.model
60
+ meta = pk.meta
61
+ func = getattr(model, TARGET_METHOD)
62
+ if inspect.iscoroutinefunction(func):
63
+ runner = functools.partial(anyio.run, func)
64
+ else:
65
+ runner = functools.partial(func)
66
+
67
+ # Determine preprocess parameters
68
+ features = meta.signatures[TARGET_METHOD].inputs
69
+ input_cols = [feature.name for feature in features]
70
+ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
71
+
72
+
73
+ # Actual table function
74
+ class {function_name}:
75
+ @vectorized(input=pd.DataFrame)
76
+ def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
77
+ df.columns = input_cols
78
+ input_df = df.astype(dtype=dtype_map)
79
+ return runner(input_df[input_cols])
@@ -72,8 +72,8 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
72
72
 
73
73
  # Actual table function
74
74
  class {function_name}:
75
- @vectorized(input=pd.DataFrame)
76
- def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
75
+ @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
76
+ def process(self, df: pd.DataFrame) -> pd.DataFrame:
77
77
  df.columns = input_cols
78
78
  input_df = df.astype(dtype=dtype_map)
79
79
  return runner(input_df[input_cols])
@@ -32,8 +32,6 @@ def get_model_method_options_from_options(
32
32
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
33
33
  raise NotImplementedError
34
34
 
35
- # TODO(TH): enforce minimum snowflake version
36
-
37
35
  return ModelMethodOptions(
38
36
  case_sensitive=method_option.get("case_sensitive", False),
39
37
  function_type=function_type,
@@ -47,10 +45,9 @@ class ModelMethod:
47
45
  Attributes:
48
46
  model_meta: Model Metadata.
49
47
  target_method: Original target method name to call with the model.
50
- method_name: The actual method name registered in manifest and used in SQL.
51
-
52
- function_generator: Function file generator.
53
48
  runtime_name: Name of the Model Runtime to run the method.
49
+ function_generator: Function file generator.
50
+ is_partitioned_function: Whether the model method function is partitioned.
54
51
 
55
52
  options: Model Method Options.
56
53
  """
@@ -63,11 +60,13 @@ class ModelMethod:
63
60
  target_method: str,
64
61
  runtime_name: str,
65
62
  function_generator: function_generator.FunctionGenerator,
63
+ is_partitioned_function: bool = False,
66
64
  options: Optional[ModelMethodOptions] = None,
67
65
  ) -> None:
68
66
  self.model_meta = model_meta
69
67
  self.target_method = target_method
70
68
  self.function_generator = function_generator
69
+ self.is_partitioned_function = is_partitioned_function
71
70
  self.runtime_name = runtime_name
72
71
  self.options = options or {}
73
72
  try:
@@ -111,6 +110,7 @@ class ModelMethod:
111
110
  workspace_path / ModelMethod.FUNCTIONS_DIR_REL_PATH / f"{self.target_method}.py",
112
111
  self.target_method,
113
112
  self.function_type,
113
+ self.is_partitioned_function,
114
114
  options=options,
115
115
  )
116
116
  input_list = [
@@ -75,7 +75,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
75
75
  name: str,
76
76
  model_meta: model_meta.ModelMetadata,
77
77
  model_blobs_dir_path: str,
78
- **kwargs: Unpack[model_types.ModelLoadOption],
78
+ **kwargs: Unpack[model_types.BaseModelLoadOption],
79
79
  ) -> model_types._ModelType:
80
80
  """Load the model into memory.
81
81
 
@@ -96,7 +96,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
96
96
  cls,
97
97
  raw_model: model_types._ModelType,
98
98
  model_meta: model_meta.ModelMetadata,
99
- **kwargs: Unpack[model_types.ModelLoadOption],
99
+ **kwargs: Unpack[model_types.BaseModelLoadOption],
100
100
  ) -> custom_model.CustomModel:
101
101
  """Create a custom model class wrap for unified interface when being deployed. The predict method will be
102
102
  re-targeted based on target_method metadata.
@@ -36,6 +36,7 @@ def validate_signature(
36
36
  predictions_df = get_prediction_fn(target_method, local_sample_input)
37
37
  sig = model_signature.infer_signature(local_sample_input, predictions_df)
38
38
  model_meta.signatures[target_method] = sig
39
+
39
40
  return model_meta
40
41
 
41
42
 
@@ -122,7 +122,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
122
122
  name: str,
123
123
  model_meta: model_meta_api.ModelMetadata,
124
124
  model_blobs_dir_path: str,
125
- **kwargs: Unpack[model_types.ModelLoadOption],
125
+ **kwargs: Unpack[model_types.CatBoostModelLoadOptions],
126
126
  ) -> "catboost.CatBoost":
127
127
  import catboost
128
128
 
@@ -157,7 +157,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
157
157
  cls,
158
158
  raw_model: "catboost.CatBoost",
159
159
  model_meta: model_meta_api.ModelMetadata,
160
- **kwargs: Unpack[model_types.ModelLoadOption],
160
+ **kwargs: Unpack[model_types.CatBoostModelLoadOptions],
161
161
  ) -> custom_model.CustomModel:
162
162
  import catboost
163
163