data-designer 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. data_designer/_version.py +2 -2
  2. data_designer/cli/README.md +15 -1
  3. data_designer/cli/commands/download.py +56 -0
  4. data_designer/cli/commands/list.py +4 -18
  5. data_designer/cli/controllers/__init__.py +2 -1
  6. data_designer/cli/controllers/download_controller.py +217 -0
  7. data_designer/cli/controllers/model_controller.py +4 -3
  8. data_designer/cli/forms/field.py +65 -19
  9. data_designer/cli/forms/model_builder.py +251 -44
  10. data_designer/cli/main.py +11 -1
  11. data_designer/cli/repositories/persona_repository.py +88 -0
  12. data_designer/cli/services/__init__.py +2 -1
  13. data_designer/cli/services/download_service.py +97 -0
  14. data_designer/cli/ui.py +131 -0
  15. data_designer/cli/utils.py +34 -0
  16. data_designer/config/analysis/__init__.py +2 -0
  17. data_designer/config/analysis/column_profilers.py +75 -7
  18. data_designer/config/analysis/column_statistics.py +192 -48
  19. data_designer/config/analysis/dataset_profiler.py +23 -5
  20. data_designer/config/analysis/utils/reporting.py +3 -3
  21. data_designer/config/base.py +3 -3
  22. data_designer/config/column_configs.py +27 -6
  23. data_designer/config/column_types.py +24 -17
  24. data_designer/config/config_builder.py +36 -27
  25. data_designer/config/data_designer_config.py +7 -7
  26. data_designer/config/datastore.py +6 -6
  27. data_designer/config/default_model_settings.py +27 -34
  28. data_designer/config/exports.py +8 -0
  29. data_designer/config/models.py +155 -29
  30. data_designer/config/preview_results.py +6 -8
  31. data_designer/config/processors.py +63 -2
  32. data_designer/config/sampler_constraints.py +1 -2
  33. data_designer/config/sampler_params.py +50 -31
  34. data_designer/config/seed.py +1 -2
  35. data_designer/config/utils/code_lang.py +4 -5
  36. data_designer/config/utils/constants.py +31 -8
  37. data_designer/config/utils/io_helpers.py +5 -5
  38. data_designer/config/utils/misc.py +1 -4
  39. data_designer/config/utils/numerical_helpers.py +2 -2
  40. data_designer/config/utils/type_helpers.py +3 -3
  41. data_designer/config/utils/validation.py +7 -8
  42. data_designer/config/utils/visualization.py +32 -17
  43. data_designer/config/validator_params.py +4 -8
  44. data_designer/engine/analysis/column_profilers/base.py +0 -7
  45. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
  46. data_designer/engine/analysis/column_statistics.py +16 -16
  47. data_designer/engine/analysis/dataset_profiler.py +25 -4
  48. data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
  49. data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
  50. data_designer/engine/column_generators/generators/base.py +34 -0
  51. data_designer/engine/column_generators/generators/embedding.py +45 -0
  52. data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
  53. data_designer/engine/column_generators/registry.py +4 -2
  54. data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
  55. data_designer/engine/configurable_task.py +2 -2
  56. data_designer/engine/dataset_builders/artifact_storage.py +1 -2
  57. data_designer/engine/dataset_builders/column_wise_builder.py +58 -15
  58. data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
  59. data_designer/engine/models/facade.py +66 -9
  60. data_designer/engine/models/litellm_overrides.py +5 -6
  61. data_designer/engine/models/parsers/errors.py +2 -4
  62. data_designer/engine/models/parsers/parser.py +2 -3
  63. data_designer/engine/models/parsers/postprocessors.py +3 -4
  64. data_designer/engine/models/parsers/types.py +4 -4
  65. data_designer/engine/models/registry.py +47 -12
  66. data_designer/engine/models/telemetry.py +355 -0
  67. data_designer/engine/models/usage.py +7 -9
  68. data_designer/engine/processing/ginja/ast.py +1 -2
  69. data_designer/engine/processing/utils.py +40 -2
  70. data_designer/engine/registry/base.py +12 -12
  71. data_designer/engine/sampling_gen/constraints.py +1 -2
  72. data_designer/engine/sampling_gen/data_sources/base.py +14 -14
  73. data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
  74. data_designer/engine/sampling_gen/people_gen.py +3 -7
  75. data_designer/engine/validators/base.py +2 -2
  76. data_designer/logging.py +2 -2
  77. data_designer/plugin_manager.py +3 -3
  78. data_designer/plugins/plugin.py +3 -3
  79. data_designer/plugins/registry.py +2 -2
  80. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/METADATA +32 -1
  81. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/RECORD +84 -77
  82. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/WHEEL +0 -0
  83. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/entry_points.txt +0 -0
  84. {data_designer-0.1.5.dist-info → data_designer-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -5,10 +5,11 @@ from __future__ import annotations
5
5
 
6
6
  import logging
7
7
 
8
- from data_designer.config.models import ModelConfig
8
+ from data_designer.config.models import GenerationType, ModelConfig
9
9
  from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
10
10
  from data_designer.engine.models.facade import ModelFacade
11
11
  from data_designer.engine.models.litellm_overrides import apply_litellm_patches
12
+ from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
12
13
  from data_designer.engine.secret_resolver import SecretResolver
13
14
 
14
15
  logger = logging.getLogger(__name__)
@@ -25,7 +26,7 @@ class ModelRegistry:
25
26
  self._secret_resolver = secret_resolver
26
27
  self._model_provider_registry = model_provider_registry
27
28
  self._model_configs = {}
28
- self._models = {}
29
+ self._models: dict[str, ModelFacade] = {}
29
30
  self._set_model_configs(model_configs)
30
31
 
31
32
  @property
@@ -69,11 +70,36 @@ class ModelRegistry:
69
70
  if model.usage_stats.has_usage
70
71
  }
