data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +34 -26
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +8 -0
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +6 -8
  31. data_designer/config/processors.py +63 -2
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +31 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +7 -8
  42. data_designer/config/utils/visualization.py +32 -17
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +1 -2
  57. data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +20 -11
  66. data_designer/engine/models/usage.py +7 -9
  67. data_designer/engine/processing/ginja/ast.py +1 -2
  68. data_designer/engine/processing/utils.py +40 -2
  69. data_designer/engine/registry/base.py +12 -12
  70. data_designer/engine/sampling_gen/constraints.py +1 -2
  71. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  72. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  73. data_designer/engine/sampling_gen/people_gen.py +3 -7
  74. data_designer/engine/validators/base.py +2 -2
  75. data_designer/logging.py +2 -2
  76. data_designer/plugin_manager.py +3 -3
  77. data_designer/plugins/plugin.py +3 -3
  78. data_designer/plugins/registry.py +2 -2
  79. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
  80. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
  81. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
  82. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
  83. {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -8,7 +8,7 @@ import json
8
8
  import logging
9
9
  from concurrent.futures import Future, ThreadPoolExecutor
10
10
  from threading import Lock, Semaphore
11
- from typing import Any, Optional, Protocol
11
+ from typing import Any, Protocol
12
12
 
13
13
  from pydantic import BaseModel, Field
14
14
 
@@ -46,13 +46,13 @@ class ExecutorResults(BaseModel):
46
46
  class CallbackWithContext(Protocol):
47
47
  """Executor callback functions must accept a context kw argument."""
48
48
 
49
- def __call__(self, result: Any, *, context: Optional[dict] = None) -> Any: ...
49
+ def __call__(self, result: Any, *, context: dict | None = None) -> Any: ...
50
50
 
51
51
 
52
52
  class ErrorCallbackWithContext(Protocol):
53
53
  """Error callbacks take the Exception instance and context."""
54
54
 
55
- def __call__(self, exc: Exception, *, context: Optional[dict] = None) -> Any: ...
55
+ def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ...
56
56
 
57
57
 
58
58
  class ConcurrentThreadExecutor:
@@ -92,8 +92,8 @@ class ConcurrentThreadExecutor:
92
92
  *,
93
93
  max_workers: int,
94
94
  column_name: str,
95
- result_callback: Optional[CallbackWithContext] = None,
96
- error_callback: Optional[ErrorCallbackWithContext] = None,
95
+ result_callback: CallbackWithContext | None = None,
96
+ error_callback: ErrorCallbackWithContext | None = None,
97
97
  shutdown_error_rate: float = 0.50,
98
98
  shutdown_error_window: int = 10,
99
99
  ):
@@ -136,7 +136,7 @@ class ConcurrentThreadExecutor:
136
136
  )
137
137
  )
138
138
 
139
- def submit(self, fn, *args, context: Optional[dict] = None, **kwargs) -> None:
139
+ def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None:
140
140
  if self._executor is None:
141
141
  raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
142
142
 
@@ -9,9 +9,9 @@ from copy import deepcopy
9
9
  from typing import Any
10
10
 
11
11
  from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
12
- from litellm.types.utils import ModelResponse
12
+ from litellm.types.utils import EmbeddingResponse, ModelResponse
13
13
 
14
- from data_designer.config.models import ModelConfig, ModelProvider
14
+ from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
15
15
  from data_designer.engine.model_provider import ModelProviderRegistry
