snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,11 @@ import enum
2
2
  import pathlib
3
3
  import tempfile
4
4
  import warnings
5
- from typing import Any, Callable, Dict, List, Optional, Union, overload
5
+ from typing import Any, Callable, Optional, Union, overload
6
6
 
7
7
  import pandas as pd
8
8
 
9
+ from snowflake import snowpark
9
10
  from snowflake.ml._internal import telemetry
10
11
  from snowflake.ml._internal.utils import sql_identifier
11
12
  from snowflake.ml.lineage import lineage_node
@@ -32,7 +33,7 @@ class ModelVersion(lineage_node.LineageNode):
32
33
  _service_ops: service_ops.ServiceOperator
33
34
  _model_name: sql_identifier.SqlIdentifier
34
35
  _version_name: sql_identifier.SqlIdentifier
35
- _functions: List[model_manifest_schema.ModelFunctionInfo]
36
+ _functions: list[model_manifest_schema.ModelFunctionInfo]
36
37
 
37
38
  def __init__(self) -> None:
38
39
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
@@ -152,7 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
152
153
  project=_TELEMETRY_PROJECT,
153
154
  subproject=_TELEMETRY_SUBPROJECT,
154
155
  )
155
- def show_metrics(self) -> Dict[str, Any]:
156
+ def show_metrics(self) -> dict[str, Any]:
156
157
  """Show all metrics logged with the model version.
157
158
 
158
159
  Returns:
@@ -293,7 +294,7 @@ class ModelVersion(lineage_node.LineageNode):
293
294
  statement_params=statement_params,
294
295
  )
295
296
 
296
- def _get_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
297
+ def _get_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
297
298
  statement_params = telemetry.get_statement_params(
298
299
  project=_TELEMETRY_PROJECT,
299
300
  subproject=_TELEMETRY_SUBPROJECT,
@@ -327,7 +328,7 @@ class ModelVersion(lineage_node.LineageNode):
327
328
  project=_TELEMETRY_PROJECT,
328
329
  subproject=_TELEMETRY_SUBPROJECT,
329
330
  )
330
- def show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
331
+ def show_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
331
332
  """Show all functions information in a model version that is callable.
332
333
 
333
334
  Returns:
@@ -405,11 +406,6 @@ class ModelVersion(lineage_node.LineageNode):
405
406
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
406
407
  type validation to make sure your input data won't overflow when providing to the model.
407
408
 
408
- Raises:
409
- ValueError: When no method with the corresponding name is available.
410
- ValueError: When there are more than 1 target methods available in the model but no function name specified.
411
- ValueError: When the partition column is not a valid Snowflake identifier.
412
-
413
409
  Returns:
414
410
  The prediction data. It would be the same type dataframe as your input.
415
411
  """
@@ -422,29 +418,7 @@ class ModelVersion(lineage_node.LineageNode):
422
418
  # Partition column must be a valid identifier
423
419
  partition_column = sql_identifier.SqlIdentifier(partition_column)
424
420
 
425
- functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions
426
-
427
- if function_name:
428
- req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
429
- find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
430
- lambda method: method["name"] == req_method_name
431
- )
432
- target_function_info = next(
433
- filter(find_method, functions),
434
- None,
435
- )
436
- if target_function_info is None:
437
- raise ValueError(
438
- f"There is no method with name {function_name} available in the model"
439
- f" {self.fully_qualified_model_name} version {self.version_name}"
440
- )
441
- elif len(functions) != 1:
442
- raise ValueError(
443
- f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
444
- f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
445
- )
446
- else:
447
- target_function_info = functions[0]
421
+ target_function_info = self._get_function_info(function_name=function_name)
448
422
 
449
423
  if service_name:
450
424
  database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
@@ -475,6 +449,33 @@ class ModelVersion(lineage_node.LineageNode):
475
449
  is_partitioned=target_function_info["is_partitioned"],
476
450
  )
477
451
 
452
+ def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
453
+ functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
454
+
455
+ if function_name:
456
+ req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
457
+ find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
458
+ lambda method: method["name"] == req_method_name
459
+ )
460
+ target_function_info = next(
461
+ filter(find_method, functions),
462
+ None,
463
+ )
464
+ if target_function_info is None:
465
+ raise ValueError(
466
+ f"There is no method with name {function_name} available in the model"
467
+ f" {self.fully_qualified_model_name} version {self.version_name}"
468
+ )
469
+ elif len(functions) != 1:
470
+ raise ValueError(
471
+ f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
472
+ f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
473
+ )
474
+ else:
475
+ target_function_info = functions[0]
476
+
477
+ return target_function_info
478
+
478
479
  @telemetry.send_api_usage_telemetry(
479
480
  project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
480
481
  )
@@ -684,7 +685,7 @@ class ModelVersion(lineage_node.LineageNode):
684
685
  num_workers: Optional[int] = None,
685
686
  max_batch_rows: Optional[int] = None,
686
687
  force_rebuild: bool = False,
687
- build_external_access_integrations: Optional[List[str]] = None,
688
+ build_external_access_integrations: Optional[list[str]] = None,
688
689
  block: bool = True,
689
690
  ) -> Union[str, async_job.AsyncJob]:
690
691
  """Create an inference service with the given spec.
