@thispointon/kondi-chat 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +556 -0
  3. package/bin/kondi-chat +56 -0
  4. package/bin/kondi-chat.js +72 -0
  5. package/package.json +55 -0
  6. package/scripts/demo.tape +49 -0
  7. package/scripts/postinstall.cjs +103 -0
  8. package/src/audit/analytics.ts +261 -0
  9. package/src/audit/ledger.ts +253 -0
  10. package/src/audit/telemetry.ts +165 -0
  11. package/src/cli/backend.ts +675 -0
  12. package/src/cli/commands.ts +419 -0
  13. package/src/cli/help.ts +182 -0
  14. package/src/cli/submit-helpers.ts +159 -0
  15. package/src/cli/submit.ts +539 -0
  16. package/src/cli/wizard.ts +121 -0
  17. package/src/context/bootstrap.ts +138 -0
  18. package/src/context/budget.ts +100 -0
  19. package/src/context/manager.ts +666 -0
  20. package/src/context/memory.ts +160 -0
  21. package/src/context/preflight.ts +176 -0
  22. package/src/context/project-brain.ts +101 -0
  23. package/src/context/receipts.ts +108 -0
  24. package/src/context/skills.ts +154 -0
  25. package/src/context/symbol-index.ts +240 -0
  26. package/src/council/profiles.ts +137 -0
  27. package/src/council/tool.ts +138 -0
  28. package/src/council-engine/cli/council-artifacts.ts +230 -0
  29. package/src/council-engine/cli/council-config.ts +178 -0
  30. package/src/council-engine/cli/council-session-export.ts +116 -0
  31. package/src/council-engine/cli/kondi.ts +98 -0
  32. package/src/council-engine/cli/llm-caller.ts +229 -0
  33. package/src/council-engine/cli/localStorage-shim.ts +119 -0
  34. package/src/council-engine/cli/node-platform.ts +68 -0
  35. package/src/council-engine/cli/run-council.ts +481 -0
  36. package/src/council-engine/cli/run-pipeline.ts +772 -0
  37. package/src/council-engine/cli/session-export.ts +153 -0
  38. package/src/council-engine/configs/councils/analysis.json +101 -0
  39. package/src/council-engine/configs/councils/code-planning.json +86 -0
  40. package/src/council-engine/configs/councils/coding.json +89 -0
  41. package/src/council-engine/configs/councils/debate.json +97 -0
  42. package/src/council-engine/configs/councils/solo-claude.json +34 -0
  43. package/src/council-engine/configs/councils/solo-gpt.json +34 -0
  44. package/src/council-engine/council/coding-orchestrator.ts +1205 -0
  45. package/src/council-engine/council/context-bootstrap.ts +147 -0
  46. package/src/council-engine/council/context-inspection.ts +42 -0
  47. package/src/council-engine/council/context-store.ts +763 -0
  48. package/src/council-engine/council/deliberation-orchestrator.ts +2762 -0
  49. package/src/council-engine/council/factory.ts +164 -0
  50. package/src/council-engine/council/index.ts +201 -0
  51. package/src/council-engine/council/ledger-store.ts +438 -0
  52. package/src/council-engine/council/prompts.ts +1689 -0
  53. package/src/council-engine/council/storage-cleanup.ts +164 -0
  54. package/src/council-engine/council/store.ts +1110 -0
  55. package/src/council-engine/council/synthesis.ts +291 -0
  56. package/src/council-engine/council/types.ts +845 -0
  57. package/src/council-engine/council/validation.ts +613 -0
  58. package/src/council-engine/pipeline/build-detect.ts +73 -0
  59. package/src/council-engine/pipeline/executor.ts +1048 -0
  60. package/src/council-engine/pipeline/index.ts +9 -0
  61. package/src/council-engine/pipeline/install-detect.ts +84 -0
  62. package/src/council-engine/pipeline/memory-store.ts +182 -0
  63. package/src/council-engine/pipeline/output-parsers.ts +146 -0
  64. package/src/council-engine/pipeline/run-output.ts +149 -0
  65. package/src/council-engine/pipeline/session-import.ts +177 -0
  66. package/src/council-engine/pipeline/store.ts +753 -0
  67. package/src/council-engine/pipeline/test-detect.ts +82 -0
  68. package/src/council-engine/pipeline/types.ts +401 -0
  69. package/src/council-engine/services/deliberationSummary.ts +114 -0
  70. package/src/council-engine/tsconfig.json +16 -0
  71. package/src/council-engine/types/mcp.ts +122 -0
  72. package/src/council-engine/utils/filterTools.ts +73 -0
  73. package/src/engine/apply.ts +238 -0
  74. package/src/engine/checkpoints.ts +237 -0
  75. package/src/engine/consultants.ts +347 -0
  76. package/src/engine/diff.ts +171 -0
  77. package/src/engine/errors.ts +102 -0
  78. package/src/engine/git-tools.ts +246 -0
  79. package/src/engine/hooks.ts +181 -0
  80. package/src/engine/loop-guard.ts +155 -0
  81. package/src/engine/permissions.ts +293 -0
  82. package/src/engine/pipeline.ts +376 -0
  83. package/src/engine/sub-agents.ts +133 -0
  84. package/src/engine/task-card.ts +185 -0
  85. package/src/engine/task-router.ts +256 -0
  86. package/src/engine/task-store.ts +86 -0
  87. package/src/engine/tools.ts +783 -0
  88. package/src/engine/verify.ts +111 -0
  89. package/src/mcp/client.ts +225 -0
  90. package/src/mcp/config.ts +120 -0
  91. package/src/mcp/tool-manager.ts +192 -0
  92. package/src/mcp/types.ts +61 -0
  93. package/src/providers/llm-caller.ts +943 -0
  94. package/src/providers/rate-limiter.ts +238 -0
  95. package/src/router/NOTES.md +28 -0
  96. package/src/router/collector.ts +474 -0
  97. package/src/router/embeddings.ts +286 -0
  98. package/src/router/index.ts +299 -0
  99. package/src/router/intent-router.ts +225 -0
  100. package/src/router/nn-router.ts +205 -0
  101. package/src/router/profiles.ts +309 -0
  102. package/src/router/registry.ts +565 -0
  103. package/src/router/rules.ts +274 -0
  104. package/src/router/train.py +408 -0
  105. package/src/session/store.ts +211 -0
  106. package/src/test-utils/mock-llm.ts +39 -0
  107. package/src/types.ts +322 -0
  108. package/src/web/manager.ts +311 -0