71
72
 
73
+ def get_model_usage_snapshot(self) -> dict[str, ModelUsageStats]:
74
+ return {
75
+ model.model_name: model.usage_stats.model_copy(deep=True)
76
+ for model in self._models.values()
77
+ if model.usage_stats.has_usage
78
+ }
79
+
80
+ def get_usage_deltas(self, snapshot: dict[str, ModelUsageStats]) -> dict[str, ModelUsageStats]:
81
+ deltas = {}
82
+ for model_name, current in self.get_model_usage_snapshot().items():
83
+ prev = snapshot.get(model_name)
84
+ delta_input = current.token_usage.input_tokens - (prev.token_usage.input_tokens if prev else 0)
85
+ delta_output = current.token_usage.output_tokens - (prev.token_usage.output_tokens if prev else 0)
86
+ delta_successful = current.request_usage.successful_requests - (
87
+ prev.request_usage.successful_requests if prev else 0
88
+ )
89
+ delta_failed = current.request_usage.failed_requests - (prev.request_usage.failed_requests if prev else 0)
90
+
91
+ if delta_input > 0 or delta_output > 0 or delta_successful > 0 or delta_failed > 0:
92
+ deltas[model_name] = ModelUsageStats(
93
+ token_usage=TokenUsageStats(input_tokens=delta_input, output_tokens=delta_output),
94
+ request_usage=RequestUsageStats(successful_requests=delta_successful, failed_requests=delta_failed),
95
+ )
96
+ return deltas
97
+
72
98
  def get_model_provider(self, *, model_alias: str) -> ModelProvider:
73
99
  model_config = self.get_model_config(model_alias=model_alias)
74
100
  return self._model_provider_registry.get_provider(model_config.provider)
75
101
 
76
- def run_health_check(self, model_aliases: set[str]) -> None:
102
+ def run_health_check(self, model_aliases: list[str]) -> None:
77
103
  logger.info("🩺 Running health checks for models...")
