snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -604,7 +604,7 @@ class FeatureStore:
604
604
  logger.info(f"Registered FeatureView {feature_view.name}/{version} successfully.")
605
605
  return self.get_feature_view(feature_view.name, str(version))
606
606
 
607
- @dispatch_decorator()
607
+ @overload
608
608
  def update_feature_view(
609
609
  self,
610
610
  name: str,
@@ -613,13 +613,37 @@ class FeatureStore:
613
613
  refresh_freq: Optional[str] = None,
614
614
  warehouse: Optional[str] = None,
615
615
  desc: Optional[str] = None,
616
+ ) -> FeatureView:
617
+ ...
618
+
619
+ @overload
620
+ def update_feature_view(
621
+ self,
622
+ name: FeatureView,
623
+ version: Optional[str] = None,
624
+ *,
625
+ refresh_freq: Optional[str] = None,
626
+ warehouse: Optional[str] = None,
627
+ desc: Optional[str] = None,
628
+ ) -> FeatureView:
629
+ ...
630
+
631
+ @dispatch_decorator() # type: ignore[misc]
632
+ def update_feature_view(
633
+ self,
634
+ name: Union[FeatureView, str],
635
+ version: Optional[str] = None,
636
+ *,
637
+ refresh_freq: Optional[str] = None,
638
+ warehouse: Optional[str] = None,
639
+ desc: Optional[str] = None,
616
640
  ) -> FeatureView:
617
641
  """Update a registered feature view.
618
642
  Check feature_view.py for which fields are allowed to be updated after registration.
619
643
 
620
644
  Args:
621
- name: name of the FeatureView to be updated.
622
- version: version of the FeatureView to be updated.
645
+ name: FeatureView object or name to suspend.
646
+ version: Optional version of feature view. Must set when argument feature_view is a str.
623
647
  refresh_freq: updated refresh frequency.
624
648
  warehouse: updated warehouse.
625
649
  desc: description of feature view.
@@ -661,7 +685,7 @@ class FeatureStore:
661
685
  SnowflakeMLException: [RuntimeError] If FeatureView is not managed and refresh_freq is defined.
662
686
  SnowflakeMLException: [RuntimeError] Failed to update feature view.
663
687
  """
664
- feature_view = self.get_feature_view(name=name, version=version)
688
+ feature_view = self._validate_feature_view_name_and_version_input(name, version)
665
689
  new_desc = desc if desc is not None else feature_view.desc
666
690
 
667
691
  if feature_view.status == FeatureViewStatus.STATIC:
@@ -696,7 +720,7 @@ class FeatureStore:
696
720
  f"Update feature view {feature_view.name}/{feature_view.version} failed: {e}"
697
721
  ),
698
722
  ) from e
699
- return self.get_feature_view(name=name, version=version)
723
+ return self.get_feature_view(name=feature_view.name, version=str(feature_view.version))
700
724
 
701
725
  @overload
702
726
  def read_feature_view(self, feature_view: str, version: str) -> DataFrame:
@@ -2121,7 +2145,7 @@ class FeatureStore:
2121
2145
  if "." not in name:
2122
2146
  return f"{self._config.full_schema_path}.{name}"
2123
2147
 
