DeepFabric 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. deepfabric/__init__.py +70 -0
  2. deepfabric/__main__.py +6 -0
  3. deepfabric/auth.py +382 -0
  4. deepfabric/builders.py +303 -0
  5. deepfabric/builders_agent.py +1304 -0
  6. deepfabric/cli.py +1288 -0
  7. deepfabric/config.py +899 -0
  8. deepfabric/config_manager.py +251 -0
  9. deepfabric/constants.py +94 -0
  10. deepfabric/dataset_manager.py +534 -0
  11. deepfabric/error_codes.py +581 -0
  12. deepfabric/evaluation/__init__.py +47 -0
  13. deepfabric/evaluation/backends/__init__.py +32 -0
  14. deepfabric/evaluation/backends/ollama_backend.py +137 -0
  15. deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
  16. deepfabric/evaluation/backends/transformers_backend.py +326 -0
  17. deepfabric/evaluation/evaluator.py +845 -0
  18. deepfabric/evaluation/evaluators/__init__.py +13 -0
  19. deepfabric/evaluation/evaluators/base.py +104 -0
  20. deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
  21. deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
  22. deepfabric/evaluation/evaluators/registry.py +66 -0
  23. deepfabric/evaluation/inference.py +155 -0
  24. deepfabric/evaluation/metrics.py +397 -0
  25. deepfabric/evaluation/parser.py +304 -0
  26. deepfabric/evaluation/reporters/__init__.py +13 -0
  27. deepfabric/evaluation/reporters/base.py +56 -0
  28. deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
  29. deepfabric/evaluation/reporters/file_reporter.py +61 -0
  30. deepfabric/evaluation/reporters/multi_reporter.py +56 -0
  31. deepfabric/exceptions.py +67 -0
  32. deepfabric/factory.py +26 -0
  33. deepfabric/generator.py +1084 -0
  34. deepfabric/graph.py +545 -0
  35. deepfabric/hf_hub.py +214 -0
  36. deepfabric/kaggle_hub.py +219 -0
  37. deepfabric/llm/__init__.py +41 -0
  38. deepfabric/llm/api_key_verifier.py +534 -0
  39. deepfabric/llm/client.py +1206 -0
  40. deepfabric/llm/errors.py +105 -0
  41. deepfabric/llm/rate_limit_config.py +262 -0
  42. deepfabric/llm/rate_limit_detector.py +278 -0
  43. deepfabric/llm/retry_handler.py +270 -0
  44. deepfabric/metrics.py +212 -0
  45. deepfabric/progress.py +262 -0
  46. deepfabric/prompts.py +290 -0
  47. deepfabric/schemas.py +1000 -0
  48. deepfabric/spin/__init__.py +6 -0
  49. deepfabric/spin/client.py +263 -0
  50. deepfabric/spin/models.py +26 -0
  51. deepfabric/stream_simulator.py +90 -0
  52. deepfabric/tools/__init__.py +5 -0
  53. deepfabric/tools/defaults.py +85 -0
  54. deepfabric/tools/loader.py +87 -0
  55. deepfabric/tools/mcp_client.py +677 -0
  56. deepfabric/topic_manager.py +303 -0
  57. deepfabric/topic_model.py +20 -0
  58. deepfabric/training/__init__.py +35 -0
  59. deepfabric/training/api_key_prompt.py +302 -0
  60. deepfabric/training/callback.py +363 -0
  61. deepfabric/training/metrics_sender.py +301 -0
  62. deepfabric/tree.py +438 -0
  63. deepfabric/tui.py +1267 -0
  64. deepfabric/update_checker.py +166 -0
  65. deepfabric/utils.py +150 -0
  66. deepfabric/validation.py +143 -0
  67. deepfabric-4.4.0.dist-info/METADATA +702 -0
  68. deepfabric-4.4.0.dist-info/RECORD +71 -0
  69. deepfabric-4.4.0.dist-info/WHEEL +4 -0
  70. deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
  71. deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
