data-designer 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +14 -1
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +5 -4
  31. data_designer/config/processors.py +109 -4
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +39 -9
  42. data_designer/config/utils/visualization.py +62 -15
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +14 -5
  57. data_designer/engine/dataset_builders/column_wise_builder.py +12 -8
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/processors/drop_columns.py +1 -1
  69. data_designer/engine/processing/processors/registry.py +3 -0
  70. data_designer/engine/processing/processors/schema_transform.py +53 -0
  71. data_designer/engine/processing/utils.py +40 -2
  72. data_designer/engine/registry/base.py +12 -12
  73. data_designer/engine/sampling_gen/constraints.py +1 -2
  74. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  75. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  76. data_designer/engine/sampling_gen/people_gen.py +3 -7
  77. data_designer/engine/validators/base.py +2 -2
  78. data_designer/interface/data_designer.py +12 -0
  79. data_designer/interface/results.py +36 -0
  80. data_designer/logging.py +2 -2
  81. data_designer/plugin_manager.py +3 -3
  82. data_designer/plugins/plugin.py +3 -3
  83. data_designer/plugins/registry.py +2 -2
  84. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/METADATA +9 -9
  85. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/RECORD +88 -81
  86. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  87. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  88. {data_designer-0.1.4.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Any, Optional, Protocol, Type, runtime_checkable
4
+ from typing import Any, Protocol, runtime_checkable
5
5
 
6
6
  from lxml.etree import _Element
7
7
  from pydantic import BaseModel, Field
@@ -30,7 +30,7 @@ class LLMStructuredResponse(BaseModel):
30
30
  out.parsed = out.parsed[-n:]
31
31
  return out
32
32
 
33
- def filter(self, block_types: list[Type[BaseModel]]) -> Self:
33
+ def filter(self, block_types: list[type[BaseModel]]) -> Self:
34
34
  out = self.model_copy()
35
35
  out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
36
36
  return out
@@ -44,7 +44,7 @@ class TagParser(Protocol):
44
44
  element, do some computation, and return some kind of structured
45
45
  output, represented as a subclass of Pydantic `BaseModel`.
46
46
  This protocol implementation can cover both classes as well
47
- as curried fuctions as parsers (e.g. `partial`).
47
+ as curried functions as parsers (e.g. `partial`).
48
48
  """
49
49
 
50
50
  def __call__(self, element: _Element) -> BaseModel: ...
@@ -69,7 +69,7 @@ class TextBlock(BaseModel):
69
69
 
70
70
  class CodeBlock(BaseModel):
71
71
  code: str
72
- code_lang: Optional[str] = None
72
+ code_lang: str | None = None
73
73
 
74
74
 
75
75
  class StructuredDataBlock(BaseModel):
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
 
6
6
  import logging
7
7
 
8
- from data_designer.config.models import ModelConfig
8
+ from data_designer.config.models import GenerationType, ModelConfig
9
9
  from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
10
10
  from data_designer.engine.models.facade import ModelFacade
11
11
  from data_designer.engine.models.litellm_overrides import apply_litellm_patches
@@ -73,7 +73,7 @@ class ModelRegistry:
73
73
  model_config = self.get_model_config(model_alias=model_alias)
74
74
  return self._model_provider_registry.get_provider(model_config.provider)
75
75
 
76
- def run_health_check(self, model_aliases: set[str]) -> None:
76
+ def run_health_check(self, model_aliases: list[str]) -> None:
77
77
  logger.info("🩺 Running health checks for models...")
78
78
  for model_alias in model_aliases:
79
79
  model = self.get_model(model_alias=model_alias)
@@ -81,15 +81,24 @@ class ModelRegistry:
81
81
  f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
82
82
  )
83
83
  try:
84
- model.generate(
85
- prompt="Hello!",
86
- parser=lambda x: x,
87
- system_prompt="You are a helpful assistant.",
88
- max_correction_steps=0,
89
- max_conversation_restarts=0,
90
- skip_usage_tracking=True,
91
- purpose="running health checks",
92
- )
84
+ if model.model_generation_type == GenerationType.EMBEDDING:
85
+ model.generate_text_embeddings(
86
+ input_texts=["Hello!"],
87
+ skip_usage_tracking=True,
88
+ purpose="running health checks",
89
+ )
90
+ elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
91
+ model.generate(
92
+ prompt="Hello!",
93
+ parser=lambda x: x,
94
+ system_prompt="You are a helpful assistant.",
95
+ max_correction_steps=0,
96
+ max_conversation_restarts=0,
97
+ skip_usage_tracking=True,
98
+ purpose="running health checks",
99
+ )
100
+ else:
101
+ raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
93
102
  logger.info(" |-- ✅ Passed!")
94
103
  except Exception as e:
95
104
  logger.error(" |-- ❌ Failed!")
@@ -11,20 +11,20 @@ logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  class TokenUsageStats(BaseModel):
14
- prompt_tokens: int = 0
15
- completion_tokens: int = 0
14
+ input_tokens: int = 0
15
+ output_tokens: int = 0
16
16
 
17
17
  @computed_field
18
18
  def total_tokens(self) -> int:
19
- return self.prompt_tokens + self.completion_tokens
19
+ return self.input_tokens + self.output_tokens
20
20
 
21
21
  @property
22
22
  def has_usage(self) -> bool:
23
23
  return self.total_tokens > 0
24
24
 
25
- def extend(self, *, prompt_tokens: int, completion_tokens: int) -> None:
26
- self.prompt_tokens += prompt_tokens
27
- self.completion_tokens += completion_tokens
25
+ def extend(self, *, input_tokens: int, output_tokens: int) -> None:
26
+ self.input_tokens += input_tokens
27
+ self.output_tokens += output_tokens
28
28
 
29
29
 
30
30
  class RequestUsageStats(BaseModel):
@@ -56,9 +56,7 @@ class ModelUsageStats(BaseModel):
56
56
  self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
57
57
  ) -> None:
58
58
  if token_usage is not None:
59
- self.token_usage.extend(
60
- prompt_tokens=token_usage.prompt_tokens, completion_tokens=token_usage.completion_tokens
61
- )
59
+ self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
62
60
  if request_usage is not None:
63
61
  self.request_usage.extend(
64
62
  successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from collections import deque
5
- from typing import Optional, Type
6
5
 
7
6
  from jinja2 import nodes as j_nodes
8
7
 
@@ -33,7 +32,7 @@ def ast_max_depth(node: j_nodes.Node) -> int:
33
32
  return max_depth
34
33
 
35
34
 
36
- def ast_descendant_count(ast: j_nodes.Node, only_type: Optional[Type[j_nodes.Node]] = None) -> int:
35
+ def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int:
37
36
  """Count the number of nodes which descend from the given node.
38
37
 
39
38
  Args:
@@ -17,7 +17,7 @@ class DropColumnsProcessor(Processor[DropColumnsProcessorConfig]):
17
17
  @staticmethod
18
18
  def metadata() -> ConfigurableTaskMetadata:
19
19
  return ConfigurableTaskMetadata(
20
- name="drop_columns",
20
+ name="drop_columns_processor",
21
21
  description="Drop columns from the input dataset.",
22
22
  required_resources=None,
23
23
  )
@@ -5,9 +5,11 @@ from data_designer.config.base import ConfigBase
5
5
  from data_designer.config.processors import (
6
6
  DropColumnsProcessorConfig,
7
7
  ProcessorType,
8
+ SchemaTransformProcessorConfig,
8
9
  )
9
10
  from data_designer.engine.processing.processors.base import Processor
10
11
  from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
12
+ from data_designer.engine.processing.processors.schema_transform import SchemaTransformProcessor
11
13
  from data_designer.engine.registry.base import TaskRegistry
12
14
 
13
15
 
@@ -16,5 +18,6 @@ class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ...
16
18
 
17
19
  def create_default_processor_registry() -> ProcessorRegistry:
18
20
  registry = ProcessorRegistry()
21
+ registry.register(ProcessorType.SCHEMA_TRANSFORM, SchemaTransformProcessor, SchemaTransformProcessorConfig, False)
19
22
  registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False)
