snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/__init__.py +4 -1
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +281 -21
  4. snowflake/cortex/_extract_answer.py +0 -1
  5. snowflake/cortex/_sentiment.py +0 -1
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +12 -85
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  16. snowflake/ml/_internal/telemetry.py +38 -2
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  20. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  21. snowflake/ml/data/data_connector.py +133 -0
  22. snowflake/ml/data/data_ingestor.py +28 -0
  23. snowflake/ml/data/data_source.py +23 -0
  24. snowflake/ml/dataset/dataset.py +39 -32
  25. snowflake/ml/dataset/dataset_reader.py +18 -118
  26. snowflake/ml/feature_store/access_manager.py +7 -1
  27. snowflake/ml/feature_store/entity.py +19 -2
  28. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  29. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  30. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  31. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  32. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  34. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  35. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  36. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  37. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  38. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  39. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  40. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  43. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  44. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  45. snowflake/ml/feature_store/feature_store.py +987 -264
  46. snowflake/ml/feature_store/feature_view.py +228 -13
  47. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  48. snowflake/ml/fileset/fileset.py +2 -2
  49. snowflake/ml/fileset/snowfs.py +4 -15
  50. snowflake/ml/fileset/stage_fs.py +24 -18
  51. snowflake/ml/lineage/__init__.py +3 -0
  52. snowflake/ml/lineage/lineage_node.py +139 -0
  53. snowflake/ml/model/_client/model/model_impl.py +47 -14
  54. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  55. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  56. snowflake/ml/model/_client/sql/model.py +1 -0
  57. snowflake/ml/model/_client/sql/model_version.py +45 -2
  58. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  59. snowflake/ml/model/_model_composer/model_composer.py +15 -17
  60. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
  61. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  62. snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
  63. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  64. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
  65. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
  66. snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
  67. snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
  68. snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
  69. snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
  70. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  71. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  72. snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
  73. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  74. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  75. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  76. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  77. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  78. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  79. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  80. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  81. snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
  82. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  83. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  84. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  85. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  86. snowflake/ml/model/_packager/model_packager.py +9 -4
  87. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  88. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  89. snowflake/ml/model/custom_model.py +22 -2
  90. snowflake/ml/model/model_signature.py +4 -4
  91. snowflake/ml/model/type_hints.py +77 -4
  92. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
  93. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  94. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  95. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  96. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  97. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  98. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  99. snowflake/ml/modeling/cluster/birch.py +4 -2
  100. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  101. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  102. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  103. snowflake/ml/modeling/cluster/k_means.py +4 -2
  104. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  105. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  106. snowflake/ml/modeling/cluster/optics.py +4 -2
  107. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  108. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  109. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  110. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  111. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  112. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  113. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  114. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  115. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  116. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  117. snowflake/ml/modeling/covariance/oas.py +4 -2
  118. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  119. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  120. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  121. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  122. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  123. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  124. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  125. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  126. snowflake/ml/modeling/decomposition/pca.py +4 -2
  127. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  128. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  129. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  130. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  131. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  132. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  133. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  134. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  135. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  136. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  137. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  138. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  139. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  140. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  141. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  142. snowflake/ml/modeling/manifold/isomap.py +4 -2
  143. snowflake/ml/modeling/manifold/mds.py +4 -2
  144. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  145. snowflake/ml/modeling/manifold/tsne.py +4 -2
  146. snowflake/ml/modeling/metrics/ranking.py +3 -0
  147. snowflake/ml/modeling/metrics/regression.py +3 -0
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  150. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  151. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  152. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  153. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  154. snowflake/ml/modeling/pipeline/pipeline.py +5 -4
  155. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  156. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  157. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  158. snowflake/ml/registry/_manager/model_manager.py +16 -3
  159. snowflake/ml/registry/registry.py +100 -13
  160. snowflake/ml/version.py +1 -1
  161. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
  162. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
  163. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  164. snowflake/ml/_internal/lineage/data_source.py +0 -10
  165. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,22 @@
1
1
  import collections
2
2
  import copy
3
3
  import pathlib
4
+ import warnings
4
5
  from typing import List, Optional, cast
5
6
 
6
7
  import yaml
7
8
 
8
- from snowflake.ml._internal.lineage import data_source
9
+ from snowflake.ml.data import data_source
9
10
  from snowflake.ml.model import type_hints
