smoltalk 0.0.55 → 0.0.56

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,7 @@ export declare class SmolAnthropic extends BaseClient implements SmolClient {
12
12
  getModel(): ModelName;
13
13
  private calculateUsageAndCost;
14
14
  private buildRequest;
15
+ private rethrowAsSmolError;
15
16
  _textSync(config: PromptConfig): Promise<Result<PromptResult>>;
16
17
  _textStream(config: PromptConfig): AsyncGenerator<StreamChunk>;
17
18
  }
@@ -4,6 +4,7 @@ import { SystemMessage, DeveloperMessage } from "../classes/message/index.js";
4
4
  import { getLogger } from "../logger.js";
5
5
  import { success, } from "../types.js";
6
6
  import { zodToAnthropicTool } from "../util/tool.js";
7
+ import { SmolContentPolicyError, SmolContextWindowExceededError, } from "../smolError.js";
7
8
  import { BaseClient } from "./baseClient.js";
8
9
  import { Model } from "../model.js";
9
10
  const DEFAULT_MAX_TOKENS = 4096;
@@ -82,6 +83,24 @@ export class SmolAnthropic extends BaseClient {
82
83
  : undefined;
83
84
  return { system, messages: anthropicMessages, tools, thinking };
84
85
  }
86
+ rethrowAsSmolError(error) {
87
+ if (error instanceof Anthropic.APIError) {
88
+ const msg = error.message.toLowerCase();
89
+ if (msg.includes("prompt is too long") ||
90
+ msg.includes("context length") ||
91
+ msg.includes("context window") ||
92
+ msg.includes("too many tokens")) {
93
+ throw new SmolContextWindowExceededError(error.message);
94
+ }
95
+ if (msg.includes("content policy") ||
96
+ msg.includes("usage policies") ||
97
+ msg.includes("content filtering") ||
98
+ msg.includes("violates our")) {
99
+ throw new SmolContentPolicyError(error.message);
100
+ }
101
+ }
102
+ throw error;
103
+ }
85
104
  async _textSync(config) {
86
105
  const { system, messages, tools, thinking } = this.buildRequest(config);
87
106
  let debugData = {
@@ -95,19 +114,25 @@ export class SmolAnthropic extends BaseClient {
95
114
  this.logger.debug("Sending request to Anthropic:", debugData);
96
115
  this.statelogClient?.promptRequest(debugData);
97
116
  const signal = this.getAbortSignal(config);
98
- const response = await this.client.messages.create({
99
- model: this.getModel(),
100
- max_tokens: config.maxTokens ?? DEFAULT_MAX_TOKENS,
101
- messages,
102
- ...(system && { system }),
103
- ...(tools && { tools }),
104
- ...(thinking && { thinking }),
105
- ...(config.temperature !== undefined && {
106
- temperature: config.temperature,
107
- }),
108
- ...(config.rawAttributes || {}),
109
- stream: false,
110
- }, { ...(signal && { signal }) });
117
+ let response;
118
+ try {
119
+ response = await this.client.messages.create({
120
+ model: this.getModel(),
121
+ max_tokens: config.maxTokens ?? DEFAULT_MAX_TOKENS,
122
+ messages,
123
+ ...(system && { system }),
124
+ ...(tools && { tools }),
125
+ ...(thinking && { thinking }),
126
+ ...(config.temperature !== undefined && {
127
+ temperature: config.temperature,
128
+ }),
129
+ ...(config.rawAttributes || {}),
130
+ stream: false,
131
+ }, { ...(signal && { signal }) });
132
+ }
133
+ catch (error) {
134
+ this.rethrowAsSmolError(error);
135
+ }
111
136
  this.logger.debug("Response from Anthropic:", response);
112
137
  this.statelogClient?.promptResponse(response);
113
138
  let output = null;
@@ -148,19 +173,25 @@ export class SmolAnthropic extends BaseClient {
148
173
  this.logger.debug("Sending streaming request to Anthropic:", streamDebugData);
149
174
  this.statelogClient?.promptRequest(streamDebugData);
150
175
  const signal = this.getAbortSignal(config);
151
- const stream = await this.client.messages.create({
152
- model: this.model,
153
- max_tokens: config.maxTokens ?? DEFAULT_MAX_TOKENS,
154
- messages,
155
- ...(system && { system }),
156
- ...(tools && { tools }),
157
- ...(thinking && { thinking }),
158
- ...(config.temperature !== undefined && {
159
- temperature: config.temperature,
160
- }),
161
- ...(config.rawAttributes || {}),
162
- stream: true,
163
- }, { ...(signal && { signal }) });
176
+ let stream;
177
+ try {
178
+ stream = await this.client.messages.create({
179
+ model: this.model,
180
+ max_tokens: config.maxTokens ?? DEFAULT_MAX_TOKENS,
181
+ messages,
182
+ ...(system && { system }),
183
+ ...(tools && { tools }),
184
+ ...(thinking && { thinking }),
185
+ ...(config.temperature !== undefined && {
186
+ temperature: config.temperature,
187
+ }),
188
+ ...(config.rawAttributes || {}),
189
+ stream: true,
190
+ }, { ...(signal && { signal }) });
191
+ }
192
+ catch (error) {
193
+ this.rethrowAsSmolError(error);
194
+ }
164
195
  let content = "";
165
196
  // Track tool blocks by index: index -> { id, name, arguments (partial JSON) }
166
197
  const toolBlocks = new Map();
@@ -24,6 +24,7 @@ export declare class BaseClient implements SmolClient {
24
24
  continue: boolean;
25
25
  newPromptConfig: PromptConfig;
26
26
  };
27
+ private recordLatency;
27
28
  extractResponse(promptConfig: PromptConfig, rawValue: any, schema: any, depth?: number): any;
28
29
  textWithRetry(promptConfig: PromptConfig, retries: number): Promise<Result<PromptResult>>;
29
30
  _textSync(promptConfig: PromptConfig): Promise<Result<PromptResult>>;
@@ -1,4 +1,5 @@
1
1
  import { AssistantMessage, userMessage, assistantMessage, } from "../classes/message/index.js";
2
+ import { latencyTracker } from "../latencyTracker.js";
2
3
  import { getLogger } from "../logger.js";
3
4
  import { getModel, isTextModel } from "../models.js";
4
5
  import { SmolStructuredOutputError } from "../smolError.js";
@@ -146,9 +147,11 @@ export class BaseClient {
146
147
  value: { output: null, toolCalls: [], model: this.config.model },
147
148
  };
148
149
  }
150
+ const startTime = performance.now();
149
151
  try {
150
152
  const result = await this.textWithRetry(newPromptConfig, newPromptConfig.responseFormatOptions?.numRetries ||
151
153
  DEFAULT_NUM_RETRIES);
154
+ this.recordLatency(startTime, result);
152
155
  return result;
153
156
  }
154
157
  catch (err) {
@@ -210,6 +213,15 @@ export class BaseClient {
210
213
  }
211
214
  return { continue: true, newPromptConfig: promptConfig };
212
215
  }
216
+ recordLatency(startTime, result) {
217
+ if (!result.success)
218
+ return;
219
+ const outputTokens = result.value.usage?.outputTokens;
220
+ if (!outputTokens || outputTokens <= 0)
221
+ return;
222
+ const elapsedMs = performance.now() - startTime;
223
+ latencyTracker.record(this.config.model, elapsedMs, outputTokens);
224
+ }
213
225
  extractResponse(promptConfig, rawValue, schema, depth = 0) {
214
226
  const MAX_DEPTH = 5;
215
227
  if (depth > MAX_DEPTH) {
@@ -374,8 +386,18 @@ export class BaseClient {
374
386
  };
375
387
  return;
376
388
  }
389
+ const startTime = performance.now();
377
390
  try {
378
- yield* this._textStream(newPromptConfig);
391
+ for await (const chunk of this._textStream(newPromptConfig)) {
392
+ if (chunk.type === "done") {
393
+ const outputTokens = chunk.result.usage?.outputTokens;
394
+ if (outputTokens && outputTokens > 0) {
395
+ const elapsedMs = performance.now() - startTime;
396
+ latencyTracker.record(this.config.model, elapsedMs, outputTokens);
397
+ }
398
+ }
399
+ yield chunk;
400
+ }
379
401
  }
380
402
  catch (err) {
381
403
  if (this.isAbortError(err)) {
@@ -3,6 +3,7 @@ import { ToolCall } from "../classes/ToolCall.js";
3
3
  import { getLogger } from "../logger.js";
4
4
  import { addCosts, addTokenUsage, success, } from "../types.js";
5
5
  import { zodToGoogleTool } from "../util/tool.js";
6
+ import { SmolContentPolicyError, SmolContextWindowExceededError, } from "../smolError.js";
6
7
  import { sanitizeAttributes } from "../util.js";
7
8
  import { BaseClient } from "./baseClient.js";
8
9
  import { Model } from "../model.js";
@@ -171,10 +172,28 @@ export class SmolGoogle extends BaseClient {
171
172
  async __textSync(request) {
172
173
  this.logger.debug("Sending request to Google Gemini:", JSON.stringify(request, null, 2));
173
174
  this.statelogClient?.promptRequest(request);
174
- // Send the prompt as the latest message
175
- const result = await this.client.models.generateContent(request);
175
+ let result;
176
+ try {
177
+ result = await this.client.models.generateContent(request);
178
+ }
179
+ catch (error) {
180
+ const msg = (error.message || "").toLowerCase();
181
+ if (msg.includes("token") &&
182
+ (msg.includes("exceed") ||
183
+ msg.includes("too long") ||
184
+ msg.includes("limit"))) {
185
+ throw new SmolContextWindowExceededError(error.message);
186
+ }
187
+ throw error;
188
+ }
176
189
  this.logger.debug("Response from Google Gemini:", JSON.stringify(result, null, 2));
177
190
  this.statelogClient?.promptResponse(result);
191
+ for (const candidate of result.candidates || []) {
192
+ const finishReason = candidate.finishReason;
193
+ if (finishReason === "SAFETY" || finishReason === "PROHIBITED_CONTENT") {
194
+ throw new SmolContentPolicyError(`Content blocked by Google safety filter: ${finishReason}`);
195
+ }
196
+ }
178
197
  const toolCalls = [];
179
198
  const thinkingBlocks = [];
180
199
  let textContent = "";
@@ -230,7 +249,20 @@ export class SmolGoogle extends BaseClient {
230
249
  }
231
250
  this.logger.debug("Sending streaming request to Google Gemini:", JSON.stringify(request, null, 2));
232
251
  this.statelogClient?.promptRequest(request);
233
- const stream = await this.client.models.generateContentStream(request);
252
+ let stream;
253
+ try {
254
+ stream = await this.client.models.generateContentStream(request);
255
+ }
256
+ catch (error) {
257
+ const msg = (error.message || "").toLowerCase();
258
+ if (msg.includes("token") &&
259
+ (msg.includes("exceed") ||
260
+ msg.includes("too long") ||
261
+ msg.includes("limit"))) {
262
+ throw new SmolContextWindowExceededError(error.message);
263
+ }
264
+ throw error;
265
+ }
234
266
  let content = "";
235
267
  const toolCallsMap = new Map();
236
268
  const thinkingBlocks = [];
@@ -5,6 +5,7 @@ import { success, } from "../types.js";
5
5
  import { zodToGoogleTool } from "../util/tool.js";
6
6
  import { sanitizeAttributes } from "../util.js";
7
7
  import { BaseClient } from "./baseClient.js";
8
+ import { SmolContextWindowExceededError } from "../smolError.js";
8
9
  import { Model } from "../model.js";
9
10
  export const DEFAULT_OLLAMA_HOST = "http://localhost:11434";
10
11
  export class SmolOllama extends BaseClient {
@@ -80,6 +81,13 @@ export class SmolOllama extends BaseClient {
80
81
  // @ts-ignore
81
82
  result = await this.client.chat(request);
82
83
  }
84
+ catch (error) {
85
+ const msg = (error.message || "").toLowerCase();
86
+ if (msg.includes("context length") || msg.includes("context window")) {
87
+ throw new SmolContextWindowExceededError(error.message);
88
+ }
89
+ throw error;
90
+ }
83
91
  finally {
84
92
  if (signal && abortHandler) {
85
93
  signal.removeEventListener("abort", abortHandler);
@@ -12,6 +12,7 @@ export declare class SmolOpenAi extends BaseClient implements SmolClient {
12
12
  getModel(): ModelName;
13
13
  private calculateUsageAndCost;
14
14
  private buildRequest;
15
+ private rethrowAsSmolError;
15
16
  _textSync(config: PromptConfig): Promise<Result<PromptResult>>;
16
17
  _textStream(config: PromptConfig): AsyncGenerator<StreamChunk>;
17
18
  }
@@ -4,6 +4,7 @@ import { ToolCall } from "../classes/ToolCall.js";
4
4
  import { isFunctionToolCall, sanitizeAttributes } from "../util.js";
5
5
  import { getLogger } from "../logger.js";
6
6
  import { BaseClient } from "./baseClient.js";
7
+ import { SmolContentPolicyError, SmolContextWindowExceededError, } from "../smolError.js";
7
8
  import { zodToOpenAITool } from "../util/tool.js";
8
9
  import { Model } from "../model.js";
9
10
  export class SmolOpenAi extends BaseClient {
@@ -68,17 +69,37 @@ export class SmolOpenAi extends BaseClient {
68
69
  }
69
70
  return request;
70
71
  }
72
+ rethrowAsSmolError(error) {
73
+ if (error instanceof OpenAI.APIError) {
74
+ if (error.code === "context_length_exceeded") {
75
+ throw new SmolContextWindowExceededError(error.message);
76
+ }
77
+ if (error.code === "content_policy_violation") {
78
+ throw new SmolContentPolicyError(error.message);
79
+ }
80
+ }
81
+ throw error;
82
+ }
71
83
  async _textSync(config) {
72
84
  const request = this.buildRequest(config);
73
85
  this.logger.debug("Sending request to OpenAI:", JSON.stringify(request, null, 2));
74
86
  this.statelogClient?.promptRequest(request);
75
87
  const signal = this.getAbortSignal(config);
76
- const completion = await this.client.chat.completions.create({
77
- ...request,
78
- stream: false,
79
- }, { ...(signal && { signal }) });
88
+ let completion;
89
+ try {
90
+ completion = await this.client.chat.completions.create({
91
+ ...request,
92
+ stream: false,
93
+ }, { ...(signal && { signal }) });
94
+ }
95
+ catch (error) {
96
+ this.rethrowAsSmolError(error);
97
+ }
80
98
  this.logger.debug("Response from OpenAI:", JSON.stringify(completion, null, 2));
81
99
  this.statelogClient?.promptResponse(completion);
100
+ if (completion.choices[0]?.finish_reason === "content_filter") {
101
+ throw new SmolContentPolicyError("Content blocked by OpenAI content filter");
102
+ }
82
103
  const message = completion.choices[0].message;
83
104
  const output = message.content;
84
105
  const _toolCalls = message.tool_calls;
@@ -109,11 +130,17 @@ export class SmolOpenAi extends BaseClient {
109
130
  this.logger.debug("Sending streaming request to OpenAI:", JSON.stringify(request, null, 2));
110
131
  this.statelogClient?.promptRequest(request);
111
132
  const signal = this.getAbortSignal(config);
112
- const completion = await this.client.chat.completions.create({
113
- ...request,
114
- stream: true,
115
- stream_options: { include_usage: true },
116
- }, { ...(signal && { signal }) });
133
+ let completion;
134
+ try {
135
+ completion = await this.client.chat.completions.create({
136
+ ...request,
137
+ stream: true,
138
+ stream_options: { include_usage: true },
139
+ }, { ...(signal && { signal }) });
140
+ }
141
+ catch (error) {
142
+ this.rethrowAsSmolError(error);
143
+ }
117
144
  let content = "";
118
145
  const toolCallsMap = new Map();
119
146
  let usage;
@@ -127,6 +154,9 @@ export class SmolOpenAi extends BaseClient {
127
154
  }
128
155
  if (!chunk.choices || chunk.choices.length === 0)
129
156
  continue;
157
+ if (chunk.choices[0]?.finish_reason === "content_filter") {
158
+ throw new SmolContentPolicyError("Content blocked by OpenAI content filter");
159
+ }
130
160
  const delta = chunk.choices[0]?.delta;
131
161
  if (!delta)
132
162
  continue;
@@ -13,6 +13,7 @@ export declare class SmolOpenAiResponses extends BaseClient implements SmolClien
13
13
  private convertMessages;
14
14
  private buildRequest;
15
15
  private calculateUsageAndCost;
16
+ private rethrowAsSmolError;
16
17
  _textSync(config: PromptConfig): Promise<Result<PromptResult>>;
17
18
  _textStream(config: PromptConfig): AsyncGenerator<StreamChunk>;
18
19
  }
@@ -6,6 +6,7 @@ import { BaseClient } from "./baseClient.js";
6
6
  import { zodToOpenAIResponsesTool } from "../util/tool.js";
7
7
  import { sanitizeAttributes } from "../util.js";
8
8
  import { Model } from "../model.js";
9
+ import { SmolContentPolicyError, SmolContextWindowExceededError, } from "../smolError.js";
9
10
  export class SmolOpenAiResponses extends BaseClient {
10
11
  client;
11
12
  logger;
@@ -101,15 +102,32 @@ export class SmolOpenAiResponses extends BaseClient {
101
102
  }
102
103
  return { usage, cost };
103
104
  }
105
+ rethrowAsSmolError(error) {
106
+ if (error instanceof OpenAI.APIError) {
107
+ if (error.code === "context_length_exceeded") {
108
+ throw new SmolContextWindowExceededError(error.message);
109
+ }
110
+ if (error.code === "content_policy_violation") {
111
+ throw new SmolContentPolicyError(error.message);
112
+ }
113
+ }
114
+ throw error;
115
+ }
104
116
  async _textSync(config) {
105
117
  const request = this.buildRequest(config);
106
118
  this.logger.debug("Sending request to OpenAI Responses API:", JSON.stringify(request, null, 2));
107
119
  this.statelogClient?.promptRequest(request);
108
120
  const signal = this.getAbortSignal(config);
109
- const response = await this.client.responses.create({
110
- ...request,
111
- stream: false,
112
- }, { ...(signal && { signal }) });
121
+ let response;
122
+ try {
123
+ response = await this.client.responses.create({
124
+ ...request,
125
+ stream: false,
126
+ }, { ...(signal && { signal }) });
127
+ }
128
+ catch (error) {
129
+ this.rethrowAsSmolError(error);
130
+ }
113
131
  this.logger.debug("Response from OpenAI Responses API:", JSON.stringify(response, null, 2));
114
132
  this.statelogClient?.promptResponse(response);
115
133
  const output = response.output_text || null;
@@ -133,9 +151,15 @@ export class SmolOpenAiResponses extends BaseClient {
133
151
  this.logger.debug("Sending streaming request to OpenAI Responses API:", JSON.stringify(request, null, 2));
134
152
  this.statelogClient?.promptRequest(request);
135
153
  const signal = this.getAbortSignal(config);
136
- const stream = this.client.responses.stream(request, {
137
- ...(signal && { signal }),
138
- });
154
+ let stream;
155
+ try {
156
+ stream = this.client.responses.stream(request, {
157
+ ...(signal && { signal }),
158
+ });
159
+ }
160
+ catch (error) {
161
+ this.rethrowAsSmolError(error);
162
+ }
139
163
  let content = "";
140
164
  const functionCalls = new Map();
141
165
  let usage;
package/dist/index.d.ts CHANGED
@@ -8,3 +8,5 @@ export * from "./classes/message/index.js";
8
8
  export * from "./functions.js";
9
9
  export * from "./classes/ToolCall.js";
10
10
  export * from "./strategies/index.js";
11
+ export { latencyTracker } from "./latencyTracker.js";
12
+ export type { LatencySample } from "./latencyTracker.js";
package/dist/index.js CHANGED
@@ -8,3 +8,4 @@ export * from "./classes/message/index.js";
8
8
  export * from "./functions.js";
9
9
  export * from "./classes/ToolCall.js";
10
10
  export * from "./strategies/index.js";
11
+ export { latencyTracker } from "./latencyTracker.js";
@@ -0,0 +1,32 @@
1
+ export type LatencySample = {
2
+ /** Milliseconds per output token */
3
+ msPerToken: number;
4
+ /** Timestamp when sample was recorded */
5
+ timestamp: number;
6
+ };
7
+ declare class LatencyTracker {
8
+ private samples;
9
+ private windowSize;
10
+ constructor(windowSize?: number);
11
+ /** Record a latency sample for a model. */
12
+ record(model: string, elapsedMs: number, outputTokens: number): void;
13
+ /** Get the windowed mean ms-per-token for a model, or null if no samples. */
14
+ getMeanMsPerToken(model: string): number | null;
15
+ /**
16
+ * Get estimated output tokens per second for a model based on tracked latency.
17
+ * Returns null if no samples exist or if the number of samples is below the minimum required.
18
+ */
19
+ getTokensPerSecond(model: string, minSamples?: number): number | null;
20
+ /** Get the number of samples recorded for a model. */
21
+ getSampleCount(model: string): number;
22
+ /** Get all samples for a model (defensive copy). */
23
+ getSamples(model: string): LatencySample[];
24
+ /** Clear all samples for a model. */
25
+ clear(model?: string): void;
26
+ /** Update the window size. Existing samples beyond the new size are trimmed. */
27
+ setWindowSize(size: number): void;
28
+ getWindowSize(): number;
29
+ }
30
+ /** Global singleton latency tracker. */
31
+ export declare const latencyTracker: LatencyTracker;
32
+ export {};
@@ -0,0 +1,73 @@
1
+ const DEFAULT_WINDOW_SIZE = 10;
2
+ class LatencyTracker {
3
+ samples = new Map();
4
+ windowSize;
5
+ constructor(windowSize = DEFAULT_WINDOW_SIZE) {
6
+ this.windowSize = windowSize;
7
+ }
8
+ /** Record a latency sample for a model. */
9
+ record(model, elapsedMs, outputTokens) {
10
+ if (outputTokens <= 0 || elapsedMs <= 0)
11
+ return;
12
+ const msPerToken = elapsedMs / outputTokens;
13
+ const samples = this.samples.get(model) ?? [];
14
+ samples.push({ msPerToken, timestamp: Date.now() });
15
+ // Keep only the last windowSize samples
16
+ if (samples.length > this.windowSize) {
17
+ samples.splice(0, samples.length - this.windowSize);
18
+ }
19
+ this.samples.set(model, samples);
20
+ }
21
+ /** Get the windowed mean ms-per-token for a model, or null if no samples. */
22
+ getMeanMsPerToken(model) {
23
+ const samples = this.samples.get(model);
24
+ if (!samples || samples.length === 0)
25
+ return null;
26
+ const sum = samples.reduce((acc, s) => acc + s.msPerToken, 0);
27
+ return sum / samples.length;
28
+ }
29
+ /**
30
+ * Get estimated output tokens per second for a model based on tracked latency.
31
+ * Returns null if no samples exist or if the number of samples is below the minimum required.
32
+ */
33
+ getTokensPerSecond(model, minSamples = 1) {
34
+ const sampleCount = this.getSampleCount(model);
35
+ if (sampleCount < minSamples)
36
+ return null;
37
+ const msPerToken = this.getMeanMsPerToken(model);
38
+ if (msPerToken === null || msPerToken === 0)
39
+ return null;
40
+ return 1000 / msPerToken;
41
+ }
42
+ /** Get the number of samples recorded for a model. */
43
+ getSampleCount(model) {
44
+ return this.samples.get(model)?.length ?? 0;
45
+ }
46
+ /** Get all samples for a model (defensive copy). */
47
+ getSamples(model) {
48
+ return [...(this.samples.get(model) ?? [])];
49
+ }
50
+ /** Clear all samples for a model. */
51
+ clear(model) {
52
+ if (model) {
53
+ this.samples.delete(model);
54
+ }
55
+ else {
56
+ this.samples.clear();
57
+ }
58
+ }
59
+ /** Update the window size. Existing samples beyond the new size are trimmed. */
60
+ setWindowSize(size) {
61
+ this.windowSize = size;
62
+ for (const [model, samples] of this.samples) {
63
+ if (samples.length > size) {
64
+ samples.splice(0, samples.length - size);
65
+ }
66
+ }
67
+ }
68
+ getWindowSize() {
69
+ return this.windowSize;
70
+ }
71
+ }
72
+ /** Global singleton latency tracker. */
73
+ export const latencyTracker = new LatencyTracker();
package/dist/model.d.ts CHANGED
@@ -6,14 +6,14 @@ export declare class Model {
6
6
  private resolvedModel;
7
7
  private provider?;
8
8
  constructor(model: ModelName | ModelConfig | ModelNameAndProvider, provider?: Provider);
9
- getModel(): ModelName | ModelNameAndProvider | {
9
+ getModel(): string | ModelNameAndProvider | {
10
10
  optimizeFor: ("reasoning" | "speed" | "cost" | "large-context")[];
11
11
  providers: ("local" | "ollama" | "openai" | "openai-responses" | "anthropic" | "google" | "replicate" | "modal")[];
12
12
  limit?: {
13
13
  cost?: number | undefined;
14
14
  } | undefined;
15
15
  };
16
- getResolvedModel(): ModelName;
16
+ getResolvedModel(): string;
17
17
  getProvider(): Provider | undefined;
18
18
  setProvider(): Provider | undefined;
19
19
  resolveModel(models?: readonly TextModel[]): ModelName;
@@ -31,5 +31,6 @@ export declare class Model {
31
31
  currency: string;
32
32
  } | null;
33
33
  toString(): string;
34
+ toJSON(): ModelName | ModelNameAndProvider;
34
35
  static create(model: ModelLike, provider?: Provider): Model;
35
36
  }
package/dist/model.js CHANGED
@@ -1,3 +1,4 @@
1
+ import { latencyTracker } from "./latencyTracker.js";
1
2
  import { getModel, isTextModel, textModels, registeredTextModels, } from "./models.js";
2
3
  import { SmolError } from "./smolError.js";
3
4
  import { ModelConfigSchema, ModelNameAndProviderSchema, ModelNameSchema, } from "./strategies/types.js";
@@ -115,7 +116,10 @@ export class Model {
115
116
  case "cost":
116
117
  return (m.inputTokenCost ?? 0) + (m.outputTokenCost ?? 0);
117
118
  case "speed":
118
- return m.outputTokensPerSecond ?? 0;
119
+ // Prefer tracked latency over static estimates
120
+ return (latencyTracker.getTokensPerSecond(m.modelName) ??
121
+ m.outputTokensPerSecond ??
122
+ 0);
119
123
  case "reasoning":
120
124
  return (m.inputTokenCost ?? 0) + (m.outputTokenCost ?? 0);
121
125
  case "large-context":
@@ -149,6 +153,12 @@ export class Model {
149
153
  toString() {
150
154
  return `Model(${JSON.stringify(this.model)})`;
151
155
  }
156
+ toJSON() {
157
+ if (ModelNameAndProviderSchema.safeParse(this.model).success) {
158
+ return this.model;
159
+ }
160
+ return this.getResolvedModel();
161
+ }
152
162
  static create(model, provider) {
153
163
  if (model instanceof Model) {
154
164
  return model;
package/dist/models.d.ts CHANGED
@@ -688,7 +688,7 @@ export type TextModelName = (typeof textModels)[number]["modelName"];
688
688
  export type ImageModelName = (typeof imageModels)[number]["modelName"];
689
689
  export type SpeechToTextModelName = (typeof speechToTextModels)[number]["modelName"];
690
690
  export type EmbeddingsModelName = (typeof embeddingsModels)[number]["modelName"];
691
- export type ModelName = TextModelName | ImageModelName | SpeechToTextModelName;
691
+ export type ModelName = string;
692
692
  export declare const registeredTextModels: TextModel[];
693
693
  export declare function registerTextModel(model: Omit<TextModel, "type"> & {
694
694
  type?: "text";
@@ -7,3 +7,9 @@ export declare class SmolStructuredOutputError extends SmolError {
7
7
  export declare class SmolTimeoutError extends SmolError {
8
8
  constructor(message: string);
9
9
  }
10
+ export declare class SmolContentPolicyError extends SmolError {
11
+ constructor(message: string);
12
+ }
13
+ export declare class SmolContextWindowExceededError extends SmolError {
14
+ constructor(message: string);
15
+ }
package/dist/smolError.js CHANGED
@@ -16,3 +16,15 @@ export class SmolTimeoutError extends SmolError {
16
16
  this.name = "SmolTimeoutError";
17
17
  }
18
18
  }
19
+ export class SmolContentPolicyError extends SmolError {
20
+ constructor(message) {
21
+ super(message);
22
+ this.name = "SmolContentPolicyError";
23
+ }
24
+ }
25
+ export class SmolContextWindowExceededError extends SmolError {
26
+ constructor(message) {
27
+ super(message);
28
+ this.name = "SmolContextWindowExceededError";
29
+ }
30
+ }
@@ -1,4 +1,4 @@
1
- import { SmolStructuredOutputError, SmolTimeoutError } from "../smolError.js";
1
+ import { SmolContentPolicyError, SmolContextWindowExceededError, SmolStructuredOutputError, SmolTimeoutError, } from "../smolError.js";
2
2
  import { success, } from "../types.js";
3
3
  import { BaseStrategy } from "./baseStrategy.js";
4
4
  import { IDStrategy } from "./idStrategy.js";
@@ -59,6 +59,28 @@ export class FallbackStrategy extends BaseStrategy {
59
59
  });
60
60
  }
61
61
  }
62
+ else if (error instanceof SmolContentPolicyError) {
63
+ if (fallbackStrategies.contentPolicyViolation &&
64
+ fallbackStrategies.contentPolicyViolation.length > 0) {
65
+ this.statelogClient?.debug("FallbackStrategy: falling back due to content policy violation", {
66
+ failedStrategy: strategy.toString(),
67
+ });
68
+ return this._textWithFallbacks(config, fromJSON(fallbackStrategies.contentPolicyViolation[0]), {
69
+ contentPolicyViolation: fallbackStrategies.contentPolicyViolation.slice(1),
70
+ });
71
+ }
72
+ }
73
+ else if (error instanceof SmolContextWindowExceededError) {
74
+ if (fallbackStrategies.contextWindowExceeded &&
75
+ fallbackStrategies.contextWindowExceeded.length > 0) {
76
+ this.statelogClient?.debug("FallbackStrategy: falling back due to context window exceeded", {
77
+ failedStrategy: strategy.toString(),
78
+ });
79
+ return this._textWithFallbacks(config, fromJSON(fallbackStrategies.contextWindowExceeded[0]), {
80
+ contextWindowExceeded: fallbackStrategies.contextWindowExceeded.slice(1),
81
+ });
82
+ }
83
+ }
62
84
  if (fallbackStrategies.error && fallbackStrategies.error.length > 0) {
63
85
  this.statelogClient?.debug("FallbackStrategy: falling back due to error", {
64
86
  failedStrategy: strategy.toString(),
@@ -0,0 +1,17 @@
1
+ import { Model } from "../model.js";
2
+ import { PromptResult, Result, SmolPromptConfig } from "../types.js";
3
+ import { BaseStrategy } from "./baseStrategy.js";
4
+ import { ModelNameAndProvider, StrategyJSON } from "./types.js";
5
+ export declare class FastestStrategy extends BaseStrategy {
6
+ models: (string | ModelNameAndProvider | Model)[];
7
+ epsilon: number;
8
+ constructor(models: (string | ModelNameAndProvider | Model)[], epsilon?: number);
9
+ toString(): string;
10
+ toShortString(): string;
11
+ _text(config: SmolPromptConfig): Promise<Result<PromptResult>>;
12
+ private pickFastest;
13
+ /** Get tokens/sec for a model: tracked latency first, then static estimate, then 0. */
14
+ private getSpeed;
15
+ toJSON(): StrategyJSON;
16
+ static fromJSON(json: unknown): FastestStrategy;
17
+ }
@@ -0,0 +1,95 @@
1
+ import { latencyTracker } from "../latencyTracker.js";
2
+ import { getLogger } from "../logger.js";
3
+ import { Model } from "../model.js";
4
+ import { BaseStrategy } from "./baseStrategy.js";
5
+ import { IDStrategy } from "./idStrategy.js";
6
+ import { FastestStrategyJSONSchema, } from "./types.js";
7
+ // what percentage of the time to explore (pick a random model instead of the fastest) - this prevents us from getting stuck on a model that was fast in the past but has since become slow
8
+ const DEFAULT_EPSILON = 0.1;
9
+ export class FastestStrategy extends BaseStrategy {
10
+ models;
11
+ epsilon;
12
+ constructor(models, epsilon = DEFAULT_EPSILON) {
13
+ super();
14
+ this.models = models;
15
+ this.epsilon = epsilon;
16
+ }
17
+ toString() {
18
+ return `FastestStrategy([${this.models.map((s) => s.toString()).join(", ")}])`;
19
+ }
20
+ toShortString() {
21
+ return `fastest([${this.models.map((s) => s.toString()).join(", ")}])`;
22
+ }
23
+ async _text(config) {
24
+ const resolved = this.models.map((model) => Model.create(model));
25
+ let chosen = null;
26
+ const logger = getLogger(config.logLevel);
27
+ if (Math.random() < this.epsilon) {
28
+ // Explore: pick a random model
29
+ chosen = resolved[Math.floor(Math.random() * resolved.length)];
30
+ logger.debug("fastest strategy - exploring random model", {
31
+ model: chosen.getResolvedModel(),
32
+ });
33
+ this.statelogClient?.debug("fastest strategy - picking random model", {
34
+ model: chosen.getResolvedModel(),
35
+ });
36
+ }
37
+ else {
38
+ // Exploit: pick the fastest model by tracked latency
39
+ chosen = this.pickFastest(resolved);
40
+ if (chosen) {
41
+ logger.debug("fastest strategy - exploiting fastest model", {
42
+ model: chosen.getResolvedModel(),
43
+ });
44
+ this.statelogClient?.debug("fastest strategy - using fastest model", {
45
+ model: chosen.getResolvedModel(),
46
+ });
47
+ }
48
+ else {
49
+ // we don't have latency data for any model, so just pick randomly
50
+ chosen = resolved[Math.floor(Math.random() * resolved.length)];
51
+ logger.debug("fastest strategy - no latency data, picking random model", {
52
+ models: resolved.map((m) => m.getResolvedModel()),
53
+ chosen: chosen.getResolvedModel(),
54
+ });
55
+ this.statelogClient?.debug("fastest strategy - no latency data, picking random model", {
56
+ models: resolved.map((m) => m.getResolvedModel()),
57
+ chosen,
58
+ });
59
+ }
60
+ }
61
+ const strategy = new IDStrategy(chosen);
62
+ return strategy.text(config);
63
+ }
64
+ pickFastest(models) {
65
+ let best = null;
66
+ let bestSpeed = 0;
67
+ for (let model of models) {
68
+ const speed = this.getSpeed(model);
69
+ if (speed && speed > bestSpeed) {
70
+ bestSpeed = speed;
71
+ best = model;
72
+ }
73
+ }
74
+ return best;
75
+ }
76
+ /** Get tokens/sec for a model: tracked latency first, then static estimate, then 0. */
77
+ getSpeed(model) {
78
+ const MIN_SAMPLES = 3;
79
+ const tracked = latencyTracker.getTokensPerSecond(model.getResolvedModel(), MIN_SAMPLES);
80
+ return tracked;
81
+ }
82
+ toJSON() {
83
+ return {
84
+ type: "fastest",
85
+ params: {
86
+ models: this.models.map((s) => (s instanceof Model ? s.toJSON() : s)),
87
+ },
88
+ };
89
+ }
90
+ static fromJSON(json) {
91
+ const parsed = FastestStrategyJSONSchema.parse(json);
92
+ const models = parsed.params.models;
93
+ return new FastestStrategy(models);
94
+ }
95
+ }
@@ -1,11 +1,15 @@
1
+ import { Model } from "../model.js";
1
2
  import { ModelLike, ModelParam } from "../types.js";
2
- import { FallbackStrategyConfig, Strategy, StrategyJSON } from "./types.js";
3
+ import { FallbackStrategyConfig, ModelNameAndProvider, Strategy, StrategyJSON } from "./types.js";
3
4
  export * from "./baseStrategy.js";
4
5
  export * from "./fallbackStrategy.js";
5
6
  export * from "./idStrategy.js";
6
7
  export * from "./raceStrategy.js";
8
+ export * from "./randomStrategy.js";
7
9
  export * from "./types.js";
8
10
  export declare function race(...strategies: ModelParam[]): Strategy;
11
+ export declare function random(...strategies: ModelParam[]): Strategy;
12
+ export declare function fastest(models: (string | ModelNameAndProvider | Model)[], epsilon?: number): Strategy;
9
13
  export declare function id(model: ModelLike): Strategy;
10
14
  export declare function fallback(primaryStrategy: ModelParam, config: FallbackStrategyConfig | string | string[]): Strategy;
11
15
  export declare function fromJSON(json: StrategyJSON): Strategy;
@@ -1,15 +1,24 @@
1
1
  import { FallbackStrategy } from "./fallbackStrategy.js";
2
+ import { FastestStrategy } from "./fastestStrategy.js";
2
3
  import { IDStrategy } from "./idStrategy.js";
3
4
  import { RaceStrategy } from "./raceStrategy.js";
4
- import { FallbackStrategyJSONSchema, IDStrategyJSONSchema, ModelNameAndProviderSchema, RaceStrategyJSONSchema, } from "./types.js";
5
+ import { RandomStrategy } from "./randomStrategy.js";
6
+ import { FallbackStrategyJSONSchema, FastestStrategyJSONSchema, IDStrategyJSONSchema, ModelNameAndProviderSchema, RaceStrategyJSONSchema, RandomStrategyJSONSchema, } from "./types.js";
5
7
  export * from "./baseStrategy.js";
6
8
  export * from "./fallbackStrategy.js";
7
9
  export * from "./idStrategy.js";
8
10
  export * from "./raceStrategy.js";
11
+ export * from "./randomStrategy.js";
9
12
  export * from "./types.js";
10
13
  export function race(...strategies) {
11
14
  return new RaceStrategy(strategies);
12
15
  }
16
+ export function random(...strategies) {
17
+ return new RandomStrategy(...strategies);
18
+ }
19
+ export function fastest(models, epsilon) {
20
+ return new FastestStrategy(models, epsilon);
21
+ }
13
22
  export function id(model) {
14
23
  return new IDStrategy(model);
15
24
  }
@@ -39,6 +48,12 @@ export function fromJSON(json) {
39
48
  else if (FallbackStrategyJSONSchema.safeParse(json).success) {
40
49
  return FallbackStrategy.fromJSON(json);
41
50
  }
51
+ else if (RandomStrategyJSONSchema.safeParse(json).success) {
52
+ return RandomStrategy.fromJSON(json);
53
+ }
54
+ else if (FastestStrategyJSONSchema.safeParse(json).success) {
55
+ return FastestStrategy.fromJSON(json);
56
+ }
42
57
  else if (typeof json === "string") {
43
58
  return id(json);
44
59
  }
@@ -0,0 +1,12 @@
1
+ import { ModelParam, PromptResult, Result, SmolPromptConfig } from "../types.js";
2
+ import { BaseStrategy } from "./baseStrategy.js";
3
+ import { Strategy, StrategyJSON } from "./types.js";
4
+ export declare class RandomStrategy extends BaseStrategy {
5
+ strategies: Strategy[];
6
+ constructor(...strategies: (Strategy | ModelParam)[]);
7
+ toString(): string;
8
+ toShortString(): string;
9
+ _text(config: SmolPromptConfig): Promise<Result<PromptResult>>;
10
+ toJSON(): StrategyJSON;
11
+ static fromJSON(json: unknown): RandomStrategy;
12
+ }
@@ -0,0 +1,39 @@
1
+ import { BaseStrategy } from "./baseStrategy.js";
2
+ import { IDStrategy } from "./idStrategy.js";
3
+ import { fromJSON } from "./index.js";
4
+ import { RandomStrategyJSONSchema } from "./types.js";
5
+ export class RandomStrategy extends BaseStrategy {
6
+ strategies;
7
+ constructor(...strategies) {
8
+ super();
9
+ this.strategies = strategies.map((s) => s instanceof BaseStrategy ? s : new IDStrategy(s));
10
+ }
11
+ toString() {
12
+ return `RandomStrategy([${this.strategies.map((s) => s.toString()).join(", ")}])`;
13
+ }
14
+ toShortString() {
15
+ return `random([${this.strategies.map((s) => s.toString()).join(", ")}])`;
16
+ }
17
+ async _text(config) {
18
+ const randomIndex = Math.floor(Math.random() * this.strategies.length);
19
+ const strategy = this.strategies[randomIndex];
20
+ this.statelogClient?.debug("random strategy chosen", {
21
+ strategy,
22
+ });
23
+ const result = await strategy.text(config);
24
+ return result;
25
+ }
26
+ toJSON() {
27
+ return {
28
+ type: "random",
29
+ params: {
30
+ strategies: this.strategies.map((s) => s.toJSON()),
31
+ },
32
+ };
33
+ }
34
+ static fromJSON(json) {
35
+ const parsed = RandomStrategyJSONSchema.parse(json);
36
+ const strategies = parsed.params.strategies.map((s) => fromJSON(s));
37
+ return new RandomStrategy(...strategies);
38
+ }
39
+ }
@@ -14,15 +14,19 @@ export declare const FallbackReasonSchema: z.ZodEnum<{
14
14
  error: "error";
15
15
  timeout: "timeout";
16
16
  structuredOutputFailure: "structuredOutputFailure";
17
+ contentPolicyViolation: "contentPolicyViolation";
18
+ contextWindowExceeded: "contextWindowExceeded";
17
19
  }>;
18
20
  export declare const FallbackStrategyConfigSchema: z.ZodLazy<z.ZodRecord<z.ZodEnum<{
19
21
  error: "error";
20
22
  timeout: "timeout";
21
23
  structuredOutputFailure: "structuredOutputFailure";
24
+ contentPolicyViolation: "contentPolicyViolation";
25
+ contextWindowExceeded: "contextWindowExceeded";
22
26
  }> & z.core.$partial, z.ZodArray<z.ZodType<StrategyJSON, unknown, z.core.$ZodTypeInternals<StrategyJSON, unknown>>>>>;
23
27
  export type FallbackReason = z.infer<typeof FallbackReasonSchema>;
24
28
  export type FallbackStrategyConfig = z.infer<typeof FallbackStrategyConfigSchema>;
25
- export type StrategyJSON = string | ModelNameAndProvider | IDStrategyJSON | RaceStrategyJSON | FallbackStrategyJSON;
29
+ export type StrategyJSON = string | ModelNameAndProvider | IDStrategyJSON | RaceStrategyJSON | FallbackStrategyJSON | RandomStrategyJSON | FastestStrategyJSON;
26
30
  export declare const IDStrategyJSONSchema: z.ZodObject<{
27
31
  type: z.ZodLiteral<"id">;
28
32
  params: z.ZodObject<{
@@ -46,6 +50,20 @@ export type FallbackStrategyJSON = {
46
50
  config: FallbackStrategyConfig;
47
51
  };
48
52
  };
53
+ export declare const RandomStrategyJSONSchema: z.ZodType<RandomStrategyJSON>;
54
+ export type RandomStrategyJSON = {
55
+ type: "random";
56
+ params: {
57
+ strategies: StrategyJSON[];
58
+ };
59
+ };
60
+ export declare const FastestStrategyJSONSchema: z.ZodType<FastestStrategyJSON>;
61
+ export type FastestStrategyJSON = {
62
+ type: "fastest";
63
+ params: {
64
+ models: (ModelNameAndProvider | string)[];
65
+ };
66
+ };
49
67
  export type ModelNameAndProvider = {
50
68
  model: string;
51
69
  provider: string;
@@ -4,6 +4,8 @@ export const FallbackReasonSchema = z.enum([
4
4
  "error",
5
5
  "timeout",
6
6
  "structuredOutputFailure",
7
+ "contentPolicyViolation",
8
+ "contextWindowExceeded",
7
9
  ]);
8
10
  export const FallbackStrategyConfigSchema = z.lazy(() => z.partialRecord(FallbackReasonSchema, z.array(StrategyJSONSchema)));
9
11
  export const IDStrategyJSONSchema = z.object({
@@ -21,6 +23,16 @@ export const FallbackStrategyJSONSchema = z.lazy(() => z.object({
21
23
  config: FallbackStrategyConfigSchema,
22
24
  }),
23
25
  }));
26
+ export const RandomStrategyJSONSchema = z.lazy(() => z.object({
27
+ type: z.literal("random"),
28
+ params: z.object({ strategies: z.array(StrategyJSONSchema) }),
29
+ }));
30
+ export const FastestStrategyJSONSchema = z.lazy(() => z.object({
31
+ type: z.literal("fastest"),
32
+ params: z.object({
33
+ models: z.array(z.union([ModelNameAndProviderSchema, z.string()])),
34
+ }),
35
+ }));
24
36
  export const ModelNameAndProviderSchema = z.object({
25
37
  model: z.string(),
26
38
  provider: z.string(),
@@ -49,6 +61,8 @@ export const StrategyJSONSchema = z.lazy(() => z.union([
49
61
  IDStrategyJSONSchema,
50
62
  RaceStrategyJSONSchema,
51
63
  FallbackStrategyJSONSchema,
64
+ RandomStrategyJSONSchema,
65
+ FastestStrategyJSONSchema,
52
66
  ]));
53
67
  // Helper to detect if a value is a StrategyJSON object (not a plain string)
54
68
  export function isStrategy(value) {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "smoltalk",
3
- "version": "0.0.55",
3
+ "version": "0.0.56",
4
4
  "description": "A common interface for LLM APIs",
5
5
  "homepage": "https://github.com/egonSchiele/smoltalk",
6
6
  "scripts": {