2124
- db_name, schema_name, object_name, _ = identifier.parse_schema_level_object_identifier(name)
2148
+ db_name, schema_name, object_name = identifier.parse_schema_level_object_identifier(name)
2125
2149
  return "{}.{}.{}".format(
2126
2150
  db_name or self._config.database,
2127
2151
  schema_name or self._config.schema,
@@ -2186,11 +2210,7 @@ class FeatureStore:
2186
2210
  if len(fv_maps.keys()) == 0:
2187
2211
  return self._session.create_dataframe([], schema=_LIST_FEATURE_VIEW_SCHEMA)
2188
2212
 
2189
- filters = (
2190
- [lambda d: d["entityName"].startswith(feature_view_name.resolved())] # type: ignore[union-attr]
2191
- if feature_view_name
2192
- else None
2193
- )
2213
+ filters = [lambda d: d["entityName"].startswith(feature_view_name.resolved())] if feature_view_name else None
2194
2214
  res = self._lookup_tagged_objects(self._get_entity_name(entity_name), filters)
2195
2215
 
2196
2216
  output_values: List[List[Any]] = []
@@ -2281,16 +2301,20 @@ class FeatureStore:
2281
2301
  timestamp_col=timestamp_col,
2282
2302
  desc=desc,
2283
2303
  version=version,
2284
- status=FeatureViewStatus(row["scheduling_state"])
2285
- if len(row["scheduling_state"]) > 0
2286
- else FeatureViewStatus.MASKED,
2304
+ status=(
2305
+ FeatureViewStatus(row["scheduling_state"])
2306
+ if len(row["scheduling_state"]) > 0
2307
+ else FeatureViewStatus.MASKED
2308
+ ),
2287
2309
  feature_descs=self._fetch_column_descs("DYNAMIC TABLE", fv_name),
2288
2310
  refresh_freq=row["target_lag"],
2289
2311
  database=self._config.database.identifier(),
2290
2312
  schema=self._config.schema.identifier(),
2291
- warehouse=SqlIdentifier(row["warehouse"], case_sensitive=True).identifier()
2292
- if len(row["warehouse"]) > 0
2293
- else None,
2313
+ warehouse=(
2314
+ SqlIdentifier(row["warehouse"], case_sensitive=True).identifier()
2315
+ if len(row["warehouse"]) > 0
2316
+ else None
2317
+ ),
2294
2318
  refresh_mode=row["refresh_mode"],
2295
2319
  refresh_mode_reason=row["refresh_mode_reason"],
2296
2320
  owner=row["owner"],
@@ -706,7 +706,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
706
706
  >>> ).attach_feature_desc({"AGE": "my age", "TITLE": '"my title"'})
707
707
  >>> fv = fs.register_feature_view(draft_fv, '1.0')
708
708
  <BLANKLINE>
709
- fv.to_df().show()
709
+ >>> fv.to_df().show()
710
710
  ----------------------------------------------------------------...
711
711
  |"NAME" |"ENTITIES" |"TIMESTAMP_COL" |"DESC" |
712
712
  ----------------------------------------------------------------...
@@ -801,7 +801,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
801
801
 
802
802
  @staticmethod
803
803
  def _load_from_lineage_node(session: Session, name: str, version: str) -> FeatureView:
804
- db_name, feature_store_name, feature_view_name, _ = identifier.parse_schema_level_object_identifier(name)
804
+ db_name, feature_store_name, feature_view_name = identifier.parse_schema_level_object_identifier(name)
805
805
 
806
806
  session_warehouse = session.get_current_warehouse()
807
807
 
@@ -35,7 +35,7 @@ class SFEmbeddedStageFileSystem(stage_fs.SFStageFileSystem):
35
35
  **kwargs: Any,
36
36
  ) -> None:
37
37
 
38
- (db, schema, object_name, _) = identifier.parse_schema_level_object_identifier(name)
38
+ (db, schema, object_name) = identifier.parse_schema_level_object_identifier(name)
39
39
  self._name = name # TODO: Require or resolve FQN
40
40
  self._domain = domain
41
41
 
@@ -538,7 +538,7 @@ def _validate_target_stage_loc(snowpark_session: snowpark.Session, target_stage_
538
538
  original_exception=fileset_errors.FileSetLocationError('FileSet location should start with "@".'),
539
539
  )
540
540
  try:
541
- db, schema, stage, _ = identifier.parse_schema_level_object_identifier(target_stage_loc[1:])
541
+ db, schema, stage, _ = identifier.parse_snowflake_stage_path(target_stage_loc[1:])
542
542
  if db is None or schema is None:
543
543
  raise ValueError("The stage path should be in the form '@<database>.<schema>.<stage>/*'")
544
544
  df_stages = snowpark_session.sql(f"Show stages like '{stage}' in SCHEMA {db}.{schema}")
