snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,25 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import row, session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import row
9
6
 
10
7
 
11
- class ModuleTagSQLClient:
12
- def __init__(
13
- self,
14
- session: session.Session,
15
- *,
16
- database_name: sql_identifier.SqlIdentifier,
17
- schema_name: sql_identifier.SqlIdentifier,
18
- ) -> None:
19
- self._session = session
20
- self._database_name = database_name
21
- self._schema_name = schema_name
22
-
23
- def __eq__(self, __value: object) -> bool:
24
- if not isinstance(__value, ModuleTagSQLClient):
25
- return False
26
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
27
-
28
- def fully_qualified_module_name(
29
- self,
30
- module_name: sql_identifier.SqlIdentifier,
31
- ) -> str:
32
- return identifier.get_schema_level_object_identifier(
33
- self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier()
34
- )
35
-
8
+ class ModuleTagSQLClient(_base._BaseSQLClient):
36
9
  def set_tag_on_model(
37
10
  self,
38
- model_name: sql_identifier.SqlIdentifier,
39
11
  *,
40
- tag_database_name: sql_identifier.SqlIdentifier,
41
- tag_schema_name: sql_identifier.SqlIdentifier,
12
+ database_name: Optional[sql_identifier.SqlIdentifier],
13
+ schema_name: Optional[sql_identifier.SqlIdentifier],
14
+ model_name: sql_identifier.SqlIdentifier,
15
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
16
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
42
17
  tag_name: sql_identifier.SqlIdentifier,
43
18
  tag_value: str,
44
19
  statement_params: Optional[Dict[str, Any]] = None,
45
20
  ) -> None:
46
- fq_model_name = self.fully_qualified_module_name(model_name)
47
- fq_tag_name = identifier.get_schema_level_object_identifier(
48
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
49
- )
21
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
22
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
50
23
  query_result_checker.SqlResultValidator(
51
24
  self._session,
52
25
  f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$",
@@ -55,17 +28,17 @@ class ModuleTagSQLClient:
55
28
 
56
29
  def unset_tag_on_model(
57
30
  self,
58
- model_name: sql_identifier.SqlIdentifier,
59
31
  *,
60
- tag_database_name: sql_identifier.SqlIdentifier,
61
- tag_schema_name: sql_identifier.SqlIdentifier,
32
+ database_name: Optional[sql_identifier.SqlIdentifier],
33
+ schema_name: Optional[sql_identifier.SqlIdentifier],
34
+ model_name: sql_identifier.SqlIdentifier,
35
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
36
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
62
37
  tag_name: sql_identifier.SqlIdentifier,
63
38
  statement_params: Optional[Dict[str, Any]] = None,
64
39
  ) -> None:
65
- fq_model_name = self.fully_qualified_module_name(model_name)
66
- fq_tag_name = identifier.get_schema_level_object_identifier(
67
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
68
- )
40
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
41
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
69
42
  query_result_checker.SqlResultValidator(
70
43
  self._session,
71
44
  f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}",
@@ -74,21 +47,21 @@ class ModuleTagSQLClient:
74
47
 
75
48
  def get_tag_value(
76
49
  self,
77
- module_name: sql_identifier.SqlIdentifier,
78
50
  *,
79
- tag_database_name: sql_identifier.SqlIdentifier,
80
- tag_schema_name: sql_identifier.SqlIdentifier,
51
+ database_name: Optional[sql_identifier.SqlIdentifier],
52
+ schema_name: Optional[sql_identifier.SqlIdentifier],
53
+ model_name: sql_identifier.SqlIdentifier,
54
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
55
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
81
56
  tag_name: sql_identifier.SqlIdentifier,
82
57
  statement_params: Optional[Dict[str, Any]] = None,
83
58
  ) -> row.Row:
84
- fq_module_name = self.fully_qualified_module_name(module_name)
85
- fq_tag_name = identifier.get_schema_level_object_identifier(
86
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
87
- )
59
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
60
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
88
61
  return (
89
62
  query_result_checker.SqlResultValidator(
90
63
  self._session,
91
- f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_module_name}$$, 'MODULE') AS TAG_VALUE",
64
+ f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_model_name}$$, 'MODULE') AS TAG_VALUE",
92
65
  statement_params=statement_params,
