google-cloud-pipeline-components 2.14.1__py3-none-any.whl → 2.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of google-cloud-pipeline-components might be problematic. Click here for more details.
- google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py +1 -1
- google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py +24 -0
- google_cloud_pipeline_components/_implementation/starry_net/__init__.py +41 -0
- google_cloud_pipeline_components/_implementation/{model_evaluation/import_evaluation → starry_net/dataprep}/__init__.py +1 -2
- google_cloud_pipeline_components/_implementation/starry_net/dataprep/component.py +173 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/component.py +23 -0
- google_cloud_pipeline_components/_implementation/starry_net/evaluation/evaluation.yaml +197 -0
- google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py +62 -0
- google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/component.py +77 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_dataprep_args/component.py +97 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_eval_args/component.py +76 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_test_set/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_test_set/component.py +48 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_tfrecord_args/component.py +70 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_train_args/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/set_train_args/component.py +90 -0
- google_cloud_pipeline_components/_implementation/starry_net/train/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/train/component.py +220 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_decomposition_plots/component.py +64 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/__init__.py +13 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/component.py +23 -0
- google_cloud_pipeline_components/_implementation/starry_net/upload_model/upload_model.yaml +37 -0
- google_cloud_pipeline_components/_implementation/starry_net/version.py +18 -0
- google_cloud_pipeline_components/container/preview/custom_job/remote_runner.py +22 -0
- google_cloud_pipeline_components/container/utils/error_surfacing.py +45 -0
- google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py +36 -7
- google_cloud_pipeline_components/preview/automl/forecasting/forecasting_ensemble.py +1 -1
- google_cloud_pipeline_components/preview/automl/forecasting/forecasting_stage_1_tuner.py +2 -2
- google_cloud_pipeline_components/preview/automl/forecasting/forecasting_stage_2_tuner.py +2 -2
- google_cloud_pipeline_components/preview/automl/forecasting/learn_to_learn_forecasting_pipeline.yaml +38 -34
- google_cloud_pipeline_components/preview/automl/forecasting/sequence_to_sequence_forecasting_pipeline.yaml +38 -34
- google_cloud_pipeline_components/preview/automl/forecasting/temporal_fusion_transformer_forecasting_pipeline.yaml +38 -34
- google_cloud_pipeline_components/preview/automl/forecasting/time_series_dense_encoder_forecasting_pipeline.yaml +38 -34
- google_cloud_pipeline_components/preview/automl/forecasting/utils.py +49 -7
- google_cloud_pipeline_components/preview/automl/tabular/auto_feature_engineering.py +1 -1
- google_cloud_pipeline_components/preview/automl/tabular/automl_tabular_feature_selection_pipeline.yaml +39 -39
- google_cloud_pipeline_components/preview/automl/tabular/automl_tabular_v2_pipeline.yaml +41 -41
- google_cloud_pipeline_components/preview/automl/tabular/distillation_stage_feature_transform_engine.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/feature_selection.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/feature_selection_pipeline.yaml +4 -4
- google_cloud_pipeline_components/preview/automl/tabular/feature_transform_engine.py +3 -3
- google_cloud_pipeline_components/preview/automl/tabular/tabnet_hyperparameter_tuning_job.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/tabnet_hyperparameter_tuning_job_pipeline.yaml +15 -15
- google_cloud_pipeline_components/preview/automl/tabular/tabnet_trainer.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/tabnet_trainer_pipeline.yaml +13 -13
- google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_hyperparameter_tuning_job.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_hyperparameter_tuning_job_pipeline.yaml +14 -14
- google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_trainer.py +2 -2
- google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_trainer_pipeline.yaml +13 -13
- google_cloud_pipeline_components/preview/automl/tabular/xgboost_hyperparameter_tuning_job_pipeline.yaml +14 -14
- google_cloud_pipeline_components/preview/automl/tabular/xgboost_trainer_pipeline.yaml +13 -13
- google_cloud_pipeline_components/preview/custom_job/utils.py +45 -6
- google_cloud_pipeline_components/preview/llm/rlhf/component.py +3 -6
- google_cloud_pipeline_components/preview/starry_net/__init__.py +19 -0
- google_cloud_pipeline_components/preview/starry_net/component.py +469 -0
- google_cloud_pipeline_components/proto/task_error_pb2.py +0 -1
- google_cloud_pipeline_components/v1/automl/forecasting/bqml_arima_predict_pipeline.yaml +10 -10
- google_cloud_pipeline_components/v1/automl/forecasting/bqml_arima_train_pipeline.yaml +31 -31
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_predict_pipeline.yaml +13 -13
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer.py +3 -3
- google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer_pipeline.yaml +14 -14
- google_cloud_pipeline_components/v1/automl/tabular/automl_tabular_pipeline.yaml +37 -37
- google_cloud_pipeline_components/v1/automl/tabular/cv_trainer.py +2 -2
- google_cloud_pipeline_components/v1/automl/tabular/ensemble.py +2 -2
- google_cloud_pipeline_components/v1/automl/tabular/finalizer.py +1 -1
- google_cloud_pipeline_components/v1/automl/tabular/infra_validator.py +1 -1
- google_cloud_pipeline_components/v1/automl/tabular/split_materialized_data.py +1 -1
- google_cloud_pipeline_components/v1/automl/tabular/stage_1_tuner.py +2 -2
- google_cloud_pipeline_components/v1/automl/tabular/stats_and_example_gen.py +2 -2
- google_cloud_pipeline_components/v1/automl/tabular/training_configurator_and_validator.py +1 -1
- google_cloud_pipeline_components/v1/automl/tabular/transform.py +2 -2
- google_cloud_pipeline_components/v1/custom_job/component.py +3 -0
- google_cloud_pipeline_components/v1/custom_job/utils.py +4 -0
- google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +21 -0
- google_cloud_pipeline_components/version.py +1 -1
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/METADATA +17 -20
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/RECORD +87 -58
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/WHEEL +1 -1
- google_cloud_pipeline_components/_implementation/model_evaluation/import_evaluation/component.py +0 -208
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/LICENSE +0 -0
- {google_cloud_pipeline_components-2.14.1.dist-info → google_cloud_pipeline_components-2.16.0.dist-info}/top_level.txt +0 -0
|
@@ -38,6 +38,9 @@ def evaluation_dataset_preprocessor_internal(
|
|
|
38
38
|
output_dirs: dsl.OutputPath(list),
|
|
39
39
|
gcp_resources: dsl.OutputPath(str),
|
|
40
40
|
input_field_name: str = 'input_text',
|
|
41
|
+
role_field_name: str = 'role',
|
|
42
|
+
target_field_name: str = 'ground_truth',
|
|
43
|
+
model_name: str = 'publishers/google/model/text-bison@002',
|
|
41
44
|
display_name: str = 'llm_evaluation_dataset_preprocessor_component',
|
|
42
45
|
machine_type: str = 'e2-highmem-16',
|
|
43
46
|
service_account: str = '',
|
|
@@ -56,6 +59,11 @@ def evaluation_dataset_preprocessor_internal(
|
|
|
56
59
|
gcs_source_uris: A json escaped list of GCS URIs of the input eval dataset.
|
|
57
60
|
input_field_name: The field name of the input eval dataset instances that
|
|
58
61
|
contains the input prompts to the LLM.
|
|
62
|
+
role_field_name: The field name of the role for input eval dataset instances
|
|
63
|
+
that contains the input prompts to the LLM.
|
|
64
|
+
target_field_name: The field name of the target for input eval dataset
|
|
65
|
+
instances.
|
|
66
|
+
model_name: Name of the model being used to create model-specific schemas.
|
|
59
67
|
machine_type: The machine type of this custom job. If not set, defaulted
|
|
60
68
|
to `e2-highmem-16`. More details:
|
|
61
69
|
https://cloud.google.com/compute/docs/machine-resource
|
|
@@ -92,6 +100,11 @@ def evaluation_dataset_preprocessor_internal(
|
|
|
92
100
|
f'--eval_dataset_preprocessor={True}',
|
|
93
101
|
f'--gcs_source_uris={gcs_source_uris}',
|
|
94
102
|
f'--input_field_name={input_field_name}',
|
|
103
|
+
f'--role_field_name={role_field_name}',
|
|
104
|
+
(
|
|
105
|
+
f'--target_field_name={target_field_name}'
|
|
106
|
+
f'--model_name={model_name}'
|
|
107
|
+
),
|
|
95
108
|
f'--output_dirs={output_dirs}',
|
|
96
109
|
'--executor_input={{$.json_escape[1]}}',
|
|
97
110
|
],
|
|
@@ -109,6 +122,9 @@ def llm_evaluation_dataset_preprocessor_graph_component(
|
|
|
109
122
|
location: str,
|
|
110
123
|
gcs_source_uris: List[str],
|
|
111
124
|
input_field_name: str = 'input_text',
|
|
125
|
+
role_field_name: str = 'role',
|
|
126
|
+
target_field_name: str = 'ground_truth',
|
|
127
|
+
model_name: str = 'publishers/google/model/text-bison@002',
|
|
112
128
|
display_name: str = 'llm_evaluation_dataset_preprocessor_component',
|
|
113
129
|
machine_type: str = 'e2-standard-4',
|
|
114
130
|
service_account: str = '',
|
|
@@ -126,6 +142,11 @@ def llm_evaluation_dataset_preprocessor_graph_component(
|
|
|
126
142
|
gcs_source_uris: A list of GCS URIs of the input eval dataset.
|
|
127
143
|
input_field_name: The field name of the input eval dataset instances that
|
|
128
144
|
contains the input prompts to the LLM.
|
|
145
|
+
role_field_name: The field name of the role for input eval dataset
|
|
146
|
+
instances that contains the input prompts to the LLM.
|
|
147
|
+
target_field_name: The field name of the target for input eval dataset
|
|
148
|
+
instances.
|
|
149
|
+
model_name: Name of the model being used to create model-specific schemas.
|
|
129
150
|
display_name: The name of the Evaluation job.
|
|
130
151
|
machine_type: The machine type of this custom job. If not set, defaulted
|
|
131
152
|
to `e2-standard-4`. More details:
|
|
@@ -163,6 +184,9 @@ def llm_evaluation_dataset_preprocessor_graph_component(
|
|
|
163
184
|
input_list=gcs_source_uris
|
|
164
185
|
).output,
|
|
165
186
|
input_field_name=input_field_name,
|
|
187
|
+
role_field_name=role_field_name,
|
|
188
|
+
target_field_name=target_field_name,
|
|
189
|
+
model_name=model_name,
|
|
166
190
|
display_name=display_name,
|
|
167
191
|
machine_type=machine_type,
|
|
168
192
|
service_account=service_account,
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from google_cloud_pipeline_components._implementation.starry_net.dataprep.component import dataprep as DataprepOp
|
|
15
|
+
from google_cloud_pipeline_components._implementation.starry_net.evaluation.component import evaluation as EvaluationOp
|
|
16
|
+
from google_cloud_pipeline_components._implementation.starry_net.get_training_artifacts.component import get_training_artifacts as GetTrainingArtifactsOp
|
|
17
|
+
from google_cloud_pipeline_components._implementation.starry_net.maybe_set_tfrecord_args.component import maybe_set_tfrecord_args as MaybeSetTfrecordArgsOp
|
|
18
|
+
from google_cloud_pipeline_components._implementation.starry_net.set_dataprep_args.component import set_dataprep_args as SetDataprepArgsOp
|
|
19
|
+
from google_cloud_pipeline_components._implementation.starry_net.set_eval_args.component import set_eval_args as SetEvalArgsOp
|
|
20
|
+
from google_cloud_pipeline_components._implementation.starry_net.set_test_set.component import set_test_set as SetTestSetOp
|
|
21
|
+
from google_cloud_pipeline_components._implementation.starry_net.set_tfrecord_args.component import set_tfrecord_args as SetTfrecordArgsOp
|
|
22
|
+
from google_cloud_pipeline_components._implementation.starry_net.set_train_args.component import set_train_args as SetTrainArgsOp
|
|
23
|
+
from google_cloud_pipeline_components._implementation.starry_net.train.component import train as TrainOp
|
|
24
|
+
from google_cloud_pipeline_components._implementation.starry_net.upload_decomposition_plots.component import upload_decomposition_plots as UploadDecompositionPlotsOp
|
|
25
|
+
from google_cloud_pipeline_components._implementation.starry_net.upload_model.component import upload_model as UploadModelOp
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
'DataprepOp',
|
|
30
|
+
'EvaluationOp',
|
|
31
|
+
'GetTrainingArtifactsOp',
|
|
32
|
+
'MaybeSetTfrecordArgsOp',
|
|
33
|
+
'SetDataprepArgsOp',
|
|
34
|
+
'SetEvalArgsOp',
|
|
35
|
+
'SetTestSetOp',
|
|
36
|
+
'SetTfrecordArgsOp',
|
|
37
|
+
'SetTrainArgsOp',
|
|
38
|
+
'TrainOp',
|
|
39
|
+
'UploadDecompositionPlotsOp',
|
|
40
|
+
'UploadModelOp',
|
|
41
|
+
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,4 +11,3 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
"""Google Cloud Pipeline Evaluation Import Evaluation Component."""
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Starry Net component for data preparation."""
|
|
15
|
+
|
|
16
|
+
from google_cloud_pipeline_components import utils
|
|
17
|
+
from google_cloud_pipeline_components._implementation.starry_net import version
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.container_component
|
|
22
|
+
def dataprep(
|
|
23
|
+
gcp_resources: dsl.OutputPath(str),
|
|
24
|
+
dataprep_dir: dsl.Output[dsl.Artifact], # pytype: disable=unsupported-operands
|
|
25
|
+
backcast_length: int,
|
|
26
|
+
forecast_length: int,
|
|
27
|
+
train_end_date: str,
|
|
28
|
+
n_val_windows: int,
|
|
29
|
+
n_test_windows: int,
|
|
30
|
+
test_set_stride: int,
|
|
31
|
+
model_blocks: str,
|
|
32
|
+
bigquery_source: str,
|
|
33
|
+
ts_identifier_columns: str,
|
|
34
|
+
time_column: str,
|
|
35
|
+
static_covariate_columns: str,
|
|
36
|
+
static_covariates_vocab_path: str, # pytype: disable=unused-argument
|
|
37
|
+
target_column: str,
|
|
38
|
+
machine_type: str,
|
|
39
|
+
docker_region: str,
|
|
40
|
+
location: str,
|
|
41
|
+
project: str,
|
|
42
|
+
job_id: str,
|
|
43
|
+
job_name_prefix: str,
|
|
44
|
+
num_workers: int,
|
|
45
|
+
max_num_workers: int,
|
|
46
|
+
disk_size_gb: int,
|
|
47
|
+
test_set_only: bool,
|
|
48
|
+
bigquery_output: str,
|
|
49
|
+
nan_threshold: float,
|
|
50
|
+
zero_threshold: float,
|
|
51
|
+
gcs_source: str,
|
|
52
|
+
gcs_static_covariate_source: str,
|
|
53
|
+
encryption_spec_key_name: str,
|
|
54
|
+
):
|
|
55
|
+
# fmt: off
|
|
56
|
+
"""Runs Dataprep for training and evaluating a STARRY-Net model.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the
|
|
60
|
+
CustomJob.
|
|
61
|
+
dataprep_dir: The gcp bucket path where all dataprep artifacts
|
|
62
|
+
are saved.
|
|
63
|
+
backcast_length: The length of the input window to feed into the model.
|
|
64
|
+
forecast_length: The length of the forecast horizon.
|
|
65
|
+
train_end_date: The last date of data to use in the training set. All
|
|
66
|
+
subsequent dates are part of the test set.
|
|
67
|
+
n_val_windows: The number of windows to use for the val set. If 0, no
|
|
68
|
+
validation set is used.
|
|
69
|
+
n_test_windows: The number of windows to use for the test set. Must be >= 1.
|
|
70
|
+
test_set_stride: The number of timestamps to roll forward when
|
|
71
|
+
constructing the val and test sets.
|
|
72
|
+
model_blocks: The stringified tuple of blocks to use in the order
|
|
73
|
+
that they appear in the model. Possible values are `cleaning`,
|
|
74
|
+
`change_point`, `trend`, `hour_of_week-hybrid`, `day_of_week-hybrid`,
|
|
75
|
+
`day_of_year-hybrid`, `week_of_year-hybrid`, `month_of_year-hybrid`,
|
|
76
|
+
`residual`, `quantile`.
|
|
77
|
+
bigquery_source: The BigQuery source of the data.
|
|
78
|
+
ts_identifier_columns: The columns that identify unique time series in the BigQuery
|
|
79
|
+
data source.
|
|
80
|
+
time_column: The column with timestamps in the BigQuery source.
|
|
81
|
+
static_covariate_columns: The names of the staic covariates.
|
|
82
|
+
static_covariates_vocab_path: The path to the master static covariates vocab
|
|
83
|
+
json.
|
|
84
|
+
target_column: The target column in the Big Query data source.
|
|
85
|
+
machine_type: The machine type of the dataflow workers.
|
|
86
|
+
docker_region: The docker region, used to determine which image to use.
|
|
87
|
+
location: The location where the job is run.
|
|
88
|
+
project: The name of the project.
|
|
89
|
+
job_id: The pipeline job id.
|
|
90
|
+
job_name_prefix: The name of the dataflow job name prefix.
|
|
91
|
+
num_workers: The initial number of workers in the dataflow job.
|
|
92
|
+
max_num_workers: The maximum number of workers in the dataflow job.
|
|
93
|
+
disk_size_gb: The disk size of each dataflow worker.
|
|
94
|
+
test_set_only: Whether to only create the test set BigQuery table or also
|
|
95
|
+
to create TFRecords for traiing and validation.
|
|
96
|
+
bigquery_output: The BigQuery dataset where the test set is written in the
|
|
97
|
+
form bq://project.dataset.
|
|
98
|
+
nan_threshold: Series having more nan / missing values than
|
|
99
|
+
nan_threshold (inclusive) in percentage for either backtest or forecast
|
|
100
|
+
will not be sampled in the training set (including missing due to
|
|
101
|
+
train_start and train_end). All existing nans are replaced by zeros.
|
|
102
|
+
zero_threshold: Series having more 0.0 values than zero_threshold
|
|
103
|
+
(inclusive) in percentage for either backtest or forecast will not be
|
|
104
|
+
sampled in the training set.
|
|
105
|
+
gcs_source: The path the csv file of the data source.
|
|
106
|
+
gcs_static_covariate_source: The path to the csv file of static covariates.
|
|
107
|
+
encryption_spec_key_name: Customer-managed encryption key options for the
|
|
108
|
+
CustomJob. If this is set, then all resources created by the CustomJob
|
|
109
|
+
will be encrypted with the provided encryption key.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
gcp_resources: Serialized JSON of ``gcp_resources`` which tracks the
|
|
113
|
+
CustomJob.
|
|
114
|
+
dataprep_dir: The gcp bucket path where all dataprep artifacts
|
|
115
|
+
are saved.
|
|
116
|
+
"""
|
|
117
|
+
job_name = f'{job_name_prefix}-{job_id}'
|
|
118
|
+
payload = {
|
|
119
|
+
'display_name': job_name,
|
|
120
|
+
'encryption_spec': {
|
|
121
|
+
'kms_key_name': str(encryption_spec_key_name),
|
|
122
|
+
},
|
|
123
|
+
'job_spec': {
|
|
124
|
+
'worker_pool_specs': [{
|
|
125
|
+
'replica_count': '1',
|
|
126
|
+
'machine_spec': {
|
|
127
|
+
'machine_type': str(machine_type),
|
|
128
|
+
},
|
|
129
|
+
'disk_spec': {
|
|
130
|
+
'boot_disk_type': 'pd-ssd',
|
|
131
|
+
'boot_disk_size_gb': 100,
|
|
132
|
+
},
|
|
133
|
+
'container_spec': {
|
|
134
|
+
'image_uri': f'{docker_region}-docker.pkg.dev/vertex-ai-restricted/starryn/dataprep:captain_{version.DATAPREP_VERSION}',
|
|
135
|
+
'args': [
|
|
136
|
+
'--config=starryn/experiments/configs/vertex.py',
|
|
137
|
+
f'--config.datasets.backcast_length={backcast_length}',
|
|
138
|
+
f'--config.datasets.forecast_length={forecast_length}',
|
|
139
|
+
f'--config.datasets.train_end_date={train_end_date}',
|
|
140
|
+
f'--config.datasets.n_val_windows={n_val_windows}',
|
|
141
|
+
f'--config.datasets.val_rolling_window_size={test_set_stride}',
|
|
142
|
+
f'--config.datasets.n_test_windows={n_test_windows}',
|
|
143
|
+
f'--config.datasets.test_rolling_window_size={test_set_stride}',
|
|
144
|
+
f'--config.datasets.nan_threshold={nan_threshold}',
|
|
145
|
+
f'--config.datasets.zero_threshold={zero_threshold}',
|
|
146
|
+
f'--config.model.static_cov_names={static_covariate_columns}',
|
|
147
|
+
f'--config.model.blocks_list={model_blocks}',
|
|
148
|
+
f'--bigquery_source={bigquery_source}',
|
|
149
|
+
f'--bigquery_output={bigquery_output}',
|
|
150
|
+
f'--gcs_source={gcs_source}',
|
|
151
|
+
f'--gcs_static_covariate_source={gcs_static_covariate_source}',
|
|
152
|
+
f'--ts_identifier_columns={ts_identifier_columns}',
|
|
153
|
+
f'--time_column={time_column}',
|
|
154
|
+
f'--target_column={target_column}',
|
|
155
|
+
f'--job_id={job_name}',
|
|
156
|
+
f'--num_workers={num_workers}',
|
|
157
|
+
f'--max_num_workers={max_num_workers}',
|
|
158
|
+
f'--root_bucket={dataprep_dir.uri}',
|
|
159
|
+
f'--disk_size={disk_size_gb}',
|
|
160
|
+
f'--machine_type={machine_type}',
|
|
161
|
+
f'--test_set_only={test_set_only}',
|
|
162
|
+
f'--image_uri={docker_region}-docker.pkg.dev/vertex-ai-restricted/starryn/dataprep:replica_{version.DATAPREP_VERSION}',
|
|
163
|
+
],
|
|
164
|
+
},
|
|
165
|
+
}]
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
return utils.build_serverless_customjob_container_spec(
|
|
169
|
+
project=project,
|
|
170
|
+
location=location,
|
|
171
|
+
custom_job_payload=payload,
|
|
172
|
+
gcp_resources=gcp_resources,
|
|
173
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""StarryNet Evaluation Component."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from kfp import components
|
|
19
|
+
|
|
20
|
+
# TODO(b/346580764)
|
|
21
|
+
evaluation = components.load_component_from_file(
|
|
22
|
+
os.path.join(os.path.dirname(__file__), 'evaluation.yaml')
|
|
23
|
+
)
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
name: model_evaluation_forecasting
|
|
2
|
+
description: |
|
|
3
|
+
Computes a google.ForecastingMetrics Artifact, containing evaluation metrics given a model's prediction results.
|
|
4
|
+
Creates a dataflow job with Apache Beam and TFMA to compute evaluation metrics.
|
|
5
|
+
Supports point forecasting and quantile forecasting for tabular data.
|
|
6
|
+
Args:
|
|
7
|
+
project (str):
|
|
8
|
+
Project to run evaluation container.
|
|
9
|
+
location (Optional[str]):
|
|
10
|
+
Location for running the evaluation.
|
|
11
|
+
If not set, defaulted to `us-central1`.
|
|
12
|
+
root_dir (str):
|
|
13
|
+
The GCS directory for keeping staging files.
|
|
14
|
+
A random subdirectory will be created under the directory to keep job info for resuming
|
|
15
|
+
the job in case of failure.
|
|
16
|
+
predictions_format (Optional[str]):
|
|
17
|
+
The file format for the batch prediction results. `jsonl` is currently the only allowed
|
|
18
|
+
format.
|
|
19
|
+
If not set, defaulted to `jsonl`.
|
|
20
|
+
predictions_gcs_source (Optional[system.Artifact]):
|
|
21
|
+
An artifact with its URI pointing toward a GCS directory with prediction or explanation
|
|
22
|
+
files to be used for this evaluation.
|
|
23
|
+
For prediction results, the files should be named "prediction.results-*".
|
|
24
|
+
For explanation results, the files should be named "explanation.results-*".
|
|
25
|
+
predictions_bigquery_source (Optional[google.BQTable]):
|
|
26
|
+
BigQuery table with prediction or explanation data to be used for this evaluation.
|
|
27
|
+
For prediction results, the table column should be named "predicted_*".
|
|
28
|
+
ground_truth_format(Optional[str]):
|
|
29
|
+
Required for custom tabular and non tabular data.
|
|
30
|
+
The file format for the ground truth files. `jsonl` is currently the only allowed format.
|
|
31
|
+
If not set, defaulted to `jsonl`.
|
|
32
|
+
ground_truth_gcs_source(Optional[Sequence[str]]):
|
|
33
|
+
Required for custom tabular and non tabular data.
|
|
34
|
+
The GCS uris representing where the ground truth is located.
|
|
35
|
+
Used to provide ground truth for each prediction instance when they are not part of the batch prediction jobs prediction instance.
|
|
36
|
+
ground_truth_bigquery_source(Optional[str]):
|
|
37
|
+
Required for custom tabular.
|
|
38
|
+
The BigQuery table uri representing where the ground truth is located.
|
|
39
|
+
Used to provide ground truth for each prediction instance when they are not part of the batch prediction jobs prediction instance.
|
|
40
|
+
target_field_name (str):
|
|
41
|
+
The full name path of the features target field in the predictions file.
|
|
42
|
+
Formatted to be able to find nested columns, delimited by `.`.
|
|
43
|
+
Alternatively referred to as the ground truth (or ground_truth_column) field.
|
|
44
|
+
model (Optional[google.VertexModel]):
|
|
45
|
+
The Model used for predictions job.
|
|
46
|
+
Must share the same ancestor Location.
|
|
47
|
+
prediction_score_column (Optional[str]):
|
|
48
|
+
Optional. The column name of the field containing batch prediction scores.
|
|
49
|
+
Formatted to be able to find nested columns, delimited by `.`.
|
|
50
|
+
If not set, defaulted to `prediction.value` for a `point` forecasting_type and
|
|
51
|
+
`prediction.quantile_predictions` for a `quantile` forecasting_type.
|
|
52
|
+
forecasting_type (Optional[str]):
|
|
53
|
+
Optional. If the problem_type is `forecasting`, then the forecasting type being addressed
|
|
54
|
+
by this regression evaluation run. `point` and `quantile` are the supported types.
|
|
55
|
+
If not set, defaulted to `point`.
|
|
56
|
+
forecasting_quantiles (Optional[Sequence[Float]]):
|
|
57
|
+
Required for a `quantile` forecasting_type.
|
|
58
|
+
The list of quantiles in the same order appeared in the quantile prediction score column.
|
|
59
|
+
If one of the quantiles is set to `0.5f`, point evaluation will be set on that index.
|
|
60
|
+
example_weight_column (Optional[str]):
|
|
61
|
+
Optional. The column name of the field containing example weights.
|
|
62
|
+
Each value of positive_classes provided.
|
|
63
|
+
point_evaluation_quantile (Optional[Float]):
|
|
64
|
+
Required for a `quantile` forecasting_type.
|
|
65
|
+
A quantile in the list of forecasting_quantiles that will be used for point evaluation
|
|
66
|
+
metrics.
|
|
67
|
+
dataflow_service_account (Optional[str]):
|
|
68
|
+
Optional. Service account to run the dataflow job.
|
|
69
|
+
If not set, dataflow will use the default woker service account.
|
|
70
|
+
For more details, see https://cloud.google.com/dataflow/docs/concepts/security-and-permissions#default_worker_service_account
|
|
71
|
+
dataflow_disk_size (Optional[int]):
|
|
72
|
+
Optional. The disk size (in GB) of the machine executing the evaluation run.
|
|
73
|
+
If not set, defaulted to `50`.
|
|
74
|
+
dataflow_machine_type (Optional[str]):
|
|
75
|
+
Optional. The machine type executing the evaluation run.
|
|
76
|
+
If not set, defaulted to `n1-standard-4`.
|
|
77
|
+
dataflow_workers_num (Optional[int]):
|
|
78
|
+
Optional. The number of workers executing the evaluation run.
|
|
79
|
+
If not set, defaulted to `10`.
|
|
80
|
+
dataflow_max_workers_num (Optional[int]):
|
|
81
|
+
Optional. The max number of workers executing the evaluation run.
|
|
82
|
+
If not set, defaulted to `25`.
|
|
83
|
+
dataflow_subnetwork (Optional[str]):
|
|
84
|
+
Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be
|
|
85
|
+
used. More details:
|
|
86
|
+
https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications
|
|
87
|
+
dataflow_use_public_ips (Optional[bool]):
|
|
88
|
+
Specifies whether Dataflow workers use public IP addresses.
|
|
89
|
+
encryption_spec_key_name (Optional[str]):
|
|
90
|
+
Customer-managed encryption key.
|
|
91
|
+
Returns:
|
|
92
|
+
evaluation_metrics (google.ForecastingMetrics):
|
|
93
|
+
google.ForecastingMetrics artifact representing the forecasting evaluation metrics in GCS.
|
|
94
|
+
inputs:
|
|
95
|
+
- { name: project, type: String }
|
|
96
|
+
- { name: location, type: String, default: "us-central1" }
|
|
97
|
+
- { name: root_dir, type: system.Artifact }
|
|
98
|
+
- { name: predictions_format, type: String, default: "jsonl" }
|
|
99
|
+
- { name: predictions_gcs_source, type: Artifact, optional: True }
|
|
100
|
+
- { name: predictions_bigquery_source, type: google.BQTable, optional: True }
|
|
101
|
+
- { name: ground_truth_format, type: String, default: "jsonl" }
|
|
102
|
+
- { name: ground_truth_gcs_source, type: JsonArray, default: "[]" }
|
|
103
|
+
- { name: ground_truth_bigquery_source, type: String, default: "" }
|
|
104
|
+
- { name: target_field_name, type: String }
|
|
105
|
+
- { name: model, type: google.VertexModel, optional: True }
|
|
106
|
+
- { name: prediction_score_column, type: String, default: "" }
|
|
107
|
+
- { name: forecasting_type, type: String, default: "point" }
|
|
108
|
+
- { name: forecasting_quantiles, type: JsonArray, default: "[0.5]" }
|
|
109
|
+
- { name: example_weight_column, type: String, default: "" }
|
|
110
|
+
- { name: point_evaluation_quantile, type: Float, default: 0.5 }
|
|
111
|
+
- { name: dataflow_service_account, type: String, default: "" }
|
|
112
|
+
- { name: dataflow_disk_size, type: Integer, default: 50 }
|
|
113
|
+
- { name: dataflow_machine_type, type: String, default: "n1-standard-4" }
|
|
114
|
+
- { name: dataflow_workers_num, type: Integer, default: 1 }
|
|
115
|
+
- { name: dataflow_max_workers_num, type: Integer, default: 5 }
|
|
116
|
+
- { name: dataflow_subnetwork, type: String, default: "" }
|
|
117
|
+
- { name: dataflow_use_public_ips, type: Boolean, default: "true" }
|
|
118
|
+
- { name: encryption_spec_key_name, type: String, default: "" }
|
|
119
|
+
outputs:
|
|
120
|
+
- { name: evaluation_metrics, type: google.ForecastingMetrics }
|
|
121
|
+
- { name: gcp_resources, type: String }
|
|
122
|
+
implementation:
|
|
123
|
+
container:
|
|
124
|
+
image: gcr.io/ml-pipeline/model-evaluation:v0.9
|
|
125
|
+
command:
|
|
126
|
+
- python
|
|
127
|
+
- /main.py
|
|
128
|
+
args:
|
|
129
|
+
- --setup_file
|
|
130
|
+
- /setup.py
|
|
131
|
+
- --json_mode
|
|
132
|
+
- "true"
|
|
133
|
+
- --project_id
|
|
134
|
+
- { inputValue: project }
|
|
135
|
+
- --location
|
|
136
|
+
- { inputValue: location }
|
|
137
|
+
- --problem_type
|
|
138
|
+
- "forecasting"
|
|
139
|
+
- --forecasting_type
|
|
140
|
+
- { inputValue: forecasting_type }
|
|
141
|
+
- --forecasting_quantiles
|
|
142
|
+
- { inputValue: forecasting_quantiles }
|
|
143
|
+
- --point_evaluation_quantile
|
|
144
|
+
- { inputValue: point_evaluation_quantile }
|
|
145
|
+
- --batch_prediction_format
|
|
146
|
+
- { inputValue: predictions_format }
|
|
147
|
+
- if:
|
|
148
|
+
cond: {isPresent: predictions_gcs_source}
|
|
149
|
+
then:
|
|
150
|
+
- --batch_prediction_gcs_source
|
|
151
|
+
- "{{$.inputs.artifacts['predictions_gcs_source'].uri}}"
|
|
152
|
+
- if:
|
|
153
|
+
cond: {isPresent: predictions_bigquery_source}
|
|
154
|
+
then:
|
|
155
|
+
- --batch_prediction_bigquery_source
|
|
156
|
+
- "bq://{{$.inputs.artifacts['predictions_bigquery_source'].metadata['projectId']}}.{{$.inputs.artifacts['predictions_bigquery_source'].metadata['datasetId']}}.{{$.inputs.artifacts['predictions_bigquery_source'].metadata['tableId']}}"
|
|
157
|
+
- if:
|
|
158
|
+
cond: {isPresent: model}
|
|
159
|
+
then:
|
|
160
|
+
- --model_name
|
|
161
|
+
- "{{$.inputs.artifacts['model'].metadata['resourceName']}}"
|
|
162
|
+
- --ground_truth_format
|
|
163
|
+
- { inputValue: ground_truth_format }
|
|
164
|
+
- --ground_truth_gcs_source
|
|
165
|
+
- { inputValue: ground_truth_gcs_source }
|
|
166
|
+
- --ground_truth_bigquery_source
|
|
167
|
+
- { inputValue: ground_truth_bigquery_source }
|
|
168
|
+
- --root_dir
|
|
169
|
+
- "{{$.inputs.artifacts['root_dir'].uri}}"
|
|
170
|
+
- --target_field_name
|
|
171
|
+
- "instance.{{$.inputs.parameters['target_field_name']}}"
|
|
172
|
+
- --prediction_score_column
|
|
173
|
+
- { inputValue: prediction_score_column }
|
|
174
|
+
- --dataflow_job_prefix
|
|
175
|
+
- "evaluation-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}"
|
|
176
|
+
- --dataflow_service_account
|
|
177
|
+
- { inputValue: dataflow_service_account }
|
|
178
|
+
- --dataflow_disk_size
|
|
179
|
+
- { inputValue: dataflow_disk_size }
|
|
180
|
+
- --dataflow_machine_type
|
|
181
|
+
- { inputValue: dataflow_machine_type }
|
|
182
|
+
- --dataflow_workers_num
|
|
183
|
+
- { inputValue: dataflow_workers_num }
|
|
184
|
+
- --dataflow_max_workers_num
|
|
185
|
+
- { inputValue: dataflow_max_workers_num }
|
|
186
|
+
- --dataflow_subnetwork
|
|
187
|
+
- { inputValue: dataflow_subnetwork }
|
|
188
|
+
- --dataflow_use_public_ips
|
|
189
|
+
- { inputValue: dataflow_use_public_ips }
|
|
190
|
+
- --kms_key_name
|
|
191
|
+
- { inputValue: encryption_spec_key_name }
|
|
192
|
+
- --output_metrics_gcs_path
|
|
193
|
+
- { outputUri: evaluation_metrics }
|
|
194
|
+
- --gcp_resources
|
|
195
|
+
- { outputPath: gcp_resources }
|
|
196
|
+
- --executor_input
|
|
197
|
+
- "{{$}}"
|
google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""StarryNet get training artifacts component."""
|
|
15
|
+
|
|
16
|
+
from typing import NamedTuple
|
|
17
|
+
|
|
18
|
+
from kfp import dsl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dsl.component(packages_to_install=['tensorflow==2.11.0'])
|
|
22
|
+
def get_training_artifacts(
|
|
23
|
+
docker_region: str,
|
|
24
|
+
trainer_dir: dsl.InputPath(),
|
|
25
|
+
) -> NamedTuple(
|
|
26
|
+
'TrainingArtifacts',
|
|
27
|
+
image_uri=str,
|
|
28
|
+
artifact_uri=str,
|
|
29
|
+
prediction_schema_uri=str,
|
|
30
|
+
instance_schema_uri=str,
|
|
31
|
+
):
|
|
32
|
+
# fmt: off
|
|
33
|
+
"""Gets the artifact URIs from the training job.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
docker_region: The region from which the training docker image is pulled.
|
|
37
|
+
trainer_dir: The directory where training artifacts where stored.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A NamedTuple containing the image_uri for the prediction server,
|
|
41
|
+
the artifact_uri with model artifacts, the prediction_schema_uri,
|
|
42
|
+
and the instance_schema_uri.
|
|
43
|
+
"""
|
|
44
|
+
import os # pylint: disable=g-import-not-at-top
|
|
45
|
+
import tensorflow as tf # pylint: disable=g-import-not-at-top
|
|
46
|
+
|
|
47
|
+
with tf.io.gfile.GFile(os.path.join(trainer_dir, 'trainer.txt')) as f:
|
|
48
|
+
private_dir = f.read().strip()
|
|
49
|
+
|
|
50
|
+
outputs = NamedTuple(
|
|
51
|
+
'TrainingArtifacts',
|
|
52
|
+
image_uri=str,
|
|
53
|
+
artifact_uri=str,
|
|
54
|
+
prediction_schema_uri=bool,
|
|
55
|
+
instance_schema_uri=str,
|
|
56
|
+
)
|
|
57
|
+
return outputs(
|
|
58
|
+
f'{docker_region}-docker.pkg.dev/vertex-ai/starryn/predictor:20240723_0542_RC00', # pylint: disable=too-many-function-args
|
|
59
|
+
private_dir, # pylint: disable=too-many-function-args
|
|
60
|
+
os.path.join(private_dir, 'predict_schema.yaml'), # pylint: disable=too-many-function-args
|
|
61
|
+
os.path.join(private_dir, 'instance_schema.yaml'), # pylint: disable=too-many-function-args
|
|
62
|
+
)
|
google_cloud_pipeline_components/_implementation/starry_net/maybe_set_tfrecord_args/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|