snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
@@ -11,7 +11,7 @@ class StageSQLClient(_base._BaseSQLClient):
11
11
  database_name: Optional[sql_identifier.SqlIdentifier],
12
12
  schema_name: Optional[sql_identifier.SqlIdentifier],
13
13
  stage_name: sql_identifier.SqlIdentifier,
14
- statement_params: Optional[Dict[str, Any]] = None,
14
+ statement_params: Optional[dict[str, Any]] = None,
15
15
  ) -> None:
16
16
  query_result_checker.SqlResultValidator(
17
17
  self._session,
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
@@ -16,7 +16,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
16
16
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
17
17
  tag_name: sql_identifier.SqlIdentifier,
18
18
  tag_value: str,
19
- statement_params: Optional[Dict[str, Any]] = None,
19
+ statement_params: Optional[dict[str, Any]] = None,
20
20
  ) -> None:
21
21
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
22
22
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -35,7 +35,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
35
35
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
36
36
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
37
37
  tag_name: sql_identifier.SqlIdentifier,
38
- statement_params: Optional[Dict[str, Any]] = None,
38
+ statement_params: Optional[dict[str, Any]] = None,
39
39
  ) -> None:
40
40
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
41
41
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -54,7 +54,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
54
54
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
55
55
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
56
56
  tag_name: sql_identifier.SqlIdentifier,
57
- statement_params: Optional[Dict[str, Any]] = None,
57
+ statement_params: Optional[dict[str, Any]] = None,
58
58
  ) -> row.Row:
59
59
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
60
60
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -75,8 +75,8 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
75
75
  database_name: Optional[sql_identifier.SqlIdentifier],
76
76
  schema_name: Optional[sql_identifier.SqlIdentifier],
77
77
  model_name: sql_identifier.SqlIdentifier,
78
- statement_params: Optional[Dict[str, Any]] = None,
79
- ) -> List[row.Row]:
78
+ statement_params: Optional[dict[str, Any]] = None,
79
+ ) -> list[row.Row]:
80
80
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
81
81
  actual_database_name = database_name or self._database_name
82
82
  return (
@@ -3,13 +3,14 @@ import tempfile
3
3
  import uuid
4
4
  import warnings
5
5
  from types import ModuleType
6
- from typing import Any, Dict, List, Optional, Union
6
+ from typing import Any, Optional, Union
7
7
  from urllib import parse
8
8
 
9
9
  from absl import logging
10
10
  from packaging import requirements
11
11
 
12
12
  from snowflake import snowpark
13
+ from snowflake.ml import version as snowml_version
13
14
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
15
  from snowflake.ml._internal.lineage import lineage_utils
15
16
  from snowflake.ml.data import data_source
@@ -43,7 +44,7 @@ class ModelComposer:
43
44
  session: Session,
44
45
  stage_path: str,
45
46
  *,
46
- statement_params: Optional[Dict[str, Any]] = None,
47
+ statement_params: Optional[dict[str, Any]] = None,
47
48
  save_location: Optional[str] = None,
48
49
  ) -> None:
49
50
  self.session = session
@@ -122,17 +123,18 @@ class ModelComposer:
122
123
  *,
123
124
  name: str,
124
125
  model: model_types.SupportedModelType,
125
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
126
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
126
127
  sample_input_data: Optional[model_types.SupportedDataType] = None,
127
- metadata: Optional[Dict[str, str]] = None,
128
- conda_dependencies: Optional[List[str]] = None,
129
- pip_requirements: Optional[List[str]] = None,
130
- artifact_repository_map: Optional[Dict[str, str]] = None,
131
- target_platforms: Optional[List[model_types.TargetPlatform]] = None,
128
+ metadata: Optional[dict[str, str]] = None,
129
+ conda_dependencies: Optional[list[str]] = None,
130
+ pip_requirements: Optional[list[str]] = None,
131
+ artifact_repository_map: Optional[dict[str, str]] = None,
132
+ resource_constraint: Optional[dict[str, str]] = None,
133
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
132
134
  python_version: Optional[str] = None,
133
- user_files: Optional[Dict[str, List[str]]] = None,
134
- ext_modules: Optional[List[ModuleType]] = None,
135
- code_paths: Optional[List[str]] = None,
135
+ user_files: Optional[dict[str, list[str]]] = None,
136
+ ext_modules: Optional[list[ModuleType]] = None,
137
+ code_paths: Optional[list[str]] = None,
136
138
  task: model_types.Task = model_types.Task.UNKNOWN,
