caudate-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. api/__init__.py +5 -0
  2. api/anthropic_compat.py +1518 -0
  3. api/artifact_viewer.py +366 -0
  4. api/caudate_middleware.py +618 -0
  5. api/forge_bootstrapper_routes.py +377 -0
  6. api/forge_routes.py +630 -0
  7. api/forge_system_routes.py +294 -0
  8. api/openai_compat.py +1993 -0
  9. api/server.py +667 -0
  10. api/storyboard_page.py +677 -0
  11. caudate_cli-0.1.0.dist-info/METADATA +354 -0
  12. caudate_cli-0.1.0.dist-info/RECORD +153 -0
  13. caudate_cli-0.1.0.dist-info/WHEEL +5 -0
  14. caudate_cli-0.1.0.dist-info/entry_points.txt +2 -0
  15. caudate_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  16. caudate_cli-0.1.0.dist-info/top_level.txt +14 -0
  17. cognos_mcp/__init__.py +4 -0
  18. cognos_mcp/bridge.py +41 -0
  19. cognos_mcp/client.py +70 -0
  20. cognos_mcp/config.py +49 -0
  21. cognos_mcp/server.py +66 -0
  22. config.py +82 -0
  23. core/__init__.py +0 -0
  24. core/agent.py +468 -0
  25. core/agentic_loop.py +731 -0
  26. core/anthropic_auth.py +91 -0
  27. core/background.py +113 -0
  28. core/banner.py +134 -0
  29. core/bootstrap.py +292 -0
  30. core/citations.py +131 -0
  31. core/compaction.py +109 -0
  32. core/constitution.py +198 -0
  33. core/diff_viewer.py +87 -0
  34. core/export.py +85 -0
  35. core/file_refs.py +119 -0
  36. core/files.py +199 -0
  37. core/hooks.py +209 -0
  38. core/image.py +599 -0
  39. core/input.py +91 -0
  40. core/loop.py +238 -0
  41. core/memory_md.py +147 -0
  42. core/notifications.py +99 -0
  43. core/ownership.py +181 -0
  44. core/paste.py +81 -0
  45. core/permissions.py +210 -0
  46. core/plan_mode.py +215 -0
  47. core/sandbox_prompt.py +185 -0
  48. core/scheduler.py +195 -0
  49. core/schemas.py +202 -0
  50. core/session.py +90 -0
  51. core/settings.py +132 -0
  52. core/skills.py +398 -0
  53. core/slash_commands.py +977 -0
  54. core/statusline.py +61 -0
  55. core/subagent.py +300 -0
  56. core/thinking.py +50 -0
  57. core/updater.py +122 -0
  58. core/usage.py +109 -0
  59. core/worktree.py +93 -0
  60. execution/__init__.py +0 -0
  61. execution/executor.py +329 -0
  62. execution/plugins.py +108 -0
  63. execution/tools/__init__.py +0 -0
  64. execution/tools/agent_tool.py +107 -0
  65. execution/tools/agentic_tool.py +297 -0
  66. execution/tools/artifact_tool.py +191 -0
  67. execution/tools/ask_user_question_tool.py +137 -0
  68. execution/tools/base.py +81 -0
  69. execution/tools/calculator_tool.py +137 -0
  70. execution/tools/cognos_card_tool.py +124 -0
  71. execution/tools/cron_tool.py +215 -0
  72. execution/tools/datetime_tool.py +215 -0
  73. execution/tools/describe_image_tool.py +161 -0
  74. execution/tools/draw_tool.py +164 -0
  75. execution/tools/edit_image_tool.py +262 -0
  76. execution/tools/edit_tool.py +245 -0
  77. execution/tools/file_tool.py +90 -0
  78. execution/tools/find_anywhere_tool.py +255 -0
  79. execution/tools/forge_feature_tools.py +377 -0
  80. execution/tools/glob_tool.py +59 -0
  81. execution/tools/grep_tool.py +89 -0
  82. execution/tools/http_request_tool.py +224 -0
  83. execution/tools/load_skill_tool.py +104 -0
  84. execution/tools/longcat_avatar_tool.py +384 -0
  85. execution/tools/mcp_tool.py +100 -0
  86. execution/tools/notebook_tool.py +279 -0
  87. execution/tools/openapi_tool.py +440 -0
  88. execution/tools/plan_mode_tool.py +95 -0
  89. execution/tools/push_notification_tool.py +157 -0
  90. execution/tools/python_tool.py +61 -0
  91. execution/tools/respond_tool.py +40 -0
  92. execution/tools/sandbox_tool.py +378 -0
  93. execution/tools/search_tool.py +153 -0
  94. execution/tools/semantic_search_tool.py +106 -0
  95. execution/tools/shell_tool.py +283 -0
  96. execution/tools/speak_tool.py +134 -0
  97. execution/tools/storyboard_tool.py +727 -0
  98. execution/tools/system_info_tool.py +212 -0
  99. execution/tools/task_tool.py +323 -0
  100. execution/tools/think_tool.py +49 -0
  101. execution/tools/transcribe_audio_tool.py +86 -0
  102. execution/tools/update_memory_tool.py +92 -0
  103. execution/tools/web_fetch_tool.py +82 -0
  104. execution/tools/worktree_tool.py +174 -0
  105. llm/__init__.py +0 -0
  106. llm/fallback.py +116 -0
  107. llm/models.py +320 -0
  108. llm/provider.py +1356 -0
  109. llm/router.py +373 -0
  110. main.py +1889 -0
  111. memory/__init__.py +0 -0
  112. memory/episodic.py +99 -0
  113. memory/procedural.py +145 -0
  114. memory/semantic.py +71 -0
  115. memory/working.py +64 -0
  116. nn/__init__.py +43 -0
  117. nn/auto_evolve.py +245 -0
  118. nn/caudate.py +136 -0
  119. nn/config.py +141 -0
  120. nn/consolidator.py +81 -0
  121. nn/data.py +1635 -0
  122. nn/encoder.py +258 -0
  123. nn/forge_advisor.py +303 -0
  124. nn/format.py +235 -0
  125. nn/heads.py +432 -0
  126. nn/observer.py +994 -0
  127. nn/policy.py +214 -0
  128. nn/runtime.py +343 -0
  129. nn/scorer.py +175 -0
  130. nn/trainer.py +515 -0
  131. nn/vision.py +352 -0
  132. personality/__init__.py +23 -0
  133. personality/engine.py +129 -0
  134. personality/identity.py +144 -0
  135. personality/inner_voice.py +100 -0
  136. personality/mood.py +205 -0
  137. planning/__init__.py +0 -0
  138. planning/dev_server.py +221 -0
  139. planning/forge_models.py +718 -0
  140. planning/orchestrator.py +1363 -0
  141. planning/planner.py +451 -0
  142. planning/task_graph.py +61 -0
  143. reflection/__init__.py +0 -0
  144. reflection/meta_learner.py +156 -0
  145. reflection/reflector.py +127 -0
  146. ui/__init__.py +5 -0
  147. ui/display.py +88 -0
  148. voice/__init__.py +0 -0
  149. voice/conversation.py +125 -0
  150. voice/listener.py +111 -0
  151. voice/speaker.py +59 -0
  152. voice/stt.py +126 -0
  153. voice/tts.py +214 -0
