DeepFabric 4.4.1__py3-none-any.whl → 4.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. deepfabric/__init__.py +8 -0
  2. deepfabric/auth.py +8 -2
  3. deepfabric/builders.py +2 -2
  4. deepfabric/builders_agent.py +18 -6
  5. deepfabric/cli.py +292 -13
  6. deepfabric/cloud_upload.py +884 -0
  7. deepfabric/config.py +47 -20
  8. deepfabric/config_manager.py +2 -2
  9. deepfabric/dataset.py +302 -0
  10. deepfabric/evaluation/backends/__init__.py +2 -0
  11. deepfabric/evaluation/backends/llm_eval_backend.py +527 -0
  12. deepfabric/evaluation/backends/ollama_backend.py +3 -3
  13. deepfabric/evaluation/backends/tool_call_parsers.py +7 -7
  14. deepfabric/evaluation/backends/transformers_backend.py +73 -16
  15. deepfabric/evaluation/evaluator.py +41 -7
  16. deepfabric/evaluation/evaluators/builtin/tool_calling.py +13 -8
  17. deepfabric/evaluation/inference.py +77 -5
  18. deepfabric/evaluation/metrics.py +4 -0
  19. deepfabric/evaluation/parser.py +8 -8
  20. deepfabric/evaluation/reporters/cloud_reporter.py +19 -6
  21. deepfabric/exceptions.py +14 -0
  22. deepfabric/generator.py +8 -4
  23. deepfabric/graph.py +38 -0
  24. deepfabric/hf_hub.py +1 -1
  25. deepfabric/loader.py +554 -0
  26. deepfabric/schemas.py +7 -7
  27. deepfabric/topic_manager.py +4 -0
  28. deepfabric/training/__init__.py +24 -5
  29. deepfabric/training/callback.py +43 -1
  30. deepfabric/training/dataset_utils.py +223 -0
  31. deepfabric/training/metrics_sender.py +50 -16
  32. deepfabric/tui.py +9 -1
  33. deepfabric/utils.py +14 -0
  34. deepfabric/validation.py +1 -1
  35. {deepfabric-4.4.1.dist-info → deepfabric-4.6.0.dist-info}/METADATA +84 -177
  36. {deepfabric-4.4.1.dist-info → deepfabric-4.6.0.dist-info}/RECORD +39 -34
  37. {deepfabric-4.4.1.dist-info → deepfabric-4.6.0.dist-info}/WHEEL +0 -0
  38. {deepfabric-4.4.1.dist-info → deepfabric-4.6.0.dist-info}/entry_points.txt +0 -0
  39. {deepfabric-4.4.1.dist-info → deepfabric-4.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -97,6 +97,8 @@ async def _process_graph_events(graph: Graph, debug: bool = False) -> dict | Non
97
97
  get_tui().error(f" [{idx}] Node ID: {node_id}, Attempts: {attempts}")
98
98
  get_tui().error(f" Error: {last_error}")
99
99
  except Exception as e:
100
+ # Stop TUI before printing error to ensure visibility
101
+ tui.stop_live()
100
102
  if debug:
101
103
  get_tui().error(f"Debug: Full traceback:\n{traceback.format_exc()}")
102
104
  get_tui().error(f"Graph build failed: {str(e)}")
@@ -147,6 +149,8 @@ async def _process_tree_events(tree: Tree, debug: bool = False) -> dict | None:
147
149
  )
148
150
  get_tui().error(f" Error: {failure.get('error', 'Unknown error')}")
149
151
  except Exception as e:
152
+ # Stop TUI before printing error to ensure visibility
153
+ tui.stop_live()
150
154
  if debug:
151
155
  get_tui().error(f"Debug: Full traceback:\n{traceback.format_exc()}")
152
156
  get_tui().error(f"Tree build failed: {str(e)}")
@@ -1,20 +1,27 @@
1
- """DeepFabric training metrics logging.
1
+ """DeepFabric training utilities.
2
2
 