93
66
  )
94
67
  .has_dimensions(expected_rows=1, expected_cols=1)
@@ -98,16 +71,19 @@ class ModuleTagSQLClient:
98
71
 
99
72
  def get_tag_list(
100
73
  self,
101
- module_name: sql_identifier.SqlIdentifier,
102
74
  *,
75
+ database_name: Optional[sql_identifier.SqlIdentifier],
76
+ schema_name: Optional[sql_identifier.SqlIdentifier],
77
+ model_name: sql_identifier.SqlIdentifier,
103
78
  statement_params: Optional[Dict[str, Any]] = None,
104
79
  ) -> List[row.Row]:
105
- fq_module_name = self.fully_qualified_module_name(module_name)
80
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
81
+ actual_database_name = database_name or self._database_name
106
82
  return (
107
83
  query_result_checker.SqlResultValidator(
108
84
  self._session,
109
85
  f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE
110
- FROM TABLE({self._database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_module_name}$$, 'MODULE'))""",
86
+ FROM TABLE({actual_database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_model_name}$$, 'MODULE'))""",
111
87
  statement_params=statement_params,
112
88
  )
113
89
  .has_column("TAG_DATABASE", allow_empty=True)
@@ -37,6 +37,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
37
37
  session: snowpark.Session,
38
38
  artifact_stage_location: str,
39
39
  compute_pool: str,
40
+ job_name: str,
40
41
  external_access_integrations: List[str],
41
42
  ) -> None:
42
43
  """Initialization
@@ -49,6 +50,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
49
50
  artifact_stage_location: Spec file and future deployment related artifacts will be stored under
50
51
  {stage}/models/{model_id}
51
52
  compute_pool: The compute pool used to run docker image build workload.
53
+ job_name: job_name to use.
52
54
  external_access_integrations: EAIs for network connection.
53
55
  """
54
56
  self.context_dir = context_dir
@@ -58,6 +60,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
58
60
  self.artifact_stage_location = artifact_stage_location
59
61
  self.compute_pool = compute_pool
60
62
  self.external_access_integrations = external_access_integrations
63
+ self.job_name = job_name
61
64
  self.client = snowservice_client.SnowServiceClient(session)
62
65
 
63
66
  assert artifact_stage_location.startswith(
@@ -203,8 +206,9 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
203
206
  )
204
207
 
205
208
  def _launch_kaniko_job(self, spec_stage_location: str) -> None:
206
- logger.debug("Submitting job for building docker image with kaniko")
209
+ logger.debug(f"Submitting job {self.job_name} for building docker image with kaniko")
207
210
  self.client.create_job(
211
+ job_name=self.job_name,
208
212
  compute_pool=self.compute_pool,
209
213
  spec_stage_location=spec_stage_location,
210
214
  external_access_integrations=self.external_access_integrations,
@@ -30,6 +30,7 @@ USER mambauser
30
30
 
31
31
  # Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time.
32
32
  ARG MAMBA_DOCKERFILE_ACTIVATE=1
33
+ ARG MAMBA_NO_LOW_SPEED_LIMIT=1
33
34
 
34
35
  # Bitsandbytes uses this ENVVAR to determine CUDA library location
35
36
  ENV CONDA_PREFIX=/opt/conda
@@ -346,6 +346,7 @@ class SnowServiceDeployment:
346
346
  (db, schema, _, _) = identifier.parse_schema_level_object_identifier(service_func_name)
347
347
 
348
348
  self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}")
349
+ self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}")
349
350
  # Spec file and future deployment related artifacts will be stored under {stage}/models/{model_id}
350
351
  self._model_artifact_stage_location = posixpath.join(deployment_stage_path, "models", self.id)
351
352
  self.debug_dir: Optional[str] = None
@@ -468,6 +469,7 @@ class SnowServiceDeployment:
468
469
  session=self.session,
469
470
  artifact_stage_location=self._model_artifact_stage_location,
470
471
  compute_pool=self.options.compute_pool,
