snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -6,17 +6,7 @@ import pickle
6
6
  import sys
7
7
  import textwrap
8
8
  from pathlib import Path, PurePath
9
- from typing import (
10
- Any,
11
- Callable,
12
- List,
13
- Optional,
14
- Type,
15
- Union,
16
- cast,
17
- get_args,
18
- get_origin,
19
- )
9
+ from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
20
10
 
21
11
  import cloudpickle as cp
22
12
 
@@ -27,6 +17,7 @@ from snowflake.snowpark._internal import code_generation
27
17
 
28
18
  _SUPPORTED_ARG_TYPES = {str, int, float}
29
19
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
20
+ _ENTRYPOINT_FUNC_NAME = "func"
30
21
  _STARTUP_SCRIPT_PATH = PurePath("startup.sh")
31
22
  _STARTUP_SCRIPT_CODE = textwrap.dedent(
32
23
  f"""
@@ -73,14 +64,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
73
64
  ##### Ray configuration #####
74
65
  shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
75
66
 
76
- # Check if the instance ip retrieval module exists, which is a prerequisite for multi node jobs
67
+ # Check if the local get_instance_ip.py script exists
77
68
  HELPER_EXISTS=$(
78
- python3 -c "import snowflake.runtime.utils.get_instance_ip" 2>/dev/null && echo "true" || echo "false"
69
+ [ -f "get_instance_ip.py" ] && echo "true" || echo "false"
79
70
  )
80
71
 
81
72
  # Configure IP address and logging directory
82
73
  if [ "$HELPER_EXISTS" = "true" ]; then
83
- eth0Ip=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
74
+ eth0Ip=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
84
75
  else
85
76
  eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
86
77
  fi
@@ -103,7 +94,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
103
94
 
104
95
  # Determine if it should be a worker or a head node for batch jobs
105
96
  if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
106
- head_info=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --head)
97
+ head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
107
98
  if [ $? -eq 0 ]; then
108
99
  # Parse the output using read
109
100
  read head_index head_ip <<< "$head_info"
@@ -166,10 +157,17 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
166
157
  "--object-store-memory=${{shm_size}}"
167
158
  )
168
159
 
169
- # Start Ray on a worker node
170
- ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block
171
- else
160
+ # Start Ray on a worker node - run in background
161
+ ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
162
+
163
+ # Start the worker shutdown listener in the background
164
+ echo "Starting worker shutdown listener..."
165
+ python worker_shutdown_listener.py
166
+ WORKER_EXIT_CODE=$?
172
167
 
168
+ echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
169
+ exit $WORKER_EXIT_CODE
170
+ else
173
171
  # Additional head-specific parameters
174
172
  head_params=(
175
173
  "--head"
@@ -193,13 +191,39 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
193
191
  # Run user's Python entrypoint
194
192
  echo Running command: python "$@"
195
193
  python "$@"
194
+
195
+ # After the user's job completes, signal workers to shut down
196
+ echo "User job completed. Signaling workers to shut down..."
197
+ python signal_workers.py --wait-time 15
198
+ echo "Head node job completed. Exiting."
196
199
  fi
197
200
  """
198
201
  ).strip()
199
202
 
200
203
 
201
- def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
202
- parent = parent.absolute()
204
+ def resolve_source(source: Union[Path, Callable[..., Any]]) -> Union[Path, Callable[..., Any]]:
205
+ if callable(source):
206
+ return source
207
+ elif isinstance(source, Path):
208
+ # Validate source
209
+ source = source
210
+ if not source.exists():
211
+ raise FileNotFoundError(f"{source} does not exist")
212
+ return source.absolute()
213
+ else:
214
+ raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
215
+
216
+
217
+ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Optional[Path]) -> types.PayloadEntrypoint:
218
+ if callable(source):
219
+ # Entrypoint is generated for callable payloads
220
+ return types.PayloadEntrypoint(
221
+ file_path=entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH),
222
+ main_func=_ENTRYPOINT_FUNC_NAME,
223
+ )
224
+
225
+ # Resolve entrypoint path for file-based payloads
226
+ parent = source.absolute()
203
227
  if entrypoint is None:
204
228
  if parent.is_file():
205
229
  # Infer entrypoint from source
@@ -218,12 +242,23 @@ def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
218
242
  else:
219
243
  # Relative to source dir
220
244
  entrypoint = parent.joinpath(entrypoint)