3
- This module provides integration with HuggingFace Trainer and TRL trainers
4
- to log training metrics to the DeepFabric SaaS backend.
3
+ This module provides:
4
+ - Integration with HuggingFace Trainer and TRL trainers for metrics logging
5
+ - Dataset preparation utilities for optimizing training data
5
6
 
6
7
  Features:
7
8
  - Non-blocking async metrics sending
8
9
  - Notebook-friendly API key prompts (like wandb)
9
10
  - Graceful handling of failures without impacting training
11
+ - Tool filtering to reduce sequence lengths and memory usage
10
12
 
11
13
  Usage:
12
- from deepfabric.training import DeepFabricCallback
14
+ from deepfabric.training import DeepFabricCallback, prepare_dataset_for_training
13
15
 
16
+ # Prepare dataset (reduces tool overhead)
17
+ dataset = load_dataset("your/dataset", split="train")
18
+ prepared = prepare_dataset_for_training(dataset, tool_strategy="used_only")
19
+
20
+ # Train with metrics logging
14
21
  trainer = Trainer(
15
22
  model=model,
16
23
  args=training_args,
17
- train_dataset=train_dataset,
24
+ train_dataset=prepared,
18
25
  )
19
26
  trainer.add_callback(DeepFabricCallback(trainer))
20
27
  trainer.train()
@@ -27,9 +34,21 @@ Environment Variables:
27
34
  from __future__ import annotations
28
35
 
29
36
  from .callback import DeepFabricCallback
37
+ from .dataset_utils import (
38
+ ToolInclusionStrategy,
39
+ clean_tool_schema,
40
+ filter_tools_for_sample,
41
+ get_used_tool_names,
42
+ prepare_dataset_for_training,
43
+ )
30
44
  from .metrics_sender import MetricsSender
31
45
 
32
46
  __all__ = [
33
47
  "DeepFabricCallback",
34
48
  "MetricsSender",
49
+ "ToolInclusionStrategy",
50
+ "clean_tool_schema",
51
+ "filter_tools_for_sample",
52
+ "get_used_tool_names",
53
+ "prepare_dataset_for_training",
35
54
  ]
@@ -51,6 +51,7 @@ class DeepFabricCallback:
51
51
  trainer: Any | None = None,
52
52
  api_key: str | None = None,
53
53
  endpoint: str | None = None,
54
+ pipeline_id: str | None = None,
54
55
  enabled: bool = True,
55
56
  ):
56
57
  """Initialize the DeepFabric callback.
@@ -60,11 +61,14 @@ class DeepFabricCallback:
60
61
  api_key: DeepFabric API key (falls back to DEEPFABRIC_API_KEY env var,
61
62
  then prompts in interactive environments)
62
63
  endpoint: API endpoint URL (falls back to DEEPFABRIC_API_URL env var)
64
+ pipeline_id: Pipeline ID to associate training with (falls back to
65
+ DEEPFABRIC_PIPELINE_ID env var or pipeline_id.txt file)
63
66
  enabled: Whether logging is enabled (default: True)
64
67
  """
65
68
  # Get API key from arg, env, or prompt
66
69
  self.api_key = api_key or get_api_key()
67
70
  self.endpoint = endpoint or os.getenv("DEEPFABRIC_API_URL", "https://api.deepfabric.ai")
71
+ self.pipeline_id = pipeline_id or self._get_pipeline_id()
68
72
  self.run_id = str(uuid.uuid4())
69
73
  self.enabled = enabled and self.api_key is not None
70
74
 
@@ -75,14 +79,26 @@ class DeepFabricCallback:
75
79
  self.sender = MetricsSender(
76
80
  endpoint=self.endpoint,
77
81
  api_key=self.api_key if self.enabled else None,
82
+ pipeline_id=self.pipeline_id,
78
83
  )
79
84
 
80
85
  self._run_started = False
81
86
  self._model_name: str | None = None
82
87
  self._training_args_logged = False
88
+ self._start_time: datetime | None = None
83
89
 
84
90
  if self.enabled:
