google-cloud-pipeline-components 2.14.0__py3-none-any.whl → 2.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of google-cloud-pipeline-components might be problematic. Click here for more details.

Files changed (64) hide show
  1. google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +10 -26
  2. google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py +1 -1
  3. google_cloud_pipeline_components/_implementation/llm/infer_preprocessor.py +109 -0
  4. google_cloud_pipeline_components/_implementation/llm/online_evaluation_pairwise.py +8 -0
  5. google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py +5 -6
  6. google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py +24 -0
  7. google_cloud_pipeline_components/_implementation/model_evaluation/__init__.py +0 -12
  8. google_cloud_pipeline_components/_implementation/model_evaluation/llm_embedding/evaluation_llm_embedding_pipeline.py +2 -1
  9. google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py +14 -0
  10. google_cloud_pipeline_components/_implementation/starry_net/__init__.py +41 -0
  11. google_cloud_pipeline_components/_implementation/{model_evaluation/import_evaluation → starry_net/dataprep}/__init__.py +1 -2
  12. google_cloud_pipeline_components/_implementation/starry_net/dataprep/component.py +159 -0
  13. google_cloud_pipeline_components/_implementation/starry_net/evaluation/__init__.py +13 -0
  14. google_cloud_pipeline_components/_implementation/starry_net/evaluation/component.py +23 -0
  15. google_cloud_pipeline_components/_implementation/starry_net/evaluation/evaluation.yaml +197 -0
  16. google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/__init__.py +13 -0
  17. google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py +62 -0
  18. google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/__init__.py +13 -0
  19. google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/component.py +77 -0
  20. google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/__init__.py +13 -0
  21. google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/component.py +97 -0
  22. google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/__init__.py +13 -0
  23. google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/component.py +76 -0
  24. google_cloud_pipeline_components/_implementation/starry_net/set_test_set/__init__.py +13 -0
  25. google_cloud_pipeline_components/_implementation/starry_net/set_test_set/component.py +48 -0
  26. google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/__init__.py +13 -0
  27. google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/component.py +70 -0
  28. google_cloud_pipeline_components/_implementation/starry_net/set_train_args/__init__.py +13 -0
  29. google_cloud_pipeline_components/_implementation/starry_net/set_train_args/component.py +90 -0
  30. google_cloud_pipeline_components/_implementation/starry_net/train/__init__.py +13 -0
  31. google_cloud_pipeline_components/_implementation/starry_net/train/component.py +209 -0
  32. google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/__init__.py +13 -0
  33. google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/component.py +59 -0
  34. google_cloud_pipeline_components/_implementation/starry_net/upload_model/__init__.py +13 -0
  35. google_cloud_pipeline_components/_implementation/starry_net/upload_model/component.py +23 -0
  36. google_cloud_pipeline_components/_implementation/starry_net/upload_model/upload_model.yaml +37 -0
  37. google_cloud_pipeline_components/_implementation/starry_net/version.py +18 -0
  38. google_cloud_pipeline_components/container/utils/error_surfacing.py +45 -0
  39. google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py +36 -7
  40. google_cloud_pipeline_components/preview/llm/infer/component.py +22 -25
  41. google_cloud_pipeline_components/preview/llm/rlhf/component.py +15 -8
  42. google_cloud_pipeline_components/preview/model_evaluation/__init__.py +4 -1
  43. google_cloud_pipeline_components/{_implementation/model_evaluation/import_evaluation/component.py → preview/model_evaluation/model_evaluation_import_component.py} +4 -3
  44. google_cloud_pipeline_components/preview/starry_net/__init__.py +19 -0
  45. google_cloud_pipeline_components/preview/starry_net/component.py +443 -0
  46. google_cloud_pipeline_components/proto/task_error_pb2.py +32 -0
  47. google_cloud_pipeline_components/v1/automl/forecasting/prophet_predict_pipeline.yaml +13 -13
  48. google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer.py +10 -0
  49. google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer_pipeline.yaml +4 -1
  50. google_cloud_pipeline_components/v1/model_evaluation/error_analysis_pipeline.py +8 -10
  51. google_cloud_pipeline_components/v1/model_evaluation/evaluated_annotation_pipeline.py +2 -2
  52. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_tabular_feature_attribution_pipeline.py +2 -2
  53. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_tabular_pipeline.py +2 -2
  54. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_unstructure_data_pipeline.py +2 -2
  55. google_cloud_pipeline_components/v1/model_evaluation/evaluation_feature_attribution_pipeline.py +2 -2
  56. google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_classification_pipeline.py +4 -2
  57. google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +8 -2
  58. google_cloud_pipeline_components/v1/model_evaluation/model_based_llm_evaluation/autosxs/autosxs_pipeline.py +1 -0
  59. google_cloud_pipeline_components/version.py +1 -1
  60. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/METADATA +17 -20
  61. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/RECORD +64 -32
  62. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/WHEEL +1 -1
  63. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/LICENSE +0 -0
  64. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net Set Eval Args Component."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def set_eval_args(
