DeepFabric 4.5.1__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deepfabric/graph.py CHANGED
@@ -1,7 +1,10 @@
1
1
  import asyncio
2
+ import hashlib
2
3
  import json
3
4
  import textwrap
5
+ import uuid
4
6
 
7
+ from datetime import datetime, timezone
5
8
  from typing import TYPE_CHECKING, Any
6
9
 
7
10
  from pydantic import BaseModel, ConfigDict, Field
@@ -69,6 +72,15 @@ class GraphConfig(BaseModel):
69
72
  )
70
73
 
71
74
 
75
+ class GraphMetadata(BaseModel):
76
+ """Metadata for the entire graph for provenance tracking."""
77
+
78
+ provider: str = Field(..., description="LLM provider used (e.g., openai, ollama)")
79
+ model: str = Field(..., description="Model name used (e.g., gpt-4o)")
80
+ temperature: float = Field(..., description="Temperature setting used for generation")
81
+ created_at: str = Field(..., description="ISO 8601 timestamp when graph was created")
82
+
83
+
72
84
  class NodeModel(BaseModel):
73
85
  """Pydantic model for a node in the graph."""
74
86
 
@@ -84,6 +96,9 @@ class GraphModel(BaseModel):
84
96
 
85
97
  nodes: dict[int, NodeModel]
86
98
  root_id: int
99
+ metadata: GraphMetadata | None = Field(
100
+ default=None, description="Graph-level metadata for provenance tracking"
101
+ )
87
102
 
88
103
 
89
104
  class Node:
@@ -96,6 +111,14 @@ class Node:
96
111
  self.parents: list[Node] = []
97
112
  self.metadata: dict[str, Any] = metadata.copy() if metadata is not None else {}
98
113
 
114
+ # Auto-generate uuid if not present (stable node identification)
115
+ if "uuid" not in self.metadata:
116
+ self.metadata["uuid"] = str(uuid.uuid4())
117
+
118
+ # Auto-generate topic_hash if not present (duplicate detection via SHA256)
119
+ if "topic_hash" not in self.metadata:
120
+ self.metadata["topic_hash"] = hashlib.sha256(topic.encode("utf-8")).hexdigest()
121
+
99
122
  def to_pydantic(self) -> NodeModel:
100
123
  """Converts the runtime Node to its Pydantic model representation."""
