smoltalk 0.0.55 → 0.0.56
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/clients/anthropic.d.ts +1 -0
- package/dist/clients/anthropic.js +57 -26
- package/dist/clients/baseClient.d.ts +1 -0
- package/dist/clients/baseClient.js +23 -1
- package/dist/clients/google.js +35 -3
- package/dist/clients/ollama.js +8 -0
- package/dist/clients/openai.d.ts +1 -0
- package/dist/clients/openai.js +39 -9
- package/dist/clients/openaiResponses.d.ts +1 -0
- package/dist/clients/openaiResponses.js +31 -7
- package/dist/index.d.ts +2 -0
- package/dist/index.js +1 -0
- package/dist/latencyTracker.d.ts +32 -0
- package/dist/latencyTracker.js +73 -0
- package/dist/model.d.ts +3 -2
- package/dist/model.js +11 -1
- package/dist/models.d.ts +1 -1
- package/dist/smolError.d.ts +6 -0
- package/dist/smolError.js +12 -0
- package/dist/strategies/fallbackStrategy.js +23 -1
- package/dist/strategies/fastestStrategy.d.ts +17 -0
- package/dist/strategies/fastestStrategy.js +95 -0
- package/dist/strategies/index.d.ts +5 -1
- package/dist/strategies/index.js +16 -1
- package/dist/strategies/randomStrategy.d.ts +12 -0
- package/dist/strategies/randomStrategy.js +39 -0
- package/dist/strategies/types.d.ts +19 -1
- package/dist/strategies/types.js +14 -0
- package/package.json +1 -1
|
@@ -12,6 +12,7 @@ export declare class SmolAnthropic extends BaseClient implements SmolClient {
|
|
|
12
12
|
getModel(): ModelName;
|
|
13
13
|
private calculateUsageAndCost;
|
|
14
14
|
private buildRequest;
|
|
15
|
+
private rethrowAsSmolError;
|
|
15
16
|
_textSync(config: PromptConfig): Promise<Result<PromptResult>>;
|
|
16
17
|
_textStream(config: PromptConfig): AsyncGenerator<StreamChunk>;
|
|
17
18
|
}
|
|
@@ -4,6 +4,7 @@ import { SystemMessage, DeveloperMessage } from "../classes/message/index.js";
|
|
|
4
4
|
import { getLogger } from "../logger.js";
|
|
5
5
|
import { success, } from "../types.js";
|
|
6
6
|
import { zodToAnthropicTool } from "../util/tool.js";
|
|
7
|
+
import { SmolContentPolicyError, SmolContextWindowExceededError, } from "../smolError.js";
|
|
7
8
|
import { BaseClient } from "./baseClient.js";
|
|
8
9
|
import { Model } from "../model.js";
|
|
9
10
|
const DEFAULT_MAX_TOKENS = 4096;
|
|
@@ -82,6 +83,24 @@ export class SmolAnthropic extends BaseClient {
|
|
|
82
83
|
: undefined;
|
|
83
84
|
return { system, messages: anthropicMessages, tools, thinking };
|
|
84
85
|
}
|
|
86
|
+
rethrowAsSmolError(error) {
|
|
87
|
+
if (error instanceof Anthropic.APIError) {
|
|
88
|
+
const msg = error.message.toLowerCase();
|
|
89
|
+
if (msg.includes("prompt is too long") ||
|
|
90
|
+
msg.includes("context length") ||
|
|
91
|
+
msg.includes("context window") ||
|
|
92
|
+
msg.includes("too many tokens")) {
|
|
93
|
+
throw new SmolContextWindowExceededError(error.message);
|
|
94
|
+
}
|
|
95
|
+
if (msg.includes("content policy") ||
|
|
96
|
+
msg.includes("usage policies") ||
|
|
97
|
+
msg.includes("content filtering") ||
|
|
98
|
+
msg.includes("violates our")) {
|
|
99
|
+
throw new SmolContentPolicyError(error.message);
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
throw error;
|
|
103
|
+
}
|
|
85
104
|
async _textSync(config) {
|
|
86
105
|
const { system, messages, tools, thinking } = this.buildRequest(config);
|
|
87
106
|
let debugData = {
|
|
@@ -95,19 +114,25 @@ export class SmolAnthropic extends BaseClient {
|
|
|
95
114
|
this.logger.debug("Sending request to Anthropic:", debugData);
|
|
96
115
|
this.statelogClient?.promptRequest(debugData);
|
|
97
116
|
const signal = this.getAbortSignal(config);
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
117
|
+
let response;
|
|
118
|
+
try {
|
|
119
|
+
response = await this.client.messages.create({
|
|
120
|
+
model: this.getModel(),
|
|
121
|
+
max_tokens: config.maxTokens ?? DEFAULT_MAX_TOKENS,
|
|
122
|
+
messages,
|
|
123
|
+
...(system && { system }),
|
|
124
|
+
...(tools && { tools }),
|
|
125
|
+
...(thinking && { thinking }),
|
|
126
|
+
...(config.temperature !== undefined && {
|
|
127
|
+
temperature: config.temperature,
|
|
128
|
+
}),
|
|
129
|
+
...(config.rawAttributes || {}),
|
|
130
|
+
stream: false,
|
|
131
|
+
}, { ...(signal && { signal }) });
|
|
132
|
+
}
|
|
133
|
+
catch (error) {
|
|
134
|
+
this.rethrowAsSmolError(error);
|
|
135
|
+
}
|
|
111
136
|
this.logger.debug("Response from Anthropic:", response);
|
|
112
137
|
this.statelogClient?.promptResponse(response);
|
|
113
138
|
let output = null;
|
|
@@ -148,19 +173,25 @@ export class SmolAnthropic extends BaseClient {
|
|
|
148
173
|
this.logger.debug("Sending streaming request to Anthropic:", streamDebugData);
|
|
149
174
|
this.statelogClient?.promptRequest(streamDebugData);
|
|
150
175
|
const signal = this.getAbortSignal(config);
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
176
|
+
let stream;
|
|
177
|
+
try {
|
|
178
|
+
stream = await this.client.messages.create({
|
|
179
|
+
model: this.model,
|
|
180
|
+
max_tokens: config.maxTokens ?? DEFAULT_MAX_TOKENS,
|
|
181
|
+
messages,
|
|
182
|
+
...(system && { system }),
|
|
183
|
+
...(tools && { tools }),
|
|
184
|
+
...(thinking && { thinking }),
|
|
185
|
+
...(config.temperature !== undefined && {
|
|
186
|
+
temperature: config.temperature,
|
|
187
|
+
}),
|
|
188
|
+
...(config.rawAttributes || {}),
|
|
189
|
+
stream: true,
|
|
190
|
+
}, { ...(signal && { signal }) });
|
|
191
|
+
}
|
|
192
|
+
catch (error) {
|
|
193
|
+
this.rethrowAsSmolError(error);
|
|
194
|
+
}
|
|
164
195
|
let content = "";
|
|
165
196
|
// Track tool blocks by index: index -> { id, name, arguments (partial JSON) }
|
|
166
197
|
const toolBlocks = new Map();
|
|
@@ -24,6 +24,7 @@ export declare class BaseClient implements SmolClient {
|
|
|
24
24
|
continue: boolean;
|
|
25
25
|
newPromptConfig: PromptConfig;
|
|
26
26
|
};
|
|
27
|
+
private recordLatency;
|
|
27
28
|
extractResponse(promptConfig: PromptConfig, rawValue: any, schema: any, depth?: number): any;
|
|
28
29
|
textWithRetry(promptConfig: PromptConfig, retries: number): Promise<Result<PromptResult>>;
|
|
29
30
|
_textSync(promptConfig: PromptConfig): Promise<Result<PromptResult>>;
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { AssistantMessage, userMessage, assistantMessage, } from "../classes/message/index.js";
|
|
2
|
+
import { latencyTracker } from "../latencyTracker.js";
|
|
2
3
|
import { getLogger } from "../logger.js";
|
|
3
4
|
import { getModel, isTextModel } from "../models.js";
|
|
4
5
|
import { SmolStructuredOutputError } from "../smolError.js";
|
|
@@ -146,9 +147,11 @@ export class BaseClient {
|
|
|
146
147
|
value: { output: null, toolCalls: [], model: this.config.model },
|
|
147
148
|
};
|
|
148
149
|
}
|
|
150
|
+
const startTime = performance.now();
|
|
149
151
|
try {
|
|
150
152
|
const result = await this.textWithRetry(newPromptConfig, newPromptConfig.responseFormatOptions?.numRetries ||
|
|
151
153
|
DEFAULT_NUM_RETRIES);
|
|
154
|
+
this.recordLatency(startTime, result);
|
|
152
155
|
return result;
|
|
153
156
|
}
|
|
154
157
|
catch (err) {
|
|
@@ -210,6 +213,15 @@ export class BaseClient {
|
|
|
210
213
|
}
|
|
211
214
|
return { continue: true, newPromptConfig: promptConfig };
|
|
212
215
|
}
|
|
216
|
+
recordLatency(startTime, result) {
|
|
217
|
+
if (!result.success)
|
|
218
|
+
return;
|
|
219
|
+
const outputTokens = result.value.usage?.outputTokens;
|
|
220
|
+
if (!outputTokens || outputTokens <= 0)
|
|
221
|
+
return;
|
|
222
|
+
const elapsedMs = performance.now() - startTime;
|
|
223
|
+
latencyTracker.record(this.config.model, elapsedMs, outputTokens);
|
|
224
|
+
}
|
|
213
225
|
extractResponse(promptConfig, rawValue, schema, depth = 0) {
|
|
214
226
|
const MAX_DEPTH = 5;
|
|
215
227
|
if (depth > MAX_DEPTH) {
|
|
@@ -374,8 +386,18 @@ export class BaseClient {
|
|
|
374
386
|
};
|
|
375
387
|
return;
|
|
376
388
|
}
|
|
389
|
+
const startTime = performance.now();
|
|
377
390
|
try {
|
|
378
|
-
|
|
391
|
+
for await (const chunk of this._textStream(newPromptConfig)) {
|
|
392
|
+
if (chunk.type === "done") {
|
|
393
|
+
const outputTokens = chunk.result.usage?.outputTokens;
|
|
394
|
+
if (outputTokens && outputTokens > 0) {
|
|
395
|
+
const elapsedMs = performance.now() - startTime;
|
|
396
|
+
latencyTracker.record(this.config.model, elapsedMs, outputTokens);
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
yield chunk;
|
|
400
|
+
}
|
|
379
401
|
}
|
|
380
402
|
catch (err) {
|
|
381
403
|
if (this.isAbortError(err)) {
|
package/dist/clients/google.js
CHANGED
|
@@ -3,6 +3,7 @@ import { ToolCall } from "../classes/ToolCall.js";
|
|
|
3
3
|
import { getLogger } from "../logger.js";
|
|
4
4
|
import { addCosts, addTokenUsage, success, } from "../types.js";
|
|
5
5
|
import { zodToGoogleTool } from "../util/tool.js";
|
|
6
|
+
import { SmolContentPolicyError, SmolContextWindowExceededError, } from "../smolError.js";
|
|
6
7
|
import { sanitizeAttributes } from "../util.js";
|
|
7
8
|
import { BaseClient } from "./baseClient.js";
|
|
8
9
|
import { Model } from "../model.js";
|
|
@@ -171,10 +172,28 @@ export class SmolGoogle extends BaseClient {
|
|
|
171
172
|
async __textSync(request) {
|
|
172
173
|
this.logger.debug("Sending request to Google Gemini:", JSON.stringify(request, null, 2));
|
|
173
174
|
this.statelogClient?.promptRequest(request);
|
|
174
|
-
|
|
175
|
-
|
|
175
|
+
let result;
|
|
176
|
+
try {
|
|
177
|
+
result = await this.client.models.generateContent(request);
|
|
178
|
+
}
|
|
179
|
+
catch (error) {
|
|
180
|
+
const msg = (error.message || "").toLowerCase();
|
|
181
|
+
if (msg.includes("token") &&
|
|
182
|
+
(msg.includes("exceed") ||
|
|
183
|
+
msg.includes("too long") ||
|
|
184
|
+
msg.includes("limit"))) {
|
|
185
|
+
throw new SmolContextWindowExceededError(error.message);
|
|
186
|
+
}
|
|
187
|
+
throw error;
|
|
188
|
+
}
|
|
176
189
|
this.logger.debug("Response from Google Gemini:", JSON.stringify(result, null, 2));
|
|
177
190
|
this.statelogClient?.promptResponse(result);
|
|
191
|
+
for (const candidate of result.candidates || []) {
|
|
192
|
+
const finishReason = candidate.finishReason;
|
|
193
|
+
if (finishReason === "SAFETY" || finishReason === "PROHIBITED_CONTENT") {
|
|
194
|
+
throw new SmolContentPolicyError(`Content blocked by Google safety filter: ${finishReason}`);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
178
197
|
const toolCalls = [];
|
|
179
198
|
const thinkingBlocks = [];
|
|
180
199
|
let textContent = "";
|
|
@@ -230,7 +249,20 @@ export class SmolGoogle extends BaseClient {
|
|
|
230
249
|
}
|
|
231
250
|
this.logger.debug("Sending streaming request to Google Gemini:", JSON.stringify(request, null, 2));
|
|
232
251
|
this.statelogClient?.promptRequest(request);
|
|
233
|
-
|
|
252
|
+
let stream;
|
|
253
|
+
try {
|
|
254
|
+
stream = await this.client.models.generateContentStream(request);
|
|
255
|
+
}
|
|
256
|
+
catch (error) {
|
|
257
|
+
const msg = (error.message || "").toLowerCase();
|
|
258
|
+
if (msg.includes("token") &&
|
|
259
|
+
(msg.includes("exceed") ||
|
|
260
|
+
msg.includes("too long") ||
|
|
261
|
+
msg.includes("limit"))) {
|
|
262
|
+
throw new SmolContextWindowExceededError(error.message);
|
|
263
|
+
}
|
|
264
|
+
throw error;
|
|
265
|
+
}
|
|
234
266
|
let content = "";
|
|
235
267
|
const toolCallsMap = new Map();
|
|
236
268
|
const thinkingBlocks = [];
|
package/dist/clients/ollama.js
CHANGED
|
@@ -5,6 +5,7 @@ import { success, } from "../types.js";
|
|
|
5
5
|
import { zodToGoogleTool } from "../util/tool.js";
|
|
6
6
|
import { sanitizeAttributes } from "../util.js";
|
|
7
7
|
import { BaseClient } from "./baseClient.js";
|
|
8
|
+
import { SmolContextWindowExceededError } from "../smolError.js";
|
|
8
9
|
import { Model } from "../model.js";
|
|
9
10
|
export const DEFAULT_OLLAMA_HOST = "http://localhost:11434";
|
|
10
11
|
export class SmolOllama extends BaseClient {
|
|
@@ -80,6 +81,13 @@ export class SmolOllama extends BaseClient {
|
|
|
80
81
|
// @ts-ignore
|
|
81
82
|
result = await this.client.chat(request);
|
|
82
83
|
}
|
|
84
|
+
catch (error) {
|
|
85
|
+
const msg = (error.message || "").toLowerCase();
|
|
86
|
+
if (msg.includes("context length") || msg.includes("context window")) {
|
|
87
|
+
throw new SmolContextWindowExceededError(error.message);
|
|
88
|
+
}
|
|
89
|
+
throw error;
|
|
90
|
+
}
|
|
83
91
|
finally {
|
|
84
92
|
if (signal && abortHandler) {
|
|
85
93
|
signal.removeEventListener("abort", abortHandler);
|
package/dist/clients/openai.d.ts
CHANGED
|
@@ -12,6 +12,7 @@ export declare class SmolOpenAi extends BaseClient implements SmolClient {
|
|
|
12
12
|
getModel(): ModelName;
|
|
13
13
|
private calculateUsageAndCost;
|
|
14
14
|
private buildRequest;
|
|
15
|
+
private rethrowAsSmolError;
|
|
15
16
|
_textSync(config: PromptConfig): Promise<Result<PromptResult>>;
|
|
16
17
|
_textStream(config: PromptConfig): AsyncGenerator<StreamChunk>;
|
|
17
18
|
}
|
package/dist/clients/openai.js
CHANGED
|
@@ -4,6 +4,7 @@ import { ToolCall } from "../classes/ToolCall.js";
|
|
|
4
4
|
import { isFunctionToolCall, sanitizeAttributes } from "../util.js";
|
|
5
5
|
import { getLogger } from "../logger.js";
|
|
6
6
|
import { BaseClient } from "./baseClient.js";
|
|
7
|
+
import { SmolContentPolicyError, SmolContextWindowExceededError, } from "../smolError.js";
|
|
7
8
|
import { zodToOpenAITool } from "../util/tool.js";
|
|
8
9
|
import { Model } from "../model.js";
|
|
9
10
|
export class SmolOpenAi extends BaseClient {
|
|
@@ -68,17 +69,37 @@ export class SmolOpenAi extends BaseClient {
|
|
|
68
69
|
}
|
|
69
70
|
return request;
|
|
70
71
|
}
|
|
72
|
+
rethrowAsSmolError(error) {
|
|
73
|
+
if (error instanceof OpenAI.APIError) {
|
|
74
|
+
if (error.code === "context_length_exceeded") {
|
|
75
|
+
throw new SmolContextWindowExceededError(error.message);
|
|
76
|
+
}
|
|
77
|
+
if (error.code === "content_policy_violation") {
|
|
78
|
+
throw new SmolContentPolicyError(error.message);
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
throw error;
|
|
82
|
+
}
|
|
71
83
|
async _textSync(config) {
|
|
72
84
|
const request = this.buildRequest(config);
|
|
73
85
|
this.logger.debug("Sending request to OpenAI:", JSON.stringify(request, null, 2));
|
|
74
86
|
this.statelogClient?.promptRequest(request);
|
|
75
87
|
const signal = this.getAbortSignal(config);
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
88
|
+
let completion;
|
|
89
|
+
try {
|
|
90
|
+
completion = await this.client.chat.completions.create({
|
|
91
|
+
...request,
|
|
92
|
+
stream: false,
|
|
93
|
+
}, { ...(signal && { signal }) });
|
|
94
|
+
}
|
|
95
|
+
catch (error) {
|
|
96
|
+
this.rethrowAsSmolError(error);
|
|
97
|
+
}
|
|
80
98
|
this.logger.debug("Response from OpenAI:", JSON.stringify(completion, null, 2));
|
|
81
99
|
this.statelogClient?.promptResponse(completion);
|
|
100
|
+
if (completion.choices[0]?.finish_reason === "content_filter") {
|
|
101
|
+
throw new SmolContentPolicyError("Content blocked by OpenAI content filter");
|
|
102
|
+
}
|
|
82
103
|
const message = completion.choices[0].message;
|
|
83
104
|
const output = message.content;
|
|
84
105
|
const _toolCalls = message.tool_calls;
|
|
@@ -109,11 +130,17 @@ export class SmolOpenAi extends BaseClient {
|
|
|
109
130
|
this.logger.debug("Sending streaming request to OpenAI:", JSON.stringify(request, null, 2));
|
|
110
131
|
this.statelogClient?.promptRequest(request);
|
|
111
132
|
const signal = this.getAbortSignal(config);
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
133
|
+
let completion;
|
|
134
|
+
try {
|
|
135
|
+
completion = await this.client.chat.completions.create({
|
|
136
|
+
...request,
|
|
137
|
+
stream: true,
|
|
138
|
+
stream_options: { include_usage: true },
|
|
139
|
+
}, { ...(signal && { signal }) });
|
|
140
|
+
}
|
|
141
|
+
catch (error) {
|
|
142
|
+
this.rethrowAsSmolError(error);
|
|
143
|
+
}
|
|
117
144
|
let content = "";
|
|
118
145
|
const toolCallsMap = new Map();
|
|
119
146
|
let usage;
|
|
@@ -127,6 +154,9 @@ export class SmolOpenAi extends BaseClient {
|
|
|
127
154
|
}
|
|
128
155
|
if (!chunk.choices || chunk.choices.length === 0)
|
|
129
156
|
continue;
|
|
157
|
+
if (chunk.choices[0]?.finish_reason === "content_filter") {
|
|
158
|
+
throw new SmolContentPolicyError("Content blocked by OpenAI content filter");
|
|
159
|
+
}
|
|
130
160
|
const delta = chunk.choices[0]?.delta;
|
|
131
161
|
if (!delta)
|
|
132
162
|
continue;
|
|
@@ -13,6 +13,7 @@ export declare class SmolOpenAiResponses extends BaseClient implements SmolClien
|
|
|
13
13
|
private convertMessages;
|
|
14
14
|
private buildRequest;
|
|
15
15
|
private calculateUsageAndCost;
|
|
16
|
+
private rethrowAsSmolError;
|
|
16
17
|
_textSync(config: PromptConfig): Promise<Result<PromptResult>>;
|
|
17
18
|
_textStream(config: PromptConfig): AsyncGenerator<StreamChunk>;
|
|
18
19
|
}
|
|
@@ -6,6 +6,7 @@ import { BaseClient } from "./baseClient.js";
|
|
|
6
6
|
import { zodToOpenAIResponsesTool } from "../util/tool.js";
|
|
7
7
|
import { sanitizeAttributes } from "../util.js";
|
|
8
8
|
import { Model } from "../model.js";
|
|
9
|
+
import { SmolContentPolicyError, SmolContextWindowExceededError, } from "../smolError.js";
|
|
9
10
|
export class SmolOpenAiResponses extends BaseClient {
|
|
10
11
|
client;
|
|
11
12
|
logger;
|
|
@@ -101,15 +102,32 @@ export class SmolOpenAiResponses extends BaseClient {
|
|
|
101
102
|
}
|
|
102
103
|
return { usage, cost };
|
|
103
104
|
}
|
|
105
|
+
rethrowAsSmolError(error) {
|
|
106
|
+
if (error instanceof OpenAI.APIError) {
|
|
107
|
+
if (error.code === "context_length_exceeded") {
|
|
108
|
+
throw new SmolContextWindowExceededError(error.message);
|
|
109
|
+
}
|
|
110
|
+
if (error.code === "content_policy_violation") {
|
|
111
|
+
throw new SmolContentPolicyError(error.message);
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
throw error;
|
|
115
|
+
}
|
|
104
116
|
async _textSync(config) {
|
|
105
117
|
const request = this.buildRequest(config);
|
|
106
118
|
this.logger.debug("Sending request to OpenAI Responses API:", JSON.stringify(request, null, 2));
|
|
107
119
|
this.statelogClient?.promptRequest(request);
|
|
108
120
|
const signal = this.getAbortSignal(config);
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
121
|
+
let response;
|
|
122
|
+
try {
|
|
123
|
+
response = await this.client.responses.create({
|
|
124
|
+
...request,
|
|
125
|
+
stream: false,
|
|
126
|
+
}, { ...(signal && { signal }) });
|
|
127
|
+
}
|
|
128
|
+
catch (error) {
|
|
129
|
+
this.rethrowAsSmolError(error);
|
|
130
|
+
}
|
|
113
131
|
this.logger.debug("Response from OpenAI Responses API:", JSON.stringify(response, null, 2));
|
|
114
132
|
this.statelogClient?.promptResponse(response);
|
|
115
133
|
const output = response.output_text || null;
|
|
@@ -133,9 +151,15 @@ export class SmolOpenAiResponses extends BaseClient {
|
|
|
133
151
|
this.logger.debug("Sending streaming request to OpenAI Responses API:", JSON.stringify(request, null, 2));
|
|
134
152
|
this.statelogClient?.promptRequest(request);
|
|
135
153
|
const signal = this.getAbortSignal(config);
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
154
|
+
let stream;
|
|
155
|
+
try {
|
|
156
|
+
stream = this.client.responses.stream(request, {
|
|
157
|
+
...(signal && { signal }),
|
|
158
|
+
});
|
|
159
|
+
}
|
|
160
|
+
catch (error) {
|
|
161
|
+
this.rethrowAsSmolError(error);
|
|
162
|
+
}
|
|
139
163
|
let content = "";
|
|
140
164
|
const functionCalls = new Map();
|
|
141
165
|
let usage;
|
package/dist/index.d.ts
CHANGED
|
@@ -8,3 +8,5 @@ export * from "./classes/message/index.js";
|
|
|
8
8
|
export * from "./functions.js";
|
|
9
9
|
export * from "./classes/ToolCall.js";
|
|
10
10
|
export * from "./strategies/index.js";
|
|
11
|
+
export { latencyTracker } from "./latencyTracker.js";
|
|
12
|
+
export type { LatencySample } from "./latencyTracker.js";
|
package/dist/index.js
CHANGED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
export type LatencySample = {
|
|
2
|
+
/** Milliseconds per output token */
|
|
3
|
+
msPerToken: number;
|
|
4
|
+
/** Timestamp when sample was recorded */
|
|
5
|
+
timestamp: number;
|
|
6
|
+
};
|
|
7
|
+
declare class LatencyTracker {
|
|
8
|
+
private samples;
|
|
9
|
+
private windowSize;
|
|
10
|
+
constructor(windowSize?: number);
|
|
11
|
+
/** Record a latency sample for a model. */
|
|
12
|
+
record(model: string, elapsedMs: number, outputTokens: number): void;
|
|
13
|
+
/** Get the windowed mean ms-per-token for a model, or null if no samples. */
|
|
14
|
+
getMeanMsPerToken(model: string): number | null;
|
|
15
|
+
/**
|
|
16
|
+
* Get estimated output tokens per second for a model based on tracked latency.
|
|
17
|
+
* Returns null if no samples exist or if the number of samples is below the minimum required.
|
|
18
|
+
*/
|
|
19
|
+
getTokensPerSecond(model: string, minSamples?: number): number | null;
|
|
20
|
+
/** Get the number of samples recorded for a model. */
|
|
21
|
+
getSampleCount(model: string): number;
|
|
22
|
+
/** Get all samples for a model (defensive copy). */
|
|
23
|
+
getSamples(model: string): LatencySample[];
|
|
24
|
+
/** Clear all samples for a model. */
|
|
25
|
+
clear(model?: string): void;
|
|
26
|
+
/** Update the window size. Existing samples beyond the new size are trimmed. */
|
|
27
|
+
setWindowSize(size: number): void;
|
|
28
|
+
getWindowSize(): number;
|
|
29
|
+
}
|
|
30
|
+
/** Global singleton latency tracker. */
|
|
31
|
+
export declare const latencyTracker: LatencyTracker;
|
|
32
|
+
export {};
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
const DEFAULT_WINDOW_SIZE = 10;
|
|
2
|
+
class LatencyTracker {
|
|
3
|
+
samples = new Map();
|
|
4
|
+
windowSize;
|
|
5
|
+
constructor(windowSize = DEFAULT_WINDOW_SIZE) {
|
|
6
|
+
this.windowSize = windowSize;
|
|
7
|
+
}
|
|
8
|
+
/** Record a latency sample for a model. */
|
|
9
|
+
record(model, elapsedMs, outputTokens) {
|
|
10
|
+
if (outputTokens <= 0 || elapsedMs <= 0)
|
|
11
|
+
return;
|
|
12
|
+
const msPerToken = elapsedMs / outputTokens;
|
|
13
|
+
const samples = this.samples.get(model) ?? [];
|
|
14
|
+
samples.push({ msPerToken, timestamp: Date.now() });
|
|
15
|
+
// Keep only the last windowSize samples
|
|
16
|
+
if (samples.length > this.windowSize) {
|
|
17
|
+
samples.splice(0, samples.length - this.windowSize);
|
|
18
|
+
}
|
|
19
|
+
this.samples.set(model, samples);
|
|
20
|
+
}
|
|
21
|
+
/** Get the windowed mean ms-per-token for a model, or null if no samples. */
|
|
22
|
+
getMeanMsPerToken(model) {
|
|
23
|
+
const samples = this.samples.get(model);
|
|
24
|
+
if (!samples || samples.length === 0)
|
|
25
|
+
return null;
|
|
26
|
+
const sum = samples.reduce((acc, s) => acc + s.msPerToken, 0);
|
|
27
|
+
return sum / samples.length;
|
|
28
|
+
}
|
|
29
|
+
/**
|
|
30
|
+
* Get estimated output tokens per second for a model based on tracked latency.
|
|
31
|
+
* Returns null if no samples exist or if the number of samples is below the minimum required.
|
|
32
|
+
*/
|
|
33
|
+
getTokensPerSecond(model, minSamples = 1) {
|
|
34
|
+
const sampleCount = this.getSampleCount(model);
|
|
35
|
+
if (sampleCount < minSamples)
|
|
36
|
+
return null;
|
|
37
|
+
const msPerToken = this.getMeanMsPerToken(model);
|
|
38
|
+
if (msPerToken === null || msPerToken === 0)
|
|
39
|
+
return null;
|
|
40
|
+
return 1000 / msPerToken;
|
|
41
|
+
}
|
|
42
|
+
/** Get the number of samples recorded for a model. */
|
|
43
|
+
getSampleCount(model) {
|
|
44
|
+
return this.samples.get(model)?.length ?? 0;
|
|
45
|
+
}
|
|
46
|
+
/** Get all samples for a model (defensive copy). */
|
|
47
|
+
getSamples(model) {
|
|
48
|
+
return [...(this.samples.get(model) ?? [])];
|
|
49
|
+
}
|
|
50
|
+
/** Clear all samples for a model. */
|
|
51
|
+
clear(model) {
|
|
52
|
+
if (model) {
|
|
53
|
+
this.samples.delete(model);
|
|
54
|
+
}
|
|
55
|
+
else {
|
|
56
|
+
this.samples.clear();
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
/** Update the window size. Existing samples beyond the new size are trimmed. */
|
|
60
|
+
setWindowSize(size) {
|
|
61
|
+
this.windowSize = size;
|
|
62
|
+
for (const [model, samples] of this.samples) {
|
|
63
|
+
if (samples.length > size) {
|
|
64
|
+
samples.splice(0, samples.length - size);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
getWindowSize() {
|
|
69
|
+
return this.windowSize;
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
/** Global singleton latency tracker. */
|
|
73
|
+
export const latencyTracker = new LatencyTracker();
|
package/dist/model.d.ts
CHANGED
|
@@ -6,14 +6,14 @@ export declare class Model {
|
|
|
6
6
|
private resolvedModel;
|
|
7
7
|
private provider?;
|
|
8
8
|
constructor(model: ModelName | ModelConfig | ModelNameAndProvider, provider?: Provider);
|
|
9
|
-
getModel():
|
|
9
|
+
getModel(): string | ModelNameAndProvider | {
|
|
10
10
|
optimizeFor: ("reasoning" | "speed" | "cost" | "large-context")[];
|
|
11
11
|
providers: ("local" | "ollama" | "openai" | "openai-responses" | "anthropic" | "google" | "replicate" | "modal")[];
|
|
12
12
|
limit?: {
|
|
13
13
|
cost?: number | undefined;
|
|
14
14
|
} | undefined;
|
|
15
15
|
};
|
|
16
|
-
getResolvedModel():
|
|
16
|
+
getResolvedModel(): string;
|
|
17
17
|
getProvider(): Provider | undefined;
|
|
18
18
|
setProvider(): Provider | undefined;
|
|
19
19
|
resolveModel(models?: readonly TextModel[]): ModelName;
|
|
@@ -31,5 +31,6 @@ export declare class Model {
|
|
|
31
31
|
currency: string;
|
|
32
32
|
} | null;
|
|
33
33
|
toString(): string;
|
|
34
|
+
toJSON(): ModelName | ModelNameAndProvider;
|
|
34
35
|
static create(model: ModelLike, provider?: Provider): Model;
|
|
35
36
|
}
|
package/dist/model.js
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { latencyTracker } from "./latencyTracker.js";
|
|
1
2
|
import { getModel, isTextModel, textModels, registeredTextModels, } from "./models.js";
|
|
2
3
|
import { SmolError } from "./smolError.js";
|
|
3
4
|
import { ModelConfigSchema, ModelNameAndProviderSchema, ModelNameSchema, } from "./strategies/types.js";
|
|
@@ -115,7 +116,10 @@ export class Model {
|
|
|
115
116
|
case "cost":
|
|
116
117
|
return (m.inputTokenCost ?? 0) + (m.outputTokenCost ?? 0);
|
|
117
118
|
case "speed":
|
|
118
|
-
|
|
119
|
+
// Prefer tracked latency over static estimates
|
|
120
|
+
return (latencyTracker.getTokensPerSecond(m.modelName) ??
|
|
121
|
+
m.outputTokensPerSecond ??
|
|
122
|
+
0);
|
|
119
123
|
case "reasoning":
|
|
120
124
|
return (m.inputTokenCost ?? 0) + (m.outputTokenCost ?? 0);
|
|
121
125
|
case "large-context":
|
|
@@ -149,6 +153,12 @@ export class Model {
|
|
|
149
153
|
toString() {
|
|
150
154
|
return `Model(${JSON.stringify(this.model)})`;
|
|
151
155
|
}
|
|
156
|
+
toJSON() {
|
|
157
|
+
if (ModelNameAndProviderSchema.safeParse(this.model).success) {
|
|
158
|
+
return this.model;
|
|
159
|
+
}
|
|
160
|
+
return this.getResolvedModel();
|
|
161
|
+
}
|
|
152
162
|
static create(model, provider) {
|
|
153
163
|
if (model instanceof Model) {
|
|
154
164
|
return model;
|
package/dist/models.d.ts
CHANGED
|
@@ -688,7 +688,7 @@ export type TextModelName = (typeof textModels)[number]["modelName"];
|
|
|
688
688
|
export type ImageModelName = (typeof imageModels)[number]["modelName"];
|
|
689
689
|
export type SpeechToTextModelName = (typeof speechToTextModels)[number]["modelName"];
|
|
690
690
|
export type EmbeddingsModelName = (typeof embeddingsModels)[number]["modelName"];
|
|
691
|
-
export type ModelName =
|
|
691
|
+
export type ModelName = string;
|
|
692
692
|
export declare const registeredTextModels: TextModel[];
|
|
693
693
|
export declare function registerTextModel(model: Omit<TextModel, "type"> & {
|
|
694
694
|
type?: "text";
|
package/dist/smolError.d.ts
CHANGED
|
@@ -7,3 +7,9 @@ export declare class SmolStructuredOutputError extends SmolError {
|
|
|
7
7
|
export declare class SmolTimeoutError extends SmolError {
|
|
8
8
|
constructor(message: string);
|
|
9
9
|
}
|
|
10
|
+
export declare class SmolContentPolicyError extends SmolError {
|
|
11
|
+
constructor(message: string);
|
|
12
|
+
}
|
|
13
|
+
export declare class SmolContextWindowExceededError extends SmolError {
|
|
14
|
+
constructor(message: string);
|
|
15
|
+
}
|
package/dist/smolError.js
CHANGED
|
@@ -16,3 +16,15 @@ export class SmolTimeoutError extends SmolError {
|
|
|
16
16
|
this.name = "SmolTimeoutError";
|
|
17
17
|
}
|
|
18
18
|
}
|
|
19
|
+
export class SmolContentPolicyError extends SmolError {
|
|
20
|
+
constructor(message) {
|
|
21
|
+
super(message);
|
|
22
|
+
this.name = "SmolContentPolicyError";
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
export class SmolContextWindowExceededError extends SmolError {
|
|
26
|
+
constructor(message) {
|
|
27
|
+
super(message);
|
|
28
|
+
this.name = "SmolContextWindowExceededError";
|
|
29
|
+
}
|
|
30
|
+
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { SmolStructuredOutputError, SmolTimeoutError } from "../smolError.js";
|
|
1
|
+
import { SmolContentPolicyError, SmolContextWindowExceededError, SmolStructuredOutputError, SmolTimeoutError, } from "../smolError.js";
|
|
2
2
|
import { success, } from "../types.js";
|
|
3
3
|
import { BaseStrategy } from "./baseStrategy.js";
|
|
4
4
|
import { IDStrategy } from "./idStrategy.js";
|
|
@@ -59,6 +59,28 @@ export class FallbackStrategy extends BaseStrategy {
|
|
|
59
59
|
});
|
|
60
60
|
}
|
|
61
61
|
}
|
|
62
|
+
else if (error instanceof SmolContentPolicyError) {
|
|
63
|
+
if (fallbackStrategies.contentPolicyViolation &&
|
|
64
|
+
fallbackStrategies.contentPolicyViolation.length > 0) {
|
|
65
|
+
this.statelogClient?.debug("FallbackStrategy: falling back due to content policy violation", {
|
|
66
|
+
failedStrategy: strategy.toString(),
|
|
67
|
+
});
|
|
68
|
+
return this._textWithFallbacks(config, fromJSON(fallbackStrategies.contentPolicyViolation[0]), {
|
|
69
|
+
contentPolicyViolation: fallbackStrategies.contentPolicyViolation.slice(1),
|
|
70
|
+
});
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
else if (error instanceof SmolContextWindowExceededError) {
|
|
74
|
+
if (fallbackStrategies.contextWindowExceeded &&
|
|
75
|
+
fallbackStrategies.contextWindowExceeded.length > 0) {
|
|
76
|
+
this.statelogClient?.debug("FallbackStrategy: falling back due to context window exceeded", {
|
|
77
|
+
failedStrategy: strategy.toString(),
|
|
78
|
+
});
|
|
79
|
+
return this._textWithFallbacks(config, fromJSON(fallbackStrategies.contextWindowExceeded[0]), {
|
|
80
|
+
contextWindowExceeded: fallbackStrategies.contextWindowExceeded.slice(1),
|
|
81
|
+
});
|
|
82
|
+
}
|
|
83
|
+
}
|
|
62
84
|
if (fallbackStrategies.error && fallbackStrategies.error.length > 0) {
|
|
63
85
|
this.statelogClient?.debug("FallbackStrategy: falling back due to error", {
|
|
64
86
|
failedStrategy: strategy.toString(),
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import { Model } from "../model.js";
|
|
2
|
+
import { PromptResult, Result, SmolPromptConfig } from "../types.js";
|
|
3
|
+
import { BaseStrategy } from "./baseStrategy.js";
|
|
4
|
+
import { ModelNameAndProvider, StrategyJSON } from "./types.js";
|
|
5
|
+
export declare class FastestStrategy extends BaseStrategy {
|
|
6
|
+
models: (string | ModelNameAndProvider | Model)[];
|
|
7
|
+
epsilon: number;
|
|
8
|
+
constructor(models: (string | ModelNameAndProvider | Model)[], epsilon?: number);
|
|
9
|
+
toString(): string;
|
|
10
|
+
toShortString(): string;
|
|
11
|
+
_text(config: SmolPromptConfig): Promise<Result<PromptResult>>;
|
|
12
|
+
private pickFastest;
|
|
13
|
+
/** Get tokens/sec for a model: tracked latency first, then static estimate, then 0. */
|
|
14
|
+
private getSpeed;
|
|
15
|
+
toJSON(): StrategyJSON;
|
|
16
|
+
static fromJSON(json: unknown): FastestStrategy;
|
|
17
|
+
}
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import { latencyTracker } from "../latencyTracker.js";
|
|
2
|
+
import { getLogger } from "../logger.js";
|
|
3
|
+
import { Model } from "../model.js";
|
|
4
|
+
import { BaseStrategy } from "./baseStrategy.js";
|
|
5
|
+
import { IDStrategy } from "./idStrategy.js";
|
|
6
|
+
import { FastestStrategyJSONSchema, } from "./types.js";
|
|
7
|
+
// what percentage of the time to explore (pick a random model instead of the fastest) - this prevents us from getting stuck on a model that was fast in the past but has since become slow
|
|
8
|
+
const DEFAULT_EPSILON = 0.1;
|
|
9
|
+
export class FastestStrategy extends BaseStrategy {
|
|
10
|
+
models;
|
|
11
|
+
epsilon;
|
|
12
|
+
constructor(models, epsilon = DEFAULT_EPSILON) {
|
|
13
|
+
super();
|
|
14
|
+
this.models = models;
|
|
15
|
+
this.epsilon = epsilon;
|
|
16
|
+
}
|
|
17
|
+
toString() {
|
|
18
|
+
return `FastestStrategy([${this.models.map((s) => s.toString()).join(", ")}])`;
|
|
19
|
+
}
|
|
20
|
+
toShortString() {
|
|
21
|
+
return `fastest([${this.models.map((s) => s.toString()).join(", ")}])`;
|
|
22
|
+
}
|
|
23
|
+
async _text(config) {
|
|
24
|
+
const resolved = this.models.map((model) => Model.create(model));
|
|
25
|
+
let chosen = null;
|
|
26
|
+
const logger = getLogger(config.logLevel);
|
|
27
|
+
if (Math.random() < this.epsilon) {
|
|
28
|
+
// Explore: pick a random model
|
|
29
|
+
chosen = resolved[Math.floor(Math.random() * resolved.length)];
|
|
30
|
+
logger.debug("fastest strategy - exploring random model", {
|
|
31
|
+
model: chosen.getResolvedModel(),
|
|
32
|
+
});
|
|
33
|
+
this.statelogClient?.debug("fastest strategy - picking random model", {
|
|
34
|
+
model: chosen.getResolvedModel(),
|
|
35
|
+
});
|
|
36
|
+
}
|
|
37
|
+
else {
|
|
38
|
+
// Exploit: pick the fastest model by tracked latency
|
|
39
|
+
chosen = this.pickFastest(resolved);
|
|
40
|
+
if (chosen) {
|
|
41
|
+
logger.debug("fastest strategy - exploiting fastest model", {
|
|
42
|
+
model: chosen.getResolvedModel(),
|
|
43
|
+
});
|
|
44
|
+
this.statelogClient?.debug("fastest strategy - using fastest model", {
|
|
45
|
+
model: chosen.getResolvedModel(),
|
|
46
|
+
});
|
|
47
|
+
}
|
|
48
|
+
else {
|
|
49
|
+
// we don't have latency data for any model, so just pick randomly
|
|
50
|
+
chosen = resolved[Math.floor(Math.random() * resolved.length)];
|
|
51
|
+
logger.debug("fastest strategy - no latency data, picking random model", {
|
|
52
|
+
models: resolved.map((m) => m.getResolvedModel()),
|
|
53
|
+
chosen: chosen.getResolvedModel(),
|
|
54
|
+
});
|
|
55
|
+
this.statelogClient?.debug("fastest strategy - no latency data, picking random model", {
|
|
56
|
+
models: resolved.map((m) => m.getResolvedModel()),
|
|
57
|
+
chosen,
|
|
58
|
+
});
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
const strategy = new IDStrategy(chosen);
|
|
62
|
+
return strategy.text(config);
|
|
63
|
+
}
|
|
64
|
+
pickFastest(models) {
|
|
65
|
+
let best = null;
|
|
66
|
+
let bestSpeed = 0;
|
|
67
|
+
for (let model of models) {
|
|
68
|
+
const speed = this.getSpeed(model);
|
|
69
|
+
if (speed && speed > bestSpeed) {
|
|
70
|
+
bestSpeed = speed;
|
|
71
|
+
best = model;
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
return best;
|
|
75
|
+
}
|
|
76
|
+
/** Get tokens/sec for a model: tracked latency first, then static estimate, then 0. */
|
|
77
|
+
getSpeed(model) {
|
|
78
|
+
const MIN_SAMPLES = 3;
|
|
79
|
+
const tracked = latencyTracker.getTokensPerSecond(model.getResolvedModel(), MIN_SAMPLES);
|
|
80
|
+
return tracked;
|
|
81
|
+
}
|
|
82
|
+
toJSON() {
|
|
83
|
+
return {
|
|
84
|
+
type: "fastest",
|
|
85
|
+
params: {
|
|
86
|
+
models: this.models.map((s) => (s instanceof Model ? s.toJSON() : s)),
|
|
87
|
+
},
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
static fromJSON(json) {
|
|
91
|
+
const parsed = FastestStrategyJSONSchema.parse(json);
|
|
92
|
+
const models = parsed.params.models;
|
|
93
|
+
return new FastestStrategy(models);
|
|
94
|
+
}
|
|
95
|
+
}
|
|
@@ -1,11 +1,15 @@
|
|
|
1
|
+
import { Model } from "../model.js";
|
|
1
2
|
import { ModelLike, ModelParam } from "../types.js";
|
|
2
|
-
import { FallbackStrategyConfig, Strategy, StrategyJSON } from "./types.js";
|
|
3
|
+
import { FallbackStrategyConfig, ModelNameAndProvider, Strategy, StrategyJSON } from "./types.js";
|
|
3
4
|
export * from "./baseStrategy.js";
|
|
4
5
|
export * from "./fallbackStrategy.js";
|
|
5
6
|
export * from "./idStrategy.js";
|
|
6
7
|
export * from "./raceStrategy.js";
|
|
8
|
+
export * from "./randomStrategy.js";
|
|
7
9
|
export * from "./types.js";
|
|
8
10
|
export declare function race(...strategies: ModelParam[]): Strategy;
|
|
11
|
+
export declare function random(...strategies: ModelParam[]): Strategy;
|
|
12
|
+
export declare function fastest(models: (string | ModelNameAndProvider | Model)[], epsilon?: number): Strategy;
|
|
9
13
|
export declare function id(model: ModelLike): Strategy;
|
|
10
14
|
export declare function fallback(primaryStrategy: ModelParam, config: FallbackStrategyConfig | string | string[]): Strategy;
|
|
11
15
|
export declare function fromJSON(json: StrategyJSON): Strategy;
|
package/dist/strategies/index.js
CHANGED
|
@@ -1,15 +1,24 @@
|
|
|
1
1
|
import { FallbackStrategy } from "./fallbackStrategy.js";
|
|
2
|
+
import { FastestStrategy } from "./fastestStrategy.js";
|
|
2
3
|
import { IDStrategy } from "./idStrategy.js";
|
|
3
4
|
import { RaceStrategy } from "./raceStrategy.js";
|
|
4
|
-
import {
|
|
5
|
+
import { RandomStrategy } from "./randomStrategy.js";
|
|
6
|
+
import { FallbackStrategyJSONSchema, FastestStrategyJSONSchema, IDStrategyJSONSchema, ModelNameAndProviderSchema, RaceStrategyJSONSchema, RandomStrategyJSONSchema, } from "./types.js";
|
|
5
7
|
export * from "./baseStrategy.js";
|
|
6
8
|
export * from "./fallbackStrategy.js";
|
|
7
9
|
export * from "./idStrategy.js";
|
|
8
10
|
export * from "./raceStrategy.js";
|
|
11
|
+
export * from "./randomStrategy.js";
|
|
9
12
|
export * from "./types.js";
|
|
10
13
|
export function race(...strategies) {
|
|
11
14
|
return new RaceStrategy(strategies);
|
|
12
15
|
}
|
|
16
|
+
export function random(...strategies) {
|
|
17
|
+
return new RandomStrategy(...strategies);
|
|
18
|
+
}
|
|
19
|
+
export function fastest(models, epsilon) {
|
|
20
|
+
return new FastestStrategy(models, epsilon);
|
|
21
|
+
}
|
|
13
22
|
export function id(model) {
|
|
14
23
|
return new IDStrategy(model);
|
|
15
24
|
}
|
|
@@ -39,6 +48,12 @@ export function fromJSON(json) {
|
|
|
39
48
|
else if (FallbackStrategyJSONSchema.safeParse(json).success) {
|
|
40
49
|
return FallbackStrategy.fromJSON(json);
|
|
41
50
|
}
|
|
51
|
+
else if (RandomStrategyJSONSchema.safeParse(json).success) {
|
|
52
|
+
return RandomStrategy.fromJSON(json);
|
|
53
|
+
}
|
|
54
|
+
else if (FastestStrategyJSONSchema.safeParse(json).success) {
|
|
55
|
+
return FastestStrategy.fromJSON(json);
|
|
56
|
+
}
|
|
42
57
|
else if (typeof json === "string") {
|
|
43
58
|
return id(json);
|
|
44
59
|
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import { ModelParam, PromptResult, Result, SmolPromptConfig } from "../types.js";
|
|
2
|
+
import { BaseStrategy } from "./baseStrategy.js";
|
|
3
|
+
import { Strategy, StrategyJSON } from "./types.js";
|
|
4
|
+
export declare class RandomStrategy extends BaseStrategy {
|
|
5
|
+
strategies: Strategy[];
|
|
6
|
+
constructor(...strategies: (Strategy | ModelParam)[]);
|
|
7
|
+
toString(): string;
|
|
8
|
+
toShortString(): string;
|
|
9
|
+
_text(config: SmolPromptConfig): Promise<Result<PromptResult>>;
|
|
10
|
+
toJSON(): StrategyJSON;
|
|
11
|
+
static fromJSON(json: unknown): RandomStrategy;
|
|
12
|
+
}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import { BaseStrategy } from "./baseStrategy.js";
|
|
2
|
+
import { IDStrategy } from "./idStrategy.js";
|
|
3
|
+
import { fromJSON } from "./index.js";
|
|
4
|
+
import { RandomStrategyJSONSchema } from "./types.js";
|
|
5
|
+
export class RandomStrategy extends BaseStrategy {
|
|
6
|
+
strategies;
|
|
7
|
+
constructor(...strategies) {
|
|
8
|
+
super();
|
|
9
|
+
this.strategies = strategies.map((s) => s instanceof BaseStrategy ? s : new IDStrategy(s));
|
|
10
|
+
}
|
|
11
|
+
toString() {
|
|
12
|
+
return `RandomStrategy([${this.strategies.map((s) => s.toString()).join(", ")}])`;
|
|
13
|
+
}
|
|
14
|
+
toShortString() {
|
|
15
|
+
return `random([${this.strategies.map((s) => s.toString()).join(", ")}])`;
|
|
16
|
+
}
|
|
17
|
+
async _text(config) {
|
|
18
|
+
const randomIndex = Math.floor(Math.random() * this.strategies.length);
|
|
19
|
+
const strategy = this.strategies[randomIndex];
|
|
20
|
+
this.statelogClient?.debug("random strategy chosen", {
|
|
21
|
+
strategy,
|
|
22
|
+
});
|
|
23
|
+
const result = await strategy.text(config);
|
|
24
|
+
return result;
|
|
25
|
+
}
|
|
26
|
+
toJSON() {
|
|
27
|
+
return {
|
|
28
|
+
type: "random",
|
|
29
|
+
params: {
|
|
30
|
+
strategies: this.strategies.map((s) => s.toJSON()),
|
|
31
|
+
},
|
|
32
|
+
};
|
|
33
|
+
}
|
|
34
|
+
static fromJSON(json) {
|
|
35
|
+
const parsed = RandomStrategyJSONSchema.parse(json);
|
|
36
|
+
const strategies = parsed.params.strategies.map((s) => fromJSON(s));
|
|
37
|
+
return new RandomStrategy(...strategies);
|
|
38
|
+
}
|
|
39
|
+
}
|
|
@@ -14,15 +14,19 @@ export declare const FallbackReasonSchema: z.ZodEnum<{
|
|
|
14
14
|
error: "error";
|
|
15
15
|
timeout: "timeout";
|
|
16
16
|
structuredOutputFailure: "structuredOutputFailure";
|
|
17
|
+
contentPolicyViolation: "contentPolicyViolation";
|
|
18
|
+
contextWindowExceeded: "contextWindowExceeded";
|
|
17
19
|
}>;
|
|
18
20
|
export declare const FallbackStrategyConfigSchema: z.ZodLazy<z.ZodRecord<z.ZodEnum<{
|
|
19
21
|
error: "error";
|
|
20
22
|
timeout: "timeout";
|
|
21
23
|
structuredOutputFailure: "structuredOutputFailure";
|
|
24
|
+
contentPolicyViolation: "contentPolicyViolation";
|
|
25
|
+
contextWindowExceeded: "contextWindowExceeded";
|
|
22
26
|
}> & z.core.$partial, z.ZodArray<z.ZodType<StrategyJSON, unknown, z.core.$ZodTypeInternals<StrategyJSON, unknown>>>>>;
|
|
23
27
|
export type FallbackReason = z.infer<typeof FallbackReasonSchema>;
|
|
24
28
|
export type FallbackStrategyConfig = z.infer<typeof FallbackStrategyConfigSchema>;
|
|
25
|
-
export type StrategyJSON = string | ModelNameAndProvider | IDStrategyJSON | RaceStrategyJSON | FallbackStrategyJSON;
|
|
29
|
+
export type StrategyJSON = string | ModelNameAndProvider | IDStrategyJSON | RaceStrategyJSON | FallbackStrategyJSON | RandomStrategyJSON | FastestStrategyJSON;
|
|
26
30
|
export declare const IDStrategyJSONSchema: z.ZodObject<{
|
|
27
31
|
type: z.ZodLiteral<"id">;
|
|
28
32
|
params: z.ZodObject<{
|
|
@@ -46,6 +50,20 @@ export type FallbackStrategyJSON = {
|
|
|
46
50
|
config: FallbackStrategyConfig;
|
|
47
51
|
};
|
|
48
52
|
};
|
|
53
|
+
export declare const RandomStrategyJSONSchema: z.ZodType<RandomStrategyJSON>;
|
|
54
|
+
export type RandomStrategyJSON = {
|
|
55
|
+
type: "random";
|
|
56
|
+
params: {
|
|
57
|
+
strategies: StrategyJSON[];
|
|
58
|
+
};
|
|
59
|
+
};
|
|
60
|
+
export declare const FastestStrategyJSONSchema: z.ZodType<FastestStrategyJSON>;
|
|
61
|
+
export type FastestStrategyJSON = {
|
|
62
|
+
type: "fastest";
|
|
63
|
+
params: {
|
|
64
|
+
models: (ModelNameAndProvider | string)[];
|
|
65
|
+
};
|
|
66
|
+
};
|
|
49
67
|
export type ModelNameAndProvider = {
|
|
50
68
|
model: string;
|
|
51
69
|
provider: string;
|
package/dist/strategies/types.js
CHANGED
|
@@ -4,6 +4,8 @@ export const FallbackReasonSchema = z.enum([
|
|
|
4
4
|
"error",
|
|
5
5
|
"timeout",
|
|
6
6
|
"structuredOutputFailure",
|
|
7
|
+
"contentPolicyViolation",
|
|
8
|
+
"contextWindowExceeded",
|
|
7
9
|
]);
|
|
8
10
|
export const FallbackStrategyConfigSchema = z.lazy(() => z.partialRecord(FallbackReasonSchema, z.array(StrategyJSONSchema)));
|
|
9
11
|
export const IDStrategyJSONSchema = z.object({
|
|
@@ -21,6 +23,16 @@ export const FallbackStrategyJSONSchema = z.lazy(() => z.object({
|
|
|
21
23
|
config: FallbackStrategyConfigSchema,
|
|
22
24
|
}),
|
|
23
25
|
}));
|
|
26
|
+
export const RandomStrategyJSONSchema = z.lazy(() => z.object({
|
|
27
|
+
type: z.literal("random"),
|
|
28
|
+
params: z.object({ strategies: z.array(StrategyJSONSchema) }),
|
|
29
|
+
}));
|
|
30
|
+
export const FastestStrategyJSONSchema = z.lazy(() => z.object({
|
|
31
|
+
type: z.literal("fastest"),
|
|
32
|
+
params: z.object({
|
|
33
|
+
models: z.array(z.union([ModelNameAndProviderSchema, z.string()])),
|
|
34
|
+
}),
|
|
35
|
+
}));
|
|
24
36
|
export const ModelNameAndProviderSchema = z.object({
|
|
25
37
|
model: z.string(),
|
|
26
38
|
provider: z.string(),
|
|
@@ -49,6 +61,8 @@ export const StrategyJSONSchema = z.lazy(() => z.union([
|
|
|
49
61
|
IDStrategyJSONSchema,
|
|
50
62
|
RaceStrategyJSONSchema,
|
|
51
63
|
FallbackStrategyJSONSchema,
|
|
64
|
+
RandomStrategyJSONSchema,
|
|
65
|
+
FastestStrategyJSONSchema,
|
|
52
66
|
]));
|
|
53
67
|
// Helper to detect if a value is a StrategyJSON object (not a plain string)
|
|
54
68
|
export function isStrategy(value) {
|