16
16
  from data_designer.engine.models.errors import (
17
17
  GenerationValidationFailureError,
@@ -49,6 +49,10 @@ class ModelFacade:
49
49
  def model_provider(self) -> ModelProvider:
50
50
  return self._model_provider_registry.get_provider(self._model_config.provider)
51
51
 
52
+ @property
53
+ def model_generation_type(self) -> GenerationType:
54
+ return self._model_config.generation_type
55
+
52
56
  @property
53
57
  def model_provider_name(self) -> str:
54
58
  return self.model_provider.name
@@ -64,13 +68,12 @@ class ModelFacade:
64
68
  def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
65
69
  logger.debug(
66
70
  f"Prompting model {self.model_name!r}...",
67
- extra={"model": self.model_name, "messages": messages, "sensitive": True},
71
+ extra={"model": self.model_name, "messages": messages},
68
72
  )
69
73
  response = None
70
- if self.model_provider.extra_body:
71
- kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
74
+ kwargs = self.consolidate_kwargs(**kwargs)
72
75
  try:
73
- response = self._router.completion(self.model_name, messages, **kwargs)
76
+ response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
74
77
  logger.debug(
75
78
  f"Received completion from model {self.model_name!r}",
76
79
  extra={
@@ -84,9 +87,50 @@ class ModelFacade:
84
87
  except Exception as e:
85
88
  raise e
86
89
  finally:
87
- if not skip_usage_tracking:
90
+ if not skip_usage_tracking and response is not None:
88
91
  self._track_usage(response)
89
92
 
93
+ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
94
+ # Remove purpose from kwargs to avoid passing it to the model
95
+ kwargs.pop("purpose", None)
96
+ kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
97
+ if self.model_provider.extra_body:
98
+ kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
99
+ return kwargs
100
+
101
+ @catch_llm_exceptions
102
+ def generate_text_embeddings(
103
+ self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
104
+ ) -> list[list[float]]:
105
+ logger.debug(
106
+ f"Generating embeddings with model {self.model_name!r}...",
107
+ extra={
108
+ "model": self.model_name,
109
+ "input_count": len(input_texts),
110
+ },
111
+ )
112
+ kwargs = self.consolidate_kwargs(**kwargs)
113
+ response = None
114
+ try:
115
+ response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
116
+ logger.debug(
117
+ f"Received embeddings from model {self.model_name!r}",
118
+ extra={
119
+ "model": self.model_name,
120
+ "embedding_count": len(response.data) if response.data else 0,
121
+ "usage": self._usage_stats.model_dump(),
122
+ },
123
+ )
124
+ if response.data and len(response.data) == len(input_texts):
125
+ return [data["embedding"] for data in response.data]
126
+ else:
127
+ raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
128
+ except Exception as e:
129
+ raise e
130
+ finally:
131
+ if not skip_usage_tracking and response is not None:
132
+ self._track_usage_from_embedding(response)
133
+
90
134
  @catch_llm_exceptions
91
135
  def generate(
92
136
  self,
@@ -218,8 +262,21 @@ class ModelFacade:
218
262
  ):
219
263
  self._usage_stats.extend(
220
264
  token_usage=TokenUsageStats(
221
- prompt_tokens=response.usage.prompt_tokens,
222
- completion_tokens=response.usage.completion_tokens,
265
+ input_tokens=response.usage.prompt_tokens,
266
+ output_tokens=response.usage.completion_tokens,
267
+ ),
268
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
269
+ )
270
+
271
+ def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
272
+ if response is None:
273
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
274
+ return
275
+ if response.usage is not None and response.usage.prompt_tokens is not None:
276
+ self._usage_stats.extend(
277
+ token_usage=TokenUsageStats(
278
+ input_tokens=response.usage.prompt_tokens,
279
+ output_tokens=0,
223
280
  ),
224
281
  request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
225
282
  )
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
 
6
6
  import random
7
7
  import threading
8
- from typing import Optional, Union
9
8
 
10
9
  import httpx
11
10
  import litellm
@@ -90,7 +89,7 @@ class CustomRouter(Router):
90
89
  self._initial_retry_after_s = initial_retry_after_s
91
90
  self._jitter_pct = jitter_pct
92
91
 
93
- def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, float]]:
92
+ def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
94
93
  """
95
94
  Most of this code logic was extracted directly from the parent
96
95
  `Router`'s `_time_to_sleep_before_retry` function. Our override
@@ -99,7 +98,7 @@ class CustomRouter(Router):
99
98
  return this info, we'll simply use that retry value returned here.
100
99
  """
101
100
 
102
- response_headers: Optional[httpx.Headers] = None
101
+ response_headers: httpx.Headers | None = None
103
102
  if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
104
103
  response_headers = e.response.headers # type: ignore
105
104
  if hasattr(e, "litellm_response_headers"):
@@ -119,9 +118,9 @@ class CustomRouter(Router):
119
118
  e: Exception,
120
119
  remaining_retries: int,
121
120
  num_retries: int,
122
- healthy_deployments: Optional[list] = None,
123
- all_deployments: Optional[list] = None,
124
- ) -> Union[int, float]:
121
+ healthy_deployments: list | None = None,
122
+ all_deployments: list | None = None,
123
+ ) -> int | float:
125
124
  """
126
125
  Implements exponential backoff for retries.
127
126
 
@@ -1,8 +1,6 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Optional
5
-
6
4
 
7
5
  class ParserException(Exception):
8
6
  """Identifies errors resulting from generic parser errors.
@@ -12,7 +10,7 @@ class ParserException(Exception):
12
10
  attempted to parse.
