snowflake-ml-python 1.11.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/experiment_tracking.py +89 -4
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/constants.py +8 -16
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +19 -5
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +12 -4
- snowflake/ml/jobs/_utils/spec_utils.py +4 -6
- snowflake/ml/jobs/_utils/types.py +2 -1
- snowflake/ml/jobs/job.py +33 -17
- snowflake/ml/jobs/manager.py +107 -12
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +61 -65
- snowflake/ml/model/_client/ops/service_ops.py +73 -154
- snowflake/ml/model/_client/service/model_deployment_spec.py +20 -37
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +14 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +66 -5
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +192 -188
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import logging
|
|
4
5
|
import re
|
|
5
6
|
import warnings
|
|
6
7
|
from collections import OrderedDict
|
|
@@ -31,10 +32,12 @@ from snowflake.snowpark.types import (
|
|
|
31
32
|
_NumericType,
|
|
32
33
|
)
|
|
33
34
|
|
|
35
|
+
_DEFAULT_TARGET_LAG = "10 seconds"
|
|
34
36
|
_FEATURE_VIEW_NAME_DELIMITER = "$"
|
|
35
37
|
_LEGACY_TIMESTAMP_COL_PLACEHOLDER_VALS = ["FS_TIMESTAMP_COL_PLACEHOLDER_VAL", "NULL"]
|
|
36
38
|
_TIMESTAMP_COL_PLACEHOLDER = "NULL"
|
|
37
39
|
_FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE"
|
|
40
|
+
_ONLINE_TABLE_SUFFIX = "$ONLINE"
|
|
38
41
|
# Feature view version rule is aligned with dataset version rule in SQL.
|
|
39
42
|
_FEATURE_VIEW_VERSION_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.\-]*$")
|
|
40
43
|
_FEATURE_VIEW_VERSION_MAX_LENGTH = 128
|
|
@@ -45,6 +48,44 @@ _RESULT_SCAN_QUERY_PATTERN = re.compile(
|
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class OnlineConfig:
|
|
53
|
+
"""Configuration for online feature storage."""
|
|
54
|
+
|
|
55
|
+
enable: bool = False
|
|
56
|
+
target_lag: Optional[str] = None
|
|
57
|
+
|
|
58
|
+
def __post_init__(self) -> None:
|
|
59
|
+
if self.target_lag is None:
|
|
60
|
+
return
|
|
61
|
+
if not isinstance(self.target_lag, str) or not self.target_lag.strip():
|
|
62
|
+
raise ValueError("target_lag must be a non-empty string")
|
|
63
|
+
|
|
64
|
+
object.__setattr__(self, "target_lag", self.target_lag.strip())
|
|
65
|
+
|
|
66
|
+
def to_json(self) -> str:
|
|
67
|
+
data: dict[str, Any] = asdict(self)
|
|
68
|
+
return json.dumps(data)
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_json(cls, json_str: str) -> OnlineConfig:
|
|
72
|
+
data = json.loads(json_str)
|
|
73
|
+
return cls(**data)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class StoreType(Enum):
|
|
77
|
+
"""
|
|
78
|
+
Enumeration for specifying the storage type when reading from or refreshing feature views.
|
|
79
|
+
|
|
80
|
+
The Feature View supports two storage modes:
|
|
81
|
+
- OFFLINE: Traditional batch storage for historical feature data and training
|
|
82
|
+
- ONLINE: Low-latency storage optimized for real-time feature serving
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
ONLINE = "online"
|
|
86
|
+
OFFLINE = "offline"
|
|
87
|
+
|
|
88
|
+
|
|
48
89
|
@dataclass(frozen=True)
|
|
49
90
|
class _FeatureViewMetadata:
|
|
50
91
|
"""Represent metadata tracked on top of FV backend object"""
|
|
@@ -171,6 +212,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
171
212
|
initialize: str = "ON_CREATE",
|
|
172
213
|
refresh_mode: str = "AUTO",
|
|
173
214
|
cluster_by: Optional[list[str]] = None,
|
|
215
|
+
online_config: Optional[OnlineConfig] = None,
|
|
174
216
|
**_kwargs: Any,
|
|
175
217
|
) -> None:
|
|
176
218
|
"""
|
|
@@ -204,6 +246,8 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
204
246
|
cluster_by: Columns to cluster the feature view by.
|
|
205
247
|
- Defaults to the join keys from entities.
|
|
206
248
|
- If `timestamp_col` is provided, it is added to the default clustering keys.
|
|
249
|
+
online_config: Optional configuration for online storage. If provided with enable=True,
|
|
250
|
+
online storage will be enabled. Defaults to None (no online storage).
|
|
207
251
|
_kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
|
|
208
252
|
|
|
209
253
|
Example::
|
|
@@ -227,9 +271,26 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
227
271
|
>>> registered_fv = fs.register_feature_view(draft_fv, "v1")
|
|
228
272
|
>>> print(registered_fv.status)
|
|
229
273
|
FeatureViewStatus.ACTIVE
|
|
274
|
+
<BLANKLINE>
|
|
275
|
+
>>> # Example with online configuration for online feature storage
|
|
276
|
+
>>> config = OnlineConfig(enable=True, target_lag='15s')
|
|
277
|
+
>>> online_fv = FeatureView(
|
|
278
|
+
... name="my_online_fv",
|
|
279
|
+
... entities=[e1, e2],
|
|
280
|
+
... feature_df=feature_df,
|
|
281
|
+
... timestamp_col='TS',
|
|
282
|
+
... refresh_freq='1d',
|
|
283
|
+
... desc='Feature view with online storage',
|
|
284
|
+
... online_config=config # optional, enables online feature storage
|
|
285
|
+
... )
|
|
286
|
+
>>> registered_online_fv = fs.register_feature_view(online_fv, "v1")
|
|
287
|
+
>>> print(registered_online_fv.online)
|
|
288
|
+
True
|
|
230
289
|
|
|
231
290
|
# noqa: DAR401
|
|
232
291
|
"""
|
|
292
|
+
if online_config is not None:
|
|
293
|
+
logging.warning("'online_config' is in private preview since 1.12.0. Do not use it in production.")
|
|
233
294
|
|
|
234
295
|
self._name: SqlIdentifier = SqlIdentifier(name)
|
|
235
296
|
self._entities: list[Entity] = entities
|
|
@@ -257,6 +318,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
257
318
|
self._cluster_by: list[SqlIdentifier] = (
|
|
258
319
|
[SqlIdentifier(col) for col in cluster_by] if cluster_by is not None else self._get_default_cluster_by()
|
|
259
320
|
)
|
|
321
|
+
self._online_config: Optional[OnlineConfig] = online_config
|
|
260
322
|
|
|
261
323
|
# Validate kwargs
|
|
262
324
|
if _kwargs:
|
|
@@ -470,6 +532,31 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
470
532
|
def feature_descs(self) -> Optional[dict[SqlIdentifier, str]]:
|
|
471
533
|
return self._feature_desc
|
|
472
534
|
|
|
535
|
+
@property
|
|
536
|
+
def online(self) -> bool:
|
|
537
|
+
return self._online_config.enable if self._online_config else False
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def online_config(self) -> Optional[OnlineConfig]:
|
|
541
|
+
return self._online_config
|
|
542
|
+
|
|
543
|
+
def fully_qualified_online_table_name(self) -> str:
|
|
544
|
+
"""Get the fully qualified name for the online feature table.
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
The fully qualified name (<database_name>.<schema_name>.<online_table_name>) for the
|
|
548
|
+
online feature table in Snowflake.
|
|
549
|
+
|
|
550
|
+
Raises:
|
|
551
|
+
RuntimeError: if the FeatureView is not registered or not configured for online storage.
|
|
552
|
+
"""
|
|
553
|
+
if self.status == FeatureViewStatus.DRAFT or self.version is None:
|
|
554
|
+
raise RuntimeError(f"FeatureView {self.name} has not been registered.")
|
|
555
|
+
if not self.online:
|
|
556
|
+
raise RuntimeError(f"FeatureView {self.name} is not configured for online storage.")
|
|
557
|
+
online_table_name = self._get_online_table_name(self.name, self.version)
|
|
558
|
+
return f"{self._database}.{self._schema}.{online_table_name}"
|
|
559
|
+
|
|
473
560
|
def list_columns(self) -> DataFrame:
|
|
474
561
|
"""List all columns and their information.
|
|
475
562
|
|
|
@@ -756,6 +843,8 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
756
843
|
feature_desc_dict[k.identifier()] = v
|
|
757
844
|
fv_dict["_feature_desc"] = feature_desc_dict
|
|
758
845
|
|
|
846
|
+
fv_dict["_online_config"] = self._online_config.to_json() if self._online_config is not None else None
|
|
847
|
+
|
|
759
848
|
lineage_node_keys = [key for key in fv_dict if key.startswith("_node") or key == "_session"]
|
|
760
849
|
|
|
761
850
|
for key in lineage_node_keys:
|
|
@@ -844,6 +933,9 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
844
933
|
owner=json_dict["_owner"],
|
|
845
934
|
infer_schema_df=session.sql(json_dict.get("_infer_schema_query", None)),
|
|
846
935
|
session=session,
|
|
936
|
+
online_config=OnlineConfig.from_json(json_dict["_online_config"])
|
|
937
|
+
if json_dict.get("_online_config")
|
|
938
|
+
else None,
|
|
847
939
|
)
|
|
848
940
|
|
|
849
941
|
def _get_compact_repr(self) -> _CompactRepresentation:
|
|
@@ -916,6 +1008,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
916
1008
|
infer_schema_df: Optional[DataFrame],
|
|
917
1009
|
session: Session,
|
|
918
1010
|
cluster_by: Optional[list[str]] = None,
|
|
1011
|
+
online_config: Optional[OnlineConfig] = None,
|
|
919
1012
|
) -> FeatureView:
|
|
920
1013
|
fv = FeatureView(
|
|
921
1014
|
name=name,
|
|
@@ -925,6 +1018,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
925
1018
|
desc=desc,
|
|
926
1019
|
_infer_schema_df=infer_schema_df,
|
|
927
1020
|
cluster_by=cluster_by,
|
|
1021
|
+
online_config=online_config,
|
|
928
1022
|
)
|
|
929
1023
|
fv._version = FeatureViewVersion(version) if version is not None else None
|
|
930
1024
|
fv._status = status
|
|
@@ -961,5 +1055,33 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
961
1055
|
|
|
962
1056
|
return default_cluster_by_cols
|
|
963
1057
|
|
|
1058
|
+
@staticmethod
|
|
1059
|
+
def _get_online_table_name(
|
|
1060
|
+
feature_view_name: Union[SqlIdentifier, str], version: Optional[Union[FeatureViewVersion, str]] = None
|
|
1061
|
+
) -> SqlIdentifier:
|
|
1062
|
+
"""Get the online feature table name without qualification.
|
|
1063
|
+
|
|
1064
|
+
Args:
|
|
1065
|
+
feature_view_name: Offline feature view name.
|
|
1066
|
+
version: Feature view version. If not provided, feature_view_name must be a SqlIdentifier.
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
The online table name SqlIdentifier
|
|
1070
|
+
"""
|
|
1071
|
+
if version is None:
|
|
1072
|
+
assert isinstance(feature_view_name, SqlIdentifier), "Single argument must be SqlIdentifier"
|
|
1073
|
+
online_name = f"{feature_view_name.resolved()}{_ONLINE_TABLE_SUFFIX}"
|
|
1074
|
+
return SqlIdentifier(online_name, case_sensitive=True)
|
|
1075
|
+
else:
|
|
1076
|
+
fv_name = (
|
|
1077
|
+
feature_view_name
|
|
1078
|
+
if isinstance(feature_view_name, SqlIdentifier)
|
|
1079
|
+
else SqlIdentifier(feature_view_name, case_sensitive=True)
|
|
1080
|
+
)
|
|
1081
|
+
fv_version = version if isinstance(version, FeatureViewVersion) else FeatureViewVersion(version)
|
|
1082
|
+
physical_name = FeatureView._get_physical_name(fv_name, fv_version).resolved()
|
|
1083
|
+
online_name = f"{physical_name}{_ONLINE_TABLE_SUFFIX}"
|
|
1084
|
+
return SqlIdentifier(online_name, case_sensitive=True)
|
|
1085
|
+
|
|
964
1086
|
|
|
965
1087
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["feature_view"] = FeatureView
|
|
@@ -3,26 +3,23 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
|
3
3
|
|
|
4
4
|
# SPCS specification constants
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
|
6
|
+
MEMORY_VOLUME_NAME = "dshm"
|
|
7
|
+
STAGE_VOLUME_NAME = "stage-volume"
|
|
8
|
+
|
|
9
|
+
# Environment variables
|
|
10
|
+
STAGE_MOUNT_PATH_ENV_VAR = "MLRS_STAGE_MOUNT_PATH"
|
|
6
11
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
|
7
12
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
|
8
13
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
|
9
14
|
TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
|
|
10
15
|
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
|
11
|
-
MEMORY_VOLUME_NAME = "dshm"
|
|
12
|
-
STAGE_VOLUME_NAME = "stage-volume"
|
|
13
|
-
# Base mount path
|
|
14
|
-
STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
15
16
|
|
|
16
|
-
# Stage
|
|
17
|
+
# Stage mount paths
|
|
18
|
+
STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
17
19
|
APP_STAGE_SUBPATH = "app"
|
|
18
20
|
SYSTEM_STAGE_SUBPATH = "system"
|
|
19
21
|
OUTPUT_STAGE_SUBPATH = "output"
|
|
20
|
-
|
|
21
|
-
# Complete mount paths (automatically generated from base + subpath)
|
|
22
|
-
APP_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{APP_STAGE_SUBPATH}"
|
|
23
|
-
SYSTEM_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{SYSTEM_STAGE_SUBPATH}"
|
|
24
|
-
OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
|
|
25
|
-
|
|
22
|
+
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result.pkl"
|
|
26
23
|
|
|
27
24
|
# Default container image information
|
|
28
25
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
|
@@ -59,8 +56,6 @@ ENABLE_HEALTH_CHECKS = "false"
|
|
|
59
56
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
|
60
57
|
JOB_POLL_MAX_DELAY_SECONDS = 30
|
|
61
58
|
|
|
62
|
-
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_MOUNT_PATH}/mljob_result.pkl"
|
|
63
|
-
|
|
64
59
|
# Log start and end messages
|
|
65
60
|
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|
|
66
61
|
LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
|
|
@@ -98,6 +93,3 @@ CLOUD_INSTANCE_FAMILIES = {
|
|
|
98
93
|
SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
|
|
99
94
|
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
|
100
95
|
}
|
|
101
|
-
|
|
102
|
-
# runtime version environment variable
|
|
103
|
-
ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FeatureFlags(Enum):
|
|
6
|
+
USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
|
|
7
|
+
ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS"
|
|
8
|
+
|
|
9
|
+
def is_enabled(self) -> bool:
|
|
10
|
+
return os.getenv(self.value, "false").lower() == "true"
|
|
11
|
+
|
|
12
|
+
def is_disabled(self) -> bool:
|
|
13
|
+
return not self.is_enabled()
|
|
14
|
+
|
|
15
|
+
def __str__(self) -> str:
|
|
16
|
+
return self.value
|
|
@@ -60,7 +60,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
60
60
|
|
|
61
61
|
# Change directory to user payload directory
|
|
62
62
|
if [ -n "${constants.PAYLOAD_DIR_ENV_VAR}" ]; then
|
|
63
|
-
cd ${constants.PAYLOAD_DIR_ENV_VAR}
|
|
63
|
+
cd ${constants.STAGE_MOUNT_PATH_ENV_VAR}/${constants.PAYLOAD_DIR_ENV_VAR}
|
|
64
64
|
fi
|
|
65
65
|
|
|
66
66
|
##### Set up Python environment #####
|
|
@@ -69,7 +69,10 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
69
69
|
|
|
70
70
|
if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
|
|
71
71
|
echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
|
|
72
|
-
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
72
|
+
if ! pip install --no-index -r $MLRS_SYSTEM_REQUIREMENTS_FILE; then
|
|
73
|
+
echo "Offline install failed, falling back to regular pip install"
|
|
74
|
+
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
75
|
+
fi
|
|
73
76
|
fi
|
|
74
77
|
|
|
75
78
|
MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
|
|
@@ -535,19 +538,30 @@ class JobPayload:
|
|
|
535
538
|
|
|
536
539
|
upload_system_resources(session, system_stage_path)
|
|
537
540
|
python_entrypoint: list[Union[str, PurePath]] = [
|
|
538
|
-
PurePath(
|
|
539
|
-
PurePath(
|
|
541
|
+
PurePath(constants.STAGE_VOLUME_MOUNT_PATH, constants.SYSTEM_STAGE_SUBPATH, "mljob_launcher.py"),
|
|
542
|
+
PurePath(
|
|
543
|
+
constants.STAGE_VOLUME_MOUNT_PATH,
|
|
544
|
+
constants.APP_STAGE_SUBPATH,
|
|
545
|
+
entrypoint.file_path.relative_to(source).as_posix(),
|
|
546
|
+
),
|
|
540
547
|
]
|
|
541
548
|
if entrypoint.main_func:
|
|
542
549
|
python_entrypoint += ["--script_main_func", entrypoint.main_func]
|
|
543
550
|
|
|
551
|
+
env_vars = {
|
|
552
|
+
constants.STAGE_MOUNT_PATH_ENV_VAR: constants.STAGE_VOLUME_MOUNT_PATH,
|
|
553
|
+
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_STAGE_SUBPATH,
|
|
554
|
+
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
555
|
+
}
|
|
556
|
+
|
|
544
557
|
return types.UploadedPayload(
|
|
545
558
|
stage_path=stage_path,
|
|
546
559
|
entrypoint=[
|
|
547
560
|
"bash",
|
|
548
|
-
f"{constants.
|
|
561
|
+
f"{constants.STAGE_VOLUME_MOUNT_PATH}/{constants.SYSTEM_STAGE_SUBPATH}/{_STARTUP_SCRIPT_PATH}",
|
|
549
562
|
*python_entrypoint,
|
|
550
563
|
],
|
|
564
|
+
env_vars=env_vars,
|
|
551
565
|
)
|
|
552
566
|
|
|
553
567
|
|
|
@@ -41,18 +41,29 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
|
|
|
41
41
|
from snowflake.runtime.utils import session_utils
|
|
42
42
|
|
|
43
43
|
session = session_utils.get_session()
|
|
44
|
-
|
|
45
|
-
result = df.select('"instance_id"', '"ip_address"', '"start_time"', '"status"').collect()
|
|
44
|
+
result = session.sql(f"show service instances in service {service_name}").collect()
|
|
46
45
|
|
|
47
46
|
if not result:
|
|
48
47
|
return None
|
|
49
|
-
|
|
50
|
-
#
|
|
51
|
-
|
|
52
|
-
|
|
48
|
+
# we have already integrated with first_instance startup policy,
|
|
49
|
+
# the instance 0 is guaranteed to be the head instance
|
|
50
|
+
head_instance = next(
|
|
51
|
+
(
|
|
52
|
+
row
|
|
53
|
+
for row in result
|
|
54
|
+
if "instance_id" in row and row["instance_id"] is not None and int(row["instance_id"]) == 0
|
|
55
|
+
),
|
|
56
|
+
None,
|
|
57
|
+
)
|
|
58
|
+
# fallback to find the first instance if the instance 0 is not found
|
|
59
|
+
if not head_instance:
|
|
60
|
+
# Sort by start_time first, then by instance_id. If start_time is null/empty, it will be sorted to the end.
|
|
61
|
+
sorted_instances = sorted(
|
|
62
|
+
result, key=lambda x: (not bool(x["start_time"]), x["start_time"], int(x["instance_id"]))
|
|
63
|
+
)
|
|
64
|
+
head_instance = sorted_instances[0]
|
|
53
65
|
if not head_instance["instance_id"] or not head_instance["ip_address"]:
|
|
54
66
|
return None
|
|
55
|
-
|
|
56
67
|
# Validate head instance IP
|
|
57
68
|
ip_address = head_instance["ip_address"]
|
|
58
69
|
try:
|
|
@@ -48,8 +48,8 @@ MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_IN
|
|
|
48
48
|
TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
|
|
49
49
|
|
|
50
50
|
# Fallbacks in case of SnowML version mismatch
|
|
51
|
+
STAGE_MOUNT_PATH_ENV_VAR = getattr(constants, "STAGE_MOUNT_PATH_ENV_VAR", "MLRS_STAGE_MOUNT_PATH")
|
|
51
52
|
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
|
52
|
-
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
|
|
53
53
|
PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
|
|
54
54
|
|
|
55
55
|
# Constants for the wait_for_instances function
|
|
@@ -57,6 +57,9 @@ MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
|
|
|
57
57
|
TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
|
|
58
58
|
CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
|
|
59
59
|
|
|
60
|
+
STAGE_MOUNT_PATH = os.environ.get(STAGE_MOUNT_PATH_ENV_VAR, "/mnt/job_stage")
|
|
61
|
+
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "output/mljob_result.pkl")
|
|
62
|
+
|
|
60
63
|
|
|
61
64
|
try:
|
|
62
65
|
from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
|
|
@@ -226,6 +229,8 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
226
229
|
# This is needed because mljob_launcher.py is now in /mnt/job_stage/system
|
|
227
230
|
# but user scripts are in the payload directory and may import from each other
|
|
228
231
|
payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
|
|
232
|
+
if payload_dir and not os.path.isabs(payload_dir):
|
|
233
|
+
payload_dir = os.path.join(STAGE_MOUNT_PATH, payload_dir)
|
|
229
234
|
if payload_dir and payload_dir not in sys.path:
|
|
230
235
|
sys.path.insert(0, payload_dir)
|
|
231
236
|
|
|
@@ -276,7 +281,10 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
276
281
|
Exception: Re-raises any exception caught during script execution.
|
|
277
282
|
"""
|
|
278
283
|
# Ensure the output directory exists before trying to write result files.
|
|
279
|
-
|
|
284
|
+
result_abs_path = (
|
|
285
|
+
JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH)
|
|
286
|
+
)
|
|
287
|
+
output_dir = os.path.dirname(result_abs_path)
|
|
280
288
|
os.makedirs(output_dir, exist_ok=True)
|
|
281
289
|
|
|
282
290
|
try:
|
|
@@ -317,7 +325,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
317
325
|
result_dict = result_obj.to_dict()
|
|
318
326
|
try:
|
|
319
327
|
# Serialize result using cloudpickle
|
|
320
|
-
result_pickle_path =
|
|
328
|
+
result_pickle_path = result_abs_path
|
|
321
329
|
with open(result_pickle_path, "wb") as f:
|
|
322
330
|
cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
|
|
323
331
|
except Exception as pkl_exc:
|
|
@@ -326,7 +334,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
326
334
|
try:
|
|
327
335
|
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
|
328
336
|
# TODO: Manually convert non-serializable types to strings
|
|
329
|
-
result_json_path = os.path.splitext(
|
|
337
|
+
result_json_path = os.path.splitext(result_abs_path)[0] + ".json"
|
|
330
338
|
with open(result_json_path, "w") as f:
|
|
331
339
|
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
|
332
340
|
except Exception as json_exc:
|
|
@@ -7,7 +7,7 @@ from typing import Any, Literal, Optional, Union
|
|
|
7
7
|
|
|
8
8
|
from snowflake import snowpark
|
|
9
9
|
from snowflake.ml._internal.utils import snowflake_env
|
|
10
|
-
from snowflake.ml.jobs._utils import constants, query_helper, types
|
|
10
|
+
from snowflake.ml.jobs._utils import constants, feature_flags, query_helper, types
|
|
11
11
|
from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
|
|
12
12
|
|
|
13
13
|
|
|
@@ -63,7 +63,7 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.Image
|
|
|
63
63
|
# Use MLRuntime image
|
|
64
64
|
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
65
65
|
container_image = None
|
|
66
|
-
if
|
|
66
|
+
if feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
|
|
67
67
|
container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
|
|
68
68
|
|
|
69
69
|
if not container_image:
|
|
@@ -98,6 +98,7 @@ def generate_spec_overrides(
|
|
|
98
98
|
container_spec: dict[str, Any] = {
|
|
99
99
|
"name": constants.DEFAULT_CONTAINER_NAME,
|
|
100
100
|
}
|
|
101
|
+
|
|
101
102
|
if environment_vars:
|
|
102
103
|
# TODO: Validate environment variables
|
|
103
104
|
container_spec["env"] = environment_vars
|
|
@@ -213,10 +214,7 @@ def generate_service_spec(
|
|
|
213
214
|
|
|
214
215
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
|
215
216
|
|
|
216
|
-
env_vars =
|
|
217
|
-
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
|
|
218
|
-
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
219
|
-
}
|
|
217
|
+
env_vars = payload.env_vars
|
|
220
218
|
endpoints: list[dict[str, Any]] = []
|
|
221
219
|
|
|
222
220
|
if target_instances > 1:
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
3
|
from pathlib import PurePath
|
|
4
4
|
from typing import Iterator, Literal, Optional, Protocol, Union, runtime_checkable
|
|
5
5
|
|
|
@@ -90,6 +90,7 @@ class UploadedPayload:
|
|
|
90
90
|
# TODO: Include manifest of payload files for validation
|
|
91
91
|
stage_path: PurePath
|
|
92
92
|
entrypoint: list[Union[str, PurePath]]
|
|
93
|
+
env_vars: dict[str, str] = field(default_factory=dict)
|
|
93
94
|
|
|
94
95
|
|
|
95
96
|
@dataclass(frozen=True)
|
snowflake/ml/jobs/job.py
CHANGED
|
@@ -99,21 +99,23 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
99
99
|
result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
|
100
100
|
if result_path_str is None:
|
|
101
101
|
raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
|
|
102
|
-
volume_mounts = self._container_spec["volumeMounts"]
|
|
103
|
-
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
104
102
|
|
|
103
|
+
# If result path is relative, it is relative to the stage mount path
|
|
105
104
|
result_path = Path(result_path_str)
|
|
105
|
+
if not result_path.is_absolute():
|
|
106
|
+
return f"{self._stage_path}/{result_path.as_posix()}"
|
|
107
|
+
|
|
108
|
+
# If result path is absolute, it is relative to the stage mount path
|
|
109
|
+
volume_mounts = self._container_spec["volumeMounts"]
|
|
110
|
+
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
106
111
|
stage_mount = Path(stage_mount_str)
|
|
107
112
|
try:
|
|
108
113
|
relative_path = result_path.relative_to(stage_mount)
|
|
114
|
+
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
109
115
|
except ValueError:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
)
|
|
114
|
-
relative_path = result_path
|
|
115
|
-
|
|
116
|
-
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
|
|
118
|
+
)
|
|
117
119
|
|
|
118
120
|
@overload
|
|
119
121
|
def get_logs(
|
|
@@ -419,15 +421,29 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
|
419
421
|
if not rows:
|
|
420
422
|
return None
|
|
421
423
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
+
# we have already integrated with first_instance startup policy,
|
|
425
|
+
# the instance 0 is guaranteed to be the head instance
|
|
426
|
+
head_instance = next(
|
|
427
|
+
(
|
|
428
|
+
row
|
|
429
|
+
for row in rows
|
|
430
|
+
if "instance_id" in row and row["instance_id"] is not None and int(row["instance_id"]) == 0
|
|
431
|
+
),
|
|
432
|
+
None,
|
|
433
|
+
)
|
|
434
|
+
# fallback to find the first instance if the instance 0 is not found
|
|
435
|
+
if not head_instance:
|
|
436
|
+
if target_instances > len(rows):
|
|
437
|
+
raise RuntimeError(
|
|
438
|
+
f"Couldn’t retrieve head instance due to missing instances. {target_instances} > {len(rows)}"
|
|
439
|
+
)
|
|
440
|
+
# Sort by start_time first, then by instance_id
|
|
441
|
+
try:
|
|
442
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
|
443
|
+
except TypeError:
|
|
444
|
+
raise RuntimeError("Job instance information unavailable.")
|
|
445
|
+
head_instance = sorted_instances[0]
|
|
424
446
|
|
|
425
|
-
# Sort by start_time first, then by instance_id
|
|
426
|
-
try:
|
|
427
|
-
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
|
428
|
-
except TypeError:
|
|
429
|
-
raise RuntimeError("Job instance information unavailable.")
|
|
430
|
-
head_instance = sorted_instances[0]
|
|
431
447
|
if not head_instance["start_time"]:
|
|
432
448
|
# If head instance hasn't started yet, return None
|
|
433
449
|
return None
|