@threaded/ai 1.0.29 → 1.0.30

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -232,6 +232,7 @@ var addUsage = (existing, promptTokens, completionTokens, totalTokens) => ({
232
232
  // src/embed.ts
233
233
  var modelCache = /* @__PURE__ */ new Map();
234
234
  var embed = async (model2, text, config) => {
235
+ const isBatch = Array.isArray(text);
235
236
  if (model2.startsWith("openai/")) {
236
237
  const modelName = model2.replace("openai/", "");
237
238
  const apiKey = getKey("openai") || process.env.OPENAI_API_KEY;
@@ -258,6 +259,9 @@ var embed = async (model2, text, config) => {
258
259
  throw new Error(`OpenAI API error: ${error}`);
259
260
  }
260
261
  const data = await response.json();
262
+ if (isBatch) {
263
+ return data.data.map((d) => d.embedding);
264
+ }
261
265
  return data.data[0].embedding;
262
266
  }
263
267
  try {
@@ -269,6 +273,14 @@ var embed = async (model2, text, config) => {
269
273
  modelCache.set(model2, extractor2);
270
274
  }
271
275
  const extractor = modelCache.get(model2);
276
+ if (isBatch) {
277
+ const results = [];
278
+ for (const t of text) {
279
+ const result2 = await extractor(t, { pooling: "mean", normalize: true });
280
+ results.push(Array.from(result2.data));
281
+ }
282
+ return results;
283
+ }
272
284
  const result = await extractor(text, { pooling: "mean", normalize: true });
273
285
  return Array.from(result.data);
274
286
  } catch (error) {