snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/ml/_internal/platform_capabilities.py +13 -11
  3. snowflake/ml/_internal/telemetry.py +42 -13
  4. snowflake/ml/_internal/utils/identifier.py +2 -2
  5. snowflake/ml/data/data_connector.py +1 -1
  6. snowflake/ml/jobs/_utils/constants.py +10 -1
  7. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +51 -34
  9. snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
  10. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  11. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +86 -3
  12. snowflake/ml/jobs/_utils/spec_utils.py +8 -6
  13. snowflake/ml/jobs/decorators.py +13 -3
  14. snowflake/ml/jobs/job.py +206 -26
  15. snowflake/ml/jobs/manager.py +78 -34
  16. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  17. snowflake/ml/model/_client/ops/service_ops.py +31 -17
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +351 -170
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
  20. snowflake/ml/model/_client/sql/model_version.py +1 -1
  21. snowflake/ml/model/_client/sql/service.py +20 -32
  22. snowflake/ml/model/_model_composer/model_composer.py +44 -19
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
  24. snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
  25. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  32. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +5 -4
  33. snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
  34. snowflake/ml/model/custom_model.py +17 -4
  35. snowflake/ml/model/model_signature.py +3 -3
  36. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  37. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  38. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  39. snowflake/ml/modeling/cluster/birch.py +9 -1
  40. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  41. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  42. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  43. snowflake/ml/modeling/cluster/k_means.py +9 -1
  44. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  45. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  46. snowflake/ml/modeling/cluster/optics.py +9 -1
  47. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  48. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  49. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  50. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  51. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  52. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  53. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  54. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  55. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  56. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  57. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  58. snowflake/ml/modeling/covariance/oas.py +9 -1
  59. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  60. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  61. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  62. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  63. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  64. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  65. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  66. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  67. snowflake/ml/modeling/decomposition/pca.py +9 -1
  68. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  69. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  70. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  71. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  72. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  73. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  74. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  75. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  76. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  77. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  78. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  79. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  80. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  81. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  82. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  83. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  84. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  85. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  86. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  87. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  88. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  89. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  90. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  91. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  92. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  93. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  94. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  95. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  96. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  97. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  98. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  99. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  100. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  101. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  102. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  103. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  104. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  105. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  106. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  107. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  108. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  109. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  110. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  111. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  112. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  113. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  114. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  115. snowflake/ml/modeling/linear_model/lars.py +9 -1
  116. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  117. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  118. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  119. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  120. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  122. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  123. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  124. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  126. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  127. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  128. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  129. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  130. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  131. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  132. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  133. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  134. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  135. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  136. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  137. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  138. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  139. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  140. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  141. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  142. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  143. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  144. snowflake/ml/modeling/manifold/isomap.py +9 -1
  145. snowflake/ml/modeling/manifold/mds.py +9 -1
  146. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  147. snowflake/ml/modeling/manifold/tsne.py +9 -1
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  150. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  151. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  152. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  153. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  154. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  155. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  156. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  157. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  158. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  159. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  160. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  161. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  162. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  163. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  164. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  165. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  166. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  167. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  168. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  169. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  170. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  171. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  172. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  173. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  174. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  175. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  176. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  177. snowflake/ml/modeling/svm/svc.py +9 -1
  178. snowflake/ml/modeling/svm/svr.py +9 -1
  179. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  180. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  181. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  182. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  183. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  184. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  185. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  186. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  187. snowflake/ml/monitoring/explain_visualize.py +424 -0
  188. snowflake/ml/registry/_manager/model_manager.py +23 -2
  189. snowflake/ml/registry/registry.py +10 -9
  190. snowflake/ml/utils/connection_params.py +8 -2
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +58 -8
  193. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +196 -195
  194. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
  195. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  196. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,10 @@
1
1
  from snowflake.cortex._classify_text import ClassifyText, classify_text
2
- from snowflake.cortex._complete import Complete, CompleteOptions, complete
2
+ from snowflake.cortex._complete import (
3
+ Complete,
4
+ CompleteOptions,
5
+ ConversationMessage,
6
+ complete,
7
+ )
3
8
  from snowflake.cortex._embed_text_768 import EmbedText768, embed_text_768
4
9
  from snowflake.cortex._embed_text_1024 import EmbedText1024, embed_text_1024
5
10
  from snowflake.cortex._extract_answer import ExtractAnswer, extract_answer
