@mastra/chroma 0.10.3 → 0.10.4-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +24 -0
- package/dist/_tsup-dts-rollup.d.cts +23 -13
- package/dist/_tsup-dts-rollup.d.ts +23 -13
- package/dist/index.cjs +162 -62
- package/dist/index.js +158 -58
- package/package.json +3 -3
- package/src/vector/filter.test.ts +24 -19
- package/src/vector/filter.ts +35 -4
- package/src/vector/index.test.ts +4 -4
- package/src/vector/index.ts +166 -75
package/src/vector/index.ts
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { MastraError, ErrorDomain, ErrorCategory } from '@mastra/core/error';
|
|
1
2
|
import { MastraVector } from '@mastra/core/vector';
|
|
2
3
|
import type {
|
|
3
4
|
QueryResult,
|
|
@@ -10,21 +11,20 @@ import type {
|
|
|
10
11
|
DeleteVectorParams,
|
|
11
12
|
UpdateVectorParams,
|
|
12
13
|
} from '@mastra/core/vector';
|
|
13
|
-
|
|
14
|
-
import type { VectorFilter } from '@mastra/core/vector/filter';
|
|
15
14
|
import { ChromaClient } from 'chromadb';
|
|
16
15
|
import type { UpdateRecordsParams, Collection } from 'chromadb';
|
|
16
|
+
import type { ChromaVectorDocumentFilter, ChromaVectorFilter } from './filter';
|
|
17
17
|
import { ChromaFilterTranslator } from './filter';
|
|
18
18
|
|
|
19
19
|
interface ChromaUpsertVectorParams extends UpsertVectorParams {
|
|
20
20
|
documents?: string[];
|
|
21
21
|
}
|
|
22
22
|
|
|
23
|
-
interface ChromaQueryVectorParams extends QueryVectorParams {
|
|
24
|
-
documentFilter?:
|
|
23
|
+
interface ChromaQueryVectorParams extends QueryVectorParams<ChromaVectorFilter> {
|
|
24
|
+
documentFilter?: ChromaVectorDocumentFilter;
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
-
export class ChromaVector extends MastraVector {
|
|
27
|
+
export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
28
28
|
private client: ChromaClient;
|
|
29
29
|
private collections: Map<string, any>;
|
|
30
30
|
|
|
@@ -70,28 +70,34 @@ export class ChromaVector extends MastraVector {
|
|
|
70
70
|
}
|
|
71
71
|
|
|
72
72
|
async upsert({ indexName, vectors, metadata, ids, documents }: ChromaUpsertVectorParams): Promise<string[]> {
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
await collection.upsert({
|
|
88
|
-
ids: generatedIds,
|
|
89
|
-
embeddings: vectors,
|
|
90
|
-
metadatas: normalizedMetadata,
|
|
91
|
-
documents: documents,
|
|
92
|
-
});
|
|
73
|
+
try {
|
|
74
|
+
const collection = await this.getCollection(indexName);
|
|
75
|
+
|
|
76
|
+
const stats = await this.describeIndex({ indexName });
|
|
77
|
+
this.validateVectorDimensions(vectors, stats.dimension);
|
|
78
|
+
const generatedIds = ids || vectors.map(() => crypto.randomUUID());
|
|
79
|
+
const normalizedMetadata = metadata || vectors.map(() => ({}));
|
|
80
|
+
|
|
81
|
+
await collection.upsert({
|
|
82
|
+
ids: generatedIds,
|
|
83
|
+
embeddings: vectors,
|
|
84
|
+
metadatas: normalizedMetadata,
|
|
85
|
+
documents: documents,
|
|
86
|
+
});
|
|
93
87
|
|
|
94
|
-
|
|
88
|
+
return generatedIds;
|
|
89
|
+
} catch (error: any) {
|
|
90
|
+
if (error instanceof MastraError) throw error;
|
|
91
|
+
throw new MastraError(
|
|
92
|
+
{
|
|
93
|
+
id: 'CHROMA_VECTOR_UPSERT_FAILED',
|
|
94
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
95
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
96
|
+
details: { indexName },
|
|
97
|
+
},
|
|
98
|
+
error,
|
|
99
|
+
);
|
|
100
|
+
}
|
|
95
101
|
}
|
|
96
102
|
|
|
97
103
|
private HnswSpaceMap = {
|
|
@@ -104,11 +110,23 @@ export class ChromaVector extends MastraVector {
|
|
|
104
110
|
|
|
105
111
|
async createIndex({ indexName, dimension, metric = 'cosine' }: CreateIndexParams): Promise<void> {
|
|
106
112
|
if (!Number.isInteger(dimension) || dimension <= 0) {
|
|
107
|
-
throw new
|
|
113
|
+
throw new MastraError({
|
|
114
|
+
id: 'CHROMA_VECTOR_CREATE_INDEX_INVALID_DIMENSION',
|
|
115
|
+
text: 'Dimension must be a positive integer',
|
|
116
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
117
|
+
category: ErrorCategory.USER,
|
|
118
|
+
details: { dimension },
|
|
119
|
+
});
|
|
108
120
|
}
|
|
109
121
|
const hnswSpace = this.HnswSpaceMap[metric];
|
|
110
|
-
if (!['cosine', 'l2', 'ip'].includes(hnswSpace)) {
|
|
111
|
-
throw new
|
|
122
|
+
if (!hnswSpace || !['cosine', 'l2', 'ip'].includes(hnswSpace)) {
|
|
123
|
+
throw new MastraError({
|
|
124
|
+
id: 'CHROMA_VECTOR_CREATE_INDEX_INVALID_METRIC',
|
|
125
|
+
text: `Invalid metric: "${metric}". Must be one of: cosine, euclidean, dotproduct`,
|
|
126
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
127
|
+
category: ErrorCategory.USER,
|
|
128
|
+
details: { metric },
|
|
129
|
+
});
|
|
112
130
|
}
|
|
113
131
|
try {
|
|
114
132
|
await this.client.createCollection({
|
|
@@ -126,11 +144,19 @@ export class ChromaVector extends MastraVector {
|
|
|
126
144
|
await this.validateExistingIndex(indexName, dimension, metric);
|
|
127
145
|
return;
|
|
128
146
|
}
|
|
129
|
-
throw
|
|
147
|
+
throw new MastraError(
|
|
148
|
+
{
|
|
149
|
+
id: 'CHROMA_VECTOR_CREATE_INDEX_FAILED',
|
|
150
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
151
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
152
|
+
details: { indexName },
|
|
153
|
+
},
|
|
154
|
+
error,
|
|
155
|
+
);
|
|
130
156
|
}
|
|
131
157
|
}
|
|
132
158
|
|
|
133
|
-
transformFilter(filter?:
|
|
159
|
+
transformFilter(filter?: ChromaVectorFilter) {
|
|
134
160
|
const translator = new ChromaFilterTranslator();
|
|
135
161
|
return translator.translate(filter);
|
|
136
162
|
}
|
|
@@ -142,32 +168,55 @@ export class ChromaVector extends MastraVector {
|
|
|
142
168
|
includeVector = false,
|
|
143
169
|
documentFilter,
|
|
144
170
|
}: ChromaQueryVectorParams): Promise<QueryResult[]> {
|
|
145
|
-
|
|
171
|
+
try {
|
|
172
|
+
const collection = await this.getCollection(indexName, true);
|
|
146
173
|
|
|
147
|
-
|
|
174
|
+
const defaultInclude = ['documents', 'metadatas', 'distances'];
|
|
148
175
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
176
|
+
const translatedFilter = this.transformFilter(filter);
|
|
177
|
+
const results = await collection.query({
|
|
178
|
+
queryEmbeddings: [queryVector],
|
|
179
|
+
nResults: topK,
|
|
180
|
+
where: translatedFilter,
|
|
181
|
+
whereDocument: documentFilter,
|
|
182
|
+
include: includeVector ? [...defaultInclude, 'embeddings'] : defaultInclude,
|
|
183
|
+
});
|
|
157
184
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
})
|
|
185
|
+
return (results.ids[0] || []).map((id: string, index: number) => ({
|
|
186
|
+
id,
|
|
187
|
+
score: results.distances?.[0]?.[index] || 0,
|
|
188
|
+
metadata: results.metadatas?.[0]?.[index] || {},
|
|
189
|
+
document: results.documents?.[0]?.[index],
|
|
190
|
+
...(includeVector && { vector: results.embeddings?.[0]?.[index] || [] }),
|
|
191
|
+
}));
|
|
192
|
+
} catch (error: any) {
|
|
193
|
+
if (error instanceof MastraError) throw error;
|
|
194
|
+
throw new MastraError(
|
|
195
|
+
{
|
|
196
|
+
id: 'CHROMA_VECTOR_QUERY_FAILED',
|
|
197
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
198
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
199
|
+
details: { indexName },
|
|
200
|
+
},
|
|
201
|
+
error,
|
|
202
|
+
);
|
|
203
|
+
}
|
|
166
204
|
}
|
|
167
205
|
|
|
168
206
|
async listIndexes(): Promise<string[]> {
|
|
169
|
-
|
|
170
|
-
|
|
207
|
+
try {
|
|
208
|
+
const collections = await this.client.listCollections();
|
|
209
|
+
return collections.map(collection => collection);
|
|
210
|
+
} catch (error: any) {
|
|
211
|
+
throw new MastraError(
|
|
212
|
+
{
|
|
213
|
+
id: 'CHROMA_VECTOR_LIST_INDEXES_FAILED',
|
|
214
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
215
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
216
|
+
},
|
|
217
|
+
error,
|
|
218
|
+
);
|
|
219
|
+
}
|
|
171
220
|
}
|
|
172
221
|
|
|
173
222
|
/**
|
|
@@ -177,22 +226,47 @@ export class ChromaVector extends MastraVector {
|
|
|
177
226
|
* @returns A promise that resolves to the index statistics including dimension, count and metric
|
|
178
227
|
*/
|
|
179
228
|
async describeIndex({ indexName }: DescribeIndexParams): Promise<IndexStats> {
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
229
|
+
try {
|
|
230
|
+
const collection = await this.getCollection(indexName);
|
|
231
|
+
const count = await collection.count();
|
|
232
|
+
const metadata = collection.metadata;
|
|
183
233
|
|
|
184
|
-
|
|
234
|
+
const hnswSpace = metadata?.['hnsw:space'] as 'cosine' | 'l2' | 'ip';
|
|
185
235
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
236
|
+
return {
|
|
237
|
+
dimension: metadata?.dimension || 0,
|
|
238
|
+
count,
|
|
239
|
+
metric: this.HnswSpaceMap[hnswSpace] as 'cosine' | 'euclidean' | 'dotproduct',
|
|
240
|
+
};
|
|
241
|
+
} catch (error: any) {
|
|
242
|
+
if (error instanceof MastraError) throw error;
|
|
243
|
+
throw new MastraError(
|
|
244
|
+
{
|
|
245
|
+
id: 'CHROMA_VECTOR_DESCRIBE_INDEX_FAILED',
|
|
246
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
247
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
248
|
+
details: { indexName },
|
|
249
|
+
},
|
|
250
|
+
error,
|
|
251
|
+
);
|
|
252
|
+
}
|
|
191
253
|
}
|
|
192
254
|
|
|
193
255
|
async deleteIndex({ indexName }: DeleteIndexParams): Promise<void> {
|
|
194
|
-
|
|
195
|
-
|
|
256
|
+
try {
|
|
257
|
+
await this.client.deleteCollection({ name: indexName });
|
|
258
|
+
this.collections.delete(indexName);
|
|
259
|
+
} catch (error: any) {
|
|
260
|
+
throw new MastraError(
|
|
261
|
+
{
|
|
262
|
+
id: 'CHROMA_VECTOR_DELETE_INDEX_FAILED',
|
|
263
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
264
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
265
|
+
details: { indexName },
|
|
266
|
+
},
|
|
267
|
+
error,
|
|
268
|
+
);
|
|
269
|
+
}
|
|
196
270
|
}
|
|
197
271
|
|
|
198
272
|
/**
|
|
@@ -206,11 +280,17 @@ export class ChromaVector extends MastraVector {
|
|
|
206
280
|
* @throws Will throw an error if no updates are provided or if the update operation fails.
|
|
207
281
|
*/
|
|
208
282
|
async updateVector({ indexName, id, update }: UpdateVectorParams): Promise<void> {
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
283
|
+
if (!update.vector && !update.metadata) {
|
|
284
|
+
throw new MastraError({
|
|
285
|
+
id: 'CHROMA_VECTOR_UPDATE_NO_PAYLOAD',
|
|
286
|
+
text: 'No updates provided for vector',
|
|
287
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
288
|
+
category: ErrorCategory.USER,
|
|
289
|
+
details: { indexName, id },
|
|
290
|
+
});
|
|
291
|
+
}
|
|
213
292
|
|
|
293
|
+
try {
|
|
214
294
|
const collection: Collection = await this.getCollection(indexName, true);
|
|
215
295
|
|
|
216
296
|
const updateOptions: UpdateRecordsParams = { ids: [id] };
|
|
@@ -227,23 +307,34 @@ export class ChromaVector extends MastraVector {
|
|
|
227
307
|
|
|
228
308
|
return await collection.update(updateOptions);
|
|
229
309
|
} catch (error: any) {
|
|
230
|
-
|
|
310
|
+
if (error instanceof MastraError) throw error;
|
|
311
|
+
throw new MastraError(
|
|
312
|
+
{
|
|
313
|
+
id: 'CHROMA_VECTOR_UPDATE_FAILED',
|
|
314
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
315
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
316
|
+
details: { indexName, id },
|
|
317
|
+
},
|
|
318
|
+
error,
|
|
319
|
+
);
|
|
231
320
|
}
|
|
232
321
|
}
|
|
233
322
|
|
|
234
|
-
/**
|
|
235
|
-
* Deletes a vector by its ID.
|
|
236
|
-
* @param indexName - The name of the index containing the vector.
|
|
237
|
-
* @param id - The ID of the vector to delete.
|
|
238
|
-
* @returns A promise that resolves when the deletion is complete.
|
|
239
|
-
* @throws Will throw an error if the deletion operation fails.
|
|
240
|
-
*/
|
|
241
323
|
async deleteVector({ indexName, id }: DeleteVectorParams): Promise<void> {
|
|
242
324
|
try {
|
|
243
325
|
const collection: Collection = await this.getCollection(indexName, true);
|
|
244
326
|
await collection.delete({ ids: [id] });
|
|
245
327
|
} catch (error: any) {
|
|
246
|
-
|
|
328
|
+
if (error instanceof MastraError) throw error;
|
|
329
|
+
throw new MastraError(
|
|
330
|
+
{
|
|
331
|
+
id: 'CHROMA_VECTOR_DELETE_FAILED',
|
|
332
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
333
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
334
|
+
details: { indexName, id },
|
|
335
|
+
},
|
|
336
|
+
error,
|
|
337
|
+
);
|
|
247
338
|
}
|
|
248
339
|
}
|
|
249
340
|
}
|