10
11
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
11
12
  from snowflake.ml.model._model_composer.model_method import (
12
13
  function_generator,
13
14
  model_method,
14
15
  )
15
- from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
16
- from snowflake.snowpark import Session
16
+ from snowflake.ml.model._packager.model_meta import (
17
+ model_meta as model_meta_api,
18
+ model_meta_schema,
19
+ )
17
20
 
18
21
 
19
22
  class ModelManifest:
@@ -33,9 +36,8 @@ class ModelManifest:
33
36
 
34
37
  def save(
35
38
  self,
36
- session: Session,
37
39
  model_meta: model_meta_api.ModelMetadata,
38
- model_file_rel_path: pathlib.PurePosixPath,
40
+ model_rel_path: pathlib.PurePosixPath,
39
41
  options: Optional[type_hints.ModelSaveOption] = None,
40
42
  data_sources: Optional[List[data_source.DataSource]] = None,
41
43
  ) -> None:
@@ -44,10 +46,10 @@ class ModelManifest:
44
46
 
45
47
  runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
46
48
  runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
47
- runtime_to_use.imports.append(model_file_rel_path)
49
+ runtime_to_use.imports.append(str(model_rel_path) + "/")
48
50
  runtime_dict = runtime_to_use.save(self.workspace_path)
49
51
 
50
- self.function_generator = function_generator.FunctionGenerator(model_file_rel_path=model_file_rel_path)
52
+ self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
51
53
  self.methods: List[model_method.ModelMethod] = []
52
54
  for target_method in model_meta.signatures.keys():
53
55
  method = model_method.ModelMethod(
@@ -55,6 +57,9 @@ class ModelManifest:
55
57
  target_method=target_method,
56
58
  runtime_name=self._DEFAULT_RUNTIME_NAME,
57
59
  function_generator=self.function_generator,
60
+ is_partitioned_function=model_meta.function_properties.get(target_method, {}).get(
61
+ model_meta_schema.FunctionProperties.PARTITIONED.value, False
62
+ ),
58
63
  options=model_method.get_model_method_options_from_options(options, target_method),
59
64
  )
60
65
 
@@ -69,6 +74,16 @@ class ModelManifest:
69
74
  "In this case, set case_sensitive as True for those methods to distinguish them."
70
75
  )
71
76
 
77
+ dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
78
+ if options.get("include_pip_dependencies"):
79
+ warnings.warn(
80
+ "`include_pip_dependencies` specified as True: pip dependencies will be included and may not"
81
+ "be warehouse-compabible. The model may need to be run in SPCS.",
82
+ category=UserWarning,
83
+ stacklevel=1,
84
+ )
85
+ dependencies["pip"] = runtime_dict["dependencies"]["pip"]
86
+
72
87
  manifest_dict = model_manifest_schema.ModelManifestDict(
73
88
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
74
89
  runtimes={
@@ -76,9 +91,7 @@ class ModelManifest:
76
91
  language="PYTHON",
77
92
  version=runtime_to_use.runtime_env.python_version,
78
93
  imports=runtime_dict["imports"],
79
- dependencies=model_manifest_schema.ModelRuntimeDependenciesDict(
80
- conda=runtime_dict["dependencies"]["conda"]
81
- ),
94
+ dependencies=dependencies,
82
95
  )
83
96
  },
