data-designer-engine 0.4.0__py3-none-any.whl → 0.4.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/engine/_version.py +2 -2
- data_designer/engine/column_generators/generators/llm_completion.py +7 -10
- data_designer/engine/dataset_builders/column_wise_builder.py +5 -24
- data_designer/engine/models/facade.py +26 -23
- data_designer/engine/models/registry.py +0 -5
- data_designer/engine/models/telemetry.py +5 -8
- data_designer/engine/models/utils.py +21 -84
- data_designer/engine/processing/processors/schema_transform.py +5 -27
- {data_designer_engine-0.4.0.dist-info → data_designer_engine-0.4.0rc1.dist-info}/METADATA +1 -1
- {data_designer_engine-0.4.0.dist-info → data_designer_engine-0.4.0rc1.dist-info}/RECORD +11 -12
- data_designer/engine/dataset_builders/utils/progress_tracker.py +0 -122
- {data_designer_engine-0.4.0.dist-info → data_designer_engine-0.4.0rc1.dist-info}/WHEEL +0 -0
data_designer/engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.4.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 4, 0)
|
|
31
|
+
__version__ = version = '0.4.0rc1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 4, 0, 'rc1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -12,7 +12,7 @@ from data_designer.config.column_configs import (
|
|
|
12
12
|
LLMStructuredColumnConfig,
|
|
13
13
|
LLMTextColumnConfig,
|
|
14
14
|
)
|
|
15
|
-
from data_designer.config.utils.constants import
|
|
15
|
+
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
|
|
16
16
|
from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
|
|
17
17
|
from data_designer.engine.column_generators.utils.prompt_renderer import (
|
|
18
18
|
PromptType,
|
|
@@ -62,11 +62,11 @@ class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfig
|
|
|
62
62
|
|
|
63
63
|
multi_modal_context = None
|
|
64
64
|
if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
|
|
65
|
-
multi_modal_context = [
|
|
66
|
-
|
|
67
|
-
|
|
65
|
+
multi_modal_context = [
|
|
66
|
+
context.get_context(deserialized_record) for context in self.config.multi_modal_context
|
|
67
|
+
]
|
|
68
68
|
|
|
69
|
-
response,
|
|
69
|
+
response, reasoning_trace = self.model.generate(
|
|
70
70
|
prompt=self.prompt_renderer.render(
|
|
71
71
|
record=deserialized_record,
|
|
72
72
|
prompt_template=self.config.prompt,
|
|
@@ -87,11 +87,8 @@ class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfig
|
|
|
87
87
|
serialized_output = self.response_recipe.serialize_output(response)
|
|
88
88
|
data[self.config.name] = self._process_serialized_output(serialized_output)
|
|
89
89
|
|
|
90
|
-
|
|
91
|
-
self.config.
|
|
92
|
-
)
|
|
93
|
-
if should_save_trace:
|
|
94
|
-
data[self.config.name + TRACE_COLUMN_POSTFIX] = [message.to_dict() for message in trace]
|
|
90
|
+
if reasoning_trace:
|
|
91
|
+
data[self.config.name + REASONING_TRACE_COLUMN_POSTFIX] = reasoning_trace
|
|
95
92
|
|
|
96
93
|
return data
|
|
97
94
|
|
|
@@ -34,7 +34,6 @@ from data_designer.engine.dataset_builders.multi_column_configs import MultiColu
|
|
|
34
34
|
from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
|
|
35
35
|
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
|
|
36
36
|
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
|
|
37
|
-
from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker
|
|
38
37
|
from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
|
|
39
38
|
from data_designer.engine.processing.processors.base import Processor
|
|
40
39
|
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
|
|
@@ -222,18 +221,16 @@ class ColumnWiseDatasetBuilder:
|
|
|
222
221
|
"generator so concurrency through threads is not supported."
|
|
223
222
|
)
|
|
224
223
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
224
|
+
logger.info(
|
|
225
|
+
f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' "
|
|
226
|
+
f"with {max_workers} concurrent workers"
|
|
228
227
|
)
|
|
229
|
-
progress_tracker.log_start(max_workers)
|
|
230
|
-
|
|
231
228
|
settings = self._resource_provider.run_config
|
|
232
229
|
with ConcurrentThreadExecutor(
|
|
233
230
|
max_workers=max_workers,
|
|
234
231
|
column_name=generator.config.name,
|
|
235
|
-
result_callback=self.
|
|
236
|
-
error_callback=self.
|
|
232
|
+
result_callback=self._worker_result_callback,
|
|
233
|
+
error_callback=self._worker_error_callback,
|
|
237
234
|
shutdown_error_rate=settings.shutdown_error_rate,
|
|
238
235
|
shutdown_error_window=settings.shutdown_error_window,
|
|
239
236
|
disable_early_shutdown=settings.disable_early_shutdown,
|
|
@@ -241,26 +238,10 @@ class ColumnWiseDatasetBuilder:
|
|
|
241
238
|
for i, record in self.batch_manager.iter_current_batch():
|
|
242
239
|
executor.submit(lambda record: generator.generate(record), record, context={"index": i})
|
|
243
240
|
|
|
244
|
-
progress_tracker.log_final()
|
|
245
|
-
|
|
246
241
|
if len(self._records_to_drop) > 0:
|
|
247
242
|
self.batch_manager.drop_records(self._records_to_drop)
|
|
248
243
|
self._records_to_drop.clear()
|
|
249
244
|
|
|
250
|
-
def _make_result_callback(self, progress_tracker: ProgressTracker) -> Callable[[dict], None]:
|
|
251
|
-
def callback(result: dict, *, context: dict | None = None) -> None:
|
|
252
|
-
self._worker_result_callback(result, context=context)
|
|
253
|
-
progress_tracker.record_success()
|
|
254
|
-
|
|
255
|
-
return callback
|
|
256
|
-
|
|
257
|
-
def _make_error_callback(self, progress_tracker: ProgressTracker) -> Callable[[Exception], None]:
|
|
258
|
-
def callback(exc: Exception, *, context: dict | None = None) -> None:
|
|
259
|
-
self._worker_error_callback(exc, context=context)
|
|
260
|
-
progress_tracker.record_failure()
|
|
261
|
-
|
|
262
|
-
return callback
|
|
263
|
-
|
|
264
245
|
def _write_processed_batch(self, dataframe: pd.DataFrame) -> None:
|
|
265
246
|
self.batch_manager.update_records(dataframe.to_dict(orient="records"))
|
|
266
247
|
self.batch_manager.write()
|
|
@@ -18,7 +18,7 @@ from data_designer.engine.models.errors import (
|
|
|
18
18
|
from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
|
|
19
19
|
from data_designer.engine.models.parsers.errors import ParserException
|
|
20
20
|
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
|
|
21
|
-
from data_designer.engine.models.utils import
|
|
21
|
+
from data_designer.engine.models.utils import prompt_to_messages, str_to_message
|
|
22
22
|
from data_designer.engine.secret_resolver import SecretResolver
|
|
23
23
|
from data_designer.lazy_heavy_imports import litellm
|
|
24
24
|
|
|
@@ -67,17 +67,16 @@ class ModelFacade:
|
|
|
67
67
|
return self._usage_stats
|
|
68
68
|
|
|
69
69
|
def completion(
|
|
70
|
-
self, messages: list[
|
|
70
|
+
self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
|
|
71
71
|
) -> litellm.ModelResponse:
|
|
72
|
-
message_payloads = [message.to_dict() for message in messages]
|
|
73
72
|
logger.debug(
|
|
74
73
|
f"Prompting model {self.model_name!r}...",
|
|
75
|
-
extra={"model": self.model_name, "messages":
|
|
74
|
+
extra={"model": self.model_name, "messages": messages},
|
|
76
75
|
)
|
|
77
76
|
response = None
|
|
78
77
|
kwargs = self.consolidate_kwargs(**kwargs)
|
|
79
78
|
try:
|
|
80
|
-
response = self._router.completion(model=self.model_name, messages=
|
|
79
|
+
response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
|
|
81
80
|
logger.debug(
|
|
82
81
|
f"Received completion from model {self.model_name!r}",
|
|
83
82
|
extra={
|
|
@@ -150,7 +149,7 @@ class ModelFacade:
|
|
|
150
149
|
skip_usage_tracking: bool = False,
|
|
151
150
|
purpose: str | None = None,
|
|
152
151
|
**kwargs,
|
|
153
|
-
) -> tuple[Any,
|
|
152
|
+
) -> tuple[Any, str | None]:
|
|
154
153
|
"""Generate a parsed output with correction steps.
|
|
155
154
|
|
|
156
155
|
This generation call will attempt to generate an output which is
|
|
@@ -183,12 +182,6 @@ class ModelFacade:
|
|
|
183
182
|
It is expected to be used by the @catch_llm_exceptions decorator.
|
|
184
183
|
**kwargs: Additional arguments to pass to the model.
|
|
185
184
|
|
|
186
|
-
Returns:
|
|
187
|
-
A tuple containing:
|
|
188
|
-
- The parsed output object from the parser.
|
|
189
|
-
- The full trace of ChatMessage entries in the conversation, including any
|
|
190
|
-
corrections and reasoning traces. Callers can decide whether to store this.
|
|
191
|
-
|
|
192
185
|
Raises:
|
|
193
186
|
GenerationValidationFailureError: If the maximum number of retries or
|
|
194
187
|
correction steps are met and the last response failures on
|
|
@@ -197,17 +190,29 @@ class ModelFacade:
|
|
|
197
190
|
output_obj = None
|
|
198
191
|
curr_num_correction_steps = 0
|
|
199
192
|
curr_num_restarts = 0
|
|
193
|
+
curr_generation_attempt = 0
|
|
194
|
+
max_generation_attempts = (max_correction_steps + 1) * (max_conversation_restarts + 1)
|
|
200
195
|
|
|
201
196
|
starting_messages = prompt_to_messages(
|
|
202
197
|
user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
|
|
203
198
|
)
|
|
204
|
-
messages
|
|
199
|
+
messages = deepcopy(starting_messages)
|
|
205
200
|
|
|
206
201
|
while True:
|
|
202
|
+
curr_generation_attempt += 1
|
|
203
|
+
logger.debug(
|
|
204
|
+
f"Starting generation attempt {curr_generation_attempt} of {max_generation_attempts} attempts."
|
|
205
|
+
)
|
|
206
|
+
|
|
207
207
|
completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
|
|
208
208
|
response = completion_response.choices[0].message.content or ""
|
|
209
209
|
reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
|
|
210
|
-
|
|
210
|
+
|
|
211
|
+
if reasoning_trace:
|
|
212
|
+
## There are generally some extra newlines with how these get parsed.
|
|
213
|
+
response = response.strip()
|
|
214
|
+
reasoning_trace = reasoning_trace.strip()
|
|
215
|
+
|
|
211
216
|
curr_num_correction_steps += 1
|
|
212
217
|
|
|
213
218
|
try:
|
|
@@ -218,23 +223,21 @@ class ModelFacade:
|
|
|
218
223
|
raise GenerationValidationFailureError(
|
|
219
224
|
"Unsuccessful generation attempt. No retries were attempted."
|
|
220
225
|
) from exc
|
|
221
|
-
|
|
222
226
|
if curr_num_correction_steps <= max_correction_steps:
|
|
223
|
-
|
|
224
|
-
messages
|
|
225
|
-
|
|
227
|
+
## Add turns to loop-back errors for correction
|
|
228
|
+
messages += [
|
|
229
|
+
str_to_message(content=response, role="assistant"),
|
|
230
|
+
str_to_message(content=str(get_exception_primary_cause(exc)), role="user"),
|
|
231
|
+
]
|
|
226
232
|
elif curr_num_restarts < max_conversation_restarts:
|
|
227
233
|
curr_num_correction_steps = 0
|
|
228
234
|
curr_num_restarts += 1
|
|
229
235
|
messages = deepcopy(starting_messages)
|
|
230
|
-
|
|
231
236
|
else:
|
|
232
237
|
raise GenerationValidationFailureError(
|
|
233
|
-
f"Unsuccessful generation despite {
|
|
234
|
-
f"and {max_conversation_restarts} conversation restarts."
|
|
238
|
+
f"Unsuccessful generation attempt despite {max_generation_attempts} attempts."
|
|
235
239
|
) from exc
|
|
236
|
-
|
|
237
|
-
return output_obj, messages
|
|
240
|
+
return output_obj, reasoning_trace
|
|
238
241
|
|
|
239
242
|
def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
|
|
240
243
|
provider = self._model_provider_registry.get_provider(model_config.provider)
|
|
@@ -107,11 +107,6 @@ class ModelRegistry:
|
|
|
107
107
|
def run_health_check(self, model_aliases: list[str]) -> None:
|
|
108
108
|
logger.info("🩺 Running health checks for models...")
|
|
109
109
|
for model_alias in model_aliases:
|
|
110
|
-
model_config = self.get_model_config(model_alias=model_alias)
|
|
111
|
-
if model_config.skip_health_check:
|
|
112
|
-
logger.info(f" |-- ⏭️ Skipping health check for model alias {model_alias!r} (skip_health_check=True)")
|
|
113
|
-
continue
|
|
114
|
-
|
|
115
110
|
model = self.get_model(model_alias=model_alias)
|
|
116
111
|
logger.info(
|
|
117
112
|
f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
|
|
@@ -8,7 +8,6 @@ Environment variables:
|
|
|
8
8
|
- NEMO_TELEMETRY_ENABLED: Whether telemetry is enabled.
|
|
9
9
|
- NEMO_DEPLOYMENT_TYPE: The deployment type the event came from.
|
|
10
10
|
- NEMO_TELEMETRY_ENDPOINT: The endpoint to send the telemetry events to.
|
|
11
|
-
- NEMO_SESSION_PREFIX: Optional prefix to add to session IDs.
|
|
12
11
|
"""
|
|
13
12
|
|
|
14
13
|
from __future__ import annotations
|
|
@@ -19,12 +18,15 @@ import platform
|
|
|
19
18
|
from dataclasses import dataclass
|
|
20
19
|
from datetime import datetime, timezone
|
|
21
20
|
from enum import Enum
|
|
22
|
-
from typing import Any, ClassVar
|
|
21
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
23
22
|
|
|
24
23
|
from pydantic import BaseModel, Field
|
|
25
24
|
|
|
26
25
|
from data_designer.lazy_heavy_imports import httpx
|
|
27
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
import httpx
|
|
29
|
+
|
|
28
30
|
TELEMETRY_ENABLED = os.getenv("NEMO_TELEMETRY_ENABLED", "true").lower() in ("1", "true", "yes")
|
|
29
31
|
CLIENT_ID = "184482118588404"
|
|
30
32
|
NEMO_TELEMETRY_VERSION = "nemo-telemetry/1.0"
|
|
@@ -33,7 +35,6 @@ NEMO_TELEMETRY_ENDPOINT = os.getenv(
|
|
|
33
35
|
"NEMO_TELEMETRY_ENDPOINT", "https://events.telemetry.data.nvidia.com/v1.1/events/json"
|
|
34
36
|
).lower()
|
|
35
37
|
CPU_ARCHITECTURE = platform.uname().machine
|
|
36
|
-
SESSION_PREFIX = os.getenv("NEMO_SESSION_PREFIX")
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
class NemoSourceEnum(str, Enum):
|
|
@@ -230,11 +231,7 @@ class TelemetryHandler:
|
|
|
230
231
|
self._timer_task: asyncio.Task | None = None
|
|
231
232
|
self._running = False
|
|
232
233
|
self._source_client_version = source_client_version
|
|
233
|
-
|
|
234
|
-
if SESSION_PREFIX:
|
|
235
|
-
self._session_id = f"{SESSION_PREFIX}{session_id}"
|
|
236
|
-
else:
|
|
237
|
-
self._session_id = session_id
|
|
234
|
+
self._session_id = session_id
|
|
238
235
|
|
|
239
236
|
async def astart(self) -> None:
|
|
240
237
|
if self._running:
|
|
@@ -3,81 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
|
-
from
|
|
7
|
-
from typing import Any, Literal
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
@dataclass
|
|
11
|
-
class ChatMessage:
|
|
12
|
-
"""A chat message in an LLM conversation.
|
|
13
|
-
|
|
14
|
-
This dataclass represents messages exchanged in a conversation with an LLM,
|
|
15
|
-
supporting various message types including user prompts, assistant responses,
|
|
16
|
-
system instructions, and tool interactions.
|
|
17
|
-
|
|
18
|
-
Attributes:
|
|
19
|
-
role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'.
|
|
20
|
-
content: The message content. Can be a string or a list of content blocks
|
|
21
|
-
for multimodal messages (e.g., text + images).
|
|
22
|
-
reasoning_content: Optional reasoning/thinking content from the assistant,
|
|
23
|
-
typically from extended thinking or chain-of-thought models.
|
|
24
|
-
tool_calls: Optional list of tool calls requested by the assistant.
|
|
25
|
-
Each tool call contains 'id', 'type', and 'function' keys.
|
|
26
|
-
tool_call_id: Optional ID linking a tool response to its corresponding
|
|
27
|
-
tool call. Required for messages with role='tool'.
|
|
28
|
-
"""
|
|
29
|
-
|
|
30
|
-
role: Literal["user", "assistant", "system", "tool"]
|
|
31
|
-
content: str | list[dict[str, Any]] = ""
|
|
32
|
-
reasoning_content: str | None = None
|
|
33
|
-
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
|
34
|
-
tool_call_id: str | None = None
|
|
35
|
-
|
|
36
|
-
def to_dict(self) -> dict[str, Any]:
|
|
37
|
-
"""Convert the message to a dictionary format for API calls.
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
A dictionary containing the message fields. Only includes non-empty
|
|
41
|
-
optional fields to keep the output clean.
|
|
42
|
-
"""
|
|
43
|
-
result: dict[str, Any] = {"role": self.role, "content": self.content}
|
|
44
|
-
if self.reasoning_content:
|
|
45
|
-
result["reasoning_content"] = self.reasoning_content
|
|
46
|
-
if self.tool_calls:
|
|
47
|
-
result["tool_calls"] = self.tool_calls
|
|
48
|
-
if self.tool_call_id:
|
|
49
|
-
result["tool_call_id"] = self.tool_call_id
|
|
50
|
-
return result
|
|
51
|
-
|
|
52
|
-
@classmethod
|
|
53
|
-
def as_user(cls, content: str | list[dict[str, Any]]) -> ChatMessage:
|
|
54
|
-
"""Create a user message."""
|
|
55
|
-
return cls(role="user", content=content)
|
|
56
|
-
|
|
57
|
-
@classmethod
|
|
58
|
-
def as_assistant(
|
|
59
|
-
cls,
|
|
60
|
-
content: str = "",
|
|
61
|
-
reasoning_content: str | None = None,
|
|
62
|
-
tool_calls: list[dict[str, Any]] | None = None,
|
|
63
|
-
) -> ChatMessage:
|
|
64
|
-
"""Create an assistant message."""
|
|
65
|
-
return cls(
|
|
66
|
-
role="assistant",
|
|
67
|
-
content=content,
|
|
68
|
-
reasoning_content=reasoning_content,
|
|
69
|
-
tool_calls=tool_calls or [],
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
@classmethod
|
|
73
|
-
def as_system(cls, content: str) -> ChatMessage:
|
|
74
|
-
"""Create a system message."""
|
|
75
|
-
return cls(role="system", content=content)
|
|
76
|
-
|
|
77
|
-
@classmethod
|
|
78
|
-
def as_tool(cls, content: str, tool_call_id: str) -> ChatMessage:
|
|
79
|
-
"""Create a tool response message."""
|
|
80
|
-
return cls(role="tool", content=content, tool_call_id=tool_call_id)
|
|
6
|
+
from typing import Any
|
|
81
7
|
|
|
82
8
|
|
|
83
9
|
def prompt_to_messages(
|
|
@@ -85,17 +11,28 @@ def prompt_to_messages(
|
|
|
85
11
|
user_prompt: str,
|
|
86
12
|
system_prompt: str | None = None,
|
|
87
13
|
multi_modal_context: list[dict[str, Any]] | None = None,
|
|
88
|
-
) -> list[
|
|
89
|
-
"""Convert a user and system prompt into
|
|
14
|
+
) -> list[dict[str, str | list[dict]]]:
|
|
15
|
+
"""Convert a user and system prompt into Messages format.
|
|
90
16
|
|
|
91
17
|
Args:
|
|
92
18
|
user_prompt (str): A user prompt.
|
|
93
19
|
system_prompt (str, optional): An optional system prompt.
|
|
94
20
|
"""
|
|
95
|
-
user_content
|
|
96
|
-
if multi_modal_context:
|
|
97
|
-
user_content = [
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
return
|
|
21
|
+
user_content = user_prompt
|
|
22
|
+
if multi_modal_context and len(multi_modal_context) > 0:
|
|
23
|
+
user_content = []
|
|
24
|
+
user_content.append({"type": "text", "text": user_prompt})
|
|
25
|
+
for context in multi_modal_context:
|
|
26
|
+
user_content.append(context)
|
|
27
|
+
return (
|
|
28
|
+
[
|
|
29
|
+
str_to_message(content=system_prompt, role="system"),
|
|
30
|
+
str_to_message(content=user_content, role="user"),
|
|
31
|
+
]
|
|
32
|
+
if system_prompt
|
|
33
|
+
else [str_to_message(content=user_content, role="user")]
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def str_to_message(content: str | list[dict], role: str = "user") -> dict[str, str | list[dict]]:
|
|
38
|
+
return {"content": content, "role": role}
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
import logging
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
9
|
|
|
10
10
|
from data_designer.config.processors import SchemaTransformProcessorConfig
|
|
11
11
|
from data_designer.engine.dataset_builders.artifact_storage import BatchStage
|
|
@@ -20,26 +20,6 @@ if TYPE_CHECKING:
|
|
|
20
20
|
logger = logging.getLogger(__name__)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
def _json_escape_record(record: dict[str, Any]) -> dict[str, Any]:
|
|
24
|
-
"""Escape record values for safe insertion into a JSON template."""
|
|
25
|
-
|
|
26
|
-
def escape_for_json_string(s: str) -> str:
|
|
27
|
-
"""Use json.dumps to escape, then strip the surrounding quotes."""
|
|
28
|
-
return json.dumps(s)[1:-1]
|
|
29
|
-
|
|
30
|
-
escaped = {}
|
|
31
|
-
for key, value in record.items():
|
|
32
|
-
if isinstance(value, str):
|
|
33
|
-
escaped[key] = escape_for_json_string(value)
|
|
34
|
-
elif isinstance(value, (dict, list)):
|
|
35
|
-
escaped[key] = escape_for_json_string(json.dumps(value))
|
|
36
|
-
elif value is None:
|
|
37
|
-
escaped[key] = "null"
|
|
38
|
-
else:
|
|
39
|
-
escaped[key] = str(value)
|
|
40
|
-
return escaped
|
|
41
|
-
|
|
42
|
-
|
|
43
23
|
class SchemaTransformProcessor(WithJinja2UserTemplateRendering, Processor[SchemaTransformProcessorConfig]):
|
|
44
24
|
@property
|
|
45
25
|
def template_as_str(self) -> str:
|
|
@@ -47,12 +27,10 @@ class SchemaTransformProcessor(WithJinja2UserTemplateRendering, Processor[Schema
|
|
|
47
27
|
|
|
48
28
|
def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
|
|
49
29
|
self.prepare_jinja2_template_renderer(self.template_as_str, data.columns.to_list())
|
|
50
|
-
formatted_records = [
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
rendered = self.render_template(escaped)
|
|
55
|
-
formatted_records.append(json.loads(rendered))
|
|
30
|
+
formatted_records = [
|
|
31
|
+
json.loads(self.render_template(deserialize_json_values(record)).replace("\n", "\\n"))
|
|
32
|
+
for record in data.to_dict(orient="records")
|
|
33
|
+
]
|
|
56
34
|
formatted_data = pd.DataFrame(formatted_records)
|
|
57
35
|
if current_batch_number is not None:
|
|
58
36
|
self.artifact_storage.write_batch_to_parquet_file(
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
data_designer/engine/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
|
|
2
|
-
data_designer/engine/_version.py,sha256=
|
|
2
|
+
data_designer/engine/_version.py,sha256=yib4WPM_pEWXdpIHBdFnf29aurTH5f4xrnwVlv7cijo,714
|
|
3
3
|
data_designer/engine/compiler.py,sha256=4QAeCJjINtH0afSXygdhiKMyq2KIfaDthK3ApZLgrQ0,4152
|
|
4
4
|
data_designer/engine/configurable_task.py,sha256=6R4FPXPzIeK0lqNVSEXzRDtK14B3dFz38lplr-nkvRE,2539
|
|
5
5
|
data_designer/engine/errors.py,sha256=YXI7ny83BQ16sOK43CpTm384hJTKuZkPTEAjlHlDIfA,1303
|
|
@@ -20,7 +20,7 @@ data_designer/engine/column_generators/generators/__init__.py,sha256=ObZ6NUPeEvv
|
|
|
20
20
|
data_designer/engine/column_generators/generators/base.py,sha256=QElk5KsaUQ3EYwlv40NcZgQsw3HIkX3YQV_0S3erl7Q,4209
|
|
21
21
|
data_designer/engine/column_generators/generators/embedding.py,sha256=uB0jgHlCgctgIUf9ZfMqG1YThbJ0g-GCX3VdNbdDSko,1407
|
|
22
22
|
data_designer/engine/column_generators/generators/expression.py,sha256=BiQcfVTinvQl3OI9nkdhB9B7FGBueWiHJwxTA8uNVuY,2330
|
|
23
|
-
data_designer/engine/column_generators/generators/llm_completion.py,sha256=
|
|
23
|
+
data_designer/engine/column_generators/generators/llm_completion.py,sha256=3S3ikNLLLGnutUdcuswL5dUfcLgT_-he8DiRZ9K706U,4721
|
|
24
24
|
data_designer/engine/column_generators/generators/samplers.py,sha256=gNzURmu9K8Zb5MHamKvZPIxmWlFgl2W4FIVgaFcy4f0,3371
|
|
25
25
|
data_designer/engine/column_generators/generators/seed_dataset.py,sha256=CoQPbz4Ww7pBLaGw8-CYqIk1sjfkBaoRMKZQexdfgKY,6824
|
|
26
26
|
data_designer/engine/column_generators/generators/validation.py,sha256=YfYbk-8_ZUye0No6_Q7hIqpZv_tunnEZ6HkLSMFXlDE,6659
|
|
@@ -29,7 +29,7 @@ data_designer/engine/column_generators/utils/generator_classification.py,sha256=
|
|
|
29
29
|
data_designer/engine/column_generators/utils/judge_score_factory.py,sha256=gESiqMrQzbbcFpZas0sAAAkrH2DL0Z4Nq5ywBO-pQ6k,2141
|
|
30
30
|
data_designer/engine/column_generators/utils/prompt_renderer.py,sha256=LATVAlDYwL7HyM7Nogd6n9XTTk-j9s64o4z0LpKHMhQ,4819
|
|
31
31
|
data_designer/engine/dataset_builders/artifact_storage.py,sha256=CKpTBtJTde7OQvsFZQa1v1autVz5yUxlBHkIKeATFnE,10999
|
|
32
|
-
data_designer/engine/dataset_builders/column_wise_builder.py,sha256=
|
|
32
|
+
data_designer/engine/dataset_builders/column_wise_builder.py,sha256=9n_UYWOulUVvSnqJE9cW9f4ObF4Xa9wRxHiabJvJW8c,15723
|
|
33
33
|
data_designer/engine/dataset_builders/errors.py,sha256=gLXtPcGSMBG10PzQ85dOXskdA0mKbBQrHa_VtP9sbVY,400
|
|
34
34
|
data_designer/engine/dataset_builders/multi_column_configs.py,sha256=U4Pg0ETCBq5phRhb2zt8IFa4fRx-aTMakomKOBnrs0U,1660
|
|
35
35
|
data_designer/engine/dataset_builders/utils/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
|
|
@@ -38,16 +38,15 @@ data_designer/engine/dataset_builders/utils/config_compiler.py,sha256=NGI6U0vgG8
|
|
|
38
38
|
data_designer/engine/dataset_builders/utils/dag.py,sha256=RIEI75OtiphkuDl1vfI_MQC1xMiiIg29s-0C_fNZkWQ,2613
|
|
39
39
|
data_designer/engine/dataset_builders/utils/dataset_batch_manager.py,sha256=IfWd_HcfEzIPhgFp2dJaxNIKRlrPsHqYATFXauvCfaw,8133
|
|
40
40
|
data_designer/engine/dataset_builders/utils/errors.py,sha256=G1MIkQDXguSqHK1EP-60FkG_bys7bJ1UgJnSvcNgtt8,411
|
|
41
|
-
data_designer/engine/dataset_builders/utils/progress_tracker.py,sha256=3zSljzDHwhqgP9IqPUR3XbwC231JvLNWslpmhqKIbUg,4255
|
|
42
41
|
data_designer/engine/models/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
|
|
43
42
|
data_designer/engine/models/errors.py,sha256=k9oZnmk8DRD8U2SVKJJRLwrcdsCcVoJiOb_Q7ZyEdvg,12271
|
|
44
|
-
data_designer/engine/models/facade.py,sha256=
|
|
43
|
+
data_designer/engine/models/facade.py,sha256=UBMpw_o2JcsWpJsPdpTPKfFZCh_i0eeG_oaWi1XeKds,12582
|
|
45
44
|
data_designer/engine/models/factory.py,sha256=2NjI0iiGv8ayQ1c249lsJtha4pDmvmtSjdwvlvitRds,1581
|
|
46
45
|
data_designer/engine/models/litellm_overrides.py,sha256=e9IZCFQ6BhNWlOTncm8ErL8w4rtE1_4USh2mtUYxCZI,6207
|
|
47
|
-
data_designer/engine/models/registry.py,sha256=
|
|
48
|
-
data_designer/engine/models/telemetry.py,sha256=
|
|
46
|
+
data_designer/engine/models/registry.py,sha256=7hZ6TQwwZf259yRZmc3ZI20a4wAo3PCOozPi9Mc5KLo,6827
|
|
47
|
+
data_designer/engine/models/telemetry.py,sha256=wmuekvPRZjNz7p7ImKx5H_hqDRhTv_dSB-u2S6Ze3uo,12502
|
|
49
48
|
data_designer/engine/models/usage.py,sha256=A0LV9Ycuj_7snOsaqnirs4mlkAjozv2mzj2om2FpDoU,2410
|
|
50
|
-
data_designer/engine/models/utils.py,sha256=
|
|
49
|
+
data_designer/engine/models/utils.py,sha256=HS5pXAAz7IcOcijeClC-xxq6R6DUmC2ykZu8Vr33Ivk,1259
|
|
51
50
|
data_designer/engine/models/parsers/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
|
|
52
51
|
data_designer/engine/models/parsers/errors.py,sha256=ODcZ4TOsmZyH4-MoNkKXhjiMm_4gLWPsz90qKtNF9_Q,1053
|
|
53
52
|
data_designer/engine/models/parsers/parser.py,sha256=XkdDt2WEnolvsv2bArq4hhujfJ3kLmG6G2jkRXMYA8c,9489
|
|
@@ -70,7 +69,7 @@ data_designer/engine/processing/gsonschema/validators.py,sha256=ui3PzGjIclI6Hlw4
|
|
|
70
69
|
data_designer/engine/processing/processors/base.py,sha256=bkAQO0yK6ATJ3zTwS7F9FXobenJqydCyfijSP2MM-70,472
|
|
71
70
|
data_designer/engine/processing/processors/drop_columns.py,sha256=xT7ym2pQc-R0-YHIuYDQGFn2uAf74309-pV4H878Wlk,1866
|
|
72
71
|
data_designer/engine/processing/processors/registry.py,sha256=ewuFY8QeXpql5CNTZZa_87aYPGPNv1H0hpJR7CBVuzI,1097
|
|
73
|
-
data_designer/engine/processing/processors/schema_transform.py,sha256=
|
|
72
|
+
data_designer/engine/processing/processors/schema_transform.py,sha256=RhLXXKoj9MFpOqsXZ2hfSaTr7_yUUNI3gmFBS4XtEy4,2006
|
|
74
73
|
data_designer/engine/registry/base.py,sha256=eACpE7o_c2btiiXrOFJw7o0VvACo7DSqhj8AntkNkCQ,3579
|
|
75
74
|
data_designer/engine/registry/data_designer_registry.py,sha256=mz8ksE49pS1JRVDNubYSxTs0j-8Q6sd08F_dYyTCWSE,1528
|
|
76
75
|
data_designer/engine/registry/errors.py,sha256=k1EaV7egNQwNmRsI8EfymTfeNprcDutPf2M6Vc1nbn8,350
|
|
@@ -109,6 +108,6 @@ data_designer/engine/validators/local_callable.py,sha256=JaL-yOXrTFpubiO2QlSt4Qb
|
|
|
109
108
|
data_designer/engine/validators/python.py,sha256=omXjwMaomQYiyq4g6XqKt2wexVuI_rWue9Dk-CYc-do,8039
|
|
110
109
|
data_designer/engine/validators/remote.py,sha256=rythhIrH2GvqncMQeF3FiJa9Om0KZWeK3cWjW-ZubaM,3077
|
|
111
110
|
data_designer/engine/validators/sql.py,sha256=AMaEdA-gj9j0zwVp809x3ycKltd51wVEhI8mMYGyxd4,2408
|
|
112
|
-
data_designer_engine-0.4.
|
|
113
|
-
data_designer_engine-0.4.
|
|
114
|
-
data_designer_engine-0.4.
|
|
111
|
+
data_designer_engine-0.4.0rc1.dist-info/METADATA,sha256=FybLz1fOjJ2bK0zQ93Ti17o7WZTxDFtrBeGx7Oa6jCo,1876
|
|
112
|
+
data_designer_engine-0.4.0rc1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
113
|
+
data_designer_engine-0.4.0rc1.dist-info/RECORD,,
|
|
@@ -1,122 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
from __future__ import annotations
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
import time
|
|
8
|
-
from threading import Lock
|
|
9
|
-
|
|
10
|
-
from data_designer.logging import RandomEmoji
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class ProgressTracker:
|
|
16
|
-
"""
|
|
17
|
-
Thread-safe progress tracker for monitoring concurrent task completion.
|
|
18
|
-
|
|
19
|
-
Tracks completed, successful, and failed task counts and logs progress
|
|
20
|
-
at configurable intervals. Designed for use with ConcurrentThreadExecutor
|
|
21
|
-
to provide visibility into long-running batch operations.
|
|
22
|
-
|
|
23
|
-
Example usage:
|
|
24
|
-
tracker = ProgressTracker(total_records=100, label="LLM_TEXT column 'response'")
|
|
25
|
-
tracker.log_start(max_workers=8)
|
|
26
|
-
|
|
27
|
-
# In callbacks from ConcurrentThreadExecutor:
|
|
28
|
-
tracker.record_success() # or tracker.record_failure()
|
|
29
|
-
|
|
30
|
-
# After executor completes:
|
|
31
|
-
tracker.log_final()
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
def __init__(self, total_records: int, label: str, log_interval_percent: int = 10):
|
|
35
|
-
"""
|
|
36
|
-
Initialize the progress tracker.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
total_records: Total number of records to process.
|
|
40
|
-
label: Human-readable label for log messages (e.g., "LLM_TEXT column 'response'").
|
|
41
|
-
log_interval_percent: How often to log progress as a percentage (default 10%).
|
|
42
|
-
"""
|
|
43
|
-
self.total_records = total_records
|
|
44
|
-
self.label = label
|
|
45
|
-
|
|
46
|
-
self.completed = 0
|
|
47
|
-
self.success = 0
|
|
48
|
-
self.failed = 0
|
|
49
|
-
|
|
50
|
-
interval_fraction = max(1, log_interval_percent) / 100.0
|
|
51
|
-
self.log_interval = max(1, int(total_records * interval_fraction)) if total_records > 0 else 1
|
|
52
|
-
self.next_log_at = self.log_interval
|
|
53
|
-
|
|
54
|
-
self.start_time = time.perf_counter()
|
|
55
|
-
self.lock = Lock()
|
|
56
|
-
self._random_emoji = RandomEmoji()
|
|
57
|
-
|
|
58
|
-
def log_start(self, max_workers: int) -> None:
|
|
59
|
-
"""Log the start of processing with worker count and interval information."""
|
|
60
|
-
logger.info(
|
|
61
|
-
"🐙 Processing %s with %d concurrent workers",
|
|
62
|
-
self.label,
|
|
63
|
-
max_workers,
|
|
64
|
-
)
|
|
65
|
-
logger.info(
|
|
66
|
-
"🧭 %s will report progress every %d record(s).",
|
|
67
|
-
self.label,
|
|
68
|
-
self.log_interval,
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
def record_success(self) -> None:
|
|
72
|
-
"""Record a successful task completion and log progress if at interval."""
|
|
73
|
-
self._record_completion(success=True)
|
|
74
|
-
|
|
75
|
-
def record_failure(self) -> None:
|
|
76
|
-
"""Record a failed task completion and log progress if at interval."""
|
|
77
|
-
self._record_completion(success=False)
|
|
78
|
-
|
|
79
|
-
def log_final(self) -> None:
|
|
80
|
-
"""Log final progress summary."""
|
|
81
|
-
with self.lock:
|
|
82
|
-
if self.completed > 0:
|
|
83
|
-
self._log_progress_unlocked()
|
|
84
|
-
|
|
85
|
-
def _record_completion(self, *, success: bool) -> None:
|
|
86
|
-
should_log = False
|
|
87
|
-
with self.lock:
|
|
88
|
-
self.completed += 1
|
|
89
|
-
if success:
|
|
90
|
-
self.success += 1
|
|
91
|
-
else:
|
|
92
|
-
self.failed += 1
|
|
93
|
-
|
|
94
|
-
if self.completed >= self.next_log_at and self.completed < self.total_records:
|
|
95
|
-
should_log = True
|
|
96
|
-
while self.next_log_at <= self.completed:
|
|
97
|
-
self.next_log_at += self.log_interval
|
|
98
|
-
|
|
99
|
-
if should_log:
|
|
100
|
-
with self.lock:
|
|
101
|
-
self._log_progress_unlocked()
|
|
102
|
-
|
|
103
|
-
def _log_progress_unlocked(self) -> None:
|
|
104
|
-
"""Log current progress. Must be called while holding the lock."""
|
|
105
|
-
elapsed = time.perf_counter() - self.start_time
|
|
106
|
-
rate = self.completed / elapsed if elapsed > 0 else 0.0
|
|
107
|
-
remaining = max(0, self.total_records - self.completed)
|
|
108
|
-
eta = f"{(remaining / rate):.1f}s" if rate > 0 else "unknown"
|
|
109
|
-
percent = (self.completed / self.total_records) * 100 if self.total_records else 100.0
|
|
110
|
-
|
|
111
|
-
logger.info(
|
|
112
|
-
" |-- %s %s progress: %d/%d (%.0f%%) complete, %d ok, %d failed, %.2f rec/s, eta %s",
|
|
113
|
-
self._random_emoji.progress(percent),
|
|
114
|
-
self.label,
|
|
115
|
-
self.completed,
|
|
116
|
-
self.total_records,
|
|
117
|
-
percent,
|
|
118
|
-
self.success,
|
|
119
|
-
self.failed,
|
|
120
|
-
rate,
|
|
121
|
-
eta,
|
|
122
|
-
)
|
|
File without changes
|