85
- logger.debug(f"DeepFabric callback initialized (run_id={self.run_id})")
91
+ if self.pipeline_id:
92
+ logger.debug(
93
+ f"DeepFabric callback initialized (run_id={self.run_id}, "
94
+ f"pipeline_id={self.pipeline_id})"
95
+ )
96
+ else:
97
+ logger.warning(
98
+ "DeepFabric callback initialized but no pipeline_id set. "
99
+ "Metrics will not be sent. Set DEEPFABRIC_PIPELINE_ID env var "
100
+ "or create pipeline_id.txt file."
101
+ )
86
102
  else:
87
103
  logger.debug("DeepFabric callback disabled (no API key)")
88
104
 
@@ -101,6 +117,7 @@ class DeepFabricCallback:
101
117
  return
102
118
 
103
119
  self._run_started = True
120
+ self._start_time = datetime.now(timezone.utc)
104
121
 
105
122
  # Extract model name from various sources
106
123
  model = kwargs.get("model")
@@ -121,6 +138,7 @@ class DeepFabricCallback:
121
138
  "num_train_epochs": state.num_train_epochs,
122
139
  "is_world_process_zero": getattr(state, "is_world_process_zero", True),
123
140
  },
141
+ "started_at": self._start_time.isoformat(),
124
142
  }
125
143
  )
126
144
 
@@ -204,6 +222,8 @@ class DeepFabricCallback:
204
222
  if not self.enabled or not self._run_started:
205
223
  return
206
224
 
225
+ completed_at = datetime.now(timezone.utc)
226
+
207
227
  self.sender.send_run_end(
208
228
  {
209
229
  "run_id": self.run_id,
@@ -212,6 +232,7 @@ class DeepFabricCallback:
212
232
  "total_flos": getattr(state, "total_flos", None),
213
233
  "best_metric": getattr(state, "best_metric", None),
214
234
  "best_model_checkpoint": getattr(state, "best_model_checkpoint", None),
235
+ "completed_at": completed_at.isoformat(),
215
236
  }
216
237
  )
217
238
 
@@ -246,6 +267,27 @@ class DeepFabricCallback:
246
267
  }
247
268
  )
248
269
 
270
+ def _get_pipeline_id(self) -> str | None:
271
+ """Get pipeline ID from environment or file.
272
+
273
+ Returns:
274
+ Pipeline ID or None
275
+ """
276
+ # Try environment variable first
277
+ pipeline_id = os.getenv("DEEPFABRIC_PIPELINE_ID", "")
278
+ if pipeline_id:
279
+ return pipeline_id
280
+
281
+ # Try pipeline_id.txt file
282
+ pipeline_file = "pipeline_id.txt"
283
+ if os.path.exists(pipeline_file):
284
+ with open(pipeline_file) as f:
285
+ pipeline_id = f.read().strip()
286
+ if pipeline_id:
287
+ return pipeline_id
288
+
289
+ return None
290
+
249
291
  def _extract_model_name(self, args: TrainingArguments, model: Any | None) -> str | None:
