smoltalk 0.0.52 → 0.0.53

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/dist/classes/ToolCall.d.ts +8 -6
  2. package/dist/classes/ToolCall.js +9 -1
  3. package/dist/classes/message/AssistantMessage.d.ts +37 -13
  4. package/dist/classes/message/AssistantMessage.js +27 -19
  5. package/dist/classes/message/BaseMessage.d.ts +1 -0
  6. package/dist/classes/message/BaseMessage.js +5 -0
  7. package/dist/classes/message/DeveloperMessage.d.ts +12 -6
  8. package/dist/classes/message/DeveloperMessage.js +13 -6
  9. package/dist/classes/message/SystemMessage.d.ts +12 -6
  10. package/dist/classes/message/SystemMessage.js +13 -6
  11. package/dist/classes/message/ToolMessage.d.ts +13 -7
  12. package/dist/classes/message/ToolMessage.js +15 -7
  13. package/dist/classes/message/UserMessage.d.ts +9 -6
  14. package/dist/classes/message/UserMessage.js +11 -3
  15. package/dist/classes/message/index.js +1 -1
  16. package/dist/client.d.ts +4 -7
  17. package/dist/client.js +10 -7
  18. package/dist/clients/baseClient.d.ts +4 -5
  19. package/dist/clients/baseClient.js +26 -26
  20. package/dist/clients/google.js +2 -1
  21. package/dist/clients/ollama.js +78 -72
  22. package/dist/clients/openai.js +5 -3
  23. package/dist/clients/openaiResponses.js +2 -3
  24. package/dist/functions.d.ts +6 -2
  25. package/dist/functions.js +11 -32
  26. package/dist/model.d.ts +14 -12
  27. package/dist/model.js +36 -12
  28. package/dist/models.d.ts +13 -1
  29. package/dist/models.js +12 -0
  30. package/dist/statelogClient.d.ts +3 -4
  31. package/dist/statelogClient.js +6 -4
  32. package/dist/strategies/baseStrategy.js +2 -2
  33. package/dist/strategies/fallbackStrategy.d.ts +6 -7
  34. package/dist/strategies/fallbackStrategy.js +61 -48
  35. package/dist/strategies/idStrategy.d.ts +5 -3
  36. package/dist/strategies/idStrategy.js +44 -10
  37. package/dist/strategies/index.d.ts +3 -3
  38. package/dist/strategies/index.js +20 -22
  39. package/dist/strategies/raceStrategy.d.ts +3 -2
  40. package/dist/strategies/raceStrategy.js +8 -1
  41. package/dist/strategies/types.d.ts +68 -13
  42. package/dist/strategies/types.js +57 -1
  43. package/dist/types.d.ts +32 -11
  44. package/dist/types.js +32 -0
  45. package/dist/util/tool.js +4 -0
  46. package/dist/util.d.ts +10 -0
  47. package/dist/util.js +34 -0
  48. package/package.json +4 -5
