google-cloud-pipeline-components 2.14.0__py3-none-any.whl → 2.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of google-cloud-pipeline-components might be problematic. Click here for more details.
- google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +10 -26
- google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py +1 -1
- google_cloud_pipeline_components/_implementation/llm/infer_preprocessor.py +109 -0
- google_cloud_pipeline_components/_implementation/llm/online_evaluation_pairwise.py +8 -0
- google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py +5 -6
- google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py +24 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/__init__.py +0 -12
- google_cloud_pipeline_components/_implementation/model_evaluation/llm_embedding/evaluation_llm_embedding_pipeline.py +2 -1
- google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py +14 -0
- google_cloud_pipeline_components/_implementation/starry_net/__init__.py +41 -0
- google_cloud_pipeline_components/_implementation/{model_evaluation/import_evaluation → starry_net/dataprep}/__init__.py +1 -2
- google_cloud_pipeline_components/_implementation/starry_net/dataprep/component.py +159 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/component.py +23 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/evaluation.yaml +197 -0
- google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py +62 -0
- google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/component.py +77 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/component.py +97 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/component.py +76 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_test_set/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_test_set/component.py +48 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/component.py +70 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_train_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_train_args/component.py +90 -0
- google_cloud_pipeline_components/_implementation/starry_net/train/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/train/component.py +209 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/component.py +59 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/component.py +23 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/upload_model.yaml +37 -0
- google_cloud_pipeline_components/_implementation/starry_net/version.py +18 -0
- google_cloud_pipeline_components/container/utils/error_surfacing.py +45 -0
- google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py +36 -7
- google_cloud_pipeline_components/preview/llm/infer/component.py +22 -25
- google_cloud_pipeline_components/preview/llm/rlhf/component.py +15 -8
- google_cloud_pipeline_components/preview/model_evaluation/__init__.py +4 -1
- google_cloud_pipeline_components/{_implementation/model_evaluation/import_evaluation/component.py → preview/model_evaluation/model_evaluation_import_component.py} +4 -3
- google_cloud_pipeline_components/preview/starry_net/__init__.py +19 -0
- google_cloud_pipeline_components/preview/starry_net/component.py +443 -0
- google_cloud_pipeline_components/proto/task_error_pb2.py +32 -0
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_predict_pipeline.yaml +13 -13
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer.py +10 -0
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer_pipeline.yaml +4 -1
- google_cloud_pipeline_components/v1/model_evaluation/error_analysis_pipeline.py +8 -10
- google_cloud_pipeline_components/v1/model_evaluation/evaluated_annotation_pipeline.py +2 -2
- google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_tabular_feature_attribution_pipeline.py +2 -2
- google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_tabular_pipeline.py +2 -2
- google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_unstructure_data_pipeline.py +2 -2
- google_cloud_pipeline_components/v1/model_evaluation/evaluation_feature_attribution_pipeline.py +2 -2
- google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_classification_pipeline.py +4 -2
- google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +8 -2
- google_cloud_pipeline_components/v1/model_evaluation/model_based_llm_evaluation/autosxs/autosxs_pipeline.py +1 -0
- google_cloud_pipeline_components/version.py +1 -1
- {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/METADATA +17 -20
- {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/RECORD +64 -32
- {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/WHEEL +1 -1
- {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/LICENSE +0 -0
- {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net Set Eval Args Component."""
|
|
15
|
+
|
|
16
|
+
from typing import List, NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component
|
|
22
|
+
def set_eval_args(
|
|
23
|
+
big_query_source: dsl.Input[dsl.Artifact], quantiles: List[float]
|
|
24
|
+
) -> NamedTuple(
|
|
25
|
+
'EvalArgs',
|
|
26
|
+
big_query_source=str,
|
|
27
|
+
forecasting_type=str,
|
|
28
|
+
quantiles=List[float],
|
|
29
|
+
prediction_score_column=str,
|
|
30
|
+
):
|
|
31
|
+
# fmt: off
|
|
32
|
+
"""Creates Evaluation args.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
big_query_source: The BQ Table containing the test set.
|
|
36
|
+
quantiles: The quantiles the model was trained to output.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A NamedTuple containing big_query_source as a string, forecasting_type
|
|
40
|
+
used for evaluation step, quantiles in the format expected by the evaluation
|
|
41
|
+
job, and the prediction_score_column used to evaluate.
|
|
42
|
+
"""
|
|
43
|
+
outputs = NamedTuple(
|
|
44
|
+
'EvalArgs',
|
|
45
|
+
big_query_source=str,
|
|
46
|
+
forecasting_type=str,
|
|
47
|
+
quantiles=List[float],
|
|
48
|
+
prediction_score_column=str)
|
|
49
|
+
|
|
50
|
+
def set_forecasting_type_for_eval(quantiles: List[float]) -> str:
|
|
51
|
+
if quantiles and quantiles[-1] != 0.5:
|
|
52
|
+
return 'quantile'
|
|
53
|
+
return 'point'
|
|
54
|
+
|
|
55
|
+
def set_quantiles_for_eval(quantiles: List[float]) -> List[float]:
|
|
56
|
+
updated_q = [q for q in quantiles if q != 0.5]
|
|
57
|
+
if updated_q:
|
|
58
|
+
updated_q = [0.5] + updated_q
|
|
59
|
+
return updated_q
|
|
60
|
+
|
|
61
|
+
def set_prediction_score_column(
|
|
62
|
+
quantiles: List[float]) -> str:
|
|
63
|
+
updated_q = [q for q in quantiles if q != 0.5]
|
|
64
|
+
if updated_q:
|
|
65
|
+
return 'predicted_x.quantile_predictions'
|
|
66
|
+
return 'predicted_x.value'
|
|
67
|
+
|
|
68
|
+
project_id = big_query_source.metadata['projectId']
|
|
69
|
+
dataset_id = big_query_source.metadata['datasetId']
|
|
70
|
+
table_id = big_query_source.metadata['tableId']
|
|
71
|
+
return outputs(
|
|
72
|
+
f'bq://{project_id}.{dataset_id}.{table_id}', # pylint: disable=too-many-function-args
|
|
73
|
+
set_forecasting_type_for_eval(quantiles), # pylint: disable=too-many-function-args
|
|
74
|
+
set_quantiles_for_eval(quantiles), # pylint: disable=too-many-function-args
|
|
75
|
+
set_prediction_score_column(quantiles), # pylint: disable=too-many-function-args
|
|
76
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net Set Test Set Component."""
|
|
15
|
+
|
|
16
|
+
from typing import NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component(packages_to_install=['tensorflow==2.11.0'])
|
|
22
|
+
def set_test_set(
|
|
23
|
+
dataprep_dir: dsl.InputPath(),
|
|
24
|
+
) -> NamedTuple('TestSetArtifact', uri=str, artifact=dsl.Artifact):
|
|
25
|
+
# fmt: off
|
|
26
|
+
"""Creates test set artifact.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
dataprep_dir: The bucket where dataprep artifacts are stored.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The test set dsl.Artifact.
|
|
33
|
+
"""
|
|
34
|
+
import os # pylint: disable=g-import-not-at-top
|
|
35
|
+
import json # pylint: disable=g-import-not-at-top
|
|
36
|
+
import tensorflow as tf # pylint: disable=g-import-not-at-top
|
|
37
|
+
|
|
38
|
+
with tf.io.gfile.GFile(
|
|
39
|
+
os.path.join(dataprep_dir, 'big_query_test_set.json')
|
|
40
|
+
) as f:
|
|
41
|
+
metadata = json.load(f)
|
|
42
|
+
project = metadata['projectId']
|
|
43
|
+
dataset = metadata['datasetId']
|
|
44
|
+
table = metadata['tableId']
|
|
45
|
+
output = NamedTuple('TestSetArtifact', uri=str, artifact=dsl.Artifact)
|
|
46
|
+
uri = f'bq://{project}.{dataset}.{table}'
|
|
47
|
+
artifact = dsl.Artifact(uri=uri, metadata=metadata)
|
|
48
|
+
return output(uri, artifact) # pylint: disable=too-many-function-args
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
""" "Starry Net component to set TFRecord args."""
|
|
15
|
+
|
|
16
|
+
from typing import List, NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component
|
|
22
|
+
def set_tfrecord_args(
|
|
23
|
+
dataprep_dir: dsl.InputPath(),
|
|
24
|
+
static_covariates: List[str],
|
|
25
|
+
) -> NamedTuple(
|
|
26
|
+
'TfrecordArgs',
|
|
27
|
+
static_covariates_vocab_path=str,
|
|
28
|
+
train_tf_record_patterns=str,
|
|
29
|
+
val_tf_record_patterns=str,
|
|
30
|
+
test_tf_record_patterns=str,
|
|
31
|
+
):
|
|
32
|
+
# fmt: off
|
|
33
|
+
"""Creates Trainer TFRecord args.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
dataprep_dir: The dataprep directory where dataprep artifacts are stored.
|
|
37
|
+
static_covariates: The static covariates to train the model with.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A NamedTuple containing the path to the static covariates covabulary, and
|
|
41
|
+
the tf record patterns for the train, validation, and test sets.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
outputs = NamedTuple(
|
|
45
|
+
'TfrecordArgs',
|
|
46
|
+
static_covariates_vocab_path=str,
|
|
47
|
+
train_tf_record_patterns=str,
|
|
48
|
+
val_tf_record_patterns=str,
|
|
49
|
+
test_tf_record_patterns=str,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if static_covariates and dataprep_dir:
|
|
53
|
+
static_covariates_vocab_path = f'{dataprep_dir}/static_covariate_vocab.json'
|
|
54
|
+
else:
|
|
55
|
+
static_covariates_vocab_path = ''
|
|
56
|
+
if dataprep_dir:
|
|
57
|
+
train_tf_record_patterns = f"('{dataprep_dir}/tf_records/train*',)"
|
|
58
|
+
val_tf_record_patterns = f"('{dataprep_dir}/tf_records/val*',)"
|
|
59
|
+
test_tf_record_patterns = (
|
|
60
|
+
f"('{dataprep_dir}/tf_records/test_path_for_plot*',)")
|
|
61
|
+
else:
|
|
62
|
+
train_tf_record_patterns = '()'
|
|
63
|
+
val_tf_record_patterns = '()'
|
|
64
|
+
test_tf_record_patterns = '()'
|
|
65
|
+
return outputs(
|
|
66
|
+
static_covariates_vocab_path, # pylint: disable=too-many-function-args
|
|
67
|
+
train_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
68
|
+
val_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
69
|
+
test_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
70
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net component to set training args."""
|
|
15
|
+
|
|
16
|
+
from typing import List, NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component
|
|
22
|
+
def set_train_args(
|
|
23
|
+
quantiles: List[float],
|
|
24
|
+
model_blocks: List[str],
|
|
25
|
+
static_covariates: List[str],
|
|
26
|
+
) -> NamedTuple(
|
|
27
|
+
'TrainArgs',
|
|
28
|
+
quantiles=str,
|
|
29
|
+
use_static_covariates=bool,
|
|
30
|
+
static_covariate_names=str,
|
|
31
|
+
model_blocks=str,
|
|
32
|
+
freeze_point_forecasts=bool,
|
|
33
|
+
):
|
|
34
|
+
# fmt: off
|
|
35
|
+
"""Creates Trainer model args.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
quantiles: The list of floats representing quantiles. Leave blank if
|
|
39
|
+
only training to produce point forecasts.
|
|
40
|
+
model_blocks: The list of model blocks to use in the order they will appear
|
|
41
|
+
in the model. Possible values are `cleaning`, `change_point`, `trend`,
|
|
42
|
+
`hour_of_week`, `day_of_week`, `day_of_year`, `week_of_year`,
|
|
43
|
+
`month_of_year`, `residual`.
|
|
44
|
+
static_covariates: The list of strings of static covariate names.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A NamedTuple containing the quantiles formatted as expected by the train
|
|
48
|
+
job, a bool indicating whether the job should train with static covariates,
|
|
49
|
+
the model blocks formatted as expected by the train job, and a bool
|
|
50
|
+
indicating whether or not to do two-pass training, fist training for point
|
|
51
|
+
forecsats and then quantiles.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
outputs = NamedTuple(
|
|
55
|
+
'TrainArgs',
|
|
56
|
+
quantiles=str,
|
|
57
|
+
use_static_covariates=bool,
|
|
58
|
+
static_covariate_names=str,
|
|
59
|
+
model_blocks=str,
|
|
60
|
+
freeze_point_forecasts=bool,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def set_quantiles(input_list: List[float]) -> str:
|
|
64
|
+
if not input_list or input_list[0] != 0.5:
|
|
65
|
+
input_list = [0.5] + input_list
|
|
66
|
+
if len(input_list) == 1:
|
|
67
|
+
return str(input_list).replace('[', '(').replace(']', ',)')
|
|
68
|
+
return str(input_list).replace('[', '(').replace(']', ')')
|
|
69
|
+
|
|
70
|
+
def maybe_update_model_blocks(
|
|
71
|
+
quantiles: List[float], model_blocks: List[str]) -> List[str]:
|
|
72
|
+
updated_q = [q for q in quantiles if q != 0.5]
|
|
73
|
+
model_blocks = [b for b in model_blocks if b != 'quantile']
|
|
74
|
+
if updated_q:
|
|
75
|
+
model_blocks.append('quantile')
|
|
76
|
+
return [f'{b}-hybrid' if '_of_' in b else b for b in model_blocks]
|
|
77
|
+
|
|
78
|
+
def create_name_tuple_from_list(input_list: List[str]) -> str:
|
|
79
|
+
if len(input_list) == 1:
|
|
80
|
+
return str(input_list).replace('[', '(').replace(']', ',)')
|
|
81
|
+
return str(input_list).replace('[', '(').replace(']', ')')
|
|
82
|
+
|
|
83
|
+
return outputs(
|
|
84
|
+
set_quantiles(quantiles), # pylint: disable=too-many-function-args
|
|
85
|
+
True if static_covariates else False, # pylint: disable=too-many-function-args
|
|
86
|
+
create_name_tuple_from_list(static_covariates), # pylint: disable=too-many-function-args
|
|
87
|
+
create_name_tuple_from_list( # pylint: disable=too-many-function-args
|
|
88
|
+
maybe_update_model_blocks(quantiles, model_blocks)),
|
|
89
|
+
True if quantiles and quantiles[-1] != 0.5 else False, # pylint: disable=too-many-function-args
|
|
90
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Container Component for training STARRY-Net."""
|
|
15
|
+
|
|
16
|
+
from google_cloud_pipeline_components import _placeholders
|
|
17
|
+
from google_cloud_pipeline_components import utils
|
|
18
|
+
from google_cloud_pipeline_components._implementation.starry_net import version
|
|
19
|
+
|
|
20
|
+
from kfp import dsl
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dsl.container_component
|
|
24
|
+
def train(
|
|
25
|
+
gcp_resources: dsl.OutputPath(str),
|
|
26
|
+
trainer_dir: dsl.Output[dsl.Artifact], # pytype: disable=unsupported-operands
|
|
27
|
+
num_epochs: int,
|
|
28
|
+
backcast_length: int,
|
|
29
|
+
forecast_length: int,
|
|
30
|
+
train_end_date: str,
|
|
31
|
+
csv_data_path: str,
|
|
32
|
+
csv_static_covariates_path: str,
|
|
33
|
+
static_covariates_vocab_path: str,
|
|
34
|
+
train_tf_record_patterns: str,
|
|
35
|
+
val_tf_record_patterns: str,
|
|
36
|
+
test_tf_record_patterns: str,
|
|
37
|
+
n_decomposition_plots: int,
|
|
38
|
+
n_val_windows: int,
|
|
39
|
+
n_test_windows: int,
|
|
40
|
+
test_set_stride: int,
|
|
41
|
+
cleaning_activation_regularizer_coeff: float,
|
|
42
|
+
change_point_activation_regularizer_coeff: float,
|
|
43
|
+
change_point_output_regularizer_coeff: float,
|
|
44
|
+
alpha_upper_bound: float,
|
|
45
|
+
beta_upper_bound: float,
|
|
46
|
+
phi_lower_bound: float,
|
|
47
|
+
b_fixed_val: int,
|
|
48
|
+
b0_fixed_val: int,
|
|
49
|
+
phi_fixed_val: int,
|
|
50
|
+
quantiles: str,
|
|
51
|
+
use_static_covariates: bool,
|
|
52
|
+
static_covariate_names: str,
|
|
53
|
+
model_blocks: str,
|
|
54
|
+
freeze_point_forecasts: bool,
|
|
55
|
+
machine_type: str,
|
|
56
|
+
accelerator_type: str,
|
|
57
|
+
docker_region: str,
|
|
58
|
+
location: str,
|
|
59
|
+
job_id: str,
|
|
60
|
+
project: str,
|
|
61
|
+
encryption_spec_key_name: str,
|
|
62
|
+
):
|
|
63
|
+
# fmt: off
|
|
64
|
+
"""Trains a STARRY-Net model.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the
|
|
68
|
+
CustomJob.
|
|
69
|
+
trainer_dir: The gcp bucket path where training artifacts are saved.
|
|
70
|
+
num_epochs: The number of epochs to train for.
|
|
71
|
+
backcast_length: The length of the input window to feed into the model.
|
|
72
|
+
forecast_length: The length of the forecast horizon.
|
|
73
|
+
train_end_date: The last date of data to use in the training set. All
|
|
74
|
+
subsequent dates are part of the test set.
|
|
75
|
+
csv_data_path: The path to the training data csv.
|
|
76
|
+
csv_static_covariates_path: The path to the static covariates csv.
|
|
77
|
+
static_covariates_vocab_path: The path to the master static covariates vocab
|
|
78
|
+
json.
|
|
79
|
+
train_tf_record_patterns: The glob patterns to the tf records to use for
|
|
80
|
+
training.
|
|
81
|
+
val_tf_record_patterns: The glob patterns to the tf records to use for
|
|
82
|
+
validation.
|
|
83
|
+
test_tf_record_patterns: The glob patterns to the tf records to use for
|
|
84
|
+
testing.
|
|
85
|
+
n_decomposition_plots: How many decomposition plots to save to tensorboard.
|
|
86
|
+
n_val_windows: The number of windows to use for the val set. If 0, no
|
|
87
|
+
validation set is used.
|
|
88
|
+
n_test_windows: The number of windows to use for the test set. Must be >= 1.
|
|
89
|
+
test_set_stride: The number of timestamps to roll forward when
|
|
90
|
+
constructing the val and test sets.
|
|
91
|
+
cleaning_activation_regularizer_coeff: The regularization coefficient for
|
|
92
|
+
the cleaning param estimator's final layer's activation in the cleaning
|
|
93
|
+
block.
|
|
94
|
+
change_point_activation_regularizer_coeff: The regularization coefficient
|
|
95
|
+
for the change point param estimator's final layer's activation in the
|
|
96
|
+
change_point block.
|
|
97
|
+
change_point_output_regularizer_coeff: The regularization coefficient
|
|
98
|
+
for the change point param estimator's output in the change_point block.
|
|
99
|
+
alpha_upper_bound: The upper bound for data smooth parameter alpha in the
|
|
100
|
+
trend block.
|
|
101
|
+
beta_upper_bound: The upper bound for data smooth parameter beta in the
|
|
102
|
+
trend block.
|
|
103
|
+
phi_lower_bound: The lower bound for damping param phi in the trend block.
|
|
104
|
+
b_fixed_val: The fixed value for b in the trend block. If set to anything
|
|
105
|
+
other than -1, the trend block will not learn to provide estimates
|
|
106
|
+
but use the fixed value directly.
|
|
107
|
+
b0_fixed_val: The fixed value for b0 in the trend block. If set to
|
|
108
|
+
anything other than -1, the trend block will not learn to provide
|
|
109
|
+
estimates but use the fixed value directly.
|
|
110
|
+
phi_fixed_val: The fixed value for phi in the trend block. If set to
|
|
111
|
+
anything other than -1, the trend block will not learn to provide
|
|
112
|
+
estimates but use the fixed value directly.
|
|
113
|
+
quantiles: The stringified tuple of quantiles to learn in the quantile
|
|
114
|
+
block, e.g., 0.5,0.9,0.95. This should always start with 0.5,
|
|
115
|
+
representing the point forecasts.
|
|
116
|
+
use_static_covariates: Whether to use static covariates.
|
|
117
|
+
static_covariate_names: The stringified tuple of names of the static
|
|
118
|
+
covariates.
|
|
119
|
+
model_blocks: The stringified tuple of blocks to use in the order
|
|
120
|
+
that they appear in the model. Possible values are `cleaning`,
|
|
121
|
+
`change_point`, `trend`, `hour_of_week-hybrid`, `day_of_week-hybrid`,
|
|
122
|
+
`day_of_year-hybrid`, `week_of_year-hybrid`, `month_of_year-hybrid`,
|
|
123
|
+
`residual`, `quantile`.
|
|
124
|
+
freeze_point_forecasts: Whether or not to do two pass training, where
|
|
125
|
+
first the point forecast model is trained, then the quantile block is,
|
|
126
|
+
added, all preceding blocks are frozen, and the quantile block is trained.
|
|
127
|
+
This should always be True if quantiles != [0.5].
|
|
128
|
+
machine_type: The machine type.
|
|
129
|
+
accelerator_type: The accelerator type.
|
|
130
|
+
docker_region: The docker region, used to determine which image to use.
|
|
131
|
+
location: Location for creating the custom training job. If not set,
|
|
132
|
+
defaults to us-central1.
|
|
133
|
+
job_id: The pipeline job id.
|
|
134
|
+
project: Project to create the custom training job in. Defaults to
|
|
135
|
+
the project in which the PipelineJob is run.
|
|
136
|
+
encryption_spec_key_name: Customer-managed encryption key options for the
|
|
137
|
+
CustomJob. If this is set, then all resources created by the CustomJob
|
|
138
|
+
will be encrypted with the provided encryption key.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the
|
|
142
|
+
CustomJob.
|
|
143
|
+
trainer_dir: The gcp bucket path where training artifacts are saved.
|
|
144
|
+
"""
|
|
145
|
+
job_name = f'trainer-{job_id}'
|
|
146
|
+
payload = {
|
|
147
|
+
'display_name': job_name,
|
|
148
|
+
'encryption_spec': {
|
|
149
|
+
'kms_key_name': str(encryption_spec_key_name),
|
|
150
|
+
},
|
|
151
|
+
'job_spec': {
|
|
152
|
+
'worker_pool_specs': [{
|
|
153
|
+
'replica_count': '1',
|
|
154
|
+
'machine_spec': {
|
|
155
|
+
'machine_type': str(machine_type),
|
|
156
|
+
'accelerator_type': str(accelerator_type),
|
|
157
|
+
'accelerator_count': 1,
|
|
158
|
+
},
|
|
159
|
+
'disk_spec': {
|
|
160
|
+
'boot_disk_type': 'pd-ssd',
|
|
161
|
+
'boot_disk_size_gb': 100,
|
|
162
|
+
},
|
|
163
|
+
'container_spec': {
|
|
164
|
+
'image_uri': f'{docker_region}-docker.pkg.dev/vertex-ai-restricted/starryn/trainer:{version.TRAINER_VERSION}',
|
|
165
|
+
'args': [
|
|
166
|
+
f'--vertex_experiment_dir={trainer_dir.path}',
|
|
167
|
+
f'--vertex_job_id={job_id}',
|
|
168
|
+
'--config=analysis/trafficforecast/starryn/experiments/configs/vertex.py',
|
|
169
|
+
f'--config.num_epochs={num_epochs}',
|
|
170
|
+
f'--config.freeze_point_forecasts={freeze_point_forecasts}',
|
|
171
|
+
f'--config.callbacks.tensorboard.n_decomposition_plots={n_decomposition_plots}',
|
|
172
|
+
f'--config.datasets.backcast_length={backcast_length}',
|
|
173
|
+
f'--config.datasets.forecast_length={forecast_length}',
|
|
174
|
+
f'--config.datasets.train_end_date={train_end_date}',
|
|
175
|
+
f'--config.datasets.train_path={csv_data_path}',
|
|
176
|
+
f'--config.datasets.static_covariates_path={csv_static_covariates_path}',
|
|
177
|
+
f'--config.datasets.static_covariates_vocab_path={static_covariates_vocab_path}',
|
|
178
|
+
f'--config.datasets.train_tf_record_patterns={train_tf_record_patterns}',
|
|
179
|
+
f'--config.datasets.val_tf_record_patterns={val_tf_record_patterns}',
|
|
180
|
+
f'--config.datasets.test_tf_record_patterns={test_tf_record_patterns}',
|
|
181
|
+
f'--config.datasets.n_val_windows={n_val_windows}',
|
|
182
|
+
f'--config.datasets.val_rolling_window_size={test_set_stride}',
|
|
183
|
+
f'--config.datasets.n_test_windows={n_test_windows}',
|
|
184
|
+
f'--config.datasets.test_rolling_window_size={test_set_stride}',
|
|
185
|
+
f'--config.model.regularizer_coeff={cleaning_activation_regularizer_coeff}',
|
|
186
|
+
f'--config.model.activation_regularizer_coeff={change_point_activation_regularizer_coeff}',
|
|
187
|
+
f'--config.model.output_regularizer_coeff={change_point_output_regularizer_coeff}',
|
|
188
|
+
f'--config.model.alpha_upper_bound={alpha_upper_bound}',
|
|
189
|
+
f'--config.model.beta_upper_bound={beta_upper_bound}',
|
|
190
|
+
f'--config.model.phi_lower_bound={phi_lower_bound}',
|
|
191
|
+
f'--config.model.b_fixed_val={b_fixed_val}',
|
|
192
|
+
f'--config.model.b0_fixed_val={b0_fixed_val}',
|
|
193
|
+
f'--config.model.phi_fixed_val={phi_fixed_val}',
|
|
194
|
+
f'--config.model.quantiles={quantiles}',
|
|
195
|
+
f'--config.model.use_static_covariates_trend={use_static_covariates}',
|
|
196
|
+
f'--config.model.use_static_covariates_calendar={use_static_covariates}',
|
|
197
|
+
f'--config.model.static_cov_names={static_covariate_names}',
|
|
198
|
+
f'--config.model.blocks_list={model_blocks}',
|
|
199
|
+
],
|
|
200
|
+
},
|
|
201
|
+
}]
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
return utils.build_serverless_customjob_container_spec(
|
|
205
|
+
project=project,
|
|
206
|
+
location=location,
|
|
207
|
+
custom_job_payload=payload,
|
|
208
|
+
gcp_resources=gcp_resources,
|
|
209
|
+
)
|
google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/component.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net upload decomposition plots component."""
|
|
15
|
+
|
|
16
|
+
from kfp import dsl
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dsl.component(packages_to_install=['google-cloud-aiplatform[tensorboard]'])
|
|
20
|
+
def upload_decomposition_plots(
|
|
21
|
+
project: str,
|
|
22
|
+
location: str,
|
|
23
|
+
tensorboard_id: str,
|
|
24
|
+
display_name: str,
|
|
25
|
+
trainer_dir: dsl.InputPath(),
|
|
26
|
+
) -> dsl.Artifact:
|
|
27
|
+
# fmt: off
|
|
28
|
+
"""Uploads decomposition plots to Tensorboard.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
project: The project where the pipeline is run. Defaults to current project.
|
|
32
|
+
location: The location where the pipeline components are run.
|
|
33
|
+
tensorboard_id: The tensorboard instance ID.
|
|
34
|
+
display_name: The diplay name of the job.
|
|
35
|
+
trainer_dir: The directory where training artifacts where stored.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
A dsl.Artifact where the URI is the URI where decomposition plots can be
|
|
39
|
+
viewed.
|
|
40
|
+
"""
|
|
41
|
+
import os # pylint: disable=g-import-not-at-top
|
|
42
|
+
from google.cloud import aiplatform # pylint: disable=g-import-not-at-top
|
|
43
|
+
|
|
44
|
+
log_dir = os.path.join(trainer_dir, 'tensorboard', 'r=1:gc=0')
|
|
45
|
+
project_number = os.environ['CLOUD_ML_PROJECT_ID']
|
|
46
|
+
aiplatform.init(project=project, location=location)
|
|
47
|
+
aiplatform.upload_tb_log(
|
|
48
|
+
tensorboard_id=tensorboard_id,
|
|
49
|
+
tensorboard_experiment_name=display_name,
|
|
50
|
+
logdir=log_dir,
|
|
51
|
+
experiment_display_name=display_name,
|
|
52
|
+
description=f'Tensorboard for {display_name}',
|
|
53
|
+
)
|
|
54
|
+
uri = (
|
|
55
|
+
f'https://{location}.tensorboard.googleusercontent.com/experiment/'
|
|
56
|
+
f'projects+{project_number}+locations+{location}+tensorboards+'
|
|
57
|
+
f'{tensorboard_id}+experiments+{display_name}/#images'
|
|
58
|
+
)
|
|
59
|
+
return dsl.Artifact(uri=uri)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|