20
23
  return registry
@@ -0,0 +1,53 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+ import logging
6
+
7
+ import pandas as pd
8
+
9
+ from data_designer.config.processors import SchemaTransformProcessorConfig
10
+ from data_designer.engine.configurable_task import ConfigurableTaskMetadata
11
+ from data_designer.engine.dataset_builders.artifact_storage import BatchStage
12
+ from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
13
+ from data_designer.engine.processing.processors.base import Processor
14
+ from data_designer.engine.processing.utils import deserialize_json_values
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class SchemaTransformProcessor(WithJinja2UserTemplateRendering, Processor[SchemaTransformProcessorConfig]):
20
+ @staticmethod
21
+ def metadata() -> ConfigurableTaskMetadata:
22
+ return ConfigurableTaskMetadata(
23
+ name="schema_transform_processor",
24
+ description="Generate dataset with transformed schema using a Jinja2 template.",
25
+ required_resources=None,
26
+ )
27
+
28
+ @property
29
+ def template_as_str(self) -> str:
30
+ return json.dumps(self.config.template)
31
+
32
+ def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
33
+ self.prepare_jinja2_template_renderer(self.template_as_str, data.columns.to_list())
34
+ formatted_records = [
35
+ json.loads(self.render_template(deserialize_json_values(record)).replace("\n", "\\n"))
36
+ for record in data.to_dict(orient="records")
37
+ ]
38
+ formatted_data = pd.DataFrame(formatted_records)
39
+ if current_batch_number is not None:
40
+ self.artifact_storage.write_batch_to_parquet_file(
41
+ batch_number=current_batch_number,
42
+ dataframe=formatted_data,
43
+ batch_stage=BatchStage.PROCESSORS_OUTPUTS,
44
+ subfolder=self.config.name,
45
+ )
46
+ else:
47
+ self.artifact_storage.write_parquet_file(
48
+ parquet_file_name=f"{self.config.name}.parquet",
49
+ dataframe=formatted_data,
50
+ batch_stage=BatchStage.PROCESSORS_OUTPUTS,
51
+ )
52
+
53
+ return data
@@ -1,9 +1,11 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ import ast
4
5
  import json
