data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from data_designer.config.column_configs import ValidationColumnConfig
|
|
9
|
+
from data_designer.config.errors import InvalidConfigError
|
|
10
|
+
from data_designer.config.utils.code_lang import SQL_DIALECTS, CodeLang
|
|
11
|
+
from data_designer.config.validator_params import (
|
|
12
|
+
ValidatorParamsT,
|
|
13
|
+
ValidatorType,
|
|
14
|
+
)
|
|
15
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
16
|
+
ColumnGenerator,
|
|
17
|
+
GenerationStrategy,
|
|
18
|
+
GeneratorMetadata,
|
|
19
|
+
)
|
|
20
|
+
from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
|
|
21
|
+
from data_designer.engine.errors import DataDesignerRuntimeError
|
|
22
|
+
from data_designer.engine.validators import (
|
|
23
|
+
BaseValidator,
|
|
24
|
+
LocalCallableValidator,
|
|
25
|
+
PythonValidator,
|
|
26
|
+
RemoteValidator,
|
|
27
|
+
SQLValidator,
|
|
28
|
+
ValidationResult,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_validator_from_params(validator_type: ValidatorType, validator_params: ValidatorParamsT) -> BaseValidator:
|
|
35
|
+
if validator_type == ValidatorType.CODE:
|
|
36
|
+
if validator_params.code_lang == CodeLang.PYTHON:
|
|
37
|
+
return PythonValidator(validator_params)
|
|
38
|
+
elif validator_params.code_lang in SQL_DIALECTS:
|
|
39
|
+
return SQLValidator(validator_params)
|
|
40
|
+
elif validator_type == ValidatorType.REMOTE:
|
|
41
|
+
return RemoteValidator(validator_params)
|
|
42
|
+
else:
|
|
43
|
+
return LocalCallableValidator(validator_params)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ValidationColumnGenerator(ColumnGenerator[ValidationColumnConfig]):
|
|
47
|
+
@staticmethod
|
|
48
|
+
def metadata() -> GeneratorMetadata:
|
|
49
|
+
return GeneratorMetadata(
|
|
50
|
+
name="validate",
|
|
51
|
+
description="Validate data.",
|
|
52
|
+
generation_strategy=GenerationStrategy.FULL_COLUMN,
|
|
53
|
+
required_resources=None,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
57
|
+
logger.info(f"🔍 Validating column {self.config.name!r} with {len(data)} records")
|
|
58
|
+
logger.info(f" |-- target columns: {self.config.target_columns}")
|
|
59
|
+
logger.info(f" |-- validator type: {self.config.validator_type}")
|
|
60
|
+
logger.info(f" |-- validator params: {self.config.validator_params}")
|
|
61
|
+
logger.info(f" |-- batch size: {self.config.batch_size}")
|
|
62
|
+
|
|
63
|
+
validator = get_validator_from_params(self.config.validator_type, self.config.validator_params)
|
|
64
|
+
|
|
65
|
+
# Check if the target columns are present in the dataset
|
|
66
|
+
missing_columns = set(self.config.target_columns) - set(data.columns)
|
|
67
|
+
if missing_columns:
|
|
68
|
+
raise InvalidConfigError(
|
|
69
|
+
f"Target columns {missing_columns} defined in validation column {self.config.name!r} are missing in dataset"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Check whether to pass single columns or multiple columns to the validator
|
|
73
|
+
validate_columns_separately = False
|
|
74
|
+
if self.config.validator_type == ValidatorType.CODE and len(self.config.target_columns) > 1:
|
|
75
|
+
# Code validator expects single column input, so we validate each column separately
|
|
76
|
+
validate_columns_separately = True
|
|
77
|
+
|
|
78
|
+
columns_to_validate = [[col] for col in self.config.target_columns]
|
|
79
|
+
else:
|
|
80
|
+
columns_to_validate = [self.config.target_columns]
|
|
81
|
+
|
|
82
|
+
outputs_as_dicts = None
|
|
83
|
+
for cols in columns_to_validate:
|
|
84
|
+
# Filter the dataset to only include the target columns, and convert to a list of dictionaries
|
|
85
|
+
records = data[cols].to_dict(orient="records")
|
|
86
|
+
|
|
87
|
+
batched_records = [
|
|
88
|
+
records[batch_start : batch_start + self.config.batch_size]
|
|
89
|
+
for batch_start in range(0, len(records), self.config.batch_size)
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
# Run validation in parallel or sequentially, depending on the validator type and parameters
|
|
93
|
+
if (
|
|
94
|
+
self.config.validator_type == ValidatorType.REMOTE
|
|
95
|
+
and self.config.validator_params.max_parallel_requests > 1
|
|
96
|
+
):
|
|
97
|
+
concatenated_outputs = self._validate_in_parallel(validator, batched_records)
|
|
98
|
+
else:
|
|
99
|
+
concatenated_outputs = []
|
|
100
|
+
for batch in batched_records:
|
|
101
|
+
concatenated_outputs.extend(self._validate_batch(validator, batch))
|
|
102
|
+
|
|
103
|
+
if validate_columns_separately:
|
|
104
|
+
if outputs_as_dicts is None:
|
|
105
|
+
outputs_as_dicts = [{cols[0]: output.model_dump(mode="json")} for output in concatenated_outputs]
|
|
106
|
+
else:
|
|
107
|
+
for dict_output in outputs_as_dicts:
|
|
108
|
+
dict_output[cols[0]] = concatenated_outputs[0].model_dump(mode="json")
|
|
109
|
+
else:
|
|
110
|
+
outputs_as_dicts = [output.model_dump(mode="json") for output in concatenated_outputs]
|
|
111
|
+
|
|
112
|
+
validation_results = pd.DataFrame({self.config.name: outputs_as_dicts})
|
|
113
|
+
return pd.concat([data, validation_results], axis=1)
|
|
114
|
+
|
|
115
|
+
def _validate_in_parallel(self, validator: BaseValidator, batched_records: list[list[dict]]) -> pd.DataFrame:
|
|
116
|
+
"""Run validation in parallel."""
|
|
117
|
+
|
|
118
|
+
outputs = [None] * len(batched_records)
|
|
119
|
+
|
|
120
|
+
def result_callback(result: ValidationResult, context: dict):
|
|
121
|
+
outputs[context["index"]] = result
|
|
122
|
+
|
|
123
|
+
def error_callback(error: Exception, context: dict):
|
|
124
|
+
outputs[context["index"]] = ValidationResult.empty(size=len(batched_records[context["index"]]))
|
|
125
|
+
|
|
126
|
+
with ConcurrentThreadExecutor(
|
|
127
|
+
max_workers=self.config.validator_params.max_parallel_requests,
|
|
128
|
+
column_name=self.config.name,
|
|
129
|
+
result_callback=result_callback,
|
|
130
|
+
error_callback=error_callback,
|
|
131
|
+
) as executor:
|
|
132
|
+
for i, batch in enumerate(batched_records):
|
|
133
|
+
executor.submit(lambda batch: self._validate_batch(validator, batch), batch, context={"index": i})
|
|
134
|
+
|
|
135
|
+
if any(output is None for output in outputs):
|
|
136
|
+
raise DataDesignerRuntimeError("Validation task failed due to an unexpected error in parallel execution")
|
|
137
|
+
|
|
138
|
+
# Concatenate the outputs and convert to a DataFrame
|
|
139
|
+
return sum([output.data for output in outputs], [])
|
|
140
|
+
|
|
141
|
+
def _validate_batch(self, validator: BaseValidator, batch: list[dict]) -> ValidationResult:
|
|
142
|
+
try:
|
|
143
|
+
return validator.run_validation(batch)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
error_to_display = str(e).replace("\n", "\n ") # add spaces to improve readability
|
|
146
|
+
logger.error(f"Batch could not be validated:\n {error_to_display}")
|
|
147
|
+
raise e
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.config.base import ConfigBase
|
|
5
|
+
from data_designer.config.column_configs import (
|
|
6
|
+
ExpressionColumnConfig,
|
|
7
|
+
LLMCodeColumnConfig,
|
|
8
|
+
LLMJudgeColumnConfig,
|
|
9
|
+
LLMStructuredColumnConfig,
|
|
10
|
+
LLMTextColumnConfig,
|
|
11
|
+
ValidationColumnConfig,
|
|
12
|
+
)
|
|
13
|
+
from data_designer.config.column_types import DataDesignerColumnType
|
|
14
|
+
from data_designer.engine.column_generators.generators.base import ColumnGenerator
|
|
15
|
+
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
|
|
16
|
+
from data_designer.engine.column_generators.generators.llm_generators import (
|
|
17
|
+
LLMCodeCellGenerator,
|
|
18
|
+
LLMJudgeCellGenerator,
|
|
19
|
+
LLMStructuredCellGenerator,
|
|
20
|
+
LLMTextCellGenerator,
|
|
21
|
+
)
|
|
22
|
+
from data_designer.engine.column_generators.generators.samplers import SamplerColumnGenerator
|
|
23
|
+
from data_designer.engine.column_generators.generators.seed_dataset import SeedDatasetColumnGenerator
|
|
24
|
+
from data_designer.engine.column_generators.generators.validation import ValidationColumnGenerator
|
|
25
|
+
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
26
|
+
SamplerMultiColumnConfig,
|
|
27
|
+
SeedDatasetMultiColumnConfig,
|
|
28
|
+
)
|
|
29
|
+
from data_designer.engine.registry.base import TaskRegistry
|
|
30
|
+
from data_designer.plugins.plugin import PluginType
|
|
31
|
+
from data_designer.plugins.registry import PluginRegistry
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ColumnGeneratorRegistry(TaskRegistry[DataDesignerColumnType, ColumnGenerator, ConfigBase]): ...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_default_column_generator_registry(with_plugins: bool = True) -> ColumnGeneratorRegistry:
|
|
38
|
+
registry = ColumnGeneratorRegistry()
|
|
39
|
+
registry.register(DataDesignerColumnType.LLM_TEXT, LLMTextCellGenerator, LLMTextColumnConfig)
|
|
40
|
+
registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
|
|
41
|
+
registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
|
|
42
|
+
registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
|
|
43
|
+
registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
|
|
44
|
+
registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
|
|
45
|
+
registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
|
|
46
|
+
registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
|
|
47
|
+
|
|
48
|
+
if with_plugins:
|
|
49
|
+
for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
|
|
50
|
+
registry.register(
|
|
51
|
+
DataDesignerColumnType(plugin.name),
|
|
52
|
+
plugin.task_cls,
|
|
53
|
+
plugin.config_cls,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return registry
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.engine.errors import DataDesignerError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PromptTemplateRenderError(DataDesignerError): ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ExpressionTemplateRenderError(DataDesignerError): ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SeedDatasetError(DataDesignerError): ...
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Type
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
8
|
+
|
|
9
|
+
from data_designer.config.column_configs import Score
|
|
10
|
+
|
|
11
|
+
SCORING_FORMAT = "* {score}: {description}"
|
|
12
|
+
SCORE_FIELD_DESCRIPTION_FORMAT = "Score Descriptions for {enum_name}:\n{scoring}"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseJudgeResponse(BaseModel):
|
|
16
|
+
"""Base model for all rubrics."""
|
|
17
|
+
|
|
18
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
19
|
+
reasoning: str = Field(..., description="Reasoning for the assigned score.")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
|
|
23
|
+
"""Convert score descriptions into a single text block."""
|
|
24
|
+
list_block = "\n".join(
|
|
25
|
+
[SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
|
|
26
|
+
)
|
|
27
|
+
return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
|
|
31
|
+
"""Create a JudgeResponse data type."""
|
|
32
|
+
enum_members = {}
|
|
33
|
+
for option in score.options.keys():
|
|
34
|
+
member_name = f"VALUE_{option}"
|
|
35
|
+
enum_members[member_name] = option
|
|
36
|
+
|
|
37
|
+
DynamicScaleEnum = Enum(f"{score.name}Enum", enum_members)
|
|
38
|
+
options = _stringify_scoring(score.options, enum_type=DynamicScaleEnum)
|
|
39
|
+
|
|
40
|
+
return create_model(
|
|
41
|
+
score.name,
|
|
42
|
+
__doc__=score.description if score.description else None,
|
|
43
|
+
__base__=BaseJudgeResponse,
|
|
44
|
+
score=(DynamicScaleEnum, Field(..., description=options)),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def create_judge_structured_output_model(
|
|
49
|
+
judge_responses: list[Type[BaseJudgeResponse]],
|
|
50
|
+
) -> Type[BaseModel]:
|
|
51
|
+
"""Create a JudgeStructuredOutput class dynamically."""
|
|
52
|
+
return create_model(
|
|
53
|
+
"JudgeStructuredOutput",
|
|
54
|
+
__doc__=f"Response schema for scores with the following names: {[response.__name__ for response in judge_responses]}.",
|
|
55
|
+
__base__=BaseModel,
|
|
56
|
+
**{response.__name__.lower(): (response, ...) for response in judge_responses},
|
|
57
|
+
)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from data_designer.config.column_configs import SingleColumnConfig
|
|
8
|
+
from data_designer.config.column_types import DataDesignerColumnType
|
|
9
|
+
from data_designer.config.models import ModelConfig
|
|
10
|
+
from data_designer.config.utils.code_lang import CodeLang
|
|
11
|
+
from data_designer.config.utils.misc import get_prompt_template_keywords
|
|
12
|
+
from data_designer.config.utils.type_helpers import StrEnum
|
|
13
|
+
from data_designer.engine.column_generators.utils.errors import PromptTemplateRenderError
|
|
14
|
+
from data_designer.engine.column_generators.utils.judge_score_factory import (
|
|
15
|
+
create_judge_response_model,
|
|
16
|
+
create_judge_structured_output_model,
|
|
17
|
+
)
|
|
18
|
+
from data_designer.engine.models.recipes.base import ResponseRecipe
|
|
19
|
+
from data_designer.engine.models.recipes.response_recipes import (
|
|
20
|
+
CodeResponseRecipe,
|
|
21
|
+
PydanticResponseRecipe,
|
|
22
|
+
StructuredResponseRecipe,
|
|
23
|
+
TextResponseRecipe,
|
|
24
|
+
)
|
|
25
|
+
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
|
|
26
|
+
from data_designer.engine.processing.ginja.exceptions import UserTemplateError, UserTemplateUnsupportedFiltersError
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PromptType(StrEnum):
|
|
32
|
+
SYSTEM_PROMPT = "system_prompt"
|
|
33
|
+
USER_PROMPT = "user_prompt"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RecordBasedPromptRenderer(WithJinja2UserTemplateRendering):
|
|
37
|
+
def __init__(self, response_recipe: ResponseRecipe, *, error_message_context: dict[str, str] | None = None):
|
|
38
|
+
self.response_recipe = response_recipe
|
|
39
|
+
self._error_message_context = error_message_context
|
|
40
|
+
|
|
41
|
+
def render(self, *, prompt_template: str | None, record: dict, prompt_type: PromptType) -> str | None:
|
|
42
|
+
self._prepare_environment(prompt_template=prompt_template, record=record, prompt_type=prompt_type)
|
|
43
|
+
rendered_prompt = self.render_multi_template(prompt_type, record) if prompt_template else ""
|
|
44
|
+
recipe_applicator = (
|
|
45
|
+
self.response_recipe.apply_recipe_to_user_prompt
|
|
46
|
+
if prompt_type == PromptType.USER_PROMPT
|
|
47
|
+
else self.response_recipe.apply_recipe_to_system_prompt
|
|
48
|
+
)
|
|
49
|
+
return recipe_applicator(rendered_prompt)
|
|
50
|
+
|
|
51
|
+
def _prepare_environment(self, *, prompt_template: str | None, record: dict, prompt_type: PromptType) -> None:
|
|
52
|
+
try:
|
|
53
|
+
self.prepare_jinja2_multi_template_renderer(
|
|
54
|
+
template_name=prompt_type.value,
|
|
55
|
+
prompt_template=prompt_template,
|
|
56
|
+
dataset_variables=list(record.keys()),
|
|
57
|
+
)
|
|
58
|
+
except (UserTemplateUnsupportedFiltersError, UserTemplateError) as exc:
|
|
59
|
+
template_variables = get_prompt_template_keywords(prompt_template)
|
|
60
|
+
missing_columns = list(set(template_variables) - set(record.keys()))
|
|
61
|
+
|
|
62
|
+
error_msg = (
|
|
63
|
+
f"There was an error preparing the {prompt_type.value.replace('_', ' ')} "
|
|
64
|
+
"template. Please double check that the template is valid Jinja2 syntax, that all "
|
|
65
|
+
"referenced variables are defined, and that any filters you are using are supported."
|
|
66
|
+
)
|
|
67
|
+
if len(missing_columns) > 0:
|
|
68
|
+
error_msg += f"\nThe following {missing_columns} columns are missing!"
|
|
69
|
+
if self._error_message_context is not None:
|
|
70
|
+
error_msg += f"\n{json.dumps(self._error_message_context, indent=2)}"
|
|
71
|
+
logger.error(f"🛑 {error_msg}")
|
|
72
|
+
raise PromptTemplateRenderError(f"{exc!s} {error_msg}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def create_response_recipe(
|
|
76
|
+
column_config: SingleColumnConfig, model_config: ModelConfig | None = None
|
|
77
|
+
) -> ResponseRecipe:
|
|
78
|
+
if model_config and column_config.model_alias != model_config.alias:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Column config model alias {column_config.model_alias} does not match model config alias {model_config.alias}"
|
|
81
|
+
)
|
|
82
|
+
if column_config.column_type == DataDesignerColumnType.LLM_TEXT:
|
|
83
|
+
return TextResponseRecipe()
|
|
84
|
+
if column_config.column_type == DataDesignerColumnType.LLM_CODE:
|
|
85
|
+
return CodeResponseRecipe(
|
|
86
|
+
syntax=CodeLang.parse_lang(column_config.code_lang),
|
|
87
|
+
)
|
|
88
|
+
if column_config.column_type == DataDesignerColumnType.LLM_STRUCTURED:
|
|
89
|
+
return StructuredResponseRecipe(
|
|
90
|
+
json_schema=column_config.output_format,
|
|
91
|
+
)
|
|
92
|
+
if column_config.column_type == DataDesignerColumnType.LLM_JUDGE:
|
|
93
|
+
return PydanticResponseRecipe(
|
|
94
|
+
data_type=create_judge_structured_output_model(
|
|
95
|
+
[create_judge_response_model(s) for s in column_config.scores]
|
|
96
|
+
),
|
|
97
|
+
)
|
|
98
|
+
raise ValueError(f"No response recipe found for column type: {column_config.column_type}")
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Generic, Type, TypeVar, get_origin
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from data_designer.config.base import ConfigBase
|
|
11
|
+
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
12
|
+
from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType
|
|
13
|
+
|
|
14
|
+
DataT = TypeVar("DataT", dict, pd.DataFrame)
|
|
15
|
+
TaskConfigT = TypeVar("ConfigT", bound=ConfigBase)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConfigurableTaskMetadata(ConfigBase):
|
|
19
|
+
name: str
|
|
20
|
+
description: str
|
|
21
|
+
required_resources: list[ResourceType] | None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ConfigurableTask(ABC, Generic[TaskConfigT]):
|
|
25
|
+
def __init__(self, config: TaskConfigT, *, resource_provider: ResourceProvider | None):
|
|
26
|
+
self._config = self.get_config_type().model_validate(config)
|
|
27
|
+
self._resource_provider = resource_provider
|
|
28
|
+
self._validate_resources()
|
|
29
|
+
self._validate()
|
|
30
|
+
self._initialize()
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def get_config_type(cls) -> Type[TaskConfigT]:
|
|
34
|
+
for base in cls.__orig_bases__:
|
|
35
|
+
if hasattr(base, "__args__") and len(base.__args__) == 1:
|
|
36
|
+
arg = base.__args__[0]
|
|
37
|
+
origin = get_origin(arg) or arg
|
|
38
|
+
if isinstance(origin, type) and issubclass(origin, ConfigBase):
|
|
39
|
+
return base.__args__[0]
|
|
40
|
+
raise TypeError(
|
|
41
|
+
f"Could not determine config type for `{cls.__name__}`. Please ensure that the "
|
|
42
|
+
"`ConfigurableTask` is defined with a generic type argument, where the type argument "
|
|
43
|
+
"is a subclass of `ConfigBase`."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def artifact_path(self) -> Path:
|
|
48
|
+
return self.artifact_storage.artifact_path
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def artifact_storage(self) -> ArtifactStorage:
|
|
52
|
+
return self.resource_provider.artifact_storage
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def base_dataset_path(self) -> Path:
|
|
56
|
+
return self.artifact_storage.base_dataset_path
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def config(self) -> TaskConfigT:
|
|
60
|
+
return self._config
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def resource_provider(self) -> ResourceProvider:
|
|
64
|
+
if self._resource_provider is None:
|
|
65
|
+
raise ValueError(f"No resource provider provided for the `{self.metadata().name}` task.")
|
|
66
|
+
return self._resource_provider
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def metadata() -> ConfigurableTaskMetadata: ...
|
|
71
|
+
|
|
72
|
+
def _initialize(self) -> None:
|
|
73
|
+
"""An internal method for custom initialization logic, which will be called in the constructor."""
|
|
74
|
+
|
|
75
|
+
def _validate(self) -> None:
|
|
76
|
+
"""An internal method for custom validation logic, which will be called in the constructor."""
|
|
77
|
+
|
|
78
|
+
def _validate_resources(self) -> None:
|
|
79
|
+
for resource in self.metadata().required_resources or []:
|
|
80
|
+
if resource is not None:
|
|
81
|
+
if getattr(self.resource_provider, ResourceType(resource).value) is None:
|
|
82
|
+
raise ValueError(f"Resource {resource} is required for the `{self.metadata().name}`")
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
import shutil
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from pydantic import BaseModel, field_validator, model_validator
|
|
12
|
+
|
|
13
|
+
from data_designer.config.utils.io_helpers import read_parquet_dataset
|
|
14
|
+
from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum
|
|
15
|
+
from data_designer.engine.dataset_builders.errors import ArtifactStorageError
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
BATCH_FILE_NAME_FORMAT = "batch_{batch_number:05d}.parquet"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BatchStage(StrEnum):
|
|
23
|
+
PARTIAL_RESULT = "partial_results_path"
|
|
24
|
+
FINAL_RESULT = "final_dataset_path"
|
|
25
|
+
DROPPED_COLUMNS = "dropped_columns_dataset_path"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ArtifactStorage(BaseModel):
|
|
29
|
+
artifact_path: Path | str
|
|
30
|
+
dataset_name: str = "dataset"
|
|
31
|
+
final_dataset_folder_name: str = "parquet-files"
|
|
32
|
+
partial_results_folder_name: str = "tmp-partial-parquet-files"
|
|
33
|
+
dropped_columns_folder_name: str = "dropped-columns-parquet-files"
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def artifact_path_exists(self) -> bool:
|
|
37
|
+
return self.artifact_path.exists()
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def base_dataset_path(self) -> Path:
|
|
41
|
+
return self.artifact_path / self.dataset_name
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def dropped_columns_dataset_path(self) -> Path:
|
|
45
|
+
return self.base_dataset_path / self.dropped_columns_folder_name
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def final_dataset_path(self) -> Path:
|
|
49
|
+
return self.base_dataset_path / self.final_dataset_folder_name
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def metadata_file_path(self) -> Path:
|
|
53
|
+
return self.base_dataset_path / "metadata.json"
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def partial_results_path(self) -> Path:
|
|
57
|
+
return self.base_dataset_path / self.partial_results_folder_name
|
|
58
|
+
|
|
59
|
+
@field_validator("artifact_path")
|
|
60
|
+
def validate_artifact_path(cls, v: Union[Path, str]) -> Path:
|
|
61
|
+
v = Path(v)
|
|
62
|
+
if not v.is_dir():
|
|
63
|
+
raise ArtifactStorageError("Artifact path must exist and be a directory")
|
|
64
|
+
return v
|
|
65
|
+
|
|
66
|
+
@model_validator(mode="after")
|
|
67
|
+
def validate_folder_names(self):
|
|
68
|
+
folder_names = [
|
|
69
|
+
self.dataset_name,
|
|
70
|
+
self.final_dataset_folder_name,
|
|
71
|
+
self.partial_results_folder_name,
|
|
72
|
+
self.dropped_columns_folder_name,
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
for name in folder_names:
|
|
76
|
+
if len(name) == 0:
|
|
77
|
+
raise ArtifactStorageError("🛑 Directory names must be non-empty strings.")
|
|
78
|
+
|
|
79
|
+
if len(set(folder_names)) != len(folder_names):
|
|
80
|
+
raise ArtifactStorageError("🛑 Folder names must be unique (no collisions allowed).")
|
|
81
|
+
|
|
82
|
+
invalid_chars = {"<", ">", ":", '"', "/", "\\", "|", "?", "*"}
|
|
83
|
+
for name in folder_names:
|
|
84
|
+
if any(char in invalid_chars for char in name):
|
|
85
|
+
raise ArtifactStorageError(f"🛑 Directory name '{name}' contains invalid characters.")
|
|
86
|
+
|
|
87
|
+
return self
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def mkdir_if_needed(path: Path | str) -> Path:
|
|
91
|
+
"""Create the directory if it does not exist."""
|
|
92
|
+
path = Path(path)
|
|
93
|
+
if not path.exists():
|
|
94
|
+
logger.debug(f"📁 Creating directory: {path}")
|
|
95
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
96
|
+
return path
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def read_parquet_files(path: Path) -> pd.DataFrame:
|
|
100
|
+
return read_parquet_dataset(path)
|
|
101
|
+
|
|
102
|
+
def create_batch_file_path(
|
|
103
|
+
self,
|
|
104
|
+
batch_number: int,
|
|
105
|
+
batch_stage: BatchStage,
|
|
106
|
+
) -> Path:
|
|
107
|
+
if batch_number < 0:
|
|
108
|
+
raise ArtifactStorageError("🛑 Batch number must be non-negative.")
|
|
109
|
+
return self._get_stage_path(batch_stage) / BATCH_FILE_NAME_FORMAT.format(batch_number=batch_number)
|
|
110
|
+
|
|
111
|
+
def load_dataset(self, batch_stage: BatchStage = BatchStage.FINAL_RESULT) -> pd.DataFrame:
|
|
112
|
+
return read_parquet_dataset(self._get_stage_path(batch_stage))
|
|
113
|
+
|
|
114
|
+
def load_dataset_with_dropped_columns(self) -> pd.DataFrame:
|
|
115
|
+
# The pyarrow backend has better support for nested data types.
|
|
116
|
+
df = self.load_dataset()
|
|
117
|
+
if (
|
|
118
|
+
self.dropped_columns_dataset_path.exists()
|
|
119
|
+
and self.create_batch_file_path(0, BatchStage.DROPPED_COLUMNS).is_file()
|
|
120
|
+
):
|
|
121
|
+
logger.debug("Concatenating dropped columns to the final dataset.")
|
|
122
|
+
df_dropped = self.load_dataset(batch_stage=BatchStage.DROPPED_COLUMNS)
|
|
123
|
+
if len(df_dropped) != len(df):
|
|
124
|
+
raise ArtifactStorageError(
|
|
125
|
+
"🛑 The dropped-columns dataset has a different number of rows than the main dataset. "
|
|
126
|
+
"Something unexpected must have happened to the dataset builder's artifacts."
|
|
127
|
+
)
|
|
128
|
+
# To ensure indexes are aligned and avoid silent misalignment (which would introduce NaNs),
|
|
129
|
+
# check that the indexes are identical before concatenation.
|
|
130
|
+
if not df.index.equals(df_dropped.index):
|
|
131
|
+
raise ArtifactStorageError(
|
|
132
|
+
"🛑 The indexes of the main and dropped columns DataFrames are not aligned. "
|
|
133
|
+
"Something unexpected must have happened to the dataset builder's artifacts."
|
|
134
|
+
)
|
|
135
|
+
df = pd.concat([df, df_dropped], axis=1)
|
|
136
|
+
return df
|
|
137
|
+
|
|
138
|
+
def move_partial_result_to_final_file_path(self, batch_number: int) -> Path:
|
|
139
|
+
partial_result_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.PARTIAL_RESULT)
|
|
140
|
+
if not partial_result_path.exists():
|
|
141
|
+
raise ArtifactStorageError("🛑 Partial result file not found.")
|
|
142
|
+
self.mkdir_if_needed(self._get_stage_path(BatchStage.FINAL_RESULT))
|
|
143
|
+
final_file_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.FINAL_RESULT)
|
|
144
|
+
shutil.move(partial_result_path, final_file_path)
|
|
145
|
+
return final_file_path
|
|
146
|
+
|
|
147
|
+
def write_configs(self, json_file_name: str, configs: list[dict]) -> Path:
|
|
148
|
+
self.mkdir_if_needed(self.base_dataset_path)
|
|
149
|
+
with open(self.base_dataset_path / json_file_name, "w") as file:
|
|
150
|
+
json.dump([c.model_dump(mode="json") for c in configs], file, indent=4)
|
|
151
|
+
return self.base_dataset_path / json_file_name
|
|
152
|
+
|
|
153
|
+
def write_batch_to_parquet_file(
|
|
154
|
+
self,
|
|
155
|
+
batch_number: int,
|
|
156
|
+
dataframe: pd.DataFrame,
|
|
157
|
+
batch_stage: BatchStage,
|
|
158
|
+
) -> Path:
|
|
159
|
+
file_path = self.create_batch_file_path(batch_number, batch_stage=batch_stage)
|
|
160
|
+
self.write_parquet_file(file_path.name, dataframe, batch_stage)
|
|
161
|
+
return file_path
|
|
162
|
+
|
|
163
|
+
def write_parquet_file(
|
|
164
|
+
self,
|
|
165
|
+
parquet_file_name: str,
|
|
166
|
+
dataframe: pd.DataFrame,
|
|
167
|
+
batch_stage: BatchStage,
|
|
168
|
+
) -> Path:
|
|
169
|
+
self.mkdir_if_needed(self._get_stage_path(batch_stage))
|
|
170
|
+
file_path = self._get_stage_path(batch_stage) / parquet_file_name
|
|
171
|
+
dataframe.to_parquet(file_path, index=False)
|
|
172
|
+
return file_path
|
|
173
|
+
|
|
174
|
+
def write_metadata(self, metadata: dict) -> Path:
|
|
175
|
+
self.mkdir_if_needed(self.base_dataset_path)
|
|
176
|
+
with open(self.metadata_file_path, "w") as file:
|
|
177
|
+
json.dump(metadata, file)
|
|
178
|
+
return self.metadata_file_path
|
|
179
|
+
|
|
180
|
+
def _get_stage_path(self, stage: BatchStage) -> Path:
|
|
181
|
+
return getattr(self, resolve_string_enum(stage, BatchStage).value)
|