data-designer-engine 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/_version.py +34 -0
- data_designer/engine/analysis/column_profilers/base.py +49 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
- data_designer/engine/analysis/column_profilers/registry.py +22 -0
- data_designer/engine/analysis/column_statistics.py +145 -0
- data_designer/engine/analysis/dataset_profiler.py +149 -0
- data_designer/engine/analysis/errors.py +9 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +122 -0
- data_designer/engine/column_generators/generators/embedding.py +35 -0
- data_designer/engine/column_generators/generators/expression.py +55 -0
- data_designer/engine/column_generators/generators/llm_completion.py +116 -0
- data_designer/engine/column_generators/generators/samplers.py +69 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
- data_designer/engine/column_generators/generators/validation.py +140 -0
- data_designer/engine/column_generators/registry.py +60 -0
- data_designer/engine/column_generators/utils/errors.py +15 -0
- data_designer/engine/column_generators/utils/generator_classification.py +43 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
- data_designer/engine/compiler.py +97 -0
- data_designer/engine/configurable_task.py +71 -0
- data_designer/engine/dataset_builders/artifact_storage.py +283 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
- data_designer/engine/dataset_builders/errors.py +15 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
- data_designer/engine/dataset_builders/utils/dag.py +62 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
- data_designer/engine/dataset_builders/utils/errors.py +15 -0
- data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
- data_designer/engine/errors.py +51 -0
- data_designer/engine/model_provider.py +77 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +300 -0
- data_designer/engine/models/facade.py +284 -0
- data_designer/engine/models/factory.py +42 -0
- data_designer/engine/models/litellm_overrides.py +179 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +235 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +62 -0
- data_designer/engine/models/parsers/types.py +84 -0
- data_designer/engine/models/recipes/base.py +81 -0
- data_designer/engine/models/recipes/response_recipes.py +293 -0
- data_designer/engine/models/registry.py +151 -0
- data_designer/engine/models/telemetry.py +362 -0
- data_designer/engine/models/usage.py +73 -0
- data_designer/engine/models/utils.py +101 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +65 -0
- data_designer/engine/processing/ginja/environment.py +463 -0
- data_designer/engine/processing/ginja/exceptions.py +56 -0
- data_designer/engine/processing/ginja/record.py +32 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +15 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
- data_designer/engine/processing/gsonschema/types.py +10 -0
- data_designer/engine/processing/gsonschema/validators.py +202 -0
- data_designer/engine/processing/processors/base.py +13 -0
- data_designer/engine/processing/processors/drop_columns.py +42 -0
- data_designer/engine/processing/processors/registry.py +25 -0
- data_designer/engine/processing/processors/schema_transform.py +71 -0
- data_designer/engine/processing/utils.py +169 -0
- data_designer/engine/registry/base.py +99 -0
- data_designer/engine/registry/data_designer_registry.py +39 -0
- data_designer/engine/registry/errors.py +12 -0
- data_designer/engine/resources/managed_dataset_generator.py +39 -0
- data_designer/engine/resources/managed_dataset_repository.py +197 -0
- data_designer/engine/resources/managed_storage.py +65 -0
- data_designer/engine/resources/resource_provider.py +77 -0
- data_designer/engine/resources/seed_reader.py +154 -0
- data_designer/engine/sampling_gen/column.py +91 -0
- data_designer/engine/sampling_gen/constraints.py +100 -0
- data_designer/engine/sampling_gen/data_sources/base.py +217 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
- data_designer/engine/sampling_gen/entities/errors.py +10 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
- data_designer/engine/sampling_gen/entities/person.py +144 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
- data_designer/engine/sampling_gen/errors.py +26 -0
- data_designer/engine/sampling_gen/generator.py +122 -0
- data_designer/engine/sampling_gen/jinja_utils.py +64 -0
- data_designer/engine/sampling_gen/people_gen.py +199 -0
- data_designer/engine/sampling_gen/person_constants.py +56 -0
- data_designer/engine/sampling_gen/schema.py +147 -0
- data_designer/engine/sampling_gen/schema_builder.py +61 -0
- data_designer/engine/sampling_gen/utils.py +46 -0
- data_designer/engine/secret_resolver.py +82 -0
- data_designer/engine/testing/__init__.py +12 -0
- data_designer/engine/testing/stubs.py +133 -0
- data_designer/engine/testing/utils.py +20 -0
- data_designer/engine/validation.py +367 -0
- data_designer/engine/validators/__init__.py +19 -0
- data_designer/engine/validators/base.py +38 -0
- data_designer/engine/validators/local_callable.py +39 -0
- data_designer/engine/validators/python.py +254 -0
- data_designer/engine/validators/remote.py +89 -0
- data_designer/engine/validators/sql.py +65 -0
- data_designer_engine-0.4.0.dist-info/METADATA +50 -0
- data_designer_engine-0.4.0.dist-info/RECORD +114 -0
- data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from data_designer.config.column_configs import SingleColumnConfig
|
|
10
|
+
from data_designer.config.column_types import DataDesignerColumnType
|
|
11
|
+
from data_designer.config.models import ModelConfig
|
|
12
|
+
from data_designer.config.utils.code_lang import CodeLang
|
|
13
|
+
from data_designer.config.utils.misc import extract_keywords_from_jinja2_template
|
|
14
|
+
from data_designer.config.utils.type_helpers import StrEnum
|
|
15
|
+
from data_designer.engine.column_generators.utils.errors import PromptTemplateRenderError
|
|
16
|
+
from data_designer.engine.column_generators.utils.judge_score_factory import (
|
|
17
|
+
create_judge_response_model,
|
|
18
|
+
create_judge_structured_output_model,
|
|
19
|
+
)
|
|
20
|
+
from data_designer.engine.models.recipes.base import ResponseRecipe
|
|
21
|
+
from data_designer.engine.models.recipes.response_recipes import (
|
|
22
|
+
CodeResponseRecipe,
|
|
23
|
+
PydanticResponseRecipe,
|
|
24
|
+
StructuredResponseRecipe,
|
|
25
|
+
TextResponseRecipe,
|
|
26
|
+
)
|
|
27
|
+
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
|
|
28
|
+
from data_designer.engine.processing.ginja.exceptions import UserTemplateError, UserTemplateUnsupportedFiltersError
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PromptType(StrEnum):
|
|
34
|
+
SYSTEM_PROMPT = "system_prompt"
|
|
35
|
+
USER_PROMPT = "user_prompt"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RecordBasedPromptRenderer(WithJinja2UserTemplateRendering):
|
|
39
|
+
def __init__(self, response_recipe: ResponseRecipe, *, error_message_context: dict[str, str] | None = None):
|
|
40
|
+
self.response_recipe = response_recipe
|
|
41
|
+
self._error_message_context = error_message_context
|
|
42
|
+
|
|
43
|
+
def render(self, *, prompt_template: str | None, record: dict, prompt_type: PromptType) -> str | None:
|
|
44
|
+
self._prepare_environment(prompt_template=prompt_template, record=record, prompt_type=prompt_type)
|
|
45
|
+
rendered_prompt = self.render_multi_template(prompt_type, record) if prompt_template else ""
|
|
46
|
+
recipe_applicator = (
|
|
47
|
+
self.response_recipe.apply_recipe_to_user_prompt
|
|
48
|
+
if prompt_type == PromptType.USER_PROMPT
|
|
49
|
+
else self.response_recipe.apply_recipe_to_system_prompt
|
|
50
|
+
)
|
|
51
|
+
return recipe_applicator(rendered_prompt)
|
|
52
|
+
|
|
53
|
+
def _prepare_environment(self, *, prompt_template: str | None, record: dict, prompt_type: PromptType) -> None:
|
|
54
|
+
try:
|
|
55
|
+
self.prepare_jinja2_multi_template_renderer(
|
|
56
|
+
template_name=prompt_type.value,
|
|
57
|
+
prompt_template=prompt_template,
|
|
58
|
+
dataset_variables=list(record.keys()),
|
|
59
|
+
)
|
|
60
|
+
except (UserTemplateUnsupportedFiltersError, UserTemplateError) as exc:
|
|
61
|
+
template_variables = extract_keywords_from_jinja2_template(prompt_template)
|
|
62
|
+
missing_columns = list(set(template_variables) - set(record.keys()))
|
|
63
|
+
|
|
64
|
+
error_msg = (
|
|
65
|
+
f"There was an error preparing the {prompt_type.value.replace('_', ' ')} "
|
|
66
|
+
"template. Please double check that the template is valid Jinja2 syntax, that all "
|
|
67
|
+
"referenced variables are defined, and that any filters you are using are supported."
|
|
68
|
+
)
|
|
69
|
+
if len(missing_columns) > 0:
|
|
70
|
+
error_msg += f"\nThe following {missing_columns} columns are missing!"
|
|
71
|
+
if self._error_message_context is not None:
|
|
72
|
+
error_msg += f"\n{json.dumps(self._error_message_context, indent=2)}"
|
|
73
|
+
logger.error(f"🛑 {error_msg}")
|
|
74
|
+
raise PromptTemplateRenderError(f"{exc!s} {error_msg}")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def create_response_recipe(
|
|
78
|
+
column_config: SingleColumnConfig, model_config: ModelConfig | None = None
|
|
79
|
+
) -> ResponseRecipe:
|
|
80
|
+
if model_config and column_config.model_alias != model_config.alias:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Column config model alias {column_config.model_alias} does not match model config alias {model_config.alias}"
|
|
83
|
+
)
|
|
84
|
+
if column_config.column_type == DataDesignerColumnType.LLM_TEXT:
|
|
85
|
+
return TextResponseRecipe()
|
|
86
|
+
if column_config.column_type == DataDesignerColumnType.LLM_CODE:
|
|
87
|
+
return CodeResponseRecipe(
|
|
88
|
+
syntax=CodeLang.parse_lang(column_config.code_lang),
|
|
89
|
+
)
|
|
90
|
+
if column_config.column_type == DataDesignerColumnType.LLM_STRUCTURED:
|
|
91
|
+
return StructuredResponseRecipe(
|
|
92
|
+
json_schema=column_config.output_format,
|
|
93
|
+
)
|
|
94
|
+
if column_config.column_type == DataDesignerColumnType.LLM_JUDGE:
|
|
95
|
+
return PydanticResponseRecipe(
|
|
96
|
+
data_type=create_judge_structured_output_model(
|
|
97
|
+
[create_judge_response_model(s) for s in column_config.scores]
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
raise ValueError(f"No response recipe found for column type: {column_config.column_type}")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig
|
|
9
|
+
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
10
|
+
from data_designer.config.errors import InvalidConfigError
|
|
11
|
+
from data_designer.config.sampler_params import UUIDSamplerParams
|
|
12
|
+
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
13
|
+
from data_designer.engine.resources.seed_reader import SeedReader
|
|
14
|
+
from data_designer.engine.validation import ViolationLevel, rich_print_violations, validate_data_designer_config
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def compile_data_designer_config(config: DataDesignerConfig, resource_provider: ResourceProvider) -> DataDesignerConfig:
|
|
20
|
+
_resolve_and_add_seed_columns(config, resource_provider.seed_reader)
|
|
21
|
+
_add_internal_row_id_column_if_needed(config)
|
|
22
|
+
_validate(config)
|
|
23
|
+
return config
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _resolve_and_add_seed_columns(config: DataDesignerConfig, seed_reader: SeedReader | None) -> None:
|
|
27
|
+
"""Fetches the seed dataset column names, ensures there are no conflicts
|
|
28
|
+
with other columns, and adds seed column configs to the DataDesignerConfig.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
if not seed_reader:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
seed_col_names = seed_reader.get_column_names()
|
|
35
|
+
existing_columns = {column.name for column in config.columns}
|
|
36
|
+
colliding_columns = {name for name in seed_col_names if name in existing_columns}
|
|
37
|
+
if colliding_columns:
|
|
38
|
+
raise InvalidConfigError(
|
|
39
|
+
f"🛑 Seed dataset column(s) {colliding_columns} collide with existing column(s). "
|
|
40
|
+
"Please remove the conflicting columns or use a seed dataset with different column names."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
config.columns.extend([SeedDatasetColumnConfig(name=col_name) for col_name in seed_col_names])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _add_internal_row_id_column_if_needed(config: DataDesignerConfig) -> None:
|
|
47
|
+
"""Adds a UUID sampler column named '_internal_row_id' (set to drop) if needed to enable generation.
|
|
48
|
+
|
|
49
|
+
Generation requires either:
|
|
50
|
+
- At least one sampler column (which can generate data from scratch), OR
|
|
51
|
+
- A seed dataset (which provides initial data rows)
|
|
52
|
+
|
|
53
|
+
If neither exists, a UUID sampler column '_internal_row_id' is automatically added and marked for drop
|
|
54
|
+
to enable the generation process to start.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
config: The DataDesigner configuration to potentially modify.
|
|
58
|
+
"""
|
|
59
|
+
has_sampler_column = any(isinstance(col, SamplerColumnConfig) for col in config.columns)
|
|
60
|
+
has_seed_dataset_column = any(isinstance(col, SeedDatasetColumnConfig) for col in config.columns)
|
|
61
|
+
|
|
62
|
+
if not has_sampler_column and not has_seed_dataset_column:
|
|
63
|
+
logger.warning(
|
|
64
|
+
"🔔 No sampler column or seed dataset detected. Adding UUID column '_internal_row_id' (marked for drop) to enable generation."
|
|
65
|
+
)
|
|
66
|
+
id_column = SamplerColumnConfig(
|
|
67
|
+
name="_internal_row_id",
|
|
68
|
+
sampler_type="uuid",
|
|
69
|
+
params=UUIDSamplerParams(),
|
|
70
|
+
drop=True,
|
|
71
|
+
)
|
|
72
|
+
config.columns.insert(0, id_column)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _validate(config: DataDesignerConfig) -> None:
|
|
76
|
+
allowed_references = _get_allowed_references(config)
|
|
77
|
+
violations = validate_data_designer_config(
|
|
78
|
+
columns=config.columns,
|
|
79
|
+
processor_configs=config.processors or [],
|
|
80
|
+
allowed_references=allowed_references,
|
|
81
|
+
)
|
|
82
|
+
rich_print_violations(violations)
|
|
83
|
+
if len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0:
|
|
84
|
+
raise InvalidConfigError(
|
|
85
|
+
"🛑 Your configuration contains validation errors. Please address the indicated issues and try again."
|
|
86
|
+
)
|
|
87
|
+
if len(violations) == 0:
|
|
88
|
+
logger.info("✅ Validation passed")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_allowed_references(config: DataDesignerConfig) -> list[str]:
|
|
92
|
+
refs = set[str]()
|
|
93
|
+
for column_config in config.columns:
|
|
94
|
+
refs.add(column_config.name)
|
|
95
|
+
for side_effect_column in column_config.side_effect_columns:
|
|
96
|
+
refs.add(side_effect_column)
|
|
97
|
+
return list(refs)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Generic, TypeVar, get_origin
|
|
9
|
+
|
|
10
|
+
from data_designer.config.base import ConfigBase
|
|
11
|
+
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
12
|
+
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
13
|
+
from data_designer.lazy_heavy_imports import pd
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
DataT = TypeVar("DataT", dict, pd.DataFrame)
|
|
19
|
+
TaskConfigT = TypeVar("ConfigT", bound=ConfigBase)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ConfigurableTask(ABC, Generic[TaskConfigT]):
|
|
23
|
+
def __init__(self, config: TaskConfigT, resource_provider: ResourceProvider):
|
|
24
|
+
self._config = self.get_config_type().model_validate(config)
|
|
25
|
+
self._resource_provider = resource_provider
|
|
26
|
+
self._validate()
|
|
27
|
+
self._initialize()
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def get_config_type(cls) -> type[TaskConfigT]:
|
|
31
|
+
for base in cls.__orig_bases__:
|
|
32
|
+
if hasattr(base, "__args__") and len(base.__args__) == 1:
|
|
33
|
+
arg = base.__args__[0]
|
|
34
|
+
origin = get_origin(arg) or arg
|
|
35
|
+
if isinstance(origin, type) and issubclass(origin, ConfigBase):
|
|
36
|
+
return base.__args__[0]
|
|
37
|
+
raise TypeError(
|
|
38
|
+
f"Could not determine config type for `{cls.__name__}`. Please ensure that the "
|
|
39
|
+
"`ConfigurableTask` is defined with a generic type argument, where the type argument "
|
|
40
|
+
"is a subclass of `ConfigBase`."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def artifact_path(self) -> Path:
|
|
45
|
+
return self.artifact_storage.artifact_path
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def artifact_storage(self) -> ArtifactStorage:
|
|
49
|
+
return self.resource_provider.artifact_storage
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def base_dataset_path(self) -> Path:
|
|
53
|
+
return self.artifact_storage.base_dataset_path
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def config(self) -> TaskConfigT:
|
|
57
|
+
return self._config
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def name(self) -> str:
|
|
61
|
+
return self.__class__.__name__
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def resource_provider(self) -> ResourceProvider:
|
|
65
|
+
return self._resource_provider
|
|
66
|
+
|
|
67
|
+
def _initialize(self) -> None:
|
|
68
|
+
"""An internal method for custom initialization logic, which will be called in the constructor."""
|
|
69
|
+
|
|
70
|
+
def _validate(self) -> None:
|
|
71
|
+
"""An internal method for custom validation logic, which will be called in the constructor."""
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import shutil
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, field_validator, model_validator
|
|
15
|
+
|
|
16
|
+
from data_designer.config.utils.io_helpers import read_parquet_dataset
|
|
17
|
+
from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum
|
|
18
|
+
from data_designer.engine.dataset_builders.errors import ArtifactStorageError
|
|
19
|
+
from data_designer.lazy_heavy_imports import pd
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
BATCH_FILE_NAME_FORMAT = "batch_{batch_number:05d}.parquet"
|
|
27
|
+
SDG_CONFIG_FILENAME = "sdg.json"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BatchStage(StrEnum):
|
|
31
|
+
PARTIAL_RESULT = "partial_results_path"
|
|
32
|
+
FINAL_RESULT = "final_dataset_path"
|
|
33
|
+
DROPPED_COLUMNS = "dropped_columns_dataset_path"
|
|
34
|
+
PROCESSORS_OUTPUTS = "processors_outputs_path"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ArtifactStorage(BaseModel):
|
|
38
|
+
artifact_path: Path | str
|
|
39
|
+
dataset_name: str = "dataset"
|
|
40
|
+
final_dataset_folder_name: str = "parquet-files"
|
|
41
|
+
partial_results_folder_name: str = "tmp-partial-parquet-files"
|
|
42
|
+
dropped_columns_folder_name: str = "dropped-columns-parquet-files"
|
|
43
|
+
processors_outputs_folder_name: str = "processors-files"
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def artifact_path_exists(self) -> bool:
|
|
47
|
+
return self.artifact_path.exists()
|
|
48
|
+
|
|
49
|
+
@cached_property
|
|
50
|
+
def resolved_dataset_name(self) -> str:
|
|
51
|
+
dataset_path = self.artifact_path / self.dataset_name
|
|
52
|
+
if dataset_path.exists() and len(list(dataset_path.iterdir())) > 0:
|
|
53
|
+
new_dataset_name = f"{self.dataset_name}_{datetime.now().strftime('%m-%d-%Y_%H%M%S')}"
|
|
54
|
+
logger.info(
|
|
55
|
+
f"📂 Dataset path {str(dataset_path)!r} already exists. Dataset from this session"
|
|
56
|
+
f"\n\t\t will be saved to {str(self.artifact_path / new_dataset_name)!r} instead."
|
|
57
|
+
)
|
|
58
|
+
return new_dataset_name
|
|
59
|
+
return self.dataset_name
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def base_dataset_path(self) -> Path:
|
|
63
|
+
return self.artifact_path / self.resolved_dataset_name
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def dropped_columns_dataset_path(self) -> Path:
|
|
67
|
+
return self.base_dataset_path / self.dropped_columns_folder_name
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def final_dataset_path(self) -> Path:
|
|
71
|
+
return self.base_dataset_path / self.final_dataset_folder_name
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def metadata_file_path(self) -> Path:
|
|
75
|
+
return self.base_dataset_path / "metadata.json"
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def partial_results_path(self) -> Path:
|
|
79
|
+
return self.base_dataset_path / self.partial_results_folder_name
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def processors_outputs_path(self) -> Path:
|
|
83
|
+
return self.base_dataset_path / self.processors_outputs_folder_name
|
|
84
|
+
|
|
85
|
+
@field_validator("artifact_path")
|
|
86
|
+
def validate_artifact_path(cls, v: Path | str) -> Path:
|
|
87
|
+
v = Path(v)
|
|
88
|
+
if not v.is_dir():
|
|
89
|
+
raise ArtifactStorageError("Artifact path must exist and be a directory")
|
|
90
|
+
return v
|
|
91
|
+
|
|
92
|
+
@model_validator(mode="after")
|
|
93
|
+
def validate_folder_names(self):
|
|
94
|
+
folder_names = [
|
|
95
|
+
self.dataset_name,
|
|
96
|
+
self.final_dataset_folder_name,
|
|
97
|
+
self.partial_results_folder_name,
|
|
98
|
+
self.dropped_columns_folder_name,
|
|
99
|
+
self.processors_outputs_folder_name,
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
for name in folder_names:
|
|
103
|
+
if len(name) == 0:
|
|
104
|
+
raise ArtifactStorageError("🛑 Directory names must be non-empty strings.")
|
|
105
|
+
|
|
106
|
+
if len(set(folder_names)) != len(folder_names):
|
|
107
|
+
raise ArtifactStorageError("🛑 Folder names must be unique (no collisions allowed).")
|
|
108
|
+
|
|
109
|
+
invalid_chars = {"<", ">", ":", '"', "/", "\\", "|", "?", "*"}
|
|
110
|
+
for name in folder_names:
|
|
111
|
+
if any(char in invalid_chars for char in name):
|
|
112
|
+
raise ArtifactStorageError(f"🛑 Directory name '{name}' contains invalid characters.")
|
|
113
|
+
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def mkdir_if_needed(path: Path | str) -> Path:
|
|
118
|
+
"""Create the directory if it does not exist."""
|
|
119
|
+
path = Path(path)
|
|
120
|
+
if not path.exists():
|
|
121
|
+
logger.debug(f"📁 Creating directory: {path}")
|
|
122
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
123
|
+
return path
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def read_parquet_files(path: Path) -> pd.DataFrame:
|
|
127
|
+
return read_parquet_dataset(path)
|
|
128
|
+
|
|
129
|
+
def create_batch_file_path(
|
|
130
|
+
self,
|
|
131
|
+
batch_number: int,
|
|
132
|
+
batch_stage: BatchStage,
|
|
133
|
+
) -> Path:
|
|
134
|
+
if batch_number < 0:
|
|
135
|
+
raise ArtifactStorageError("🛑 Batch number must be non-negative.")
|
|
136
|
+
return self._get_stage_path(batch_stage) / BATCH_FILE_NAME_FORMAT.format(batch_number=batch_number)
|
|
137
|
+
|
|
138
|
+
def load_dataset(self, batch_stage: BatchStage = BatchStage.FINAL_RESULT) -> pd.DataFrame:
|
|
139
|
+
return read_parquet_dataset(self._get_stage_path(batch_stage))
|
|
140
|
+
|
|
141
|
+
def load_dataset_with_dropped_columns(self) -> pd.DataFrame:
|
|
142
|
+
# The pyarrow backend has better support for nested data types.
|
|
143
|
+
df = self.load_dataset()
|
|
144
|
+
if (
|
|
145
|
+
self.dropped_columns_dataset_path.exists()
|
|
146
|
+
and self.create_batch_file_path(0, BatchStage.DROPPED_COLUMNS).is_file()
|
|
147
|
+
):
|
|
148
|
+
logger.debug("Concatenating dropped columns to the final dataset.")
|
|
149
|
+
df_dropped = self.load_dataset(batch_stage=BatchStage.DROPPED_COLUMNS)
|
|
150
|
+
if len(df_dropped) != len(df):
|
|
151
|
+
raise ArtifactStorageError(
|
|
152
|
+
"🛑 The dropped-columns dataset has a different number of rows than the main dataset. "
|
|
153
|
+
"Something unexpected must have happened to the dataset builder's artifacts."
|
|
154
|
+
)
|
|
155
|
+
# To ensure indexes are aligned and avoid silent misalignment (which would introduce NaNs),
|
|
156
|
+
# check that the indexes are identical before concatenation.
|
|
157
|
+
if not df.index.equals(df_dropped.index):
|
|
158
|
+
raise ArtifactStorageError(
|
|
159
|
+
"🛑 The indexes of the main and dropped columns DataFrames are not aligned. "
|
|
160
|
+
"Something unexpected must have happened to the dataset builder's artifacts."
|
|
161
|
+
)
|
|
162
|
+
df = pd.concat([df, df_dropped], axis=1)
|
|
163
|
+
return df
|
|
164
|
+
|
|
165
|
+
def move_partial_result_to_final_file_path(self, batch_number: int) -> Path:
|
|
166
|
+
partial_result_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.PARTIAL_RESULT)
|
|
167
|
+
if not partial_result_path.exists():
|
|
168
|
+
raise ArtifactStorageError("🛑 Partial result file not found.")
|
|
169
|
+
self.mkdir_if_needed(self._get_stage_path(BatchStage.FINAL_RESULT))
|
|
170
|
+
final_file_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.FINAL_RESULT)
|
|
171
|
+
shutil.move(partial_result_path, final_file_path)
|
|
172
|
+
return final_file_path
|
|
173
|
+
|
|
174
|
+
def write_batch_to_parquet_file(
|
|
175
|
+
self,
|
|
176
|
+
batch_number: int,
|
|
177
|
+
dataframe: pd.DataFrame,
|
|
178
|
+
batch_stage: BatchStage,
|
|
179
|
+
subfolder: str | None = None,
|
|
180
|
+
) -> Path:
|
|
181
|
+
file_path = self.create_batch_file_path(batch_number, batch_stage=batch_stage)
|
|
182
|
+
self.write_parquet_file(file_path.name, dataframe, batch_stage, subfolder=subfolder)
|
|
183
|
+
return file_path
|
|
184
|
+
|
|
185
|
+
def write_parquet_file(
|
|
186
|
+
self,
|
|
187
|
+
parquet_file_name: str,
|
|
188
|
+
dataframe: pd.DataFrame,
|
|
189
|
+
batch_stage: BatchStage,
|
|
190
|
+
subfolder: str | None = None,
|
|
191
|
+
) -> Path:
|
|
192
|
+
subfolder = subfolder or ""
|
|
193
|
+
self.mkdir_if_needed(self._get_stage_path(batch_stage) / subfolder)
|
|
194
|
+
file_path = self._get_stage_path(batch_stage) / subfolder / parquet_file_name
|
|
195
|
+
dataframe.to_parquet(file_path, index=False)
|
|
196
|
+
return file_path
|
|
197
|
+
|
|
198
|
+
def get_parquet_file_paths(self) -> list[str]:
|
|
199
|
+
"""Get list of parquet file paths relative to base_dataset_path.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
List of relative paths to parquet files in the final dataset folder.
|
|
203
|
+
"""
|
|
204
|
+
return [str(f.relative_to(self.base_dataset_path)) for f in sorted(self.final_dataset_path.glob("*.parquet"))]
|
|
205
|
+
|
|
206
|
+
def get_processor_file_paths(self) -> dict[str, list[str]]:
|
|
207
|
+
"""Get processor output files organized by processor name.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Dictionary mapping processor names to lists of relative file paths.
|
|
211
|
+
"""
|
|
212
|
+
processor_files: dict[str, list[str]] = {}
|
|
213
|
+
if self.processors_outputs_path.exists():
|
|
214
|
+
for processor_dir in sorted(self.processors_outputs_path.iterdir()):
|
|
215
|
+
if processor_dir.is_dir():
|
|
216
|
+
processor_name = processor_dir.name
|
|
217
|
+
processor_files[processor_name] = [
|
|
218
|
+
str(f.relative_to(self.base_dataset_path))
|
|
219
|
+
for f in sorted(processor_dir.rglob("*"))
|
|
220
|
+
if f.is_file()
|
|
221
|
+
]
|
|
222
|
+
return processor_files
|
|
223
|
+
|
|
224
|
+
def get_file_paths(self) -> dict[str, list[str] | dict[str, list[str]]]:
|
|
225
|
+
"""Get all file paths organized by type.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Dictionary with 'parquet-files' and 'processor-files' keys.
|
|
229
|
+
"""
|
|
230
|
+
file_paths = {
|
|
231
|
+
"parquet-files": self.get_parquet_file_paths(),
|
|
232
|
+
}
|
|
233
|
+
processor_file_paths = self.get_processor_file_paths()
|
|
234
|
+
if processor_file_paths:
|
|
235
|
+
file_paths["processor-files"] = processor_file_paths
|
|
236
|
+
|
|
237
|
+
return file_paths
|
|
238
|
+
|
|
239
|
+
def read_metadata(self) -> dict:
|
|
240
|
+
"""Read metadata from the metadata.json file.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Dictionary containing the metadata.
|
|
244
|
+
|
|
245
|
+
Raises:
|
|
246
|
+
FileNotFoundError: If metadata file doesn't exist.
|
|
247
|
+
"""
|
|
248
|
+
with open(self.metadata_file_path, "r") as file:
|
|
249
|
+
return json.load(file)
|
|
250
|
+
|
|
251
|
+
def write_metadata(self, metadata: dict) -> Path:
|
|
252
|
+
"""Write metadata to the metadata.json file.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
metadata: Dictionary containing metadata to write.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Path to the written metadata file.
|
|
259
|
+
"""
|
|
260
|
+
self.mkdir_if_needed(self.base_dataset_path)
|
|
261
|
+
with open(self.metadata_file_path, "w") as file:
|
|
262
|
+
json.dump(metadata, file, indent=4, sort_keys=True)
|
|
263
|
+
return self.metadata_file_path
|
|
264
|
+
|
|
265
|
+
def update_metadata(self, updates: dict) -> Path:
|
|
266
|
+
"""Update existing metadata with new fields.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
updates: Dictionary of fields to add/update in metadata.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Path to the updated metadata file.
|
|
273
|
+
"""
|
|
274
|
+
try:
|
|
275
|
+
existing_metadata = self.read_metadata()
|
|
276
|
+
except FileNotFoundError:
|
|
277
|
+
existing_metadata = {}
|
|
278
|
+
|
|
279
|
+
existing_metadata.update(updates)
|
|
280
|
+
return self.write_metadata(existing_metadata)
|
|
281
|
+
|
|
282
|
+
def _get_stage_path(self, stage: BatchStage) -> Path:
|
|
283
|
+
return getattr(self, resolve_string_enum(stage, BatchStage).value)
|