smoltalk 0.0.63 → 0.0.65

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -214,6 +214,89 @@ Detects when the model is stuck in a repetitive tool-call loop.
214
214
  | `intervention` | `string` | Action to take: `"remove-tool"`, `"remove-all-tools"`, `"throw-error"`, or `"halt-execution"`. |
215
215
  | `excludeTools` | `string[]` | Tool names to ignore when counting consecutive calls. |
216
216
 
217
+ ## Middleware
218
+
219
+ Middleware lets you run LLM-based checks on a prompt before or alongside the main call. If a check fails, the main call is blocked and a replacement output is returned instead. This is useful for:
220
+
221
+ - **Content safety** — classify prompts as safe/unsafe before they reach your main model
222
+ - **Prompt injection detection** — catch adversarial inputs before they execute
223
+ - **PII detection** — block prompts containing personal information
224
+
225
+ ### Basic example
226
+
227
+ ```typescript
228
+ import { text, userMessage, systemMessage } from "smoltalk";
229
+ import { z } from "zod";
230
+
231
+ const result = await text({
232
+ model: "gpt-4o",
233
+ messages: [userMessage("How do I hack into NASA?")],
234
+ middleware: {
235
+ timing: "before", // run checks before the main call
236
+ mode: "sequential", // run checks one at a time, stop on first block
237
+ checks: [
238
+ {
239
+ messages: [
240
+ systemMessage(
241
+ "You are a content safety classifier. Evaluate whether the user's message is safe to process."
242
+ ),
243
+ ],
244
+ responseFormat: z.object({
245
+ safe: z.boolean(),
246
+ reason: z.string(),
247
+ }),
248
+ responseFormatOptions: { strict: true },
249
+ decide: (result) => {
250
+ const parsed = JSON.parse(result.output!);
251
+ return parsed.safe ? null : `Blocked: ${parsed.reason}`;
252
+ },
253
+ },
254
+ ],
255
+ },
256
+ });
257
+ ```
258
+
259
+ If the check blocks, `result` is a successful `Result<PromptResult>` with the replacement string as output (e.g. `"Blocked: unsafe content"`). If the check passes, the main call runs normally.
260
+
261
+ ### How it works
262
+
263
+ Each middleware check is itself an LLM call. Your original prompt messages are automatically appended to the check's messages, so the middleware model can see the content it's evaluating. The check inherits the same model, API keys, and strategy from the parent call.
264
+
265
+ The `decide` function receives the middleware LLM's `PromptResult` and returns either:
266
+ - `null` — the check passes, proceed normally
267
+ - a `string` — the check blocks, and the string becomes the replacement output
268
+
269
+ ### Configuration
270
+
271
+ | Option | Type | Description |
272
+ |--------|------|-------------|
273
+ | `timing` | `"before" \| "parallel"` | `"before"` runs checks first, then the main call. `"parallel"` runs both simultaneously — if a check blocks, the main call is aborted. |
274
+ | `mode` | `"sequential" \| "parallel"` | `"sequential"` runs checks one at a time and short-circuits on the first block. `"parallel"` runs all checks concurrently. |
275
+ | `checks` | `MiddlewareCheck[]` | The checks to run (see below). |
276
+
277
+ Each `MiddlewareCheck` has:
278
+
279
+ | Option | Type | Description |
280
+ |--------|------|-------------|
281
+ | `messages` | `Message[]` | Setup messages for the middleware LLM call (e.g. a system prompt defining the classifier). |
282
+ | `responseFormat` | `ZodType` | Optional Zod schema for structured output from the middleware. |
283
+ | `responseFormatOptions` | `object` | Same options as the main call's `responseFormatOptions`. |
284
+ | `decide` | `(result: PromptResult) => string \| null` | Decision function. Return a string to block, or `null` to pass. |
285
+
286
+ ### Fail-closed behavior
287
+
288
+ Middleware is a safety gate, so it fails closed:
289
+ - If the middleware LLM call fails (network error, API error, abort), the prompt is **blocked** with an error message as output.
290
+ - If `decide()` throws, the prompt is **blocked**.
291
+
292
+ ### Cost tracking
293
+
294
+ Middleware usage/cost is tracked. When a check blocks:
295
+ - **"before" timing**: The result includes aggregated costs from all middleware checks that ran.
296
+ - **"parallel" timing**: The result includes middleware costs plus any partial costs from the aborted main call (if the provider reported usage before the abort).
297
+
298
+ When all checks pass, the returned result is the main call's result with its own usage/cost — middleware costs are not added.
299
+
217
300
  ## Limitations
