google-cloud-pipeline-components 2.14.0__py3-none-any.whl → 2.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of google-cloud-pipeline-components might be problematic. Click here for more details.

Files changed (64) hide show
  1. google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +10 -26
  2. google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py +1 -1
  3. google_cloud_pipeline_components/_implementation/llm/infer_preprocessor.py +109 -0
  4. google_cloud_pipeline_components/_implementation/llm/online_evaluation_pairwise.py +8 -0
  5. google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py +5 -6
  6. google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py +24 -0
  7. google_cloud_pipeline_components/_implementation/model_evaluation/__init__.py +0 -12
  8. google_cloud_pipeline_components/_implementation/model_evaluation/llm_embedding/evaluation_llm_embedding_pipeline.py +2 -1
  9. google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py +14 -0
  10. google_cloud_pipeline_components/_implementation/starry_net/__init__.py +41 -0
  11. google_cloud_pipeline_components/_implementation/{model_evaluation/import_evaluation → starry_net/dataprep}/__init__.py +1 -2
  12. google_cloud_pipeline_components/_implementation/starry_net/dataprep/component.py +159 -0
  13. google_cloud_pipeline_components/_implementation/starry_net/evaluation/__init__.py +13 -0
  14. google_cloud_pipeline_components/_implementation/starry_net/evaluation/component.py +23 -0
  15. google_cloud_pipeline_components/_implementation/starry_net/evaluation/evaluation.yaml +197 -0
  16. google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/__init__.py +13 -0
  17. google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py +62 -0
  18. google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/__init__.py +13 -0
  19. google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/component.py +77 -0
  20. google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/__init__.py +13 -0
  21. google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/component.py +97 -0
  22. google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/__init__.py +13 -0
  23. google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/component.py +76 -0
  24. google_cloud_pipeline_components/_implementation/starry_net/set_test_set/__init__.py +13 -0
  25. google_cloud_pipeline_components/_implementation/starry_net/set_test_set/component.py +48 -0
  26. google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/__init__.py +13 -0
  27. google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/component.py +70 -0
  28. google_cloud_pipeline_components/_implementation/starry_net/set_train_args/__init__.py +13 -0
  29. google_cloud_pipeline_components/_implementation/starry_net/set_train_args/component.py +90 -0
  30. google_cloud_pipeline_components/_implementation/starry_net/train/__init__.py +13 -0
  31. google_cloud_pipeline_components/_implementation/starry_net/train/component.py +209 -0
  32. google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/__init__.py +13 -0
  33. google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/component.py +59 -0
  34. google_cloud_pipeline_components/_implementation/starry_net/upload_model/__init__.py +13 -0
  35. google_cloud_pipeline_components/_implementation/starry_net/upload_model/component.py +23 -0
  36. google_cloud_pipeline_components/_implementation/starry_net/upload_model/upload_model.yaml +37 -0
  37. google_cloud_pipeline_components/_implementation/starry_net/version.py +18 -0
  38. google_cloud_pipeline_components/container/utils/error_surfacing.py +45 -0
  39. google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py +36 -7
  40. google_cloud_pipeline_components/preview/llm/infer/component.py +22 -25
  41. google_cloud_pipeline_components/preview/llm/rlhf/component.py +15 -8
  42. google_cloud_pipeline_components/preview/model_evaluation/__init__.py +4 -1
  43. google_cloud_pipeline_components/{_implementation/model_evaluation/import_evaluation/component.py → preview/model_evaluation/model_evaluation_import_component.py} +4 -3
  44. google_cloud_pipeline_components/preview/starry_net/__init__.py +19 -0
  45. google_cloud_pipeline_components/preview/starry_net/component.py +443 -0
  46. google_cloud_pipeline_components/proto/task_error_pb2.py +32 -0
  47. google_cloud_pipeline_components/v1/automl/forecasting/prophet_predict_pipeline.yaml +13 -13
  48. google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer.py +10 -0
  49. google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer_pipeline.yaml +4 -1
  50. google_cloud_pipeline_components/v1/model_evaluation/error_analysis_pipeline.py +8 -10
  51. google_cloud_pipeline_components/v1/model_evaluation/evaluated_annotation_pipeline.py +2 -2
  52. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_tabular_feature_attribution_pipeline.py +2 -2
  53. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_tabular_pipeline.py +2 -2
  54. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_unstructure_data_pipeline.py +2 -2
  55. google_cloud_pipeline_components/v1/model_evaluation/evaluation_feature_attribution_pipeline.py +2 -2
  56. google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_classification_pipeline.py +4 -2
  57. google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +8 -2
  58. google_cloud_pipeline_components/v1/model_evaluation/model_based_llm_evaluation/autosxs/autosxs_pipeline.py +1 -0
  59. google_cloud_pipeline_components/version.py +1 -1
  60. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/METADATA +17 -20
  61. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/RECORD +64 -32
  62. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/WHEEL +1 -1
  63. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/LICENSE +0 -0
  64. {google_cloud_pipeline_components-2.14.0.dist-info → google_cloud_pipeline_components-2.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,159 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net component for data preparation."""
15
+
16
+ from google_cloud_pipeline_components import utils
17
+ from google_cloud_pipeline_components._implementation.starry_net import version
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.container_component
22
+ def dataprep(
23
+ gcp_resources: dsl.OutputPath(str),
24
+ dataprep_dir: dsl.Output[dsl.Artifact], # pytype: disable=unsupported-operands
25
+ backcast_length: int,
26
+ forecast_length: int,
27
+ train_end_date: str,
28
+ n_val_windows: int,
29
+ n_test_windows: int,
30
+ test_set_stride: int,
31
+ model_blocks: str,
32
+ bigquery_source: str,
33
+ ts_identifier_columns: str,
34
+ time_column: str,
35
+ static_covariate_columns: str,
36
+ target_column: str,
37
+ machine_type: str,
38
+ docker_region: str,
39
+ location: str,
40
+ project: str,
41
+ job_id: str,
42
+ job_name_prefix: str,
43
+ num_workers: int,
44
+ max_num_workers: int,
45
+ disk_size_gb: int,
46
+ test_set_only: bool,
47
+ bigquery_output: str,
48
+ gcs_source: str,
49
+ gcs_static_covariate_source: str,
50
+ encryption_spec_key_name: str,
51
+ ):
52
+ # fmt: off
53
+ """Runs Dataprep for training and evaluating a STARRY-Net model.
54
+
55
+ Args:
56
+ gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the
57
+ CustomJob.
58
+ dataprep_dir: The gcp bucket path where all dataprep artifacts
59
+ are saved.
60
+ backcast_length: The length of the input window to feed into the model.
61
+ forecast_length: The length of the forecast horizon.
62
+ train_end_date: The last date of data to use in the training set. All
63
+ subsequent dates are part of the test set.
64
+ n_val_windows: The number of windows to use for the val set. If 0, no
65
+ validation set is used.
66
+ n_test_windows: The number of windows to use for the test set. Must be >= 1.
67
+ test_set_stride: The number of timestamps to roll forward when
68
+ constructing the val and test sets.
69
+ model_blocks: The stringified tuple of blocks to use in the order
70
+ that they appear in the model. Possible values are `cleaning`,
71
+ `change_point`, `trend`, `hour_of_week-hybrid`, `day_of_week-hybrid`,
72
+ `day_of_year-hybrid`, `week_of_year-hybrid`, `month_of_year-hybrid`,
73
+ `residual`, `quantile`.
74
+ bigquery_source: The BigQuery source of the data.
75
+ ts_identifier_columns: The columns that identify unique time series in the BigQuery
76
+ data source.
77
+ time_column: The column with timestamps in the BigQuery source.
78
+ static_covariate_columns: The names of the staic covariates.
79
+ target_column: The target column in the Big Query data source.
80
+ machine_type: The machine type of the dataflow workers.
81
+ docker_region: The docker region, used to determine which image to use.
82
+ location: The location where the job is run.
83
+ project: The name of the project.
84
+ job_id: The pipeline job id.
85
+ job_name_prefix: The name of the dataflow job name prefix.
86
+ num_workers: The initial number of workers in the dataflow job.
87
+ max_num_workers: The maximum number of workers in the dataflow job.
88
+ disk_size_gb: The disk size of each dataflow worker.
89
+ test_set_only: Whether to only create the test set BigQuery table or also
90
+ to create TFRecords for traiing and validation.
91
+ bigquery_output: The BigQuery dataset where the test set is written in the
92
+ form bq://project.dataset.
93
+ gcs_source: The path the csv file of the data source.
94
+ gcs_static_covariate_source: The path to the csv file of static covariates.
95
+ encryption_spec_key_name: Customer-managed encryption key options for the
96
+ CustomJob. If this is set, then all resources created by the CustomJob
97
+ will be encrypted with the provided encryption key.
98
+
99
+ Returns:
100
+ gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the
101
+ CustomJob.
102
+ dataprep_dir: The gcp bucket path where all dataprep artifacts
103
+ are saved.
104
+ """
105
+ job_name = f'{job_name_prefix}-{job_id}'
106
+ payload = {
107
+ 'display_name': job_name,
108
+ 'encryption_spec': {
109
+ 'kms_key_name': str(encryption_spec_key_name),
110
+ },
111
+ 'job_spec': {
112
+ 'worker_pool_specs': [{
113
+ 'replica_count': '1',
114
+ 'machine_spec': {
115
+ 'machine_type': str(machine_type),
116
+ },
117
+ 'disk_spec': {
118
+ 'boot_disk_type': 'pd-ssd',
119
+ 'boot_disk_size_gb': 100,
120
+ },
121
+ 'container_spec': {
122
+ 'image_uri': f'{docker_region}-docker.pkg.dev/vertex-ai-restricted/starryn/dataprep:captain_{version.DATAPREP_VERSION}',
123
+ 'args': [
124
+ '--config=starryn/experiments/configs/vertex.py',
125
+ f'--config.datasets.backcast_length={backcast_length}',
126
+ f'--config.datasets.forecast_length={forecast_length}',
127
+ f'--config.datasets.train_end_date={train_end_date}',
128
+ f'--config.datasets.n_val_windows={n_val_windows}',
129
+ f'--config.datasets.val_rolling_window_size={test_set_stride}',
130
+ f'--config.datasets.n_test_windows={n_test_windows}',
131
+ f'--config.datasets.test_rolling_window_size={test_set_stride}',
132
+ f'--config.model.static_cov_names={static_covariate_columns}',
133
+ f'--config.model.blocks_list={model_blocks}',
134
+ f'--bigquery_source={bigquery_source}',
135
+ f'--bigquery_output={bigquery_output}',
136
+ f'--gcs_source={gcs_source}',
137
+ f'--gcs_static_covariate_source={gcs_static_covariate_source}',
138
+ f'--ts_identifier_columns={ts_identifier_columns}',
139
+ f'--time_column={time_column}',
140
+ f'--target_column={target_column}',
141
+ f'--job_id={job_name}',
142
+ f'--num_workers={num_workers}',
143
+ f'--max_num_workers={max_num_workers}',
144
+ f'--root_bucket={dataprep_dir.uri}',
145
+ f'--disk_size={disk_size_gb}',
146
+ f'--machine_type={machine_type}',
147
+ f'--test_set_only={test_set_only}',
148
+ f'--image_uri={docker_region}-docker.pkg.dev/vertex-ai-restricted/starryn/dataprep:replica_{version.DATAPREP_VERSION}',
149
+ ],
150
+ },
151
+ }]
152
+ }
153
+ }
154
+ return utils.build_serverless_customjob_container_spec(
155
+ project=project,
156
+ location=location,
157
+ custom_job_payload=payload,
158
+ gcp_resources=gcp_resources,
159
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,23 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """StarryNet Evaluation Component."""
15
+
16
+ import os
17
+
18
+ from kfp import components
19
+
20
+ # TODO(b/346580764)
21
+ evaluation = components.load_component_from_file(
22
+ os.path.join(os.path.dirname(__file__), 'evaluation.yaml')
23
+ )
@@ -0,0 +1,197 @@
1
+ name: model_evaluation_forecasting
2
+ description: |
3
+ Computes a google.ForecastingMetrics Artifact, containing evaluation metrics given a model's prediction results.
4
+ Creates a dataflow job with Apache Beam and TFMA to compute evaluation metrics.
5
+ Supports point forecasting and quantile forecasting for tabular data.
6
+ Args:
7
+ project (str):
8
+ Project to run evaluation container.
9
+ location (Optional[str]):
10
+ Location for running the evaluation.
11
+ If not set, defaulted to `us-central1`.
12
+ root_dir (str):
13
+ The GCS directory for keeping staging files.
14
+ A random subdirectory will be created under the directory to keep job info for resuming
15
+ the job in case of failure.
16
+ predictions_format (Optional[str]):
17
+ The file format for the batch prediction results. `jsonl` is currently the only allowed
18
+ format.
19
+ If not set, defaulted to `jsonl`.
20
+ predictions_gcs_source (Optional[system.Artifact]):
21
+ An artifact with its URI pointing toward a GCS directory with prediction or explanation
22
+ files to be used for this evaluation.
23
+ For prediction results, the files should be named "prediction.results-*".
24
+ For explanation results, the files should be named "explanation.results-*".
25
+ predictions_bigquery_source (Optional[google.BQTable]):
26
+ BigQuery table with prediction or explanation data to be used for this evaluation.
27
+ For prediction results, the table column should be named "predicted_*".
28
+ ground_truth_format(Optional[str]):
29
+ Required for custom tabular and non tabular data.
30
+ The file format for the ground truth files. `jsonl` is currently the only allowed format.
31
+ If not set, defaulted to `jsonl`.
32
+ ground_truth_gcs_source(Optional[Sequence[str]]):
33
+ Required for custom tabular and non tabular data.
34
+ The GCS uris representing where the ground truth is located.
35
+ Used to provide ground truth for each prediction instance when they are not part of the batch prediction jobs prediction instance.
36
+ ground_truth_bigquery_source(Optional[str]):
37
+ Required for custom tabular.
38
+ The BigQuery table uri representing where the ground truth is located.
39
+ Used to provide ground truth for each prediction instance when they are not part of the batch prediction jobs prediction instance.
40
+ target_field_name (str):
41
+ The full name path of the features target field in the predictions file.
42
+ Formatted to be able to find nested columns, delimited by `.`.
43
+ Alternatively referred to as the ground truth (or ground_truth_column) field.
44
+ model (Optional[google.VertexModel]):
45
+ The Model used for predictions job.
46
+ Must share the same ancestor Location.
47
+ prediction_score_column (Optional[str]):
48
+ Optional. The column name of the field containing batch prediction scores.
49
+ Formatted to be able to find nested columns, delimited by `.`.
50
+ If not set, defaulted to `prediction.value` for a `point` forecasting_type and
51
+ `prediction.quantile_predictions` for a `quantile` forecasting_type.
52
+ forecasting_type (Optional[str]):
53
+ Optional. If the problem_type is `forecasting`, then the forecasting type being addressed
54
+ by this regression evaluation run. `point` and `quantile` are the supported types.
55
+ If not set, defaulted to `point`.
56
+ forecasting_quantiles (Optional[Sequence[Float]]):
57
+ Required for a `quantile` forecasting_type.
58
+ The list of quantiles in the same order appeared in the quantile prediction score column.
59
+ If one of the quantiles is set to `0.5f`, point evaluation will be set on that index.
60
+ example_weight_column (Optional[str]):
61
+ Optional. The column name of the field containing example weights.
62
+ Each value of positive_classes provided.
63
+ point_evaluation_quantile (Optional[Float]):
64
+ Required for a `quantile` forecasting_type.
65
+ A quantile in the list of forecasting_quantiles that will be used for point evaluation
66
+ metrics.
67
+ dataflow_service_account (Optional[str]):
68
+ Optional. Service account to run the dataflow job.
69
+ If not set, dataflow will use the default woker service account.
70
+ For more details, see https://cloud.google.com/dataflow/docs/concepts/security-and-permissions#default_worker_service_account
71
+ dataflow_disk_size (Optional[int]):
72
+ Optional. The disk size (in GB) of the machine executing the evaluation run.
73
+ If not set, defaulted to `50`.
74
+ dataflow_machine_type (Optional[str]):
75
+ Optional. The machine type executing the evaluation run.
76
+ If not set, defaulted to `n1-standard-4`.
77
+ dataflow_workers_num (Optional[int]):
78
+ Optional. The number of workers executing the evaluation run.
79
+ If not set, defaulted to `10`.
80
+ dataflow_max_workers_num (Optional[int]):
81
+ Optional. The max number of workers executing the evaluation run.
82
+ If not set, defaulted to `25`.
83
+ dataflow_subnetwork (Optional[str]):
84
+ Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be
85
+ used. More details:
86
+ https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications
87
+ dataflow_use_public_ips (Optional[bool]):
88
+ Specifies whether Dataflow workers use public IP addresses.
89
+ encryption_spec_key_name (Optional[str]):
90
+ Customer-managed encryption key.
91
+ Returns:
92
+ evaluation_metrics (google.ForecastingMetrics):
93
+ google.ForecastingMetrics artifact representing the forecasting evaluation metrics in GCS.
94
+ inputs:
95
+ - { name: project, type: String }
96
+ - { name: location, type: String, default: "us-central1" }
97
+ - { name: root_dir, type: system.Artifact }
98
+ - { name: predictions_format, type: String, default: "jsonl" }
99
+ - { name: predictions_gcs_source, type: Artifact, optional: True }
100
+ - { name: predictions_bigquery_source, type: google.BQTable, optional: True }
101
+ - { name: ground_truth_format, type: String, default: "jsonl" }
102
+ - { name: ground_truth_gcs_source, type: JsonArray, default: "[]" }
103
+ - { name: ground_truth_bigquery_source, type: String, default: "" }
104
+ - { name: target_field_name, type: String }
105
+ - { name: model, type: google.VertexModel, optional: True }
106
+ - { name: prediction_score_column, type: String, default: "" }
107
+ - { name: forecasting_type, type: String, default: "point" }
108
+ - { name: forecasting_quantiles, type: JsonArray, default: "[0.5]" }
109
+ - { name: example_weight_column, type: String, default: "" }
110
+ - { name: point_evaluation_quantile, type: Float, default: 0.5 }
111
+ - { name: dataflow_service_account, type: String, default: "" }
112
+ - { name: dataflow_disk_size, type: Integer, default: 50 }
113
+ - { name: dataflow_machine_type, type: String, default: "n1-standard-4" }
114
+ - { name: dataflow_workers_num, type: Integer, default: 1 }
115
+ - { name: dataflow_max_workers_num, type: Integer, default: 5 }
116
+ - { name: dataflow_subnetwork, type: String, default: "" }
117
+ - { name: dataflow_use_public_ips, type: Boolean, default: "true" }
118
+ - { name: encryption_spec_key_name, type: String, default: "" }
119
+ outputs:
120
+ - { name: evaluation_metrics, type: google.ForecastingMetrics }
121
+ - { name: gcp_resources, type: String }
122
+ implementation:
123
+ container:
124
+ image: gcr.io/ml-pipeline/model-evaluation:v0.9
125
+ command:
126
+ - python
127
+ - /main.py
128
+ args:
129
+ - --setup_file
130
+ - /setup.py
131
+ - --json_mode
132
+ - "true"
133
+ - --project_id
134
+ - { inputValue: project }
135
+ - --location
136
+ - { inputValue: location }
137
+ - --problem_type
138
+ - "forecasting"
139
+ - --forecasting_type
140
+ - { inputValue: forecasting_type }
141
+ - --forecasting_quantiles
142
+ - { inputValue: forecasting_quantiles }
143
+ - --point_evaluation_quantile
144
+ - { inputValue: point_evaluation_quantile }
145
+ - --batch_prediction_format
146
+ - { inputValue: predictions_format }
147
+ - if:
148
+ cond: {isPresent: predictions_gcs_source}
149
+ then:
150
+ - --batch_prediction_gcs_source
151
+ - "{{$.inputs.artifacts['predictions_gcs_source'].uri}}"
152
+ - if:
153
+ cond: {isPresent: predictions_bigquery_source}
154
+ then:
155
+ - --batch_prediction_bigquery_source
156
+ - "bq://{{$.inputs.artifacts['predictions_bigquery_source'].metadata['projectId']}}.{{$.inputs.artifacts['predictions_bigquery_source'].metadata['datasetId']}}.{{$.inputs.artifacts['predictions_bigquery_source'].metadata['tableId']}}"
157
+ - if:
158
+ cond: {isPresent: model}
159
+ then:
160
+ - --model_name
161
+ - "{{$.inputs.artifacts['model'].metadata['resourceName']}}"
162
+ - --ground_truth_format
163
+ - { inputValue: ground_truth_format }
164
+ - --ground_truth_gcs_source
165
+ - { inputValue: ground_truth_gcs_source }
166
+ - --ground_truth_bigquery_source
167
+ - { inputValue: ground_truth_bigquery_source }
168
+ - --root_dir
169
+ - "{{$.inputs.artifacts['root_dir'].uri}}"
170
+ - --target_field_name
171
+ - "instance.{{$.inputs.parameters['target_field_name']}}"
172
+ - --prediction_score_column
173
+ - { inputValue: prediction_score_column }
174
+ - --dataflow_job_prefix
175
+ - "evaluation-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}"
176
+ - --dataflow_service_account
177
+ - { inputValue: dataflow_service_account }
178
+ - --dataflow_disk_size
179
+ - { inputValue: dataflow_disk_size }
180
+ - --dataflow_machine_type
181
+ - { inputValue: dataflow_machine_type }
182
+ - --dataflow_workers_num
183
+ - { inputValue: dataflow_workers_num }
184
+ - --dataflow_max_workers_num
185
+ - { inputValue: dataflow_max_workers_num }
186
+ - --dataflow_subnetwork
187
+ - { inputValue: dataflow_subnetwork }
188
+ - --dataflow_use_public_ips
189
+ - { inputValue: dataflow_use_public_ips }
190
+ - --kms_key_name
191
+ - { inputValue: encryption_spec_key_name }
192
+ - --output_metrics_gcs_path
193
+ - { outputUri: evaluation_metrics }
194
+ - --gcp_resources
195
+ - { outputPath: gcp_resources }
196
+ - --executor_input
197
+ - "{{$}}"
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """StarryNet get training artifacts component."""
15
+
16
+ from typing import NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component(packages_to_install=['tensorflow==2.11.0'])
22
+ def get_training_artifacts(
23
+ docker_region: str,
24
+ trainer_dir: dsl.InputPath(),
25
+ ) -> NamedTuple(
26
+ 'TrainingArtifacts',
27
+ image_uri=str,
28
+ artifact_uri=str,
29
+ prediction_schema_uri=str,
30
+ instance_schema_uri=str,
31
+ ):
32
+ # fmt: off
33
+ """Gets the artifact URIs from the training job.
34
+
35
+ Args:
36
+ docker_region: The region from which the training docker image is pulled.
37
+ trainer_dir: The directory where training artifacts where stored.
38
+
39
+ Returns:
40
+ A NamedTuple containing the image_uri for the prediction server,
41
+ the artifact_uri with model artifacts, the prediction_schema_uri,
42
+ and the instance_schema_uri.
43
+ """
44
+ import os # pylint: disable=g-import-not-at-top
45
+ import tensorflow as tf # pylint: disable=g-import-not-at-top
46
+
47
+ with tf.io.gfile.GFile(os.path.join(trainer_dir, 'trainer.txt')) as f:
48
+ private_dir = f.read().strip()
49
+
50
+ outputs = NamedTuple(
51
+ 'TrainingArtifacts',
52
+ image_uri=str,
53
+ artifact_uri=str,
54
+ prediction_schema_uri=bool,
55
+ instance_schema_uri=str,
56
+ )
57
+ return outputs(
58
+ f'{docker_region}-docker.pkg.dev/vertex-ai/starryn/predictor:20240617_2142_RC00', # pylint: disable=too-many-function-args
59
+ private_dir, # pylint: disable=too-many-function-args
60
+ os.path.join(private_dir, 'predict_schema.yaml'), # pylint: disable=too-many-function-args
61
+ os.path.join(private_dir, 'instance_schema.yaml'), # pylint: disable=too-many-function-args
62
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Starry Net component to set TFRecord args if training with TF Records."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def maybe_set_tfrecord_args(
23
+ dataprep_previous_run_dir: str,
24
+ static_covariates: List[str],
25
+ ) -> NamedTuple(
26
+ 'TfrecordArgs',
27
+ static_covariates_vocab_path=str,
28
+ train_tf_record_patterns=str,
29
+ val_tf_record_patterns=str,
30
+ test_tf_record_patterns=str,
31
+ ):
32
+ # fmt: off
33
+ """Creates Trainer TFRecord args if training with TF Records.
34
+
35
+ Args:
36
+ dataprep_previous_run_dir: The dataprep dir from a previous run. Use this
37
+ to save time if you've already created TFRecords from your BigQuery
38
+ dataset with the same dataprep parameters as this run.
39
+ static_covariates: The static covariates to train the model with.
40
+
41
+ Returns:
42
+ A NamedTuple containing the path to the static covariates covabulary, and
43
+ the tf record patterns for the train, validation, and test sets.
44
+ """
45
+
46
+ outputs = NamedTuple(
47
+ 'TfrecordArgs',
48
+ static_covariates_vocab_path=str,
49
+ train_tf_record_patterns=str,
50
+ val_tf_record_patterns=str,
51
+ test_tf_record_patterns=str,
52
+ )
53
+
54
+ if static_covariates and dataprep_previous_run_dir:
55
+ static_covariates_vocab_path = (
56
+ f'{dataprep_previous_run_dir}/static_covariate_vocab.json'
57
+ )
58
+ else:
59
+ static_covariates_vocab_path = ''
60
+ if dataprep_previous_run_dir:
61
+ train_tf_record_patterns = (
62
+ f"('{dataprep_previous_run_dir}/tf_records/train*',)"
63
+ )
64
+ val_tf_record_patterns = f"('{dataprep_previous_run_dir}/tf_records/val*',)"
65
+ test_tf_record_patterns = (
66
+ f"('{dataprep_previous_run_dir}/tf_records/test_path_for_plot*',)"
67
+ )
68
+ else:
69
+ train_tf_record_patterns = '()'
70
+ val_tf_record_patterns = '()'
71
+ test_tf_record_patterns = '()'
72
+ return outputs(
73
+ static_covariates_vocab_path, # pylint: disable=too-many-function-args
74
+ train_tf_record_patterns, # pylint: disable=too-many-function-args
75
+ val_tf_record_patterns, # pylint: disable=too-many-function-args
76
+ test_tf_record_patterns, # pylint: disable=too-many-function-args
77
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,97 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """StarryNet Set Dataprep Args Component."""
15
+
16
+ from typing import List, NamedTuple
17
+
18
+ from kfp import dsl
19
+
20
+
21
+ @dsl.component
22
+ def set_dataprep_args(
23
+ model_blocks: List[str],
24
+ ts_identifier_columns: List[str],
25
+ static_covariate_columns: List[str],
26
+ csv_data_path: str,
27
+ previous_run_dir: str,
28
+ location: str,
29
+ ) -> NamedTuple(
30
+ 'DataprepArgs',
31
+ model_blocks=str,
32
+ ts_identifier_columns=str,
33
+ static_covariate_columns=str,
34
+ create_tf_records=bool,
35
+ docker_region=str,
36
+ ):
37
+ # fmt: off
38
+ """Creates Dataprep args.
39
+
40
+ Args:
41
+ model_blocks: The list of model blocks to use in the order they will appear
42
+ in the model. Possible values are `cleaning`, `change_point`, `trend`,
43
+ `hour_of_week`, `day_of_week`, `day_of_year`, `week_of_year`,
44
+ `month_of_year`, `residual`.
45
+ ts_identifier_columns: The list of ts_identifier columns from the BigQuery
46
+ data source.
47
+ static_covariate_columns: The list of strings of static covariate names.
48
+ csv_data_path: The path to the training data csv in the format
49
+ gs://bucket_name/sub_dir/blob_name.csv.
50
+ previous_run_dir: The dataprep dir from a previous run. Use this
51
+ to save time if you've already created TFRecords from your BigQuery
52
+ dataset with the same dataprep parameters as this run.
53
+ location: The location where the pipeline is run.
54
+
55
+ Returns:
56
+ A NamedTuple containing the model blocks formatted as expected by the
57
+ dataprep job, the ts_identifier_columns formatted as expected by the
58
+ dataprep job, the static_covariate_columns formatted as expected by the
59
+ dataprep job, a boolean indicating whether to create tf records, and the
60
+ region of the dataprep docker image.
61
+ """
62
+ outputs = NamedTuple(
63
+ 'DataprepArgs',
64
+ model_blocks=str,
65
+ ts_identifier_columns=str,
66
+ static_covariate_columns=str,
67
+ create_tf_records=bool,
68
+ docker_region=str,
69
+ )
70
+
71
+ def maybe_update_model_blocks(model_blocks: List[str]) -> List[str]:
72
+ return [f'{b}-hybrid' if '_of_' in b else b for b in model_blocks]
73
+
74
+ def create_name_tuple_from_list(input_list: List[str]) -> str:
75
+ if len(input_list) == 1:
76
+ return str(input_list).replace('[', '(').replace(']', ',)')
77
+ return str(input_list).replace('[', '(').replace(']', ')')
78
+
79
+ def set_docker_region(location: str) -> str:
80
+ if location.startswith('africa') or location.startswith('europe'):
81
+ return 'europe'
82
+ elif (
83
+ location.startswith('asia')
84
+ or location.startswith('australia')
85
+ or location.startswith('me')
86
+ ):
87
+ return 'asia'
88
+ else:
89
+ return 'us'
90
+
91
+ return outputs(
92
+ create_name_tuple_from_list(maybe_update_model_blocks(model_blocks)), # pylint: disable=too-many-function-args
93
+ ','.join(ts_identifier_columns), # pylint: disable=too-many-function-args
94
+ create_name_tuple_from_list(static_covariate_columns), # pylint: disable=too-many-function-args
95
+ False if csv_data_path or previous_run_dir else True, # pylint: disable=too-many-function-args
96
+ set_docker_region(location), # pylint: disable=too-many-function-args
97
+ )