472
+ job_name=self._job_name,
471
473
  external_access_integrations=self.options.external_access_integrations,
472
474
  )
473
475
  else:
@@ -17,11 +17,6 @@ class ResourceStatus(Enum):
17
17
  INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
18
18
 
19
19
 
20
- RESOURCE_TO_STATUS_FUNCTION_MAPPING = {
21
- ResourceType.SERVICE: "SYSTEM$GET_SERVICE_STATUS",
22
- ResourceType.JOB: "SYSTEM$GET_JOB_STATUS",
23
- }
24
-
25
20
  PREDICT = "predict"
26
21
  STAGE = "stage"
27
22
  COMPUTE_POOL = "compute_pool"
@@ -70,13 +70,16 @@ class SnowServiceClient:
70
70
  logger.debug(f"Create service with SQL: \n {sql}")
71
71
  self.session.sql(sql).collect()
72
72
 
73
- def create_job(self, compute_pool: str, spec_stage_location: str, external_access_integrations: List[str]) -> None:
73
+ def create_job(
74
+ self, job_name: str, compute_pool: str, spec_stage_location: str, external_access_integrations: List[str]
75
+ ) -> None:
74
76
  """Execute the job creation SQL command. Note that the job creation is synchronous, hence we execute it in a
75
77
  async way so that we can query the log in the meantime.
76
78
 
77
79
  Upon job failure, full job container log will be logged.
78
80
 
79
81
  Args:
82
+ job_name: name of the job
80
83
  compute_pool: name of the compute pool
81
84
  spec_stage_location: path to the stage location where the spec is located at.
82
85
  external_access_integrations: EAIs for network connection.
@@ -84,19 +87,18 @@ class SnowServiceClient:
84
87
  stage, path = uri.get_stage_and_path(spec_stage_location)
85
88
  sql = textwrap.dedent(
86
89
  f"""
87
- EXECUTE SERVICE
90
+ EXECUTE JOB SERVICE
88
91
  IN COMPUTE POOL {compute_pool}
89
92
  FROM {stage}
90
- SPEC = '{path}'
93
+ SPECIFICATION_FILE = '{path}'
94
+ NAME = {job_name}
91
95
  EXTERNAL_ACCESS_INTEGRATIONS = ({', '.join(external_access_integrations)})
92
96
  """
93
97
  )
94
98
  logger.debug(f"Create job with SQL: \n {sql}")
95
- cur = self.session._conn._conn.cursor()
96
- cur.execute_async(sql)
97
- job_id = cur._sfqid
99
+ self.session.sql(sql).collect_nowait()
98
100
  self.block_until_resource_is_ready(
99
- resource_name=str(job_id),
101
+ resource_name=job_name,
100
102
  resource_type=constants.ResourceType.JOB,
101
103
  container_name=constants.KANIKO_CONTAINER_NAME,
102
104
  max_retries=240,
@@ -182,10 +184,7 @@ class SnowServiceClient:
182
184
  """
183
185
  assert resource_type == constants.ResourceType.SERVICE or resource_type == constants.ResourceType.JOB
184
186
  query_command = ""
185
- if resource_type == constants.ResourceType.SERVICE:
186
- query_command = f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')"
187
- elif resource_type == constants.ResourceType.JOB:
188
- query_command = f"CALL SYSTEM$GET_JOB_LOGS('{resource_name}', '{container_name}')"
187
+ query_command = f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')"
189
188
  logger.warning(
190
189
  f"Best-effort log streaming from SPCS will be enabled when python logging level is set to INFO."
191
190
  f"Alternatively, you can also query the logs by running the query '{query_command}'"
@@ -201,7 +200,7 @@ class SnowServiceClient:
201
200
  )
202
201
  lsp.process_new_logs(resource_log, log_level=logging.INFO)
203
202
 
204
- status = self.get_resource_status(resource_name=resource_name, resource_type=resource_type)
203
+ status = self.get_resource_status(resource_name=resource_name)
205
204
 
206
205
  if resource_type == constants.ResourceType.JOB and status == constants.ResourceStatus.DONE:
207
206
  return
