@framers/agentos-ext-topicality 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +23 -0
- package/dist/TopicDriftTracker.d.ts +152 -0
- package/dist/TopicDriftTracker.d.ts.map +1 -0
- package/dist/TopicDriftTracker.js +265 -0
- package/dist/TopicDriftTracker.js.map +1 -0
- package/dist/TopicEmbeddingIndex.d.ts +160 -0
- package/dist/TopicEmbeddingIndex.d.ts.map +1 -0
- package/dist/TopicEmbeddingIndex.js +291 -0
- package/dist/TopicEmbeddingIndex.js.map +1 -0
- package/dist/TopicalityGuardrail.d.ts +196 -0
- package/dist/TopicalityGuardrail.d.ts.map +1 -0
- package/dist/TopicalityGuardrail.js +426 -0
- package/dist/TopicalityGuardrail.js.map +1 -0
- package/dist/index.d.ts +87 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +259 -0
- package/dist/index.js.map +1 -0
- package/dist/tools/CheckTopicTool.d.ts +148 -0
- package/dist/tools/CheckTopicTool.d.ts.map +1 -0
- package/dist/tools/CheckTopicTool.js +202 -0
- package/dist/tools/CheckTopicTool.js.map +1 -0
- package/dist/types.d.ts +358 -0
- package/dist/types.d.ts.map +1 -0
- package/dist/types.js +215 -0
- package/dist/types.js.map +1 -0
- package/package.json +42 -0
- package/src/TopicDriftTracker.ts +307 -0
- package/src/TopicEmbeddingIndex.ts +346 -0
- package/src/TopicalityGuardrail.ts +521 -0
- package/src/index.ts +302 -0
- package/src/tools/CheckTopicTool.ts +296 -0
- package/src/types.ts +565 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @fileoverview TopicDriftTracker — session-level EMA drift detection for topicality guardrails.
|
|
3
|
+
*
|
|
4
|
+
* This module tracks whether a conversation session is gradually drifting away
|
|
5
|
+
* from its allowed topics by maintaining a per-session **running embedding**
|
|
6
|
+
* that is updated with each new message using an Exponential Moving Average
|
|
7
|
+
* (EMA).
|
|
8
|
+
*
|
|
9
|
+
* ### Why EMA?
|
|
10
|
+
* A simple "last-message" check is too noisy: a single off-topic message in an
|
|
11
|
+
* otherwise on-topic conversation should not trigger a hard block. EMA
|
|
12
|
+
* smooths the signal so that sustained drift is detected while brief tangents
|
|
13
|
+
* are tolerated.
|
|
14
|
+
*
|
|
15
|
+
* The update formula is:
|
|
16
|
+
* ```
|
|
17
|
+
* running[i] = alpha * message[i] + (1 - alpha) * running[i]
|
|
18
|
+
* ```
|
|
19
|
+
* A smaller `alpha` means the running vector changes slowly (long memory);
|
|
20
|
+
* a larger `alpha` means it reacts quickly to each new message.
|
|
21
|
+
*
|
|
22
|
+
* ### Drift decision
|
|
23
|
+
* After each EMA update the tracker checks whether the running embedding is
|
|
24
|
+
* "on-topic" by calling {@link TopicEmbeddingIndex.isOnTopicByVector}. If the
|
|
25
|
+
* check fails the `driftStreak` counter is incremented; if it passes the
|
|
26
|
+
* streak resets to zero. When `driftStreak >= driftStreakLimit` the result
|
|
27
|
+
* `driftLimitExceeded` is set to `true`.
|
|
28
|
+
*
|
|
29
|
+
* ### Session management
|
|
30
|
+
* Sessions are stored in an in-memory `Map<sessionId, TopicState>`. To
|
|
31
|
+
* prevent unbounded memory growth:
|
|
32
|
+
* - Stale sessions (inactive for > `sessionTimeoutMs`) are pruned lazily
|
|
33
|
+
* whenever `map.size > maxSessions` at the start of an `update()` call.
|
|
34
|
+
* - Callers can force a full clear via {@link clear}.
|
|
35
|
+
*
|
|
36
|
+
* @module topicality/TopicDriftTracker
|
|
37
|
+
*/
|
|
38
|
+
|
|
39
|
+
import type { TopicEmbeddingIndex } from './TopicEmbeddingIndex';
|
|
40
|
+
import {
|
|
41
|
+
DEFAULT_DRIFT_CONFIG,
|
|
42
|
+
type DriftConfig,
|
|
43
|
+
type DriftResult,
|
|
44
|
+
type TopicMatch,
|
|
45
|
+
type TopicState,
|
|
46
|
+
} from './types';
|
|
47
|
+
|
|
48
|
+
// ---------------------------------------------------------------------------
|
|
49
|
+
// TopicDriftTracker
|
|
50
|
+
// ---------------------------------------------------------------------------
|
|
51
|
+
|
|
52
|
+
/**
|
|
53
|
+
* Tracks per-session topic drift using EMA-blended running embeddings.
|
|
54
|
+
*
|
|
55
|
+
* Instantiate once per agent process (not per conversation) since the tracker
|
|
56
|
+
* manages many concurrent sessions internally.
|
|
57
|
+
*
|
|
58
|
+
* @example
|
|
59
|
+
* ```ts
|
|
60
|
+
* const tracker = new TopicDriftTracker({ alpha: 0.4, driftStreakLimit: 2 });
|
|
61
|
+
*
|
|
62
|
+
* // In your message handler:
|
|
63
|
+
* const embedding = await embed(userMessage);
|
|
64
|
+
* const result = tracker.update('session-abc', embedding, allowedIndex);
|
|
65
|
+
*
|
|
66
|
+
* if (result.driftLimitExceeded) {
|
|
67
|
+
* // Take configured action: redirect, warn, or block.
|
|
68
|
+
* }
|
|
69
|
+
* ```
|
|
70
|
+
*/
|
|
71
|
+
export class TopicDriftTracker {
|
|
72
|
+
/** Fully resolved drift configuration (defaults merged with caller overrides). */
|
|
73
|
+
private readonly config: DriftConfig;
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
* In-memory session store.
|
|
77
|
+
* Key: caller-supplied session ID (e.g. conversation UUID).
|
|
78
|
+
* Value: mutable {@link TopicState} for that session.
|
|
79
|
+
*/
|
|
80
|
+
private readonly sessions: Map<string, TopicState> = new Map();
|
|
81
|
+
|
|
82
|
+
// -------------------------------------------------------------------------
|
|
83
|
+
// Constructor
|
|
84
|
+
// -------------------------------------------------------------------------
|
|
85
|
+
|
|
86
|
+
/**
|
|
87
|
+
* Creates a new `TopicDriftTracker`.
|
|
88
|
+
*
|
|
89
|
+
* @param config - Optional partial override of {@link DEFAULT_DRIFT_CONFIG}.
|
|
90
|
+
* Any fields not provided fall back to their default values. Pass an
|
|
91
|
+
* empty object `{}` or omit entirely to use all defaults.
|
|
92
|
+
*
|
|
93
|
+
* @example
|
|
94
|
+
* ```ts
|
|
95
|
+
* // Use defaults
|
|
96
|
+
* const tracker = new TopicDriftTracker();
|
|
97
|
+
*
|
|
98
|
+
* // Override only alpha and streakLimit
|
|
99
|
+
* const strictTracker = new TopicDriftTracker({ alpha: 0.5, driftStreakLimit: 2 });
|
|
100
|
+
* ```
|
|
101
|
+
*/
|
|
102
|
+
constructor(config?: Partial<DriftConfig>) {
|
|
103
|
+
// Merge caller overrides with defaults — undefined fields are taken from
|
|
104
|
+
// DEFAULT_DRIFT_CONFIG, preserving all caller-supplied values exactly.
|
|
105
|
+
this.config = { ...DEFAULT_DRIFT_CONFIG, ...(config ?? {}) };
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// -------------------------------------------------------------------------
|
|
109
|
+
// Public API — update
|
|
110
|
+
// -------------------------------------------------------------------------
|
|
111
|
+
|
|
112
|
+
/**
|
|
113
|
+
* Processes a new message embedding for the given session and returns the
|
|
114
|
+
* current drift assessment.
|
|
115
|
+
*
|
|
116
|
+
* ### Steps performed
|
|
117
|
+
* 1. **Retrieve or create** the session state. On the very first message
|
|
118
|
+
* the running embedding is initialised to a shallow copy of
|
|
119
|
+
* `messageEmbedding` (no EMA applied yet).
|
|
120
|
+
* 2. **Apply EMA** (from the second message onwards):
|
|
121
|
+
* `running[i] = alpha * message[i] + (1 - alpha) * running[i]`
|
|
122
|
+
* 3. **Check topic alignment** using
|
|
123
|
+
* `allowedIndex.isOnTopicByVector(running, driftThreshold)`.
|
|
124
|
+
* 4. **Update streak** — increment `driftStreak` on off-topic, reset to 0
|
|
125
|
+
* on on-topic.
|
|
126
|
+
* 5. **Lazy-prune** stale sessions when `map.size > maxSessions` before
|
|
127
|
+
* creating a new session (never during updates of existing sessions to
|
|
128
|
+
* avoid deleting the session we are currently updating).
|
|
129
|
+
* 6. **Persist** the updated state and return a {@link DriftResult}.
|
|
130
|
+
*
|
|
131
|
+
* @param sessionId - Unique identifier for the conversation session.
|
|
132
|
+
* Typically a UUID or user ID. Must be consistent across messages in the
|
|
133
|
+
* same conversation.
|
|
134
|
+
* @param messageEmbedding - Pre-computed numeric embedding of the current
|
|
135
|
+
* message. Must have the same dimensionality as the topic centroids used
|
|
136
|
+
* by `allowedIndex`.
|
|
137
|
+
* @param allowedIndex - Built {@link TopicEmbeddingIndex} containing the
|
|
138
|
+
* allowed topics to check against. The tracker calls only
|
|
139
|
+
* `isOnTopicByVector` (no async operations, no extra embedding calls).
|
|
140
|
+
* @returns A {@link DriftResult} describing whether the session is currently
|
|
141
|
+
* drifting and by how much.
|
|
142
|
+
*/
|
|
143
|
+
update(
|
|
144
|
+
sessionId: string,
|
|
145
|
+
messageEmbedding: number[],
|
|
146
|
+
allowedIndex: TopicEmbeddingIndex,
|
|
147
|
+
): DriftResult {
|
|
148
|
+
const now = Date.now();
|
|
149
|
+
const isNewSession = !this.sessions.has(sessionId);
|
|
150
|
+
|
|
151
|
+
// Lazy prune: only trigger when a NEW session would push us over the limit.
|
|
152
|
+
// We do not prune during updates of existing sessions to avoid accidentally
|
|
153
|
+
// deleting a session that is currently being processed.
|
|
154
|
+
if (isNewSession && this.sessions.size >= this.config.maxSessions) {
|
|
155
|
+
this.pruneStale();
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
let state: TopicState;
|
|
159
|
+
|
|
160
|
+
if (isNewSession) {
|
|
161
|
+
// First message in this session — initialise running embedding to a copy
|
|
162
|
+
// of the current message embedding. A copy prevents external mutation
|
|
163
|
+
// of the array from silently corrupting the tracker state.
|
|
164
|
+
state = {
|
|
165
|
+
runningEmbedding: [...messageEmbedding],
|
|
166
|
+
messageCount: 0, // will be incremented below
|
|
167
|
+
lastTopicScore: 0,
|
|
168
|
+
driftStreak: 0,
|
|
169
|
+
lastSeenAt: now,
|
|
170
|
+
};
|
|
171
|
+
} else {
|
|
172
|
+
// Retrieve existing state — guaranteed non-null by the `has()` check above.
|
|
173
|
+
state = this.sessions.get(sessionId)!;
|
|
174
|
+
|
|
175
|
+
// Apply the EMA update in-place.
|
|
176
|
+
// running[i] = alpha * message[i] + (1 - alpha) * running[i]
|
|
177
|
+
const alpha = this.config.alpha;
|
|
178
|
+
const oneMinusAlpha = 1 - alpha;
|
|
179
|
+
|
|
180
|
+
for (let i = 0; i < state.runningEmbedding.length; i++) {
|
|
181
|
+
state.runningEmbedding[i] =
|
|
182
|
+
alpha * messageEmbedding[i] + oneMinusAlpha * state.runningEmbedding[i];
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// Increment message counter and timestamp.
|
|
187
|
+
state.messageCount += 1;
|
|
188
|
+
state.lastSeenAt = now;
|
|
189
|
+
|
|
190
|
+
// -----------------------------------------------------------------------
|
|
191
|
+
// Topic alignment check
|
|
192
|
+
// -----------------------------------------------------------------------
|
|
193
|
+
|
|
194
|
+
// Check whether the (now-updated) running embedding is on-topic.
|
|
195
|
+
const onTopic = allowedIndex.isOnTopicByVector(
|
|
196
|
+
state.runningEmbedding,
|
|
197
|
+
this.config.driftThreshold,
|
|
198
|
+
);
|
|
199
|
+
|
|
200
|
+
// Retrieve the best-match details from the index for the result payload.
|
|
201
|
+
// matchByVector is synchronous and does not re-embed anything.
|
|
202
|
+
const topMatches = allowedIndex.matchByVector(state.runningEmbedding);
|
|
203
|
+
const nearestTopic: TopicMatch | null = topMatches.length > 0 ? topMatches[0] : null;
|
|
204
|
+
const currentSimilarity = nearestTopic?.similarity ?? 0;
|
|
205
|
+
|
|
206
|
+
// Store the latest similarity score for observability.
|
|
207
|
+
state.lastTopicScore = currentSimilarity;
|
|
208
|
+
|
|
209
|
+
// -----------------------------------------------------------------------
|
|
210
|
+
// Drift streak management
|
|
211
|
+
// -----------------------------------------------------------------------
|
|
212
|
+
|
|
213
|
+
if (onTopic) {
|
|
214
|
+
// Good message — reset the drift counter.
|
|
215
|
+
state.driftStreak = 0;
|
|
216
|
+
} else {
|
|
217
|
+
// Off-topic message — accumulate the streak.
|
|
218
|
+
state.driftStreak += 1;
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
const driftLimitExceeded = state.driftStreak >= this.config.driftStreakLimit;
|
|
222
|
+
|
|
223
|
+
// -----------------------------------------------------------------------
|
|
224
|
+
// Persist state and return result
|
|
225
|
+
// -----------------------------------------------------------------------
|
|
226
|
+
|
|
227
|
+
// Always write back (even for existing sessions, since we mutated in-place
|
|
228
|
+
// for the EMA; for new sessions we need to insert).
|
|
229
|
+
this.sessions.set(sessionId, state);
|
|
230
|
+
|
|
231
|
+
return {
|
|
232
|
+
onTopic,
|
|
233
|
+
currentSimilarity,
|
|
234
|
+
nearestTopic,
|
|
235
|
+
driftStreak: state.driftStreak,
|
|
236
|
+
driftLimitExceeded,
|
|
237
|
+
};
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
// -------------------------------------------------------------------------
|
|
241
|
+
// Public API — pruneStale
|
|
242
|
+
// -------------------------------------------------------------------------
|
|
243
|
+
|
|
244
|
+
/**
|
|
245
|
+
* Removes sessions that have been inactive for longer than `sessionTimeoutMs`.
|
|
246
|
+
*
|
|
247
|
+
* This is called lazily inside {@link update} when the session map exceeds
|
|
248
|
+
* `maxSessions`, but callers may invoke it directly to trigger an immediate
|
|
249
|
+
* cleanup (e.g. in a scheduled maintenance job).
|
|
250
|
+
*
|
|
251
|
+
* Pruning is O(n) in the number of active sessions.
|
|
252
|
+
*/
|
|
253
|
+
pruneStale(): void {
|
|
254
|
+
const now = Date.now();
|
|
255
|
+
const timeoutMs = this.config.sessionTimeoutMs;
|
|
256
|
+
|
|
257
|
+
for (const [id, state] of this.sessions) {
|
|
258
|
+
if (now - state.lastSeenAt > timeoutMs) {
|
|
259
|
+
this.sessions.delete(id);
|
|
260
|
+
}
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// -------------------------------------------------------------------------
|
|
265
|
+
// Public API — clear
|
|
266
|
+
// -------------------------------------------------------------------------
|
|
267
|
+
|
|
268
|
+
/**
|
|
269
|
+
* Removes all sessions from the tracker unconditionally.
|
|
270
|
+
*
|
|
271
|
+
* Useful for graceful shutdown, testing teardown, or resetting the agent
|
|
272
|
+
* context between evaluation runs.
|
|
273
|
+
*/
|
|
274
|
+
clear(): void {
|
|
275
|
+
this.sessions.clear();
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
// -------------------------------------------------------------------------
|
|
279
|
+
// Internal helpers (exposed for testing via package-private pattern)
|
|
280
|
+
// -------------------------------------------------------------------------
|
|
281
|
+
|
|
282
|
+
/**
|
|
283
|
+
* Returns the current number of active sessions in the internal map.
|
|
284
|
+
* Useful for observability and testing.
|
|
285
|
+
*
|
|
286
|
+
* @internal
|
|
287
|
+
*/
|
|
288
|
+
get sessionCount(): number {
|
|
289
|
+
return this.sessions.size;
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
/**
|
|
293
|
+
* Returns a copy of the {@link TopicState} for the given session, or
|
|
294
|
+
* `undefined` if the session does not exist.
|
|
295
|
+
*
|
|
296
|
+
* Exposed for unit-testing state inspection. The returned object is a
|
|
297
|
+
* shallow copy — mutating it does not affect the tracker's internal state.
|
|
298
|
+
*
|
|
299
|
+
* @internal
|
|
300
|
+
*/
|
|
301
|
+
getState(sessionId: string): TopicState | undefined {
|
|
302
|
+
const state = this.sessions.get(sessionId);
|
|
303
|
+
if (!state) return undefined;
|
|
304
|
+
// Shallow copy to prevent callers from accidentally mutating tracker state.
|
|
305
|
+
return { ...state, runningEmbedding: [...state.runningEmbedding] };
|
|
306
|
+
}
|
|
307
|
+
}
|
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @fileoverview TopicEmbeddingIndex — semantic similarity lookup for topic guardrails.
|
|
3
|
+
*
|
|
4
|
+
* This module implements a lightweight in-memory embedding index that:
|
|
5
|
+
*
|
|
6
|
+
* 1. **Builds** per-topic centroid embeddings from descriptions + examples.
|
|
7
|
+
* 2. **Matches** an arbitrary embedding or text string against all topic centroids
|
|
8
|
+
* using cosine similarity.
|
|
9
|
+
* 3. **Answers** boolean on-topic queries at a configurable similarity threshold.
|
|
10
|
+
*
|
|
11
|
+
* ### How centroids are built
|
|
12
|
+
* For each {@link TopicDescriptor} the index concatenates:
|
|
13
|
+
* ```
|
|
14
|
+
* texts = [descriptor.description, ...descriptor.examples]
|
|
15
|
+
* ```
|
|
16
|
+
* All topics are embedded in a single batch call to `embeddingFn` to minimise
|
|
17
|
+
* round-trips. The centroid for a topic is the component-wise average (mean)
|
|
18
|
+
* of all its embedding vectors.
|
|
19
|
+
*
|
|
20
|
+
* ### Similarity scoring
|
|
21
|
+
* Raw cosine similarity can be negative when vectors point in opposite directions.
|
|
22
|
+
* `matchByVector` clamps scores to `Math.max(0, similarity)` so that all
|
|
23
|
+
* {@link TopicMatch} values represent non-negative relevance scores.
|
|
24
|
+
*
|
|
25
|
+
* @module topicality/TopicEmbeddingIndex
|
|
26
|
+
*/
|
|
27
|
+
|
|
28
|
+
import { cosineSimilarity } from '@framers/agentos/core/utils/text-utils';
|
|
29
|
+
import type { TopicDescriptor, TopicMatch } from './types';
|
|
30
|
+
|
|
31
|
+
// ---------------------------------------------------------------------------
|
|
32
|
+
// Internal storage shape
|
|
33
|
+
// ---------------------------------------------------------------------------
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* Private per-topic record stored in the index map.
|
|
37
|
+
*
|
|
38
|
+
* @internal
|
|
39
|
+
*/
|
|
40
|
+
interface TopicEntry {
|
|
41
|
+
/** Original descriptor, kept for metadata retrieval. */
|
|
42
|
+
descriptor: TopicDescriptor;
|
|
43
|
+
/**
|
|
44
|
+
* Component-wise average of all embeddings derived from this topic's
|
|
45
|
+
* description and example strings.
|
|
46
|
+
*/
|
|
47
|
+
centroid: number[];
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
// ---------------------------------------------------------------------------
|
|
51
|
+
// TopicEmbeddingIndex
|
|
52
|
+
// ---------------------------------------------------------------------------
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Semantic embedding index for topicality guardrail matching.
|
|
56
|
+
*
|
|
57
|
+
* The index is intentionally **lazy** — it holds no embeddings until
|
|
58
|
+
* {@link build} is called. This makes instantiation cheap and lets the
|
|
59
|
+
* caller defer the (potentially expensive) batch embedding call until the
|
|
60
|
+
* agent's first message.
|
|
61
|
+
*
|
|
62
|
+
* @example
|
|
63
|
+
* ```ts
|
|
64
|
+
* const index = new TopicEmbeddingIndex(async (texts) => {
|
|
65
|
+
* const res = await openai.embeddings.create({ model: 'text-embedding-3-small', input: texts });
|
|
66
|
+
* return res.data.map(d => d.embedding);
|
|
67
|
+
* });
|
|
68
|
+
*
|
|
69
|
+
* await index.build(TOPIC_PRESETS.customerSupport);
|
|
70
|
+
*
|
|
71
|
+
* const matches = await index.match('How do I cancel my subscription?');
|
|
72
|
+
* // → [{ topicId: 'billing', topicName: 'Billing & Payments', similarity: 0.82 }, ...]
|
|
73
|
+
*
|
|
74
|
+
* const onTopic = await index.isOnTopic('Tell me a joke', 0.35);
|
|
75
|
+
* // → false (a joke doesn't match any customer-support topic)
|
|
76
|
+
* ```
|
|
77
|
+
*/
|
|
78
|
+
export class TopicEmbeddingIndex {
|
|
79
|
+
/**
|
|
80
|
+
* Caller-supplied batch embedding function.
|
|
81
|
+
* Invoked once during {@link build} with all topic texts concatenated.
|
|
82
|
+
*/
|
|
83
|
+
private readonly embeddingFn: (texts: string[]) => Promise<number[][]>;
|
|
84
|
+
|
|
85
|
+
/**
|
|
86
|
+
* Internal store mapping `topicId → TopicEntry`.
|
|
87
|
+
* Populated by {@link build}; empty until then.
|
|
88
|
+
*/
|
|
89
|
+
private readonly entries: Map<string, TopicEntry> = new Map();
|
|
90
|
+
|
|
91
|
+
/** Whether {@link build} has been called and completed successfully. */
|
|
92
|
+
private built: boolean = false;
|
|
93
|
+
|
|
94
|
+
// -------------------------------------------------------------------------
|
|
95
|
+
// Constructor
|
|
96
|
+
// -------------------------------------------------------------------------
|
|
97
|
+
|
|
98
|
+
/**
|
|
99
|
+
* Creates a new `TopicEmbeddingIndex`.
|
|
100
|
+
*
|
|
101
|
+
* @param embeddingFn - Async function that converts an array of text strings
|
|
102
|
+
* into corresponding numeric embedding vectors. All returned vectors must
|
|
103
|
+
* share the same dimensionality. The function is called exactly **once**
|
|
104
|
+
* per {@link build} invocation with all texts for all topics batched
|
|
105
|
+
* together.
|
|
106
|
+
*/
|
|
107
|
+
constructor(embeddingFn: (texts: string[]) => Promise<number[][]>) {
|
|
108
|
+
this.embeddingFn = embeddingFn;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
// -------------------------------------------------------------------------
|
|
112
|
+
// Public API — build
|
|
113
|
+
// -------------------------------------------------------------------------
|
|
114
|
+
|
|
115
|
+
/**
|
|
116
|
+
* Embeds all topic descriptions and examples, computes per-topic centroid
|
|
117
|
+
* embeddings, and stores them in the internal index.
|
|
118
|
+
*
|
|
119
|
+
* Calling `build()` a second time replaces the existing index entirely,
|
|
120
|
+
* allowing hot-reloading of topic configurations without recreating the
|
|
121
|
+
* instance.
|
|
122
|
+
*
|
|
123
|
+
* ### Centroid computation
|
|
124
|
+
* For each topic we collect `[description, ...examples]` as a list of
|
|
125
|
+
* strings, embed them all in one batch, then average the resulting vectors
|
|
126
|
+
* component-wise to produce a single representative centroid.
|
|
127
|
+
*
|
|
128
|
+
* All topics are embedded in a **single batch call** to minimise latency.
|
|
129
|
+
*
|
|
130
|
+
* @param topics - Array of {@link TopicDescriptor} objects to index.
|
|
131
|
+
* An empty array is valid — the index will simply return no matches.
|
|
132
|
+
* @returns A promise that resolves once all embeddings are computed and
|
|
133
|
+
* stored. Rejects if `embeddingFn` throws or returns vectors of
|
|
134
|
+
* mismatched length.
|
|
135
|
+
*/
|
|
136
|
+
async build(topics: TopicDescriptor[]): Promise<void> {
|
|
137
|
+
// Reset state before (re)building so a failed build leaves the index empty
|
|
138
|
+
// rather than in a partial state.
|
|
139
|
+
this.entries.clear();
|
|
140
|
+
this.built = false;
|
|
141
|
+
|
|
142
|
+
if (topics.length === 0) {
|
|
143
|
+
// Nothing to embed — mark as built so isBuilt returns true.
|
|
144
|
+
this.built = true;
|
|
145
|
+
return;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Collect, per topic, the list of texts to embed and the range of
|
|
149
|
+
// indices they will occupy in the flat batch array.
|
|
150
|
+
//
|
|
151
|
+
// Layout: [topic0_desc, topic0_ex0, topic0_ex1, …, topic1_desc, …]
|
|
152
|
+
const allTexts: string[] = [];
|
|
153
|
+
const topicRanges: Array<{ topic: TopicDescriptor; start: number; end: number }> = [];
|
|
154
|
+
|
|
155
|
+
for (const topic of topics) {
|
|
156
|
+
const start = allTexts.length;
|
|
157
|
+
// Always include the description as the first text for this topic.
|
|
158
|
+
allTexts.push(topic.description);
|
|
159
|
+
// Then all examples (may be empty — the centroid will just be the description).
|
|
160
|
+
for (const example of topic.examples) {
|
|
161
|
+
allTexts.push(example);
|
|
162
|
+
}
|
|
163
|
+
const end = allTexts.length; // exclusive
|
|
164
|
+
topicRanges.push({ topic, start, end });
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
// Single batch embedding call — one round-trip regardless of how many
|
|
168
|
+
// topics or examples are configured.
|
|
169
|
+
const allEmbeddings = await this.embeddingFn(allTexts);
|
|
170
|
+
|
|
171
|
+
// Validate that the embedding function returned the right number of vectors.
|
|
172
|
+
if (allEmbeddings.length !== allTexts.length) {
|
|
173
|
+
throw new Error(
|
|
174
|
+
`TopicEmbeddingIndex.build: embeddingFn returned ${allEmbeddings.length} vectors ` +
|
|
175
|
+
`but ${allTexts.length} texts were provided.`,
|
|
176
|
+
);
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
// Compute centroid for each topic from its slice of the batch result.
|
|
180
|
+
for (const { topic, start, end } of topicRanges) {
|
|
181
|
+
const slice = allEmbeddings.slice(start, end);
|
|
182
|
+
const centroid = computeCentroid(slice);
|
|
183
|
+
this.entries.set(topic.id, { descriptor: topic, centroid });
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
this.built = true;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// -------------------------------------------------------------------------
|
|
190
|
+
// Public API — matchByVector
|
|
191
|
+
// -------------------------------------------------------------------------
|
|
192
|
+
|
|
193
|
+
/**
|
|
194
|
+
* Computes similarity between a pre-computed embedding vector and all topic
|
|
195
|
+
* centroids **without** making any additional embedding calls.
|
|
196
|
+
*
|
|
197
|
+
* This is the hot path invoked by {@link TopicDriftTracker}, which maintains
|
|
198
|
+
* its own running embedding and never needs to re-embed.
|
|
199
|
+
*
|
|
200
|
+
* Results are clamped to `[0, 1]` (negative cosine → 0) and sorted
|
|
201
|
+
* descending by similarity.
|
|
202
|
+
*
|
|
203
|
+
* @param embedding - A numeric vector with the same dimensionality as the
|
|
204
|
+
* centroids produced during {@link build}.
|
|
205
|
+
* @returns Array of {@link TopicMatch} objects sorted by similarity
|
|
206
|
+
* descending. Returns an empty array if the index was not yet built or
|
|
207
|
+
* contains no topics.
|
|
208
|
+
*/
|
|
209
|
+
matchByVector(embedding: number[]): TopicMatch[] {
|
|
210
|
+
if (!this.built || this.entries.size === 0) {
|
|
211
|
+
// Return empty rather than throwing — callers can treat no match as off-topic.
|
|
212
|
+
return [];
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
const matches: TopicMatch[] = [];
|
|
216
|
+
|
|
217
|
+
for (const [topicId, entry] of this.entries) {
|
|
218
|
+
const raw = cosineSimilarity(embedding, entry.centroid);
|
|
219
|
+
// Clamp to [0, 1] — negative similarity means "opposite direction" which
|
|
220
|
+
// is no more useful than "unrelated" for topic matching.
|
|
221
|
+
const similarity = Math.max(0, raw);
|
|
222
|
+
|
|
223
|
+
matches.push({
|
|
224
|
+
topicId,
|
|
225
|
+
topicName: entry.descriptor.name,
|
|
226
|
+
similarity,
|
|
227
|
+
});
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// Sort descending so the best match is first.
|
|
231
|
+
matches.sort((a, b) => b.similarity - a.similarity);
|
|
232
|
+
return matches;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
// -------------------------------------------------------------------------
|
|
236
|
+
// Public API — match
|
|
237
|
+
// -------------------------------------------------------------------------
|
|
238
|
+
|
|
239
|
+
/**
|
|
240
|
+
* Embeds `text` and returns similarity scores against all topic centroids.
|
|
241
|
+
*
|
|
242
|
+
* This is a convenience wrapper that handles the embedding step. If you
|
|
243
|
+
* already have an embedding (e.g. from the drift tracker's running vector)
|
|
244
|
+
* prefer {@link matchByVector} to avoid a redundant embedding call.
|
|
245
|
+
*
|
|
246
|
+
* @param text - The user message or assistant output to evaluate.
|
|
247
|
+
* @returns A promise resolving to {@link TopicMatch}[] sorted descending.
|
|
248
|
+
*/
|
|
249
|
+
async match(text: string): Promise<TopicMatch[]> {
|
|
250
|
+
// Embed the single query text.
|
|
251
|
+
const [embedding] = await this.embeddingFn([text]);
|
|
252
|
+
return this.matchByVector(embedding);
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
// -------------------------------------------------------------------------
|
|
256
|
+
// Public API — isOnTopicByVector
|
|
257
|
+
// -------------------------------------------------------------------------
|
|
258
|
+
|
|
259
|
+
/**
|
|
260
|
+
* Returns `true` if the given embedding vector scores above `threshold`
|
|
261
|
+
* against **at least one** topic in the index.
|
|
262
|
+
*
|
|
263
|
+
* Uses {@link matchByVector} internally so no additional embedding call is
|
|
264
|
+
* made.
|
|
265
|
+
*
|
|
266
|
+
* @param embedding - Pre-computed numeric vector.
|
|
267
|
+
* @param threshold - Minimum similarity (in `[0, 1]`) for a topic to count
|
|
268
|
+
* as a match.
|
|
269
|
+
* @returns `true` if any topic centroid has similarity > threshold; otherwise `false`.
|
|
270
|
+
*/
|
|
271
|
+
isOnTopicByVector(embedding: number[], threshold: number): boolean {
|
|
272
|
+
const matches = this.matchByVector(embedding);
|
|
273
|
+
// The list is sorted descending, so we only need to check the first entry.
|
|
274
|
+
return matches.length > 0 && matches[0].similarity > threshold;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
// -------------------------------------------------------------------------
|
|
278
|
+
// Public API — isOnTopic
|
|
279
|
+
// -------------------------------------------------------------------------
|
|
280
|
+
|
|
281
|
+
/**
|
|
282
|
+
* Embeds `text` and returns `true` if it scores above `threshold` against
|
|
283
|
+
* at least one allowed topic.
|
|
284
|
+
*
|
|
285
|
+
* @param text - The text to evaluate.
|
|
286
|
+
* @param threshold - Minimum cosine similarity for the text to be considered on-topic.
|
|
287
|
+
* @returns A promise resolving to `true` if on-topic, `false` otherwise.
|
|
288
|
+
*/
|
|
289
|
+
async isOnTopic(text: string, threshold: number): Promise<boolean> {
|
|
290
|
+
const [embedding] = await this.embeddingFn([text]);
|
|
291
|
+
return this.isOnTopicByVector(embedding, threshold);
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
// -------------------------------------------------------------------------
|
|
295
|
+
// Getter
|
|
296
|
+
// -------------------------------------------------------------------------
|
|
297
|
+
|
|
298
|
+
/**
|
|
299
|
+
* Whether {@link build} has been called and completed successfully.
|
|
300
|
+
*
|
|
301
|
+
* Use this to guard against calling {@link match} or {@link matchByVector}
|
|
302
|
+
* before the index is ready.
|
|
303
|
+
*
|
|
304
|
+
* @example
|
|
305
|
+
* ```ts
|
|
306
|
+
* if (!index.isBuilt) await index.build(topics);
|
|
307
|
+
* ```
|
|
308
|
+
*/
|
|
309
|
+
get isBuilt(): boolean {
|
|
310
|
+
return this.built;
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
// ---------------------------------------------------------------------------
|
|
315
|
+
// Internal helpers
|
|
316
|
+
// ---------------------------------------------------------------------------
|
|
317
|
+
|
|
318
|
+
/**
|
|
319
|
+
* Computes the component-wise average (centroid) of an array of embedding
|
|
320
|
+
* vectors.
|
|
321
|
+
*
|
|
322
|
+
* All input vectors are assumed to have the same dimensionality. If the
|
|
323
|
+
* input array is empty an empty array is returned (safe no-op).
|
|
324
|
+
*
|
|
325
|
+
* @param vectors - One or more numeric vectors of equal length.
|
|
326
|
+
* @returns A single vector whose i-th element is the mean of i-th elements
|
|
327
|
+
* across all input vectors.
|
|
328
|
+
*
|
|
329
|
+
* @internal
|
|
330
|
+
*/
|
|
331
|
+
function computeCentroid(vectors: number[][]): number[] {
|
|
332
|
+
if (vectors.length === 0) return [];
|
|
333
|
+
|
|
334
|
+
const dim = vectors[0].length;
|
|
335
|
+
// Initialise accumulator to all zeros.
|
|
336
|
+
const sum = new Array<number>(dim).fill(0);
|
|
337
|
+
|
|
338
|
+
for (const vec of vectors) {
|
|
339
|
+
for (let i = 0; i < dim; i++) {
|
|
340
|
+
sum[i] += vec[i];
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
// Divide each component by the number of vectors to get the mean.
|
|
345
|
+
return sum.map((v) => v / vectors.length);
|
|
346
|
+
}
|