5
6
  import logging
6
- from typing import Any, TypeVar, Union, overload
7
+ import re
8
+ from typing import Any, TypeVar, overload
7
9
 
8
10
  import pandas as pd
9
11
 
@@ -25,7 +27,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
25
27
  # Overloads to help static type checker better understand
26
28
  # the input/output types of the deserialize_json_values function.
27
29
  @overload
28
- def deserialize_json_values(data: str) -> Union[dict[str, Any], list[Any], Any]: ...
30
+ def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ...
29
31
 
30
32
 
31
33
  @overload
@@ -100,6 +102,42 @@ def deserialize_json_values(data):
100
102
  return data
101
103
 
102
104
 
105
+ def parse_list_string(text: str) -> list[str]:
106
+ """Parse a list from a string, handling JSON arrays, Python lists, and trailing commas."""
107
+ text = text.strip()
108
+
109
+ # Try JSON first
110
+ try:
111
+ list_obj = json.loads(text)
112
+ if isinstance(list_obj, list):
113
+ return _clean_whitespace(list_obj)
114
+ except json.JSONDecodeError:
115
+ pass
116
+
117
+ # Remove trailing commas before closing brackets (common in JSON-like strings)
118
+ text_cleaned = re.sub(r",\s*]", "]", text)
119
+ text_cleaned = re.sub(r",\s*}", "}", text_cleaned)
120
+
121
+ # Try JSON again with cleaned text
122
+ try:
123
+ return _clean_whitespace(json.loads(text_cleaned))
124
+ except json.JSONDecodeError:
125
+ pass
126
+
127
+ # Try Python literal eval (handles single quotes)
128
+ try:
129
+ return _clean_whitespace(ast.literal_eval(text_cleaned))
130
+ except (ValueError, SyntaxError):
131
+ pass
132
+
133
+ # If all else fails, return the original text
134
+ return [text.strip()]
135
+
136
+
137
+ def _clean_whitespace(texts: list[str]) -> list[str]:
138
+ return [text.strip() for text in texts]
139
+
140
+
103
141
  def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