@@ -246,52 +245,24 @@ class SnowServiceClient:
246
245
  def get_resource_log(
247
246
  self, resource_name: str, resource_type: constants.ResourceType, container_name: str
248
247
  ) -> Optional[str]:
249
- if resource_type == constants.ResourceType.SERVICE:
250
- try:
251
- row = self.session.sql(
252
- f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')"
253
- ).collect()
254
- return str(row[0]["SYSTEM$GET_SERVICE_LOGS"])
255
- except Exception:
256
- return None
257
- elif resource_type == constants.ResourceType.JOB:
258
- try:
259
- row = self.session.sql(f"CALL SYSTEM$GET_JOB_LOGS('{resource_name}', '{container_name}')").collect()
260
- return str(row[0]["SYSTEM$GET_JOB_LOGS"])
261
- except Exception:
262
- return None
263
- else:
264
- raise snowml_exceptions.SnowflakeMLException(
265
- error_code=error_codes.NOT_IMPLEMENTED,
266
- original_exception=NotImplementedError(
267
- f"{resource_type.name} is not yet supported in get_resource_log function"
268
- ),
269
- )
270
-
271
- def get_resource_status(
272
- self, resource_name: str, resource_type: constants.ResourceType
273
- ) -> Optional[constants.ResourceStatus]:
248
+ try:
249
+ row = self.session.sql(
250
+ f"CALL SYSTEM$GET_SERVICE_LOGS('{resource_name}', '0', '{container_name}')"
251
+ ).collect()
252
+ return str(row[0]["SYSTEM$GET_SERVICE_LOGS"])
253
+ except Exception:
254
+ return None
255
+
256
+ def get_resource_status(self, resource_name: str) -> Optional[constants.ResourceStatus]:
274
257
  """Get resource status.
275
258
 
276
259
  Args:
277
260
  resource_name: Name of the resource.
278
- resource_type: Type of the resource.
279
-
280
- Raises:
281
- SnowflakeMLException: If resource type does not have a corresponding system function for querying status.
282
- SnowflakeMLException: If corresponding status call failed.
283
261
 
284
262
  Returns:
285
263
  Optional[constants.ResourceStatus]: The status of the resource, or None if the resource status is empty.
286
264
  """
287
- if resource_type not in constants.RESOURCE_TO_STATUS_FUNCTION_MAPPING:
288
- raise snowml_exceptions.SnowflakeMLException(
289
- error_code=error_codes.INVALID_ARGUMENT,
290
- original_exception=ValueError(
291
- f"Status querying is not supported for resources of type '{resource_type}'."
292
- ),
293
- )
294
- status_func = constants.RESOURCE_TO_STATUS_FUNCTION_MAPPING[resource_type]
265
+ status_func = "SYSTEM$GET_SERVICE_STATUS"
295
266
  try:
296
267
  row = self.session.sql(f"CALL {status_func}('{resource_name}');").collect()
297
268
  except Exception:
@@ -8,8 +8,10 @@ from typing import Any, Dict, List, Optional
8
8
 
9
9
  from absl import logging
10
10
  from packaging import requirements
11
+ from typing_extensions import deprecated
11
12
 
12
13
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
+ from snowflake.ml._internal.lineage import data_source, lineage_utils
13
15
  from snowflake.ml.model import model_signature, type_hints as model_types
14
16
  from snowflake.ml.model._model_composer.model_manifest import model_manifest
15
17
  from snowflake.ml.model._packager import model_packager
@@ -134,6 +136,7 @@ class ModelComposer:
134
136
  model_meta=self.packager.meta,
135
137
  model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
136
138
  options=options,
139
+ data_sources=self._get_data_sources(model),
137
140
  )
138
141
 
139
142
  file_utils.upload_directory_to_stage(
@@ -143,7 +146,8 @@ class ModelComposer:
143
146
  statement_params=self._statement_params,
144
147
  )
145
148
 