245
+
246
+ # Validate resolved entrypoint file
221
247
  if not entrypoint.is_file():
222
248
  raise FileNotFoundError(
223
249
  "Entrypoint not found. Ensure the entrypoint is a valid file and is under"
224
250
  f" the source directory (source={parent}, entrypoint={entrypoint})"
225
251
  )
226
- return entrypoint
252
+ if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
253
+ raise ValueError(
254
+ "Unsupported entrypoint type:"
255
+ f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
256
+ )
257
+
258
+ return types.PayloadEntrypoint(
259
+ file_path=entrypoint, # entrypoint is an absolute path at this point
260
+ main_func=None,
261
+ )
227
262
 
228
263
 
229
264
  class JobPayload:
@@ -232,46 +267,17 @@ class JobPayload:
232
267
  source: Union[str, Path, Callable[..., Any]],
233
268
  entrypoint: Optional[Union[str, Path]] = None,
234
269
  *,
235
- pip_requirements: Optional[List[str]] = None,
270
+ pip_requirements: Optional[list[str]] = None,
236
271
  ) -> None:
237
272
  self.source = Path(source) if isinstance(source, str) else source
238
273
  self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
239
274
  self.pip_requirements = pip_requirements
240
275
 
241
- def validate(self) -> None:
242
- if callable(self.source):
243
- # Any entrypoint value is OK for callable payloads (including None aka default)
244
- # since we will generate the file from the serialized callable
245
- pass
246
- elif isinstance(self.source, Path):
247
- # Validate source
248
- source = self.source
249
- if not source.exists():
250
- raise FileNotFoundError(f"{source} does not exist")
251
- source = source.absolute()
252
-
253
- # Validate entrypoint
254
- entrypoint = _resolve_entrypoint(source, self.entrypoint)
255
- if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
256
- raise ValueError(
257
- "Unsupported entrypoint type:"
258
- f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
259
- )
260
-
261
- # Update fields with normalized values
262
- self.source = source
263
- self.entrypoint = entrypoint
264
- else:
265
- raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
266
-
267
276
  def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
268
- # Validate payload
269
- self.validate()
270
-
271
277
  # Prepare local variables
272
278
  stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
273
- source = self.source
274
- entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
279
+ source = resolve_source(self.source)
280
+ entrypoint = resolve_entrypoint(source, self.entrypoint)
275
281
 
276
282
  # Create stage if necessary
277
283
  stage_name = stage_path.parts[0].lstrip("@")
@@ -290,11 +296,11 @@ class JobPayload:
290
296
  source_code = generate_python_code(source, source_code_display=True)
291
297
  _ = session.file.put_stream(
292
298
  io.BytesIO(source_code.encode()),
293
- stage_location=stage_path.joinpath(entrypoint).as_posix(),
299
+ stage_location=stage_path.joinpath(entrypoint.file_path).as_posix(),
294
300
  auto_compress=False,
295
301
  overwrite=True,
296
302
  )
297
- source = entrypoint.parent
303
+ source = Path(entrypoint.file_path.parent)
298
304
  elif source.is_dir():
299
305
  # Manually traverse the directory and upload each file, since Snowflake PUT
300
306
  # can't handle directories. Reduce the number of PUT operations by using
@@ -337,17 +343,35 @@ class JobPayload:
337
343
  overwrite=False, # FIXME
338
344
  )
339
345
 
346
+ # Upload system scripts
347
+ scripts_dir = Path(__file__).parent.joinpath("scripts")
348
+ for script_file in scripts_dir.glob("*"):
349
+ if script_file.is_file():
350
+ session.file.put(
351
+ script_file.as_posix(),
352
+ stage_path.as_posix(),
353
+ overwrite=True,
354
+ auto_compress=False,
355
+ )
356
+
357
+ python_entrypoint: list[Union[str, PurePath]] = [
358
+ PurePath("mljob_launcher.py"),
359
+ entrypoint.file_path.relative_to(source),
360
+ ]
361
+ if entrypoint.main_func:
362
+ python_entrypoint += ["--script_main_func", entrypoint.main_func]
363
+
340
364
  return types.UploadedPayload(
341
365
  stage_path=stage_path,
342
366
  entrypoint=[
343
367
  "bash",
344
368
  _STARTUP_SCRIPT_PATH,
345
- entrypoint.relative_to(source),
369
+ *python_entrypoint,
346
370
  ],
347
371
  )
348
372
 
349
373
 