@@ -15,6 +15,7 @@ from snowflake.ml._internal.exceptions import (
15
15
  from snowflake.ml._internal.utils import identifier
16
16
  from snowflake.ml.fileset import stage_fs
17
17
  from snowflake.ml.utils import connection_params
18
+ from snowflake.snowpark import context, exceptions as snowpark_exceptions
18
19
 
19
20
  PROTOCOL_NAME = "sfc"
20
21
 
@@ -84,7 +85,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
84
85
  """
85
86
  if kwargs.get(_RECREATE_FROM_SERIALIZED):
86
87
  try:
87
- snowpark_session = self._create_default_session()
88
+ snowpark_session = self._get_default_session()
88
89
  except Exception as e:
89
90
  raise snowml_exceptions.SnowflakeMLException(
90
91
  error_code=error_codes.SNOWML_DESERIALIZATION_FAILED,
@@ -103,7 +104,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
103
104
 
104
105
  super().__init__(**kwargs)
105
106
 
106
- def _create_default_session(self) -> snowpark.Session:
107
+ def _get_default_session(self) -> snowpark.Session:
107
108
  """Create a Snowpark Session from default login options.
108
109
 
109
110
  Returns:
@@ -114,6 +115,11 @@ class SFFileSystem(fsspec.AbstractFileSystem):
114
115
  ValueError: Snowflake Connection could not be created.
115
116
 
116
117
  """
118
+ try:
119
+ return context.get_active_session()
120
+ except snowpark_exceptions.SnowparkSessionException:
121
+ pass
122
+
117
123
  try:
118
124
  snowflake_config = connection_params.SnowflakeLoginOptions()
119
125
  except Exception as e:
@@ -328,7 +334,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
328
334
  ),
329
335
  )
330
336
  try:
331
- res = identifier.parse_schema_level_object_identifier(path[1:])
337
+ res = identifier.parse_snowflake_stage_path(path[1:])
332
338
  if res[1] is None or res[0] is None or (res[3] and not res[3].startswith("/")):
333
339
  raise ValueError("Invalid path. Missing database or schema identifier.")
334
340
  logging.debug(f"Parsed path: {res}")
@@ -306,6 +306,23 @@ class ModelVersion(lineage_node.LineageNode):
306
306
  statement_params=statement_params,
307
307
  )
308
308
 
309
+ @telemetry.send_api_usage_telemetry(
310
+ project=_TELEMETRY_PROJECT,
311
+ subproject=_TELEMETRY_SUBPROJECT,
312
+ )
313
+ def get_model_objective(self) -> model_types.ModelObjective:
314
+ statement_params = telemetry.get_statement_params(
315
+ project=_TELEMETRY_PROJECT,
316
+ subproject=_TELEMETRY_SUBPROJECT,
317
+ )
318
+ return self._model_ops.get_model_objective(
319
+ database_name=None,
320
+ schema_name=None,
321
+ model_name=self._model_name,
322
+ version_name=self._version_name,
323
+ statement_params=statement_params,
324
+ )
325
+
309
326
  @telemetry.send_api_usage_telemetry(
310
327
  project=_TELEMETRY_PROJECT,
311
328
  subproject=_TELEMETRY_SUBPROJECT,
@@ -606,8 +623,8 @@ class ModelVersion(lineage_node.LineageNode):
606
623
  "image_repo_database",
607
624
  "image_repo_schema",
608
625
  "image_repo",
609
- "image_name",
610
626
  "gpu_requests",
627
+ "num_workers",
611
628
  ],
612
629
  )
613
630
  def create_service(
@@ -617,11 +634,10 @@ class ModelVersion(lineage_node.LineageNode):
617
634
  image_build_compute_pool: Optional[str] = None,
618
635
  service_compute_pool: str,
619
636
  image_repo: str,
620
- image_name: Optional[str] = None,
621
637
  ingress_enabled: bool = False,
622
- min_instances: int = 1,
623
638
  max_instances: int = 1,
624
639
  gpu_requests: Optional[str] = None,
640
+ num_workers: Optional[int] = None,
625
641
  force_rebuild: bool = False,
626
642
  build_external_access_integration: str,
627
643
  ) -> str:
@@ -635,12 +651,12 @@ class ModelVersion(lineage_node.LineageNode):
635
651
  service_compute_pool: The name of the compute pool used to run the inference service.
636
652
  image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
637
653
  or schema of the model will be used.
638
- image_name: The name of the model inference image. Use a generated name if None.
639
654
  ingress_enabled: Whether to enable ingress.
640
- min_instances: The minimum number of inference service instances to run.
641
655
  max_instances: The maximum number of inference service instances to run.
642
656
  gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
643
657
  if None.
658
+ num_workers: The number of workers (replicas of models) to run the inference service.
659
+ Auto determined if None.
644
660
  force_rebuild: Whether to force a model inference image rebuild.
645
661
  build_external_access_integration: The external access integration for image build.
646
662
 
@@ -670,11 +686,10 @@ class ModelVersion(lineage_node.LineageNode):
670
686
  image_repo_database_name=image_repo_db_id,
671
687
  image_repo_schema_name=image_repo_schema_id,
672
688
  image_repo_name=image_repo_id,
673
- image_name=sql_identifier.SqlIdentifier(image_name) if image_name else None,
674
689
  ingress_enabled=ingress_enabled,
675
- min_instances=min_instances,
676
690
  max_instances=max_instances,
677
691
  gpu_requests=gpu_requests,
692
+ num_workers=num_workers,
678
693
  force_rebuild=force_rebuild,
679
694
  build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration),
680
695
  statement_params=statement_params,
@@ -554,15 +554,14 @@ class ModelOperator:
554
554
  res[function_name] = target_method
555
555
  return res
556
556
 
557
- def get_functions(
557
+ def _fetch_model_spec(
558
558
  self,
559
- *,
560
559
  database_name: Optional[sql_identifier.SqlIdentifier],
561
560
  schema_name: Optional[sql_identifier.SqlIdentifier],
562
561
  model_name: sql_identifier.SqlIdentifier,
563
562
  version_name: sql_identifier.SqlIdentifier,
564
563
  statement_params: Optional[Dict[str, Any]] = None,
565
- ) -> List[model_manifest_schema.ModelFunctionInfo]:
564
+ ) -> model_meta_schema.ModelMetadataDict:
566
565
  raw_model_spec_res = self._model_client.show_versions(
567
566
  database_name=database_name,
568
567
  schema_name=schema_name,
@@ -573,6 +572,43 @@ class ModelOperator:
573
572
  )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
574
573
  model_spec_dict = yaml.safe_load(raw_model_spec_res)
575
574
  model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
575
+ return model_spec
576
+
577
+ def get_model_objective(
578
+ self,
579
+ *,
580
+ database_name: Optional[sql_identifier.SqlIdentifier],
581
+ schema_name: Optional[sql_identifier.SqlIdentifier],
582
+ model_name: sql_identifier.SqlIdentifier,
583
+ version_name: sql_identifier.SqlIdentifier,
584
+ statement_params: Optional[Dict[str, Any]] = None,
585
+ ) -> type_hints.ModelObjective:
586
+ model_spec = self._fetch_model_spec(
587
+ database_name=database_name,
588
+ schema_name=schema_name,
589
+ model_name=model_name,
590
+ version_name=version_name,
591
+ statement_params=statement_params,
592
+ )
593
+ model_objective_val = model_spec.get("model_objective", type_hints.ModelObjective.UNKNOWN.value)
594
+ return type_hints.ModelObjective(model_objective_val)
595
+
596
+ def get_functions(
597
+ self,
598
+ *,
599
+ database_name: Optional[sql_identifier.SqlIdentifier],
600
+ schema_name: Optional[sql_identifier.SqlIdentifier],
601
+ model_name: sql_identifier.SqlIdentifier,
602
+ version_name: sql_identifier.SqlIdentifier,
603
+ statement_params: Optional[Dict[str, Any]] = None,
604
+ ) -> List[model_manifest_schema.ModelFunctionInfo]:
605
+ model_spec = self._fetch_model_spec(
606
+ database_name=database_name,
607
+ schema_name=schema_name,
608
+ model_name=model_name,
609
+ version_name=version_name,
610
+ statement_params=statement_params,
611
+ )
576
612
  show_functions_res = self._model_version_client.show_functions(
577
613
  database_name=database_name,
578
614
  schema_name=schema_name,
@@ -1,15 +1,45 @@
1
+ import dataclasses
2
+ import hashlib
3
+ import logging
1
4
  import pathlib
5
+ import queue
6
+ import sys
2
7
  import tempfile
3
- from typing import Any, Dict, Optional
8
+ import threading
9
+ import time
10
+ import uuid
11
+ from typing import Any, Dict, List, Optional, Tuple, cast
4
12
 
13
+ from snowflake import snowpark
5
14
  from snowflake.ml._internal import file_utils
6
15
  from snowflake.ml._internal.utils import sql_identifier
7
16
  from snowflake.ml.model._client.service import model_deployment_spec
8
17
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
9
- from snowflake.snowpark import session
18
+ from snowflake.snowpark import exceptions, row, session
10
19
  from snowflake.snowpark._internal import utils as snowpark_utils
11
20
 
12
21
 
22
+ def get_logger(logger_name: str) -> logging.Logger:
23
+ logger = logging.getLogger(logger_name)
24
+ logger.setLevel(logging.INFO)
25
+ handler = logging.StreamHandler(sys.stdout)
26
+ handler.setLevel(logging.INFO)
27
+ handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
28
+ logger.addHandler(handler)
29
+ return logger
30
+
31
+
32
+ logger = get_logger(__name__)
33
+ logger.propagate = False
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class ServiceLogInfo:
38
+ service_name: str
39
+ container_name: str
40
+ instance_id: str = "0"
41
+
42
+
13
43
  class ServiceOperator:
14
44
  """Service operator for container services logic."""
15
45
 
@@ -62,11 +92,10 @@ class ServiceOperator:
62
92
  image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
63
93
  image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
64
94
  image_repo_name: sql_identifier.SqlIdentifier,
65
- image_name: Optional[sql_identifier.SqlIdentifier],
66
95
  ingress_enabled: bool,
67
- min_instances: int,
68
96
  max_instances: int,
69
97
  gpu_requests: Optional[str],
98
+ num_workers: Optional[int],
70
99
  force_rebuild: bool,
71
100
  build_external_access_integration: sql_identifier.SqlIdentifier,
72
101
  statement_params: Optional[Dict[str, Any]] = None,
@@ -96,11 +125,10 @@ class ServiceOperator:
96
125
  image_repo_database_name=image_repo_database_name,
97
126
  image_repo_schema_name=image_repo_schema_name,
98
127
  image_repo_name=image_repo_name,
99
- image_name=image_name,
100
128
  ingress_enabled=ingress_enabled,
101
- min_instances=min_instances,
102
129
  max_instances=max_instances,
103
130
  gpu=gpu_requests,
131
+ num_workers=num_workers,
104
132
  force_rebuild=force_rebuild,
105
133
  external_access_integration=build_external_access_integration,
106
134
  )
@@ -111,11 +139,174 @@ class ServiceOperator:
111
139
  statement_params=statement_params,
112
140
  )
113
141
 
142
+ # check if the inference service is already running
143
+ try:
144
+ model_inference_service_status, _ = self._service_client.get_service_status(
145
+ service_name=service_name,
146
+ include_message=False,
147
+ statement_params=statement_params,
148
+ )
149
+ model_inference_service_exists = model_inference_service_status == service_sql.ServiceStatus.READY
150
+ except exceptions.SnowparkSQLException:
151
+ model_inference_service_exists = False
152
+
114
153
  # deploy the model service
115
- self._service_client.deploy_model(
154
+ query_id, async_job = self._service_client.deploy_model(
116
155
  stage_path=stage_path,
117
156
  model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH,
118
157
  statement_params=statement_params,
119
158
  )
120
159
 
160
+ # stream service logs in a thread
161
+ services = [
162
+ ServiceLogInfo(service_name=self._get_model_build_service_name(query_id), container_name="model-build"),
163
+ ServiceLogInfo(service_name=service_name, container_name="model-inference"),
164
+ ]
165
+ exception_queue: queue.Queue = queue.Queue() # type: ignore[type-arg]
166
+ log_thread = self._start_service_log_streaming(
167
+ async_job, services, model_inference_service_exists, exception_queue, statement_params
168
+ )
169
+ log_thread.join()
170
+
171
+ try:
172
+ # non-blocking check for an exception
173
+ exception = exception_queue.get(block=False)
174
+ if exception:
175
+ raise exception
176
+ except queue.Empty:
177
+ pass
178
+
121
179
  return service_name
180
+
181
+ def _start_service_log_streaming(
182
+ self,
183
+ async_job: snowpark.AsyncJob,
184
+ services: List[ServiceLogInfo],
185
+ model_inference_service_exists: bool,
186
+ exception_queue: queue.Queue, # type: ignore[type-arg]
187
+ statement_params: Optional[Dict[str, Any]] = None,
188
+ ) -> threading.Thread:
189
+ """Start the service log streaming in a separate thread."""
190
+ log_thread = threading.Thread(
191
+ target=self._stream_service_logs,
192
+ args=(async_job, services, model_inference_service_exists, exception_queue, statement_params),
193
+ )
194
+ log_thread.start()
195
+ return log_thread
196
+
197
+ def _stream_service_logs(
198
+ self,
199
+ async_job: snowpark.AsyncJob,
200
+ services: List[ServiceLogInfo],
201
+ model_inference_service_exists: bool,
202
+ exception_queue: queue.Queue, # type: ignore[type-arg]
203
+ statement_params: Optional[Dict[str, Any]] = None,
204
+ ) -> None:
205
+ """Stream service logs while the async job is running."""
206
+
207
+ def fetch_logs(service_name: str, container_name: str, offset: int) -> Tuple[str, int]:
208
+ service_logs = self._service_client.get_service_logs(
209
+ service_name=service_name,
210
+ container_name=container_name,
211
+ statement_params=statement_params,
212
+ )
213
+
214
+ # return only new logs starting after the offset
215
+ if len(service_logs) > offset:
216
+ new_logs = service_logs[offset:]
217
+ new_offset = len(service_logs)
218
+ else:
219
+ new_logs = ""
220
+ new_offset = offset
221
+
222
+ return new_logs, new_offset
223
+
224
+ is_model_build_service_done = False
225
+ log_offset = 0
226
+ model_build_service, model_inference_service = services[0], services[1]
227
+ service_name, container_name = model_build_service.service_name, model_build_service.container_name
228
+ # BuildJobName
229
+ service_logger = get_logger(service_name)
230
+ service_logger.propagate = False
231
+ while not async_job.is_done():
232
+ if model_inference_service_exists:
233
+ time.sleep(5)
234
+ continue
235
+
236
+ try:
237
+ block_size = 180
238
+ service_status, message = self._service_client.get_service_status(
239
+ service_name=service_name, include_message=True, statement_params=statement_params
240
+ )
241
+ logger.info(f"Inference service {service_name} is {service_status.value}.")
242
+
243
+ new_logs, new_offset = fetch_logs(service_name, container_name, log_offset)
244
+ if new_logs:
245
+ service_logger.info(new_logs)
246
+ log_offset = new_offset
247
+
248
+ # check if model build service is done
249
+ if not is_model_build_service_done:
250
+ service_status, _ = self._service_client.get_service_status(
251
+ service_name=model_build_service.service_name,
252
+ include_message=False,
253
+ statement_params=statement_params,
254
+ )
255
+
256
+ if service_status == service_sql.ServiceStatus.DONE:
257
+ is_model_build_service_done = True
258
+ log_offset = 0
259
+ service_name = model_inference_service.service_name
260
+ container_name = model_inference_service.container_name
261
+ # InferenceServiceName-InstanceId
262
+ service_logger = get_logger(f"{service_name}-{model_inference_service.instance_id}")
263
+ service_logger.propagate = False
264
+ logger.info(f"Model build service {model_build_service.service_name} complete.")
265
+ logger.info("-" * block_size)
266
+ except ValueError:
267
+ logger.warning(f"Unknown service status: {service_status.value}")
268
+ except Exception as ex:
269
+ logger.warning(f"Caught an exception when logging: {repr(ex)}")
270
+
271
+ time.sleep(5)
272
+
273
+ if model_inference_service_exists:
274
+ logger.info(f"Inference service {model_inference_service.service_name} is already RUNNING.")
275
+ else:
276
+ self._finalize_logs(service_logger, services[-1], log_offset, statement_params)
277
+
278
+ # catch exceptions from the deploy model execution
279
+ try:
280
+ res = cast(List[row.Row], async_job.result())
281
+ logger.info(f"Model deployment for inference service {model_inference_service.service_name} complete.")
282
+ logger.info(res[0][0])
283
+ except Exception as ex:
284
+ exception_queue.put(ex)
285
+
286
+ def _finalize_logs(
287
+ self,
288
+ service_logger: logging.Logger,
289
+ service: ServiceLogInfo,
290
+ offset: int,
291
+ statement_params: Optional[Dict[str, Any]] = None,
292
+ ) -> None:
293
+ """Fetch service logs after the async job is done to ensure no logs are missed."""
294
+ try:
295
+ service_logs = self._service_client.get_service_logs(
296
+ service_name=service.service_name,
297
+ container_name=service.container_name,
298
+ statement_params=statement_params,
299
+ )
300
+
301
+ if len(service_logs) > offset:
302
+ service_logger.info(service_logs[offset:])
303
+ except Exception as ex:
304
+ logger.warning(f"Caught an exception when logging: {repr(ex)}")
305
+
306
+ @staticmethod
307
+ def _get_model_build_service_name(query_id: str) -> str:
308
+ """Get the model build service name through the server-side logic."""
309
+ most_significant_bits = uuid.UUID(query_id).int >> 64
310
+ md5_hash = hashlib.md5(str(most_significant_bits).encode()).hexdigest()
311
+ identifier = md5_hash[:6]
312
+ return ("model_build_" + identifier).upper()
@@ -34,11 +34,10 @@ class ModelDeploymentSpec:
34
34
  image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
35
35
  image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
36
36
  image_repo_name: sql_identifier.SqlIdentifier,
37
- image_name: Optional[sql_identifier.SqlIdentifier],
38
37
  ingress_enabled: bool,
39
- min_instances: int,
40
38
  max_instances: int,
41
39
  gpu: Optional[str],
40
+ num_workers: Optional[int],
42
41
  force_rebuild: bool,
43
42
  external_access_integration: sql_identifier.SqlIdentifier,
44
43
  ) -> None:
@@ -61,8 +60,6 @@ class ModelDeploymentSpec:
61
60
  force_rebuild=force_rebuild,
62
61
  external_access_integrations=[external_access_integration.identifier()],
63
62
  )
64
- if image_name:
65
- image_build_dict["image_name"] = image_name.identifier()
66
63
 
67
64
  # service spec
68
65
  saved_service_database = service_database_name or database_name
@@ -74,12 +71,14 @@ class ModelDeploymentSpec:
74
71
  name=fq_service_name,
75
72
  compute_pool=service_compute_pool_name.identifier(),
76
73
  ingress_enabled=ingress_enabled,
77
- min_instances=min_instances,
78
74
  max_instances=max_instances,
79
75
  )
80
76
  if gpu:
81
77
  service_dict["gpu"] = gpu
82
78
 
79
+ if num_workers:
80
+ service_dict["num_workers"] = num_workers
81
+
83
82
  # model deployment spec
84
83
  model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict(
85
84
  models=[model_dict],
@@ -11,7 +11,6 @@ class ModelDict(TypedDict):
11
11
  class ImageBuildDict(TypedDict):
12
12
  compute_pool: Required[str]
13
13
  image_repo: Required[str]
14
- image_name: NotRequired[str]
15
14
  force_rebuild: Required[bool]
16
15
  external_access_integrations: Required[List[str]]
17
16
 
@@ -20,9 +19,9 @@ class ServiceDict(TypedDict):
20
19
  name: Required[str]
21
20
  compute_pool: Required[str]
22
21
  ingress_enabled: Required[bool]
23
- min_instances: Required[int]
24
22
  max_instances: Required[int]
25
23
  gpu: NotRequired[str]
24
+ num_workers: NotRequired[int]
26
25
 
27
26
 
28
27
  class ModelDeploymentSpecDict(TypedDict):