146
- def load(
149
+ @deprecated("Only used by PrPr model registry. Use static method version of load instead.")
150
+ def legacy_load(
147
151
  self,
148
152
  *,
149
153
  meta_only: bool = False,
@@ -163,3 +167,20 @@ class ModelComposer:
163
167
  with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf:
164
168
  zf.extractall(path=self._packager_workspace_path)
165
169
  self.packager.load(meta_only=meta_only, options=options)
170
+
171
+ @staticmethod
172
+ def load(
173
+ workspace_path: pathlib.Path,
174
+ *,
175
+ meta_only: bool = False,
176
+ options: Optional[model_types.ModelLoadOption] = None,
177
+ ) -> model_packager.ModelPackager:
178
+ mp = model_packager.ModelPackager(str(workspace_path / ModelComposer.MODEL_DIR_REL_PATH))
179
+ mp.load(meta_only=meta_only, options=options)
180
+ return mp
181
+
182
+ def _get_data_sources(self, model: model_types.SupportedModelType) -> Optional[List[data_source.DataSource]]:
183
+ data_sources = getattr(model, lineage_utils.DATA_SOURCES_ATTR, None)
184
+ if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
185
+ return data_sources
186
+ return None
@@ -5,6 +5,7 @@ from typing import List, Optional, cast
5
5
 
6
6
  import yaml
7
7
 
8
+ from snowflake.ml._internal.lineage import data_source
8
9
  from snowflake.ml.model import type_hints
9
10
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
10
11
  from snowflake.ml.model._model_composer.model_method import (
@@ -36,6 +37,7 @@ class ModelManifest:
36
37
  model_meta: model_meta_api.ModelMetadata,
37
38
  model_file_rel_path: pathlib.PurePosixPath,
38
39
  options: Optional[type_hints.ModelSaveOption] = None,
40
+ data_sources: Optional[List[data_source.DataSource]] = None,
39
41
  ) -> None:
40
42
  if options is None:
41
43
  options = {}
@@ -90,6 +92,10 @@ class ModelManifest:
90
92
  ],
91
93
  )
92
94
 
95
+ lineage_sources = self._extract_lineage_info(data_sources)
96
+ if lineage_sources:
97
+ manifest_dict["lineage_sources"] = lineage_sources
98
+
93
99
  with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
94
100
  # Anchors are not supported in the server, avoid that.
95
101
  yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
@@ -108,3 +114,19 @@ class ModelManifest:
108
114
  res = cast(model_manifest_schema.ModelManifestDict, raw_input)
109
115
 
110
116
  return res
117
+
118
+ def _extract_lineage_info(
119
+ self, data_sources: Optional[List[data_source.DataSource]]
120
+ ) -> List[model_manifest_schema.LineageSourceDict]:
121
+ result = []
122
+ if data_sources:
123
+ for source in data_sources:
124
+ result.append(
125
+ model_manifest_schema.LineageSourceDict(
126
+ # Currently, we only support lineage from Dataset.
127
+ type=model_manifest_schema.LineageSourceTypes.DATASET.value,
128
+ entity=source.fully_qualified_name,
129
+ version=source.version,
130
+ )
131
+ )
132
+ return result
@@ -75,8 +75,19 @@ class SnowparkMLDataDict(TypedDict):
75
75
  functions: Required[List[ModelFunctionInfoDict]]
76
76
 
77
77
 
78
+ class LineageSourceTypes(enum.Enum):
79
+ DATASET = "DATASET"
80
+
81
+
82
+ class LineageSourceDict(TypedDict):
83
+ type: Required[str]
84
+ entity: Required[str]
85
+ version: NotRequired[str]
86
+
87
+
78
88
  class ModelManifestDict(TypedDict):
79
89
  manifest_version: Required[str]
80
90
  runtimes: Required[Dict[str, ModelRuntimeDict]]
81
91
  methods: Required[List[ModelMethodDict]]
82
92
  user_data: NotRequired[Dict[str, Any]]
93
+ lineage_sources: NotRequired[List[LineageSourceDict]]
@@ -284,6 +284,7 @@ class ModelEnv:
284
284
  " This may prevent model deploying to Snowflake Warehouse."
285
285
  ),
286
286
  category=UserWarning,
287
+ stacklevel=2,
287
288
  )
288
289
  if len(channel_dependencies) == 0 and channel not in self._conda_dependencies:
289
290
  warnings.warn(
@@ -292,6 +293,7 @@ class ModelEnv:
292
293
  " This may prevent model deploying to Snowflake Warehouse."
293
294
  ),
294
295
  category=UserWarning,
296
+ stacklevel=2,
295
297
  )
296
298
  self._conda_dependencies[channel] = []
297
299
 
@@ -307,6 +309,7 @@ class ModelEnv:
307
309
  " This may be unintentional."
308
310
  ),
309
311
  category=UserWarning,
312
+ stacklevel=2,
310
313
  )
311
314
 
312
315
  if pip_requirements_list:
@@ -316,6 +319,7 @@ class ModelEnv:
316
319
  " This may prevent model deploying to Snowflake Warehouse."
317
320
  ),
318
321
  category=UserWarning,
322
+ stacklevel=2,
319
323
  )
320
324
  for pip_dependency in pip_requirements_list:
321
325
  if any(
@@ -338,6 +342,7 @@ class ModelEnv:
338
342
  " This may prevent model deploying to Snowflake Warehouse."
339
343
  ),
340
344
  category=UserWarning,
345
+ stacklevel=2,
341
346
  )
342
347
  for pip_dependency in pip_requirements_list:
343
348
  if any(
@@ -372,3 +377,39 @@ class ModelEnv:
372
377
  "cuda_version": self.cuda_version,
373
378
  "snowpark_ml_version": self.snowpark_ml_version,
374
379
  }
380
+
381
+ def validate_with_local_env(
382
+ self, check_snowpark_ml_version: bool = False
383
+ ) -> List[env_utils.IncorrectLocalEnvironmentError]:
384
+ errors = []
385
+ try:
386
+ env_utils.validate_py_runtime_version(str(self._python_version))
387
+ except env_utils.IncorrectLocalEnvironmentError as e:
388
+ errors.append(e)
389
+
390
+ for conda_reqs in self._conda_dependencies.values():
391
+ for conda_req in conda_reqs:
392
+ try:
393
+ env_utils.validate_local_installed_version_of_pip_package(
394
+ env_utils.try_convert_conda_requirement_to_pip(conda_req)
395
+ )
396
+ except env_utils.IncorrectLocalEnvironmentError as e:
397
+ errors.append(e)
398
+
399
+ for pip_req in self._pip_requirements:
400
+ try:
401
+ env_utils.validate_local_installed_version_of_pip_package(pip_req)
402
+ except env_utils.IncorrectLocalEnvironmentError as e:
403
+ errors.append(e)
404
+
405
+ if check_snowpark_ml_version:
406
+ # For Modeling model
407
+ if self._snowpark_ml_version.base_version != snowml_env.VERSION:
408
+ errors.append(
409
+ env_utils.IncorrectLocalEnvironmentError(
410
+ f"The local installed version of Snowpark ML library is {snowml_env.VERSION} "
411
+ f"which differs from required version {self.snowpark_ml_version}."
412
+ )
413
+ )
414
+
415
+ return errors
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import pathlib
2
3
  import tempfile
3
4
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
5
 
@@ -45,7 +46,7 @@ def _parse_mlflow_env(model_uri: str, env: model_env.ModelEnv) -> model_env.Mode
45
46
  if not os.path.exists(conda_env_file_path):
46
47
  raise ValueError("Cannot load MLFlow model dependencies.")
47
48
 
48
- env.load_from_conda_file(conda_env_file_path)
49
+ env.load_from_conda_file(pathlib.Path(conda_env_file_path))
49
50
 
50
51
  return env
51
52
 
@@ -320,11 +320,7 @@ class ModelMetadata:
320
320
 
321
321
  with open(model_yaml_path, "w", encoding="utf-8") as out:
322
322
  yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
323
- yaml.safe_dump(
324
- model_dict,
325
- stream=out,
326
- default_flow_style=False,
327
- )
323
+ yaml.safe_dump(model_dict, stream=out, default_flow_style=False)
328
324
 
329
325
  @staticmethod
330
326
  def _validate_model_metadata(loaded_meta: Any) -> model_meta_schema.ModelMetadataDict:
@@ -4,7 +4,6 @@ from typing import Dict, List, Optional
4
4
 