104
142
  joined_columns = set()
105
143
  for df in datasets:
@@ -2,7 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import threading
5
- from typing import Any, Generic, Type, TypeVar
5
+ from typing import Any, Generic, TypeVar
6
6
 
7
7
  from data_designer.config.base import ConfigBase
8
8
  from data_designer.config.utils.type_helpers import StrEnum
@@ -16,14 +16,14 @@ TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
16
16
 
17
17
  class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
18
18
  # registered type name -> type
19
- _registry: dict[EnumNameT, Type[TaskT]] = {}
19
+ _registry: dict[EnumNameT, type[TaskT]] = {}
20
20
  # type -> registered type name
21
- _reverse_registry: dict[Type[TaskT], EnumNameT] = {}
21
+ _reverse_registry: dict[type[TaskT], EnumNameT] = {}
22
22
 
23
23
  # registered type name -> config type
24
- _config_registry: dict[EnumNameT, Type[TaskConfigT]] = {}
24
+ _config_registry: dict[EnumNameT, type[TaskConfigT]] = {}
25
25
  # config type -> registered type name
26
- _reverse_config_registry: dict[Type[TaskConfigT], EnumNameT] = {}
26
+ _reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {}
27
27
 
28
28
  # all registries are singletons
29
29
  _instance = None
@@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
33
33
  def register(
34
34
  cls,
35
35
  name: EnumNameT,
36
- task: Type[TaskT],
37
- config: Type[TaskConfigT],
36
+ task: type[TaskT],
37
+ config: type[TaskConfigT],
38
38
  raise_on_collision: bool = False,
39
39
  ) -> None:
40
40
  if cls._has_been_registered(name):
@@ -52,22 +52,22 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
52
52
  cls._reverse_config_registry[config] = name
53
53
 
54
54
  @classmethod
55
- def get_task_type(cls, name: EnumNameT) -> Type[TaskT]:
55
+ def get_task_type(cls, name: EnumNameT) -> type[TaskT]:
56
56
  cls._raise_if_not_registered(name, cls._registry)
57
57
  return cls._registry[name]
58
58
 
59
59
  @classmethod
60
- def get_config_type(cls, name: EnumNameT) -> Type[TaskConfigT]:
60
+ def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]:
61
61
  cls._raise_if_not_registered(name, cls._config_registry)
62
62
  return cls._config_registry[name]
63
63
 
64
64
  @classmethod
65
- def get_registered_name(cls, task: Type[TaskT]) -> EnumNameT:
65
+ def get_registered_name(cls, task: type[TaskT]) -> EnumNameT:
66
66
  cls._raise_if_not_registered(task, cls._reverse_registry)
67
67
  return cls._reverse_registry[task]
68
68
 
69
69
  @classmethod
70
- def get_for_config_type(cls, config: Type[TaskConfigT]) -> Type[TaskT]:
70
+ def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]:
71
71
  cls._raise_if_not_registered(config, cls._reverse_config_registry)
72
72
  name = cls._reverse_config_registry[config]
73
73
  return cls.get_task_type(name)
@@ -77,7 +77,7 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
77
77
  return name in cls._registry
78
78
 
79
79
  @classmethod
80
- def _raise_if_not_registered(cls, key: EnumNameT | Type[TaskT] | Type[TaskConfigT], mapping: dict) -> None:
80
+ def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None:
81
81
  if not (isinstance(key, StrEnum) or isinstance(key, str)):
82
82
  cls._raise_if_not_type(key)
83
83
  if key not in mapping:
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Type
6
5
 