78
104
  for model_alias in model_aliases:
79
105
  model = self.get_model(model_alias=model_alias)
@@ -81,15 +107,24 @@ class ModelRegistry:
81
107
  f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
82
108
  )
83
109
  try:
84
- model.generate(
85
- prompt="Hello!",
86
- parser=lambda x: x,
87
- system_prompt="You are a helpful assistant.",
88
- max_correction_steps=0,
89
- max_conversation_restarts=0,
90
- skip_usage_tracking=True,
91
- purpose="running health checks",
92
- )
110
+ if model.model_generation_type == GenerationType.EMBEDDING:
111
+ model.generate_text_embeddings(
112
+ input_texts=["Hello!"],
113
+ skip_usage_tracking=True,
114
+ purpose="running health checks",
115
+ )
116
+ elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
117
+ model.generate(
118
+ prompt="Hello!",
119
+ parser=lambda x: x,
120
+ system_prompt="You are a helpful assistant.",
121
+ max_correction_steps=0,
122
+ max_conversation_restarts=0,
123
+ skip_usage_tracking=True,
124
+ purpose="running health checks",
125
+ )
126
+ else:
127
+ raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
93
128
  logger.info(" |-- ✅ Passed!")
94
129
  except Exception as e:
95
130
  logger.error(" |-- ❌ Failed!")