137
139
  options: Optional[model_types.ModelSaveOption] = None,
138
140
  ) -> model_meta.ModelMetadata:
@@ -166,14 +168,14 @@ class ModelComposer:
166
168
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
167
169
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
168
170
  self.session,
169
- reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_env.VERSION}")],
171
+ reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
170
172
  python_version=python_version or snowml_env.PYTHON_VERSION,
171
173
  statement_params=self._statement_params,
172
174
  ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
173
175
 
174
176
  if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
175
177
  logging.info(
176
- f"Local snowflake-ml-python library has version {snowml_env.VERSION},"
178
+ f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
177
179
  " which is not available in the Snowflake server, embedding local ML library automatically."
178
180
  )
179
181
  options["embed_local_ml_library"] = True
@@ -187,6 +189,7 @@ class ModelComposer:
187
189
  conda_dependencies=conda_dependencies,
188
190
  pip_requirements=pip_requirements,
189
191
  artifact_repository_map=artifact_repository_map,
192
+ resource_constraint=resource_constraint,
190
193
  target_platforms=target_platforms,
191
194
  python_version=python_version,
192
195
  ext_modules=ext_modules,
@@ -226,7 +229,7 @@ class ModelComposer:
226
229
 
227
230
  def _get_data_sources(
228
231
  self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
229
- ) -> Optional[List[data_source.DataSource]]:
232
+ ) -> Optional[list[data_source.DataSource]]:
230
233
  data_sources = lineage_utils.get_data_sources(model)
231
234
  if not data_sources and sample_input_data is not None:
232
235
  data_sources = lineage_utils.get_data_sources(sample_input_data)
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import pathlib
4
4
  import warnings
5
- from typing import Dict, List, Optional, cast
5
+ from typing import Optional, cast
6
6
 
7
7
  import yaml
8
8
 
@@ -45,10 +45,10 @@ class ModelManifest:
45
45
  self,
46
46
  model_meta: model_meta_api.ModelMetadata,
47
47
  model_rel_path: pathlib.PurePosixPath,
48
- user_files: Optional[Dict[str, List[str]]] = None,
48
+ user_files: Optional[dict[str, list[str]]] = None,
49
49
  options: Optional[type_hints.ModelSaveOption] = None,
50
- data_sources: Optional[List[data_source.DataSource]] = None,
51
- target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
50
+ data_sources: Optional[list[data_source.DataSource]] = None,
51
+ target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
52
52
  ) -> None:
53
53
  if options is None:
54
54
  options = {}
@@ -78,12 +78,13 @@ class ModelManifest:
78
78
  logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
79
79
  logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
80
80
  logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
81
+ logger.info(f"resource_constraint: {runtime_to_use.runtime_env.resource_constraint}")
81
82
  runtime_dict = runtime_to_use.save(
82
83
  self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
83
84
  )
84
85
 
85
86
  self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
86
- self.methods: List[model_method.ModelMethod] = []
87
+ self.methods: list[model_method.ModelMethod] = []
87
88
 
88
89
  for target_method in model_meta.signatures.keys():