101
124
  return NodeModel(
@@ -140,6 +163,9 @@ class Graph(TopicModel):
140
163
  # Progress reporter for streaming feedback (set by topic_manager)
141
164
  self.progress_reporter: ProgressReporter | None = None
142
165
 
166
+ # Store creation timestamp for provenance tracking
167
+ self.created_at: datetime = datetime.now(timezone.utc)
168
+
143
169
  trace(
144
170
  "graph_created",
145
171
  {
@@ -181,6 +207,12 @@ class Graph(TopicModel):
181
207
  return GraphModel(
182
208
  nodes={node_id: node.to_pydantic() for node_id, node in self.nodes.items()},
183
209
  root_id=self.root.id,
210
+ metadata=GraphMetadata(
211
+ provider=self.provider,
212
+ model=self.model_name,
213
+ temperature=self.temperature,
214
+ created_at=self.created_at.isoformat(),
215
+ ),
184
216
  )
185
217
 
186
218
  def to_json(self) -> str:
@@ -203,6 +235,12 @@ class Graph(TopicModel):
203
235
  graph = cls(**params)
204
236
  graph.nodes = {}
205
237
 
238
+ # Restore original creation timestamp if present in the loaded graph
239
+ if graph_model.metadata and graph_model.metadata.created_at:
240
+ # Handle 'Z' suffix for Python < 3.11 compatibility
241
+ created_at_str = graph_model.metadata.created_at.replace("Z", "+00:00")
242
+ graph.created_at = datetime.fromisoformat(created_at_str)
243
+
206
244
  # Create nodes
207
245
  for node_model in graph_model.nodes.values():
208
246
  node = Node(node_model.topic, node_model.id, node_model.metadata)
deepfabric/loader.py ADDED
@@ -0,0 +1,554 @@
1
+ """Dataset loading functionality for DeepFabric.
2
+
3
+ This module provides the load_dataset function for loading datasets from:
4
+ - Local text files (with line/paragraph/document sampling)
5
+ - Local JSONL files
6
+ - DeepFabric Cloud (via namespace/slug format)
7
+ """
8
+
9
+ import json
10
+ import re
11
+ import warnings
12
+
13
+ from http import HTTPStatus
14
+ from pathlib import Path
15
+ from typing import Any, Literal
16
+
17
+ import httpx
18
+
19
+ from .auth import DEFAULT_API_URL, get_stored_token
20
+ from .dataset import Dataset, DatasetDict
21
+ from .exceptions import LoaderError
22
+
23
+ # Default cache directory for cloud datasets
24
+ DEFAULT_CACHE_DIR = Path.home() / ".cache" / "deepfabric" / "datasets"
25
+
26
+
27
+ def _detect_source(path: str) -> Literal["local", "cloud"]:
28
+ """Detect if path refers to local files or DeepFabric Cloud.
29
+
30
+ Args:
31
+ path: The path argument to load_dataset
32
+
33
+ Returns:
34
+ "local" for local file loading, "cloud" for cloud loading
35
+ """
36
+ # Known local format types
37
+ if path.lower() in ("text", "json", "jsonl", "csv"):
38
+ return "local"
39
+
40
+ # Cloud pattern: namespace/slug (alphanumeric with hyphens/underscores, single slash)
41
+ cloud_pattern = re.compile(r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$")
42
+ if cloud_pattern.match(path) and not Path(path).exists():
43
+ return "cloud"
44
+
45
+ # Default to local (file path)
46
+ return "local"
47
+
48
+
49
+ def _read_text_file(file_path: str, sample_by: str) -> list[str]:
50
+ """Read text file with specified sampling strategy.
51
+
52
+ Args:
53
+ file_path: Path to the text file
54
+ sample_by: Sampling strategy - "line", "paragraph", or "document"
55
+
56
+ Returns:
57
+ List of text samples
58
+ """
59
+ with open(file_path, encoding="utf-8") as f:
60
+ content = f.read()
61
+
62
+ if sample_by == "document":
63
+ # Entire file as one sample
64
+ return [content.strip()] if content.strip() else []
65
+ if sample_by == "paragraph":
66
+ # Split on double newlines (paragraph boundaries)
67
+ paragraphs = re.split(r"\n\s*\n", content)
68
+ return [p.strip() for p in paragraphs if p.strip()]
69
+ # Default: line-by-line
70
+ return [line.strip() for line in content.split("\n") if line.strip()]
71
+
72
+
73
+ def _normalize_data_files(
74
+ data_files: dict[str, str | list[str]] | str | list[str] | None,
75
+ data_dir: str | None,
76
+ ) -> dict[str, list[str]]:
77
+ """Normalize data_files to a consistent dict format.
78
+
79
+ Args:
80
+ data_files: Input data_files in various formats
81
+ data_dir: Optional directory prefix
82
+
83
+ Returns:
84
+ Dict mapping split names to lists of file paths
85
+ """
86
+ if data_files is None:
87
+ return {}
88
+
89
+ # Normalize to dict format
90
+ if isinstance(data_files, str):
91
+ files_dict: dict[str, list[str]] = {"train": [data_files]}
92
+ elif isinstance(data_files, list):
93
+ files_dict = {"train": data_files}
94
+ else:
95
+ files_dict = {k: [v] if isinstance(v, str) else list(v) for k, v in data_files.items()}
96
+
97
+ # Apply data_dir prefix if provided
98
+ if data_dir:
99
+ files_dict = {k: [str(Path(data_dir) / f) for f in v] for k, v in files_dict.items()}
100
+
101
+ return files_dict
102
+
103
+
104
+ def _load_text_files(
105
+ data_files: dict[str, str | list[str]] | str | list[str] | None,
106
+ data_dir: str | None,
107
+ sample_by: str,
108
+ ) -> Dataset | DatasetDict:
109
+ """Load text files into Dataset.
110
+
111
+ Args:
112
+ data_files: File paths specification
113
+ data_dir: Optional directory prefix
114
+ sample_by: Sampling strategy
115
+
116
+ Returns:
117
+ Dataset or DatasetDict
118
+ """
119
+ files_dict = _normalize_data_files(data_files, data_dir)
120
+
121
+ if not files_dict:
122
+ raise LoaderError("No data files specified for text format")
123
+
124
+ # Load each split
125
+ datasets: dict[str, Dataset] = {}
126
+ for split_name, file_list in files_dict.items():
127
+ samples: list[dict[str, Any]] = []
128
+ for file_path in file_list:
129
+ if not Path(file_path).exists():
130
+ raise LoaderError(f"File not found: {file_path}")
131
+ texts = _read_text_file(file_path, sample_by)
132
+ samples.extend([{"text": t} for t in texts])
133
+ datasets[split_name] = Dataset(samples, {"source": "text", "split": split_name})
134
+
135
+ # Return single Dataset if only train split
136
+ if len(datasets) == 1 and "train" in datasets:
137
+ return datasets["train"]
138
+ return DatasetDict(datasets)
139
+
140
+
141
+ def _load_json_file(file_path: str) -> list[dict[str, Any]]:
142
+ """Load a standard JSON file (array of objects).
143
+
144
+ Args:
145
+ file_path: Path to the JSON file
146
+
147
+ Returns:
148
+ List of sample dictionaries
149
+
150
+ Raises:
151
+ LoaderError: If the file is not valid JSON or not an array of objects
152
+ """
153
+ with open(file_path, encoding="utf-8") as f:
154
+ try:
155
+ data = json.load(f)
156
+ except json.JSONDecodeError as e:
157
+ raise LoaderError(
158
+ f"Invalid JSON in {file_path}: {e}",
159
+ context={"file": file_path},
160
+ ) from e
161
+
162
+ if isinstance(data, list):
163
+ # Validate all items are dicts
164
+ for i, item in enumerate(data):
165
+ if not isinstance(item, dict):
166
+ raise LoaderError(
167
+ f"Expected array of objects in {file_path}, "
168
+ f"but item at index {i} is {type(item).__name__}",
169
+ context={"file": file_path, "index": i},
170
+ )
171
+ return data
172
+ if isinstance(data, dict):
173
+ # Single object - wrap in list
174
+ return [data]
175
+ raise LoaderError(
176
+ f"Expected JSON array or object in {file_path}, got {type(data).__name__}",
177
+ context={"file": file_path},
178
+ )
179
+
180
+
181
+ def _load_jsonl_file(file_path: str) -> list[dict[str, Any]]:
182
+ """Load a JSONL file (one JSON object per line).
183
+
184
+ Args:
185
+ file_path: Path to the JSONL file
186
+
187
+ Returns:
188
+ List of sample dictionaries
189
+ """
190
+ samples: list[dict[str, Any]] = []
191
+ with open(file_path, encoding="utf-8") as f:
192
+ for line_num, line in enumerate(f, 1):
193
+ if not line.strip():
194
+ continue
195
+ try:
196
+ samples.append(json.loads(line))
197
+ except json.JSONDecodeError as e:
198
+ raise LoaderError(
199
+ f"Invalid JSON on line {line_num} of {file_path}: {e}",
200
+ context={"file": file_path, "line": line_num},
201
+ ) from e
202
+ return samples
203
+
204
+
205
+ def _load_json_or_jsonl_file(file_path: str) -> list[dict[str, Any]]:
206
+ """Load a JSON or JSONL file based on extension.
207
+
208
+ Args:
209
+ file_path: Path to the file
210
+
211
+ Returns:
212
+ List of sample dictionaries
213
+ """
214
+ path = Path(file_path)
215
+ if path.suffix == ".jsonl":
216
+ return _load_jsonl_file(file_path)
217
+ # .json files are standard JSON (array or object)
218
+ return _load_json_file(file_path)
219
+
220
+
221
+ def _load_jsonl_files(
222
+ data_files: dict[str, str | list[str]] | str | list[str] | None,
223
+ data_dir: str | None,
224
+ ) -> Dataset | DatasetDict:
225
+ """Load JSON/JSONL files into Dataset.
226
+
227
+ Args:
228
+ data_files: File paths specification
229
+ data_dir: Optional directory prefix
230
+
231
+ Returns:
232
+ Dataset or DatasetDict
233
+ """
234
+ files_dict = _normalize_data_files(data_files, data_dir)
235
+
236
+ if not files_dict:
237
+ raise LoaderError("No data files specified for json/jsonl format")
238
+
239
+ datasets: dict[str, Dataset] = {}
240
+ for split_name, file_list in files_dict.items():
241
+ samples: list[dict[str, Any]] = []
242
+ for file_path in file_list:
243
+ if not Path(file_path).exists():
244
+ raise LoaderError(f"File not found: {file_path}")
245
+ samples.extend(_load_json_or_jsonl_file(file_path))
246
+ datasets[split_name] = Dataset(samples, {"source": "json", "split": split_name})
247
+
248
+ if len(datasets) == 1 and "train" in datasets:
249
+ return datasets["train"]
250
+ return DatasetDict(datasets)
251
+
252
+
253
+ def _load_from_directory(
254
+ data_dir: str,
255
+ sample_by: str,
256
+ ) -> Dataset | DatasetDict:
257
+ """Load all files from a directory.
258
+
259
+ Args:
260
+ data_dir: Directory path
261
+ sample_by: Sampling strategy for text files
262
+
263
+ Returns:
264
+ Dataset or DatasetDict
265
+ """
266
+ dir_path = Path(data_dir)
267
+ if not dir_path.is_dir():
268
+ raise LoaderError(f"Directory not found: {data_dir}")
269
+
270
+ # Find all supported files
271
+ text_files = list(dir_path.glob("*.txt"))
272
+ jsonl_files = list(dir_path.glob("*.jsonl"))
273
+ json_files = list(dir_path.glob("*.json"))
274
+
275
+ if not text_files and not jsonl_files and not json_files:
276
+ raise LoaderError(f"No .txt, .json, or .jsonl files found in {data_dir}")
277
+
278
+ samples: list[dict[str, Any]] = []
279
+
280
+ # Load text files
281
+ for file_path in text_files:
282
+ texts = _read_text_file(str(file_path), sample_by)
283
+ samples.extend([{"text": t} for t in texts])
284
+
285
+ # Load JSONL files (one JSON object per line)
286
+ for file_path in jsonl_files:
287
+ samples.extend(_load_jsonl_file(str(file_path)))
288
+
289
+ # Load JSON files (array of objects or single object)
290
+ for file_path in json_files:
291
+ samples.extend(_load_json_file(str(file_path)))
292
+
293
+ return Dataset(samples, {"source": "directory", "path": data_dir})
294
+
295
+
296
+ def _get_cache_path(namespace: str, slug: str, cache_dir: Path) -> Path:
297
+ """Get cache file path for a cloud dataset.
298
+
299
+ Args:
300
+ namespace: Dataset namespace
301
+ slug: Dataset slug
302
+ cache_dir: Base cache directory
303
+
304
+ Returns:
305
+ Path to cached JSONL file
306
+ """
307
+ return cache_dir / f"{namespace}_{slug}.jsonl"
308
+
309
+
310
+ def _load_from_cloud(
311
+ path: str,
312
+ split: str | None,
313
+ token: str | None,
314
+ api_url: str | None,
315
+ use_cache: bool,
316
+ streaming: bool,
317
+ ) -> Dataset:
318
+ """Load dataset from DeepFabric Cloud.
319
+
320
+ Args:
321
+ path: Dataset path in "namespace/slug" format
322
+ split: Optional split name (not used yet, reserved for future)
323
+ token: Optional auth token (uses stored token if not provided)
324
+ api_url: Optional API URL (uses default if not provided)
325
+ use_cache: Whether to use/store cached data
326
+ streaming: Whether to stream the dataset (not yet implemented on client side)
327
+
328
+ Returns:
329
+ Dataset loaded from cloud
330
+ """
331
+ # TODO: Implement streaming using Parquet shards endpoint
332
+ # Backend supports: GET /api/v1/datasets/{id}/parquet (manifest)
333
+ # and GET /api/v1/datasets/{id}/parquet/{filename} (shard with Range support)
334
+ if streaming:
335
+ warnings.warn(
336
+ "streaming=True is not yet implemented. "
337
+ "Falling back to loading entire dataset into memory. "
338
+ "For large datasets, this may cause memory issues.",
339
+ UserWarning,
340
+ stacklevel=3,
341
+ )
342
+
343
+ # TODO: Implement server-side split support
344
+ # For now, use dataset.split() after loading
345
+ if split:
346
+ warnings.warn(
347
+ f"split='{split}' is not yet implemented for cloud datasets. "
348
+ "Use dataset.split() after loading instead.",
349
+ UserWarning,
350
+ stacklevel=3,
351
+ )
352
+
353
+ # Parse namespace/slug
354
+ parts = path.split("/")
355
+ if len(parts) != 2: # noqa: PLR2004
356
+ raise LoaderError(
357
+ f"Invalid cloud path format: {path}. Expected 'namespace/slug'.",
358
+ context={"path": path},
359
+ )
360
+ namespace, slug = parts
361
+
362
+ effective_token = token or get_stored_token()
363
+ effective_api_url = api_url or DEFAULT_API_URL
364
+
365
+ # Check cache first if enabled
366
+ cache_path = _get_cache_path(namespace, slug, DEFAULT_CACHE_DIR)
367
+ if use_cache and cache_path.exists():
368
+ return Dataset.from_jsonl(str(cache_path))
369
+
370
+ # Build request headers
371
+ headers: dict[str, str] = {}
372
+ if effective_token:
373
+ headers["Authorization"] = f"Bearer {effective_token}"
374
+
375
+ # Fetch from API
376
+ try:
377
+ with httpx.Client() as client:
378
+ response = client.get(
379
+ f"{effective_api_url}/api/v1/datasets/by-slug/{namespace}/{slug}/with-samples",
380
+ headers=headers,
381
+ timeout=120.0,
382
+ )
383
+ response.raise_for_status()
384
+ data = response.json()
385
+ except httpx.HTTPStatusError as e:
386
+ if e.response.status_code == HTTPStatus.NOT_FOUND:
387
+ raise LoaderError(
388
+ f"Dataset not found: {path}",
389
+ context={"namespace": namespace, "slug": slug},
390
+ ) from e
391
+ if e.response.status_code == HTTPStatus.UNAUTHORIZED:
392
+ raise LoaderError(
393
+ f"Authentication required for dataset: {path}. "
394
+ "Run 'deepfabric auth login' or pass token parameter.",
395
+ context={"namespace": namespace, "slug": slug},
396
+ ) from e
397
+ if e.response.status_code == HTTPStatus.FORBIDDEN:
398
+ raise LoaderError(
399
+ f"Access denied for dataset: {path}. "
400
+ "You may not have permission to access this private dataset.",
401
+ context={"namespace": namespace, "slug": slug},
402
+ ) from e
403
+ raise LoaderError(
404
+ f"Failed to load dataset from cloud: {e}",
405
+ context={"path": path, "status_code": e.response.status_code},
406
+ ) from e
407
+ except httpx.RequestError as e:
408
+ raise LoaderError(
409
+ f"Network error while loading dataset: {e}",
410
+ context={"path": path},
411
+ ) from e
412
+
413
+ # Extract samples from response
414
+ samples = data.get("samples", [])
415
+ if not samples:
416
+ raise LoaderError(
417
+ f"Dataset is empty: {path}",
418
+ context={"namespace": namespace, "slug": slug},
419
+ )
420
+
421
+ # Extract sample data - API may wrap in {"data": ...} format
422
+ sample_data: list[dict[str, Any]] = []
423
+ for sample in samples:
424
+ if isinstance(sample, dict) and "data" in sample:
425
+ sample_data.append(sample["data"])
426
+ else:
427
+ sample_data.append(sample)
428
+
429
+ dataset = Dataset(sample_data, {"source": "cloud", "path": path})
430
+
431
+ # Cache the dataset by default
432
+ if use_cache:
433
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
434
+ dataset.to_jsonl(str(cache_path))
435
+
436
+ return dataset
437
+
438
+
439
+ def load_dataset(
440
+ path: str,
441
+ *,
442
+ data_files: dict[str, str | list[str]] | str | list[str] | None = None,
443
+ data_dir: str | None = None,
444
+ split: str | None = None,
445
+ sample_by: Literal["line", "paragraph", "document"] = "line",
446
+ use_cache: bool = True,
447
+ token: str | None = None,
448
+ api_url: str | None = None,
449
+ streaming: bool = False,
450
+ ) -> Dataset | DatasetDict:
451
+ """Load a dataset from local files or DeepFabric Cloud.
452
+
453
+ Args:
454
+ path: Dataset path. Can be:
455
+ - "text" for local text files (requires data_files or data_dir)
456
+ - "json" or "jsonl" for local JSON/JSONL files
457
+ - "namespace/slug" for DeepFabric Cloud datasets
458
+ - Direct file path (e.g., "data.jsonl")
459
+ data_files: File paths for local loading. Can be:
460
+ - Single path: "train.txt"
461
+ - List of paths: ["file1.txt", "file2.txt"]
462
+ - Dict for splits: {"train": "train.txt", "test": "test.txt"}
463
+ data_dir: Directory containing data files, or directory to load from
464
+ split: Which split to load ("train", "test", "validation")
465
+ sample_by: How to sample text files:
466
+ - "line": Each line is a sample (default)
467
+ - "paragraph": Split on double newlines
468
+ - "document": Entire file is one sample
469
+ use_cache: Cache cloud datasets locally (default True).
470
+ Cache location: ~/.cache/deepfabric/datasets/
471
+ token: DeepFabric Cloud auth token (defaults to stored token)
472
+ api_url: DeepFabric API URL (defaults to production)
473
+ streaming: If True, return an iterable dataset (cloud only, not yet implemented)
474
+
475
+ Returns:
476
+ Dataset or DatasetDict depending on input structure
477
+
478
+ Raises:
479
+ LoaderError: If loading fails (file not found, invalid format, auth failure, etc.)
480
+
481
+ Examples:
482
+ >>> from deepfabric import load_dataset
483
+ >>>
484
+ >>> # Load from local text files
485
+ >>> ds = load_dataset("text", data_files={"train": "train.txt", "test": "test.txt"})
486
+ >>>
487
+ >>> # Load with paragraph sampling
488
+ >>> ds = load_dataset("text", data_files="my_text.txt", sample_by="paragraph")
489
+ >>>
490
+ >>> # Load from DeepFabric Cloud
491
+ >>> ds = load_dataset("username/my-dataset")
492
+ >>>
493
+ >>> # Access data
494
+ >>> messages = ds["messages"]
495
+ >>> first_sample = ds[0]
496
+ >>>
497
+ >>> # Split into train/test
498
+ >>> splits = ds.split(test_size=0.1, seed=42)
499
+ """
500
+ source = _detect_source(path)
501
+
502
+ if source == "cloud":
503
+ return _load_from_cloud(path, split, token, api_url, use_cache, streaming)
504
+
505
+ # Local loading
506
+ if path.lower() == "text":
507
+ if not data_files and not data_dir:
508
+ raise LoaderError(
509
+ "text format requires data_files or data_dir parameter",
510
+ context={"path": path},
511
+ )
512
+ if data_dir and not data_files:
513
+ dataset = _load_from_directory(data_dir, sample_by)
514
+ else:
515
+ dataset = _load_text_files(data_files, data_dir, sample_by)
516
+
517
+ elif path.lower() in ("json", "jsonl"):
518
+ if not data_files and not data_dir:
519
+ raise LoaderError(
520
+ f"{path} format requires data_files or data_dir parameter",
521
+ context={"path": path},
522
+ )
523
+ if data_dir and not data_files:
524
+ dataset = _load_from_directory(data_dir, sample_by)
525
+ else:
526
+ dataset = _load_jsonl_files(data_files, data_dir)
527
+
528
+ else:
529
+ # Assume it's a direct file path
530
+ file_path = Path(path)
531
+ if file_path.is_file():
532
+ if file_path.suffix in (".jsonl", ".json"):
533
+ dataset = _load_jsonl_files(str(file_path), None)
534
+ else:
535
+ dataset = _load_text_files(str(file_path), None, sample_by)
536
+ elif file_path.is_dir():
537
+ dataset = _load_from_directory(str(file_path), sample_by)
538
+ else:
539
+ raise LoaderError(
540
+ f"Path not found: {path}",
541
+ context={"path": path},
542
+ )
543
+
544
+ # Handle split parameter for DatasetDict
545
+ if split is not None and isinstance(dataset, DatasetDict):
546
+ if split not in dataset:
547
+ available = list(dataset.keys())
548
+ raise LoaderError(
549
+ f"Split '{split}' not found. Available splits: {available}",
550
+ context={"path": path, "split": split, "available": available},
551
+ )
552
+ return dataset[split]
553
+
554
+ return dataset
deepfabric/schemas.py CHANGED
@@ -708,7 +708,7 @@ class FreeTextCoT(BaseModel):
708
708
  """Chain of Thought dataset in free-text format (GSM8K style)."""
709
709
 
710
710
  question: str = Field(description="The question or problem to solve")
711
- chain_of_thought: str = Field(description="Natural language reasoning explanation")
711
+ cot: str = Field(description="Natural language reasoning explanation")
712
712
  final_answer: str = Field(description="The definitive answer to the question")
713
713
 
714
714
 
@@ -726,7 +726,7 @@ class HybridCoT(BaseModel):
726
726
  """Chain of Thought dataset with both free-text and structured reasoning."""
727
727
 
728
728
  question: str = Field(description="The question or problem to solve")
729
- chain_of_thought: str = Field(description="Natural language reasoning explanation")
729
+ cot: str = Field(description="Natural language reasoning explanation")
730
730
  reasoning_trace: list[ReasoningStep] = Field(
731
731
  description="Structured reasoning steps", min_length=1
732
732
  )
@@ -780,7 +780,7 @@ class MathematicalAnswerMixin:
780
780
 
781
781
  # Capability Models for Composable Conversation Schema
782
782
  class ReasoningTrace(BaseModel):
783
- """Reasoning capability - present when conversation_type='chain_of_thought'."""
783
+ """Reasoning capability - present when conversation_type='cot'."""
784
784
 
785
785
  style: Literal["freetext", "agent"] = Field(
786
786
  description="The reasoning style: freetext (natural language) or agent (structured step-by-step for tool-calling)"
@@ -929,7 +929,7 @@ class FormattedSample(BaseModel):
929
929
  # Unified conversation schema mapping
930
930
  CONVERSATION_SCHEMAS = {
931
931
  "basic": Conversation,
932
- "chain_of_thought": Conversation,
932
+ "cot": Conversation,
933
933
  }
934
934
 
935
935
 
@@ -943,7 +943,7 @@ def get_conversation_schema(
943
943
  populated based on the configuration during generation.
944
944
 
945
945
  Args:
946
- conversation_type: Type of conversation (basic, chain_of_thought)
946
+ conversation_type: Type of conversation (basic, cot)
947
947
 
948
948
  Returns:
949
949
  Conversation schema (unified for all types)
@@ -97,6 +97,8 @@ async def _process_graph_events(graph: Graph, debug: bool = False) -> dict | Non
97
97
  get_tui().error(f" [{idx}] Node ID: {node_id}, Attempts: {attempts}")
98
98
  get_tui().error(f" Error: {last_error}")
99
99
  except Exception as e:
100
+ # Stop TUI before printing error to ensure visibility
101
+ tui.stop_live()
100
102
  if debug:
101
103
  get_tui().error(f"Debug: Full traceback:\n{traceback.format_exc()}")
102
104
  get_tui().error(f"Graph build failed: {str(e)}")
@@ -147,6 +149,8 @@ async def _process_tree_events(tree: Tree, debug: bool = False) -> dict | None:
147
149
  )
148
150
  get_tui().error(f" Error: {failure.get('error', 'Unknown error')}")
149
151
  except Exception as e:
152
+ # Stop TUI before printing error to ensure visibility
153
+ tui.stop_live()
150
154
  if debug:
151
155
  get_tui().error(f"Debug: Full traceback:\n{traceback.format_exc()}")
152
156
  get_tui().error(f"Tree build failed: {str(e)}")