caudate-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api/__init__.py +5 -0
- api/anthropic_compat.py +1518 -0
- api/artifact_viewer.py +366 -0
- api/caudate_middleware.py +618 -0
- api/forge_bootstrapper_routes.py +377 -0
- api/forge_routes.py +630 -0
- api/forge_system_routes.py +294 -0
- api/openai_compat.py +1993 -0
- api/server.py +667 -0
- api/storyboard_page.py +677 -0
- caudate_cli-0.1.0.dist-info/METADATA +354 -0
- caudate_cli-0.1.0.dist-info/RECORD +153 -0
- caudate_cli-0.1.0.dist-info/WHEEL +5 -0
- caudate_cli-0.1.0.dist-info/entry_points.txt +2 -0
- caudate_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- caudate_cli-0.1.0.dist-info/top_level.txt +14 -0
- cognos_mcp/__init__.py +4 -0
- cognos_mcp/bridge.py +41 -0
- cognos_mcp/client.py +70 -0
- cognos_mcp/config.py +49 -0
- cognos_mcp/server.py +66 -0
- config.py +82 -0
- core/__init__.py +0 -0
- core/agent.py +468 -0
- core/agentic_loop.py +731 -0
- core/anthropic_auth.py +91 -0
- core/background.py +113 -0
- core/banner.py +134 -0
- core/bootstrap.py +292 -0
- core/citations.py +131 -0
- core/compaction.py +109 -0
- core/constitution.py +198 -0
- core/diff_viewer.py +87 -0
- core/export.py +85 -0
- core/file_refs.py +119 -0
- core/files.py +199 -0
- core/hooks.py +209 -0
- core/image.py +599 -0
- core/input.py +91 -0
- core/loop.py +238 -0
- core/memory_md.py +147 -0
- core/notifications.py +99 -0
- core/ownership.py +181 -0
- core/paste.py +81 -0
- core/permissions.py +210 -0
- core/plan_mode.py +215 -0
- core/sandbox_prompt.py +185 -0
- core/scheduler.py +195 -0
- core/schemas.py +202 -0
- core/session.py +90 -0
- core/settings.py +132 -0
- core/skills.py +398 -0
- core/slash_commands.py +977 -0
- core/statusline.py +61 -0
- core/subagent.py +300 -0
- core/thinking.py +50 -0
- core/updater.py +122 -0
- core/usage.py +109 -0
- core/worktree.py +93 -0
- execution/__init__.py +0 -0
- execution/executor.py +329 -0
- execution/plugins.py +108 -0
- execution/tools/__init__.py +0 -0
- execution/tools/agent_tool.py +107 -0
- execution/tools/agentic_tool.py +297 -0
- execution/tools/artifact_tool.py +191 -0
- execution/tools/ask_user_question_tool.py +137 -0
- execution/tools/base.py +81 -0
- execution/tools/calculator_tool.py +137 -0
- execution/tools/cognos_card_tool.py +124 -0
- execution/tools/cron_tool.py +215 -0
- execution/tools/datetime_tool.py +215 -0
- execution/tools/describe_image_tool.py +161 -0
- execution/tools/draw_tool.py +164 -0
- execution/tools/edit_image_tool.py +262 -0
- execution/tools/edit_tool.py +245 -0
- execution/tools/file_tool.py +90 -0
- execution/tools/find_anywhere_tool.py +255 -0
- execution/tools/forge_feature_tools.py +377 -0
- execution/tools/glob_tool.py +59 -0
- execution/tools/grep_tool.py +89 -0
- execution/tools/http_request_tool.py +224 -0
- execution/tools/load_skill_tool.py +104 -0
- execution/tools/longcat_avatar_tool.py +384 -0
- execution/tools/mcp_tool.py +100 -0
- execution/tools/notebook_tool.py +279 -0
- execution/tools/openapi_tool.py +440 -0
- execution/tools/plan_mode_tool.py +95 -0
- execution/tools/push_notification_tool.py +157 -0
- execution/tools/python_tool.py +61 -0
- execution/tools/respond_tool.py +40 -0
- execution/tools/sandbox_tool.py +378 -0
- execution/tools/search_tool.py +153 -0
- execution/tools/semantic_search_tool.py +106 -0
- execution/tools/shell_tool.py +283 -0
- execution/tools/speak_tool.py +134 -0
- execution/tools/storyboard_tool.py +727 -0
- execution/tools/system_info_tool.py +212 -0
- execution/tools/task_tool.py +323 -0
- execution/tools/think_tool.py +49 -0
- execution/tools/transcribe_audio_tool.py +86 -0
- execution/tools/update_memory_tool.py +92 -0
- execution/tools/web_fetch_tool.py +82 -0
- execution/tools/worktree_tool.py +174 -0
- llm/__init__.py +0 -0
- llm/fallback.py +116 -0
- llm/models.py +320 -0
- llm/provider.py +1356 -0
- llm/router.py +373 -0
- main.py +1889 -0
- memory/__init__.py +0 -0
- memory/episodic.py +99 -0
- memory/procedural.py +145 -0
- memory/semantic.py +71 -0
- memory/working.py +64 -0
- nn/__init__.py +43 -0
- nn/auto_evolve.py +245 -0
- nn/caudate.py +136 -0
- nn/config.py +141 -0
- nn/consolidator.py +81 -0
- nn/data.py +1635 -0
- nn/encoder.py +258 -0
- nn/forge_advisor.py +303 -0
- nn/format.py +235 -0
- nn/heads.py +432 -0
- nn/observer.py +994 -0
- nn/policy.py +214 -0
- nn/runtime.py +343 -0
- nn/scorer.py +175 -0
- nn/trainer.py +515 -0
- nn/vision.py +352 -0
- personality/__init__.py +23 -0
- personality/engine.py +129 -0
- personality/identity.py +144 -0
- personality/inner_voice.py +100 -0
- personality/mood.py +205 -0
- planning/__init__.py +0 -0
- planning/dev_server.py +221 -0
- planning/forge_models.py +718 -0
- planning/orchestrator.py +1363 -0
- planning/planner.py +451 -0
- planning/task_graph.py +61 -0
- reflection/__init__.py +0 -0
- reflection/meta_learner.py +156 -0
- reflection/reflector.py +127 -0
- ui/__init__.py +5 -0
- ui/display.py +88 -0
- voice/__init__.py +0 -0
- voice/conversation.py +125 -0
- voice/listener.py +111 -0
- voice/speaker.py +59 -0
- voice/stt.py +126 -0
- voice/tts.py +214 -0
nn/data.py
ADDED
|
@@ -0,0 +1,1635 @@
|
|
|
1
|
+
"""Dataset + replay buffer for the cognitive controller.
|
|
2
|
+
|
|
3
|
+
Two data sources:
|
|
4
|
+
|
|
5
|
+
1. **Disk corpus** — every saved Session under `data/sessions/*.json`.
|
|
6
|
+
Each tool call inside a session becomes one (state, action, reward)
|
|
7
|
+
example. Reward defaults to 0.5 unless the session has an explicit
|
|
8
|
+
reflection score attached.
|
|
9
|
+
|
|
10
|
+
2. **Live replay buffer** — circular in-memory buffer fed by the agent
|
|
11
|
+
during use. Lets the controller train online from current sessions
|
|
12
|
+
before they're saved to disk.
|
|
13
|
+
|
|
14
|
+
ToolVocab maps the dynamic set of registered tools to stable integer
|
|
15
|
+
ids that survive across runs. The vocab grows append-only — adding a
|
|
16
|
+
new tool doesn't invalidate older checkpoints; it just adds a row to
|
|
17
|
+
the embedding table at training time.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import collections
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
import random
|
|
26
|
+
import re
|
|
27
|
+
from dataclasses import dataclass, field
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import Any, Iterator
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
|
|
33
|
+
from nn.config import NNConfig
|
|
34
|
+
from nn.format import ChatMessage, Conversation, ToolCall, ToolDef
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ---------------------------------------------------------------------
|
|
40
|
+
# Tool vocab
|
|
41
|
+
# ---------------------------------------------------------------------
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ToolVocab:
|
|
45
|
+
"""Bidirectional id ↔ name map for tool tokens.
|
|
46
|
+
|
|
47
|
+
Four special ids are reserved upfront:
|
|
48
|
+
<pad> — padding for batched sequences
|
|
49
|
+
<unk> — placeholder for missing/unmapped labels at inference
|
|
50
|
+
<bos> — start-of-sequence
|
|
51
|
+
<no_tool> — explicit "no tool was called" — a real action class
|
|
52
|
+
(distinct from <unk>: a turn where the assistant
|
|
53
|
+
correctly chose to just answer without any tool)
|
|
54
|
+
|
|
55
|
+
Keeping <no_tool> separate from <unk> matters for training: ~40% of
|
|
56
|
+
real conversation turns don't call a tool, and that's a valid
|
|
57
|
+
decision the model should learn to *predict*, not a degenerate
|
|
58
|
+
fallback class.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
SPECIAL = {
|
|
62
|
+
"<pad>": 0,
|
|
63
|
+
"<unk>": 1,
|
|
64
|
+
"<bos>": 2,
|
|
65
|
+
"<no_tool>": 3,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def __init__(self):
|
|
69
|
+
self._id2name: dict[int, str] = {v: k for k, v in self.SPECIAL.items()}
|
|
70
|
+
self._name2id: dict[str, int] = dict(self.SPECIAL)
|
|
71
|
+
self._next_id = max(self.SPECIAL.values()) + 1
|
|
72
|
+
|
|
73
|
+
def add(self, name: str) -> int:
|
|
74
|
+
if name in self._name2id:
|
|
75
|
+
return self._name2id[name]
|
|
76
|
+
idx = self._next_id
|
|
77
|
+
self._next_id += 1
|
|
78
|
+
self._name2id[name] = idx
|
|
79
|
+
self._id2name[idx] = name
|
|
80
|
+
return idx
|
|
81
|
+
|
|
82
|
+
def get(self, name: str) -> int:
|
|
83
|
+
return self._name2id.get(name, self.SPECIAL["<unk>"])
|
|
84
|
+
|
|
85
|
+
def name(self, idx: int) -> str:
|
|
86
|
+
return self._id2name.get(idx, "<unk>")
|
|
87
|
+
|
|
88
|
+
def __len__(self) -> int:
|
|
89
|
+
return self._next_id
|
|
90
|
+
|
|
91
|
+
def to_dict(self) -> dict[str, int]:
|
|
92
|
+
return dict(self._name2id)
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def from_dict(cls, data: dict[str, int]) -> "ToolVocab":
|
|
96
|
+
v = cls()
|
|
97
|
+
for name, idx in data.items():
|
|
98
|
+
if name in v.SPECIAL:
|
|
99
|
+
continue
|
|
100
|
+
v._name2id[name] = idx
|
|
101
|
+
v._id2name[idx] = name
|
|
102
|
+
v._next_id = max(v._next_id, idx + 1)
|
|
103
|
+
return v
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class SourceVocab:
|
|
107
|
+
"""Bidirectional id ↔ name map for teacher-model sources.
|
|
108
|
+
|
|
109
|
+
Phase 2 of CAUDATE_EVOLUTION.md: every ConversationSample carries a
|
|
110
|
+
`model_source` string ("anthropic/claude-opus-4-7", "ollama/...",
|
|
111
|
+
"<unknown>", etc.). The model conditions on that source via a
|
|
112
|
+
learned embedding lookup. This vocab maps the names to stable ids,
|
|
113
|
+
capped at `cfg.source_vocab_size` so the embedding table stays
|
|
114
|
+
finite. Overflow (more than the cap) folds back to <unknown>.
|
|
115
|
+
|
|
116
|
+
Reserved slot 0 = <unknown> = the zero-bias baseline. Legacy
|
|
117
|
+
untagged samples and any model we haven't seen before share this
|
|
118
|
+
slot.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
SPECIAL = {"<unknown>": 0}
|
|
122
|
+
DEFAULT_CAP = 16
|
|
123
|
+
|
|
124
|
+
def __init__(self, cap: int = DEFAULT_CAP):
|
|
125
|
+
self.cap = max(1, int(cap))
|
|
126
|
+
self._id2name: dict[int, str] = {v: k for k, v in self.SPECIAL.items()}
|
|
127
|
+
self._name2id: dict[str, int] = dict(self.SPECIAL)
|
|
128
|
+
self._next_id = max(self.SPECIAL.values()) + 1
|
|
129
|
+
|
|
130
|
+
def add(self, name: str) -> int:
|
|
131
|
+
if not isinstance(name, str) or not name:
|
|
132
|
+
return self.SPECIAL["<unknown>"]
|
|
133
|
+
if name in self._name2id:
|
|
134
|
+
return self._name2id[name]
|
|
135
|
+
if self._next_id >= self.cap:
|
|
136
|
+
# Overflow: silently fold to <unknown>. Better than
|
|
137
|
+
# blowing up; the embedding table can't grow at runtime.
|
|
138
|
+
return self.SPECIAL["<unknown>"]
|
|
139
|
+
idx = self._next_id
|
|
140
|
+
self._next_id += 1
|
|
141
|
+
self._name2id[name] = idx
|
|
142
|
+
self._id2name[idx] = name
|
|
143
|
+
return idx
|
|
144
|
+
|
|
145
|
+
def get(self, name: str) -> int:
|
|
146
|
+
return self._name2id.get(name, self.SPECIAL["<unknown>"])
|
|
147
|
+
|
|
148
|
+
def name(self, idx: int) -> str:
|
|
149
|
+
return self._id2name.get(idx, "<unknown>")
|
|
150
|
+
|
|
151
|
+
def __len__(self) -> int:
|
|
152
|
+
return self._next_id
|
|
153
|
+
|
|
154
|
+
def to_dict(self) -> dict[str, int]:
|
|
155
|
+
return dict(self._name2id)
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
def from_dict(cls, data: dict[str, int], cap: int = DEFAULT_CAP) -> "SourceVocab":
|
|
159
|
+
v = cls(cap=cap)
|
|
160
|
+
for name, idx in data.items():
|
|
161
|
+
if name in v.SPECIAL:
|
|
162
|
+
continue
|
|
163
|
+
if int(idx) >= v.cap:
|
|
164
|
+
continue # truncate any persisted ids beyond the new cap
|
|
165
|
+
v._name2id[name] = int(idx)
|
|
166
|
+
v._id2name[int(idx)] = name
|
|
167
|
+
v._next_id = max(v._next_id, int(idx) + 1)
|
|
168
|
+
return v
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ---------------------------------------------------------------------
|
|
172
|
+
# Sample structure — ConversationSample is THE training row.
|
|
173
|
+
# ---------------------------------------------------------------------
|
|
174
|
+
# Standard chat-tool-call schema (OpenAI shape) — same format used by
|
|
175
|
+
# every public function-calling dataset (xLAM, ToolBench, ToolAlpaca,
|
|
176
|
+
# Gorilla, API-Bank). Cognos sessions, external HuggingFace datasets,
|
|
177
|
+
# and the agent's live replay buffer all produce this single type, so
|
|
178
|
+
# the training path doesn't need to know where samples came from.
|
|
179
|
+
#
|
|
180
|
+
# Layout:
|
|
181
|
+
# - conversation: list[ChatMessage] ← the role-tagged turn list
|
|
182
|
+
# - tools: list[ToolDef] ← what was available to call
|
|
183
|
+
# - target_tool: str ← name of the tool actually called
|
|
184
|
+
# - target_arguments: str ← JSON-encoded args (reserved for
|
|
185
|
+
# a future generative arg head)
|
|
186
|
+
# - target_tool_call_index: int ← which tool_call within the
|
|
187
|
+
# target assistant message
|
|
188
|
+
# - Cognos-specific signals (mood, model_source, tier/think/value, the
|
|
189
|
+
# 14 optional D-heads) layer on top. External data leaves them at
|
|
190
|
+
# their default values; Cognos sessions populate them.
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@dataclass
|
|
194
|
+
class ConversationSample:
|
|
195
|
+
"""One training example in standard chat-tool-call format.
|
|
196
|
+
|
|
197
|
+
The conversation + tools fields together form one row that any
|
|
198
|
+
OpenAI-style tool-use dataset can produce. Cognos signals layer
|
|
199
|
+
on top as optional fields — they stay None for external data, so
|
|
200
|
+
the trainer's optional_target machinery skips those heads cleanly.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
# ── Standard-format payload ──────────────────────────────────────
|
|
204
|
+
conversation: list[ChatMessage] = field(default_factory=list)
|
|
205
|
+
tools: list[ToolDef] = field(default_factory=list)
|
|
206
|
+
|
|
207
|
+
# ── Targets ──────────────────────────────────────────────────────
|
|
208
|
+
# The tool name the assistant called at the target turn. Empty
|
|
209
|
+
# string or "<no_tool>" for sample-points where the assistant
|
|
210
|
+
# answered directly. Both forms are legal — the collate path
|
|
211
|
+
# normalises via the tool vocab.
|
|
212
|
+
target_tool: str = "<no_tool>"
|
|
213
|
+
target_arguments: str = "" # JSON; reserved for Phase 2
|
|
214
|
+
target_tool_call_index: int = 0 # which tool_call within the turn
|
|
215
|
+
|
|
216
|
+
# ── Cognos signals (optional) ────────────────────────────────────
|
|
217
|
+
mood: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 0.5])
|
|
218
|
+
image_paths: list[str] = field(default_factory=list)
|
|
219
|
+
model_source: str = "<unknown>"
|
|
220
|
+
surprise: float = 0.5
|
|
221
|
+
target_tier: int = 0
|
|
222
|
+
target_think: float = 0.0
|
|
223
|
+
target_value: float = 0.0
|
|
224
|
+
# Optional D-heads + Tier-1/2/4b heads (all may be None)
|
|
225
|
+
target_memory_write: float | None = None
|
|
226
|
+
target_cache_hit: float | None = None
|
|
227
|
+
target_permission: float | None = None
|
|
228
|
+
target_refusal: float | None = None
|
|
229
|
+
target_code_response: float | None = None
|
|
230
|
+
target_stall: float | None = None
|
|
231
|
+
target_difficulty: int | None = None
|
|
232
|
+
target_stop_iter: float | None = None
|
|
233
|
+
target_compaction: float | None = None
|
|
234
|
+
target_latency_s: float | None = None
|
|
235
|
+
target_token_budget: float | None = None
|
|
236
|
+
target_mood_pred: list[float] | None = None
|
|
237
|
+
target_subagent_spawn: float | None = None
|
|
238
|
+
target_reward_model: float | None = None
|
|
239
|
+
target_feature_success: float | None = None
|
|
240
|
+
|
|
241
|
+
# ── Construction helpers ─────────────────────────────────────────
|
|
242
|
+
|
|
243
|
+
@classmethod
|
|
244
|
+
def from_conversation(
|
|
245
|
+
cls,
|
|
246
|
+
conv: Conversation,
|
|
247
|
+
prefix_len: int,
|
|
248
|
+
tool_call: ToolCall,
|
|
249
|
+
tool_call_index: int = 0,
|
|
250
|
+
**cognos_signals: Any,
|
|
251
|
+
) -> "ConversationSample":
|
|
252
|
+
"""Build one sample by taking the first `prefix_len` messages
|
|
253
|
+
as context and the given tool_call as the target.
|
|
254
|
+
|
|
255
|
+
`prefix_len` is the number of messages BEFORE the assistant
|
|
256
|
+
message that emitted this tool_call. The prefix is what the
|
|
257
|
+
model sees; the target_tool is what it should predict.
|
|
258
|
+
Cognos signals (mood, model_source, etc.) are passed via
|
|
259
|
+
kwargs so the loader stays in control of which ones it sets.
|
|
260
|
+
"""
|
|
261
|
+
prefix = list(conv.messages[:prefix_len])
|
|
262
|
+
return cls(
|
|
263
|
+
conversation=prefix,
|
|
264
|
+
tools=list(conv.tools),
|
|
265
|
+
target_tool=tool_call.name or "<no_tool>",
|
|
266
|
+
target_arguments=tool_call.arguments or "",
|
|
267
|
+
target_tool_call_index=tool_call_index,
|
|
268
|
+
**cognos_signals,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# ---------------------------------------------------------------------
|
|
273
|
+
# Conversion helpers used by collate & every loader
|
|
274
|
+
# ---------------------------------------------------------------------
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def conversation_to_strings(conv: list[ChatMessage]) -> list[str]:
|
|
278
|
+
"""Convert a list[ChatMessage] into the role-prefixed string list
|
|
279
|
+
the StateEncoder consumes.
|
|
280
|
+
|
|
281
|
+
Tool calls inside an assistant message are rendered as
|
|
282
|
+
``"assistant: <content> calls: tool_name(args), ...".`` Tool result
|
|
283
|
+
messages (role="tool") are rendered as ``"tool {name}: <content>".``
|
|
284
|
+
Empty content + no calls is dropped (a no-op turn carries no signal).
|
|
285
|
+
|
|
286
|
+
The encoder operates on text; this helper is the boundary that
|
|
287
|
+
keeps the chat-schema clean upstream and the encoder API
|
|
288
|
+
unchanged downstream.
|
|
289
|
+
"""
|
|
290
|
+
out: list[str] = []
|
|
291
|
+
for m in conv:
|
|
292
|
+
if m.role == "tool":
|
|
293
|
+
label = m.name or "tool"
|
|
294
|
+
content = (m.content or "").strip()
|
|
295
|
+
if content:
|
|
296
|
+
out.append(f"tool {label}: {content[:400]}")
|
|
297
|
+
continue
|
|
298
|
+
if m.tool_calls:
|
|
299
|
+
call_strs = []
|
|
300
|
+
for tc in m.tool_calls:
|
|
301
|
+
arg_preview = (tc.arguments or "").replace("\n", " ")[:60]
|
|
302
|
+
call_strs.append(f"{tc.name}({arg_preview})")
|
|
303
|
+
content = (m.content or "").strip()
|
|
304
|
+
text = (content + " " if content else "") + "calls: " + ", ".join(call_strs)
|
|
305
|
+
out.append(f"{m.role}: {text[:400]}")
|
|
306
|
+
continue
|
|
307
|
+
content = (m.content or "").strip()
|
|
308
|
+
if content:
|
|
309
|
+
out.append(f"{m.role}: {content[:400]}")
|
|
310
|
+
return out
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def conversation_tool_history(conv: list[ChatMessage]) -> list[str]:
|
|
314
|
+
"""Extract the sequence of tool names called in this conversation."""
|
|
315
|
+
history: list[str] = []
|
|
316
|
+
for m in conv:
|
|
317
|
+
if m.role == "assistant":
|
|
318
|
+
for tc in m.tool_calls:
|
|
319
|
+
if tc.name:
|
|
320
|
+
history.append(tc.name)
|
|
321
|
+
return history
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
# ---------------------------------------------------------------------
|
|
325
|
+
# Cognos-session loader — reads data/sessions/*.json into ConversationSamples
|
|
326
|
+
# ---------------------------------------------------------------------
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def load_corpus_from_sessions(
|
|
330
|
+
sessions_dir: Path = Path("data/sessions"),
|
|
331
|
+
) -> list[ConversationSample]:
|
|
332
|
+
"""Convert each saved Cognos session into ConversationSample(s).
|
|
333
|
+
|
|
334
|
+
Sessions are stored as OpenAI-shape JSON. We emit one
|
|
335
|
+
ConversationSample per assistant tool call, with the prefix-of-
|
|
336
|
+
messages-before-the-call as context. Reward semantics:
|
|
337
|
+
metadata.reflection_scores[call_id] → target_value, else 0.5.
|
|
338
|
+
|
|
339
|
+
Sessions that recorded their tool registry under metadata.tools
|
|
340
|
+
have that list carried through; absent tools means empty (the
|
|
341
|
+
open-vocab head still trains, just without negative candidates).
|
|
342
|
+
"""
|
|
343
|
+
out: list[ConversationSample] = []
|
|
344
|
+
if not sessions_dir.exists():
|
|
345
|
+
return out
|
|
346
|
+
|
|
347
|
+
for path in sorted(sessions_dir.glob("*.json")):
|
|
348
|
+
try:
|
|
349
|
+
data = json.loads(path.read_text())
|
|
350
|
+
except Exception:
|
|
351
|
+
continue
|
|
352
|
+
raw_msgs = data.get("messages") or []
|
|
353
|
+
scores = (data.get("metadata") or {}).get("reflection_scores") or {}
|
|
354
|
+
raw_tools = (data.get("metadata") or {}).get("tools") \
|
|
355
|
+
or data.get("tools") or []
|
|
356
|
+
try:
|
|
357
|
+
conv = Conversation.from_dict(
|
|
358
|
+
{"messages": raw_msgs, "tools": raw_tools},
|
|
359
|
+
)
|
|
360
|
+
except Exception:
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
for msg_idx, msg in enumerate(conv.messages):
|
|
364
|
+
if msg.role != "assistant" or not msg.tool_calls:
|
|
365
|
+
continue
|
|
366
|
+
for call_idx, tc in enumerate(msg.tool_calls):
|
|
367
|
+
if not tc.name:
|
|
368
|
+
continue
|
|
369
|
+
score = float(scores.get(tc.id, 0.5)) if tc.id else 0.5
|
|
370
|
+
sample = ConversationSample.from_conversation(
|
|
371
|
+
conv=conv,
|
|
372
|
+
prefix_len=msg_idx,
|
|
373
|
+
tool_call=tc,
|
|
374
|
+
tool_call_index=call_idx,
|
|
375
|
+
model_source=(data.get("metadata") or {}).get(
|
|
376
|
+
"model_source", "<unknown>",
|
|
377
|
+
),
|
|
378
|
+
target_tier=int(
|
|
379
|
+
(data.get("metadata") or {}).get("last_tier", 0),
|
|
380
|
+
),
|
|
381
|
+
target_think=float(
|
|
382
|
+
(data.get("metadata") or {}).get("thinking_used", 0.0),
|
|
383
|
+
),
|
|
384
|
+
target_value=score,
|
|
385
|
+
)
|
|
386
|
+
out.append(sample)
|
|
387
|
+
return out
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
# ---------------------------------------------------------------------
|
|
391
|
+
# OpenAI-format JSONL loader (external datasets: xLAM, ToolBench, etc.)
|
|
392
|
+
# ---------------------------------------------------------------------
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def load_corpus_from_oai_jsonl(
|
|
396
|
+
path: Path,
|
|
397
|
+
model_source: str = "<external>",
|
|
398
|
+
max_rows: int | None = None,
|
|
399
|
+
) -> list[ConversationSample]:
|
|
400
|
+
"""Load a JSONL file in standard chat-tool-call format.
|
|
401
|
+
|
|
402
|
+
One line per conversation. Each conversation may contain multiple
|
|
403
|
+
assistant tool calls; we emit one ConversationSample per tool call,
|
|
404
|
+
with the prefix-of-messages-before-the-call as context.
|
|
405
|
+
|
|
406
|
+
The standard format works for: xLAM, ToolBench, ToolAlpaca, Gorilla,
|
|
407
|
+
API-Bank, and any OpenAI-fine-tuning-shaped corpus. Some datasets
|
|
408
|
+
use ``conversations`` instead of ``messages`` for the turn list —
|
|
409
|
+
``Conversation.from_dict`` handles both.
|
|
410
|
+
|
|
411
|
+
Malformed rows are skipped, not raised — external corpora are noisy
|
|
412
|
+
and one bad row must not abort the whole training pass.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
path: JSONL file path.
|
|
416
|
+
model_source: tagged on every emitted sample so the
|
|
417
|
+
source-embedding (Phase 2 of CAUDATE_EVOLUTION) can bias
|
|
418
|
+
predictions for external-data turns.
|
|
419
|
+
max_rows: cap for quick smoke runs; None = read everything.
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
list of ConversationSample, one per assistant tool call.
|
|
423
|
+
"""
|
|
424
|
+
out: list[ConversationSample] = []
|
|
425
|
+
if not path.exists():
|
|
426
|
+
logger.warning(f"oai jsonl not found: {path}")
|
|
427
|
+
return out
|
|
428
|
+
|
|
429
|
+
n_rows = 0
|
|
430
|
+
n_calls = 0
|
|
431
|
+
with path.open(encoding="utf-8") as f:
|
|
432
|
+
for line in f:
|
|
433
|
+
if max_rows is not None and n_rows >= max_rows:
|
|
434
|
+
break
|
|
435
|
+
line = line.strip()
|
|
436
|
+
if not line:
|
|
437
|
+
continue
|
|
438
|
+
try:
|
|
439
|
+
row = json.loads(line)
|
|
440
|
+
except Exception:
|
|
441
|
+
continue
|
|
442
|
+
try:
|
|
443
|
+
conv = Conversation.from_dict(row)
|
|
444
|
+
except Exception as e:
|
|
445
|
+
logger.debug(f"skipping malformed oai row: {e}")
|
|
446
|
+
continue
|
|
447
|
+
if not conv.messages:
|
|
448
|
+
continue
|
|
449
|
+
n_rows += 1
|
|
450
|
+
|
|
451
|
+
# Row-level target_tool: the cognos-toolbox amplified/replay
|
|
452
|
+
# rows are "predict the next tool given user turn + tool
|
|
453
|
+
# catalog" — no assistant message exists yet. Emit one
|
|
454
|
+
# sample with the row-level target and the full message
|
|
455
|
+
# list as context.
|
|
456
|
+
row_target = row.get("target_tool")
|
|
457
|
+
if row_target is not None:
|
|
458
|
+
row_meta = row.get("meta") or {}
|
|
459
|
+
row_source = (
|
|
460
|
+
str(row_meta.get("teacher_model"))
|
|
461
|
+
if isinstance(row_meta, dict) and row_meta.get("teacher_model")
|
|
462
|
+
else model_source
|
|
463
|
+
)
|
|
464
|
+
out.append(ConversationSample(
|
|
465
|
+
conversation=list(conv.messages),
|
|
466
|
+
tools=list(conv.tools),
|
|
467
|
+
target_tool=str(row_target) or "<no_tool>",
|
|
468
|
+
model_source=row_source,
|
|
469
|
+
))
|
|
470
|
+
n_calls += 1
|
|
471
|
+
continue
|
|
472
|
+
|
|
473
|
+
# Standard OAI shape: emit one sample per tool_call within
|
|
474
|
+
# each assistant message. The prefix is everything *before*
|
|
475
|
+
# that assistant message — that's the state the model sees
|
|
476
|
+
# when deciding which tool to call.
|
|
477
|
+
for msg_idx, msg in enumerate(conv.messages):
|
|
478
|
+
if msg.role != "assistant" or not msg.tool_calls:
|
|
479
|
+
continue
|
|
480
|
+
for call_idx, tc in enumerate(msg.tool_calls):
|
|
481
|
+
if not tc.name:
|
|
482
|
+
continue
|
|
483
|
+
sample = ConversationSample.from_conversation(
|
|
484
|
+
conv=conv,
|
|
485
|
+
prefix_len=msg_idx,
|
|
486
|
+
tool_call=tc,
|
|
487
|
+
tool_call_index=call_idx,
|
|
488
|
+
model_source=model_source,
|
|
489
|
+
)
|
|
490
|
+
out.append(sample)
|
|
491
|
+
n_calls += 1
|
|
492
|
+
|
|
493
|
+
logger.info(
|
|
494
|
+
f"loaded {n_rows} rows / {n_calls} tool-call samples from {path}"
|
|
495
|
+
)
|
|
496
|
+
return out
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
# ---------------------------------------------------------------------
|
|
500
|
+
# HuggingFace dataset loader — direct path from HF Hub to ConversationSample
|
|
501
|
+
# ---------------------------------------------------------------------
|
|
502
|
+
#
|
|
503
|
+
# HF function-calling datasets don't share one schema. Common shapes:
|
|
504
|
+
#
|
|
505
|
+
# OpenAI chat: row = {"messages": [...], "tools": [...]}
|
|
506
|
+
# xLAM-style: row = {"query": "...", "tools": "[...]",
|
|
507
|
+
# "answers": "[{...tool_call...}]"} (JSON strings)
|
|
508
|
+
# Conversations: row = {"conversations": [...same as messages...],
|
|
509
|
+
# "tools": [...]}
|
|
510
|
+
#
|
|
511
|
+
# `load_corpus_from_hf` auto-detects which one applies per dataset. For
|
|
512
|
+
# anything off the beaten path, pass a `row_to_conversation` callable
|
|
513
|
+
# that returns a Conversation given a single row — bypasses detection
|
|
514
|
+
# entirely.
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _detect_hf_shape(row: dict[str, Any]) -> str:
|
|
518
|
+
"""Return 'openai' | 'block_content' | 'conversations' | 'sharegpt'
|
|
519
|
+
| 'xlam' | 'unknown'.
|
|
520
|
+
|
|
521
|
+
The ``block_content`` shape is used by llamafactory's
|
|
522
|
+
reason-tool-use-demo-style corpora: ``messages[].content`` is a
|
|
523
|
+
list of ``{type, value}`` blocks (``text``/``reasoning``/
|
|
524
|
+
``tool_call``) and ``tools`` is a JSON-encoded string rather than
|
|
525
|
+
a list. Detected ahead of plain ``openai`` because the row's
|
|
526
|
+
``messages[0]`` still has a ``role`` key — without the early
|
|
527
|
+
check we'd silently drop every tool_call.
|
|
528
|
+
"""
|
|
529
|
+
if isinstance(row.get("messages"), list) and row["messages"] \
|
|
530
|
+
and isinstance(row["messages"][0], dict) \
|
|
531
|
+
and "role" in row["messages"][0]:
|
|
532
|
+
# Distinguish OAI-shape (content: str|None) from block-typed
|
|
533
|
+
# content (content: list of {type, value} dicts).
|
|
534
|
+
first_content = row["messages"][0].get("content")
|
|
535
|
+
if (isinstance(first_content, list) and first_content
|
|
536
|
+
and isinstance(first_content[0], dict)
|
|
537
|
+
and "value" in first_content[0]):
|
|
538
|
+
return "block_content"
|
|
539
|
+
return "openai"
|
|
540
|
+
if isinstance(row.get("conversations"), list) and row["conversations"] \
|
|
541
|
+
and isinstance(row["conversations"][0], dict):
|
|
542
|
+
first = row["conversations"][0]
|
|
543
|
+
if "role" in first:
|
|
544
|
+
return "conversations"
|
|
545
|
+
if "from" in first and "value" in first:
|
|
546
|
+
return "sharegpt"
|
|
547
|
+
if "query" in row and "answers" in row:
|
|
548
|
+
return "xlam"
|
|
549
|
+
if isinstance(row.get("system"), str) and isinstance(row.get("chat"), str):
|
|
550
|
+
return "glaive"
|
|
551
|
+
return "unknown"
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def _parse_block_content_row(row: dict[str, Any]) -> Conversation:
|
|
555
|
+
"""Parse a llamafactory-style row where each message's ``content``
|
|
556
|
+
is a list of typed blocks.
|
|
557
|
+
|
|
558
|
+
Block-type mapping:
|
|
559
|
+
- ``text`` → appended to message content
|
|
560
|
+
- ``reasoning`` → appended to message content (the trunk sees
|
|
561
|
+
it as additional context; semantically a "chain of thought"
|
|
562
|
+
rendered into the same text channel)
|
|
563
|
+
- ``tool_call`` → parsed as JSON, becomes a ToolCall on this
|
|
564
|
+
assistant message
|
|
565
|
+
|
|
566
|
+
The row's ``tools`` field is a JSON string; we decode it before
|
|
567
|
+
handing to Conversation.from_dict. When parsing fails or the
|
|
568
|
+
field is empty, the tools list is left empty.
|
|
569
|
+
"""
|
|
570
|
+
raw_msgs: list[dict[str, Any]] = []
|
|
571
|
+
for m in row.get("messages") or []:
|
|
572
|
+
if not isinstance(m, dict):
|
|
573
|
+
continue
|
|
574
|
+
role = str(m.get("role") or "user")
|
|
575
|
+
content_parts: list[str] = []
|
|
576
|
+
tool_calls: list[dict[str, Any]] = []
|
|
577
|
+
blocks = m.get("content")
|
|
578
|
+
if isinstance(blocks, list):
|
|
579
|
+
for blk in blocks:
|
|
580
|
+
if not isinstance(blk, dict):
|
|
581
|
+
continue
|
|
582
|
+
btype = str(blk.get("type") or "").lower()
|
|
583
|
+
value = blk.get("value") or blk.get("text") or ""
|
|
584
|
+
if btype == "tool_call":
|
|
585
|
+
# value is a JSON string like
|
|
586
|
+
# '{"name": "...", "arguments": {...}}'
|
|
587
|
+
try:
|
|
588
|
+
parsed = json.loads(value) if isinstance(value, str) else value
|
|
589
|
+
except Exception:
|
|
590
|
+
parsed = None
|
|
591
|
+
if isinstance(parsed, dict) and parsed.get("name"):
|
|
592
|
+
args = parsed.get("arguments")
|
|
593
|
+
if isinstance(args, dict):
|
|
594
|
+
args = json.dumps(args)
|
|
595
|
+
elif args is None:
|
|
596
|
+
args = ""
|
|
597
|
+
tool_calls.append({
|
|
598
|
+
"type": "function",
|
|
599
|
+
"function": {
|
|
600
|
+
"name": str(parsed["name"]),
|
|
601
|
+
"arguments": str(args),
|
|
602
|
+
},
|
|
603
|
+
})
|
|
604
|
+
elif btype in ("text", "reasoning"):
|
|
605
|
+
if isinstance(value, str) and value:
|
|
606
|
+
content_parts.append(value)
|
|
607
|
+
else:
|
|
608
|
+
# Unknown block types — preserve as text rather
|
|
609
|
+
# than drop. Better to feed noisy context than
|
|
610
|
+
# lose a labelled message.
|
|
611
|
+
if isinstance(value, str) and value:
|
|
612
|
+
content_parts.append(value)
|
|
613
|
+
elif isinstance(blocks, str):
|
|
614
|
+
content_parts.append(blocks)
|
|
615
|
+
|
|
616
|
+
msg_dict: dict[str, Any] = {
|
|
617
|
+
"role": role,
|
|
618
|
+
"content": " ".join(content_parts).strip(),
|
|
619
|
+
}
|
|
620
|
+
if tool_calls:
|
|
621
|
+
msg_dict["tool_calls"] = tool_calls
|
|
622
|
+
raw_msgs.append(msg_dict)
|
|
623
|
+
|
|
624
|
+
# tools is a JSON-encoded string in this shape; decode.
|
|
625
|
+
raw_tools_field = row.get("tools")
|
|
626
|
+
raw_tools: list[dict[str, Any]] = []
|
|
627
|
+
if isinstance(raw_tools_field, str) and raw_tools_field.strip():
|
|
628
|
+
try:
|
|
629
|
+
decoded = json.loads(raw_tools_field)
|
|
630
|
+
if isinstance(decoded, list):
|
|
631
|
+
raw_tools = [t for t in decoded if isinstance(t, dict)]
|
|
632
|
+
except Exception:
|
|
633
|
+
pass
|
|
634
|
+
elif isinstance(raw_tools_field, list):
|
|
635
|
+
raw_tools = [t for t in raw_tools_field if isinstance(t, dict)]
|
|
636
|
+
|
|
637
|
+
return Conversation.from_dict({"messages": raw_msgs, "tools": raw_tools})
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
# Map ShareGPT speaker tags to standard role names. ShareGPT uses
|
|
641
|
+
# `human`/`gpt` historically; modern variants also include `system`,
|
|
642
|
+
# `tool`, and `function`. Unknown tags fall back to "user" rather
|
|
643
|
+
# than silently dropping the turn — better to keep noisy context than
|
|
644
|
+
# lose a labelled message.
|
|
645
|
+
_SHAREGPT_ROLE_MAP: dict[str, str] = {
|
|
646
|
+
"human": "user",
|
|
647
|
+
"user": "user",
|
|
648
|
+
"gpt": "assistant",
|
|
649
|
+
"chatgpt": "assistant",
|
|
650
|
+
"bing": "assistant",
|
|
651
|
+
"assistant": "assistant",
|
|
652
|
+
"system": "system",
|
|
653
|
+
"tool": "tool",
|
|
654
|
+
"function": "tool",
|
|
655
|
+
"observation": "tool",
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def _parse_sharegpt_row(row: dict[str, Any]) -> Conversation:
|
|
660
|
+
"""Map a ShareGPT-format row (``conversations`` of ``{from, value}``)
|
|
661
|
+
to a standard ``Conversation``.
|
|
662
|
+
|
|
663
|
+
Used by datasets like Kimi-K2.5-Reasoning, Vicuna, WizardLM,
|
|
664
|
+
Alpaca-fc-* etc. These are typically reasoning corpora without
|
|
665
|
+
tool calls — the loader still emits one ConversationSample per
|
|
666
|
+
assistant turn (target_tool=``<no_tool>``) when
|
|
667
|
+
``emit_no_tool_turns=True`` so the trunk gets training signal.
|
|
668
|
+
"""
|
|
669
|
+
msgs: list[dict[str, Any]] = []
|
|
670
|
+
for m in row.get("conversations") or []:
|
|
671
|
+
if not isinstance(m, dict):
|
|
672
|
+
continue
|
|
673
|
+
role = _SHAREGPT_ROLE_MAP.get(str(m.get("from") or "").lower(), "user")
|
|
674
|
+
msgs.append({"role": role, "content": m.get("value") or ""})
|
|
675
|
+
return Conversation.from_dict({"messages": msgs, "tools": []})
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
def _parse_xlam_row(row: dict[str, Any]) -> Conversation:
|
|
679
|
+
"""xLAM rows pack tools and the assistant's tool calls into JSON
|
|
680
|
+
strings on separate columns. Rebuild a 2-turn conversation:
|
|
681
|
+
|
|
682
|
+
user: <query> ─► assistant: <tool_calls=...>
|
|
683
|
+
|
|
684
|
+
so the standard pipeline sees one assistant tool-call turn. We
|
|
685
|
+
deliberately drop tool-result content (xLAM doesn't simulate them)
|
|
686
|
+
— the model only learns to PREDICT the call, which is what
|
|
687
|
+
Caudate's advisor head does anyway.
|
|
688
|
+
"""
|
|
689
|
+
def _maybe_json(v):
|
|
690
|
+
if isinstance(v, str):
|
|
691
|
+
try:
|
|
692
|
+
return json.loads(v)
|
|
693
|
+
except Exception:
|
|
694
|
+
return []
|
|
695
|
+
return v or []
|
|
696
|
+
|
|
697
|
+
tools_raw = _maybe_json(row.get("tools"))
|
|
698
|
+
answers = _maybe_json(row.get("answers"))
|
|
699
|
+
if not isinstance(answers, list):
|
|
700
|
+
answers = []
|
|
701
|
+
if not isinstance(tools_raw, list):
|
|
702
|
+
tools_raw = []
|
|
703
|
+
|
|
704
|
+
tool_calls_raw: list[dict[str, Any]] = []
|
|
705
|
+
for a in answers:
|
|
706
|
+
if not isinstance(a, dict):
|
|
707
|
+
continue
|
|
708
|
+
# xLAM uses {"name": ..., "arguments": {...}}
|
|
709
|
+
tool_calls_raw.append({
|
|
710
|
+
"type": "function",
|
|
711
|
+
"function": {
|
|
712
|
+
"name": a.get("name") or a.get("tool") or "",
|
|
713
|
+
"arguments": a.get("arguments") or a.get("args") or {},
|
|
714
|
+
},
|
|
715
|
+
})
|
|
716
|
+
|
|
717
|
+
msgs: list[dict[str, Any]] = [
|
|
718
|
+
{"role": "user", "content": str(row.get("query") or "")},
|
|
719
|
+
{"role": "assistant", "content": "", "tool_calls": tool_calls_raw},
|
|
720
|
+
]
|
|
721
|
+
# xLAM's tool defs already look like {"name", "description", "parameters"}
|
|
722
|
+
# — wrap to OpenAI shape and let Conversation.from_dict normalise.
|
|
723
|
+
tools_norm = [
|
|
724
|
+
{"type": "function", "function": t if isinstance(t, dict) else {}}
|
|
725
|
+
for t in tools_raw
|
|
726
|
+
]
|
|
727
|
+
return Conversation.from_dict({"messages": msgs, "tools": tools_norm})
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def load_corpus_from_hf(
|
|
731
|
+
dataset_name: str,
|
|
732
|
+
split: str = "train",
|
|
733
|
+
model_source: str | None = None,
|
|
734
|
+
max_rows: int | None = None,
|
|
735
|
+
row_to_conversation: "Any | None" = None,
|
|
736
|
+
streaming: bool = True,
|
|
737
|
+
emit_no_tool_turns: bool = True,
|
|
738
|
+
**load_dataset_kwargs: Any,
|
|
739
|
+
) -> list[ConversationSample]:
|
|
740
|
+
"""Stream a HuggingFace dataset into ConversationSamples.
|
|
741
|
+
|
|
742
|
+
Auto-detects the row layout. Supported out-of-the-box:
|
|
743
|
+
|
|
744
|
+
- **OpenAI chat** — ``messages``/``tools`` columns, dicts with
|
|
745
|
+
``role``. Examples: anything saved via the OpenAI fine-tuning
|
|
746
|
+
format or Hermes-Function-Calling.
|
|
747
|
+
- **Conversations alias** — same as OpenAI but the column is
|
|
748
|
+
called ``conversations`` (ToolBench-style).
|
|
749
|
+
- **xLAM** — ``query``/``tools``/``answers`` columns where
|
|
750
|
+
``tools`` and ``answers`` are JSON-encoded strings. Examples:
|
|
751
|
+
Salesforce/xlam-function-calling-60k.
|
|
752
|
+
|
|
753
|
+
For any other shape, pass ``row_to_conversation=fn`` where ``fn``
|
|
754
|
+
takes a row dict and returns a ``Conversation``. The auto-detector
|
|
755
|
+
is bypassed in that case.
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
dataset_name: HF Hub identifier, e.g.
|
|
759
|
+
``"Salesforce/xlam-function-calling-60k"``.
|
|
760
|
+
split: split name (``"train"``, ``"validation"``, etc.)
|
|
761
|
+
model_source: tag set on every sample's ``model_source``.
|
|
762
|
+
Defaults to ``"hf:<dataset_name>"``.
|
|
763
|
+
max_rows: cap for smoke runs. None = read everything.
|
|
764
|
+
row_to_conversation: optional custom parser. Skip auto-detect.
|
|
765
|
+
streaming: if True (default), uses ``datasets.load_dataset(...,
|
|
766
|
+
streaming=True)`` to avoid materialising the whole dataset.
|
|
767
|
+
**load_dataset_kwargs: forwarded to ``load_dataset``.
|
|
768
|
+
|
|
769
|
+
Returns:
|
|
770
|
+
list[ConversationSample] — one per assistant tool call.
|
|
771
|
+
"""
|
|
772
|
+
try:
|
|
773
|
+
from datasets import load_dataset # type: ignore
|
|
774
|
+
except ImportError:
|
|
775
|
+
raise RuntimeError(
|
|
776
|
+
"datasets library not installed. pip install datasets, or "
|
|
777
|
+
"export the corpus to JSONL and use load_corpus_from_oai_jsonl."
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
model_source = model_source or f"hf:{dataset_name}"
|
|
781
|
+
ds = load_dataset(
|
|
782
|
+
dataset_name, split=split, streaming=streaming, **load_dataset_kwargs,
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
out: list[ConversationSample] = []
|
|
786
|
+
detected_shape: str | None = None
|
|
787
|
+
n_rows = 0
|
|
788
|
+
n_calls = 0
|
|
789
|
+
for row in ds:
|
|
790
|
+
if max_rows is not None and n_rows >= max_rows:
|
|
791
|
+
break
|
|
792
|
+
if not isinstance(row, dict):
|
|
793
|
+
continue
|
|
794
|
+
n_rows += 1
|
|
795
|
+
|
|
796
|
+
# Build the Conversation. If the caller supplied a custom
|
|
797
|
+
# parser, use it. Otherwise detect once and reuse.
|
|
798
|
+
try:
|
|
799
|
+
if row_to_conversation is not None:
|
|
800
|
+
conv = row_to_conversation(row)
|
|
801
|
+
else:
|
|
802
|
+
if detected_shape is None:
|
|
803
|
+
detected_shape = _detect_hf_shape(row)
|
|
804
|
+
logger.info(
|
|
805
|
+
f"hf loader: detected shape '{detected_shape}' "
|
|
806
|
+
f"for {dataset_name}"
|
|
807
|
+
)
|
|
808
|
+
if detected_shape == "openai":
|
|
809
|
+
conv = Conversation.from_dict(row)
|
|
810
|
+
elif detected_shape == "block_content":
|
|
811
|
+
conv = _parse_block_content_row(row)
|
|
812
|
+
elif detected_shape == "conversations":
|
|
813
|
+
conv = Conversation.from_dict({
|
|
814
|
+
"messages": row["conversations"],
|
|
815
|
+
"tools": row.get("tools") or [],
|
|
816
|
+
})
|
|
817
|
+
elif detected_shape == "sharegpt":
|
|
818
|
+
conv = _parse_sharegpt_row(row)
|
|
819
|
+
elif detected_shape == "xlam":
|
|
820
|
+
conv = _parse_xlam_row(row)
|
|
821
|
+
else:
|
|
822
|
+
logger.warning(
|
|
823
|
+
f"hf loader: unknown shape for {dataset_name}; "
|
|
824
|
+
f"pass row_to_conversation= to handle it. "
|
|
825
|
+
f"keys={list(row.keys())[:8]}"
|
|
826
|
+
)
|
|
827
|
+
break
|
|
828
|
+
except Exception as e:
|
|
829
|
+
logger.debug(f"hf loader: skipping row {n_rows}: {e}")
|
|
830
|
+
continue
|
|
831
|
+
if not isinstance(conv, Conversation) or not conv.messages:
|
|
832
|
+
continue
|
|
833
|
+
|
|
834
|
+
# Detect "any assistant turn?" so reasoning-only corpora
|
|
835
|
+
# (no tool_calls) still produce training samples via the
|
|
836
|
+
# synthetic <no_tool> target when emit_no_tool_turns=True.
|
|
837
|
+
any_call_emitted = False
|
|
838
|
+
for msg_idx, msg in enumerate(conv.messages):
|
|
839
|
+
if msg.role != "assistant":
|
|
840
|
+
continue
|
|
841
|
+
if msg.tool_calls:
|
|
842
|
+
for call_idx, tc in enumerate(msg.tool_calls):
|
|
843
|
+
if not tc.name:
|
|
844
|
+
continue
|
|
845
|
+
out.append(ConversationSample.from_conversation(
|
|
846
|
+
conv=conv,
|
|
847
|
+
prefix_len=msg_idx,
|
|
848
|
+
tool_call=tc,
|
|
849
|
+
tool_call_index=call_idx,
|
|
850
|
+
model_source=model_source,
|
|
851
|
+
))
|
|
852
|
+
n_calls += 1
|
|
853
|
+
any_call_emitted = True
|
|
854
|
+
elif emit_no_tool_turns and msg.content:
|
|
855
|
+
# Plain assistant response → "<no_tool>" target. The
|
|
856
|
+
# trunk still learns from the conversation prefix; the
|
|
857
|
+
# contrastive head learns when NOT to call anything.
|
|
858
|
+
out.append(ConversationSample(
|
|
859
|
+
conversation=list(conv.messages[:msg_idx]),
|
|
860
|
+
tools=list(conv.tools),
|
|
861
|
+
target_tool="<no_tool>",
|
|
862
|
+
target_arguments="",
|
|
863
|
+
target_tool_call_index=0,
|
|
864
|
+
model_source=model_source,
|
|
865
|
+
))
|
|
866
|
+
n_calls += 1
|
|
867
|
+
|
|
868
|
+
logger.info(
|
|
869
|
+
f"hf loader: {dataset_name}#{split} → {n_rows} rows / "
|
|
870
|
+
f"{n_calls} tool-call samples"
|
|
871
|
+
)
|
|
872
|
+
return out
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
# ---------------------------------------------------------------------
|
|
876
|
+
# Local HuggingFace cache loader — read .arrow shards directly
|
|
877
|
+
# ---------------------------------------------------------------------
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
def load_corpus_from_local_arrow(
|
|
881
|
+
cache_dir: Path | str,
|
|
882
|
+
model_source: str | None = None,
|
|
883
|
+
max_rows: int | None = None,
|
|
884
|
+
row_to_conversation: "Any | None" = None,
|
|
885
|
+
emit_no_tool_turns: bool = True,
|
|
886
|
+
) -> list[ConversationSample]:
|
|
887
|
+
"""Load ConversationSamples from a local HuggingFace cache directory.
|
|
888
|
+
|
|
889
|
+
Unlike ``load_corpus_from_hf``, this skips the ``datasets`` library
|
|
890
|
+
entirely and reads the on-disk Arrow IPC shards directly. Useful
|
|
891
|
+
when the dataset is already cached and you don't need the lib's
|
|
892
|
+
streaming/version-resolution layer.
|
|
893
|
+
|
|
894
|
+
Expected directory layout (the layout HF datasets writes by
|
|
895
|
+
default): one or more ``*.arrow`` files in ``cache_dir``, each an
|
|
896
|
+
Arrow IPC stream produced by ``Dataset.save_to_disk`` or the
|
|
897
|
+
``json`` builder. Schema is auto-detected per file's first row
|
|
898
|
+
using the same logic as ``load_corpus_from_hf``.
|
|
899
|
+
"""
|
|
900
|
+
import pyarrow as pa
|
|
901
|
+
import pyarrow.ipc as ipc
|
|
902
|
+
|
|
903
|
+
cache_dir = Path(cache_dir)
|
|
904
|
+
if not cache_dir.exists():
|
|
905
|
+
logger.warning(f"local arrow cache not found: {cache_dir}")
|
|
906
|
+
return []
|
|
907
|
+
|
|
908
|
+
shards = sorted(cache_dir.glob("*.arrow"))
|
|
909
|
+
if not shards:
|
|
910
|
+
logger.warning(f"no .arrow shards in {cache_dir}")
|
|
911
|
+
return []
|
|
912
|
+
|
|
913
|
+
model_source = model_source or f"hf-local:{cache_dir.name}"
|
|
914
|
+
out: list[ConversationSample] = []
|
|
915
|
+
detected_shape: str | None = None
|
|
916
|
+
n_rows = 0
|
|
917
|
+
n_calls = 0
|
|
918
|
+
stop = False
|
|
919
|
+
|
|
920
|
+
for shard in shards:
|
|
921
|
+
if stop:
|
|
922
|
+
break
|
|
923
|
+
try:
|
|
924
|
+
with pa.memory_map(str(shard), "r") as src:
|
|
925
|
+
rb = ipc.open_stream(src)
|
|
926
|
+
table = rb.read_all()
|
|
927
|
+
except Exception as e:
|
|
928
|
+
logger.warning(f"failed to read {shard}: {e}")
|
|
929
|
+
continue
|
|
930
|
+
|
|
931
|
+
# Convert column-major arrow to row-major dicts as we go.
|
|
932
|
+
for row in table.to_pylist():
|
|
933
|
+
if max_rows is not None and n_rows >= max_rows:
|
|
934
|
+
stop = True
|
|
935
|
+
break
|
|
936
|
+
if not isinstance(row, dict):
|
|
937
|
+
continue
|
|
938
|
+
n_rows += 1
|
|
939
|
+
|
|
940
|
+
try:
|
|
941
|
+
if row_to_conversation is not None:
|
|
942
|
+
conv = row_to_conversation(row)
|
|
943
|
+
else:
|
|
944
|
+
if detected_shape is None:
|
|
945
|
+
detected_shape = _detect_hf_shape(row)
|
|
946
|
+
logger.info(
|
|
947
|
+
f"local arrow loader: detected shape "
|
|
948
|
+
f"'{detected_shape}' for {cache_dir.name}"
|
|
949
|
+
)
|
|
950
|
+
if detected_shape == "openai":
|
|
951
|
+
conv = Conversation.from_dict(row)
|
|
952
|
+
elif detected_shape == "block_content":
|
|
953
|
+
conv = _parse_block_content_row(row)
|
|
954
|
+
elif detected_shape == "conversations":
|
|
955
|
+
conv = Conversation.from_dict({
|
|
956
|
+
"messages": row["conversations"],
|
|
957
|
+
"tools": row.get("tools") or [],
|
|
958
|
+
})
|
|
959
|
+
elif detected_shape == "sharegpt":
|
|
960
|
+
conv = _parse_sharegpt_row(row)
|
|
961
|
+
elif detected_shape == "xlam":
|
|
962
|
+
conv = _parse_xlam_row(row)
|
|
963
|
+
elif detected_shape == "glaive":
|
|
964
|
+
conv = _parse_glaive_row(row)
|
|
965
|
+
else:
|
|
966
|
+
logger.warning(
|
|
967
|
+
f"local arrow loader: unknown shape; "
|
|
968
|
+
f"keys={list(row.keys())[:8]}"
|
|
969
|
+
)
|
|
970
|
+
stop = True
|
|
971
|
+
break
|
|
972
|
+
except Exception as e:
|
|
973
|
+
logger.debug(f"local arrow: skipping row {n_rows}: {e}")
|
|
974
|
+
continue
|
|
975
|
+
|
|
976
|
+
if not isinstance(conv, Conversation) or not conv.messages:
|
|
977
|
+
continue
|
|
978
|
+
|
|
979
|
+
# Prefer the per-row teacher model id when the dataset
|
|
980
|
+
# exposes one (Kimi-K2.5-Reasoning has meta.teacher_model;
|
|
981
|
+
# other corpora may not). Falls back to the loader-level
|
|
982
|
+
# tag. Source-conditioning works either way.
|
|
983
|
+
row_source = model_source
|
|
984
|
+
meta = row.get("meta")
|
|
985
|
+
if isinstance(meta, dict) and meta.get("teacher_model"):
|
|
986
|
+
row_source = str(meta["teacher_model"])
|
|
987
|
+
|
|
988
|
+
for msg_idx, msg in enumerate(conv.messages):
|
|
989
|
+
if msg.role != "assistant":
|
|
990
|
+
continue
|
|
991
|
+
if msg.tool_calls:
|
|
992
|
+
for call_idx, tc in enumerate(msg.tool_calls):
|
|
993
|
+
if not tc.name:
|
|
994
|
+
continue
|
|
995
|
+
out.append(ConversationSample.from_conversation(
|
|
996
|
+
conv=conv,
|
|
997
|
+
prefix_len=msg_idx,
|
|
998
|
+
tool_call=tc,
|
|
999
|
+
tool_call_index=call_idx,
|
|
1000
|
+
model_source=row_source,
|
|
1001
|
+
))
|
|
1002
|
+
n_calls += 1
|
|
1003
|
+
elif emit_no_tool_turns and msg.content:
|
|
1004
|
+
out.append(ConversationSample(
|
|
1005
|
+
conversation=list(conv.messages[:msg_idx]),
|
|
1006
|
+
tools=list(conv.tools),
|
|
1007
|
+
target_tool="<no_tool>",
|
|
1008
|
+
target_arguments="",
|
|
1009
|
+
target_tool_call_index=0,
|
|
1010
|
+
model_source=row_source,
|
|
1011
|
+
))
|
|
1012
|
+
n_calls += 1
|
|
1013
|
+
|
|
1014
|
+
logger.info(
|
|
1015
|
+
f"local arrow loader: {cache_dir.name} → {n_rows} rows / "
|
|
1016
|
+
f"{n_calls} samples"
|
|
1017
|
+
)
|
|
1018
|
+
return out
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
# ---------------------------------------------------------------------
|
|
1022
|
+
# Glaive function-calling-v2 — custom text-embedded JSON shape
|
|
1023
|
+
# ---------------------------------------------------------------------
|
|
1024
|
+
|
|
1025
|
+
_GLAIVE_ROLE_MARKER_RE = re.compile(
|
|
1026
|
+
r"^(USER|ASSISTANT|FUNCTION RESPONSE):", re.MULTILINE
|
|
1027
|
+
)
|
|
1028
|
+
_GLAIVE_ROLE_NORM = {
|
|
1029
|
+
"USER": "user", "ASSISTANT": "assistant", "FUNCTION RESPONSE": "tool",
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
|
|
1033
|
+
def _glaive_extract_json_objects(text: str) -> list[dict[str, Any]]:
|
|
1034
|
+
"""Walk balanced-brace JSON objects out of free text.
|
|
1035
|
+
|
|
1036
|
+
Glaive's `system` field packs zero or more tool defs as separate
|
|
1037
|
+
JSON objects glued together with newlines and an English preamble.
|
|
1038
|
+
Standard json.loads can't parse that; we scan for top-level `{...}`
|
|
1039
|
+
spans and decode each independently. Strings (including those with
|
|
1040
|
+
backslash escapes) are honoured so braces inside string values
|
|
1041
|
+
don't break the depth counter.
|
|
1042
|
+
"""
|
|
1043
|
+
out: list[dict[str, Any]] = []
|
|
1044
|
+
depth = 0
|
|
1045
|
+
start = -1
|
|
1046
|
+
in_str = False
|
|
1047
|
+
i = 0
|
|
1048
|
+
n = len(text)
|
|
1049
|
+
while i < n:
|
|
1050
|
+
c = text[i]
|
|
1051
|
+
if in_str:
|
|
1052
|
+
if c == "\\" and i + 1 < n:
|
|
1053
|
+
i += 2
|
|
1054
|
+
continue
|
|
1055
|
+
if c == '"':
|
|
1056
|
+
in_str = False
|
|
1057
|
+
else:
|
|
1058
|
+
if c == '"':
|
|
1059
|
+
in_str = True
|
|
1060
|
+
elif c == "{":
|
|
1061
|
+
if depth == 0:
|
|
1062
|
+
start = i
|
|
1063
|
+
depth += 1
|
|
1064
|
+
elif c == "}":
|
|
1065
|
+
depth -= 1
|
|
1066
|
+
if depth == 0 and start >= 0:
|
|
1067
|
+
blob = text[start:i + 1]
|
|
1068
|
+
try:
|
|
1069
|
+
parsed = json.loads(blob)
|
|
1070
|
+
if isinstance(parsed, dict):
|
|
1071
|
+
out.append(parsed)
|
|
1072
|
+
except Exception:
|
|
1073
|
+
pass
|
|
1074
|
+
start = -1
|
|
1075
|
+
i += 1
|
|
1076
|
+
return out
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
def _glaive_parse_functioncall(blob: str) -> dict[str, Any] | None:
|
|
1080
|
+
"""Extract {name, arguments} from a `<functioncall>` payload.
|
|
1081
|
+
|
|
1082
|
+
The payload looks like `{"name": "x", "arguments": '{...}'}` where
|
|
1083
|
+
`arguments` is a single-quoted JSON string — not valid JSON. Hand-roll
|
|
1084
|
+
the extraction so we don't have to rewrite the blob.
|
|
1085
|
+
"""
|
|
1086
|
+
name_m = re.search(r'"name"\s*:\s*"([^"]+)"', blob)
|
|
1087
|
+
if not name_m:
|
|
1088
|
+
return None
|
|
1089
|
+
name = name_m.group(1)
|
|
1090
|
+
args = ""
|
|
1091
|
+
# Try single-quoted arguments first (the common case)
|
|
1092
|
+
args_m = re.search(
|
|
1093
|
+
r'"arguments"\s*:\s*\'(.*?)\'(?=\s*[,}])', blob, re.DOTALL,
|
|
1094
|
+
)
|
|
1095
|
+
if args_m:
|
|
1096
|
+
args = args_m.group(1)
|
|
1097
|
+
else:
|
|
1098
|
+
# Fallback: double-quoted string args
|
|
1099
|
+
args_m2 = re.search(r'"arguments"\s*:\s*"((?:[^"\\]|\\.)*)"', blob)
|
|
1100
|
+
if args_m2:
|
|
1101
|
+
args = args_m2.group(1).encode().decode("unicode_escape")
|
|
1102
|
+
else:
|
|
1103
|
+
# Or an inline dict
|
|
1104
|
+
args_m3 = re.search(
|
|
1105
|
+
r'"arguments"\s*:\s*(\{.*?\})\s*[,}]', blob, re.DOTALL,
|
|
1106
|
+
)
|
|
1107
|
+
if args_m3:
|
|
1108
|
+
args = args_m3.group(1)
|
|
1109
|
+
return {"name": name, "arguments": args}
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
def _parse_glaive_row(row: dict[str, Any]) -> Conversation:
|
|
1113
|
+
"""Parse a Glaive function-calling-v2 row → standard Conversation.
|
|
1114
|
+
|
|
1115
|
+
Row shape:
|
|
1116
|
+
``{"system": "<preamble + JSON tool defs>", "chat": "<USER:...
|
|
1117
|
+
ASSISTANT:...<functioncall>{...}... FUNCTION RESPONSE:...>"}``
|
|
1118
|
+
|
|
1119
|
+
Tools come from the system field's embedded JSON; turns are split
|
|
1120
|
+
on the role markers and `<functioncall>` blocks inside assistant
|
|
1121
|
+
turns become ToolCalls. `<|endoftext|>` is stripped from content.
|
|
1122
|
+
"""
|
|
1123
|
+
sys_text = str(row.get("system") or "")
|
|
1124
|
+
tool_defs = _glaive_extract_json_objects(sys_text)
|
|
1125
|
+
tools_norm: list[dict[str, Any]] = []
|
|
1126
|
+
for t in tool_defs:
|
|
1127
|
+
if not t.get("name"):
|
|
1128
|
+
continue
|
|
1129
|
+
tools_norm.append({"type": "function", "function": t})
|
|
1130
|
+
|
|
1131
|
+
chat = str(row.get("chat") or "")
|
|
1132
|
+
msgs: list[dict[str, Any]] = []
|
|
1133
|
+
matches = list(_GLAIVE_ROLE_MARKER_RE.finditer(chat))
|
|
1134
|
+
for i, m in enumerate(matches):
|
|
1135
|
+
role = _GLAIVE_ROLE_NORM.get(m.group(1))
|
|
1136
|
+
if not role:
|
|
1137
|
+
continue
|
|
1138
|
+
start = m.end()
|
|
1139
|
+
end = matches[i + 1].start() if i + 1 < len(matches) else len(chat)
|
|
1140
|
+
chunk = chat[start:end].replace("<|endoftext|>", "").strip()
|
|
1141
|
+
|
|
1142
|
+
tool_calls: list[dict[str, Any]] = []
|
|
1143
|
+
if role == "assistant":
|
|
1144
|
+
# Pull each <functioncall> payload (balanced-brace) out.
|
|
1145
|
+
j = 0
|
|
1146
|
+
cleaned: list[str] = []
|
|
1147
|
+
text_start = 0
|
|
1148
|
+
while True:
|
|
1149
|
+
fc = chunk.find("<functioncall>", j)
|
|
1150
|
+
if fc < 0:
|
|
1151
|
+
cleaned.append(chunk[text_start:])
|
|
1152
|
+
break
|
|
1153
|
+
cleaned.append(chunk[text_start:fc])
|
|
1154
|
+
# Find balanced {...} after the marker
|
|
1155
|
+
k = chunk.find("{", fc)
|
|
1156
|
+
if k < 0:
|
|
1157
|
+
text_start = fc + len("<functioncall>")
|
|
1158
|
+
j = text_start
|
|
1159
|
+
continue
|
|
1160
|
+
depth = 0
|
|
1161
|
+
end_k = k
|
|
1162
|
+
in_str = False
|
|
1163
|
+
while end_k < len(chunk):
|
|
1164
|
+
cc = chunk[end_k]
|
|
1165
|
+
if in_str:
|
|
1166
|
+
if cc == "\\" and end_k + 1 < len(chunk):
|
|
1167
|
+
end_k += 2
|
|
1168
|
+
continue
|
|
1169
|
+
if cc == '"':
|
|
1170
|
+
in_str = False
|
|
1171
|
+
else:
|
|
1172
|
+
if cc == '"':
|
|
1173
|
+
in_str = True
|
|
1174
|
+
elif cc == "{":
|
|
1175
|
+
depth += 1
|
|
1176
|
+
elif cc == "}":
|
|
1177
|
+
depth -= 1
|
|
1178
|
+
if depth == 0:
|
|
1179
|
+
break
|
|
1180
|
+
end_k += 1
|
|
1181
|
+
blob = chunk[k:end_k + 1]
|
|
1182
|
+
parsed = _glaive_parse_functioncall(blob)
|
|
1183
|
+
if parsed and parsed.get("name"):
|
|
1184
|
+
tool_calls.append({
|
|
1185
|
+
"type": "function",
|
|
1186
|
+
"function": {
|
|
1187
|
+
"name": str(parsed["name"]),
|
|
1188
|
+
"arguments": str(parsed.get("arguments") or ""),
|
|
1189
|
+
},
|
|
1190
|
+
})
|
|
1191
|
+
text_start = end_k + 1
|
|
1192
|
+
j = text_start
|
|
1193
|
+
|
|
1194
|
+
chunk = " ".join(p for p in (s.strip() for s in cleaned) if p)
|
|
1195
|
+
|
|
1196
|
+
msg_dict: dict[str, Any] = {"role": role, "content": chunk}
|
|
1197
|
+
if tool_calls:
|
|
1198
|
+
msg_dict["tool_calls"] = tool_calls
|
|
1199
|
+
msgs.append(msg_dict)
|
|
1200
|
+
|
|
1201
|
+
return Conversation.from_dict({"messages": msgs, "tools": tools_norm})
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
def load_corpus_from_glaive_json(
|
|
1205
|
+
json_path: Path | str,
|
|
1206
|
+
model_source: str | None = None,
|
|
1207
|
+
max_rows: int | None = None,
|
|
1208
|
+
emit_no_tool_turns: bool = True,
|
|
1209
|
+
) -> list[ConversationSample]:
|
|
1210
|
+
"""Load ConversationSamples from a Glaive function-calling JSON file.
|
|
1211
|
+
|
|
1212
|
+
The Glaive corpus is a single 259 MB JSON array — small enough to
|
|
1213
|
+
load whole into memory. Each row is parsed via ``_parse_glaive_row``,
|
|
1214
|
+
then emitted as one sample per assistant tool_call (open-vocab head
|
|
1215
|
+
trains) or one sample with ``target_tool='<no_tool>'`` for
|
|
1216
|
+
assistant chat turns (so the trunk still gets signal).
|
|
1217
|
+
"""
|
|
1218
|
+
json_path = Path(json_path)
|
|
1219
|
+
if not json_path.exists():
|
|
1220
|
+
logger.warning(f"glaive json not found: {json_path}")
|
|
1221
|
+
return []
|
|
1222
|
+
|
|
1223
|
+
model_source = model_source or "glaive-function-calling-v2"
|
|
1224
|
+
out: list[ConversationSample] = []
|
|
1225
|
+
|
|
1226
|
+
try:
|
|
1227
|
+
with json_path.open() as f:
|
|
1228
|
+
data = json.load(f)
|
|
1229
|
+
except Exception as e:
|
|
1230
|
+
logger.warning(f"glaive json read failed: {e}")
|
|
1231
|
+
return []
|
|
1232
|
+
if not isinstance(data, list):
|
|
1233
|
+
logger.warning(f"glaive json not a list: {json_path}")
|
|
1234
|
+
return []
|
|
1235
|
+
|
|
1236
|
+
n_rows = 0
|
|
1237
|
+
n_calls = 0
|
|
1238
|
+
for row in data:
|
|
1239
|
+
if max_rows is not None and n_rows >= max_rows:
|
|
1240
|
+
break
|
|
1241
|
+
if not isinstance(row, dict):
|
|
1242
|
+
continue
|
|
1243
|
+
n_rows += 1
|
|
1244
|
+
try:
|
|
1245
|
+
conv = _parse_glaive_row(row)
|
|
1246
|
+
except Exception as e:
|
|
1247
|
+
logger.debug(f"glaive: skipping row {n_rows}: {e}")
|
|
1248
|
+
continue
|
|
1249
|
+
if not isinstance(conv, Conversation) or not conv.messages:
|
|
1250
|
+
continue
|
|
1251
|
+
|
|
1252
|
+
for msg_idx, msg in enumerate(conv.messages):
|
|
1253
|
+
if msg.role != "assistant":
|
|
1254
|
+
continue
|
|
1255
|
+
if msg.tool_calls:
|
|
1256
|
+
for call_idx, tc in enumerate(msg.tool_calls):
|
|
1257
|
+
if not tc.name:
|
|
1258
|
+
continue
|
|
1259
|
+
out.append(ConversationSample.from_conversation(
|
|
1260
|
+
conv=conv,
|
|
1261
|
+
prefix_len=msg_idx,
|
|
1262
|
+
tool_call=tc,
|
|
1263
|
+
tool_call_index=call_idx,
|
|
1264
|
+
model_source=model_source,
|
|
1265
|
+
))
|
|
1266
|
+
n_calls += 1
|
|
1267
|
+
elif emit_no_tool_turns and msg.content:
|
|
1268
|
+
out.append(ConversationSample(
|
|
1269
|
+
conversation=list(conv.messages[:msg_idx]),
|
|
1270
|
+
tools=list(conv.tools),
|
|
1271
|
+
target_tool="<no_tool>",
|
|
1272
|
+
target_arguments="",
|
|
1273
|
+
target_tool_call_index=0,
|
|
1274
|
+
model_source=model_source,
|
|
1275
|
+
))
|
|
1276
|
+
n_calls += 1
|
|
1277
|
+
|
|
1278
|
+
logger.info(
|
|
1279
|
+
f"glaive json loader: {json_path.name} → {n_rows} rows / "
|
|
1280
|
+
f"{n_calls} samples"
|
|
1281
|
+
)
|
|
1282
|
+
return out
|
|
1283
|
+
|
|
1284
|
+
|
|
1285
|
+
# ---------------------------------------------------------------------
|
|
1286
|
+
# Forge feature-outcome corpus (ADR 0006 Phase A)
|
|
1287
|
+
# ---------------------------------------------------------------------
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
def load_corpus_from_feature_outcomes(
|
|
1291
|
+
path: Path = Path("data/nn/feature_outcomes.jsonl"),
|
|
1292
|
+
dedupe_by_feature: bool = True,
|
|
1293
|
+
) -> list[ConversationSample]:
|
|
1294
|
+
"""Convert each Forge feature-outcome row into one ConversationSample.
|
|
1295
|
+
|
|
1296
|
+
Schema (written by the orchestrator):
|
|
1297
|
+
feature_text, model_used, success (bool), n_turns, n_tool_calls,
|
|
1298
|
+
duration_s, project_id, feature_id, session_id, extras{}.
|
|
1299
|
+
|
|
1300
|
+
The feature text becomes the only "message" (one-turn batch); the
|
|
1301
|
+
model id flows through `model_source` so the trunk's source-
|
|
1302
|
+
embedding can bias predictions per teacher model. All other state
|
|
1303
|
+
(mood, tool history, images) is zero — Forge predictions are made
|
|
1304
|
+
*before* a session runs, so we deliberately do not condition on
|
|
1305
|
+
in-turn signals that don't exist yet at prediction time.
|
|
1306
|
+
|
|
1307
|
+
Other heads' targets stay None — those heads are skipped for these
|
|
1308
|
+
samples (HeadSpec.optional_target). The only label this corpus
|
|
1309
|
+
carries is `target_feature_success`.
|
|
1310
|
+
|
|
1311
|
+
``dedupe_by_feature`` (default True) collapses retries — every
|
|
1312
|
+
(project_id, feature_id) pair contributes ONE sample with the
|
|
1313
|
+
final outcome (last row by ts). Without this, a stuck feature that
|
|
1314
|
+
retried 200 times before being parked dominates the corpus and
|
|
1315
|
+
pushes the head to "always predict fail". The orchestrator's
|
|
1316
|
+
decision question is "is this feature solvable?", not "will this
|
|
1317
|
+
particular retry succeed?", so the final outcome is the right
|
|
1318
|
+
label. Set False if you actually want per-session granularity
|
|
1319
|
+
(e.g. for predict_failure_repeats analysis).
|
|
1320
|
+
|
|
1321
|
+
Rows with missing required fields are dropped silently. Skipping
|
|
1322
|
+
bad rows is the right behaviour because the JSONL is append-only
|
|
1323
|
+
and a single malformed write must not bork the next training run.
|
|
1324
|
+
"""
|
|
1325
|
+
if not path.exists():
|
|
1326
|
+
return []
|
|
1327
|
+
try:
|
|
1328
|
+
lines = path.read_text(encoding="utf-8").splitlines()
|
|
1329
|
+
except OSError:
|
|
1330
|
+
return []
|
|
1331
|
+
|
|
1332
|
+
raw_rows: list[dict[str, Any]] = []
|
|
1333
|
+
for line in lines:
|
|
1334
|
+
line = line.strip()
|
|
1335
|
+
if not line:
|
|
1336
|
+
continue
|
|
1337
|
+
try:
|
|
1338
|
+
row = json.loads(line)
|
|
1339
|
+
except Exception:
|
|
1340
|
+
continue
|
|
1341
|
+
text = (row.get("feature_text") or "").strip()
|
|
1342
|
+
success = row.get("success")
|
|
1343
|
+
if not text or success is None:
|
|
1344
|
+
continue
|
|
1345
|
+
raw_rows.append(row)
|
|
1346
|
+
|
|
1347
|
+
if dedupe_by_feature:
|
|
1348
|
+
# Keep the latest row per (project_id, feature_id). If feature_id
|
|
1349
|
+
# is missing (older rows) we fall back to (project_id, feature_text)
|
|
1350
|
+
# so a feature with no id still gets one slot, not one slot per
|
|
1351
|
+
# retry.
|
|
1352
|
+
latest: dict[tuple, dict[str, Any]] = {}
|
|
1353
|
+
for row in raw_rows:
|
|
1354
|
+
key = (
|
|
1355
|
+
row.get("project_id"),
|
|
1356
|
+
row.get("feature_id") or row.get("feature_text"),
|
|
1357
|
+
)
|
|
1358
|
+
prev = latest.get(key)
|
|
1359
|
+
if prev is None or row.get("ts", 0) > prev.get("ts", 0):
|
|
1360
|
+
latest[key] = row
|
|
1361
|
+
raw_rows = list(latest.values())
|
|
1362
|
+
|
|
1363
|
+
out: list[ConversationSample] = []
|
|
1364
|
+
for row in raw_rows:
|
|
1365
|
+
text = (row.get("feature_text") or "").strip()
|
|
1366
|
+
success = bool(row.get("success"))
|
|
1367
|
+
# Feature description as a one-turn user message — the model's
|
|
1368
|
+
# job here is "given just the feature text, predict success".
|
|
1369
|
+
# No tool calls, no tool list — feature_success is the only
|
|
1370
|
+
# head this row labels.
|
|
1371
|
+
out.append(ConversationSample(
|
|
1372
|
+
conversation=[ChatMessage(role="user", content=text)],
|
|
1373
|
+
tools=[],
|
|
1374
|
+
target_tool="<no_tool>",
|
|
1375
|
+
target_value=1.0 if success else 0.0,
|
|
1376
|
+
model_source=row.get("model_used") or "<unknown>",
|
|
1377
|
+
target_feature_success=1.0 if success else 0.0,
|
|
1378
|
+
))
|
|
1379
|
+
return out
|
|
1380
|
+
|
|
1381
|
+
|
|
1382
|
+
# ---------------------------------------------------------------------
|
|
1383
|
+
# Replay buffer
|
|
1384
|
+
# ---------------------------------------------------------------------
|
|
1385
|
+
|
|
1386
|
+
|
|
1387
|
+
@dataclass
|
|
1388
|
+
class ReplayBuffer:
|
|
1389
|
+
"""Circular buffer for online training samples."""
|
|
1390
|
+
|
|
1391
|
+
capacity: int = 8192
|
|
1392
|
+
_buf: collections.deque = field(default_factory=lambda: collections.deque(maxlen=8192))
|
|
1393
|
+
|
|
1394
|
+
def __post_init__(self) -> None:
|
|
1395
|
+
# Re-bind maxlen to honor the constructor arg
|
|
1396
|
+
if self._buf.maxlen != self.capacity:
|
|
1397
|
+
self._buf = collections.deque(self._buf, maxlen=self.capacity)
|
|
1398
|
+
|
|
1399
|
+
def push(self, sample: ConversationSample) -> None:
|
|
1400
|
+
self._buf.append(sample)
|
|
1401
|
+
|
|
1402
|
+
def __len__(self) -> int:
|
|
1403
|
+
return len(self._buf)
|
|
1404
|
+
|
|
1405
|
+
def sample(self, n: int) -> list[ConversationSample]:
|
|
1406
|
+
if n <= 0 or not self._buf:
|
|
1407
|
+
return []
|
|
1408
|
+
buf = list(self._buf)
|
|
1409
|
+
# Surprise-weighted sampling (Prioritized Experience Replay):
|
|
1410
|
+
# samples Caudate guessed badly on appear more often. Weight is
|
|
1411
|
+
# clamped above zero so easy samples still occasionally show up
|
|
1412
|
+
# (otherwise the model forgets what it already knows). Sampling
|
|
1413
|
+
# is with replacement, so n may exceed buffer size — high-
|
|
1414
|
+
# priority items can repeat within a batch.
|
|
1415
|
+
weights = [max(0.05, float(getattr(s, "surprise", 0.5))) for s in buf]
|
|
1416
|
+
return random.choices(buf, weights=weights, k=n)
|
|
1417
|
+
|
|
1418
|
+
def all(self) -> list[ConversationSample]:
|
|
1419
|
+
return list(self._buf)
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
# ---------------------------------------------------------------------
|
|
1423
|
+
# Batch collation
|
|
1424
|
+
# ---------------------------------------------------------------------
|
|
1425
|
+
|
|
1426
|
+
|
|
1427
|
+
_NO_TOOL_DESCRIPTION = (
|
|
1428
|
+
"answer the user directly without calling any tool"
|
|
1429
|
+
)
|
|
1430
|
+
|
|
1431
|
+
|
|
1432
|
+
def _format_tool_spec(t: "ToolDef") -> str:
|
|
1433
|
+
"""Render a ToolDef as the "{name}: {description}" string that the
|
|
1434
|
+
contrastive tool head's text encoder consumes. Capped at 240 chars
|
|
1435
|
+
to keep batched embedding tractable on big tool registries."""
|
|
1436
|
+
name = (t.name or "").strip()
|
|
1437
|
+
desc = (t.description or "").strip()
|
|
1438
|
+
if desc:
|
|
1439
|
+
return f"{name}: {desc}"[:240]
|
|
1440
|
+
return name[:240]
|
|
1441
|
+
|
|
1442
|
+
|
|
1443
|
+
def collate(
|
|
1444
|
+
samples: list[ConversationSample],
|
|
1445
|
+
cfg: NNConfig,
|
|
1446
|
+
vocab: ToolVocab,
|
|
1447
|
+
source_vocab: "SourceVocab | None" = None,
|
|
1448
|
+
) -> dict[str, Any]:
|
|
1449
|
+
"""Turn a list of ConversationSamples into model-ready tensors.
|
|
1450
|
+
|
|
1451
|
+
The chat-schema conversation gets flattened into role-prefixed
|
|
1452
|
+
strings by ``conversation_to_strings``; tool history is derived
|
|
1453
|
+
from past assistant tool_calls. Tool candidates per sample are
|
|
1454
|
+
surfaced as ``tool_specs`` (a list of "{name}: {description}"
|
|
1455
|
+
strings) so the contrastive tool head can score them. The first
|
|
1456
|
+
candidate slot is always a synthetic ``<no_tool>`` entry — that's
|
|
1457
|
+
how "the assistant answered directly without calling anything"
|
|
1458
|
+
becomes a real class the head can predict.
|
|
1459
|
+
"""
|
|
1460
|
+
B = len(samples)
|
|
1461
|
+
msg_window = cfg.msg_window
|
|
1462
|
+
hist_window = cfg.history_window
|
|
1463
|
+
img_window = cfg.image_window if cfg.use_vision else 0
|
|
1464
|
+
K = cfg.max_tools_per_sample
|
|
1465
|
+
|
|
1466
|
+
messages: list[list[str]] = []
|
|
1467
|
+
image_paths: list[list[str]] = []
|
|
1468
|
+
tool_ids = torch.zeros((B, hist_window), dtype=torch.long)
|
|
1469
|
+
mood = torch.zeros((B, cfg.mood_dim), dtype=torch.float32)
|
|
1470
|
+
|
|
1471
|
+
target_tier = torch.zeros(B, dtype=torch.long)
|
|
1472
|
+
target_think = torch.zeros(B, dtype=torch.float32)
|
|
1473
|
+
target_value = torch.zeros(B, dtype=torch.float32)
|
|
1474
|
+
source_id = torch.zeros(B, dtype=torch.long)
|
|
1475
|
+
|
|
1476
|
+
# Contrastive tool head inputs.
|
|
1477
|
+
# tool_specs: per-sample list of "{name}: {description}" strings,
|
|
1478
|
+
# first slot is always "<no_tool>: answer directly..."
|
|
1479
|
+
# target_tool_idx: index of the called tool within that list, or -1
|
|
1480
|
+
# if the called tool isn't among the candidates (skip-this-sample).
|
|
1481
|
+
tool_specs: list[list[str]] = []
|
|
1482
|
+
target_tool_idx = torch.full((B,), -1, dtype=torch.long)
|
|
1483
|
+
tool_target_valid = torch.zeros(B, dtype=torch.bool)
|
|
1484
|
+
|
|
1485
|
+
for i, s in enumerate(samples):
|
|
1486
|
+
# Flatten the role-tagged conversation into text the encoder
|
|
1487
|
+
# already speaks. Pad/truncate to msg_window.
|
|
1488
|
+
msg_strs = conversation_to_strings(s.conversation)[-msg_window:]
|
|
1489
|
+
if len(msg_strs) < msg_window:
|
|
1490
|
+
msg_strs = [""] * (msg_window - len(msg_strs)) + msg_strs
|
|
1491
|
+
messages.append(msg_strs)
|
|
1492
|
+
|
|
1493
|
+
if cfg.use_vision:
|
|
1494
|
+
imgs = list(s.image_paths)[-img_window:]
|
|
1495
|
+
if len(imgs) < img_window:
|
|
1496
|
+
imgs = [""] * (img_window - len(imgs)) + imgs
|
|
1497
|
+
image_paths.append(imgs)
|
|
1498
|
+
else:
|
|
1499
|
+
image_paths.append([])
|
|
1500
|
+
|
|
1501
|
+
# Tool history = the sequence of tool names called in this
|
|
1502
|
+
# conversation prefix. Stable ordering preserves recency.
|
|
1503
|
+
# Clamp every id below tool_vocab_size — the encoder's tool
|
|
1504
|
+
# embedding table is fixed-size and a vocab that grows past
|
|
1505
|
+
# it (public corpora can hit 10K+ tools) would otherwise
|
|
1506
|
+
# crash with index-out-of-range. Modulo wraps high ids into
|
|
1507
|
+
# the table with some collisions; better than missing data.
|
|
1508
|
+
tool_hist = conversation_tool_history(s.conversation)[-hist_window:]
|
|
1509
|
+
hist = [vocab.get(t) % cfg.tool_vocab_size for t in tool_hist]
|
|
1510
|
+
if len(hist) < hist_window:
|
|
1511
|
+
hist = [vocab.SPECIAL["<pad>"]] * (hist_window - len(hist)) + hist
|
|
1512
|
+
tool_ids[i] = torch.tensor(hist, dtype=torch.long)
|
|
1513
|
+
|
|
1514
|
+
mood[i] = torch.tensor(s.mood[: cfg.mood_dim], dtype=torch.float32)
|
|
1515
|
+
target_tier[i] = int(s.target_tier)
|
|
1516
|
+
target_think[i] = float(s.target_think)
|
|
1517
|
+
target_value[i] = float(s.target_value)
|
|
1518
|
+
if source_vocab is not None:
|
|
1519
|
+
# `add` is idempotent — first occurrence assigns an id, later
|
|
1520
|
+
# ones look it up. Falls back to <unknown> if the cap fills.
|
|
1521
|
+
source_id[i] = source_vocab.add(getattr(s, "model_source", "<unknown>"))
|
|
1522
|
+
else:
|
|
1523
|
+
source_id[i] = 0 # <unknown> — keeps the model in baseline mode
|
|
1524
|
+
|
|
1525
|
+
# Build the candidate tool list for this sample. The synthetic
|
|
1526
|
+
# <no_tool> entry occupies slot 0 so the head always has a
|
|
1527
|
+
# "don't call anything" option available — without it the
|
|
1528
|
+
# softmax would be forced to nominate something even when the
|
|
1529
|
+
# right answer is "respond directly".
|
|
1530
|
+
sample_tools = list(s.tools)[: max(0, K - 1)]
|
|
1531
|
+
specs = [f"<no_tool>: {_NO_TOOL_DESCRIPTION}"]
|
|
1532
|
+
for t in sample_tools:
|
|
1533
|
+
specs.append(_format_tool_spec(t))
|
|
1534
|
+
tool_specs.append(specs)
|
|
1535
|
+
|
|
1536
|
+
# Resolve target_tool to an index in `specs`. Match by name.
|
|
1537
|
+
target_name = (s.target_tool or "<no_tool>")
|
|
1538
|
+
if target_name == "<no_tool>" or target_name == "":
|
|
1539
|
+
target_tool_idx[i] = 0
|
|
1540
|
+
tool_target_valid[i] = True
|
|
1541
|
+
else:
|
|
1542
|
+
found = -1
|
|
1543
|
+
for k, t in enumerate(sample_tools):
|
|
1544
|
+
if t.name == target_name:
|
|
1545
|
+
found = k + 1 # +1 for the <no_tool> prefix slot
|
|
1546
|
+
break
|
|
1547
|
+
if found >= 0:
|
|
1548
|
+
target_tool_idx[i] = found
|
|
1549
|
+
tool_target_valid[i] = True
|
|
1550
|
+
# Else: target_tool isn't among candidates → leave
|
|
1551
|
+
# target_tool_idx=-1, tool_target_valid=False; the trainer
|
|
1552
|
+
# skips this sample's contribution to the tool loss.
|
|
1553
|
+
|
|
1554
|
+
out = {
|
|
1555
|
+
"messages": messages,
|
|
1556
|
+
"image_paths": image_paths,
|
|
1557
|
+
"tool_ids": tool_ids,
|
|
1558
|
+
"mood": mood,
|
|
1559
|
+
"target_tier": target_tier,
|
|
1560
|
+
"target_think": target_think,
|
|
1561
|
+
"target_value": target_value,
|
|
1562
|
+
"source_id": source_id,
|
|
1563
|
+
# Contrastive tool head inputs/targets
|
|
1564
|
+
"tool_specs": tool_specs,
|
|
1565
|
+
"target_tool_idx": target_tool_idx,
|
|
1566
|
+
"tool_target_valid": tool_target_valid,
|
|
1567
|
+
}
|
|
1568
|
+
|
|
1569
|
+
# Extended-head targets — only emit if EVERY sample in this batch
|
|
1570
|
+
# has the label. A partially-labeled batch would have to use a fake
|
|
1571
|
+
# value (e.g. 0.0) for the unlabeled rows, which would train the
|
|
1572
|
+
# head against a synthetic "no" signal and corrupt learning. So
|
|
1573
|
+
# if even one sample is missing the label, we drop the field
|
|
1574
|
+
# entirely and the trainer skips that head for this batch (via
|
|
1575
|
+
# HeadSpec.optional_target).
|
|
1576
|
+
_SCALAR_BCE_FIELDS = (
|
|
1577
|
+
"target_memory_write", "target_cache_hit", "target_permission",
|
|
1578
|
+
"target_refusal", "target_code_response", "target_stall",
|
|
1579
|
+
"target_stop_iter", "target_compaction",
|
|
1580
|
+
"target_subagent_spawn",
|
|
1581
|
+
"target_feature_success",
|
|
1582
|
+
)
|
|
1583
|
+
_SCALAR_MSE_FIELDS = (
|
|
1584
|
+
"target_latency_s", "target_token_budget", "target_reward_model",
|
|
1585
|
+
)
|
|
1586
|
+
for field_name in _SCALAR_BCE_FIELDS + _SCALAR_MSE_FIELDS:
|
|
1587
|
+
vals = [getattr(s, field_name, None) for s in samples]
|
|
1588
|
+
if all(v is not None for v in vals):
|
|
1589
|
+
out[field_name] = torch.tensor(
|
|
1590
|
+
[float(v) for v in vals], dtype=torch.float32
|
|
1591
|
+
)
|
|
1592
|
+
|
|
1593
|
+
# difficulty is a 3-class CE target → int64
|
|
1594
|
+
diffs = [getattr(s, "target_difficulty", None) for s in samples]
|
|
1595
|
+
if all(d is not None for d in diffs):
|
|
1596
|
+
out["target_difficulty"] = torch.tensor(
|
|
1597
|
+
[int(d) for d in diffs], dtype=torch.long
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
# mood_pred is a 4-D vector
|
|
1601
|
+
moods = [getattr(s, "target_mood_pred", None) for s in samples]
|
|
1602
|
+
if all(m is not None for m in moods):
|
|
1603
|
+
out["target_mood_pred"] = torch.tensor(
|
|
1604
|
+
[list(m) + [0.0]*(4-len(m)) if len(m) < 4 else list(m)[:4] for m in moods],
|
|
1605
|
+
dtype=torch.float32,
|
|
1606
|
+
)
|
|
1607
|
+
return out
|
|
1608
|
+
|
|
1609
|
+
|
|
1610
|
+
def split_train_eval(
|
|
1611
|
+
samples: list[ConversationSample], eval_ratio: float, seed: int,
|
|
1612
|
+
) -> tuple[list[ConversationSample], list[ConversationSample]]:
|
|
1613
|
+
rng = random.Random(seed)
|
|
1614
|
+
shuffled = list(samples)
|
|
1615
|
+
rng.shuffle(shuffled)
|
|
1616
|
+
n_eval = max(1, int(len(shuffled) * eval_ratio)) if shuffled else 0
|
|
1617
|
+
return shuffled[n_eval:], shuffled[:n_eval]
|
|
1618
|
+
|
|
1619
|
+
|
|
1620
|
+
def iter_batches(
|
|
1621
|
+
samples: list[ConversationSample],
|
|
1622
|
+
batch_size: int,
|
|
1623
|
+
cfg: NNConfig,
|
|
1624
|
+
vocab: ToolVocab,
|
|
1625
|
+
shuffle: bool = True,
|
|
1626
|
+
source_vocab: "SourceVocab | None" = None,
|
|
1627
|
+
) -> Iterator[dict[str, Any]]:
|
|
1628
|
+
if shuffle:
|
|
1629
|
+
samples = list(samples)
|
|
1630
|
+
random.shuffle(samples)
|
|
1631
|
+
for i in range(0, len(samples), batch_size):
|
|
1632
|
+
chunk = samples[i:i + batch_size]
|
|
1633
|
+
if not chunk:
|
|
1634
|
+
continue
|
|
1635
|
+
yield collate(chunk, cfg, vocab, source_vocab=source_vocab)
|