data-designer 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/_version.py +2 -2
- data_designer/cli/README.md +15 -1
- data_designer/cli/commands/download.py +56 -0
- data_designer/cli/commands/list.py +4 -18
- data_designer/cli/controllers/__init__.py +2 -1
- data_designer/cli/controllers/download_controller.py +217 -0
- data_designer/cli/controllers/model_controller.py +4 -3
- data_designer/cli/forms/field.py +65 -19
- data_designer/cli/forms/model_builder.py +251 -44
- data_designer/cli/main.py +11 -1
- data_designer/cli/repositories/persona_repository.py +88 -0
- data_designer/cli/services/__init__.py +2 -1
- data_designer/cli/services/download_service.py +97 -0
- data_designer/cli/ui.py +131 -0
- data_designer/cli/utils.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +75 -7
- data_designer/config/analysis/column_statistics.py +192 -48
- data_designer/config/analysis/dataset_profiler.py +23 -5
- data_designer/config/analysis/utils/reporting.py +3 -3
- data_designer/config/base.py +3 -3
- data_designer/config/column_configs.py +27 -6
- data_designer/config/column_types.py +24 -17
- data_designer/config/config_builder.py +34 -26
- data_designer/config/data_designer_config.py +7 -7
- data_designer/config/datastore.py +6 -6
- data_designer/config/default_model_settings.py +27 -34
- data_designer/config/exports.py +8 -0
- data_designer/config/models.py +155 -29
- data_designer/config/preview_results.py +6 -8
- data_designer/config/processors.py +63 -2
- data_designer/config/sampler_constraints.py +1 -2
- data_designer/config/sampler_params.py +31 -31
- data_designer/config/seed.py +1 -2
- data_designer/config/utils/code_lang.py +4 -5
- data_designer/config/utils/constants.py +31 -8
- data_designer/config/utils/io_helpers.py +5 -5
- data_designer/config/utils/misc.py +1 -4
- data_designer/config/utils/numerical_helpers.py +2 -2
- data_designer/config/utils/type_helpers.py +3 -3
- data_designer/config/utils/validation.py +7 -8
- data_designer/config/utils/visualization.py +32 -17
- data_designer/config/validator_params.py +4 -8
- data_designer/engine/analysis/column_profilers/base.py +0 -7
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +2 -3
- data_designer/engine/analysis/column_statistics.py +16 -16
- data_designer/engine/analysis/dataset_profiler.py +25 -4
- data_designer/engine/analysis/utils/column_statistics_calculations.py +71 -49
- data_designer/engine/analysis/utils/judge_score_processing.py +5 -5
- data_designer/engine/column_generators/generators/base.py +34 -0
- data_designer/engine/column_generators/generators/embedding.py +45 -0
- data_designer/engine/column_generators/generators/{llm_generators.py → llm_completion.py} +17 -49
- data_designer/engine/column_generators/registry.py +4 -2
- data_designer/engine/column_generators/utils/judge_score_factory.py +5 -6
- data_designer/engine/configurable_task.py +2 -2
- data_designer/engine/dataset_builders/artifact_storage.py +1 -2
- data_designer/engine/dataset_builders/column_wise_builder.py +11 -10
- data_designer/engine/dataset_builders/utils/concurrency.py +6 -6
- data_designer/engine/models/facade.py +66 -9
- data_designer/engine/models/litellm_overrides.py +5 -6
- data_designer/engine/models/parsers/errors.py +2 -4
- data_designer/engine/models/parsers/parser.py +2 -3
- data_designer/engine/models/parsers/postprocessors.py +3 -4
- data_designer/engine/models/parsers/types.py +4 -4
- data_designer/engine/models/registry.py +20 -11
- data_designer/engine/models/usage.py +7 -9
- data_designer/engine/processing/ginja/ast.py +1 -2
- data_designer/engine/processing/utils.py +40 -2
- data_designer/engine/registry/base.py +12 -12
- data_designer/engine/sampling_gen/constraints.py +1 -2
- data_designer/engine/sampling_gen/data_sources/base.py +14 -14
- data_designer/engine/sampling_gen/entities/phone_number.py +1 -2
- data_designer/engine/sampling_gen/people_gen.py +3 -7
- data_designer/engine/validators/base.py +2 -2
- data_designer/logging.py +2 -2
- data_designer/plugin_manager.py +3 -3
- data_designer/plugins/plugin.py +3 -3
- data_designer/plugins/registry.py +2 -2
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/METADATA +1 -1
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/RECORD +83 -77
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/WHEEL +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/entry_points.txt +0 -0
- {data_designer-0.1.5.dist-info → data_designer-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -8,7 +8,7 @@ import json
|
|
|
8
8
|
import logging
|
|
9
9
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
10
10
|
from threading import Lock, Semaphore
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Protocol
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel, Field
|
|
14
14
|
|
|
@@ -46,13 +46,13 @@ class ExecutorResults(BaseModel):
|
|
|
46
46
|
class CallbackWithContext(Protocol):
|
|
47
47
|
"""Executor callback functions must accept a context kw argument."""
|
|
48
48
|
|
|
49
|
-
def __call__(self, result: Any, *, context:
|
|
49
|
+
def __call__(self, result: Any, *, context: dict | None = None) -> Any: ...
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class ErrorCallbackWithContext(Protocol):
|
|
53
53
|
"""Error callbacks take the Exception instance and context."""
|
|
54
54
|
|
|
55
|
-
def __call__(self, exc: Exception, *, context:
|
|
55
|
+
def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ...
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class ConcurrentThreadExecutor:
|
|
@@ -92,8 +92,8 @@ class ConcurrentThreadExecutor:
|
|
|
92
92
|
*,
|
|
93
93
|
max_workers: int,
|
|
94
94
|
column_name: str,
|
|
95
|
-
result_callback:
|
|
96
|
-
error_callback:
|
|
95
|
+
result_callback: CallbackWithContext | None = None,
|
|
96
|
+
error_callback: ErrorCallbackWithContext | None = None,
|
|
97
97
|
shutdown_error_rate: float = 0.50,
|
|
98
98
|
shutdown_error_window: int = 10,
|
|
99
99
|
):
|
|
@@ -136,7 +136,7 @@ class ConcurrentThreadExecutor:
|
|
|
136
136
|
)
|
|
137
137
|
)
|
|
138
138
|
|
|
139
|
-
def submit(self, fn, *args, context:
|
|
139
|
+
def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None:
|
|
140
140
|
if self._executor is None:
|
|
141
141
|
raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
|
|
142
142
|
|
|
@@ -9,9 +9,9 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import Any
|
|
10
10
|
|
|
11
11
|
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
|
|
12
|
-
from litellm.types.utils import ModelResponse
|
|
12
|
+
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
|
13
13
|
|
|
14
|
-
from data_designer.config.models import ModelConfig, ModelProvider
|
|
14
|
+
from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
|
|
15
15
|
from data_designer.engine.model_provider import ModelProviderRegistry
|
|
16
16
|
from data_designer.engine.models.errors import (
|
|
17
17
|
GenerationValidationFailureError,
|
|
@@ -49,6 +49,10 @@ class ModelFacade:
|
|
|
49
49
|
def model_provider(self) -> ModelProvider:
|
|
50
50
|
return self._model_provider_registry.get_provider(self._model_config.provider)
|
|
51
51
|
|
|
52
|
+
@property
|
|
53
|
+
def model_generation_type(self) -> GenerationType:
|
|
54
|
+
return self._model_config.generation_type
|
|
55
|
+
|
|
52
56
|
@property
|
|
53
57
|
def model_provider_name(self) -> str:
|
|
54
58
|
return self.model_provider.name
|
|
@@ -64,13 +68,12 @@ class ModelFacade:
|
|
|
64
68
|
def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
|
|
65
69
|
logger.debug(
|
|
66
70
|
f"Prompting model {self.model_name!r}...",
|
|
67
|
-
extra={"model": self.model_name, "messages": messages
|
|
71
|
+
extra={"model": self.model_name, "messages": messages},
|
|
68
72
|
)
|
|
69
73
|
response = None
|
|
70
|
-
|
|
71
|
-
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
74
|
+
kwargs = self.consolidate_kwargs(**kwargs)
|
|
72
75
|
try:
|
|
73
|
-
response = self._router.completion(self.model_name, messages, **kwargs)
|
|
76
|
+
response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
|
|
74
77
|
logger.debug(
|
|
75
78
|
f"Received completion from model {self.model_name!r}",
|
|
76
79
|
extra={
|
|
@@ -84,9 +87,50 @@ class ModelFacade:
|
|
|
84
87
|
except Exception as e:
|
|
85
88
|
raise e
|
|
86
89
|
finally:
|
|
87
|
-
if not skip_usage_tracking:
|
|
90
|
+
if not skip_usage_tracking and response is not None:
|
|
88
91
|
self._track_usage(response)
|
|
89
92
|
|
|
93
|
+
def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
|
|
94
|
+
# Remove purpose from kwargs to avoid passing it to the model
|
|
95
|
+
kwargs.pop("purpose", None)
|
|
96
|
+
kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
|
|
97
|
+
if self.model_provider.extra_body:
|
|
98
|
+
kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
|
|
99
|
+
return kwargs
|
|
100
|
+
|
|
101
|
+
@catch_llm_exceptions
|
|
102
|
+
def generate_text_embeddings(
|
|
103
|
+
self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
|
|
104
|
+
) -> list[list[float]]:
|
|
105
|
+
logger.debug(
|
|
106
|
+
f"Generating embeddings with model {self.model_name!r}...",
|
|
107
|
+
extra={
|
|
108
|
+
"model": self.model_name,
|
|
109
|
+
"input_count": len(input_texts),
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
kwargs = self.consolidate_kwargs(**kwargs)
|
|
113
|
+
response = None
|
|
114
|
+
try:
|
|
115
|
+
response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
|
|
116
|
+
logger.debug(
|
|
117
|
+
f"Received embeddings from model {self.model_name!r}",
|
|
118
|
+
extra={
|
|
119
|
+
"model": self.model_name,
|
|
120
|
+
"embedding_count": len(response.data) if response.data else 0,
|
|
121
|
+
"usage": self._usage_stats.model_dump(),
|
|
122
|
+
},
|
|
123
|
+
)
|
|
124
|
+
if response.data and len(response.data) == len(input_texts):
|
|
125
|
+
return [data["embedding"] for data in response.data]
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
|
|
128
|
+
except Exception as e:
|
|
129
|
+
raise e
|
|
130
|
+
finally:
|
|
131
|
+
if not skip_usage_tracking and response is not None:
|
|
132
|
+
self._track_usage_from_embedding(response)
|
|
133
|
+
|
|
90
134
|
@catch_llm_exceptions
|
|
91
135
|
def generate(
|
|
92
136
|
self,
|
|
@@ -218,8 +262,21 @@ class ModelFacade:
|
|
|
218
262
|
):
|
|
219
263
|
self._usage_stats.extend(
|
|
220
264
|
token_usage=TokenUsageStats(
|
|
221
|
-
|
|
222
|
-
|
|
265
|
+
input_tokens=response.usage.prompt_tokens,
|
|
266
|
+
output_tokens=response.usage.completion_tokens,
|
|
267
|
+
),
|
|
268
|
+
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
|
|
272
|
+
if response is None:
|
|
273
|
+
self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
|
|
274
|
+
return
|
|
275
|
+
if response.usage is not None and response.usage.prompt_tokens is not None:
|
|
276
|
+
self._usage_stats.extend(
|
|
277
|
+
token_usage=TokenUsageStats(
|
|
278
|
+
input_tokens=response.usage.prompt_tokens,
|
|
279
|
+
output_tokens=0,
|
|
223
280
|
),
|
|
224
281
|
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
|
|
225
282
|
)
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import random
|
|
7
7
|
import threading
|
|
8
|
-
from typing import Optional, Union
|
|
9
8
|
|
|
10
9
|
import httpx
|
|
11
10
|
import litellm
|
|
@@ -90,7 +89,7 @@ class CustomRouter(Router):
|
|
|
90
89
|
self._initial_retry_after_s = initial_retry_after_s
|
|
91
90
|
self._jitter_pct = jitter_pct
|
|
92
91
|
|
|
93
|
-
def _extract_retry_delay_from_headers(self, e: Exception) ->
|
|
92
|
+
def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
|
|
94
93
|
"""
|
|
95
94
|
Most of this code logic was extracted directly from the parent
|
|
96
95
|
`Router`'s `_time_to_sleep_before_retry` function. Our override
|
|
@@ -99,7 +98,7 @@ class CustomRouter(Router):
|
|
|
99
98
|
return this info, we'll simply use that retry value returned here.
|
|
100
99
|
"""
|
|
101
100
|
|
|
102
|
-
response_headers:
|
|
101
|
+
response_headers: httpx.Headers | None = None
|
|
103
102
|
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
|
|
104
103
|
response_headers = e.response.headers # type: ignore
|
|
105
104
|
if hasattr(e, "litellm_response_headers"):
|
|
@@ -119,9 +118,9 @@ class CustomRouter(Router):
|
|
|
119
118
|
e: Exception,
|
|
120
119
|
remaining_retries: int,
|
|
121
120
|
num_retries: int,
|
|
122
|
-
healthy_deployments:
|
|
123
|
-
all_deployments:
|
|
124
|
-
) ->
|
|
121
|
+
healthy_deployments: list | None = None,
|
|
122
|
+
all_deployments: list | None = None,
|
|
123
|
+
) -> int | float:
|
|
125
124
|
"""
|
|
126
125
|
Implements exponential backoff for retries.
|
|
127
126
|
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Optional
|
|
5
|
-
|
|
6
4
|
|
|
7
5
|
class ParserException(Exception):
|
|
8
6
|
"""Identifies errors resulting from generic parser errors.
|
|
@@ -12,7 +10,7 @@ class ParserException(Exception):
|
|
|
12
10
|
attempted to parse.
|
|
13
11
|
"""
|
|
14
12
|
|
|
15
|
-
source:
|
|
13
|
+
source: str | None
|
|
16
14
|
|
|
17
15
|
@staticmethod
|
|
18
16
|
def _log_format(source: str) -> str:
|
|
@@ -24,7 +22,7 @@ class ParserException(Exception):
|
|
|
24
22
|
# return f"<source>{source}</source>"
|
|
25
23
|
return ""
|
|
26
24
|
|
|
27
|
-
def __init__(self, msg:
|
|
25
|
+
def __init__(self, msg: str | None = None, source: str | None = None):
|
|
28
26
|
msg = "" if msg is None else msg.strip()
|
|
29
27
|
|
|
30
28
|
if source is not None:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from functools import reduce
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
import marko
|
|
8
7
|
from lxml import etree
|
|
@@ -105,8 +104,8 @@ class LLMResponseParser:
|
|
|
105
104
|
|
|
106
105
|
def __init__(
|
|
107
106
|
self,
|
|
108
|
-
tag_parsers:
|
|
109
|
-
postprocessors:
|
|
107
|
+
tag_parsers: dict[str, TagParser] | None = None,
|
|
108
|
+
postprocessors: list[PostProcessor] | None = None,
|
|
110
109
|
):
|
|
111
110
|
"""
|
|
112
111
|
Initializes the LLMResponseParser with optional tag parsers and post-processors.
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Optional, Type
|
|
5
4
|
|
|
6
5
|
import json_repair
|
|
7
6
|
from pydantic import BaseModel, ValidationError
|
|
@@ -60,12 +59,12 @@ def deserialize_json_code(
|
|
|
60
59
|
|
|
61
60
|
|
|
62
61
|
class RealizePydanticTypes:
|
|
63
|
-
types: list[
|
|
62
|
+
types: list[type[BaseModel]]
|
|
64
63
|
|
|
65
|
-
def __init__(self, types: list[
|
|
64
|
+
def __init__(self, types: list[type[BaseModel]]):
|
|
66
65
|
self.types = types
|
|
67
66
|
|
|
68
|
-
def _fit_types(self, obj: dict) ->
|
|
67
|
+
def _fit_types(self, obj: dict) -> BaseModel | None:
|
|
69
68
|
final_obj = None
|
|
70
69
|
|
|
71
70
|
for t in self.types:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Protocol, runtime_checkable
|
|
5
5
|
|
|
6
6
|
from lxml.etree import _Element
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
@@ -30,7 +30,7 @@ class LLMStructuredResponse(BaseModel):
|
|
|
30
30
|
out.parsed = out.parsed[-n:]
|
|
31
31
|
return out
|
|
32
32
|
|
|
33
|
-
def filter(self, block_types: list[
|
|
33
|
+
def filter(self, block_types: list[type[BaseModel]]) -> Self:
|
|
34
34
|
out = self.model_copy()
|
|
35
35
|
out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
|
|
36
36
|
return out
|
|
@@ -44,7 +44,7 @@ class TagParser(Protocol):
|
|
|
44
44
|
element, do some computation, and return some kind of structured
|
|
45
45
|
output, represented as a subclass of Pydantic `BaseModel`.
|
|
46
46
|
This protocol implementation can cover both classes as well
|
|
47
|
-
as curried
|
|
47
|
+
as curried functions as parsers (e.g. `partial`).
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
50
|
def __call__(self, element: _Element) -> BaseModel: ...
|
|
@@ -69,7 +69,7 @@ class TextBlock(BaseModel):
|
|
|
69
69
|
|
|
70
70
|
class CodeBlock(BaseModel):
|
|
71
71
|
code: str
|
|
72
|
-
code_lang:
|
|
72
|
+
code_lang: str | None = None
|
|
73
73
|
|
|
74
74
|
|
|
75
75
|
class StructuredDataBlock(BaseModel):
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
|
|
8
|
-
from data_designer.config.models import ModelConfig
|
|
8
|
+
from data_designer.config.models import GenerationType, ModelConfig
|
|
9
9
|
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
|
|
10
10
|
from data_designer.engine.models.facade import ModelFacade
|
|
11
11
|
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
|
|
@@ -73,7 +73,7 @@ class ModelRegistry:
|
|
|
73
73
|
model_config = self.get_model_config(model_alias=model_alias)
|
|
74
74
|
return self._model_provider_registry.get_provider(model_config.provider)
|
|
75
75
|
|
|
76
|
-
def run_health_check(self, model_aliases:
|
|
76
|
+
def run_health_check(self, model_aliases: list[str]) -> None:
|
|
77
77
|
logger.info("🩺 Running health checks for models...")
|
|
78
78
|
for model_alias in model_aliases:
|
|
79
79
|
model = self.get_model(model_alias=model_alias)
|
|
@@ -81,15 +81,24 @@ class ModelRegistry:
|
|
|
81
81
|
f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
|
|
82
82
|
)
|
|
83
83
|
try:
|
|
84
|
-
model.
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
84
|
+
if model.model_generation_type == GenerationType.EMBEDDING:
|
|
85
|
+
model.generate_text_embeddings(
|
|
86
|
+
input_texts=["Hello!"],
|
|
87
|
+
skip_usage_tracking=True,
|
|
88
|
+
purpose="running health checks",
|
|
89
|
+
)
|
|
90
|
+
elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
|
|
91
|
+
model.generate(
|
|
92
|
+
prompt="Hello!",
|
|
93
|
+
parser=lambda x: x,
|
|
94
|
+
system_prompt="You are a helpful assistant.",
|
|
95
|
+
max_correction_steps=0,
|
|
96
|
+
max_conversation_restarts=0,
|
|
97
|
+
skip_usage_tracking=True,
|
|
98
|
+
purpose="running health checks",
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
|
|
93
102
|
logger.info(" |-- ✅ Passed!")
|
|
94
103
|
except Exception as e:
|
|
95
104
|
logger.error(" |-- ❌ Failed!")
|
|
@@ -11,20 +11,20 @@ logger = logging.getLogger(__name__)
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TokenUsageStats(BaseModel):
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
input_tokens: int = 0
|
|
15
|
+
output_tokens: int = 0
|
|
16
16
|
|
|
17
17
|
@computed_field
|
|
18
18
|
def total_tokens(self) -> int:
|
|
19
|
-
return self.
|
|
19
|
+
return self.input_tokens + self.output_tokens
|
|
20
20
|
|
|
21
21
|
@property
|
|
22
22
|
def has_usage(self) -> bool:
|
|
23
23
|
return self.total_tokens > 0
|
|
24
24
|
|
|
25
|
-
def extend(self, *,
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
25
|
+
def extend(self, *, input_tokens: int, output_tokens: int) -> None:
|
|
26
|
+
self.input_tokens += input_tokens
|
|
27
|
+
self.output_tokens += output_tokens
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class RequestUsageStats(BaseModel):
|
|
@@ -56,9 +56,7 @@ class ModelUsageStats(BaseModel):
|
|
|
56
56
|
self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
|
|
57
57
|
) -> None:
|
|
58
58
|
if token_usage is not None:
|
|
59
|
-
self.token_usage.extend(
|
|
60
|
-
prompt_tokens=token_usage.prompt_tokens, completion_tokens=token_usage.completion_tokens
|
|
61
|
-
)
|
|
59
|
+
self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
|
|
62
60
|
if request_usage is not None:
|
|
63
61
|
self.request_usage.extend(
|
|
64
62
|
successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from collections import deque
|
|
5
|
-
from typing import Optional, Type
|
|
6
5
|
|
|
7
6
|
from jinja2 import nodes as j_nodes
|
|
8
7
|
|
|
@@ -33,7 +32,7 @@ def ast_max_depth(node: j_nodes.Node) -> int:
|
|
|
33
32
|
return max_depth
|
|
34
33
|
|
|
35
34
|
|
|
36
|
-
def ast_descendant_count(ast: j_nodes.Node, only_type:
|
|
35
|
+
def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int:
|
|
37
36
|
"""Count the number of nodes which descend from the given node.
|
|
38
37
|
|
|
39
38
|
Args:
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
|
+
import ast
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
6
|
-
|
|
7
|
+
import re
|
|
8
|
+
from typing import Any, TypeVar, overload
|
|
7
9
|
|
|
8
10
|
import pandas as pd
|
|
9
11
|
|
|
@@ -25,7 +27,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
|
|
|
25
27
|
# Overloads to help static type checker better understand
|
|
26
28
|
# the input/output types of the deserialize_json_values function.
|
|
27
29
|
@overload
|
|
28
|
-
def deserialize_json_values(data: str) ->
|
|
30
|
+
def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ...
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
@overload
|
|
@@ -100,6 +102,42 @@ def deserialize_json_values(data):
|
|
|
100
102
|
return data
|
|
101
103
|
|
|
102
104
|
|
|
105
|
+
def parse_list_string(text: str) -> list[str]:
|
|
106
|
+
"""Parse a list from a string, handling JSON arrays, Python lists, and trailing commas."""
|
|
107
|
+
text = text.strip()
|
|
108
|
+
|
|
109
|
+
# Try JSON first
|
|
110
|
+
try:
|
|
111
|
+
list_obj = json.loads(text)
|
|
112
|
+
if isinstance(list_obj, list):
|
|
113
|
+
return _clean_whitespace(list_obj)
|
|
114
|
+
except json.JSONDecodeError:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
# Remove trailing commas before closing brackets (common in JSON-like strings)
|
|
118
|
+
text_cleaned = re.sub(r",\s*]", "]", text)
|
|
119
|
+
text_cleaned = re.sub(r",\s*}", "}", text_cleaned)
|
|
120
|
+
|
|
121
|
+
# Try JSON again with cleaned text
|
|
122
|
+
try:
|
|
123
|
+
return _clean_whitespace(json.loads(text_cleaned))
|
|
124
|
+
except json.JSONDecodeError:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
# Try Python literal eval (handles single quotes)
|
|
128
|
+
try:
|
|
129
|
+
return _clean_whitespace(ast.literal_eval(text_cleaned))
|
|
130
|
+
except (ValueError, SyntaxError):
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# If all else fails, return the original text
|
|
134
|
+
return [text.strip()]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _clean_whitespace(texts: list[str]) -> list[str]:
|
|
138
|
+
return [text.strip() for text in texts]
|
|
139
|
+
|
|
140
|
+
|
|
103
141
|
def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
|
|
104
142
|
joined_columns = set()
|
|
105
143
|
for df in datasets:
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
import threading
|
|
5
|
-
from typing import Any, Generic,
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
6
|
|
|
7
7
|
from data_designer.config.base import ConfigBase
|
|
8
8
|
from data_designer.config.utils.type_helpers import StrEnum
|
|
@@ -16,14 +16,14 @@ TaskConfigT = TypeVar("TaskConfigT", bound=ConfigBase)
|
|
|
16
16
|
|
|
17
17
|
class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
18
18
|
# registered type name -> type
|
|
19
|
-
_registry: dict[EnumNameT,
|
|
19
|
+
_registry: dict[EnumNameT, type[TaskT]] = {}
|
|
20
20
|
# type -> registered type name
|
|
21
|
-
_reverse_registry: dict[
|
|
21
|
+
_reverse_registry: dict[type[TaskT], EnumNameT] = {}
|
|
22
22
|
|
|
23
23
|
# registered type name -> config type
|
|
24
|
-
_config_registry: dict[EnumNameT,
|
|
24
|
+
_config_registry: dict[EnumNameT, type[TaskConfigT]] = {}
|
|
25
25
|
# config type -> registered type name
|
|
26
|
-
_reverse_config_registry: dict[
|
|
26
|
+
_reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {}
|
|
27
27
|
|
|
28
28
|
# all registries are singletons
|
|
29
29
|
_instance = None
|
|
@@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
33
33
|
def register(
|
|
34
34
|
cls,
|
|
35
35
|
name: EnumNameT,
|
|
36
|
-
task:
|
|
37
|
-
config:
|
|
36
|
+
task: type[TaskT],
|
|
37
|
+
config: type[TaskConfigT],
|
|
38
38
|
raise_on_collision: bool = False,
|
|
39
39
|
) -> None:
|
|
40
40
|
if cls._has_been_registered(name):
|
|
@@ -52,22 +52,22 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
52
52
|
cls._reverse_config_registry[config] = name
|
|
53
53
|
|
|
54
54
|
@classmethod
|
|
55
|
-
def get_task_type(cls, name: EnumNameT) ->
|
|
55
|
+
def get_task_type(cls, name: EnumNameT) -> type[TaskT]:
|
|
56
56
|
cls._raise_if_not_registered(name, cls._registry)
|
|
57
57
|
return cls._registry[name]
|
|
58
58
|
|
|
59
59
|
@classmethod
|
|
60
|
-
def get_config_type(cls, name: EnumNameT) ->
|
|
60
|
+
def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]:
|
|
61
61
|
cls._raise_if_not_registered(name, cls._config_registry)
|
|
62
62
|
return cls._config_registry[name]
|
|
63
63
|
|
|
64
64
|
@classmethod
|
|
65
|
-
def get_registered_name(cls, task:
|
|
65
|
+
def get_registered_name(cls, task: type[TaskT]) -> EnumNameT:
|
|
66
66
|
cls._raise_if_not_registered(task, cls._reverse_registry)
|
|
67
67
|
return cls._reverse_registry[task]
|
|
68
68
|
|
|
69
69
|
@classmethod
|
|
70
|
-
def get_for_config_type(cls, config:
|
|
70
|
+
def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]:
|
|
71
71
|
cls._raise_if_not_registered(config, cls._reverse_config_registry)
|
|
72
72
|
name = cls._reverse_config_registry[config]
|
|
73
73
|
return cls.get_task_type(name)
|
|
@@ -77,7 +77,7 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
|
|
|
77
77
|
return name in cls._registry
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def _raise_if_not_registered(cls, key: EnumNameT |
|
|
80
|
+
def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None:
|
|
81
81
|
if not (isinstance(key, StrEnum) or isinstance(key, str)):
|
|
82
82
|
cls._raise_if_not_type(key)
|
|
83
83
|
if key not in mapping:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
|
-
from typing import Type
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
8
7
|
import pandas as pd
|
|
@@ -91,5 +90,5 @@ CONSTRAINT_TYPE_TO_CHECKER = {
|
|
|
91
90
|
}
|
|
92
91
|
|
|
93
92
|
|
|
94
|
-
def get_constraint_checker(constraint_type: ConstraintType) ->
|
|
93
|
+
def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
|
|
95
94
|
return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]
|