snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/__init__.py +4 -1
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +281 -21
  4. snowflake/cortex/_extract_answer.py +0 -1
  5. snowflake/cortex/_sentiment.py +0 -1
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +12 -85
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  16. snowflake/ml/_internal/telemetry.py +38 -2
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  20. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  21. snowflake/ml/data/data_connector.py +133 -0
  22. snowflake/ml/data/data_ingestor.py +28 -0
  23. snowflake/ml/data/data_source.py +23 -0
  24. snowflake/ml/dataset/dataset.py +39 -32
  25. snowflake/ml/dataset/dataset_reader.py +18 -118
  26. snowflake/ml/feature_store/access_manager.py +7 -1
  27. snowflake/ml/feature_store/entity.py +19 -2
  28. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  29. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  30. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  31. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  32. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  34. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  35. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  36. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  37. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  38. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  39. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  40. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  43. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  44. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  45. snowflake/ml/feature_store/feature_store.py +987 -264
  46. snowflake/ml/feature_store/feature_view.py +228 -13
  47. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  48. snowflake/ml/fileset/fileset.py +2 -2
  49. snowflake/ml/fileset/snowfs.py +4 -15
  50. snowflake/ml/fileset/stage_fs.py +24 -18
  51. snowflake/ml/lineage/__init__.py +3 -0
  52. snowflake/ml/lineage/lineage_node.py +139 -0
  53. snowflake/ml/model/_client/model/model_impl.py +47 -14
  54. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  55. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  56. snowflake/ml/model/_client/sql/model.py +1 -0
  57. snowflake/ml/model/_client/sql/model_version.py +45 -2
  58. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  59. snowflake/ml/model/_model_composer/model_composer.py +15 -17
  60. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
  61. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  62. snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
  63. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  64. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
  65. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
  66. snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
  67. snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
  68. snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
  69. snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
  70. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  71. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  72. snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
  73. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  74. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  75. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  76. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  77. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  78. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  79. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  80. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  81. snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
  82. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  83. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  84. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  85. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  86. snowflake/ml/model/_packager/model_packager.py +9 -4
  87. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  88. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  89. snowflake/ml/model/custom_model.py +22 -2
  90. snowflake/ml/model/model_signature.py +4 -4
  91. snowflake/ml/model/type_hints.py +77 -4
  92. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
  93. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  94. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  95. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  96. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  97. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  98. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  99. snowflake/ml/modeling/cluster/birch.py +4 -2
  100. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  101. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  102. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  103. snowflake/ml/modeling/cluster/k_means.py +4 -2
  104. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  105. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  106. snowflake/ml/modeling/cluster/optics.py +4 -2
  107. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  108. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  109. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  110. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  111. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  112. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  113. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  114. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  115. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  116. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  117. snowflake/ml/modeling/covariance/oas.py +4 -2
  118. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  119. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  120. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  121. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  122. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  123. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  124. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  125. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  126. snowflake/ml/modeling/decomposition/pca.py +4 -2
  127. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  128. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  129. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  130. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  131. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  132. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  133. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  134. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  135. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  136. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  137. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  138. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  139. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  140. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  141. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  142. snowflake/ml/modeling/manifold/isomap.py +4 -2
  143. snowflake/ml/modeling/manifold/mds.py +4 -2
  144. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  145. snowflake/ml/modeling/manifold/tsne.py +4 -2
  146. snowflake/ml/modeling/metrics/ranking.py +3 -0
  147. snowflake/ml/modeling/metrics/regression.py +3 -0
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  150. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  151. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  152. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  153. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  154. snowflake/ml/modeling/pipeline/pipeline.py +5 -4
  155. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  156. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  157. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  158. snowflake/ml/registry/_manager/model_manager.py +16 -3
  159. snowflake/ml/registry/registry.py +100 -13
  160. snowflake/ml/version.py +1 -1
  161. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
  162. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
  163. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  164. snowflake/ml/_internal/lineage/data_source.py +0 -10
  165. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,7 @@ class ModelRuntime:
35
35
  self,
36
36
  name: str,
37
37
  env: model_env.ModelEnv,
38
- imports: Optional[List[pathlib.PurePosixPath]] = None,
38
+ imports: Optional[List[str]] = None,
39
39
  is_gpu: bool = False,
40
40
  loading_from_file: bool = False,
41
41
  ) -> None:
@@ -75,7 +75,7 @@ class ModelRuntime:
75
75
  snowpark_ml_lib_path = runtime_base_path / "snowflake-ml-python.zip"
