@hebo-ai/gateway 0.3.0 → 0.4.0-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -288,11 +288,23 @@ const gw = gateway({
288
288
  * @returns Optional RequestPatch to merge into headers / override body.
289
289
  * Returning a Response stops execution of the endpoint.
290
290
  */
291
- before: async (ctx: { request: Request }): Promise<RequestPatch | Response | void> => {
291
+ onRequest: async (ctx: { request: Request }): Promise<RequestPatch | Response | void> => {
292
292
  // Example Use Cases:
293
- // - Transform request body
294
293
  // - Verify authentication
295
294
  // - Enforce rate limits
295
+ return undefined;
296
+ },
297
+ /**
298
+ * Runs after body is parsed & validated.
299
+ * @param ctx.body Parsed request body.
300
+ * @returns Replacement parsed body, or undefined to keep original body unchanged.
301
+ */
302
+ before: async (ctx: {
303
+ body: ChatCompletionsBody | EmbeddingsBody;
304
+ operation: "text" | "embeddings";
305
+ }): Promise<ChatCompletionsBody | EmbeddingsBody | void> => {
306
+ // Example Use Cases:
307
+ // - Transform request body
296
308
  // - Observability integration
297
309
  return undefined;
298
310
  },
@@ -337,18 +349,29 @@ const gw = gateway({
337
349
  * @returns Modified result, or undefined to keep original.
338
350
  */
339
351
  after: async (ctx: {
340
- result: object | ReadableStream<Uint8Array>
341
- }): Promise<object | ReadableStream<Uint8Array> | void> => {
352
+ result: ChatCompletions | ReadableStream<ChatCompletionsChunk | OpenAIError> | Embeddings | object
353
+ }): Promise<ChatCompletions | ReadableStream<ChatCompletionsChunk | OpenAIError> | Embeddings | object | void> => {
342
354
  // Example Use Cases:
343
355
  // - Transform result
344
356
  // - Result logging
345
357
  return undefined;
346
358
  },
359
+ /**
360
+ * Runs after the gateway has produced the final Response.
361
+ * @param ctx.response Response object returned by the lifecycle.
362
+ * @returns Replacement response, or undefined to keep original.
363
+ */
364
+ onResponse: async (ctx: { response: Response }): Promise<Response | void> => {
365
+ // Example Use Cases:
366
+ // - Add response headers
367
+ // - Replace or redact response payload
368
+ return undefined;
369
+ },
347
370
  },
348
371
  });
349
372
  ```
350
373
 
351
- The `ctx` object is **readonly for core fields**. Use return values to override request / result and to provide modelId / provider instances.
374
+ The `ctx` object is **readonly for core fields**. Use return values to override request / parsed body / result / response and to provide modelId / provider instances.
352
375
 
353
376
  > [!TIP]
354
377
  > To pass data between hooks, use `ctx.state`. It’s a per-request mutable bag in which you can stash things like auth info, routing decisions, timers, or trace IDs and read them later again in any of the other hooks.
@@ -25,7 +25,7 @@ export declare const convertToToolSet: (tools: ChatCompletionsTool[] | undefined
25
25
  export declare const convertToToolChoice: (toolChoice: ChatCompletionsToolChoice | undefined) => ToolChoice<ToolSet> | undefined;
26
26
  export declare function toChatCompletions(result: GenerateTextResult<ToolSet, Output.Output>, model: string): ChatCompletions;
27
27
  export declare function toChatCompletionsResponse(result: GenerateTextResult<ToolSet, Output.Output>, model: string, responseInit?: ResponseInit): Response;
28
- export declare function toChatCompletionsStream(result: StreamTextResult<ToolSet, Output.Output>, model: string): ReadableStream<Uint8Array>;
28
+ export declare function toChatCompletionsStream(result: StreamTextResult<ToolSet, Output.Output>, model: string): ReadableStream<ChatCompletionsChunk | OpenAIError>;
29
29
  export declare function toChatCompletionsStreamResponse(result: StreamTextResult<ToolSet, Output.Output>, model: string, responseInit?: ResponseInit): Response;
30
30
  export declare class ChatCompletionsStream extends TransformStream<TextStreamPart<ToolSet>, ChatCompletionsChunk | OpenAIError> {
31
31
  constructor(model: string);
@@ -1,5 +1,5 @@
1
1
  import { convertBase64ToUint8Array } from "@ai-sdk/provider-utils";
2
- import { jsonSchema, JsonToSseTransformStream, tool } from "ai";
2
+ import { jsonSchema, tool } from "ai";
3
3
  import { GatewayError } from "../../errors/gateway";
4
4
  import { OpenAIError, toOpenAIError } from "../../errors/openai";
5
5
  import { toResponse } from "../../utils/response";
@@ -241,10 +241,7 @@ export function toChatCompletionsResponse(result, model, responseInit) {
241
241
  return toResponse(toChatCompletions(result, model), responseInit);
242
242
  }
243
243
  export function toChatCompletionsStream(result, model) {
244
- return result.fullStream
245
- .pipeThrough(new ChatCompletionsStream(model))
246
- .pipeThrough(new JsonToSseTransformStream())
247
- .pipeThrough(new TextEncoderStream());
244
+ return result.fullStream.pipeThrough(new ChatCompletionsStream(model));
248
245
  }
249
246
  export function toChatCompletionsStreamResponse(result, model, responseInit) {
250
247
  return toResponse(toChatCompletionsStream(result, model), responseInit);
@@ -302,6 +299,7 @@ export class ChatCompletionsStream extends TransformStream {
302
299
  }
303
300
  case "error": {
304
301
  const error = part.error;
302
+ // FUTURE mask in production mode and return responseID
305
303
  controller.enqueue(toOpenAIError(error));
306
304
  break;
307
305
  }
@@ -30,13 +30,14 @@ export const chatCompletions = (config) => {
30
30
  throw new GatewayError(z.prettifyError(parsed.error), 400);
31
31
  }
32
32
  ctx.body = parsed.data;
33
+ ctx.operation = "text";
34
+ ctx.body = (await hooks?.before?.(ctx)) ?? ctx.body;
33
35
  // Resolve model + provider (hooks may override defaults).
34
36
  let inputs, stream;
35
- ({ model: ctx.modelId, stream, ...inputs } = parsed.data);
37
+ ({ model: ctx.modelId, stream, ...inputs } = ctx.body);
36
38
  ctx.resolvedModelId =
37
39
  (await hooks?.resolveModelId?.(ctx)) ?? ctx.modelId;
38
40
  logger.debug(`[chat] resolved ${ctx.modelId} to ${ctx.resolvedModelId}`);
39
- ctx.operation = "text";
40
41
  const override = await hooks?.resolveProvider?.(ctx);
41
42
  ctx.provider =
42
43
  override ??
@@ -79,7 +80,7 @@ export const chatCompletions = (config) => {
79
80
  throw new DOMException("Upstream failed", "AbortError");
80
81
  },
81
82
  timeout: {
82
- chunkMs: 5 * 60 * 1000,
83
+ totalMs: 5 * 60 * 1000,
83
84
  },
84
85
  experimental_include: {
85
86
  requestBody: false,
@@ -88,7 +89,8 @@ export const chatCompletions = (config) => {
88
89
  ...textOptions,
89
90
  });
90
91
  markPerf(ctx.request, "aiSdkEnd");
91
- return toChatCompletionsStream(result, ctx.modelId);
92
+ ctx.result = toChatCompletionsStream(result, ctx.modelId);
93
+ return (await hooks?.after?.(ctx)) ?? ctx.result;
92
94
  }
93
95
  const result = await generateText({
94
96
  model: languageModelWithMiddleware,
@@ -104,7 +106,8 @@ export const chatCompletions = (config) => {
104
106
  });
105
107
  markPerf(ctx.request, "aiSdkEnd");
106
108
  logger.trace({ requestId: resolveRequestId(ctx.request), result }, "[chat] AI SDK result");
107
- return toChatCompletions(result, ctx.modelId);
109
+ ctx.result = toChatCompletions(result, ctx.modelId);
110
+ return (await hooks?.after?.(ctx)) ?? ctx.result;
108
111
  };
109
112
  return { handler: winterCgHandler(handler, config) };
110
113
  };
@@ -30,13 +30,14 @@ export const embeddings = (config) => {
30
30
  throw new GatewayError(z.prettifyError(parsed.error), 400);
31
31
  }
32
32
  ctx.body = parsed.data;
33
+ ctx.operation = "embeddings";
34
+ ctx.body = (await hooks?.before?.(ctx)) ?? ctx.body;
33
35
  // Resolve model + provider (hooks may override defaults).
34
36
  let inputs;
35
- ({ model: ctx.modelId, ...inputs } = parsed.data);
37
+ ({ model: ctx.modelId, ...inputs } = ctx.body);
36
38
  ctx.resolvedModelId =
37
39
  (await hooks?.resolveModelId?.(ctx)) ?? ctx.modelId;
38
40
  logger.debug(`[embeddings] resolved ${ctx.modelId} to ${ctx.resolvedModelId}`);
39
- ctx.operation = "embeddings";
40
41
  const override = await hooks?.resolveProvider?.(ctx);
41
42
  ctx.provider =
42
43
  override ??
@@ -67,7 +68,8 @@ export const embeddings = (config) => {
67
68
  });
68
69
  markPerf(ctx.request, "aiSdkEnd");
69
70
  logger.trace({ requestId: resolveRequestId(ctx.request), result }, "[embeddings] AI SDK result");
70
- return toEmbeddings(result, ctx.modelId);
71
+ ctx.result = toEmbeddings(result, ctx.modelId);
72
+ return (await hooks?.after?.(ctx)) ?? ctx.result;
71
73
  };
72
74
  return { handler: winterCgHandler(handler, config) };
73
75
  };
@@ -27,6 +27,7 @@ export function toOpenAIErrorResponse(error, responseInit) {
27
27
  let message;
28
28
  if (shouldMask) {
29
29
  const requestId = resolveRequestId(responseInit);
30
+ // FUTURE: always attach requestId to errors (masked and unmasked)
30
31
  message = `${STATUS_CODE(meta.status)} (${requestId})`;
31
32
  }
32
33
  else {
@@ -1,2 +1,2 @@
1
1
  import type { GatewayConfig, GatewayContext } from "./types";
2
- export declare const winterCgHandler: (run: (ctx: GatewayContext) => Promise<object | ReadableStream<Uint8Array>>, config: GatewayConfig) => (request: Request, state?: Record<string, unknown>) => Promise<Response>;
2
+ export declare const winterCgHandler: (run: (ctx: GatewayContext) => Promise<object | ReadableStream<object>>, config: GatewayConfig) => (request: Request, state?: Record<string, unknown>) => Promise<Response>;
package/dist/lifecycle.js CHANGED
@@ -9,23 +9,19 @@ export const winterCgHandler = (run, config) => {
9
9
  const parsedConfig = parseConfig(config);
10
10
  const core = async (ctx) => {
11
11
  try {
12
- const before = await parsedConfig.hooks?.before?.(ctx);
13
- if (before) {
14
- if (before instanceof Response) {
15
- ctx.response = before;
12
+ const onRequest = await parsedConfig.hooks?.onRequest?.(ctx);
13
+ if (onRequest) {
14
+ if (onRequest instanceof Response) {
15
+ ctx.response = onRequest;
16
16
  return;
17
17
  }
18
- ctx.request = maybeApplyRequestPatch(ctx.request, before);
19
- }
20
- ctx.result = await run(ctx);
21
- const after = await parsedConfig.hooks?.after?.(ctx);
22
- if (after)
23
- ctx.result = after;
24
- if (ctx.result instanceof Response) {
25
- ctx.response = ctx.result;
26
- return;
18
+ ctx.request = maybeApplyRequestPatch(ctx.request, onRequest);
27
19
  }
20
+ ctx.result = (await run(ctx));
28
21
  ctx.response = toResponse(ctx.result, prepareResponseInit(ctx.request));
22
+ const onResponse = await parsedConfig.hooks?.onResponse?.(ctx);
23
+ if (onResponse)
24
+ ctx.response = onResponse;
29
25
  }
30
26
  catch (error) {
31
27
  logger.error({
package/dist/types.d.ts CHANGED
@@ -1,11 +1,12 @@
1
1
  import type { ProviderV3 } from "@ai-sdk/provider";
2
- import type { ChatCompletionsBody } from "./endpoints/chat-completions/schema";
3
- import type { EmbeddingsBody } from "./endpoints/embeddings/schema";
2
+ import type { ChatCompletions, ChatCompletionsBody, ChatCompletionsChunk } from "./endpoints/chat-completions/schema";
3
+ import type { Embeddings, EmbeddingsBody } from "./endpoints/embeddings/schema";
4
+ import type { OpenAIError } from "./errors/openai";
4
5
  import type { Logger, LoggerConfig } from "./logger";
5
6
  import type { ModelCatalog, ModelId } from "./models/types";
6
7
  import type { ProviderId, ProviderRegistry } from "./providers/types";
7
8
  /**
8
- * Request overrides returned from the `before` hook.
9
+ * Request overrides returned from the `onRequest` hook.
9
10
  */
10
11
  export type RequestPatch = {
11
12
  /**
@@ -64,7 +65,7 @@ export type GatewayContext = {
64
65
  /**
65
66
  * Result returned by the handler (pre-response).
66
67
  */
67
- result?: object | ReadableStream<Uint8Array>;
68
+ result?: ChatCompletions | ReadableStream<ChatCompletionsChunk | OpenAIError> | Embeddings | object;
68
69
  /**
69
70
  * Final response returned by the lifecycle.
70
71
  */
@@ -77,10 +78,12 @@ export type HookContext = Omit<Readonly<GatewayContext>, "state"> & {
77
78
  state: GatewayContext["state"];
78
79
  };
79
80
  type RequiredHookContext<K extends keyof GatewayContext> = Omit<HookContext, K> & Required<Pick<HookContext, K>>;
80
- export type BeforeHookContext = RequiredHookContext<"request">;
81
+ export type OnRequestHookContext = RequiredHookContext<"request">;
82
+ export type BeforeHookContext = RequiredHookContext<"request" | "body" | "operation">;
81
83
  export type ResolveModelHookContext = RequiredHookContext<"request" | "body" | "modelId">;
82
84
  export type ResolveProviderHookContext = RequiredHookContext<"request" | "body" | "modelId" | "resolvedModelId" | "operation">;
83
85
  export type AfterHookContext = RequiredHookContext<"request" | "result" | "provider" | "resolvedModelId" | "operation">;
86
+ export type OnResponseHookContext = RequiredHookContext<"request" | "response">;
84
87
  /**
85
88
  * Hooks to plugin to the gateway lifecycle.
86
89
  */
@@ -90,7 +93,12 @@ export type GatewayHooks = {
90
93
  * @returns Optional RequestPatch to merge into headers / override body,
91
94
  * or Response to short-circuit the request.
92
95
  */
93
- before?: (ctx: BeforeHookContext) => void | RequestPatch | Response | Promise<void | RequestPatch | Response>;
96
+ onRequest?: (ctx: OnRequestHookContext) => void | RequestPatch | Response | Promise<void | RequestPatch | Response>;
97
+ /**
98
+ * Runs after request JSON is parsed and validated for chat completions / embeddings.
99
+ * @returns Replacement parsed body, or undefined to keep original.
100
+ */
101
+ before?: (ctx: BeforeHookContext) => void | ChatCompletionsBody | EmbeddingsBody | Promise<void | ChatCompletionsBody | EmbeddingsBody>;
94
102
  /**
95
103
  * Maps a user-provided model ID or alias to a canonical ID.
96
104
  * @returns Canonical model ID or undefined to keep original.
@@ -103,9 +111,14 @@ export type GatewayHooks = {
103
111
  resolveProvider?: (ctx: ResolveProviderHookContext) => ProviderV3 | void | Promise<ProviderV3 | void>;
104
112
  /**
105
113
  * Runs after the endpoint handler.
106
- * @returns Response to replace, or undefined to keep original.
114
+ * @returns Result to replace, or undefined to keep original.
115
+ */
116
+ after?: (ctx: AfterHookContext) => void | ChatCompletions | ReadableStream<ChatCompletionsChunk | OpenAIError> | Embeddings | Promise<void | ChatCompletions | ReadableStream<ChatCompletionsChunk | OpenAIError> | Embeddings>;
117
+ /**
118
+ * Runs after the lifecycle has produced the final Response.
119
+ * @returns Replacement Response, or undefined to keep original.
107
120
  */
108
- after?: (ctx: AfterHookContext) => void | object | ReadableStream<Uint8Array> | Promise<void | object | ReadableStream<Uint8Array>>;
121
+ onResponse?: (ctx: OnResponseHookContext) => void | Response | Promise<void | Response>;
109
122
  };
110
123
  /**
111
124
  * Main configuration object for the gateway.
@@ -1,3 +1,3 @@
1
1
  export declare const prepareResponseInit: (request: Request) => ResponseInit;
2
2
  export declare const mergeResponseInit: (defaultHeaders: HeadersInit, responseInit?: ResponseInit) => ResponseInit;
3
- export declare const toResponse: (result: ReadableStream<Uint8Array> | Uint8Array<ArrayBuffer> | object | string, responseInit?: ResponseInit) => Response;
3
+ export declare const toResponse: (result: ReadableStream | Uint8Array<ArrayBuffer> | object | string, responseInit?: ResponseInit) => Response;
@@ -1,5 +1,17 @@
1
1
  import { REQUEST_ID_HEADER, resolveRequestId } from "./headers";
2
2
  const TEXT_ENCODER = new TextEncoder();
3
+ class JsonToSseTransformStream extends TransformStream {
4
+ constructor() {
5
+ super({
6
+ transform(part, controller) {
7
+ controller.enqueue(`data: ${JSON.stringify(part)}\n\n`);
8
+ },
9
+ flush(controller) {
10
+ controller.enqueue("data: [DONE]\n\n");
11
+ },
12
+ });
13
+ }
14
+ }
3
15
  export const prepareResponseInit = (request) => ({
4
16
  headers: { [REQUEST_ID_HEADER]: resolveRequestId(request.headers) },
5
17
  });
@@ -20,7 +32,10 @@ export const mergeResponseInit = (defaultHeaders, responseInit) => {
20
32
  export const toResponse = (result, responseInit) => {
21
33
  let body;
22
34
  const isStream = result instanceof ReadableStream;
23
- if (isStream || result instanceof Uint8Array) {
35
+ if (isStream) {
36
+ body = result.pipeThrough(new JsonToSseTransformStream()).pipeThrough(new TextEncoderStream());
37
+ }
38
+ else if (result instanceof Uint8Array) {
24
39
  body = result;
25
40
  }
26
41
  else if (typeof result === "string") {
@@ -44,7 +59,7 @@ export const toResponse = (result, responseInit) => {
44
59
  ? {
45
60
  "content-type": "text/event-stream",
46
61
  "cache-control": "no-cache",
47
- Connection: "keep-alive",
62
+ connection: "keep-alive",
48
63
  }
49
64
  : {
50
65
  "content-type": "application/json",
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@hebo-ai/gateway",
3
- "version": "0.3.0",
3
+ "version": "0.4.0-alpha.1",
4
4
  "description": "AI gateway as a framework. For full control over models, routing & lifecycle. OpenAI-compatible /chat/completions, /embeddings & /models.",
5
5
  "keywords": [
6
6
  "ai",
@@ -18,7 +18,7 @@ import type {
18
18
  } from "ai";
19
19
 
20
20
  import { convertBase64ToUint8Array } from "@ai-sdk/provider-utils";
21
- import { jsonSchema, JsonToSseTransformStream, tool } from "ai";
21
+ import { jsonSchema, tool } from "ai";
22
22
 
23
23
  import type {
24
24
  ChatCompletionsToolCall,
@@ -368,11 +368,8 @@ export function toChatCompletionsResponse(
368
368
  export function toChatCompletionsStream(
369
369
  result: StreamTextResult<ToolSet, Output.Output>,
370
370
  model: string,
371
- ): ReadableStream<Uint8Array> {
372
- return result.fullStream
373
- .pipeThrough(new ChatCompletionsStream(model))
374
- .pipeThrough(new JsonToSseTransformStream())
375
- .pipeThrough(new TextEncoderStream());
371
+ ): ReadableStream<ChatCompletionsChunk | OpenAIError> {
372
+ return result.fullStream.pipeThrough(new ChatCompletionsStream(model));
376
373
  }
377
374
 
378
375
  export function toChatCompletionsStreamResponse(
@@ -476,6 +473,7 @@ export class ChatCompletionsStream extends TransformStream<
476
473
 
477
474
  case "error": {
478
475
  const error = part.error;
476
+ // FUTURE mask in production mode and return responseID
479
477
  controller.enqueue(toOpenAIError(error));
480
478
  break;
481
479
  }
@@ -2,6 +2,8 @@ import { generateText, streamText, wrapLanguageModel } from "ai";
2
2
  import * as z from "zod/mini";
3
3
 
4
4
  import type {
5
+ AfterHookContext,
6
+ BeforeHookContext,
5
7
  GatewayConfig,
6
8
  Endpoint,
7
9
  GatewayContext,
@@ -43,15 +45,17 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
43
45
  }
44
46
  ctx.body = parsed.data;
45
47
 
48
+ ctx.operation = "text";
49
+ ctx.body = (await hooks?.before?.(ctx as BeforeHookContext)) ?? ctx.body;
50
+
46
51
  // Resolve model + provider (hooks may override defaults).
47
52
  let inputs, stream;
48
- ({ model: ctx.modelId, stream, ...inputs } = parsed.data);
53
+ ({ model: ctx.modelId, stream, ...inputs } = ctx.body);
49
54
 
50
55
  ctx.resolvedModelId =
51
56
  (await hooks?.resolveModelId?.(ctx as ResolveModelHookContext)) ?? ctx.modelId;
52
57
  logger.debug(`[chat] resolved ${ctx.modelId} to ${ctx.resolvedModelId}`);
53
58
 
54
- ctx.operation = "text";
55
59
  const override = await hooks?.resolveProvider?.(ctx as ResolveProviderHookContext);
56
60
  ctx.provider =
57
61
  override ??
@@ -101,7 +105,7 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
101
105
  throw new DOMException("Upstream failed", "AbortError");
102
106
  },
103
107
  timeout: {
104
- chunkMs: 5 * 60 * 1000,
108
+ totalMs: 5 * 60 * 1000,
105
109
  },
106
110
  experimental_include: {
107
111
  requestBody: false,
@@ -111,7 +115,9 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
111
115
  });
112
116
  markPerf(ctx.request, "aiSdkEnd");
113
117
 
114
- return toChatCompletionsStream(result, ctx.modelId);
118
+ ctx.result = toChatCompletionsStream(result, ctx.modelId);
119
+
120
+ return (await hooks?.after?.(ctx as AfterHookContext)) ?? ctx.result;
115
121
  }
116
122
 
117
123
  const result = await generateText({
@@ -130,7 +136,9 @@ export const chatCompletions = (config: GatewayConfig): Endpoint => {
130
136
 
131
137
  logger.trace({ requestId: resolveRequestId(ctx.request), result }, "[chat] AI SDK result");
132
138
 
133
- return toChatCompletions(result, ctx.modelId);
139
+ ctx.result = toChatCompletions(result, ctx.modelId);
140
+
141
+ return (await hooks?.after?.(ctx as AfterHookContext)) ?? ctx.result;
134
142
  };
135
143
 
136
144
  return { handler: winterCgHandler(handler, config) };
@@ -2,6 +2,8 @@ import { embedMany, wrapEmbeddingModel } from "ai";
2
2
  import * as z from "zod/mini";
3
3
 
4
4
  import type {
5
+ AfterHookContext,
6
+ BeforeHookContext,
5
7
  GatewayConfig,
6
8
  Endpoint,
7
9
  GatewayContext,
@@ -43,15 +45,17 @@ export const embeddings = (config: GatewayConfig): Endpoint => {
43
45
  }
44
46
  ctx.body = parsed.data;
45
47
 
48
+ ctx.operation = "embeddings";
49
+ ctx.body = (await hooks?.before?.(ctx as BeforeHookContext)) ?? ctx.body;
50
+
46
51
  // Resolve model + provider (hooks may override defaults).
47
52
  let inputs;
48
- ({ model: ctx.modelId, ...inputs } = parsed.data);
53
+ ({ model: ctx.modelId, ...inputs } = ctx.body);
49
54
 
50
55
  ctx.resolvedModelId =
51
56
  (await hooks?.resolveModelId?.(ctx as ResolveModelHookContext)) ?? ctx.modelId;
52
57
  logger.debug(`[embeddings] resolved ${ctx.modelId} to ${ctx.resolvedModelId}`);
53
58
 
54
- ctx.operation = "embeddings";
55
59
  const override = await hooks?.resolveProvider?.(ctx as ResolveProviderHookContext);
56
60
  ctx.provider =
57
61
  override ??
@@ -94,7 +98,9 @@ export const embeddings = (config: GatewayConfig): Endpoint => {
94
98
  "[embeddings] AI SDK result",
95
99
  );
96
100
 
97
- return toEmbeddings(result, ctx.modelId);
101
+ ctx.result = toEmbeddings(result, ctx.modelId);
102
+
103
+ return (await hooks?.after?.(ctx as AfterHookContext)) ?? ctx.result;
98
104
  };
99
105
 
100
106
  return { handler: winterCgHandler(handler, config) };
@@ -35,6 +35,7 @@ export function toOpenAIErrorResponse(error: unknown, responseInit?: ResponseIni
35
35
  let message;
36
36
  if (shouldMask) {
37
37
  const requestId = resolveRequestId(responseInit);
38
+ // FUTURE: always attach requestId to errors (masked and unmasked)
38
39
  message = `${STATUS_CODE(meta.status)} (${requestId})`;
39
40
  } else {
40
41
  message = meta.message;
package/src/lifecycle.ts CHANGED
@@ -1,4 +1,9 @@
1
- import type { AfterHookContext, BeforeHookContext, GatewayConfig, GatewayContext } from "./types";
1
+ import type {
2
+ GatewayConfig,
3
+ GatewayContext,
4
+ OnRequestHookContext,
5
+ OnResponseHookContext,
6
+ } from "./types";
2
7
 
3
8
  import { parseConfig } from "./config";
4
9
  import { toOpenAIErrorResponse } from "./errors/openai";
@@ -9,32 +14,27 @@ import { maybeApplyRequestPatch, prepareRequestHeaders } from "./utils/request";
9
14
  import { prepareResponseInit, toResponse } from "./utils/response";
10
15
 
11
16
  export const winterCgHandler = (
12
- run: (ctx: GatewayContext) => Promise<object | ReadableStream<Uint8Array>>,
17
+ run: (ctx: GatewayContext) => Promise<object | ReadableStream<object>>,
13
18
  config: GatewayConfig,
14
19
  ) => {
15
20
  const parsedConfig = parseConfig(config);
16
21
 
17
22
  const core = async (ctx: GatewayContext): Promise<void> => {
18
23
  try {
19
- const before = await parsedConfig.hooks?.before?.(ctx as BeforeHookContext);
20
- if (before) {
21
- if (before instanceof Response) {
22
- ctx.response = before;
24
+ const onRequest = await parsedConfig.hooks?.onRequest?.(ctx as OnRequestHookContext);
25
+ if (onRequest) {
26
+ if (onRequest instanceof Response) {
27
+ ctx.response = onRequest;
23
28
  return;
24
29
  }
25
- ctx.request = maybeApplyRequestPatch(ctx.request, before);
30
+ ctx.request = maybeApplyRequestPatch(ctx.request, onRequest);
26
31
  }
27
32
 
28
- ctx.result = await run(ctx);
33
+ ctx.result = (await run(ctx)) as typeof ctx.result;
34
+ ctx.response = toResponse(ctx.result!, prepareResponseInit(ctx.request));
29
35
 
30
- const after = await parsedConfig.hooks?.after?.(ctx as AfterHookContext);
31
- if (after) ctx.result = after;
32
-
33
- if (ctx.result instanceof Response) {
34
- ctx.response = ctx.result;
35
- return;
36
- }
37
- ctx.response = toResponse(ctx.result, prepareResponseInit(ctx.request));
36
+ const onResponse = await parsedConfig.hooks?.onResponse?.(ctx as OnResponseHookContext);
37
+ if (onResponse) ctx.response = onResponse;
38
38
  } catch (error) {
39
39
  logger.error({
40
40
  requestId: resolveRequestId(ctx.request)!,
package/src/types.ts CHANGED
@@ -1,13 +1,18 @@
1
1
  import type { ProviderV3 } from "@ai-sdk/provider";
2
2
 
3
- import type { ChatCompletionsBody } from "./endpoints/chat-completions/schema";
4
- import type { EmbeddingsBody } from "./endpoints/embeddings/schema";
3
+ import type {
4
+ ChatCompletions,
5
+ ChatCompletionsBody,
6
+ ChatCompletionsChunk,
7
+ } from "./endpoints/chat-completions/schema";
8
+ import type { Embeddings, EmbeddingsBody } from "./endpoints/embeddings/schema";
9
+ import type { OpenAIError } from "./errors/openai";
5
10
  import type { Logger, LoggerConfig } from "./logger";
6
11
  import type { ModelCatalog, ModelId } from "./models/types";
7
12
  import type { ProviderId, ProviderRegistry } from "./providers/types";
8
13
 
9
14
  /**
10
- * Request overrides returned from the `before` hook.
15
+ * Request overrides returned from the `onRequest` hook.
11
16
  */
12
17
  export type RequestPatch = {
13
18
  /**
@@ -67,7 +72,11 @@ export type GatewayContext = {
67
72
  /**
68
73
  * Result returned by the handler (pre-response).
69
74
  */
70
- result?: object | ReadableStream<Uint8Array>;
75
+ result?:
76
+ | ChatCompletions
77
+ | ReadableStream<ChatCompletionsChunk | OpenAIError>
78
+ | Embeddings
79
+ | object;
71
80
  /**
72
81
  * Final response returned by the lifecycle.
73
82
  */
@@ -83,7 +92,8 @@ export type HookContext = Omit<Readonly<GatewayContext>, "state"> & {
83
92
 
84
93
  type RequiredHookContext<K extends keyof GatewayContext> = Omit<HookContext, K> &
85
94
  Required<Pick<HookContext, K>>;
86
- export type BeforeHookContext = RequiredHookContext<"request">;
95
+ export type OnRequestHookContext = RequiredHookContext<"request">;
96
+ export type BeforeHookContext = RequiredHookContext<"request" | "body" | "operation">;
87
97
  export type ResolveModelHookContext = RequiredHookContext<"request" | "body" | "modelId">;
88
98
  export type ResolveProviderHookContext = RequiredHookContext<
89
99
  "request" | "body" | "modelId" | "resolvedModelId" | "operation"
@@ -91,6 +101,7 @@ export type ResolveProviderHookContext = RequiredHookContext<
91
101
  export type AfterHookContext = RequiredHookContext<
92
102
  "request" | "result" | "provider" | "resolvedModelId" | "operation"
93
103
  >;
104
+ export type OnResponseHookContext = RequiredHookContext<"request" | "response">;
94
105
 
95
106
  /**
96
107
  * Hooks to plugin to the gateway lifecycle.
@@ -101,9 +112,20 @@ export type GatewayHooks = {
101
112
  * @returns Optional RequestPatch to merge into headers / override body,
102
113
  * or Response to short-circuit the request.
103
114
  */
115
+ onRequest?: (
116
+ ctx: OnRequestHookContext,
117
+ ) => void | RequestPatch | Response | Promise<void | RequestPatch | Response>;
118
+ /**
119
+ * Runs after request JSON is parsed and validated for chat completions / embeddings.
120
+ * @returns Replacement parsed body, or undefined to keep original.
121
+ */
104
122
  before?: (
105
123
  ctx: BeforeHookContext,
106
- ) => void | RequestPatch | Response | Promise<void | RequestPatch | Response>;
124
+ ) =>
125
+ | void
126
+ | ChatCompletionsBody
127
+ | EmbeddingsBody
128
+ | Promise<void | ChatCompletionsBody | EmbeddingsBody>;
107
129
  /**
108
130
  * Maps a user-provided model ID or alias to a canonical ID.
109
131
  * @returns Canonical model ID or undefined to keep original.
@@ -118,15 +140,23 @@ export type GatewayHooks = {
118
140
  ) => ProviderV3 | void | Promise<ProviderV3 | void>;
119
141
  /**
120
142
  * Runs after the endpoint handler.
121
- * @returns Response to replace, or undefined to keep original.
143
+ * @returns Result to replace, or undefined to keep original.
122
144
  */
123
145
  after?: (
124
146
  ctx: AfterHookContext,
125
147
  ) =>
126
148
  | void
127
- | object
128
- | ReadableStream<Uint8Array>
129
- | Promise<void | object | ReadableStream<Uint8Array>>;
149
+ | ChatCompletions
150
+ | ReadableStream<ChatCompletionsChunk | OpenAIError>
151
+ | Embeddings
152
+ | Promise<
153
+ void | ChatCompletions | ReadableStream<ChatCompletionsChunk | OpenAIError> | Embeddings
154
+ >;
155
+ /**
156
+ * Runs after the lifecycle has produced the final Response.
157
+ * @returns Replacement Response, or undefined to keep original.
158
+ */
159
+ onResponse?: (ctx: OnResponseHookContext) => void | Response | Promise<void | Response>;
130
160
  };
131
161
 
132
162
  /**
@@ -2,6 +2,19 @@ import { REQUEST_ID_HEADER, resolveRequestId } from "./headers";
2
2
 
3
3
  const TEXT_ENCODER = new TextEncoder();
4
4
 
5
+ class JsonToSseTransformStream extends TransformStream<unknown, string> {
6
+ constructor() {
7
+ super({
8
+ transform(part, controller) {
9
+ controller.enqueue(`data: ${JSON.stringify(part)}\n\n`);
10
+ },
11
+ flush(controller) {
12
+ controller.enqueue("data: [DONE]\n\n");
13
+ },
14
+ });
15
+ }
16
+ }
17
+
5
18
  export const prepareResponseInit = (request: Request): ResponseInit => ({
6
19
  headers: { [REQUEST_ID_HEADER]: resolveRequestId(request.headers)! },
7
20
  });
@@ -25,13 +38,15 @@ export const mergeResponseInit = (
25
38
  };
26
39
 
27
40
  export const toResponse = (
28
- result: ReadableStream<Uint8Array> | Uint8Array<ArrayBuffer> | object | string,
41
+ result: ReadableStream | Uint8Array<ArrayBuffer> | object | string,
29
42
  responseInit?: ResponseInit,
30
43
  ): Response => {
31
44
  let body: BodyInit;
32
45
 
33
46
  const isStream = result instanceof ReadableStream;
34
- if (isStream || result instanceof Uint8Array) {
47
+ if (isStream) {
48
+ body = result.pipeThrough(new JsonToSseTransformStream()).pipeThrough(new TextEncoderStream());
49
+ } else if (result instanceof Uint8Array) {
35
50
  body = result;
36
51
  } else if (typeof result === "string") {
37
52
  body = TEXT_ENCODER.encode(result);
@@ -57,7 +72,7 @@ export const toResponse = (
57
72
  ? {
58
73
  "content-type": "text/event-stream",
59
74
  "cache-control": "no-cache",
60
- Connection: "keep-alive",
75
+ connection: "keep-alive",
61
76
  }
62
77
  : {
63
78
  "content-type": "application/json",