snowflake-ml-python 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/telemetry.py +3 -1
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/experiment_tracking.py +113 -6
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/constants.py +8 -16
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +19 -5
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +23 -5
- snowflake/ml/jobs/_utils/spec_utils.py +4 -6
- snowflake/ml/jobs/_utils/types.py +2 -1
- snowflake/ml/jobs/job.py +38 -19
- snowflake/ml/jobs/manager.py +136 -19
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +25 -0
- snowflake/ml/model/_client/model/model_version_impl.py +62 -65
- snowflake/ml/model/_client/ops/model_ops.py +42 -9
- snowflake/ml/model/_client/ops/service_ops.py +75 -154
- snowflake/ml/model/_client/service/model_deployment_spec.py +23 -37
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +15 -4
- snowflake/ml/model/_client/sql/service.py +4 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +309 -22
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/models/huggingface_pipeline.py +23 -0
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/METADATA +82 -5
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/RECORD +198 -194
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import logging
|
|
4
5
|
import re
|
|
5
6
|
import warnings
|
|
6
7
|
from collections import OrderedDict
|
|
@@ -31,10 +32,12 @@ from snowflake.snowpark.types import (
|
|
|
31
32
|
_NumericType,
|
|
32
33
|
)
|
|
33
34
|
|
|
35
|
+
_DEFAULT_TARGET_LAG = "10 seconds"
|
|
34
36
|
_FEATURE_VIEW_NAME_DELIMITER = "$"
|
|
35
37
|
_LEGACY_TIMESTAMP_COL_PLACEHOLDER_VALS = ["FS_TIMESTAMP_COL_PLACEHOLDER_VAL", "NULL"]
|
|
36
38
|
_TIMESTAMP_COL_PLACEHOLDER = "NULL"
|
|
37
39
|
_FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE"
|
|
40
|
+
_ONLINE_TABLE_SUFFIX = "$ONLINE"
|
|
38
41
|
# Feature view version rule is aligned with dataset version rule in SQL.
|
|
39
42
|
_FEATURE_VIEW_VERSION_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.\-]*$")
|
|
40
43
|
_FEATURE_VIEW_VERSION_MAX_LENGTH = 128
|
|
@@ -45,6 +48,44 @@ _RESULT_SCAN_QUERY_PATTERN = re.compile(
|
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class OnlineConfig:
|
|
53
|
+
"""Configuration for online feature storage."""
|
|
54
|
+
|
|
55
|
+
enable: bool = False
|
|
56
|
+
target_lag: Optional[str] = None
|
|
57
|
+
|
|
58
|
+
def __post_init__(self) -> None:
|
|
59
|
+
if self.target_lag is None:
|
|
60
|
+
return
|
|
61
|
+
if not isinstance(self.target_lag, str) or not self.target_lag.strip():
|
|
62
|
+
raise ValueError("target_lag must be a non-empty string")
|
|
63
|
+
|
|
64
|
+
object.__setattr__(self, "target_lag", self.target_lag.strip())
|
|
65
|
+
|
|
66
|
+
def to_json(self) -> str:
|
|
67
|
+
data: dict[str, Any] = asdict(self)
|
|
68
|
+
return json.dumps(data)
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_json(cls, json_str: str) -> OnlineConfig:
|
|
72
|
+
data = json.loads(json_str)
|
|
73
|
+
return cls(**data)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class StoreType(Enum):
|
|
77
|
+
"""
|
|
78
|
+
Enumeration for specifying the storage type when reading from or refreshing feature views.
|
|
79
|
+
|
|
80
|
+
The Feature View supports two storage modes:
|
|
81
|
+
- OFFLINE: Traditional batch storage for historical feature data and training
|
|
82
|
+
- ONLINE: Low-latency storage optimized for real-time feature serving
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
ONLINE = "online"
|
|
86
|
+
OFFLINE = "offline"
|
|
87
|
+
|
|
88
|
+
|
|
48
89
|
@dataclass(frozen=True)
|
|
49
90
|
class _FeatureViewMetadata:
|
|
50
91
|
"""Represent metadata tracked on top of FV backend object"""
|
|
@@ -171,6 +212,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
171
212
|
initialize: str = "ON_CREATE",
|
|
172
213
|
refresh_mode: str = "AUTO",
|
|
173
214
|
cluster_by: Optional[list[str]] = None,
|
|
215
|
+
online_config: Optional[OnlineConfig] = None,
|
|
174
216
|
**_kwargs: Any,
|
|
175
217
|
) -> None:
|
|
176
218
|
"""
|
|
@@ -204,6 +246,8 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
204
246
|
cluster_by: Columns to cluster the feature view by.
|
|
205
247
|
- Defaults to the join keys from entities.
|
|
206
248
|
- If `timestamp_col` is provided, it is added to the default clustering keys.
|
|
249
|
+
online_config: Optional configuration for online storage. If provided with enable=True,
|
|
250
|
+
online storage will be enabled. Defaults to None (no online storage).
|
|
207
251
|
_kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
|
|
208
252
|
|
|
209
253
|
Example::
|
|
@@ -227,9 +271,26 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
227
271
|
>>> registered_fv = fs.register_feature_view(draft_fv, "v1")
|
|
228
272
|
>>> print(registered_fv.status)
|
|
229
273
|
FeatureViewStatus.ACTIVE
|
|
274
|
+
<BLANKLINE>
|
|
275
|
+
>>> # Example with online configuration for online feature storage
|
|
276
|
+
>>> config = OnlineConfig(enable=True, target_lag='15s')
|
|
277
|
+
>>> online_fv = FeatureView(
|
|
278
|
+
... name="my_online_fv",
|
|
279
|
+
... entities=[e1, e2],
|
|
280
|
+
... feature_df=feature_df,
|
|
281
|
+
... timestamp_col='TS',
|
|
282
|
+
... refresh_freq='1d',
|
|
283
|
+
... desc='Feature view with online storage',
|
|
284
|
+
... online_config=config # optional, enables online feature storage
|
|
285
|
+
... )
|
|
286
|
+
>>> registered_online_fv = fs.register_feature_view(online_fv, "v1")
|
|
287
|
+
>>> print(registered_online_fv.online)
|
|
288
|
+
True
|
|
230
289
|
|
|
231
290
|
# noqa: DAR401
|
|
232
291
|
"""
|
|
292
|
+
if online_config is not None:
|
|
293
|
+
logging.warning("'online_config' is in private preview since 1.12.0. Do not use it in production.")
|
|
233
294
|
|
|
234
295
|
self._name: SqlIdentifier = SqlIdentifier(name)
|
|
235
296
|
self._entities: list[Entity] = entities
|
|
@@ -257,6 +318,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
257
318
|
self._cluster_by: list[SqlIdentifier] = (
|
|
258
319
|
[SqlIdentifier(col) for col in cluster_by] if cluster_by is not None else self._get_default_cluster_by()
|
|
259
320
|
)
|
|
321
|
+
self._online_config: Optional[OnlineConfig] = online_config
|
|
260
322
|
|
|
261
323
|
# Validate kwargs
|
|
262
324
|
if _kwargs:
|
|
@@ -470,6 +532,31 @@ class FeatureView(lineage_node.LineageNode):
|
|
|
470
532
|
def feature_descs(self) -> Optional[dict[SqlIdentifier, str]]:
|
|
471
533
|
return self._feature_desc
|
|
472
534
|
|
|
535
|
+
@property
|
|
536
|
+
def online(self) -> bool:
|
|
537
|
+
return self._online_config.enable if self._online_config else False
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def online_config(self) -> Optional[OnlineConfig]:
|
|
541
|
+
return self._online_config
|
|
542
|
+
|
|
543
|
+
def fully_qualified_online_table_name(self) -> str:
|
|
544
|
+
"""Get the fully qualified name for the online feature table.
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
The fully qualified name (<database_name>.<schema_name>.<online_table_name>) for the
|
|
548
|
+
online feature table in Snowflake.
|
|
549
|
+
|
|
550
|
+
Raises:
|
|
551
|
+
RuntimeError: if the FeatureView is not registered or not configured for online storage.
|
|
552
|
+
"""
|
|
553
|
+
if self.status == FeatureViewStatus.DRAFT or self.version is None:
|
|
554
|
+
raise RuntimeError(f"FeatureView {self.name} has not been registered.")
|
|
555
|
+
if not self.online:
|
|
556
|
+
raise RuntimeError(f"FeatureView {self.name} is not configured for online storage.")
|
|
557
|
+
online_table_name = self._get_online_table_name(self.name, self.version)
|
|
558
|
+
return f"{self._database}.{self._schema}.{online_table_name}"
|
|
559
|
+
|
|
473
560
|
def list_columns(self) -> DataFrame:
|
|
474
561
|
"""List all columns and their information.
|
|
475
562
|
|
|
@@ -756,6 +843,8 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
756
843
|
feature_desc_dict[k.identifier()] = v
|
|
757
844
|
fv_dict["_feature_desc"] = feature_desc_dict
|
|
758
845
|
|
|
846
|
+
fv_dict["_online_config"] = self._online_config.to_json() if self._online_config is not None else None
|
|
847
|
+
|
|
759
848
|
lineage_node_keys = [key for key in fv_dict if key.startswith("_node") or key == "_session"]
|
|
760
849
|
|
|
761
850
|
for key in lineage_node_keys:
|
|
@@ -844,6 +933,9 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
844
933
|
owner=json_dict["_owner"],
|
|
845
934
|
infer_schema_df=session.sql(json_dict.get("_infer_schema_query", None)),
|
|
846
935
|
session=session,
|
|
936
|
+
online_config=OnlineConfig.from_json(json_dict["_online_config"])
|
|
937
|
+
if json_dict.get("_online_config")
|
|
938
|
+
else None,
|
|
847
939
|
)
|
|
848
940
|
|
|
849
941
|
def _get_compact_repr(self) -> _CompactRepresentation:
|
|
@@ -916,6 +1008,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
916
1008
|
infer_schema_df: Optional[DataFrame],
|
|
917
1009
|
session: Session,
|
|
918
1010
|
cluster_by: Optional[list[str]] = None,
|
|
1011
|
+
online_config: Optional[OnlineConfig] = None,
|
|
919
1012
|
) -> FeatureView:
|
|
920
1013
|
fv = FeatureView(
|
|
921
1014
|
name=name,
|
|
@@ -925,6 +1018,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
925
1018
|
desc=desc,
|
|
926
1019
|
_infer_schema_df=infer_schema_df,
|
|
927
1020
|
cluster_by=cluster_by,
|
|
1021
|
+
online_config=online_config,
|
|
928
1022
|
)
|
|
929
1023
|
fv._version = FeatureViewVersion(version) if version is not None else None
|
|
930
1024
|
fv._status = status
|
|
@@ -961,5 +1055,33 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
|
961
1055
|
|
|
962
1056
|
return default_cluster_by_cols
|
|
963
1057
|
|
|
1058
|
+
@staticmethod
|
|
1059
|
+
def _get_online_table_name(
|
|
1060
|
+
feature_view_name: Union[SqlIdentifier, str], version: Optional[Union[FeatureViewVersion, str]] = None
|
|
1061
|
+
) -> SqlIdentifier:
|
|
1062
|
+
"""Get the online feature table name without qualification.
|
|
1063
|
+
|
|
1064
|
+
Args:
|
|
1065
|
+
feature_view_name: Offline feature view name.
|
|
1066
|
+
version: Feature view version. If not provided, feature_view_name must be a SqlIdentifier.
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
The online table name SqlIdentifier
|
|
1070
|
+
"""
|
|
1071
|
+
if version is None:
|
|
1072
|
+
assert isinstance(feature_view_name, SqlIdentifier), "Single argument must be SqlIdentifier"
|
|
1073
|
+
online_name = f"{feature_view_name.resolved()}{_ONLINE_TABLE_SUFFIX}"
|
|
1074
|
+
return SqlIdentifier(online_name, case_sensitive=True)
|
|
1075
|
+
else:
|
|
1076
|
+
fv_name = (
|
|
1077
|
+
feature_view_name
|
|
1078
|
+
if isinstance(feature_view_name, SqlIdentifier)
|
|
1079
|
+
else SqlIdentifier(feature_view_name, case_sensitive=True)
|
|
1080
|
+
)
|
|
1081
|
+
fv_version = version if isinstance(version, FeatureViewVersion) else FeatureViewVersion(version)
|
|
1082
|
+
physical_name = FeatureView._get_physical_name(fv_name, fv_version).resolved()
|
|
1083
|
+
online_name = f"{physical_name}{_ONLINE_TABLE_SUFFIX}"
|
|
1084
|
+
return SqlIdentifier(online_name, case_sensitive=True)
|
|
1085
|
+
|
|
964
1086
|
|
|
965
1087
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["feature_view"] = FeatureView
|
|
@@ -3,26 +3,23 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
|
3
3
|
|
|
4
4
|
# SPCS specification constants
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
|
6
|
+
MEMORY_VOLUME_NAME = "dshm"
|
|
7
|
+
STAGE_VOLUME_NAME = "stage-volume"
|
|
8
|
+
|
|
9
|
+
# Environment variables
|
|
10
|
+
STAGE_MOUNT_PATH_ENV_VAR = "MLRS_STAGE_MOUNT_PATH"
|
|
6
11
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
|
7
12
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
|
8
13
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
|
9
14
|
TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
|
|
10
15
|
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
|
11
|
-
MEMORY_VOLUME_NAME = "dshm"
|
|
12
|
-
STAGE_VOLUME_NAME = "stage-volume"
|
|
13
|
-
# Base mount path
|
|
14
|
-
STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
15
16
|
|
|
16
|
-
# Stage
|
|
17
|
+
# Stage mount paths
|
|
18
|
+
STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
17
19
|
APP_STAGE_SUBPATH = "app"
|
|
18
20
|
SYSTEM_STAGE_SUBPATH = "system"
|
|
19
21
|
OUTPUT_STAGE_SUBPATH = "output"
|
|
20
|
-
|
|
21
|
-
# Complete mount paths (automatically generated from base + subpath)
|
|
22
|
-
APP_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{APP_STAGE_SUBPATH}"
|
|
23
|
-
SYSTEM_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{SYSTEM_STAGE_SUBPATH}"
|
|
24
|
-
OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
|
|
25
|
-
|
|
22
|
+
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result.pkl"
|
|
26
23
|
|
|
27
24
|
# Default container image information
|
|
28
25
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
|
@@ -59,8 +56,6 @@ ENABLE_HEALTH_CHECKS = "false"
|
|
|
59
56
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
|
60
57
|
JOB_POLL_MAX_DELAY_SECONDS = 30
|
|
61
58
|
|
|
62
|
-
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_MOUNT_PATH}/mljob_result.pkl"
|
|
63
|
-
|
|
64
59
|
# Log start and end messages
|
|
65
60
|
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|
|
66
61
|
LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
|
|
@@ -98,6 +93,3 @@ CLOUD_INSTANCE_FAMILIES = {
|
|
|
98
93
|
SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
|
|
99
94
|
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
|
100
95
|
}
|
|
101
|
-
|
|
102
|
-
# runtime version environment variable
|
|
103
|
-
ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FeatureFlags(Enum):
|
|
6
|
+
USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
|
|
7
|
+
ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS"
|
|
8
|
+
|
|
9
|
+
def is_enabled(self) -> bool:
|
|
10
|
+
return os.getenv(self.value, "false").lower() == "true"
|
|
11
|
+
|
|
12
|
+
def is_disabled(self) -> bool:
|
|
13
|
+
return not self.is_enabled()
|
|
14
|
+
|
|
15
|
+
def __str__(self) -> str:
|
|
16
|
+
return self.value
|
|
@@ -60,7 +60,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
60
60
|
|
|
61
61
|
# Change directory to user payload directory
|
|
62
62
|
if [ -n "${constants.PAYLOAD_DIR_ENV_VAR}" ]; then
|
|
63
|
-
cd ${constants.PAYLOAD_DIR_ENV_VAR}
|
|
63
|
+
cd ${constants.STAGE_MOUNT_PATH_ENV_VAR}/${constants.PAYLOAD_DIR_ENV_VAR}
|
|
64
64
|
fi
|
|
65
65
|
|
|
66
66
|
##### Set up Python environment #####
|
|
@@ -69,7 +69,10 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
69
69
|
|
|
70
70
|
if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
|
|
71
71
|
echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
|
|
72
|
-
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
72
|
+
if ! pip install --no-index -r $MLRS_SYSTEM_REQUIREMENTS_FILE; then
|
|
73
|
+
echo "Offline install failed, falling back to regular pip install"
|
|
74
|
+
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
75
|
+
fi
|
|
73
76
|
fi
|
|
74
77
|
|
|
75
78
|
MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
|
|
@@ -535,19 +538,30 @@ class JobPayload:
|
|
|
535
538
|
|
|
536
539
|
upload_system_resources(session, system_stage_path)
|
|
537
540
|
python_entrypoint: list[Union[str, PurePath]] = [
|
|
538
|
-
PurePath(
|
|
539
|
-
PurePath(
|
|
541
|
+
PurePath(constants.STAGE_VOLUME_MOUNT_PATH, constants.SYSTEM_STAGE_SUBPATH, "mljob_launcher.py"),
|
|
542
|
+
PurePath(
|
|
543
|
+
constants.STAGE_VOLUME_MOUNT_PATH,
|
|
544
|
+
constants.APP_STAGE_SUBPATH,
|
|
545
|
+
entrypoint.file_path.relative_to(source).as_posix(),
|
|
546
|
+
),
|
|
540
547
|
]
|
|
541
548
|
if entrypoint.main_func:
|
|
542
549
|
python_entrypoint += ["--script_main_func", entrypoint.main_func]
|
|
543
550
|
|
|
551
|
+
env_vars = {
|
|
552
|
+
constants.STAGE_MOUNT_PATH_ENV_VAR: constants.STAGE_VOLUME_MOUNT_PATH,
|
|
553
|
+
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_STAGE_SUBPATH,
|
|
554
|
+
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
555
|
+
}
|
|
556
|
+
|
|
544
557
|
return types.UploadedPayload(
|
|
545
558
|
stage_path=stage_path,
|
|
546
559
|
entrypoint=[
|
|
547
560
|
"bash",
|
|
548
|
-
f"{constants.
|
|
561
|
+
f"{constants.STAGE_VOLUME_MOUNT_PATH}/{constants.SYSTEM_STAGE_SUBPATH}/{_STARTUP_SCRIPT_PATH}",
|
|
549
562
|
*python_entrypoint,
|
|
550
563
|
],
|
|
564
|
+
env_vars=env_vars,
|
|
551
565
|
)
|
|
552
566
|
|
|
553
567
|
|
|
@@ -41,18 +41,29 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
|
|
|
41
41
|
from snowflake.runtime.utils import session_utils
|
|
42
42
|
|
|
43
43
|
session = session_utils.get_session()
|
|
44
|
-
|
|
45
|
-
result = df.select('"instance_id"', '"ip_address"', '"start_time"', '"status"').collect()
|
|
44
|
+
result = session.sql(f"show service instances in service {service_name}").collect()
|
|
46
45
|
|
|
47
46
|
if not result:
|
|
48
47
|
return None
|
|
49
|
-
|
|
50
|
-
#
|
|
51
|
-
|
|
52
|
-
|
|
48
|
+
# we have already integrated with first_instance startup policy,
|
|
49
|
+
# the instance 0 is guaranteed to be the head instance
|
|
50
|
+
head_instance = next(
|
|
51
|
+
(
|
|
52
|
+
row
|
|
53
|
+
for row in result
|
|
54
|
+
if "instance_id" in row and row["instance_id"] is not None and int(row["instance_id"]) == 0
|
|
55
|
+
),
|
|
56
|
+
None,
|
|
57
|
+
)
|
|
58
|
+
# fallback to find the first instance if the instance 0 is not found
|
|
59
|
+
if not head_instance:
|
|
60
|
+
# Sort by start_time first, then by instance_id. If start_time is null/empty, it will be sorted to the end.
|
|
61
|
+
sorted_instances = sorted(
|
|
62
|
+
result, key=lambda x: (not bool(x["start_time"]), x["start_time"], int(x["instance_id"]))
|
|
63
|
+
)
|
|
64
|
+
head_instance = sorted_instances[0]
|
|
53
65
|
if not head_instance["instance_id"] or not head_instance["ip_address"]:
|
|
54
66
|
return None
|
|
55
|
-
|
|
56
67
|
# Validate head instance IP
|
|
57
68
|
ip_address = head_instance["ip_address"]
|
|
58
69
|
try:
|
|
@@ -48,8 +48,8 @@ MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_IN
|
|
|
48
48
|
TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
|
|
49
49
|
|
|
50
50
|
# Fallbacks in case of SnowML version mismatch
|
|
51
|
+
STAGE_MOUNT_PATH_ENV_VAR = getattr(constants, "STAGE_MOUNT_PATH_ENV_VAR", "MLRS_STAGE_MOUNT_PATH")
|
|
51
52
|
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
|
52
|
-
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
|
|
53
53
|
PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
|
|
54
54
|
|
|
55
55
|
# Constants for the wait_for_instances function
|
|
@@ -57,6 +57,9 @@ MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
|
|
|
57
57
|
TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
|
|
58
58
|
CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
|
|
59
59
|
|
|
60
|
+
STAGE_MOUNT_PATH = os.environ.get(STAGE_MOUNT_PATH_ENV_VAR, "/mnt/job_stage")
|
|
61
|
+
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "output/mljob_result.pkl")
|
|
62
|
+
|
|
60
63
|
|
|
61
64
|
try:
|
|
62
65
|
from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
|
|
@@ -226,12 +229,16 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
226
229
|
# This is needed because mljob_launcher.py is now in /mnt/job_stage/system
|
|
227
230
|
# but user scripts are in the payload directory and may import from each other
|
|
228
231
|
payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
|
|
232
|
+
if payload_dir and not os.path.isabs(payload_dir):
|
|
233
|
+
payload_dir = os.path.join(STAGE_MOUNT_PATH, payload_dir)
|
|
229
234
|
if payload_dir and payload_dir not in sys.path:
|
|
230
235
|
sys.path.insert(0, payload_dir)
|
|
231
236
|
|
|
232
237
|
# Create a Snowpark session before running the script
|
|
233
238
|
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
234
|
-
|
|
239
|
+
config = SnowflakeLoginOptions()
|
|
240
|
+
config["client_session_keep_alive"] = "True"
|
|
241
|
+
session = Session.builder.configs(config).create() # noqa: F841
|
|
235
242
|
|
|
236
243
|
try:
|
|
237
244
|
|
|
@@ -259,6 +266,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
259
266
|
finally:
|
|
260
267
|
# Restore original sys.argv
|
|
261
268
|
sys.argv = original_argv
|
|
269
|
+
session.close()
|
|
262
270
|
|
|
263
271
|
|
|
264
272
|
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
|
|
@@ -276,9 +284,19 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
276
284
|
Exception: Re-raises any exception caught during script execution.
|
|
277
285
|
"""
|
|
278
286
|
# Ensure the output directory exists before trying to write result files.
|
|
279
|
-
|
|
287
|
+
result_abs_path = (
|
|
288
|
+
JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH)
|
|
289
|
+
)
|
|
290
|
+
output_dir = os.path.dirname(result_abs_path)
|
|
280
291
|
os.makedirs(output_dir, exist_ok=True)
|
|
281
292
|
|
|
293
|
+
try:
|
|
294
|
+
import ray
|
|
295
|
+
|
|
296
|
+
ray.init(address="auto")
|
|
297
|
+
except ModuleNotFoundError:
|
|
298
|
+
warnings.warn("Ray is not installed, skipping Ray initialization", ImportWarning, stacklevel=1)
|
|
299
|
+
|
|
282
300
|
try:
|
|
283
301
|
# Wait for minimum required instances if specified
|
|
284
302
|
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
|
|
@@ -317,7 +335,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
317
335
|
result_dict = result_obj.to_dict()
|
|
318
336
|
try:
|
|
319
337
|
# Serialize result using cloudpickle
|
|
320
|
-
result_pickle_path =
|
|
338
|
+
result_pickle_path = result_abs_path
|
|
321
339
|
with open(result_pickle_path, "wb") as f:
|
|
322
340
|
cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
|
|
323
341
|
except Exception as pkl_exc:
|
|
@@ -326,7 +344,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
326
344
|
try:
|
|
327
345
|
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
|
328
346
|
# TODO: Manually convert non-serializable types to strings
|
|
329
|
-
result_json_path = os.path.splitext(
|
|
347
|
+
result_json_path = os.path.splitext(result_abs_path)[0] + ".json"
|
|
330
348
|
with open(result_json_path, "w") as f:
|
|
331
349
|
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
|
332
350
|
except Exception as json_exc:
|
|
@@ -7,7 +7,7 @@ from typing import Any, Literal, Optional, Union
|
|
|
7
7
|
|
|
8
8
|
from snowflake import snowpark
|
|
9
9
|
from snowflake.ml._internal.utils import snowflake_env
|
|
10
|
-
from snowflake.ml.jobs._utils import constants, query_helper, types
|
|
10
|
+
from snowflake.ml.jobs._utils import constants, feature_flags, query_helper, types
|
|
11
11
|
from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
|
|
12
12
|
|
|
13
13
|
|
|
@@ -63,7 +63,7 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.Image
|
|
|
63
63
|
# Use MLRuntime image
|
|
64
64
|
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
65
65
|
container_image = None
|
|
66
|
-
if
|
|
66
|
+
if feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
|
|
67
67
|
container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
|
|
68
68
|
|
|
69
69
|
if not container_image:
|
|
@@ -98,6 +98,7 @@ def generate_spec_overrides(
|
|
|
98
98
|
container_spec: dict[str, Any] = {
|
|
99
99
|
"name": constants.DEFAULT_CONTAINER_NAME,
|
|
100
100
|
}
|
|
101
|
+
|
|
101
102
|
if environment_vars:
|
|
102
103
|
# TODO: Validate environment variables
|
|
103
104
|
container_spec["env"] = environment_vars
|
|
@@ -213,10 +214,7 @@ def generate_service_spec(
|
|
|
213
214
|
|
|
214
215
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
|
215
216
|
|
|
216
|
-
env_vars =
|
|
217
|
-
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
|
|
218
|
-
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
219
|
-
}
|
|
217
|
+
env_vars = payload.env_vars
|
|
220
218
|
endpoints: list[dict[str, Any]] = []
|
|
221
219
|
|
|
222
220
|
if target_instances > 1:
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
3
|
from pathlib import PurePath
|
|
4
4
|
from typing import Iterator, Literal, Optional, Protocol, Union, runtime_checkable
|
|
5
5
|
|
|
@@ -90,6 +90,7 @@ class UploadedPayload:
|
|
|
90
90
|
# TODO: Include manifest of payload files for validation
|
|
91
91
|
stage_path: PurePath
|
|
92
92
|
entrypoint: list[Union[str, PurePath]]
|
|
93
|
+
env_vars: dict[str, str] = field(default_factory=dict)
|
|
93
94
|
|
|
94
95
|
|
|
95
96
|
@dataclass(frozen=True)
|
snowflake/ml/jobs/job.py
CHANGED
|
@@ -50,7 +50,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
50
50
|
def min_instances(self) -> int:
|
|
51
51
|
try:
|
|
52
52
|
return int(self._container_spec["env"].get(constants.MIN_INSTANCES_ENV_VAR, 1))
|
|
53
|
-
except TypeError:
|
|
53
|
+
except (TypeError, ValueError):
|
|
54
54
|
return 1
|
|
55
55
|
|
|
56
56
|
@property
|
|
@@ -83,7 +83,10 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
83
83
|
def _container_spec(self) -> dict[str, Any]:
|
|
84
84
|
"""Get the job's main container spec."""
|
|
85
85
|
containers = self._service_spec["spec"]["containers"]
|
|
86
|
-
|
|
86
|
+
try:
|
|
87
|
+
container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
|
|
88
|
+
except StopIteration:
|
|
89
|
+
raise ValueError(f"Container '{constants.DEFAULT_CONTAINER_NAME}' not found in job {self.name}")
|
|
87
90
|
return cast(dict[str, Any], container_spec)
|
|
88
91
|
|
|
89
92
|
@property
|
|
@@ -99,21 +102,23 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
99
102
|
result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
|
100
103
|
if result_path_str is None:
|
|
101
104
|
raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
|
|
102
|
-
volume_mounts = self._container_spec["volumeMounts"]
|
|
103
|
-
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
104
105
|
|
|
106
|
+
# If result path is relative, it is relative to the stage mount path
|
|
105
107
|
result_path = Path(result_path_str)
|
|
108
|
+
if not result_path.is_absolute():
|
|
109
|
+
return f"{self._stage_path}/{result_path.as_posix()}"
|
|
110
|
+
|
|
111
|
+
# If result path is absolute, it is relative to the stage mount path
|
|
112
|
+
volume_mounts = self._container_spec["volumeMounts"]
|
|
113
|
+
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
106
114
|
stage_mount = Path(stage_mount_str)
|
|
107
115
|
try:
|
|
108
116
|
relative_path = result_path.relative_to(stage_mount)
|
|
117
|
+
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
109
118
|
except ValueError:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
)
|
|
114
|
-
relative_path = result_path
|
|
115
|
-
|
|
116
|
-
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
|
|
121
|
+
)
|
|
117
122
|
|
|
118
123
|
@overload
|
|
119
124
|
def get_logs(
|
|
@@ -419,15 +424,29 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
|
419
424
|
if not rows:
|
|
420
425
|
return None
|
|
421
426
|
|
|
422
|
-
|
|
423
|
-
|
|
427
|
+
# we have already integrated with first_instance startup policy,
|
|
428
|
+
# the instance 0 is guaranteed to be the head instance
|
|
429
|
+
head_instance = next(
|
|
430
|
+
(
|
|
431
|
+
row
|
|
432
|
+
for row in rows
|
|
433
|
+
if "instance_id" in row and row["instance_id"] is not None and int(row["instance_id"]) == 0
|
|
434
|
+
),
|
|
435
|
+
None,
|
|
436
|
+
)
|
|
437
|
+
# fallback to find the first instance if the instance 0 is not found
|
|
438
|
+
if not head_instance:
|
|
439
|
+
if target_instances > len(rows):
|
|
440
|
+
raise RuntimeError(
|
|
441
|
+
f"Couldn’t retrieve head instance due to missing instances. {target_instances} > {len(rows)}"
|
|
442
|
+
)
|
|
443
|
+
# Sort by start_time first, then by instance_id
|
|
444
|
+
try:
|
|
445
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
|
446
|
+
except TypeError:
|
|
447
|
+
raise RuntimeError("Job instance information unavailable.")
|
|
448
|
+
head_instance = sorted_instances[0]
|
|
424
449
|
|
|
425
|
-
# Sort by start_time first, then by instance_id
|
|
426
|
-
try:
|
|
427
|
-
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
|
428
|
-
except TypeError:
|
|
429
|
-
raise RuntimeError("Job instance information unavailable.")
|
|
430
|
-
head_instance = sorted_instances[0]
|
|
431
450
|
if not head_instance["start_time"]:
|
|
432
451
|
# If head instance hasn't started yet, return None
|
|
433
452
|
return None
|