5
5
  from absl import logging
6
6
 
7
- from snowflake.ml._internal import env_utils
8
7
  from snowflake.ml._internal.exceptions import (
9
8
  error_codes,
10
9
  exceptions as snowml_exceptions,
@@ -129,8 +128,6 @@ class ModelPackager:
129
128
 
130
129
  model_meta.load_code_path(self.local_dir_path)
131
130
 
132
- env_utils.validate_py_runtime_version(self.meta.env.python_version)
133
-
134
131
  handler = model_handler.load_handler(self.meta.model_type)
135
132
  if handler is None:
136
133
  raise snowml_exceptions.SnowflakeMLException(
@@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
3
3
 
4
4
  import pandas as pd
5
5
 
6
+ from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
7
+
6
8
 
7
9
  class PandasModelTrainer:
8
10
  """
@@ -72,11 +74,61 @@ class PandasModelTrainer:
72
74
  Tuple[pd.DataFrame, object]: [predicted dataset, estimator]
73
75
  """
74
76
  assert hasattr(self.estimator, "fit_predict") # make type checker happy
75
- args = {"X": self.dataset[self.input_cols]}
76
- result = self.estimator.fit_predict(**args)
77
+ result = self.estimator.fit_predict(X=self.dataset[self.input_cols])
77
78
  result_df = pd.DataFrame(data=result, columns=expected_output_cols_list)
78
79
  if drop_input_cols:
79
80
  result_df = result_df
80
81
  else:
81
- result_df = pd.concat([self.dataset, result_df], axis=1)
82
+ # in case the output column name overlap with the input column names,
83
+ # remove the ones in input column names
84
+ remove_dataset_col_name_exist_in_output_col = list(
85
+ set(self.dataset.columns) - set(expected_output_cols_list)
86
+ )
87
+ result_df = pd.concat([self.dataset[remove_dataset_col_name_exist_in_output_col], result_df], axis=1)
88
+ return (result_df, self.estimator)
89
+
90
+ def train_fit_transform(
91
+ self,
92
+ expected_output_cols_list: List[str],
93
+ drop_input_cols: Optional[bool] = False,
94
+ ) -> Tuple[pd.DataFrame, object]:
95
+ """Trains the model using specified features and target columns from the dataset.
96
+ This API is different from fit itself because it would also provide the transform
97
+ output.
98
+
99
+ Args:
100
+ expected_output_cols_list (List[str]): The output columns
101
+ name as a list. Defaults to None.
102
+ drop_input_cols (Optional[bool]): Boolean to determine whether to
103
+ drop the input columns from the output dataset.
104
+
105
+ Returns:
106
+ Tuple[pd.DataFrame, object]: [transformed dataset, estimator]
107
+ """
108
+ assert hasattr(self.estimator, "fit") # make type checker happy
109
+ assert hasattr(self.estimator, "fit_transform") # make type checker happy
110
+
111
+ argspec = inspect.getfullargspec(self.estimator.fit)
112
+ args = {"X": self.dataset[self.input_cols]}
113
+ if self.label_cols:
114
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
115
+ args[label_arg_name] = self.dataset[self.label_cols].squeeze()
116
+
117
+ if self.sample_weight_col is not None and "sample_weight" in argspec.args:
118
+ args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
119
+
120
+ inference_res = self.estimator.fit_transform(**args)
121
+
122
+ transformed_numpy_array, output_cols = handle_inference_result(
123
+ inference_res=inference_res, output_cols=expected_output_cols_list, inference_method="fit_transform"
124
+ )
125
+
126
+ result_df = pd.DataFrame(data=transformed_numpy_array, columns=output_cols)
127
+ if drop_input_cols:
128
+ result_df = result_df
129
+ else:
130
+ # in case the output column name overlap with the input column names,
131
+ # remove the ones in input column names
132
+ remove_dataset_col_name_exist_in_output_col = list(set(self.dataset.columns) - set(output_cols))
133
+ result_df = pd.concat([self.dataset[remove_dataset_col_name_exist_in_output_col], result_df], axis=1)
82
134
  return (result_df, self.estimator)