snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +5 -1
- snowflake/ml/experiment/callback/xgboost.py +5 -1
- snowflake/ml/experiment/experiment_tracking.py +89 -4
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +9 -14
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +61 -19
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
- snowflake/ml/jobs/_utils/spec_utils.py +44 -13
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +7 -8
- snowflake/ml/jobs/job.py +34 -18
- snowflake/ml/jobs/manager.py +107 -24
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +225 -73
- snowflake/ml/model/_client/ops/service_ops.py +128 -174
- snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +4 -3
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/registry/_manager/model_manager.py +7 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import logging
|
|
4
5
|
import re
|
|
5
6
|
import warnings
|
|
6
7
|
from collections import OrderedDict
|
|
@@ -31,10 +32,12 @@ from snowflake.snowpark.types import (
|
|
|
31
32
|
_NumericType,
|
|
32
33
|
)
|
|
33
34
|
|
|
35
|
+
_DEFAULT_TARGET_LAG = "10 seconds"
|
|
34
36
|
_FEATURE_VIEW_NAME_DELIMITER = "$"
|
|
35
37
|
_LEGACY_TIMESTAMP_COL_PLACEHOLDER_VALS = ["FS_TIMESTAMP_COL_PLACEHOLDER_VAL", "NULL"]
|
|
36
38
|
_TIMESTAMP_COL_PLACEHOLDER = "NULL"
|
|
37
39
|
_FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE"
|
|
40
|
+
_ONLINE_TABLE_SUFFIX = "$ONLINE"
|
|
38
41
|
# Feature view version rule is aligned with dataset version rule in SQL.
|
|
39
42
|
_FEATURE_VIEW_VERSION_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.\-]*$")
|
|
40
43
|
_FEATURE_VIEW_VERSION_MAX_LENGTH = 128
|
|
@@ -45,6 +48,44 @@ _RESULT_SCAN_QUERY_PATTERN = re.compile(
|
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class OnlineConfig:
|
|
53
|
+
"""Configuration for online feature storage."""
|
|
54
|
+
|
|
55
|
+
enable: bool = False
|
|
56
|
+
target_lag: Optional[str] = None
|
|
57
|
+
|
|
58
|
+
def __post_init__(self) -> None:
|
|
59
|
+
if self.target_lag is None:
|
|
60
|
+
return
|
|
61
|
+
if not isinstance(self.target_lag, str) or not self.target_lag.strip():
|
|
62
|
+
raise ValueError("target_lag must be a non-empty string")
|
|
63
|
+
|
|
64
|
+
object.__setattr__(self, "target_lag", self.target_lag.strip())
|
|
65
|
+
|
|
66
|
+
def to_json(self) -> str:
|
|
67
|
+
data: dict[str, Any] = asdict(self)
|
|
68
|
+
return json.dumps(data)
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_json(cls, json_str: str) -> OnlineConfig:
|
|
72
|
+
data = json.loads(json_str)
|
|
73
|
+
return cls(**data)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class StoreType(Enum):
|
|
77
|
+
"""
|
|
78
|
+
Enumeration for specifying the storage type when reading from or refreshing feature views.
|
|
79
|
+
|
|
80
|
+
The Feature View supports two storage modes:
|
|
81
|
+
- OFFLINE: Traditional batch storage for historical feature data and training
|
|
82
|
+
- ONLINE: Low-latency storage optimized for real-time feature serving
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
ONLINE = "online"
|
|
86
|
+
OFFLINE = "offline"
|
|
87
|
+
|
|
88
|
+
|
|
48
89
|
@dataclass(frozen=True)
|
|
49
90
|
class _FeatureViewMetadata:
|
|
50
91
|
"""Represent metadata tracked on top of FV backend object"""
|
|
@@ -171,6 +212,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
171
212
|
initialize: str = "ON_CREATE",
|
|
172
213
|
refresh_mode: str = "AUTO",
|
|
173
214
|
cluster_by: Optional[list[str]] = None,
|
|
215
|
+
online_config: Optional[OnlineConfig] = None,
|
|
174
216
|
**_kwargs: Any,
|
|
175
217
|
) -> None:
|
|
176
218
|
"""
|
|
@@ -204,6 +246,8 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
204
246
|
cluster_by: Columns to cluster the feature view by.
|
|
205
247
|
- Defaults to the join keys from entities.
|
|
206
248
|
- If `timestamp_col` is provided, it is added to the default clustering keys.
|
|
249
|
+
online_config: Optional configuration for online storage. If provided with enable=True,
|
|
250
|
+
online storage will be enabled. Defaults to None (no online storage).
|
|
207
251
|
_kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
|
|
208
252
|
|
|
209
253
|
Example::
|
|
@@ -227,9 +271,26 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
227
271
|
>>> registered_fv = fs.register_feature_view(draft_fv, "v1")
|
|
228
272
|
>>> print(registered_fv.status)
|
|
229
273
|
FeatureViewStatus.ACTIVE
|
|
274
|
+
<BLANKLINE>
|
|
275
|
+
>>> # Example with online configuration for online feature storage
|
|
276
|
+
>>> config = OnlineConfig(enable=True, target_lag='15s')
|
|
277
|
+
>>> online_fv = FeatureView(
|
|
278
|
+
... name="my_online_fv",
|
|
279
|
+
... entities=[e1, e2],
|
|
280
|
+
... feature_df=feature_df,
|
|
281
|
+
... timestamp_col='TS',
|
|
282
|
+
... refresh_freq='1d',
|
|
283
|
+
... desc='Feature view with online storage',
|
|
284
|
+
... online_config=config # optional, enables online feature storage
|
|
285
|
+
... )
|
|
286
|
+
>>> registered_online_fv = fs.register_feature_view(online_fv, "v1")
|
|
287
|
+
>>> print(registered_online_fv.online)
|
|
288
|
+
True
|
|
230
289
|
|
|
231
290
|
# noqa: DAR401
|
|
232
291
|
"""
|
|
292
|
+
if online_config is not None:
|
|
293
|
+
logging.warning("'online_config' is in private preview since 1.12.0. Do not use it in production.")
|
|
233
294
|
|
|
234
295
|
self._name: SqlIdentifier = SqlIdentifier(name)
|
|
235
296
|
self._entities: list[Entity] = entities
|
|
@@ -257,6 +318,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
257
318
|
self._cluster_by: list[SqlIdentifier] = (
|
|
258
319
|
[SqlIdentifier(col) for col in cluster_by] if cluster_by is not None else self._get_default_cluster_by()
|
|
259
320
|
)
|
|
321
|
+
self._online_config: Optional[OnlineConfig] = online_config
|
|
260
322
|
|
|
261
323
|
# Validate kwargs
|
|
262
324
|
if _kwargs:
|
|
@@ -470,6 +532,31 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
470
532
|
def feature_descs(self) -> Optional[dict[SqlIdentifier, str]]:
|
|
471
533
|
return self._feature_desc
|
|
472
534
|
|
|
535
|
+
@property
|
|
536
|
+
def online(self) -> bool:
|
|
537
|
+
return self._online_config.enable if self._online_config else False
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def online_config(self) -> Optional[OnlineConfig]:
|
|
541
|
+
return self._online_config
|
|
542
|
+
|
|
543
|
+
def fully_qualified_online_table_name(self) -> str:
|
|
544
|
+
"""Get the fully qualified name for the online feature table.
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
The fully qualified name (<database_name>.<schema_name>.<online_table_name>) for the
|
|
548
|
+
online feature table in Snowflake.
|
|
549
|
+
|
|
550
|
+
Raises:
|
|
551
|
+
RuntimeError: if the FeatureView is not registered or not configured for online storage.
|
|
552
|
+
"""
|
|
553
|
+
if self.status == FeatureViewStatus.DRAFT or self.version is None:
|
|
554
|
+
raise RuntimeError(f"FeatureView {self.name} has not been registered.")
|
|
555
|
+
if not self.online:
|
|
556
|
+
raise RuntimeError(f"FeatureView {self.name} is not configured for online storage.")
|
|
557
|
+
online_table_name = self._get_online_table_name(self.name, self.version)
|
|
558
|
+
return f"{self._database}.{self._schema}.{online_table_name}"
|
|
559
|
+
|
|
473
560
|
def list_columns(self) -> DataFrame:
|
|
474
561
|
"""List all columns and their information.
|
|
475
562
|
|
|
@@ -756,6 +843,8 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
756
843
|
feature_desc_dict[k.identifier()] = v
|
|
757
844
|
fv_dict["_feature_desc"] = feature_desc_dict
|
|
758
845
|
|
|
846
|
+
fv_dict["_online_config"] = self._online_config.to_json() if self._online_config is not None else None
|
|
847
|
+
|
|
759
848
|
lineage_node_keys = [key for key in fv_dict if key.startswith("_node") or key == "_session"]
|
|
760
849
|
|
|
761
850
|
for key in lineage_node_keys:
|
|
@@ -844,6 +933,9 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
844
933
|
owner=json_dict["_owner"],
|
|
845
934
|
infer_schema_df=session.sql(json_dict.get("_infer_schema_query", None)),
|
|
846
935
|
session=session,
|
|
936
|
+
online_config=OnlineConfig.from_json(json_dict["_online_config"])
|
|
937
|
+
if json_dict.get("_online_config")
|
|
938
|
+
else None,
|
|
847
939
|
)
|
|
848
940
|
|
|
849
941
|
def _get_compact_repr(self) -> _CompactRepresentation:
|
|
@@ -916,6 +1008,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
916
1008
|
infer_schema_df: Optional[DataFrame],
|
|
917
1009
|
session: Session,
|
|
918
1010
|
cluster_by: Optional[list[str]] = None,
|
|
1011
|
+
online_config: Optional[OnlineConfig] = None,
|
|
919
1012
|
) -> FeatureView:
|
|
920
1013
|
fv = FeatureView(
|
|
921
1014
|
name=name,
|
|
@@ -925,6 +1018,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
925
1018
|
desc=desc,
|
|
926
1019
|
_infer_schema_df=infer_schema_df,
|
|
927
1020
|
cluster_by=cluster_by,
|
|
1021
|
+
online_config=online_config,
|
|
928
1022
|
)
|
|
929
1023
|
fv._version = FeatureViewVersion(version) if version is not None else None
|
|
930
1024
|
fv._status = status
|
|
@@ -961,5 +1055,33 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
961
1055
|
|
|
962
1056
|
return default_cluster_by_cols
|
|
963
1057
|
|
|
1058
|
+
@staticmethod
|
|
1059
|
+
def _get_online_table_name(
|
|
1060
|
+
feature_view_name: Union[SqlIdentifier, str], version: Optional[Union[FeatureViewVersion, str]] = None
|
|
1061
|
+
) -> SqlIdentifier:
|
|
1062
|
+
"""Get the online feature table name without qualification.
|
|
1063
|
+
|
|
1064
|
+
Args:
|
|
1065
|
+
feature_view_name: Offline feature view name.
|
|
1066
|
+
version: Feature view version. If not provided, feature_view_name must be a SqlIdentifier.
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
The online table name SqlIdentifier
|
|
1070
|
+
"""
|
|
1071
|
+
if version is None:
|
|
1072
|
+
assert isinstance(feature_view_name, SqlIdentifier), "Single argument must be SqlIdentifier"
|
|
1073
|
+
online_name = f"{feature_view_name.resolved()}{_ONLINE_TABLE_SUFFIX}"
|
|
1074
|
+
return SqlIdentifier(online_name, case_sensitive=True)
|
|
1075
|
+
else:
|
|
1076
|
+
fv_name = (
|
|
1077
|
+
feature_view_name
|
|
1078
|
+
if isinstance(feature_view_name, SqlIdentifier)
|
|
1079
|
+
else SqlIdentifier(feature_view_name, case_sensitive=True)
|
|
1080
|
+
)
|
|
1081
|
+
fv_version = version if isinstance(version, FeatureViewVersion) else FeatureViewVersion(version)
|
|
1082
|
+
physical_name = FeatureView._get_physical_name(fv_name, fv_version).resolved()
|
|
1083
|
+
online_name = f"{physical_name}{_ONLINE_TABLE_SUFFIX}"
|
|
1084
|
+
return SqlIdentifier(online_name, case_sensitive=True)
|
|
1085
|
+
|
|
964
1086
|
|
|
965
1087
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["feature_view"] = FeatureView
|
|
File without changes
|
|
@@ -3,32 +3,29 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
|
3
3
|
|
|
4
4
|
# SPCS specification constants
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
|
6
|
+
MEMORY_VOLUME_NAME = "dshm"
|
|
7
|
+
STAGE_VOLUME_NAME = "stage-volume"
|
|
8
|
+
|
|
9
|
+
# Environment variables
|
|
10
|
+
STAGE_MOUNT_PATH_ENV_VAR = "MLRS_STAGE_MOUNT_PATH"
|
|
6
11
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
|
7
12
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
|
8
13
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
|
9
14
|
TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
|
|
10
15
|
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
|
11
|
-
MEMORY_VOLUME_NAME = "dshm"
|
|
12
|
-
STAGE_VOLUME_NAME = "stage-volume"
|
|
13
|
-
# Base mount path
|
|
14
|
-
STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
15
16
|
|
|
16
|
-
# Stage
|
|
17
|
+
# Stage mount paths
|
|
18
|
+
STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
17
19
|
APP_STAGE_SUBPATH = "app"
|
|
18
20
|
SYSTEM_STAGE_SUBPATH = "system"
|
|
19
21
|
OUTPUT_STAGE_SUBPATH = "output"
|
|
20
|
-
|
|
21
|
-
# Complete mount paths (automatically generated from base + subpath)
|
|
22
|
-
APP_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{APP_STAGE_SUBPATH}"
|
|
23
|
-
SYSTEM_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{SYSTEM_STAGE_SUBPATH}"
|
|
24
|
-
OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
|
|
25
|
-
|
|
22
|
+
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result.pkl"
|
|
26
23
|
|
|
27
24
|
# Default container image information
|
|
28
25
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
|
29
26
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
|
30
27
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
|
31
|
-
DEFAULT_IMAGE_TAG = "1.
|
|
28
|
+
DEFAULT_IMAGE_TAG = "1.6.2"
|
|
32
29
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
|
33
30
|
|
|
34
31
|
# Percent of container memory to allocate for /dev/shm volume
|
|
@@ -59,8 +56,6 @@ ENABLE_HEALTH_CHECKS = "false"
|
|
|
59
56
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
|
60
57
|
JOB_POLL_MAX_DELAY_SECONDS = 30
|
|
61
58
|
|
|
62
|
-
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_MOUNT_PATH}/mljob_result.pkl"
|
|
63
|
-
|
|
64
59
|
# Log start and end messages
|
|
65
60
|
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|
|
66
61
|
LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FeatureFlags(Enum):
|
|
6
|
+
USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
|
|
7
|
+
ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS"
|
|
8
|
+
|
|
9
|
+
def is_enabled(self) -> bool:
|
|
10
|
+
return os.getenv(self.value, "false").lower() == "true"
|
|
11
|
+
|
|
12
|
+
def is_disabled(self) -> bool:
|
|
13
|
+
return not self.is_enabled()
|
|
14
|
+
|
|
15
|
+
def __str__(self) -> str:
|
|
16
|
+
return self.value
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import importlib
|
|
2
3
|
import inspect
|
|
3
4
|
import io
|
|
4
5
|
import itertools
|
|
@@ -7,6 +8,7 @@ import logging
|
|
|
7
8
|
import pickle
|
|
8
9
|
import sys
|
|
9
10
|
import textwrap
|
|
11
|
+
from importlib.abc import Traversable
|
|
10
12
|
from pathlib import Path, PurePath
|
|
11
13
|
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
|
|
12
14
|
|
|
@@ -58,7 +60,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
58
60
|
|
|
59
61
|
# Change directory to user payload directory
|
|
60
62
|
if [ -n "${constants.PAYLOAD_DIR_ENV_VAR}" ]; then
|
|
61
|
-
cd ${constants.PAYLOAD_DIR_ENV_VAR}
|
|
63
|
+
cd ${constants.STAGE_MOUNT_PATH_ENV_VAR}/${constants.PAYLOAD_DIR_ENV_VAR}
|
|
62
64
|
fi
|
|
63
65
|
|
|
64
66
|
##### Set up Python environment #####
|
|
@@ -67,7 +69,10 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
67
69
|
|
|
68
70
|
if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
|
|
69
71
|
echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
|
|
70
|
-
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
72
|
+
if ! pip install --no-index -r $MLRS_SYSTEM_REQUIREMENTS_FILE; then
|
|
73
|
+
echo "Offline install failed, falling back to regular pip install"
|
|
74
|
+
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
75
|
+
fi
|
|
71
76
|
fi
|
|
72
77
|
|
|
73
78
|
MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
|
|
@@ -262,11 +267,24 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
|
|
|
262
267
|
# Manually traverse the directory and upload each file, since Snowflake PUT
|
|
263
268
|
# can't handle directories. Reduce the number of PUT operations by using
|
|
264
269
|
# wildcard patterns to batch upload files with the same extension.
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
+
upload_path_patterns = set()
|
|
271
|
+
for p in source_path.resolve().rglob("*"):
|
|
272
|
+
if p.is_dir():
|
|
273
|
+
continue
|
|
274
|
+
if p.name.startswith("."):
|
|
275
|
+
# Hidden files: use .* pattern for batch upload
|
|
276
|
+
if p.suffix:
|
|
277
|
+
upload_path_patterns.add(p.parent.joinpath(f".*{p.suffix}"))
|
|
278
|
+
else:
|
|
279
|
+
upload_path_patterns.add(p.parent.joinpath(".*"))
|
|
280
|
+
else:
|
|
281
|
+
# Regular files: use * pattern for batch upload
|
|
282
|
+
if p.suffix:
|
|
283
|
+
upload_path_patterns.add(p.parent.joinpath(f"*{p.suffix}"))
|
|
284
|
+
else:
|
|
285
|
+
upload_path_patterns.add(p)
|
|
286
|
+
|
|
287
|
+
for path in upload_path_patterns:
|
|
270
288
|
session.file.put(
|
|
271
289
|
str(path),
|
|
272
290
|
payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
|
|
@@ -282,6 +300,27 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
|
|
|
282
300
|
)
|
|
283
301
|
|
|
284
302
|
|
|
303
|
+
def upload_system_resources(session: snowpark.Session, stage_path: PurePath) -> None:
|
|
304
|
+
resource_ref = importlib.resources.files(__package__).joinpath("scripts")
|
|
305
|
+
|
|
306
|
+
def upload_dir(ref: Traversable, relative_path: str = "") -> None:
|
|
307
|
+
for item in ref.iterdir():
|
|
308
|
+
current_path = Path(relative_path) / item.name if relative_path else Path(item.name)
|
|
309
|
+
if item.is_dir():
|
|
310
|
+
# Recursively process subdirectories
|
|
311
|
+
upload_dir(item, str(current_path))
|
|
312
|
+
elif item.is_file():
|
|
313
|
+
content = item.read_bytes()
|
|
314
|
+
session.file.put_stream(
|
|
315
|
+
io.BytesIO(content),
|
|
316
|
+
stage_path.joinpath(current_path).as_posix(),
|
|
317
|
+
auto_compress=False,
|
|
318
|
+
overwrite=True,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
upload_dir(resource_ref)
|
|
322
|
+
|
|
323
|
+
|
|
285
324
|
def resolve_source(
|
|
286
325
|
source: Union[types.PayloadPath, Callable[..., Any]]
|
|
287
326
|
) -> Union[types.PayloadPath, Callable[..., Any]]:
|
|
@@ -497,29 +536,32 @@ class JobPayload:
|
|
|
497
536
|
overwrite=False, # FIXME
|
|
498
537
|
)
|
|
499
538
|
|
|
500
|
-
|
|
501
|
-
for script_file in scripts_dir.glob("*"):
|
|
502
|
-
if script_file.is_file():
|
|
503
|
-
session.file.put(
|
|
504
|
-
script_file.as_posix(),
|
|
505
|
-
system_stage_path.as_posix(),
|
|
506
|
-
overwrite=True,
|
|
507
|
-
auto_compress=False,
|
|
508
|
-
)
|
|
539
|
+
upload_system_resources(session, system_stage_path)
|
|
509
540
|
python_entrypoint: list[Union[str, PurePath]] = [
|
|
510
|
-
PurePath(
|
|
511
|
-
PurePath(
|
|
541
|
+
PurePath(constants.STAGE_VOLUME_MOUNT_PATH, constants.SYSTEM_STAGE_SUBPATH, "mljob_launcher.py"),
|
|
542
|
+
PurePath(
|
|
543
|
+
constants.STAGE_VOLUME_MOUNT_PATH,
|
|
544
|
+
constants.APP_STAGE_SUBPATH,
|
|
545
|
+
entrypoint.file_path.relative_to(source).as_posix(),
|
|
546
|
+
),
|
|
512
547
|
]
|
|
513
548
|
if entrypoint.main_func:
|
|
514
549
|
python_entrypoint += ["--script_main_func", entrypoint.main_func]
|
|
515
550
|
|
|
551
|
+
env_vars = {
|
|
552
|
+
constants.STAGE_MOUNT_PATH_ENV_VAR: constants.STAGE_VOLUME_MOUNT_PATH,
|
|
553
|
+
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_STAGE_SUBPATH,
|
|
554
|
+
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
555
|
+
}
|
|
556
|
+
|
|
516
557
|
return types.UploadedPayload(
|
|
517
558
|
stage_path=stage_path,
|
|
518
559
|
entrypoint=[
|
|
519
560
|
"bash",
|
|
520
|
-
f"{constants.
|
|
561
|
+
f"{constants.STAGE_VOLUME_MOUNT_PATH}/{constants.SYSTEM_STAGE_SUBPATH}/{_STARTUP_SCRIPT_PATH}",
|
|
521
562
|
*python_entrypoint,
|
|
522
563
|
],
|
|
564
|
+
env_vars=env_vars,
|
|
523
565
|
)
|
|
524
566
|
|
|
525
567
|
|
|
@@ -4,6 +4,7 @@ from snowflake import snowpark
|
|
|
4
4
|
from snowflake.snowpark import Row
|
|
5
5
|
from snowflake.snowpark._internal import utils
|
|
6
6
|
from snowflake.snowpark._internal.analyzer import snowflake_plan
|
|
7
|
+
from snowflake.snowpark._internal.utils import is_in_stored_procedure
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> list[Row]:
|
|
@@ -14,7 +15,10 @@ def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> lis
|
|
|
14
15
|
|
|
15
16
|
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
|
16
17
|
def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]:
|
|
17
|
-
|
|
18
|
+
kwargs: dict[str, Any] = {"query": query_text, "params": params}
|
|
19
|
+
if not is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
|
20
|
+
kwargs["_force_qmark_paramstyle"] = True
|
|
21
|
+
result = session._conn.run_query(**kwargs)
|
|
18
22
|
if not isinstance(result, dict) or "data" not in result:
|
|
19
23
|
raise ValueError(f"Unprocessable result: {result}")
|
|
20
24
|
return result_set_to_rows(session, result)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
from packaging.version import Version
|
|
4
|
+
from pydantic import BaseModel, Field, RootModel, field_validator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SpcsContainerRuntime(BaseModel):
|
|
8
|
+
python_version: Version = Field(alias="pythonVersion")
|
|
9
|
+
hardware_type: str = Field(alias="hardwareType")
|
|
10
|
+
runtime_container_image: str = Field(alias="runtimeContainerImage")
|
|
11
|
+
|
|
12
|
+
@field_validator("python_version", mode="before")
|
|
13
|
+
@classmethod
|
|
14
|
+
def validate_python_version(cls, v: Union[str, Version]) -> Version:
|
|
15
|
+
if isinstance(v, Version):
|
|
16
|
+
return v
|
|
17
|
+
try:
|
|
18
|
+
return Version(v)
|
|
19
|
+
except Exception:
|
|
20
|
+
raise ValueError(f"Invalid Python version format: {v}")
|
|
21
|
+
|
|
22
|
+
class Config:
|
|
23
|
+
frozen = True
|
|
24
|
+
extra = "allow"
|
|
25
|
+
arbitrary_types_allowed = True
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RuntimeEnvironmentEntry(BaseModel):
|
|
29
|
+
spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
|
|
30
|
+
|
|
31
|
+
class Config:
|
|
32
|
+
extra = "allow"
|
|
33
|
+
frozen = True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
|
|
37
|
+
@field_validator("root", mode="before")
|
|
38
|
+
@classmethod
|
|
39
|
+
def _filter_to_dict_entries(cls, data: Any) -> dict[str, dict[str, Any]]:
|
|
40
|
+
"""
|
|
41
|
+
Pre-validation hook: keep only those items at the root level
|
|
42
|
+
whose values are dicts. Non-dict values will be dropped.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
data: The input data to filter, expected to be a dictionary.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A dictionary containing only the key-value pairs where values are dictionaries.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If input data is not a dictionary.
|
|
52
|
+
"""
|
|
53
|
+
# If the entire root is not a dict, raise error immediately
|
|
54
|
+
if not isinstance(data, dict):
|
|
55
|
+
raise ValueError(f"Expected dictionary data, but got {type(data).__name__}: {data}")
|
|
56
|
+
|
|
57
|
+
# Filter out any key whose value is not a dict
|
|
58
|
+
return {key: value for key, value in data.items() if isinstance(value, dict)}
|
|
59
|
+
|
|
60
|
+
def get_spcs_container_runtimes(self) -> list[SpcsContainerRuntime]:
|
|
61
|
+
return [
|
|
62
|
+
entry.spcs_container_runtime for entry in self.root.values() if entry.spcs_container_runtime is not None
|
|
63
|
+
]
|
|
@@ -41,18 +41,29 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
|
|
|
41
41
|
from snowflake.runtime.utils import session_utils
|
|
42
42
|
|
|
43
43
|
session = session_utils.get_session()
|
|
44
|
-
|
|
45
|
-
result = df.select('"instance_id"', '"ip_address"', '"start_time"', '"status"').collect()
|
|
44
|
+
result = session.sql(f"show service instances in service {service_name}").collect()
|
|
46
45
|
|
|
47
46
|
if not result:
|
|
48
47
|
return None
|
|
49
|
-
|
|
50
|
-
#
|
|
51
|
-
|
|
52
|
-
|
|
48
|
+
# we have already integrated with first_instance startup policy,
|
|
49
|
+
# the instance 0 is guaranteed to be the head instance
|
|
50
|
+
head_instance = next(
|
|
51
|
+
(
|
|
52
|
+
row
|
|
53
|
+
for row in result
|
|
54
|
+
if "instance_id" in row and row["instance_id"] is not None and int(row["instance_id"]) == 0
|
|
55
|
+
),
|
|
56
|
+
None,
|
|
57
|
+
)
|
|
58
|
+
# fallback to find the first instance if the instance 0 is not found
|
|
59
|
+
if not head_instance:
|
|
60
|
+
# Sort by start_time first, then by instance_id. If start_time is null/empty, it will be sorted to the end.
|
|
61
|
+
sorted_instances = sorted(
|
|
62
|
+
result, key=lambda x: (not bool(x["start_time"]), x["start_time"], int(x["instance_id"]))
|
|
63
|
+
)
|
|
64
|
+
head_instance = sorted_instances[0]
|
|
53
65
|
if not head_instance["instance_id"] or not head_instance["ip_address"]:
|
|
54
66
|
return None
|
|
55
|
-
|
|
56
67
|
# Validate head instance IP
|
|
57
68
|
ip_address = head_instance["ip_address"]
|
|
58
69
|
try:
|
|
@@ -48,8 +48,8 @@ MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_IN
|
|
|
48
48
|
TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
|
|
49
49
|
|
|
50
50
|
# Fallbacks in case of SnowML version mismatch
|
|
51
|
+
STAGE_MOUNT_PATH_ENV_VAR = getattr(constants, "STAGE_MOUNT_PATH_ENV_VAR", "MLRS_STAGE_MOUNT_PATH")
|
|
51
52
|
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
|
52
|
-
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
|
|
53
53
|
PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
|
|
54
54
|
|
|
55
55
|
# Constants for the wait_for_instances function
|
|
@@ -57,6 +57,9 @@ MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
|
|
|
57
57
|
TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
|
|
58
58
|
CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
|
|
59
59
|
|
|
60
|
+
STAGE_MOUNT_PATH = os.environ.get(STAGE_MOUNT_PATH_ENV_VAR, "/mnt/job_stage")
|
|
61
|
+
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "output/mljob_result.pkl")
|
|
62
|
+
|
|
60
63
|
|
|
61
64
|
try:
|
|
62
65
|
from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
|
|
@@ -173,10 +176,10 @@ def wait_for_instances(
|
|
|
173
176
|
|
|
174
177
|
start_time = time.time()
|
|
175
178
|
current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
|
|
176
|
-
logger.
|
|
179
|
+
logger.info(
|
|
177
180
|
"Waiting for instances to be ready "
|
|
178
|
-
"(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
|
|
179
|
-
min_instances, target_instances, timeout, check_interval
|
|
181
|
+
"(min_instances={}, target_instances={}, min_wait_time={}s, timeout={}s, max_check_interval={}s)".format(
|
|
182
|
+
min_instances, target_instances, min_wait_time, timeout, check_interval
|
|
180
183
|
)
|
|
181
184
|
)
|
|
182
185
|
|
|
@@ -226,6 +229,8 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
226
229
|
# This is needed because mljob_launcher.py is now in /mnt/job_stage/system
|
|
227
230
|
# but user scripts are in the payload directory and may import from each other
|
|
228
231
|
payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
|
|
232
|
+
if payload_dir and not os.path.isabs(payload_dir):
|
|
233
|
+
payload_dir = os.path.join(STAGE_MOUNT_PATH, payload_dir)
|
|
229
234
|
if payload_dir and payload_dir not in sys.path:
|
|
230
235
|
sys.path.insert(0, payload_dir)
|
|
231
236
|
|
|
@@ -276,7 +281,10 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
276
281
|
Exception: Re-raises any exception caught during script execution.
|
|
277
282
|
"""
|
|
278
283
|
# Ensure the output directory exists before trying to write result files.
|
|
279
|
-
|
|
284
|
+
result_abs_path = (
|
|
285
|
+
JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH)
|
|
286
|
+
)
|
|
287
|
+
output_dir = os.path.dirname(result_abs_path)
|
|
280
288
|
os.makedirs(output_dir, exist_ok=True)
|
|
281
289
|
|
|
282
290
|
try:
|
|
@@ -317,7 +325,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
317
325
|
result_dict = result_obj.to_dict()
|
|
318
326
|
try:
|
|
319
327
|
# Serialize result using cloudpickle
|
|
320
|
-
result_pickle_path =
|
|
328
|
+
result_pickle_path = result_abs_path
|
|
321
329
|
with open(result_pickle_path, "wb") as f:
|
|
322
330
|
cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
|
|
323
331
|
except Exception as pkl_exc:
|
|
@@ -326,7 +334,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
326
334
|
try:
|
|
327
335
|
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
|
328
336
|
# TODO: Manually convert non-serializable types to strings
|
|
329
|
-
result_json_path = os.path.splitext(
|
|
337
|
+
result_json_path = os.path.splitext(result_abs_path)[0] + ".json"
|
|
330
338
|
with open(result_json_path, "w") as f:
|
|
331
339
|
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
|
332
340
|
except Exception as json_exc:
|