350
- def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
374
+ def _get_parameter_type(param: inspect.Parameter) -> Optional[type[object]]:
351
375
  # Unwrap Optional type annotations
352
376
  param_type = param.annotation
353
377
  if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
@@ -356,10 +380,10 @@ def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
356
380
  # Return None for empty type annotations
357
381
  if param_type == inspect.Parameter.empty:
358
382
  return None
359
- return cast(Type[object], param_type)
383
+ return cast(type[object], param_type)
360
384
 
361
385
 
362
- def _validate_parameter_type(param_type: Type[object], param_name: str) -> None:
386
+ def _validate_parameter_type(param_type: type[object], param_name: str) -> None:
363
387
  # Validate param_type is a supported type
364
388
  if param_type not in _SUPPORTED_ARG_TYPES:
365
389
  raise ValueError(
@@ -471,12 +495,11 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
471
495
  # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
472
496
  source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
473
497
 
474
- func_name = "func"
475
498
  func_code = f"""
476
499
  {source_code_comment}
477
500
 
478
501
  import pickle
479
- {func_name} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
502
+ {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
480
503
  """
481
504
 
482
505
  arg_dict_name = "kwargs"
@@ -487,6 +510,7 @@ import pickle
487
510
 
488
511
  return f"""
489
512
  ### Version guard to check compatibility across Python versions ###
513
+ import os
490
514
  import sys
491
515
  import warnings
492
516
 
@@ -508,5 +532,5 @@ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor
508
532
  if __name__ == '__main__':
509
533
  {textwrap.indent(param_code, ' ')}
510
534
 
511
- {func_name}(**{arg_dict_name})
535
+ __return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
512
536
  """
@@ -0,0 +1,4 @@
1
+ # Constants defining the shutdown signal actor configuration.
2
+ SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
3
+ SHUTDOWN_ACTOR_NAMESPACE = "default"
4
+ SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
@@ -0,0 +1,136 @@
1
+ #!/usr/bin/env python3
2
+ # This file is modified from mlruntime/service/snowflake/runtime/utils
3
+ import argparse
4
+ import logging
5
+ import socket
6
+ import sys
7
+ import time
8
+ from typing import Optional
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_self_ip() -> Optional[str]:
16
+ """Get the IP address of the current service instance.
17
+ References:
18
+ - https://docs.snowflake.com/en/developer-guide/snowpark-container-services/working-with-services#general-guidelines-related-to-service-to-service-communications # noqa: E501
19
+
20
+ Returns:
21
+ Optional[str]: The IP address of the current service instance, or None if unable to retrieve.
22
+ """
23
+ try:
24
+ hostname = socket.gethostname()
25
+ instance_ip = socket.gethostbyname(hostname)
26
+ return instance_ip
27
+ except OSError as e:
28
+ logger.error(f"Error: Unable to get IP address via socket. {e}")
29
+ return None
30
+
31
+
32
+ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
33
+ """Get the first instance of a batch job based on start time and instance ID.
34
+
35
+ Args:
36
+ service_name (str): The name of the service to query.
37
+
38
+ Returns:
39
+ tuple[str, str]: A tuple containing (instance_id, ip_address) of the head instance.
40
+ """
41
+ from snowflake.runtime.utils import session_utils
42
+
43
+ session = session_utils.get_session()
44
+ df = session.sql(f"show service instances in service {service_name}")
45
+ result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
46
+
47
+ if not result:
48
+ return None
49
+
50
+ # Sort by start_time first, then by instance_id
51
+ sorted_instances = sorted(result, key=lambda x: (x["start_time"], int(x["instance_id"])))
52
+ head_instance = sorted_instances[0]
53
+ if not head_instance["instance_id"] or not head_instance["ip_address"]:
54
+ return None
55
+
56
+ # Validate head instance IP
57
+ ip_address = head_instance["ip_address"]
58
+ try:
59
+ socket.inet_aton(ip_address) # Validate IPv4 address
60
+ return (head_instance["instance_id"], ip_address)
61
+ except OSError:
62
+ logger.error(f"Error: Invalid IP address format: {ip_address}")
63
+ return None
64
+
65
+
66
+ def main():
67
+ """Retrieves the IP address of a specified service instance or the current service.
68
+ Args:
69
+ service_name (str,required) Name of the service to query
70
+ --instance-index (int, optional) Index of the service instance to query. Default: -1
71
+ Currently only supports -1 to get the IP address of the current service instance.
72
+ --head (bool, optional) Get the head instance information using show services.
73
+ If set, instance-index will be ignored, and the script will return the index and IP address of
74
+ the head instance, split by a space. Default: False.
75
+ --timeout (int, optional) Maximum time to wait for IP address retrieval in seconds. Default: 720 seconds
76
+ --retry-interval (int, optional) Time to wait between retry attempts in seconds. Default: 10 seconds
77
+ Usage Examples:
78
+ python get_instance_ip.py myservice --instance-index=1 --retry-interval=5
79
+ Returns:
80
+ Prints the IP address to stdout if successful. Exits with status code 0 on success, 1 on failure
81
+ """
82
+
83
+ parser = argparse.ArgumentParser(description="Get IP address of a service instance")
84
+ group = parser.add_mutually_exclusive_group()
85
+ parser.add_argument("service_name", help="Name of the service")
86
+ group.add_argument(
87
+ "--instance-index",
88
+ type=int,
89
+ default=-1,
90
+ help="Index of service instance (default: -1 for self instance)",
91
+ )
92
+ group.add_argument(
93
+ "--head",
94
+ action="store_true",
95
+ help="Get head instance information using show services",
96
+ )
97
+ parser.add_argument("--timeout", type=int, default=720, help="Timeout in seconds (default: 720)")
98
+ parser.add_argument(
99
+ "--retry-interval",
100
+ type=int,
101
+ default=10,
102
+ help="Retry interval in seconds (default: 10)",
103
+ )
104
+
105
+ args = parser.parse_args()
106
+ start_time = time.time()
107
+
108
+ if args.head:
109
+ while time.time() - start_time < args.timeout:
110
+ head_info = get_first_instance(args.service_name)
111
+ if head_info:
112
+ # Print to stdout to allow capture but don't use logger
113
+ sys.stdout.write(f"{head_info[0]} {head_info[1]}\n")
114
+ sys.exit(0)
115
+ time.sleep(args.retry_interval)
116
+ # If we get here, we've timed out
117
+ logger.error("Error: Unable to retrieve head IP address")
118
+ sys.exit(1)
119
+
120
+ # If the index is -1, use get_self_ip to get the IP address of the current service
121
+ if args.instance_index == -1:
122
+ ip_address = get_self_ip()
123
+ if ip_address:
124
+ sys.stdout.write(f"{ip_address}\n")
125
+ sys.exit(0)
126
+ else:
127
+ logger.error("Error: Unable to retrieve self IP address")
128
+ sys.exit(1)
129
+ else:
130
+ # We don't support querying a specific instance index other than -1
131
+ logger.error("Error: Invalid arguments. Only --instance-index=-1 is supported for now.")
132
+ sys.exit(1)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
@@ -0,0 +1,181 @@
1
+ import argparse
2
+ import copy
3
+ import importlib.util
4
+ import json
5
+ import os
6
+ import runpy
7
+ import sys
8
+ import traceback
9
+ import warnings
10
+ from pathlib import Path
11
+ from typing import Any, Optional
12
+
13
+ import cloudpickle
14
+
15
+ from snowflake.ml.jobs._utils import constants
16
+ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
17
+ from snowflake.snowpark import Session
18
+
19
+ # Fallbacks in case of SnowML version mismatch
20
+ RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
21
+
22
+ JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
23
+
24
+
25
+ try:
26
+ from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
27
+ except ImportError:
28
+ from dataclasses import dataclass
29
+
30
+ @dataclass(frozen=True)
31
+ class ExecutionResult: # type: ignore[no-redef]
32
+ result: Optional[Any] = None
33
+ exception: Optional[BaseException] = None
34
+
35
+ @property
36
+ def success(self) -> bool:
37
+ return self.exception is None
38
+
39
+ def to_dict(self) -> dict[str, Any]:
40
+ """Return the serializable dictionary."""
41
+ if isinstance(self.exception, BaseException):
42
+ exc_type = type(self.exception)
43
+ return {
44
+ "success": False,
45
+ "exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
46
+ "exc_value": self.exception,
47
+ "exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
48
+ }
49
+ return {
50
+ "success": True,
51
+ "result_type": type(self.result).__qualname__,
52
+ "result": self.result,
53
+ }
54
+
55
+
56
+ # Create a custom JSON encoder that converts non-serializable types to strings
57
+ class SimpleJSONEncoder(json.JSONEncoder):
58
+ def default(self, obj: Any) -> Any:
59
+ try:
60
+ return super().default(obj)
61
+ except TypeError:
62
+ return str(obj)
63
+
64
+
65
+ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
66
+ """
67
+ Execute a Python script and return its result.
68
+
69
+ Args:
70
+ script_path: Path to the Python script
71
+ script_args: Arguments to pass to the script
72
+ main_func: The name of the function to call in the script (if any)
73
+
74
+ Returns:
75
+ Result from script execution, either from the main function or the script's __return__ value
76
+
77
+ Raises:
78
+ RuntimeError: If the specified main_func is not found or not callable
79
+ """
80
+ # Save original sys.argv and modify it for the script (applies to runpy execution only)
81
+ original_argv = sys.argv
82
+ sys.argv = [script_path, *script_args]
83
+
84
+ # Create a Snowpark session before running the script
85
+ # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
86
+ session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
87
+
88
+ try:
89
+ if main_func:
90
+ # Use importlib for scripts with a main function defined
91
+ module_name = Path(script_path).stem
92
+ spec = importlib.util.spec_from_file_location(module_name, script_path)
93
+ assert spec is not None
94
+ assert spec.loader is not None
95
+ module = importlib.util.module_from_spec(spec)
96
+ spec.loader.exec_module(module)
97
+
98
+ # Validate main function
99
+ if not (func := getattr(module, main_func, None)) or not callable(func):
100
+ raise RuntimeError(f"Function '{main_func}' not a valid entrypoint for {script_path}")
101
+
102
+ # Call main function
103
+ result = func(*script_args)
104
+ return result
105
+ else:
106
+ # Use runpy for other scripts
107
+ globals_dict = runpy.run_path(script_path, run_name="__main__")
108
+ result = globals_dict.get("__return__", None)
109
+ return result
110
+ finally:
111
+ # Restore original sys.argv
112
+ sys.argv = original_argv
113
+
114
+
115
+ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
116
+ """Executes a Python script and serializes the result to JOB_RESULT_PATH.
117
+
118
+ Args:
119
+ script_path (str): Path to the Python script to execute.
120
+ script_args (Any): Arguments to pass to the script.
121
+ script_main_func (str, optional): The name of the function to call in the script (if any).
122
+
123
+ Returns:
124
+ ExecutionResult: Object containing execution results.
125
+
126
+ Raises:
127
+ Exception: Re-raises any exception caught during script execution.
128
+ """
129
+ # Run the script with the specified arguments
130
+ try:
131
+ result = run_script(script_path, *script_args, main_func=script_main_func)
132
+ result_obj = ExecutionResult(result=result)
133
+ return result_obj
134
+ except Exception as e:
135
+ tb = e.__traceback__
136
+ skip_files = {__file__, runpy.__file__}
137
+ while tb and tb.tb_frame.f_code.co_filename in skip_files:
138
+ # Skip any frames preceding user script execution
139
+ tb = tb.tb_next
140
+ cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__
141
+ cleaned_ex = cleaned_ex.with_traceback(tb)
142
+ result_obj = ExecutionResult(exception=cleaned_ex)
143
+ raise
144
+ finally:
145
+ result_dict = result_obj.to_dict()
146
+ try:
147
+ # Serialize result using cloudpickle
148
+ result_pickle_path = JOB_RESULT_PATH
149
+ with open(result_pickle_path, "wb") as f:
150
+ cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
151
+ except Exception as pkl_exc:
152
+ warnings.warn(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}", RuntimeWarning, stacklevel=1)
153
+
154
+ try:
155
+ # Serialize result to JSON as fallback path in case of cross version incompatibility
156
+ # TODO: Manually convert non-serializable types to strings
157
+ result_json_path = os.path.splitext(JOB_RESULT_PATH)[0] + ".json"
158
+ with open(result_json_path, "w") as f:
159
+ json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
160
+ except Exception as json_exc:
161
+ warnings.warn(
162
+ f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
163
+ )
164
+
165
+
166
+ if __name__ == "__main__":
167
+ # Parse command line arguments
168
+ parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
169
+ parser.add_argument("script_path", help="Path to the Python script to execute")
170
+ parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
171
+ parser.add_argument(
172
+ "--script_main_func", required=False, help="The name of the main function to call in the script"
173
+ )
174
+ args, unknown_args = parser.parse_known_args()
175
+
176
+ main(
177
+ args.script_path,
178
+ *args.script_args,
179
+ *unknown_args,
180
+ script_main_func=args.script_main_func,
181
+ )