7
6
  import numpy as np
8
7
  import pandas as pd
@@ -91,5 +90,5 @@ CONSTRAINT_TYPE_TO_CHECKER = {
91
90
  }
92
91
 
93
92
 
94
- def get_constraint_checker(constraint_type: ConstraintType) -> Type[ConstraintChecker]:
93
+ def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
95
94
  return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]
@@ -2,7 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Any, Generic, Optional, Type, TypeVar, Union
5
+ from typing import Any, Generic, TypeVar
6
6
 
7
7
  import numpy as np
8
8
  import pandas as pd
@@ -45,7 +45,7 @@ class PassthroughMixin:
45
45
  return series
46
46
 
47
47
  @staticmethod
48
- def validate_data_conversion(convert_to: Optional[str]) -> None:
48
+ def validate_data_conversion(convert_to: str | None) -> None:
49
49
  pass
50
50
 
51
51
 
@@ -71,7 +71,7 @@ class TypeConversionMixin:
71
71
  return series
72
72
 
73
73
  @staticmethod
74
- def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
74
+ def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
75
75
  if convert_to is not None:
76
76
  if convert_to == "int":
77
77
  series = series.round()
@@ -79,18 +79,18 @@ class TypeConversionMixin:
79
79
  return series
80
80
 
81
81
  @staticmethod
82
- def validate_data_conversion(convert_to: Optional[str]) -> None:
82
+ def validate_data_conversion(convert_to: str | None) -> None:
83
83
  if convert_to is not None and convert_to not in ["float", "int", "str"]:
84
84
  raise ValueError(f"Invalid `convert_to` value: {convert_to}. Must be one of: [float, int, str]")
85
85
 
86
86
 
87
87
  class DatetimeFormatMixin:
88
88
  @staticmethod
89
- def preproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
89
+ def preproc(series: pd.Series, convert_to: str | None) -> pd.Series:
90
90
  return series
91
91
 
92
92
  @staticmethod
93
- def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
93
+ def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
94
94
  if convert_to is not None:
95
95
  return series.dt.strftime(convert_to)
96
96
  if series.dt.month.nunique() == 1:
@@ -104,7 +104,7 @@ class DatetimeFormatMixin:
104
104
  return series.apply(lambda dt: dt.isoformat()).astype(str)
105
105
 
106
106
  @staticmethod
107
- def validate_data_conversion(convert_to: Optional[str]) -> None:
107
+ def validate_data_conversion(convert_to: str | None) -> None:
108
108
  if convert_to is not None:
109
109
  try:
110
110
  pd.to_datetime(pd.to_datetime("2012-12-21").strftime(convert_to))
@@ -121,7 +121,7 @@ class DataSource(ABC, Generic[GenericParamsT]):
121
121
  def __init__(
122
122
  self,
123
123
  params: GenericParamsT,
124
- random_state: Optional[RadomStateT] = None,
124
+ random_state: RadomStateT | None = None,
125
125
  **kwargs,
126
126
  ):
127
127
  self.rng = check_random_state(random_state)
@@ -130,7 +130,7 @@ class DataSource(ABC, Generic[GenericParamsT]):
130
130
  self._validate()
131
131
 
132
132
  @classmethod
133
- def get_param_type(cls) -> Type[GenericParamsT]:
133
+ def get_param_type(cls) -> type[GenericParamsT]:
134
134
  return cls.__orig_bases__[-1].__args__[0]
135
135
 
136
136
  @abstractmethod
@@ -138,7 +138,7 @@ class DataSource(ABC, Generic[GenericParamsT]):
138
138
  self,
139
139
  dataframe: pd.DataFrame,
140
140
  column_name: str,
141
- index: Optional[list[int]] = None,
141
+ index: list[int] | None = None,
142
142
  ) -> pd.DataFrame: ...
143
143
 
144
144
  @staticmethod
@@ -147,11 +147,11 @@ class DataSource(ABC, Generic[GenericParamsT]):
147
147
 
148
148
  @staticmethod
