snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/telemetry.py +4 -2
  7. snowflake/ml/_internal/type_utils.py +3 -3
  8. snowflake/ml/_internal/utils/import_utils.py +31 -0
  9. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
  10. snowflake/ml/data/__init__.py +5 -0
  11. snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
  12. snowflake/ml/data/data_connector.py +1 -1
  13. snowflake/ml/data/torch_utils.py +33 -14
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
  16. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
  18. snowflake/ml/feature_store/examples/example_helper.py +6 -3
  19. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
  20. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
  21. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
  22. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
  23. snowflake/ml/feature_store/feature_store.py +1 -2
  24. snowflake/ml/feature_store/feature_view.py +5 -1
  25. snowflake/ml/model/_client/model/model_version_impl.py +145 -11
  26. snowflake/ml/model/_client/ops/model_ops.py +56 -16
  27. snowflake/ml/model/_client/ops/service_ops.py +46 -30
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  30. snowflake/ml/model/_client/sql/service.py +25 -1
  31. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  34. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
  36. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  37. snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
  38. snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
  39. snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
  41. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
  42. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
  43. snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
  44. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
  45. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  46. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  47. snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
  48. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  49. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  50. snowflake/ml/model/_packager/model_packager.py +0 -11
  51. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  52. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  53. snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
  54. snowflake/ml/model/_signatures/core.py +63 -16
  55. snowflake/ml/model/_signatures/pandas_handler.py +87 -27
  56. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  57. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  58. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  59. snowflake/ml/model/_signatures/utils.py +4 -0
  60. snowflake/ml/model/custom_model.py +47 -7
  61. snowflake/ml/model/model_signature.py +40 -9
  62. snowflake/ml/model/type_hints.py +9 -1
  63. snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
  64. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
  65. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
  66. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
  67. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
  68. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
  69. snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
  70. snowflake/ml/modeling/cluster/dbscan.py +5 -2
  71. snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
  72. snowflake/ml/modeling/cluster/k_means.py +14 -19
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  74. snowflake/ml/modeling/cluster/optics.py +6 -6
  75. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
  76. snowflake/ml/modeling/compose/column_transformer.py +15 -5
  77. snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
  78. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  79. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  80. snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
  81. snowflake/ml/modeling/covariance/oas.py +1 -1
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
  85. snowflake/ml/modeling/decomposition/pca.py +28 -15
  86. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
  87. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
  88. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
  89. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
  90. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
  91. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
  92. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
  93. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
  94. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
  96. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  104. snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
  106. snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lars.py +0 -10
  109. snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
  110. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
  112. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
  113. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
  119. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
  120. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  122. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
  123. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
  124. snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
  125. snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
  126. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
  127. snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
  128. snowflake/ml/modeling/manifold/isomap.py +1 -1
  129. snowflake/ml/modeling/manifold/mds.py +3 -3
  130. snowflake/ml/modeling/manifold/tsne.py +10 -4
  131. snowflake/ml/modeling/metrics/classification.py +12 -16
  132. snowflake/ml/modeling/metrics/ranking.py +3 -3
  133. snowflake/ml/modeling/metrics/regression.py +3 -3
  134. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  135. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  136. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  137. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  138. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
  139. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
  140. snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
  141. snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
  142. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  143. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
  144. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  145. snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
  146. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
  147. snowflake/ml/modeling/pipeline/pipeline.py +16 -14
  148. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
  149. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
  150. snowflake/ml/modeling/svm/linear_svc.py +25 -16
  151. snowflake/ml/modeling/svm/linear_svr.py +23 -17
  152. snowflake/ml/modeling/svm/nu_svc.py +5 -3
  153. snowflake/ml/modeling/svm/nu_svr.py +3 -1
  154. snowflake/ml/modeling/svm/svc.py +9 -5
  155. snowflake/ml/modeling/svm/svr.py +3 -1
  156. snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
  157. snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
  158. snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
  159. snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
  160. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
  161. snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
  162. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  163. snowflake/ml/monitoring/model_monitor.py +37 -0
  164. snowflake/ml/registry/_manager/model_manager.py +15 -1
  165. snowflake/ml/registry/registry.py +32 -37
  166. snowflake/ml/version.py +1 -1
  167. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
  168. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
  169. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  170. snowflake/ml/monitoring/_client/model_monitor.py +0 -126
  171. snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
  172. snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
  173. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  174. /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
  175. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  176. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -614,6 +614,102 @@ class ModelVersion(lineage_node.LineageNode):