nn/data.py ADDED
@@ -0,0 +1,1635 @@
1
+ """Dataset + replay buffer for the cognitive controller.
2
+
3
+ Two data sources:
4
+
5
+ 1. **Disk corpus** — every saved Session under `data/sessions/*.json`.
6
+ Each tool call inside a session becomes one (state, action, reward)
7
+ example. Reward defaults to 0.5 unless the session has an explicit
8
+ reflection score attached.
9
+
10
+ 2. **Live replay buffer** — circular in-memory buffer fed by the agent
11
+ during use. Lets the controller train online from current sessions
12
+ before they're saved to disk.
13
+
14
+ ToolVocab maps the dynamic set of registered tools to stable integer
15
+ ids that survive across runs. The vocab grows append-only — adding a
16
+ new tool doesn't invalidate older checkpoints; it just adds a row to
17
+ the embedding table at training time.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import collections
23
+ import json
24
+ import logging
25
+ import random
26
+ import re
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Iterator
30
+
31
+ import torch
32
+
33
+ from nn.config import NNConfig
34
+ from nn.format import ChatMessage, Conversation, ToolCall, ToolDef
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # ---------------------------------------------------------------------
40
+ # Tool vocab
41
+ # ---------------------------------------------------------------------
42
+
43
+
44
+ class ToolVocab:
45
+ """Bidirectional id ↔ name map for tool tokens.
46
+
47
+ Four special ids are reserved upfront:
48
+ <pad> — padding for batched sequences
49
+ <unk> — placeholder for missing/unmapped labels at inference
50
+ <bos> — start-of-sequence
51
+ <no_tool> — explicit "no tool was called" — a real action class
52
+ (distinct from <unk>: a turn where the assistant
53
+ correctly chose to just answer without any tool)
54
+
55
+ Keeping <no_tool> separate from <unk> matters for training: ~40% of
56
+ real conversation turns don't call a tool, and that's a valid
57
+ decision the model should learn to *predict*, not a degenerate
58
+ fallback class.
59
+ """
60
+
61
+ SPECIAL = {
62
+ "<pad>": 0,
63
+ "<unk>": 1,
64
+ "<bos>": 2,
65
+ "<no_tool>": 3,
66
+ }
67
+
68
+ def __init__(self):
69
+ self._id2name: dict[int, str] = {v: k for k, v in self.SPECIAL.items()}
70
+ self._name2id: dict[str, int] = dict(self.SPECIAL)
71
+ self._next_id = max(self.SPECIAL.values()) + 1
72
+
73
+ def add(self, name: str) -> int:
74
+ if name in self._name2id:
75
+ return self._name2id[name]
76
+ idx = self._next_id
77
+ self._next_id += 1
78
+ self._name2id[name] = idx
79
+ self._id2name[idx] = name
80
+ return idx
81
+
82
+ def get(self, name: str) -> int:
83
+ return self._name2id.get(name, self.SPECIAL["<unk>"])
84
+
85
+ def name(self, idx: int) -> str:
86
+ return self._id2name.get(idx, "<unk>")
87
+
88
+ def __len__(self) -> int:
89
+ return self._next_id
90
+
91
+ def to_dict(self) -> dict[str, int]:
92
+ return dict(self._name2id)
93
+
94
+ @classmethod
95
+ def from_dict(cls, data: dict[str, int]) -> "ToolVocab":
96
+ v = cls()
97
+ for name, idx in data.items():
98
+ if name in v.SPECIAL:
99
+ continue
100
+ v._name2id[name] = idx
101
+ v._id2name[idx] = name
102
+ v._next_id = max(v._next_id, idx + 1)
103
+ return v
104
+
105
+
106
+ class SourceVocab:
107
+ """Bidirectional id ↔ name map for teacher-model sources.
108
+
109
+ Phase 2 of CAUDATE_EVOLUTION.md: every ConversationSample carries a
110
+ `model_source` string ("anthropic/claude-opus-4-7", "ollama/...",
111
+ "<unknown>", etc.). The model conditions on that source via a
112
+ learned embedding lookup. This vocab maps the names to stable ids,
113
+ capped at `cfg.source_vocab_size` so the embedding table stays
114
+ finite. Overflow (more than the cap) folds back to <unknown>.
115
+
116
+ Reserved slot 0 = <unknown> = the zero-bias baseline. Legacy
117
+ untagged samples and any model we haven't seen before share this
118
+ slot.
119
+ """
120
+
121
+ SPECIAL = {"<unknown>": 0}
122
+ DEFAULT_CAP = 16
123
+
124
+ def __init__(self, cap: int = DEFAULT_CAP):
125
+ self.cap = max(1, int(cap))
126
+ self._id2name: dict[int, str] = {v: k for k, v in self.SPECIAL.items()}
127
+ self._name2id: dict[str, int] = dict(self.SPECIAL)
128
+ self._next_id = max(self.SPECIAL.values()) + 1
129
+
130
+ def add(self, name: str) -> int:
131
+ if not isinstance(name, str) or not name:
132
+ return self.SPECIAL["<unknown>"]
133
+ if name in self._name2id:
134
+ return self._name2id[name]
135
+ if self._next_id >= self.cap:
136
+ # Overflow: silently fold to <unknown>. Better than
137
+ # blowing up; the embedding table can't grow at runtime.
138
+ return self.SPECIAL["<unknown>"]
139
+ idx = self._next_id
140
+ self._next_id += 1
141
+ self._name2id[name] = idx
142
+ self._id2name[idx] = name
143
+ return idx
144
+
145
+ def get(self, name: str) -> int:
146
+ return self._name2id.get(name, self.SPECIAL["<unknown>"])
147
+
148
+ def name(self, idx: int) -> str:
149
+ return self._id2name.get(idx, "<unknown>")
150
+
151
+ def __len__(self) -> int:
152
+ return self._next_id
153
+
154
+ def to_dict(self) -> dict[str, int]:
155
+ return dict(self._name2id)
156
+
157
+ @classmethod
158
+ def from_dict(cls, data: dict[str, int], cap: int = DEFAULT_CAP) -> "SourceVocab":
159
+ v = cls(cap=cap)
160
+ for name, idx in data.items():
161
+ if name in v.SPECIAL:
162
+ continue
163
+ if int(idx) >= v.cap:
164
+ continue # truncate any persisted ids beyond the new cap
165
+ v._name2id[name] = int(idx)
166
+ v._id2name[int(idx)] = name
167
+ v._next_id = max(v._next_id, int(idx) + 1)
168
+ return v
169
+
170
+
171
+ # ---------------------------------------------------------------------
172
+ # Sample structure — ConversationSample is THE training row.
173
+ # ---------------------------------------------------------------------
174
+ # Standard chat-tool-call schema (OpenAI shape) — same format used by
175
+ # every public function-calling dataset (xLAM, ToolBench, ToolAlpaca,
176
+ # Gorilla, API-Bank). Cognos sessions, external HuggingFace datasets,
177
+ # and the agent's live replay buffer all produce this single type, so
178
+ # the training path doesn't need to know where samples came from.
179
+ #
180
+ # Layout:
181
+ # - conversation: list[ChatMessage] ← the role-tagged turn list
182
+ # - tools: list[ToolDef] ← what was available to call
183
+ # - target_tool: str ← name of the tool actually called
184
+ # - target_arguments: str ← JSON-encoded args (reserved for
185
+ # a future generative arg head)
186
+ # - target_tool_call_index: int ← which tool_call within the
187
+ # target assistant message
188
+ # - Cognos-specific signals (mood, model_source, tier/think/value, the
189
+ # 14 optional D-heads) layer on top. External data leaves them at
190
+ # their default values; Cognos sessions populate them.
191
+
192
+
193
+ @dataclass
194
+ class ConversationSample:
195
+ """One training example in standard chat-tool-call format.
196
+
197
+ The conversation + tools fields together form one row that any
198
+ OpenAI-style tool-use dataset can produce. Cognos signals layer
199
+ on top as optional fields — they stay None for external data, so
200
+ the trainer's optional_target machinery skips those heads cleanly.
201
+ """
202
+
203
+ # ── Standard-format payload ──────────────────────────────────────
204
+ conversation: list[ChatMessage] = field(default_factory=list)
205
+ tools: list[ToolDef] = field(default_factory=list)
206
+
207
+ # ── Targets ──────────────────────────────────────────────────────
208
+ # The tool name the assistant called at the target turn. Empty
209
+ # string or "<no_tool>" for sample-points where the assistant
210
+ # answered directly. Both forms are legal — the collate path
211
+ # normalises via the tool vocab.
212
+ target_tool: str = "<no_tool>"
213
+ target_arguments: str = "" # JSON; reserved for Phase 2
214
+ target_tool_call_index: int = 0 # which tool_call within the turn
215
+
216
+ # ── Cognos signals (optional) ────────────────────────────────────
217
+ mood: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 0.5])
218
+ image_paths: list[str] = field(default_factory=list)
219
+ model_source: str = "<unknown>"
220
+ surprise: float = 0.5
221
+ target_tier: int = 0
222
+ target_think: float = 0.0
223
+ target_value: float = 0.0
224
+ # Optional D-heads + Tier-1/2/4b heads (all may be None)
225
+ target_memory_write: float | None = None
226
+ target_cache_hit: float | None = None
227
+ target_permission: float | None = None
228
+ target_refusal: float | None = None
229
+ target_code_response: float | None = None
230
+ target_stall: float | None = None
231
+ target_difficulty: int | None = None
232
+ target_stop_iter: float | None = None
233
+ target_compaction: float | None = None
234
+ target_latency_s: float | None = None
235
+ target_token_budget: float | None = None
236
+ target_mood_pred: list[float] | None = None
237
+ target_subagent_spawn: float | None = None
238
+ target_reward_model: float | None = None
239
+ target_feature_success: float | None = None
240
+
241
+ # ── Construction helpers ─────────────────────────────────────────
242
+
243
+ @classmethod
244
+ def from_conversation(
245
+ cls,
246
+ conv: Conversation,
247
+ prefix_len: int,
248
+ tool_call: ToolCall,
249
+ tool_call_index: int = 0,
250
+ **cognos_signals: Any,
251
+ ) -> "ConversationSample":
252
+ """Build one sample by taking the first `prefix_len` messages
253
+ as context and the given tool_call as the target.
254
+
255
+ `prefix_len` is the number of messages BEFORE the assistant
256
+ message that emitted this tool_call. The prefix is what the
257
+ model sees; the target_tool is what it should predict.
258
+ Cognos signals (mood, model_source, etc.) are passed via
259
+ kwargs so the loader stays in control of which ones it sets.
260
+ """
261
+ prefix = list(conv.messages[:prefix_len])
262
+ return cls(
263
+ conversation=prefix,
264
+ tools=list(conv.tools),
265
+ target_tool=tool_call.name or "<no_tool>",
266
+ target_arguments=tool_call.arguments or "",
267
+ target_tool_call_index=tool_call_index,
268
+ **cognos_signals,
269
+ )
270
+
271
+
272
+ # ---------------------------------------------------------------------
273
+ # Conversion helpers used by collate & every loader
274
+ # ---------------------------------------------------------------------
275
+
276
+
277
+ def conversation_to_strings(conv: list[ChatMessage]) -> list[str]:
278
+ """Convert a list[ChatMessage] into the role-prefixed string list
279
+ the StateEncoder consumes.
280
+
281
+ Tool calls inside an assistant message are rendered as
282
+ ``"assistant: <content> calls: tool_name(args), ...".`` Tool result
283
+ messages (role="tool") are rendered as ``"tool {name}: <content>".``
284
+ Empty content + no calls is dropped (a no-op turn carries no signal).
285
+
286
+ The encoder operates on text; this helper is the boundary that
287
+ keeps the chat-schema clean upstream and the encoder API
288
+ unchanged downstream.
289
+ """
290
+ out: list[str] = []
291
+ for m in conv:
292
+ if m.role == "tool":
293
+ label = m.name or "tool"
294
+ content = (m.content or "").strip()
295
+ if content:
296
+ out.append(f"tool {label}: {content[:400]}")
297
+ continue
298
+ if m.tool_calls:
299
+ call_strs = []
300
+ for tc in m.tool_calls:
301
+ arg_preview = (tc.arguments or "").replace("\n", " ")[:60]
302
+ call_strs.append(f"{tc.name}({arg_preview})")
303
+ content = (m.content or "").strip()
304
+ text = (content + " " if content else "") + "calls: " + ", ".join(call_strs)
305
+ out.append(f"{m.role}: {text[:400]}")
306
+ continue
307
+ content = (m.content or "").strip()
308
+ if content:
309
+ out.append(f"{m.role}: {content[:400]}")
310
+ return out
311
+
312
+
313
+ def conversation_tool_history(conv: list[ChatMessage]) -> list[str]:
314
+ """Extract the sequence of tool names called in this conversation."""
315
+ history: list[str] = []
316
+ for m in conv:
317
+ if m.role == "assistant":
318
+ for tc in m.tool_calls:
319
+ if tc.name:
320
+ history.append(tc.name)
321
+ return history
322
+
323
+
324
+ # ---------------------------------------------------------------------
325
+ # Cognos-session loader — reads data/sessions/*.json into ConversationSamples
326
+ # ---------------------------------------------------------------------
327
+
328
+
329
+ def load_corpus_from_sessions(
330
+ sessions_dir: Path = Path("data/sessions"),
331
+ ) -> list[ConversationSample]:
332
+ """Convert each saved Cognos session into ConversationSample(s).
333
+
334
+ Sessions are stored as OpenAI-shape JSON. We emit one
335
+ ConversationSample per assistant tool call, with the prefix-of-
336
+ messages-before-the-call as context. Reward semantics:
337
+ metadata.reflection_scores[call_id] → target_value, else 0.5.
338
+
339
+ Sessions that recorded their tool registry under metadata.tools
340
+ have that list carried through; absent tools means empty (the
341
+ open-vocab head still trains, just without negative candidates).
342
+ """
343
+ out: list[ConversationSample] = []
344
+ if not sessions_dir.exists():
345
+ return out
346
+
347
+ for path in sorted(sessions_dir.glob("*.json")):
348
+ try:
349
+ data = json.loads(path.read_text())
350
+ except Exception:
351
+ continue
352
+ raw_msgs = data.get("messages") or []
353
+ scores = (data.get("metadata") or {}).get("reflection_scores") or {}
354
+ raw_tools = (data.get("metadata") or {}).get("tools") \
355
+ or data.get("tools") or []
356
+ try:
357
+ conv = Conversation.from_dict(
358
+ {"messages": raw_msgs, "tools": raw_tools},
359
+ )
360
+ except Exception:
361
+ continue
362
+
363
+ for msg_idx, msg in enumerate(conv.messages):
364
+ if msg.role != "assistant" or not msg.tool_calls:
365
+ continue
366
+ for call_idx, tc in enumerate(msg.tool_calls):
367
+ if not tc.name:
368
+ continue
369
+ score = float(scores.get(tc.id, 0.5)) if tc.id else 0.5
370
+ sample = ConversationSample.from_conversation(
371
+ conv=conv,
372
+ prefix_len=msg_idx,
373
+ tool_call=tc,
374
+ tool_call_index=call_idx,
375
+ model_source=(data.get("metadata") or {}).get(
376
+ "model_source", "<unknown>",
377
+ ),
378
+ target_tier=int(
379
+ (data.get("metadata") or {}).get("last_tier", 0),
380
+ ),
381
+ target_think=float(
382
+ (data.get("metadata") or {}).get("thinking_used", 0.0),
383
+ ),
384
+ target_value=score,
385
+ )
386
+ out.append(sample)
387
+ return out
388
+
389
+
390
+ # ---------------------------------------------------------------------
391
+ # OpenAI-format JSONL loader (external datasets: xLAM, ToolBench, etc.)
392
+ # ---------------------------------------------------------------------
393
+
394
+
395
+ def load_corpus_from_oai_jsonl(
396
+ path: Path,
397
+ model_source: str = "<external>",
398
+ max_rows: int | None = None,
399
+ ) -> list[ConversationSample]:
400
+ """Load a JSONL file in standard chat-tool-call format.
401
+
402
+ One line per conversation. Each conversation may contain multiple
403
+ assistant tool calls; we emit one ConversationSample per tool call,
404
+ with the prefix-of-messages-before-the-call as context.
405
+
406
+ The standard format works for: xLAM, ToolBench, ToolAlpaca, Gorilla,
407
+ API-Bank, and any OpenAI-fine-tuning-shaped corpus. Some datasets
408
+ use ``conversations`` instead of ``messages`` for the turn list —
409
+ ``Conversation.from_dict`` handles both.
410
+
411
+ Malformed rows are skipped, not raised — external corpora are noisy
412
+ and one bad row must not abort the whole training pass.
413
+
414
+ Args:
415
+ path: JSONL file path.
416
+ model_source: tagged on every emitted sample so the
417
+ source-embedding (Phase 2 of CAUDATE_EVOLUTION) can bias
418
+ predictions for external-data turns.
419
+ max_rows: cap for quick smoke runs; None = read everything.
420
+
421
+ Returns:
422
+ list of ConversationSample, one per assistant tool call.
423
+ """
424
+ out: list[ConversationSample] = []
425
+ if not path.exists():
426
+ logger.warning(f"oai jsonl not found: {path}")
427
+ return out
428
+
429
+ n_rows = 0
430
+ n_calls = 0
431
+ with path.open(encoding="utf-8") as f:
432
+ for line in f:
433
+ if max_rows is not None and n_rows >= max_rows:
434
+ break
435
+ line = line.strip()
436
+ if not line:
437
+ continue
438
+ try:
439
+ row = json.loads(line)
440
+ except Exception:
441
+ continue
442
+ try:
443
+ conv = Conversation.from_dict(row)
444
+ except Exception as e:
445
+ logger.debug(f"skipping malformed oai row: {e}")
446
+ continue
447
+ if not conv.messages:
448
+ continue
449
+ n_rows += 1
450
+
451
+ # Row-level target_tool: the cognos-toolbox amplified/replay
452
+ # rows are "predict the next tool given user turn + tool
453
+ # catalog" — no assistant message exists yet. Emit one
454
+ # sample with the row-level target and the full message
455
+ # list as context.
456
+ row_target = row.get("target_tool")
457
+ if row_target is not None:
458
+ row_meta = row.get("meta") or {}
459
+ row_source = (
460
+ str(row_meta.get("teacher_model"))
461
+ if isinstance(row_meta, dict) and row_meta.get("teacher_model")
462
+ else model_source
463
+ )
464
+ out.append(ConversationSample(
465
+ conversation=list(conv.messages),
466
+ tools=list(conv.tools),
467
+ target_tool=str(row_target) or "<no_tool>",
468
+ model_source=row_source,
469
+ ))
470
+ n_calls += 1
471
+ continue
472
+
473
+ # Standard OAI shape: emit one sample per tool_call within
474
+ # each assistant message. The prefix is everything *before*
475
+ # that assistant message — that's the state the model sees
476
+ # when deciding which tool to call.
477
+ for msg_idx, msg in enumerate(conv.messages):
478
+ if msg.role != "assistant" or not msg.tool_calls:
479
+ continue
480
+ for call_idx, tc in enumerate(msg.tool_calls):
481
+ if not tc.name:
482
+ continue
483
+ sample = ConversationSample.from_conversation(
484
+ conv=conv,
485
+ prefix_len=msg_idx,
486
+ tool_call=tc,
487
+ tool_call_index=call_idx,
488
+ model_source=model_source,
489
+ )
490
+ out.append(sample)
491
+ n_calls += 1
492
+
493
+ logger.info(
494
+ f"loaded {n_rows} rows / {n_calls} tool-call samples from {path}"
495
+ )
496
+ return out
497
+
498
+
499
+ # ---------------------------------------------------------------------
500
+ # HuggingFace dataset loader — direct path from HF Hub to ConversationSample
501
+ # ---------------------------------------------------------------------
502
+ #
503
+ # HF function-calling datasets don't share one schema. Common shapes:
504
+ #
505
+ # OpenAI chat: row = {"messages": [...], "tools": [...]}
506
+ # xLAM-style: row = {"query": "...", "tools": "[...]",
507
+ # "answers": "[{...tool_call...}]"} (JSON strings)
508
+ # Conversations: row = {"conversations": [...same as messages...],
509
+ # "tools": [...]}
510
+ #
511
+ # `load_corpus_from_hf` auto-detects which one applies per dataset. For
512
+ # anything off the beaten path, pass a `row_to_conversation` callable
513
+ # that returns a Conversation given a single row — bypasses detection
514
+ # entirely.
515
+
516
+
517
+ def _detect_hf_shape(row: dict[str, Any]) -> str:
518
+ """Return 'openai' | 'block_content' | 'conversations' | 'sharegpt'
519
+ | 'xlam' | 'unknown'.
520
+
521
+ The ``block_content`` shape is used by llamafactory's
522
+ reason-tool-use-demo-style corpora: ``messages[].content`` is a
523
+ list of ``{type, value}`` blocks (``text``/``reasoning``/
524
+ ``tool_call``) and ``tools`` is a JSON-encoded string rather than
525
+ a list. Detected ahead of plain ``openai`` because the row's
526
+ ``messages[0]`` still has a ``role`` key — without the early
527
+ check we'd silently drop every tool_call.
528
+ """
529
+ if isinstance(row.get("messages"), list) and row["messages"] \
530
+ and isinstance(row["messages"][0], dict) \
531
+ and "role" in row["messages"][0]:
532
+ # Distinguish OAI-shape (content: str|None) from block-typed
533
+ # content (content: list of {type, value} dicts).
534
+ first_content = row["messages"][0].get("content")
535
+ if (isinstance(first_content, list) and first_content
536
+ and isinstance(first_content[0], dict)
537
+ and "value" in first_content[0]):
538
+ return "block_content"
539
+ return "openai"
540
+ if isinstance(row.get("conversations"), list) and row["conversations"] \
541
+ and isinstance(row["conversations"][0], dict):
542
+ first = row["conversations"][0]
543
+ if "role" in first:
544
+ return "conversations"
545
+ if "from" in first and "value" in first:
546
+ return "sharegpt"
547
+ if "query" in row and "answers" in row:
548
+ return "xlam"
549
+ if isinstance(row.get("system"), str) and isinstance(row.get("chat"), str):
550
+ return "glaive"
551
+ return "unknown"
552
+
553
+
554
+ def _parse_block_content_row(row: dict[str, Any]) -> Conversation:
555
+ """Parse a llamafactory-style row where each message's ``content``
556
+ is a list of typed blocks.
557
+
558
+ Block-type mapping:
559
+ - ``text`` → appended to message content
560
+ - ``reasoning`` → appended to message content (the trunk sees
561
+ it as additional context; semantically a "chain of thought"
562
+ rendered into the same text channel)
563
+ - ``tool_call`` → parsed as JSON, becomes a ToolCall on this
564
+ assistant message
565
+
566
+ The row's ``tools`` field is a JSON string; we decode it before
567
+ handing to Conversation.from_dict. When parsing fails or the
568
+ field is empty, the tools list is left empty.
569
+ """
570
+ raw_msgs: list[dict[str, Any]] = []
571
+ for m in row.get("messages") or []:
572
+ if not isinstance(m, dict):
573
+ continue
574
+ role = str(m.get("role") or "user")
575
+ content_parts: list[str] = []
576
+ tool_calls: list[dict[str, Any]] = []
577
+ blocks = m.get("content")
578
+ if isinstance(blocks, list):
579
+ for blk in blocks:
580
+ if not isinstance(blk, dict):
581
+ continue
582
+ btype = str(blk.get("type") or "").lower()
583
+ value = blk.get("value") or blk.get("text") or ""
584
+ if btype == "tool_call":
585
+ # value is a JSON string like
586
+ # '{"name": "...", "arguments": {...}}'
587
+ try:
588
+ parsed = json.loads(value) if isinstance(value, str) else value
589
+ except Exception:
590
+ parsed = None
591
+ if isinstance(parsed, dict) and parsed.get("name"):
592
+ args = parsed.get("arguments")
593
+ if isinstance(args, dict):
594
+ args = json.dumps(args)
595
+ elif args is None:
596
+ args = ""
597
+ tool_calls.append({
598
+ "type": "function",
599
+ "function": {
600
+ "name": str(parsed["name"]),
601
+ "arguments": str(args),
602
+ },
603
+ })
604
+ elif btype in ("text", "reasoning"):
605
+ if isinstance(value, str) and value:
606
+ content_parts.append(value)
607
+ else:
608
+ # Unknown block types — preserve as text rather
609
+ # than drop. Better to feed noisy context than
610
+ # lose a labelled message.
611
+ if isinstance(value, str) and value:
612
+ content_parts.append(value)
613
+ elif isinstance(blocks, str):
614
+ content_parts.append(blocks)
615
+
616
+ msg_dict: dict[str, Any] = {
617
+ "role": role,
618
+ "content": " ".join(content_parts).strip(),
619
+ }
620
+ if tool_calls:
621
+ msg_dict["tool_calls"] = tool_calls
622
+ raw_msgs.append(msg_dict)
623
+
624
+ # tools is a JSON-encoded string in this shape; decode.
625
+ raw_tools_field = row.get("tools")
626
+ raw_tools: list[dict[str, Any]] = []
627
+ if isinstance(raw_tools_field, str) and raw_tools_field.strip():
628
+ try:
629
+ decoded = json.loads(raw_tools_field)
630
+ if isinstance(decoded, list):
631
+ raw_tools = [t for t in decoded if isinstance(t, dict)]
632
+ except Exception:
633
+ pass
634
+ elif isinstance(raw_tools_field, list):
635
+ raw_tools = [t for t in raw_tools_field if isinstance(t, dict)]
636
+
637
+ return Conversation.from_dict({"messages": raw_msgs, "tools": raw_tools})
638
+
639
+
640
+ # Map ShareGPT speaker tags to standard role names. ShareGPT uses
641
+ # `human`/`gpt` historically; modern variants also include `system`,
642
+ # `tool`, and `function`. Unknown tags fall back to "user" rather
643
+ # than silently dropping the turn — better to keep noisy context than
644
+ # lose a labelled message.
645
+ _SHAREGPT_ROLE_MAP: dict[str, str] = {
646
+ "human": "user",
647
+ "user": "user",
648
+ "gpt": "assistant",
649
+ "chatgpt": "assistant",
650
+ "bing": "assistant",
651
+ "assistant": "assistant",
652
+ "system": "system",
653
+ "tool": "tool",
654
+ "function": "tool",
655
+ "observation": "tool",
656
+ }
657
+
658
+
659
+ def _parse_sharegpt_row(row: dict[str, Any]) -> Conversation:
660
+ """Map a ShareGPT-format row (``conversations`` of ``{from, value}``)
661
+ to a standard ``Conversation``.
662
+
663
+ Used by datasets like Kimi-K2.5-Reasoning, Vicuna, WizardLM,
664
+ Alpaca-fc-* etc. These are typically reasoning corpora without
665
+ tool calls — the loader still emits one ConversationSample per
666
+ assistant turn (target_tool=``<no_tool>``) when
667
+ ``emit_no_tool_turns=True`` so the trunk gets training signal.
668
+ """
669
+ msgs: list[dict[str, Any]] = []
670
+ for m in row.get("conversations") or []:
671
+ if not isinstance(m, dict):
672
+ continue
673
+ role = _SHAREGPT_ROLE_MAP.get(str(m.get("from") or "").lower(), "user")
674
+ msgs.append({"role": role, "content": m.get("value") or ""})
675
+ return Conversation.from_dict({"messages": msgs, "tools": []})
676
+
677
+
678
+ def _parse_xlam_row(row: dict[str, Any]) -> Conversation:
679
+ """xLAM rows pack tools and the assistant's tool calls into JSON
680
+ strings on separate columns. Rebuild a 2-turn conversation:
681
+
682
+ user: <query> ─► assistant: <tool_calls=...>
683
+
684
+ so the standard pipeline sees one assistant tool-call turn. We
685
+ deliberately drop tool-result content (xLAM doesn't simulate them)
686
+ — the model only learns to PREDICT the call, which is what
687
+ Caudate's advisor head does anyway.
688
+ """
689
+ def _maybe_json(v):
690
+ if isinstance(v, str):
691
+ try:
692
+ return json.loads(v)
693
+ except Exception:
694
+ return []
695
+ return v or []
696
+
697
+ tools_raw = _maybe_json(row.get("tools"))
698
+ answers = _maybe_json(row.get("answers"))
699
+ if not isinstance(answers, list):
700
+ answers = []
701
+ if not isinstance(tools_raw, list):
702
+ tools_raw = []
703
+
704
+ tool_calls_raw: list[dict[str, Any]] = []
705
+ for a in answers:
706
+ if not isinstance(a, dict):
707
+ continue
708
+ # xLAM uses {"name": ..., "arguments": {...}}
709
+ tool_calls_raw.append({
710
+ "type": "function",
711
+ "function": {
712
+ "name": a.get("name") or a.get("tool") or "",
713
+ "arguments": a.get("arguments") or a.get("args") or {},
714
+ },
715
+ })
716
+
717
+ msgs: list[dict[str, Any]] = [
718
+ {"role": "user", "content": str(row.get("query") or "")},
719
+ {"role": "assistant", "content": "", "tool_calls": tool_calls_raw},
720
+ ]
721
+ # xLAM's tool defs already look like {"name", "description", "parameters"}
722
+ # — wrap to OpenAI shape and let Conversation.from_dict normalise.
723
+ tools_norm = [
724
+ {"type": "function", "function": t if isinstance(t, dict) else {}}
725
+ for t in tools_raw
726
+ ]
727
+ return Conversation.from_dict({"messages": msgs, "tools": tools_norm})
728
+
729
+
730
+ def load_corpus_from_hf(
731
+ dataset_name: str,
732
+ split: str = "train",
733
+ model_source: str | None = None,
734
+ max_rows: int | None = None,
735
+ row_to_conversation: "Any | None" = None,
736
+ streaming: bool = True,
737
+ emit_no_tool_turns: bool = True,
738
+ **load_dataset_kwargs: Any,
739
+ ) -> list[ConversationSample]:
740
+ """Stream a HuggingFace dataset into ConversationSamples.
741
+
742
+ Auto-detects the row layout. Supported out-of-the-box:
743
+
744
+ - **OpenAI chat** — ``messages``/``tools`` columns, dicts with
745
+ ``role``. Examples: anything saved via the OpenAI fine-tuning
746
+ format or Hermes-Function-Calling.
747
+ - **Conversations alias** — same as OpenAI but the column is
748
+ called ``conversations`` (ToolBench-style).
749
+ - **xLAM** — ``query``/``tools``/``answers`` columns where
750
+ ``tools`` and ``answers`` are JSON-encoded strings. Examples:
751
+ Salesforce/xlam-function-calling-60k.
752
+
753
+ For any other shape, pass ``row_to_conversation=fn`` where ``fn``
754
+ takes a row dict and returns a ``Conversation``. The auto-detector
755
+ is bypassed in that case.
756
+
757
+ Args:
758
+ dataset_name: HF Hub identifier, e.g.
759
+ ``"Salesforce/xlam-function-calling-60k"``.
760
+ split: split name (``"train"``, ``"validation"``, etc.)
761
+ model_source: tag set on every sample's ``model_source``.
762
+ Defaults to ``"hf:<dataset_name>"``.
763
+ max_rows: cap for smoke runs. None = read everything.
764
+ row_to_conversation: optional custom parser. Skip auto-detect.
765
+ streaming: if True (default), uses ``datasets.load_dataset(...,
766
+ streaming=True)`` to avoid materialising the whole dataset.
767
+ **load_dataset_kwargs: forwarded to ``load_dataset``.
768
+
769
+ Returns:
770
+ list[ConversationSample] — one per assistant tool call.
771
+ """
772
+ try:
773
+ from datasets import load_dataset # type: ignore
774
+ except ImportError:
775
+ raise RuntimeError(
776
+ "datasets library not installed. pip install datasets, or "
777
+ "export the corpus to JSONL and use load_corpus_from_oai_jsonl."
778
+ )
779
+
780
+ model_source = model_source or f"hf:{dataset_name}"
781
+ ds = load_dataset(
782
+ dataset_name, split=split, streaming=streaming, **load_dataset_kwargs,
783
+ )
784
+
785
+ out: list[ConversationSample] = []
786
+ detected_shape: str | None = None
787
+ n_rows = 0
788
+ n_calls = 0
789
+ for row in ds:
790
+ if max_rows is not None and n_rows >= max_rows:
791
+ break
792
+ if not isinstance(row, dict):
793
+ continue
794
+ n_rows += 1
795
+
796
+ # Build the Conversation. If the caller supplied a custom
797
+ # parser, use it. Otherwise detect once and reuse.
798
+ try:
799
+ if row_to_conversation is not None:
800
+ conv = row_to_conversation(row)
801
+ else:
802
+ if detected_shape is None:
803
+ detected_shape = _detect_hf_shape(row)
804
+ logger.info(
805
+ f"hf loader: detected shape '{detected_shape}' "
806
+ f"for {dataset_name}"
807
+ )
808
+ if detected_shape == "openai":
809
+ conv = Conversation.from_dict(row)
810
+ elif detected_shape == "block_content":
811
+ conv = _parse_block_content_row(row)
812
+ elif detected_shape == "conversations":
813
+ conv = Conversation.from_dict({
814
+ "messages": row["conversations"],
815
+ "tools": row.get("tools") or [],
816
+ })
817
+ elif detected_shape == "sharegpt":
818
+ conv = _parse_sharegpt_row(row)
819
+ elif detected_shape == "xlam":
820
+ conv = _parse_xlam_row(row)
821
+ else:
822
+ logger.warning(
823
+ f"hf loader: unknown shape for {dataset_name}; "
824
+ f"pass row_to_conversation= to handle it. "
825
+ f"keys={list(row.keys())[:8]}"
826
+ )
827
+ break
828
+ except Exception as e:
829
+ logger.debug(f"hf loader: skipping row {n_rows}: {e}")
830
+ continue
831
+ if not isinstance(conv, Conversation) or not conv.messages:
832
+ continue
833
+
834
+ # Detect "any assistant turn?" so reasoning-only corpora
835
+ # (no tool_calls) still produce training samples via the
836
+ # synthetic <no_tool> target when emit_no_tool_turns=True.
837
+ any_call_emitted = False
838
+ for msg_idx, msg in enumerate(conv.messages):
839
+ if msg.role != "assistant":
840
+ continue
841
+ if msg.tool_calls:
842
+ for call_idx, tc in enumerate(msg.tool_calls):
843
+ if not tc.name:
844
+ continue
845
+ out.append(ConversationSample.from_conversation(
846
+ conv=conv,
847
+ prefix_len=msg_idx,
848
+ tool_call=tc,
849
+ tool_call_index=call_idx,
850
+ model_source=model_source,
851
+ ))
852
+ n_calls += 1
853
+ any_call_emitted = True
854
+ elif emit_no_tool_turns and msg.content:
855
+ # Plain assistant response → "<no_tool>" target. The
856
+ # trunk still learns from the conversation prefix; the
857
+ # contrastive head learns when NOT to call anything.
858
+ out.append(ConversationSample(
859
+ conversation=list(conv.messages[:msg_idx]),
860
+ tools=list(conv.tools),
861
+ target_tool="<no_tool>",
862
+ target_arguments="",
863
+ target_tool_call_index=0,
864
+ model_source=model_source,
865
+ ))
866
+ n_calls += 1
867
+
868
+ logger.info(
869
+ f"hf loader: {dataset_name}#{split} → {n_rows} rows / "
870
+ f"{n_calls} tool-call samples"
871
+ )
872
+ return out
873
+
874
+
875
+ # ---------------------------------------------------------------------
876
+ # Local HuggingFace cache loader — read .arrow shards directly
877
+ # ---------------------------------------------------------------------
878
+
879
+
880
+ def load_corpus_from_local_arrow(
881
+ cache_dir: Path | str,
882
+ model_source: str | None = None,
883
+ max_rows: int | None = None,
884
+ row_to_conversation: "Any | None" = None,
885
+ emit_no_tool_turns: bool = True,
886
+ ) -> list[ConversationSample]:
887
+ """Load ConversationSamples from a local HuggingFace cache directory.
888
+
889
+ Unlike ``load_corpus_from_hf``, this skips the ``datasets`` library
890
+ entirely and reads the on-disk Arrow IPC shards directly. Useful
891
+ when the dataset is already cached and you don't need the lib's
892
+ streaming/version-resolution layer.
893
+
894
+ Expected directory layout (the layout HF datasets writes by
895
+ default): one or more ``*.arrow`` files in ``cache_dir``, each an
896
+ Arrow IPC stream produced by ``Dataset.save_to_disk`` or the
897
+ ``json`` builder. Schema is auto-detected per file's first row
898
+ using the same logic as ``load_corpus_from_hf``.
899
+ """
900
+ import pyarrow as pa
901
+ import pyarrow.ipc as ipc
902
+
903
+ cache_dir = Path(cache_dir)
904
+ if not cache_dir.exists():
905
+ logger.warning(f"local arrow cache not found: {cache_dir}")
906
+ return []
907
+
908
+ shards = sorted(cache_dir.glob("*.arrow"))
909
+ if not shards:
910
+ logger.warning(f"no .arrow shards in {cache_dir}")
911
+ return []
912
+
913
+ model_source = model_source or f"hf-local:{cache_dir.name}"
914
+ out: list[ConversationSample] = []
915
+ detected_shape: str | None = None
916
+ n_rows = 0
917
+ n_calls = 0
918
+ stop = False
919
+
920
+ for shard in shards:
921
+ if stop:
922
+ break
923
+ try:
924
+ with pa.memory_map(str(shard), "r") as src:
925
+ rb = ipc.open_stream(src)
926
+ table = rb.read_all()
927
+ except Exception as e:
928
+ logger.warning(f"failed to read {shard}: {e}")
929
+ continue
930
+
931
+ # Convert column-major arrow to row-major dicts as we go.
932
+ for row in table.to_pylist():
933
+ if max_rows is not None and n_rows >= max_rows:
934
+ stop = True
935
+ break
936
+ if not isinstance(row, dict):
937
+ continue
938
+ n_rows += 1
939
+
940
+ try:
941
+ if row_to_conversation is not None:
942
+ conv = row_to_conversation(row)
943
+ else:
944
+ if detected_shape is None:
945
+ detected_shape = _detect_hf_shape(row)
946
+ logger.info(
947
+ f"local arrow loader: detected shape "
948
+ f"'{detected_shape}' for {cache_dir.name}"
949
+ )
950
+ if detected_shape == "openai":
951
+ conv = Conversation.from_dict(row)
952
+ elif detected_shape == "block_content":
953
+ conv = _parse_block_content_row(row)
954
+ elif detected_shape == "conversations":
955
+ conv = Conversation.from_dict({
956
+ "messages": row["conversations"],
957
+ "tools": row.get("tools") or [],
958
+ })
959
+ elif detected_shape == "sharegpt":
960
+ conv = _parse_sharegpt_row(row)
961
+ elif detected_shape == "xlam":
962
+ conv = _parse_xlam_row(row)
963
+ elif detected_shape == "glaive":
964
+ conv = _parse_glaive_row(row)
965
+ else:
966
+ logger.warning(
967
+ f"local arrow loader: unknown shape; "
968
+ f"keys={list(row.keys())[:8]}"
969
+ )
970
+ stop = True
971
+ break
972
+ except Exception as e:
973
+ logger.debug(f"local arrow: skipping row {n_rows}: {e}")
974
+ continue
975
+
976
+ if not isinstance(conv, Conversation) or not conv.messages:
977
+ continue
978
+
979
+ # Prefer the per-row teacher model id when the dataset
980
+ # exposes one (Kimi-K2.5-Reasoning has meta.teacher_model;
981
+ # other corpora may not). Falls back to the loader-level
982
+ # tag. Source-conditioning works either way.
983
+ row_source = model_source
984
+ meta = row.get("meta")
985
+ if isinstance(meta, dict) and meta.get("teacher_model"):
986
+ row_source = str(meta["teacher_model"])
987
+
988
+ for msg_idx, msg in enumerate(conv.messages):
989
+ if msg.role != "assistant":
990
+ continue
991
+ if msg.tool_calls:
992
+ for call_idx, tc in enumerate(msg.tool_calls):
993
+ if not tc.name:
994
+ continue
995
+ out.append(ConversationSample.from_conversation(
996
+ conv=conv,
997
+ prefix_len=msg_idx,
998
+ tool_call=tc,
999
+ tool_call_index=call_idx,
1000
+ model_source=row_source,
1001
+ ))
1002
+ n_calls += 1
1003
+ elif emit_no_tool_turns and msg.content:
1004
+ out.append(ConversationSample(
1005
+ conversation=list(conv.messages[:msg_idx]),
1006
+ tools=list(conv.tools),
1007
+ target_tool="<no_tool>",
1008
+ target_arguments="",
1009
+ target_tool_call_index=0,
1010
+ model_source=row_source,
1011
+ ))
1012
+ n_calls += 1
1013
+
1014
+ logger.info(
1015
+ f"local arrow loader: {cache_dir.name} → {n_rows} rows / "
1016
+ f"{n_calls} samples"
1017
+ )
1018
+ return out
1019
+
1020
+
1021
+ # ---------------------------------------------------------------------
1022
+ # Glaive function-calling-v2 — custom text-embedded JSON shape
1023
+ # ---------------------------------------------------------------------
1024
+
1025
+ _GLAIVE_ROLE_MARKER_RE = re.compile(
1026
+ r"^(USER|ASSISTANT|FUNCTION RESPONSE):", re.MULTILINE
1027
+ )
1028
+ _GLAIVE_ROLE_NORM = {
1029
+ "USER": "user", "ASSISTANT": "assistant", "FUNCTION RESPONSE": "tool",
1030
+ }
1031
+
1032
+
1033
+ def _glaive_extract_json_objects(text: str) -> list[dict[str, Any]]:
1034
+ """Walk balanced-brace JSON objects out of free text.
1035
+
1036
+ Glaive's `system` field packs zero or more tool defs as separate
1037
+ JSON objects glued together with newlines and an English preamble.
1038
+ Standard json.loads can't parse that; we scan for top-level `{...}`
1039
+ spans and decode each independently. Strings (including those with
1040
+ backslash escapes) are honoured so braces inside string values
1041
+ don't break the depth counter.
1042
+ """
1043
+ out: list[dict[str, Any]] = []
1044
+ depth = 0
1045
+ start = -1
1046
+ in_str = False
1047
+ i = 0
1048
+ n = len(text)
1049
+ while i < n:
1050
+ c = text[i]
1051
+ if in_str:
1052
+ if c == "\\" and i + 1 < n:
1053
+ i += 2
1054
+ continue
1055
+ if c == '"':
1056
+ in_str = False
1057
+ else:
1058
+ if c == '"':
1059
+ in_str = True
1060
+ elif c == "{":
1061
+ if depth == 0:
1062
+ start = i
1063
+ depth += 1
1064
+ elif c == "}":
1065
+ depth -= 1
1066
+ if depth == 0 and start >= 0:
1067
+ blob = text[start:i + 1]
1068
+ try:
1069
+ parsed = json.loads(blob)
1070
+ if isinstance(parsed, dict):
1071
+ out.append(parsed)
1072
+ except Exception:
1073
+ pass
1074
+ start = -1
1075
+ i += 1
1076
+ return out
1077
+
1078
+
1079
+ def _glaive_parse_functioncall(blob: str) -> dict[str, Any] | None:
1080
+ """Extract {name, arguments} from a `<functioncall>` payload.
1081
+
1082
+ The payload looks like `{"name": "x", "arguments": '{...}'}` where
1083
+ `arguments` is a single-quoted JSON string — not valid JSON. Hand-roll
1084
+ the extraction so we don't have to rewrite the blob.
1085
+ """
1086
+ name_m = re.search(r'"name"\s*:\s*"([^"]+)"', blob)
1087
+ if not name_m:
1088
+ return None
1089
+ name = name_m.group(1)
1090
+ args = ""
1091
+ # Try single-quoted arguments first (the common case)
1092
+ args_m = re.search(
1093
+ r'"arguments"\s*:\s*\'(.*?)\'(?=\s*[,}])', blob, re.DOTALL,
1094
+ )
1095
+ if args_m:
1096
+ args = args_m.group(1)
1097
+ else:
1098
+ # Fallback: double-quoted string args
1099
+ args_m2 = re.search(r'"arguments"\s*:\s*"((?:[^"\\]|\\.)*)"', blob)
1100
+ if args_m2:
1101
+ args = args_m2.group(1).encode().decode("unicode_escape")
1102
+ else:
1103
+ # Or an inline dict
1104
+ args_m3 = re.search(
1105
+ r'"arguments"\s*:\s*(\{.*?\})\s*[,}]', blob, re.DOTALL,
1106
+ )
1107
+ if args_m3:
1108
+ args = args_m3.group(1)
1109
+ return {"name": name, "arguments": args}
1110
+
1111
+
1112
+ def _parse_glaive_row(row: dict[str, Any]) -> Conversation:
1113
+ """Parse a Glaive function-calling-v2 row → standard Conversation.
1114
+
1115
+ Row shape:
1116
+ ``{"system": "<preamble + JSON tool defs>", "chat": "<USER:...
1117
+ ASSISTANT:...<functioncall>{...}... FUNCTION RESPONSE:...>"}``
1118
+
1119
+ Tools come from the system field's embedded JSON; turns are split
1120
+ on the role markers and `<functioncall>` blocks inside assistant
1121
+ turns become ToolCalls. `<|endoftext|>` is stripped from content.
1122
+ """
1123
+ sys_text = str(row.get("system") or "")
1124
+ tool_defs = _glaive_extract_json_objects(sys_text)
1125
+ tools_norm: list[dict[str, Any]] = []
1126
+ for t in tool_defs:
1127
+ if not t.get("name"):
1128
+ continue
1129
+ tools_norm.append({"type": "function", "function": t})
1130
+
1131
+ chat = str(row.get("chat") or "")
1132
+ msgs: list[dict[str, Any]] = []
1133
+ matches = list(_GLAIVE_ROLE_MARKER_RE.finditer(chat))
1134
+ for i, m in enumerate(matches):
1135
+ role = _GLAIVE_ROLE_NORM.get(m.group(1))
1136
+ if not role:
1137
+ continue
1138
+ start = m.end()
1139
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(chat)
1140
+ chunk = chat[start:end].replace("<|endoftext|>", "").strip()
1141
+
1142
+ tool_calls: list[dict[str, Any]] = []
1143
+ if role == "assistant":
1144
+ # Pull each <functioncall> payload (balanced-brace) out.
1145
+ j = 0
1146
+ cleaned: list[str] = []
1147
+ text_start = 0
1148
+ while True:
1149
+ fc = chunk.find("<functioncall>", j)
1150
+ if fc < 0:
1151
+ cleaned.append(chunk[text_start:])
1152
+ break
1153
+ cleaned.append(chunk[text_start:fc])
1154
+ # Find balanced {...} after the marker
1155
+ k = chunk.find("{", fc)
1156
+ if k < 0:
1157
+ text_start = fc + len("<functioncall>")
1158
+ j = text_start
1159
+ continue
1160
+ depth = 0
1161
+ end_k = k
1162
+ in_str = False
1163
+ while end_k < len(chunk):
1164
+ cc = chunk[end_k]
1165
+ if in_str:
1166
+ if cc == "\\" and end_k + 1 < len(chunk):
1167
+ end_k += 2
1168
+ continue
1169
+ if cc == '"':
1170
+ in_str = False
1171
+ else:
1172
+ if cc == '"':
1173
+ in_str = True
1174
+ elif cc == "{":
1175
+ depth += 1
1176
+ elif cc == "}":
1177
+ depth -= 1
1178
+ if depth == 0:
1179
+ break
1180
+ end_k += 1
1181
+ blob = chunk[k:end_k + 1]
1182
+ parsed = _glaive_parse_functioncall(blob)
1183
+ if parsed and parsed.get("name"):
1184
+ tool_calls.append({
1185
+ "type": "function",
1186
+ "function": {
1187
+ "name": str(parsed["name"]),
1188
+ "arguments": str(parsed.get("arguments") or ""),
1189
+ },
1190
+ })
1191
+ text_start = end_k + 1
1192
+ j = text_start
1193
+
1194
+ chunk = " ".join(p for p in (s.strip() for s in cleaned) if p)
1195
+
1196
+ msg_dict: dict[str, Any] = {"role": role, "content": chunk}
1197
+ if tool_calls:
1198
+ msg_dict["tool_calls"] = tool_calls
1199
+ msgs.append(msg_dict)
1200
+
1201
+ return Conversation.from_dict({"messages": msgs, "tools": tools_norm})
1202
+
1203
+
1204
+ def load_corpus_from_glaive_json(
1205
+ json_path: Path | str,
1206
+ model_source: str | None = None,
1207
+ max_rows: int | None = None,
1208
+ emit_no_tool_turns: bool = True,
1209
+ ) -> list[ConversationSample]:
1210
+ """Load ConversationSamples from a Glaive function-calling JSON file.
1211
+
1212
+ The Glaive corpus is a single 259 MB JSON array — small enough to
1213
+ load whole into memory. Each row is parsed via ``_parse_glaive_row``,
1214
+ then emitted as one sample per assistant tool_call (open-vocab head
1215
+ trains) or one sample with ``target_tool='<no_tool>'`` for
1216
+ assistant chat turns (so the trunk still gets signal).
1217
+ """
1218
+ json_path = Path(json_path)
1219
+ if not json_path.exists():
1220
+ logger.warning(f"glaive json not found: {json_path}")
1221
+ return []
1222
+
1223
+ model_source = model_source or "glaive-function-calling-v2"
1224
+ out: list[ConversationSample] = []
1225
+
1226
+ try:
1227
+ with json_path.open() as f:
1228
+ data = json.load(f)
1229
+ except Exception as e:
1230
+ logger.warning(f"glaive json read failed: {e}")
1231
+ return []
1232
+ if not isinstance(data, list):
1233
+ logger.warning(f"glaive json not a list: {json_path}")
1234
+ return []
1235
+
1236
+ n_rows = 0
1237
+ n_calls = 0
1238
+ for row in data:
1239
+ if max_rows is not None and n_rows >= max_rows:
1240
+ break
1241
+ if not isinstance(row, dict):
1242
+ continue
1243
+ n_rows += 1
1244
+ try:
1245
+ conv = _parse_glaive_row(row)
1246
+ except Exception as e:
1247
+ logger.debug(f"glaive: skipping row {n_rows}: {e}")
1248
+ continue
1249
+ if not isinstance(conv, Conversation) or not conv.messages:
1250
+ continue
1251
+
1252
+ for msg_idx, msg in enumerate(conv.messages):
1253
+ if msg.role != "assistant":
1254
+ continue
1255
+ if msg.tool_calls:
1256
+ for call_idx, tc in enumerate(msg.tool_calls):
1257
+ if not tc.name:
1258
+ continue
1259
+ out.append(ConversationSample.from_conversation(
1260
+ conv=conv,
1261
+ prefix_len=msg_idx,
1262
+ tool_call=tc,
1263
+ tool_call_index=call_idx,
1264
+ model_source=model_source,
1265
+ ))
1266
+ n_calls += 1
1267
+ elif emit_no_tool_turns and msg.content:
1268
+ out.append(ConversationSample(
1269
+ conversation=list(conv.messages[:msg_idx]),
1270
+ tools=list(conv.tools),
1271
+ target_tool="<no_tool>",
1272
+ target_arguments="",
1273
+ target_tool_call_index=0,
1274
+ model_source=model_source,
1275
+ ))
1276
+ n_calls += 1
1277
+
1278
+ logger.info(
1279
+ f"glaive json loader: {json_path.name} → {n_rows} rows / "
1280
+ f"{n_calls} samples"
1281
+ )
1282
+ return out
1283
+
1284
+
1285
+ # ---------------------------------------------------------------------
1286
+ # Forge feature-outcome corpus (ADR 0006 Phase A)
1287
+ # ---------------------------------------------------------------------
1288
+
1289
+
1290
+ def load_corpus_from_feature_outcomes(
1291
+ path: Path = Path("data/nn/feature_outcomes.jsonl"),
1292
+ dedupe_by_feature: bool = True,
1293
+ ) -> list[ConversationSample]:
1294
+ """Convert each Forge feature-outcome row into one ConversationSample.
1295
+
1296
+ Schema (written by the orchestrator):
1297
+ feature_text, model_used, success (bool), n_turns, n_tool_calls,
1298
+ duration_s, project_id, feature_id, session_id, extras{}.
1299
+
1300
+ The feature text becomes the only "message" (one-turn batch); the
1301
+ model id flows through `model_source` so the trunk's source-
1302
+ embedding can bias predictions per teacher model. All other state
1303
+ (mood, tool history, images) is zero — Forge predictions are made
1304
+ *before* a session runs, so we deliberately do not condition on
1305
+ in-turn signals that don't exist yet at prediction time.
1306
+
1307
+ Other heads' targets stay None — those heads are skipped for these
1308
+ samples (HeadSpec.optional_target). The only label this corpus
1309
+ carries is `target_feature_success`.
1310
+
1311
+ ``dedupe_by_feature`` (default True) collapses retries — every
1312
+ (project_id, feature_id) pair contributes ONE sample with the
1313
+ final outcome (last row by ts). Without this, a stuck feature that
1314
+ retried 200 times before being parked dominates the corpus and
1315
+ pushes the head to "always predict fail". The orchestrator's
1316
+ decision question is "is this feature solvable?", not "will this
1317
+ particular retry succeed?", so the final outcome is the right
1318
+ label. Set False if you actually want per-session granularity
1319
+ (e.g. for predict_failure_repeats analysis).
1320
+
1321
+ Rows with missing required fields are dropped silently. Skipping
1322
+ bad rows is the right behaviour because the JSONL is append-only
1323
+ and a single malformed write must not bork the next training run.
1324
+ """
1325
+ if not path.exists():
1326
+ return []
1327
+ try:
1328
+ lines = path.read_text(encoding="utf-8").splitlines()
1329
+ except OSError:
1330
+ return []
1331
+
1332
+ raw_rows: list[dict[str, Any]] = []
1333
+ for line in lines:
1334
+ line = line.strip()
1335
+ if not line:
1336
+ continue
1337
+ try:
1338
+ row = json.loads(line)
1339
+ except Exception:
1340
+ continue
1341
+ text = (row.get("feature_text") or "").strip()
1342
+ success = row.get("success")
1343
+ if not text or success is None:
1344
+ continue
1345
+ raw_rows.append(row)
1346
+
1347
+ if dedupe_by_feature:
1348
+ # Keep the latest row per (project_id, feature_id). If feature_id
1349
+ # is missing (older rows) we fall back to (project_id, feature_text)
1350
+ # so a feature with no id still gets one slot, not one slot per
1351
+ # retry.
1352
+ latest: dict[tuple, dict[str, Any]] = {}
1353
+ for row in raw_rows:
1354
+ key = (
1355
+ row.get("project_id"),
1356
+ row.get("feature_id") or row.get("feature_text"),
1357
+ )
1358
+ prev = latest.get(key)
1359
+ if prev is None or row.get("ts", 0) > prev.get("ts", 0):
1360
+ latest[key] = row
1361
+ raw_rows = list(latest.values())
1362
+
1363
+ out: list[ConversationSample] = []
1364
+ for row in raw_rows:
1365
+ text = (row.get("feature_text") or "").strip()
1366
+ success = bool(row.get("success"))
1367
+ # Feature description as a one-turn user message — the model's
1368
+ # job here is "given just the feature text, predict success".
1369
+ # No tool calls, no tool list — feature_success is the only
1370
+ # head this row labels.
1371
+ out.append(ConversationSample(
1372
+ conversation=[ChatMessage(role="user", content=text)],
1373
+ tools=[],
1374
+ target_tool="<no_tool>",
1375
+ target_value=1.0 if success else 0.0,
1376
+ model_source=row.get("model_used") or "<unknown>",
1377
+ target_feature_success=1.0 if success else 0.0,
1378
+ ))
1379
+ return out
1380
+
1381
+
1382
+ # ---------------------------------------------------------------------
1383
+ # Replay buffer
1384
+ # ---------------------------------------------------------------------
1385
+
1386
+
1387
+ @dataclass
1388
+ class ReplayBuffer:
1389
+ """Circular buffer for online training samples."""
1390
+
1391
+ capacity: int = 8192
1392
+ _buf: collections.deque = field(default_factory=lambda: collections.deque(maxlen=8192))
1393
+
1394
+ def __post_init__(self) -> None:
1395
+ # Re-bind maxlen to honor the constructor arg
1396
+ if self._buf.maxlen != self.capacity:
1397
+ self._buf = collections.deque(self._buf, maxlen=self.capacity)
1398
+
1399
+ def push(self, sample: ConversationSample) -> None:
1400
+ self._buf.append(sample)
1401
+
1402
+ def __len__(self) -> int:
1403
+ return len(self._buf)
1404
+
1405
+ def sample(self, n: int) -> list[ConversationSample]:
1406
+ if n <= 0 or not self._buf:
1407
+ return []
1408
+ buf = list(self._buf)
1409
+ # Surprise-weighted sampling (Prioritized Experience Replay):
1410
+ # samples Caudate guessed badly on appear more often. Weight is
1411
+ # clamped above zero so easy samples still occasionally show up
1412
+ # (otherwise the model forgets what it already knows). Sampling
1413
+ # is with replacement, so n may exceed buffer size — high-
1414
+ # priority items can repeat within a batch.
1415
+ weights = [max(0.05, float(getattr(s, "surprise", 0.5))) for s in buf]
1416
+ return random.choices(buf, weights=weights, k=n)
1417
+
1418
+ def all(self) -> list[ConversationSample]:
1419
+ return list(self._buf)
1420
+
1421
+
1422
+ # ---------------------------------------------------------------------
1423
+ # Batch collation
1424
+ # ---------------------------------------------------------------------
1425
+
1426
+
1427
+ _NO_TOOL_DESCRIPTION = (
1428
+ "answer the user directly without calling any tool"
1429
+ )
1430
+
1431
+
1432
+ def _format_tool_spec(t: "ToolDef") -> str:
1433
+ """Render a ToolDef as the "{name}: {description}" string that the
1434
+ contrastive tool head's text encoder consumes. Capped at 240 chars
1435
+ to keep batched embedding tractable on big tool registries."""
1436
+ name = (t.name or "").strip()
1437
+ desc = (t.description or "").strip()
1438
+ if desc:
1439
+ return f"{name}: {desc}"[:240]
1440
+ return name[:240]
1441
+
1442
+
1443
+ def collate(
1444
+ samples: list[ConversationSample],
1445
+ cfg: NNConfig,
1446
+ vocab: ToolVocab,
1447
+ source_vocab: "SourceVocab | None" = None,
1448
+ ) -> dict[str, Any]:
1449
+ """Turn a list of ConversationSamples into model-ready tensors.
1450
+
1451
+ The chat-schema conversation gets flattened into role-prefixed
1452
+ strings by ``conversation_to_strings``; tool history is derived
1453
+ from past assistant tool_calls. Tool candidates per sample are
1454
+ surfaced as ``tool_specs`` (a list of "{name}: {description}"
1455
+ strings) so the contrastive tool head can score them. The first
1456
+ candidate slot is always a synthetic ``<no_tool>`` entry — that's
1457
+ how "the assistant answered directly without calling anything"
1458
+ becomes a real class the head can predict.
1459
+ """
1460
+ B = len(samples)
1461
+ msg_window = cfg.msg_window
1462
+ hist_window = cfg.history_window
1463
+ img_window = cfg.image_window if cfg.use_vision else 0
1464
+ K = cfg.max_tools_per_sample
1465
+
1466
+ messages: list[list[str]] = []
1467
+ image_paths: list[list[str]] = []
1468
+ tool_ids = torch.zeros((B, hist_window), dtype=torch.long)
1469
+ mood = torch.zeros((B, cfg.mood_dim), dtype=torch.float32)
1470
+
1471
+ target_tier = torch.zeros(B, dtype=torch.long)
1472
+ target_think = torch.zeros(B, dtype=torch.float32)
1473
+ target_value = torch.zeros(B, dtype=torch.float32)
1474
+ source_id = torch.zeros(B, dtype=torch.long)
1475
+
1476
+ # Contrastive tool head inputs.
1477
+ # tool_specs: per-sample list of "{name}: {description}" strings,
1478
+ # first slot is always "<no_tool>: answer directly..."
1479
+ # target_tool_idx: index of the called tool within that list, or -1
1480
+ # if the called tool isn't among the candidates (skip-this-sample).
1481
+ tool_specs: list[list[str]] = []
1482
+ target_tool_idx = torch.full((B,), -1, dtype=torch.long)
1483
+ tool_target_valid = torch.zeros(B, dtype=torch.bool)
1484
+
1485
+ for i, s in enumerate(samples):
1486
+ # Flatten the role-tagged conversation into text the encoder
1487
+ # already speaks. Pad/truncate to msg_window.
1488
+ msg_strs = conversation_to_strings(s.conversation)[-msg_window:]
1489
+ if len(msg_strs) < msg_window:
1490
+ msg_strs = [""] * (msg_window - len(msg_strs)) + msg_strs
1491
+ messages.append(msg_strs)
1492
+
1493
+ if cfg.use_vision:
1494
+ imgs = list(s.image_paths)[-img_window:]
1495
+ if len(imgs) < img_window:
1496
+ imgs = [""] * (img_window - len(imgs)) + imgs
1497
+ image_paths.append(imgs)
1498
+ else:
1499
+ image_paths.append([])
1500
+
1501
+ # Tool history = the sequence of tool names called in this
1502
+ # conversation prefix. Stable ordering preserves recency.
1503
+ # Clamp every id below tool_vocab_size — the encoder's tool
1504
+ # embedding table is fixed-size and a vocab that grows past
1505
+ # it (public corpora can hit 10K+ tools) would otherwise
1506
+ # crash with index-out-of-range. Modulo wraps high ids into
1507
+ # the table with some collisions; better than missing data.
1508
+ tool_hist = conversation_tool_history(s.conversation)[-hist_window:]
1509
+ hist = [vocab.get(t) % cfg.tool_vocab_size for t in tool_hist]
1510
+ if len(hist) < hist_window:
1511
+ hist = [vocab.SPECIAL["<pad>"]] * (hist_window - len(hist)) + hist
1512
+ tool_ids[i] = torch.tensor(hist, dtype=torch.long)
1513
+
1514
+ mood[i] = torch.tensor(s.mood[: cfg.mood_dim], dtype=torch.float32)
1515
+ target_tier[i] = int(s.target_tier)
1516
+ target_think[i] = float(s.target_think)
1517
+ target_value[i] = float(s.target_value)
1518
+ if source_vocab is not None:
1519
+ # `add` is idempotent — first occurrence assigns an id, later
1520
+ # ones look it up. Falls back to <unknown> if the cap fills.
1521
+ source_id[i] = source_vocab.add(getattr(s, "model_source", "<unknown>"))
1522
+ else:
1523
+ source_id[i] = 0 # <unknown> — keeps the model in baseline mode
1524
+
1525
+ # Build the candidate tool list for this sample. The synthetic
1526
+ # <no_tool> entry occupies slot 0 so the head always has a
1527
+ # "don't call anything" option available — without it the
1528
+ # softmax would be forced to nominate something even when the
1529
+ # right answer is "respond directly".
1530
+ sample_tools = list(s.tools)[: max(0, K - 1)]
1531
+ specs = [f"<no_tool>: {_NO_TOOL_DESCRIPTION}"]
1532
+ for t in sample_tools:
1533
+ specs.append(_format_tool_spec(t))
1534
+ tool_specs.append(specs)
1535
+
1536
+ # Resolve target_tool to an index in `specs`. Match by name.
1537
+ target_name = (s.target_tool or "<no_tool>")
1538
+ if target_name == "<no_tool>" or target_name == "":
1539
+ target_tool_idx[i] = 0
1540
+ tool_target_valid[i] = True
1541
+ else:
1542
+ found = -1
1543
+ for k, t in enumerate(sample_tools):
1544
+ if t.name == target_name:
1545
+ found = k + 1 # +1 for the <no_tool> prefix slot
1546
+ break
1547
+ if found >= 0:
1548
+ target_tool_idx[i] = found
1549
+ tool_target_valid[i] = True
1550
+ # Else: target_tool isn't among candidates → leave
1551
+ # target_tool_idx=-1, tool_target_valid=False; the trainer
1552
+ # skips this sample's contribution to the tool loss.
1553
+
1554
+ out = {
1555
+ "messages": messages,
1556
+ "image_paths": image_paths,
1557
+ "tool_ids": tool_ids,
1558
+ "mood": mood,
1559
+ "target_tier": target_tier,
1560
+ "target_think": target_think,
1561
+ "target_value": target_value,
1562
+ "source_id": source_id,
1563
+ # Contrastive tool head inputs/targets
1564
+ "tool_specs": tool_specs,
1565
+ "target_tool_idx": target_tool_idx,
1566
+ "tool_target_valid": tool_target_valid,
1567
+ }
1568
+
1569
+ # Extended-head targets — only emit if EVERY sample in this batch
1570
+ # has the label. A partially-labeled batch would have to use a fake
1571
+ # value (e.g. 0.0) for the unlabeled rows, which would train the
1572
+ # head against a synthetic "no" signal and corrupt learning. So
1573
+ # if even one sample is missing the label, we drop the field
1574
+ # entirely and the trainer skips that head for this batch (via
1575
+ # HeadSpec.optional_target).
1576
+ _SCALAR_BCE_FIELDS = (
1577
+ "target_memory_write", "target_cache_hit", "target_permission",
1578
+ "target_refusal", "target_code_response", "target_stall",
1579
+ "target_stop_iter", "target_compaction",
1580
+ "target_subagent_spawn",
1581
+ "target_feature_success",
1582
+ )
1583
+ _SCALAR_MSE_FIELDS = (
1584
+ "target_latency_s", "target_token_budget", "target_reward_model",
1585
+ )
1586
+ for field_name in _SCALAR_BCE_FIELDS + _SCALAR_MSE_FIELDS:
1587
+ vals = [getattr(s, field_name, None) for s in samples]
1588
+ if all(v is not None for v in vals):
1589
+ out[field_name] = torch.tensor(
1590
+ [float(v) for v in vals], dtype=torch.float32
1591
+ )
1592
+
1593
+ # difficulty is a 3-class CE target → int64
1594
+ diffs = [getattr(s, "target_difficulty", None) for s in samples]
1595
+ if all(d is not None for d in diffs):
1596
+ out["target_difficulty"] = torch.tensor(
1597
+ [int(d) for d in diffs], dtype=torch.long
1598
+ )
1599
+
1600
+ # mood_pred is a 4-D vector
1601
+ moods = [getattr(s, "target_mood_pred", None) for s in samples]
1602
+ if all(m is not None for m in moods):
1603
+ out["target_mood_pred"] = torch.tensor(
1604
+ [list(m) + [0.0]*(4-len(m)) if len(m) < 4 else list(m)[:4] for m in moods],
1605
+ dtype=torch.float32,
1606
+ )
1607
+ return out
1608
+
1609
+
1610
+ def split_train_eval(
1611
+ samples: list[ConversationSample], eval_ratio: float, seed: int,
1612
+ ) -> tuple[list[ConversationSample], list[ConversationSample]]:
1613
+ rng = random.Random(seed)
1614
+ shuffled = list(samples)
1615
+ rng.shuffle(shuffled)
1616
+ n_eval = max(1, int(len(shuffled) * eval_ratio)) if shuffled else 0
1617
+ return shuffled[n_eval:], shuffled[:n_eval]
1618
+
1619
+
1620
+ def iter_batches(
1621
+ samples: list[ConversationSample],
1622
+ batch_size: int,
1623
+ cfg: NNConfig,
1624
+ vocab: ToolVocab,
1625
+ shuffle: bool = True,
1626
+ source_vocab: "SourceVocab | None" = None,
1627
+ ) -> Iterator[dict[str, Any]]:
1628
+ if shuffle:
1629
+ samples = list(samples)
1630
+ random.shuffle(samples)
1631
+ for i in range(0, len(samples), batch_size):
1632
+ chunk = samples[i:i + batch_size]
1633
+ if not chunk:
1634
+ continue
1635
+ yield collate(chunk, cfg, vocab, source_vocab=source_vocab)