13
11
  """
14
12
 
15
- source: Optional[str]
13
+ source: str | None
16
14
 
17
15
  @staticmethod
18
16
  def _log_format(source: str) -> str:
@@ -24,7 +22,7 @@ class ParserException(Exception):
24
22
  # return f"<source>{source}</source>"
25
23
  return ""
26
24
 
27
- def __init__(self, msg: Optional[str] = None, source: Optional[str] = None):
25
+ def __init__(self, msg: str | None = None, source: str | None = None):
28
26
  msg = "" if msg is None else msg.strip()
29
27
 
30
28
  if source is not None:
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from functools import reduce
5
- from typing import Optional
6
5
 
7
6
  import marko
8
7
  from lxml import etree
@@ -105,8 +104,8 @@ class LLMResponseParser:
105
104
 
106
105
  def __init__(
107
106
  self,
108
- tag_parsers: Optional[dict[str, TagParser]] = None,
109
- postprocessors: Optional[list[PostProcessor]] = None,
107
+ tag_parsers: dict[str, TagParser] | None = None,
108
+ postprocessors: list[PostProcessor] | None = None,
110
109
  ):
111
110
  """
112
111
  Initializes the LLMResponseParser with optional tag parsers and post-processors.
@@ -1,7 +1,6 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Optional, Type
5
4
 
6
5
  import json_repair
7
6
  from pydantic import BaseModel, ValidationError
