google-cloud-pipeline-components 2.14.1__py3-none-any.whl → 2.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of google-cloud-pipeline-components might be problematic. Click here for more details.

Files changed (88) hide show
  1. google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py +1 -1
  2. google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py +24 -0
  3. google_cloud_pipeline_components/_implementation/starry_net/__init__.py +41 -0
  4. google_cloud_pipeline_components/_implementation/{model_evaluation/import_evaluation → starry_net/dataprep}/__init__.py +1 -2
  5. google_cloud_pipeline_components/_implementation/starry_net/dataprep/component.py +173 -0
  6. google_cloud_pipeline_components/_implementation/starry_net/evaluation/__init__.py +13 -0
  7. google_cloud_pipeline_components/_implementation/starry_net/evaluation/component.py +23 -0
  8. google_cloud_pipeline_components/_implementation/starry_net/evaluation/evaluation.yaml +197 -0
  9. google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/__init__.py +13 -0
  10. google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py +62 -0
  11. google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/__init__.py +13 -0
  12. google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/component.py +77 -0
  13. google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/__init__.py +13 -0
  14. google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/component.py +97 -0
  15. google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/__init__.py +13 -0
  16. google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/component.py +76 -0
  17. google_cloud_pipeline_components/_implementation/starry_net/set_test_set/__init__.py +13 -0
  18. google_cloud_pipeline_components/_implementation/starry_net/set_test_set/component.py +48 -0
  19. google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/__init__.py +13 -0
  20. google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/component.py +70 -0
  21. google_cloud_pipeline_components/_implementation/starry_net/set_train_args/__init__.py +13 -0
  22. google_cloud_pipeline_components/_implementation/starry_net/set_train_args/component.py +90 -0
  23. google_cloud_pipeline_components/_implementation/starry_net/train/__init__.py +13 -0
  24. google_cloud_pipeline_components/_implementation/starry_net/train/component.py +220 -0
  25. google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/__init__.py +13 -0
  26. google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/component.py +64 -0
  27. google_cloud_pipeline_components/_implementation/starry_net/upload_model/__init__.py +13 -0
  28. google_cloud_pipeline_components/_implementation/starry_net/upload_model/component.py +23 -0
  29. google_cloud_pipeline_components/_implementation/starry_net/upload_model/upload_model.yaml +37 -0
  30. google_cloud_pipeline_components/_implementation/starry_net/version.py +18 -0
  31. google_cloud_pipeline_components/container/preview/custom_job/remote_runner.py +22 -0
  32. google_cloud_pipeline_components/container/utils/error_surfacing.py +45 -0
  33. google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py +36 -7
  34. google_cloud_pipeline_components/preview/automl/forecasting/forecasting_ensemble.py +1 -1
  35. google_cloud_pipeline_components/preview/automl/forecasting/forecasting_stage_1_tuner.py +2 -2
  36. google_cloud_pipeline_components/preview/automl/forecasting/forecasting_stage_2_tuner.py +2 -2
  37. google_cloud_pipeline_components/preview/automl/forecasting/learn_to_learn_forecasting_pipeline.yaml +38 -34
  38. google_cloud_pipeline_components/preview/automl/forecasting/sequence_to_sequence_forecasting_pipeline.yaml +38 -34
  39. google_cloud_pipeline_components/preview/automl/forecasting/temporal_fusion_transformer_forecasting_pipeline.yaml +38 -34
  40. google_cloud_pipeline_components/preview/automl/forecasting/time_series_dense_encoder_forecasting_pipeline.yaml +38 -34
  41. google_cloud_pipeline_components/preview/automl/forecasting/utils.py +49 -7
  42. google_cloud_pipeline_components/preview/automl/tabular/auto_feature_engineering.py +1 -1
  43. google_cloud_pipeline_components/preview/automl/tabular/automl_tabular_feature_selection_pipeline.yaml +39 -39
  44. google_cloud_pipeline_components/preview/automl/tabular/automl_tabular_v2_pipeline.yaml +41 -41
  45. google_cloud_pipeline_components/preview/automl/tabular/distillation_stage_feature_transform_engine.py +2 -2
  46. google_cloud_pipeline_components/preview/automl/tabular/feature_selection.py +2 -2
  47. google_cloud_pipeline_components/preview/automl/tabular/feature_selection_pipeline.yaml +4 -4
  48. google_cloud_pipeline_components/preview/automl/tabular/feature_transform_engine.py +3 -3
  49. google_cloud_pipeline_components/preview/automl/tabular/tabnet_hyperparameter_tuning_job.py +2 -2
  50. google_cloud_pipeline_components/preview/automl/tabular/tabnet_hyperparameter_tuning_job_pipeline.yaml +15 -15
  51. google_cloud_pipeline_components/preview/automl/tabular/tabnet_trainer.py +2 -2
  52. google_cloud_pipeline_components/preview/automl/tabular/tabnet_trainer_pipeline.yaml +13 -13
  53. google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_hyperparameter_tuning_job.py +2 -2
  54. google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_hyperparameter_tuning_job_pipeline.yaml +14 -14
  55. google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_trainer.py +2 -2
  56. google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_trainer_pipeline.yaml +13 -13
  57. google_cloud_pipeline_components/preview/automl/tabular/xgboost_hyperparameter_tuning_job_pipeline.yaml +14 -14
  58. google_cloud_pipeline_components/preview/automl/tabular/xgboost_trainer_pipeline.yaml +13 -13
  59. google_cloud_pipeline_components/preview/custom_job/utils.py +45 -6
  60. google_cloud_pipeline_components/preview/llm/rlhf/component.py +3 -6
  61. google_cloud_pipeline_components/preview/starry_net/__init__.py +19 -0
  62. google_cloud_pipeline_components/preview/starry_net/component.py +469 -0
  63. google_cloud_pipeline_components/proto/task_error_pb2.py +0 -1
  64. google_cloud_pipeline_components/v1/automl/forecasting/bqml_arima_predict_pipeline.yaml +10 -10
  65. google_cloud_pipeline_components/v1/automl/forecasting/bqml_arima_train_pipeline.yaml +31 -31
  66. google_cloud_pipeline_components/v1/automl/forecasting/prophet_predict_pipeline.yaml +13 -13
  67. google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer.py +3 -3
  68. google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer_pipeline.yaml +14 -14
  69. google_cloud_pipeline_components/v1/automl/tabular/automl_tabular_pipeline.yaml +37 -37
  70. google_cloud_pipeline_components/v1/automl/tabular/cv_trainer.py +2 -2
  71. google_cloud_pipeline_components/v1/automl/tabular/ensemble.py +2 -2
  72. google_cloud_pipeline_components/v1/automl/tabular/finalizer.py +1 -1
  73. google_cloud_pipeline_components/v1/automl/tabular/infra_validator.py +1 -1
  74. google_cloud_pipeline_components/v1/automl/tabular/split_materialized_data.py +1 -1
  75. google_cloud_pipeline_components/v1/automl/tabular/stage_1_tuner.py +2 -2
  76. google_cloud_pipeline_components/v1/automl/tabular/stats_and_example_gen.py +2 -2
  77. google_cloud_pipeline_components/v1/automl/tabular/training_configurator_and_validator.py +1 -1
  78. google_cloud_pipeline_components/v1/automl/tabular/transform.py +2 -2
  79. google_cloud_pipeline_components/v1/custom_job/component.py +3 -0
  80. google_cloud_pipeline_components/v1/custom_job/utils.py +4 -0
  81. google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +21 -0
  82. google_cloud_pipeline_components/version.py +1 -1
  83. {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/METADATA +17 -20
  84. {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/RECORD +87 -58
  85. {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/WHEEL +1 -1
  86. google_cloud_pipeline_components/_implementation/model_evaluation/import_evaluation/component.py +0 -208
  87. {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/LICENSE +0 -0
  88. {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net component to set TFRecord args if training with TF Records."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def maybe_set_tfrecord_args(
23
+ dataprep_previous_run_dir: str,
24
+ static_covariates: List[str],
25
+ ) -> NamedTuple(
26
+ 'TfrecordArgs',
27
+ static_covariates_vocab_path=str,
28
+ train_tf_record_patterns=str,
29
+ val_tf_record_patterns=str,
30
+ test_tf_record_patterns=str,
31
+ ):
32
+ # fmt: off
33
+ """Creates Trainer TFRecord args if training with TF Records.
34
+
35
+ Args:
36
+ dataprep_previous_run_dir: The dataprep dir from a previous run. Use this
37
+ to save time if you've already created TFRecords from your BigQuery
38
+ dataset with the same dataprep parameters as this run.
39
+ static_covariates: The static covariates to train the model with.
40
+
41
+ Returns:
42
+ A NamedTuple containing the path to the static covariates covabulary, and
43
+ the tf record patterns for the train, validation, and test sets.
44
+ """
45
+
46
+ outputs = NamedTuple(
47
+ 'TfrecordArgs',
48
+ static_covariates_vocab_path=str,
49
+ train_tf_record_patterns=str,
50
+ val_tf_record_patterns=str,
51
+ test_tf_record_patterns=str,
52
+ )
53
+
54
+ if static_covariates and dataprep_previous_run_dir:
55
+ static_covariates_vocab_path = (
56
+ f'{dataprep_previous_run_dir}/static_covariate_vocab.json'
57
+ )
58
+ else:
59
+ static_covariates_vocab_path = ''
60
+ if dataprep_previous_run_dir:
61
+ train_tf_record_patterns = (
62
+ f"('{dataprep_previous_run_dir}/tf_records/train*',)"
63
+ )
64
+ val_tf_record_patterns = f"('{dataprep_previous_run_dir}/tf_records/val*',)"
65
+ test_tf_record_patterns = (
66
+ f"('{dataprep_previous_run_dir}/tf_records/test_path_for_plot*',)"
67
+ )
68
+ else:
69
+ train_tf_record_patterns = '()'
70
+ val_tf_record_patterns = '()'
71
+ test_tf_record_patterns = '()'
72
+ return outputs(
73
+ static_covariates_vocab_path, # pylint: disable=too-many-function-args
74
+ train_tf_record_patterns, # pylint: disable=too-many-function-args
75
+ val_tf_record_patterns, # pylint: disable=too-many-function-args
76
+ test_tf_record_patterns, # pylint: disable=too-many-function-args
77
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,97 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """StarryNet Set Dataprep Args Component."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def set_dataprep_args(
23
+ model_blocks: List[str],
24
+ ts_identifier_columns: List[str],
25
+ static_covariate_columns: List[str],
26
+ csv_data_path: str,
27
+ previous_run_dir: str,
28
+ location: str,
29
+ ) -> NamedTuple(
30
+ 'DataprepArgs',
31
+ model_blocks=str,
32
+ ts_identifier_columns=str,
33
+ static_covariate_columns=str,
34
+ create_tf_records=bool,
35
+ docker_region=str,
36
+ ):
37
+ # fmt: off
38
+ """Creates Dataprep args.
39
+
40
+ Args:
41
+ model_blocks: The list of model blocks to use in the order they will appear
42
+ in the model. Possible values are `cleaning`, `change_point`, `trend`,
43
+ `hour_of_week`, `day_of_week`, `day_of_year`, `week_of_year`,
44
+ `month_of_year`, `residual`.
45
+ ts_identifier_columns: The list of ts_identifier columns from the BigQuery
46
+ data source.
47
+ static_covariate_columns: The list of strings of static covariate names.
48
+ csv_data_path: The path to the training data csv in the format
49
+ gs://bucket_name/sub_dir/blob_name.csv.
50
+ previous_run_dir: The dataprep dir from a previous run. Use this
51
+ to save time if you've already created TFRecords from your BigQuery
52
+ dataset with the same dataprep parameters as this run.
53
+ location: The location where the pipeline is run.
54
+
55
+ Returns:
56
+ A NamedTuple containing the model blocks formatted as expected by the
57
+ dataprep job, the ts_identifier_columns formatted as expected by the
58
+ dataprep job, the static_covariate_columns formatted as expected by the
59
+ dataprep job, a boolean indicating whether to create tf records, and the
60
+ region of the dataprep docker image.
61
+ """
62
+ outputs = NamedTuple(
63
+ 'DataprepArgs',
64
+ model_blocks=str,
65
+ ts_identifier_columns=str,
66
+ static_covariate_columns=str,
67
+ create_tf_records=bool,
68
+ docker_region=str,
69
+ )
70
+
71
+ def maybe_update_model_blocks(model_blocks: List[str]) -> List[str]:
72
+ return [f'{b}-hybrid' if '_of_' in b else b for b in model_blocks]
73
+
74
+ def create_name_tuple_from_list(input_list: List[str]) -> str:
75
+ if len(input_list) == 1:
76
+ return str(input_list).replace('[', '(').replace(']', ',)')
77
+ return str(input_list).replace('[', '(').replace(']', ')')
78
+
79
+ def set_docker_region(location: str) -> str:
80
+ if location.startswith('africa') or location.startswith('europe'):
81
+ return 'europe'
82
+ elif (
83
+ location.startswith('asia')
84
+ or location.startswith('australia')
85
+ or location.startswith('me')
86
+ ):
87
+ return 'asia'
88
+ else:
89
+ return 'us'
90
+
91
+ return outputs(
92
+ create_name_tuple_from_list(maybe_update_model_blocks(model_blocks)), # pylint: disable=too-many-function-args
93
+ ','.join(ts_identifier_columns), # pylint: disable=too-many-function-args
94
+ create_name_tuple_from_list(static_covariate_columns), # pylint: disable=too-many-function-args
95
+ False if csv_data_path or previous_run_dir else True, # pylint: disable=too-many-function-args
96
+ set_docker_region(location), # pylint: disable=too-many-function-args
97
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net Set Eval Args Component."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def set_eval_args(
23
+ big_query_source: dsl.Input[dsl.Artifact], quantiles: List[float]
24
+ ) -> NamedTuple(
25
+ 'EvalArgs',
26
+ big_query_source=str,
27
+ forecasting_type=str,
28
+ quantiles=List[float],
29
+ prediction_score_column=str,
30
+ ):
31
+ # fmt: off
32
+ """Creates Evaluation args.
33
+
34
+ Args:
35
+ big_query_source: The BQ Table containing the test set.
36
+ quantiles: The quantiles the model was trained to output.
37
+
38
+ Returns:
39
+ A NamedTuple containing big_query_source as a string, forecasting_type
40
+ used for evaluation step, quantiles in the format expected by the evaluation
41
+ job, and the prediction_score_column used to evaluate.
42
+ """
43
+ outputs = NamedTuple(
44
+ 'EvalArgs',
45
+ big_query_source=str,
46
+ forecasting_type=str,
47
+ quantiles=List[float],
48
+ prediction_score_column=str)
49
+
50
+ def set_forecasting_type_for_eval(quantiles: List[float]) -> str:
51
+ if quantiles and quantiles[-1] != 0.5:
52
+ return 'quantile'
53
+ return 'point'
54
+
55
+ def set_quantiles_for_eval(quantiles: List[float]) -> List[float]:
56
+ updated_q = [q for q in quantiles if q != 0.5]
57
+ if updated_q:
58
+ updated_q = [0.5] + updated_q
59
+ return updated_q
60
+
61
+ def set_prediction_score_column(
62
+ quantiles: List[float]) -> str:
63
+ updated_q = [q for q in quantiles if q != 0.5]
64
+ if updated_q:
65
+ return 'predicted_x.quantile_predictions'
66
+ return 'predicted_x.value'
67
+
68
+ project_id = big_query_source.metadata['projectId']
69
+ dataset_id = big_query_source.metadata['datasetId']
70
+ table_id = big_query_source.metadata['tableId']
71
+ return outputs(
72
+ f'bq://{project_id}.{dataset_id}.{table_id}', # pylint: disable=too-many-function-args
73
+ set_forecasting_type_for_eval(quantiles), # pylint: disable=too-many-function-args
74
+ set_quantiles_for_eval(quantiles), # pylint: disable=too-many-function-args
75
+ set_prediction_score_column(quantiles), # pylint: disable=too-many-function-args
76
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net Set Test Set Component."""
15
+
16
+ from typing import NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component(packages_to_install=['tensorflow==2.11.0'])
22
+ def set_test_set(
23
+ dataprep_dir: dsl.InputPath(),
24
+ ) -> NamedTuple('TestSetArtifact', uri=str, artifact=dsl.Artifact):
25
+ # fmt: off
26
+ """Creates test set artifact.
27
+
28
+ Args:
29
+ dataprep_dir: The bucket where dataprep artifacts are stored.
30
+
31
+ Returns:
32
+ The test set dsl.Artifact.
33
+ """
34
+ import os # pylint: disable=g-import-not-at-top
35
+ import json # pylint: disable=g-import-not-at-top
36
+ import tensorflow as tf # pylint: disable=g-import-not-at-top
37
+
38
+ with tf.io.gfile.GFile(
39
+ os.path.join(dataprep_dir, 'big_query_test_set.json')
40
+ ) as f:
41
+ metadata = json.load(f)
42
+ project = metadata['projectId']
43
+ dataset = metadata['datasetId']
44
+ table = metadata['tableId']
45
+ output = NamedTuple('TestSetArtifact', uri=str, artifact=dsl.Artifact)
46
+ uri = f'bq://{project}.{dataset}.{table}'
47
+ artifact = dsl.Artifact(uri=uri, metadata=metadata)
48
+ return output(uri, artifact) # pylint: disable=too-many-function-args
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,70 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ "Starry Net component to set TFRecord args."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def set_tfrecord_args(
23
+ dataprep_dir: dsl.InputPath(),
24
+ static_covariates: List[str],
25
+ ) -> NamedTuple(
26
+ 'TfrecordArgs',
27
+ static_covariates_vocab_path=str,
28
+ train_tf_record_patterns=str,
29
+ val_tf_record_patterns=str,
30
+ test_tf_record_patterns=str,
31
+ ):
32
+ # fmt: off
33
+ """Creates Trainer TFRecord args.
34
+
35
+ Args:
36
+ dataprep_dir: The dataprep directory where dataprep artifacts are stored.
37
+ static_covariates: The static covariates to train the model with.
38
+
39
+ Returns:
40
+ A NamedTuple containing the path to the static covariates covabulary, and
41
+ the tf record patterns for the train, validation, and test sets.
42
+ """
43
+
44
+ outputs = NamedTuple(
45
+ 'TfrecordArgs',
46
+ static_covariates_vocab_path=str,
47
+ train_tf_record_patterns=str,
48
+ val_tf_record_patterns=str,
49
+ test_tf_record_patterns=str,
50
+ )
51
+
52
+ if static_covariates and dataprep_dir:
53
+ static_covariates_vocab_path = f'{dataprep_dir}/static_covariate_vocab.json'
54
+ else:
55
+ static_covariates_vocab_path = ''
56
+ if dataprep_dir:
57
+ train_tf_record_patterns = f"('{dataprep_dir}/tf_records/train*',)"
58
+ val_tf_record_patterns = f"('{dataprep_dir}/tf_records/val*',)"
59
+ test_tf_record_patterns = (
60
+ f"('{dataprep_dir}/tf_records/test_path_for_plot*',)")
61
+ else:
62
+ train_tf_record_patterns = '()'
63
+ val_tf_record_patterns = '()'
64
+ test_tf_record_patterns = '()'
65
+ return outputs(
66
+ static_covariates_vocab_path, # pylint: disable=too-many-function-args
67
+ train_tf_record_patterns, # pylint: disable=too-many-function-args
68
+ val_tf_record_patterns, # pylint: disable=too-many-function-args
69
+ test_tf_record_patterns, # pylint: disable=too-many-function-args
70
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,90 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net component to set training args."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def set_train_args(
23
+ quantiles: List[float],
24
+ model_blocks: List[str],
25
+ static_covariates: List[str],
26
+ ) -> NamedTuple(
27
+ 'TrainArgs',
28
+ quantiles=str,
29
+ use_static_covariates=bool,
30
+ static_covariate_names=str,
31
+ model_blocks=str,
32
+ freeze_point_forecasts=bool,
33
+ ):
34
+ # fmt: off
35
+ """Creates Trainer model args.
36
+
37
+ Args:
38
+ quantiles: The list of floats representing quantiles. Leave blank if
39
+ only training to produce point forecasts.
40
+ model_blocks: The list of model blocks to use in the order they will appear
41
+ in the model. Possible values are `cleaning`, `change_point`, `trend`,
42
+ `hour_of_week`, `day_of_week`, `day_of_year`, `week_of_year`,
43
+ `month_of_year`, `residual`.
44
+ static_covariates: The list of strings of static covariate names.
45
+
46
+ Returns:
47
+ A NamedTuple containing the quantiles formatted as expected by the train
48
+ job, a bool indicating whether the job should train with static covariates,
49
+ the model blocks formatted as expected by the train job, and a bool
50
+ indicating whether or not to do two-pass training, fist training for point
51
+ forecsats and then quantiles.
52
+ """
53
+
54
+ outputs = NamedTuple(
55
+ 'TrainArgs',
56
+ quantiles=str,
57
+ use_static_covariates=bool,
58
+ static_covariate_names=str,
59
+ model_blocks=str,
60
+ freeze_point_forecasts=bool,
61
+ )
62
+
63
+ def set_quantiles(input_list: List[float]) -> str:
64
+ if not input_list or input_list[0] != 0.5:
65
+ input_list = [0.5] + input_list
66
+ if len(input_list) == 1:
67
+ return str(input_list).replace('[', '(').replace(']', ',)')
68
+ return str(input_list).replace('[', '(').replace(']', ')')
69
+
70
+ def maybe_update_model_blocks(
71
+ quantiles: List[float], model_blocks: List[str]) -> List[str]:
72
+ updated_q = [q for q in quantiles if q != 0.5]
73
+ model_blocks = [b for b in model_blocks if b != 'quantile']
74
+ if updated_q:
75
+ model_blocks.append('quantile')
76
+ return [f'{b}-hybrid' if '_of_' in b else b for b in model_blocks]
77
+
78
+ def create_name_tuple_from_list(input_list: List[str]) -> str:
79
+ if len(input_list) == 1:
80
+ return str(input_list).replace('[', '(').replace(']', ',)')
81
+ return str(input_list).replace('[', '(').replace(']', ')')
82
+
83
+ return outputs(
84
+ set_quantiles(quantiles), # pylint: disable=too-many-function-args
85
+ True if static_covariates else False, # pylint: disable=too-many-function-args
86
+ create_name_tuple_from_list(static_covariates), # pylint: disable=too-many-function-args
87
+ create_name_tuple_from_list( # pylint: disable=too-many-function-args
88
+ maybe_update_model_blocks(quantiles, model_blocks)),
89
+ True if quantiles and quantiles[-1] != 0.5 else False, # pylint: disable=too-many-function-args
90
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.