snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/telemetry.py +4 -2
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/import_utils.py +31 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/data/torch_utils.py +33 -14
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
- snowflake/ml/feature_store/examples/example_helper.py +6 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
- snowflake/ml/feature_store/feature_store.py +1 -2
- snowflake/ml/feature_store/feature_view.py +5 -1
- snowflake/ml/model/_client/model/model_version_impl.py +145 -11
- snowflake/ml/model/_client/ops/model_ops.py +56 -16
- snowflake/ml/model/_client/ops/service_ops.py +46 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
- snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_packager.py +0 -11
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +87 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/custom_model.py +47 -7
- snowflake/ml/model/model_signature.py +40 -9
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
- snowflake/ml/modeling/cluster/dbscan.py +5 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
- snowflake/ml/modeling/cluster/k_means.py +14 -19
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
- snowflake/ml/modeling/cluster/optics.py +6 -6
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
- snowflake/ml/modeling/compose/column_transformer.py +15 -5
- snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
- snowflake/ml/modeling/decomposition/pca.py +28 -15
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
- snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +0 -10
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +3 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
- snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
- snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +3 -3
- snowflake/ml/modeling/manifold/tsne.py +10 -4
- snowflake/ml/modeling/metrics/classification.py +12 -16
- snowflake/ml/modeling/metrics/ranking.py +3 -3
- snowflake/ml/modeling/metrics/regression.py +3 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +16 -14
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
- snowflake/ml/modeling/svm/linear_svc.py +25 -16
- snowflake/ml/modeling/svm/linear_svr.py +23 -17
- snowflake/ml/modeling/svm/nu_svc.py +5 -3
- snowflake/ml/modeling/svm/nu_svr.py +3 -1
- snowflake/ml/modeling/svm/svc.py +9 -5
- snowflake/ml/modeling/svm/svr.py +3 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
- snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +37 -0
- snowflake/ml/registry/_manager/model_manager.py +15 -1
- snowflake/ml/registry/registry.py +32 -37
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/_client/model_monitor.py +0 -126
- snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
- snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -614,6 +614,102 @@ class ModelVersion(lineage_node.LineageNode):
|
|
614
614
|
version_name=sql_identifier.SqlIdentifier(version),
|
615
615
|
)
|
616
616
|
|
617
|
+
@overload
|
618
|
+
def create_service(
|
619
|
+
self,
|
620
|
+
*,
|
621
|
+
service_name: str,
|
622
|
+
image_build_compute_pool: Optional[str] = None,
|
623
|
+
service_compute_pool: str,
|
624
|
+
image_repo: str,
|
625
|
+
ingress_enabled: bool = False,
|
626
|
+
max_instances: int = 1,
|
627
|
+
cpu_requests: Optional[str] = None,
|
628
|
+
memory_requests: Optional[str] = None,
|
629
|
+
gpu_requests: Optional[str] = None,
|
630
|
+
num_workers: Optional[int] = None,
|
631
|
+
max_batch_rows: Optional[int] = None,
|
632
|
+
force_rebuild: bool = False,
|
633
|
+
build_external_access_integration: Optional[str] = None,
|
634
|
+
) -> str:
|
635
|
+
"""Create an inference service with the given spec.
|
636
|
+
|
637
|
+
Args:
|
638
|
+
service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
|
639
|
+
schema of the model will be used.
|
640
|
+
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
641
|
+
the service compute pool if None.
|
642
|
+
service_compute_pool: The name of the compute pool used to run the inference service.
|
643
|
+
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
644
|
+
or schema of the model will be used.
|
645
|
+
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
646
|
+
BIND SERVICE ENDPOINT privilege on the account.
|
647
|
+
max_instances: The maximum number of inference service instances to run. The same value it set to
|
648
|
+
MIN_INSTANCES property of the service.
|
649
|
+
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
650
|
+
None, we attempt to utilize all the vCPU of the node.
|
651
|
+
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
652
|
+
requires a unit (GiB, MiB). If None, we attempt to utilize all the memory of the node.
|
653
|
+
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
654
|
+
if None.
|
655
|
+
num_workers: The number of workers to run the inference service for handling requests in parallel within an
|
656
|
+
instance of the service. By default, it is set to 2*vCPU+1 of the node for CPU based inference and 1 for
|
657
|
+
GPU based inference. For GPU based inference, please see best practices before playing with this value.
|
658
|
+
max_batch_rows: The maximum number of rows to batch for inference. Auto determined if None. Minimum 32.
|
659
|
+
force_rebuild: Whether to force a model inference image rebuild.
|
660
|
+
build_external_access_integration: (Deprecated) The external access integration for image build. This is
|
661
|
+
usually permitting access to conda & PyPI repositories.
|
662
|
+
"""
|
663
|
+
...
|
664
|
+
|
665
|
+
@overload
|
666
|
+
def create_service(
|
667
|
+
self,
|
668
|
+
*,
|
669
|
+
service_name: str,
|
670
|
+
image_build_compute_pool: Optional[str] = None,
|
671
|
+
service_compute_pool: str,
|
672
|
+
image_repo: str,
|
673
|
+
ingress_enabled: bool = False,
|
674
|
+
max_instances: int = 1,
|
675
|
+
cpu_requests: Optional[str] = None,
|
676
|
+
memory_requests: Optional[str] = None,
|
677
|
+
gpu_requests: Optional[str] = None,
|
678
|
+
num_workers: Optional[int] = None,
|
679
|
+
max_batch_rows: Optional[int] = None,
|
680
|
+
force_rebuild: bool = False,
|
681
|
+
build_external_access_integrations: Optional[List[str]] = None,
|
682
|
+
) -> str:
|
683
|
+
"""Create an inference service with the given spec.
|
684
|
+
|
685
|
+
Args:
|
686
|
+
service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
|
687
|
+
schema of the model will be used.
|
688
|
+
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
689
|
+
the service compute pool if None.
|
690
|
+
service_compute_pool: The name of the compute pool used to run the inference service.
|
691
|
+
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
692
|
+
or schema of the model will be used.
|
693
|
+
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
694
|
+
BIND SERVICE ENDPOINT privilege on the account.
|
695
|
+
max_instances: The maximum number of inference service instances to run. The same value it set to
|
696
|
+
MIN_INSTANCES property of the service.
|
697
|
+
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
698
|
+
None, we attempt to utilize all the vCPU of the node.
|
699
|
+
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
700
|
+
requires a unit (GiB, MiB). If None, we attempt to utilize all the memory of the node.
|
701
|
+
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
702
|
+
if None.
|
703
|
+
num_workers: The number of workers to run the inference service for handling requests in parallel within an
|
704
|
+
instance of the service. By default, it is set to 2*vCPU+1 of the node for CPU based inference and 1 for
|
705
|
+
GPU based inference. For GPU based inference, please see best practices before playing with this value.
|
706
|
+
max_batch_rows: The maximum number of rows to batch for inference. Auto determined if None. Minimum 32.
|
707
|
+
force_rebuild: Whether to force a model inference image rebuild.
|
708
|
+
build_external_access_integrations: The external access integrations for image build. This is usually
|
709
|
+
permitting access to conda & PyPI repositories.
|
710
|
+
"""
|
711
|
+
...
|
712
|
+
|
617
713
|
@telemetry.send_api_usage_telemetry(
|
618
714
|
project=_TELEMETRY_PROJECT,
|
619
715
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -638,11 +734,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
638
734
|
image_repo: str,
|
639
735
|
ingress_enabled: bool = False,
|
640
736
|
max_instances: int = 1,
|
737
|
+
cpu_requests: Optional[str] = None,
|
738
|
+
memory_requests: Optional[str] = None,
|
641
739
|
gpu_requests: Optional[str] = None,
|
642
740
|
num_workers: Optional[int] = None,
|
643
741
|
max_batch_rows: Optional[int] = None,
|
644
742
|
force_rebuild: bool = False,
|
645
|
-
build_external_access_integration: str,
|
743
|
+
build_external_access_integration: Optional[str] = None,
|
744
|
+
build_external_access_integrations: Optional[List[str]] = None,
|
646
745
|
) -> str:
|
647
746
|
"""Create an inference service with the given spec.
|
648
747
|
|
@@ -658,6 +757,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
658
757
|
BIND SERVICE ENDPOINT privilege on the account.
|
659
758
|
max_instances: The maximum number of inference service instances to run. The same value it set to
|
660
759
|
MIN_INSTANCES property of the service.
|
760
|
+
cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
|
761
|
+
None, we attempt to utilize all the vCPU of the node.
|
762
|
+
memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
|
763
|
+
requires a unit (GiB, MiB). If None, we attempt to utilize all the memory of the node.
|
661
764
|
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
662
765
|
if None.
|
663
766
|
num_workers: The number of workers to run the inference service for handling requests in parallel within an
|
@@ -665,9 +768,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
665
768
|
GPU based inference. For GPU based inference, please see best practices before playing with this value.
|
666
769
|
max_batch_rows: The maximum number of rows to batch for inference. Auto determined if None. Minimum 32.
|
667
770
|
force_rebuild: Whether to force a model inference image rebuild.
|
668
|
-
build_external_access_integration: The external access integration for image build. This is
|
771
|
+
build_external_access_integration: (Deprecated) The external access integration for image build. This is
|
772
|
+
usually permitting access to conda & PyPI repositories.
|
773
|
+
build_external_access_integrations: The external access integrations for image build. This is usually
|
669
774
|
permitting access to conda & PyPI repositories.
|
670
775
|
|
776
|
+
Raises:
|
777
|
+
ValueError: Illegal external access integration arguments.
|
778
|
+
|
671
779
|
Returns:
|
672
780
|
Result information about service creation from server.
|
673
781
|
"""
|
@@ -675,6 +783,20 @@ class ModelVersion(lineage_node.LineageNode):
|
|
675
783
|
project=_TELEMETRY_PROJECT,
|
676
784
|
subproject=_TELEMETRY_SUBPROJECT,
|
677
785
|
)
|
786
|
+
if build_external_access_integration is not None:
|
787
|
+
msg = (
|
788
|
+
"`build_external_access_integration` is deprecated. "
|
789
|
+
"Please use `build_external_access_integrations` instead."
|
790
|
+
)
|
791
|
+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
792
|
+
if build_external_access_integrations is not None:
|
793
|
+
msg = (
|
794
|
+
"`build_external_access_integration` and `build_external_access_integrations` cannot be set at the"
|
795
|
+
"same time. Please use `build_external_access_integrations` only."
|
796
|
+
)
|
797
|
+
raise ValueError(msg)
|
798
|
+
build_external_access_integrations = [build_external_access_integration]
|
799
|
+
|
678
800
|
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
679
801
|
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
680
802
|
return self._service_ops.create_service(
|
@@ -696,11 +818,17 @@ class ModelVersion(lineage_node.LineageNode):
|
|
696
818
|
image_repo_name=image_repo_id,
|
697
819
|
ingress_enabled=ingress_enabled,
|
698
820
|
max_instances=max_instances,
|
821
|
+
cpu_requests=cpu_requests,
|
822
|
+
memory_requests=memory_requests,
|
699
823
|
gpu_requests=gpu_requests,
|
700
824
|
num_workers=num_workers,
|
701
825
|
max_batch_rows=max_batch_rows,
|
702
826
|
force_rebuild=force_rebuild,
|
703
|
-
|
827
|
+
build_external_access_integrations=(
|
828
|
+
None
|
829
|
+
if build_external_access_integrations is None
|
830
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
831
|
+
),
|
704
832
|
statement_params=statement_params,
|
705
833
|
)
|
706
834
|
|
@@ -710,7 +838,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
710
838
|
)
|
711
839
|
def list_services(
|
712
840
|
self,
|
713
|
-
) ->
|
841
|
+
) -> pd.DataFrame:
|
714
842
|
"""List all the service names using this model version.
|
715
843
|
|
716
844
|
Returns:
|
@@ -722,12 +850,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
722
850
|
subproject=_TELEMETRY_SUBPROJECT,
|
723
851
|
)
|
724
852
|
|
725
|
-
return
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
853
|
+
return pd.DataFrame(
|
854
|
+
self._model_ops.show_services(
|
855
|
+
database_name=None,
|
856
|
+
schema_name=None,
|
857
|
+
model_name=self._model_name,
|
858
|
+
version_name=self._version_name,
|
859
|
+
statement_params=statement_params,
|
860
|
+
)
|
731
861
|
)
|
732
862
|
|
733
863
|
@telemetry.send_api_usage_telemetry(
|
@@ -755,12 +885,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
755
885
|
project=_TELEMETRY_PROJECT,
|
756
886
|
subproject=_TELEMETRY_SUBPROJECT,
|
757
887
|
)
|
888
|
+
|
889
|
+
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
758
890
|
self._model_ops.delete_service(
|
759
891
|
database_name=None,
|
760
892
|
schema_name=None,
|
761
893
|
model_name=self._model_name,
|
762
894
|
version_name=self._version_name,
|
763
|
-
|
895
|
+
service_database_name=database_name_id,
|
896
|
+
service_schema_name=schema_name_id,
|
897
|
+
service_name=service_name_id,
|
764
898
|
statement_params=statement_params,
|
765
899
|
)
|
766
900
|
|
@@ -3,7 +3,7 @@ import os
|
|
3
3
|
import pathlib
|
4
4
|
import tempfile
|
5
5
|
import warnings
|
6
|
-
from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
|
6
|
+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload
|
7
7
|
|
8
8
|
import yaml
|
9
9
|
|
@@ -31,7 +31,15 @@ from snowflake.snowpark import dataframe, row, session
|
|
31
31
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
32
32
|
|
33
33
|
|
34
|
+
class ServiceInfo(TypedDict):
|
35
|
+
name: str
|
36
|
+
inference_endpoint: Optional[str]
|
37
|
+
|
38
|
+
|
34
39
|
class ModelOperator:
|
40
|
+
INFERENCE_SERVICE_ENDPOINT_NAME = "inference"
|
41
|
+
INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app"
|
42
|
+
|
35
43
|
def __init__(
|
36
44
|
self,
|
37
45
|
session: session.Session,
|
@@ -514,7 +522,7 @@ class ModelOperator:
|
|
514
522
|
statement_params=statement_params,
|
515
523
|
)
|
516
524
|
|
517
|
-
def
|
525
|
+
def show_services(
|
518
526
|
self,
|
519
527
|
*,
|
520
528
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -522,7 +530,7 @@ class ModelOperator:
|
|
522
530
|
model_name: sql_identifier.SqlIdentifier,
|
523
531
|
version_name: sql_identifier.SqlIdentifier,
|
524
532
|
statement_params: Optional[Dict[str, Any]] = None,
|
525
|
-
) -> List[
|
533
|
+
) -> List[ServiceInfo]:
|
526
534
|
res = self._model_client.show_versions(
|
527
535
|
database_name=database_name,
|
528
536
|
schema_name=schema_name,
|
@@ -530,8 +538,8 @@ class ModelOperator:
|
|
530
538
|
version_name=version_name,
|
531
539
|
statement_params=statement_params,
|
532
540
|
)
|
533
|
-
|
534
|
-
if
|
541
|
+
service_col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME
|
542
|
+
if service_col_name not in res[0]:
|
535
543
|
# User need to opt into BCR 2024_08
|
536
544
|
raise exceptions.SnowflakeMLException(
|
537
545
|
error_code=error_codes.OPT_IN_REQUIRED,
|
@@ -540,9 +548,31 @@ class ModelOperator:
|
|
540
548
|
"https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_08_bundle)."
|
541
549
|
),
|
542
550
|
)
|
543
|
-
|
551
|
+
|
552
|
+
json_array = json.loads(res[0][service_col_name])
|
544
553
|
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
545
|
-
|
554
|
+
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
555
|
+
|
556
|
+
result = []
|
557
|
+
ingress_url: Optional[str] = None
|
558
|
+
for fully_qualified_service_name in fully_qualified_service_names:
|
559
|
+
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
560
|
+
for res_row in self._service_client.show_endpoints(
|
561
|
+
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
562
|
+
):
|
563
|
+
if (
|
564
|
+
res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME]
|
565
|
+
== self.INFERENCE_SERVICE_ENDPOINT_NAME
|
566
|
+
and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None
|
567
|
+
):
|
568
|
+
ingress_url = str(
|
569
|
+
res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
|
570
|
+
)
|
571
|
+
if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
|
572
|
+
ingress_url = None
|
573
|
+
result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
|
574
|
+
|
575
|
+
return result
|
546
576
|
|
547
577
|
def delete_service(
|
548
578
|
self,
|
@@ -551,32 +581,42 @@ class ModelOperator:
|
|
551
581
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
552
582
|
model_name: sql_identifier.SqlIdentifier,
|
553
583
|
version_name: sql_identifier.SqlIdentifier,
|
554
|
-
|
584
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
585
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
586
|
+
service_name: sql_identifier.SqlIdentifier,
|
555
587
|
statement_params: Optional[Dict[str, Any]] = None,
|
556
588
|
) -> None:
|
557
|
-
services = self.
|
589
|
+
services = self.show_services(
|
558
590
|
database_name=database_name,
|
559
591
|
schema_name=schema_name,
|
560
592
|
model_name=model_name,
|
561
593
|
version_name=version_name,
|
562
594
|
statement_params=statement_params,
|
563
595
|
)
|
564
|
-
|
596
|
+
|
597
|
+
# Fall back to the model's database and schema.
|
598
|
+
# database_name or schema_name are set if the model is created or get using fully qualified name
|
599
|
+
# Otherwise, the model's database and schema are same as registry's database and schema, which are set in the
|
600
|
+
# self._model_client.
|
601
|
+
|
602
|
+
service_database_name = service_database_name or database_name or self._model_client._database_name
|
603
|
+
service_schema_name = service_schema_name or schema_name or self._model_client._schema_name
|
565
604
|
fully_qualified_service_name = sql_identifier.get_fully_qualified_name(
|
566
|
-
|
605
|
+
service_database_name, service_schema_name, service_name
|
567
606
|
)
|
568
607
|
|
569
|
-
for
|
570
|
-
if
|
608
|
+
for service_info in services:
|
609
|
+
if service_info["name"] == fully_qualified_service_name:
|
571
610
|
self._service_client.drop_service(
|
572
|
-
database_name=
|
573
|
-
schema_name=
|
611
|
+
database_name=service_database_name,
|
612
|
+
schema_name=service_schema_name,
|
574
613
|
service_name=service_name,
|
575
614
|
statement_params=statement_params,
|
576
615
|
)
|
577
616
|
return
|
578
617
|
raise ValueError(
|
579
|
-
f"Service '{
|
618
|
+
f"Service '{fully_qualified_service_name}' does not exist "
|
619
|
+
"or unauthorized or not associated with this model version."
|
580
620
|
)
|
581
621
|
|
582
622
|
def get_model_version_manifest(
|
@@ -100,13 +100,26 @@ class ServiceOperator:
|
|
100
100
|
image_repo_name: sql_identifier.SqlIdentifier,
|
101
101
|
ingress_enabled: bool,
|
102
102
|
max_instances: int,
|
103
|
+
cpu_requests: Optional[str],
|
104
|
+
memory_requests: Optional[str],
|
103
105
|
gpu_requests: Optional[str],
|
104
106
|
num_workers: Optional[int],
|
105
107
|
max_batch_rows: Optional[int],
|
106
108
|
force_rebuild: bool,
|
107
|
-
|
109
|
+
build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
|
108
110
|
statement_params: Optional[Dict[str, Any]] = None,
|
109
111
|
) -> str:
|
112
|
+
|
113
|
+
# Fall back to the registry's database and schema if not provided
|
114
|
+
database_name = database_name or self._database_name
|
115
|
+
schema_name = schema_name or self._schema_name
|
116
|
+
|
117
|
+
# Fall back to the model's database and schema if not provided then to the registry's database and schema
|
118
|
+
service_database_name = service_database_name or database_name or self._database_name
|
119
|
+
service_schema_name = service_schema_name or schema_name or self._schema_name
|
120
|
+
|
121
|
+
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
122
|
+
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
110
123
|
# create a temp stage
|
111
124
|
stage_name = sql_identifier.SqlIdentifier(
|
112
125
|
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
@@ -119,9 +132,17 @@ class ServiceOperator:
|
|
119
132
|
)
|
120
133
|
stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
|
121
134
|
|
135
|
+
# TODO(hayu): Remove the version check after Snowflake 8.40.0 release
|
136
|
+
if (
|
137
|
+
snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
|
138
|
+
< version.parse("8.40.0")
|
139
|
+
and build_external_access_integrations is None
|
140
|
+
):
|
141
|
+
raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
|
142
|
+
|
122
143
|
self._model_deployment_spec.save(
|
123
|
-
database_name=database_name
|
124
|
-
schema_name=schema_name
|
144
|
+
database_name=database_name,
|
145
|
+
schema_name=schema_name,
|
125
146
|
model_name=model_name,
|
126
147
|
version_name=version_name,
|
127
148
|
service_database_name=service_database_name,
|
@@ -134,11 +155,13 @@ class ServiceOperator:
|
|
134
155
|
image_repo_name=image_repo_name,
|
135
156
|
ingress_enabled=ingress_enabled,
|
136
157
|
max_instances=max_instances,
|
158
|
+
cpu=cpu_requests,
|
159
|
+
memory=memory_requests,
|
137
160
|
gpu=gpu_requests,
|
138
161
|
num_workers=num_workers,
|
139
162
|
max_batch_rows=max_batch_rows,
|
140
163
|
force_rebuild=force_rebuild,
|
141
|
-
|
164
|
+
external_access_integrations=build_external_access_integrations,
|
142
165
|
)
|
143
166
|
file_utils.upload_directory_to_stage(
|
144
167
|
self._session,
|
@@ -163,32 +186,25 @@ class ServiceOperator:
|
|
163
186
|
statement_params=statement_params,
|
164
187
|
)
|
165
188
|
|
166
|
-
#
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
model_build_service_name
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
log_thread = self._start_service_log_streaming(
|
186
|
-
async_job, services, model_inference_service_exists, force_rebuild, statement_params
|
187
|
-
)
|
188
|
-
log_thread.join()
|
189
|
-
else:
|
190
|
-
while not async_job.is_done():
|
191
|
-
time.sleep(5)
|
189
|
+
# stream service logs in a thread
|
190
|
+
model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
|
191
|
+
model_build_service = ServiceLogInfo(
|
192
|
+
database_name=service_database_name,
|
193
|
+
schema_name=service_schema_name,
|
194
|
+
service_name=model_build_service_name,
|
195
|
+
container_name="model-build",
|
196
|
+
)
|
197
|
+
model_inference_service = ServiceLogInfo(
|
198
|
+
database_name=service_database_name,
|
199
|
+
schema_name=service_schema_name,
|
200
|
+
service_name=service_name,
|
201
|
+
container_name="model-inference",
|
202
|
+
)
|
203
|
+
services = [model_build_service, model_inference_service]
|
204
|
+
log_thread = self._start_service_log_streaming(
|
205
|
+
async_job, services, model_inference_service_exists, force_rebuild, statement_params
|
206
|
+
)
|
207
|
+
log_thread.join()
|
192
208
|
|
193
209
|
res = cast(str, cast(List[row.Row], async_job.result())[0][0])
|
194
210
|
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import pathlib
|
2
|
-
from typing import Optional
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import yaml
|
5
5
|
|
@@ -36,11 +36,13 @@ class ModelDeploymentSpec:
|
|
36
36
|
image_repo_name: sql_identifier.SqlIdentifier,
|
37
37
|
ingress_enabled: bool,
|
38
38
|
max_instances: int,
|
39
|
+
cpu: Optional[str],
|
40
|
+
memory: Optional[str],
|
39
41
|
gpu: Optional[str],
|
40
42
|
num_workers: Optional[int],
|
41
43
|
max_batch_rows: Optional[int],
|
42
44
|
force_rebuild: bool,
|
43
|
-
|
45
|
+
external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
|
44
46
|
) -> None:
|
45
47
|
# create the deployment spec
|
46
48
|
# models spec
|
@@ -55,12 +57,15 @@ class ModelDeploymentSpec:
|
|
55
57
|
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
56
58
|
saved_image_repo_database.identifier(), saved_image_repo_schema.identifier(), image_repo_name.identifier()
|
57
59
|
)
|
58
|
-
image_build_dict
|
59
|
-
compute_pool
|
60
|
-
image_repo
|
61
|
-
force_rebuild
|
62
|
-
|
63
|
-
|
60
|
+
image_build_dict: model_deployment_spec_schema.ImageBuildDict = {
|
61
|
+
"compute_pool": image_build_compute_pool_name.identifier(),
|
62
|
+
"image_repo": fq_image_repo_name,
|
63
|
+
"force_rebuild": force_rebuild,
|
64
|
+
}
|
65
|
+
if external_access_integrations is not None:
|
66
|
+
image_build_dict["external_access_integrations"] = [
|
67
|
+
eai.identifier() for eai in external_access_integrations
|
68
|
+
]
|
64
69
|
|
65
70
|
# service spec
|
66
71
|
saved_service_database = service_database_name or database_name
|
@@ -74,6 +79,12 @@ class ModelDeploymentSpec:
|
|
74
79
|
ingress_enabled=ingress_enabled,
|
75
80
|
max_instances=max_instances,
|
76
81
|
)
|
82
|
+
if cpu:
|
83
|
+
service_dict["cpu"] = cpu
|
84
|
+
|
85
|
+
if memory:
|
86
|
+
service_dict["memory"] = memory
|
87
|
+
|
77
88
|
if gpu:
|
78
89
|
service_dict["gpu"] = gpu
|
79
90
|
|
@@ -12,7 +12,7 @@ class ImageBuildDict(TypedDict):
|
|
12
12
|
compute_pool: Required[str]
|
13
13
|
image_repo: Required[str]
|
14
14
|
force_rebuild: Required[bool]
|
15
|
-
external_access_integrations:
|
15
|
+
external_access_integrations: NotRequired[List[str]]
|
16
16
|
|
17
17
|
|
18
18
|
class ServiceDict(TypedDict):
|
@@ -20,6 +20,8 @@ class ServiceDict(TypedDict):
|
|
20
20
|
compute_pool: Required[str]
|
21
21
|
ingress_enabled: Required[bool]
|
22
22
|
max_instances: Required[int]
|
23
|
+
cpu: NotRequired[str]
|
24
|
+
memory: NotRequired[str]
|
23
25
|
gpu: NotRequired[str]
|
24
26
|
num_workers: NotRequired[int]
|
25
27
|
max_batch_rows: NotRequired[int]
|
@@ -10,7 +10,7 @@ from snowflake.ml._internal.utils import (
|
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
12
|
from snowflake.ml.model._client.sql import _base
|
13
|
-
from snowflake.snowpark import dataframe, functions as F, types as spt
|
13
|
+
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
14
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
15
|
|
16
16
|
|
@@ -26,6 +26,9 @@ class ServiceStatus(enum.Enum):
|
|
26
26
|
|
27
27
|
|
28
28
|
class ServiceSQLClient(_base._BaseSQLClient):
|
29
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
30
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
31
|
+
|
29
32
|
def build_model_container(
|
30
33
|
self,
|
31
34
|
*,
|
@@ -216,3 +219,24 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
216
219
|
f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}",
|
217
220
|
statement_params=statement_params,
|
218
221
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
222
|
+
|
223
|
+
def show_endpoints(
|
224
|
+
self,
|
225
|
+
*,
|
226
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
227
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
228
|
+
service_name: sql_identifier.SqlIdentifier,
|
229
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
230
|
+
) -> List[row.Row]:
|
231
|
+
fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
232
|
+
res = (
|
233
|
+
query_result_checker.SqlResultValidator(
|
234
|
+
self._session,
|
235
|
+
(f"SHOW ENDPOINTS IN SERVICE {fully_qualified_service_name}"),
|
236
|
+
statement_params=statement_params,
|
237
|
+
)
|
238
|
+
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
|
239
|
+
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
|
240
|
+
)
|
241
|
+
|
242
|
+
return res.validate()
|
@@ -86,6 +86,7 @@ class ModelComposer:
|
|
86
86
|
metadata: Optional[Dict[str, str]] = None,
|
87
87
|
conda_dependencies: Optional[List[str]] = None,
|
88
88
|
pip_requirements: Optional[List[str]] = None,
|
89
|
+
target_platforms: Optional[List[model_types.TargetPlatform]] = None,
|
89
90
|
python_version: Optional[str] = None,
|
90
91
|
ext_modules: Optional[List[ModuleType]] = None,
|
91
92
|
code_paths: Optional[List[str]] = None,
|
@@ -131,6 +132,7 @@ class ModelComposer:
|
|
131
132
|
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
132
133
|
options=options,
|
133
134
|
data_sources=self._get_data_sources(model, sample_input_data),
|
135
|
+
target_platforms=target_platforms,
|
134
136
|
)
|
135
137
|
|
136
138
|
file_utils.upload_directory_to_stage(
|
@@ -44,6 +44,7 @@ class ModelManifest:
|
|
44
44
|
model_rel_path: pathlib.PurePosixPath,
|
45
45
|
options: Optional[type_hints.ModelSaveOption] = None,
|
46
46
|
data_sources: Optional[List[data_source.DataSource]] = None,
|
47
|
+
target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
|
47
48
|
) -> None:
|
48
49
|
if options is None:
|
49
50
|
options = {}
|
@@ -132,6 +133,9 @@ class ModelManifest:
|
|
132
133
|
if lineage_sources:
|
133
134
|
manifest_dict["lineage_sources"] = lineage_sources
|
134
135
|
|
136
|
+
if target_platforms:
|
137
|
+
manifest_dict["target_platforms"] = [platform.value for platform in target_platforms]
|
138
|
+
|
135
139
|
with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
|
136
140
|
# Anchors are not supported in the server, avoid that.
|
137
141
|
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
@@ -5,6 +5,7 @@ import sys
|
|
5
5
|
|
6
6
|
import anyio
|
7
7
|
import pandas as pd
|
8
|
+
import numpy as np
|
8
9
|
from _snowflake import vectorized
|
9
10
|
|
10
11
|
from snowflake.ml.model._packager import model_packager
|
@@ -47,4 +48,4 @@ def {function_name}(df: pd.DataFrame) -> dict:
|
|
47
48
|
df.columns = input_cols
|
48
49
|
input_df = df.astype(dtype=dtype_map)
|
49
50
|
predictions_df = runner(input_df[input_cols])
|
50
|
-
return predictions_df.to_dict("records")
|
51
|
+
return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
|