250
292
  """Extract model name from various sources.
251
293
 
@@ -0,0 +1,223 @@
1
+ """Dataset preparation utilities for training.
2
+
3
+ This module provides utilities for preparing DeepFabric datasets for training,
4
+ including tool filtering to reduce sequence lengths and memory usage.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+
11
+ from typing import TYPE_CHECKING, Any, Literal
12
+
13
+ if TYPE_CHECKING:
14
+ from datasets import Dataset
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ ToolInclusionStrategy = Literal["all", "used_only", "used_plus_related"]
19
+
20
+
21
+ def get_used_tool_names(messages: list[dict[str, Any]]) -> set[str]:
22
+ """Extract tool names that are actually called in a conversation.
23
+
24
+ Args:
25
+ messages: List of message dicts from the conversation
26
+
27
+ Returns:
28
+ Set of tool names that were called
29
+ """
30
+ used_tools: set[str] = set()
31
+
32
+ for msg in messages:
33
+ if msg.get("role") == "assistant":
34
+ tool_calls = msg.get("tool_calls", [])
35
+ if tool_calls:
36
+ for tc in tool_calls:
37
+ if isinstance(tc, dict):
38
+ # OpenAI format: {"function": {"name": "..."}}
39
+ func = tc.get("function", {})
40
+ if isinstance(func, dict) and func.get("name"):
41
+ used_tools.add(func["name"])
42
+ # Alternative format: {"name": "..."}
43
+ elif tc.get("name"):
44
+ used_tools.add(tc["name"])
45
+
46
+ return used_tools
47
+
48
+
49
+ def clean_tool_schema(tool: dict[str, Any]) -> dict[str, Any]:
50
+ """Remove null/None values from tool schema to reduce size.
51
+
52
+ Args:
53
+ tool: Tool definition in OpenAI format
54
+
55
+ Returns:
56
+ Cleaned tool definition with nulls removed
57
+ """
58
+ if not isinstance(tool, dict):
59
+ return tool
60
+
61
+ cleaned: dict[str, Any] = {}
62
+
63
+ for key, value in tool.items():
64
+ if value is None:
65
+ continue
66
+ if isinstance(value, dict):
67
+ cleaned_value = clean_tool_schema(value)
68
+ # Only include if dict is not empty after cleaning
69
+ if cleaned_value:
70
+ cleaned[key] = cleaned_value
71
+ elif isinstance(value, list):
72
+ cleaned_list = []
73
+ for item in value:
74
+ if isinstance(item, dict):
75
+ cleaned_item = clean_tool_schema(item)
76
+ if cleaned_item:
77
+ cleaned_list.append(cleaned_item)
78
+ elif item is not None:
79
+ cleaned_list.append(item)
80
+ if cleaned_list:
81
+ cleaned[key] = cleaned_list
82
+ else:
83
+ cleaned[key] = value
84
+
85
+ return cleaned
86
+
87
+
88
+ def filter_tools_for_sample(
89
+ sample: dict[str, Any],
90
+ strategy: ToolInclusionStrategy = "used_only",
91
+ min_tools: int = 1,
92
+ clean_schemas: bool = True,
93
+ ) -> dict[str, Any]:
94
+ """Filter tools in a sample to only include relevant ones.
95
+
96
+ Args:
97
+ sample: Dataset sample with 'messages' and 'tools' fields
98
+ strategy: Tool inclusion strategy:
99
+ - "all": Keep all tools (no filtering)
100
+ - "used_only": Only include tools that are called in the conversation
101
+ - "used_plus_related": Include used tools plus related ones (not implemented)
102
+ min_tools: Minimum number of tools to include (fallback if filtering
103
+ removes all tools)
104
+ clean_schemas: Whether to remove null values from tool schemas
105
+
106
+ Returns:
107
+ Modified sample with filtered tools
108
+ """
109
+ if strategy == "all" and not clean_schemas:
110
+ return sample
111
+
112
+ messages = sample.get("messages", [])
113
+ all_tools = sample.get("tools", [])
114
+
115
+ if not all_tools:
116
+ return sample
117
+
118
+ # Clean schemas if requested
119
+ if clean_schemas:
120
+ all_tools = [clean_tool_schema(tool) for tool in all_tools]
121
+
122
+ if strategy == "all":
123
+ sample["tools"] = all_tools
124
+ return sample
125
+
126
+ # Get tools actually used
127
+ used_names = get_used_tool_names(messages)
128
+
129
+ if not used_names:
130
+ # No tools used - keep minimum number of tools
131
+ sample["tools"] = all_tools[:min_tools] if min_tools > 0 else []
132
+ return sample
133
+
134
+ # Filter to used tools
135
+ filtered_tools = []
136
+ for tool in all_tools:
137
+ func = tool.get("function", {})
138
+ if isinstance(func, dict) and func.get("name") in used_names:
139
+ filtered_tools.append(tool)
140
+
141
+ # Ensure minimum tools
142
+ if len(filtered_tools) < min_tools:
143
+ # Add more tools from the original list
144
+ for tool in all_tools:
145
+ if tool not in filtered_tools:
146
+ filtered_tools.append(tool)
147
+ if len(filtered_tools) >= min_tools:
148
+ break
149
+
150
+ sample["tools"] = filtered_tools
151
+ return sample
152
+
153
+
154
+ def prepare_dataset_for_training(
155
+ dataset: Dataset,
156
+ tool_strategy: ToolInclusionStrategy = "used_only",
157
+ clean_tool_schemas: bool = True,
158
+ min_tools: int = 1,
159
+ num_proc: int | None = None,
160
+ ) -> Dataset:
161
+ """Prepare a DeepFabric dataset for training with optimizations.
162
+
163
+ This function applies various optimizations to reduce dataset size and
164
+ memory usage during training:
165
+ - Filters tools to only include those actually used in each conversation
166
+ - Removes null values from tool schemas
167
+ - Can be extended with additional preprocessing steps
168
+
169
+ Args:
170
+ dataset: HuggingFace Dataset with DeepFabric conversation format
171
+ tool_strategy: How to filter tools (see filter_tools_for_sample)
172
+ clean_tool_schemas: Whether to remove null values from tool schemas
173
+ min_tools: Minimum tools to keep per sample
174
+ num_proc: Number of processes for parallel processing
175
+
176
+ Returns:
177
+ Processed dataset ready for training
178
+
179
+ Example:
180
+ >>> from datasets import load_dataset
181
+ >>> from deepfabric.training import prepare_dataset_for_training
182
+ >>>
183
+ >>> dataset = load_dataset("your/dataset", split="train")
184
+ >>> prepared = prepare_dataset_for_training(
185
+ ... dataset,
186
+ ... tool_strategy="used_only",
187
+ ... clean_tool_schemas=True,
188
+ ... )
189
+ >>> # Now use prepared dataset for training
190
+ """
191
+ logger.info(
192
+ "Preparing dataset for training: tool_strategy=%s, clean_schemas=%s",
193
+ tool_strategy,
194
+ clean_tool_schemas,
195
+ )
196
+
197
+ # Get initial stats
198
+ if "tools" in dataset.column_names:
199
+ initial_tool_counts = [len(sample.get("tools", []) or []) for sample in dataset]
200
+ avg_initial = (
201
+ sum(initial_tool_counts) / len(initial_tool_counts) if initial_tool_counts else 0
202
+ )
203
+ logger.info("Initial average tools per sample: %.1f", avg_initial)
204
+
205
+ # Apply tool filtering
206
+ processed = dataset.map(
207
+ lambda x: filter_tools_for_sample(
208
+ x,
209
+ strategy=tool_strategy,
210
+ min_tools=min_tools,
211
+ clean_schemas=clean_tool_schemas,
212
+ ),
213
+ num_proc=num_proc,
214
+ desc="Filtering tools",
215
+ )
216
+
217
+ # Log final stats
218
+ if "tools" in processed.column_names:
219
+ final_tool_counts = [len(sample.get("tools", []) or []) for sample in processed]
220
+ avg_final = sum(final_tool_counts) / len(final_tool_counts) if final_tool_counts else 0
221
+ logger.info("Final average tools per sample: %.1f", avg_final)
222
+
223
+ return processed
@@ -35,6 +35,7 @@ class MetricsSender:
35
35
  self,
