data-designer-engine 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/_version.py +34 -0
- data_designer/engine/analysis/column_profilers/base.py +49 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
- data_designer/engine/analysis/column_profilers/registry.py +22 -0
- data_designer/engine/analysis/column_statistics.py +145 -0
- data_designer/engine/analysis/dataset_profiler.py +149 -0
- data_designer/engine/analysis/errors.py +9 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +122 -0
- data_designer/engine/column_generators/generators/embedding.py +35 -0
- data_designer/engine/column_generators/generators/expression.py +55 -0
- data_designer/engine/column_generators/generators/llm_completion.py +116 -0
- data_designer/engine/column_generators/generators/samplers.py +69 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
- data_designer/engine/column_generators/generators/validation.py +140 -0
- data_designer/engine/column_generators/registry.py +60 -0
- data_designer/engine/column_generators/utils/errors.py +15 -0
- data_designer/engine/column_generators/utils/generator_classification.py +43 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
- data_designer/engine/compiler.py +97 -0
- data_designer/engine/configurable_task.py +71 -0
- data_designer/engine/dataset_builders/artifact_storage.py +283 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
- data_designer/engine/dataset_builders/errors.py +15 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
- data_designer/engine/dataset_builders/utils/dag.py +62 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
- data_designer/engine/dataset_builders/utils/errors.py +15 -0
- data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
- data_designer/engine/errors.py +51 -0
- data_designer/engine/model_provider.py +77 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +300 -0
- data_designer/engine/models/facade.py +284 -0
- data_designer/engine/models/factory.py +42 -0
- data_designer/engine/models/litellm_overrides.py +179 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +235 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +62 -0
- data_designer/engine/models/parsers/types.py +84 -0
- data_designer/engine/models/recipes/base.py +81 -0
- data_designer/engine/models/recipes/response_recipes.py +293 -0
- data_designer/engine/models/registry.py +151 -0
- data_designer/engine/models/telemetry.py +362 -0
- data_designer/engine/models/usage.py +73 -0
- data_designer/engine/models/utils.py +101 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +65 -0
- data_designer/engine/processing/ginja/environment.py +463 -0
- data_designer/engine/processing/ginja/exceptions.py +56 -0
- data_designer/engine/processing/ginja/record.py +32 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +15 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
- data_designer/engine/processing/gsonschema/types.py +10 -0
- data_designer/engine/processing/gsonschema/validators.py +202 -0
- data_designer/engine/processing/processors/base.py +13 -0
- data_designer/engine/processing/processors/drop_columns.py +42 -0
- data_designer/engine/processing/processors/registry.py +25 -0
- data_designer/engine/processing/processors/schema_transform.py +71 -0
- data_designer/engine/processing/utils.py +169 -0
- data_designer/engine/registry/base.py +99 -0
- data_designer/engine/registry/data_designer_registry.py +39 -0
- data_designer/engine/registry/errors.py +12 -0
- data_designer/engine/resources/managed_dataset_generator.py +39 -0
- data_designer/engine/resources/managed_dataset_repository.py +197 -0
- data_designer/engine/resources/managed_storage.py +65 -0
- data_designer/engine/resources/resource_provider.py +77 -0
- data_designer/engine/resources/seed_reader.py +154 -0
- data_designer/engine/sampling_gen/column.py +91 -0
- data_designer/engine/sampling_gen/constraints.py +100 -0
- data_designer/engine/sampling_gen/data_sources/base.py +217 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
- data_designer/engine/sampling_gen/entities/errors.py +10 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
- data_designer/engine/sampling_gen/entities/person.py +144 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
- data_designer/engine/sampling_gen/errors.py +26 -0
- data_designer/engine/sampling_gen/generator.py +122 -0
- data_designer/engine/sampling_gen/jinja_utils.py +64 -0
- data_designer/engine/sampling_gen/people_gen.py +199 -0
- data_designer/engine/sampling_gen/person_constants.py +56 -0
- data_designer/engine/sampling_gen/schema.py +147 -0
- data_designer/engine/sampling_gen/schema_builder.py +61 -0
- data_designer/engine/sampling_gen/utils.py +46 -0
- data_designer/engine/secret_resolver.py +82 -0
- data_designer/engine/testing/__init__.py +12 -0
- data_designer/engine/testing/stubs.py +133 -0
- data_designer/engine/testing/utils.py +20 -0
- data_designer/engine/validation.py +367 -0
- data_designer/engine/validators/__init__.py +19 -0
- data_designer/engine/validators/base.py +38 -0
- data_designer/engine/validators/local_callable.py +39 -0
- data_designer/engine/validators/python.py +254 -0
- data_designer/engine/validators/remote.py +89 -0
- data_designer/engine/validators/sql.py +65 -0
- data_designer_engine-0.4.0.dist-info/METADATA +50 -0
- data_designer_engine-0.4.0.dist-info/RECORD +114 -0
- data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import re
|
|
8
|
+
from copy import deepcopy
|
|
9
|
+
from decimal import ROUND_HALF_UP, Decimal
|
|
10
|
+
from typing import TYPE_CHECKING, Any, overload
|
|
11
|
+
|
|
12
|
+
from data_designer.engine.processing.gsonschema.exceptions import JSONSchemaValidationError
|
|
13
|
+
from data_designer.engine.processing.gsonschema.schema_transformers import forbid_additional_properties
|
|
14
|
+
from data_designer.engine.processing.gsonschema.types import DataObjectT, JSONSchemaT, T_primitive
|
|
15
|
+
from data_designer.lazy_heavy_imports import jsonschema
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import jsonschema
|
|
19
|
+
|
|
20
|
+
DEFAULT_JSONSCHEMA_VALIDATOR = jsonschema.Draft202012Validator
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def prune_additional_properties(
|
|
26
|
+
_, allow_additional_properties: bool, instance: DataObjectT, schema: JSONSchemaT
|
|
27
|
+
) -> None:
|
|
28
|
+
"""A JSONSchemaValidtor extension function to prune additional properties.
|
|
29
|
+
|
|
30
|
+
Operates on an individual schema in-place.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
allow_additional_properties (bool): The value of the `additionalProperties`
|
|
34
|
+
field for this schema.
|
|
35
|
+
instance (DataObjectT): The data object being validated.
|
|
36
|
+
schema (JSONSchemaT): The schema for this object.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Nothing (in place)
|
|
40
|
+
"""
|
|
41
|
+
# Only act if the instance is a dict.
|
|
42
|
+
if not isinstance(instance, dict) or allow_additional_properties:
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
# Allowed keys are those defined in the schema's "properties".
|
|
46
|
+
allowed = schema.get("properties", {}).keys()
|
|
47
|
+
|
|
48
|
+
# Iterate over a copy of keys so we can modify the dict in place.
|
|
49
|
+
n_removed = 0
|
|
50
|
+
for key in list(instance.keys()):
|
|
51
|
+
if key not in allowed:
|
|
52
|
+
instance.pop(key)
|
|
53
|
+
n_removed += 1
|
|
54
|
+
logger.info(f"Unspecified property removed from data object: {key}.")
|
|
55
|
+
|
|
56
|
+
if n_removed > 0:
|
|
57
|
+
logger.info(f"{n_removed} unspecified properties removed from data object.")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def extend_jsonschema_validator_with_pruning(validator):
|
|
61
|
+
"""Modify behavior of a jsonschema.Validator to use pruning.
|
|
62
|
+
|
|
63
|
+
Validators extended using this function will prune extra
|
|
64
|
+
fields, rather than raising a ValidationError, when encountering
|
|
65
|
+
extra, unspecified fiends when `additionalProperties: False` is
|
|
66
|
+
set in the validating schema.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
validator (Type[jsonschema.Validator): A validator class
|
|
70
|
+
to extend with pruning behavior.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Type[jsonschema.Validator]: A validator class that will
|
|
74
|
+
prune extra fields.
|
|
75
|
+
"""
|
|
76
|
+
return jsonschema.validators.extend(validator, {"additionalProperties": prune_additional_properties})
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _get_decimal_info_from_anyof(schema: dict) -> tuple[bool, int | None]:
|
|
80
|
+
"""Check if schema is a Decimal anyOf and extract decimal places.
|
|
81
|
+
|
|
82
|
+
Returns (is_decimal, decimal_places) where decimal_places is None if no constraint.
|
|
83
|
+
"""
|
|
84
|
+
any_of = schema.get("anyOf")
|
|
85
|
+
if not isinstance(any_of, list):
|
|
86
|
+
return False, None
|
|
87
|
+
|
|
88
|
+
has_number = any(item.get("type") == "number" for item in any_of)
|
|
89
|
+
if not has_number:
|
|
90
|
+
return False, None
|
|
91
|
+
|
|
92
|
+
for item in any_of:
|
|
93
|
+
if item.get("type") == "string" and "pattern" in item:
|
|
94
|
+
match = re.search(r"\\d\{0,(\d+)\}", item["pattern"])
|
|
95
|
+
if match:
|
|
96
|
+
return True, int(match.group(1))
|
|
97
|
+
return True, None # Decimal without precision constraint
|
|
98
|
+
return False, None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def normalize_decimal_fields(obj: DataObjectT, schema: JSONSchemaT) -> DataObjectT:
|
|
102
|
+
"""Normalize Decimal-like anyOf fields to floats with proper precision."""
|
|
103
|
+
if not isinstance(obj, dict):
|
|
104
|
+
return obj
|
|
105
|
+
|
|
106
|
+
defs = schema.get("$defs", {})
|
|
107
|
+
obj_schema = defs.get(schema.get("$ref", "")[len("#/$defs/") :], schema)
|
|
108
|
+
props = obj_schema.get("properties", {})
|
|
109
|
+
|
|
110
|
+
for key, value in obj.items():
|
|
111
|
+
field_schema = props.get(key, {})
|
|
112
|
+
if "$ref" in field_schema:
|
|
113
|
+
field_schema = defs.get(field_schema["$ref"][len("#/$defs/") :], {})
|
|
114
|
+
|
|
115
|
+
if isinstance(value, dict):
|
|
116
|
+
obj[key] = normalize_decimal_fields(value, schema)
|
|
117
|
+
elif isinstance(value, list):
|
|
118
|
+
obj[key] = [normalize_decimal_fields(v, schema) if isinstance(v, dict) else v for v in value]
|
|
119
|
+
elif isinstance(value, (int, float, str)) and not isinstance(value, bool):
|
|
120
|
+
is_decimal, decimal_places = _get_decimal_info_from_anyof(field_schema)
|
|
121
|
+
if is_decimal:
|
|
122
|
+
d = Decimal(str(value))
|
|
123
|
+
if decimal_places is not None:
|
|
124
|
+
d = d.quantize(Decimal(f"0.{'0' * decimal_places}"), rounding=ROUND_HALF_UP)
|
|
125
|
+
obj[key] = float(d)
|
|
126
|
+
|
|
127
|
+
return obj
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
## We don't expect the outer data type (e.g. dict, list, or const) to be
|
|
131
|
+
## modified by the pruning action.
|
|
132
|
+
@overload
|
|
133
|
+
def validate(
|
|
134
|
+
obj: dict[str, Any],
|
|
135
|
+
schema: JSONSchemaT,
|
|
136
|
+
pruning: bool = False,
|
|
137
|
+
no_extra_properties: bool = False,
|
|
138
|
+
) -> dict[str, Any]: ...
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@overload
|
|
142
|
+
def validate(
|
|
143
|
+
obj: list[Any],
|
|
144
|
+
schema: JSONSchemaT,
|
|
145
|
+
pruning: bool = False,
|
|
146
|
+
no_extra_properties: bool = False,
|
|
147
|
+
) -> list[Any]: ...
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@overload
|
|
151
|
+
def validate(
|
|
152
|
+
obj: T_primitive,
|
|
153
|
+
schema: JSONSchemaT,
|
|
154
|
+
pruning: bool = False,
|
|
155
|
+
no_extra_properties: bool = False,
|
|
156
|
+
) -> T_primitive: ...
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def validate(
|
|
160
|
+
obj: DataObjectT,
|
|
161
|
+
schema: JSONSchemaT,
|
|
162
|
+
pruning: bool = False,
|
|
163
|
+
no_extra_properties: bool = False,
|
|
164
|
+
) -> DataObjectT:
|
|
165
|
+
"""Validate a data object against a JSONSchema.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
obj (DataObjectT): A data structure to validate against the
|
|
169
|
+
schema.
|
|
170
|
+
schema: (JSONSchemaT): A valid JSONSchema to use to validate
|
|
171
|
+
the provided data object.
|
|
172
|
+
pruning (bool): If set to `True`, then the default behavior for
|
|
173
|
+
`additionalProperties: False` is set to prune non-specified
|
|
174
|
+
properties instead of raising a ValidationError.
|
|
175
|
+
Default: `False`.
|
|
176
|
+
no_extra_properties (bool): If set to `True`, then
|
|
177
|
+
`additionalProperties: False` is set on all the schema
|
|
178
|
+
and all of its sub-schemas. This operation overrides any
|
|
179
|
+
existing settings of `additionalProperties` within the
|
|
180
|
+
schema. Default: `False`.
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
JSONSchemaValidationError: This exception raised in the
|
|
184
|
+
event that the JSONSchema doesn't match the provided
|
|
185
|
+
schema.
|
|
186
|
+
"""
|
|
187
|
+
final_object = deepcopy(obj)
|
|
188
|
+
validator = DEFAULT_JSONSCHEMA_VALIDATOR
|
|
189
|
+
if pruning:
|
|
190
|
+
validator = extend_jsonschema_validator_with_pruning(validator)
|
|
191
|
+
|
|
192
|
+
if no_extra_properties:
|
|
193
|
+
schema = forbid_additional_properties(schema)
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
validator(schema).validate(final_object)
|
|
197
|
+
except jsonschema.ValidationError as exc:
|
|
198
|
+
raise JSONSchemaValidationError(str(exc)) from exc
|
|
199
|
+
|
|
200
|
+
final_object = normalize_decimal_fields(final_object, schema)
|
|
201
|
+
|
|
202
|
+
return final_object
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Processor(ConfigurableTask[TaskConfigT], ABC):
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def process(self, data: DataT, *, current_batch_number: int | None = None) -> DataT: ...
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from data_designer.config.processors import DropColumnsProcessorConfig
|
|
10
|
+
from data_designer.engine.dataset_builders.artifact_storage import BatchStage
|
|
11
|
+
from data_designer.engine.processing.processors.base import Processor
|
|
12
|
+
from data_designer.lazy_heavy_imports import pd
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DropColumnsProcessor(Processor[DropColumnsProcessorConfig]):
|
|
21
|
+
def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
|
|
22
|
+
logger.info(f"🙈 Dropping columns: {self.config.column_names}")
|
|
23
|
+
if current_batch_number is not None: # not in preview mode
|
|
24
|
+
self._save_dropped_columns_if_needed(data, current_batch_number)
|
|
25
|
+
for column in self.config.column_names:
|
|
26
|
+
if column in data.columns:
|
|
27
|
+
data.drop(columns=[column], inplace=True)
|
|
28
|
+
else:
|
|
29
|
+
logger.warning(f"⚠️ Cannot drop column: `{column}` not found in the dataset.")
|
|
30
|
+
return data
|
|
31
|
+
|
|
32
|
+
def _save_dropped_columns_if_needed(self, data: pd.DataFrame, current_batch_number: int) -> None:
|
|
33
|
+
logger.debug("📦 Saving dropped columns to dropped-columns directory")
|
|
34
|
+
dropped_column_parquet_file_name = self.artifact_storage.create_batch_file_path(
|
|
35
|
+
batch_number=current_batch_number,
|
|
36
|
+
batch_stage=BatchStage.DROPPED_COLUMNS,
|
|
37
|
+
).name
|
|
38
|
+
self.artifact_storage.write_parquet_file(
|
|
39
|
+
parquet_file_name=dropped_column_parquet_file_name,
|
|
40
|
+
dataframe=data[self.config.column_names],
|
|
41
|
+
batch_stage=BatchStage.DROPPED_COLUMNS,
|
|
42
|
+
)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.config.base import ConfigBase
|
|
7
|
+
from data_designer.config.processors import (
|
|
8
|
+
DropColumnsProcessorConfig,
|
|
9
|
+
ProcessorType,
|
|
10
|
+
SchemaTransformProcessorConfig,
|
|
11
|
+
)
|
|
12
|
+
from data_designer.engine.processing.processors.base import Processor
|
|
13
|
+
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
|
|
14
|
+
from data_designer.engine.processing.processors.schema_transform import SchemaTransformProcessor
|
|
15
|
+
from data_designer.engine.registry.base import TaskRegistry
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_default_processor_registry() -> ProcessorRegistry:
|
|
22
|
+
registry = ProcessorRegistry()
|
|
23
|
+
registry.register(ProcessorType.SCHEMA_TRANSFORM, SchemaTransformProcessor, SchemaTransformProcessorConfig, False)
|
|
24
|
+
registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False)
|
|
25
|
+
return registry
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from data_designer.config.processors import SchemaTransformProcessorConfig
|
|
11
|
+
from data_designer.engine.dataset_builders.artifact_storage import BatchStage
|
|
12
|
+
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
|
|
13
|
+
from data_designer.engine.processing.processors.base import Processor
|
|
14
|
+
from data_designer.engine.processing.utils import deserialize_json_values
|
|
15
|
+
from data_designer.lazy_heavy_imports import pd
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import pandas as pd
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _json_escape_record(record: dict[str, Any]) -> dict[str, Any]:
|
|
24
|
+
"""Escape record values for safe insertion into a JSON template."""
|
|
25
|
+
|
|
26
|
+
def escape_for_json_string(s: str) -> str:
|
|
27
|
+
"""Use json.dumps to escape, then strip the surrounding quotes."""
|
|
28
|
+
return json.dumps(s)[1:-1]
|
|
29
|
+
|
|
30
|
+
escaped = {}
|
|
31
|
+
for key, value in record.items():
|
|
32
|
+
if isinstance(value, str):
|
|
33
|
+
escaped[key] = escape_for_json_string(value)
|
|
34
|
+
elif isinstance(value, (dict, list)):
|
|
35
|
+
escaped[key] = escape_for_json_string(json.dumps(value))
|
|
36
|
+
elif value is None:
|
|
37
|
+
escaped[key] = "null"
|
|
38
|
+
else:
|
|
39
|
+
escaped[key] = str(value)
|
|
40
|
+
return escaped
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SchemaTransformProcessor(WithJinja2UserTemplateRendering, Processor[SchemaTransformProcessorConfig]):
|
|
44
|
+
@property
|
|
45
|
+
def template_as_str(self) -> str:
|
|
46
|
+
return json.dumps(self.config.template)
|
|
47
|
+
|
|
48
|
+
def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
|
|
49
|
+
self.prepare_jinja2_template_renderer(self.template_as_str, data.columns.to_list())
|
|
50
|
+
formatted_records = []
|
|
51
|
+
for record in data.to_dict(orient="records"):
|
|
52
|
+
deserialized = deserialize_json_values(record)
|
|
53
|
+
escaped = _json_escape_record(deserialized)
|
|
54
|
+
rendered = self.render_template(escaped)
|
|
55
|
+
formatted_records.append(json.loads(rendered))
|
|
56
|
+
formatted_data = pd.DataFrame(formatted_records)
|
|
57
|
+
if current_batch_number is not None:
|
|
58
|
+
self.artifact_storage.write_batch_to_parquet_file(
|
|
59
|
+
batch_number=current_batch_number,
|
|
60
|
+
dataframe=formatted_data,
|
|
61
|
+
batch_stage=BatchStage.PROCESSORS_OUTPUTS,
|
|
62
|
+
subfolder=self.config.name,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
self.artifact_storage.write_parquet_file(
|
|
66
|
+
parquet_file_name=f"{self.config.name}.parquet",
|
|
67
|
+
dataframe=formatted_data,
|
|
68
|
+
batch_stage=BatchStage.PROCESSORS_OUTPUTS,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return data
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import ast
|
|
7
|
+
import copy
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import re
|
|
11
|
+
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
|
12
|
+
|
|
13
|
+
from data_designer.lazy_heavy_imports import pd
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T")
|
|
21
|
+
K = TypeVar("K")
|
|
22
|
+
V = TypeVar("V")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
|
|
26
|
+
_verify_columns_are_unique(datasets)
|
|
27
|
+
_verify_dataset_lengths_are_equal(datasets)
|
|
28
|
+
emoji = " + ".join(["💾"] * len(datasets))
|
|
29
|
+
logger.info(f"({emoji}) Concatenating {len(datasets)} datasets")
|
|
30
|
+
return pd.concat([df for df in datasets], axis=1)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Overloads to help static type checker better understand
|
|
34
|
+
# the input/output types of the deserialize_json_values function.
|
|
35
|
+
@overload
|
|
36
|
+
def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@overload
|
|
40
|
+
def deserialize_json_values(data: list[T]) -> list[Any]: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
def deserialize_json_values(data: dict[K, V]) -> dict[K, Any]: ...
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@overload
|
|
48
|
+
def deserialize_json_values(data: T) -> T: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def deserialize_json_values(data):
|
|
52
|
+
"""De-serialize JSON strings in various input formats.
|
|
53
|
+
|
|
54
|
+
This function creates a deep copy of the input data and does not mutate the original.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
data: Input data in one of four formats:
|
|
58
|
+
- Single string (JSON string to deserialize)
|
|
59
|
+
- List of strings (list of JSON strings to deserialize)
|
|
60
|
+
- Dictionary (potentially with nested JSON strings to deserialize)
|
|
61
|
+
- Some other object that can't be deserialized.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Deserialized data in the corresponding format:
|
|
65
|
+
- Dictionary (when input is a single string)
|
|
66
|
+
- List of dictionaries (when input is a list of strings)
|
|
67
|
+
- Dictionary (when input is a dictionary, with nested JSON strings deserialized)
|
|
68
|
+
- The original object (if there is no deserialization to perform)
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
# Create a deep copy to avoid mutating the original data
|
|
72
|
+
data_copy = copy.deepcopy(data)
|
|
73
|
+
|
|
74
|
+
# Case 1: Single string input
|
|
75
|
+
if isinstance(data_copy, str):
|
|
76
|
+
try:
|
|
77
|
+
return json.loads(data_copy)
|
|
78
|
+
except json.JSONDecodeError:
|
|
79
|
+
return data_copy
|
|
80
|
+
|
|
81
|
+
# Case 2: List of strings input
|
|
82
|
+
elif isinstance(data_copy, list):
|
|
83
|
+
result = []
|
|
84
|
+
for item in data_copy:
|
|
85
|
+
if isinstance(item, str):
|
|
86
|
+
try:
|
|
87
|
+
result.append(json.loads(item))
|
|
88
|
+
except json.JSONDecodeError:
|
|
89
|
+
result.append(item)
|
|
90
|
+
else:
|
|
91
|
+
# If list contains non-string items, recursively process them
|
|
92
|
+
result.append(deserialize_json_values(item))
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
# Case 3: Dictionary input with potential nested JSON strings
|
|
96
|
+
elif isinstance(data_copy, dict):
|
|
97
|
+
result = {}
|
|
98
|
+
for key, value in data_copy.items():
|
|
99
|
+
if isinstance(value, str):
|
|
100
|
+
try:
|
|
101
|
+
result[key] = json.loads(value)
|
|
102
|
+
except json.JSONDecodeError:
|
|
103
|
+
result[key] = value
|
|
104
|
+
elif isinstance(value, (dict, list)):
|
|
105
|
+
# Recursively process nested dictionaries and lists
|
|
106
|
+
result[key] = deserialize_json_values(value)
|
|
107
|
+
else:
|
|
108
|
+
result[key] = value
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
# Fallback for other data types
|
|
112
|
+
else:
|
|
113
|
+
return data_copy
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def parse_list_string(text: str) -> list[str]:
|
|
117
|
+
"""Parse a list from a string, handling JSON arrays, Python lists, and trailing commas."""
|
|
118
|
+
text = text.strip()
|
|
119
|
+
|
|
120
|
+
# Try JSON first
|
|
121
|
+
try:
|
|
122
|
+
list_obj = json.loads(text)
|
|
123
|
+
if isinstance(list_obj, list):
|
|
124
|
+
return _clean_whitespace(list_obj)
|
|
125
|
+
except json.JSONDecodeError:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
# Remove trailing commas before closing brackets (common in JSON-like strings)
|
|
129
|
+
text_cleaned = re.sub(r",\s*]", "]", text)
|
|
130
|
+
text_cleaned = re.sub(r",\s*}", "}", text_cleaned)
|
|
131
|
+
|
|
132
|
+
# Try JSON again with cleaned text
|
|
133
|
+
try:
|
|
134
|
+
return _clean_whitespace(json.loads(text_cleaned))
|
|
135
|
+
except json.JSONDecodeError:
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
# Try Python literal eval (handles single quotes)
|
|
139
|
+
try:
|
|
140
|
+
return _clean_whitespace(ast.literal_eval(text_cleaned))
|
|
141
|
+
except (ValueError, SyntaxError):
|
|
142
|
+
pass
|
|
143
|
+
|
|
144
|
+
# If all else fails, return the original text
|
|
145
|
+
return [text.strip()]
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _clean_whitespace(texts: list[str]) -> list[str]:
|
|
149
|
+
return [text.strip() for text in texts]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
|
|
153
|
+
joined_columns = set()
|
|
154
|
+
for df in datasets:
|
|
155
|
+
columns = set(df.columns)
|
|
156
|
+
overlapping_columns = joined_columns & columns
|
|
157
|
+
if len(overlapping_columns) > 0:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"🛑 Input datasets have overlapping columns: {overlapping_columns} "
|
|
160
|
+
"Please ensure that the column names are unique."
|
|
161
|
+
)
|
|
162
|
+
joined_columns.update(columns)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _verify_dataset_lengths_are_equal(datasets: list[pd.DataFrame]) -> None:
|
|
166
|
+
if len(set([len(df) for df in datasets])) > 1:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
"🛑 Input datasets have different lengths. Please ensure that the datasets have the same number of rows."
|
|
169
|
+
)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import threading
|
|
7
|
+
from typing import Any, Generic, TypeVar
|
|
8
|
+
|
|
9
|
+
from data_designer.config.base import ConfigBase
|
|
10
|
+
from data_designer.config.utils.type_helpers import StrEnum
|
|
11
|
+
from data_designer.engine.configurable_task import ConfigurableTask
|
|
12
|
+
from data_designer.engine.registry.errors import NotFoundInRegistryError, RegistryItemNotTypeError
|
|
13
|
+
|
|
14
|
+
EnumNameT = TypeVar("EnumNameT", bound=StrEnum)
|
|
15
|
+
TaskT = TypeVar("TaskT", bound=ConfigurableTask)
|
|
16
|
+
TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
20
|
+
# registered type name -> type
|
|
21
|
+
_registry: dict[EnumNameT, type[TaskT]] = {}
|
|
22
|
+
# type -> registered type name
|
|
23
|
+
_reverse_registry: dict[type[TaskT], EnumNameT] = {}
|
|
24
|
+
|
|
25
|
+
# registered type name -> config type
|
|
26
|
+
_config_registry: dict[EnumNameT, type[TaskConfigT]] = {}
|
|
27
|
+
# config type -> registered type name
|
|
28
|
+
_reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {}
|
|
29
|
+
|
|
30
|
+
# all registries are singletons
|
|
31
|
+
_instance = None
|
|
32
|
+
_lock = threading.Lock()
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def register(
|
|
36
|
+
cls,
|
|
37
|
+
name: EnumNameT,
|
|
38
|
+
task: type[TaskT],
|
|
39
|
+
config: type[TaskConfigT],
|
|
40
|
+
raise_on_collision: bool = False,
|
|
41
|
+
) -> None:
|
|
42
|
+
if cls._has_been_registered(name):
|
|
43
|
+
if not raise_on_collision:
|
|
44
|
+
return
|
|
45
|
+
raise ValueError(f"{name} has already been registered!")
|
|
46
|
+
|
|
47
|
+
cls._raise_if_not_type(task)
|
|
48
|
+
cls._raise_if_not_type(config)
|
|
49
|
+
|
|
50
|
+
with cls._lock:
|
|
51
|
+
cls._registry[name] = task
|
|
52
|
+
cls._reverse_registry[task] = name
|
|
53
|
+
cls._config_registry[name] = config
|
|
54
|
+
cls._reverse_config_registry[config] = name
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def get_task_type(cls, name: EnumNameT) -> type[TaskT]:
|
|
58
|
+
cls._raise_if_not_registered(name, cls._registry)
|
|
59
|
+
return cls._registry[name]
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]:
|
|
63
|
+
cls._raise_if_not_registered(name, cls._config_registry)
|
|
64
|
+
return cls._config_registry[name]
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def get_registered_name(cls, task: type[TaskT]) -> EnumNameT:
|
|
68
|
+
cls._raise_if_not_registered(task, cls._reverse_registry)
|
|
69
|
+
return cls._reverse_registry[task]
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]:
|
|
73
|
+
cls._raise_if_not_registered(config, cls._reverse_config_registry)
|
|
74
|
+
name = cls._reverse_config_registry[config]
|
|
75
|
+
return cls.get_task_type(name)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def _has_been_registered(cls, name: EnumNameT) -> bool:
|
|
79
|
+
return name in cls._registry
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None:
|
|
83
|
+
if not (isinstance(key, StrEnum) or isinstance(key, str)):
|
|
84
|
+
cls._raise_if_not_type(key)
|
|
85
|
+
if key not in mapping:
|
|
86
|
+
raise NotFoundInRegistryError(f"{key} not found in registry")
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def _raise_if_not_type(cls, obj: Any) -> None:
|
|
90
|
+
if not isinstance(obj, type):
|
|
91
|
+
raise RegistryItemNotTypeError(f"{obj} is not a class!")
|
|
92
|
+
|
|
93
|
+
def __new__(cls, *args, **kwargs):
|
|
94
|
+
"""Registry is a singleton."""
|
|
95
|
+
if not cls._instance:
|
|
96
|
+
with cls._lock:
|
|
97
|
+
if not cls._instance:
|
|
98
|
+
cls._instance = super().__new__(cls)
|
|
99
|
+
return cls._instance
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.engine.analysis.column_profilers.registry import (
|
|
7
|
+
ColumnProfilerRegistry,
|
|
8
|
+
create_default_column_profiler_registry,
|
|
9
|
+
)
|
|
10
|
+
from data_designer.engine.column_generators.registry import (
|
|
11
|
+
ColumnGeneratorRegistry,
|
|
12
|
+
create_default_column_generator_registry,
|
|
13
|
+
)
|
|
14
|
+
from data_designer.engine.processing.processors.registry import ProcessorRegistry, create_default_processor_registry
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DataDesignerRegistry:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
*,
|
|
21
|
+
column_generator_registry: ColumnGeneratorRegistry | None = None,
|
|
22
|
+
column_profiler_registry: ColumnProfilerRegistry | None = None,
|
|
23
|
+
processor_registry: ProcessorRegistry | None = None,
|
|
24
|
+
):
|
|
25
|
+
self._column_generator_registry = column_generator_registry or create_default_column_generator_registry()
|
|
26
|
+
self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry()
|
|
27
|
+
self._processor_registry = processor_registry or create_default_processor_registry()
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def column_generators(self) -> ColumnGeneratorRegistry:
|
|
31
|
+
return self._column_generator_registry
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def column_profilers(self) -> ColumnProfilerRegistry:
|
|
35
|
+
return self._column_profiler_registry
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def processors(self) -> ProcessorRegistry:
|
|
39
|
+
return self._processor_registry
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from data_designer.engine.errors import DataDesignerError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NotFoundInRegistryError(DataDesignerError): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RegistryItemNotTypeError(DataDesignerError): ...
|