@@ -60,12 +59,12 @@ def deserialize_json_code(
60
59
 
61
60
 
62
61
  class RealizePydanticTypes:
63
- types: list[Type[BaseModel]]
62
+ types: list[type[BaseModel]]
64
63
 
65
- def __init__(self, types: list[Type[BaseModel]]):
64
+ def __init__(self, types: list[type[BaseModel]]):
66
65
  self.types = types
67
66
 
68
- def _fit_types(self, obj: dict) -> Optional[BaseModel]:
67
+ def _fit_types(self, obj: dict) -> BaseModel | None:
69
68
  final_obj = None
70
69
 
71
70
  for t in self.types:
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
- from typing import Any, Optional, Protocol, Type, runtime_checkable
4
+ from typing import Any, Protocol, runtime_checkable
5
5
 
6
6
  from lxml.etree import _Element
7
7
  from pydantic import BaseModel, Field
@@ -30,7 +30,7 @@ class LLMStructuredResponse(BaseModel):
30
30
  out.parsed = out.parsed[-n:]
31
31
  return out
32
32
 
33
- def filter(self, block_types: list[Type[BaseModel]]) -> Self:
33
+ def filter(self, block_types: list[type[BaseModel]]) -> Self:
34
34
  out = self.model_copy()
35
35
  out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
36
36
  return out
@@ -44,7 +44,7 @@ class TagParser(Protocol):
44
44
  element, do some computation, and return some kind of structured
45
45
  output, represented as a subclass of Pydantic `BaseModel`.
46
46
  This protocol implementation can cover both classes as well
47
- as curried fuctions as parsers (e.g. `partial`).
47
+ as curried functions as parsers (e.g. `partial`).
48
48
  """
49
49
 
50
50
  def __call__(self, element: _Element) -> BaseModel: ...
@@ -69,7 +69,7 @@ class TextBlock(BaseModel):
69
69
 
70
70
  class CodeBlock(BaseModel):
71
71
  code: str
72
- code_lang: Optional[str] = None
72
+ code_lang: str | None = None
73
73
 
74
74
 
75
75
  class StructuredDataBlock(BaseModel):
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
 
6
6
  import logging
7
7
 
8
- from data_designer.config.models import ModelConfig
8
+ from data_designer.config.models import GenerationType, ModelConfig
9
9
  from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
10
10
  from data_designer.engine.models.facade import ModelFacade
11
11
  from data_designer.engine.models.litellm_overrides import apply_litellm_patches
@@ -73,7 +73,7 @@ class ModelRegistry:
73
73
  model_config = self.get_model_config(model_alias=model_alias)
74
74
  return self._model_provider_registry.get_provider(model_config.provider)
75
75
 
76
- def run_health_check(self, model_aliases: set[str]) -> None:
76
+ def run_health_check(self, model_aliases: list[str]) -> None:
77
77
  logger.info("🩺 Running health checks for models...")
78
78
  for model_alias in model_aliases:
79
79
  model = self.get_model(model_alias=model_alias)
@@ -81,15 +81,24 @@ class ModelRegistry:
81
81
  f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
82
82
  )
83
83
  try:
84
- model.generate(
85
- prompt="Hello!",
86
- parser=lambda x: x,
87
- system_prompt="You are a helpful assistant.",
88
- max_correction_steps=0,
89
- max_conversation_restarts=0,
90
- skip_usage_tracking=True,
91
- purpose="running health checks",
92
- )
84
+ if model.model_generation_type == GenerationType.EMBEDDING:
85
+ model.generate_text_embeddings(
86
+ input_texts=["Hello!"],
87
+ skip_usage_tracking=True,
88
+ purpose="running health checks",
89
+ )
90
+ elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
91
+ model.generate(
92
+ prompt="Hello!",
93
+ parser=lambda x: x,
94
+ system_prompt="You are a helpful assistant.",
95
+ max_correction_steps=0,
96
+ max_conversation_restarts=0,
97
+ skip_usage_tracking=True,
98
+ purpose="running health checks",
99
+ )
100
+ else:
101
+ raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
93
102
  logger.info(" |-- ✅ Passed!")
94
103
  except Exception as e:
95
104
  logger.error(" |-- ❌ Failed!")
@@ -11,20 +11,20 @@ logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  class TokenUsageStats(BaseModel):
14
- prompt_tokens: int = 0
15
- completion_tokens: int = 0
14
+ input_tokens: int = 0
15
+ output_tokens: int = 0
16
16
 
17
17
  @computed_field
18
18
  def total_tokens(self) -> int:
19
- return self.prompt_tokens + self.completion_tokens
19
+ return self.input_tokens + self.output_tokens
20
20
 
21
21
  @property
22
22
  def has_usage(self) -> bool:
23
23
  return self.total_tokens > 0
24
24
 
25
- def extend(self, *, prompt_tokens: int, completion_tokens: int) -> None:
26
- self.prompt_tokens += prompt_tokens
27
- self.completion_tokens += completion_tokens
25
+ def extend(self, *, input_tokens: int, output_tokens: int) -> None:
26
+ self.input_tokens += input_tokens
27
+ self.output_tokens += output_tokens
28
28
 
29
29
 
30
30
  class RequestUsageStats(BaseModel):
@@ -56,9 +56,7 @@ class ModelUsageStats(BaseModel):
56
56
  self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
57
57
  ) -> None:
58
58
  if token_usage is not None:
59
- self.token_usage.extend(
60
- prompt_tokens=token_usage.prompt_tokens, completion_tokens=token_usage.completion_tokens
61
- )
59
+ self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
62
60
  if request_usage is not None:
63
61
  self.request_usage.extend(
64
62
  successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from collections import deque
5
- from typing import Optional, Type
6
5
 
7
6
  from jinja2 import nodes as j_nodes
8
7
 
@@ -33,7 +32,7 @@ def ast_max_depth(node: j_nodes.Node) -> int:
33
32
  return max_depth
34
33
 
35
34
 
36
- def ast_descendant_count(ast: j_nodes.Node, only_type: Optional[Type[j_nodes.Node]] = None) -> int:
35
+ def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int:
37
36
  """Count the number of nodes which descend from the given node.
38
37
 
39
38
  Args:
@@ -1,9 +1,11 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ import ast
4
5
  import json
5
6
  import logging
6
- from typing import Any, TypeVar, Union, overload
7
+ import re
8
+ from typing import Any, TypeVar, overload
7
9
 
8
10
  import pandas as pd
9
11
 
@@ -25,7 +27,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
25
27
  # Overloads to help static type checker better understand
26
28
  # the input/output types of the deserialize_json_values function.
27
29
  @overload
28
- def deserialize_json_values(data: str) -> Union[dict[str, Any], list[Any], Any]: ...
30
+ def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ...
29
31
 
30
32
 
31
33
  @overload
@@ -100,6 +102,42 @@ def deserialize_json_values(data):
100
102
  return data
101
103
 
102
104
 
105
+ def parse_list_string(text: str) -> list[str]:
106
+ """Parse a list from a string, handling JSON arrays, Python lists, and trailing commas."""
107
+ text = text.strip()
108
+
109
+ # Try JSON first
110
+ try:
111
+ list_obj = json.loads(text)
112
+ if isinstance(list_obj, list):
113
+ return _clean_whitespace(list_obj)
114
+ except json.JSONDecodeError:
115
+ pass
116
+
117
+ # Remove trailing commas before closing brackets (common in JSON-like strings)
118
+ text_cleaned = re.sub(r",\s*]", "]", text)
119
+ text_cleaned = re.sub(r",\s*}", "}", text_cleaned)
120
+
121
+ # Try JSON again with cleaned text
122
+ try:
123
+ return _clean_whitespace(json.loads(text_cleaned))
124
+ except json.JSONDecodeError:
125
+ pass
126
+
127
+ # Try Python literal eval (handles single quotes)
128
+ try:
129
+ return _clean_whitespace(ast.literal_eval(text_cleaned))
130
+ except (ValueError, SyntaxError):
131
+ pass
132
+
133
+ # If all else fails, return the original text
134
+ return [text.strip()]
135
+
136
+
137
+ def _clean_whitespace(texts: list[str]) -> list[str]:
138
+ return [text.strip() for text in texts]
139
+
140
+
103
141
  def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
104
142
  joined_columns = set()
105
143
  for df in datasets:
@@ -2,7 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import threading
5
- from typing import Any, Generic, Type, TypeVar
5
+ from typing import Any, Generic, TypeVar
6
6
 
7
7
  from data_designer.config.base import ConfigBase
8
8
  from data_designer.config.utils.type_helpers import StrEnum
@@ -16,14 +16,14 @@ TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
16
16
 
17
17
  class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
18
18
  # registered type name -> type
19
- _registry: dict[EnumNameT, Type[TaskT]] = {}
19
+ _registry: dict[EnumNameT, type[TaskT]] = {}
20
20
  # type -> registered type name
21
- _reverse_registry: dict[Type[TaskT], EnumNameT] = {}
21
+ _reverse_registry: dict[type[TaskT], EnumNameT] = {}
22
22
 
23
23
  # registered type name -> config type
24
- _config_registry: dict[EnumNameT, Type[TaskConfigT]] = {}
24
+ _config_registry: dict[EnumNameT, type[TaskConfigT]] = {}
25
25
  # config type -> registered type name
26
- _reverse_config_registry: dict[Type[TaskConfigT], EnumNameT] = {}
26
+ _reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {}
27
27
 
28
28
  # all registries are singletons
29
29
  _instance = None
@@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
33
33
  def register(
34
34
  cls,
35
35
  name: EnumNameT,
36
- task: Type[TaskT],
37
- config: Type[TaskConfigT],
36
+ task: type[TaskT],
37
+ config: type[TaskConfigT],
38
38
  raise_on_collision: bool = False,
39
39
  ) -> None:
40
40
  if cls._has_been_registered(name):
@@ -52,22 +52,22 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
52
52
  cls._reverse_config_registry[config] = name
53
53
 
54
54
  @classmethod
55
- def get_task_type(cls, name: EnumNameT) -> Type[TaskT]:
55
+ def get_task_type(cls, name: EnumNameT) -> type[TaskT]:
56
56
  cls._raise_if_not_registered(name, cls._registry)
57
57
  return cls._registry[name]
58
58
 
59
59
  @classmethod
60
- def get_config_type(cls, name: EnumNameT) -> Type[TaskConfigT]:
60
+ def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]:
61
61
  cls._raise_if_not_registered(name, cls._config_registry)