@@ -0,0 +1,274 @@
1
+ /**
2
+ * Rule-Based Router — the "teacher" that makes routing decisions.
3
+ *
4
+ * Maps (phase, task_kind) to the best model from the registry.
5
+ * This is the initial routing strategy. Every decision it makes
6
+ * gets logged by the collector, which eventually trains an NN
7
+ * to replace it.
8
+ *
9
+ * Strategy:
10
+ * - discuss/dispatch/reflect → best reasoning model
11
+ * - execute → cheapest coding model (promote on failure)
12
+ * - compress/state_update → cheapest summarization model
13
+ * - verify → no LLM call (local tools)
14
+ */
15
+
16
+ import type { LedgerPhase, ProviderId, TaskKind } from '../types.ts';
17
+ import { ModelRegistry, type ModelCapability, type ModelEntry } from './registry.ts';
18
+ import type { BudgetProfile } from './profiles.ts';
19
+
20
+ /**
21
+ * Minimal subset of ModelRegistry used by the routing strategy helpers.
22
+ * A scoped view (see scopedRegistry) implements this without inheriting.
23
+ */
24
+ interface RegistryView {
25
+ getEnabled(): ModelEntry[];
26
+ getByCapability(capability: ModelCapability): ModelEntry[];
27
+ getCheapest(capability: ModelCapability): ModelEntry | undefined;
28
+ getBest(capability: ModelCapability): ModelEntry | undefined;
29
+ }
30
+
31
+ function scopedRegistry(registry: ModelRegistry, providers: ProviderId[]): RegistryView {
32
+ const allowed = new Set(providers);
33
+ const filter = (m: ModelEntry) => allowed.has(m.provider);
34
+ return {
35
+ getEnabled: () => registry.getEnabled().filter(filter),
36
+ getByCapability: (cap) => registry.getByCapability(cap).filter(filter),
37
+ getCheapest: (cap) => registry.getByCapability(cap).filter(filter)[0],
38
+ getBest: (cap) => {
39
+ const list = registry.getByCapability(cap).filter(filter);
40
+ return list[list.length - 1];
41
+ },
42
+ };
43
+ }
44
+
45
+ // ---------------------------------------------------------------------------
46
+ // Route decision
47
+ // ---------------------------------------------------------------------------
48
+
49
+ export interface RouteDecision {
50
+ model: ModelEntry;
51
+ reason: string;
52
+ /** Was this a promotion (retry after failure)? */
53
+ promoted: boolean;
54
+ }
55
+
56
+ // ---------------------------------------------------------------------------
57
+ // Rule-based router
58
+ // ---------------------------------------------------------------------------
59
+
60
+ export class RuleRouter {
61
+ private registry: ModelRegistry;
62
+ private profile?: BudgetProfile;
63
+ private override?: ModelEntry;
64
+
65
+ constructor(registry: ModelRegistry) {
66
+ this.registry = registry;
67
+ }
68
+
69
+ /** Set the active budget profile — changes model selection priorities */
70
+ setProfile(profile: BudgetProfile): void {
71
+ this.profile = profile;
72
+ }
73
+
74
+ /** Registry view scoped to the profile's declared models (via rolePinning). */
75
+ private reg(): RegistryView {
76
+ if (!this.profile?.rolePinning) return this.registry;
77
+ const providers = new Set<ProviderId>();
78
+ for (const modelId of Object.values(this.profile.rolePinning)) {
79
+ const m = this.registry.getById(modelId);
80
+ if (m) providers.add(m.provider);
81
+ }
82
+ return providers.size > 0
83
+ ? scopedRegistry(this.registry, [...providers])
84
+ : this.registry;
85
+ }
86
+
87
+ /** Force all routing to a specific model. Pass undefined to clear. */
88
+ setOverride(model: ModelEntry | undefined): void {
89
+ this.override = model;
90
+ }
91
+
92
+ /** Get the current override, if any */
93
+ getOverride(): ModelEntry | undefined {
94
+ return this.override;
95
+ }
96
+
97
+ /**
98
+ * Select the best model for a given phase and optional task context.
99
+ *
100
+ * @param phase Pipeline phase (discuss, dispatch, execute, etc.)
101
+ * @param taskKind Type of task being executed (if in a task context)
102
+ * @param failures Number of prior failures for this task (triggers promotion)
103
+ * @param promotionThreshold Failures before promoting to best model
104
+ */
105
+ select(
106
+ phase: LedgerPhase,
107
+ taskKind?: TaskKind,
108
+ failures = 0,
109
+ promotionThreshold = 2,
110
+ ): RouteDecision {
111
+ // Manual override — user forced a specific model with /use
112
+ if (this.override) {
113
+ return { model: this.override, reason: `override: ${this.override.alias || this.override.id}`, promoted: false };
114
+ }
115
+
116
+ const promoted = failures >= promotionThreshold;
117
+
118
+ // Promotion overrides: if the cheap model failed enough, use the best
119
+ if (promoted && (phase === 'execute')) {
120
+ const best = this.reg().getBest('coding');
121
+ if (best) {
122
+ return { model: best, reason: `promoted after ${failures} failures`, promoted: true };
123
+ }
124
+ }
125
+
126
+ // Phase-based routing
127
+ switch (phase) {
128
+ case 'discuss':
129
+ case 'dispatch':
130
+ return this.selectForReasoning();
131
+
132
+ case 'reflect':
133
+ return this.selectForReview();
134
+
135
+ case 'execute':
136
+ return this.selectForExecution(taskKind);
137
+
138
+ case 'compress':
139
+ case 'state_update':
140
+ return this.selectForCheap();
141
+
142
+ default:
143
+ return this.selectForReasoning();
144
+ }
145
+ }
146
+
147
+ // -------------------------------------------------------------------------
148
+ // Strategy helpers
149
+ // -------------------------------------------------------------------------
150
+
151
+ private selectForReasoning(): RouteDecision {
152
+ // Use profile preferences if available
153
+ if (this.profile) {
154
+ const prefs = this.profile.planningPreference;
155
+ const selector = this.profile.preferLocal
156
+ ? (cap: string) => this.reg().getCheapest(cap)
157
+ : (cap: string) => this.reg().getBest(cap);
158
+ for (const cap of prefs) {
159
+ const model = selector(cap);
160
+ if (model) return { model, reason: `${this.profile.name}: ${cap}`, promoted: false };
161
+ }
162
+ }
163
+
164
+ // Default: best planning model
165
+ const model = this.reg().getBest('planning')
166
+ || this.reg().getBest('reasoning')
167
+ || this.reg().getBest('coding')
168
+ || this.fallback();
169
+ return { model, reason: 'reasoning phase — best planner', promoted: false };
170
+ }
171
+
172
+ private selectForExecution(taskKind?: TaskKind): RouteDecision {
173
+ // Use profile preferences if available
174
+ if (this.profile) {
175
+ // Try direct task kind match first
176
+ if (taskKind) {
177
+ const directMatch = this.profile.preferLocal
178
+ ? this.reg().getCheapest(taskKind)
179
+ : this.reg().getByCapability(taskKind)[0];
180
+ if (directMatch) {
181
+ return { model: directMatch, reason: `${this.profile.name}: ${taskKind} match`, promoted: false };
182
+ }
183
+ }
184
+
185
+ // Then profile's execution preferences
186
+ const prefs = this.profile.executionPreference;
187
+ for (const cap of prefs) {
188
+ const model = this.reg().getCheapest(cap);
189
+ if (model) return { model, reason: `${this.profile.name}: ${cap}`, promoted: false };
190
+ }
191
+ }
192
+
193
+ // Default: try to match task kind directly to a capability
194
+ if (taskKind) {
195
+ const directMatch = this.reg().getCheapest(taskKind);
196
+ if (directMatch) {
197
+ return { model: directMatch, reason: `${taskKind} task — direct capability match`, promoted: false };
198
+ }
199
+ }
200
+
201
+ // Known task kind → capability mapping
202
+ switch (taskKind) {
203
+ case 'analysis':
204
+ case 'code-review':
205
+ const reviewer = this.reg().getBest('code-review')
206
+ || this.reg().getBest('analysis')
207
+ || this.reg().getBest('reasoning')
208
+ || this.fallback();
209
+ return { model: reviewer, reason: `${taskKind} task — best reviewer`, promoted: false };
210
+
211
+ case 'marketing':
212
+ case 'writing':
213
+ const writer = this.reg().getCheapest('marketing')
214
+ || this.reg().getCheapest('writing')
215
+ || this.reg().getCheapest('general')
216
+ || this.fallback();
217
+ return { model: writer, reason: `${taskKind} task — best writer`, promoted: false };
218
+
219
+ case 'test':
220
+ case 'fix':
221
+ const fixer = this.reg().getCheapest('fast-coding')
222
+ || this.reg().getCheapest('coding')
223
+ || this.fallback();
224
+ return { model: fixer, reason: `${taskKind} task — cheapest coder`, promoted: false };
225
+
226
+ case 'implementation':
227
+ case 'refactor':
228
+ case 'refactoring':
229
+ const coder = this.reg().getCheapest('coding')
230
+ || this.fallback();
231
+ return { model: coder, reason: `${taskKind} task — cheapest coder`, promoted: false };
232
+
233
+ default:
234
+ // Unknown kind — use cheapest coding model as default for execution
235
+ const defaultModel = this.reg().getCheapest('coding')
236
+ || this.reg().getCheapest('general')
237
+ || this.fallback();
238
+ return { model: defaultModel, reason: `${taskKind || 'unknown'} task — default`, promoted: false };
239
+ }
240
+ }
241
+
242
+ private selectForReview(): RouteDecision {
243
+ // Honor the profile's reviewPreference capability list when set;
244
+ // otherwise fall back to the reasoning path so reflect still gets a
245
+ // strong model. Profiles with `reviewPreference: []` (e.g. `cheap`)
246
+ // skip directly to the reasoning fallback — that's the documented
247
+ // "no separate reviewer" behavior.
248
+ if (this.profile && this.profile.reviewPreference.length > 0) {
249
+ const selector = this.profile.preferLocal
250
+ ? (cap: string) => this.reg().getCheapest(cap)
251
+ : (cap: string) => this.reg().getBest(cap);
252
+ for (const cap of this.profile.reviewPreference) {
253
+ const model = selector(cap);
254
+ if (model) return { model, reason: `${this.profile.name}: review ${cap}`, promoted: false };
255
+ }
256
+ }
257
+ return this.selectForReasoning();
258
+ }
259
+
260
+ private selectForCheap(): RouteDecision {
261
+ const model = this.reg().getCheapest('summarization')
262
+ || this.reg().getCheapest('general')
263
+ || this.fallback();
264
+ return { model, reason: 'cheap phase — summarization', promoted: false };
265
+ }
266
+
267
+ private fallback(): ModelEntry {
268
+ const enabled = this.reg().getEnabled();
269
+ if (enabled.length === 0) {
270
+ throw new Error('No models enabled in registry. Run /models to configure.');
271
+ }
272
+ return enabled[0];
273
+ }
274
+ }
@@ -0,0 +1,408 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train a lightweight neural network router from kondi-chat routing data.
4
+
5
+ Consumes the JSONL training data collected by the orchestrator's rule-based
6
+ router and trains an NN that predicts which model will succeed for a given
7
+ task. The orchestrator is the teacher; this NN is the student.
8
+
9
+ Usage:
10
+ python src/router/train.py [--data-dir .kondi-chat] [--out router_model.json]
11
+
12
+ The trained model is exported as JSON (weights + config) so it can be
13
+ loaded in TypeScript without a Python runtime.
14
+ """
15
+
16
+ import json
17
+ import sys
18
+ import argparse
19
+ from pathlib import Path
20
+ import numpy as np
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Data loading
24
+ # ---------------------------------------------------------------------------
25
+
26
+ def load_samples(data_dir: str) -> list[dict]:
27
+ """Load routing samples from the collector's JSONL file."""
28
+ path = Path(data_dir) / "routing-data.jsonl"
29
+ if not path.exists():
30
+ print(f"No routing data found at {path}")
31
+ sys.exit(1)
32
+
33
+ samples = []
34
+ for line in path.read_text().splitlines():
35
+ if line.strip():
36
+ samples.append(json.loads(line))
37
+
38
+ print(f"Loaded {len(samples)} routing samples")
39
+ return samples
40
+
41
+
42
+ def encode_features(samples: list[dict]) -> tuple[np.ndarray, dict]:
43
+ """
44
+ Encode samples into feature vectors. Features are dynamically
45
+ discovered from the data — no hardcoded categories.
46
+
47
+ If samples have embeddings, they are concatenated with structured
48
+ features: [embedding(768D) | phase_onehot | kind_onehot | scalars]
49
+
50
+ Returns (feature_matrix, feature_info) where feature_info contains
51
+ the encoding schema needed for inference.
52
+ """
53
+ # Discover categories from data
54
+ phases = sorted(set(s["phase"] for s in samples))
55
+ task_kinds = sorted(set(s.get("taskKind") or "none" for s in samples))
56
+
57
+ # Check for embeddings
58
+ samples_with_embeddings = [s for s in samples if s.get("embedding")]
59
+ has_embeddings = len(samples_with_embeddings) > len(samples) * 0.5 # Need >50%
60
+ embedding_dim = 0
61
+
62
+ if has_embeddings:
63
+ embedding_dim = len(samples_with_embeddings[0]["embedding"])
64
+ print(f"Using embeddings: {embedding_dim}D ({len(samples_with_embeddings)}/{len(samples)} samples)")
65
+ else:
66
+ if samples_with_embeddings:
67
+ print(f"Too few embeddings ({len(samples_with_embeddings)}/{len(samples)}), using structured features only")
68
+ else:
69
+ print("No embeddings found, using structured features only")
70
+
71
+ structured_names = (
72
+ [f"phase:{p}" for p in phases] +
73
+ [f"kind:{k}" for k in task_kinds] +
74
+ ["prompt_length", "context_tokens", "failures"]
75
+ )
76
+
77
+ feature_names = (
78
+ ([f"emb_{i}" for i in range(embedding_dim)] if has_embeddings else []) +
79
+ structured_names
80
+ )
81
+
82
+ features = []
83
+ for s in samples:
84
+ # Structured features
85
+ phase_vec = [1 if p == s["phase"] else 0 for p in phases]
86
+ kind_vec = [1 if k == (s.get("taskKind") or "none") else 0 for k in task_kinds]
87
+ prompt_norm = min(s.get("promptLength", 0) / 10_000, 1.0)
88
+ context_norm = min(s.get("contextTokens", 0) / 100_000, 1.0)
89
+ failure_norm = min(s.get("failures", 0) / 5.0, 1.0)
90
+ structured = phase_vec + kind_vec + [prompt_norm, context_norm, failure_norm]
91
+
92
+ if has_embeddings:
93
+ emb = s.get("embedding") or [0.0] * embedding_dim
94
+ features.append(emb + structured)
95
+ else:
96
+ features.append(structured)
97
+
98
+ feature_info = {
99
+ "phases": phases,
100
+ "taskKinds": task_kinds,
101
+ "featureNames": feature_names,
102
+ "inputDim": len(feature_names),
103
+ "embeddingDim": embedding_dim,
104
+ "hasEmbeddings": has_embeddings,
105
+ }
106
+
107
+ return np.array(features, dtype=np.float32), feature_info
108
+
109
+
110
+ def encode_labels(samples: list[dict], model_names: list[str]) -> np.ndarray:
111
+ """
112
+ Encode labels: for each sample, 1 if the model succeeded, 0 if it
113
+ failed, -1 if we don't know (model wasn't tried on this sample).
114
+ """
115
+ labels = []
116
+ for s in samples:
117
+ row = []
118
+ for name in model_names:
119
+ if s["modelId"] == name:
120
+ row.append(1.0 if s.get("succeeded", False) else 0.0)
121
+ else:
122
+ row.append(-1.0) # Unknown — exclude from loss
123
+ labels.append(row)
124
+ return np.array(labels, dtype=np.float32)
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # Neural Network (numpy only — no PyTorch dependency)
129
+ # ---------------------------------------------------------------------------
130
+
131
+ def relu(x: np.ndarray) -> np.ndarray:
132
+ return np.maximum(0, x)
133
+
134
+ def sigmoid(x: np.ndarray) -> np.ndarray:
135
+ return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
136
+
137
+ def relu_derivative(x: np.ndarray) -> np.ndarray:
138
+ return (x > 0).astype(np.float32)
139
+
140
+
141
+ class SimpleNN:
142
+ """
143
+ Multi-layer neural network trained with backprop.
144
+ No PyTorch needed — pure numpy for minimal dependencies.
145
+ """
146
+
147
+ def __init__(self, layer_dims: list[int]):
148
+ """layer_dims: [input_dim, hidden1, hidden2, ..., output_dim]"""
149
+ self.weights: list[np.ndarray] = []
150
+ self.biases: list[np.ndarray] = []
151
+ for i in range(len(layer_dims) - 1):
152
+ # Xavier initialization
153
+ scale = np.sqrt(2.0 / layer_dims[i])
154
+ self.weights.append(np.random.randn(layer_dims[i], layer_dims[i + 1]).astype(np.float32) * scale)
155
+ self.biases.append(np.zeros(layer_dims[i + 1], dtype=np.float32))
156
+
157
+ def forward(self, x: np.ndarray) -> np.ndarray:
158
+ """Forward pass. Hidden layers use ReLU, output uses sigmoid."""
159
+ self._activations = [x]
160
+ self._pre_activations = []
161
+ for i, (w, b) in enumerate(zip(self.weights, self.biases)):
162
+ z = x @ w + b
163
+ self._pre_activations.append(z)
164
+ if i < len(self.weights) - 1:
165
+ x = relu(z)
166
+ else:
167
+ x = sigmoid(z)
168
+ self._activations.append(x)
169
+ return x
170
+
171
+ def backward(self, y_true: np.ndarray, mask: np.ndarray, lr: float = 0.001):
172
+ """
173
+ Backprop with masked loss (ignore samples where mask == 0).
174
+ mask: same shape as y_true, 1 where we have labels, 0 where unknown.
175
+ """
176
+ n = max(mask.sum(), 1)
177
+ y_pred = self._activations[-1]
178
+
179
+ # Output gradient (BCE with mask)
180
+ delta = (y_pred - y_true) * mask / n
181
+
182
+ for i in range(len(self.weights) - 1, -1, -1):
183
+ a_prev = self._activations[i]
184
+ dw = a_prev.T @ delta
185
+ db = delta.sum(axis=0)
186
+
187
+ self.weights[i] -= lr * dw
188
+ self.biases[i] -= lr * db
189
+
190
+ if i > 0:
191
+ delta = (delta @ self.weights[i].T) * relu_derivative(self._pre_activations[i - 1])
192
+
193
+ def predict(self, x: np.ndarray) -> np.ndarray:
194
+ return self.forward(x)
195
+
196
+ def to_json(self) -> dict:
197
+ """Export weights as JSON-serializable dict."""
198
+ return {
199
+ "weights": [w.tolist() for w in self.weights],
200
+ "biases": [b.tolist() for b in self.biases],
201
+ "layerDims": [self.weights[0].shape[0]] + [w.shape[1] for w in self.weights],
202
+ }
203
+
204
+ @classmethod
205
+ def from_json(cls, data: dict) -> "SimpleNN":
206
+ dims = data["layerDims"]
207
+ nn = cls(dims)
208
+ nn.weights = [np.array(w, dtype=np.float32) for w in data["weights"]]
209
+ nn.biases = [np.array(b, dtype=np.float32) for b in data["biases"]]
210
+ return nn
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # Training
215
+ # ---------------------------------------------------------------------------
216
+
217
+ def train(
218
+ X: np.ndarray,
219
+ Y: np.ndarray,
220
+ model_names: list[str],
221
+ hidden_dims: list[int] = [64, 32],
222
+ epochs: int = 200,
223
+ lr: float = 0.005,
224
+ val_split: float = 0.15,
225
+ ) -> tuple[SimpleNN, dict]:
226
+ """Train the router NN and return (model, metrics)."""
227
+
228
+ # Train/val split
229
+ n = len(X)
230
+ idx = np.random.permutation(n)
231
+ val_n = int(n * val_split)
232
+ val_idx, train_idx = idx[:val_n], idx[val_n:]
233
+
234
+ X_train, X_val = X[train_idx], X[val_idx]
235
+ Y_train, Y_val = Y[train_idx], Y[val_idx]
236
+
237
+ # Mask: 1 where we have labels, 0 where unknown (-1)
238
+ mask_train = (Y_train >= 0).astype(np.float32)
239
+ mask_val = (Y_val >= 0).astype(np.float32)
240
+
241
+ # Replace -1 with 0 for computation (masked out anyway)
242
+ Y_train_clean = np.maximum(Y_train, 0)
243
+ Y_val_clean = np.maximum(Y_val, 0)
244
+
245
+ input_dim = X.shape[1]
246
+ output_dim = Y.shape[1]
247
+ layer_dims = [input_dim] + hidden_dims + [output_dim]
248
+
249
+ nn = SimpleNN(layer_dims)
250
+ best_val_loss = float("inf")
251
+ best_weights = None
252
+ patience = 20
253
+ patience_counter = 0
254
+
255
+ print(f"\nTraining: {len(X_train)} train, {len(X_val)} val")
256
+ print(f"Architecture: {layer_dims}")
257
+ print(f"Models: {model_names}\n")
258
+
259
+ for epoch in range(epochs):
260
+ # Forward + backward on train
261
+ pred = nn.forward(X_train)
262
+ nn.backward(Y_train_clean, mask_train, lr=lr)
263
+
264
+ if (epoch + 1) % 20 == 0 or epoch == 0:
265
+ # Compute masked BCE loss on validation
266
+ val_pred = nn.predict(X_val)
267
+ eps = 1e-8
268
+ bce = -(Y_val_clean * np.log(val_pred + eps) + (1 - Y_val_clean) * np.log(1 - val_pred + eps))
269
+ val_loss = (bce * mask_val).sum() / max(mask_val.sum(), 1)
270
+
271
+ print(f" Epoch {epoch + 1:4d}/{epochs}: val_loss={val_loss:.4f}")
272
+
273
+ if val_loss < best_val_loss:
274
+ best_val_loss = val_loss
275
+ best_weights = ([w.copy() for w in nn.weights], [b.copy() for b in nn.biases])
276
+ patience_counter = 0
277
+ else:
278
+ patience_counter += 1
279
+ if patience_counter >= patience:
280
+ print(f" Early stopping at epoch {epoch + 1}")
281
+ break
282
+
283
+ # Restore best weights
284
+ if best_weights:
285
+ nn.weights, nn.biases = best_weights
286
+
287
+ # Evaluate
288
+ val_pred = nn.predict(X_val)
289
+ metrics = evaluate(val_pred, Y_val, mask_val, model_names)
290
+
291
+ return nn, metrics
292
+
293
+
294
+ def evaluate(
295
+ pred: np.ndarray,
296
+ y_true: np.ndarray,
297
+ mask: np.ndarray,
298
+ model_names: list[str],
299
+ ) -> dict:
300
+ """Evaluate the trained model."""
301
+ results = {}
302
+
303
+ for i, name in enumerate(model_names):
304
+ m = mask[:, i] > 0
305
+ if m.sum() == 0:
306
+ continue
307
+ y = y_true[m, i]
308
+ p = pred[m, i]
309
+ preds = (p >= 0.5).astype(float)
310
+ acc = (preds == y).mean()
311
+ results[name] = {
312
+ "accuracy": float(acc),
313
+ "samples": int(m.sum()),
314
+ "positive_rate": float(y.mean()),
315
+ }
316
+
317
+ # System accuracy: pick model with highest predicted prob
318
+ chosen_idx = np.argmax(pred, axis=1)
319
+ # Only count where we have a label for the chosen model
320
+ correct = 0
321
+ counted = 0
322
+ for i in range(len(pred)):
323
+ ci = chosen_idx[i]
324
+ if mask[i, ci] > 0:
325
+ correct += y_true[i, ci]
326
+ counted += 1
327
+
328
+ results["_system"] = {
329
+ "accuracy": float(correct / max(counted, 1)),
330
+ "evaluated": int(counted),
331
+ }
332
+
333
+ return results
334
+
335
+
336
+ # ---------------------------------------------------------------------------
337
+ # Main
338
+ # ---------------------------------------------------------------------------
339
+
340
+ def main():
341
+ parser = argparse.ArgumentParser(description="Train kondi-chat routing NN")
342
+ parser.add_argument("--data-dir", default=".kondi-chat", help="Directory with routing-data.jsonl")
343
+ parser.add_argument("--out", default=".kondi-chat/router-model.json", help="Output model path")
344
+ parser.add_argument("--hidden", default="auto", help="Hidden layer dimensions (comma-separated, or 'auto')")
345
+ parser.add_argument("--epochs", type=int, default=200)
346
+ parser.add_argument("--lr", type=float, default=0.005)
347
+ args = parser.parse_args()
348
+
349
+ # Load and encode data
350
+ samples = load_samples(args.data_dir)
351
+ model_names = sorted(set(s["modelId"] for s in samples))
352
+
353
+ if len(model_names) < 2:
354
+ print(f"Need samples from at least 2 models to train. Found: {model_names}")
355
+ sys.exit(1)
356
+
357
+ X, feature_info = encode_features(samples)
358
+ Y = encode_labels(samples, model_names)
359
+
360
+ print(f"Features: {X.shape[1]} dimensions")
361
+ print(f"Models to route between: {model_names}")
362
+
363
+ for i, name in enumerate(model_names):
364
+ known = (Y[:, i] >= 0).sum()
365
+ positive = (Y[:, i] == 1).sum()
366
+ print(f" {name}: {known} samples, {positive} successes ({positive/max(known,1)*100:.0f}%)")
367
+
368
+ # Train
369
+ if args.hidden == "auto":
370
+ # Auto-size: larger hidden layers when embeddings are present
371
+ if feature_info.get("hasEmbeddings"):
372
+ hidden_dims = [256, 128]
373
+ else:
374
+ hidden_dims = [64, 32]
375
+ print(f"Auto-selected hidden dims: {hidden_dims}")
376
+ else:
377
+ hidden_dims = [int(x) for x in args.hidden.split(",")]
378
+ nn, metrics = train(X, Y, model_names, hidden_dims=hidden_dims, epochs=args.epochs, lr=args.lr)
379
+
380
+ # Print results
381
+ print("\nResults:")
382
+ print("=" * 60)
383
+ for name, m in metrics.items():
384
+ if name == "_system":
385
+ print(f" System accuracy: {m['accuracy']:.3f} ({m['evaluated']} samples)")
386
+ else:
387
+ print(f" {name:35s}: acc={m['accuracy']:.3f} (n={m['samples']}, pos_rate={m['positive_rate']:.2f})")
388
+
389
+ # Export
390
+ out_path = Path(args.out)
391
+ out_path.parent.mkdir(parents=True, exist_ok=True)
392
+
393
+ model_data = {
394
+ "nn": nn.to_json(),
395
+ "featureInfo": feature_info,
396
+ "modelNames": model_names,
397
+ "metrics": metrics,
398
+ "trainedAt": str(np.datetime64("now")),
399
+ "sampleCount": len(samples),
400
+ }
401
+
402
+ out_path.write_text(json.dumps(model_data, indent=2))
403
+ print(f"\nModel saved to {out_path}")
404
+ print(f"Load in TypeScript with: JSON.parse(readFileSync('{out_path}', 'utf-8'))")
405
+
406
+
407
+ if __name__ == "__main__":
408
+ main()