@@ -14,6 +19,7 @@ __all__ = [
14
19
  "Complete",
15
20
  "complete",
16
21
  "CompleteOptions",
22
+ "ConversationMessage",
17
23
  "EmbedText768",
18
24
  "embed_text_768",
19
25
  "EmbedText1024",
@@ -11,6 +11,9 @@ from snowflake.snowpark import (
11
11
  session as snowpark_session,
12
12
  )
13
13
 
14
+ LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
15
+ INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC"
16
+
14
17
 
15
18
  class PlatformCapabilities:
16
19
  """Class that retrieves platform feature values for the currently running server.
@@ -18,12 +21,12 @@ class PlatformCapabilities:
18
21
  Example usage:
19
22
  ```
20
23
  pc = PlatformCapabilities.get_instance(session)
21
- if pc.is_nested_function_enabled():
22
- # Nested functions are enabled.
23
- print("Nested functions are enabled.")
24
+ if pc.is_inlined_deployment_spec_enabled():
25
+ # Inline deployment spec is enabled.
26
+ print("Inline deployment spec is enabled.")
24
27
  else:
25
- # Nested functions are disabled.
26
- print("Nested functions are disabled or not supported.")
28
+ # Inline deployment spec is disabled.
29
+ print("Inline deployment spec is disabled or not supported.")
27
30
  ```
28
31
  """
29
32
 
@@ -50,9 +53,11 @@ class PlatformCapabilities:
50
53
 
51
54
  # For contextmanager, we need to have return type Iterator[Never]. However, Never type is introduced only in
52
55
  # Python 3.11. So, we are ignoring the type for this method.
56
+ _dummy_features: dict[str, Any] = {"dummy": "dummy"}
57
+
53
58
  @classmethod # type: ignore[arg-type]
54
59
  @contextmanager
55
- def mock_features(cls, features: dict[str, Any]) -> None: # type: ignore[misc]
60
+ def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
56
61
  logging.debug(f"Setting mock features: {features}")
57
62
  cls.set_mock_features(features)
58
63
  try:
@@ -61,14 +66,11 @@ class PlatformCapabilities:
61
66
  logging.debug(f"Clearing mock features: {features}")
62
67
  cls.clear_mock_features()
63
68
 
64
- def is_nested_function_enabled(self) -> bool:
65
- return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
66
-
67
69
  def is_inlined_deployment_spec_enabled(self) -> bool:
68
- return self._get_bool_feature("ENABLE_INLINE_DEPLOYMENT_SPEC", False)
70
+ return self._get_bool_feature(INLINE_DEPLOYMENT_SPEC_PARAMETER, False)
69
71
 
70
72
  def is_live_commit_enabled(self) -> bool:
71
- return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
73
+ return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
72
74
 
73
75
  @staticmethod
74
76
  def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
@@ -4,6 +4,7 @@ import enum
4
4
  import functools
5
5
  import inspect
6
6
  import operator
7
+ import os
7
8
  import sys
8
9
  import time
9
10
  import traceback
@@ -13,7 +14,7 @@ from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union, c
13
14
  from typing_extensions import ParamSpec
14
15
 
15
16
  from snowflake import connector
16
- from snowflake.connector import telemetry as connector_telemetry, time_util
17
+ from snowflake.connector import connect, telemetry as connector_telemetry, time_util
17
18
  from snowflake.ml import version as snowml_version
18
19
  from snowflake.ml._internal import env
19
20
  from snowflake.ml._internal.exceptions import (
@@ -37,6 +38,37 @@ _Args = ParamSpec("_Args")
37
38
  _ReturnValue = TypeVar("_ReturnValue")
38
39
 
39
40
 
41
+ def _get_login_token() -> Union[str, bytes]:
42
+ with open("/snowflake/session/token") as f:
43
+ return f.read()
44
+
45
+
46
+ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
47
+ conn = None
48
+ if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
49
+ try:
50
+ conn = connect(
51
+ host=os.getenv("SNOWFLAKE_HOST"),
52
+ account=os.getenv("SNOWFLAKE_ACCOUNT"),
53
+ token=_get_login_token(),
54
+ authenticator="oauth",
55
+ )
56
+ except Exception:
57
+ # Failed to get a new SnowflakeConnection in SPCS. Fall back to using the active session.
58
+ # This will work in some cases once SPCS enables multiple authentication modes, and users select any auth.
59
+ pass
60
+
61
+ if conn is None:
62
+ try:
63
+ active_session = next(iter(session._get_active_sessions()))
64
+ conn = active_session._conn._conn if active_session.telemetry_enabled else None
65
+ except snowpark_exceptions.SnowparkSessionException:
66
+ # Failed to get an active session. No connection available.
67
+ pass
68
+
69
+ return conn
70
+
71
+
40
72
  @enum.unique
41
73
  class TelemetryProject(enum.Enum):
42
74
  MLOPS = "MLOps"
@@ -378,10 +410,14 @@ def send_custom_usage(
378
410
  data: Optional[dict[str, Any]] = None,
379
411
  **kwargs: Any,
380
412
  ) -> None:
381
- active_session = next(iter(session._get_active_sessions()))
382
- assert active_session, "Missing active session object"
413
+ conn = _get_snowflake_connection()
414
+ if conn is None:
415
+ raise ValueError(
416
+ """Snowflake connection is required to send custom telemetry. This means there
417
+ must be at least one active session, or that telemetry is being sent from within an SPCS service."""
418
+ )
383
419
 
384
- client = _SourceTelemetryClient(conn=active_session._conn._conn, project=project, subproject=subproject)
420
+ client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
385
421
  common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
386
422
  data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
387
423
  client._send(msg=data)
@@ -501,7 +537,6 @@ def send_api_usage_telemetry(
501
537
  return update_stmt_params_if_snowpark_df(result, statement_params)
502
538
 
503
539
  # prioritize `conn_attr_name` over the active session
504
- telemetry_enabled = True
505
540
  if conn_attr_name:
506
541
  # raise AttributeError if conn attribute does not exist in `self`
507
542
  conn = operator.attrgetter(conn_attr_name)(args[0])
@@ -509,16 +544,10 @@ def send_api_usage_telemetry(
509
544
  raise TypeError(
510
545
  f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
511
546
  )
512
- # get an active session
513
547
  else:
514
- try:
515
- active_session = next(iter(session._get_active_sessions()))
516
- conn = active_session._conn._conn
517
- telemetry_enabled = active_session.telemetry_enabled
518
- except snowpark_exceptions.SnowparkSessionException:
519
- conn = None
548
+ conn = _get_snowflake_connection()
520
549
 
521
- if conn is None or not telemetry_enabled:
550
+ if conn is None:
522
551
  # Telemetry not enabled, just execute without our additional telemetry logic
523
552
  try:
524
553
  return ctx.run(execute_func_with_statement_params)
@@ -12,7 +12,7 @@ SF_IDENTIFIER_RE = re.compile(_SF_IDENTIFIER)
12
12
  _SF_SCHEMA_LEVEL_OBJECT = (
13
13
  rf"(?:(?:(?P<db>{_SF_IDENTIFIER})\.)?(?P<schema>{_SF_IDENTIFIER})\.)?(?P<object>{_SF_IDENTIFIER})"
14
14
  )
15
- _SF_STAGE_PATH = rf"{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>.*)"
15
+ _SF_STAGE_PATH = rf"@?{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>/.*)?"
16
16
  _SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)
17
17
  _SF_STAGE_PATH_RE = re.compile(_SF_STAGE_PATH)
18
18
 
@@ -197,7 +197,7 @@ def parse_snowflake_stage_path(
197
197
  res.group("db"),
198
198
  res.group("schema"),
199
199
  res.group("object"),
200
- res.group("path"),
200
+ res.group("path") or "",
201
201
  )
202
202
 
203
203
 
@@ -249,7 +249,7 @@ class DataConnector:
249
249
 
250
250
  # Switch to use Runtime's Data Ingester if running in ML runtime
251
251
  # Fail silently if the data ingester is not found
252
- if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
252
+ if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR, "").lower() in ("true", "1"):
253
253
  try:
254
254
  from runtime_external_entities import get_ingester_class
255
255
 
@@ -5,6 +5,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
7
  RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
8
+ MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
8
9
  MEMORY_VOLUME_NAME = "dshm"
9
10
  STAGE_VOLUME_NAME = "stage-volume"
10
11
  STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
@@ -13,7 +14,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
13
14
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
14
15
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
15
16
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
16
- DEFAULT_IMAGE_TAG = "1.0.1"
17
+ DEFAULT_IMAGE_TAG = "1.2.3"
17
18
  DEFAULT_ENTRYPOINT_PATH = "func.py"
18
19
 
19
20
  # Percent of container memory to allocate for /dev/shm volume
@@ -37,6 +38,7 @@ RAY_PORTS = {
37
38
  # Node health check configuration
38
39
  # TODO(SNOW-1937020): Revisit the health check configuration
39
40
  ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
41
+ ENABLE_HEALTH_CHECKS_ENV_VAR = "ENABLE_HEALTH_CHECKS"
40
42
  ENABLE_HEALTH_CHECKS = "false"
41
43
 
42
44
  # Job status polling constants
@@ -47,6 +49,13 @@ JOB_POLL_MAX_DELAY_SECONDS = 1
47
49
  IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
48
50
  RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
49
51
 
52
+ # Log start and end messages
53
+ LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
54
+ LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
55
+
56
+ # Default setting for verbose logging in get_log function
57
+ DEFAULT_VERBOSE_LOG = False
58
+
50
59
  # Compute pool resource information
51
60
  # TODO: Query Snowflake for resource information instead of relying on this hardcoded
52
61
  # table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
@@ -80,7 +80,7 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
80
80
  # TODO: Check if file exists
81
81
  with session.file.get_stream(result_path) as result_stream:
82
82
  return ExecutionResult.from_dict(pickle.load(result_stream))
83
- except (sp_exceptions.SnowparkSQLException, TypeError, pickle.UnpicklingError):
83
+ except (sp_exceptions.SnowparkSQLException, pickle.UnpicklingError, TypeError, ImportError):
84
84
  # Fall back to JSON result if loading pickled result fails for any reason
85
85
  result_json_path = os.path.splitext(result_path)[0] + ".json"
86
86
  with session.file.get_stream(result_json_path) as result_stream:
@@ -9,6 +9,7 @@ from pathlib import Path, PurePath
9
9
  from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
10
10
 
11
11
  import cloudpickle as cp
12
+ from packaging import version
12
13
 
13
14
  from snowflake import snowpark
14
15
  from snowflake.ml.jobs._utils import constants, types
@@ -97,11 +98,23 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
97
98
  head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
98
99
  if [ $? -eq 0 ]; then
99
100
  # Parse the output using read
100
- read head_index head_ip <<< "$head_info"
101
+ read head_index head_ip head_status<<< "$head_info"
102
+
103
+ if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
104
+ NODE_TYPE="worker"
105
+ echo "{constants.LOG_START_MSG}"
106
+ fi
101
107
 
102
108
  # Use the parsed variables
103
109
  echo "Head Instance Index: $head_index"
104
110
  echo "Head Instance IP: $head_ip"
111
+ echo "Head Instance Status: $head_status"
112
+
113
+ # If the head status is not "READY" or "PENDING", exit early
114
+ if [ "$head_status" != "READY" ] && [ "$head_status" != "PENDING" ]; then
115
+ echo "Head instance status is not READY or PENDING. Exiting."
116
+ exit 0
117
+ fi
105
118
 
106
119
  else
107
120
  echo "Error: Failed to get head instance information."
@@ -109,9 +122,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
109
122
  exit 1
110
123
  fi
111
124
 
112
- if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
113
- NODE_TYPE="worker"
114
- fi
125
+
115
126
  fi
116
127
 
117
128
  # Common parameters for both head and worker nodes
@@ -160,6 +171,10 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
160
171
  # Start Ray on a worker node - run in background
161
172
  ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
162
173
 
174
+ echo "Worker node started on address $eth0Ip. See more logs in the head node."
175
+
176
+ echo "{constants.LOG_END_MSG}"
177
+
163
178
  # Start the worker shutdown listener in the background
164
179
  echo "Starting worker shutdown listener..."
165
180
  python worker_shutdown_listener.py
@@ -181,15 +196,16 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
181
196
 
182
197
  # Start Ray on the head node
183
198
  ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
199
+
184
200
  ##### End Ray configuration #####
185
201
 
186
202
  # TODO: Monitor MLRS and handle process crashes
187
203
  python -m web.ml_runtime_grpc_server &
188
204
 
189
205
  # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
206
+ echo Running command: python "$@"
190
207
 
191
208
  # Run user's Python entrypoint
192
- echo Running command: python "$@"
193
209
  python "$@"
194
210
 
195
211
  # After the user's job completes, signal workers to shut down
@@ -278,17 +294,19 @@ class JobPayload:
278
294
  stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
279
295
  source = resolve_source(self.source)
280
296
  entrypoint = resolve_entrypoint(source, self.entrypoint)
297
+ pip_requirements = self.pip_requirements or []
281
298
 
282
299
  # Create stage if necessary
283
300
  stage_name = stage_path.parts[0].lstrip("@")
284
301
  # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
285
302
  try:
286
- session.sql(f"describe stage {stage_name}").collect()
303
+ session.sql("describe stage identifier(?)", params=[stage_name]).collect()
287
304
  except sp_exceptions.SnowparkSQLException:
288
305
  session.sql(
289
- f"create stage if not exists {stage_name}"
306
+ "create stage if not exists identifier(?)"
290
307
  " encryption = ( type = 'SNOWFLAKE_SSE' )"
291
- " comment = 'Created by snowflake.ml.jobs Python API'"
308
+ " comment = 'Created by snowflake.ml.jobs Python API'",
309
+ params=[stage_name],
292
310
  ).collect()
293
311
 
294
312
  # Upload payload to stage
@@ -301,6 +319,8 @@ class JobPayload:
301
319
  overwrite=True,
302
320
  )
303
321
  source = Path(entrypoint.file_path.parent)
322
+ if not any(r.startswith("cloudpickle") for r in pip_requirements):
323
+ pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
304
324
  elif source.is_dir():
305
325
  # Manually traverse the directory and upload each file, since Snowflake PUT
306
326
  # can't handle directories. Reduce the number of PUT operations by using
@@ -325,10 +345,10 @@ class JobPayload:
325
345
 
326
346
  # Upload requirements
327
347
  # TODO: Check if payload includes both a requirements.txt file and pip_requirements
328
- if self.pip_requirements:
348
+ if pip_requirements:
329
349
  # Upload requirements.txt to stage
330
350
  session.file.put_stream(
331
- io.BytesIO("\n".join(self.pip_requirements).encode()),
351
+ io.BytesIO("\n".join(pip_requirements).encode()),
332
352
  stage_location=stage_path.joinpath("requirements.txt").as_posix(),
333
353
  auto_compress=False,
334
354
  overwrite=True,
@@ -495,13 +515,6 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
495
515
  # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
496
516
  source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
497
517
 
498
- func_code = f"""
499
- {source_code_comment}
500
-
501
- import pickle
502
- {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
503
- """
504
-
505
518
  arg_dict_name = "kwargs"
506
519
  if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
507
520
  param_code = f"{arg_dict_name} = {{}}"
@@ -509,25 +522,29 @@ import pickle
509
522
  param_code = _generate_param_handler_code(signature, arg_dict_name)
510
523
 
511
524
  return f"""
512
- ### Version guard to check compatibility across Python versions ###
513
- import os
514
525
  import sys
515
- import warnings
516
-
517
- if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
518
- warnings.warn(
519
- "Python version mismatch: job was created using"
520
- " python{sys.version_info.major}.{sys.version_info.minor}"
521
- f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
522
- " Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
523
- " This will be fixed in a future release; for now, please use Python version"
524
- f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
525
- RuntimeWarning,
526
- stacklevel=0,
527
- )
528
- ### End version guard ###
526
+ import pickle
529
527
 
530
- {func_code.strip()}
528
+ try:
529
+ {textwrap.indent(source_code_comment, ' ')}
530
+ {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
531
+ except (TypeError, pickle.PickleError):
532
+ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
533
+ raise RuntimeError(
534
+ "Failed to deserialize function due to Python version mismatch."
535
+ f" Runtime environment is Python {{sys.version_info.major}}.{{sys.version_info.minor}}"
536
+ " but function was serialized using Python {sys.version_info.major}.{sys.version_info.minor}."
537
+ ) from None
538
+ raise
539
+ except AttributeError as e:
540
+ if 'cloudpickle' in str(e):
541
+ import cloudpickle as cp
542
+ raise RuntimeError(
543
+ "Failed to deserialize function due to cloudpickle version mismatch."
544
+ f" Runtime environment uses cloudpickle=={{cp.__version__}}"
545
+ " but job was serialized using cloudpickle=={cp.__version__}."
546
+ ) from e
547
+ raise
531
548
 
532
549
  if __name__ == '__main__':
533
550
  {textwrap.indent(param_code, ' ')}
@@ -2,3 +2,9 @@
2
2
  SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
3
3
  SHUTDOWN_ACTOR_NAMESPACE = "default"
4
4
  SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
5
+
6
+
7
+ # Log start and end messages
8
+ # Inherited from snowflake.ml.jobs._utils.constants
9
+ LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
10
+ LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
@@ -29,7 +29,7 @@ def get_self_ip() -> Optional[str]:
29
29
  return None
30
30
 
31
31
 
32
- def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
32
+ def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
33
33
  """Get the first instance of a batch job based on start time and instance ID.
34
34
 
35
35
  Args:
@@ -42,7 +42,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
42
42
 
43
43
  session = session_utils.get_session()
44
44
  df = session.sql(f"show service instances in service {service_name}")
45
- result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
45
+ result = df.select('"instance_id"', '"ip_address"', '"start_time"', '"status"').collect()
46
46
 
47
47
  if not result:
48
48
  return None
@@ -57,7 +57,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
57
57
  ip_address = head_instance["ip_address"]
58
58
  try:
59
59
  socket.inet_aton(ip_address) # Validate IPv4 address
60
- return (head_instance["instance_id"], ip_address)
60
+ return (head_instance["instance_id"], ip_address, head_instance["status"])
61
61
  except OSError:
62
62
  logger.error(f"Error: Invalid IP address format: {ip_address}")
63
63
  return None
@@ -110,7 +110,7 @@ def main():
110
110
  head_info = get_first_instance(args.service_name)
111
111
  if head_info:
112
112
  # Print to stdout to allow capture but don't use logger
113
- sys.stdout.write(f"{head_info[0]} {head_info[1]}\n")
113
+ sys.stdout.write(" ".join(head_info) + "\n")
114
114
  sys.exit(0)
115
115
  time.sleep(args.retry_interval)
116
116
  # If we get here, we've timed out
@@ -2,25 +2,35 @@ import argparse
2
2
  import copy
3
3
  import importlib.util
4
4
  import json
5
+ import logging
5
6
  import os
6
7
  import runpy
7
8
  import sys
9
+ import time
8
10
  import traceback
9
11
  import warnings
10
12
  from pathlib import Path
11
13
  from typing import Any, Optional
12
14
 
13
15
  import cloudpickle
16
+ from constants import LOG_END_MSG, LOG_START_MSG
14
17
 
15
18
  from snowflake.ml.jobs._utils import constants
16
19
  from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
17
20
  from snowflake.snowpark import Session
18
21
 
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
+ logger = logging.getLogger(__name__)
25
+
19
26
  # Fallbacks in case of SnowML version mismatch
20
27
  RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
21
-
22
28
  JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
23
29
 
30
+ # Constants for the wait_for_min_instances function
31
+ CHECK_INTERVAL = 10 # seconds
32
+ TIMEOUT = 720 # seconds
33
+
24
34
 
25
35
  try:
26
36
  from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
@@ -59,7 +69,67 @@ class SimpleJSONEncoder(json.JSONEncoder):
59
69
  try:
60
70
  return super().default(obj)
61
71
  except TypeError:
62
- return str(obj)
72
+ return f"Unserializable object: {repr(obj)}"
73
+
74
+
75
+ def get_active_node_count() -> int:
76
+ """
77
+ Count the number of active nodes in the Ray cluster.
78
+
79
+ Returns:
80
+ int: Total count of active nodes
81
+ """
82
+ import ray
83
+
84
+ if not ray.is_initialized():
85
+ ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)
86
+ try:
87
+ nodes = [node for node in ray.nodes() if node.get("Alive")]
88
+ total_active = len(nodes)
89
+
90
+ logger.info(f"Active nodes: {total_active}")
91
+ return total_active
92
+ except Exception as e:
93
+ logger.warning(f"Error getting active node count: {e}")
94
+ return 0
95
+
96
+
97
+ def wait_for_min_instances(min_instances: int) -> None:
98
+ """
99
+ Wait until the specified minimum number of instances are available in the Ray cluster.
100
+
101
+ Args:
102
+ min_instances: Minimum number of instances required
103
+
104
+ Raises:
105
+ TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
106
+ """
107
+ if min_instances <= 1:
108
+ logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
109
+ return
110
+
111
+ start_time = time.time()
112
+ timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
113
+ check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
114
+ logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
115
+
116
+ while time.time() - start_time < timeout:
117
+ total_nodes = get_active_node_count()
118
+
119
+ if total_nodes >= min_instances:
120
+ elapsed = time.time() - start_time
121
+ logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
122
+ return
123
+
124
+ logger.debug(
125
+ f"Waiting for instances: {total_nodes}/{min_instances} available "
126
+ f"(elapsed: {time.time() - start_time:.1f}s)"
127
+ )
128
+ time.sleep(check_interval)
129
+
130
+ raise TimeoutError(
131
+ f"Timed out after {timeout}s waiting for {min_instances} instances, only {get_active_node_count()} available"
132
+ )
63
133
 
64
134
 
65
135
  def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
@@ -86,6 +156,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
86
156
  session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
87
157
 
88
158
  try:
159
+
89
160
  if main_func:
90
161
  # Use importlib for scripts with a main function defined
91
162
  module_name = Path(script_path).stem
@@ -126,9 +197,21 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
126
197
  Raises:
127
198
  Exception: Re-raises any exception caught during script execution.
128
199
  """
129
- # Run the script with the specified arguments
130
200
  try:
201
+ # Wait for minimum required instances if specified
202
+ min_instances_str = os.environ.get("JOB_MIN_INSTANCES", 1)
203
+ if min_instances_str and int(min_instances_str) > 1:
204
+ wait_for_min_instances(int(min_instances_str))
205
+
206
+ # Log start marker for user script execution
207
+ print(LOG_START_MSG) # noqa: T201
208
+
209
+ # Run the script with the specified arguments
131
210
  result = run_script(script_path, *script_args, main_func=script_main_func)
211
+
212
+ # Log end marker for user script execution
213
+ print(LOG_END_MSG) # noqa: T201
214
+
132
215
  result_obj = ExecutionResult(result=result)
133
216
  return result_obj
134
217
  except Exception as e:
@@ -11,7 +11,7 @@ from snowflake.ml.jobs._utils import constants, types
11
11
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
12
12
  """Extract resource information for the specified compute pool"""
13
13
  # Get the instance family
14
- rows = session.sql(f"show compute pools like '{compute_pool}'").collect()
14
+ rows = session.sql("show compute pools like ?", params=[compute_pool]).collect()
15
15
  if not rows:
16
16
  raise ValueError(f"Compute pool '{compute_pool}' not found")
17
17
  instance_family: str = rows[0]["instance_family"]
@@ -85,7 +85,8 @@ def generate_service_spec(
85
85
  compute_pool: str,
86
86
  payload: types.UploadedPayload,
87
87
  args: Optional[list[str]] = None,
88
- num_instances: Optional[int] = None,
88
+ target_instances: int = 1,
89
+ min_instances: int = 1,
89
90
  enable_metrics: bool = False,
90
91
  ) -> dict[str, Any]:
91
92
  """
@@ -96,13 +97,13 @@ def generate_service_spec(
96
97
  compute_pool: Compute pool for job execution
97
98
  payload: Uploaded job payload
98
99
  args: Arguments to pass to entrypoint script
99
- num_instances: Number of instances for multi-node job
100
+ target_instances: Number of instances for multi-node job
100
101
  enable_metrics: Enable platform metrics for the job
102
+ min_instances: Minimum number of instances required to start the job
101
103
 
102
104
  Returns:
103
105
  Job service specification
104
106
  """
105
- is_multi_node = num_instances is not None and num_instances > 1
106
107
  image_spec = _get_image_spec(session, compute_pool)
107
108
 
108
109
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
@@ -180,10 +181,11 @@ def generate_service_spec(
180
181
  }
181
182
  endpoints = []
182
183
 
183
- if is_multi_node:
184
+ if target_instances > 1:
184
185
  # Update environment variables for multi-node job
185
186
  env_vars.update(constants.RAY_PORTS)
186
- env_vars["ENABLE_HEALTH_CHECKS"] = constants.ENABLE_HEALTH_CHECKS
187
+ env_vars[constants.ENABLE_HEALTH_CHECKS_ENV_VAR] = constants.ENABLE_HEALTH_CHECKS
188
+ env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
187
189
 
188
190
  # Define Ray endpoints for intra-service instance communication
189
191
  ray_endpoints = [