76
76
  file_utils.zip_python_package(str(snowpark_ml_lib_path), "snowflake.ml")
77
77
  snowpark_ml_lib_rel_path = pathlib.PurePosixPath(snowpark_ml_lib_path.relative_to(packager_path).as_posix())
78
- self.imports.append(snowpark_ml_lib_rel_path)
78
+ self.imports.append(str(snowpark_ml_lib_rel_path))
79
79
 
80
80
  self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
81
81
  self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
@@ -108,6 +108,4 @@ class ModelRuntime:
108
108
  warnings.simplefilter("ignore")
109
109
  env.load_from_conda_file(packager_path / conda_env_rel_path)
110
110
  env.load_from_pip_file(packager_path / pip_requirements_rel_path)
111
- return ModelRuntime(
112
- name=name, env=env, imports=list(map(pathlib.PurePosixPath, loaded_dict["imports"])), loading_from_file=True
113
- )
111
+ return ModelRuntime(name=name, env=env, imports=loaded_dict["imports"], loading_from_file=True)
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import inspect
3
- from typing import Any, Callable, Coroutine, Dict, Generator, Optional
3
+ from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional
4
4
 
5
5
  import anyio
6
6
  import pandas as pd
@@ -168,7 +168,7 @@ class CustomModel:
168
168
  def _get_infer_methods(
169
169
  self,
170
170
  ) -> Generator[Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame], None, None]:
171
- """Returns all methods in CLS with DECORATOR as the outermost decorator."""
171
+ """Returns all methods in CLS with `inference_api` decorator as the outermost decorator."""
172
172
  for cls_method_str in dir(self):
173
173
  cls_method = getattr(self, cls_method_str)
174
174
  if getattr(cls_method, "_is_inference_api", False):
@@ -177,6 +177,18 @@ class CustomModel:
177
177
  else:
178
178
  raise TypeError("A non-method inference API function is not supported.")
179
179
 
180
+ def _get_partitioned_infer_methods(self) -> List[str]:
181
+ """Returns all methods in CLS with `partitioned_inference_api` as the outermost decorator."""
182
+ rv = []
183
+ for cls_method_str in dir(self):
184
+ cls_method = getattr(self, cls_method_str)
185
+ if getattr(cls_method, "_is_partitioned_inference_api", False):
186
+ if inspect.ismethod(cls_method):
187
+ rv.append(cls_method_str)
188
+ else:
189
+ raise TypeError("A non-method inference API function is not supported.")
190
+ return rv
191
+
180
192
 
181
193
  def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]) -> None:
182
194
  """Validate the user provided predict method.
@@ -219,3 +231,11 @@ def inference_api(
219
231
  ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
220
232
  func.__dict__["_is_inference_api"] = True
221
233
  return func
234
+
235
+
236
+ def partitioned_inference_api(
237
+ func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]
238
+ ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
239
+ func.__dict__["_is_inference_api"] = True
240
+ func.__dict__["_is_partitioned_inference_api"] = True
241
+ return func
@@ -232,7 +232,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
232
232
  ),
233
233
  )
234
234
  else:
235
- if isinstance(data_col[0], list):
235
+ if isinstance(data_col.iloc[0], list):
236
236
  if not ft_shape:
237
237
  raise snowml_exceptions.SnowflakeMLException(
238
238
  error_code=error_codes.INVALID_DATA,
@@ -266,7 +266,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
266
266
  ),
267
267
  )
268
268
 
269
- elif isinstance(data_col[0], np.ndarray):
269
+ elif isinstance(data_col.iloc[0], np.ndarray):
270
270
  if not ft_shape:
271
271
  raise snowml_exceptions.SnowflakeMLException(
272
272
  error_code=error_codes.INVALID_DATA,
@@ -297,7 +297,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
297
297
  ),
298
298
  )
299
299
 
300
- elif isinstance(data_col[0], str):
300
+ elif isinstance(data_col.iloc[0], str):
301
301
  if ft_shape is not None:
302
302
  raise snowml_exceptions.SnowflakeMLException(
303
303
  error_code=error_codes.INVALID_DATA,
@@ -316,7 +316,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
316
316
  ),
317
317
  )
318
318
 
319
- elif isinstance(data_col[0], bytes):
319
+ elif isinstance(data_col.iloc[0], bytes):
320
320
  if ft_shape is not None:
321
321
  raise snowml_exceptions.SnowflakeMLException(
322
322
  error_code=error_codes.INVALID_DATA,
@@ -232,11 +232,13 @@ class BaseModelSaveOption(TypedDict):
232
232
  _legacy_save: NotRequired[bool]
233
233
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
234
234
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
235
+ include_pip_dependencies: NotRequired[bool]
235
236
 
236
237
 
237
238
  class CatBoostModelSaveOptions(BaseModelSaveOption):
238
239
  target_methods: NotRequired[Sequence[str]]
239
240
  cuda_version: NotRequired[str]
241
+ enable_explainability: NotRequired[bool]
240
242
 
241
243
 
242
244
  class CustomModelSaveOption(BaseModelSaveOption):
@@ -250,10 +252,12 @@ class SKLModelSaveOptions(BaseModelSaveOption):
250
252
  class XGBModelSaveOptions(BaseModelSaveOption):
251
253
  target_methods: NotRequired[Sequence[str]]
252
254
  cuda_version: NotRequired[str]
255
+ enable_explainability: NotRequired[bool]
253
256
 
254
257
 
255
258
  class LGBMModelSaveOptions(BaseModelSaveOption):
256
259
  target_methods: NotRequired[Sequence[str]]
260
+ enable_explainability: NotRequired[bool]
257
261
 
258
262
 
259
263
  class SNOWModelSaveOptions(BaseModelSaveOption):
@@ -313,15 +317,84 @@ ModelSaveOption = Union[
313
317
  ]
314
318
 
315
319
 
316
- class ModelLoadOption(TypedDict):
317
- """Options for loading the model.
320
+ class BaseModelLoadOption(TypedDict):
321
+ """Options for loading the model."""
322
+
323
+ ...
324
+
325
+
326
+ class CatBoostModelLoadOptions(BaseModelLoadOption):
327
+ use_gpu: NotRequired[bool]
328
+
329
+
330
+ class CustomModelLoadOption(BaseModelLoadOption):
331
+ ...
332
+
333
+
334
+ class SKLModelLoadOptions(BaseModelLoadOption):
335
+ ...
336
+
337
+
338
+ class XGBModelLoadOptions(BaseModelLoadOption):
339
+ use_gpu: NotRequired[bool]
340
+
341
+
342
+ class LGBMModelLoadOptions(BaseModelLoadOption):
343
+ ...
344
+
345
+
346
+ class SNOWModelLoadOptions(BaseModelLoadOption):
347
+ ...
318
348
 
319
- use_gpu: Enable GPU-specific loading logic.
320
- """
321
349
 
350
+ class PyTorchLoadOptions(BaseModelLoadOption):
322
351
  use_gpu: NotRequired[bool]
323
352
 
324
353
 
354
+ class TorchScriptLoadOptions(BaseModelLoadOption):
355
+ use_gpu: NotRequired[bool]
356
+
357
+
358
+ class TensorflowLoadOptions(BaseModelLoadOption):
359
+ ...
360
+
361
+
362
+ class MLFlowLoadOptions(BaseModelLoadOption):
363
+ ...
364
+
365
+
366
+ class HuggingFaceLoadOptions(BaseModelLoadOption):
367
+ use_gpu: NotRequired[bool]
368
+ device_map: NotRequired[str]
369
+ device: NotRequired[Union[str, int]]
370
+
371
+
372
+ class SentenceTransformersLoadOptions(BaseModelLoadOption):
373
+ use_gpu: NotRequired[bool]
374
+
375
+
376
+ class LLMLoadOptions(BaseModelLoadOption):
377
+ ...
378
+
379
+
380
+ ModelLoadOption = Union[
381
+ BaseModelLoadOption,
382
+ CatBoostModelLoadOptions,
383
+ CustomModelLoadOption,
384
+ LGBMModelLoadOptions,
385
+ SKLModelLoadOptions,
386
+ XGBModelLoadOptions,
387
+ SNOWModelLoadOptions,
388
+ PyTorchLoadOptions,
389
+ TorchScriptLoadOptions,
390
+ TensorflowLoadOptions,
391
+ MLFlowLoadOptions,
392
+ HuggingFaceLoadOptions,
393
+ SentenceTransformersLoadOptions,
394
+ LLMLoadOptions,
395
+ ]
396
+
397
+
325
398
  class SnowparkContainerServiceDeployDetails(TypedDict):
326
399
  """
327
400
  Attributes:
@@ -41,7 +41,7 @@ cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snow
41
41
 
42
42
  _PROJECT = "ModelDevelopment"
43
43
  DEFAULT_UDTF_NJOBS = 3
44
- ENABLE_EFFICIENT_MEMORY_USAGE = False
44
+ ENABLE_EFFICIENT_MEMORY_USAGE = True
45
45
  _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