614
614
  version_name=sql_identifier.SqlIdentifier(version),
615
615
  )
616
616
 
617
+ @overload
618
+ def create_service(
619
+ self,
620
+ *,
621
+ service_name: str,
622
+ image_build_compute_pool: Optional[str] = None,
623
+ service_compute_pool: str,
624
+ image_repo: str,
625
+ ingress_enabled: bool = False,
626
+ max_instances: int = 1,
627
+ cpu_requests: Optional[str] = None,
628
+ memory_requests: Optional[str] = None,
629
+ gpu_requests: Optional[str] = None,
630
+ num_workers: Optional[int] = None,
631
+ max_batch_rows: Optional[int] = None,
632
+ force_rebuild: bool = False,
633
+ build_external_access_integration: Optional[str] = None,
634
+ ) -> str:
635
+ """Create an inference service with the given spec.
636
+
637
+ Args:
638
+ service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
639
+ schema of the model will be used.
640
+ image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
641
+ the service compute pool if None.
642
+ service_compute_pool: The name of the compute pool used to run the inference service.
643
+ image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
644
+ or schema of the model will be used.
645
+ ingress_enabled: If true, creates an service endpoint associated with the service. User must have
646
+ BIND SERVICE ENDPOINT privilege on the account.
647
+ max_instances: The maximum number of inference service instances to run. The same value it set to
648
+ MIN_INSTANCES property of the service.
649
+ cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
650
+ None, we attempt to utilize all the vCPU of the node.
651
+ memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
652
+ requires a unit (GiB, MiB). If None, we attempt to utilize all the memory of the node.
653
+ gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
654
+ if None.
655
+ num_workers: The number of workers to run the inference service for handling requests in parallel within an
656
+ instance of the service. By default, it is set to 2*vCPU+1 of the node for CPU based inference and 1 for
657
+ GPU based inference. For GPU based inference, please see best practices before playing with this value.
658
+ max_batch_rows: The maximum number of rows to batch for inference. Auto determined if None. Minimum 32.
659
+ force_rebuild: Whether to force a model inference image rebuild.
660
+ build_external_access_integration: (Deprecated) The external access integration for image build. This is
661
+ usually permitting access to conda & PyPI repositories.
662
+ """
663
+ ...
664
+
665
+ @overload
666
+ def create_service(
667
+ self,
668
+ *,
669
+ service_name: str,
670
+ image_build_compute_pool: Optional[str] = None,
671
+ service_compute_pool: str,
672
+ image_repo: str,
673
+ ingress_enabled: bool = False,
674
+ max_instances: int = 1,
675
+ cpu_requests: Optional[str] = None,
676
+ memory_requests: Optional[str] = None,
677
+ gpu_requests: Optional[str] = None,
678
+ num_workers: Optional[int] = None,
679
+ max_batch_rows: Optional[int] = None,
680
+ force_rebuild: bool = False,
681
+ build_external_access_integrations: Optional[List[str]] = None,
682
+ ) -> str:
683
+ """Create an inference service with the given spec.
684
+
685
+ Args:
686
+ service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
687
+ schema of the model will be used.
688
+ image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
689
+ the service compute pool if None.
690
+ service_compute_pool: The name of the compute pool used to run the inference service.
691
+ image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
692
+ or schema of the model will be used.
693
+ ingress_enabled: If true, creates an service endpoint associated with the service. User must have
694
+ BIND SERVICE ENDPOINT privilege on the account.
695
+ max_instances: The maximum number of inference service instances to run. The same value it set to
696
+ MIN_INSTANCES property of the service.
697
+ cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
698
+ None, we attempt to utilize all the vCPU of the node.
699
+ memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
700
+ requires a unit (GiB, MiB). If None, we attempt to utilize all the memory of the node.
701
+ gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
702
+ if None.
703
+ num_workers: The number of workers to run the inference service for handling requests in parallel within an
704
+ instance of the service. By default, it is set to 2*vCPU+1 of the node for CPU based inference and 1 for
705
+ GPU based inference. For GPU based inference, please see best practices before playing with this value.
706
+ max_batch_rows: The maximum number of rows to batch for inference. Auto determined if None. Minimum 32.
707
+ force_rebuild: Whether to force a model inference image rebuild.
708
+ build_external_access_integrations: The external access integrations for image build. This is usually
709
+ permitting access to conda & PyPI repositories.
710
+ """
711
+ ...
712
+
617
713
  @telemetry.send_api_usage_telemetry(
618
714
  project=_TELEMETRY_PROJECT,
619
715
  subproject=_TELEMETRY_SUBPROJECT,
@@ -638,11 +734,14 @@ class ModelVersion(lineage_node.LineageNode):
638
734
  image_repo: str,
639
735
  ingress_enabled: bool = False,
640
736
  max_instances: int = 1,
737
+ cpu_requests: Optional[str] = None,
738
+ memory_requests: Optional[str] = None,
641
739
  gpu_requests: Optional[str] = None,
642
740
  num_workers: Optional[int] = None,
643
741
  max_batch_rows: Optional[int] = None,
644
742
  force_rebuild: bool = False,
645
- build_external_access_integration: str,
743
+ build_external_access_integration: Optional[str] = None,
744
+ build_external_access_integrations: Optional[List[str]] = None,
646
745
  ) -> str:
647
746
  """Create an inference service with the given spec.
648
747
 
@@ -658,6 +757,10 @@ class ModelVersion(lineage_node.LineageNode):
658
757
  BIND SERVICE ENDPOINT privilege on the account.
659
758
  max_instances: The maximum number of inference service instances to run. The same value it set to
660
759
  MIN_INSTANCES property of the service.
760
+ cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
761
+ None, we attempt to utilize all the vCPU of the node.
762
+ memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
763
+ requires a unit (GiB, MiB). If None, we attempt to utilize all the memory of the node.
661
764
  gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
662
765
  if None.
663
766
  num_workers: The number of workers to run the inference service for handling requests in parallel within an
@@ -665,9 +768,14 @@ class ModelVersion(lineage_node.LineageNode):
665
768
  GPU based inference. For GPU based inference, please see best practices before playing with this value.
666
769
  max_batch_rows: The maximum number of rows to batch for inference. Auto determined if None. Minimum 32.
667
770
  force_rebuild: Whether to force a model inference image rebuild.
668
- build_external_access_integration: The external access integration for image build. This is usually
771
+ build_external_access_integration: (Deprecated) The external access integration for image build. This is
772
+ usually permitting access to conda & PyPI repositories.
773
+ build_external_access_integrations: The external access integrations for image build. This is usually
669
774
  permitting access to conda & PyPI repositories.
670
775
 
776
+ Raises:
777
+ ValueError: Illegal external access integration arguments.
778
+
671
779
  Returns:
672
780
  Result information about service creation from server.
673
781
  """
@@ -675,6 +783,20 @@ class ModelVersion(lineage_node.LineageNode):
675
783
  project=_TELEMETRY_PROJECT,
676
784
  subproject=_TELEMETRY_SUBPROJECT,
677
785
  )
786
+ if build_external_access_integration is not None:
787
+ msg = (
788
+ "`build_external_access_integration` is deprecated. "
789
+ "Please use `build_external_access_integrations` instead."
790
+ )
791
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
792
+ if build_external_access_integrations is not None:
793
+ msg = (
794
+ "`build_external_access_integration` and `build_external_access_integrations` cannot be set at the"
795
+ "same time. Please use `build_external_access_integrations` only."
796
+ )
797
+ raise ValueError(msg)
798
+ build_external_access_integrations = [build_external_access_integration]
799
+
678
800
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
679
801
  image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
