@threaded/ai 1.0.28 → 1.0.30
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +13 -1
- package/dist/index.cjs.map +1 -1
- package/dist/index.d.cts +5 -3
- package/dist/index.d.ts +5 -3
- package/dist/index.js +13 -1
- package/dist/index.js.map +1 -1
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -232,6 +232,7 @@ var addUsage = (existing, promptTokens, completionTokens, totalTokens) => ({
|
|
|
232
232
|
// src/embed.ts
|
|
233
233
|
var modelCache = /* @__PURE__ */ new Map();
|
|
234
234
|
var embed = async (model2, text, config) => {
|
|
235
|
+
const isBatch = Array.isArray(text);
|
|
235
236
|
if (model2.startsWith("openai/")) {
|
|
236
237
|
const modelName = model2.replace("openai/", "");
|
|
237
238
|
const apiKey = getKey("openai") || process.env.OPENAI_API_KEY;
|
|
@@ -258,6 +259,9 @@ var embed = async (model2, text, config) => {
|
|
|
258
259
|
throw new Error(`OpenAI API error: ${error}`);
|
|
259
260
|
}
|
|
260
261
|
const data = await response.json();
|
|
262
|
+
if (isBatch) {
|
|
263
|
+
return data.data.map((d) => d.embedding);
|
|
264
|
+
}
|
|
261
265
|
return data.data[0].embedding;
|
|
262
266
|
}
|
|
263
267
|
try {
|
|
@@ -269,6 +273,14 @@ var embed = async (model2, text, config) => {
|
|
|
269
273
|
modelCache.set(model2, extractor2);
|
|
270
274
|
}
|
|
271
275
|
const extractor = modelCache.get(model2);
|
|
276
|
+
if (isBatch) {
|
|
277
|
+
const results = [];
|
|
278
|
+
for (const t of text) {
|
|
279
|
+
const result2 = await extractor(t, { pooling: "mean", normalize: true });
|
|
280
|
+
results.push(Array.from(result2.data));
|
|
281
|
+
}
|
|
282
|
+
return results;
|
|
283
|
+
}
|
|
272
284
|
const result = await extractor(text, { pooling: "mean", normalize: true });
|
|
273
285
|
return Array.from(result.data);
|
|
274
286
|
} catch (error) {
|
|
@@ -1127,7 +1139,7 @@ var callHuggingFace = async (config, ctx) => {
|
|
|
1127
1139
|
const { pipeline } = await import("@huggingface/transformers");
|
|
1128
1140
|
if (!modelCache2.has(model2)) {
|
|
1129
1141
|
const generator2 = await pipeline("text-generation", model2, {
|
|
1130
|
-
dtype: "
|
|
1142
|
+
dtype: "q4"
|
|
1131
1143
|
});
|
|
1132
1144
|
modelCache2.set(model2, generator2);
|
|
1133
1145
|
}
|