snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -6,12 +6,13 @@ import textwrap
6
6
  import warnings
7
7
  from enum import Enum
8
8
  from importlib import metadata as importlib_metadata
9
- from typing import Any, DefaultDict, Dict, List, Optional, Tuple
9
+ from typing import Any, DefaultDict, Optional
10
10
 
11
11
  import yaml
12
12
  from packaging import requirements, specifiers, version
13
13
 
14
14
  import snowflake.connector
15
+ from snowflake.ml import version as snowml_version
15
16
  from snowflake.ml._internal import env as snowml_env, relax_version_strategy
16
17
  from snowflake.ml._internal.utils import query_result_checker
17
18
  from snowflake.snowpark import context, exceptions, session
@@ -27,8 +28,8 @@ class CONDA_OS(Enum):
27
28
 
28
29
 
29
30
  _NODEFAULTS = "nodefaults"
30
- _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
31
- _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
31
+ _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: dict[str, list[version.Version]] = {}
32
+ _SNOWFLAKE_CONDA_PACKAGE_CACHE: dict[str, list[version.Version]] = {}
32
33
  _SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"]
33
34
 
34
35
  DEFAULT_CHANNEL_NAME = ""
@@ -64,7 +65,7 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
64
65
  return r
65
66
 
66
67
 
67
- def _validate_conda_dependency_string(dep_str: str) -> Tuple[str, requirements.Requirement]:
68
+ def _validate_conda_dependency_string(dep_str: str) -> tuple[str, requirements.Requirement]:
68
69
  """Validate conda dependency string like `pytorch == 1.12.1` or `conda-forge::transformer` and split the channel
69
70
  name before the double colon and requirement specification after that.
70
71
 
@@ -115,7 +116,7 @@ class DuplicateDependencyInMultipleChannelsError(Exception):
115
116
  ...
116
117
 
117
118
 
118
- def append_requirement_list(req_list: List[requirements.Requirement], p_req: requirements.Requirement) -> None:
119
+ def append_requirement_list(req_list: list[requirements.Requirement], p_req: requirements.Requirement) -> None:
119
120
  """Append a requirement to an existing requirement list. If need and able to merge, merge it, otherwise, append it.
120
121
 
121
122
  Args:
@@ -134,7 +135,7 @@ def append_requirement_list(req_list: List[requirements.Requirement], p_req: req
134
135
 
135
136
 
136
137
  def append_conda_dependency(
137
- conda_chan_deps: DefaultDict[str, List[requirements.Requirement]], p_chan_dep: Tuple[str, requirements.Requirement]
138
+ conda_chan_deps: DefaultDict[str, list[requirements.Requirement]], p_chan_dep: tuple[str, requirements.Requirement]
138
139
  ) -> None:
139
140
  """Append a conda dependency to an existing conda dependencies dict, if not existed in any channel.
140
141
  To avoid making unnecessary modification to dict, we check the existence first, then try to merge, then append,
@@ -164,45 +165,73 @@ def append_conda_dependency(
164
165
  conda_chan_deps[p_channel].append(p_req)
165
166
 
166
167
 
167
- def validate_pip_requirement_string_list(req_str_list: List[str]) -> List[requirements.Requirement]:
168
- """Validate the a list of pip requirement string according to PEP 508.
168
+ def validate_pip_requirement_string_list(
169
+ req_str_list: list[str], add_local_version_specifier: bool = False
170
+ ) -> list[requirements.Requirement]:
171
+ """Validate the list of pip requirement strings according to PEP 508.
169
172
 
170
173
  Args:
171
- req_str_list: The list of string contains the pip requirement specification.
174
+ req_str_list: The list of strings containing the pip requirement specification.
175
+ add_local_version_specifier: if True, add the version specifier of the locally installed package version to
176
+ requirements without version specifiers.
172
177
 
173
178
  Returns:
174
179
  A requirements.Requirement list containing the requirement information.
175
180
  """
176
- seen_pip_requirement_list: List[requirements.Requirement] = []
181
+ seen_pip_requirement_list: list[requirements.Requirement] = []
177
182
  for req_str in req_str_list:
178
183
  append_requirement_list(seen_pip_requirement_list, _validate_pip_requirement_string(req_str=req_str))
179
184
 
185
+ if add_local_version_specifier:
186
+ # For any requirement string that does not contain a specifier, add the specifier of a locally installed version
187
+ # if it exists.
188
+ seen_pip_requirement_list = list(
189
+ map(
190
+ lambda req: req if req.specifier else get_local_installed_version_of_pip_package(req),
191
+ seen_pip_requirement_list,
192
+ )
193
+ )
194
+
180
195
  return seen_pip_requirement_list
181
196
 
182
197
 
183
- def validate_conda_dependency_string_list(dep_str_list: List[str]) -> DefaultDict[str, List[requirements.Requirement]]:
198
+ def validate_conda_dependency_string_list(
199
+ dep_str_list: list[str], add_local_version_specifier: bool = False
200
+ ) -> DefaultDict[str, list[requirements.Requirement]]:
184
201
  """Validate a list of conda dependency string, find any duplicate package across different channel and create a dict
185
202
  to represent the whole dependencies.
186
203
 
187
204
  Args:
188
205
  dep_str_list: The list of string contains the conda dependency specification.
206
+ add_local_version_specifier: if True, add the version specifier of the locally installed package version to
207
+ requirements without version specifiers.
189
208
 
190
209
  Returns:
191
210
  A dict mapping from the channel name to the list of requirements from that channel.
192
211
  """
193
212
  validated_conda_dependency_list = list(map(_validate_conda_dependency_string, dep_str_list))
194
- ret_conda_dependency_dict: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
213
+ ret_conda_dependency_dict: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
195
214
  for p_channel, p_req in validated_conda_dependency_list:
196
215
  append_conda_dependency(ret_conda_dependency_dict, (p_channel, p_req))
197
216
 
217
+ if add_local_version_specifier:
218
+ # For any conda dependency string that does not contain a specifier, add the specifier of a locally installed
219
+ # version if it exists. This is best-effort: if the conda package does not have the same name as the pip
220
+ # package, it won't be found in the local environment.
221
+ for channel_str, reqs in ret_conda_dependency_dict.items():
222
+ reqs = list(
223
+ map(lambda req: req if req.specifier else get_local_installed_version_of_pip_package(req), reqs)
224
+ )
225
+ ret_conda_dependency_dict[channel_str] = reqs
226
+
198
227
  return ret_conda_dependency_dict
199
228
 
200
229
 
201
230
  def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement) -> requirements.Requirement:
202
231
  """Get the local installed version of a given pip package requirement.
203
- If the package is locally installed, and the local version meet the specifier of the requirements, return a new
232
+ If the package is locally installed, and the local version meets the specifier of the requirements, return a new
204
233
  requirement specifier that pins the version.
205
- If the local version does not meet the specifier of the requirements, a warn will be omitted and returns
234
+ If the local version does not meet the specifier of the requirements, a warning will be emitted and returns
206
235
  the original package requirement.
207
236
  If the package is not locally installed or not found, the original package requirement is returned.
208
237
 
@@ -217,7 +246,7 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement
217
246
  local_dist_version = local_dist.version
218
247
  except importlib_metadata.PackageNotFoundError:
219
248
  if pip_req.name == SNOWPARK_ML_PKG_NAME:
220
- local_dist_version = snowml_env.VERSION
249
+ local_dist_version = snowml_version.VERSION
221
250
  else:
222
251
  return pip_req
223
252
  new_pip_req = copy.deepcopy(pip_req)
@@ -372,8 +401,8 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
372
401
 
373
402
 
374
403
  def get_matched_package_versions_in_information_schema_with_active_session(
375
- reqs: List[requirements.Requirement], python_version: str
376
- ) -> Dict[str, List[version.Version]]:
404
+ reqs: list[requirements.Requirement], python_version: str
405
+ ) -> dict[str, list[version.Version]]:
377
406
  try:
378
407
  session = context.get_active_session()
379
408
  except exceptions.SnowparkSessionException:
@@ -383,10 +412,10 @@ def get_matched_package_versions_in_information_schema_with_active_session(
383
412
 
384
413
  def get_matched_package_versions_in_information_schema(
385
414
  session: session.Session,
386
- reqs: List[requirements.Requirement],
415
+ reqs: list[requirements.Requirement],
387
416
  python_version: str,
388
- statement_params: Optional[Dict[str, Any]] = None,
389
- ) -> Dict[str, List[version.Version]]:
417
+ statement_params: Optional[dict[str, Any]] = None,
418
+ ) -> dict[str, list[version.Version]]:
390
419
  """Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