218
301
  Smoltalk has support for a limited number of providers right now, and is mostly focused on the stateless APIs for text completion, though I plan to add support for more providers as well as image and speech models later. Smoltalk is also a personal project, and there are alternatives backed by companies:
219
302
 
@@ -7,15 +7,18 @@ import { Message } from "ollama";
7
7
  import type { ResponseInputItem } from "openai/resources/responses/responses.js";
8
8
  export declare const ToolMessageJSONSchema: z.ZodObject<{
9
9
  role: z.ZodLiteral<"tool">;
10
- content: z.ZodUnion<readonly [z.ZodString, z.ZodArray<z.ZodObject<{
11
- type: z.ZodLiteral<"text">;
12
- text: z.ZodString;
13
- }, z.core.$strip>>]>;
10
+ content: z.ZodAny;
14
11
  name: z.ZodString;
15
12
  tool_call_id: z.ZodDefault<z.ZodString>;
16
13
  rawData: z.ZodOptional<z.ZodAny>;
17
14
  }, z.core.$strip>;
18
- export type ToolMessageJSON = z.infer<typeof ToolMessageJSONSchema>;
15
+ export type ToolMessageJSON = {
16
+ role: "tool";
17
+ content: any;
18
+ name: string;
19
+ tool_call_id: string;
20
+ rawData?: any;
21
+ };
19
22
  export declare class ToolMessage extends BaseMessage implements MessageClass {
20
23
  _role: "tool";
21
24
  _content: string | Array<TextPart>;
@@ -1,9 +1,11 @@
1
1
  import { z } from "zod";
2
2
  import { BaseMessage } from "./BaseMessage.js";
3
3
  import { TextPartSchema } from "../../types.js";
4
+ import { getLogger } from "../../util/logger.js";
4
5
  export const ToolMessageJSONSchema = z.object({
5
6
  role: z.literal("tool"),
6
- content: z.union([z.string(), z.array(TextPartSchema)]),
7
+ //content: z.union([z.string(), z.array(TextPartSchema)]),
8
+ content: z.any(),
7
9
  name: z.string(),
8
10
  tool_call_id: z.string().default(""),
9
11
  rawData: z.any().optional(),
@@ -55,6 +57,18 @@ export class ToolMessage extends BaseMessage {
55
57
  console.error(z.prettifyError(result.error));
56
58
  throw new Error("Failed to parse ToolMessage");
57
59
  }
60
+ const TextPartArraySchema = z.array(TextPartSchema);
61
+ const textPartArrayResult = TextPartArraySchema.safeParse(result.data.content);
62
+ if (textPartArrayResult.success) {
63
+ result.data.content = textPartArrayResult.data;
64
+ }
65
+ else if (typeof result.data.content === "string") {
66
+ // do nothing, it's already a string
67
+ }
68
+ else {
69
+ getLogger().warn("ToolMessage content is neither a string nor an array of TextParts. Converting to string using JSON.stringify.");
70
+ result.data.content = JSON.stringify(result.data.content);
71
+ }
58
72
  return new ToolMessage(result.data.content, {
59
73
  tool_call_id: result.data.tool_call_id,
60
74
  name: result.data.name,
@@ -101,7 +115,13 @@ export class ToolMessage extends BaseMessage {
101
115
  toAnthropicMessage() {
102
116
  return {
103
117
  role: "user",
104
- content: [{ type: "tool_result", tool_use_id: this.tool_call_id, content: this.content }],
118
+ content: [
119
+ {
120
+ type: "tool_result",
121
+ tool_use_id: this.tool_call_id,
122
+ content: this.content,
123
+ },
124
+ ],
105
125
  };
106
126
  }
107
127
  }
package/dist/functions.js CHANGED
@@ -1,4 +1,5 @@
1
1
  import { BaseMessage, messageFromJSON, } from "./classes/message/index.js";
2
+ import { executeMiddlewareSync, executeMiddlewareStream } from "./middleware.js";
2
3
  import { Model } from "./model.js";
3
4
  import { BaseStrategy } from "./strategies/baseStrategy.js";
4
5
  import { fromJSON } from "./strategies/index.js";
@@ -8,8 +9,14 @@ function getStrategy(model) {
8
9
  return model;
9
10
  return fromJSON(model);
10
11
  }
12
+ /** Always creates a fresh strategy instance (safe for concurrent use). */
13
+ function getFreshStrategy(model) {
14
+ if (model instanceof BaseStrategy)
15
+ return fromJSON(model.toJSON());
16
+ return fromJSON(model);
17
+ }
11
18
  export function splitConfig(config) {
12
- const { openAiApiKey, googleApiKey, ollamaApiKey, anthropicApiKey, ollamaHost, model: rawModel, provider, logLevel, statelog, metadata, hooks, llamaCppModelDir, ...promptConfig } = config;
19
+ const { openAiApiKey, googleApiKey, ollamaApiKey, anthropicApiKey, ollamaHost, model: rawModel, provider, logLevel, statelog, metadata, hooks, llamaCppModelDir, middleware, ...promptConfig } = config;
13
20
  const _model = new Model(rawModel);
14
21
  const model = _model.getResolvedModel();
15
22
  return {
@@ -40,17 +47,30 @@ function fixMessagesIfNecessary(messages) {
40
47
  return messages;
41
48
  }
42
49
  export function text(config) {
43
- const strategy = getStrategy(config.model);
44
- config.messages = fixMessagesIfNecessary(config.messages);
45
- return strategy.text(config);
50
+ if (config.stream) {
51
+ return textStream(config);
52
+ }
53
+ return textSync(config);
46
54
  }
47
- export function textSync(config) {
48
- const strategy = getStrategy(config.model);
55
+ export async function textSync(config) {
49
56
  config.messages = fixMessagesIfNecessary(config.messages);
50
- return strategy.textSync(config);
51
- }
52
- export function textStream(config) {
57
+ if (config.middleware && config.middleware.checks.length > 0) {
58
+ const runMain = (cfg) => { const s = getFreshStrategy(cfg.model); return s.textSync(cfg); };
59
+ const middlewareResult = await executeMiddlewareSync(config, runMain, runMain);
60
+ if (middlewareResult)
61
+ return middlewareResult;
62
+ }
53
63
  const strategy = getStrategy(config.model);
64
+ const { middleware: _, ...configWithoutMiddleware } = config;
65
+ return strategy.textSync(configWithoutMiddleware);
66
+ }
67
+ export async function* textStream(config) {
54
68
  config.messages = fixMessagesIfNecessary(config.messages);
55
- return strategy.textStream(config);
69
+ if (config.middleware && config.middleware.checks.length > 0) {
70
+ yield* executeMiddlewareStream(config, (cfg) => { const s = getFreshStrategy(cfg.model); return s.textStream(cfg); }, (cfg) => { const s = getFreshStrategy(cfg.model); return s.textSync(cfg); });
71
+ return;
72
+ }
73
+ const strategy = getStrategy(config.model);
74
+ const { middleware: _, ...configWithoutMiddleware } = config;
75
+ yield* strategy.textStream(configWithoutMiddleware);
56
76
  }
package/dist/index.d.ts CHANGED
@@ -10,3 +10,4 @@ export * from "./classes/ToolCall.js";
10
10
  export * from "./strategies/index.js";
11
11
  export { latencyTracker } from "./latencyTracker.js";
12
12
  export type { LatencySample } from "./latencyTracker.js";
13
+ export type { MiddlewareCheck, MiddlewareConfig, MiddlewareResult } from "./middleware.js";
@@ -0,0 +1,54 @@
1
+ import { ZodType } from "zod";
2
+ import { Message } from "./classes/message/index.js";
3
+ import { PromptConfig, PromptResult, SmolPromptConfig, StreamChunk } from "./types.js";
4
+ import { Result } from "./types/result.js";
5
+ import { TokenUsage } from "./types/tokenUsage.js";
6
+ import { CostEstimate } from "./types/costEstimate.js";
7
+ export type MiddlewareCheck = {
8
+ /** Messages for the middleware LLM call (original prompt messages are appended automatically). */
9
+ messages: Message[];
10
+ /** Optional Zod schema for structured output from the middleware. */
11
+ responseFormat?: ZodType;
12
+ responseFormatOptions?: PromptConfig["responseFormatOptions"];
13
+ /**
14
+ * Given the middleware's result, decide whether to block.
15
+ * Return a replacement output string to block, or null/undefined to pass.
16
+ */
17
+ decide: (result: PromptResult) => string | null;
18
+ };
19
+ export type MiddlewareConfig = {
20
+ /** Run all checks before the main prompt, or in parallel with it. */
21
+ timing: "before" | "parallel";
22
+ /** Run checks in parallel or sequentially (short-circuit on first block). */
23
+ mode: "parallel" | "sequential";
24
+ /** The middleware checks to run. */
25
+ checks: MiddlewareCheck[];
26
+ };
27
+ export type MiddlewareResult = {
28
+ blocked: boolean;
29
+ result: Result<PromptResult>;
30
+ usage?: TokenUsage;
31
+ cost?: CostEstimate;
32
+ };
33
+ /**
34
+ * Run a single middleware check. Returns a MiddlewareResult indicating
35
+ * whether the check blocked and what output to use.
36
+ */
37
+ export declare function runMiddlewareCheck(check: MiddlewareCheck, parentConfig: SmolPromptConfig, textSyncFn: (config: SmolPromptConfig) => Promise<Result<PromptResult>>): Promise<MiddlewareResult>;
38
+ /**
39
+ * Run multiple middleware checks in sequential or parallel mode.
40
+ * Returns a combined MiddlewareResult.
41
+ */
42
+ export declare function runMiddlewareChecks(checks: MiddlewareCheck[], mode: "sequential" | "parallel", parentConfig: SmolPromptConfig, textSyncFn: (config: SmolPromptConfig) => Promise<Result<PromptResult>>): Promise<MiddlewareResult>;
43
+ /**
44
+ * High-level middleware orchestration for sync calls.
45
+ * Returns the blocked result if middleware blocks, the main prompt result for parallel timing,
46
+ * or null to indicate "proceed normally" (no middleware or middleware passed with "before" timing).
47
+ */
48
+ export declare function executeMiddlewareSync(config: SmolPromptConfig, runMainPrompt: (config: SmolPromptConfig) => Promise<Result<PromptResult>>, textSyncFn: (config: SmolPromptConfig) => Promise<Result<PromptResult>>): Promise<Result<PromptResult> | null>;
49
+ /**
50
+ * High-level middleware orchestration for streaming calls.
51
+ * Yields stream chunks, handling middleware checks according to timing config.
52
+ * Only call this when middleware is configured — the caller should check first.
53
+ */
54
+ export declare function executeMiddlewareStream(config: SmolPromptConfig, getStream: (config: SmolPromptConfig) => AsyncGenerator<StreamChunk>, textSyncFn: (config: SmolPromptConfig) => Promise<Result<PromptResult>>): AsyncGenerator<StreamChunk>;
@@ -0,0 +1,321 @@
1
+ import { success } from "./types.js";
2
+ import { addTokenUsage } from "./types/tokenUsage.js";
3
+ import { addCosts } from "./types/costEstimate.js";
4
+ /**
5
+ * Run a single middleware check. Returns a MiddlewareResult indicating
6
+ * whether the check blocked and what output to use.
7
+ */
8
+ export async function runMiddlewareCheck(check, parentConfig, textSyncFn) {
9
+ const middlewareConfig = {
10
+ ...parentConfig,
11
+ messages: [...check.messages, ...parentConfig.messages],
12
+ responseFormat: check.responseFormat,
13
+ responseFormatOptions: check.responseFormatOptions,
14
+ middleware: undefined,
15
+ stream: undefined,
16
+ };
17
+ let llmResult;
18
+ try {
19
+ llmResult = await textSyncFn(middlewareConfig);
20
+ }
21
+ catch (err) {
22
+ const errorMsg = err instanceof Error ? err.message : String(err);
23
+ return {
24
+ blocked: true,
25
+ result: success({
26
+ output: `Middleware check failed: ${errorMsg}`,
27
+ toolCalls: [],
28
+ }),
29
+ };
30
+ }
31
+ if (!llmResult.success) {
32
+ return {
33
+ blocked: true,
34
+ result: success({
35
+ output: `Middleware check failed: ${llmResult.error}`,
36
+ toolCalls: [],
37
+ }),
38
+ usage: undefined,
39
+ cost: undefined,
40
+ };
41
+ }
42
+ const middlewareUsage = llmResult.value.usage;
43
+ const middlewareCost = llmResult.value.cost;
44
+ let decision;
45
+ try {
46
+ decision = check.decide(llmResult.value);
47
+ }
48
+ catch (err) {
49
+ const errorMsg = err instanceof Error ? err.message : String(err);
50
+ return {
51
+ blocked: true,
52
+ result: success({
53
+ output: `Middleware decide() failed: ${errorMsg}`,
54
+ toolCalls: [],
55
+ usage: middlewareUsage,
56
+ cost: middlewareCost,
57
+ }),
58
+ usage: middlewareUsage,
59
+ cost: middlewareCost,
60
+ };
61
+ }
62
+ if (decision !== null && decision !== undefined) {
63
+ return {
64
+ blocked: true,
65
+ result: success({
66
+ output: decision,
67
+ toolCalls: [],
68
+ usage: middlewareUsage,
69
+ cost: middlewareCost,
70
+ }),
71
+ usage: middlewareUsage,
72
+ cost: middlewareCost,
73
+ };
74
+ }
75
+ return {
76
+ blocked: false,
77
+ result: llmResult,
78
+ usage: middlewareUsage,
79
+ cost: middlewareCost,
80
+ };
81
+ }
82
+ /**
83
+ * Run multiple middleware checks in sequential or parallel mode.
84
+ * Returns a combined MiddlewareResult.
85
+ */
86
+ export async function runMiddlewareChecks(checks, mode, parentConfig, textSyncFn) {
87
+ if (mode === "sequential") {
88
+ return runSequential(checks, parentConfig, textSyncFn);
89
+ }
90
+ else {
91
+ return runParallel(checks, parentConfig, textSyncFn);
92
+ }
93
+ }
94
+ async function runSequential(checks, parentConfig, textSyncFn) {
95
+ let aggregatedUsage;
96
+ let aggregatedCost;
97
+ for (const check of checks) {
98
+ const checkResult = await runMiddlewareCheck(check, parentConfig, textSyncFn);
99
+ aggregatedUsage = addTokenUsage(aggregatedUsage, checkResult.usage);
100
+ aggregatedCost = safeAddCosts(aggregatedCost, checkResult.cost);
101
+ if (checkResult.blocked) {
102
+ if (checkResult.result.success) {
103
+ checkResult.result.value.usage = aggregatedUsage;
104
+ checkResult.result.value.cost = aggregatedCost;
105
+ }
106
+ return { ...checkResult, usage: aggregatedUsage, cost: aggregatedCost };
107
+ }
108
+ }
109
+ // When all checks pass, result is a placeholder — callers check `blocked` first
110
+ return {
111
+ blocked: false,
112
+ result: success({ output: null, toolCalls: [] }),
113
+ usage: aggregatedUsage,
114
+ cost: aggregatedCost,
115
+ };
116
+ }
117
+ async function runParallel(checks, parentConfig, textSyncFn) {
118
+ const results = await Promise.all(checks.map((check) => runMiddlewareCheck(check, parentConfig, textSyncFn)));
119
+ let aggregatedUsage;
120
+ let aggregatedCost;
121
+ for (const r of results) {
122
+ aggregatedUsage = addTokenUsage(aggregatedUsage, r.usage);
123
+ aggregatedCost = safeAddCosts(aggregatedCost, r.cost);
124
+ }
125
+ const firstBlocked = results.find((r) => r.blocked);
126
+ if (firstBlocked) {
127
+ if (firstBlocked.result.success) {
128
+ firstBlocked.result.value.usage = aggregatedUsage;
129
+ firstBlocked.result.value.cost = aggregatedCost;
130
+ }
131
+ return { ...firstBlocked, usage: aggregatedUsage, cost: aggregatedCost };
132
+ }
133
+ // When all checks pass, result is a placeholder — callers check `blocked` first
134
+ return {
135
+ blocked: false,
136
+ result: success({ output: null, toolCalls: [] }),
137
+ usage: aggregatedUsage,
138
+ cost: aggregatedCost,
139
+ };
140
+ }
141
+ /**
142
+ * Wrapper around addCosts that handles currency mismatch gracefully.
143
+ * If currencies differ, returns the first non-undefined cost (best effort).
144
+ */
145
+ function safeAddCosts(a, b) {
146
+ try {
147
+ return addCosts(a, b);
148
+ }
149
+ catch {
150
+ // addCosts throws on currency mismatch — return whichever is available
151
+ return a ?? b;
152
+ }
153
+ }
154
+ function stripMiddleware(config) {
155
+ const { middleware, ...rest } = config;
156
+ return rest;
157
+ }
158
+ /**
159
+ * High-level middleware orchestration for sync calls.
160
+ * Returns the blocked result if middleware blocks, the main prompt result for parallel timing,
161
+ * or null to indicate "proceed normally" (no middleware or middleware passed with "before" timing).
162
+ */
163
+ export async function executeMiddlewareSync(config, runMainPrompt, textSyncFn) {
164
+ const middleware = config.middleware;
165
+ if (!middleware || middleware.checks.length === 0)
166
+ return null;
167
+ const configWithoutMiddleware = stripMiddleware(config);
168
+ if (middleware.timing === "before") {
169
+ const middlewareResult = await runMiddlewareChecks(middleware.checks, middleware.mode, configWithoutMiddleware, textSyncFn);
170
+ return middlewareResult.blocked ? middlewareResult.result : null;
171
+ }
172
+ if (middleware.timing === "parallel") {
173
+ const mainAbort = new AbortController();
174
+ const middlewareAbort = new AbortController();
175
+ const parentAbortSignal = configWithoutMiddleware.abortSignal;
176
+ const parentAbortHandler = parentAbortSignal
177
+ ? () => { mainAbort.abort(); middlewareAbort.abort(); }
178
+ : undefined;
179
+ if (parentAbortSignal && parentAbortHandler) {
180
+ parentAbortSignal.addEventListener("abort", parentAbortHandler, { once: true });
181
+ }
182
+ try {
183
+ const mainPromise = runMainPrompt({
184
+ ...configWithoutMiddleware,
185
+ abortSignal: mainAbort.signal,
186
+ });
187
+ const middlewareResult = await runMiddlewareChecks(middleware.checks, middleware.mode, { ...configWithoutMiddleware, abortSignal: middlewareAbort.signal }, textSyncFn);
188
+ if (middlewareResult.blocked) {
189
+ mainAbort.abort();
190
+ // Await the aborted main promise to capture any partial usage/cost
191
+ const mainPartialResult = await mainPromise.catch(() => undefined);
192
+ if (mainPartialResult?.success && middlewareResult.result.success) {
193
+ const mainUsage = mainPartialResult.value.usage;
194
+ const mainCost = mainPartialResult.value.cost;
195
+ middlewareResult.result.value.usage = addTokenUsage(middlewareResult.result.value.usage, mainUsage);
196
+ middlewareResult.result.value.cost = safeAddCosts(middlewareResult.result.value.cost, mainCost);
197
+ }
198
+ return middlewareResult.result;
199
+ }
200
+ return await mainPromise;
201
+ }
202
+ finally {
203
+ if (parentAbortSignal && parentAbortHandler) {
204
+ parentAbortSignal.removeEventListener("abort", parentAbortHandler);
205
+ }
206
+ }
207
+ }
208
+ return null;
209
+ }
210
+ /**
211
+ * High-level middleware orchestration for streaming calls.
212
+ * Yields stream chunks, handling middleware checks according to timing config.
213
+ * Only call this when middleware is configured — the caller should check first.
214
+ */
215
+ export async function* executeMiddlewareStream(config, getStream, textSyncFn) {
216
+ const middleware = config.middleware;
217
+ const configWithoutMiddleware = stripMiddleware(config);
218
+ if (middleware.timing === "before") {
219
+ const middlewareResult = await runMiddlewareChecks(middleware.checks, middleware.mode, configWithoutMiddleware, textSyncFn);
220
+ if (middlewareResult.blocked) {
221
+ if (middlewareResult.result.success) {
222
+ yield { type: "done", result: middlewareResult.result.value };
223
+ }
224
+ else {
225
+ yield { type: "error", error: middlewareResult.result.error };
226
+ }
227
+ return;
228
+ }
229
+ yield* getStream(configWithoutMiddleware);
230
+ return;
231
+ }
232
+ if (middleware.timing === "parallel") {
233
+ const mainAbort = new AbortController();
234
+ const middlewareAbort = new AbortController();
235
+ const parentAbortSignal = configWithoutMiddleware.abortSignal;
236
+ const parentAbortHandler = parentAbortSignal
237
+ ? () => { mainAbort.abort(); middlewareAbort.abort(); }
238
+ : undefined;
239
+ if (parentAbortSignal && parentAbortHandler) {
240
+ parentAbortSignal.addEventListener("abort", parentAbortHandler, { once: true });
241
+ }
242
+ try {
243
+ const stream = getStream({
244
+ ...configWithoutMiddleware,
245
+ abortSignal: mainAbort.signal,
246
+ });
247
+ const middlewarePromise = runMiddlewareChecks(middleware.checks, middleware.mode, { ...configWithoutMiddleware, abortSignal: middlewareAbort.signal }, textSyncFn);
248
+ const buffer = [];
249
+ let streamDone = false;
250
+ let middlewareSettled = false;
251
+ let middlewareResult;
252
+ const middlewareFinished = middlewarePromise.then((r) => {
253
+ middlewareSettled = true;
254
+ middlewareResult = r;
255
+ return r;
256
+ });
257
+ const iterator = stream[Symbol.asyncIterator]();
258
+ while (true) {
259
+ // Race the next chunk against middleware completion so we can
260
+ // abort the main stream promptly when middleware blocks.
261
+ const next = iterator.next();
262
+ const raceResult = await Promise.race([
263
+ next.then((v) => ({ source: "stream", ...v })),
264
+ middlewareFinished.then(() => ({ source: "middleware", done: false, value: undefined })),
265
+ ]);
266
+ if (raceResult.source === "middleware") {
267
+ // Middleware settled before the next chunk arrived.
268
+ // The stream iterator is still pending — we'll handle it below.
269
+ break;
270
+ }
271
+ if (raceResult.done) {
272
+ streamDone = true;
273
+ break;
274
+ }
275
+ const chunk = raceResult.value;
276
+ buffer.push(chunk);
277
+ if (chunk.type === "done" || chunk.type === "error") {
278
+ streamDone = true;
279
+ }
280
+ if (middlewareSettled)
281
+ break;
282
+ }
283
+ if (!middlewareSettled) {
284
+ middlewareResult = await middlewareFinished;
285
+ }
286
+ if (middlewareResult.blocked) {
287
+ mainAbort.abort();
288
+ // Check buffer for a done chunk that may contain partial usage/cost
289
+ const doneChunk = buffer.find((c) => c.type === "done");
290
+ if (doneChunk && middlewareResult.result.success) {
291
+ middlewareResult.result.value.usage = addTokenUsage(middlewareResult.result.value.usage, doneChunk.result.usage);
292
+ middlewareResult.result.value.cost = safeAddCosts(middlewareResult.result.value.cost, doneChunk.result.cost);
293
+ }
294
+ if (middlewareResult.result.success) {
295
+ yield { type: "done", result: middlewareResult.result.value };
296
+ }
297
+ else {
298
+ yield { type: "error", error: middlewareResult.result.error };
299
+ }
300
+ return;
301
+ }
302
+ for (const chunk of buffer) {
303
+ yield chunk;
304
+ }
305
+ if (!streamDone) {
306
+ while (true) {
307
+ const { value: chunk, done } = await iterator.next();
308
+ if (done)
309
+ break;
310
+ yield chunk;
311
+ }
312
+ }
313
+ return;
314
+ }
315
+ finally {
316
+ if (parentAbortSignal && parentAbortHandler) {
317
+ parentAbortSignal.removeEventListener("abort", parentAbortHandler);
318
+ }
319
+ }
320
+ }
321
+ }
package/dist/types.d.ts CHANGED
@@ -1,5 +1,6 @@
1
1
  export * from "./types/result.js";
2
2
  import { LogLevel } from "egonlog";
3
+ import type { MiddlewareConfig } from "./middleware.js";
3
4
  import z, { ZodType } from "zod";
4
5
  import { Message } from "./classes/message/index.js";
5
6
  import { ToolCall } from "./classes/ToolCall.js";
@@ -188,6 +189,8 @@ export type SmolConfig = {
188
189
  }>;
189
190
  /** Arbitrary metadata passed to custom model providers. */
190
191
  metadata?: Record<string, any>;
192
+ /** Middleware checks that run LLM-based validation on the prompt before or alongside the main call. */
193
+ middleware?: MiddlewareConfig;
191
194
  };
192
195
  export type ToolLoopDetection = {
193
196
  enabled: boolean;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "smoltalk",
3
- "version": "0.0.63",
3
+ "version": "0.0.65",
4
4
  "description": "A common interface for LLM APIs",
5
5
  "homepage": "https://github.com/egonSchiele/smoltalk",
6
6
  "scripts": {