snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. snowflake/cortex/_sentiment.py +7 -4
  2. snowflake/ml/_internal/env_utils.py +6 -0
  3. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  4. snowflake/ml/_internal/telemetry.py +1 -0
  5. snowflake/ml/_internal/utils/identifier.py +1 -1
  6. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  7. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  8. snowflake/ml/dataset/__init__.py +2 -1
  9. snowflake/ml/dataset/dataset.py +4 -3
  10. snowflake/ml/dataset/dataset_reader.py +5 -8
  11. snowflake/ml/feature_store/__init__.py +6 -0
  12. snowflake/ml/feature_store/access_manager.py +283 -0
  13. snowflake/ml/feature_store/feature_store.py +160 -100
  14. snowflake/ml/feature_store/feature_view.py +30 -19
  15. snowflake/ml/fileset/embedded_stage_fs.py +15 -12
  16. snowflake/ml/fileset/snowfs.py +2 -30
  17. snowflake/ml/fileset/stage_fs.py +25 -7
  18. snowflake/ml/model/_client/model/model_impl.py +46 -39
  19. snowflake/ml/model/_client/model/model_version_impl.py +24 -2
  20. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  21. snowflake/ml/model/_client/ops/model_ops.py +174 -16
  22. snowflake/ml/model/_client/sql/_base.py +34 -0
  23. snowflake/ml/model/_client/sql/model.py +32 -39
  24. snowflake/ml/model/_client/sql/model_version.py +111 -42
  25. snowflake/ml/model/_client/sql/stage.py +6 -32
  26. snowflake/ml/model/_client/sql/tag.py +32 -56
  27. snowflake/ml/model/_model_composer/model_composer.py +8 -4
  28. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  29. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  30. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  31. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +90 -142
  32. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +159 -0
  33. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
  34. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +8 -1
  35. snowflake/ml/modeling/cluster/affinity_propagation.py +8 -1
  36. snowflake/ml/modeling/cluster/agglomerative_clustering.py +8 -1
  37. snowflake/ml/modeling/cluster/birch.py +8 -1
  38. snowflake/ml/modeling/cluster/bisecting_k_means.py +8 -1
  39. snowflake/ml/modeling/cluster/dbscan.py +8 -1
  40. snowflake/ml/modeling/cluster/feature_agglomeration.py +8 -1
  41. snowflake/ml/modeling/cluster/k_means.py +8 -1
  42. snowflake/ml/modeling/cluster/mean_shift.py +8 -1
  43. snowflake/ml/modeling/cluster/mini_batch_k_means.py +8 -1
  44. snowflake/ml/modeling/cluster/optics.py +8 -1
  45. snowflake/ml/modeling/cluster/spectral_biclustering.py +8 -1
  46. snowflake/ml/modeling/cluster/spectral_clustering.py +8 -1
  47. snowflake/ml/modeling/cluster/spectral_coclustering.py +8 -1
  48. snowflake/ml/modeling/compose/column_transformer.py +8 -1
  49. snowflake/ml/modeling/compose/transformed_target_regressor.py +8 -1
  50. snowflake/ml/modeling/covariance/elliptic_envelope.py +8 -1
  51. snowflake/ml/modeling/covariance/empirical_covariance.py +8 -1
  52. snowflake/ml/modeling/covariance/graphical_lasso.py +8 -1
  53. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +8 -1
  54. snowflake/ml/modeling/covariance/ledoit_wolf.py +8 -1
  55. snowflake/ml/modeling/covariance/min_cov_det.py +8 -1
  56. snowflake/ml/modeling/covariance/oas.py +8 -1
  57. snowflake/ml/modeling/covariance/shrunk_covariance.py +8 -1
  58. snowflake/ml/modeling/decomposition/dictionary_learning.py +8 -1
  59. snowflake/ml/modeling/decomposition/factor_analysis.py +8 -1
  60. snowflake/ml/modeling/decomposition/fast_ica.py +8 -1
  61. snowflake/ml/modeling/decomposition/incremental_pca.py +8 -1
  62. snowflake/ml/modeling/decomposition/kernel_pca.py +8 -1
  63. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +8 -1
  64. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +8 -1
  65. snowflake/ml/modeling/decomposition/pca.py +8 -1
  66. snowflake/ml/modeling/decomposition/sparse_pca.py +8 -1
  67. snowflake/ml/modeling/decomposition/truncated_svd.py +8 -1
  68. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +8 -1
  69. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +8 -1
  70. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +8 -1
  71. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +8 -1
  72. snowflake/ml/modeling/ensemble/bagging_classifier.py +8 -1
  73. snowflake/ml/modeling/ensemble/bagging_regressor.py +8 -1
  74. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +8 -1
  75. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +8 -1
  76. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +8 -1
  77. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +8 -1
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +8 -1
  79. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +8 -1
  80. snowflake/ml/modeling/ensemble/isolation_forest.py +8 -1
  81. snowflake/ml/modeling/ensemble/random_forest_classifier.py +8 -1
  82. snowflake/ml/modeling/ensemble/random_forest_regressor.py +8 -1
  83. snowflake/ml/modeling/ensemble/stacking_regressor.py +8 -1
  84. snowflake/ml/modeling/ensemble/voting_classifier.py +8 -1
  85. snowflake/ml/modeling/ensemble/voting_regressor.py +8 -1
  86. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +8 -1
  87. snowflake/ml/modeling/feature_selection/select_fdr.py +8 -1
  88. snowflake/ml/modeling/feature_selection/select_fpr.py +8 -1
  89. snowflake/ml/modeling/feature_selection/select_fwe.py +8 -1
  90. snowflake/ml/modeling/feature_selection/select_k_best.py +8 -1
  91. snowflake/ml/modeling/feature_selection/select_percentile.py +8 -1
  92. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +8 -1
  93. snowflake/ml/modeling/feature_selection/variance_threshold.py +8 -1
  94. snowflake/ml/modeling/framework/base.py +4 -3
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +8 -1
  96. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +8 -1
  97. snowflake/ml/modeling/impute/iterative_imputer.py +8 -1
  98. snowflake/ml/modeling/impute/knn_imputer.py +8 -1
  99. snowflake/ml/modeling/impute/missing_indicator.py +8 -1
  100. snowflake/ml/modeling/impute/simple_imputer.py +21 -2
  101. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +8 -1
  102. snowflake/ml/modeling/kernel_approximation/nystroem.py +8 -1
  103. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +8 -1
  104. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +8 -1
  105. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +8 -1
  106. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +8 -1
  107. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +8 -1
  108. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +8 -1
  109. snowflake/ml/modeling/linear_model/ard_regression.py +8 -1
  110. snowflake/ml/modeling/linear_model/bayesian_ridge.py +8 -1
  111. snowflake/ml/modeling/linear_model/elastic_net.py +8 -1
  112. snowflake/ml/modeling/linear_model/elastic_net_cv.py +8 -1
  113. snowflake/ml/modeling/linear_model/gamma_regressor.py +8 -1
  114. snowflake/ml/modeling/linear_model/huber_regressor.py +8 -1
  115. snowflake/ml/modeling/linear_model/lars.py +8 -1
  116. snowflake/ml/modeling/linear_model/lars_cv.py +8 -1
  117. snowflake/ml/modeling/linear_model/lasso.py +8 -1
  118. snowflake/ml/modeling/linear_model/lasso_cv.py +8 -1
  119. snowflake/ml/modeling/linear_model/lasso_lars.py +8 -1
  120. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +8 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +8 -1
  122. snowflake/ml/modeling/linear_model/linear_regression.py +8 -1
  123. snowflake/ml/modeling/linear_model/logistic_regression.py +8 -1
  124. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +8 -1
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +8 -1
  126. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +8 -1
  127. snowflake/ml/modeling/linear_model/multi_task_lasso.py +8 -1
  128. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +8 -1
  129. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +8 -1
  130. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +8 -1
  131. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +8 -1
  132. snowflake/ml/modeling/linear_model/perceptron.py +8 -1
  133. snowflake/ml/modeling/linear_model/poisson_regressor.py +8 -1
  134. snowflake/ml/modeling/linear_model/ransac_regressor.py +8 -1
  135. snowflake/ml/modeling/linear_model/ridge.py +8 -1
  136. snowflake/ml/modeling/linear_model/ridge_classifier.py +8 -1
  137. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +8 -1
  138. snowflake/ml/modeling/linear_model/ridge_cv.py +8 -1
  139. snowflake/ml/modeling/linear_model/sgd_classifier.py +8 -1
  140. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +8 -1
  141. snowflake/ml/modeling/linear_model/sgd_regressor.py +8 -1
  142. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +8 -1
  143. snowflake/ml/modeling/linear_model/tweedie_regressor.py +8 -1
  144. snowflake/ml/modeling/manifold/isomap.py +8 -1
  145. snowflake/ml/modeling/manifold/mds.py +8 -1
  146. snowflake/ml/modeling/manifold/spectral_embedding.py +8 -1
  147. snowflake/ml/modeling/manifold/tsne.py +8 -1
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +8 -1
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +8 -1
  150. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +8 -1
  151. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +8 -1
  152. snowflake/ml/modeling/multiclass/output_code_classifier.py +8 -1
  153. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +8 -1
  154. snowflake/ml/modeling/naive_bayes/categorical_nb.py +8 -1
  155. snowflake/ml/modeling/naive_bayes/complement_nb.py +8 -1
  156. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +8 -1
  157. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +8 -1
  158. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +8 -1
  159. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +8 -1
  160. snowflake/ml/modeling/neighbors/kernel_density.py +8 -1
  161. snowflake/ml/modeling/neighbors/local_outlier_factor.py +8 -1
  162. snowflake/ml/modeling/neighbors/nearest_centroid.py +8 -1
  163. snowflake/ml/modeling/neighbors/nearest_neighbors.py +8 -1
  164. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +8 -1
  165. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +8 -1
  166. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +8 -1
  167. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +8 -1
  168. snowflake/ml/modeling/neural_network/mlp_classifier.py +8 -1
  169. snowflake/ml/modeling/neural_network/mlp_regressor.py +8 -1
  170. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  171. snowflake/ml/modeling/pipeline/pipeline.py +27 -7
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -1
  173. snowflake/ml/modeling/semi_supervised/label_propagation.py +8 -1
  174. snowflake/ml/modeling/semi_supervised/label_spreading.py +8 -1
  175. snowflake/ml/modeling/svm/linear_svc.py +8 -1
  176. snowflake/ml/modeling/svm/linear_svr.py +8 -1
  177. snowflake/ml/modeling/svm/nu_svc.py +8 -1
  178. snowflake/ml/modeling/svm/nu_svr.py +8 -1
  179. snowflake/ml/modeling/svm/svc.py +8 -1
  180. snowflake/ml/modeling/svm/svr.py +8 -1
  181. snowflake/ml/modeling/tree/decision_tree_classifier.py +8 -1
  182. snowflake/ml/modeling/tree/decision_tree_regressor.py +8 -1
  183. snowflake/ml/modeling/tree/extra_tree_classifier.py +8 -1
  184. snowflake/ml/modeling/tree/extra_tree_regressor.py +8 -1
  185. snowflake/ml/modeling/xgboost/xgb_classifier.py +8 -1
  186. snowflake/ml/modeling/xgboost/xgb_regressor.py +8 -1
  187. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +8 -1
  188. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +8 -1
  189. snowflake/ml/registry/_manager/model_manager.py +95 -8
  190. snowflake/ml/registry/registry.py +10 -1
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/METADATA +66 -10
  193. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/RECORD +196 -192
  194. snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
  195. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/LICENSE.txt +0 -0
  196. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/WHEEL +0 -0
  197. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,8 @@ from snowflake.ml._internal.utils import (
9
9
  query_result_checker,
10
10
  sql_identifier,
11
11
  )