84
97
  methods=[
@@ -121,12 +134,13 @@ class ModelManifest:
121
134
  result = []
122
135
  if data_sources:
123
136
  for source in data_sources:
124
- result.append(
125
- model_manifest_schema.LineageSourceDict(
126
- # Currently, we only support lineage from Dataset.
127
- type=model_manifest_schema.LineageSourceTypes.DATASET.value,
128
- entity=source.fully_qualified_name,
129
- version=source.version,
137
+ if isinstance(source, data_source.DatasetInfo):
138
+ result.append(
139
+ model_manifest_schema.LineageSourceDict(
140
+ # Currently, we only support lineage from Dataset.
141
+ type=model_manifest_schema.LineageSourceTypes.DATASET.value,
142
+ entity=source.fully_qualified_name,
143
+ version=source.version,
144
+ )
130
145
  )
131
- )
132
146
  return result
@@ -18,7 +18,8 @@ class ModelMethodFunctionTypes(enum.Enum):
18
18
 
19
19
 
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
- conda: Required[str]
21
+ conda: NotRequired[str]
22
+ pip: NotRequired[str]
22
23
 
23
24
 
24
25
  class ModelRuntimeDict(TypedDict):
@@ -3,7 +3,14 @@ from typing import Optional, TypedDict
3
3
 
4
4
  from typing_extensions import NotRequired
5
5
 
6
+ from snowflake.ml._internal.exceptions import (
7
+ error_codes,
8
+ exceptions as snowml_exceptions,
9
+ )
6
10
  from snowflake.ml.model import type_hints
11
+ from snowflake.ml.model._model_composer.model_manifest.model_manifest_schema import (
12
+ ModelMethodFunctionTypes,
13
+ )
7
14
 
8
15
 
9
16
  class FunctionGenerateOptions(TypedDict):
@@ -26,15 +33,16 @@ class FunctionGenerator:
26
33
 
27
34
  def __init__(
28
35
  self,
29
- model_file_rel_path: pathlib.PurePosixPath,
36
+ model_dir_rel_path: pathlib.PurePosixPath,
30
37
  ) -> None:
31
- self.model_file_rel_path = model_file_rel_path
38
+ self.model_dir_rel_path = model_dir_rel_path
32
39
 
33
40
  def generate(
34
41
  self,
35
42
  function_file_path: pathlib.Path,
36
43
  target_method: str,
37
44
  function_type: str,
45
+ is_partitioned_function: bool = False,
38
46
  options: Optional[FunctionGenerateOptions] = None,
39
47
  ) -> None:
40
48
  import importlib_resources
@@ -42,7 +50,15 @@ class FunctionGenerator:
42
50
  if options is None:
43
51
  options = {}
44
52
 
45
- template_filename = f"infer_{function_type.lower()}.py_template"
53
+ if is_partitioned_function:
54
+ if function_type != ModelMethodFunctionTypes.TABLE_FUNCTION.value:
55
+ raise snowml_exceptions.SnowflakeMLException(
56
+ error_code=error_codes.INVALID_DATA,
57
+ original_exception=ValueError("Partitioned inference api functions must have type TABLE_FUNCTION."),
58
+ )
59
+ template_filename = "infer_partitioned.py_template"
60
+ else:
61
+ template_filename = f"infer_{function_type.lower()}.py_template"
46
62
 
47
63
  function_template = (
48
64
  importlib_resources.files("snowflake.ml.model._model_composer.model_method")
@@ -51,7 +67,7 @@ class FunctionGenerator:
51
67
  )
52
68
 
53
69
  udf_code = function_template.format(
54
- model_file_name=self.model_file_rel_path.name,
70
+ model_dir_name=self.model_dir_rel_path.name,
55
71
  target_method=target_method,
56
72
  max_batch_size=options.get("max_batch_size", None),
57
73
  function_name=FunctionGenerator.FUNCTION_NAME,
@@ -1,12 +1,7 @@
1
- import fcntl
2
1
  import functools
3
2
  import inspect
4
3
  import os
5
4
  import sys
6
- import threading
7
- import zipfile
8
- from types import TracebackType
9
- from typing import Optional, Type
10
5
 
11
6
  import anyio
12
7
  import pandas as pd
@@ -15,42 +10,18 @@ from _snowflake import vectorized
15
10
  from snowflake.ml.model._packager import model_packager
16
11
 
17
12
 
18
- class FileLock:
19
- def __enter__(self) -> None:
20
- self._lock = threading.Lock()
21
- self._lock.acquire()
22
- self._fd = open("/tmp/lockfile.LOCK", "w+")
23
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
-
25
- def __exit__(
26
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
- ) -> None:
28
- self._fd.close()
29
- self._lock.release()
30
-
31
-
32
13
  # User-defined parameters
33
- MODEL_FILE_NAME = "{model_file_name}"
14
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
34
15
  TARGET_METHOD = "{target_method}"
35
16
  MAX_BATCH_SIZE = {max_batch_size}
36
17
 
37
-
38
18
  # Retrieve the model
39
19
  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
20
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
-
42
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
- extracted = "/tmp/models"
45
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
-
47
- with FileLock():
48
- if not os.path.isdir(extracted_model_dir_path):
49
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
- myzip.extractall(extracted_model_dir_path)
21
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
51
22
 
52
23
  # Load the model
53
- pk = model_packager.ModelPackager(extracted_model_dir_path)
24
+ pk = model_packager.ModelPackager(model_dir_path)
54
25
  pk.load(as_custom_model=True)
55
26
  assert pk.model, "model is not loaded"
56
27
  assert pk.meta, "model metadata is not loaded"
@@ -0,0 +1,55 @@
1
+ import fcntl
2
+ import functools
3
+ import inspect
4
+ import os
5
+ import sys
6
+ import threading
7
+ import zipfile
8
+ from types import TracebackType
9
+ from typing import Optional, Type
10
+
11
+ import anyio
12
+ import pandas as pd
13
+ from _snowflake import vectorized
14
+
15
+ from snowflake.ml.model._packager import model_packager
16
+
17
+
18
+ # User-defined parameters
19
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
20
+ TARGET_METHOD = "{target_method}"
21
+ MAX_BATCH_SIZE = {max_batch_size}
22
+
23
+ # Retrieve the model
24
+ IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
25
+ import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
26
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
27
+
28
+ # Load the model
29
+ pk = model_packager.ModelPackager(model_dir_path)
30
+ pk.load(as_custom_model=True)
31
+ assert pk.model, "model is not loaded"
32
+ assert pk.meta, "model metadata is not loaded"
33
+
34
+ # Determine the actual runner
35
+ model = pk.model
36
+ meta = pk.meta
37
+ func = getattr(model, TARGET_METHOD)
38
+ if inspect.iscoroutinefunction(func):
39
+ runner = functools.partial(anyio.run, func)
40
+ else:
41
+ runner = functools.partial(func)
42
+
43
+ # Determine preprocess parameters
44
+ features = meta.signatures[TARGET_METHOD].inputs
45
+ input_cols = [feature.name for feature in features]
46
+ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
47
+
48
+
49
+ # Actual table function
50
+ class {function_name}:
51
+ @vectorized(input=pd.DataFrame)
52
+ def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
53
+ df.columns = input_cols
54
+ input_df = df.astype(dtype=dtype_map)
55
+ return runner(input_df[input_cols])
@@ -1,12 +1,7 @@
1
- import fcntl
2
1
  import functools
3
2
  import inspect
4
3
  import os
5
4
  import sys
6
- import threading
7
- import zipfile
8
- from types import TracebackType
9
- from typing import Optional, Type
10
5
 
11
6
  import anyio
12
7
  import pandas as pd
@@ -15,42 +10,18 @@ from _snowflake import vectorized
15
10
  from snowflake.ml.model._packager import model_packager
16
11
 
17
12
 
18
- class FileLock:
19
- def __enter__(self) -> None:
20
- self._lock = threading.Lock()
21
- self._lock.acquire()
22
- self._fd = open("/tmp/lockfile.LOCK", "w+")
23
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
-
25
- def __exit__(
26
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
- ) -> None:
28
- self._fd.close()
29
- self._lock.release()
30
-
31
-
32
13
  # User-defined parameters
33
- MODEL_FILE_NAME = "{model_file_name}"
14
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
34
15
  TARGET_METHOD = "{target_method}"
35
16
  MAX_BATCH_SIZE = {max_batch_size}
36
17
 
37
-
38
18
  # Retrieve the model
39
19
  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
20
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
-
42
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
- extracted = "/tmp/models"
45
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
-
47
- with FileLock():
48
- if not os.path.isdir(extracted_model_dir_path):
49
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
- myzip.extractall(extracted_model_dir_path)
21
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
51
22
 
52
23
  # Load the model
53
- pk = model_packager.ModelPackager(extracted_model_dir_path)
24
+ pk = model_packager.ModelPackager(model_dir_path)
54
25
  pk.load(as_custom_model=True)
55
26
  assert pk.model, "model is not loaded"
56
27
  assert pk.meta, "model metadata is not loaded"
@@ -72,8 +43,8 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
72
43
 
73
44
  # Actual table function
74
45
  class {function_name}:
75
- @vectorized(input=pd.DataFrame)
76
- def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
46
+ @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
47
+ def process(self, df: pd.DataFrame) -> pd.DataFrame:
77
48
  df.columns = input_cols
78
49
  input_df = df.astype(dtype=dtype_map)
79
50
  return runner(input_df[input_cols])
@@ -26,13 +26,14 @@ class ModelMethodOptions(TypedDict):
26
26
  def get_model_method_options_from_options(
27
27
  options: type_hints.ModelSaveOption, target_method: str
28
28
  ) -> ModelMethodOptions:
29
+ default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
30
+ if options.get("enable_explainability", False) and target_method.startswith("explain"):
31
+ default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
29
32
  method_option = options.get("method_options", {}).get(target_method, {})
30
- global_function_type = options.get("function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value)
33
+ global_function_type = options.get("function_type", default_function_type)
31
34
  function_type = method_option.get("function_type", global_function_type)
32
35
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
33
- raise NotImplementedError
34
-
35
- # TODO(TH): enforce minimum snowflake version
36
+ raise NotImplementedError(f"Function type {function_type} is not supported.")
36
37
 
37
38
  return ModelMethodOptions(
38
39
  case_sensitive=method_option.get("case_sensitive", False),
@@ -47,10 +48,9 @@ class ModelMethod:
47
48
  Attributes:
48
49
  model_meta: Model Metadata.
49
50
  target_method: Original target method name to call with the model.
50
- method_name: The actual method name registered in manifest and used in SQL.
51
-
52
- function_generator: Function file generator.
53
51
  runtime_name: Name of the Model Runtime to run the method.
52
+ function_generator: Function file generator.
53
+ is_partitioned_function: Whether the model method function is partitioned.
54
54
 
55
55
  options: Model Method Options.
56
56
  """
@@ -63,11 +63,13 @@ class ModelMethod:
63
63
  target_method: str,
64
64
  runtime_name: str,
65
65
  function_generator: function_generator.FunctionGenerator,
66
+ is_partitioned_function: bool = False,
66
67
  options: Optional[ModelMethodOptions] = None,
67
68
  ) -> None:
68
69
  self.model_meta = model_meta
69
70
  self.target_method = target_method
70
71
  self.function_generator = function_generator
72
+ self.is_partitioned_function = is_partitioned_function
71
73
  self.runtime_name = runtime_name
72
74
  self.options = options or {}
73
75
  try:
@@ -111,6 +113,7 @@ class ModelMethod:
111
113
  workspace_path / ModelMethod.FUNCTIONS_DIR_REL_PATH / f"{self.target_method}.py",
112
114
  self.target_method,
113
115
  self.function_type,
116
+ self.is_partitioned_function,
114
117
  options=options,
115
118
  )
116
119
  input_list = [
@@ -1,4 +1,5 @@
1
1
  from abc import abstractmethod
2
+ from enum import Enum
2
3
  from typing import Dict, Generic, Optional, Protocol, Type, final
3
4
 
4
5
  from typing_extensions import TypeGuard, Unpack
@@ -8,6 +9,15 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
8
9
  from snowflake.ml.model._packager.model_meta import model_meta
9
10
 
10
11
 
12
+ class ModelObjective(Enum):
13
+ # This is not getting stored anywhere as metadata yet so it should be fine to slowly extend it for better coverage
14
+ UNKNOWN = "unknown"
15
+ BINARY_CLASSIFICATION = "binary_classification"
16
+ MULTI_CLASSIFICATION = "multi_classification"
17
+ REGRESSION = "regression"
18
+ RANKING = "ranking"
19
+
20
+
11
21
  class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
12
22
  HANDLER_TYPE: model_types.SupportedModelHandlerType
13
23
  HANDLER_VERSION: str
@@ -16,7 +26,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
16
26
 
17
27
  @classmethod
18
28
  @abstractmethod
19
- def can_handle(cls, model: model_types.SupportedDataType) -> TypeGuard[model_types._ModelType]:
29
+ def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard[model_types._ModelType]:
20
30
  """Whether this handler could support the type of the `model`.
21
31
 
22
32
  Args:
@@ -75,7 +85,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
75
85
  name: str,
76
86
  model_meta: model_meta.ModelMetadata,
77
87
  model_blobs_dir_path: str,
78
- **kwargs: Unpack[model_types.ModelLoadOption],
88
+ **kwargs: Unpack[model_types.BaseModelLoadOption],
79
89
  ) -> model_types._ModelType:
80
90
  """Load the model into memory.
81
91
 
@@ -96,7 +106,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
96
106
  cls,
97
107
  raw_model: model_types._ModelType,
98
108
  model_meta: model_meta.ModelMetadata,
99
- **kwargs: Unpack[model_types.ModelLoadOption],
109
+ **kwargs: Unpack[model_types.BaseModelLoadOption],
100
110
  ) -> custom_model.CustomModel:
101
111
  """Create a custom model class wrap for unified interface when being deployed. The predict method will be
102
112
  re-targeted based on target_method metadata.
@@ -1,4 +1,9 @@
1
- from typing import Callable, Iterable, Optional, Sequence, cast
1
+ import json
2
+ from typing import Any, Callable, Iterable, Optional, Sequence, cast
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import pandas as pd
2
7
 
3
8
  from snowflake.ml.model import model_signature, type_hints as model_types
4
9
  from snowflake.ml.model._packager.model_meta import model_meta
@@ -36,6 +41,25 @@ def validate_signature(
36
41
  predictions_df = get_prediction_fn(target_method, local_sample_input)
37
42
  sig = model_signature.infer_signature(local_sample_input, predictions_df)
38
43
  model_meta.signatures[target_method] = sig
44
+
45
+ return model_meta
46
+
47
+
48
+ def add_explain_method_signature(
49
+ model_meta: model_meta.ModelMetadata,
50
+ explain_method: str,
51
+ target_method: str,
52
+ output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
53
+ ) -> model_meta.ModelMetadata:
54
+ if target_method not in model_meta.signatures:
55
+ raise ValueError(f"Signature for target method {target_method} is missing")
56
+ inputs = model_meta.signatures[target_method].inputs
57
+ model_meta.signatures[explain_method] = model_signature.ModelSignature(
58
+ inputs=inputs,
59
+ outputs=[
60
+ model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs
61
+ ],
62
+ )
39
63
  return model_meta
40
64
 
41
65
 
@@ -55,3 +79,37 @@ def validate_target_methods(model: model_types.SupportedModelType, target_method
55
79
  for method_name in target_methods:
56
80
  if not _is_callable(model, method_name):
57
81
  raise ValueError(f"Target method {method_name} is not callable or does not exist in the model.")
82
+
83
+
84
+ def get_num_classes_if_exists(model: model_types.SupportedModelType) -> int:
85
+ num_classes = getattr(model, "classes_", [])
86
+ return len(num_classes)
87
+
88
+
89
+ def convert_explanations_to_2D_df(
90
+ model: model_types.SupportedModelType, explanations: npt.NDArray[Any]
91
+ ) -> pd.DataFrame:
92
+ if explanations.ndim != 3:
93
+ return pd.DataFrame(explanations)
94
+
95
+ if hasattr(model, "classes_"):
96
+ classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
97
+ len_classes = len(classes_list)
98
+ if explanations.shape[2] != len_classes:
99
+ raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
100
+ else:
101
+ classes_list = [i for i in range(explanations.shape[2])]
102
+ exp_2d = []
103
+ # TODO (SNOW-1549044): Optimize this
104
+ for row in explanations:
105
+ col_list = []
106
+ for column in row:
107
+ class_explanations = {}
108
+ for cl, cl_exp in zip(classes_list, column):
109
+ if isinstance(cl, (int, np.integer)):
110
+ cl = int(cl)
111
+ class_explanations[cl] = cl_exp
112
+ col_list.append(json.dumps(class_explanations))
113
+ exp_2d.append(col_list)
114
+
115
+ return pd.DataFrame(exp_2d)
@@ -33,6 +33,22 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
33
33
  MODELE_BLOB_FILE_OR_DIR = "model.bin"
34
34
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
35
35
 
36
+ @classmethod
37
+ def get_model_objective(cls, model: "catboost.CatBoost") -> _base.ModelObjective:
38
+ import catboost
39
+
40
+ if isinstance(model, catboost.CatBoostClassifier):
41
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
42
+ if num_classes == 2:
43
+ return _base.ModelObjective.BINARY_CLASSIFICATION
44
+ return _base.ModelObjective.MULTI_CLASSIFICATION
45
+ if isinstance(model, catboost.CatBoostRanker):
46
+ return _base.ModelObjective.RANKING
47
+ if isinstance(model, catboost.CatBoostRegressor):
48
+ return _base.ModelObjective.REGRESSION
49
+ # TODO: Find out model type from the generic Catboost Model
50
+ return _base.ModelObjective.UNKNOWN
51
+
36
52
  @classmethod
37
53
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
38
54
  return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
@@ -89,6 +105,16 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
89
105
  sample_input_data=sample_input_data,
90
106
  get_prediction_fn=get_prediction,
91
107
  )
108
+ if kwargs.get("enable_explainability", False):
109
+ output_type = model_signature.DataType.DOUBLE
110
+ if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
111
+ output_type = model_signature.DataType.STRING
112
+ model_meta = handlers_utils.add_explain_method_signature(
113
+ model_meta=model_meta,
114
+ explain_method="explain",
115
+ target_method="predict",
116
+ output_return_type=output_type,
117
+ )
92
118
 
93
119
  model_blob_path = os.path.join(model_blobs_dir_path, name)
94
120
  os.makedirs(model_blob_path, exist_ok=True)
@@ -112,6 +138,11 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
112
138
  ],
113
139
  check_local_version=True,
114
140
  )
141
+ if kwargs.get("enable_explainability", False):
142
+ model_meta.env.include_if_absent(
143
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
144
+ check_local_version=True,
145
+ )
115
146
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
116
147
 
117
148
  return None
@@ -122,7 +153,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
122
153
  name: str,
123
154
  model_meta: model_meta_api.ModelMetadata,
124
155
  model_blobs_dir_path: str,
125
- **kwargs: Unpack[model_types.ModelLoadOption],
156
+ **kwargs: Unpack[model_types.CatBoostModelLoadOptions],
126
157
  ) -> "catboost.CatBoost":
127
158
  import catboost
128
159
 
@@ -157,7 +188,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
157
188
  cls,
158
189
  raw_model: "catboost.CatBoost",
159
190
  model_meta: model_meta_api.ModelMetadata,
160
- **kwargs: Unpack[model_types.ModelLoadOption],
191
+ **kwargs: Unpack[model_types.CatBoostModelLoadOptions],
161
192
  ) -> custom_model.CustomModel:
162
193
  import catboost
163
194
 
@@ -186,6 +217,17 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
186
217
 
187
218
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
188
219
 
220
+ @custom_model.inference_api
221
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
222
+ import shap
223
+
224
+ explainer = shap.TreeExplainer(raw_model)
225
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
226
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
227
+
228
+ if target_method == "explain":
229
+ return explain_fn
230
+
189
231
  return fn
190
232
 
191
233
  type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
@@ -17,6 +17,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
17
17
  from snowflake.ml.model._packager.model_meta import (
18
18
  model_blob_meta,
19
19
  model_meta as model_meta_api,
20
+ model_meta_schema,
20
21
  )
21
22
 
22
23
 
@@ -68,6 +69,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
68
69
  predictions_df = target_method(model, sample_input_data)
69
70
  return predictions_df
70
71
 
72
+ for func_name in model._get_partitioned_infer_methods():
73
+ function_properties = model_meta.function_properties.get(func_name, {})
74
+ function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
75
+ model_meta.function_properties[func_name] = function_properties
76
+
71
77
  if not is_sub_model:
72
78
  model_meta = handlers_utils.validate_signature(
73
79
  model=model,
@@ -101,14 +107,16 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
101
107
 
102
108
  # Make sure that the module where the model is defined get pickled by value as well.
103
109
  cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
104
- picked_obj = (model.__class__, model.context)
110
+ pickled_obj = (model.__class__, model.context)
105
111
  with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
106
- cloudpickle.dump(picked_obj, f)
112
+ cloudpickle.dump(pickled_obj, f)
113
+ # model meta will be saved by the context manager
107
114
  model_meta.models[name] = model_blob_meta.ModelBlobMeta(
108
115
  name=name,
109
116
  model_type=cls.HANDLER_TYPE,
110
117
  path=cls.MODELE_BLOB_FILE_OR_DIR,
111
118
  handler_version=cls.HANDLER_VERSION,
119
+ function_properties=model_meta.function_properties,
112
120
  artifacts={
113
121
  name: pathlib.Path(
114
122
  os.path.join(cls.MODEL_ARTIFACTS_DIR, os.path.basename(os.path.normpath(path=uri)))
@@ -128,7 +136,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
128
136
  name: str,
129
137
  model_meta: model_meta_api.ModelMetadata,
130
138
  model_blobs_dir_path: str,
131
- **kwargs: Unpack[model_types.ModelLoadOption],
139
+ **kwargs: Unpack[model_types.CustomModelLoadOption],
132
140
  ) -> "custom_model.CustomModel":
133
141
  model_blob_path = os.path.join(model_blobs_dir_path, name)
134
142
 
@@ -175,6 +183,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
175
183
  cls,
176
184
  raw_model: custom_model.CustomModel,
177
185
  model_meta: model_meta_api.ModelMetadata,
178
- **kwargs: Unpack[model_types.ModelLoadOption],
186
+ **kwargs: Unpack[model_types.CustomModelLoadOption],
179
187
  ) -> custom_model.CustomModel:
180
188
  return raw_model