data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +14 -1
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +5 -4
- data_designer/config/processors.py +109 -4
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +39 -9
- data_designer/config/utils/visualization.py +62 -15
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +14 -5
- data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/processors/drop_columns.py +1 -1
- data_designer/engine/processing/processors/registry.py +3 -0
- data_designer/engine/processing/processors/schema_transform.py +53 -0
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/interface/data_designer.py +12 -0
- data_designer/interface/results.py +36 -0
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Protocol, runtime_checkable
|
|
5
5
|
|
|
6
6
|
from lxml.etree import _Element
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
@@ -30,7 +30,7 @@ class LLMStructuredResponse(BaseModel):
|
|
|
30
30
|
out.parsed = out.parsed[-n:]
|
|
31
31
|
return out
|
|
32
32
|
|
|
33
|
-
def filter(self, block_types: list[
|
|
33
|
+
def filter(self, block_types: list[type[BaseModel]]) -> Self:
|
|
34
34
|
out = self.model_copy()
|
|
35
35
|
out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
|
|
36
36
|
return out
|
|
@@ -44,7 +44,7 @@ class TagParser(Protocol):
|
|
|
44
44
|
element, do some computation, and return some kind of structured
|
|
45
45
|
output, represented as a subclass of Pydantic `BaseModel`.
|
|
46
46
|
This protocol implementation can cover both classes as well
|
|
47
|
-
as curried
|
|
47
|
+
as curried functions as parsers (e.g. `partial`).
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
50
|
def __call__(self, element: _Element) -> BaseModel: ...
|
|
@@ -69,7 +69,7 @@ class TextBlock(BaseModel):
|
|
|
69
69
|
|
|
70
70
|
class CodeBlock(BaseModel):
|
|
71
71
|
code: str
|
|
72
|
-
code_lang:
|
|
72
|
+
code_lang: str | None = None
|
|
73
73
|
|
|
74
74
|
|
|
75
75
|
class StructuredDataBlock(BaseModel):
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
|
|
8
|
-
from data_designer.config.models import ModelConfig
|
|
8
|
+
from data_designer.config.models import GenerationType, ModelConfig
|
|
9
9
|
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
|
|
10
10
|
from data_designer.engine.models.facade import ModelFacade
|
|
11
11
|
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
@@ -73,7 +73,7 @@ class ModelRegistry:
|
|
|
73
73
|
model_config = self.get_model_config(model_alias=model_alias)
|
|
74
74
|
return self._model_provider_registry.get_provider(model_config.provider)
|
|
75
75
|
|
|
76
|
-
def run_health_check(self, model_aliases:
|
|
76
|
+
def run_health_check(self, model_aliases: list[str]) -> None:
|
|
77
77
|
logger.info("🩺 Running health checks for models...")
|
|
78
78
|
for model_alias in model_aliases:
|
|
79
79
|
model = self.get_model(model_alias=model_alias)
|
|
@@ -81,15 +81,24 @@ class ModelRegistry:
|
|
|
81
81
|
f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
|
|
82
82
|
)
|
|
83
83
|
try:
|
|
84
|
-
model.
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
84
|
+
if model.model_generation_type == GenerationType.EMBEDDING:
|
|
85
|
+
model.generate_text_embeddings(
|
|
86
|
+
input_texts=["Hello!"],
|
|
87
|
+
skip_usage_tracking=True,
|
|
88
|
+
purpose="running health checks",
|
|
89
|
+
)
|
|
90
|
+
elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
|
|
91
|
+
model.generate(
|
|
92
|
+
prompt="Hello!",
|
|
93
|
+
parser=lambda x: x,
|
|
94
|
+
system_prompt="You are a helpful assistant.",
|
|
95
|
+
max_correction_steps=0,
|
|
96
|
+
max_conversation_restarts=0,
|
|
97
|
+
skip_usage_tracking=True,
|
|
98
|
+
purpose="running health checks",
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
|
|
93
102
|
logger.info(" |-- ✅ Passed!")
|
|
94
103
|
except Exception as e:
|
|
95
104
|
logger.error(" |-- ❌ Failed!")
|
|
@@ -11,20 +11,20 @@ logger = logging.getLogger(__name__)
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TokenUsageStats(BaseModel):
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
input_tokens: int = 0
|
|
15
|
+
output_tokens: int = 0
|
|
16
16
|
|
|
17
17
|
@computed_field
|
|
18
18
|
def total_tokens(self) -> int:
|
|
19
|
-
return self.
|
|
19
|
+
return self.input_tokens + self.output_tokens
|
|
20
20
|
|
|
21
21
|
@property
|
|
22
22
|
def has_usage(self) -> bool:
|
|
23
23
|
return self.total_tokens > 0
|
|
24
24
|
|
|
25
|
-
def extend(self, *,
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
25
|
+
def extend(self, *, input_tokens: int, output_tokens: int) -> None:
|
|
26
|
+
self.input_tokens += input_tokens
|
|
27
|
+
self.output_tokens += output_tokens
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class RequestUsageStats(BaseModel):
|
|
@@ -56,9 +56,7 @@ class ModelUsageStats(BaseModel):
|
|
|
56
56
|
self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
|
|
57
57
|
) -> None:
|
|
58
58
|
if token_usage is not None:
|
|
59
|
-
self.token_usage.extend(
|
|
60
|
-
prompt_tokens=token_usage.prompt_tokens, completion_tokens=token_usage.completion_tokens
|
|
61
|
-
)
|
|
59
|
+
self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
|
|
62
60
|
if request_usage is not None:
|
|
63
61
|
self.request_usage.extend(
|
|
64
62
|
successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from collections import deque
|
|
5
|
-
from typing import Optional, Type
|
|
6
5
|
|
|
7
6
|
from jinja2 import nodes as j_nodes
|
|
8
7
|
|
|
@@ -33,7 +32,7 @@ def ast_max_depth(node: j_nodes.Node) -> int:
|
|
|
33
32
|
return max_depth
|
|
34
33
|
|
|
35
34
|
|
|
36
|
-
def ast_descendant_count(ast: j_nodes.Node, only_type:
|
|
35
|
+
def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int:
|
|
37
36
|
"""Count the number of nodes which descend from the given node.
|
|
38
37
|
|
|
39
38
|
Args:
|
|
@@ -17,7 +17,7 @@ class DropColumnsProcessor(Processor[DropColumnsProcessorConfig]):
|
|
|
17
17
|
@staticmethod
|
|
18
18
|
def metadata() -> ConfigurableTaskMetadata:
|
|
19
19
|
return ConfigurableTaskMetadata(
|
|
20
|
-
name="
|
|
20
|
+
name="drop_columns_processor",
|
|
21
21
|
description="Drop columns from the input dataset.",
|
|
22
22
|
required_resources=None,
|
|
23
23
|
)
|
|
@@ -5,9 +5,11 @@ from data_designer.config.base import ConfigBase
|
|
|
5
5
|
from data_designer.config.processors import (
|
|
6
6
|
DropColumnsProcessorConfig,
|
|
7
7
|
ProcessorType,
|
|
8
|
+
SchemaTransformProcessorConfig,
|
|
8
9
|
)
|
|
9
10
|
from data_designer.engine.processing.processors.base import Processor
|
|
10
11
|
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
|
|
12
|
+
from data_designer.engine.processing.processors.schema_transform import SchemaTransformProcessor
|
|
11
13
|
from data_designer.engine.registry.base import TaskRegistry
|
|
12
14
|
|
|
13
15
|
|
|
@@ -16,5 +18,6 @@ class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ...
|
|
|
16
18
|
|
|
17
19
|
def create_default_processor_registry() -> ProcessorRegistry:
|
|
18
20
|
registry = ProcessorRegistry()
|
|
21
|
+
registry.register(ProcessorType.SCHEMA_TRANSFORM, SchemaTransformProcessor, SchemaTransformProcessorConfig, False)
|
|
19
22
|
registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False)
|
|
20
23
|
return registry
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from data_designer.config.processors import SchemaTransformProcessorConfig
|
|
10
|
+
from data_designer.engine.configurable_task import ConfigurableTaskMetadata
|
|
11
|
+
from data_designer.engine.dataset_builders.artifact_storage import BatchStage
|
|
12
|
+
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
|
|
13
|
+
from data_designer.engine.processing.processors.base import Processor
|
|
14
|
+
from data_designer.engine.processing.utils import deserialize_json_values
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SchemaTransformProcessor(WithJinja2UserTemplateRendering, Processor[SchemaTransformProcessorConfig]):
|
|
20
|
+
@staticmethod
|
|
21
|
+
def metadata() -> ConfigurableTaskMetadata:
|
|
22
|
+
return ConfigurableTaskMetadata(
|
|
23
|
+
name="schema_transform_processor",
|
|
24
|
+
description="Generate dataset with transformed schema using a Jinja2 template.",
|
|
25
|
+
required_resources=None,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def template_as_str(self) -> str:
|
|
30
|
+
return json.dumps(self.config.template)
|
|
31
|
+
|
|
32
|
+
def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
|
|
33
|
+
self.prepare_jinja2_template_renderer(self.template_as_str, data.columns.to_list())
|
|
34
|
+
formatted_records = [
|
|
35
|
+
json.loads(self.render_template(deserialize_json_values(record)).replace("\n", "\\n"))
|
|
36
|
+
for record in data.to_dict(orient="records")
|
|
37
|
+
]
|
|
38
|
+
formatted_data = pd.DataFrame(formatted_records)
|
|
39
|
+
if current_batch_number is not None:
|
|
40
|
+
self.artifact_storage.write_batch_to_parquet_file(
|
|
41
|
+
batch_number=current_batch_number,
|
|
42
|
+
dataframe=formatted_data,
|
|
43
|
+
batch_stage=BatchStage.PROCESSORS_OUTPUTS,
|
|
44
|
+
subfolder=self.config.name,
|
|
45
|
+
)
|
|
46
|
+
else:
|
|
47
|
+
self.artifact_storage.write_parquet_file(
|
|
48
|
+
parquet_file_name=f"{self.config.name}.parquet",
|
|
49
|
+
dataframe=formatted_data,
|
|
50
|
+
batch_stage=BatchStage.PROCESSORS_OUTPUTS,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return data
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import ast
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
6
|
-
|
|
7
|
+
import re
|
|
8
|
+
from typing import Any, TypeVar, overload
|
|
7
9
|
|
|
8
10
|
import pandas as pd
|
|
9
11
|
|
|
@@ -25,7 +27,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
|
|
|
25
27
|
# Overloads to help static type checker better understand
|
|
26
28
|
# the input/output types of the deserialize_json_values function.
|
|
27
29
|
@overload
|
|
28
|
-
def deserialize_json_values(data: str) ->
|
|
30
|
+
def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ...
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
@overload
|
|
@@ -100,6 +102,42 @@ def deserialize_json_values(data):
|
|
|
100
102
|
return data
|
|
101
103
|
|
|
102
104
|
|
|
105
|
+
def parse_list_string(text: str) -> list[str]:
|
|
106
|
+
"""Parse a list from a string, handling JSON arrays, Python lists, and trailing commas."""
|
|
107
|
+
text = text.strip()
|
|
108
|
+
|
|
109
|
+
# Try JSON first
|
|
110
|
+
try:
|
|
111
|
+
list_obj = json.loads(text)
|
|
112
|
+
if isinstance(list_obj, list):
|
|
113
|
+
return _clean_whitespace(list_obj)
|
|
114
|
+
except json.JSONDecodeError:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
# Remove trailing commas before closing brackets (common in JSON-like strings)
|
|
118
|
+
text_cleaned = re.sub(r",\s*]", "]", text)
|
|
119
|
+
text_cleaned = re.sub(r",\s*}", "}", text_cleaned)
|
|
120
|
+
|
|
121
|
+
# Try JSON again with cleaned text
|
|
122
|
+
try:
|
|
123
|
+
return _clean_whitespace(json.loads(text_cleaned))
|
|
124
|
+
except json.JSONDecodeError:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
# Try Python literal eval (handles single quotes)
|
|
128
|
+
try:
|
|
129
|
+
return _clean_whitespace(ast.literal_eval(text_cleaned))
|
|
130
|
+
except (ValueError, SyntaxError):
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# If all else fails, return the original text
|
|
134
|
+
return [text.strip()]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _clean_whitespace(texts: list[str]) -> list[str]:
|
|
138
|
+
return [text.strip() for text in texts]
|
|
139
|
+
|
|
140
|
+
|
|
103
141
|
def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
|
|
104
142
|
joined_columns = set()
|
|
105
143
|
for df in datasets:
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
import threading
|
|
5
|
-
from typing import Any, Generic,
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
6
|
|
|
7
7
|
from data_designer.config.base import ConfigBase
|
|
8
8
|
from data_designer.config.utils.type_helpers import StrEnum
|
|
@@ -16,14 +16,14 @@ TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
|
|
|
16
16
|
|
|
17
17
|
class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
18
18
|
# registered type name -> type
|
|
19
|
-
_registry: dict[EnumNameT,
|
|
19
|
+
_registry: dict[EnumNameT, type[TaskT]] = {}
|
|
20
20
|
# type -> registered type name
|
|
21
|
-
_reverse_registry: dict[
|
|
21
|
+
_reverse_registry: dict[type[TaskT], EnumNameT] = {}
|
|
22
22
|
|
|
23
23
|
# registered type name -> config type
|
|
24
|
-
_config_registry: dict[EnumNameT,
|
|
24
|
+
_config_registry: dict[EnumNameT, type[TaskConfigT]] = {}
|
|
25
25
|
# config type -> registered type name
|
|
26
|
-
_reverse_config_registry: dict[
|
|
26
|
+
_reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {}
|
|
27
27
|
|
|
28
28
|
# all registries are singletons
|
|
29
29
|
_instance = None
|
|
@@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
33
33
|
def register(
|
|
34
34
|
cls,
|
|
35
35
|
name: EnumNameT,
|
|
36
|
-
task:
|
|
37
|
-
config:
|
|
36
|
+
task: type[TaskT],
|
|
37
|
+
config: type[TaskConfigT],
|
|
38
38
|
raise_on_collision: bool = False,
|
|
39
39
|
) -> None:
|
|
40
40
|
if cls._has_been_registered(name):
|
|
@@ -52,22 +52,22 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
52
52
|
cls._reverse_config_registry[config] = name
|
|
53
53
|
|
|
54
54
|
@classmethod
|
|
55
|
-
def get_task_type(cls, name: EnumNameT) ->
|
|
55
|
+
def get_task_type(cls, name: EnumNameT) -> type[TaskT]:
|
|
56
56
|
cls._raise_if_not_registered(name, cls._registry)
|
|
57
57
|
return cls._registry[name]
|
|
58
58
|
|
|
59
59
|
@classmethod
|
|
60
|
-
def get_config_type(cls, name: EnumNameT) ->
|
|
60
|
+
def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]:
|
|
61
61
|
cls._raise_if_not_registered(name, cls._config_registry)
|
|
62
62
|
return cls._config_registry[name]
|
|
63
63
|
|
|
64
64
|
@classmethod
|
|
65
|
-
def get_registered_name(cls, task:
|
|
65
|
+
def get_registered_name(cls, task: type[TaskT]) -> EnumNameT:
|
|
66
66
|
cls._raise_if_not_registered(task, cls._reverse_registry)
|
|
67
67
|
return cls._reverse_registry[task]
|
|
68
68
|
|
|
69
69
|
@classmethod
|
|
70
|
-
def get_for_config_type(cls, config:
|
|
70
|
+
def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]:
|
|
71
71
|
cls._raise_if_not_registered(config, cls._reverse_config_registry)
|
|
72
72
|
name = cls._reverse_config_registry[config]
|
|
73
73
|
return cls.get_task_type(name)
|
|
@@ -77,7 +77,7 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
77
77
|
return name in cls._registry
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def _raise_if_not_registered(cls, key: EnumNameT |
|
|
80
|
+
def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None:
|
|
81
81
|
if not (isinstance(key, StrEnum) or isinstance(key, str)):
|
|
82
82
|
cls._raise_if_not_type(key)
|
|
83
83
|
if key not in mapping:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
|
-
from typing import Type
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
8
7
|
import pandas as pd
|
|
@@ -91,5 +90,5 @@ CONSTRAINT_TYPE_TO_CHECKER = {
|
|
|
91
90
|
}
|
|
92
91
|
|
|
93
92
|
|
|
94
|
-
def get_constraint_checker(constraint_type: ConstraintType) ->
|
|
93
|
+
def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
|
|
95
94
|
return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
|
-
from typing import Any, Generic,
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pandas as pd
|
|
@@ -45,7 +45,7 @@ class PassthroughMixin:
|
|
|
45
45
|
return series
|
|
46
46
|
|
|
47
47
|
@staticmethod
|
|
48
|
-
def validate_data_conversion(convert_to:
|
|
48
|
+
def validate_data_conversion(convert_to: str | None) -> None:
|
|
49
49
|
pass
|
|
50
50
|
|
|
51
51
|
|
|
@@ -71,7 +71,7 @@ class TypeConversionMixin:
|
|
|
71
71
|
return series
|
|
72
72
|
|
|
73
73
|
@staticmethod
|
|
74
|
-
def postproc(series: pd.Series, convert_to:
|
|
74
|
+
def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
|
|
75
75
|
if convert_to is not None:
|
|
76
76
|
if convert_to == "int":
|
|
77
77
|
series = series.round()
|
|
@@ -79,18 +79,18 @@ class TypeConversionMixin:
|
|
|
79
79
|
return series
|
|
80
80
|
|
|
81
81
|
@staticmethod
|
|
82
|
-
def validate_data_conversion(convert_to:
|
|
82
|
+
def validate_data_conversion(convert_to: str | None) -> None:
|
|
83
83
|
if convert_to is not None and convert_to not in ["float", "int", "str"]:
|
|
84
84
|
raise ValueError(f"Invalid `convert_to` value: {convert_to}. Must be one of: [float, int, str]")
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
class DatetimeFormatMixin:
|
|
88
88
|
@staticmethod
|
|
89
|
-
def preproc(series: pd.Series, convert_to:
|
|
89
|
+
def preproc(series: pd.Series, convert_to: str | None) -> pd.Series:
|
|
90
90
|
return series
|
|
91
91
|
|
|
92
92
|
@staticmethod
|
|
93
|
-
def postproc(series: pd.Series, convert_to:
|
|
93
|
+
def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
|
|
94
94
|
if convert_to is not None:
|
|
95
95
|
return series.dt.strftime(convert_to)
|
|
96
96
|
if series.dt.month.nunique() == 1:
|
|
@@ -104,7 +104,7 @@ class DatetimeFormatMixin:
|
|
|
104
104
|
return series.apply(lambda dt: dt.isoformat()).astype(str)
|
|
105
105
|
|
|
106
106
|
@staticmethod
|
|
107
|
-
def validate_data_conversion(convert_to:
|
|
107
|
+
def validate_data_conversion(convert_to: str | None) -> None:
|
|
108
108
|
if convert_to is not None:
|
|
109
109
|
try:
|
|
110
110
|
pd.to_datetime(pd.to_datetime("2012-12-21").strftime(convert_to))
|
|
@@ -121,7 +121,7 @@ class DataSource(ABC, Generic[GenericParamsT]):
|
|
|
121
121
|
def __init__(
|
|
122
122
|
self,
|
|
123
123
|
params: GenericParamsT,
|
|
124
|
-
random_state:
|
|
124
|
+
random_state: RadomStateT | None = None,
|
|
125
125
|
**kwargs,
|
|
126
126
|
):
|
|
127
127
|
self.rng = check_random_state(random_state)
|
|
@@ -130,7 +130,7 @@ class DataSource(ABC, Generic[GenericParamsT]):
|
|
|
130
130
|
self._validate()
|
|
131
131
|
|
|
132
132
|
@classmethod
|
|
133
|
-
def get_param_type(cls) ->
|
|
133
|
+
def get_param_type(cls) -> type[GenericParamsT]:
|
|
134
134
|
return cls.__orig_bases__[-1].__args__[0]
|
|
135
135
|
|
|
136
136
|
@abstractmethod
|
|
@@ -138,7 +138,7 @@ class DataSource(ABC, Generic[GenericParamsT]):
|
|
|
138
138
|
self,
|
|
139
139
|
dataframe: pd.DataFrame,
|
|
140
140
|
column_name: str,
|
|
141
|
-
index:
|
|
141
|
+
index: list[int] | None = None,
|
|
142
142
|
) -> pd.DataFrame: ...
|
|
143
143
|
|
|
144
144
|
@staticmethod
|
|
@@ -147,11 +147,11 @@ class DataSource(ABC, Generic[GenericParamsT]):
|
|
|
147
147
|
|
|
148
148
|
@staticmethod
|
|
149
149
|
@abstractmethod
|
|
150
|
-
def postproc(series: pd.Series, convert_to:
|
|
150
|
+
def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: ...
|
|
151
151
|
|
|
152
152
|
@staticmethod
|
|
153
153
|
@abstractmethod
|
|
154
|
-
def validate_data_conversion(convert_to:
|
|
154
|
+
def validate_data_conversion(convert_to: str | None) -> None: ...
|
|
155
155
|
|
|
156
156
|
def get_required_column_names(self) -> tuple[str, ...]:
|
|
157
157
|
return tuple()
|
|
@@ -182,7 +182,7 @@ class Sampler(DataSource[GenericParamsT], ABC):
|
|
|
182
182
|
self,
|
|
183
183
|
dataframe: pd.DataFrame,
|
|
184
184
|
column_name: str,
|
|
185
|
-
index:
|
|
185
|
+
index: list[int] | None = None,
|
|
186
186
|
) -> pd.DataFrame:
|
|
187
187
|
index = slice(None) if index is None else index
|
|
188
188
|
|
|
@@ -208,7 +208,7 @@ class Sampler(DataSource[GenericParamsT], ABC):
|
|
|
208
208
|
class ScipyStatsSampler(Sampler[GenericParamsT], ABC):
|
|
209
209
|
@property
|
|
210
210
|
@abstractmethod
|
|
211
|
-
def distribution(self) ->
|
|
211
|
+
def distribution(self) -> stats.rv_continuous | stats.rv_discrete: ...
|
|
212
212
|
|
|
213
213
|
def sample(self, num_samples: int) -> NumpyArray1dT:
|
|
214
214
|
return self.distribution.rvs(size=num_samples, random_state=self.rng)
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
import random
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Optional
|
|
7
6
|
|
|
8
7
|
import pandas as pd
|
|
9
8
|
from pydantic import BaseModel, Field, field_validator
|
|
@@ -13,7 +12,7 @@ ZIPCODE_AREA_CODE_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DA
|
|
|
13
12
|
ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"]))
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
def get_area_code(zip_prefix:
|
|
15
|
+
def get_area_code(zip_prefix: str | None = None) -> str:
|
|
17
16
|
"""
|
|
18
17
|
Sample an area code for the given ZIP code prefix, population-weighted.
|
|
19
18
|
|
|
@@ -8,12 +8,12 @@ import uuid
|
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
9
|
from collections.abc import Callable
|
|
10
10
|
from copy import deepcopy
|
|
11
|
-
from typing import TYPE_CHECKING, Any,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, TypeAlias
|
|
12
12
|
|
|
13
13
|
import pandas as pd
|
|
14
14
|
from faker import Faker
|
|
15
15
|
|
|
16
|
-
from data_designer.config.utils.constants import
|
|
16
|
+
from data_designer.config.utils.constants import DEFAULT_AGE_RANGE
|
|
17
17
|
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
|
|
18
18
|
from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
|
|
19
19
|
from data_designer.engine.sampling_gen.entities.person import (
|
|
@@ -27,17 +27,13 @@ if TYPE_CHECKING:
|
|
|
27
27
|
from data_designer.engine.sampling_gen.schema import DataSchema
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
EngineT =
|
|
30
|
+
EngineT: TypeAlias = Faker | ManagedDatasetGenerator
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class PeopleGen(ABC):
|
|
34
34
|
"""Unified interface for generating people data."""
|
|
35
35
|
|
|
36
36
|
def __init__(self, engine: EngineT, locale: str):
|
|
37
|
-
if locale not in AVAILABLE_LOCALES:
|
|
38
|
-
raise ValueError(
|
|
39
|
-
f"Locale {locale} is not a supported locale.Supported locales: {', '.join(AVAILABLE_LOCALES)}"
|
|
40
|
-
)
|
|
41
37
|
self.locale = locale
|
|
42
38
|
self._engine = engine
|
|
43
39
|
|
|
@@ -2,14 +2,14 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
|
-
from typing import Iterator
|
|
5
|
+
from typing import Iterator
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, ConfigDict
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class ValidationOutput(BaseModel):
|
|
12
|
-
is_valid:
|
|
12
|
+
is_valid: bool | None
|
|
13
13
|
model_config = ConfigDict(extra="allow")
|
|
14
14
|
|
|
15
15
|
|
|
@@ -249,6 +249,17 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
249
249
|
except Exception as e:
|
|
250
250
|
raise DataDesignerProfilingError(f"🛑 Error profiling preview dataset: {e}")
|
|
251
251
|
|
|
252
|
+
if builder.artifact_storage.processors_outputs_path.exists():
|
|
253
|
+
processor_artifacts = {
|
|
254
|
+
processor_config.name: pd.read_parquet(
|
|
255
|
+
builder.artifact_storage.processors_outputs_path / f"{processor_config.name}.parquet",
|
|
256
|
+
dtype_backend="pyarrow",
|
|
257
|
+
).to_dict(orient="records")
|
|
258
|
+
for processor_config in config_builder.get_processor_configs()
|
|
259
|
+
}
|
|
260
|
+
else:
|
|
261
|
+
processor_artifacts = {}
|
|
262
|
+
|
|
252
263
|
if (
|
|
253
264
|
len(processed_dataset) > 0
|
|
254
265
|
and isinstance(analysis, DatasetProfilerResults)
|
|
@@ -259,6 +270,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
|
|
|
259
270
|
return PreviewResults(
|
|
260
271
|
dataset=processed_dataset,
|
|
261
272
|
analysis=analysis,
|
|
273
|
+
processor_artifacts=processor_artifacts,
|
|
262
274
|
config_builder=config_builder,
|
|
263
275
|
)
|
|
264
276
|
|
|
@@ -3,12 +3,15 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
6
8
|
import pandas as pd
|
|
7
9
|
|
|
8
10
|
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
9
11
|
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
10
12
|
from data_designer.config.utils.visualization import WithRecordSamplerMixin
|
|
11
13
|
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
|
|
14
|
+
from data_designer.engine.dataset_builders.errors import ArtifactStorageError
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
class DatasetCreationResults(WithRecordSamplerMixin):
|
|
@@ -53,3 +56,36 @@ class DatasetCreationResults(WithRecordSamplerMixin):
|
|
|
53
56
|
A pandas DataFrame containing the full generated dataset.
|
|
54
57
|
"""
|
|
55
58
|
return self.artifact_storage.load_dataset()
|
|
59
|
+
|
|
60
|
+
def load_processor_dataset(self, processor_name: str) -> pd.DataFrame:
|
|
61
|
+
"""Load the dataset generated by a processor.
|
|
62
|
+
|
|
63
|
+
This only works for processors that write their artifacts in Parquet format.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
processor_name: The name of the processor to load the dataset from.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A pandas DataFrame containing the dataset generated by the processor.
|
|
70
|
+
"""
|
|
71
|
+
try:
|
|
72
|
+
dataset = self.artifact_storage.read_parquet_files(
|
|
73
|
+
self.artifact_storage.processors_outputs_path / processor_name
|
|
74
|
+
)
|
|
75
|
+
except Exception as e:
|
|
76
|
+
raise ArtifactStorageError(f"Failed to load dataset for processor {processor_name}: {e}")
|
|
77
|
+
|
|
78
|
+
return dataset
|
|
79
|
+
|
|
80
|
+
def get_path_to_processor_artifacts(self, processor_name: str) -> Path:
|
|
81
|
+
"""Get the path to the artifacts generated by a processor.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
processor_name: The name of the processor to load the artifact from.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The path to the artifacts.
|
|
88
|
+
"""
|
|
89
|
+
if not self.artifact_storage.processors_outputs_path.exists():
|
|
90
|
+
raise ArtifactStorageError(f"Processor {processor_name} has no artifacts.")
|
|
91
|
+
return self.artifact_storage.processors_outputs_path / processor_name
|
data_designer/logging.py
CHANGED
|
@@ -6,7 +6,7 @@ import random
|
|
|
6
6
|
import sys
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import TextIO
|
|
9
|
+
from typing import TextIO
|
|
10
10
|
|
|
11
11
|
from pythonjsonlogger import jsonlogger
|
|
12
12
|
|
|
@@ -19,7 +19,7 @@ class LoggerConfig:
|
|
|
19
19
|
|
|
20
20
|
@dataclass
|
|
21
21
|
class OutputConfig:
|
|
22
|
-
destination:
|
|
22
|
+
destination: TextIO | Path
|
|
23
23
|
structured: bool
|
|
24
24
|
|
|
25
25
|
|