23
+ big_query_source: dsl.Input[dsl.Artifact], quantiles: List[float]
24
+ ) -> NamedTuple(
25
+ 'EvalArgs',
26
+ big_query_source=str,
27
+ forecasting_type=str,
28
+ quantiles=List[float],
29
+ prediction_score_column=str,
30
+ ):
31
+ # fmt: off
32
+ """Creates Evaluation args.
33
+
34
+ Args:
35
+ big_query_source: The BQ Table containing the test set.
36
+ quantiles: The quantiles the model was trained to output.
37
+
38
+ Returns:
39
+ A NamedTuple containing big_query_source as a string, forecasting_type
40
+ used for evaluation step, quantiles in the format expected by the evaluation
41
+ job, and the prediction_score_column used to evaluate.
42
+ """
43
+ outputs = NamedTuple(
44
+ 'EvalArgs',
45
+ big_query_source=str,
46
+ forecasting_type=str,
47
+ quantiles=List[float],
48
+ prediction_score_column=str)
49
+
50
+ def set_forecasting_type_for_eval(quantiles: List[float]) -> str:
51
+ if quantiles and quantiles[-1] != 0.5:
52
+ return 'quantile'
53
+ return 'point'
54
+
55
+ def set_quantiles_for_eval(quantiles: List[float]) -> List[float]:
56
+ updated_q = [q for q in quantiles if q != 0.5]
57
+ if updated_q:
58
+ updated_q = [0.5] + updated_q
59
+ return updated_q
60
+
61
+ def set_prediction_score_column(
62
+ quantiles: List[float]) -> str:
63
+ updated_q = [q for q in quantiles if q != 0.5]
64
+ if updated_q:
65
+ return 'predicted_x.quantile_predictions'
66
+ return 'predicted_x.value'
67
+
68
+ project_id = big_query_source.metadata['projectId']
69
+ dataset_id = big_query_source.metadata['datasetId']
70
+ table_id = big_query_source.metadata['tableId']
71
+ return outputs(
72
+ f'bq://{project_id}.{dataset_id}.{table_id}', # pylint: disable=too-many-function-args
73
+ set_forecasting_type_for_eval(quantiles), # pylint: disable=too-many-function-args
74
+ set_quantiles_for_eval(quantiles), # pylint: disable=too-many-function-args
75
+ set_prediction_score_column(quantiles), # pylint: disable=too-many-function-args
76
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net Set Test Set Component."""
15
+
16
+ from typing import NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component(packages_to_install=['tensorflow==2.11.0'])
22
+ def set_test_set(
23
+ dataprep_dir: dsl.InputPath(),
24
+ ) -> NamedTuple('TestSetArtifact', uri=str, artifact=dsl.Artifact):
25
+ # fmt: off
26
+ """Creates test set artifact.
27
+
28
+ Args:
29
+ dataprep_dir: The bucket where dataprep artifacts are stored.
30
+
31
+ Returns:
32
+ The test set dsl.Artifact.
33
+ """
34
+ import os # pylint: disable=g-import-not-at-top
35
+ import json # pylint: disable=g-import-not-at-top
36
+ import tensorflow as tf # pylint: disable=g-import-not-at-top
37
+
38
+ with tf.io.gfile.GFile(
39
+ os.path.join(dataprep_dir, 'big_query_test_set.json')
40
+ ) as f:
41
+ metadata = json.load(f)
42
+ project = metadata['projectId']
43
+ dataset = metadata['datasetId']
44
+ table = metadata['tableId']
45
+ output = NamedTuple('TestSetArtifact', uri=str, artifact=dsl.Artifact)
46
+ uri = f'bq://{project}.{dataset}.{table}'
47
+ artifact = dsl.Artifact(uri=uri, metadata=metadata)
48
+ return output(uri, artifact) # pylint: disable=too-many-function-args
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,70 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ "Starry Net component to set TFRecord args."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def set_tfrecord_args(
23
+ dataprep_dir: dsl.InputPath(),
24
+ static_covariates: List[str],
25
+ ) -> NamedTuple(
26
+ 'TfrecordArgs',
27
+ static_covariates_vocab_path=str,
28
+ train_tf_record_patterns=str,
29
+ val_tf_record_patterns=str,
30
+ test_tf_record_patterns=str,
31
+ ):
32
+ # fmt: off
33
+ """Creates Trainer TFRecord args.
34
+
35
+ Args:
36
+ dataprep_dir: The dataprep directory where dataprep artifacts are stored.
37
+ static_covariates: The static covariates to train the model with.
38
+
39
+ Returns:
40
+ A NamedTuple containing the path to the static covariates covabulary, and
41
+ the tf record patterns for the train, validation, and test sets.
42
+ """
43
+
44
+ outputs = NamedTuple(
45
+ 'TfrecordArgs',
46
+ static_covariates_vocab_path=str,
47
+ train_tf_record_patterns=str,
48
+ val_tf_record_patterns=str,
49
+ test_tf_record_patterns=str,
50
+ )
51
+
52
+ if static_covariates and dataprep_dir:
53
+ static_covariates_vocab_path = f'{dataprep_dir}/static_covariate_vocab.json'
54
+ else:
55
+ static_covariates_vocab_path = ''
56
+ if dataprep_dir:
57
+ train_tf_record_patterns = f"('{dataprep_dir}/tf_records/train*',)"
58
+ val_tf_record_patterns = f"('{dataprep_dir}/tf_records/val*',)"
59
+ test_tf_record_patterns = (
60
+ f"('{dataprep_dir}/tf_records/test_path_for_plot*',)")
61
+ else:
62
+ train_tf_record_patterns = '()'
63
+ val_tf_record_patterns = '()'
64
+ test_tf_record_patterns = '()'
65
+ return outputs(
66
+ static_covariates_vocab_path, # pylint: disable=too-many-function-args
67
+ train_tf_record_patterns, # pylint: disable=too-many-function-args
68
+ val_tf_record_patterns, # pylint: disable=too-many-function-args
69
+ test_tf_record_patterns, # pylint: disable=too-many-function-args
70
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,90 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net component to set training args."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def set_train_args(
23
+ quantiles: List[float],
24
+ model_blocks: List[str],
25
+ static_covariates: List[str],
26
+ ) -> NamedTuple(
27
+ 'TrainArgs',
28
+ quantiles=str,
29
+ use_static_covariates=bool,
30
+ static_covariate_names=str,
31
+ model_blocks=str,
32
+ freeze_point_forecasts=bool,
33
+ ):
34
+ # fmt: off
35
+ """Creates Trainer model args.
36
+
37
+ Args:
38
+ quantiles: The list of floats representing quantiles. Leave blank if
39
+ only training to produce point forecasts.
40
+ model_blocks: The list of model blocks to use in the order they will appear
41
+ in the model. Possible values are `cleaning`, `change_point`, `trend`,
42
+ `hour_of_week`, `day_of_week`, `day_of_year`, `week_of_year`,
43
+ `month_of_year`, `residual`.
44
+ static_covariates: The list of strings of static covariate names.
45
+
46
+ Returns:
47
+ A NamedTuple containing the quantiles formatted as expected by the train
48
+ job, a bool indicating whether the job should train with static covariates,
49
+ the model blocks formatted as expected by the train job, and a bool
50
+ indicating whether or not to do two-pass training, fist training for point
51
+ forecsats and then quantiles.
52
+ """
53
+
54
+ outputs = NamedTuple(
55
+ 'TrainArgs',
56
+ quantiles=str,
57
+ use_static_covariates=bool,
58
+ static_covariate_names=str,
59
+ model_blocks=str,
60
+ freeze_point_forecasts=bool,
61
+ )
62
+
63
+ def set_quantiles(input_list: List[float]) -> str:
64
+ if not input_list or input_list[0] != 0.5:
65
+ input_list = [0.5] + input_list
66
+ if len(input_list) == 1:
67
+ return str(input_list).replace('[', '(').replace(']', ',)')
68
+ return str(input_list).replace('[', '(').replace(']', ')')
69
+
70
+ def maybe_update_model_blocks(
71
+ quantiles: List[float], model_blocks: List[str]) -> List[str]:
72
+ updated_q = [q for q in quantiles if q != 0.5]
73
+ model_blocks = [b for b in model_blocks if b != 'quantile']
74
+ if updated_q:
75
+ model_blocks.append('quantile')
76
+ return [f'{b}-hybrid' if '_of_' in b else b for b in model_blocks]
77
+
78
+ def create_name_tuple_from_list(input_list: List[str]) -> str:
79
+ if len(input_list) == 1:
80
+ return str(input_list).replace('[', '(').replace(']', ',)')
81
+ return str(input_list).replace('[', '(').replace(']', ')')
82
+
83
+ return outputs(
84
+ set_quantiles(quantiles), # pylint: disable=too-many-function-args
85
+ True if static_covariates else False, # pylint: disable=too-many-function-args
86
+ create_name_tuple_from_list(static_covariates), # pylint: disable=too-many-function-args
87
+ create_name_tuple_from_list( # pylint: disable=too-many-function-args
88
+ maybe_update_model_blocks(quantiles, model_blocks)),
89
+ True if quantiles and quantiles[-1] != 0.5 else False, # pylint: disable=too-many-function-args
90
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,209 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Container Component for training STARRY-Net."""
15
+
16
+ from google_cloud_pipeline_components import _placeholders
17
+ from google_cloud_pipeline_components import utils
18
+ from google_cloud_pipeline_components._implementation.starry_net import version
19
+
20
+ from kfp import dsl
21
+
22
+
23
+ @dsl.container_component
24
+ def train(
25
+ gcp_resources: dsl.OutputPath(str),
26
+ trainer_dir: dsl.Output[dsl.Artifact], # pytype: disable=unsupported-operands
27
+ num_epochs: int,
28
+ backcast_length: int,
29
+ forecast_length: int,
30
+ train_end_date: str,
31
+ csv_data_path: str,
32
+ csv_static_covariates_path: str,
33
+ static_covariates_vocab_path: str,
34
+ train_tf_record_patterns: str,
35
+ val_tf_record_patterns: str,
36
+ test_tf_record_patterns: str,
37
+ n_decomposition_plots: int,
38
+ n_val_windows: int,
39
+ n_test_windows: int,
40
+ test_set_stride: int,
41
+ cleaning_activation_regularizer_coeff: float,
42
+ change_point_activation_regularizer_coeff: float,
43
+ change_point_output_regularizer_coeff: float,
44
+ alpha_upper_bound: float,
45
+ beta_upper_bound: float,
46
+ phi_lower_bound: float,
47
+ b_fixed_val: int,
48
+ b0_fixed_val: int,
49
+ phi_fixed_val: int,
50
+ quantiles: str,
51
+ use_static_covariates: bool,
52
+ static_covariate_names: str,
53
+ model_blocks: str,
54
+ freeze_point_forecasts: bool,
55
+ machine_type: str,
56
+ accelerator_type: str,
57
+ docker_region: str,
58
+ location: str,
59
+ job_id: str,
60
+ project: str,
61
+ encryption_spec_key_name: str,
62
+ ):
63
+ # fmt: off
64
+ """Trains a STARRY-Net model.
65
+
66
+ Args:
67
+ gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the
68
+ CustomJob.
69
+ trainer_dir: The gcp bucket path where training artifacts are saved.
70
+ num_epochs: The number of epochs to train for.
71
+ backcast_length: The length of the input window to feed into the model.
72
+ forecast_length: The length of the forecast horizon.
73
+ train_end_date: The last date of data to use in the training set. All
74
+ subsequent dates are part of the test set.
75
+ csv_data_path: The path to the training data csv.
76
+ csv_static_covariates_path: The path to the static covariates csv.
77
+ static_covariates_vocab_path: The path to the master static covariates vocab
78
+ json.
79
+ train_tf_record_patterns: The glob patterns to the tf records to use for
80
+ training.
81
+ val_tf_record_patterns: The glob patterns to the tf records to use for
82
+ validation.
83
+ test_tf_record_patterns: The glob patterns to the tf records to use for
84
+ testing.
85
+ n_decomposition_plots: How many decomposition plots to save to tensorboard.
86
+ n_val_windows: The number of windows to use for the val set. If 0, no
87
+ validation set is used.
88
+ n_test_windows: The number of windows to use for the test set. Must be >= 1.
89
+ test_set_stride: The number of timestamps to roll forward when
90
+ constructing the val and test sets.
91
+ cleaning_activation_regularizer_coeff: The regularization coefficient for
92
+ the cleaning param estimator's final layer's activation in the cleaning
93
+ block.
94
+ change_point_activation_regularizer_coeff: The regularization coefficient
95
+ for the change point param estimator's final layer's activation in the
96
+ change_point block.
97
+ change_point_output_regularizer_coeff: The regularization coefficient
98
+ for the change point param estimator's output in the change_point block.
99
+ alpha_upper_bound: The upper bound for data smooth parameter alpha in the
100
+ trend block.
101
+ beta_upper_bound: The upper bound for data smooth parameter beta in the
102
+ trend block.
103
+ phi_lower_bound: The lower bound for damping param phi in the trend block.
104
+ b_fixed_val: The fixed value for b in the trend block. If set to anything
105
+ other than -1, the trend block will not learn to provide estimates
106
+ but use the fixed value directly.
107
+ b0_fixed_val: The fixed value for b0 in the trend block. If set to
108
+ anything other than -1, the trend block will not learn to provide
109
+ estimates but use the fixed value directly.
110
+ phi_fixed_val: The fixed value for phi in the trend block. If set to
111
+ anything other than -1, the trend block will not learn to provide
112
+ estimates but use the fixed value directly.
113
+ quantiles: The stringified tuple of quantiles to learn in the quantile
114
+ block, e.g., 0.5,0.9,0.95. This should always start with 0.5,
115
+ representing the point forecasts.
116
+ use_static_covariates: Whether to use static covariates.
117
+ static_covariate_names: The stringified tuple of names of the static
118
+ covariates.
119
+ model_blocks: The stringified tuple of blocks to use in the order
120
+ that they appear in the model. Possible values are `cleaning`,
121
+ `change_point`, `trend`, `hour_of_week-hybrid`, `day_of_week-hybrid`,
122
+ `day_of_year-hybrid`, `week_of_year-hybrid`, `month_of_year-hybrid`,
123
+ `residual`, `quantile`.
124
+ freeze_point_forecasts: Whether or not to do two pass training, where
125
+ first the point forecast model is trained, then the quantile block is,
126
+ added, all preceding blocks are frozen, and the quantile block is trained.
127
+ This should always be True if quantiles != [0.5].
128
+ machine_type: The machine type.
129
+ accelerator_type: The accelerator type.
130
+ docker_region: The docker region, used to determine which image to use.
131
+ location: Location for creating the custom training job. If not set,
132
+ defaults to us-central1.
133
+ job_id: The pipeline job id.
134
+ project: Project to create the custom training job in. Defaults to
135
+ the project in which the PipelineJob is run.
136
+ encryption_spec_key_name: Customer-managed encryption key options for the
137
+ CustomJob. If this is set, then all resources created by the CustomJob
138
+ will be encrypted with the provided encryption key.
139
+
140
+ Returns:
141
+ gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the
142
+ CustomJob.
143
+ trainer_dir: The gcp bucket path where training artifacts are saved.
144
+ """
145
+ job_name = f'trainer-{job_id}'
146
+ payload = {
147
+ 'display_name': job_name,
148
+ 'encryption_spec': {
149
+ 'kms_key_name': str(encryption_spec_key_name),
150
+ },
151
+ 'job_spec': {
152
+ 'worker_pool_specs': [{
153
+ 'replica_count': '1',
154
+ 'machine_spec': {
155
+ 'machine_type': str(machine_type),
156
+ 'accelerator_type': str(accelerator_type),
157
+ 'accelerator_count': 1,
158
+ },
159
+ 'disk_spec': {
160
+ 'boot_disk_type': 'pd-ssd',
161
+ 'boot_disk_size_gb': 100,
162
+ },
163
+ 'container_spec': {
164
+ 'image_uri': f'{docker_region}-docker.pkg.dev/vertex-ai-restricted/starryn/trainer:{version.TRAINER_VERSION}',
165
+ 'args': [
166
+ f'--vertex_experiment_dir={trainer_dir.path}',
167
+ f'--vertex_job_id={job_id}',
168
+ '--config=analysis/trafficforecast/starryn/experiments/configs/vertex.py',
169
+ f'--config.num_epochs={num_epochs}',
170
+ f'--config.freeze_point_forecasts={freeze_point_forecasts}',
171
+ f'--config.callbacks.tensorboard.n_decomposition_plots={n_decomposition_plots}',
172
+ f'--config.datasets.backcast_length={backcast_length}',
173
+ f'--config.datasets.forecast_length={forecast_length}',
174
+ f'--config.datasets.train_end_date={train_end_date}',
175
+ f'--config.datasets.train_path={csv_data_path}',
176
+ f'--config.datasets.static_covariates_path={csv_static_covariates_path}',
177
+ f'--config.datasets.static_covariates_vocab_path={static_covariates_vocab_path}',
178
+ f'--config.datasets.train_tf_record_patterns={train_tf_record_patterns}',
179
+ f'--config.datasets.val_tf_record_patterns={val_tf_record_patterns}',
180
+ f'--config.datasets.test_tf_record_patterns={test_tf_record_patterns}',
181
+ f'--config.datasets.n_val_windows={n_val_windows}',
182
+ f'--config.datasets.val_rolling_window_size={test_set_stride}',
183
+ f'--config.datasets.n_test_windows={n_test_windows}',
184
+ f'--config.datasets.test_rolling_window_size={test_set_stride}',
185
+ f'--config.model.regularizer_coeff={cleaning_activation_regularizer_coeff}',
186
+ f'--config.model.activation_regularizer_coeff={change_point_activation_regularizer_coeff}',
187
+ f'--config.model.output_regularizer_coeff={change_point_output_regularizer_coeff}',
188
+ f'--config.model.alpha_upper_bound={alpha_upper_bound}',
189
+ f'--config.model.beta_upper_bound={beta_upper_bound}',
190
+ f'--config.model.phi_lower_bound={phi_lower_bound}',
191
+ f'--config.model.b_fixed_val={b_fixed_val}',
192
+ f'--config.model.b0_fixed_val={b0_fixed_val}',
193
+ f'--config.model.phi_fixed_val={phi_fixed_val}',
194
+ f'--config.model.quantiles={quantiles}',
195
+ f'--config.model.use_static_covariates_trend={use_static_covariates}',
196
+ f'--config.model.use_static_covariates_calendar={use_static_covariates}',
197
+ f'--config.model.static_cov_names={static_covariate_names}',
198
+ f'--config.model.blocks_list={model_blocks}',
199
+ ],
200
+ },
201
+ }]
202
+ }
203
+ }
204
+ return utils.build_serverless_customjob_container_spec(
205
+ project=project,
206
+ location=location,
207
+ custom_job_payload=payload,
208
+ gcp_resources=gcp_resources,
209
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net upload decomposition plots component."""
15
+
16
+ from kfp import dsl
17
+
18
+
19
+ @dsl.component(packages_to_install=['google-cloud-aiplatform[tensorboard]'])
20
+ def upload_decomposition_plots(
21
+ project: str,
22
+ location: str,
23
+ tensorboard_id: str,
24
+ display_name: str,
25
+ trainer_dir: dsl.InputPath(),
26
+ ) -> dsl.Artifact:
27
+ # fmt: off
28
+ """Uploads decomposition plots to Tensorboard.
29
+
30
+ Args:
31
+ project: The project where the pipeline is run. Defaults to current project.
32
+ location: The location where the pipeline components are run.
33
+ tensorboard_id: The tensorboard instance ID.
34
+ display_name: The diplay name of the job.
35
+ trainer_dir: The directory where training artifacts where stored.
36
+
37
+ Returns:
38
+ A dsl.Artifact where the URI is the URI where decomposition plots can be
39
+ viewed.
40
+ """
41
+ import os # pylint: disable=g-import-not-at-top
42
+ from google.cloud import aiplatform # pylint: disable=g-import-not-at-top
43
+
44
+ log_dir = os.path.join(trainer_dir, 'tensorboard', 'r=1:gc=0')
45
+ project_number = os.environ['CLOUD_ML_PROJECT_ID']
46
+ aiplatform.init(project=project, location=location)
47
+ aiplatform.upload_tb_log(
48
+ tensorboard_id=tensorboard_id,
49
+ tensorboard_experiment_name=display_name,
50
+ logdir=log_dir,
51
+ experiment_display_name=display_name,
52
+ description=f'Tensorboard for {display_name}',
53
+ )
54
+ uri = (
55
+ f'https://{location}.tensorboard.googleusercontent.com/experiment/'
56
+ f'projects+{project_number}+locations+{location}+tensorboards+'
57
+ f'{tensorboard_id}+experiments+{display_name}/#images'
58
+ )
59
+ return dsl.Artifact(uri=uri)
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.