snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. snowflake/cortex/_complete.py +19 -0
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/platform_capabilities.py +87 -0
  4. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  5. snowflake/ml/_internal/telemetry.py +21 -0
  6. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  7. snowflake/ml/dataset/dataset.py +0 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +6 -0
  11. snowflake/ml/jobs/__init__.py +21 -0
  12. snowflake/ml/jobs/_utils/constants.py +57 -0
  13. snowflake/ml/jobs/_utils/payload_utils.py +438 -0
  14. snowflake/ml/jobs/_utils/spec_utils.py +296 -0
  15. snowflake/ml/jobs/_utils/types.py +39 -0
  16. snowflake/ml/jobs/decorators.py +71 -0
  17. snowflake/ml/jobs/job.py +113 -0
  18. snowflake/ml/jobs/manager.py +298 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +11 -2
  20. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  21. snowflake/ml/model/_client/sql/service.py +13 -6
  22. snowflake/ml/model/_packager/model_env/model_env.py +45 -28
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +19 -6
  24. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  25. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +17 -0
  26. snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
  28. snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
  30. snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  32. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
  33. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
  36. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  37. snowflake/ml/model/_signatures/base_handler.py +1 -2
  38. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  39. snowflake/ml/model/_signatures/core.py +2 -2
  40. snowflake/ml/model/_signatures/numpy_handler.py +11 -12
  41. snowflake/ml/model/_signatures/pandas_handler.py +11 -9
  42. snowflake/ml/model/_signatures/pytorch_handler.py +3 -6
  43. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  44. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  45. snowflake/ml/model/model_signature.py +25 -4
  46. snowflake/ml/model/type_hints.py +15 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  51. snowflake/ml/modeling/cluster/birch.py +6 -3
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  53. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  55. snowflake/ml/modeling/cluster/k_means.py +6 -3
  56. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  58. snowflake/ml/modeling/cluster/optics.py +6 -3
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  62. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  69. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  70. snowflake/ml/modeling/covariance/oas.py +6 -3
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  74. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  79. snowflake/ml/modeling/decomposition/pca.py +6 -3
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  110. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  111. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  112. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  113. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  114. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  115. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  116. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  117. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  118. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  119. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  120. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  121. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  122. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  123. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  124. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  125. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  126. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  127. snowflake/ml/modeling/linear_model/lars.py +6 -3
  128. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  129. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  130. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  131. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  132. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  133. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  134. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  135. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  136. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  137. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  139. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  140. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  141. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  142. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  143. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  144. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  145. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  146. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  148. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  149. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  151. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  152. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  153. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  154. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  155. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  156. snowflake/ml/modeling/manifold/isomap.py +6 -3
  157. snowflake/ml/modeling/manifold/mds.py +6 -3
  158. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  159. snowflake/ml/modeling/manifold/tsne.py +6 -3
  160. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  161. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  162. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  163. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  174. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  184. snowflake/ml/modeling/pipeline/pipeline.py +28 -3
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -5
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  188. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  189. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  190. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  191. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  192. snowflake/ml/modeling/svm/svc.py +6 -3
  193. snowflake/ml/modeling/svm/svr.py +6 -3
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
  202. snowflake/ml/registry/registry.py +34 -4
  203. snowflake/ml/version.py +1 -1
  204. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +81 -33
  205. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +208 -196
  206. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
  207. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
  208. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,57 @@
