google-cloud-pipeline-components 2.14.1__py3-none-any.whl → 2.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of google-cloud-pipeline-components might be problematic. Click here for more details.
- google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py +1 -1
- google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py +24 -0
- google_cloud_pipeline_components/_implementation/starry_net/__init__.py +41 -0
- google_cloud_pipeline_components/_implementation/{model_evaluation/import_evaluation → starry_net/dataprep}/__init__.py +1 -2
- google_cloud_pipeline_components/_implementation/starry_net/dataprep/component.py +173 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/component.py +23 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/evaluation.yaml +197 -0
- google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py +62 -0
- google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/component.py +77 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/component.py +97 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/component.py +76 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_test_set/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_test_set/component.py +48 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/component.py +70 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_train_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_train_args/component.py +90 -0
- google_cloud_pipeline_components/_implementation/starry_net/train/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/train/component.py +220 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/component.py +64 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/component.py +23 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/upload_model.yaml +37 -0
- google_cloud_pipeline_components/_implementation/starry_net/version.py +18 -0
- google_cloud_pipeline_components/container/preview/custom_job/remote_runner.py +22 -0
- google_cloud_pipeline_components/container/utils/error_surfacing.py +45 -0
- google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py +36 -7
- google_cloud_pipeline_components/preview/automl/forecasting/forecasting_ensemble.py +1 -1
- google_cloud_pipeline_components/preview/automl/forecasting/forecasting_stage_1_tuner.py +2 -2
- google_cloud_pipeline_components/preview/automl/forecasting/forecasting_stage_2_tuner.py +2 -2
- google_cloud_pipeline_components/preview/automl/forecasting/learn_to_learn_forecasting_pipeline.yaml +38 -34
- google_cloud_pipeline_components/preview/automl/forecasting/sequence_to_sequence_forecasting_pipeline.yaml +38 -34
- google_cloud_pipeline_components/preview/automl/forecasting/temporal_fusion_transformer_forecasting_pipeline.yaml +38 -34
- google_cloud_pipeline_components/preview/automl/forecasting/time_series_dense_encoder_forecasting_pipeline.yaml +38 -34
- google_cloud_pipeline_components/preview/automl/forecasting/utils.py +49 -7
- google_cloud_pipeline_components/preview/automl/tabular/auto_feature_engineering.py +1 -1
- google_cloud_pipeline_components/preview/automl/tabular/automl_tabular_feature_selection_pipeline.yaml +39 -39
- google_cloud_pipeline_components/preview/automl/tabular/automl_tabular_v2_pipeline.yaml +41 -41
- google_cloud_pipeline_components/preview/automl/tabular/distillation_stage_feature_transform_engine.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/feature_selection.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/feature_selection_pipeline.yaml +4 -4
- google_cloud_pipeline_components/preview/automl/tabular/feature_transform_engine.py +3 -3
- google_cloud_pipeline_components/preview/automl/tabular/tabnet_hyperparameter_tuning_job.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/tabnet_hyperparameter_tuning_job_pipeline.yaml +15 -15
- google_cloud_pipeline_components/preview/automl/tabular/tabnet_trainer.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/tabnet_trainer_pipeline.yaml +13 -13
- google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_hyperparameter_tuning_job.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_hyperparameter_tuning_job_pipeline.yaml +14 -14
- google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_trainer.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_trainer_pipeline.yaml +13 -13
- google_cloud_pipeline_components/preview/automl/tabular/xgboost_hyperparameter_tuning_job_pipeline.yaml +14 -14
- google_cloud_pipeline_components/preview/automl/tabular/xgboost_trainer_pipeline.yaml +13 -13
- google_cloud_pipeline_components/preview/custom_job/utils.py +45 -6
- google_cloud_pipeline_components/preview/llm/rlhf/component.py +3 -6
- google_cloud_pipeline_components/preview/starry_net/__init__.py +19 -0
- google_cloud_pipeline_components/preview/starry_net/component.py +469 -0
- google_cloud_pipeline_components/proto/task_error_pb2.py +0 -1
- google_cloud_pipeline_components/v1/automl/forecasting/bqml_arima_predict_pipeline.yaml +10 -10
- google_cloud_pipeline_components/v1/automl/forecasting/bqml_arima_train_pipeline.yaml +31 -31
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_predict_pipeline.yaml +13 -13
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer.py +3 -3
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer_pipeline.yaml +14 -14
- google_cloud_pipeline_components/v1/automl/tabular/automl_tabular_pipeline.yaml +37 -37
- google_cloud_pipeline_components/v1/automl/tabular/cv_trainer.py +2 -2
- google_cloud_pipeline_components/v1/automl/tabular/ensemble.py +2 -2
- google_cloud_pipeline_components/v1/automl/tabular/finalizer.py +1 -1
- google_cloud_pipeline_components/v1/automl/tabular/infra_validator.py +1 -1
- google_cloud_pipeline_components/v1/automl/tabular/split_materialized_data.py +1 -1
- google_cloud_pipeline_components/v1/automl/tabular/stage_1_tuner.py +2 -2
- google_cloud_pipeline_components/v1/automl/tabular/stats_and_example_gen.py +2 -2
- google_cloud_pipeline_components/v1/automl/tabular/training_configurator_and_validator.py +1 -1
- google_cloud_pipeline_components/v1/automl/tabular/transform.py +2 -2
- google_cloud_pipeline_components/v1/custom_job/component.py +3 -0
- google_cloud_pipeline_components/v1/custom_job/utils.py +4 -0
- google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +21 -0
- google_cloud_pipeline_components/version.py +1 -1
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/METADATA +17 -20
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/RECORD +87 -58
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/WHEEL +1 -1
- google_cloud_pipeline_components/_implementation/model_evaluation/import_evaluation/component.py +0 -208
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/LICENSE +0 -0
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/top_level.txt +0 -0
google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/component.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net component to set TFRecord args if training with TF Records."""
|
|
15
|
+
|
|
16
|
+
from typing import List, NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component
|
|
22
|
+
def maybe_set_tfrecord_args(
|
|
23
|
+
dataprep_previous_run_dir: str,
|
|
24
|
+
static_covariates: List[str],
|
|
25
|
+
) -> NamedTuple(
|
|
26
|
+
'TfrecordArgs',
|
|
27
|
+
static_covariates_vocab_path=str,
|
|
28
|
+
train_tf_record_patterns=str,
|
|
29
|
+
val_tf_record_patterns=str,
|
|
30
|
+
test_tf_record_patterns=str,
|
|
31
|
+
):
|
|
32
|
+
# fmt: off
|
|
33
|
+
"""Creates Trainer TFRecord args if training with TF Records.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
dataprep_previous_run_dir: The dataprep dir from a previous run. Use this
|
|
37
|
+
to save time if you've already created TFRecords from your BigQuery
|
|
38
|
+
dataset with the same dataprep parameters as this run.
|
|
39
|
+
static_covariates: The static covariates to train the model with.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A NamedTuple containing the path to the static covariates covabulary, and
|
|
43
|
+
the tf record patterns for the train, validation, and test sets.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
outputs = NamedTuple(
|
|
47
|
+
'TfrecordArgs',
|
|
48
|
+
static_covariates_vocab_path=str,
|
|
49
|
+
train_tf_record_patterns=str,
|
|
50
|
+
val_tf_record_patterns=str,
|
|
51
|
+
test_tf_record_patterns=str,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if static_covariates and dataprep_previous_run_dir:
|
|
55
|
+
static_covariates_vocab_path = (
|
|
56
|
+
f'{dataprep_previous_run_dir}/static_covariate_vocab.json'
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
static_covariates_vocab_path = ''
|
|
60
|
+
if dataprep_previous_run_dir:
|
|
61
|
+
train_tf_record_patterns = (
|
|
62
|
+
f"('{dataprep_previous_run_dir}/tf_records/train*',)"
|
|
63
|
+
)
|
|
64
|
+
val_tf_record_patterns = f"('{dataprep_previous_run_dir}/tf_records/val*',)"
|
|
65
|
+
test_tf_record_patterns = (
|
|
66
|
+
f"('{dataprep_previous_run_dir}/tf_records/test_path_for_plot*',)"
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
train_tf_record_patterns = '()'
|
|
70
|
+
val_tf_record_patterns = '()'
|
|
71
|
+
test_tf_record_patterns = '()'
|
|
72
|
+
return outputs(
|
|
73
|
+
static_covariates_vocab_path, # pylint: disable=too-many-function-args
|
|
74
|
+
train_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
75
|
+
val_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
76
|
+
test_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
77
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""StarryNet Set Dataprep Args Component."""
|
|
15
|
+
|
|
16
|
+
from typing import List, NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component
|
|
22
|
+
def set_dataprep_args(
|
|
23
|
+
model_blocks: List[str],
|
|
24
|
+
ts_identifier_columns: List[str],
|
|
25
|
+
static_covariate_columns: List[str],
|
|
26
|
+
csv_data_path: str,
|
|
27
|
+
previous_run_dir: str,
|
|
28
|
+
location: str,
|
|
29
|
+
) -> NamedTuple(
|
|
30
|
+
'DataprepArgs',
|
|
31
|
+
model_blocks=str,
|
|
32
|
+
ts_identifier_columns=str,
|
|
33
|
+
static_covariate_columns=str,
|
|
34
|
+
create_tf_records=bool,
|
|
35
|
+
docker_region=str,
|
|
36
|
+
):
|
|
37
|
+
# fmt: off
|
|
38
|
+
"""Creates Dataprep args.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_blocks: The list of model blocks to use in the order they will appear
|
|
42
|
+
in the model. Possible values are `cleaning`, `change_point`, `trend`,
|
|
43
|
+
`hour_of_week`, `day_of_week`, `day_of_year`, `week_of_year`,
|
|
44
|
+
`month_of_year`, `residual`.
|
|
45
|
+
ts_identifier_columns: The list of ts_identifier columns from the BigQuery
|
|
46
|
+
data source.
|
|
47
|
+
static_covariate_columns: The list of strings of static covariate names.
|
|
48
|
+
csv_data_path: The path to the training data csv in the format
|
|
49
|
+
gs://bucket_name/sub_dir/blob_name.csv.
|
|
50
|
+
previous_run_dir: The dataprep dir from a previous run. Use this
|
|
51
|
+
to save time if you've already created TFRecords from your BigQuery
|
|
52
|
+
dataset with the same dataprep parameters as this run.
|
|
53
|
+
location: The location where the pipeline is run.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A NamedTuple containing the model blocks formatted as expected by the
|
|
57
|
+
dataprep job, the ts_identifier_columns formatted as expected by the
|
|
58
|
+
dataprep job, the static_covariate_columns formatted as expected by the
|
|
59
|
+
dataprep job, a boolean indicating whether to create tf records, and the
|
|
60
|
+
region of the dataprep docker image.
|
|
61
|
+
"""
|
|
62
|
+
outputs = NamedTuple(
|
|
63
|
+
'DataprepArgs',
|
|
64
|
+
model_blocks=str,
|
|
65
|
+
ts_identifier_columns=str,
|
|
66
|
+
static_covariate_columns=str,
|
|
67
|
+
create_tf_records=bool,
|
|
68
|
+
docker_region=str,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def maybe_update_model_blocks(model_blocks: List[str]) -> List[str]:
|
|
72
|
+
return [f'{b}-hybrid' if '_of_' in b else b for b in model_blocks]
|
|
73
|
+
|
|
74
|
+
def create_name_tuple_from_list(input_list: List[str]) -> str:
|
|
75
|
+
if len(input_list) == 1:
|
|
76
|
+
return str(input_list).replace('[', '(').replace(']', ',)')
|
|
77
|
+
return str(input_list).replace('[', '(').replace(']', ')')
|
|
78
|
+
|
|
79
|
+
def set_docker_region(location: str) -> str:
|
|
80
|
+
if location.startswith('africa') or location.startswith('europe'):
|
|
81
|
+
return 'europe'
|
|
82
|
+
elif (
|
|
83
|
+
location.startswith('asia')
|
|
84
|
+
or location.startswith('australia')
|
|
85
|
+
or location.startswith('me')
|
|
86
|
+
):
|
|
87
|
+
return 'asia'
|
|
88
|
+
else:
|
|
89
|
+
return 'us'
|
|
90
|
+
|
|
91
|
+
return outputs(
|
|
92
|
+
create_name_tuple_from_list(maybe_update_model_blocks(model_blocks)), # pylint: disable=too-many-function-args
|
|
93
|
+
','.join(ts_identifier_columns), # pylint: disable=too-many-function-args
|
|
94
|
+
create_name_tuple_from_list(static_covariate_columns), # pylint: disable=too-many-function-args
|
|
95
|
+
False if csv_data_path or previous_run_dir else True, # pylint: disable=too-many-function-args
|
|
96
|
+
set_docker_region(location), # pylint: disable=too-many-function-args
|
|
97
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net Set Eval Args Component."""
|
|
15
|
+
|
|
16
|
+
from typing import List, NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component
|
|
22
|
+
def set_eval_args(
|
|
23
|
+
big_query_source: dsl.Input[dsl.Artifact], quantiles: List[float]
|
|
24
|
+
) -> NamedTuple(
|
|
25
|
+
'EvalArgs',
|
|
26
|
+
big_query_source=str,
|
|
27
|
+
forecasting_type=str,
|
|
28
|
+
quantiles=List[float],
|
|
29
|
+
prediction_score_column=str,
|
|
30
|
+
):
|
|
31
|
+
# fmt: off
|
|
32
|
+
"""Creates Evaluation args.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
big_query_source: The BQ Table containing the test set.
|
|
36
|
+
quantiles: The quantiles the model was trained to output.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A NamedTuple containing big_query_source as a string, forecasting_type
|
|
40
|
+
used for evaluation step, quantiles in the format expected by the evaluation
|
|
41
|
+
job, and the prediction_score_column used to evaluate.
|
|
42
|
+
"""
|
|
43
|
+
outputs = NamedTuple(
|
|
44
|
+
'EvalArgs',
|
|
45
|
+
big_query_source=str,
|
|
46
|
+
forecasting_type=str,
|
|
47
|
+
quantiles=List[float],
|
|
48
|
+
prediction_score_column=str)
|
|
49
|
+
|
|
50
|
+
def set_forecasting_type_for_eval(quantiles: List[float]) -> str:
|
|
51
|
+
if quantiles and quantiles[-1] != 0.5:
|
|
52
|
+
return 'quantile'
|
|
53
|
+
return 'point'
|
|
54
|
+
|
|
55
|
+
def set_quantiles_for_eval(quantiles: List[float]) -> List[float]:
|
|
56
|
+
updated_q = [q for q in quantiles if q != 0.5]
|
|
57
|
+
if updated_q:
|
|
58
|
+
updated_q = [0.5] + updated_q
|
|
59
|
+
return updated_q
|
|
60
|
+
|
|
61
|
+
def set_prediction_score_column(
|
|
62
|
+
quantiles: List[float]) -> str:
|
|
63
|
+
updated_q = [q for q in quantiles if q != 0.5]
|
|
64
|
+
if updated_q:
|
|
65
|
+
return 'predicted_x.quantile_predictions'
|
|
66
|
+
return 'predicted_x.value'
|
|
67
|
+
|
|
68
|
+
project_id = big_query_source.metadata['projectId']
|
|
69
|
+
dataset_id = big_query_source.metadata['datasetId']
|
|
70
|
+
table_id = big_query_source.metadata['tableId']
|
|
71
|
+
return outputs(
|
|
72
|
+
f'bq://{project_id}.{dataset_id}.{table_id}', # pylint: disable=too-many-function-args
|
|
73
|
+
set_forecasting_type_for_eval(quantiles), # pylint: disable=too-many-function-args
|
|
74
|
+
set_quantiles_for_eval(quantiles), # pylint: disable=too-many-function-args
|
|
75
|
+
set_prediction_score_column(quantiles), # pylint: disable=too-many-function-args
|
|
76
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net Set Test Set Component."""
|
|
15
|
+
|
|
16
|
+
from typing import NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component(packages_to_install=['tensorflow==2.11.0'])
|
|
22
|
+
def set_test_set(
|
|
23
|
+
dataprep_dir: dsl.InputPath(),
|
|
24
|
+
) -> NamedTuple('TestSetArtifact', uri=str, artifact=dsl.Artifact):
|
|
25
|
+
# fmt: off
|
|
26
|
+
"""Creates test set artifact.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
dataprep_dir: The bucket where dataprep artifacts are stored.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The test set dsl.Artifact.
|
|
33
|
+
"""
|
|
34
|
+
import os # pylint: disable=g-import-not-at-top
|
|
35
|
+
import json # pylint: disable=g-import-not-at-top
|
|
36
|
+
import tensorflow as tf # pylint: disable=g-import-not-at-top
|
|
37
|
+
|
|
38
|
+
with tf.io.gfile.GFile(
|
|
39
|
+
os.path.join(dataprep_dir, 'big_query_test_set.json')
|
|
40
|
+
) as f:
|
|
41
|
+
metadata = json.load(f)
|
|
42
|
+
project = metadata['projectId']
|
|
43
|
+
dataset = metadata['datasetId']
|
|
44
|
+
table = metadata['tableId']
|
|
45
|
+
output = NamedTuple('TestSetArtifact', uri=str, artifact=dsl.Artifact)
|
|
46
|
+
uri = f'bq://{project}.{dataset}.{table}'
|
|
47
|
+
artifact = dsl.Artifact(uri=uri, metadata=metadata)
|
|
48
|
+
return output(uri, artifact) # pylint: disable=too-many-function-args
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
""" "Starry Net component to set TFRecord args."""
|
|
15
|
+
|
|
16
|
+
from typing import List, NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component
|
|
22
|
+
def set_tfrecord_args(
|
|
23
|
+
dataprep_dir: dsl.InputPath(),
|
|
24
|
+
static_covariates: List[str],
|
|
25
|
+
) -> NamedTuple(
|
|
26
|
+
'TfrecordArgs',
|
|
27
|
+
static_covariates_vocab_path=str,
|
|
28
|
+
train_tf_record_patterns=str,
|
|
29
|
+
val_tf_record_patterns=str,
|
|
30
|
+
test_tf_record_patterns=str,
|
|
31
|
+
):
|
|
32
|
+
# fmt: off
|
|
33
|
+
"""Creates Trainer TFRecord args.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
dataprep_dir: The dataprep directory where dataprep artifacts are stored.
|
|
37
|
+
static_covariates: The static covariates to train the model with.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A NamedTuple containing the path to the static covariates covabulary, and
|
|
41
|
+
the tf record patterns for the train, validation, and test sets.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
outputs = NamedTuple(
|
|
45
|
+
'TfrecordArgs',
|
|
46
|
+
static_covariates_vocab_path=str,
|
|
47
|
+
train_tf_record_patterns=str,
|
|
48
|
+
val_tf_record_patterns=str,
|
|
49
|
+
test_tf_record_patterns=str,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if static_covariates and dataprep_dir:
|
|
53
|
+
static_covariates_vocab_path = f'{dataprep_dir}/static_covariate_vocab.json'
|
|
54
|
+
else:
|
|
55
|
+
static_covariates_vocab_path = ''
|
|
56
|
+
if dataprep_dir:
|
|
57
|
+
train_tf_record_patterns = f"('{dataprep_dir}/tf_records/train*',)"
|
|
58
|
+
val_tf_record_patterns = f"('{dataprep_dir}/tf_records/val*',)"
|
|
59
|
+
test_tf_record_patterns = (
|
|
60
|
+
f"('{dataprep_dir}/tf_records/test_path_for_plot*',)")
|
|
61
|
+
else:
|
|
62
|
+
train_tf_record_patterns = '()'
|
|
63
|
+
val_tf_record_patterns = '()'
|
|
64
|
+
test_tf_record_patterns = '()'
|
|
65
|
+
return outputs(
|
|
66
|
+
static_covariates_vocab_path, # pylint: disable=too-many-function-args
|
|
67
|
+
train_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
68
|
+
val_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
69
|
+
test_tf_record_patterns, # pylint: disable=too-many-function-args
|
|
70
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net component to set training args."""
|
|
15
|
+
|
|
16
|
+
from typing import List, NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component
|
|
22
|
+
def set_train_args(
|
|
23
|
+
quantiles: List[float],
|
|
24
|
+
model_blocks: List[str],
|
|
25
|
+
static_covariates: List[str],
|
|
26
|
+
) -> NamedTuple(
|
|
27
|
+
'TrainArgs',
|
|
28
|
+
quantiles=str,
|
|
29
|
+
use_static_covariates=bool,
|
|
30
|
+
static_covariate_names=str,
|
|
31
|
+
model_blocks=str,
|
|
32
|
+
freeze_point_forecasts=bool,
|
|
33
|
+
):
|
|
34
|
+
# fmt: off
|
|
35
|
+
"""Creates Trainer model args.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
quantiles: The list of floats representing quantiles. Leave blank if
|
|
39
|
+
only training to produce point forecasts.
|
|
40
|
+
model_blocks: The list of model blocks to use in the order they will appear
|
|
41
|
+
in the model. Possible values are `cleaning`, `change_point`, `trend`,
|
|
42
|
+
`hour_of_week`, `day_of_week`, `day_of_year`, `week_of_year`,
|
|
43
|
+
`month_of_year`, `residual`.
|
|
44
|
+
static_covariates: The list of strings of static covariate names.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A NamedTuple containing the quantiles formatted as expected by the train
|
|
48
|
+
job, a bool indicating whether the job should train with static covariates,
|
|
49
|
+
the model blocks formatted as expected by the train job, and a bool
|
|
50
|
+
indicating whether or not to do two-pass training, fist training for point
|
|
51
|
+
forecsats and then quantiles.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
outputs = NamedTuple(
|
|
55
|
+
'TrainArgs',
|
|
56
|
+
quantiles=str,
|
|
57
|
+
use_static_covariates=bool,
|
|
58
|
+
static_covariate_names=str,
|
|
59
|
+
model_blocks=str,
|
|
60
|
+
freeze_point_forecasts=bool,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def set_quantiles(input_list: List[float]) -> str:
|
|
64
|
+
if not input_list or input_list[0] != 0.5:
|
|
65
|
+
input_list = [0.5] + input_list
|
|
66
|
+
if len(input_list) == 1:
|
|
67
|
+
return str(input_list).replace('[', '(').replace(']', ',)')
|
|
68
|
+
return str(input_list).replace('[', '(').replace(']', ')')
|
|
69
|
+
|
|
70
|
+
def maybe_update_model_blocks(
|
|
71
|
+
quantiles: List[float], model_blocks: List[str]) -> List[str]:
|
|
72
|
+
updated_q = [q for q in quantiles if q != 0.5]
|
|
73
|
+
model_blocks = [b for b in model_blocks if b != 'quantile']
|
|
74
|
+
if updated_q:
|
|
75
|
+
model_blocks.append('quantile')
|
|
76
|
+
return [f'{b}-hybrid' if '_of_' in b else b for b in model_blocks]
|
|
77
|
+
|
|
78
|
+
def create_name_tuple_from_list(input_list: List[str]) -> str:
|
|
79
|
+
if len(input_list) == 1:
|
|
80
|
+
return str(input_list).replace('[', '(').replace(']', ',)')
|
|
81
|
+
return str(input_list).replace('[', '(').replace(']', ')')
|
|
82
|
+
|
|
83
|
+
return outputs(
|
|
84
|
+
set_quantiles(quantiles), # pylint: disable=too-many-function-args
|
|
85
|
+
True if static_covariates else False, # pylint: disable=too-many-function-args
|
|
86
|
+
create_name_tuple_from_list(static_covariates), # pylint: disable=too-many-function-args
|
|
87
|
+
create_name_tuple_from_list( # pylint: disable=too-many-function-args
|
|
88
|
+
maybe_update_model_blocks(quantiles, model_blocks)),
|
|
89
|
+
True if quantiles and quantiles[-1] != 0.5 else False, # pylint: disable=too-many-function-args
|
|
90
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|