89
90
  method = model_method.ModelMethod(
@@ -100,7 +101,7 @@ class ModelManifest:
100
101
 
101
102
  self.methods.append(method)
102
103
 
103
- self.user_files: List[model_user_file.ModelUserFile] = []
104
+ self.user_files: list[model_user_file.ModelUserFile] = []
104
105
 
105
106
  if user_files is not None:
106
107
  for subdirectory, paths in user_files.items():
@@ -127,16 +128,19 @@ class ModelManifest:
127
128
  if model_meta.env.artifact_repository_map:
128
129
  dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
129
130
 
131
+ runtime = model_manifest_schema.ModelRuntimeDict(
132
+ language="PYTHON",
133
+ version=runtime_to_use.runtime_env.python_version,
134
+ imports=runtime_dict["imports"],
135
+ dependencies=dependencies,
136
+ )
137
+
138
+ if runtime_dict["resource_constraint"]:
139
+ runtime["resource_constraint"] = runtime_dict["resource_constraint"]
140
+
130
141
  manifest_dict = model_manifest_schema.ModelManifestDict(
131
142
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
132
- runtimes={
133
- self._DEFAULT_RUNTIME_NAME: model_manifest_schema.ModelRuntimeDict(
134
- language="PYTHON",
135
- version=runtime_to_use.runtime_env.python_version,
136
- imports=runtime_dict["imports"],
137
- dependencies=dependencies,
138
- )
139
- },
143
+ runtimes={self._DEFAULT_RUNTIME_NAME: runtime},
140
144
  methods=[
141
145
  method.save(
142
146
  self.workspace_path,
@@ -178,8 +182,8 @@ class ModelManifest:
178
182
  return res
179
183
 
180
184
  def _extract_lineage_info(
181
- self, data_sources: Optional[List[data_source.DataSource]]
182
- ) -> List[model_manifest_schema.LineageSourceDict]:
185
+ self, data_sources: Optional[list[data_source.DataSource]]
186
+ ) -> list[model_manifest_schema.LineageSourceDict]:
183
187
  result = []
184
188
  if data_sources:
185
189
  for source in data_sources:
@@ -1,6 +1,6 @@
1
1
  # This files contains schema definition of what will be written into MANIFEST.yml
2
2
  import enum
3
- from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
3
+ from typing import Any, Literal, Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired, Required
6
6
 
@@ -20,14 +20,15 @@ class ModelMethodFunctionTypes(enum.Enum):
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
21
  conda: NotRequired[str]
22
22
  pip: NotRequired[str]
23
- artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
23
+ artifact_repository_map: NotRequired[Optional[dict[str, str]]]
24
24
 
25
25
 
26
26
  class ModelRuntimeDict(TypedDict):
27
27
  language: Required[Literal["PYTHON"]]
28
28
  version: Required[str]
29
- imports: Required[List[str]]
29
+ imports: Required[list[str]]
30
30
  dependencies: Required[ModelRuntimeDependenciesDict]
31
+ resource_constraint: NotRequired[Optional[dict[str, str]]]
31
32
 
32
33
 
33
34
  class ModelMethodSignatureField(TypedDict):
@@ -43,8 +44,8 @@ class ModelFunctionMethodDict(TypedDict):
43
44
  runtime: Required[str]
44
45
  type: Required[str]
45
46
  handler: Required[str]
46
- inputs: Required[List[ModelMethodSignatureFieldWithName]]
47
- outputs: Required[Union[List[ModelMethodSignatureField], List[ModelMethodSignatureFieldWithName]]]
47
+ inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
+ outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
48
49
 
49
50
 
50
51
  ModelMethodDict = ModelFunctionMethodDict
@@ -71,12 +72,12 @@ class ModelFunctionInfo(TypedDict):
71
72
  class ModelFunctionInfoDict(TypedDict):
72
73
  name: Required[str]
73
74
  target_method: Required[str]
74
- signature: Required[Dict[str, Any]]
75
+ signature: Required[dict[str, Any]]
75
76
 
76
77
 
77
78
  class SnowparkMLDataDict(TypedDict):
78
79
  schema_version: Required[str]
79
- functions: Required[List[ModelFunctionInfoDict]]
80
+ functions: Required[list[ModelFunctionInfoDict]]
80
81
 
81
82
 
82
83
  class LineageSourceTypes(enum.Enum):
@@ -92,9 +93,9 @@ class LineageSourceDict(TypedDict):
92
93
 
93
94
  class ModelManifestDict(TypedDict):
94
95
  manifest_version: Required[str]
95
- runtimes: Required[Dict[str, ModelRuntimeDict]]
96
- methods: Required[List[ModelMethodDict]]
97
- user_data: NotRequired[Dict[str, Any]]
98
- user_files: NotRequired[List[str]]
99
- lineage_sources: NotRequired[List[LineageSourceDict]]
100
- target_platforms: NotRequired[List[str]]
96
+ runtimes: Required[dict[str, ModelRuntimeDict]]
97
+ methods: Required[list[ModelMethodDict]]
98
+ user_data: NotRequired[dict[str, Any]]
99
+ user_files: NotRequired[list[str]]
100
+ lineage_sources: NotRequired[list[LineageSourceDict]]
101
+ target_platforms: NotRequired[list[str]]
@@ -1,6 +1,6 @@
1
1
  import collections
2
2
  import pathlib
3
- from typing import List, Optional, TypedDict, Union
3
+ from typing import Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired
6
6
 
@@ -137,8 +137,8 @@ class ModelMethod:
137
137
  )
138
138
 
139
139
  outputs: Union[
140
- List[model_manifest_schema.ModelMethodSignatureField],
141
- List[model_manifest_schema.ModelMethodSignatureFieldWithName],
140
+ list[model_manifest_schema.ModelMethodSignatureField],
141
+ list[model_manifest_schema.ModelMethodSignatureFieldWithName],
142
142
  ]
143
143
  if self.function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
144
144
  outputs = [
@@ -3,10 +3,11 @@ import itertools
3
3
  import os
4
4
  import pathlib
5
5
  import warnings
6
- from typing import DefaultDict, Dict, List, Optional
6
+ from typing import DefaultDict, Optional
7
7
 
8
8
  from packaging import requirements, version
9
9
 
10
+ from snowflake.ml import version as snowml_version
10
11
  from snowflake.ml._internal import env as snowml_env, env_utils
11
12
  from snowflake.ml.model._packager.model_meta import model_meta_schema
12
13
 
@@ -19,9 +20,8 @@ _DEFAULT_CONDA_ENV_FILENAME = "conda.yml"
19
20
  _DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
20
21
 
21
22
  # The default CUDA version is chosen based on the driver availability in SPCS.
22
- # If changing this version, we need also change the version of default PyTorch in HuggingFace pipeline handler to
23
- # make sure they are compatible.
24
- DEFAULT_CUDA_VERSION = "11.8"
23
+ # Make sure they are aligned with default CUDA version in inference server.
24
+ DEFAULT_CUDA_VERSION = "12.4"
25
25
 
26
26
 
27
27
  class ModelEnv:
@@ -38,15 +38,16 @@ class ModelEnv:
38
38
  self.prefer_pip: bool = prefer_pip
39
39
  self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
40
40
  self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
41
- self.artifact_repository_map: Optional[Dict[str, str]] = None
42
- self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
43
- self._pip_requirements: List[requirements.Requirement] = []
41
+ self.artifact_repository_map: Optional[dict[str, str]] = None
42
+ self.resource_constraint: Optional[dict[str, str]] = None
43
+ self._conda_dependencies: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
44
+ self._pip_requirements: list[requirements.Requirement] = []
44
45
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
45
46
  self._cuda_version: Optional[version.Version] = None
46
- self._snowpark_ml_version: version.Version = version.parse(snowml_env.VERSION)
47
+ self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
47
48
 
48
49
  @property
49
- def conda_dependencies(self) -> List[str]:
50
+ def conda_dependencies(self) -> list[str]:
50
51
  """List of conda channel and dependencies from that to run the model"""
51
52
  return sorted(
52
53
  f"{chan}::{str(req)}" if chan else str(req)
@@ -57,24 +58,24 @@ class ModelEnv:
57
58
  @conda_dependencies.setter
58
59
  def conda_dependencies(
59
60
  self,
60
- conda_dependencies: Optional[List[str]] = None,
61
+ conda_dependencies: Optional[list[str]] = None,
61
62
  ) -> None:
62
63
  self._conda_dependencies = env_utils.validate_conda_dependency_string_list(
63
- conda_dependencies if conda_dependencies else []
64
+ conda_dependencies if conda_dependencies else [], add_local_version_specifier=True
64
65
  )
65
66
 
66
67
  @property
67
- def pip_requirements(self) -> List[str]:
68
+ def pip_requirements(self) -> list[str]:
68
69
  """List of pip Python packages requirements for running the model."""
69
70
  return sorted(list(map(str, self._pip_requirements)))
70
71
 
71
72
  @pip_requirements.setter
72
73
  def pip_requirements(
73
74
  self,
74
- pip_requirements: Optional[List[str]] = None,
75
+ pip_requirements: Optional[list[str]] = None,
75
76
  ) -> None:
76
77
  self._pip_requirements = env_utils.validate_pip_requirement_string_list(
77
- pip_requirements if pip_requirements else []
78
+ pip_requirements if pip_requirements else [], add_local_version_specifier=True
78
79
  )
79
80
 
80
81
  @property
@@ -117,7 +118,7 @@ class ModelEnv:
117
118
 
118
119
  def include_if_absent(
119
120
  self,
120
- pkgs: List[ModelDependency],
121
+ pkgs: list[ModelDependency],
121
122
  check_local_version: bool = False,
122
123
  ) -> None:
123
124
  """Append requirements into model env if absent. Depending on the environment, requirements may be added
@@ -128,7 +129,7 @@ class ModelEnv:
128
129
  check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
129
130
  """
130
131
  if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
131
- pip_pkg_reqs: List[str] = []
132
+ pip_pkg_reqs: list[str] = []
132
133
  warnings.warn(
133
134
  (
134
135
  "Dependencies specified from pip requirements."
@@ -145,7 +146,7 @@ class ModelEnv:
145
146
  else:
146
147
  self._include_if_absent_conda(pkgs, check_local_version)
147
148
 
148
- def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
149
+ def _include_if_absent_conda(self, pkgs: list[ModelDependency], check_local_version: bool = False) -> None:
149
150
  """Append requirements into model env conda dependencies if absent.
150
151
 
151
152
  Args:
@@ -190,7 +191,7 @@ class ModelEnv:
190
191
  stacklevel=2,
191
192
  )
192
193
 
193
- def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
194
+ def _include_if_absent_pip(self, pkgs: list[str], check_local_version: bool = False) -> None:
194
195
  """Append pip requirements into model env pip requirements if absent.
195
196
 
196
197
  Args:
@@ -207,7 +208,7 @@ class ModelEnv:
207
208
  except env_utils.DuplicateDependencyError:
208
209
  pass
209
210
 
210
- def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
211
+ def remove_if_present_conda(self, conda_pkgs: list[str]) -> None:
211
212
  """Remove conda requirements from model env if present.
212
213
 
213
214
  Args:
@@ -352,13 +353,14 @@ class ModelEnv:
352
353
  def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
353
354
  self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
354
355
  self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
355
- self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
356
+ self.artifact_repository_map = env_dict.get("artifact_repository_map")
357
+ self.resource_constraint = env_dict.get("resource_constraint")
356
358
 
357
359
  self.load_from_conda_file(base_dir / self.conda_env_rel_path)
358
360
  self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
359
361
 
360
362
  self.python_version = env_dict["python_version"]
361
- self.cuda_version = env_dict.get("cuda_version", None)
363
+ self.cuda_version = env_dict.get("cuda_version")
362
364
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
363
365
 
364
366
  def save_as_dict(
@@ -381,7 +383,8 @@ class ModelEnv:
381
383
  return {
382
384
  "conda": self.conda_env_rel_path.as_posix(),
383
385
  "pip": self.pip_requirements_rel_path.as_posix(),
384
- "artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
386
+ "artifact_repository_map": self.artifact_repository_map or {},
387
+ "resource_constraint": self.resource_constraint or {},
385
388
  "python_version": self.python_version,
386
389
  "cuda_version": self.cuda_version,
387
390
  "snowpark_ml_version": self.snowpark_ml_version,
@@ -389,7 +392,7 @@ class ModelEnv:
389
392
 
390
393
  def validate_with_local_env(
391
394
  self, check_snowpark_ml_version: bool = False
392
- ) -> List[env_utils.IncorrectLocalEnvironmentError]:
395
+ ) -> list[env_utils.IncorrectLocalEnvironmentError]:
393
396
  errors = []
394
397
  try:
395
398
  env_utils.validate_py_runtime_version(str(self._python_version))
@@ -413,10 +416,10 @@ class ModelEnv:
413
416
 
414
417
  if check_snowpark_ml_version:
415
418
  # For Modeling model
416
- if self._snowpark_ml_version.base_version != snowml_env.VERSION:
419
+ if self._snowpark_ml_version.base_version != snowml_version.VERSION:
417
420
  errors.append(
418
421
  env_utils.IncorrectLocalEnvironmentError(
419
- f"The local installed version of Snowpark ML library is {snowml_env.VERSION} "
422
+ f"The local installed version of Snowpark ML library is {snowml_version.VERSION} "
420
423
  f"which differs from required version {self.snowpark_ml_version}."
421
424
  )
422
425
  )
@@ -2,13 +2,13 @@ import functools
2
2
  import importlib
3
3
  import pkgutil
4
4
  from types import ModuleType
5
- from typing import Any, Callable, Dict, Optional, Type, TypeVar, cast
5
+ from typing import Any, Callable, Optional, TypeVar, cast
6
6
 
7
7
  from snowflake.ml.model import type_hints as model_types
8
8
  from snowflake.ml.model._packager.model_handlers import _base
9
9
 
10
10
  _HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
11
- _MODEL_HANDLER_REGISTRY: Dict[str, Type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
11
+ _MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
12
12
  _IS_HANDLER_LOADED = False
13
13
 
14
14
 
@@ -54,7 +54,7 @@ def ensure_handlers_registration(fn: F) -> F:
54
54
  @ensure_handlers_registration
55
55
  def find_handler(
56
56
  model: model_types.SupportedModelType,
57
- ) -> Optional[Type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
57
+ ) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
58
58
  for handler in _MODEL_HANDLER_REGISTRY.values():
59
59
  if handler.can_handle(model):
60
60
  return handler
@@ -64,7 +64,7 @@ def find_handler(
64
64
  @ensure_handlers_registration
65
65
  def load_handler(
66
66
  target_model_type: model_types.SupportedModelHandlerType,
67
- ) -> Optional[Type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
67
+ ) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
68
68
  for model_type, handler in _MODEL_HANDLER_REGISTRY.items():
69
69
  if target_model_type == model_type:
70
70
  return handler
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  from abc import abstractmethod
3
- from typing import Dict, Generic, Optional, Protocol, Type, final
3
+ from typing import Generic, Optional, Protocol, final
4
4
 
5
5
  import pandas as pd
6
6
  from typing_extensions import TypeGuard, Unpack
@@ -14,7 +14,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
14
14
  HANDLER_TYPE: model_types.SupportedModelHandlerType
15
15
  HANDLER_VERSION: str
16
16
  _MIN_SNOWPARK_ML_VERSION: str
17
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]]
17
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]]
18
18
 
19
19
  @classmethod
20
20
  @abstractmethod
@@ -1,8 +1,9 @@
1
+ import importlib
1
2
  import json
2
3
  import os
3
4
  import pathlib
4
5
  import warnings
5
- from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast
6
+ from typing import Any, Callable, Iterable, Optional, Sequence, cast
6
7
 
7
8
  import numpy as np
8
9
  import numpy.typing as npt
@@ -10,8 +11,10 @@ import pandas as pd
10
11
  from absl import logging
11
12
 
12
13
  import snowflake.snowpark.dataframe as sp_df
14
+ from snowflake.ml._internal import env
13
15
  from snowflake.ml._internal.utils import identifier
14
16
  from snowflake.ml.model import model_signature, type_hints as model_types
17
+ from snowflake.ml.model._packager.model_env import model_env
15
18
  from snowflake.ml.model._packager.model_meta import model_meta
16
19
  from snowflake.ml.model._signatures import (
17
20
  core,
@@ -231,7 +234,7 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
231
234
 
232
235
 
233
236
  def get_explain_target_method(
234
- model_metadata: model_meta.ModelMetadata, target_methods_list: List[str]
237
+ model_metadata: model_meta.ModelMetadata, target_methods_list: list[str]
235
238
  ) -> Optional[str]:
236
239
  for method in model_metadata.signatures.keys():
237
240
  if method in target_methods_list:
@@ -248,7 +251,7 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
248
251
  config_dict = json.load(f)
249
252
 
250
253
  # a. get repository and class_path from configs
251
- auto_map_configs = cast(Dict[str, str], config_dict.get("auto_map", {}))
254
+ auto_map_configs = cast(dict[str, str], config_dict.get("auto_map", {}))
252
255
  for config_name, config_value in auto_map_configs.items():
253
256
  repository, _, class_path = config_value.rpartition("--")
254
257
 
@@ -261,3 +264,12 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
261
264
 
262
265
  with open(f_path, "w") as f:
263
266
  json.dump(config_dict, f)
267
+
268
+
269
+ def get_default_cuda_version() -> str:
270
+ # Default to the env cuda version when running in ML runtime
271
+ if env.IN_ML_RUNTIME and importlib.util.find_spec("torch") is not None:
272
+ import torch
273
+
274
+ return torch.version.cuda or model_env.DEFAULT_CUDA_VERSION
275
+ return model_env.DEFAULT_CUDA_VERSION
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -30,7 +30,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
30
30
  HANDLER_TYPE = "catboost"
31
31
  HANDLER_VERSION = "2024-03-21"
32
32
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
33
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
33
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
34
34
 
35
35
  MODEL_BLOB_FILE_OR_DIR = "model.bin"
36
36
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -147,7 +147,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
147
147
  if enable_explainability:
148
148
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
149
149
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
150
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
150
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
151
151
 
152
152
  return None
153
153
 
@@ -202,7 +202,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
202
202
  def _create_custom_model(
203
203
  raw_model: "catboost.CatBoost",
204
204
  model_meta: model_meta_api.ModelMetadata,
205
- ) -> Type[custom_model.CustomModel]:
205
+ ) -> type[custom_model.CustomModel]:
206
206
  def fn_factory(
207
207
  raw_model: "catboost.CatBoost",
208
208
  signature: model_signature.ModelSignature,
@@ -235,7 +235,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
235
235
 
236
236
  return fn
237
237
 
238
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
238
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
239
239
  for target_method_name, sig in model_meta.signatures.items():
240
240
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
241
241