@@ -0,0 +1,355 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Telemetry handler for NeMo products.
6
+
7
+ Environment variables:
8
+ - NEMO_TELEMETRY_ENABLED: Whether telemetry is enabled.
9
+ - NEMO_DEPLOYMENT_TYPE: The deployment type the event came from.
10
+ - NEMO_TELEMETRY_ENDPOINT: The endpoint to send the telemetry events to.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ import os
17
+ import platform
18
+ from dataclasses import dataclass
19
+ from datetime import datetime, timezone
20
+ from enum import Enum
21
+ from typing import Any, ClassVar
22
+
23
+ import httpx
24
+ from pydantic import BaseModel, Field
25
+
26
+ TELEMETRY_ENABLED = os.getenv("NEMO_TELEMETRY_ENABLED", "true").lower() in ("1", "true", "yes")
27
+ CLIENT_ID = "184482118588404"
28
+ NEMO_TELEMETRY_VERSION = "nemo-telemetry/1.0"
29
+ MAX_RETRIES = 3
30
+ NEMO_TELEMETRY_ENDPOINT = os.getenv(
31
+ "NEMO_TELEMETRY_ENDPOINT", "https://events.telemetry.data.nvidia.com/v1.1/events/json"
32
+ ).lower()
33
+ CPU_ARCHITECTURE = platform.uname().machine
34
+
35
+
36
+ class NemoSourceEnum(str, Enum):
37
+ INFERENCE = "inference"
38
+ AUDITOR = "auditor"
39
+ DATADESIGNER = "datadesigner"
40
+ EVALUATOR = "evaluator"
41
+ GUARDRAILS = "guardrails"
42
+ UNDEFINED = "undefined"
43
+
44
+
45
+ class DeploymentTypeEnum(str, Enum):
46
+ LIBRARY = "library"
47
+ API = "api"
48
+ UNDEFINED = "undefined"
49
+
50
+
51
+ _deployment_type_raw = os.getenv("NEMO_DEPLOYMENT_TYPE", "library").lower()
52
+ try:
53
+ DEPLOYMENT_TYPE = DeploymentTypeEnum(_deployment_type_raw)
54
+ except ValueError:
55
+ valid_values = [e.value for e in DeploymentTypeEnum]
56
+ raise ValueError(
57
+ f"Invalid NEMO_DEPLOYMENT_TYPE: {_deployment_type_raw!r}. Must be one of: {valid_values}"
58
+ ) from None
59
+
60
+
61
+ class TaskStatusEnum(str, Enum):
62
+ SUCCESS = "success"
63
+ FAILURE = "failure"
64
+ UNDEFINED = "undefined"
65
+
66
+
67
+ class TelemetryEvent(BaseModel):
68
+ _event_name: ClassVar[str] # Subclasses must define this
69
+ _schema_version: ClassVar[str] = "1.3"
70
+
71
+ def __init_subclass__(cls, **kwargs: Any) -> None:
72
+ super().__init_subclass__(**kwargs)
73
+ if "_event_name" not in cls.__dict__:
74
+ raise TypeError(f"{cls.__name__} must define '_event_name' class variable")
75
+
76
+
77
+ class InferenceEvent(TelemetryEvent):
78
+ _event_name: ClassVar[str] = "inference_event"
79
+
80
+ nemo_source: NemoSourceEnum = Field(
81
+ ...,
82
+ alias="nemoSource",
83
+ description="The NeMo product that created the event (i.e. data-designer).",
84
+ )
85
+ task: str = Field(
86
+ ...,
87
+ description="The type of task that was performed that generated the inference event (i.e. preview-job, batch-job).",
88
+ )
89
+ task_status: TaskStatusEnum = Field(
90
+ ...,
91
+ alias="taskStatus",
92
+ description="The status of the task.",
93
+ )
94
+ deployment_type: DeploymentTypeEnum = Field(
95
+ default=DEPLOYMENT_TYPE,
96
+ alias="deploymentType",
97
+ description="The deployment type the event came from.",
98
+ )
99
+ model: str = Field(
100
+ ...,
101
+ description="The name of the model that was used.",
102
+ )
103
+ model_group: str = Field(
104
+ default="undefined",
105
+ alias="modelGroup",
106
+ description="An optional identifier to group models together.",
107
+ )
108
+ input_bytes: int = Field(
109
+ default=-1,
110
+ alias="inputBytes",
111
+ description="Number of bytes provided as input to the model. -1 if not available.",
112
+ ge=-9223372036854775808,
113
+ le=9223372036854775807,
114
+ )
115
+ input_tokens: int = Field(
116
+ default=-1,
117
+ alias="inputTokens",
118
+ description="Number of tokens provided as input to the model. -1 if not available.",
119
+ ge=-9223372036854775808,
120
+ le=9223372036854775807,
121
+ )
122
+ output_bytes: int = Field(
123
+ default=-1,
124
+ alias="outputBytes",
125
+ description="Number of bytes returned by the model. -1 if not available.",
126
+ ge=-9223372036854775808,
127
+ le=9223372036854775807,
128
+ )
129
+ output_tokens: int = Field(
130
+ default=-1,
131
+ alias="outputTokens",
132
+ description="Number of tokens returned by the model. -1 if not available.",
133
+ ge=-9223372036854775808,
134
+ le=9223372036854775807,
135
+ )
136
+
137
+ model_config = {"populate_by_name": True}
138
+
139
+
140
+ @dataclass
141
+ class QueuedEvent:
142
+ event: TelemetryEvent
143
+ timestamp: datetime
144
+ retry_count: int = 0
145
+
146
+
147
+ def _get_iso_timestamp(dt: datetime | None = None) -> str:
148
+ if dt is None:
149
+ dt = datetime.now(timezone.utc)
150
+ return dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{dt.microsecond // 1000:03d}Z"
151
+
152
+
153
+ def build_payload(
154
+ events: list[QueuedEvent], *, source_client_version: str, session_id: str = "undefined"
155
+ ) -> dict[str, Any]:
156
+ return {
157
+ "browserType": "undefined", # do not change
158
+ "clientId": CLIENT_ID,
159
+ "clientType": "Native", # do not change
160
+ "clientVariant": "Release", # do not change
161
+ "clientVer": source_client_version,
162
+ "cpuArchitecture": CPU_ARCHITECTURE,
163
+ "deviceGdprBehOptIn": "None", # do not change
164
+ "deviceGdprFuncOptIn": "None", # do not change
165
+ "deviceGdprTechOptIn": "None", # do not change
166
+ "deviceId": "undefined", # do not change
167
+ "deviceMake": "undefined", # do not change
168
+ "deviceModel": "undefined", # do not change
169
+ "deviceOS": "undefined", # do not change
170
+ "deviceOSVersion": "undefined", # do not change
171
+ "deviceType": "undefined", # do not change
172
+ "eventProtocol": "1.6", # do not change
173
+ "eventSchemaVer": events[0].event._schema_version,
174
+ "eventSysVer": NEMO_TELEMETRY_VERSION,
175
+ "externalUserId": "undefined", # do not change
176
+ "gdprBehOptIn": "None", # do not change
177
+ "gdprFuncOptIn": "None", # do not change
178
+ "gdprTechOptIn": "None", # do not change
179
+ "idpId": "undefined", # do not change
180
+ "integrationId": "undefined", # do not change
181
+ "productName": "undefined", # do not change
182
+ "productVersion": "undefined", # do not change
183
+ "sentTs": _get_iso_timestamp(),
184
+ "sessionId": session_id,
185
+ "userId": "undefined", # do not change
186
+ "events": [
187
+ {
188
+ "ts": _get_iso_timestamp(queued.timestamp),
189
+ "parameters": queued.event.model_dump(by_alias=True),
190
+ "name": queued.event._event_name,
191
+ }
192
+ for queued in events
193
+ ],
194
+ }
195
+
196
+
197
+ class TelemetryHandler:
198
+ """
199
+ Handles telemetry event batching, flushing, and retry logic for NeMo products.
200
+
201
+ Args:
202
+ flush_interval_seconds (float): The interval in seconds to flush the events.
203
+ max_queue_size (int): The maximum number of events to queue before flushing.
204
+ max_retries (int): The maximum number of times to retry sending an event.
205
+ source_client_version (str): The version of the source client. This should be the version of
206
+ the actual NeMo product that is sending the events, typically the same as the version of
207
+ a PyPi package that a user would install.
208
+ session_id (str): An optional session ID to associate with the events.
209
+ This should be a unique identifier for the session, such as a UUID.
210
+ It is used to group events together.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ flush_interval_seconds: float = 120.0,
216
+ max_queue_size: int = 50,
217
+ max_retries: int = MAX_RETRIES,
218
+ source_client_version: str = "undefined",
219
+ session_id: str = "undefined",
220
+ ):
221
+ self._flush_interval = flush_interval_seconds
222
+ self._max_queue_size = max_queue_size
223
+ self._max_retries = max_retries
224
+ self._events: list[QueuedEvent] = []
225
+ self._dlq: list[QueuedEvent] = [] # Dead letter queue for retry
226
+ self._flush_signal = asyncio.Event()
227
+ self._timer_task: asyncio.Task | None = None
228
+ self._running = False
229
+ self._source_client_version = source_client_version
230
+ self._session_id = session_id
231
+
232
+ async def astart(self) -> None:
233
+ if self._running:
234
+ return
235
+ self._running = True
236
+ self._timer_task = asyncio.create_task(self._timer_loop())
237
+
238
+ async def astop(self) -> None:
239
+ self._running = False
240
+ self._flush_signal.set()
241
+ if self._timer_task:
242
+ self._timer_task.cancel()
243
+ try:
244
+ await self._timer_task
245
+ except asyncio.CancelledError:
246
+ pass
247
+ self._timer_task = None
248
+ await self._flush_events()
249
+
250
+ async def aflush(self) -> None:
251
+ self._flush_signal.set()
252
+
253
+ def start(self) -> None:
254
+ self._run_sync(self.astart())
255
+
256
+ def stop(self) -> None:
257
+ self._run_sync(self.astop())
258
+
259
+ def flush(self) -> None:
260
+ self._flush_signal.set()
261
+
262
+ def enqueue(self, event: TelemetryEvent) -> None:
263
+ if not TELEMETRY_ENABLED:
264
+ return
265
+ if not isinstance(event, TelemetryEvent):
266
+ # Silently fail as we prioritize not disrupting upstream call sites and telemetry is best effort
267
+ return
268
+ queued = QueuedEvent(event=event, timestamp=datetime.now(timezone.utc))
269
+ self._events.append(queued)
270
+ if len(self._events) >= self._max_queue_size:
271
+ self._flush_signal.set()
272
+
273
+ def _run_sync(self, coro: Any) -> Any:
274
+ try:
275
+ loop = asyncio.get_running_loop()
276
+ except RuntimeError:
277
+ loop = None
278
+
279
+ if loop and loop.is_running():
280
+ import concurrent.futures
281
+
282
+ with concurrent.futures.ThreadPoolExecutor() as pool:
283
+ future = pool.submit(asyncio.run, coro)
284
+ return future.result()
285
+ else:
286
+ return asyncio.run(coro)
287
+
288
+ def __enter__(self) -> TelemetryHandler:
289
+ self.start()
290
+ return self
291
+
292
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
293
+ self.stop()
294
+
295
+ async def __aenter__(self) -> TelemetryHandler:
296
+ await self.astart()
297
+ return self
298
+
299
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
300
+ await self.astop()
301
+
302
+ async def _timer_loop(self) -> None:
303
+ while self._running:
304
+ try:
305
+ await asyncio.wait_for(
306
+ self._flush_signal.wait(),
307
+ timeout=self._flush_interval,
308
+ )
309
+ except asyncio.TimeoutError:
310
+ pass
311
+ self._flush_signal.clear()
312
+ await self._flush_events()
313
+
314
+ async def _flush_events(self) -> None:
315
+ dlq_events, self._dlq = self._dlq, []
316
+ new_events, self._events = self._events, []
317
+ events_to_send = dlq_events + new_events
318
+ if events_to_send:
319
+ await self._send_events(events_to_send)
320
+
321
+ async def _send_events(self, events: list[QueuedEvent]) -> None:
322
+ async with httpx.AsyncClient() as client:
323
+ await self._send_events_with_client(client, events)
324
+
325
+ async def _send_events_with_client(self, client: httpx.AsyncClient, events: list[QueuedEvent]) -> None:
326
+ if not events:
327
+ return
328
+
329
+ payload = build_payload(events, source_client_version=self._source_client_version, session_id=self._session_id)
330
+ try:
331
+ response = await client.post(NEMO_TELEMETRY_ENDPOINT, json=payload)
332
+ # 2xx, 400, 422 are all considered complete (no retry)
333
+ # 400/422 indicate bad payload which retrying won't fix
334
+ if response.status_code in (400, 422) or response.is_success:
335
+ return
336
+ # 413 (payload too large) - split and retry
337
+ if response.status_code == 413:
338
+ if len(events) == 1:
339
+ # Can't split further, drop the event
340
+ return
341
+ mid = len(events) // 2
342
+ await self._send_events_with_client(client, events[:mid])
343
+ await self._send_events_with_client(client, events[mid:])
344
+ return
345
+ if response.status_code == 408 or response.status_code >= 500:
346
+ self._add_to_dlq(events)
347
+ except httpx.HTTPError:
348
+ self._add_to_dlq(events)
349
+
350
+ def _add_to_dlq(self, events: list[QueuedEvent]) -> None:
351
+ for queued in events:
352
+ queued.retry_count += 1
353
+ if queued.retry_count > self._max_retries:
354
+ continue
355
+ self._dlq.append(queued)
@@ -11,20 +11,20 @@ logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  class TokenUsageStats(BaseModel):
14
- prompt_tokens: int = 0
15
- completion_tokens: int = 0
14
+ input_tokens: int = 0
15
+ output_tokens: int = 0
16
16
 
17
17
  @computed_field
18
18
  def total_tokens(self) -> int:
19
- return self.prompt_tokens + self.completion_tokens
19
+ return self.input_tokens + self.output_tokens
20
20
 
21
21
  @property
22
22
  def has_usage(self) -> bool:
23
23
  return self.total_tokens > 0
24
24
 
25
- def extend(self, *, prompt_tokens: int, completion_tokens: int) -> None:
26
- self.prompt_tokens += prompt_tokens
27
- self.completion_tokens += completion_tokens
25
+ def extend(self, *, input_tokens: int, output_tokens: int) -> None:
26
+ self.input_tokens += input_tokens
27
+ self.output_tokens += output_tokens
28
28
 
29
29
 
30
30
  class RequestUsageStats(BaseModel):
@@ -56,9 +56,7 @@ class ModelUsageStats(BaseModel):
56
56
  self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
57
57
  ) -> None:
58
58
  if token_usage is not None:
59
- self.token_usage.extend(
60
- prompt_tokens=token_usage.prompt_tokens, completion_tokens=token_usage.completion_tokens
61
- )
59
+ self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
62
60
  if request_usage is not None:
63
61
  self.request_usage.extend(
64
62
  successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from collections import deque
5
- from typing import Optional, Type
6
5
 
7
6
  from jinja2 import nodes as j_nodes
8
7
 
@@ -33,7 +32,7 @@ def ast_max_depth(node: j_nodes.Node) -> int:
33
32
  return max_depth
34
33
 
35
34
 
36
- def ast_descendant_count(ast: j_nodes.Node, only_type: Optional[Type[j_nodes.Node]] = None) -> int:
35
+ def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int:
37
36
  """Count the number of nodes which descend from the given node.
38
37
 
39
38
  Args:
@@ -1,9 +1,11 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
+ import ast
4
5
  import json
5
6
  import logging
6
- from typing import Any, TypeVar, Union, overload
7
+ import re
8
+ from typing import Any, TypeVar, overload
7
9
 
8
10
  import pandas as pd
9
11
 
@@ -25,7 +27,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
25
27
  # Overloads to help static type checker better understand
26
28
  # the input/output types of the deserialize_json_values function.
27
29
  @overload
28
- def deserialize_json_values(data: str) -> Union[dict[str, Any], list[Any], Any]: ...
30
+ def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ...
29
31
 
30
32
 
31
33
  @overload
@@ -100,6 +102,42 @@ def deserialize_json_values(data):
100
102
  return data
101
103
 
102
104
 
105
+ def parse_list_string(text: str) -> list[str]:
106
+ """Parse a list from a string, handling JSON arrays, Python lists, and trailing commas."""
107
+ text = text.strip()
108
+
109
+ # Try JSON first
110
+ try:
111
+ list_obj = json.loads(text)
112
+ if isinstance(list_obj, list):
113
+ return _clean_whitespace(list_obj)
114
+ except json.JSONDecodeError:
115
+ pass
116
+
117
+ # Remove trailing commas before closing brackets (common in JSON-like strings)
118
+ text_cleaned = re.sub(r",\s*]", "]", text)
119
+ text_cleaned = re.sub(r",\s*}", "}", text_cleaned)
120
+
121
+ # Try JSON again with cleaned text
122
+ try:
123
+ return _clean_whitespace(json.loads(text_cleaned))
124
+ except json.JSONDecodeError:
125
+ pass
126
+
127
+ # Try Python literal eval (handles single quotes)
128
+ try:
129
+ return _clean_whitespace(ast.literal_eval(text_cleaned))
130
+ except (ValueError, SyntaxError):
131
+ pass
132
+
133
+ # If all else fails, return the original text
134
+ return [text.strip()]
135
+
136
+
137
+ def _clean_whitespace(texts: list[str]) -> list[str]:
138
+ return [text.strip() for text in texts]
139
+
140
+
103
141
  def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
104
142
  joined_columns = set()
105
143
  for df in datasets:
@@ -2,7 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import threading
5
- from typing import Any, Generic, Type, TypeVar
5
+ from typing import Any, Generic, TypeVar
6
6
 
7
7
  from data_designer.config.base import ConfigBase
8
8
  from data_designer.config.utils.type_helpers import StrEnum
@@ -16,14 +16,14 @@ TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
16
16
 
17
17
  class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
18
18
  # registered type name -> type
19
- _registry: dict[EnumNameT, Type[TaskT]] = {}
19
+ _registry: dict[EnumNameT, type[TaskT]] = {}
20
20
  # type -> registered type name
21
- _reverse_registry: dict[Type[TaskT], EnumNameT] = {}
21
+ _reverse_registry: dict[type[TaskT], EnumNameT] = {}
22
22
 
23
23
  # registered type name -> config type
24
- _config_registry: dict[EnumNameT, Type[TaskConfigT]] = {}
24
+ _config_registry: dict[EnumNameT, type[TaskConfigT]] = {}
25
25
  # config type -> registered type name
26
- _reverse_config_registry: dict[Type[TaskConfigT], EnumNameT] = {}
26
+ _reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {}
27
27
 
28
28
  # all registries are singletons
29
29
  _instance = None
@@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
33
33
  def register(
34
34
  cls,
35
35
  name: EnumNameT,
36
- task: Type[TaskT],
37
- config: Type[TaskConfigT],
36
+ task: type[TaskT],
37
+ config: type[TaskConfigT],
38
38
  raise_on_collision: bool = False,
39
39
  ) -> None:
40
40
  if cls._has_been_registered(name):
@@ -52,22 +52,22 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
52
52
  cls._reverse_config_registry[config] = name
53
53
 
54
54
  @classmethod
55
- def get_task_type(cls, name: EnumNameT) -> Type[TaskT]:
55
+ def get_task_type(cls, name: EnumNameT) -> type[TaskT]:
56
56
  cls._raise_if_not_registered(name, cls._registry)
57
57
  return cls._registry[name]
58
58
 
59
59
  @classmethod
60
- def get_config_type(cls, name: EnumNameT) -> Type[TaskConfigT]:
60
+ def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]:
61
61
  cls._raise_if_not_registered(name, cls._config_registry)
