@threaded/ai 1.0.28 → 1.0.30

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -208,13 +208,15 @@ declare const addUsage: (existing: TokenUsage | undefined, promptTokens: number,
208
208
  * openai models use the prefix "openai/" (e.g., "openai/text-embedding-3-small")
209
209
  * all other models use huggingface transformers
210
210
  *
211
+ * accepts a single string or an array of strings for batch embedding
212
+ *
211
213
  * @example
212
214
  * const vector = await embed("openai/text-embedding-3-small", "hello world");
213
- * const vector2 = await embed("Xenova/all-MiniLM-L6-v2", "hello world");
215
+ * const vectors = await embed("openai/text-embedding-3-small", ["hello", "world"]);
214
216
  */
215
- declare const embed: (model: string, text: string, config?: {
217
+ declare const embed: (model: string, text: string | string[], config?: {
216
218
  dimensions?: number;
217
- }) => Promise<number[]>;
219
+ }) => Promise<number[] | number[][]>;
218
220
 
219
221
  declare const generateImage: (model: string, prompt: string, config?: ImageConfig) => Promise<ImageResult>;
220
222
 
package/dist/index.d.ts CHANGED
@@ -208,13 +208,15 @@ declare const addUsage: (existing: TokenUsage | undefined, promptTokens: number,
208
208
  * openai models use the prefix "openai/" (e.g., "openai/text-embedding-3-small")
209
209
  * all other models use huggingface transformers
210
210
  *
211
+ * accepts a single string or an array of strings for batch embedding
212
+ *
211
213
  * @example
212
214
  * const vector = await embed("openai/text-embedding-3-small", "hello world");
213
- * const vector2 = await embed("Xenova/all-MiniLM-L6-v2", "hello world");
215
+ * const vectors = await embed("openai/text-embedding-3-small", ["hello", "world"]);
214
216
  */
215
- declare const embed: (model: string, text: string, config?: {
217
+ declare const embed: (model: string, text: string | string[], config?: {
216
218
  dimensions?: number;
217
- }) => Promise<number[]>;
219
+ }) => Promise<number[] | number[][]>;
218
220
 
219
221
  declare const generateImage: (model: string, prompt: string, config?: ImageConfig) => Promise<ImageResult>;
220
222
 
package/dist/index.js CHANGED
@@ -158,6 +158,7 @@ var addUsage = (existing, promptTokens, completionTokens, totalTokens) => ({
158
158
  // src/embed.ts
159
159
  var modelCache = /* @__PURE__ */ new Map();
160
160
  var embed = async (model2, text, config) => {
161
+ const isBatch = Array.isArray(text);
161
162
  if (model2.startsWith("openai/")) {
162
163
  const modelName = model2.replace("openai/", "");
163
164
  const apiKey = getKey("openai") || process.env.OPENAI_API_KEY;
@@ -184,6 +185,9 @@ var embed = async (model2, text, config) => {
184
185
  throw new Error(`OpenAI API error: ${error}`);
185
186
  }
186
187
  const data = await response.json();
188
+ if (isBatch) {
189
+ return data.data.map((d) => d.embedding);
190
+ }
187
191
  return data.data[0].embedding;
188
192
  }
189
193
  try {
@@ -195,6 +199,14 @@ var embed = async (model2, text, config) => {
195
199
  modelCache.set(model2, extractor2);
196
200
  }
197
201
  const extractor = modelCache.get(model2);
202
+ if (isBatch) {
203
+ const results = [];
204
+ for (const t of text) {
205
+ const result2 = await extractor(t, { pooling: "mean", normalize: true });
206
+ results.push(Array.from(result2.data));
207
+ }
208
+ return results;
209
+ }
198
210
  const result = await extractor(text, { pooling: "mean", normalize: true });
199
211
  return Array.from(result.data);
200
212
  } catch (error) {
@@ -1053,7 +1065,7 @@ var callHuggingFace = async (config, ctx) => {
1053
1065
  const { pipeline } = await import("@huggingface/transformers");
1054
1066
  if (!modelCache2.has(model2)) {
1055
1067
  const generator2 = await pipeline("text-generation", model2, {
1056
- dtype: "q4f16"
1068
+ dtype: "q4"
1057
1069
  });
1058
1070
  modelCache2.set(model2, generator2);
1059
1071
  }