12
- from snowflake.snowpark import dataframe, functions as F, row, session, types as spt
12
+ from snowflake.ml.model._client.sql import _base
13
+ from snowflake.snowpark import dataframe, functions as F, row, types as spt
13
14
  from snowflake.snowpark._internal import utils as snowpark_utils
14
15
 
15
16
 
@@ -20,44 +21,51 @@ def _normalize_url_for_sql(url: str) -> str:
20
21
  return f"'{url}'"
21
22
 
22
23
 
23
- class ModelVersionSQLClient:
24
+ class ModelVersionSQLClient(_base._BaseSQLClient):
24
25
  FUNCTION_NAME_COL_NAME = "name"
25
26
  FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
26
27
 
27
- def __init__(
28
+ def create_from_stage(
28
29
  self,
29
- session: session.Session,
30
30
  *,
31
- database_name: sql_identifier.SqlIdentifier,
32
- schema_name: sql_identifier.SqlIdentifier,
31
+ database_name: Optional[sql_identifier.SqlIdentifier],
32
+ schema_name: Optional[sql_identifier.SqlIdentifier],
33
+ model_name: sql_identifier.SqlIdentifier,
34
+ version_name: sql_identifier.SqlIdentifier,
35
+ stage_path: str,
36
+ statement_params: Optional[Dict[str, Any]] = None,
33
37
  ) -> None:
34
- self._session = session
35
- self._database_name = database_name
36
- self._schema_name = schema_name
37
-
38
- def __eq__(self, __value: object) -> bool:
39
- if not isinstance(__value, ModelVersionSQLClient):
40
- return False
41
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
42
-
43
- def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
44
- return identifier.get_schema_level_object_identifier(
45
- self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
46
- )
38
+ query_result_checker.SqlResultValidator(
39
+ self._session,
40
+ (
41
+ f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
42
+ f" WITH VERSION {version_name.identifier()} FROM {stage_path}"
43
+ ),
44
+ statement_params=statement_params,
45
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
47
46
 
48
- def create_from_stage(
47
+ def create_from_model_version(
49
48
  self,
50
49
  *,
50
+ source_database_name: Optional[sql_identifier.SqlIdentifier],
51
+ source_schema_name: Optional[sql_identifier.SqlIdentifier],
52
+ source_model_name: sql_identifier.SqlIdentifier,
53
+ source_version_name: sql_identifier.SqlIdentifier,
54
+ database_name: Optional[sql_identifier.SqlIdentifier],
55
+ schema_name: Optional[sql_identifier.SqlIdentifier],
51
56
  model_name: sql_identifier.SqlIdentifier,
52
57
  version_name: sql_identifier.SqlIdentifier,
53
- stage_path: str,
54
58
  statement_params: Optional[Dict[str, Any]] = None,
55
59
  ) -> None:
60
+ fq_source_model_name = self.fully_qualified_object_name(
61
+ source_database_name, source_schema_name, source_model_name
62
+ )
63
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
56
64
  query_result_checker.SqlResultValidator(
57
65
  self._session,
58
66
  (
59
- f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
60
- f" FROM {stage_path}"
67
+ f"CREATE MODEL {fq_model_name} WITH VERSION {version_name} FROM MODEL {fq_source_model_name}"
68
+ f" VERSION {source_version_name}"
61
69
  ),
62
70
  statement_params=statement_params,
63
71
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -66,6 +74,8 @@ class ModelVersionSQLClient:
66
74
  def add_version_from_stage(
67
75
  self,
68
76
  *,
77
+ database_name: Optional[sql_identifier.SqlIdentifier],
78
+ schema_name: Optional[sql_identifier.SqlIdentifier],
69
79
  model_name: sql_identifier.SqlIdentifier,
70
80
  version_name: sql_identifier.SqlIdentifier,
71
81
  stage_path: str,
@@ -74,8 +84,34 @@ class ModelVersionSQLClient:
74
84
  query_result_checker.SqlResultValidator(
75
85
  self._session,
76
86
  (
77
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}"
78
- f" FROM {stage_path}"
87
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
88
+ f" ADD VERSION {version_name.identifier()} FROM {stage_path}"
89
+ ),
90
+ statement_params=statement_params,
91
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
92
+
93
+ def add_version_from_model_version(
94
+ self,
95
+ *,
96
+ source_database_name: Optional[sql_identifier.SqlIdentifier],
97
+ source_schema_name: Optional[sql_identifier.SqlIdentifier],
98
+ source_model_name: sql_identifier.SqlIdentifier,
99
+ source_version_name: sql_identifier.SqlIdentifier,
100
+ database_name: Optional[sql_identifier.SqlIdentifier],
101
+ schema_name: Optional[sql_identifier.SqlIdentifier],
102
+ model_name: sql_identifier.SqlIdentifier,
103
+ version_name: sql_identifier.SqlIdentifier,
104
+ statement_params: Optional[Dict[str, Any]] = None,
105
+ ) -> None:
106
+ fq_source_model_name = self.fully_qualified_object_name(
107
+ source_database_name, source_schema_name, source_model_name
108
+ )
109
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
110
+ query_result_checker.SqlResultValidator(
111
+ self._session,
112
+ (
113
+ f"ALTER MODEL {fq_model_name} ADD VERSION {version_name} FROM MODEL {fq_source_model_name}"
114
+ f" VERSION {source_version_name}"
79
115
  ),
80
116
  statement_params=statement_params,
81
117
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -83,6 +119,8 @@ class ModelVersionSQLClient:
83
119
  def set_default_version(
84
120
  self,
85
121
  *,
122
+ database_name: Optional[sql_identifier.SqlIdentifier],
123
+ schema_name: Optional[sql_identifier.SqlIdentifier],
86
124
  model_name: sql_identifier.SqlIdentifier,
87
125
  version_name: sql_identifier.SqlIdentifier,
88
126
  statement_params: Optional[Dict[str, Any]] = None,
@@ -90,7 +128,7 @@ class ModelVersionSQLClient:
90
128
  query_result_checker.SqlResultValidator(
91
129
  self._session,
92
130
  (
93
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
131
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
94
132
  f"SET DEFAULT_VERSION = {version_name.identifier()}"
95
133
  ),
96
134
  statement_params=statement_params,
@@ -99,6 +137,8 @@ class ModelVersionSQLClient:
99
137
  def list_file(
100
138
  self,
101
139
  *,
140
+ database_name: Optional[sql_identifier.SqlIdentifier],
141
+ schema_name: Optional[sql_identifier.SqlIdentifier],
102
142
  model_name: sql_identifier.SqlIdentifier,
103
143
  version_name: sql_identifier.SqlIdentifier,
104
144
  file_path: pathlib.PurePosixPath,
@@ -110,7 +150,10 @@ class ModelVersionSQLClient:
110
150
 
111
151
  stage_location = (
112
152
  pathlib.PurePosixPath(
113
- self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
153
+ self.fully_qualified_object_name(database_name, schema_name, model_name),
154
+ "versions",
155
+ version_name.resolved(),
156
+ file_path,
114
157
  ).as_posix()
115
158
  + trailing_slash
116
159
  )
@@ -124,13 +167,15 @@ class ModelVersionSQLClient:
124
167
  f"List {_normalize_url_for_sql(stage_location_url)}",
125
168
  statement_params=statement_params,
126
169
  )
127
- .has_column("name")
170
+ .has_column("name", allow_empty=True)
128
171
  .validate()
129
172
  )
130
173
 
131
174
  def get_file(
132
175
  self,
133
176
  *,
177
+ database_name: Optional[sql_identifier.SqlIdentifier],
178
+ schema_name: Optional[sql_identifier.SqlIdentifier],
134
179
  model_name: sql_identifier.SqlIdentifier,
135
180
  version_name: sql_identifier.SqlIdentifier,
136
181
  file_path: pathlib.PurePosixPath,
@@ -138,7 +183,10 @@ class ModelVersionSQLClient:
138
183
  statement_params: Optional[Dict[str, Any]] = None,
139
184
  ) -> pathlib.Path:
140
185
  stage_location = pathlib.PurePosixPath(
141
- self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
186
+ self.fully_qualified_object_name(database_name, schema_name, model_name),
187
+ "versions",
188
+ version_name.resolved(),
189
+ file_path,
142
190
  ).as_posix()
143
191
  stage_location_url = ParseResult(
144
192
  scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
@@ -149,7 +197,7 @@ class ModelVersionSQLClient:
149
197
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
150
198
  options = {"parallel": 10}
151
199
  cursor = self._session._conn._cursor
152
- cursor._download(stage_location_url, str(target_path), options) # type: ignore[attr-defined]
200
+ cursor._download(stage_location_url, str(target_path), options) # type: ignore[union-attr]
153
201
  cursor.fetchall()
154
202
  else:
155
203
  query_result_checker.SqlResultValidator(
@@ -162,6 +210,8 @@ class ModelVersionSQLClient:
162
210
  def show_functions(
163
211
  self,
164
212
  *,
213
+ database_name: Optional[sql_identifier.SqlIdentifier],
214
+ schema_name: Optional[sql_identifier.SqlIdentifier],
165
215
  model_name: sql_identifier.SqlIdentifier,
166
216
  version_name: sql_identifier.SqlIdentifier,
167
217
  statement_params: Optional[Dict[str, Any]] = None,
@@ -169,7 +219,7 @@ class ModelVersionSQLClient:
169
219
  res = query_result_checker.SqlResultValidator(
170
220
  self._session,
171
221
  (
172
- f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_model_name(model_name)}"
222
+ f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
173
223
  f" VERSION {version_name.identifier()}"
174
224
  ),
175
225
  statement_params=statement_params,
@@ -180,15 +230,17 @@ class ModelVersionSQLClient:
180
230
  def set_comment(
181
231
  self,
182
232
  *,
183
- comment: str,
233
+ database_name: Optional[sql_identifier.SqlIdentifier],
234
+ schema_name: Optional[sql_identifier.SqlIdentifier],
184
235
  model_name: sql_identifier.SqlIdentifier,
185
236
  version_name: sql_identifier.SqlIdentifier,
237
+ comment: str,
186
238
  statement_params: Optional[Dict[str, Any]] = None,
187
239
  ) -> None:
188
240
  query_result_checker.SqlResultValidator(
189
241
  self._session,
190
242
  (
191
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
243
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
192
244
  f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
193
245
  ),
194
246
  statement_params=statement_params,
@@ -197,6 +249,8 @@ class ModelVersionSQLClient:
197
249
  def invoke_function_method(
198
250
  self,
199
251
  *,
252
+ database_name: Optional[sql_identifier.SqlIdentifier],
253
+ schema_name: Optional[sql_identifier.SqlIdentifier],
200
254
  model_name: sql_identifier.SqlIdentifier,
201
255
  version_name: sql_identifier.SqlIdentifier,
202
256
  method_name: sql_identifier.SqlIdentifier,
@@ -210,10 +264,12 @@ class ModelVersionSQLClient:
210
264
  INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
211
265
  with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
212
266
  else:
267
+ actual_database_name = database_name or self._database_name
268
+ actual_schema_name = schema_name or self._schema_name
213
269
  tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
214
270
  INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
215
- self._database_name.identifier(),
216
- self._schema_name.identifier(),
271
+ actual_database_name.identifier(),
272
+ actual_schema_name.identifier(),
217
273
  tmp_table_name,
218
274
  )
219
275
  input_df.write.save_as_table( # type: ignore[call-overload]
@@ -228,7 +284,8 @@ class ModelVersionSQLClient:
228
284
  module_version_alias = "MODEL_VERSION_ALIAS"
229
285
  with_statements.append(
230
286
  f"{module_version_alias} AS "
231
- f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
287
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
288
+ f" VERSION {version_name.identifier()}"
232
289
  )
233
290
 
234
291
  args_sql_list = []
@@ -267,6 +324,8 @@ class ModelVersionSQLClient:
267
324
  def invoke_table_function_method(
268
325
  self,
269
326
  *,
327
+ database_name: Optional[sql_identifier.SqlIdentifier],
328
+ schema_name: Optional[sql_identifier.SqlIdentifier],
270
329
  model_name: sql_identifier.SqlIdentifier,
271
330
  version_name: sql_identifier.SqlIdentifier,
272
331
  method_name: sql_identifier.SqlIdentifier,
@@ -281,10 +340,12 @@ class ModelVersionSQLClient:
281
340
  INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
282
341
  with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
283
342
  else:
343
+ actual_database_name = database_name or self._database_name
344
+ actual_schema_name = schema_name or self._schema_name
284
345
  tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
285
346
  INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
286
- self._database_name.identifier(),
287
- self._schema_name.identifier(),
347
+ actual_database_name.identifier(),
348
+ actual_schema_name.identifier(),
288
349
  tmp_table_name,
289
350
  )
290
351
  input_df.write.save_as_table( # type: ignore[call-overload]
@@ -297,7 +358,8 @@ class ModelVersionSQLClient:
297
358
  module_version_alias = "MODEL_VERSION_ALIAS"
298
359
  with_statements.append(
299
360
  f"{module_version_alias} AS "
300
- f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
361
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
362
+ f" VERSION {version_name.identifier()}"
301
363
  )
302
364
 
303
365
  partition_by = partition_column.identifier() if partition_column is not None else "1"
@@ -344,6 +406,8 @@ class ModelVersionSQLClient:
344
406
  self,
345
407
  metadata_dict: Dict[str, Any],
346
408
  *,
409
+ database_name: Optional[sql_identifier.SqlIdentifier],
410
+ schema_name: Optional[sql_identifier.SqlIdentifier],
347
411
  model_name: sql_identifier.SqlIdentifier,
348
412
  version_name: sql_identifier.SqlIdentifier,
349
413
  statement_params: Optional[Dict[str, Any]] = None,
@@ -352,8 +416,8 @@ class ModelVersionSQLClient:
352
416
  query_result_checker.SqlResultValidator(
353
417
  self._session,
354
418
  (
355
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}"
356
- f" SET METADATA=$${json_metadata}$$"
419
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
420
+ f" MODIFY VERSION {version_name.identifier()} SET METADATA=$${json_metadata}$$"
357
421
  ),
358
422
  statement_params=statement_params,
359
423
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -361,12 +425,17 @@ class ModelVersionSQLClient:
361
425
  def drop_version(
362
426
  self,
363
427
  *,
428
+ database_name: Optional[sql_identifier.SqlIdentifier],
429
+ schema_name: Optional[sql_identifier.SqlIdentifier],
364
430
  model_name: sql_identifier.SqlIdentifier,
365
431
  version_name: sql_identifier.SqlIdentifier,
366
432
  statement_params: Optional[Dict[str, Any]] = None,
367
433
  ) -> None:
368
434
  query_result_checker.SqlResultValidator(
369
435
  self._session,
370
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} DROP VERSION {version_name.identifier()}",
436
+ (
437
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
438
+ f" DROP VERSION {version_name.identifier()}"
439
+ ),
371
440
  statement_params=statement_params,
372
441
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -1,46 +1,20 @@
1
1
  from typing import Any, Dict, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
9
5
 
10
6
 
11
- class StageSQLClient:
12
- def __init__(
13
- self,
14
- session: session.Session,
15
- *,
16
- database_name: sql_identifier.SqlIdentifier,
17
- schema_name: sql_identifier.SqlIdentifier,
18
- ) -> None:
19
- self._session = session
20
- self._database_name = database_name
21
- self._schema_name = schema_name
22
-
23
- def __eq__(self, __value: object) -> bool:
24
- if not isinstance(__value, StageSQLClient):
25
- return False
26
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
27
-
28
- def fully_qualified_stage_name(
29
- self,
30
- stage_name: sql_identifier.SqlIdentifier,
31
- ) -> str:
32
- return identifier.get_schema_level_object_identifier(
33
- self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier()
34
- )
35
-
7
+ class StageSQLClient(_base._BaseSQLClient):
36
8
  def create_tmp_stage(
37
9
  self,
38
10
  *,
11
+ database_name: Optional[sql_identifier.SqlIdentifier],
12
+ schema_name: Optional[sql_identifier.SqlIdentifier],
39
13
  stage_name: sql_identifier.SqlIdentifier,
40
14
  statement_params: Optional[Dict[str, Any]] = None,
41
15
  ) -> None:
42
16
  query_result_checker.SqlResultValidator(
43
17
  self._session,
44
- f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}",
18
+ f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
45
19
  statement_params=statement_params,
46
20
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -1,52 +1,25 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import row, session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import row
9
6
 
10
7
 
11
- class ModuleTagSQLClient:
12
- def __init__(
13
- self,
14
- session: session.Session,
15
- *,
16
- database_name: sql_identifier.SqlIdentifier,
17
- schema_name: sql_identifier.SqlIdentifier,
18
- ) -> None:
19
- self._session = session
20
- self._database_name = database_name
21
- self._schema_name = schema_name
22
-
23
- def __eq__(self, __value: object) -> bool:
24
- if not isinstance(__value, ModuleTagSQLClient):
25
- return False
26
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
27
-
28
- def fully_qualified_module_name(
29
- self,
30
- module_name: sql_identifier.SqlIdentifier,
31
- ) -> str:
32
- return identifier.get_schema_level_object_identifier(
33
- self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier()
34
- )
35
-
8
+ class ModuleTagSQLClient(_base._BaseSQLClient):
36
9
  def set_tag_on_model(
37
10
  self,
38
- model_name: sql_identifier.SqlIdentifier,
39
11
  *,
40
- tag_database_name: sql_identifier.SqlIdentifier,
41
- tag_schema_name: sql_identifier.SqlIdentifier,
12
+ database_name: Optional[sql_identifier.SqlIdentifier],
13
+ schema_name: Optional[sql_identifier.SqlIdentifier],
14
+ model_name: sql_identifier.SqlIdentifier,
15
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
16
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
42
17
  tag_name: sql_identifier.SqlIdentifier,
43
18
  tag_value: str,
44
19
  statement_params: Optional[Dict[str, Any]] = None,
45
20
  ) -> None:
46
- fq_model_name = self.fully_qualified_module_name(model_name)
47
- fq_tag_name = identifier.get_schema_level_object_identifier(
48
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
49
- )
21
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
22
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
50
23
  query_result_checker.SqlResultValidator(
51
24
  self._session,
52
25
  f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$",
@@ -55,17 +28,17 @@ class ModuleTagSQLClient:
55
28
 
56
29
  def unset_tag_on_model(
57
30
  self,
58
- model_name: sql_identifier.SqlIdentifier,
59
31
  *,
60
- tag_database_name: sql_identifier.SqlIdentifier,
61
- tag_schema_name: sql_identifier.SqlIdentifier,
32
+ database_name: Optional[sql_identifier.SqlIdentifier],
33
+ schema_name: Optional[sql_identifier.SqlIdentifier],
34
+ model_name: sql_identifier.SqlIdentifier,
35
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
36
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
62
37
  tag_name: sql_identifier.SqlIdentifier,
63
38
  statement_params: Optional[Dict[str, Any]] = None,
64
39
  ) -> None:
65
- fq_model_name = self.fully_qualified_module_name(model_name)
66
- fq_tag_name = identifier.get_schema_level_object_identifier(
67
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
68
- )
40
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
41
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
69
42
  query_result_checker.SqlResultValidator(
70
43
  self._session,
71
44
  f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}",
@@ -74,21 +47,21 @@ class ModuleTagSQLClient:
74
47
 
75
48
  def get_tag_value(
76
49
  self,
77
- module_name: sql_identifier.SqlIdentifier,
78
50
  *,
79
- tag_database_name: sql_identifier.SqlIdentifier,
80
- tag_schema_name: sql_identifier.SqlIdentifier,
51
+ database_name: Optional[sql_identifier.SqlIdentifier],
52
+ schema_name: Optional[sql_identifier.SqlIdentifier],
53
+ model_name: sql_identifier.SqlIdentifier,
54
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
55
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
81
56
  tag_name: sql_identifier.SqlIdentifier,
82
57
  statement_params: Optional[Dict[str, Any]] = None,
83
58
  ) -> row.Row:
84
- fq_module_name = self.fully_qualified_module_name(module_name)
85
- fq_tag_name = identifier.get_schema_level_object_identifier(
86
- tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
87
- )
59
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
60
+ fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
88
61
  return (
89
62
  query_result_checker.SqlResultValidator(
90
63
  self._session,
91
- f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_module_name}$$, 'MODULE') AS TAG_VALUE",
64
+ f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_model_name}$$, 'MODULE') AS TAG_VALUE",
92
65
  statement_params=statement_params,
93
66
  )
94
67
  .has_dimensions(expected_rows=1, expected_cols=1)
@@ -98,16 +71,19 @@ class ModuleTagSQLClient:
98
71
 
99
72
  def get_tag_list(
100
73
  self,
101
- module_name: sql_identifier.SqlIdentifier,
102
74
  *,
75
+ database_name: Optional[sql_identifier.SqlIdentifier],
76
+ schema_name: Optional[sql_identifier.SqlIdentifier],
77
+ model_name: sql_identifier.SqlIdentifier,
103
78
  statement_params: Optional[Dict[str, Any]] = None,
104
79
  ) -> List[row.Row]:
105
- fq_module_name = self.fully_qualified_module_name(module_name)
80
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
81
+ actual_database_name = database_name or self._database_name
106
82
  return (
107
83
  query_result_checker.SqlResultValidator(
108
84
  self._session,
109
85
  f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE
110
- FROM TABLE({self._database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_module_name}$$, 'MODULE'))""",
86
+ FROM TABLE({actual_database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_model_name}$$, 'MODULE'))""",
111
87
  statement_params=statement_params,
112
88
  )
113
89
  .has_column("TAG_DATABASE", allow_empty=True)
@@ -11,7 +11,7 @@ from packaging import requirements
11
11
  from typing_extensions import deprecated
12
12
 
13
13
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
- from snowflake.ml._internal.lineage import data_source
14
+ from snowflake.ml._internal.lineage import data_source, lineage_utils
15
15
  from snowflake.ml.model import model_signature, type_hints as model_types
16
16
  from snowflake.ml.model._model_composer.model_manifest import model_manifest
17
17
  from snowflake.ml.model._packager import model_packager
@@ -136,7 +136,7 @@ class ModelComposer:
136
136
  model_meta=self.packager.meta,
137
137
  model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
138
138
  options=options,
139
- data_sources=self._get_data_sources(model),
139
+ data_sources=self._get_data_sources(model, sample_input_data),
140
140
  )
141
141
 
142
142
  file_utils.upload_directory_to_stage(
@@ -179,8 +179,12 @@ class ModelComposer:
179
179
  mp.load(meta_only=meta_only, options=options)
180
180
  return mp
181
181
 
182
- def _get_data_sources(self, model: model_types.SupportedModelType) -> Optional[List[data_source.DataSource]]:
183
- data_sources = getattr(model, "_data_sources", None)
182
+ def _get_data_sources(
183
+ self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
184
+ ) -> Optional[List[data_source.DataSource]]:
185
+ data_sources = getattr(model, lineage_utils.DATA_SOURCES_ATTR, None)
186
+ if not data_sources and sample_input_data is not None:
187
+ data_sources = getattr(sample_input_data, lineage_utils.DATA_SOURCES_ATTR, None)
184
188
  if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
185
189
  return data_sources
186
190
  return None
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import pathlib
2
3
  import tempfile
3
4
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
5
 
@@ -45,7 +46,7 @@ def _parse_mlflow_env(model_uri: str, env: model_env.ModelEnv) -> model_env.Mode
45
46
  if not os.path.exists(conda_env_file_path):
46
47
  raise ValueError("Cannot load MLFlow model dependencies.")
47
48
 
48
- env.load_from_conda_file(conda_env_file_path)
49
+ env.load_from_conda_file(pathlib.Path(conda_env_file_path))
49
50
 
50
51
  return env
51
52
 
@@ -281,9 +281,7 @@ class ModelMetadata:
281
281
  "cpu": model_runtime.ModelRuntime("cpu", self.env),
282
282
  }
283
283
  if self.env.cuda_version:
284
- runtimes.update(
285
- {"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True, server_availability_source="conda")}
286
- )
284
+ runtimes.update({"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True)})
287
285
  return runtimes
288
286
 
289
287
  def save(self, model_dir_path: str) -> None:
@@ -1,11 +1,11 @@
1
1
  import copy
2
2
  import pathlib
3
3
  import warnings
4
- from typing import List, Literal, Optional
4
+ from typing import List, Optional
5
5
 
6
6
  from packaging import requirements
7
7
 
8
- from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
8
+ from snowflake.ml._internal import env_utils, file_utils
9
9
  from snowflake.ml.model._packager.model_env import model_env
10
10
  from snowflake.ml.model._packager.model_meta import model_meta_schema
11
11
  from snowflake.ml.model._packager.model_runtime import (
@@ -37,7 +37,6 @@ class ModelRuntime:
37
37
  env: model_env.ModelEnv,
38
38
  imports: Optional[List[pathlib.PurePosixPath]] = None,
39
39
  is_gpu: bool = False,
40
- server_availability_source: Literal["snowflake", "conda"] = "snowflake",
41
40
  loading_from_file: bool = False,
42
41
  ) -> None:
43
42
  self.name = name
@@ -48,30 +47,7 @@ class ModelRuntime:
48
47
  return
49
48
 
50
49
  snowml_pkg_spec = f"{env_utils.SNOWPARK_ML_PKG_NAME}=={self.runtime_env.snowpark_ml_version}"
51
- if self.runtime_env._snowpark_ml_version.local:
52
- self.embed_local_ml_library = True
53
- else:
54
- if server_availability_source == "snowflake":
55
- snowml_server_availability = (
56
- len(
57
- env_utils.get_matched_package_versions_in_information_schema_with_active_session(
58
- reqs=[requirements.Requirement(snowml_pkg_spec)],
59
- python_version=snowml_env.PYTHON_VERSION,
60
- ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
61
- )
62
- >= 1
63
- )
64
- else:
65
- snowml_server_availability = (
66
- len(
67
- env_utils.get_matched_package_versions_in_snowflake_conda_channel(
68
- req=requirements.Requirement(snowml_pkg_spec),
69
- python_version=snowml_env.PYTHON_VERSION,
70
- )
71
- )
72
- >= 1
73
- )
74
- self.embed_local_ml_library = not snowml_server_availability
50
+ self.embed_local_ml_library = self.runtime_env._snowpark_ml_version.local
75
51
 
76
52
  additional_package = (
77
53
  _SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES if self.embed_local_ml_library else [snowml_pkg_spec]