deepfabric/schemas.py ADDED
@@ -0,0 +1,1000 @@
1
+ import ast
2
+ import json
3
+ import logging
4
+ import re
5
+ import secrets
6
+ import string
7
+
8
+ from decimal import ROUND_HALF_UP, Decimal
9
+ from typing import Annotated, Any, Literal
10
+
11
+ from pydantic import BaseModel, BeforeValidator, Field, field_validator
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Tool Call ID constants for DeepFabric Format
16
+ TOOL_CALL_ID_CHARS = string.ascii_letters + string.digits
17
+ TOOL_CALL_ID_LENGTH = 9
18
+ TOOL_CALL_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{9}$")
19
+
20
+
21
+ def generate_tool_call_id() -> str:
22
+ """Generate a 9-character alphanumeric tool call ID.
23
+
24
+ Returns:
25
+ A string of exactly 9 alphanumeric characters (A-Z, a-z, 0-9).
26
+ """
27
+ return "".join(secrets.choice(TOOL_CALL_ID_CHARS) for _ in range(TOOL_CALL_ID_LENGTH))
28
+
29
+
30
+ # Type alias for metadata/structured_data fields
31
+ # Provider-specific transformations in llm/client.py handle:
32
+ # - OpenAI: adds additionalProperties: false
33
+ # - Gemini: strips additionalProperties
34
+ MetadataDict = dict[str, Any] | None
35
+
36
+
37
+ class ExcludeNoneBaseModel(BaseModel):
38
+ """Base model that excludes None values during serialization."""
39
+
40
+ def model_dump(self, **kwargs):
41
+ """Override to always exclude None values for clean serialization."""
42
+ kwargs.setdefault("exclude_none", True)
43
+ return super().model_dump(**kwargs)
44
+
45
+
46
+ # Basic message schema
47
+ class ChatMessage(ExcludeNoneBaseModel):
48
+ """A single message in a conversation."""
49
+
50
+ model_config = {"extra": "forbid"}
51
+
52
+ role: Literal["system", "user", "assistant", "tool"] = Field(
53
+ description="The role of the message sender"
54
+ )
55
+ content: str | None = Field(
56
+ default=None, description="The content of the message (optional when tool_calls is present)"
57
+ )
58
+ tool_calls: list["ToolCall"] | None = Field(
59
+ default=None, description="Tool calls made by the assistant (DeepFabric format)"
60
+ )
61
+ tool_call_id: str | None = Field(
62
+ default=None,
63
+ description="ID linking tool result to the original tool call (9-char alphanumeric)",
64
+ )
65
+
66
+ @field_validator("tool_call_id")
67
+ @classmethod
68
+ def validate_tool_call_id_format(cls, v: str | None) -> str | None:
69
+ """Validate tool_call_id matches the 9-char alphanumeric format when present."""
70
+ if v is not None and not TOOL_CALL_ID_PATTERN.match(v):
71
+ raise ValueError(f"tool_call_id must be exactly 9 alphanumeric characters, got: '{v}'")
72
+ return v
73
+
74
+
75
+ class ChatTranscript(BaseModel):
76
+ """A complete conversation transcript with messages."""
77
+
78
+ messages: list[ChatMessage] = Field(
79
+ description="List of messages in the conversation", min_length=1
80
+ )
81
+
82
+
83
+ class ReasoningStep(BaseModel):
84
+ """A single step in a chain of reasoning."""
85
+
86
+ step_number: int = Field(description="The step number in the reasoning chain")
87
+ thought: str = Field(description="The reasoning or thought for this step")
88
+ action: str | None = Field(
89
+ default=None,
90
+ description=(
91
+ "Action taken in this reasoning step. For tool-calling, use one of these formats: "
92
+ "1) Plain function name: 'get_weather' "
93
+ "2) Function call: 'get_weather(city=\"Paris\")' "
94
+ "3) Descriptive text (less reliable): 'I will call the get_weather tool'"
95
+ ),
96
+ )
97
+
98
+
99
+ class StructuredConversation(BaseModel):
100
+ """A conversation with optional structured reasoning and metadata."""
101
+
102
+ messages: list[ChatMessage] = Field(
103
+ description="List of messages in the conversation", min_length=1
104
+ )
105
+ reasoning_trace: list[ReasoningStep] | None = Field(
106
+ default=None, description="Optional chain of reasoning steps"
107
+ )
108
+ metadata: dict[str, Any] | None = Field(
109
+ default=None, description="Optional metadata about the conversation"
110
+ )
111
+
112
+
113
+ # Tool definition schemas for structured tool system
114
+ class ToolParameter(BaseModel):
115
+ """A single parameter for a tool/function."""
116
+
117
+ name: str = Field(description="Parameter name")
118
+ type: Literal["str", "int", "float", "bool", "list", "dict"] = Field(
119
+ description="Parameter type"
120
+ )
121
+ description: str = Field(description="What this parameter does")
122
+ required: bool = Field(default=True, description="Whether this parameter is required")
123
+ default: str | None = Field(
124
+ default=None,
125
+ description=(
126
+ "Default value if not provided. Stored as string for HuggingFace Datasets compatibility "
127
+ "(Arrow/Parquet requires consistent types). Actual type is preserved in 'type' field."
128
+ ),
129
+ )
130
+
131
+
132
+ # MCP (Model Context Protocol) tool schema models
133
+ class MCPInputSchemaProperty(BaseModel):
134
+ """A single property in an MCP input schema."""
135
+
136
+ model_config = {"extra": "allow"}
137
+
138
+ type: str = Field(default="string", description="JSON Schema type")
139
+ description: str = Field(default="", description="Property description")
140
+ default: Any | None = Field(default=None, description="Default value")
141
+
142
+
143
+ class MCPInputSchema(BaseModel):
144
+ """MCP tool input schema (JSON Schema format)."""
145
+
146
+ model_config = {"extra": "allow"}
147
+
148
+ type: str = Field(default="object", description="Schema type")
149
+ properties: dict[str, MCPInputSchemaProperty] = Field(
150
+ default_factory=dict, description="Parameter properties"
151
+ )
152
+ required: list[str] = Field(default_factory=list, description="Required parameter names")
153
+
154
+
155
+ class MCPToolDefinition(BaseModel):
156
+ """MCP (Model Context Protocol) tool definition.
157
+
158
+ See: https://modelcontextprotocol.io/specification/2025-06-18/schema#tool
159
+ """
160
+
161
+ model_config = {"extra": "allow"}
162
+
163
+ name: str = Field(description="Tool name")
164
+ description: str = Field(default="", description="Tool description")
165
+ input_schema: MCPInputSchema | None = Field(
166
+ default=None,
167
+ alias="inputSchema",
168
+ description="JSON Schema for tool parameters (optional, some tools have no params)",
169
+ )
170
+
171
+ @property
172
+ def input_schema_safe(self) -> MCPInputSchema:
173
+ """Get input_schema, returning empty schema if None."""
174
+ return self.input_schema or MCPInputSchema()
175
+
176
+
177
+ class ToolDefinition(BaseModel):
178
+ """Complete definition of a tool/function."""
179
+
180
+ name: str = Field(description="Tool name (function name)")
181
+ description: str = Field(description="What this tool does")
182
+ parameters: list[ToolParameter] = Field(description="List of parameters this tool accepts")
183
+ returns: str = Field(description="Description of what this tool returns")
184
+ category: str = Field(default="general", description="Tool category for grouping")
185
+ component: str | None = Field(
186
+ default=None,
187
+ description="Spin component name that implements this tool (e.g., 'vfs', 'github-mock')",
188
+ )
189
+
190
+ def to_signature(self) -> str:
191
+ """Generate a function signature string."""
192
+ params = []
193
+ for p in self.parameters:
194
+ if p.required:
195
+ params.append(f"{p.name}: {p.type}")
196
+ else:
197
+ params.append(f"{p.name}: {p.type} = {p.default}")
198
+ return f"{self.name}({', '.join(params)}) → {self.returns}"
199
+
200
+ def to_openai(self) -> dict[str, Any]:
201
+ """
202
+ Convert tool definition to OpenAI function calling schema format.
203
+
204
+ This format is compatible with TRL's SFTTrainer and other HuggingFace
205
+ training frameworks that support tool/function calling.
206
+
207
+ Returns:
208
+ Dictionary in OpenAI function calling schema format with:
209
+ - type: Always "function"
210
+ - function: Object containing name, description, and parameters schema
211
+ """
212
+ # Map DeepFabric types to JSON Schema types
213
+ type_mapping = {
214
+ "str": "string",
215
+ "int": "integer",
216
+ "float": "number",
217
+ "bool": "boolean",
218
+ "list": "array",
219
+ "dict": "object",
220
+ }
221
+
222
+ properties = {}
223
+ required = []
224
+
225
+ for param in self.parameters:
226
+ json_type = type_mapping.get(param.type, "string")
227
+ properties[param.name] = {
228
+ "type": json_type,
229
+ "description": param.description,
230
+ }
231
+
232
+ # Add default value if present and not required
233
+ # Convert string default back to proper type for JSON Schema
234
+ if not param.required and param.default is not None and param.default != "":
235
+ default_value = param.default
236
+ # Convert string representation back to typed value
237
+ try:
238
+ if param.type == "int":
239
+ default_value = int(param.default)
240
+ elif param.type == "float":
241
+ default_value = float(param.default)
242
+ elif param.type == "bool":
243
+ default_value = param.default.lower() in ("true", "1", "yes")
244
+ elif param.type in ("list", "dict"):
245
+ # Handle special cases
246
+ if param.default.lower() == "none":
247
+ default_value = None
248
+ else:
249
+ # Try JSON first
250
+ try:
251
+ default_value = json.loads(param.default)
252
+ except json.JSONDecodeError:
253
+ # Fallback: try Python literal (e.g., ['markdown'] -> ["markdown"])
254
+ try:
255
+ default_value = ast.literal_eval(param.default)
256
+ except (ValueError, SyntaxError):
257
+ # Give up and skip this default
258
+ continue
259
+ # str remains as-is
260
+ properties[param.name]["default"] = default_value
261
+ except (ValueError, AttributeError):
262
+ # Skip invalid default values (e.g., empty strings for int/float)
263
+ pass
264
+
265
+ # Track required parameters
266
+ if param.required:
267
+ required.append(param.name)
268
+
269
+ return {
270
+ "type": "function",
271
+ "function": {
272
+ "name": self.name,
273
+ "description": self.description,
274
+ "parameters": {
275
+ "type": "object",
276
+ "properties": properties,
277
+ "required": required,
278
+ },
279
+ },
280
+ }
281
+
282
+ @classmethod
283
+ def from_openai(cls, openai_tool: dict[str, Any]) -> "ToolDefinition":
284
+ """Create a ToolDefinition from OpenAI function calling schema format.
285
+
286
+ Args:
287
+ openai_tool: Dictionary in OpenAI format with type="function" and function object
288
+
289
+ Returns:
290
+ ToolDefinition instance
291
+ """
292
+ # Reverse type mapping
293
+ type_mapping = {
294
+ "string": "str",
295
+ "integer": "int",
296
+ "number": "float",
297
+ "boolean": "bool",
298
+ "array": "list",
299
+ "object": "dict",
300
+ }
301
+
302
+ func = openai_tool.get("function", {})
303
+ name = func.get("name", "")
304
+ description = func.get("description", "")
305
+ params_schema = func.get("parameters", {})
306
+
307
+ properties = params_schema.get("properties", {})
308
+ required_params = set(params_schema.get("required", []))
309
+
310
+ parameters = []
311
+ for param_name, param_props in properties.items():
312
+ # Skip None/null property values (can occur in some tool schemas)
313
+ if param_props is None:
314
+ continue
315
+ json_type = param_props.get("type", "string")
316
+ df_type = type_mapping.get(json_type, "str")
317
+
318
+ default = param_props.get("default")
319
+ default_str = str(default) if default is not None else ""
320
+
321
+ parameters.append(
322
+ ToolParameter(
323
+ name=param_name,
324
+ type=df_type, # type: ignore[arg-type]
325
+ description=param_props.get("description", ""),
326
+ required=param_name in required_params,
327
+ default=default_str,
328
+ )
329
+ )
330
+
331
+ return cls(
332
+ name=name,
333
+ description=description,
334
+ parameters=parameters,
335
+ returns="", # OpenAI format doesn't include return description
336
+ category="general",
337
+ )
338
+
339
+ @classmethod
340
+ def from_mcp(cls, mcp_tool: MCPToolDefinition | dict[str, Any]) -> "ToolDefinition":
341
+ """Create a ToolDefinition from MCP (Model Context Protocol) tool schema.
342
+
343
+ Args:
344
+ mcp_tool: Either an MCPToolDefinition instance or a dict in MCP format
345
+
346
+ Returns:
347
+ ToolDefinition instance
348
+ """
349
+ # Type mapping from JSON Schema to DeepFabric types
350
+ type_mapping = {
351
+ "string": "str",
352
+ "integer": "int",
353
+ "number": "float",
354
+ "boolean": "bool",
355
+ "array": "list",
356
+ "object": "dict",
357
+ }
358
+
359
+ # Parse dict to MCPToolDefinition if needed
360
+ if isinstance(mcp_tool, dict):
361
+ mcp_tool = MCPToolDefinition.model_validate(mcp_tool)
362
+
363
+ # Use safe property to handle None input_schema
364
+ input_schema = mcp_tool.input_schema_safe
365
+ parameters = []
366
+ required_params = set(input_schema.required)
367
+
368
+ for param_name, param_props in input_schema.properties.items():
369
+ df_type = type_mapping.get(param_props.type, "str")
370
+ default_str = str(param_props.default) if param_props.default is not None else ""
371
+
372
+ parameters.append(
373
+ ToolParameter(
374
+ name=param_name,
375
+ type=df_type, # type: ignore[arg-type]
376
+ description=param_props.description,
377
+ required=param_name in required_params,
378
+ default=default_str,
379
+ )
380
+ )
381
+
382
+ return cls(
383
+ name=mcp_tool.name,
384
+ description=mcp_tool.description,
385
+ parameters=parameters,
386
+ returns="", # MCP format doesn't include return description
387
+ category="general",
388
+ )
389
+
390
+
391
+ class ToolRegistry(BaseModel):
392
+ """Registry of available tools."""
393
+
394
+ tools: list[ToolDefinition] = Field(description="List of available tool definitions")
395
+
396
+ def get_tool(self, name: str) -> ToolDefinition | None:
397
+ """Get a tool by name."""
398
+ return next((t for t in self.tools if t.name == name), None)
399
+
400
+ def get_tools_by_category(self, category: str) -> list[ToolDefinition]:
401
+ """Get all tools in a category."""
402
+ return [t for t in self.tools if t.category == category]
403
+
404
+ def get_tool_names(self) -> list[str]:
405
+ """Get list of all tool names."""
406
+ return [t.name for t in self.tools]
407
+
408
+ def to_openai_format(self) -> list[dict[str, Any]]:
409
+ """
410
+ Convert all tools to OpenAI function calling schema format.
411
+
412
+ This method is specifically designed for use with HuggingFace TRL's
413
+ SFTTrainer and other training frameworks that require tools to be
414
+ provided in OpenAI function calling format.
415
+
416
+ Returns:
417
+ List of tool definitions in OpenAI function calling schema format.
418
+ Each tool includes type="function" and a function object with
419
+ name, description, and parameters.
420
+
421
+ Example:
422
+ >>> registry = ToolRegistry(tools=[...])
423
+ >>> trl_tools = registry.to_openai_format()
424
+ >>> # Use in dataset: {"messages": [...], "tools": trl_tools}
425
+ """
426
+ return [tool.to_openai() for tool in self.tools]
427
+
428
+
429
+ # Agent tool-calling schemas
430
+ class ToolReasoningStep(BaseModel):
431
+ """A reasoning step that leads to tool selection and parameter construction."""
432
+
433
+ step_number: int = Field(description="The step number in the tool planning sequence")
434
+ reasoning: str = Field(description="Why this tool is needed at this point")
435
+ selected_tool_name: str = Field(description="Name of the tool being selected")
436
+ parameter_reasoning: str = Field(description="Reasoning for parameter values")
437
+ expected_result: str = Field(description="What the tool should return and how it helps")
438
+
439
+
440
+ class ToolExecution(BaseModel):
441
+ """Represents actual execution of a tool with reasoning context."""
442
+
443
+ function_name: str = Field(min_length=1, description="Name of the function/tool being called")
444
+ arguments: str = Field(
445
+ min_length=2, description="JSON string of arguments passed to the function"
446
+ )
447
+ reasoning: str = Field(min_length=1, description="Brief explanation of why executing now")
448
+ result: str = Field(min_length=1, description="The result returned from the tool execution")
449
+
450
+ @field_validator("arguments")
451
+ @classmethod
452
+ def validate_arguments_json(cls, v: str) -> str:
453
+ """Validate that arguments are valid JSON with no null/empty placeholders.
454
+
455
+ Empty objects {} are allowed for parameterless tools like list_files().
456
+ """
457
+ stripped = v.strip()
458
+
459
+ # Parse and validate JSON
460
+ try:
461
+ parsed = json.loads(stripped)
462
+ if isinstance(parsed, dict):
463
+ for key, value in parsed.items():
464
+ if value is None:
465
+ raise ValueError(f"Argument '{key}' is null - must provide actual value")
466
+ if isinstance(value, str) and value == "":
467
+ raise ValueError(
468
+ f"Argument '{key}' is empty string - must provide actual value"
469
+ )
470
+ except json.JSONDecodeError as e:
471
+ raise ValueError(f"Arguments must be valid JSON: {e}") from e
472
+
473
+ return v
474
+
475
+ @property
476
+ def parsed_arguments(self) -> dict[str, Any]:
477
+ """Parse arguments JSON string to dict.
478
+
479
+ Uses Any for values as function arguments can be strings, numbers, booleans, lists, nested dicts, etc.
480
+ """
481
+ return json.loads(self.arguments)
482
+
483
+ def to_tool_call(self, tool_call_id: str | None = None) -> "ToolCall":
484
+ """Convert ToolExecution to a ToolCall for the final dataset.
485
+
486
+ Args:
487
+ tool_call_id: The 9-char alphanumeric ID. If None, generates one.
488
+
489
+ Returns:
490
+ ToolCall with JSON string arguments (HuggingFace compatible).
491
+ """
492
+ return ToolCall(
493
+ id=tool_call_id or generate_tool_call_id(),
494
+ type="function",
495
+ function=ToolCallFunction(
496
+ name=self.function_name,
497
+ arguments=self.arguments, # Pass string directly
498
+ ),
499
+ )
500
+
501
+ class Config:
502
+ extra = "forbid"
503
+
504
+
505
+ class PendingToolCall(BaseModel):
506
+ """A tool call request before execution (no result yet).
507
+
508
+ Used in AgentStep for the ReAct loop - the LLM generates these,
509
+ then Spin executes them and populates the result separately.
510
+ """
511
+
512
+ function_name: str = Field(min_length=1, description="Name of the function/tool to call")
513
+ arguments: str = Field(
514
+ min_length=2, description="JSON string of arguments to pass to the function"
515
+ )
516
+ reasoning: str = Field(min_length=1, description="Brief explanation of why calling this tool")
517
+
518
+ @field_validator("arguments")
519
+ @classmethod
520
+ def validate_arguments_json(cls, v: str) -> str:
521
+ """Validate that arguments are valid JSON with no null/empty placeholders.
522
+
523
+ Empty objects {} are allowed for parameterless tools like list_files().
524
+ """
525
+ stripped = v.strip()
526
+
527
+ # Parse and validate JSON
528
+ try:
529
+ parsed = json.loads(stripped)
530
+ if isinstance(parsed, dict):
531
+ for key, value in parsed.items():
532
+ if value is None:
533
+ raise ValueError(f"Argument '{key}' is null - must provide actual value")
534
+ if isinstance(value, str) and value == "":
535
+ raise ValueError(
536
+ f"Argument '{key}' is empty string - must provide actual value"
537
+ )
538
+ except json.JSONDecodeError as e:
539
+ raise ValueError(f"Arguments must be valid JSON: {e}") from e
540
+
541
+ return v
542
+
543
+ @property
544
+ def parsed_arguments(self) -> dict[str, Any]:
545
+ """Parse arguments JSON string to dict."""
546
+ return json.loads(self.arguments)
547
+
548
+ def to_tool_execution(self, result: str) -> "ToolExecution":
549
+ """Convert to ToolExecution after getting result from Spin.
550
+
551
+ Args:
552
+ result: The result from tool execution
553
+
554
+ Returns:
555
+ ToolExecution with the result populated
556
+ """
557
+ return ToolExecution(
558
+ function_name=self.function_name,
559
+ arguments=self.arguments,
560
+ reasoning=self.reasoning,
561
+ result=result,
562
+ )
563
+
564
+ class Config:
565
+ extra = "forbid"
566
+
567
+
568
+ class AgentStep(BaseModel):
569
+ """A single step in ReAct-style agent reasoning.
570
+
571
+ Each step represents one iteration of the think-act-observe loop:
572
+ 1. Agent thinks about what to do next (thought)
573
+ 2. Agent decides on tool calls for THIS step only
574
+ 3. Tools are executed and results observed
575
+ 4. Process repeats until is_final=True
576
+
577
+ This ensures tool calls are made based on observed results,
578
+ not hallucinated assumptions about what tools will return.
579
+ """
580
+
581
+ thought: str = Field(
582
+ min_length=1,
583
+ description="Agent's reasoning about what to do next based on observations so far",
584
+ )
585
+ tool_calls: list["PendingToolCall"] = Field(
586
+ default_factory=list,
587
+ description="Tool calls for THIS step only. Empty if agent is done.",
588
+ )
589
+ is_final: bool = Field(
590
+ default=False,
591
+ description="True if agent has enough information and is ready to respond to user",
592
+ )
593
+
594
+ class Config:
595
+ extra = "forbid"
596
+
597
+
598
+ # Tool calling schemas for conversations that include function calls
599
+ class FunctionCall(BaseModel):
600
+ """A function call with arguments."""
601
+
602
+ name: str = Field(description="The name of the function to call")
603
+ arguments: dict[str, Any] = Field(description="Arguments to pass to the function")
604
+
605
+
606
+ def _serialize_arguments(v: dict[str, Any] | str) -> str:
607
+ """Serialize arguments to JSON string, stripping None values.
608
+
609
+ This ensures consistent schema for HuggingFace datasets (Arrow/Parquet compatibility).
610
+ Accepts both dict and str inputs for backward compatibility.
611
+ """
612
+ match v:
613
+ case dict():
614
+ cleaned = {k: val for k, val in v.items() if val is not None}
615
+ return json.dumps(cleaned, separators=(",", ":"))
616
+ case str():
617
+ # Validate JSON, strip nulls, re-serialize for consistency
618
+ parsed = json.loads(v)
619
+ cleaned = (
620
+ {k: val for k, val in parsed.items() if val is not None}
621
+ if isinstance(parsed, dict)
622
+ else parsed
623
+ )
624
+ return json.dumps(cleaned, separators=(",", ":"))
625
+ case _:
626
+ raise ValueError(f"arguments must be dict or str, got {type(v)}")
627
+
628
+
629
+ # Type alias for JSON-serialized arguments (HuggingFace compatible)
630
+ ArgumentsStr = Annotated[str, BeforeValidator(_serialize_arguments)]
631
+
632
+
633
+ class ToolCallFunction(ExcludeNoneBaseModel):
634
+ """Function details within a tool call (DeepFabric Format)."""
635
+
636
+ name: str = Field(min_length=1, description="The name of the function to call")
637
+ arguments: ArgumentsStr = Field(description="Arguments as JSON string (HuggingFace compatible)")
638
+
639
+ @property
640
+ def parsed_arguments(self) -> dict[str, Any]:
641
+ """Parse arguments JSON string to dict when needed at runtime."""
642
+ return json.loads(self.arguments)
643
+
644
+ class Config:
645
+ extra = "forbid"
646
+
647
+
648
+ class ToolCall(ExcludeNoneBaseModel):
649
+ """A tool call in DeepFabric Format.
650
+
651
+ Implements the DeepFabric Format specification:
652
+ - ID: Exactly 9 alphanumeric characters (A-Z, a-z, 0-9)
653
+ - Type: Always "function"
654
+ - Arguments: JSON string (for HuggingFace Arrow/Parquet compatibility)
655
+ """
656
+
657
+ id: str = Field(
658
+ min_length=9,
659
+ max_length=9,
660
+ description="Unique tool call ID - exactly 9 alphanumeric characters",
661
+ )
662
+ type: Literal["function"] = Field(default="function", description="Tool call type")
663
+ function: ToolCallFunction = Field(description="Function call details")
664
+
665
+ @field_validator("id")
666
+ @classmethod
667
+ def validate_tool_call_id(cls, v: str) -> str:
668
+ """Validate that ID is exactly 9 alphanumeric characters."""
669
+ if not TOOL_CALL_ID_PATTERN.match(v):
670
+ raise ValueError(
671
+ f"Tool call ID must be exactly 9 alphanumeric characters (A-Z, a-z, 0-9), got: '{v}'"
672
+ )
673
+ return v
674
+
675
+ class Config:
676
+ extra = "forbid"
677
+
678
+
679
+ # Resolve forward reference for ChatMessage.tool_calls
680
+ ChatMessage.model_rebuild()
681
+
682
+
683
+ class ToolMessage(BaseModel):
684
+ """A message that includes tool/function calling."""
685
+
686
+ role: Literal["system", "user", "assistant", "tool"] = Field(
687
+ description="The role of the message sender"
688
+ )
689
+ content: str | None = Field(default=None, description="The text content of the message")
690
+ function_call: FunctionCall | None = Field(
691
+ default=None, description="Function call made by the assistant"
692
+ )
693
+ tool_calls: list[FunctionCall] | None = Field(
694
+ default=None, description="Multiple tool calls made by the assistant"
695
+ )
696
+
697
+
698
+ class ToolConversation(BaseModel):
699
+ """A conversation that may include function/tool calls."""
700
+
701
+ messages: list[ToolMessage] = Field(
702
+ description="List of messages that may include tool calls", min_length=1
703
+ )
704
+
705
+
706
+ # Chain of Thought schemas
707
+ class FreeTextCoT(BaseModel):
708
+ """Chain of Thought dataset in free-text format (GSM8K style)."""
709
+
710
+ question: str = Field(description="The question or problem to solve")
711
+ chain_of_thought: str = Field(description="Natural language reasoning explanation")
712
+ final_answer: str = Field(description="The definitive answer to the question")
713
+
714
+
715
+ class StructuredCoT(BaseModel):
716
+ """Chain of Thought dataset with structured reasoning trace."""
717
+
718
+ messages: list[ChatMessage] = Field(description="Conversation messages", min_length=1)
719
+ reasoning_trace: list[ReasoningStep] = Field(
720
+ description="Structured reasoning steps", min_length=1
721
+ )
722
+ final_answer: str = Field(description="The definitive answer to the question")
723
+
724
+
725
+ class HybridCoT(BaseModel):
726
+ """Chain of Thought dataset with both free-text and structured reasoning."""
727
+
728
+ question: str = Field(description="The question or problem to solve")
729
+ chain_of_thought: str = Field(description="Natural language reasoning explanation")
730
+ reasoning_trace: list[ReasoningStep] = Field(
731
+ description="Structured reasoning steps", min_length=1
732
+ )
733
+ final_answer: str = Field(description="The definitive answer to the question")
734
+
735
+
736
+ # Mathematical variants with numerical-only final answers
737
+ class MathematicalAnswerMixin:
738
+ """Mixin class providing mathematical answer formatting and validation."""
739
+
740
+ @classmethod
741
+ def _format_mathematical_answer(cls, v: str) -> str:
742
+ """Format mathematical answers with strict consistency rules."""
743
+ v_stripped = v.strip()
744
+
745
+ # Handle cases where model returns multiple answers (e.g., "2, 3")
746
+ # Take the first one if comma-separated list detected
747
+ if ", " in v_stripped:
748
+ v_stripped = v_stripped.split(", ")[0].strip()
749
+
750
+ # Basic validation pattern
751
+ pattern = r"^-?\d{1,3}(,\d{3})*(\.\d+)?([eE][+-]?\d+)?$|^-?\d+(\.\d+)?([eE][+-]?\d+)?$"
752
+ if not re.match(pattern, v_stripped):
753
+ msg = f"final_answer must be numerical, got: {v}"
754
+ raise ValueError(msg)
755
+
756
+ # Remove commas for processing
757
+ v_clean = v_stripped.replace(",", "")
758
+
759
+ # Apply formatting rules for consistency
760
+ if cls._is_scientific_notation(v_clean):
761
+ return v_clean # Preserve scientific notation
762
+
763
+ if "." in v_clean:
764
+ decimal_parts = v_clean.split(".")
765
+ if len(decimal_parts) == 2: # noqa: PLR2004
766
+ decimal_places = len(decimal_parts[1])
767
+ # Round to 2 decimal places for precision artifacts
768
+ if decimal_places >= 3: # noqa: PLR2004
769
+ num = Decimal(v_clean)
770
+ rounded = num.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
771
+ v_clean = str(rounded)
772
+
773
+ return v_clean
774
+
775
+ @staticmethod
776
+ def _is_scientific_notation(value: str) -> bool:
777
+ """Detect scientific notation."""
778
+ return "e" in value.lower()
779
+
780
+
781
+ # Capability Models for Composable Conversation Schema
782
+ class ReasoningTrace(BaseModel):
783
+ """Reasoning capability - present when conversation_type='chain_of_thought'."""
784
+
785
+ style: Literal["freetext", "agent"] = Field(
786
+ description="The reasoning style: freetext (natural language) or agent (structured step-by-step for tool-calling)"
787
+ )
788
+ content: str | list[ReasoningStep] = Field(
789
+ description="Reasoning content - string for freetext, list of ReasoningStep for agent"
790
+ )
791
+
792
+ class Config:
793
+ extra = "forbid"
794
+
795
+
796
+ class ToolContext(BaseModel):
797
+ """Tool execution history - present when tools are used.
798
+
799
+ Note: available_tools has been removed as it was redundant with the
800
+ top-level 'tools' field in Conversation. Use 'tools' for the OpenAI-format
801
+ tool definitions needed by chat templates.
802
+ """
803
+
804
+ executions: list[ToolExecution] = Field(
805
+ default_factory=list,
806
+ description="Tool executions performed during the conversation (may be empty if agent answered without tools)",
807
+ )
808
+
809
+ class Config:
810
+ extra = "forbid"
811
+
812
+
813
+ class AgentContext(BaseModel):
814
+ """Agent capability - present when agent_mode is enabled."""
815
+
816
+ mode: Literal["single_turn", "multi_turn"] = Field(
817
+ description="Agent interaction mode: single_turn for one-shot tool use, multi_turn for extended conversations"
818
+ )
819
+ planning_trace: str | None = Field(
820
+ default=None, description="Agent's planning and reasoning about tool usage strategy"
821
+ )
822
+ execution_summary: str | None = Field(
823
+ default=None, description="Summary of agent's execution and results interpretation"
824
+ )
825
+
826
+ class Config:
827
+ extra = "forbid"
828
+
829
+
830
+ class Conversation(ExcludeNoneBaseModel):
831
+ """
832
+ Unified conversation schema with optional capability fields.
833
+
834
+ This composable schema supports various combinations:
835
+ - Basic conversation: just messages
836
+ - With reasoning: messages + reasoning capability
837
+ - With tools: messages + tool_context capability
838
+ - Agent mode: messages + tool_context + agent_context capabilities
839
+ - Full combination: all capabilities enabled
840
+
841
+ The schema validates that capability combinations are consistent
842
+ (e.g., agent_context requires tool_context).
843
+ """
844
+
845
+ messages: list[ChatMessage] = Field(description="Core conversation messages", min_length=1)
846
+ metadata: MetadataDict = Field(
847
+ default=None, description="Conversation metadata (topic, domain, etc.)"
848
+ )
849
+
850
+ # Optional capability fields - use empty strings/dicts instead of None for OpenAI compatibility
851
+ reasoning: ReasoningTrace | None = Field(
852
+ default=None, description="Reasoning capability - chain of thought traces"
853
+ )
854
+ tool_context: ToolContext | None = Field(
855
+ default=None, description="Tool capability - available tools and executions"
856
+ )
857
+ tools: list[dict[str, Any]] | None = Field(
858
+ default=None,
859
+ description="OpenAI-compatible tool definitions (populated from tool_context for training)",
860
+ )
861
+ agent_context: AgentContext | None = Field(
862
+ default=None, description="Agent capability - agentic behavior and planning"
863
+ )
864
+ structured_data: MetadataDict = Field(
865
+ default=None, description="Additional structured data for specific formats"
866
+ )
867
+
868
+ # Fields for backward compatibility and specific use cases
869
+ question: str = Field(default="", description="Original question (useful for Q&A formats)")
870
+ final_answer: str = Field(default="", description="Final answer (useful for reasoning formats)")
871
+
872
+ @field_validator("reasoning")
873
+ @classmethod
874
+ def validate_reasoning_trace(cls, v: ReasoningTrace | None) -> ReasoningTrace | None:
875
+ """Validate reasoning trace content matches style."""
876
+ if v is None:
877
+ return None
878
+
879
+ if v.style == "agent" and not isinstance(v.content, list):
880
+ msg = (
881
+ f"Reasoning style '{v.style}' requires list of ReasoningStep, got {type(v.content)}"
882
+ )
883
+ raise ValueError(msg)
884
+ if v.style == "freetext" and not isinstance(v.content, str):
885
+ msg = f"Reasoning style 'freetext' requires string content, got {type(v.content)}"
886
+ raise ValueError(msg)
887
+
888
+ return v
889
+
890
+ @field_validator("agent_context")
891
+ @classmethod
892
+ def validate_agent_requires_tools(cls, v: AgentContext | None, info) -> AgentContext | None:
893
+ """Validate that agent_context requires tool_context."""
894
+ if v is not None:
895
+ # Access tool_context from the model data
896
+ tool_context = info.data.get("tool_context")
897
+ if tool_context is None:
898
+ msg = "agent_context requires tool_context to be present"
899
+ raise ValueError(msg)
900
+ return v
901
+
902
+ class Config:
903
+ extra = "forbid"
904
+ json_schema_extra = {"additionalProperties": False}
905
+
906
+
907
+ class FormattedSample(BaseModel):
908
+ """Output from chat template formatting.
909
+
910
+ This model represents a dataset sample after applying a model-specific
911
+ chat template via the Dataset.format() method. The text field contains
912
+ the fully formatted prompt ready for training or inference.
913
+
914
+ Attributes:
915
+ text: Formatted text with chat template applied (e.g., ChatML, Llama format)
916
+
917
+ Example:
918
+ >>> sample = FormattedSample(text="<|im_start|>user\\nHello<|im_end|>...")
919
+ >>> sample.text[:20]
920
+ '<|im_start|>user\\nHe'
921
+ """
922
+
923
+ text: str = Field(..., min_length=1, description="Chat template formatted text")
924
+
925
+ class Config:
926
+ extra = "forbid"
927
+
928
+
929
+ # Unified conversation schema mapping
930
+ CONVERSATION_SCHEMAS = {
931
+ "basic": Conversation,
932
+ "chain_of_thought": Conversation,
933
+ }
934
+
935
+
936
+ def get_conversation_schema(
937
+ conversation_type: str = "basic",
938
+ ) -> type[Conversation]:
939
+ """Get the appropriate schema for a conversation configuration.
940
+
941
+ With the unified Conversation schema, this now always returns Conversation.
942
+ The schema's capability fields (reasoning, tool_context, agent_context) are
943
+ populated based on the configuration during generation.
944
+
945
+ Args:
946
+ conversation_type: Type of conversation (basic, chain_of_thought)
947
+
948
+ Returns:
949
+ Conversation schema (unified for all types)
950
+
951
+ Raises:
952
+ ValueError: If conversation_type is not supported
953
+ """
954
+ if conversation_type not in CONVERSATION_SCHEMAS:
955
+ valid_types = ", ".join(CONVERSATION_SCHEMAS.keys())
956
+ msg = f"Unsupported conversation type: {conversation_type}. Valid types: {valid_types}"
957
+ raise ValueError(msg)
958
+
959
+ # All types now use the unified Conversation schema
960
+ # Capabilities are populated during generation based on config
961
+ return CONVERSATION_SCHEMAS[conversation_type]
962
+
963
+
964
+ # Topic generation schemas for tree and graph (needed by other modules)
965
+ class TopicList(BaseModel):
966
+ """A list of subtopics for tree/graph generation."""
967
+
968
+ subtopics: list[str] = Field(
969
+ description="List of subtopic names",
970
+ min_length=1,
971
+ )
972
+
973
+
974
+ class TopicNode(BaseModel):
975
+ """A topic node with subtopics for graph generation."""
976
+
977
+ topic: str = Field(description="The topic name")
978
+ subtopics: list[str] = Field(
979
+ description="List of subtopic names",
980
+ default_factory=list,
981
+ )
982
+
983
+
984
+ class GraphSubtopic(BaseModel):
985
+ """A subtopic with connections for graph generation."""
986
+
987
+ topic: str = Field(description="The subtopic name")
988
+ connections: list[int] = Field(
989
+ default_factory=list,
990
+ description="List of existing node IDs to connect to, empty list if none",
991
+ )
992
+
993
+
994
+ class GraphSubtopics(BaseModel):
995
+ """List of subtopics with connections for graph generation."""
996
+
997
+ subtopics: list[GraphSubtopic] = Field(
998
+ description="List of subtopics with their connections",
999
+ min_length=1,
1000
+ )