391
420
  Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
392
421
  exist in the information_schema table but has not yet become available in the Snowflake Conda channel.
@@ -400,8 +429,8 @@ def get_matched_package_versions_in_information_schema(
400
429
  Returns:
401
430
  A Dict, whose key is the package name, and value is a list of versions match the requirements.
402
431
  """
403
- ret_dict: Dict[str, List[version.Version]] = {}
404
- reqs_to_request: List[requirements.Requirement] = []
432
+ ret_dict: dict[str, list[version.Version]] = {}
433
+ reqs_to_request: list[requirements.Requirement] = []
405
434
  for req in reqs:
406
435
  if req.name in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
407
436
  available_versions = list(
@@ -457,7 +486,7 @@ def get_matched_package_versions_in_information_schema(
457
486
 
458
487
  def save_conda_env_file(
459
488
  path: pathlib.Path,
460
- conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
489
+ conda_chan_deps: DefaultDict[str, list[requirements.Requirement]],
461
490
  python_version: str,
462
491
  cuda_version: Optional[str] = None,
463
492
  default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
@@ -478,7 +507,7 @@ def save_conda_env_file(
478
507
  """
479
508
  assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
480
509
  path.parent.mkdir(parents=True, exist_ok=True)
481
- env: Dict[str, Any] = dict()
510
+ env: dict[str, Any] = dict()
482
511
  env["name"] = "snow-env"
483
512
  # Get all channels in the dependencies, ordered by the number of the packages which belongs to and put into
484
513
  # channels section.
@@ -505,7 +534,7 @@ def save_conda_env_file(
505
534
  yaml.safe_dump(env, stream=f, default_flow_style=False)
506
535
 
507
536
 
508
- def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requirement]) -> None:
537
+ def save_requirements_file(path: pathlib.Path, pip_deps: list[requirements.Requirement]) -> None:
509
538
  """Generate Python requirements.txt file in the given directory path.
510
539
 
511
540
  Args:
@@ -521,9 +550,9 @@ def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requi
521
550
 
522
551
  def load_conda_env_file(
523
552
  path: pathlib.Path,
524
- ) -> Tuple[
525
- DefaultDict[str, List[requirements.Requirement]],
526
- Optional[List[requirements.Requirement]],
553
+ ) -> tuple[
554
+ DefaultDict[str, list[requirements.Requirement]],
555
+ Optional[list[requirements.Requirement]],
527
556
  Optional[str],
528
557
  Optional[str],
529
558
  ]:
@@ -601,7 +630,7 @@ def load_conda_env_file(
601
630
  return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version, cuda_version
602
631
 
603
632
 
604
- def load_requirements_file(path: pathlib.Path) -> List[requirements.Requirement]:
633
+ def load_requirements_file(path: pathlib.Path) -> list[requirements.Requirement]:
605
634
  """Load Python requirements.txt file from the given directory path.
606
635
 
607
636
  Args:
@@ -641,8 +670,8 @@ def parse_python_version_string(dep: str) -> Optional[str]:
641
670
 
642
671
 