62
62
  return cls._config_registry[name]
63
63
 
64
64
  @classmethod
65
- def get_registered_name(cls, task: Type[TaskT]) -> EnumNameT:
65
+ def get_registered_name(cls, task: type[TaskT]) -> EnumNameT:
66
66
  cls._raise_if_not_registered(task, cls._reverse_registry)
67
67
  return cls._reverse_registry[task]
68
68
 
69
69
  @classmethod
70
- def get_for_config_type(cls, config: Type[TaskConfigT]) -> Type[TaskT]:
70
+ def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]:
71
71
  cls._raise_if_not_registered(config, cls._reverse_config_registry)
72
72
  name = cls._reverse_config_registry[config]
73
73
  return cls.get_task_type(name)
@@ -77,7 +77,7 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
77
77
  return name in cls._registry
78
78
 
79
79
  @classmethod
80
- def _raise_if_not_registered(cls, key: EnumNameT | Type[TaskT] | Type[TaskConfigT], mapping: dict) -> None:
80
+ def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None:
81
81
  if not (isinstance(key, StrEnum) or isinstance(key, str)):
82
82
  cls._raise_if_not_type(key)
83
83
  if key not in mapping:
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Type
6
5
 
7
6
  import numpy as np
8
7
  import pandas as pd
@@ -91,5 +90,5 @@ CONSTRAINT_TYPE_TO_CHECKER = {
91
90
  }
92
91
 
93
92
 
94
- def get_constraint_checker(constraint_type: ConstraintType) -> Type[ConstraintChecker]:
93
+ def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
95
94
  return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]