149
149
  @abstractmethod
150
- def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: ...
150
+ def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: ...
151
151
 
152
152
  @staticmethod
153
153
  @abstractmethod
154
- def validate_data_conversion(convert_to: Optional[str]) -> None: ...
154
+ def validate_data_conversion(convert_to: str | None) -> None: ...
155
155
 
156
156
  def get_required_column_names(self) -> tuple[str, ...]:
157
157
  return tuple()
@@ -182,7 +182,7 @@ class Sampler(DataSource[GenericParamsT], ABC):
182
182
  self,
183
183
  dataframe: pd.DataFrame,
184
184
  column_name: str,
185
- index: Optional[list[int]] = None,
185
+ index: list[int] | None = None,
186
186
  ) -> pd.DataFrame:
187
187
  index = slice(None) if index is None else index
188
188
 
@@ -208,7 +208,7 @@ class Sampler(DataSource[GenericParamsT], ABC):
208
208
  class ScipyStatsSampler(Sampler[GenericParamsT], ABC):
209
209
  @property
210
210
  @abstractmethod
211
- def distribution(self) -> Union[stats.rv_continuous, stats.rv_discrete]: ...
211
+ def distribution(self) -> stats.rv_continuous | stats.rv_discrete: ...
212
212
 
213
213
  def sample(self, num_samples: int) -> NumpyArray1dT:
214
214
  return self.distribution.rvs(size=num_samples, random_state=self.rng)
@@ -3,7 +3,6 @@
3
3
 
4
4
  import random
5
5
  from pathlib import Path
6
- from typing import Optional
7
6
 
8
7
  import pandas as pd
9
8
  from pydantic import BaseModel, Field, field_validator
@@ -13,7 +12,7 @@ ZIPCODE_AREA_CODE_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DA
13
12
  ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"]))
14
13
 
15
14
 
16
- def get_area_code(zip_prefix: Optional[str] = None) -> str:
15
+ def get_area_code(zip_prefix: str | None = None) -> str:
17
16
  """
18
17
  Sample an area code for the given ZIP code prefix, population-weighted.
19
18
 
@@ -8,12 +8,12 @@ import uuid
8
8
  from abc import ABC, abstractmethod
9
9
  from collections.abc import Callable
10
10
  from copy import deepcopy
11
- from typing import TYPE_CHECKING, Any, Union
11
+ from typing import TYPE_CHECKING, Any, TypeAlias
12
12
 
13
13
  import pandas as pd
14
14
  from faker import Faker
15
15
 
16
- from data_designer.config.utils.constants import AVAILABLE_LOCALES, DEFAULT_AGE_RANGE
16
+ from data_designer.config.utils.constants import DEFAULT_AGE_RANGE
17
17
  from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
18
18
  from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS
19
19
  from data_designer.engine.sampling_gen.entities.person import (
@@ -27,17 +27,13 @@ if TYPE_CHECKING:
27
27
  from data_designer.engine.sampling_gen.schema import DataSchema
28
28
 
29
29
 
30
- EngineT = Union[Faker, ManagedDatasetGenerator]
30
+ EngineT: TypeAlias = Faker | ManagedDatasetGenerator
31
31
 
32
32
 
33
33
  class PeopleGen(ABC):
34
34
  """Unified interface for generating people data."""
35
35
 
36
36
  def __init__(self, engine: EngineT, locale: str):
37
- if locale not in AVAILABLE_LOCALES:
38
- raise ValueError(
39
- f"Locale {locale} is not a supported locale.Supported locales: {', '.join(AVAILABLE_LOCALES)}"
40
- )
41
37
  self.locale = locale
42
38
  self._engine = engine
43
39
 
@@ -2,14 +2,14 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Iterator, Optional
5
+ from typing import Iterator
6
6
 
7
7
  from pydantic import BaseModel, ConfigDict
8
8
  from typing_extensions import Self
9
9
 
10
10
 
11
11
  class ValidationOutput(BaseModel):
12
- is_valid: Optional[bool]
12
+ is_valid: bool | None
13
13
  model_config = ConfigDict(extra="allow")
14
14
 
15
15
 
@@ -249,6 +249,17 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
249
249
  except Exception as e:
250
250
  raise DataDesignerProfilingError(f"🛑 Error profiling preview dataset: {e}")
251
251
 
252
+ if builder.artifact_storage.processors_outputs_path.exists():
253
+ processor_artifacts = {
254
+ processor_config.name: pd.read_parquet(
255
+ builder.artifact_storage.processors_outputs_path / f"{processor_config.name}.parquet",
256
+ dtype_backend="pyarrow",
257
+ ).to_dict(orient="records")
258
+ for processor_config in config_builder.get_processor_configs()
259
+ }
260
+ else:
261
+ processor_artifacts = {}
262
+
252
263
  if (
253
264
  len(processed_dataset) > 0
254
265
  and isinstance(analysis, DatasetProfilerResults)
@@ -259,6 +270,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]):
259
270
  return PreviewResults(
260
271
  dataset=processed_dataset,
261
272
  analysis=analysis,
273
+ processor_artifacts=processor_artifacts,
262
274
  config_builder=config_builder,
263
275
  )
264
276
 
@@ -3,12 +3,15 @@
3
3
 
4
4
  from __future__ import annotations
5
5
 
6
+ from pathlib import Path
7
+
6
8
  import pandas as pd
7
9
 
8
10
  from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
9
11
  from data_designer.config.config_builder import DataDesignerConfigBuilder
10
12
  from data_designer.config.utils.visualization import WithRecordSamplerMixin
11
13
  from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
14
+ from data_designer.engine.dataset_builders.errors import ArtifactStorageError
12
15
 
13
16
 
14
17
  class DatasetCreationResults(WithRecordSamplerMixin):
@@ -53,3 +56,36 @@ class DatasetCreationResults(WithRecordSamplerMixin):
53
56
  A pandas DataFrame containing the full generated dataset.
54
57
  """