46
46
 
47
47
 
@@ -377,6 +377,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
377
377
  anonymous=True,
378
378
  imports=imports, # type: ignore[arg-type]
379
379
  statement_params=sproc_statement_params,
380
+ execute_as="caller",
380
381
  )
381
382
  def _distributed_search(
382
383
  session: Session,
@@ -782,6 +783,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
782
783
  anonymous=True,
783
784
  imports=imports, # type: ignore[arg-type]
784
785
  statement_params=sproc_statement_params,
786
+ execute_as="caller",
785
787
  )
786
788
  def _distributed_search(
787
789
  session: Session,
@@ -83,7 +83,19 @@ def _load_data_into_udf() -> Tuple[
83
83
  with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
84
84
  fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
85
85
 
86
- # convert dataframe to numpy would save memory consumption
86
+ # Convert dataframe to numpy would save memory consumption
87
+ # Except for Pipeline, we need to keep the dataframe for the column names
88
+ from sklearn.pipeline import Pipeline
89
+ if isinstance(base_estimator, Pipeline):
90
+ return (
91
+ df[CONSTANTS['input_cols']],
92
+ df[CONSTANTS['label_cols']].squeeze(),
93
+ indices,
94
+ params_to_evaluate,
95
+ base_estimator,
96
+ fit_and_score_kwargs,
97
+ CONSTANTS
98
+ )
87
99
  return (
88
100
  df[CONSTANTS['input_cols']].to_numpy(),
89
101
  df[CONSTANTS['label_cols']].squeeze().to_numpy(),
@@ -286,6 +286,7 @@ class SnowparkTransformHandlers:
286
286
  session=session,
287
287
  statement_params=statement_params,
288
288
  anonymous=True,
289
+ execute_as="caller",
289
290
  )
290
291
  def score_wrapper_sproc(
291
292
  session: Session,
@@ -207,6 +207,7 @@ class SnowparkModelTrainer:
207
207
  session=self.session,
208
208
  statement_params=statement_params,
209
209
  anonymous=True,
210
+ execute_as="caller",
210
211
  )
211
212
 
212
213
  return fit_wrapper_sproc
@@ -236,6 +237,7 @@ class SnowparkModelTrainer:
236
237
  replace=True,
237
238
  session=self.session,
238
239
  statement_params=statement_params,
240
+ execute_as="caller",
239
241
  )
240
242
 
241
243
  self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined]
@@ -493,6 +495,7 @@ class SnowparkModelTrainer:
493
495
  session=self.session,
494
496
  statement_params=statement_params,
495
497
  anonymous=True,
498
+ execute_as="caller",
496
499
  )
497
500
 
498
501
  return fit_predict_wrapper_sproc
@@ -524,6 +527,7 @@ class SnowparkModelTrainer:
524
527
  replace=True,
525
528
  session=self.session,
526
529
  statement_params=statement_params,
530
+ execute_as="caller",
527
531
  )
528
532
 
529
533
  self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
@@ -550,6 +554,7 @@ class SnowparkModelTrainer:
550
554
  session=self.session,
551
555
  statement_params=statement_params,
552
556
  anonymous=True,
557
+ execute_as="caller",
553
558
  )
554
559
  return fit_transform_wrapper_sproc
555
560
 
@@ -580,6 +585,7 @@ class SnowparkModelTrainer:
580
585
  replace=True,
581
586
  session=self.session,
582
587
  statement_params=statement_params,
588
+ execute_as="caller",
583
589
  )
584
590
 
585
591
  self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
@@ -303,6 +303,7 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
303
303
  statement_params=statement_params,
304
304
  anonymous=True,
305
305
  imports=list(import_file_paths),
306
+ execute_as="caller",
306
307
  ) # type: ignore[misc]
307
308
  def fit_wrapper_sproc(
308
309
  session: Session,
@@ -76,8 +76,10 @@ class AffinityPropagation(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class AgglomerativeClustering(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class Birch(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class BisectingKMeans(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class DBSCAN(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class FeatureAgglomeration(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class KMeans(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class MeanShift(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class MiniBatchKMeans(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class OPTICS(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class SpectralBiclustering(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class SpectralClustering(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class SpectralCoclustering(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class ColumnTransformer(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class EllipticEnvelope(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class EmpiricalCovariance(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class GraphicalLasso(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class GraphicalLassoCV(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class LedoitWolf(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class MinCovDet(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -76,8 +76,10 @@ class OAS(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must