680
802
  return self._service_ops.create_service(
@@ -696,11 +818,17 @@ class ModelVersion(lineage_node.LineageNode):
696
818
  image_repo_name=image_repo_id,
697
819
  ingress_enabled=ingress_enabled,
698
820
  max_instances=max_instances,
821
+ cpu_requests=cpu_requests,
822
+ memory_requests=memory_requests,
699
823
  gpu_requests=gpu_requests,
700
824
  num_workers=num_workers,
701
825
  max_batch_rows=max_batch_rows,
702
826
  force_rebuild=force_rebuild,
703
- build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration),
827
+ build_external_access_integrations=(
828
+ None
829
+ if build_external_access_integrations is None
830
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
831
+ ),
704
832
  statement_params=statement_params,
705
833
  )
706
834
 
@@ -710,7 +838,7 @@ class ModelVersion(lineage_node.LineageNode):
710
838
  )
711
839
  def list_services(
712
840
  self,
713
- ) -> List[str]:
841
+ ) -> pd.DataFrame:
714
842
  """List all the service names using this model version.
715
843
 
716
844
  Returns:
@@ -722,12 +850,14 @@ class ModelVersion(lineage_node.LineageNode):
722
850
  subproject=_TELEMETRY_SUBPROJECT,
723
851
  )
724
852
 
725
- return self._model_ops.list_inference_services(
726
- database_name=None,
727
- schema_name=None,
728
- model_name=self._model_name,
729
- version_name=self._version_name,
730
- statement_params=statement_params,
853
+ return pd.DataFrame(
854
+ self._model_ops.show_services(
855
+ database_name=None,
856
+ schema_name=None,
857
+ model_name=self._model_name,
858
+ version_name=self._version_name,
859
+ statement_params=statement_params,
860
+ )
731
861
  )
732
862
 
733
863
  @telemetry.send_api_usage_telemetry(
@@ -755,12 +885,16 @@ class ModelVersion(lineage_node.LineageNode):
755
885
  project=_TELEMETRY_PROJECT,
756
886
  subproject=_TELEMETRY_SUBPROJECT,
757
887
  )
888
+
889
+ database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
758
890
  self._model_ops.delete_service(
759
891
  database_name=None,
760
892
  schema_name=None,
761
893
  model_name=self._model_name,
762
894
  version_name=self._version_name,
763
- service_name=service_name,
895
+ service_database_name=database_name_id,
896
+ service_schema_name=schema_name_id,
897
+ service_name=service_name_id,
764
898
  statement_params=statement_params,
765
899
  )
766
900
 
@@ -3,7 +3,7 @@ import os
3
3
  import pathlib
4
4
  import tempfile
5
5
  import warnings
6
- from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
6
+ from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload
7
7
 
8
8
  import yaml
9
9
 
@@ -31,7 +31,15 @@ from snowflake.snowpark import dataframe, row, session
31
31
  from snowflake.snowpark._internal import utils as snowpark_utils
32
32
 
33
33
 
34
+ class ServiceInfo(TypedDict):
35
+ name: str
36
+ inference_endpoint: Optional[str]
37
+
38
+
34
39
  class ModelOperator:
40
+ INFERENCE_SERVICE_ENDPOINT_NAME = "inference"
41
+ INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app"
42
+
35
43
  def __init__(
36
44
  self,
37
45
  session: session.Session,
@@ -514,7 +522,7 @@ class ModelOperator:
514
522
  statement_params=statement_params,
515
523
  )
516
524
 
517
- def list_inference_services(
525
+ def show_services(
518
526
  self,
519
527
  *,
520
528
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -522,7 +530,7 @@ class ModelOperator:
522
530
  model_name: sql_identifier.SqlIdentifier,
523
531
  version_name: sql_identifier.SqlIdentifier,
524
532
  statement_params: Optional[Dict[str, Any]] = None,
525
- ) -> List[str]:
533
+ ) -> List[ServiceInfo]:
526
534
  res = self._model_client.show_versions(
527
535
  database_name=database_name,
528
536
  schema_name=schema_name,
@@ -530,8 +538,8 @@ class ModelOperator:
530
538
  version_name=version_name,
531
539
  statement_params=statement_params,
532
540
  )
533
- col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME
534
- if col_name not in res[0]:
541
+ service_col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME
542
+ if service_col_name not in res[0]:
535
543
  # User need to opt into BCR 2024_08
536
544
  raise exceptions.SnowflakeMLException(
537
545
  error_code=error_codes.OPT_IN_REQUIRED,
@@ -540,9 +548,31 @@ class ModelOperator:
540
548
  "https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_08_bundle)."
541
549
  ),
542
550
  )
543
- json_array = json.loads(res[0][col_name])
551
+
552
+ json_array = json.loads(res[0][service_col_name])
544
553
  # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
545
- return [str(service) for service in json_array if "MODEL_BUILD_" not in service]
554
+ fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
555
+
556
+ result = []
557
+ ingress_url: Optional[str] = None
558
+ for fully_qualified_service_name in fully_qualified_service_names:
559
+ db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
560
+ for res_row in self._service_client.show_endpoints(
561
+ database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
562
+ ):
563
+ if (
564
+ res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME]
565
+ == self.INFERENCE_SERVICE_ENDPOINT_NAME
566
+ and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None
567
+ ):
568
+ ingress_url = str(
569
+ res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
570
+ )
571
+ if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
572
+ ingress_url = None
573
+ result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
574
+
575
+ return result
546
576
 
547
577
  def delete_service(
548
578
  self,
@@ -551,32 +581,42 @@ class ModelOperator:
551
581
  schema_name: Optional[sql_identifier.SqlIdentifier],
552
582
  model_name: sql_identifier.SqlIdentifier,
553
583
  version_name: sql_identifier.SqlIdentifier,
554
- service_name: str,
584
+ service_database_name: Optional[sql_identifier.SqlIdentifier],
585
+ service_schema_name: Optional[sql_identifier.SqlIdentifier],
586
+ service_name: sql_identifier.SqlIdentifier,
555
587
  statement_params: Optional[Dict[str, Any]] = None,
556
588
  ) -> None:
557
- services = self.list_inference_services(
589
+ services = self.show_services(
558
590
  database_name=database_name,
559
591
  schema_name=schema_name,
560
592
  model_name=model_name,
561
593
  version_name=version_name,
562
594
  statement_params=statement_params,
563
595
  )
564
- db, schema, service_name = sql_identifier.parse_fully_qualified_name(service_name)
596
+
597
+ # Fall back to the model's database and schema.
598
+ # database_name or schema_name are set if the model is created or get using fully qualified name
599
+ # Otherwise, the model's database and schema are same as registry's database and schema, which are set in the
600
+ # self._model_client.
601
+
602
+ service_database_name = service_database_name or database_name or self._model_client._database_name
603
+ service_schema_name = service_schema_name or schema_name or self._model_client._schema_name
565
604
  fully_qualified_service_name = sql_identifier.get_fully_qualified_name(
566
- db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema()
605
+ service_database_name, service_schema_name, service_name
567
606
  )
568
607
 
569
- for service in services:
570
- if service == fully_qualified_service_name:
608
+ for service_info in services:
609
+ if service_info["name"] == fully_qualified_service_name:
571
610
  self._service_client.drop_service(
572
- database_name=db,
573
- schema_name=schema,
611
+ database_name=service_database_name,
612
+ schema_name=service_schema_name,
574
613
  service_name=service_name,
575
614
  statement_params=statement_params,
576
615
  )
577
616
  return
578
617
  raise ValueError(
579
- f"Service '{service_name}' does not exist or unauthorized or not associated with this model version."
618
+ f"Service '{fully_qualified_service_name}' does not exist "
619
+ "or unauthorized or not associated with this model version."
580
620
  )
581
621
 
582
622
  def get_model_version_manifest(
@@ -100,13 +100,26 @@ class ServiceOperator:
100
100
  image_repo_name: sql_identifier.SqlIdentifier,
101
101
  ingress_enabled: bool,
102
102
  max_instances: int,
103
+ cpu_requests: Optional[str],
104
+ memory_requests: Optional[str],
103
105
  gpu_requests: Optional[str],
104
106
  num_workers: Optional[int],
105
107
  max_batch_rows: Optional[int],
106
108
  force_rebuild: bool,
107
- build_external_access_integration: sql_identifier.SqlIdentifier,
109
+ build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
108
110
  statement_params: Optional[Dict[str, Any]] = None,
109
111
  ) -> str:
112
+
113
+ # Fall back to the registry's database and schema if not provided
114
+ database_name = database_name or self._database_name
115
+ schema_name = schema_name or self._schema_name
116
+
117
+ # Fall back to the model's database and schema if not provided then to the registry's database and schema
118
+ service_database_name = service_database_name or database_name or self._database_name
119
+ service_schema_name = service_schema_name or schema_name or self._schema_name
120
+
121
+ image_repo_database_name = image_repo_database_name or database_name or self._database_name
122
+ image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
110
123
  # create a temp stage
111
124
  stage_name = sql_identifier.SqlIdentifier(
112
125
  snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
@@ -119,9 +132,17 @@ class ServiceOperator:
119
132
  )
120
133
  stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
121
134
 
135
+ # TODO(hayu): Remove the version check after Snowflake 8.40.0 release
136
+ if (
137
+ snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
138
+ < version.parse("8.40.0")
139
+ and build_external_access_integrations is None
140
+ ):
141
+ raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
142
+
122
143
  self._model_deployment_spec.save(
123
- database_name=database_name or self._database_name,
124
- schema_name=schema_name or self._schema_name,
144
+ database_name=database_name,
145
+ schema_name=schema_name,
125
146
  model_name=model_name,
126
147
  version_name=version_name,
127
148
  service_database_name=service_database_name,
@@ -134,11 +155,13 @@ class ServiceOperator:
134
155
  image_repo_name=image_repo_name,
135
156
  ingress_enabled=ingress_enabled,
136
157
  max_instances=max_instances,
158
+ cpu=cpu_requests,
159
+ memory=memory_requests,
137
160
  gpu=gpu_requests,
138
161
  num_workers=num_workers,
139
162
  max_batch_rows=max_batch_rows,
140
163
  force_rebuild=force_rebuild,
141
- external_access_integration=build_external_access_integration,
164
+ external_access_integrations=build_external_access_integrations,
142
165
  )
143
166
  file_utils.upload_directory_to_stage(
144
167
  self._session,
@@ -163,32 +186,25 @@ class ServiceOperator:
163
186
  statement_params=statement_params,
164
187
  )
165
188
 
166
- # TODO(hayu): Remove the version check after Snowflake 8.37.0 release
167
- if snowflake_env.get_current_snowflake_version(
168
- self._session, statement_params=statement_params
169
- ) >= version.parse("8.37.0"):
170
- # stream service logs in a thread
171
- model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
172
- model_build_service = ServiceLogInfo(
173
- database_name=service_database_name,
174
- schema_name=service_schema_name,
175
- service_name=model_build_service_name,
176
- container_name="model-build",
177
- )
178
- model_inference_service = ServiceLogInfo(
179
- database_name=service_database_name,
180
- schema_name=service_schema_name,
181
- service_name=service_name,
182
- container_name="model-inference",
183
- )
184
- services = [model_build_service, model_inference_service]
185
- log_thread = self._start_service_log_streaming(
186
- async_job, services, model_inference_service_exists, force_rebuild, statement_params
187
- )
188
- log_thread.join()
189
- else:
190
- while not async_job.is_done():
191
- time.sleep(5)
189
+ # stream service logs in a thread
190
+ model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
191
+ model_build_service = ServiceLogInfo(
192
+ database_name=service_database_name,
193
+ schema_name=service_schema_name,
194
+ service_name=model_build_service_name,
195
+ container_name="model-build",
196
+ )
197
+ model_inference_service = ServiceLogInfo(
198
+ database_name=service_database_name,
199
+ schema_name=service_schema_name,
200
+ service_name=service_name,
201
+ container_name="model-inference",
202
+ )
203
+ services = [model_build_service, model_inference_service]
204
+ log_thread = self._start_service_log_streaming(
205
+ async_job, services, model_inference_service_exists, force_rebuild, statement_params
206
+ )
207
+ log_thread.join()
192
208
 
193
209
  res = cast(str, cast(List[row.Row], async_job.result())[0][0])
194
210
  module_logger.info(f"Inference service {service_name} deployment complete: {res}")
@@ -1,5 +1,5 @@
1
1
  import pathlib
2
- from typing import Optional
2
+ from typing import List, Optional
3
3
 
4
4
  import yaml
5
5
 
@@ -36,11 +36,13 @@ class ModelDeploymentSpec:
36
36
  image_repo_name: sql_identifier.SqlIdentifier,
37
37
  ingress_enabled: bool,
38
38
  max_instances: int,
39
+ cpu: Optional[str],
40
+ memory: Optional[str],
39
41
  gpu: Optional[str],
40
42
  num_workers: Optional[int],
41
43
  max_batch_rows: Optional[int],
42
44
  force_rebuild: bool,
43
- external_access_integration: sql_identifier.SqlIdentifier,
45
+ external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
44
46
  ) -> None:
45
47
  # create the deployment spec
46
48
  # models spec
@@ -55,12 +57,15 @@ class ModelDeploymentSpec:
55
57
  fq_image_repo_name = identifier.get_schema_level_object_identifier(
56
58
  saved_image_repo_database.identifier(), saved_image_repo_schema.identifier(), image_repo_name.identifier()
57
59
  )
58
- image_build_dict = model_deployment_spec_schema.ImageBuildDict(
59
- compute_pool=image_build_compute_pool_name.identifier(),
60
- image_repo=fq_image_repo_name,
61
- force_rebuild=force_rebuild,
62
- external_access_integrations=[external_access_integration.identifier()],
63
- )
60
+ image_build_dict: model_deployment_spec_schema.ImageBuildDict = {
61
+ "compute_pool": image_build_compute_pool_name.identifier(),
62
+ "image_repo": fq_image_repo_name,
63
+ "force_rebuild": force_rebuild,
64
+ }
65
+ if external_access_integrations is not None:
66
+ image_build_dict["external_access_integrations"] = [
67
+ eai.identifier() for eai in external_access_integrations
68
+ ]
64
69
 
65
70
  # service spec
66
71
  saved_service_database = service_database_name or database_name
@@ -74,6 +79,12 @@ class ModelDeploymentSpec:
74
79
  ingress_enabled=ingress_enabled,
75
80
  max_instances=max_instances,
76
81
  )
82
+ if cpu:
83
+ service_dict["cpu"] = cpu
84
+
85
+ if memory:
86
+ service_dict["memory"] = memory
87
+
77
88
  if gpu:
78
89
  service_dict["gpu"] = gpu
79
90
 
@@ -12,7 +12,7 @@ class ImageBuildDict(TypedDict):
12
12
  compute_pool: Required[str]
13
13
  image_repo: Required[str]
14
14
  force_rebuild: Required[bool]
15
- external_access_integrations: Required[List[str]]
15
+ external_access_integrations: NotRequired[List[str]]
16
16
 
17
17
 
18
18
  class ServiceDict(TypedDict):
@@ -20,6 +20,8 @@ class ServiceDict(TypedDict):
20
20
  compute_pool: Required[str]
21
21
  ingress_enabled: Required[bool]
22
22
  max_instances: Required[int]
23
+ cpu: NotRequired[str]
24
+ memory: NotRequired[str]
23
25
  gpu: NotRequired[str]
24
26
  num_workers: NotRequired[int]
25
27
  max_batch_rows: NotRequired[int]
@@ -10,7 +10,7 @@ from snowflake.ml._internal.utils import (
10
10
  sql_identifier,
11
11
  )
12
12
  from snowflake.ml.model._client.sql import _base
13
- from snowflake.snowpark import dataframe, functions as F, types as spt
13
+ from snowflake.snowpark import dataframe, functions as F, row, types as spt
14
14
  from snowflake.snowpark._internal import utils as snowpark_utils
15
15
 
16
16
 
@@ -26,6 +26,9 @@ class ServiceStatus(enum.Enum):
26
26
 
27
27
 
28
28
  class ServiceSQLClient(_base._BaseSQLClient):
29
+ MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
30
+ MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
31
+
29
32
  def build_model_container(
30
33
  self,
31
34
  *,
@@ -216,3 +219,24 @@ class ServiceSQLClient(_base._BaseSQLClient):
216
219
  f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}",
217
220
  statement_params=statement_params,
218
221
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
222
+
223
+ def show_endpoints(
224
+ self,
225
+ *,
226
+ database_name: Optional[sql_identifier.SqlIdentifier],
227
+ schema_name: Optional[sql_identifier.SqlIdentifier],
228
+ service_name: sql_identifier.SqlIdentifier,
229
+ statement_params: Optional[Dict[str, Any]] = None,
230
+ ) -> List[row.Row]:
231
+ fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
232
+ res = (
233
+ query_result_checker.SqlResultValidator(
234
+ self._session,
235
+ (f"SHOW ENDPOINTS IN SERVICE {fully_qualified_service_name}"),
236
+ statement_params=statement_params,
237
+ )
238
+ .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
239
+ .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
240
+ )
241
+
242
+ return res.validate()
@@ -86,6 +86,7 @@ class ModelComposer:
86
86
  metadata: Optional[Dict[str, str]] = None,
87
87
  conda_dependencies: Optional[List[str]] = None,
88
88
  pip_requirements: Optional[List[str]] = None,
89
+ target_platforms: Optional[List[model_types.TargetPlatform]] = None,
89
90
  python_version: Optional[str] = None,
90
91
  ext_modules: Optional[List[ModuleType]] = None,
91
92
  code_paths: Optional[List[str]] = None,
@@ -131,6 +132,7 @@ class ModelComposer:
131
132
  model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
132
133
  options=options,
133
134
  data_sources=self._get_data_sources(model, sample_input_data),
135
+ target_platforms=target_platforms,
134
136
  )
135
137
 
136
138
  file_utils.upload_directory_to_stage(
@@ -44,6 +44,7 @@ class ModelManifest:
44
44
  model_rel_path: pathlib.PurePosixPath,
45
45
  options: Optional[type_hints.ModelSaveOption] = None,
46
46
  data_sources: Optional[List[data_source.DataSource]] = None,
47
+ target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
47
48
  ) -> None:
48
49
  if options is None:
49
50
  options = {}
@@ -132,6 +133,9 @@ class ModelManifest:
132
133
  if lineage_sources:
133
134
  manifest_dict["lineage_sources"] = lineage_sources
134
135
 
136
+ if target_platforms:
137
+ manifest_dict["target_platforms"] = [platform.value for platform in target_platforms]
138
+
135
139
  with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
136
140
  # Anchors are not supported in the server, avoid that.
137
141
  yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
@@ -95,3 +95,4 @@ class ModelManifestDict(TypedDict):
95
95
  methods: Required[List[ModelMethodDict]]
96
96
  user_data: NotRequired[Dict[str, Any]]
97
97
  lineage_sources: NotRequired[List[LineageSourceDict]]
98
+ target_platforms: NotRequired[List[str]]
@@ -5,6 +5,7 @@ import sys
5
5
 
6
6
  import anyio
7
7
  import pandas as pd
8
+ import numpy as np
8
9
  from _snowflake import vectorized
9
10
 
10
11
  from snowflake.ml.model._packager import model_packager
@@ -47,4 +48,4 @@ def {function_name}(df: pd.DataFrame) -> dict:
47
48
  df.columns = input_cols
48
49
  input_df = df.astype(dtype=dtype_map)
49
50
  predictions_df = runner(input_df[input_cols])
50
- return predictions_df.to_dict("records")
51
+ return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")