55
58
  return self.artifact_storage.load_dataset()
59
+
60
+ def load_processor_dataset(self, processor_name: str) -> pd.DataFrame:
61
+ """Load the dataset generated by a processor.
62
+
63
+ This only works for processors that write their artifacts in Parquet format.
64
+
65
+ Args:
66
+ processor_name: The name of the processor to load the dataset from.
67
+
68
+ Returns:
69
+ A pandas DataFrame containing the dataset generated by the processor.
70
+ """
71
+ try:
72
+ dataset = self.artifact_storage.read_parquet_files(
73
+ self.artifact_storage.processors_outputs_path / processor_name
74
+ )
75
+ except Exception as e:
76
+ raise ArtifactStorageError(f"Failed to load dataset for processor {processor_name}: {e}")
77
+
78
+ return dataset
79
+
80
+ def get_path_to_processor_artifacts(self, processor_name: str) -> Path:
81
+ """Get the path to the artifacts generated by a processor.
82
+
83
+ Args:
84
+ processor_name: The name of the processor to load the artifact from.
85
+
86
+ Returns:
87
+ The path to the artifacts.
88
+ """
89
+ if not self.artifact_storage.processors_outputs_path.exists():
90
+ raise ArtifactStorageError(f"Processor {processor_name} has no artifacts.")
91
+ return self.artifact_storage.processors_outputs_path / processor_name
data_designer/logging.py CHANGED
@@ -6,7 +6,7 @@ import random
6
6
  import sys
7
7
  from dataclasses import dataclass, field
8
8
  from pathlib import Path
9
- from typing import TextIO, Union
9
+ from typing import TextIO
10
10
 
11
11
  from pythonjsonlogger import jsonlogger
12
12
 
@@ -19,7 +19,7 @@ class LoggerConfig:
19
19
 
20
20
  @dataclass
21
21
  class OutputConfig:
22
- destination: Union[TextIO, Path]
22
+ destination: TextIO | Path
23
23
  structured: bool
24
24
 
25
25