@@ -751,7 +752,7 @@ class ModelVersion(lineage_node.LineageNode):
751
752
  max_batch_rows: Optional[int] = None,
752
753
  force_rebuild: bool = False,
753
754
  build_external_access_integration: Optional[str] = None,
754
- build_external_access_integrations: Optional[List[str]] = None,
755
+ build_external_access_integrations: Optional[list[str]] = None,
755
756
  block: bool = True,
756
757
  ) -> Union[str, async_job.AsyncJob]:
757
758
  """Create an inference service with the given spec.
@@ -914,5 +915,72 @@ class ModelVersion(lineage_node.LineageNode):
914
915
  statement_params=statement_params,
915
916
  )
916
917
 
918
+ @snowpark._internal.utils.private_preview(version="1.8.3")
919
+ @telemetry.send_api_usage_telemetry(
920
+ project=_TELEMETRY_PROJECT,
921
+ subproject=_TELEMETRY_SUBPROJECT,
922
+ )
923
+ def run_job(
924
+ self,
925
+ X: Union[pd.DataFrame, "dataframe.DataFrame"],
926
+ *,
927
+ job_name: str,
928
+ compute_pool: str,
929
+ image_repo: str,
930
+ output_table_name: str,
931
+ function_name: Optional[str] = None,
932
+ cpu_requests: Optional[str] = None,
933
+ memory_requests: Optional[str] = None,
934
+ gpu_requests: Optional[Union[str, int]] = None,
935
+ num_workers: Optional[int] = None,
936
+ max_batch_rows: Optional[int] = None,
937
+ force_rebuild: bool = False,
938
+ build_external_access_integrations: Optional[list[str]] = None,
939
+ ) -> Union[pd.DataFrame, dataframe.DataFrame]:
940
+ statement_params = telemetry.get_statement_params(
941
+ project=_TELEMETRY_PROJECT,
942
+ subproject=_TELEMETRY_SUBPROJECT,
943
+ )
944
+ target_function_info = self._get_function_info(function_name=function_name)
945
+ job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
946
+ image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
947
+ output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
948
+ output_table_name
949
+ )
950
+ warehouse = self._service_ops._session.get_current_warehouse()
951
+ assert warehouse, "No active warehouse selected in the current session."
952
+ return self._service_ops.invoke_job_method(
953
+ target_method=target_function_info["target_method"],
954
+ signature=target_function_info["signature"],
955
+ X=X,
956
+ database_name=None,
957
+ schema_name=None,
958
+ model_name=self._model_name,
959
+ version_name=self._version_name,
960
+ job_database_name=job_db_id,
961
+ job_schema_name=job_schema_id,
962
+ job_name=job_id,
963
+ compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
964
+ warehouse_name=sql_identifier.SqlIdentifier(warehouse),
965
+ image_repo_database_name=image_repo_db_id,
966
+ image_repo_schema_name=image_repo_schema_id,
967
+ image_repo_name=image_repo_id,
968
+ output_table_database_name=output_table_db_id,
969
+ output_table_schema_name=output_table_schema_id,
970
+ output_table_name=output_table_id,
971
+ cpu_requests=cpu_requests,
972
+ memory_requests=memory_requests,
973
+ gpu_requests=gpu_requests,
974
+ num_workers=num_workers,
975
+ max_batch_rows=max_batch_rows,
976
+ force_rebuild=force_rebuild,
977
+ build_external_access_integrations=(
978
+ None
979
+ if build_external_access_integrations is None
980
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
981
+ ),
982
+ statement_params=statement_params,
983
+ )
984
+
917
985
 
918
986
  lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, Optional, TypedDict
2
+ from typing import Any, Optional, TypedDict
3
3
 
4
4
  from typing_extensions import NotRequired
5
5
 
@@ -14,7 +14,7 @@ MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
14
14
 
15
15
 
16
16
  class ModelVersionMetadataSchema(TypedDict):
17
- metrics: NotRequired[Dict[str, Any]]
17
+ metrics: NotRequired[dict[str, Any]]
18
18
 
19
19
 
20
20
  class MetadataOperator:
@@ -44,7 +44,7 @@ class MetadataOperator:
44
44
  )
45
45
 
46
46
  @staticmethod
47
- def _parse(metadata_dict: Dict[str, Any]) -> ModelVersionMetadataSchema:
47
+ def _parse(metadata_dict: dict[str, Any]) -> ModelVersionMetadataSchema:
48
48
  loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
49
49
  if loaded_metadata_schema_version is None:
50
50
  return ModelVersionMetadataSchema(metrics={})
@@ -65,8 +65,8 @@ class MetadataOperator:
65
65
  schema_name: Optional[sql_identifier.SqlIdentifier],
66
66
  model_name: sql_identifier.SqlIdentifier,
67
67
  version_name: sql_identifier.SqlIdentifier,
68
- statement_params: Optional[Dict[str, Any]] = None,
69
- ) -> Dict[str, Any]:
68
+ statement_params: Optional[dict[str, Any]] = None,
69
+ ) -> dict[str, Any]:
70
70
  version_info_list = self._model_client.show_versions(
71
71
  database_name=database_name,
72
72
  schema_name=schema_name,
@@ -89,7 +89,7 @@ class MetadataOperator:
89
89
  schema_name: Optional[sql_identifier.SqlIdentifier],
90
90
  model_name: sql_identifier.SqlIdentifier,
91
91
  version_name: sql_identifier.SqlIdentifier,
92
- statement_params: Optional[Dict[str, Any]] = None,
92
+ statement_params: Optional[dict[str, Any]] = None,
93
93
  ) -> ModelVersionMetadataSchema:
94
94
  metadata_dict = self._get_current_metadata_dict(
95
95
  database_name=database_name,
@@ -108,7 +108,7 @@ class MetadataOperator:
108
108
  schema_name: Optional[sql_identifier.SqlIdentifier],
109
109
  model_name: sql_identifier.SqlIdentifier,
110
110
  version_name: sql_identifier.SqlIdentifier,
111
- statement_params: Optional[Dict[str, Any]] = None,
111
+ statement_params: Optional[dict[str, Any]] = None,
112
112
  ) -> None:
113
113
  metadata_dict = self._get_current_metadata_dict(
114
114
  database_name=database_name,
@@ -4,7 +4,7 @@ import os
4
4
  import pathlib
5
5
  import tempfile
6
6
  import warnings
7
- from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload
7
+ from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
8
8
 
9
9
  import yaml
10
10
 
@@ -104,7 +104,7 @@ class ModelOperator:
104
104
  *,
105
105
  database_name: Optional[sql_identifier.SqlIdentifier],
106
106
  schema_name: Optional[sql_identifier.SqlIdentifier],
107
- statement_params: Optional[Dict[str, Any]] = None,
107
+ statement_params: Optional[dict[str, Any]] = None,
108
108
  ) -> str:
109
109
  stage_name = sql_identifier.SqlIdentifier(
110
110
  snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
@@ -137,7 +137,7 @@ class ModelOperator:
137
137
  schema_name: Optional[sql_identifier.SqlIdentifier],
138
138
  model_name: sql_identifier.SqlIdentifier,
139
139
  version_name: sql_identifier.SqlIdentifier,
140
- statement_params: Optional[Dict[str, Any]] = None,
140
+ statement_params: Optional[dict[str, Any]] = None,
141
141
  ) -> ModelAction:
142
142
  if self.validate_existence(
143
143
  database_name=database_name,
@@ -169,7 +169,7 @@ class ModelOperator:
169
169
  schema_name: Optional[sql_identifier.SqlIdentifier],
170
170
  model_name: sql_identifier.SqlIdentifier,
171
171
  version_name: sql_identifier.SqlIdentifier,
172
- statement_params: Optional[Dict[str, Any]] = None,
172
+ statement_params: Optional[dict[str, Any]] = None,
173
173
  ) -> None:
174
174
  model_action = self.get_model_action_from_model_name_and_version(
175
175
  database_name=database_name,
@@ -205,7 +205,7 @@ class ModelOperator:
205
205
  schema_name: Optional[sql_identifier.SqlIdentifier],
206
206
  model_name: sql_identifier.SqlIdentifier,
207
207
  version_name: sql_identifier.SqlIdentifier,
208
- statement_params: Optional[Dict[str, Any]] = None,
208
+ statement_params: Optional[dict[str, Any]] = None,
209
209
  use_live_commit: Optional[bool] = False,
210
210
  ) -> None:
211
211
 
@@ -263,7 +263,7 @@ class ModelOperator:
263
263
  model_name: sql_identifier.SqlIdentifier,
264
264
  version_name: sql_identifier.SqlIdentifier,
265
265
  model_exists: bool,
266
- statement_params: Optional[Dict[str, Any]] = None,
266
+ statement_params: Optional[dict[str, Any]] = None,
267
267
  ) -> None:
268
268
  if model_exists:
269
269
  return self._model_version_client.add_version_from_model_version(
@@ -296,8 +296,8 @@ class ModelOperator:
296
296
  database_name: Optional[sql_identifier.SqlIdentifier],
297
297
  schema_name: Optional[sql_identifier.SqlIdentifier],
298
298
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
299
- statement_params: Optional[Dict[str, Any]] = None,
300
- ) -> List[row.Row]:
299
+ statement_params: Optional[dict[str, Any]] = None,
300
+ ) -> list[row.Row]:
301
301
  if model_name:
302
302
  return self._model_client.show_versions(
303
303
  database_name=database_name,
@@ -320,8 +320,8 @@ class ModelOperator:
320
320
  database_name: Optional[sql_identifier.SqlIdentifier],
321
321
  schema_name: Optional[sql_identifier.SqlIdentifier],
322
322
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
323
- statement_params: Optional[Dict[str, Any]] = None,
324
- ) -> List[sql_identifier.SqlIdentifier]:
323
+ statement_params: Optional[dict[str, Any]] = None,
324
+ ) -> list[sql_identifier.SqlIdentifier]:
325
325
  res = self.show_models_or_versions(
326
326
  database_name=database_name,
327
327
  schema_name=schema_name,
@@ -341,7 +341,7 @@ class ModelOperator:
341
341
  schema_name: Optional[sql_identifier.SqlIdentifier],
342
342
  model_name: sql_identifier.SqlIdentifier,
343
343
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
344
- statement_params: Optional[Dict[str, Any]] = None,
344
+ statement_params: Optional[dict[str, Any]] = None,
345
345
  ) -> bool:
346
346
  if version_name:
347
347
  res = self._model_client.show_versions(
@@ -369,7 +369,7 @@ class ModelOperator:
369
369
  schema_name: Optional[sql_identifier.SqlIdentifier],
370
370
  model_name: sql_identifier.SqlIdentifier,
371
371
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
372
- statement_params: Optional[Dict[str, Any]] = None,
372
+ statement_params: Optional[dict[str, Any]] = None,
373
373
  ) -> str:
374
374
  if version_name:
375
375
  res = self._model_client.show_versions(
@@ -398,7 +398,7 @@ class ModelOperator:
398
398
  schema_name: Optional[sql_identifier.SqlIdentifier],
399
399
  model_name: sql_identifier.SqlIdentifier,
400
400
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
401
- statement_params: Optional[Dict[str, Any]] = None,
401
+ statement_params: Optional[dict[str, Any]] = None,
402
402
  ) -> None:
403
403
  if version_name:
404
404
  self._model_version_client.set_comment(
@@ -426,7 +426,7 @@ class ModelOperator:
426
426
  schema_name: Optional[sql_identifier.SqlIdentifier],
427
427
  model_name: sql_identifier.SqlIdentifier,
428
428
  version_name: sql_identifier.SqlIdentifier,
429
- statement_params: Optional[Dict[str, Any]] = None,
429
+ statement_params: Optional[dict[str, Any]] = None,
430
430
  ) -> None:
431
431
  self._model_version_client.set_alias(
432
432
  alias_name=alias_name,
@@ -444,7 +444,7 @@ class ModelOperator:
444
444
  database_name: Optional[sql_identifier.SqlIdentifier],
445
445
  schema_name: Optional[sql_identifier.SqlIdentifier],
446
446
  model_name: sql_identifier.SqlIdentifier,
447
- statement_params: Optional[Dict[str, Any]] = None,
447
+ statement_params: Optional[dict[str, Any]] = None,
448
448
  ) -> None:
449
449
  self._model_version_client.unset_alias(
450
450
  database_name=database_name,
@@ -461,7 +461,7 @@ class ModelOperator:
461
461
  schema_name: Optional[sql_identifier.SqlIdentifier],
462
462
  model_name: sql_identifier.SqlIdentifier,
463
463
  version_name: sql_identifier.SqlIdentifier,
464
- statement_params: Optional[Dict[str, Any]] = None,
464
+ statement_params: Optional[dict[str, Any]] = None,
465
465
  ) -> None:
466
466
  if not self.validate_existence(
467
467
  database_name=database_name,
@@ -485,7 +485,7 @@ class ModelOperator:
485
485
  database_name: Optional[sql_identifier.SqlIdentifier],
486
486
  schema_name: Optional[sql_identifier.SqlIdentifier],
487
487
  model_name: sql_identifier.SqlIdentifier,
488
- statement_params: Optional[Dict[str, Any]] = None,
488
+ statement_params: Optional[dict[str, Any]] = None,
489
489
  ) -> sql_identifier.SqlIdentifier:
490
490
  res = self._model_client.show_models(
491
491
  database_name=database_name,
@@ -504,7 +504,7 @@ class ModelOperator:
504
504
  schema_name: Optional[sql_identifier.SqlIdentifier],
505
505
  model_name: sql_identifier.SqlIdentifier,
506
506
  alias_name: sql_identifier.SqlIdentifier,
507
- statement_params: Optional[Dict[str, Any]] = None,
507
+ statement_params: Optional[dict[str, Any]] = None,
508
508
  ) -> Optional[sql_identifier.SqlIdentifier]:
509
509
  res = self._model_client.show_versions(
510
510
  database_name=database_name,
@@ -528,7 +528,7 @@ class ModelOperator:
528
528
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
529
529
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
530
530
  tag_name: sql_identifier.SqlIdentifier,
531
- statement_params: Optional[Dict[str, Any]] = None,
531
+ statement_params: Optional[dict[str, Any]] = None,
532
532
  ) -> Optional[str]:
533
533
  r = self._tag_client.get_tag_value(
534
534
  database_name=database_name,
@@ -550,15 +550,15 @@ class ModelOperator:
550
550
  database_name: Optional[sql_identifier.SqlIdentifier],
551
551
  schema_name: Optional[sql_identifier.SqlIdentifier],
552
552
  model_name: sql_identifier.SqlIdentifier,
553
- statement_params: Optional[Dict[str, Any]] = None,
554
- ) -> Dict[str, str]:
553
+ statement_params: Optional[dict[str, Any]] = None,
554
+ ) -> dict[str, str]:
555
555
  tags_info = self._tag_client.get_tag_list(
556
556
  database_name=database_name,
557
557
  schema_name=schema_name,
558
558
  model_name=model_name,
559
559
  statement_params=statement_params,
560
560
  )
561
- res: Dict[str, str] = {
561
+ res: dict[str, str] = {
562
562
  identifier.get_schema_level_object_identifier(
563
563
  sql_identifier.SqlIdentifier(r.TAG_DATABASE, case_sensitive=True),
564
564
  sql_identifier.SqlIdentifier(r.TAG_SCHEMA, case_sensitive=True),
@@ -578,7 +578,7 @@ class ModelOperator:
578
578
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
579
579
  tag_name: sql_identifier.SqlIdentifier,
580
580
  tag_value: str,
581
- statement_params: Optional[Dict[str, Any]] = None,
581
+ statement_params: Optional[dict[str, Any]] = None,
582
582
  ) -> None:
583
583
  self._tag_client.set_tag_on_model(
584
584
  database_name=database_name,
@@ -600,7 +600,7 @@ class ModelOperator:
600
600
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
601
601
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
602
602
  tag_name: sql_identifier.SqlIdentifier,
603
- statement_params: Optional[Dict[str, Any]] = None,
603
+ statement_params: Optional[dict[str, Any]] = None,
604
604
  ) -> None:
605
605
  self._tag_client.unset_tag_on_model(
606
606
  database_name=database_name,
@@ -619,8 +619,8 @@ class ModelOperator:
619
619
  schema_name: Optional[sql_identifier.SqlIdentifier],
620
620
  model_name: sql_identifier.SqlIdentifier,
621
621
  version_name: sql_identifier.SqlIdentifier,
622
- statement_params: Optional[Dict[str, Any]] = None,
623
- ) -> List[ServiceInfo]:
622
+ statement_params: Optional[dict[str, Any]] = None,
623
+ ) -> list[ServiceInfo]:
624
624
  res = self._model_client.show_versions(
625
625
  database_name=database_name,
626
626
  schema_name=schema_name,
@@ -682,7 +682,7 @@ class ModelOperator:
682
682
  service_database_name: Optional[sql_identifier.SqlIdentifier],
683
683
  service_schema_name: Optional[sql_identifier.SqlIdentifier],
684
684
  service_name: sql_identifier.SqlIdentifier,
685
- statement_params: Optional[Dict[str, Any]] = None,
685
+ statement_params: Optional[dict[str, Any]] = None,
686
686
  ) -> None:
687
687
  services = self.show_services(
688
688
  database_name=database_name,
@@ -724,7 +724,7 @@ class ModelOperator:
724
724
  schema_name: Optional[sql_identifier.SqlIdentifier],
725
725
  model_name: sql_identifier.SqlIdentifier,
726
726
  version_name: sql_identifier.SqlIdentifier,
727
- statement_params: Optional[Dict[str, Any]] = None,
727
+ statement_params: Optional[dict[str, Any]] = None,
728
728
  ) -> model_manifest_schema.ModelManifestDict:
729
729
  with tempfile.TemporaryDirectory() as tmpdir:
730
730
  self._model_version_client.get_file(
@@ -741,9 +741,9 @@ class ModelOperator:
741
741
 
742
742
  @staticmethod
743
743
  def _match_model_spec_with_sql_functions(
744
- sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
745
- ) -> Dict[sql_identifier.SqlIdentifier, str]:
746
- res: Dict[sql_identifier.SqlIdentifier, str] = {}
744
+ sql_functions_names: list[sql_identifier.SqlIdentifier], target_methods: list[str]
745
+ ) -> dict[sql_identifier.SqlIdentifier, str]:
746
+ res: dict[sql_identifier.SqlIdentifier, str] = {}
747
747
 
748
748
  for target_method in target_methods:
749
749
  # Here we need to find the SQL function corresponding to the Python function.
@@ -766,7 +766,7 @@ class ModelOperator:
766
766
  schema_name: Optional[sql_identifier.SqlIdentifier],
767
767
  model_name: sql_identifier.SqlIdentifier,
768
768
  version_name: sql_identifier.SqlIdentifier,
769
- statement_params: Optional[Dict[str, Any]] = None,
769
+ statement_params: Optional[dict[str, Any]] = None,
770
770
  ) -> model_meta_schema.ModelMetadataDict:
771
771
  raw_model_spec_res = self._model_client.show_versions(
772
772
  database_name=database_name,
@@ -787,7 +787,7 @@ class ModelOperator:
787
787
  schema_name: Optional[sql_identifier.SqlIdentifier],
788
788
  model_name: sql_identifier.SqlIdentifier,
789
789
  version_name: sql_identifier.SqlIdentifier,
790
- statement_params: Optional[Dict[str, Any]] = None,
790
+ statement_params: Optional[dict[str, Any]] = None,
791
791
  ) -> type_hints.Task:
792
792
  model_version = self._model_client.show_versions(
793
793
  database_name=database_name,
@@ -809,8 +809,8 @@ class ModelOperator:
809
809
  schema_name: Optional[sql_identifier.SqlIdentifier],
810
810
  model_name: sql_identifier.SqlIdentifier,
811
811
  version_name: sql_identifier.SqlIdentifier,
812
- statement_params: Optional[Dict[str, Any]] = None,
813
- ) -> List[model_manifest_schema.ModelFunctionInfo]:
812
+ statement_params: Optional[dict[str, Any]] = None,
813
+ ) -> list[model_manifest_schema.ModelFunctionInfo]:
814
814
  model_spec = self._fetch_model_spec(
815
815
  database_name=database_name,
816
816
  schema_name=schema_name,
@@ -907,7 +907,7 @@ class ModelOperator:
907
907
  version_name: sql_identifier.SqlIdentifier,
908
908
  strict_input_validation: bool = False,
909
909
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
910
- statement_params: Optional[Dict[str, str]] = None,
910
+ statement_params: Optional[dict[str, str]] = None,
911
911
  is_partitioned: Optional[bool] = None,
912
912
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
913
913
  ...
@@ -923,7 +923,7 @@ class ModelOperator:
923
923
  schema_name: Optional[sql_identifier.SqlIdentifier],
924
924
  service_name: sql_identifier.SqlIdentifier,
925
925
  strict_input_validation: bool = False,
926
- statement_params: Optional[Dict[str, str]] = None,
926
+ statement_params: Optional[dict[str, str]] = None,
927
927
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
928
928
  ...
929
929
 
@@ -941,7 +941,7 @@ class ModelOperator:
941
941
  service_name: Optional[sql_identifier.SqlIdentifier] = None,
942
942
  strict_input_validation: bool = False,
943
943
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
944
- statement_params: Optional[Dict[str, str]] = None,
944
+ statement_params: Optional[dict[str, str]] = None,
945
945
  is_partitioned: Optional[bool] = None,
946
946
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
947
947
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
@@ -1059,7 +1059,7 @@ class ModelOperator:
1059
1059
  schema_name: Optional[sql_identifier.SqlIdentifier],
1060
1060
  model_name: sql_identifier.SqlIdentifier,
1061
1061
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
1062
- statement_params: Optional[Dict[str, Any]] = None,
1062
+ statement_params: Optional[dict[str, Any]] = None,
1063
1063
  ) -> None:
1064
1064
  if version_name:
1065
1065
  self._model_version_client.drop_version(
@@ -1086,7 +1086,7 @@ class ModelOperator:
1086
1086
  new_model_db: Optional[sql_identifier.SqlIdentifier],
1087
1087
  new_model_schema: Optional[sql_identifier.SqlIdentifier],
1088
1088
  new_model_name: sql_identifier.SqlIdentifier,
1089
- statement_params: Optional[Dict[str, Any]] = None,
1089
+ statement_params: Optional[dict[str, Any]] = None,
1090
1090
  ) -> None:
1091
1091
  self._model_client.rename(
1092
1092
  database_name=database_name,
@@ -1121,7 +1121,7 @@ class ModelOperator:
1121
1121
  version_name: sql_identifier.SqlIdentifier,
1122
1122
  target_path: pathlib.Path,
1123
1123
  mode: Literal["full", "model", "minimal"] = "model",
1124
- statement_params: Optional[Dict[str, Any]] = None,
1124
+ statement_params: Optional[dict[str, Any]] = None,
1125
1125
  ) -> None:
1126
1126
  for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
1127
1127
  list_file_res = self._model_version_client.list_file(