@@ -11,6 +11,9 @@ export class BaseClient {
11
11
  statelogClient;
12
12
  constructor(config) {
13
13
  this.config = config || {};
14
+ if (this.config.logLevel) {
15
+ getLogger(this.config.logLevel);
16
+ }
14
17
  if (this.config.statelog) {
15
18
  this.statelogClient = getStatelogClient(this.config.statelog);
16
19
  }
@@ -165,7 +168,7 @@ export class BaseClient {
165
168
  }
166
169
  }
167
170
  checkForToolLoops(promptConfig) {
168
- if (!this.config.toolLoopDetection?.enabled) {
171
+ if (!promptConfig.toolLoopDetection?.enabled) {
169
172
  return { continue: true, newPromptConfig: promptConfig };
170
173
  }
171
174
  const toolCallCounts = {};
@@ -175,9 +178,9 @@ export class BaseClient {
175
178
  toolCallCounts[msg.name] += 1;
176
179
  }
177
180
  for (const [toolName, count] of Object.entries(toolCallCounts)) {
178
- if (count >= this.config.toolLoopDetection.maxConsecutive &&
179
- !(this.config.toolLoopDetection.excludeTools ?? []).includes(toolName)) {
180
- const intervention = this.config.toolLoopDetection.intervention || "remove-tool";
181
+ if (count >= promptConfig.toolLoopDetection.maxCalls &&
182
+ !(promptConfig.toolLoopDetection.excludeTools ?? []).includes(toolName)) {
183
+ const intervention = promptConfig.toolLoopDetection.intervention || "remove-tool";
181
184
  const logger = getLogger();
182
185
  logger.warn(`Tool loop detected for tool "${toolName}" called ${count} times. Intervention: ${intervention}`);
183
186
  this.statelogClient?.debug("Tool loop detected", {
@@ -207,7 +210,11 @@ export class BaseClient {
207
210
  }
208
211
  return { continue: true, newPromptConfig: promptConfig };
209
212
  }
210
- extractResponse(promptConfig, rawValue, schema) {
213
+ extractResponse(promptConfig, rawValue, schema, depth = 0) {
214
+ const MAX_DEPTH = 5;
215
+ if (depth > MAX_DEPTH) {
216
+ throw new Error("extractResponse exceeded maximum depth");
217
+ }
211
218
  // 1. Direct match — try parsing as-is
212
219
  const direct = schema.safeParse(rawValue);
213
220
  if (direct.success) {
@@ -227,9 +234,16 @@ export class BaseClient {
227
234
  .replace(/^```json\s*/, "")
228
235
  .replace(/```\s*$/, "");
229
236
  try {
230
- return this.extractResponse(promptConfig, JSON.parse(stripped), schema);
237
+ return this.extractResponse(promptConfig, JSON.parse(stripped), schema, depth + 1);
238
+ }
239
+ catch (err) {
240
+ const logger = getLogger();
241
+ logger.debug("extractResponse: failed to parse JSON from string", {
242
+ error: err.message,
243
+ rawValue: stripped,
244
+ });
245
+ this.statelogClient?.debug("extractResponse: failed to parse JSON from string", { error: err.message });
231
246
  }
232
- catch { }
233
247
  return rawValue;
234
248
  }
235
249
  // 3. Null/undefined/primitive — nothing to unwrap
@@ -250,10 +264,10 @@ export class BaseClient {
250
264
  if (key in rawValue) {
251
265
  const inner = schema.safeParse(rawValue[key]);
252
266
  if (inner.success)
253
- return inner.data[key];
267
+ return inner.data;
254
268
  }
255
269
  }
256
- // 6. Shallow search — check every value of the object
270
+ // 6. Shallow search — check every value of the object (own keys only)
257
271
  for (const key of Object.keys(rawValue)) {
258
272
  const inner = schema.safeParse(rawValue[key]);
259
273
  if (inner.success)
@@ -304,7 +318,7 @@ export class BaseClient {
304
318
  catch (err) {
305
319
  const errorMessage = err.message;
306
320
  const logger = getLogger();
307
- logger.warn(`Response format validation failed (retries left: ${retries}): `, errorMessage, "output:", JSON.stringify(output, null, 2), "responseFormat:", JSON.stringify(promptConfig.responseFormat, null, 2));
321
+ logger.warn(`Response format validation failed (retries left: ${retries}): `, errorMessage);
308
322
  if (err instanceof z.ZodError) {
309
323
  logger.warn("Zod error details:", z.prettifyError(err));
310
324
  }
@@ -312,11 +326,6 @@ export class BaseClient {
312
326
  retriesLeft: retries,
313
327
  error: errorMessage,
314
328
  });
315
- this.statelogClient?.diff({
316
- message: "Response format validation failed",
317
- itemA: promptConfig.responseFormat,
318
- itemB: output,
319
- });
320
329
  const retryMessages = [
321
330
  ...promptConfig.messages,
322
331
  assistantMessage(output),
@@ -326,21 +335,12 @@ export class BaseClient {
326
335
  }
327
336
  }
328
337
  }
329
- throw new SmolStructuredOutputError(`Failed to get valid response after ${DEFAULT_NUM_RETRIES} attempts: ${result.success ? "Output did not match expected format" : result.error}`);
338
+ const numRetries = promptConfig.responseFormatOptions?.numRetries || DEFAULT_NUM_RETRIES;
339
+ throw new SmolStructuredOutputError(`Failed to get valid response after ${numRetries} attempts: ${result.success ? "Output did not match expected format" : result.error}`);
330
340
  }
331
341
  async _textSync(promptConfig) {
332
342
  throw new Error("Method not implemented.");
333
343
  }
334
- prompt(text, promptConfig) {
335
- const msg = userMessage(text);
336
- const newPromptConfig = {
337
- ...promptConfig,
338
- messages: promptConfig?.messages
339
- ? [...promptConfig.messages, msg]
340
- : [msg],
341
- };
342
- return this.text(newPromptConfig);
343
- }
344
344
  async *textStream(config) {
345
345
  const messageLimitResult = this.checkMessageLimit(config);
346
346
  if (messageLimitResult) {
@@ -3,6 +3,7 @@ import { ToolCall } from "../classes/ToolCall.js";
3
3
  import { getLogger } from "../logger.js";
4
4
  import { addCosts, addTokenUsage, success, } from "../types.js";
5
5
  import { zodToGoogleTool } from "../util/tool.js";
6
+ import { sanitizeAttributes } from "../util.js";
6
7
  import { BaseClient } from "./baseClient.js";
7
8
  import { Model } from "../model.js";
8
9
  import { userMessage } from "../classes/message/index.js";
@@ -80,7 +81,7 @@ export class SmolGoogle extends BaseClient {
80
81
  contents: messages,
81
82
  model: this.getModel(),
82
83
  config: genConfig,
83
- ...(config.rawAttributes || {}),
84
+ ...sanitizeAttributes(config.rawAttributes),
84
85
  };
85
86
  }
86
87
  async _textSync(config) {
@@ -3,6 +3,7 @@ import { ToolCall } from "../classes/ToolCall.js";
3
3
  import { getLogger } from "../logger.js";
4
4
  import { success, } from "../types.js";
5
5
  import { zodToGoogleTool } from "../util/tool.js";
6
+ import { sanitizeAttributes } from "../util.js";
6
7
  import { BaseClient } from "./baseClient.js";
7
8
  import { Model } from "../model.js";
8
9
  export const DEFAULT_OLLAMA_HOST = "http://localhost:11434";
@@ -66,9 +67,7 @@ export class SmolOllama extends BaseClient {
66
67
  if (config.responseFormat) {
67
68
  request.format = config.responseFormat.toJSONSchema();
68
69
  }
69
- if (config.rawAttributes) {
70
- Object.assign(request, config.rawAttributes);
71
- }
70
+ Object.assign(request, sanitizeAttributes(config.rawAttributes));
72
71
  this.logger.debug("Sending request to Ollama:", JSON.stringify(request, null, 2));
73
72
  this.statelogClient?.promptRequest(request);
74
73
  const signal = this.getAbortSignal(config);
@@ -76,10 +75,15 @@ export class SmolOllama extends BaseClient {
76
75
  if (signal && abortHandler) {
77
76
  signal.addEventListener("abort", abortHandler, { once: true });
78
77
  }
79
- // @ts-ignore
80
- const result = await this.client.chat(request);
81
- if (signal && abortHandler) {
82
- signal.removeEventListener("abort", abortHandler);
78
+ let result;
79
+ try {
80
+ // @ts-ignore
81
+ result = await this.client.chat(request);
82
+ }
83
+ finally {
84
+ if (signal && abortHandler) {
85
+ signal.removeEventListener("abort", abortHandler);
86
+ }
83
87
  }
84
88
  this.logger.debug("Response from Ollama:", JSON.stringify(result, null, 2));
85
89
  this.statelogClient?.promptResponse(result);
@@ -114,9 +118,7 @@ export class SmolOllama extends BaseClient {
114
118
  if (config.responseFormat) {
115
119
  request.format = config.responseFormat.toJSONSchema();
116
120
  }
117
- if (config.rawAttributes) {
118
- Object.assign(request, config.rawAttributes);
119
- }
121
+ Object.assign(request, sanitizeAttributes(config.rawAttributes));
120
122
  this.logger.debug("Sending streaming request to Ollama:", JSON.stringify(request, null, 2));
121
123
  this.statelogClient?.promptRequest(request);
122
124
  const signal = this.getAbortSignal(config);
@@ -124,73 +126,77 @@ export class SmolOllama extends BaseClient {
124
126
  if (signal && abortHandler) {
125
127
  signal.addEventListener("abort", abortHandler, { once: true });
126
128
  }
127
- // @ts-ignore
128
- const stream = await this.client.chat(request);
129
- let content = "";
130
- const toolCallsMap = new Map();
131
- let usage;
132
- let cost;
133
- let lastChunk;
134
- for await (const chunk of stream) {
135
- lastChunk = chunk;
136
- // Handle text content
137
- if (chunk.message?.content) {
138
- content += chunk.message.content;
139
- yield { type: "text", text: chunk.message.content };
140
- }
141
- // Handle tool calls
142
- if (chunk.message?.tool_calls) {
143
- for (const tc of chunk.message.tool_calls) {
144
- const tool_call = tc;
145
- const id = tool_call.id || tool_call.function.name || "";
146
- const name = tool_call.function.name || "";
147
- if (!toolCallsMap.has(id)) {
148
- toolCallsMap.set(id, {
149
- id: id,
150
- name: name,
151
- arguments: tool_call.function.arguments || {},
152
- });
153
- }
154
- else {
155
- // Merge arguments if tool call is split across chunks
156
- const existing = toolCallsMap.get(id);
157
- if (tool_call.function.arguments) {
158
- existing.arguments = {
159
- ...existing.arguments,
160
- ...tool_call.function.arguments,
161
- };
129
+ try {
130
+ // @ts-ignore
131
+ const stream = await this.client.chat(request);
132
+ let content = "";
133
+ const toolCallsMap = new Map();
134
+ let usage;
135
+ let cost;
136
+ let lastChunk;
137
+ for await (const chunk of stream) {
138
+ lastChunk = chunk;
139
+ // Handle text content
140
+ if (chunk.message?.content) {
141
+ content += chunk.message.content;
142
+ yield { type: "text", text: chunk.message.content };
143
+ }
144
+ // Handle tool calls
145
+ if (chunk.message?.tool_calls) {
146
+ for (const tc of chunk.message.tool_calls) {
147
+ const tool_call = tc;
148
+ const id = tool_call.id || tool_call.function.name || "";
149
+ const name = tool_call.function.name || "";
150
+ if (!toolCallsMap.has(id)) {
151
+ toolCallsMap.set(id, {
152
+ id: id,
153
+ name: name,
154
+ arguments: tool_call.function.arguments || {},
155
+ });
156
+ }
157
+ else {
158
+ // Merge arguments if tool call is split across chunks
159
+ const existing = toolCallsMap.get(id);
160
+ if (tool_call.function.arguments) {
161
+ existing.arguments = {
162
+ ...existing.arguments,
163
+ ...tool_call.function.arguments,
164
+ };
165
+ }
162
166
  }
163
167
  }
164
168
  }
165
169
  }
170
+ this.logger.debug("Streaming response completed from Ollama");
171
+ // Extract usage from the last chunk
172
+ if (lastChunk) {
173
+ const usageAndCost = this.calculateUsageAndCost(lastChunk);
174
+ usage = usageAndCost.usage;
175
+ cost = usageAndCost.cost;
176
+ }
177
+ this.statelogClient?.promptResponse({ content, usage, cost });
178
+ // Yield tool calls
179
+ const toolCalls = [];
180
+ for (const tc of toolCallsMap.values()) {
181
+ const toolCall = new ToolCall(tc.id, tc.name, tc.arguments);
182
+ toolCalls.push(toolCall);
183
+ yield { type: "tool_call", toolCall };
184
+ }
185
+ yield {
186
+ type: "done",
187
+ result: {
188
+ output: content || null,
189
+ toolCalls,
190
+ usage,
191
+ cost,
192
+ model: this.getModel(),
193
+ },
194
+ };
166
195
  }
167
- if (signal && abortHandler) {
168
- signal.removeEventListener("abort", abortHandler);
169
- }
170
- this.logger.debug("Streaming response completed from Ollama");
171
- // Extract usage from the last chunk
172
- if (lastChunk) {
173
- const usageAndCost = this.calculateUsageAndCost(lastChunk);
174
- usage = usageAndCost.usage;
175
- cost = usageAndCost.cost;
176
- }
177
- this.statelogClient?.promptResponse({ content, usage, cost });
178
- // Yield tool calls
179
- const toolCalls = [];
180
- for (const tc of toolCallsMap.values()) {
181
- const toolCall = new ToolCall(tc.id, tc.name, tc.arguments);
182
- toolCalls.push(toolCall);
183
- yield { type: "tool_call", toolCall };
196
+ finally {
197
+ if (signal && abortHandler) {
198
+ signal.removeEventListener("abort", abortHandler);
199
+ }
184
200
  }
185
- yield {
186
- type: "done",
187
- result: {
188
- output: content || null,
189
- toolCalls,
190
- usage,
191
- cost,
192
- model: this.getModel(),
193
- },
194
- };
195
201
  }
196
202
  }
@@ -1,7 +1,7 @@
1
1
  import OpenAI from "openai";
2
2
  import { success, } from "../types.js";
3
3
  import { ToolCall } from "../classes/ToolCall.js";
4
- import { isFunctionToolCall } from "../util.js";
4
+ import { isFunctionToolCall, sanitizeAttributes } from "../util.js";
5
5
  import { getLogger } from "../logger.js";
6
6
  import { BaseClient } from "./baseClient.js";
7
7
  import { zodToOpenAITool } from "../util/tool.js";
@@ -55,7 +55,7 @@ export class SmolOpenAi extends BaseClient {
55
55
  ...(config.reasoningEffort && {
56
56
  reasoning_effort: config.reasoningEffort,
57
57
  }),
58
- ...(config.rawAttributes || {}),
58
+ ...sanitizeAttributes(config.rawAttributes),
59
59
  };
60
60
  if (config.responseFormat) {
61
61
  request.response_format = {
@@ -119,13 +119,15 @@ export class SmolOpenAi extends BaseClient {
119
119
  let usage;
120
120
  let cost;
121
121
  for await (const chunk of completion) {
122
- const delta = chunk.choices[0]?.delta;
123
122
  // Extract usage from the final chunk
124
123
  if (chunk.usage) {
125
124
  const usageAndCost = this.calculateUsageAndCost(chunk.usage);
126
125
  usage = usageAndCost.usage;
127
126
  cost = usageAndCost.cost;
128
127
  }
128
+ if (!chunk.choices || chunk.choices.length === 0)
129
+ continue;
130
+ const delta = chunk.choices[0]?.delta;
129
131
  if (!delta)
130
132
  continue;
131
133
  if (delta.content) {
@@ -4,6 +4,7 @@ import { ToolCall } from "../classes/ToolCall.js";
4
4
  import { getLogger } from "../logger.js";
5
5
  import { BaseClient } from "./baseClient.js";
6
6
  import { zodToOpenAIResponsesTool } from "../util/tool.js";
7
+ import { sanitizeAttributes } from "../util.js";
7
8
  import { Model } from "../model.js";
8
9
  export class SmolOpenAiResponses extends BaseClient {
9
10
  client;
@@ -80,9 +81,7 @@ export class SmolOpenAiResponses extends BaseClient {
80
81
  if (config.reasoningEffort) {
81
82
  request.reasoning = { effort: config.reasoningEffort };
82
83
  }
83
- if (config.rawAttributes) {
84
- Object.assign(request, config.rawAttributes);
85
- }
84
+ Object.assign(request, sanitizeAttributes(config.rawAttributes));
86
85
  return request;
87
86
  }
88
87
  calculateUsageAndCost(usageData) {
@@ -1,5 +1,10 @@
1
- import { SmolPromptConfig, PromptResult, StreamChunk } from "./types.js";
1
+ import { getClient } from "./client.js";
2
+ import { PromptConfig, PromptResult, SmolPromptConfig, StreamChunk } from "./types.js";
2
3
  import { Result } from "./types/result.js";
4
+ export declare function splitConfig(config: SmolPromptConfig): {
5
+ smolConfig: Parameters<typeof getClient>[0];
6
+ promptConfig: PromptConfig;
7
+ };
3
8
  export declare function text(config: SmolPromptConfig & {
4
9
  stream: true;
5
10
  }): AsyncGenerator<StreamChunk>;
@@ -8,4 +13,3 @@ export declare function text(config: SmolPromptConfig & {
8
13
  }): Promise<Result<PromptResult>>;
9
14
  export declare function textSync(config: SmolPromptConfig): Promise<Result<PromptResult>>;
10
15
  export declare function textStream(config: SmolPromptConfig): AsyncGenerator<StreamChunk>;
11
- export declare function prompt(promptText: string, config: SmolPromptConfig): Promise<Result<PromptResult>> | AsyncGenerator<StreamChunk>;
package/dist/functions.js CHANGED
@@ -2,14 +2,13 @@ import { getClient } from "./client.js";
2
2
  import { Model } from "./model.js";
3
3
  import { BaseStrategy } from "./strategies/baseStrategy.js";
4
4
  import { fromJSON } from "./strategies/index.js";
5
- function hydrateStrategy(config) {
6
- if (config.strategy && !(config.strategy instanceof BaseStrategy)) {
7
- return { ...config, strategy: fromJSON(config.strategy) };
8
- }
9
- return config;
5
+ function getStrategy(model) {
6
+ if (model instanceof BaseStrategy)
7
+ return model;
8
+ return fromJSON(model);
10
9
  }
11
- function splitConfig(config) {
12
- const { openAiApiKey, googleApiKey, ollamaApiKey, anthropicApiKey, ollamaHost, model: rawModel, provider, logLevel, toolLoopDetection, statelog, ...promptConfig } = config;
10
+ export function splitConfig(config) {
11
+ const { openAiApiKey, googleApiKey, ollamaApiKey, anthropicApiKey, ollamaHost, model: rawModel, provider, logLevel, statelog, ...promptConfig } = config;
13
12
  const _model = new Model(rawModel);
14
13
  const model = _model.getResolvedModel();
15
14
  return {
@@ -22,41 +21,21 @@ function splitConfig(config) {
22
21
  model,
23
22
  provider,
24
23
  logLevel,
25
- toolLoopDetection,
26
24
  statelog,
27
25
  },
28
26
  promptConfig,
29
27
  };
30
28
  }
31
29
  export function text(config) {
32
- config = hydrateStrategy(config);
33
- if (config.strategy) {
34
- return config.strategy.text(config);
35
- }
36
- const { smolConfig, promptConfig } = splitConfig(config);
37
- const client = getClient(smolConfig);
38
- return client.text(promptConfig);
30
+ const strategy = getStrategy(config.model);
31
+ return strategy.text(config);
39
32
  }
40
33
  export function textSync(config) {
41
- config = hydrateStrategy(config);
42
- if (config.strategy) {
43
- return config.strategy.textSync(config);
44
- }
45
- const { smolConfig, promptConfig } = splitConfig(config);
46
- const client = getClient(smolConfig);
47
- return client.textSync(promptConfig);
34
+ const strategy = getStrategy(config.model);
35
+ return strategy.textSync(config);
48
36
  }
49
37
  export function textStream(config) {
50
- config = hydrateStrategy(config);
51
- /* if (config.strategy) {
52
- return (config.strategy as import("./strategies/types.js").Strategy).textStream(config);
53
- }
54
- */ const { smolConfig, promptConfig } = splitConfig(config);
55
- const client = getClient(smolConfig);
56
- return client.textStream(promptConfig);
57
- }
58
- export function prompt(promptText, config) {
59
38
  const { smolConfig, promptConfig } = splitConfig(config);
60
39
  const client = getClient(smolConfig);
61
- return client.prompt(promptText, promptConfig);
40
+ return client.textStream(promptConfig);
62
41
  }
package/dist/model.d.ts CHANGED
@@ -1,20 +1,21 @@
1
1
  import { ModelName, Provider, TextModel } from "./models.js";
2
+ import { ModelConfig, ModelNameAndProvider } from "./strategies/types.js";
2
3
  import { ModelLike } from "./types.js";
3
- export type Optimization = "speed" | "reasoning" | "cost" | "large-context";
4
- export type ModelConfig = {
5
- optimizeFor: Optimization[];
6
- providers: Provider[];
7
- limit?: {
8
- cost?: number;
9
- };
10
- };
11
4
  export declare class Model {
12
5
  private model;
13
6
  private resolvedModel;
14
- constructor(model: ModelName | ModelConfig);
15
- getModel(): ModelName | ModelConfig;
7
+ private provider?;
8
+ constructor(model: ModelName | ModelConfig | ModelNameAndProvider, provider?: Provider);
9
+ getModel(): ModelName | ModelNameAndProvider | {
10
+ optimizeFor: ("reasoning" | "speed" | "cost" | "large-context")[];
11
+ providers: ("local" | "ollama" | "openai" | "openai-responses" | "anthropic" | "google" | "replicate" | "modal")[];
12
+ limit?: {
13
+ cost?: number | undefined;
14
+ } | undefined;
15
+ };
16
16
  getResolvedModel(): ModelName;
17
- isModelConfig(model: ModelName | ModelConfig): model is ModelConfig;
17
+ getProvider(): Provider | undefined;
18
+ setProvider(): Provider | undefined;
18
19
  resolveModel(models?: readonly TextModel[]): ModelName;
19
20
  private getRawMetric;
20
21
  private isLowerBetter;
@@ -29,5 +30,6 @@ export declare class Model {
29
30
  totalCost: number;
30
31
  currency: string;
31
32
  } | null;
32
- static create(model: ModelLike): Model;
33
+ toString(): string;
34
+ static create(model: ModelLike, provider?: Provider): Model;
33
35
  }
package/dist/model.js CHANGED
@@ -1,5 +1,6 @@
1
1
  import { getModel, isTextModel, textModels, } from "./models.js";
2
2
  import { SmolError } from "./smolError.js";
3
+ import { ModelConfigSchema, ModelNameAndProviderSchema, ModelNameSchema, } from "./strategies/types.js";
3
4
  import { round } from "./util.js";
4
5
  const WEIGHTS = {
5
6
  1: [1],
@@ -10,9 +11,11 @@ const WEIGHTS = {
10
11
  export class Model {
11
12
  model;
12
13
  resolvedModel;
13
- constructor(model) {
14
+ provider;
15
+ constructor(model, provider) {
14
16
  this.model = model;
15
17
  this.resolvedModel = this.resolveModel();
18
+ this.provider = provider || this.setProvider();
16
19
  }
17
20
  getModel() {
18
21
  return this.model;
@@ -20,17 +23,35 @@ export class Model {
20
23
  getResolvedModel() {
21
24
  return this.resolvedModel;
22
25
  }
23
- isModelConfig(model) {
24
- return typeof model === "object" && "optimizeFor" in model;
26
+ getProvider() {
27
+ if (this.provider) {
28
+ return this.provider;
29
+ }
30
+ return undefined;
31
+ }
32
+ setProvider() {
33
+ if (ModelNameAndProviderSchema.safeParse(this.model).success) {
34
+ const { model, provider } = this.model;
35
+ return provider;
36
+ }
37
+ const resolved = this.getResolvedModel();
38
+ const modelInfo = getModel(resolved);
39
+ if (modelInfo) {
40
+ return modelInfo.provider;
41
+ }
42
+ return undefined;
25
43
  }
26
44
  resolveModel(models = textModels) {
27
- if (!this.isModelConfig(this.model)) {
28
- const modelName = this.model;
29
- const model = getModel(modelName);
30
- if (!model) {
31
- throw new SmolError(`Model ${modelName} is not recognized. Please specify a known model or a valid ModelConfig.`);
32
- }
33
- return modelName;
45
+ if (ModelNameSchema.safeParse(this.model).success) {
46
+ return this.model;
47
+ }
48
+ if (ModelNameAndProviderSchema.safeParse(this.model).success) {
49
+ const { model, provider } = this.model;
50
+ return model;
51
+ }
52
+ const isModelConfig = ModelConfigSchema.safeParse(this.model).success;
53
+ if (!isModelConfig) {
54
+ throw new SmolError(`Model ${JSON.stringify(this.model)} is not recognized. Please specify a known model or a valid ModelConfig.`);
34
55
  }
35
56
  const model = this.model;
36
57
  let candidates = models.filter((m) => model.providers.includes(m.provider) &&
@@ -123,10 +144,13 @@ export class Model {
123
144
  currency: "USD",
124
145
  };
125
146
  }
126
- static create(model) {
147
+ toString() {
148
+ return `Model(${JSON.stringify(this.model)})`;
149
+ }
150
+ static create(model, provider) {
127
151
  if (model instanceof Model) {
128
152
  return model;
129
153
  }
130
- return new Model(model);
154
+ return new Model(model, provider);
131
155
  }
132
156
  }
package/dist/models.d.ts CHANGED
@@ -1,4 +1,16 @@
1
- export type Provider = "local" | "ollama" | "openai" | "openai-responses" | "anthropic" | "google" | "replicate" | "modal";
1
+ import { z } from "zod";
2
+ export declare const providers: readonly ["local", "ollama", "openai", "openai-responses", "anthropic", "google", "replicate", "modal"];
3
+ export declare const ProviderSchema: z.ZodEnum<{
4
+ local: "local";
5
+ ollama: "ollama";
6
+ openai: "openai";
7
+ "openai-responses": "openai-responses";
8
+ anthropic: "anthropic";
9
+ google: "google";
10
+ replicate: "replicate";
11
+ modal: "modal";
12
+ }>;
13
+ export type Provider = z.infer<typeof ProviderSchema>;
2
14
  export type BaseModel = {
3
15
  modelName: string;
4
16
  provider: Provider;
package/dist/models.js CHANGED
@@ -1,3 +1,15 @@
1
+ import { z } from "zod";
2
+ export const providers = [
3
+ "local",
4
+ "ollama",
5
+ "openai",
6
+ "openai-responses",
7
+ "anthropic",
8
+ "google",
9
+ "replicate",
10
+ "modal",
11
+ ];
12
+ export const ProviderSchema = z.enum(providers);
1
13
  export const speechToTextModels = [
2
14
  { type: "speech-to-text", modelName: "whisper-local", provider: "local" },
3
15
  {
@@ -1,6 +1,5 @@
1
+ import { ModelLike } from "./types.js";
1
2
  import { Result } from "./types/result.js";
2
- import { ModelName } from "./models.js";
3
- import { ModelConfig } from "./model.js";
4
3
  export type AgencyFile = {
5
4
  name: string;
6
5
  contents: string;
@@ -62,7 +61,7 @@ export declare class StatelogClient {
62
61
  promptCompletion({ messages, completion, model, timeTaken, tools, responseFormat, }: {
63
62
  messages: any[];
64
63
  completion: any;
65
- model?: ModelName | ModelConfig | string;
64
+ model?: ModelLike;
66
65
  timeTaken?: number;
67
66
  tools?: {
68
67
  name: string;
@@ -75,7 +74,7 @@ export declare class StatelogClient {
75
74
  toolName: string;
76
75
  args: any;
77
76
  output: any;
78
- model?: ModelName | ModelConfig;
77
+ model?: ModelLike;
79
78
  timeTaken?: number;
80
79
  }): Promise<void>;
81
80
  diff({ itemA, itemB, message, }: {