643
672
  def _find_conda_dep_spec(
644
- conda_chan_deps: DefaultDict[str, List[requirements.Requirement]], pkg_name: str
645
- ) -> Optional[Tuple[str, requirements.Requirement]]:
673
+ conda_chan_deps: DefaultDict[str, list[requirements.Requirement]], pkg_name: str
674
+ ) -> Optional[tuple[str, requirements.Requirement]]:
646
675
  for channel in conda_chan_deps:
647
676
  spec = next(filter(lambda req: req.name == pkg_name, conda_chan_deps[channel]), None)
648
677
  if spec:
@@ -650,14 +679,14 @@ def _find_conda_dep_spec(
650
679
  return None
651
680
 
652
681
 
653
- def _find_pip_req_spec(pip_reqs: List[requirements.Requirement], pkg_name: str) -> Optional[requirements.Requirement]:
682
+ def _find_pip_req_spec(pip_reqs: list[requirements.Requirement], pkg_name: str) -> Optional[requirements.Requirement]:
654
683
  spec = next(filter(lambda req: req.name == pkg_name, pip_reqs), None)
655
684
  return spec
656
685
 
657
686
 
658
687
  def find_dep_spec(
659
- conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
660
- pip_reqs: List[requirements.Requirement],
688
+ conda_chan_deps: DefaultDict[str, list[requirements.Requirement]],
689
+ pip_reqs: list[requirements.Requirement],
661
690
  conda_pkg_name: str,
662
691
  pip_pkg_name: Optional[str] = None,
663
692
  remove_spec: bool = False,
@@ -11,18 +11,7 @@ import sys
11
11
  import tarfile
12
12
  import tempfile
13
13
  import zipfile
14
- from typing import (
15
- Any,
16
- Callable,
17
- Dict,
18
- Generator,
19
- List,
20
- Literal,
21
- Optional,
22
- Set,
23
- Tuple,
24
- Union,
25
- )
14
+ from typing import Any, Callable, Generator, Literal, Optional, Union
26
15
  from urllib import parse
27
16
 
28
17
  import cloudpickle
@@ -37,7 +26,7 @@ GENERATED_PY_FILE_EXT = (".pyc", ".pyo", ".pyd", ".pyi")
37
26
  def copytree(
38
27
  src: "Union[str, os.PathLike[str]]",
39
28
  dst: "Union[str, os.PathLike[str]]",
40
- ignore: Optional[Callable[..., Set[str]]] = None,
29
+ ignore: Optional[Callable[..., set[str]]] = None,
41
30
  dirs_exist_ok: bool = False,
42
31
  ) -> "Union[str, os.PathLike[str]]":
43
32
  """This is a forked version of shutil.copytree that remove all copystat, to make sure it works in Sproc.
@@ -170,7 +159,7 @@ def zip_python_package(zipfile_path: str, package_name: str, ignore_generated_py
170
159
 
171
160
 
172
161
  def hash_directory(
173
- directory: Union[str, pathlib.Path], *, ignore_hidden: bool = False, excluded_files: Optional[List[str]] = None
162
+ directory: Union[str, pathlib.Path], *, ignore_hidden: bool = False, excluded_files: Optional[list[str]] = None
174
163
  ) -> str:
175
164
  """Hash the **content** of a folder recursively using SHA-1.
176
165
 
@@ -186,7 +175,7 @@ def hash_directory(
186
175
  excluded_files = []
187
176
 
188
177
  def _update_hash_from_dir(
189
- directory: Union[str, pathlib.Path], hash: "hashlib._Hash", *, ignore_hidden: bool, excluded_files: List[str]
178
+ directory: Union[str, pathlib.Path], hash: "hashlib._Hash", *, ignore_hidden: bool, excluded_files: list[str]
190
179
  ) -> "hashlib._Hash":
191
180
  assert pathlib.Path(directory).is_dir(), "Provided path is not a directory."
192
181
  for path in sorted(pathlib.Path(directory).iterdir(), key=lambda p: str(p).lower()):
@@ -208,7 +197,7 @@ def hash_directory(
208
197
  ).hexdigest()
209
198
 
210
199
 
211
- def get_all_modules(dirname: str, prefix: str = "") -> List[str]:
200
+ def get_all_modules(dirname: str, prefix: str = "") -> list[str]:
212
201
  modules = [mod.name for mod in pkgutil.iter_modules([dirname], prefix=prefix)]
213
202
  subdirs = [f.path for f in os.scandir(dirname) if f.is_dir()]
214
203
  for sub_dirname in subdirs:
@@ -248,7 +237,7 @@ def _create_tar_gz_stream(source_dir: str, arcname: Optional[str] = None) -> Gen
248
237
  yield output_stream
249
238
 
250
239
 
251
- def get_package_path(package_name: str, strategy: Literal["first", "last"] = "first") -> Tuple[str, str]:
240
+ def get_package_path(package_name: str, strategy: Literal["first", "last"] = "first") -> tuple[str, str]:
252
241
  """[Obsolete]Return the path to where a package is defined and its start location.
253
242
  Example 1: snowflake.ml -> path/to/site-packages/snowflake/ml, path/to/site-packages
254
243
  Example 2: zip_imported_module -> path/to/some/zipfile.zip/zip_imported_module, path/to/some/zipfile.zip
@@ -267,7 +256,7 @@ def get_package_path(package_name: str, strategy: Literal["first", "last"] = "fi
267
256
  return pkg_path, pkg_start_path
268
257
 
269
258
 
270
- def stage_object(session: snowpark.Session, object: object, stage_location: str) -> List[snowpark.PutResult]:
259
+ def stage_object(session: snowpark.Session, object: object, stage_location: str) -> list[snowpark.PutResult]:
271
260
  temp_file = tempfile.NamedTemporaryFile(delete=False)
272
261
  temp_file_path = temp_file.name
273
262
  temp_file.close()
@@ -279,7 +268,7 @@ def stage_object(session: snowpark.Session, object: object, stage_location: str)
279
268
 
280
269
 
281
270
  def stage_file_exists(
282
- session: snowpark.Session, stage_location: str, file_name: str, statement_params: Dict[str, Any]
271
+ session: snowpark.Session, stage_location: str, file_name: str, statement_params: dict[str, Any]
283
272
  ) -> bool:
284
273
  try:
285
274
  res = session.sql(f"list {stage_location}/{file_name}").collect(statement_params=statement_params)
@@ -297,7 +286,7 @@ def upload_directory_to_stage(
297
286
  local_path: pathlib.Path,
298
287
  stage_path: Union[pathlib.PurePosixPath, parse.ParseResult],
299
288
  *,
300
- statement_params: Optional[Dict[str, Any]] = None,
289
+ statement_params: Optional[dict[str, Any]] = None,
301
290
  ) -> None:
302
291
  """Upload a local folder recursively to a stage and keep the structure.
303
292
 
@@ -350,7 +339,7 @@ def download_directory_from_stage(
350
339
  stage_path: pathlib.PurePosixPath,
351
340
  local_path: pathlib.Path,
352
341
  *,
353
- statement_params: Optional[Dict[str, Any]] = None,
342
+ statement_params: Optional[dict[str, Any]] = None,
354
343
  ) -> None:
355
344
  """Upload a folder in stage recursively to a folder in local and keep the structure.
356
345
 
@@ -15,7 +15,6 @@ In this module you will find:
15
15
 
16
16
  import math
17
17
  from abc import ABC, abstractmethod
18
- from typing import Dict, List, Tuple
19
18
 
20
19
 
21
20
  class HRIDBase(ABC):
@@ -28,12 +27,11 @@ class HRIDBase(ABC):
28
27
  @abstractmethod
29
28
  def __id_generator__(self) -> int:
30
29
  """The generator to use to generate new IDs. The implementer needs to provide this."""
31
- pass
32
30
 
33
- __hrid_structure__: Tuple[str, ...]
31
+ __hrid_structure__: tuple[str, ...]
34
32
  """The HRID structure to be generated. The implementer needs to provide this."""
35
33
 
36
- __hrid_words__: Dict[str, Tuple[str, ...]]
34
+ __hrid_words__: dict[str, tuple[str, ...]]
37
35
  """The mapping between the HRID parts and the words to use. The implementer needs to provide this."""
38
36
 
39
37
  __separator__ = "_"
@@ -82,7 +80,7 @@ class HRIDBase(ABC):
82
80
  hrid.append(str(values[idxs[i]]))
83
81
  return self.__separator__.join(hrid)
84
82
 
85
- def generate(self) -> Tuple[int, str]:
83
+ def generate(self) -> tuple[int, str]:
86
84
  """Generate an ID and the corresponding HRID.
87
85
 
88
86
  Returns:
@@ -92,7 +90,7 @@ class HRIDBase(ABC):
92
90
  hrid = self.id_to_hrid(id)
93
91
  return (id, hrid)
94
92
 
95
- def _id_to_idxs(self, id: int) -> List[int]:
93
+ def _id_to_idxs(self, id: int) -> list[int]:
96
94
  """Take the ID and convert it to indices into the HRID words.
97
95
 
98
96
  Args:
@@ -109,7 +107,7 @@ class HRIDBase(ABC):
109
107
  idxs.append((id & mask) >> shift)
110
108
  return idxs
111
109
 
112
- def _hrid_to_idxs(self, hrid: str) -> List[int]:
110
+ def _hrid_to_idxs(self, hrid: str) -> list[int]:
113
111
  """Take the HRID and convert it to indices into the HRID words.
114
112
 
115
113
  Args:
@@ -2,10 +2,9 @@ import importlib
2
2
  import inspect
3
3
  import pkgutil
4
4
  from types import FunctionType
5
- from typing import Dict
6
5
 
7
6
 
8
- def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> Dict[str, type]:
7
+ def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> dict[str, type]:
9
8
  """Finds classes defined all the python modules in the given package directory.
10
9
 
11
10
  Args:
@@ -36,7 +35,7 @@ def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> Dict[s
36
35
  return exportable_classes
37
36
 
38
37
 
39
- def fetch_functions_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> Dict[str, FunctionType]:
38
+ def fetch_functions_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> dict[str, FunctionType]:
40
39
  """Finds functions defined all the python modules in the given package directory.
41
40
 
42
41
  Args:
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Any, Callable, List, Optional, get_args
3
+ from typing import Any, Callable, Optional, get_args
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml.data import data_source
@@ -9,7 +9,7 @@ _DATA_SOURCES_ATTR = "_data_sources"
9
9
 
10
10
 
11
11
  def _wrap_func(
12
- fn: Callable[..., snowpark.DataFrame], data_sources: List[data_source.DataSource]
12
+ fn: Callable[..., snowpark.DataFrame], data_sources: list[data_source.DataSource]
13
13
  ) -> Callable[..., snowpark.DataFrame]:
14
14
  """Wrap a DataFrame transform function to propagate data_sources to derived DataFrames."""
15
15
 
@@ -34,9 +34,9 @@ def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., sno
34
34
  return wrapped
35
35
 
36
36
 
37
- def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
37
+ def get_data_sources(*args: Any) -> Optional[list[data_source.DataSource]]:
38
38
  """Helper method for extracting data sources attribute from DataFrames in an argument list"""
39
- result: Optional[List[data_source.DataSource]] = None
39
+ result: Optional[list[data_source.DataSource]] = None
40
40
  for arg in args:
41
41
  srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
42
42
  if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
@@ -46,7 +46,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
46
46
  return result
47
47
 
48
48
 
49
- def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
49
+ def set_data_sources(obj: Any, data_sources: Optional[list[data_source.DataSource]]) -> None:
50
50
  """Helper method for attaching data sources to an object"""
51
51
  if data_sources:
52
52
  assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
@@ -54,7 +54,7 @@ def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSourc
54
54
 
55
55
 
56
56
  def patch_dataframe(
57
- df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False
57
+ df: snowpark.DataFrame, data_sources: list[data_source.DataSource], inplace: bool = False
58
58
  ) -> snowpark.DataFrame:
59
59
  """
60
60
  Monkey patch a DataFrame to add attach the provided data_sources as an attribute of the DataFrame.
@@ -1,5 +1,6 @@
1
1
  import json
2
- from typing import Any, Dict, Optional
2
+ from contextlib import contextmanager
3
+ from typing import Any, Optional
3
4
 
4
5
  from absl import logging
5
6
 
@@ -27,21 +28,50 @@ class PlatformCapabilities:
27
28
  """
28
29
 
29
30
  _instance: Optional["PlatformCapabilities"] = None
31
+ # Used for unittesting only. This is to avoid the need to mock the session object or reaching out to Snowflake
32
+ _mock_features: Optional[dict[str, Any]] = None
30
33
 
31
34
  @classmethod
32
35
  def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
36
+ # Used for unittesting only. In this situation, _instance is not initialized.
37
+ if cls._mock_features is not None:
38
+ return cls(features=cls._mock_features)
33
39
  if not cls._instance:
34
- cls._instance = cls(session)
40
+ cls._instance = cls(session=session)
35
41
  return cls._instance
36
42
 
43
+ @classmethod
44
+ def set_mock_features(cls, features: Optional[dict[str, Any]] = None) -> None:
45
+ cls._mock_features = features
46
+
47
+ @classmethod
48
+ def clear_mock_features(cls) -> None:
49
+ cls._mock_features = None
50
+
51
+ # For contextmanager, we need to have return type Iterator[Never]. However, Never type is introduced only in
52
+ # Python 3.11. So, we are ignoring the type for this method.
53
+ @classmethod # type: ignore[arg-type]
54
+ @contextmanager
55
+ def mock_features(cls, features: dict[str, Any]) -> None: # type: ignore[misc]
56
+ logging.debug(f"Setting mock features: {features}")
57
+ cls.set_mock_features(features)
58
+ try:
59
+ yield
60
+ finally:
61
+ logging.debug(f"Clearing mock features: {features}")
62
+ cls.clear_mock_features()
63
+
37
64
  def is_nested_function_enabled(self) -> bool:
38
65
  return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
39
66
 
67
+ def is_inlined_deployment_spec_enabled(self) -> bool:
68
+ return self._get_bool_feature("ENABLE_INLINE_DEPLOYMENT_SPEC", False)
69
+
40
70
  def is_live_commit_enabled(self) -> bool:
41
71
  return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
42
72
 
43
73
  @staticmethod
44
- def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
74
+ def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
45
75
  try:
46
76
  result = (
47
77
  query_result_checker.SqlResultValidator(
@@ -68,11 +98,17 @@ class PlatformCapabilities:
68
98
  # This can happen is server side is older than 9.2. That is fine.
69
99
  return {}
70
100
 
71
- def __init__(self, session: Optional[snowpark_session.Session] = None) -> None:
101
+ def __init__(
102
+ self, *, session: Optional[snowpark_session.Session] = None, features: Optional[dict[str, Any]] = None
103
+ ) -> None:
104
+ # This is for testing purposes only.
105
+ if features:
106
+ self.features = features
107
+ return
72
108
  if not session:
73
109
  session = next(iter(snowpark_session._get_active_sessions()))
74
110
  assert session, "Missing active session object"
75
- self.features: Dict[str, Any] = PlatformCapabilities._get_features(session)
111
+ self.features = PlatformCapabilities._get_features(session)
76
112
 
77
113
  def _get_bool_feature(self, feature_name: str, default_value: bool) -> bool:
78
114
  value = self.features.get(feature_name, default_value)