data-designer-engine 0.4.0__py3-none-any.whl → 0.4.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.4.0'
32
- __version_tuple__ = version_tuple = (0, 4, 0)
31
+ __version__ = version = '0.4.0rc2'
32
+ __version_tuple__ = version_tuple = (0, 4, 0, 'rc2')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -12,7 +12,7 @@ from data_designer.config.column_configs import (
12
12
  LLMStructuredColumnConfig,
13
13
  LLMTextColumnConfig,
14
14
  )
15
- from data_designer.config.utils.constants import TRACE_COLUMN_POSTFIX
15
+ from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
16
16
  from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
17
17
  from data_designer.engine.column_generators.utils.prompt_renderer import (
18
18
  PromptType,
@@ -66,7 +66,7 @@ class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfig
66
66
  for context in self.config.multi_modal_context:
67
67
  multi_modal_context.extend(context.get_contexts(deserialized_record))
68
68
 
69
- response, trace = self.model.generate(
69
+ response, reasoning_trace = self.model.generate(
70
70
  prompt=self.prompt_renderer.render(
71
71
  record=deserialized_record,
72
72
  prompt_template=self.config.prompt,
@@ -87,11 +87,8 @@ class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfig
87
87
  serialized_output = self.response_recipe.serialize_output(response)
88
88
  data[self.config.name] = self._process_serialized_output(serialized_output)
89
89
 
90
- should_save_trace = (
91
- self.config.with_trace or self.resource_provider.run_config.debug_override_save_all_column_traces
92
- )
93
- if should_save_trace:
94
- data[self.config.name + TRACE_COLUMN_POSTFIX] = [message.to_dict() for message in trace]
90
+ if reasoning_trace:
91
+ data[self.config.name + REASONING_TRACE_COLUMN_POSTFIX] = reasoning_trace
95
92
 
96
93
  return data
97
94
 
@@ -34,7 +34,6 @@ from data_designer.engine.dataset_builders.multi_column_configs import MultiColu
34
34
  from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
35
35
  from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
36
36
  from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
37
- from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker
38
37
  from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
39
38
  from data_designer.engine.processing.processors.base import Processor
40
39
  from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
@@ -222,18 +221,16 @@ class ColumnWiseDatasetBuilder:
222
221
  "generator so concurrency through threads is not supported."
223
222
  )
224
223
 
225
- progress_tracker = ProgressTracker(
226
- total_records=self.batch_manager.num_records_batch,
227
- label=f"{generator.config.column_type} column '{generator.config.name}'",
224
+ logger.info(
225
+ f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' "
226
+ f"with {max_workers} concurrent workers"
228
227
  )
229
- progress_tracker.log_start(max_workers)
230
-
231
228
  settings = self._resource_provider.run_config
232
229
  with ConcurrentThreadExecutor(
233
230
  max_workers=max_workers,
234
231
  column_name=generator.config.name,
235
- result_callback=self._make_result_callback(progress_tracker),
236
- error_callback=self._make_error_callback(progress_tracker),
232
+ result_callback=self._worker_result_callback,
233
+ error_callback=self._worker_error_callback,
237
234
  shutdown_error_rate=settings.shutdown_error_rate,
238
235
  shutdown_error_window=settings.shutdown_error_window,
239
236
  disable_early_shutdown=settings.disable_early_shutdown,
@@ -241,26 +238,10 @@ class ColumnWiseDatasetBuilder:
241
238
  for i, record in self.batch_manager.iter_current_batch():
242
239
  executor.submit(lambda record: generator.generate(record), record, context={"index": i})
243
240
 
244
- progress_tracker.log_final()
245
-
246
241
  if len(self._records_to_drop) > 0:
247
242
  self.batch_manager.drop_records(self._records_to_drop)
248
243
  self._records_to_drop.clear()
249
244
 
250
- def _make_result_callback(self, progress_tracker: ProgressTracker) -> Callable[[dict], None]:
251
- def callback(result: dict, *, context: dict | None = None) -> None:
252
- self._worker_result_callback(result, context=context)
253
- progress_tracker.record_success()
254
-
255
- return callback
256
-
257
- def _make_error_callback(self, progress_tracker: ProgressTracker) -> Callable[[Exception], None]:
258
- def callback(exc: Exception, *, context: dict | None = None) -> None:
259
- self._worker_error_callback(exc, context=context)
260
- progress_tracker.record_failure()
261
-
262
- return callback
263
-
264
245
  def _write_processed_batch(self, dataframe: pd.DataFrame) -> None:
265
246
  self.batch_manager.update_records(dataframe.to_dict(orient="records"))
266
247
  self.batch_manager.write()
@@ -18,7 +18,7 @@ from data_designer.engine.models.errors import (
18
18
  from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
19
19
  from data_designer.engine.models.parsers.errors import ParserException
20
20
  from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
21
- from data_designer.engine.models.utils import ChatMessage, prompt_to_messages
21
+ from data_designer.engine.models.utils import prompt_to_messages, str_to_message
22
22
  from data_designer.engine.secret_resolver import SecretResolver
23
23
  from data_designer.lazy_heavy_imports import litellm
24
24
 
@@ -67,17 +67,16 @@ class ModelFacade:
67
67
  return self._usage_stats
68
68
 
69
69
  def completion(
70
- self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs
70
+ self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs
71
71
  ) -> litellm.ModelResponse:
72
- message_payloads = [message.to_dict() for message in messages]
73
72
  logger.debug(
74
73
  f"Prompting model {self.model_name!r}...",
75
- extra={"model": self.model_name, "messages": message_payloads},
74
+ extra={"model": self.model_name, "messages": messages},
76
75
  )
77
76
  response = None
78
77
  kwargs = self.consolidate_kwargs(**kwargs)
79
78
  try:
80
- response = self._router.completion(model=self.model_name, messages=message_payloads, **kwargs)
79
+ response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
81
80
  logger.debug(
82
81
  f"Received completion from model {self.model_name!r}",
83
82
  extra={
@@ -150,7 +149,7 @@ class ModelFacade:
150
149
  skip_usage_tracking: bool = False,
151
150
  purpose: str | None = None,
152
151
  **kwargs,
153
- ) -> tuple[Any, list[ChatMessage]]:
152
+ ) -> tuple[Any, str | None]:
154
153
  """Generate a parsed output with correction steps.
155
154
 
156
155
  This generation call will attempt to generate an output which is
@@ -183,12 +182,6 @@ class ModelFacade:
183
182
  It is expected to be used by the @catch_llm_exceptions decorator.
184
183
  **kwargs: Additional arguments to pass to the model.
185
184
 
186
- Returns:
187
- A tuple containing:
188
- - The parsed output object from the parser.
189
- - The full trace of ChatMessage entries in the conversation, including any
190
- corrections and reasoning traces. Callers can decide whether to store this.
191
-
192
185
  Raises:
193
186
  GenerationValidationFailureError: If the maximum number of retries or
194
187
  correction steps are met and the last response failures on
@@ -197,17 +190,29 @@ class ModelFacade:
197
190
  output_obj = None
198
191
  curr_num_correction_steps = 0
199
192
  curr_num_restarts = 0
193
+ curr_generation_attempt = 0
194
+ max_generation_attempts = (max_correction_steps + 1) * (max_conversation_restarts + 1)
200
195
 
201
196
  starting_messages = prompt_to_messages(
202
197
  user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context
203
198
  )
204
- messages: list[ChatMessage] = deepcopy(starting_messages)
199
+ messages = deepcopy(starting_messages)
205
200
 
206
201
  while True:
202
+ curr_generation_attempt += 1
203
+ logger.debug(
204
+ f"Starting generation attempt {curr_generation_attempt} of {max_generation_attempts} attempts."
205
+ )
206
+
207
207
  completion_response = self.completion(messages, skip_usage_tracking=skip_usage_tracking, **kwargs)
208
208
  response = completion_response.choices[0].message.content or ""
209
209
  reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None)
210
- messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
210
+
211
+ if reasoning_trace:
212
+ ## There are generally some extra newlines with how these get parsed.
213
+ response = response.strip()
214
+ reasoning_trace = reasoning_trace.strip()
215
+
211
216
  curr_num_correction_steps += 1
212
217
 
213
218
  try:
@@ -218,23 +223,21 @@ class ModelFacade:
218
223
  raise GenerationValidationFailureError(
219
224
  "Unsuccessful generation attempt. No retries were attempted."
220
225
  ) from exc
221
-
222
226
  if curr_num_correction_steps <= max_correction_steps:
223
- # Add user message with error for correction
224
- messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc))))
225
-
227
+ ## Add turns to loop-back errors for correction
228
+ messages += [
229
+ str_to_message(content=response, role="assistant"),
230
+ str_to_message(content=str(get_exception_primary_cause(exc)), role="user"),
231
+ ]
226
232
  elif curr_num_restarts < max_conversation_restarts:
227
233
  curr_num_correction_steps = 0
228
234
  curr_num_restarts += 1
229
235
  messages = deepcopy(starting_messages)
230
-
231
236
  else:
232
237
  raise GenerationValidationFailureError(
233
- f"Unsuccessful generation despite {max_correction_steps} correction steps "
234
- f"and {max_conversation_restarts} conversation restarts."
238
+ f"Unsuccessful generation attempt despite {max_generation_attempts} attempts."
235
239
  ) from exc
236
-
237
- return output_obj, messages
240
+ return output_obj, reasoning_trace
238
241
 
239
242
  def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict:
240
243
  provider = self._model_provider_registry.get_provider(model_config.provider)
@@ -3,81 +3,7 @@
3
3
 
4
4
  from __future__ import annotations
5
5
 
6
- from dataclasses import dataclass, field
7
- from typing import Any, Literal
8
-
9
-
10
- @dataclass
11
- class ChatMessage:
12
- """A chat message in an LLM conversation.
13
-
14
- This dataclass represents messages exchanged in a conversation with an LLM,
15
- supporting various message types including user prompts, assistant responses,
16
- system instructions, and tool interactions.
17
-
18
- Attributes:
19
- role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'.
20
- content: The message content. Can be a string or a list of content blocks
21
- for multimodal messages (e.g., text + images).
22
- reasoning_content: Optional reasoning/thinking content from the assistant,
23
- typically from extended thinking or chain-of-thought models.
24
- tool_calls: Optional list of tool calls requested by the assistant.
25
- Each tool call contains 'id', 'type', and 'function' keys.
26
- tool_call_id: Optional ID linking a tool response to its corresponding
27
- tool call. Required for messages with role='tool'.
28
- """
29
-
30
- role: Literal["user", "assistant", "system", "tool"]
31
- content: str | list[dict[str, Any]] = ""
32
- reasoning_content: str | None = None
33
- tool_calls: list[dict[str, Any]] = field(default_factory=list)
34
- tool_call_id: str | None = None
35
-
36
- def to_dict(self) -> dict[str, Any]:
37
- """Convert the message to a dictionary format for API calls.
38
-
39
- Returns:
40
- A dictionary containing the message fields. Only includes non-empty
41
- optional fields to keep the output clean.
42
- """
43
- result: dict[str, Any] = {"role": self.role, "content": self.content}
44
- if self.reasoning_content:
45
- result["reasoning_content"] = self.reasoning_content
46
- if self.tool_calls:
47
- result["tool_calls"] = self.tool_calls
48
- if self.tool_call_id:
49
- result["tool_call_id"] = self.tool_call_id
50
- return result
51
-
52
- @classmethod
53
- def as_user(cls, content: str | list[dict[str, Any]]) -> ChatMessage:
54
- """Create a user message."""
55
- return cls(role="user", content=content)
56
-
57
- @classmethod
58
- def as_assistant(
59
- cls,
60
- content: str = "",
61
- reasoning_content: str | None = None,
62
- tool_calls: list[dict[str, Any]] | None = None,
63
- ) -> ChatMessage:
64
- """Create an assistant message."""
65
- return cls(
66
- role="assistant",
67
- content=content,
68
- reasoning_content=reasoning_content,
69
- tool_calls=tool_calls or [],
70
- )
71
-
72
- @classmethod
73
- def as_system(cls, content: str) -> ChatMessage:
74
- """Create a system message."""
75
- return cls(role="system", content=content)
76
-
77
- @classmethod
78
- def as_tool(cls, content: str, tool_call_id: str) -> ChatMessage:
79
- """Create a tool response message."""
80
- return cls(role="tool", content=content, tool_call_id=tool_call_id)
6
+ from typing import Any
81
7
 
82
8
 
83
9
  def prompt_to_messages(
@@ -85,17 +11,28 @@ def prompt_to_messages(
85
11
  user_prompt: str,
86
12
  system_prompt: str | None = None,
87
13
  multi_modal_context: list[dict[str, Any]] | None = None,
88
- ) -> list[ChatMessage]:
89
- """Convert a user and system prompt into ChatMessage list.
14
+ ) -> list[dict[str, str | list[dict]]]:
15
+ """Convert a user and system prompt into Messages format.
90
16
 
91
17
  Args:
92
18
  user_prompt (str): A user prompt.
93
19
  system_prompt (str, optional): An optional system prompt.
94
20
  """
95
- user_content: str | list[dict[str, Any]] = user_prompt
96
- if multi_modal_context:
97
- user_content = [*multi_modal_context, {"type": "text", "text": user_prompt}]
98
-
99
- if system_prompt:
100
- return [ChatMessage.as_system(system_prompt), ChatMessage.as_user(user_content)]
101
- return [ChatMessage.as_user(user_content)]
21
+ user_content = user_prompt
22
+ if multi_modal_context and len(multi_modal_context) > 0:
23
+ user_content = []
24
+ for context in multi_modal_context:
25
+ user_content.append(context)
26
+ user_content.append({"type": "text", "text": user_prompt})
27
+ return (
28
+ [
29
+ str_to_message(content=system_prompt, role="system"),
30
+ str_to_message(content=user_content, role="user"),
31
+ ]
32
+ if system_prompt
33
+ else [str_to_message(content=user_content, role="user")]
34
+ )
35
+
36
+
37
+ def str_to_message(content: str | list[dict], role: str = "user") -> dict[str, str | list[dict]]:
38
+ return {"content": content, "role": role}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: data-designer-engine
3
- Version: 0.4.0
3
+ Version: 0.4.0rc2
4
4
  Summary: Generation engine for DataDesigner synthetic data generation
5
5
  License-Expression: Apache-2.0
6
6
  Classifier: Development Status :: 4 - Beta
@@ -1,5 +1,5 @@
1
1
  data_designer/engine/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
2
- data_designer/engine/_version.py,sha256=2_0GUP7yBCXRus-qiJKxQD62z172WSs1sQ6DVpPsbmM,704
2
+ data_designer/engine/_version.py,sha256=FvItxCBzPigrdVpFPfL1gQeV1-km5r7nCNGUzrYebTU,714
3
3
  data_designer/engine/compiler.py,sha256=4QAeCJjINtH0afSXygdhiKMyq2KIfaDthK3ApZLgrQ0,4152
4
4
  data_designer/engine/configurable_task.py,sha256=6R4FPXPzIeK0lqNVSEXzRDtK14B3dFz38lplr-nkvRE,2539
5
5
  data_designer/engine/errors.py,sha256=YXI7ny83BQ16sOK43CpTm384hJTKuZkPTEAjlHlDIfA,1303
@@ -20,7 +20,7 @@ data_designer/engine/column_generators/generators/__init__.py,sha256=ObZ6NUPeEvv
20
20
  data_designer/engine/column_generators/generators/base.py,sha256=QElk5KsaUQ3EYwlv40NcZgQsw3HIkX3YQV_0S3erl7Q,4209
21
21
  data_designer/engine/column_generators/generators/embedding.py,sha256=uB0jgHlCgctgIUf9ZfMqG1YThbJ0g-GCX3VdNbdDSko,1407
22
22
  data_designer/engine/column_generators/generators/expression.py,sha256=BiQcfVTinvQl3OI9nkdhB9B7FGBueWiHJwxTA8uNVuY,2330
23
- data_designer/engine/column_generators/generators/llm_completion.py,sha256=gMOOdd0_BY-RLXrArx1u8GL7YJfVvKceTqn_Zg1xHPI,4897
23
+ data_designer/engine/column_generators/generators/llm_completion.py,sha256=udYWE3lwaQhZqxRTHQc6w1kWGEvLAfIh2OUjX6vxMB0,4750
24
24
  data_designer/engine/column_generators/generators/samplers.py,sha256=gNzURmu9K8Zb5MHamKvZPIxmWlFgl2W4FIVgaFcy4f0,3371
25
25
  data_designer/engine/column_generators/generators/seed_dataset.py,sha256=CoQPbz4Ww7pBLaGw8-CYqIk1sjfkBaoRMKZQexdfgKY,6824
26
26
  data_designer/engine/column_generators/generators/validation.py,sha256=YfYbk-8_ZUye0No6_Q7hIqpZv_tunnEZ6HkLSMFXlDE,6659
@@ -29,7 +29,7 @@ data_designer/engine/column_generators/utils/generator_classification.py,sha256=
29
29
  data_designer/engine/column_generators/utils/judge_score_factory.py,sha256=gESiqMrQzbbcFpZas0sAAAkrH2DL0Z4Nq5ywBO-pQ6k,2141
30
30
  data_designer/engine/column_generators/utils/prompt_renderer.py,sha256=LATVAlDYwL7HyM7Nogd6n9XTTk-j9s64o4z0LpKHMhQ,4819
31
31
  data_designer/engine/dataset_builders/artifact_storage.py,sha256=CKpTBtJTde7OQvsFZQa1v1autVz5yUxlBHkIKeATFnE,10999
32
- data_designer/engine/dataset_builders/column_wise_builder.py,sha256=UAfl-iejVYqvmVx2anGmtPKfmqztM5o8nvyVzxYrM_0,16581
32
+ data_designer/engine/dataset_builders/column_wise_builder.py,sha256=9n_UYWOulUVvSnqJE9cW9f4ObF4Xa9wRxHiabJvJW8c,15723
33
33
  data_designer/engine/dataset_builders/errors.py,sha256=gLXtPcGSMBG10PzQ85dOXskdA0mKbBQrHa_VtP9sbVY,400
34
34
  data_designer/engine/dataset_builders/multi_column_configs.py,sha256=U4Pg0ETCBq5phRhb2zt8IFa4fRx-aTMakomKOBnrs0U,1660
35
35
  data_designer/engine/dataset_builders/utils/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
@@ -38,16 +38,15 @@ data_designer/engine/dataset_builders/utils/config_compiler.py,sha256=NGI6U0vgG8
38
38
  data_designer/engine/dataset_builders/utils/dag.py,sha256=RIEI75OtiphkuDl1vfI_MQC1xMiiIg29s-0C_fNZkWQ,2613
39
39
  data_designer/engine/dataset_builders/utils/dataset_batch_manager.py,sha256=IfWd_HcfEzIPhgFp2dJaxNIKRlrPsHqYATFXauvCfaw,8133
40
40
  data_designer/engine/dataset_builders/utils/errors.py,sha256=G1MIkQDXguSqHK1EP-60FkG_bys7bJ1UgJnSvcNgtt8,411
41
- data_designer/engine/dataset_builders/utils/progress_tracker.py,sha256=3zSljzDHwhqgP9IqPUR3XbwC231JvLNWslpmhqKIbUg,4255
42
41
  data_designer/engine/models/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
43
42
  data_designer/engine/models/errors.py,sha256=k9oZnmk8DRD8U2SVKJJRLwrcdsCcVoJiOb_Q7ZyEdvg,12271
44
- data_designer/engine/models/facade.py,sha256=ckwFxcMHC23-qKU8bdBC0eWKYx6vfVjvp9-0AtCXMX0,12497
43
+ data_designer/engine/models/facade.py,sha256=UBMpw_o2JcsWpJsPdpTPKfFZCh_i0eeG_oaWi1XeKds,12582
45
44
  data_designer/engine/models/factory.py,sha256=2NjI0iiGv8ayQ1c249lsJtha4pDmvmtSjdwvlvitRds,1581
46
45
  data_designer/engine/models/litellm_overrides.py,sha256=e9IZCFQ6BhNWlOTncm8ErL8w4rtE1_4USh2mtUYxCZI,6207
47
46
  data_designer/engine/models/registry.py,sha256=Bid7Mv_ebzbTrlfzN-1wbcFxp_qQwilL0h2iwN5UPJ0,7099
48
47
  data_designer/engine/models/telemetry.py,sha256=_VZR6Iatr6-5Hypw3bes5Jr4y7Y3VagxFEVAv36eHcE,12733
49
48
  data_designer/engine/models/usage.py,sha256=A0LV9Ycuj_7snOsaqnirs4mlkAjozv2mzj2om2FpDoU,2410
50
- data_designer/engine/models/utils.py,sha256=Szy3lOg_E14DRAx6U2Dpr3HXPg09xIr3VUnoREiZ1mw,3807
49
+ data_designer/engine/models/utils.py,sha256=sLBs-STJSe7BGzDAngRGGxo6GwAvFmtimqUs54zZ6DU,1259
51
50
  data_designer/engine/models/parsers/__init__.py,sha256=ObZ6NUPeEvvpGTJ5WIGKUyIrIjaI747OM6ErweRtHxQ,137
52
51
  data_designer/engine/models/parsers/errors.py,sha256=ODcZ4TOsmZyH4-MoNkKXhjiMm_4gLWPsz90qKtNF9_Q,1053
53
52
  data_designer/engine/models/parsers/parser.py,sha256=XkdDt2WEnolvsv2bArq4hhujfJ3kLmG6G2jkRXMYA8c,9489
@@ -109,6 +108,6 @@ data_designer/engine/validators/local_callable.py,sha256=JaL-yOXrTFpubiO2QlSt4Qb
109
108
  data_designer/engine/validators/python.py,sha256=omXjwMaomQYiyq4g6XqKt2wexVuI_rWue9Dk-CYc-do,8039
110
109
  data_designer/engine/validators/remote.py,sha256=rythhIrH2GvqncMQeF3FiJa9Om0KZWeK3cWjW-ZubaM,3077
111
110
  data_designer/engine/validators/sql.py,sha256=AMaEdA-gj9j0zwVp809x3ycKltd51wVEhI8mMYGyxd4,2408
112
- data_designer_engine-0.4.0.dist-info/METADATA,sha256=hHuNlKxfNErQUPbmwmBkux0M2q9ebuFna97Xoe8y2lc,1873
113
- data_designer_engine-0.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
114
- data_designer_engine-0.4.0.dist-info/RECORD,,
111
+ data_designer_engine-0.4.0rc2.dist-info/METADATA,sha256=ZChyQl5ksGCWVi_XE6wD-GXG9-wWHko1vBDnd9ecLqw,1876
112
+ data_designer_engine-0.4.0rc2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
113
+ data_designer_engine-0.4.0rc2.dist-info/RECORD,,
@@ -1,122 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- import time
8
- from threading import Lock
9
-
10
- from data_designer.logging import RandomEmoji
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class ProgressTracker:
16
- """
17
- Thread-safe progress tracker for monitoring concurrent task completion.
18
-
19
- Tracks completed, successful, and failed task counts and logs progress
20
- at configurable intervals. Designed for use with ConcurrentThreadExecutor
21
- to provide visibility into long-running batch operations.
22
-
23
- Example usage:
24
- tracker = ProgressTracker(total_records=100, label="LLM_TEXT column 'response'")
25
- tracker.log_start(max_workers=8)
26
-
27
- # In callbacks from ConcurrentThreadExecutor:
28
- tracker.record_success() # or tracker.record_failure()
29
-
30
- # After executor completes:
31
- tracker.log_final()
32
- """
33
-
34
- def __init__(self, total_records: int, label: str, log_interval_percent: int = 10):
35
- """
36
- Initialize the progress tracker.
37
-
38
- Args:
39
- total_records: Total number of records to process.
40
- label: Human-readable label for log messages (e.g., "LLM_TEXT column 'response'").
41
- log_interval_percent: How often to log progress as a percentage (default 10%).
42
- """
43
- self.total_records = total_records
44
- self.label = label
45
-
46
- self.completed = 0
47
- self.success = 0
48
- self.failed = 0
49
-
50
- interval_fraction = max(1, log_interval_percent) / 100.0
51
- self.log_interval = max(1, int(total_records * interval_fraction)) if total_records > 0 else 1
52
- self.next_log_at = self.log_interval
53
-
54
- self.start_time = time.perf_counter()
55
- self.lock = Lock()
56
- self._random_emoji = RandomEmoji()
57
-
58
- def log_start(self, max_workers: int) -> None:
59
- """Log the start of processing with worker count and interval information."""
60
- logger.info(
61
- "🐙 Processing %s with %d concurrent workers",
62
- self.label,
63
- max_workers,
64
- )
65
- logger.info(
66
- "🧭 %s will report progress every %d record(s).",
67
- self.label,
68
- self.log_interval,
69
- )
70
-
71
- def record_success(self) -> None:
72
- """Record a successful task completion and log progress if at interval."""
73
- self._record_completion(success=True)
74
-
75
- def record_failure(self) -> None:
76
- """Record a failed task completion and log progress if at interval."""
77
- self._record_completion(success=False)
78
-
79
- def log_final(self) -> None:
80
- """Log final progress summary."""
81
- with self.lock:
82
- if self.completed > 0:
83
- self._log_progress_unlocked()
84
-
85
- def _record_completion(self, *, success: bool) -> None:
86
- should_log = False
87
- with self.lock:
88
- self.completed += 1
89
- if success:
90
- self.success += 1
91
- else:
92
- self.failed += 1
93
-
94
- if self.completed >= self.next_log_at and self.completed < self.total_records:
95
- should_log = True
96
- while self.next_log_at <= self.completed:
97
- self.next_log_at += self.log_interval
98
-
99
- if should_log:
100
- with self.lock:
101
- self._log_progress_unlocked()
102
-
103
- def _log_progress_unlocked(self) -> None:
104
- """Log current progress. Must be called while holding the lock."""
105
- elapsed = time.perf_counter() - self.start_time
106
- rate = self.completed / elapsed if elapsed > 0 else 0.0
107
- remaining = max(0, self.total_records - self.completed)
108
- eta = f"{(remaining / rate):.1f}s" if rate > 0 else "unknown"
109
- percent = (self.completed / self.total_records) * 100 if self.total_records else 100.0
110
-
111
- logger.info(
112
- " |-- %s %s progress: %d/%d (%.0f%%) complete, %d ok, %d failed, %.2f rec/s, eta %s",
113
- self._random_emoji.progress(percent),
114
- self.label,
115
- self.completed,
116
- self.total_records,
117
- percent,
118
- self.success,
119
- self.failed,
120
- rate,
121
- eta,
122
- )