chromadb 1.7.2 → 1.7.3-beta2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/chromadb.d.ts +230 -488
- package/dist/chromadb.legacy-esm.js +154 -879
- package/dist/chromadb.mjs +154 -879
- package/dist/chromadb.mjs.map +1 -1
- package/dist/cjs/chromadb.cjs +157 -885
- package/dist/cjs/chromadb.cjs.map +1 -1
- package/dist/cjs/chromadb.d.cts +230 -488
- package/package.json +2 -6
- package/src/ChromaClient.ts +41 -67
- package/src/Collection.ts +58 -17
- package/src/embeddings/DefaultEmbeddingFunction.ts +99 -0
- package/src/embeddings/TransformersEmbeddingFunction.ts +99 -0
- package/src/generated/api.ts +45 -563
- package/src/generated/configuration.ts +5 -5
- package/src/generated/index.ts +2 -2
- package/src/generated/models.ts +35 -79
- package/src/generated/runtime.ts +5 -5
- package/src/index.ts +3 -36
- package/src/types.ts +0 -75
- package/src/utils.ts +3 -17
- package/src/AdminClient.ts +0 -272
- package/src/CloudClient.ts +0 -46
- package/src/embeddings/GoogleGeminiEmbeddingFunction.ts +0 -69
- package/src/embeddings/HuggingFaceEmbeddingServerFunction.ts +0 -31
- package/src/embeddings/JinaEmbeddingFunction.ts +0 -46
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "chromadb",
|
|
3
|
-
"version": "1.7.
|
|
3
|
+
"version": "1.7.3-beta2",
|
|
4
4
|
"description": "A JavaScript interface for chroma",
|
|
5
5
|
"keywords": [],
|
|
6
6
|
"author": "",
|
|
@@ -43,7 +43,7 @@
|
|
|
43
43
|
"scripts": {
|
|
44
44
|
"test": "run-s db:clean db:cleanauth db:run test:runfull db:clean test:runfull-authonly db:cleanauth",
|
|
45
45
|
"testnoauth": "run-s db:clean db:run test:runfull db:clean",
|
|
46
|
-
"testauth": "run-s db:cleanauth test:runfull-authonly db:cleanauth",
|
|
46
|
+
"testauth": "run-s db:cleanauth db:run-auth test:runfull-authonly db:cleanauth",
|
|
47
47
|
"test:set-port": "cross-env URL=localhost:8001",
|
|
48
48
|
"test:run": "jest --runInBand --testPathIgnorePatterns=test/auth.*.test.ts",
|
|
49
49
|
"test:run-auth-basic": "jest --runInBand --testPathPattern=test/auth.basic.test.ts",
|
|
@@ -76,14 +76,10 @@
|
|
|
76
76
|
"cliui": "^8.0.1"
|
|
77
77
|
},
|
|
78
78
|
"peerDependencies": {
|
|
79
|
-
"@google/generative-ai": "^0.1.1",
|
|
80
79
|
"cohere-ai": "^5.0.0 || ^6.0.0",
|
|
81
80
|
"openai": "^3.0.0 || ^4.0.0"
|
|
82
81
|
},
|
|
83
82
|
"peerDependenciesMeta": {
|
|
84
|
-
"@google/generative-ai": {
|
|
85
|
-
"optional": true
|
|
86
|
-
},
|
|
87
83
|
"cohere-ai": {
|
|
88
84
|
"optional": true
|
|
89
85
|
},
|
package/src/ChromaClient.ts
CHANGED
|
@@ -2,16 +2,14 @@ import { IEmbeddingFunction } from './embeddings/IEmbeddingFunction';
|
|
|
2
2
|
import { Configuration, ApiApi as DefaultApi } from "./generated";
|
|
3
3
|
import { handleSuccess, handleError } from "./utils";
|
|
4
4
|
import { Collection } from './Collection';
|
|
5
|
-
import {
|
|
5
|
+
import { CollectionMetadata, CollectionType, ConfigOptions } from './types';
|
|
6
6
|
import {
|
|
7
7
|
AuthOptions,
|
|
8
8
|
ClientAuthProtocolAdapter,
|
|
9
9
|
IsomorphicFetchClientAuthProtocolAdapter
|
|
10
10
|
} from "./auth";
|
|
11
|
-
import {
|
|
11
|
+
import { DefaultEmbeddingFunction } from './embeddings/DefaultEmbeddingFunction';
|
|
12
12
|
|
|
13
|
-
const DEFAULT_TENANT = "default_tenant"
|
|
14
|
-
const DEFAULT_DATABASE = "default_database"
|
|
15
13
|
|
|
16
14
|
export class ChromaClient {
|
|
17
15
|
/**
|
|
@@ -19,9 +17,6 @@ export class ChromaClient {
|
|
|
19
17
|
*/
|
|
20
18
|
private api: DefaultApi & ConfigOptions;
|
|
21
19
|
private apiAdapter: ClientAuthProtocolAdapter<any>|undefined;
|
|
22
|
-
private tenant: string = DEFAULT_TENANT;
|
|
23
|
-
private database: string = DEFAULT_DATABASE;
|
|
24
|
-
private _adminClient?: AdminClient
|
|
25
20
|
|
|
26
21
|
/**
|
|
27
22
|
* Creates a new ChromaClient instance.
|
|
@@ -40,17 +35,15 @@ export class ChromaClient {
|
|
|
40
35
|
path,
|
|
41
36
|
fetchOptions,
|
|
42
37
|
auth,
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
38
|
+
}: {
|
|
39
|
+
path?: string,
|
|
40
|
+
fetchOptions?: RequestInit,
|
|
41
|
+
auth?: AuthOptions,
|
|
42
|
+
} = {}) {
|
|
46
43
|
if (path === undefined) path = "http://localhost:8000";
|
|
47
|
-
this.tenant = tenant;
|
|
48
|
-
this.database = database;
|
|
49
|
-
|
|
50
44
|
const apiConfig: Configuration = new Configuration({
|
|
51
45
|
basePath: path,
|
|
52
46
|
});
|
|
53
|
-
|
|
54
47
|
if (auth !== undefined) {
|
|
55
48
|
this.apiAdapter = new IsomorphicFetchClientAuthProtocolAdapter(new DefaultApi(apiConfig), auth);
|
|
56
49
|
this.api = this.apiAdapter.getApi();
|
|
@@ -58,19 +51,6 @@ export class ChromaClient {
|
|
|
58
51
|
this.api = new DefaultApi(apiConfig);
|
|
59
52
|
}
|
|
60
53
|
|
|
61
|
-
this._adminClient = new AdminClient({
|
|
62
|
-
path: path,
|
|
63
|
-
fetchOptions: fetchOptions,
|
|
64
|
-
auth: auth,
|
|
65
|
-
tenant: tenant,
|
|
66
|
-
database: database
|
|
67
|
-
});
|
|
68
|
-
|
|
69
|
-
// TODO: Validate tenant and database on client creation
|
|
70
|
-
// this got tricky because:
|
|
71
|
-
// - the constructor is sync but the generated api is async
|
|
72
|
-
// - we need to inject auth information so a simple rewrite/fetch does not work
|
|
73
|
-
|
|
74
54
|
this.api.options = fetchOptions ?? {};
|
|
75
55
|
}
|
|
76
56
|
|
|
@@ -143,9 +123,18 @@ export class ChromaClient {
|
|
|
143
123
|
name,
|
|
144
124
|
metadata,
|
|
145
125
|
embeddingFunction
|
|
146
|
-
}:
|
|
126
|
+
}: {
|
|
127
|
+
name: string,
|
|
128
|
+
metadata?: CollectionMetadata,
|
|
129
|
+
embeddingFunction?: IEmbeddingFunction
|
|
130
|
+
}): Promise<Collection> {
|
|
131
|
+
|
|
132
|
+
if (embeddingFunction === undefined) {
|
|
133
|
+
embeddingFunction = new DefaultEmbeddingFunction();
|
|
134
|
+
}
|
|
135
|
+
|
|
147
136
|
const newCollection = await this.api
|
|
148
|
-
.createCollection(
|
|
137
|
+
.createCollection({
|
|
149
138
|
name,
|
|
150
139
|
metadata,
|
|
151
140
|
}, this.api.options)
|
|
@@ -184,9 +173,18 @@ export class ChromaClient {
|
|
|
184
173
|
name,
|
|
185
174
|
metadata,
|
|
186
175
|
embeddingFunction
|
|
187
|
-
}:
|
|
176
|
+
}: {
|
|
177
|
+
name: string,
|
|
178
|
+
metadata?: CollectionMetadata,
|
|
179
|
+
embeddingFunction?: IEmbeddingFunction
|
|
180
|
+
}): Promise<Collection> {
|
|
181
|
+
|
|
182
|
+
if (embeddingFunction === undefined) {
|
|
183
|
+
embeddingFunction = new DefaultEmbeddingFunction();
|
|
184
|
+
}
|
|
185
|
+
|
|
188
186
|
const newCollection = await this.api
|
|
189
|
-
.createCollection(
|
|
187
|
+
.createCollection({
|
|
190
188
|
name,
|
|
191
189
|
metadata,
|
|
192
190
|
'get_or_create': true
|
|
@@ -211,44 +209,15 @@ export class ChromaClient {
|
|
|
211
209
|
* Lists all collections.
|
|
212
210
|
*
|
|
213
211
|
* @returns {Promise<CollectionType[]>} A promise that resolves to a list of collection names.
|
|
214
|
-
* @param {PositiveInteger} [params.limit] - Optional limit on the number of items to get.
|
|
215
|
-
* @param {PositiveInteger} [params.offset] - Optional offset on the items to get.
|
|
216
212
|
* @throws {Error} If there is an issue listing the collections.
|
|
217
213
|
*
|
|
218
214
|
* @example
|
|
219
215
|
* ```typescript
|
|
220
|
-
* const collections = await client.listCollections(
|
|
221
|
-
* limit: 10,
|
|
222
|
-
* offset: 0,
|
|
223
|
-
* });
|
|
224
|
-
* ```
|
|
225
|
-
*/
|
|
226
|
-
public async listCollections({
|
|
227
|
-
limit,
|
|
228
|
-
offset,
|
|
229
|
-
}: ListCollectionsParams = {}): Promise<CollectionType[]> {
|
|
230
|
-
const response = await this.api.listCollections(
|
|
231
|
-
this.tenant,
|
|
232
|
-
this.database,
|
|
233
|
-
limit,
|
|
234
|
-
offset,
|
|
235
|
-
this.api.options);
|
|
236
|
-
return handleSuccess(response);
|
|
237
|
-
}
|
|
238
|
-
|
|
239
|
-
/**
|
|
240
|
-
* Counts all collections.
|
|
241
|
-
*
|
|
242
|
-
* @returns {Promise<number>} A promise that resolves to the number of collections.
|
|
243
|
-
* @throws {Error} If there is an issue counting the collections.
|
|
244
|
-
*
|
|
245
|
-
* @example
|
|
246
|
-
* ```typescript
|
|
247
|
-
* const collections = await client.countCollections();
|
|
216
|
+
* const collections = await client.listCollections();
|
|
248
217
|
* ```
|
|
249
218
|
*/
|
|
250
|
-
public async
|
|
251
|
-
const response = await this.api.
|
|
219
|
+
public async listCollections(): Promise<CollectionType[]> {
|
|
220
|
+
const response = await this.api.listCollections(this.api.options);
|
|
252
221
|
return handleSuccess(response);
|
|
253
222
|
}
|
|
254
223
|
|
|
@@ -270,9 +239,12 @@ export class ChromaClient {
|
|
|
270
239
|
public async getCollection({
|
|
271
240
|
name,
|
|
272
241
|
embeddingFunction
|
|
273
|
-
}:
|
|
242
|
+
}: {
|
|
243
|
+
name: string;
|
|
244
|
+
embeddingFunction?: IEmbeddingFunction
|
|
245
|
+
}): Promise<Collection> {
|
|
274
246
|
const response = await this.api
|
|
275
|
-
.getCollection(name, this.
|
|
247
|
+
.getCollection(name, this.api.options)
|
|
276
248
|
.then(handleSuccess)
|
|
277
249
|
.catch(handleError);
|
|
278
250
|
|
|
@@ -306,9 +278,11 @@ export class ChromaClient {
|
|
|
306
278
|
*/
|
|
307
279
|
public async deleteCollection({
|
|
308
280
|
name
|
|
309
|
-
}:
|
|
281
|
+
}: {
|
|
282
|
+
name: string
|
|
283
|
+
}): Promise<void> {
|
|
310
284
|
return await this.api
|
|
311
|
-
.deleteCollection(name, this.
|
|
285
|
+
.deleteCollection(name, this.api.options)
|
|
312
286
|
.then(handleSuccess)
|
|
313
287
|
.catch(handleError);
|
|
314
288
|
}
|
package/src/Collection.ts
CHANGED
|
@@ -1,17 +1,21 @@
|
|
|
1
1
|
import {
|
|
2
|
+
IncludeEnum,
|
|
3
|
+
Metadata,
|
|
4
|
+
Metadatas,
|
|
5
|
+
Embedding,
|
|
6
|
+
Embeddings,
|
|
7
|
+
Document,
|
|
8
|
+
Documents,
|
|
9
|
+
Where,
|
|
10
|
+
WhereDocument,
|
|
11
|
+
ID,
|
|
12
|
+
IDs,
|
|
13
|
+
PositiveInteger,
|
|
2
14
|
GetResponse,
|
|
3
15
|
QueryResponse,
|
|
4
16
|
AddResponse,
|
|
5
17
|
CollectionMetadata,
|
|
6
|
-
ConfigOptions
|
|
7
|
-
GetParams,
|
|
8
|
-
AddParams,
|
|
9
|
-
UpsertParams,
|
|
10
|
-
ModifyCollectionParams,
|
|
11
|
-
UpdateParams,
|
|
12
|
-
QueryParams,
|
|
13
|
-
PeekParams,
|
|
14
|
-
DeleteParams
|
|
18
|
+
ConfigOptions
|
|
15
19
|
} from "./types";
|
|
16
20
|
import { IEmbeddingFunction } from './embeddings/IEmbeddingFunction';
|
|
17
21
|
import { ApiApi as DefaultApi } from "./generated";
|
|
@@ -169,7 +173,12 @@ export class Collection {
|
|
|
169
173
|
embeddings,
|
|
170
174
|
metadatas,
|
|
171
175
|
documents,
|
|
172
|
-
}:
|
|
176
|
+
}: {
|
|
177
|
+
ids: ID | IDs,
|
|
178
|
+
embeddings?: Embedding | Embeddings,
|
|
179
|
+
metadatas?: Metadata | Metadatas,
|
|
180
|
+
documents?: Document | Documents,
|
|
181
|
+
}): Promise<AddResponse> {
|
|
173
182
|
|
|
174
183
|
const [idsArray, embeddingsArray, metadatasArray, documentsArray] = await this.validate(
|
|
175
184
|
true,
|
|
@@ -219,7 +228,12 @@ export class Collection {
|
|
|
219
228
|
embeddings,
|
|
220
229
|
metadatas,
|
|
221
230
|
documents,
|
|
222
|
-
}:
|
|
231
|
+
}: {
|
|
232
|
+
ids: ID | IDs,
|
|
233
|
+
embeddings?: Embedding | Embeddings,
|
|
234
|
+
metadatas?: Metadata | Metadatas,
|
|
235
|
+
documents?: Document | Documents,
|
|
236
|
+
}): Promise<boolean> {
|
|
223
237
|
const [idsArray, embeddingsArray, metadatasArray, documentsArray] = await this.validate(
|
|
224
238
|
true,
|
|
225
239
|
ids,
|
|
@@ -279,7 +293,10 @@ export class Collection {
|
|
|
279
293
|
public async modify({
|
|
280
294
|
name,
|
|
281
295
|
metadata
|
|
282
|
-
}:
|
|
296
|
+
}: {
|
|
297
|
+
name?: string,
|
|
298
|
+
metadata?: CollectionMetadata
|
|
299
|
+
} = {}): Promise<void> {
|
|
283
300
|
const response = await this.api
|
|
284
301
|
.updateCollection(
|
|
285
302
|
this.id,
|
|
@@ -296,6 +313,7 @@ export class Collection {
|
|
|
296
313
|
this.setMetadata(metadata || this.metadata);
|
|
297
314
|
|
|
298
315
|
return response;
|
|
316
|
+
|
|
299
317
|
}
|
|
300
318
|
|
|
301
319
|
/**
|
|
@@ -328,7 +346,14 @@ export class Collection {
|
|
|
328
346
|
offset,
|
|
329
347
|
include,
|
|
330
348
|
whereDocument,
|
|
331
|
-
}:
|
|
349
|
+
}: {
|
|
350
|
+
ids?: ID | IDs,
|
|
351
|
+
where?: Where,
|
|
352
|
+
limit?: PositiveInteger,
|
|
353
|
+
offset?: PositiveInteger,
|
|
354
|
+
include?: IncludeEnum[],
|
|
355
|
+
whereDocument?: WhereDocument
|
|
356
|
+
} = {}): Promise<GetResponse> {
|
|
332
357
|
let idsArray = undefined;
|
|
333
358
|
if (ids !== undefined) idsArray = toArray(ids);
|
|
334
359
|
|
|
@@ -370,7 +395,12 @@ export class Collection {
|
|
|
370
395
|
embeddings,
|
|
371
396
|
metadatas,
|
|
372
397
|
documents,
|
|
373
|
-
}:
|
|
398
|
+
}: {
|
|
399
|
+
ids: ID | IDs,
|
|
400
|
+
embeddings?: Embedding | Embeddings,
|
|
401
|
+
metadatas?: Metadata | Metadatas,
|
|
402
|
+
documents?: Document | Documents,
|
|
403
|
+
}): Promise<boolean> {
|
|
374
404
|
if (
|
|
375
405
|
embeddings === undefined &&
|
|
376
406
|
documents === undefined &&
|
|
@@ -451,7 +481,14 @@ export class Collection {
|
|
|
451
481
|
queryTexts,
|
|
452
482
|
whereDocument,
|
|
453
483
|
include,
|
|
454
|
-
}:
|
|
484
|
+
}: {
|
|
485
|
+
queryEmbeddings?: Embedding | Embeddings,
|
|
486
|
+
nResults?: PositiveInteger,
|
|
487
|
+
where?: Where,
|
|
488
|
+
queryTexts?: string | string[],
|
|
489
|
+
whereDocument?: WhereDocument, // {"$contains":"search_string"}
|
|
490
|
+
include?: IncludeEnum[] // ["metadata", "document"]
|
|
491
|
+
}): Promise<QueryResponse> {
|
|
455
492
|
if (nResults === undefined) nResults = 10
|
|
456
493
|
if (queryEmbeddings === undefined && queryTexts === undefined) {
|
|
457
494
|
throw new Error(
|
|
@@ -499,7 +536,7 @@ export class Collection {
|
|
|
499
536
|
* });
|
|
500
537
|
* ```
|
|
501
538
|
*/
|
|
502
|
-
public async peek({ limit }:
|
|
539
|
+
public async peek({ limit }: { limit?: PositiveInteger } = {}): Promise<GetResponse> {
|
|
503
540
|
if (limit === undefined) limit = 10;
|
|
504
541
|
const response = await this.api.aGet(this.id, {
|
|
505
542
|
limit: limit,
|
|
@@ -529,7 +566,11 @@ export class Collection {
|
|
|
529
566
|
ids,
|
|
530
567
|
where,
|
|
531
568
|
whereDocument
|
|
532
|
-
}:
|
|
569
|
+
}: {
|
|
570
|
+
ids?: ID | IDs,
|
|
571
|
+
where?: Where,
|
|
572
|
+
whereDocument?: WhereDocument
|
|
573
|
+
} = {}): Promise<string[]> {
|
|
533
574
|
let idsArray = undefined;
|
|
534
575
|
if (ids !== undefined) idsArray = toArray(ids);
|
|
535
576
|
return await this.api
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import { IEmbeddingFunction } from "./IEmbeddingFunction";
|
|
2
|
+
|
|
3
|
+
// Dynamically import module
|
|
4
|
+
let TransformersApi: Promise<any>;
|
|
5
|
+
|
|
6
|
+
export class DefaultEmbeddingFunction implements IEmbeddingFunction {
|
|
7
|
+
private pipelinePromise?: Promise<any> | null;
|
|
8
|
+
private transformersApi: any;
|
|
9
|
+
private model: string;
|
|
10
|
+
private revision: string;
|
|
11
|
+
private quantized: boolean;
|
|
12
|
+
private progress_callback: Function | null;
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* DefaultEmbeddingFunction constructor.
|
|
16
|
+
* @param options The configuration options.
|
|
17
|
+
* @param options.model The model to use to calculate embeddings. Defaults to 'Xenova/all-MiniLM-L6-v2', which is an ONNX port of `sentence-transformers/all-MiniLM-L6-v2`.
|
|
18
|
+
* @param options.revision The specific model version to use (can be a branch, tag name, or commit id). Defaults to 'main'.
|
|
19
|
+
* @param options.quantized Whether to load the 8-bit quantized version of the model. Defaults to `false`.
|
|
20
|
+
* @param options.progress_callback If specified, this function will be called during model construction, to provide the user with progress updates.
|
|
21
|
+
*/
|
|
22
|
+
constructor({
|
|
23
|
+
model = "Xenova/all-MiniLM-L6-v2",
|
|
24
|
+
revision = "main",
|
|
25
|
+
quantized = false,
|
|
26
|
+
progress_callback = null,
|
|
27
|
+
}: {
|
|
28
|
+
model?: string;
|
|
29
|
+
revision?: string;
|
|
30
|
+
quantized?: boolean;
|
|
31
|
+
progress_callback?: Function | null;
|
|
32
|
+
} = {}) {
|
|
33
|
+
this.model = model;
|
|
34
|
+
this.revision = revision;
|
|
35
|
+
this.quantized = quantized;
|
|
36
|
+
this.progress_callback = progress_callback;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
public async generate(texts: string[]): Promise<number[][]> {
|
|
40
|
+
await this.loadClient();
|
|
41
|
+
|
|
42
|
+
// Store a promise that resolves to the pipeline
|
|
43
|
+
this.pipelinePromise = new Promise(async (resolve, reject) => {
|
|
44
|
+
try {
|
|
45
|
+
const pipeline = this.transformersApi
|
|
46
|
+
|
|
47
|
+
const quantized = this.quantized
|
|
48
|
+
const revision = this.revision
|
|
49
|
+
const progress_callback = this.progress_callback
|
|
50
|
+
|
|
51
|
+
resolve(
|
|
52
|
+
await pipeline("feature-extraction", this.model, {
|
|
53
|
+
quantized,
|
|
54
|
+
revision,
|
|
55
|
+
progress_callback,
|
|
56
|
+
})
|
|
57
|
+
);
|
|
58
|
+
} catch (e) {
|
|
59
|
+
reject(e);
|
|
60
|
+
}
|
|
61
|
+
});
|
|
62
|
+
|
|
63
|
+
let pipe = await this.pipelinePromise;
|
|
64
|
+
let output = await pipe(texts, { pooling: "mean", normalize: true });
|
|
65
|
+
return output.tolist();
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
private async loadClient() {
|
|
69
|
+
if(this.transformersApi) return;
|
|
70
|
+
try {
|
|
71
|
+
// eslint-disable-next-line global-require,import/no-extraneous-dependencies
|
|
72
|
+
let { pipeline } = await DefaultEmbeddingFunction.import();
|
|
73
|
+
TransformersApi = pipeline;
|
|
74
|
+
} catch (_a) {
|
|
75
|
+
// @ts-ignore
|
|
76
|
+
if (_a.code === 'MODULE_NOT_FOUND') {
|
|
77
|
+
throw new Error("Please install the chromadb-default-embed package to use the DefaultEmbeddingFunction, `npm install -S chromadb-default-embed`");
|
|
78
|
+
}
|
|
79
|
+
throw _a; // Re-throw other errors
|
|
80
|
+
}
|
|
81
|
+
this.transformersApi = TransformersApi;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
/** @ignore */
|
|
85
|
+
static async import(): Promise<{
|
|
86
|
+
// @ts-ignore
|
|
87
|
+
pipeline: typeof import("chromadb-default-embed");
|
|
88
|
+
}> {
|
|
89
|
+
try {
|
|
90
|
+
// @ts-ignore
|
|
91
|
+
const { pipeline } = await import("chromadb-default-embed");
|
|
92
|
+
return { pipeline };
|
|
93
|
+
} catch (e) {
|
|
94
|
+
throw new Error(
|
|
95
|
+
"Please install chromadb-default-embed as a dependency with, e.g. `yarn add chromadb-default-embed`"
|
|
96
|
+
);
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
}
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import { IEmbeddingFunction } from "./IEmbeddingFunction";
|
|
2
|
+
|
|
3
|
+
// Dynamically import module
|
|
4
|
+
let TransformersApi: Promise<any>;
|
|
5
|
+
|
|
6
|
+
export class TransformersEmbeddingFunction implements IEmbeddingFunction {
|
|
7
|
+
private pipelinePromise?: Promise<any> | null;
|
|
8
|
+
private transformersApi: any;
|
|
9
|
+
private model: string;
|
|
10
|
+
private revision: string;
|
|
11
|
+
private quantized: boolean;
|
|
12
|
+
private progress_callback: Function | null;
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* TransformersEmbeddingFunction constructor.
|
|
16
|
+
* @param options The configuration options.
|
|
17
|
+
* @param options.model The model to use to calculate embeddings. Defaults to 'Xenova/all-MiniLM-L6-v2', which is an ONNX port of `sentence-transformers/all-MiniLM-L6-v2`.
|
|
18
|
+
* @param options.revision The specific model version to use (can be a branch, tag name, or commit id). Defaults to 'main'.
|
|
19
|
+
* @param options.quantized Whether to load the 8-bit quantized version of the model. Defaults to `false`.
|
|
20
|
+
* @param options.progress_callback If specified, this function will be called during model construction, to provide the user with progress updates.
|
|
21
|
+
*/
|
|
22
|
+
constructor({
|
|
23
|
+
model = "Xenova/all-MiniLM-L6-v2",
|
|
24
|
+
revision = "main",
|
|
25
|
+
quantized = false,
|
|
26
|
+
progress_callback = null,
|
|
27
|
+
}: {
|
|
28
|
+
model?: string;
|
|
29
|
+
revision?: string;
|
|
30
|
+
quantized?: boolean;
|
|
31
|
+
progress_callback?: Function | null;
|
|
32
|
+
} = {}) {
|
|
33
|
+
this.model = model;
|
|
34
|
+
this.revision = revision;
|
|
35
|
+
this.quantized = quantized;
|
|
36
|
+
this.progress_callback = progress_callback;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
public async generate(texts: string[]): Promise<number[][]> {
|
|
40
|
+
await this.loadClient();
|
|
41
|
+
|
|
42
|
+
// Store a promise that resolves to the pipeline
|
|
43
|
+
this.pipelinePromise = new Promise(async (resolve, reject) => {
|
|
44
|
+
try {
|
|
45
|
+
const pipeline = this.transformersApi
|
|
46
|
+
|
|
47
|
+
const quantized = this.quantized
|
|
48
|
+
const revision = this.revision
|
|
49
|
+
const progress_callback = this.progress_callback
|
|
50
|
+
|
|
51
|
+
resolve(
|
|
52
|
+
await pipeline("feature-extraction", this.model, {
|
|
53
|
+
quantized,
|
|
54
|
+
revision,
|
|
55
|
+
progress_callback,
|
|
56
|
+
})
|
|
57
|
+
);
|
|
58
|
+
} catch (e) {
|
|
59
|
+
reject(e);
|
|
60
|
+
}
|
|
61
|
+
});
|
|
62
|
+
|
|
63
|
+
let pipe = await this.pipelinePromise;
|
|
64
|
+
let output = await pipe(texts, { pooling: "mean", normalize: true });
|
|
65
|
+
return output.tolist();
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
private async loadClient() {
|
|
69
|
+
if(this.transformersApi) return;
|
|
70
|
+
try {
|
|
71
|
+
// eslint-disable-next-line global-require,import/no-extraneous-dependencies
|
|
72
|
+
let { pipeline } = await TransformersEmbeddingFunction.import();
|
|
73
|
+
TransformersApi = pipeline;
|
|
74
|
+
} catch (_a) {
|
|
75
|
+
// @ts-ignore
|
|
76
|
+
if (_a.code === 'MODULE_NOT_FOUND') {
|
|
77
|
+
throw new Error("Please install the @xenova/transformers package to use the TransformersEmbeddingFunction, `npm install -S @xenova/transformers`");
|
|
78
|
+
}
|
|
79
|
+
throw _a; // Re-throw other errors
|
|
80
|
+
}
|
|
81
|
+
this.transformersApi = TransformersApi;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
/** @ignore */
|
|
85
|
+
static async import(): Promise<{
|
|
86
|
+
// @ts-ignore
|
|
87
|
+
pipeline: typeof import("@xenova/transformers");
|
|
88
|
+
}> {
|
|
89
|
+
try {
|
|
90
|
+
// @ts-ignore
|
|
91
|
+
const { pipeline } = await import("@xenova/transformers");
|
|
92
|
+
return { pipeline };
|
|
93
|
+
} catch (e) {
|
|
94
|
+
throw new Error(
|
|
95
|
+
"Please install @xenova/transformers as a dependency with, e.g. `yarn add @xenova/transformers`"
|
|
96
|
+
);
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
}
|