36
36
  endpoint: str,
37
37
  api_key: str | None,
38
+ pipeline_id: str | None = None,
38
39
  batch_size: int = 10,
39
40
  flush_interval: float = 5.0,
40
41
  max_queue_size: int = 1000,
@@ -45,6 +46,7 @@ class MetricsSender:
45
46
  Args:
46
47
  endpoint: Base URL for the DeepFabric API
47
48
  api_key: API key for authentication (None disables sending)
49
+ pipeline_id: Pipeline ID to associate training runs with (required)
48
50
  batch_size: Number of metrics to batch before sending
49
51
  flush_interval: Seconds between automatic flushes
50
52
  max_queue_size: Maximum queue size (overflow drops metrics)
@@ -52,12 +54,14 @@ class MetricsSender:
52
54
  """
53
55
  self.endpoint = endpoint.rstrip("/")
54
56
  self.api_key = api_key
57
+ self.pipeline_id = pipeline_id
55
58
  self.batch_size = batch_size
56
59
  self.flush_interval = flush_interval
57
60
  self.timeout = timeout
58
61
 
59
62
  self._queue: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=max_queue_size)
60
63
  self._stop_event = threading.Event()
64
+ self._flush_event = threading.Event()
61
65
  self._enabled = api_key is not None
62
66
 
63
67
  # Start background sender thread
@@ -177,19 +181,25 @@ class MetricsSender:
177
181
  should_flush = (
178
182
  len(batch) >= self.batch_size
179
183
  or (time.monotonic() - last_flush) >= self.flush_interval
184
+ or self._flush_event.is_set()
180
185
  )
181
186
 
182
187
  if should_flush:
183
188
  self._flush_batch(batch)
184
189
  batch = []
185
190
  last_flush = time.monotonic()
191
+ self._flush_event.clear()
186
192
 
187
193
  except queue.Empty:
188
- # Timeout - flush if we have pending items
189
- if batch and (time.monotonic() - last_flush) >= self.flush_interval:
194
+ # Timeout - flush if we have pending items or flush requested
195
+ if batch and (
196
+ (time.monotonic() - last_flush) >= self.flush_interval
197
+ or self._flush_event.is_set()
198
+ ):
190
199
  self._flush_batch(batch)
191
200
  batch = []
192
201
  last_flush = time.monotonic()
202
+ self._flush_event.clear()
193
203
 
194
204
  # On shutdown, drain the queue and flush everything
195
205
  while not self._queue.empty():
@@ -209,21 +219,34 @@ class MetricsSender:
209
219
  if not batch or not self._enabled:
210
220
  return
211
221
 
222
+ if not self.pipeline_id:
223
+ logger.debug("No pipeline_id set, skipping metrics send")
224
+ return
225
+
212
226
  # Separate events and metrics
213
- events = [item for item in batch if item["type"] != "metrics"]
227
+ run_start_events = [item for item in batch if item["type"] == "run_start"]
228
+ run_end_events = [item for item in batch if item["type"] == "run_end"]
214
229
  metrics = [item["data"] for item in batch if item["type"] == "metrics"]
215
230
 
216
- # Send events first (run_start, run_end)
217
- for event in events:
218
- self._send_to_api(
219
- endpoint=f"{self.endpoint}/v1/training/runs",
220
- payload={"event_type": event["type"], **event["data"]},
221
- )
231
+ # Build query string with pipeline_id
232
+ query = f"?pipeline_id={self.pipeline_id}"
233
+
234
+ def send_run_events(events: list[dict[str, Any]]) -> None:
235
+ """Send run start/end events."""
236
+ for event in events:
237
+ self._send_to_api(
238
+ endpoint=f"{self.endpoint}/api/v1/training/runs{query}",
239
+ payload={"event_type": event["type"], **event["data"]},
240
+ )
241
+
242
+ # Send run events, ensuring start events are processed before end events
243
+ send_run_events(run_start_events)
244
+ send_run_events(run_end_events)
222
245
 
223
246
  # Send metrics batch
224
247
  if metrics:
225
248
  self._send_to_api(
226
- endpoint=f"{self.endpoint}/v1/training/metrics",
249
+ endpoint=f"{self.endpoint}/api/v1/training/metrics{query}",
227
250
  payload={"metrics": metrics},
228
251
  )
229
252
  self._metrics_sent += len(metrics)
@@ -252,22 +275,27 @@ class MetricsSender:
252
275
 
253
276
  if not response.ok:
254
277
  self._send_errors += 1
255
- logger.debug(f"API request failed: {response.status_code} {response.text[:100]}")
278
+ logger.warning(
279
+ "API error: %s %s (endpoint: %s)",
280
+ response.status_code,
281
+ response.text[:200],
282
+ endpoint,
283
+ )
256
284
  return False
257
285
 
258
286
  except requests.exceptions.Timeout:
259
287
  self._send_errors += 1
260
- logger.debug("API request timed out")
288
+ logger.warning("Request timed out: %s", endpoint)
261
289
  return False
262
290
 
263
- except requests.exceptions.ConnectionError:
291
+ except requests.exceptions.ConnectionError as e:
264
292
  self._send_errors += 1
265
- logger.debug("API connection error")
293
+ logger.warning("Connection error: %s (endpoint: %s)", e, endpoint)
266
294
  return False
267
295
 
268
296
  except requests.exceptions.RequestException as e:
269
297
  self._send_errors += 1
270
- logger.debug(f"API request error: {e}")
298
+ logger.warning("Request error: %s (endpoint: %s)", e, endpoint)
271
299
  return False
272
300
 
273
301
  else:
@@ -282,8 +310,14 @@ class MetricsSender:
282
310
  if not self._enabled:
283
311
  return
284
312
 
313
+ # Signal the background thread to flush its current batch
314
+ self._flush_event.set()
315
+
285
316
  start = time.monotonic()
286
- while not self._queue.empty() and (time.monotonic() - start) < timeout:
317
+ # Wait for queue to empty and flush event to be cleared (indicates batch was sent)
318
+ while (time.monotonic() - start) < timeout:
319
+ if self._queue.empty() and not self._flush_event.is_set():
320
+ break
287
321
  time.sleep(0.1)
288
322
 
289
323
  def shutdown(self) -> None:
deepfabric/tui.py CHANGED
@@ -41,14 +41,22 @@ class TopicBuildingMixin:
41
41
 
42
42
  Subclasses must have these attributes:
43
43
  - tui: DeepFabricTUI instance
44
+ - live_display: Live | None
44
45
  - live_layout: Layout | None
45
46
  - events_log: deque
46
47
  """
