DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
deepfabric/schemas.py
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
import secrets
|
|
6
|
+
import string
|
|
7
|
+
|
|
8
|
+
from decimal import ROUND_HALF_UP, Decimal
|
|
9
|
+
from typing import Annotated, Any, Literal
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, BeforeValidator, Field, field_validator
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Tool Call ID constants for DeepFabric Format
|
|
16
|
+
TOOL_CALL_ID_CHARS = string.ascii_letters + string.digits
|
|
17
|
+
TOOL_CALL_ID_LENGTH = 9
|
|
18
|
+
TOOL_CALL_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{9}$")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def generate_tool_call_id() -> str:
|
|
22
|
+
"""Generate a 9-character alphanumeric tool call ID.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
A string of exactly 9 alphanumeric characters (A-Z, a-z, 0-9).
|
|
26
|
+
"""
|
|
27
|
+
return "".join(secrets.choice(TOOL_CALL_ID_CHARS) for _ in range(TOOL_CALL_ID_LENGTH))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Type alias for metadata/structured_data fields
|
|
31
|
+
# Provider-specific transformations in llm/client.py handle:
|
|
32
|
+
# - OpenAI: adds additionalProperties: false
|
|
33
|
+
# - Gemini: strips additionalProperties
|
|
34
|
+
MetadataDict = dict[str, Any] | None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ExcludeNoneBaseModel(BaseModel):
|
|
38
|
+
"""Base model that excludes None values during serialization."""
|
|
39
|
+
|
|
40
|
+
def model_dump(self, **kwargs):
|
|
41
|
+
"""Override to always exclude None values for clean serialization."""
|
|
42
|
+
kwargs.setdefault("exclude_none", True)
|
|
43
|
+
return super().model_dump(**kwargs)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Basic message schema
|
|
47
|
+
class ChatMessage(ExcludeNoneBaseModel):
|
|
48
|
+
"""A single message in a conversation."""
|
|
49
|
+
|
|
50
|
+
model_config = {"extra": "forbid"}
|
|
51
|
+
|
|
52
|
+
role: Literal["system", "user", "assistant", "tool"] = Field(
|
|
53
|
+
description="The role of the message sender"
|
|
54
|
+
)
|
|
55
|
+
content: str | None = Field(
|
|
56
|
+
default=None, description="The content of the message (optional when tool_calls is present)"
|
|
57
|
+
)
|
|
58
|
+
tool_calls: list["ToolCall"] | None = Field(
|
|
59
|
+
default=None, description="Tool calls made by the assistant (DeepFabric format)"
|
|
60
|
+
)
|
|
61
|
+
tool_call_id: str | None = Field(
|
|
62
|
+
default=None,
|
|
63
|
+
description="ID linking tool result to the original tool call (9-char alphanumeric)",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@field_validator("tool_call_id")
|
|
67
|
+
@classmethod
|
|
68
|
+
def validate_tool_call_id_format(cls, v: str | None) -> str | None:
|
|
69
|
+
"""Validate tool_call_id matches the 9-char alphanumeric format when present."""
|
|
70
|
+
if v is not None and not TOOL_CALL_ID_PATTERN.match(v):
|
|
71
|
+
raise ValueError(f"tool_call_id must be exactly 9 alphanumeric characters, got: '{v}'")
|
|
72
|
+
return v
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ChatTranscript(BaseModel):
|
|
76
|
+
"""A complete conversation transcript with messages."""
|
|
77
|
+
|
|
78
|
+
messages: list[ChatMessage] = Field(
|
|
79
|
+
description="List of messages in the conversation", min_length=1
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ReasoningStep(BaseModel):
|
|
84
|
+
"""A single step in a chain of reasoning."""
|
|
85
|
+
|
|
86
|
+
step_number: int = Field(description="The step number in the reasoning chain")
|
|
87
|
+
thought: str = Field(description="The reasoning or thought for this step")
|
|
88
|
+
action: str | None = Field(
|
|
89
|
+
default=None,
|
|
90
|
+
description=(
|
|
91
|
+
"Action taken in this reasoning step. For tool-calling, use one of these formats: "
|
|
92
|
+
"1) Plain function name: 'get_weather' "
|
|
93
|
+
"2) Function call: 'get_weather(city=\"Paris\")' "
|
|
94
|
+
"3) Descriptive text (less reliable): 'I will call the get_weather tool'"
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class StructuredConversation(BaseModel):
|
|
100
|
+
"""A conversation with optional structured reasoning and metadata."""
|
|
101
|
+
|
|
102
|
+
messages: list[ChatMessage] = Field(
|
|
103
|
+
description="List of messages in the conversation", min_length=1
|
|
104
|
+
)
|
|
105
|
+
reasoning_trace: list[ReasoningStep] | None = Field(
|
|
106
|
+
default=None, description="Optional chain of reasoning steps"
|
|
107
|
+
)
|
|
108
|
+
metadata: dict[str, Any] | None = Field(
|
|
109
|
+
default=None, description="Optional metadata about the conversation"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# Tool definition schemas for structured tool system
|
|
114
|
+
class ToolParameter(BaseModel):
|
|
115
|
+
"""A single parameter for a tool/function."""
|
|
116
|
+
|
|
117
|
+
name: str = Field(description="Parameter name")
|
|
118
|
+
type: Literal["str", "int", "float", "bool", "list", "dict"] = Field(
|
|
119
|
+
description="Parameter type"
|
|
120
|
+
)
|
|
121
|
+
description: str = Field(description="What this parameter does")
|
|
122
|
+
required: bool = Field(default=True, description="Whether this parameter is required")
|
|
123
|
+
default: str | None = Field(
|
|
124
|
+
default=None,
|
|
125
|
+
description=(
|
|
126
|
+
"Default value if not provided. Stored as string for HuggingFace Datasets compatibility "
|
|
127
|
+
"(Arrow/Parquet requires consistent types). Actual type is preserved in 'type' field."
|
|
128
|
+
),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# MCP (Model Context Protocol) tool schema models
|
|
133
|
+
class MCPInputSchemaProperty(BaseModel):
|
|
134
|
+
"""A single property in an MCP input schema."""
|
|
135
|
+
|
|
136
|
+
model_config = {"extra": "allow"}
|
|
137
|
+
|
|
138
|
+
type: str = Field(default="string", description="JSON Schema type")
|
|
139
|
+
description: str = Field(default="", description="Property description")
|
|
140
|
+
default: Any | None = Field(default=None, description="Default value")
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class MCPInputSchema(BaseModel):
|
|
144
|
+
"""MCP tool input schema (JSON Schema format)."""
|
|
145
|
+
|
|
146
|
+
model_config = {"extra": "allow"}
|
|
147
|
+
|
|
148
|
+
type: str = Field(default="object", description="Schema type")
|
|
149
|
+
properties: dict[str, MCPInputSchemaProperty] = Field(
|
|
150
|
+
default_factory=dict, description="Parameter properties"
|
|
151
|
+
)
|
|
152
|
+
required: list[str] = Field(default_factory=list, description="Required parameter names")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class MCPToolDefinition(BaseModel):
|
|
156
|
+
"""MCP (Model Context Protocol) tool definition.
|
|
157
|
+
|
|
158
|
+
See: https://modelcontextprotocol.io/specification/2025-06-18/schema#tool
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
model_config = {"extra": "allow"}
|
|
162
|
+
|
|
163
|
+
name: str = Field(description="Tool name")
|
|
164
|
+
description: str = Field(default="", description="Tool description")
|
|
165
|
+
input_schema: MCPInputSchema | None = Field(
|
|
166
|
+
default=None,
|
|
167
|
+
alias="inputSchema",
|
|
168
|
+
description="JSON Schema for tool parameters (optional, some tools have no params)",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def input_schema_safe(self) -> MCPInputSchema:
|
|
173
|
+
"""Get input_schema, returning empty schema if None."""
|
|
174
|
+
return self.input_schema or MCPInputSchema()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ToolDefinition(BaseModel):
|
|
178
|
+
"""Complete definition of a tool/function."""
|
|
179
|
+
|
|
180
|
+
name: str = Field(description="Tool name (function name)")
|
|
181
|
+
description: str = Field(description="What this tool does")
|
|
182
|
+
parameters: list[ToolParameter] = Field(description="List of parameters this tool accepts")
|
|
183
|
+
returns: str = Field(description="Description of what this tool returns")
|
|
184
|
+
category: str = Field(default="general", description="Tool category for grouping")
|
|
185
|
+
component: str | None = Field(
|
|
186
|
+
default=None,
|
|
187
|
+
description="Spin component name that implements this tool (e.g., 'vfs', 'github-mock')",
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def to_signature(self) -> str:
|
|
191
|
+
"""Generate a function signature string."""
|
|
192
|
+
params = []
|
|
193
|
+
for p in self.parameters:
|
|
194
|
+
if p.required:
|
|
195
|
+
params.append(f"{p.name}: {p.type}")
|
|
196
|
+
else:
|
|
197
|
+
params.append(f"{p.name}: {p.type} = {p.default}")
|
|
198
|
+
return f"{self.name}({', '.join(params)}) → {self.returns}"
|
|
199
|
+
|
|
200
|
+
def to_openai(self) -> dict[str, Any]:
|
|
201
|
+
"""
|
|
202
|
+
Convert tool definition to OpenAI function calling schema format.
|
|
203
|
+
|
|
204
|
+
This format is compatible with TRL's SFTTrainer and other HuggingFace
|
|
205
|
+
training frameworks that support tool/function calling.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Dictionary in OpenAI function calling schema format with:
|
|
209
|
+
- type: Always "function"
|
|
210
|
+
- function: Object containing name, description, and parameters schema
|
|
211
|
+
"""
|
|
212
|
+
# Map DeepFabric types to JSON Schema types
|
|
213
|
+
type_mapping = {
|
|
214
|
+
"str": "string",
|
|
215
|
+
"int": "integer",
|
|
216
|
+
"float": "number",
|
|
217
|
+
"bool": "boolean",
|
|
218
|
+
"list": "array",
|
|
219
|
+
"dict": "object",
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
properties = {}
|
|
223
|
+
required = []
|
|
224
|
+
|
|
225
|
+
for param in self.parameters:
|
|
226
|
+
json_type = type_mapping.get(param.type, "string")
|
|
227
|
+
properties[param.name] = {
|
|
228
|
+
"type": json_type,
|
|
229
|
+
"description": param.description,
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
# Add default value if present and not required
|
|
233
|
+
# Convert string default back to proper type for JSON Schema
|
|
234
|
+
if not param.required and param.default is not None and param.default != "":
|
|
235
|
+
default_value = param.default
|
|
236
|
+
# Convert string representation back to typed value
|
|
237
|
+
try:
|
|
238
|
+
if param.type == "int":
|
|
239
|
+
default_value = int(param.default)
|
|
240
|
+
elif param.type == "float":
|
|
241
|
+
default_value = float(param.default)
|
|
242
|
+
elif param.type == "bool":
|
|
243
|
+
default_value = param.default.lower() in ("true", "1", "yes")
|
|
244
|
+
elif param.type in ("list", "dict"):
|
|
245
|
+
# Handle special cases
|
|
246
|
+
if param.default.lower() == "none":
|
|
247
|
+
default_value = None
|
|
248
|
+
else:
|
|
249
|
+
# Try JSON first
|
|
250
|
+
try:
|
|
251
|
+
default_value = json.loads(param.default)
|
|
252
|
+
except json.JSONDecodeError:
|
|
253
|
+
# Fallback: try Python literal (e.g., ['markdown'] -> ["markdown"])
|
|
254
|
+
try:
|
|
255
|
+
default_value = ast.literal_eval(param.default)
|
|
256
|
+
except (ValueError, SyntaxError):
|
|
257
|
+
# Give up and skip this default
|
|
258
|
+
continue
|
|
259
|
+
# str remains as-is
|
|
260
|
+
properties[param.name]["default"] = default_value
|
|
261
|
+
except (ValueError, AttributeError):
|
|
262
|
+
# Skip invalid default values (e.g., empty strings for int/float)
|
|
263
|
+
pass
|
|
264
|
+
|
|
265
|
+
# Track required parameters
|
|
266
|
+
if param.required:
|
|
267
|
+
required.append(param.name)
|
|
268
|
+
|
|
269
|
+
return {
|
|
270
|
+
"type": "function",
|
|
271
|
+
"function": {
|
|
272
|
+
"name": self.name,
|
|
273
|
+
"description": self.description,
|
|
274
|
+
"parameters": {
|
|
275
|
+
"type": "object",
|
|
276
|
+
"properties": properties,
|
|
277
|
+
"required": required,
|
|
278
|
+
},
|
|
279
|
+
},
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
@classmethod
|
|
283
|
+
def from_openai(cls, openai_tool: dict[str, Any]) -> "ToolDefinition":
|
|
284
|
+
"""Create a ToolDefinition from OpenAI function calling schema format.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
openai_tool: Dictionary in OpenAI format with type="function" and function object
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
ToolDefinition instance
|
|
291
|
+
"""
|
|
292
|
+
# Reverse type mapping
|
|
293
|
+
type_mapping = {
|
|
294
|
+
"string": "str",
|
|
295
|
+
"integer": "int",
|
|
296
|
+
"number": "float",
|
|
297
|
+
"boolean": "bool",
|
|
298
|
+
"array": "list",
|
|
299
|
+
"object": "dict",
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
func = openai_tool.get("function", {})
|
|
303
|
+
name = func.get("name", "")
|
|
304
|
+
description = func.get("description", "")
|
|
305
|
+
params_schema = func.get("parameters", {})
|
|
306
|
+
|
|
307
|
+
properties = params_schema.get("properties", {})
|
|
308
|
+
required_params = set(params_schema.get("required", []))
|
|
309
|
+
|
|
310
|
+
parameters = []
|
|
311
|
+
for param_name, param_props in properties.items():
|
|
312
|
+
# Skip None/null property values (can occur in some tool schemas)
|
|
313
|
+
if param_props is None:
|
|
314
|
+
continue
|
|
315
|
+
json_type = param_props.get("type", "string")
|
|
316
|
+
df_type = type_mapping.get(json_type, "str")
|
|
317
|
+
|
|
318
|
+
default = param_props.get("default")
|
|
319
|
+
default_str = str(default) if default is not None else ""
|
|
320
|
+
|
|
321
|
+
parameters.append(
|
|
322
|
+
ToolParameter(
|
|
323
|
+
name=param_name,
|
|
324
|
+
type=df_type, # type: ignore[arg-type]
|
|
325
|
+
description=param_props.get("description", ""),
|
|
326
|
+
required=param_name in required_params,
|
|
327
|
+
default=default_str,
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
return cls(
|
|
332
|
+
name=name,
|
|
333
|
+
description=description,
|
|
334
|
+
parameters=parameters,
|
|
335
|
+
returns="", # OpenAI format doesn't include return description
|
|
336
|
+
category="general",
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
@classmethod
|
|
340
|
+
def from_mcp(cls, mcp_tool: MCPToolDefinition | dict[str, Any]) -> "ToolDefinition":
|
|
341
|
+
"""Create a ToolDefinition from MCP (Model Context Protocol) tool schema.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
mcp_tool: Either an MCPToolDefinition instance or a dict in MCP format
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
ToolDefinition instance
|
|
348
|
+
"""
|
|
349
|
+
# Type mapping from JSON Schema to DeepFabric types
|
|
350
|
+
type_mapping = {
|
|
351
|
+
"string": "str",
|
|
352
|
+
"integer": "int",
|
|
353
|
+
"number": "float",
|
|
354
|
+
"boolean": "bool",
|
|
355
|
+
"array": "list",
|
|
356
|
+
"object": "dict",
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
# Parse dict to MCPToolDefinition if needed
|
|
360
|
+
if isinstance(mcp_tool, dict):
|
|
361
|
+
mcp_tool = MCPToolDefinition.model_validate(mcp_tool)
|
|
362
|
+
|
|
363
|
+
# Use safe property to handle None input_schema
|
|
364
|
+
input_schema = mcp_tool.input_schema_safe
|
|
365
|
+
parameters = []
|
|
366
|
+
required_params = set(input_schema.required)
|
|
367
|
+
|
|
368
|
+
for param_name, param_props in input_schema.properties.items():
|
|
369
|
+
df_type = type_mapping.get(param_props.type, "str")
|
|
370
|
+
default_str = str(param_props.default) if param_props.default is not None else ""
|
|
371
|
+
|
|
372
|
+
parameters.append(
|
|
373
|
+
ToolParameter(
|
|
374
|
+
name=param_name,
|
|
375
|
+
type=df_type, # type: ignore[arg-type]
|
|
376
|
+
description=param_props.description,
|
|
377
|
+
required=param_name in required_params,
|
|
378
|
+
default=default_str,
|
|
379
|
+
)
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
return cls(
|
|
383
|
+
name=mcp_tool.name,
|
|
384
|
+
description=mcp_tool.description,
|
|
385
|
+
parameters=parameters,
|
|
386
|
+
returns="", # MCP format doesn't include return description
|
|
387
|
+
category="general",
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class ToolRegistry(BaseModel):
|
|
392
|
+
"""Registry of available tools."""
|
|
393
|
+
|
|
394
|
+
tools: list[ToolDefinition] = Field(description="List of available tool definitions")
|
|
395
|
+
|
|
396
|
+
def get_tool(self, name: str) -> ToolDefinition | None:
|
|
397
|
+
"""Get a tool by name."""
|
|
398
|
+
return next((t for t in self.tools if t.name == name), None)
|
|
399
|
+
|
|
400
|
+
def get_tools_by_category(self, category: str) -> list[ToolDefinition]:
|
|
401
|
+
"""Get all tools in a category."""
|
|
402
|
+
return [t for t in self.tools if t.category == category]
|
|
403
|
+
|
|
404
|
+
def get_tool_names(self) -> list[str]:
|
|
405
|
+
"""Get list of all tool names."""
|
|
406
|
+
return [t.name for t in self.tools]
|
|
407
|
+
|
|
408
|
+
def to_openai_format(self) -> list[dict[str, Any]]:
|
|
409
|
+
"""
|
|
410
|
+
Convert all tools to OpenAI function calling schema format.
|
|
411
|
+
|
|
412
|
+
This method is specifically designed for use with HuggingFace TRL's
|
|
413
|
+
SFTTrainer and other training frameworks that require tools to be
|
|
414
|
+
provided in OpenAI function calling format.
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
List of tool definitions in OpenAI function calling schema format.
|
|
418
|
+
Each tool includes type="function" and a function object with
|
|
419
|
+
name, description, and parameters.
|
|
420
|
+
|
|
421
|
+
Example:
|
|
422
|
+
>>> registry = ToolRegistry(tools=[...])
|
|
423
|
+
>>> trl_tools = registry.to_openai_format()
|
|
424
|
+
>>> # Use in dataset: {"messages": [...], "tools": trl_tools}
|
|
425
|
+
"""
|
|
426
|
+
return [tool.to_openai() for tool in self.tools]
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
# Agent tool-calling schemas
|
|
430
|
+
class ToolReasoningStep(BaseModel):
|
|
431
|
+
"""A reasoning step that leads to tool selection and parameter construction."""
|
|
432
|
+
|
|
433
|
+
step_number: int = Field(description="The step number in the tool planning sequence")
|
|
434
|
+
reasoning: str = Field(description="Why this tool is needed at this point")
|
|
435
|
+
selected_tool_name: str = Field(description="Name of the tool being selected")
|
|
436
|
+
parameter_reasoning: str = Field(description="Reasoning for parameter values")
|
|
437
|
+
expected_result: str = Field(description="What the tool should return and how it helps")
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class ToolExecution(BaseModel):
|
|
441
|
+
"""Represents actual execution of a tool with reasoning context."""
|
|
442
|
+
|
|
443
|
+
function_name: str = Field(min_length=1, description="Name of the function/tool being called")
|
|
444
|
+
arguments: str = Field(
|
|
445
|
+
min_length=2, description="JSON string of arguments passed to the function"
|
|
446
|
+
)
|
|
447
|
+
reasoning: str = Field(min_length=1, description="Brief explanation of why executing now")
|
|
448
|
+
result: str = Field(min_length=1, description="The result returned from the tool execution")
|
|
449
|
+
|
|
450
|
+
@field_validator("arguments")
|
|
451
|
+
@classmethod
|
|
452
|
+
def validate_arguments_json(cls, v: str) -> str:
|
|
453
|
+
"""Validate that arguments are valid JSON with no null/empty placeholders.
|
|
454
|
+
|
|
455
|
+
Empty objects {} are allowed for parameterless tools like list_files().
|
|
456
|
+
"""
|
|
457
|
+
stripped = v.strip()
|
|
458
|
+
|
|
459
|
+
# Parse and validate JSON
|
|
460
|
+
try:
|
|
461
|
+
parsed = json.loads(stripped)
|
|
462
|
+
if isinstance(parsed, dict):
|
|
463
|
+
for key, value in parsed.items():
|
|
464
|
+
if value is None:
|
|
465
|
+
raise ValueError(f"Argument '{key}' is null - must provide actual value")
|
|
466
|
+
if isinstance(value, str) and value == "":
|
|
467
|
+
raise ValueError(
|
|
468
|
+
f"Argument '{key}' is empty string - must provide actual value"
|
|
469
|
+
)
|
|
470
|
+
except json.JSONDecodeError as e:
|
|
471
|
+
raise ValueError(f"Arguments must be valid JSON: {e}") from e
|
|
472
|
+
|
|
473
|
+
return v
|
|
474
|
+
|
|
475
|
+
@property
|
|
476
|
+
def parsed_arguments(self) -> dict[str, Any]:
|
|
477
|
+
"""Parse arguments JSON string to dict.
|
|
478
|
+
|
|
479
|
+
Uses Any for values as function arguments can be strings, numbers, booleans, lists, nested dicts, etc.
|
|
480
|
+
"""
|
|
481
|
+
return json.loads(self.arguments)
|
|
482
|
+
|
|
483
|
+
def to_tool_call(self, tool_call_id: str | None = None) -> "ToolCall":
|
|
484
|
+
"""Convert ToolExecution to a ToolCall for the final dataset.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
tool_call_id: The 9-char alphanumeric ID. If None, generates one.
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
ToolCall with JSON string arguments (HuggingFace compatible).
|
|
491
|
+
"""
|
|
492
|
+
return ToolCall(
|
|
493
|
+
id=tool_call_id or generate_tool_call_id(),
|
|
494
|
+
type="function",
|
|
495
|
+
function=ToolCallFunction(
|
|
496
|
+
name=self.function_name,
|
|
497
|
+
arguments=self.arguments, # Pass string directly
|
|
498
|
+
),
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
class Config:
|
|
502
|
+
extra = "forbid"
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
class PendingToolCall(BaseModel):
|
|
506
|
+
"""A tool call request before execution (no result yet).
|
|
507
|
+
|
|
508
|
+
Used in AgentStep for the ReAct loop - the LLM generates these,
|
|
509
|
+
then Spin executes them and populates the result separately.
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
function_name: str = Field(min_length=1, description="Name of the function/tool to call")
|
|
513
|
+
arguments: str = Field(
|
|
514
|
+
min_length=2, description="JSON string of arguments to pass to the function"
|
|
515
|
+
)
|
|
516
|
+
reasoning: str = Field(min_length=1, description="Brief explanation of why calling this tool")
|
|
517
|
+
|
|
518
|
+
@field_validator("arguments")
|
|
519
|
+
@classmethod
|
|
520
|
+
def validate_arguments_json(cls, v: str) -> str:
|
|
521
|
+
"""Validate that arguments are valid JSON with no null/empty placeholders.
|
|
522
|
+
|
|
523
|
+
Empty objects {} are allowed for parameterless tools like list_files().
|
|
524
|
+
"""
|
|
525
|
+
stripped = v.strip()
|
|
526
|
+
|
|
527
|
+
# Parse and validate JSON
|
|
528
|
+
try:
|
|
529
|
+
parsed = json.loads(stripped)
|
|
530
|
+
if isinstance(parsed, dict):
|
|
531
|
+
for key, value in parsed.items():
|
|
532
|
+
if value is None:
|
|
533
|
+
raise ValueError(f"Argument '{key}' is null - must provide actual value")
|
|
534
|
+
if isinstance(value, str) and value == "":
|
|
535
|
+
raise ValueError(
|
|
536
|
+
f"Argument '{key}' is empty string - must provide actual value"
|
|
537
|
+
)
|
|
538
|
+
except json.JSONDecodeError as e:
|
|
539
|
+
raise ValueError(f"Arguments must be valid JSON: {e}") from e
|
|
540
|
+
|
|
541
|
+
return v
|
|
542
|
+
|
|
543
|
+
@property
|
|
544
|
+
def parsed_arguments(self) -> dict[str, Any]:
|
|
545
|
+
"""Parse arguments JSON string to dict."""
|
|
546
|
+
return json.loads(self.arguments)
|
|
547
|
+
|
|
548
|
+
def to_tool_execution(self, result: str) -> "ToolExecution":
|
|
549
|
+
"""Convert to ToolExecution after getting result from Spin.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
result: The result from tool execution
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
ToolExecution with the result populated
|
|
556
|
+
"""
|
|
557
|
+
return ToolExecution(
|
|
558
|
+
function_name=self.function_name,
|
|
559
|
+
arguments=self.arguments,
|
|
560
|
+
reasoning=self.reasoning,
|
|
561
|
+
result=result,
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
class Config:
|
|
565
|
+
extra = "forbid"
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class AgentStep(BaseModel):
|
|
569
|
+
"""A single step in ReAct-style agent reasoning.
|
|
570
|
+
|
|
571
|
+
Each step represents one iteration of the think-act-observe loop:
|
|
572
|
+
1. Agent thinks about what to do next (thought)
|
|
573
|
+
2. Agent decides on tool calls for THIS step only
|
|
574
|
+
3. Tools are executed and results observed
|
|
575
|
+
4. Process repeats until is_final=True
|
|
576
|
+
|
|
577
|
+
This ensures tool calls are made based on observed results,
|
|
578
|
+
not hallucinated assumptions about what tools will return.
|
|
579
|
+
"""
|
|
580
|
+
|
|
581
|
+
thought: str = Field(
|
|
582
|
+
min_length=1,
|
|
583
|
+
description="Agent's reasoning about what to do next based on observations so far",
|
|
584
|
+
)
|
|
585
|
+
tool_calls: list["PendingToolCall"] = Field(
|
|
586
|
+
default_factory=list,
|
|
587
|
+
description="Tool calls for THIS step only. Empty if agent is done.",
|
|
588
|
+
)
|
|
589
|
+
is_final: bool = Field(
|
|
590
|
+
default=False,
|
|
591
|
+
description="True if agent has enough information and is ready to respond to user",
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
class Config:
|
|
595
|
+
extra = "forbid"
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
# Tool calling schemas for conversations that include function calls
|
|
599
|
+
class FunctionCall(BaseModel):
|
|
600
|
+
"""A function call with arguments."""
|
|
601
|
+
|
|
602
|
+
name: str = Field(description="The name of the function to call")
|
|
603
|
+
arguments: dict[str, Any] = Field(description="Arguments to pass to the function")
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def _serialize_arguments(v: dict[str, Any] | str) -> str:
|
|
607
|
+
"""Serialize arguments to JSON string, stripping None values.
|
|
608
|
+
|
|
609
|
+
This ensures consistent schema for HuggingFace datasets (Arrow/Parquet compatibility).
|
|
610
|
+
Accepts both dict and str inputs for backward compatibility.
|
|
611
|
+
"""
|
|
612
|
+
match v:
|
|
613
|
+
case dict():
|
|
614
|
+
cleaned = {k: val for k, val in v.items() if val is not None}
|
|
615
|
+
return json.dumps(cleaned, separators=(",", ":"))
|
|
616
|
+
case str():
|
|
617
|
+
# Validate JSON, strip nulls, re-serialize for consistency
|
|
618
|
+
parsed = json.loads(v)
|
|
619
|
+
cleaned = (
|
|
620
|
+
{k: val for k, val in parsed.items() if val is not None}
|
|
621
|
+
if isinstance(parsed, dict)
|
|
622
|
+
else parsed
|
|
623
|
+
)
|
|
624
|
+
return json.dumps(cleaned, separators=(",", ":"))
|
|
625
|
+
case _:
|
|
626
|
+
raise ValueError(f"arguments must be dict or str, got {type(v)}")
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
# Type alias for JSON-serialized arguments (HuggingFace compatible)
|
|
630
|
+
ArgumentsStr = Annotated[str, BeforeValidator(_serialize_arguments)]
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
class ToolCallFunction(ExcludeNoneBaseModel):
|
|
634
|
+
"""Function details within a tool call (DeepFabric Format)."""
|
|
635
|
+
|
|
636
|
+
name: str = Field(min_length=1, description="The name of the function to call")
|
|
637
|
+
arguments: ArgumentsStr = Field(description="Arguments as JSON string (HuggingFace compatible)")
|
|
638
|
+
|
|
639
|
+
@property
|
|
640
|
+
def parsed_arguments(self) -> dict[str, Any]:
|
|
641
|
+
"""Parse arguments JSON string to dict when needed at runtime."""
|
|
642
|
+
return json.loads(self.arguments)
|
|
643
|
+
|
|
644
|
+
class Config:
|
|
645
|
+
extra = "forbid"
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
class ToolCall(ExcludeNoneBaseModel):
|
|
649
|
+
"""A tool call in DeepFabric Format.
|
|
650
|
+
|
|
651
|
+
Implements the DeepFabric Format specification:
|
|
652
|
+
- ID: Exactly 9 alphanumeric characters (A-Z, a-z, 0-9)
|
|
653
|
+
- Type: Always "function"
|
|
654
|
+
- Arguments: JSON string (for HuggingFace Arrow/Parquet compatibility)
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
id: str = Field(
|
|
658
|
+
min_length=9,
|
|
659
|
+
max_length=9,
|
|
660
|
+
description="Unique tool call ID - exactly 9 alphanumeric characters",
|
|
661
|
+
)
|
|
662
|
+
type: Literal["function"] = Field(default="function", description="Tool call type")
|
|
663
|
+
function: ToolCallFunction = Field(description="Function call details")
|
|
664
|
+
|
|
665
|
+
@field_validator("id")
|
|
666
|
+
@classmethod
|
|
667
|
+
def validate_tool_call_id(cls, v: str) -> str:
|
|
668
|
+
"""Validate that ID is exactly 9 alphanumeric characters."""
|
|
669
|
+
if not TOOL_CALL_ID_PATTERN.match(v):
|
|
670
|
+
raise ValueError(
|
|
671
|
+
f"Tool call ID must be exactly 9 alphanumeric characters (A-Z, a-z, 0-9), got: '{v}'"
|
|
672
|
+
)
|
|
673
|
+
return v
|
|
674
|
+
|
|
675
|
+
class Config:
|
|
676
|
+
extra = "forbid"
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
# Resolve forward reference for ChatMessage.tool_calls
|
|
680
|
+
ChatMessage.model_rebuild()
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
class ToolMessage(BaseModel):
|
|
684
|
+
"""A message that includes tool/function calling."""
|
|
685
|
+
|
|
686
|
+
role: Literal["system", "user", "assistant", "tool"] = Field(
|
|
687
|
+
description="The role of the message sender"
|
|
688
|
+
)
|
|
689
|
+
content: str | None = Field(default=None, description="The text content of the message")
|
|
690
|
+
function_call: FunctionCall | None = Field(
|
|
691
|
+
default=None, description="Function call made by the assistant"
|
|
692
|
+
)
|
|
693
|
+
tool_calls: list[FunctionCall] | None = Field(
|
|
694
|
+
default=None, description="Multiple tool calls made by the assistant"
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
class ToolConversation(BaseModel):
|
|
699
|
+
"""A conversation that may include function/tool calls."""
|
|
700
|
+
|
|
701
|
+
messages: list[ToolMessage] = Field(
|
|
702
|
+
description="List of messages that may include tool calls", min_length=1
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
# Chain of Thought schemas
|
|
707
|
+
class FreeTextCoT(BaseModel):
|
|
708
|
+
"""Chain of Thought dataset in free-text format (GSM8K style)."""
|
|
709
|
+
|
|
710
|
+
question: str = Field(description="The question or problem to solve")
|
|
711
|
+
chain_of_thought: str = Field(description="Natural language reasoning explanation")
|
|
712
|
+
final_answer: str = Field(description="The definitive answer to the question")
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
class StructuredCoT(BaseModel):
|
|
716
|
+
"""Chain of Thought dataset with structured reasoning trace."""
|
|
717
|
+
|
|
718
|
+
messages: list[ChatMessage] = Field(description="Conversation messages", min_length=1)
|
|
719
|
+
reasoning_trace: list[ReasoningStep] = Field(
|
|
720
|
+
description="Structured reasoning steps", min_length=1
|
|
721
|
+
)
|
|
722
|
+
final_answer: str = Field(description="The definitive answer to the question")
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
class HybridCoT(BaseModel):
|
|
726
|
+
"""Chain of Thought dataset with both free-text and structured reasoning."""
|
|
727
|
+
|
|
728
|
+
question: str = Field(description="The question or problem to solve")
|
|
729
|
+
chain_of_thought: str = Field(description="Natural language reasoning explanation")
|
|
730
|
+
reasoning_trace: list[ReasoningStep] = Field(
|
|
731
|
+
description="Structured reasoning steps", min_length=1
|
|
732
|
+
)
|
|
733
|
+
final_answer: str = Field(description="The definitive answer to the question")
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
# Mathematical variants with numerical-only final answers
|
|
737
|
+
class MathematicalAnswerMixin:
|
|
738
|
+
"""Mixin class providing mathematical answer formatting and validation."""
|
|
739
|
+
|
|
740
|
+
@classmethod
|
|
741
|
+
def _format_mathematical_answer(cls, v: str) -> str:
|
|
742
|
+
"""Format mathematical answers with strict consistency rules."""
|
|
743
|
+
v_stripped = v.strip()
|
|
744
|
+
|
|
745
|
+
# Handle cases where model returns multiple answers (e.g., "2, 3")
|
|
746
|
+
# Take the first one if comma-separated list detected
|
|
747
|
+
if ", " in v_stripped:
|
|
748
|
+
v_stripped = v_stripped.split(", ")[0].strip()
|
|
749
|
+
|
|
750
|
+
# Basic validation pattern
|
|
751
|
+
pattern = r"^-?\d{1,3}(,\d{3})*(\.\d+)?([eE][+-]?\d+)?$|^-?\d+(\.\d+)?([eE][+-]?\d+)?$"
|
|
752
|
+
if not re.match(pattern, v_stripped):
|
|
753
|
+
msg = f"final_answer must be numerical, got: {v}"
|
|
754
|
+
raise ValueError(msg)
|
|
755
|
+
|
|
756
|
+
# Remove commas for processing
|
|
757
|
+
v_clean = v_stripped.replace(",", "")
|
|
758
|
+
|
|
759
|
+
# Apply formatting rules for consistency
|
|
760
|
+
if cls._is_scientific_notation(v_clean):
|
|
761
|
+
return v_clean # Preserve scientific notation
|
|
762
|
+
|
|
763
|
+
if "." in v_clean:
|
|
764
|
+
decimal_parts = v_clean.split(".")
|
|
765
|
+
if len(decimal_parts) == 2: # noqa: PLR2004
|
|
766
|
+
decimal_places = len(decimal_parts[1])
|
|
767
|
+
# Round to 2 decimal places for precision artifacts
|
|
768
|
+
if decimal_places >= 3: # noqa: PLR2004
|
|
769
|
+
num = Decimal(v_clean)
|
|
770
|
+
rounded = num.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
771
|
+
v_clean = str(rounded)
|
|
772
|
+
|
|
773
|
+
return v_clean
|
|
774
|
+
|
|
775
|
+
@staticmethod
|
|
776
|
+
def _is_scientific_notation(value: str) -> bool:
|
|
777
|
+
"""Detect scientific notation."""
|
|
778
|
+
return "e" in value.lower()
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
# Capability Models for Composable Conversation Schema
|
|
782
|
+
class ReasoningTrace(BaseModel):
|
|
783
|
+
"""Reasoning capability - present when conversation_type='chain_of_thought'."""
|
|
784
|
+
|
|
785
|
+
style: Literal["freetext", "agent"] = Field(
|
|
786
|
+
description="The reasoning style: freetext (natural language) or agent (structured step-by-step for tool-calling)"
|
|
787
|
+
)
|
|
788
|
+
content: str | list[ReasoningStep] = Field(
|
|
789
|
+
description="Reasoning content - string for freetext, list of ReasoningStep for agent"
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
class Config:
|
|
793
|
+
extra = "forbid"
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
class ToolContext(BaseModel):
|
|
797
|
+
"""Tool execution history - present when tools are used.
|
|
798
|
+
|
|
799
|
+
Note: available_tools has been removed as it was redundant with the
|
|
800
|
+
top-level 'tools' field in Conversation. Use 'tools' for the OpenAI-format
|
|
801
|
+
tool definitions needed by chat templates.
|
|
802
|
+
"""
|
|
803
|
+
|
|
804
|
+
executions: list[ToolExecution] = Field(
|
|
805
|
+
default_factory=list,
|
|
806
|
+
description="Tool executions performed during the conversation (may be empty if agent answered without tools)",
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
class Config:
|
|
810
|
+
extra = "forbid"
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
class AgentContext(BaseModel):
|
|
814
|
+
"""Agent capability - present when agent_mode is enabled."""
|
|
815
|
+
|
|
816
|
+
mode: Literal["single_turn", "multi_turn"] = Field(
|
|
817
|
+
description="Agent interaction mode: single_turn for one-shot tool use, multi_turn for extended conversations"
|
|
818
|
+
)
|
|
819
|
+
planning_trace: str | None = Field(
|
|
820
|
+
default=None, description="Agent's planning and reasoning about tool usage strategy"
|
|
821
|
+
)
|
|
822
|
+
execution_summary: str | None = Field(
|
|
823
|
+
default=None, description="Summary of agent's execution and results interpretation"
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
class Config:
|
|
827
|
+
extra = "forbid"
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
class Conversation(ExcludeNoneBaseModel):
|
|
831
|
+
"""
|
|
832
|
+
Unified conversation schema with optional capability fields.
|
|
833
|
+
|
|
834
|
+
This composable schema supports various combinations:
|
|
835
|
+
- Basic conversation: just messages
|
|
836
|
+
- With reasoning: messages + reasoning capability
|
|
837
|
+
- With tools: messages + tool_context capability
|
|
838
|
+
- Agent mode: messages + tool_context + agent_context capabilities
|
|
839
|
+
- Full combination: all capabilities enabled
|
|
840
|
+
|
|
841
|
+
The schema validates that capability combinations are consistent
|
|
842
|
+
(e.g., agent_context requires tool_context).
|
|
843
|
+
"""
|
|
844
|
+
|
|
845
|
+
messages: list[ChatMessage] = Field(description="Core conversation messages", min_length=1)
|
|
846
|
+
metadata: MetadataDict = Field(
|
|
847
|
+
default=None, description="Conversation metadata (topic, domain, etc.)"
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
# Optional capability fields - use empty strings/dicts instead of None for OpenAI compatibility
|
|
851
|
+
reasoning: ReasoningTrace | None = Field(
|
|
852
|
+
default=None, description="Reasoning capability - chain of thought traces"
|
|
853
|
+
)
|
|
854
|
+
tool_context: ToolContext | None = Field(
|
|
855
|
+
default=None, description="Tool capability - available tools and executions"
|
|
856
|
+
)
|
|
857
|
+
tools: list[dict[str, Any]] | None = Field(
|
|
858
|
+
default=None,
|
|
859
|
+
description="OpenAI-compatible tool definitions (populated from tool_context for training)",
|
|
860
|
+
)
|
|
861
|
+
agent_context: AgentContext | None = Field(
|
|
862
|
+
default=None, description="Agent capability - agentic behavior and planning"
|
|
863
|
+
)
|
|
864
|
+
structured_data: MetadataDict = Field(
|
|
865
|
+
default=None, description="Additional structured data for specific formats"
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
# Fields for backward compatibility and specific use cases
|
|
869
|
+
question: str = Field(default="", description="Original question (useful for Q&A formats)")
|
|
870
|
+
final_answer: str = Field(default="", description="Final answer (useful for reasoning formats)")
|
|
871
|
+
|
|
872
|
+
@field_validator("reasoning")
|
|
873
|
+
@classmethod
|
|
874
|
+
def validate_reasoning_trace(cls, v: ReasoningTrace | None) -> ReasoningTrace | None:
|
|
875
|
+
"""Validate reasoning trace content matches style."""
|
|
876
|
+
if v is None:
|
|
877
|
+
return None
|
|
878
|
+
|
|
879
|
+
if v.style == "agent" and not isinstance(v.content, list):
|
|
880
|
+
msg = (
|
|
881
|
+
f"Reasoning style '{v.style}' requires list of ReasoningStep, got {type(v.content)}"
|
|
882
|
+
)
|
|
883
|
+
raise ValueError(msg)
|
|
884
|
+
if v.style == "freetext" and not isinstance(v.content, str):
|
|
885
|
+
msg = f"Reasoning style 'freetext' requires string content, got {type(v.content)}"
|
|
886
|
+
raise ValueError(msg)
|
|
887
|
+
|
|
888
|
+
return v
|
|
889
|
+
|
|
890
|
+
@field_validator("agent_context")
|
|
891
|
+
@classmethod
|
|
892
|
+
def validate_agent_requires_tools(cls, v: AgentContext | None, info) -> AgentContext | None:
|
|
893
|
+
"""Validate that agent_context requires tool_context."""
|
|
894
|
+
if v is not None:
|
|
895
|
+
# Access tool_context from the model data
|
|
896
|
+
tool_context = info.data.get("tool_context")
|
|
897
|
+
if tool_context is None:
|
|
898
|
+
msg = "agent_context requires tool_context to be present"
|
|
899
|
+
raise ValueError(msg)
|
|
900
|
+
return v
|
|
901
|
+
|
|
902
|
+
class Config:
|
|
903
|
+
extra = "forbid"
|
|
904
|
+
json_schema_extra = {"additionalProperties": False}
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
class FormattedSample(BaseModel):
|
|
908
|
+
"""Output from chat template formatting.
|
|
909
|
+
|
|
910
|
+
This model represents a dataset sample after applying a model-specific
|
|
911
|
+
chat template via the Dataset.format() method. The text field contains
|
|
912
|
+
the fully formatted prompt ready for training or inference.
|
|
913
|
+
|
|
914
|
+
Attributes:
|
|
915
|
+
text: Formatted text with chat template applied (e.g., ChatML, Llama format)
|
|
916
|
+
|
|
917
|
+
Example:
|
|
918
|
+
>>> sample = FormattedSample(text="<|im_start|>user\\nHello<|im_end|>...")
|
|
919
|
+
>>> sample.text[:20]
|
|
920
|
+
'<|im_start|>user\\nHe'
|
|
921
|
+
"""
|
|
922
|
+
|
|
923
|
+
text: str = Field(..., min_length=1, description="Chat template formatted text")
|
|
924
|
+
|
|
925
|
+
class Config:
|
|
926
|
+
extra = "forbid"
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
# Unified conversation schema mapping
|
|
930
|
+
CONVERSATION_SCHEMAS = {
|
|
931
|
+
"basic": Conversation,
|
|
932
|
+
"chain_of_thought": Conversation,
|
|
933
|
+
}
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
def get_conversation_schema(
|
|
937
|
+
conversation_type: str = "basic",
|
|
938
|
+
) -> type[Conversation]:
|
|
939
|
+
"""Get the appropriate schema for a conversation configuration.
|
|
940
|
+
|
|
941
|
+
With the unified Conversation schema, this now always returns Conversation.
|
|
942
|
+
The schema's capability fields (reasoning, tool_context, agent_context) are
|
|
943
|
+
populated based on the configuration during generation.
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
conversation_type: Type of conversation (basic, chain_of_thought)
|
|
947
|
+
|
|
948
|
+
Returns:
|
|
949
|
+
Conversation schema (unified for all types)
|
|
950
|
+
|
|
951
|
+
Raises:
|
|
952
|
+
ValueError: If conversation_type is not supported
|
|
953
|
+
"""
|
|
954
|
+
if conversation_type not in CONVERSATION_SCHEMAS:
|
|
955
|
+
valid_types = ", ".join(CONVERSATION_SCHEMAS.keys())
|
|
956
|
+
msg = f"Unsupported conversation type: {conversation_type}. Valid types: {valid_types}"
|
|
957
|
+
raise ValueError(msg)
|
|
958
|
+
|
|
959
|
+
# All types now use the unified Conversation schema
|
|
960
|
+
# Capabilities are populated during generation based on config
|
|
961
|
+
return CONVERSATION_SCHEMAS[conversation_type]
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
# Topic generation schemas for tree and graph (needed by other modules)
|
|
965
|
+
class TopicList(BaseModel):
|
|
966
|
+
"""A list of subtopics for tree/graph generation."""
|
|
967
|
+
|
|
968
|
+
subtopics: list[str] = Field(
|
|
969
|
+
description="List of subtopic names",
|
|
970
|
+
min_length=1,
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
class TopicNode(BaseModel):
|
|
975
|
+
"""A topic node with subtopics for graph generation."""
|
|
976
|
+
|
|
977
|
+
topic: str = Field(description="The topic name")
|
|
978
|
+
subtopics: list[str] = Field(
|
|
979
|
+
description="List of subtopic names",
|
|
980
|
+
default_factory=list,
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
class GraphSubtopic(BaseModel):
|
|
985
|
+
"""A subtopic with connections for graph generation."""
|
|
986
|
+
|
|
987
|
+
topic: str = Field(description="The subtopic name")
|
|
988
|
+
connections: list[int] = Field(
|
|
989
|
+
default_factory=list,
|
|
990
|
+
description="List of existing node IDs to connect to, empty list if none",
|
|
991
|
+
)
|
|
992
|
+
|
|
993
|
+
|
|
994
|
+
class GraphSubtopics(BaseModel):
|
|
995
|
+
"""List of subtopics with connections for graph generation."""
|
|
996
|
+
|
|
997
|
+
subtopics: list[GraphSubtopic] = Field(
|
|
998
|
+
description="List of subtopics with their connections",
|
|
999
|
+
min_length=1,
|
|
1000
|
+
)
|