1
+ from snowflake.ml._internal.utils.snowflake_env import SnowflakeCloudType
2
+ from snowflake.ml.jobs._utils.types import ComputeResources
3
+
4
+ # SPCS specification constants
5
+ DEFAULT_CONTAINER_NAME = "main"
6
+ PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
+ MEMORY_VOLUME_NAME = "dshm"
8
+ STAGE_VOLUME_NAME = "stage-volume"
9
+ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
10
+
11
+ # Default container image information
12
+ DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
13
+ DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
14
+ DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
15
+ DEFAULT_IMAGE_TAG = "0.9.2"
16
+ DEFAULT_ENTRYPOINT_PATH = "func.py"
17
+
18
+ # Percent of container memory to allocate for /dev/shm volume
19
+ MEMORY_VOLUME_SIZE = 0.3
20
+
21
+ # Job status polling constants
22
+ JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
23
+ JOB_POLL_MAX_DELAY_SECONDS = 1
24
+
25
+ # Magic attributes
26
+ IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
27
+
28
+ # Compute pool resource information
29
+ # TODO: Query Snowflake for resource information instead of relying on this hardcoded
30
+ # table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
31
+ COMMON_INSTANCE_FAMILIES = {
32
+ "CPU_X64_XS": ComputeResources(cpu=1, memory=6),
33
+ "CPU_X64_S": ComputeResources(cpu=3, memory=13),
34
+ "CPU_X64_M": ComputeResources(cpu=6, memory=28),
35
+ "CPU_X64_L": ComputeResources(cpu=28, memory=116),
36
+ "HIGHMEM_X64_S": ComputeResources(cpu=6, memory=58),
37
+ }
38
+ AWS_INSTANCE_FAMILIES = {
39
+ "HIGHMEM_X64_M": ComputeResources(cpu=28, memory=240),
40
+ "HIGHMEM_X64_L": ComputeResources(cpu=124, memory=984),
41
+ "GPU_NV_S": ComputeResources(cpu=6, memory=27, gpu=1, gpu_type="A10G"),
42
+ "GPU_NV_M": ComputeResources(cpu=44, memory=178, gpu=4, gpu_type="A10G"),
43
+ "GPU_NV_L": ComputeResources(cpu=92, memory=1112, gpu=8, gpu_type="A100"),
44
+ }
45
+ AZURE_INSTANCE_FAMILIES = {
46
+ "HIGHMEM_X64_M": ComputeResources(cpu=28, memory=244),
47
+ "HIGHMEM_X64_L": ComputeResources(cpu=92, memory=654),
48
+ "GPU_NV_XS": ComputeResources(cpu=3, memory=26, gpu=1, gpu_type="T4"),
49
+ "GPU_NV_SM": ComputeResources(cpu=32, memory=424, gpu=1, gpu_type="A10"),
50
+ "GPU_NV_2M": ComputeResources(cpu=68, memory=858, gpu=2, gpu_type="A10"),
51
+ "GPU_NV_3M": ComputeResources(cpu=44, memory=424, gpu=2, gpu_type="A100"),
52
+ "GPU_NV_SL": ComputeResources(cpu=92, memory=858, gpu=4, gpu_type="A100"),
53
+ }
54
+ CLOUD_INSTANCE_FAMILIES = {
55
+ SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
56
+ SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
57
+ }
@@ -0,0 +1,438 @@
1
+ import functools
2
+ import inspect
3
+ import io
4
+ import itertools
5
+ import pickle
6
+ import sys
7
+ import textwrap
8
+ from pathlib import Path, PurePath
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ List,
13
+ Optional,
14
+ Type,
15
+ Union,
16
+ cast,
17
+ get_args,
18
+ get_origin,
19
+ )
20
+
21
+ import cloudpickle as cp
22
+
23
+ from snowflake import snowpark
24
+ from snowflake.ml.jobs._utils import constants, types
25
+ from snowflake.snowpark import exceptions as sp_exceptions
26
+ from snowflake.snowpark._internal import code_generation
27
+
28
+ _SUPPORTED_ARG_TYPES = {str, int, float}
29
+ _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
30
+ _STARTUP_SCRIPT_PATH = PurePath("startup.sh")
31
+ _STARTUP_SCRIPT_CODE = textwrap.dedent(
32
+ f"""
33
+ #!/bin/bash
34
+
35
+ ##### Perform common set up steps #####
36
+ set -e # exit if a command fails
37
+
38
+ echo "Creating log directories..."
39
+ mkdir -p /var/log/managedservices/user/mlrs
40
+ mkdir -p /var/log/managedservices/system/mlrs
41
+ mkdir -p /var/log/managedservices/system/ray
42
+
43
+ echo "*/1 * * * * root /etc/ray_copy_cron.sh" >> /etc/cron.d/ray_copy_cron
44
+ echo "" >> /etc/cron.d/ray_copy_cron
45
+ chmod 744 /etc/cron.d/ray_copy_cron
46
+
47
+ service cron start
48
+
49
+ mkdir -p /tmp/prometheus-multi-dir
50
+
51
+ # Change directory to user payload directory
52
+ if [ -n "${constants.PAYLOAD_DIR_ENV_VAR}" ]; then
53
+ cd ${constants.PAYLOAD_DIR_ENV_VAR}
54
+ fi
55
+
56
+ ##### Set up Python environment #####
57
+ export PYTHONPATH=/opt/env/site-packages/
58
+ MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
59
+ if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
60
+ # TODO: Prevent collisions with MLRS packages using virtualenvs
61
+ echo "Installing packages from $MLRS_REQUIREMENTS_FILE"
62
+ pip install -r $MLRS_REQUIREMENTS_FILE
63
+ fi
64
+
65
+ MLRS_CONDA_ENV_FILE=${{MLRS_CONDA_ENV_FILE:-"environment.yml"}}
66
+ if [ -f "${{MLRS_CONDA_ENV_FILE}}" ]; then
67
+ # TODO: Handle conda environment
68
+ echo "Custom conda environments not currently supported"
69
+ exit 1
70
+ fi
71
+ ##### End Python environment setup #####
72
+
73
+ ##### Ray configuration #####
74
+ shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
75
+
76
+ # Configure IP address and logging directory
77
+ eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
78
+ log_dir="/tmp/ray"
79
+
80
+ # Check if eth0Ip is a valid IP address and fall back to default if necessary
81
+ if [[ ! $eth0Ip =~ ^[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+$ ]]; then
82
+ eth0Ip="127.0.0.1"
83
+ fi
84
+
85
+ # Common parameters for both head and worker nodes
86
+ common_params=(
87
+ "--node-ip-address=$eth0Ip"
88
+ "--object-manager-port=${{RAY_OBJECT_MANAGER_PORT:-12011}}"
89
+ "--node-manager-port=${{RAY_NODE_MANAGER_PORT:-12012}}"
90
+ "--runtime-env-agent-port=${{RAY_RUNTIME_ENV_AGENT_PORT:-12013}}"
91
+ "--dashboard-agent-grpc-port=${{RAY_DASHBOARD_AGENT_GRPC_PORT:-12014}}"
92
+ "--dashboard-agent-listen-port=${{RAY_DASHBOARD_AGENT_LISTEN_PORT:-12015}}"
93
+ "--min-worker-port=${{RAY_MIN_WORKER_PORT:-12031}}"
94
+ "--max-worker-port=${{RAY_MAX_WORKER_PORT:-13000}}"
95
+ "--metrics-export-port=11502"
96
+ "--temp-dir=$log_dir"
97
+ "--disable-usage-stats"
98
+ )
99
+
100
+ # Additional head-specific parameters
101
+ head_params=(
102
+ "--head"
103
+ "--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
104
+ "--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Listening port for Ray Client Server
105
+ "--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
106
+ "--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc on
107
+ "--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for local debugging
108
+ "--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
109
+ )
110
+
111
+ # Start Ray on the head node
112
+ ray start "${{common_params[@]}}" "${{head_params[@]}}" &
113
+ ##### End Ray configuration #####
114
+
115
+ # TODO: Monitor MLRS and handle process crashes
116
+ python -m web.ml_runtime_grpc_server &
117
+
118
+ # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
119
+
120
+ # Run user's Python entrypoint
121
+ echo Running command: python "$@"
122
+ python "$@"
123
+ """
124
+ ).strip()
125
+
126
+
127
+ def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
128
+ parent = parent.absolute()
129
+ if entrypoint is None:
130
+ if parent.is_file():
131
+ # Infer entrypoint from source
132
+ entrypoint = parent
133
+ else:
134
+ raise ValueError("entrypoint must be provided when source is a directory")
135
+ elif entrypoint.is_absolute():
136
+ # Absolute path - validate it's a subpath of source dir
137
+ if not entrypoint.is_relative_to(parent):
138
+ raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint})")
139
+ else:
140
+ # Relative path
141
+ if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
142
+ # Relative to working dir iff path is relative to source dir and exists
143
+ entrypoint = abs_entrypoint
144
+ else:
145
+ # Relative to source dir
146
+ entrypoint = parent.joinpath(entrypoint)
147
+ if not entrypoint.is_file():
148
+ raise FileNotFoundError(
149
+ "Entrypoint not found. Ensure the entrypoint is a valid file and is under"
150
+ f" the source directory (source={parent}, entrypoint={entrypoint})"
151
+ )
152
+ return entrypoint
153
+
154
+
155
+ class JobPayload:
156
+ def __init__(
157
+ self,
158
+ source: Union[str, Path, Callable[..., Any]],
159
+ entrypoint: Optional[Union[str, Path]] = None,
160
+ *,
161
+ pip_requirements: Optional[List[str]] = None,
162
+ ) -> None:
163
+ self.source = Path(source) if isinstance(source, str) else source
164
+ self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
165
+ self.pip_requirements = pip_requirements
166
+
167
+ def validate(self) -> None:
168
+ if callable(self.source):
169
+ # Any entrypoint value is OK for callable payloads (including None aka default)
170
+ # since we will generate the file from the serialized callable
171
+ pass
172
+ elif isinstance(self.source, Path):
173
+ # Validate source
174
+ source = self.source
175
+ if not source.exists():
176
+ raise FileNotFoundError(f"{source} does not exist")
177
+ source = source.absolute()
178
+
179
+ # Validate entrypoint
180
+ entrypoint = _resolve_entrypoint(source, self.entrypoint)
181
+ if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
182
+ raise ValueError(
183
+ "Unsupported entrypoint type:"
184
+ f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
185
+ )
186
+
187
+ # Update fields with normalized values
188
+ self.source = source
189
+ self.entrypoint = entrypoint
190
+ else:
191
+ raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
192
+
193
+ def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
194
+ # Validate payload
195
+ self.validate()
196
+
197
+ # Prepare local variables
198
+ stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
199
+ source = self.source
200
+ entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
201
+
202
+ # Create stage if necessary
203
+ stage_name = stage_path.parts[0].lstrip("@")
204
+ # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
205
+ try:
206
+ session.sql(f"describe stage {stage_name}").collect()
207
+ except sp_exceptions.SnowparkSQLException:
208
+ session.sql(
209
+ f"create stage if not exists {stage_name}"
210
+ " encryption = ( type = 'SNOWFLAKE_SSE' )"
211
+ " comment = 'Created by snowflake.ml.jobs Python API'"
212
+ ).collect()
213
+
214
+ # Upload payload to stage
215
+ if not isinstance(source, Path):
216
+ source_code = generate_python_code(source, source_code_display=True)
217
+ _ = session.file.put_stream(
218
+ io.BytesIO(source_code.encode()),
219
+ stage_location=stage_path.joinpath(entrypoint).as_posix(),
220
+ auto_compress=False,
221
+ overwrite=True,
222
+ )
223
+ source = entrypoint.parent
224
+ elif source.is_dir():
225
+ # Manually traverse the directory and upload each file, since Snowflake PUT
226
+ # can't handle directories. Reduce the number of PUT operations by using
227
+ # wildcard patterns to batch upload files with the same extension.
228
+ for path in {
229
+ p.parent.joinpath(f"*{p.suffix}") if p.suffix else p for p in source.resolve().rglob("*") if p.is_file()
230
+ }:
231
+ session.file.put(
232
+ str(path),
233
+ stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
234
+ overwrite=True,
235
+ auto_compress=False,
236
+ )
237
+ else:
238
+ session.file.put(
239
+ str(source.resolve()),
240
+ stage_path.as_posix(),
241
+ overwrite=True,
242
+ auto_compress=False,
243
+ )
244
+ source = source.parent
245
+
246
+ # Upload requirements
247
+ # TODO: Check if payload includes both a requirements.txt file and pip_requirements
248
+ if self.pip_requirements:
249
+ # Upload requirements.txt to stage
250
+ session.file.put_stream(
251
+ io.BytesIO("\n".join(self.pip_requirements).encode()),
252
+ stage_location=stage_path.joinpath("requirements.txt").as_posix(),
253
+ auto_compress=False,
254
+ overwrite=True,
255
+ )
256
+
257
+ # Upload startup script
258
+ # TODO: Make sure payload does not include file with same name
259
+ session.file.put_stream(
260
+ io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
261
+ stage_location=stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
262
+ auto_compress=False,
263
+ overwrite=False, # FIXME
264
+ )
265
+
266
+ return types.UploadedPayload(
267
+ stage_path=stage_path,
268
+ entrypoint=[
269
+ "bash",
270
+ _STARTUP_SCRIPT_PATH,
271
+ entrypoint.relative_to(source),
272
+ ],
273
+ )
274
+
275
+
276
+ def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
277
+ # Unwrap Optional type annotations
278
+ param_type = param.annotation
279
+ if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
280
+ param_type = next(t for t in get_args(param_type) if t is not type(None))
281
+
282
+ # Return None for empty type annotations
283
+ if param_type == inspect.Parameter.empty:
284
+ return None
285
+ return cast(Type[object], param_type)
286
+
287
+
288
+ def _validate_parameter_type(param_type: Type[object], param_name: str) -> None:
289
+ # Validate param_type is a supported type
290
+ if param_type not in _SUPPORTED_ARG_TYPES:
291
+ raise ValueError(
292
+ f"Unsupported argument type {param_type} for '{param_name}'."
293
+ f" Supported types: {', '.join(t.__name__ for t in _SUPPORTED_ARG_TYPES)}"
294
+ )
295
+
296
+
297
+ def _generate_source_code_comment(func: Callable[..., Any]) -> str:
298
+ """Generate a comment string containing the source code of a function for readability."""
299
+ try:
300
+ if isinstance(func, functools.partial):
301
+ # Unwrap functools.partial and generate source code comment from the original function
302
+ comment = code_generation.generate_source_code(func.func) # type: ignore[arg-type]
303
+ args = itertools.chain((repr(a) for a in func.args), (f"{k}={v!r}" for k, v in func.keywords.items()))
304
+
305
+ # Update invocation comment to show arguments passed via functools.partial
306
+ comment = comment.replace(
307
+ f"= {func.func.__name__}",
308
+ "= functools.partial({}({}))".format(
309
+ func.func.__name__,
310
+ ", ".join(args),
311
+ ),
312
+ )
313
+ return comment
314
+ else:
315
+ return code_generation.generate_source_code(func) # type: ignore[arg-type]
316
+ except Exception as exc:
317
+ error_msg = f"Source code comment could not be generated for {func} due to error {exc}."
318
+ return code_generation.comment_source_code(error_msg)
319
+
320
+
321
+ def _serialize_callable(func: Callable[..., Any]) -> bytes:
322
+ try:
323
+ func_bytes: bytes = cp.dumps(func)
324
+ return func_bytes
325
+ except pickle.PicklingError as e:
326
+ if isinstance(func, functools.partial):
327
+ # Try to find which part of the partial isn't serializable for better debuggability
328
+ objects = [
329
+ ("function", func.func),
330
+ *((f"positional arg {i}", a) for i, a in enumerate(func.args)),
331
+ *((f"keyword arg '{k}'", v) for k, v in func.keywords.items()),
332
+ ]
333
+ for name, obj in objects:
334
+ try:
335
+ cp.dumps(obj)
336
+ except pickle.PicklingError:
337
+ raise ValueError(f"Unable to serialize {name}: {obj}") from e
338
+ raise ValueError(f"Unable to serialize function: {func}") from e
339
+
340
+
341
+ def _generate_param_handler_code(signature: inspect.Signature, output_name: str = "kwargs") -> str:
342
+ # Generate argparse logic for argument handling (type coercion, default values, etc)
343
+ argparse_code = ["import argparse", "", "parser = argparse.ArgumentParser()"]
344
+ argparse_postproc = []
345
+ for name, param in signature.parameters.items():
346
+ opts = {}
347
+
348
+ param_type = _get_parameter_type(param)
349
+ if param_type is not None:
350
+ _validate_parameter_type(param_type, name)
351
+ opts["type"] = param_type.__name__
352
+
353
+ if param.default != inspect.Parameter.empty:
354
+ opts["default"] = f"'{param.default}'" if isinstance(param.default, str) else param.default
355
+
356
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
357
+ # Keyword argument
358
+ argparse_code.append(
359
+ f"parser.add_argument('--{name}', required={'default' not in opts},"
360
+ f" {', '.join(f'{k}={v}' for k, v in opts.items())})"
361
+ )
362
+ else:
363
+ # Positional argument. Use `argparse.add_mutually_exclusive_group()`
364
+ # to allow passing positional args by name as well
365
+ group_name = f"{name}_group"
366
+ argparse_code.append(
367
+ f"{group_name} = parser.add_mutually_exclusive_group(required={'default' not in opts})"
368
+ )
369
+ argparse_code.append(
370
+ f"{group_name}.add_argument('pos-{name}', metavar='{name}', nargs='?',"
371
+ f" {', '.join(f'{k}={v}' for k, v in opts.items() if k != 'default')})"
372
+ )
373
+ argparse_code.append(
374
+ f"{group_name}.add_argument('--{name}', {', '.join(f'{k}={v}' for k, v in opts.items())})"
375
+ )
376
+ argparse_code.append("") # Add newline for readability
377
+ argparse_postproc.append(
378
+ f"args.{name} = {name} if ({name} := args.__dict__.pop('pos-{name}')) is not None else args.{name}"
379
+ )
380
+ argparse_code.append("args = parser.parse_args()")
381
+ param_code = "\n".join(argparse_code + argparse_postproc)
382
+ param_code += f"\n{output_name} = vars(args)"
383
+
384
+ return param_code
385
+
386
+
387
+ def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
388
+ """Generate an entrypoint script from a Python function."""
389
+ signature = inspect.signature(func)
390
+ if any(
391
+ p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
392
+ for p in signature.parameters.values()
393
+ ):
394
+ raise NotImplementedError("Function must not have unpacking arguments (* or **)")
395
+
396
+ # Mirrored from Snowpark generate_python_code() function
397
+ # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
398
+ source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
399
+
400
+ func_name = "func"
401
+ func_code = f"""
402
+ {source_code_comment}
403
+
404
+ import pickle
405
+ {func_name} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
406
+ """
407
+
408
+ arg_dict_name = "kwargs"
409
+ if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
410
+ param_code = f"{arg_dict_name} = {{}}"
411
+ else:
412
+ param_code = _generate_param_handler_code(signature, arg_dict_name)
413
+
414
+ return f"""
415
+ ### Version guard to check compatibility across Python versions ###
416
+ import sys
417
+ import warnings
418
+
419
+ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
420
+ warnings.warn(
421
+ "Python version mismatch: job was created using"
422
+ " python{sys.version_info.major}.{sys.version_info.minor}"
423
+ f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
424
+ " Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
425
+ " This will be fixed in a future release; for now, please use Python version"
426
+ f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
427
+ RuntimeWarning,
428
+ stacklevel=0,
429
+ )
430
+ ### End version guard ###
431
+
432
+ {func_code.strip()}
433
+
434
+ if __name__ == '__main__':
435
+ {textwrap.indent(param_code, ' ')}
436
+
437
+ {func_name}(**{arg_dict_name})
438
+ """