47
48
 
48
49
  tui: "DeepFabricTUI"
50
+ live_display: "Live | None"
49
51
  live_layout: "Layout | None"
50
52
  events_log: "deque"
51
53
 
54
+ def stop_live(self) -> None:
55
+ """Stop the Live display if it's running."""
56
+ if self.live_display:
57
+ self.live_display.stop()
58
+ self.live_display = None
59
+
52
60
  def _refresh_left(self) -> None:
53
61
  """Update events panel in left column."""
54
62
  if self.live_layout is not None:
@@ -910,7 +918,7 @@ class DatasetGenerationTUI(StreamObserver):
910
918
  # Map conversation types to friendly names
911
919
  type_map = {
912
920
  "basic": "Basic Q&A",
913
- "chain_of_thought": "Chain of Thought",
921
+ "cot": "Chain of Thought",
914
922
  "single_turn_agent": "Single-Turn Agent (Tool Calling)",
915
923
  "multi_turn_agent": "Multi-Turn Agent (Tool Calling)",
916
924
  }
deepfabric/utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import ast
2
2
  import asyncio
3
3
  import json
4
+ import os
4
5
  import re
5
6
 
6
7
  VALIDATION_ERROR_INDICATORS = [
@@ -147,4 +148,17 @@ def read_topic_tree_from_jsonl(file_path: str) -> list[dict]:
147
148
  with open(file_path) as file:
148
149
  for line in file:
149
150
  topic_tree.append(json.loads(line.strip()))
151
+
150
152
  return topic_tree
153
+
154
+
155
+ def get_bool_env(key: str, default: bool = False) -> bool:
156
+ """Get a boolean environment variable.
157
+
158
+ Supports: '1', 'true', 'yes', 'on' (case-insensitive) as True.
159
+ Everything else is False unless default is True and key is missing.
160
+ """
161
+ val = os.getenv(key)
162
+ if val is None:
163
+ return default
164
+ return val.lower() in ("1", "true", "yes", "on")
deepfabric/validation.py CHANGED
@@ -79,7 +79,7 @@ def validate_path_requirements(
79
79
  for steps, batch in optimal_combinations[:3]: # Show top 3
80
80
  total_samples = steps * batch
81
81
  recommendations.append(
82
- f" --num-steps {steps} --batch-size {batch} (generates {total_samples} samples)"
82
+ f" --num-samples {steps} --batch-size {batch} (generates {total_samples} samples)"
83
83
  )
84
84
 
85
85
  recommendations.extend(