62
62
  return cls._config_registry[name]
63
63
 
64
64
  @classmethod
65
- def get_registered_name(cls, task: Type[TaskT]) -> EnumNameT:
65
+ def get_registered_name(cls, task: type[TaskT]) -> EnumNameT:
66
66
  cls._raise_if_not_registered(task, cls._reverse_registry)
67
67
  return cls._reverse_registry[task]
68
68
 
69
69
  @classmethod
70
- def get_for_config_type(cls, config: Type[TaskConfigT]) -> Type[TaskT]:
70
+ def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]:
71
71
  cls._raise_if_not_registered(config, cls._reverse_config_registry)
72
72
  name = cls._reverse_config_registry[config]
73
73
  return cls.get_task_type(name)
@@ -77,7 +77,7 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
77
77
  return name in cls._registry
78
78
 
79
79
  @classmethod
80
- def _raise_if_not_registered(cls, key: EnumNameT | Type[TaskT] | Type[TaskConfigT], mapping: dict) -> None:
80
+ def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None:
81
81
  if not (isinstance(key, StrEnum) or isinstance(key, str)):
82
82
  cls._raise_if_not_type(key)
83
83
  if key not in mapping:
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Type
6
5
 
7
6
  import numpy as np
8
7
  import pandas as pd
@@ -91,5 +90,5 @@ CONSTRAINT_TYPE_TO_CHECKER = {
91
90
  }
92
91
 
93
92
 
94
- def get_constraint_checker(constraint_type: ConstraintType) -> Type[ConstraintChecker]:
93
+ def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
95
94
  return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]