@dbx-tools/appkit-mastra 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/model.ts ADDED
@@ -0,0 +1,491 @@
1
+ /**
2
+ * Databricks Model Serving resolver for Mastra agents.
3
+ *
4
+ * Each agent step calls {@link buildModel} with the active
5
+ * `RequestContext`. The user stamped by `MastraServer` carries an
6
+ * AppKit `WorkspaceClient`; we ask it for the workspace host and a
7
+ * fresh bearer header, then point Mastra's OpenAI-compatible provider
8
+ * at `/serving-endpoints` on that host.
9
+ *
10
+ * Model id resolution walks three sources before falling back to the
11
+ * hard-coded default, **in this priority order**:
12
+ *
13
+ * 1. Per-request override stashed by the auth middleware under
14
+ * {@link MASTRA_MODEL_OVERRIDE_KEY} (header / query / body).
15
+ * 2. The static `modelId` passed in by the agent / plugin (string
16
+ * sugar on `def.model` or `config.defaultModel`).
17
+ * 3. `DATABRICKS_SERVING_ENDPOINT_NAME` env var.
18
+ * 4. {@link FALLBACK_MODEL_ID}.
19
+ *
20
+ * Whatever wins is then fuzzy-matched against the live
21
+ * `/serving-endpoints` list ({@link listServingEndpoints}) so loose
22
+ * names like `"claude sonnet"` resolve to the real endpoint name.
23
+ * Fuzzy matching is best-effort: when the workspace client throws
24
+ * (network blip, expired token at cache-fill time) we fall back to
25
+ * the input verbatim and let Databricks return the canonical error.
26
+ */
27
+
28
+ import {
29
+ commonUtils,
30
+ httpUtils,
31
+ logUtils,
32
+ stringUtils,
33
+ } from "@dbx-tools/appkit-shared";
34
+ import type { MastraModelConfig } from "@mastra/core/llm";
35
+ import type { RequestContext } from "@mastra/core/request-context";
36
+
37
+ import { MASTRA_USER_KEY, type MastraPluginConfig, type User } from "./config.js";
38
+ import {
39
+ listServingEndpoints,
40
+ MASTRA_MODEL_OVERRIDE_KEY,
41
+ resolveModelId,
42
+ resolveServingConfig,
43
+ type ServingEndpointSummary,
44
+ } from "./serving.js";
45
+
46
+ /**
47
+ * Capability tiers for Databricks Foundation Model API endpoints.
48
+ *
49
+ * - {@link ModelTier.Thinking}: deepest reasoning / "thinking" models
50
+ * (Claude Opus, GPT-5.5 Pro, Gemini Pro, Llama 4 Maverick, etc).
51
+ * Highest cost and latency; reserve for hard multi-step reasoning.
52
+ * - {@link ModelTier.Balanced}: cost/latency sweet spot for general
53
+ * agent work (Claude Sonnet, GPT-5.x, Gemini Flash, Llama 3.3 70B).
54
+ * The right default for most agents.
55
+ * - {@link ModelTier.Fast}: cheap and quick; classification, routing,
56
+ * tool-arg extraction, simple summarisation (Claude Haiku, GPT-5
57
+ * mini/nano, Gemini Flash Lite, GPT-OSS 20B, Llama 3.1 8B).
58
+ *
59
+ * String enum so the value is the slug we use in cache keys, logs,
60
+ * and as the value users see in serialized configs.
61
+ */
62
+ export enum ModelTier {
63
+ Thinking = "thinking",
64
+ Balanced = "balanced",
65
+ Fast = "fast",
66
+ }
67
+
68
+ /**
69
+ * Catalogue of Databricks-hosted Foundation Model API endpoints,
70
+ * grouped by capability {@link ModelTier} and then by provider. Each
71
+ * inner array is priority-ordered (most powerful first within the
72
+ * same provider+tier).
73
+ *
74
+ * Provider buckets:
75
+ *
76
+ * - `claude`: Anthropic Claude family (closed; flagship reasoning).
77
+ * - `gpt`: OpenAI GPT-5 family (closed; "ChatGPT" on Databricks FMAPI).
78
+ * - `gemini`: Google Gemini family (closed; multimodal + web-search).
79
+ * - `openSource`: open-weights models (widest regional / SKU availability).
80
+ *
81
+ * The list is curated by hand; refresh from the Databricks "supported
82
+ * foundation models" doc when new endpoints land.
83
+ */
84
+ export const MODEL_CATALOG = {
85
+ [ModelTier.Thinking]: {
86
+ claude: [
87
+ "databricks-claude-opus-4-8",
88
+ "databricks-claude-opus-4-7",
89
+ "databricks-claude-opus-4-6",
90
+ "databricks-claude-opus-4-5",
91
+ "databricks-claude-opus-4-1",
92
+ ],
93
+ gpt: ["databricks-gpt-5-5-pro"],
94
+ gemini: [
95
+ "databricks-gemini-3-1-pro",
96
+ "databricks-gemini-3-pro",
97
+ "databricks-gemini-2-5-pro",
98
+ ],
99
+ openSource: [
100
+ "databricks-llama-4-maverick",
101
+ "databricks-gpt-oss-120b",
102
+ "databricks-meta-llama-3-1-405b-instruct",
103
+ ],
104
+ },
105
+ [ModelTier.Balanced]: {
106
+ claude: [
107
+ "databricks-claude-sonnet-4-6",
108
+ "databricks-claude-sonnet-4-5",
109
+ "databricks-claude-sonnet-4",
110
+ ],
111
+ gpt: [
112
+ "databricks-gpt-5-5",
113
+ "databricks-gpt-5-4",
114
+ "databricks-gpt-5-2",
115
+ "databricks-gpt-5-1",
116
+ "databricks-gpt-5",
117
+ ],
118
+ gemini: [
119
+ "databricks-gemini-3-5-flash",
120
+ "databricks-gemini-3-flash",
121
+ "databricks-gemini-2-5-flash",
122
+ ],
123
+ openSource: [
124
+ "databricks-meta-llama-3-3-70b-instruct",
125
+ "databricks-qwen3-next-80b-a3b-instruct",
126
+ "databricks-qwen35-122b-a10b",
127
+ ],
128
+ },
129
+ [ModelTier.Fast]: {
130
+ claude: ["databricks-claude-haiku-4-5"],
131
+ gpt: [
132
+ "databricks-gpt-5-4-mini",
133
+ "databricks-gpt-5-4-nano",
134
+ "databricks-gpt-5-mini",
135
+ "databricks-gpt-5-nano",
136
+ ],
137
+ gemini: ["databricks-gemini-3-1-flash-lite"],
138
+ openSource: [
139
+ "databricks-gpt-oss-20b",
140
+ "databricks-gemma-3-12b",
141
+ "databricks-meta-llama-3-1-8b-instruct",
142
+ ],
143
+ },
144
+ } as const satisfies Record<ModelTier, Record<string, readonly string[]>>;
145
+
146
+ /**
147
+ * Round-robin zip: take one from each input list in order, skipping
148
+ * lists that have already been exhausted. Used to interleave provider
149
+ * buckets within a tier so the resolver alternates between vendors
150
+ * instead of draining one before trying the next.
151
+ *
152
+ * Example: `interleave(["a1","a2","a3"], ["b1","b2"])` ->
153
+ * `["a1","b1","a2","b2","a3"]`.
154
+ */
155
+ function interleave<T>(...lists: readonly (readonly T[])[]): T[] {
156
+ const out: T[] = [];
157
+ const max = Math.max(0, ...lists.map((l) => l.length));
158
+ for (let i = 0; i < max; i++) {
159
+ for (const list of lists) {
160
+ if (i < list.length) out.push(list[i]!);
161
+ }
162
+ }
163
+ return out;
164
+ }
165
+
166
+ /**
167
+ * Priority-ordered model ids for a single capability {@link ModelTier},
168
+ * interleaved across providers so a workspace missing the top Claude
169
+ * still lands on a flagship GPT / Gemini on the next probe.
170
+ *
171
+ * Provider order within the interleave: Claude, GPT, Gemini, then the
172
+ * open-weights tail appended verbatim as the universal floor (widest
173
+ * regional availability).
174
+ *
175
+ * @example
176
+ * ```ts
177
+ * mastra({
178
+ * defaultModelFallbacks: modelsForTier(ModelTier.Fast),
179
+ * });
180
+ * ```
181
+ */
182
+ export function modelsForTier(tier: ModelTier): readonly string[] {
183
+ const bucket = MODEL_CATALOG[tier];
184
+ return [
185
+ ...interleave(bucket.claude, bucket.gpt, bucket.gemini),
186
+ ...bucket.openSource,
187
+ ];
188
+ }
189
+
190
+ /**
191
+ * Top model id at the given {@link ModelTier}. Sync; the agent-step
192
+ * resolver fuzzy-matches it against the workspace catalogue at call
193
+ * time, so this works even when the literal top pick isn't deployed.
194
+ *
195
+ * Use when wiring a tier-appropriate model into an agent definition:
196
+ *
197
+ * @example
198
+ * ```ts
199
+ * const classifier = createAgent({
200
+ * instructions: "Classify this email",
201
+ * model: modelForTier(ModelTier.Fast), // cheap, quick
202
+ * });
203
+ *
204
+ * const planner = createAgent({
205
+ * instructions: "Plan a multi-step migration",
206
+ * model: modelForTier(ModelTier.Thinking), // deep reasoning
207
+ * });
208
+ * ```
209
+ */
210
+ export function modelForTier(tier: ModelTier): string {
211
+ return modelsForTier(tier)[0]!;
212
+ }
213
+
214
+ /**
215
+ * Last-resort model ids used when neither `config.defaultModel`,
216
+ * per-agent `model`, nor `DATABRICKS_SERVING_ENDPOINT_NAME` is set.
217
+ *
218
+ * Walked in order at resolve time: the first id whose endpoint is
219
+ * actually present in the workspace's `/serving-endpoints` listing
220
+ * wins. Workspaces vary - not every region / SKU has every model,
221
+ * and the list of Foundation Model APIs evolves quickly - so the
222
+ * resolver degrades all the way from "best thinking model" down to
223
+ * "smallest commodity Llama" before giving up.
224
+ *
225
+ * Built by chaining the per-tier interleaves (Thinking -> Balanced
226
+ * -> Fast); within each tier the providers are round-robin-zipped
227
+ * (Claude, GPT, Gemini, then open-weights tail). Override the entire
228
+ * list via `MastraPluginConfig.defaultModelFallbacks` (e.g. to pin a
229
+ * regulated workspace to a specific approved subset, or to bias the
230
+ * priority toward a particular tier).
231
+ */
232
+ export const FALLBACK_MODEL_IDS: readonly string[] = [
233
+ ...modelsForTier(ModelTier.Thinking),
234
+ ...modelsForTier(ModelTier.Balanced),
235
+ ...modelsForTier(ModelTier.Fast),
236
+ ];
237
+
238
+ /** Optional overrides accepted by {@link buildModel}. */
239
+ export interface BuildModelOverrides {
240
+ /**
241
+ * Static model id from the agent / plugin config (string sugar on
242
+ * `def.model` or `config.defaultModel`). Loses to the per-request
243
+ * override but wins over env / fallback.
244
+ */
245
+ modelId?: string;
246
+ }
247
+
248
+ /**
249
+ * Resolve a `MastraModelConfig` for the current agent step. Runs
250
+ * while `agent.stream` is inside the `asUser(req)` scope so tokens
251
+ * are user-scoped; outside an active user context the workspace
252
+ * client falls back to the service principal.
253
+ */
254
+ export async function buildModel(
255
+ config: MastraPluginConfig,
256
+ requestContext: RequestContext,
257
+ overrides: BuildModelOverrides = {},
258
+ ): Promise<MastraModelConfig> {
259
+ void setupFetchInterceptor();
260
+ const user = requestContext.get(MASTRA_USER_KEY) as User;
261
+ const clientConfig = user.executionContext.client.config;
262
+ const host = (await clientConfig.getHost()).toString();
263
+ const headers = new Headers();
264
+ await clientConfig.authenticate(headers);
265
+ // The OpenAI Node SDK appends paths like `/chat/completions` to whatever
266
+ // URL we hand it. Drop the trailing slash so the resulting URL stays
267
+ // well-formed (`/serving-endpoints/chat/completions`).
268
+ const url = new URL("/serving-endpoints", host).toString().replace(/\/$/, "");
269
+
270
+ const modelId = await pickModelId(config, requestContext, overrides, user, host);
271
+
272
+ return {
273
+ providerId: config.providerId ?? "databricks",
274
+ modelId,
275
+ url,
276
+ headers: Object.fromEntries(headers.entries()),
277
+ };
278
+ }
279
+
280
+ /**
281
+ * Walk the resolution ladder and pick a modelId.
282
+ *
283
+ * 1. **Explicit ask** (per-request override, agent `model` string,
284
+ * `config.defaultModel` string, or `DATABRICKS_SERVING_ENDPOINT_NAME`):
285
+ * when fuzzy matching is on, snap the input to the closest live
286
+ * endpoint so loose names like `"claude sonnet"` resolve. When it's
287
+ * off (or no endpoint matches within threshold), the input is used
288
+ * verbatim and Databricks surfaces the canonical 404.
289
+ *
290
+ * 2. **No explicit ask**: walk
291
+ * {@link MastraPluginConfig.defaultModelFallbacks} (or
292
+ * {@link FALLBACK_MODEL_IDS} when unset) and return the first id
293
+ * whose endpoint is actually present in the workspace listing. A
294
+ * workspace without Claude Opus still gets a sensible default by
295
+ * skipping ahead to whichever Sonnet / GPT-5 / Llama variant is
296
+ * wired up.
297
+ *
298
+ * Catalogue fetches fail loud: network / auth errors propagate to the
299
+ * caller so they see the real SDK message instead of a silent fallback
300
+ * to the top of the priority list.
301
+ */
302
+ async function pickModelId(
303
+ config: MastraPluginConfig,
304
+ requestContext: RequestContext,
305
+ overrides: BuildModelOverrides,
306
+ user: User,
307
+ host: string,
308
+ ): Promise<string> {
309
+ const serving = resolveServingConfig(config, FALLBACK_MODEL_IDS);
310
+ const override = serving.allowOverride
311
+ ? (requestContext.get(MASTRA_MODEL_OVERRIDE_KEY) as string | undefined)
312
+ : undefined;
313
+ const explicit =
314
+ override ?? overrides.modelId ?? process.env.DATABRICKS_SERVING_ENDPOINT_NAME;
315
+
316
+ // Cheap exit: when the caller named a specific model and fuzzy
317
+ // matching is off, there's no reason to touch the catalogue at all.
318
+ if (explicit !== undefined && !serving.fuzzy) return explicit;
319
+
320
+ const endpoints = await listServingEndpoints(user.executionContext.client, host, {
321
+ ttlMs: serving.ttlMs,
322
+ });
323
+ const modelId =
324
+ explicit !== undefined
325
+ ? resolveModelId(explicit, endpoints, { threshold: serving.threshold }).modelId
326
+ : pickFirstAvailable(serving.fallbacks, endpoints);
327
+ //logUtils.logger(config).debug(`model selected: ${modelId}`);
328
+ return modelId;
329
+ }
330
+
331
+ /**
332
+ * Find the first id in `fallbacks` whose endpoint is present in
333
+ * `endpoints`. Returns the top fallback when the workspace has none
334
+ * of them so callers always get a string; an offline workspace will
335
+ * then receive a clean 404 from Databricks instead of a malformed
336
+ * config.
337
+ */
338
+ function pickFirstAvailable(
339
+ fallbacks: readonly string[],
340
+ endpoints: readonly ServingEndpointSummary[],
341
+ ): string {
342
+ const present = new Set(endpoints.map((e) => e.name));
343
+ for (const candidate of fallbacks) {
344
+ if (present.has(candidate)) return candidate;
345
+ }
346
+ return fallbacks[0] ?? FALLBACK_MODEL_IDS[0]!;
347
+ }
348
+
349
+ /** Path prefix that identifies a Databricks Model Serving REST call. */
350
+ const SERVING_ENDPOINTS_PATH_PREFIX = "/serving-endpoints/";
351
+
352
+ /**
353
+ * OpenAI-flavoured chat message shape we need to mutate. We do not
354
+ * import the OpenAI / AI SDK types because both packages keep these
355
+ * fields under internal namespaces; the wire payload is the contract
356
+ * here and it's stable enough to inline.
357
+ */
358
+ interface ChatMessage {
359
+ role: "system" | "user" | "assistant" | "tool";
360
+ content?: string;
361
+ tool_calls?: Array<{ id: string; type: string; function: unknown }>;
362
+ tool_call_id?: string;
363
+ }
364
+
365
+ /**
366
+ * Install a single shared `globalThis.fetch` wrapper for every POST to
367
+ * `/serving-endpoints/...`. The wrapper does two things:
368
+ *
369
+ * 1. Rewrites the outgoing `messages` array to repair Mastra/AI SDK
370
+ * stream-replay quirks that Databricks-hosted Claude rejects (see
371
+ * {@link sanitizeServingMessages}).
372
+ * 2. When `MASTRA_DEBUG_LLM=1`, dumps the (post-sanitize) JSON body
373
+ * to stderr so 4xx debugging doesn't have to fight AI SDK's
374
+ * `[Array]` formatter.
375
+ *
376
+ * Safe to call from any hot path: {@link commonUtils.memoize} ensures
377
+ * the wrapper is installed at most once per process, so subsequent
378
+ * calls collapse to a single cached promise even when
379
+ * {@link buildModel} fires on every agent step.
380
+ */
381
+ const setupFetchInterceptor = commonUtils.memoize((): void => {
382
+ const debug = Boolean(process.env.MASTRA_DEBUG_LLM);
383
+ const original = globalThis.fetch.bind(globalThis);
384
+ globalThis.fetch = (async (input, init) => {
385
+ const url = httpUtils.toURL(input);
386
+ if (
387
+ !url ||
388
+ !url.pathname.startsWith(SERVING_ENDPOINTS_PATH_PREFIX) ||
389
+ typeof init?.body !== "string"
390
+ ) {
391
+ return original(input, init);
392
+ }
393
+ const rewritten = rewriteServingBody(init.body);
394
+ if (rewritten !== init.body) {
395
+ init = { ...init, body: rewritten };
396
+ }
397
+ if (debug) {
398
+ try {
399
+ console.error("[mastra:llm-debug] -> POST", url.toString());
400
+ console.error(JSON.stringify(JSON.parse(rewritten), null, 2));
401
+ } catch {
402
+ console.error("[mastra:llm-debug] -> POST", url.toString(), "(non-JSON body)");
403
+ }
404
+ }
405
+ return original(input, init);
406
+ }) as typeof globalThis.fetch;
407
+ });
408
+
409
+ /**
410
+ * Parse, sanitize, and re-serialize a `/serving-endpoints/...` POST
411
+ * body. Returns the original string verbatim when the body is not
412
+ * JSON, has no `messages`, or no rewrite was needed; this lets the
413
+ * caller skip the allocation of a new `init` object in the common
414
+ * pass-through case.
415
+ */
416
+ function rewriteServingBody(body: string): string {
417
+ let parsed: { messages?: unknown };
418
+ try {
419
+ parsed = JSON.parse(body);
420
+ } catch {
421
+ return body;
422
+ }
423
+ if (!Array.isArray(parsed.messages)) return body;
424
+ const changed = sanitizeServingMessages(parsed.messages as ChatMessage[]);
425
+ return changed ? JSON.stringify(parsed) : body;
426
+ }
427
+
428
+ /**
429
+ * Repair a Mastra/AI SDK message replay that Databricks-hosted Claude
430
+ * rejects with `"This model does not support assistant message
431
+ * prefill. The conversation must end with a user message."`.
432
+ *
433
+ * The bug pattern: when an assistant turn streams text *and* a
434
+ * `tool_call`, the AI SDK persists them as two separate assistant
435
+ * entries (text-only and tool-call-only). On the next agent step the
436
+ * tool-call entry is replayed *before* the tool result and the
437
+ * text entry is replayed *after* it, so the conversation ends with a
438
+ * trailing assistant text message. Anthropic interprets that as a
439
+ * prefill request and rejects it on Databricks (the upstream Bedrock
440
+ * route disallows prefill).
441
+ *
442
+ * Fix: when the last message is an assistant text with no `tool_calls`
443
+ * and the chain immediately before it is `assistant(tool_calls=...)`
444
+ * followed only by `tool(...)` results, fold the trailing text back
445
+ * into the `content` of that opening assistant and drop the duplicate.
446
+ * The result is the canonical OpenAI shape
447
+ * `[..., user, assistant(text + tool_calls), tool(...)]` which both
448
+ * Databricks Claude and every other endpoint accept.
449
+ *
450
+ * Mutates `messages` in place; returns `true` when something changed
451
+ * so the caller knows whether to re-serialize.
452
+ */
453
+ function sanitizeServingMessages(messages: ChatMessage[]): boolean {
454
+ if (messages.length < 2) return false;
455
+ const last = messages[messages.length - 1];
456
+ if (
457
+ !last ||
458
+ last.role !== "assistant" ||
459
+ (last.tool_calls && last.tool_calls.length > 0)
460
+ ) {
461
+ return false;
462
+ }
463
+
464
+ // Walk back through any contiguous tool-result messages to find the
465
+ // assistant turn that opened this tool sequence.
466
+ let i = messages.length - 2;
467
+ while (i >= 0 && messages[i]?.role === "tool") i--;
468
+ if (i < 0) return false;
469
+ const opener = messages[i];
470
+ if (
471
+ !opener ||
472
+ opener.role !== "assistant" ||
473
+ !opener.tool_calls ||
474
+ opener.tool_calls.length === 0
475
+ ) {
476
+ return false;
477
+ }
478
+
479
+ // `trimToNull` collapses the `typeof string && trimmed` dance and
480
+ // drops blank fragments before the `\n\n` join below, so the merge
481
+ // never introduces stray leading / trailing whitespace.
482
+ const merged = [
483
+ stringUtils.trimToNull(opener.content),
484
+ stringUtils.trimToNull(last.content),
485
+ ]
486
+ .filter((s): s is string => s !== null)
487
+ .join("\n\n");
488
+ opener.content = merged;
489
+ messages.pop();
490
+ return true;
491
+ }