@mastra/chroma 0.11.3 → 0.11.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +1 -1
- package/CHANGELOG.md +45 -0
- package/README.md +47 -21
- package/dist/index.cjs +113 -49
- package/dist/index.cjs.map +1 -1
- package/dist/index.js +114 -50
- package/dist/index.js.map +1 -1
- package/dist/vector/filter.d.ts +0 -3
- package/dist/vector/filter.d.ts.map +1 -1
- package/dist/vector/index.d.ts +33 -13
- package/dist/vector/index.d.ts.map +1 -1
- package/package.json +7 -6
- package/src/vector/filter.test.ts +3 -0
- package/src/vector/filter.ts +4 -12
- package/src/vector/index.test.ts +138 -16
- package/src/vector/index.ts +144 -66
package/src/vector/filter.ts
CHANGED
|
@@ -21,17 +21,6 @@ export type ChromaVectorFilter = VectorFilter<
|
|
|
21
21
|
ChromaBlacklisted
|
|
22
22
|
>;
|
|
23
23
|
|
|
24
|
-
type ChromaDocumentOperatorValueMap = ChromaOperatorValueMap;
|
|
25
|
-
|
|
26
|
-
type ChromaDocumentBlacklisted = Exclude<ChromaBlacklisted, '$contains'>;
|
|
27
|
-
|
|
28
|
-
export type ChromaVectorDocumentFilter = VectorFilter<
|
|
29
|
-
keyof ChromaDocumentOperatorValueMap,
|
|
30
|
-
ChromaDocumentOperatorValueMap,
|
|
31
|
-
ChromaLogicalOperatorValueMap,
|
|
32
|
-
ChromaDocumentBlacklisted
|
|
33
|
-
>;
|
|
34
|
-
|
|
35
24
|
/**
|
|
36
25
|
* Translator for Chroma filter queries.
|
|
37
26
|
* Maintains MongoDB-compatible syntax while ensuring proper validation
|
|
@@ -59,7 +48,7 @@ export class ChromaFilterTranslator extends BaseFilterTranslator<ChromaVectorFil
|
|
|
59
48
|
private translateNode(node: ChromaVectorFilter, currentPath: string = ''): any {
|
|
60
49
|
// Handle primitive values and arrays
|
|
61
50
|
if (this.isRegex(node)) {
|
|
62
|
-
throw new Error('Regex is
|
|
51
|
+
throw new Error('Regex is supported in Chroma via the `documentFilter` argument');
|
|
63
52
|
}
|
|
64
53
|
if (this.isPrimitive(node)) return this.normalizeComparisonValue(node);
|
|
65
54
|
if (Array.isArray(node)) return { $in: this.normalizeArrayValues(node) };
|
|
@@ -70,6 +59,9 @@ export class ChromaFilterTranslator extends BaseFilterTranslator<ChromaVectorFil
|
|
|
70
59
|
if (entries.length === 1 && firstEntry && this.isOperator(firstEntry[0])) {
|
|
71
60
|
const [operator, value] = firstEntry;
|
|
72
61
|
const translated = this.translateOperator(operator, value);
|
|
62
|
+
if (this.isLogicalOperator(operator) && Array.isArray(translated) && translated.length === 1) {
|
|
63
|
+
return translated[0];
|
|
64
|
+
}
|
|
73
65
|
return this.isLogicalOperator(operator) ? { [operator]: translated } : translated;
|
|
74
66
|
}
|
|
75
67
|
|
package/src/vector/index.test.ts
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
1
|
+
import { createVectorTestSuite } from '@internal/storage-test-utils';
|
|
1
2
|
import type { QueryResult, IndexStats } from '@mastra/core/vector';
|
|
2
3
|
import { describe, expect, beforeEach, afterEach, it, beforeAll, afterAll, vi } from 'vitest';
|
|
3
4
|
|
|
4
5
|
import { ChromaVector } from './';
|
|
5
6
|
|
|
6
7
|
describe('ChromaVector Integration Tests', () => {
|
|
7
|
-
let vectorDB = new ChromaVector(
|
|
8
|
-
path: 'http://localhost:8000',
|
|
9
|
-
});
|
|
8
|
+
let vectorDB = new ChromaVector();
|
|
10
9
|
|
|
11
10
|
const testIndexName = 'test-index';
|
|
12
11
|
const testIndexName2 = 'test-index-2';
|
|
@@ -122,7 +121,7 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
122
121
|
const ids = await vectorDB.upsert({ indexName: testIndexName, vectors: testVectors });
|
|
123
122
|
expect(ids).toHaveLength(3);
|
|
124
123
|
|
|
125
|
-
const idToBeUpdated = ids[0];
|
|
124
|
+
const idToBeUpdated = ids[0] as string;
|
|
126
125
|
const newVector = [1, 2, 3];
|
|
127
126
|
const newMetaData = {
|
|
128
127
|
test: 'updates',
|
|
@@ -150,7 +149,7 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
150
149
|
const ids = await vectorDB.upsert({ indexName: testIndexName, vectors: testVectors });
|
|
151
150
|
expect(ids).toHaveLength(3);
|
|
152
151
|
|
|
153
|
-
const idToBeUpdated = ids[0];
|
|
152
|
+
const idToBeUpdated = ids[0] as string;
|
|
154
153
|
const newMetaData = {
|
|
155
154
|
test: 'updates',
|
|
156
155
|
};
|
|
@@ -163,7 +162,7 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
163
162
|
|
|
164
163
|
const results: QueryResult[] = await vectorDB.query({
|
|
165
164
|
indexName: testIndexName,
|
|
166
|
-
queryVector: testVectors[0],
|
|
165
|
+
queryVector: testVectors[0] as number[],
|
|
167
166
|
topK: 2,
|
|
168
167
|
includeVector: true,
|
|
169
168
|
});
|
|
@@ -176,7 +175,7 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
176
175
|
const ids = await vectorDB.upsert({ indexName: testIndexName, vectors: testVectors });
|
|
177
176
|
expect(ids).toHaveLength(3);
|
|
178
177
|
|
|
179
|
-
const idToBeUpdated = ids[0];
|
|
178
|
+
const idToBeUpdated = ids[0] as string;
|
|
180
179
|
const newVector = [1, 2, 3];
|
|
181
180
|
|
|
182
181
|
const update = {
|
|
@@ -204,7 +203,7 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
204
203
|
it('should delete the vector by id', async () => {
|
|
205
204
|
const ids = await vectorDB.upsert({ indexName: testIndexName, vectors: testVectors });
|
|
206
205
|
expect(ids).toHaveLength(3);
|
|
207
|
-
const idToBeDeleted = ids[0];
|
|
206
|
+
const idToBeDeleted = ids[0] as string;
|
|
208
207
|
|
|
209
208
|
await vectorDB.deleteVector({ indexName: testIndexName, id: idToBeDeleted });
|
|
210
209
|
|
|
@@ -1285,7 +1284,8 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
1285
1284
|
const results = await vectorDB.query({ indexName: testIndexName3, queryVector: [1.0, 0.0, 0.0], topK: 3 });
|
|
1286
1285
|
expect(results).toHaveLength(3);
|
|
1287
1286
|
// Verify documents are returned
|
|
1288
|
-
|
|
1287
|
+
|
|
1288
|
+
expect(results[0]!.document).toBe(testDocuments[0]);
|
|
1289
1289
|
});
|
|
1290
1290
|
|
|
1291
1291
|
it('should filter documents using $contains', async () => {
|
|
@@ -1317,8 +1317,19 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
1317
1317
|
documentFilter: { $contains: 'fox' },
|
|
1318
1318
|
});
|
|
1319
1319
|
expect(results).toHaveLength(1);
|
|
1320
|
-
expect(results[0]
|
|
1321
|
-
expect(results[0]
|
|
1320
|
+
expect(results[0]!.metadata?.source).toBe('pangram1');
|
|
1321
|
+
expect(results[0]!.document).toContain('fox');
|
|
1322
|
+
});
|
|
1323
|
+
|
|
1324
|
+
it('should get records with metadata and document filters', async () => {
|
|
1325
|
+
const results = await vectorDB.get({
|
|
1326
|
+
indexName: testIndexName3,
|
|
1327
|
+
filter: { source: 'pangram1' },
|
|
1328
|
+
documentFilter: { $contains: 'fox' },
|
|
1329
|
+
});
|
|
1330
|
+
expect(results).toHaveLength(1);
|
|
1331
|
+
expect(results[0]!.metadata?.source).toBe('pangram1');
|
|
1332
|
+
expect(results[0]!.document).toContain('fox');
|
|
1322
1333
|
});
|
|
1323
1334
|
});
|
|
1324
1335
|
|
|
@@ -1331,8 +1342,8 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
1331
1342
|
documentFilter: { $and: [{ $contains: 'quick' }, { $not_contains: 'fox' }] },
|
|
1332
1343
|
});
|
|
1333
1344
|
expect(results).toHaveLength(1);
|
|
1334
|
-
expect(results[0]
|
|
1335
|
-
expect(results[0]
|
|
1345
|
+
expect(results[0]!.document).toContain('quick');
|
|
1346
|
+
expect(results[0]!.document).not.toContain('fox');
|
|
1336
1347
|
});
|
|
1337
1348
|
|
|
1338
1349
|
it('should handle $or conditions', async () => {
|
|
@@ -1343,8 +1354,8 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
1343
1354
|
documentFilter: { $or: [{ $contains: 'fox' }, { $contains: 'zebras' }] },
|
|
1344
1355
|
});
|
|
1345
1356
|
expect(results).toHaveLength(2);
|
|
1346
|
-
expect(results[0]
|
|
1347
|
-
expect(results[1]
|
|
1357
|
+
expect(results[0]!.document).toContain('fox');
|
|
1358
|
+
expect(results[1]!.document).toContain('zebras');
|
|
1348
1359
|
});
|
|
1349
1360
|
});
|
|
1350
1361
|
|
|
@@ -1395,7 +1406,7 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
1395
1406
|
documentFilter: { $contains: 'quick brown' }, // Test multi-word match
|
|
1396
1407
|
});
|
|
1397
1408
|
expect(results.length).toBe(1);
|
|
1398
|
-
expect(results[0]
|
|
1409
|
+
expect(results[0]!.document).toContain('quick brown');
|
|
1399
1410
|
});
|
|
1400
1411
|
|
|
1401
1412
|
it('should handle deeply nested logical operators', async () => {
|
|
@@ -1442,6 +1453,7 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
1442
1453
|
vectorDB.query({
|
|
1443
1454
|
indexName: testIndexName3,
|
|
1444
1455
|
queryVector: [1, 0, 0],
|
|
1456
|
+
// @ts-ignore
|
|
1445
1457
|
documentFilter: {},
|
|
1446
1458
|
}),
|
|
1447
1459
|
).rejects.toThrow();
|
|
@@ -1548,3 +1560,113 @@ describe('ChromaVector Integration Tests', () => {
|
|
|
1548
1560
|
}, 30000);
|
|
1549
1561
|
});
|
|
1550
1562
|
});
|
|
1563
|
+
|
|
1564
|
+
// Metadata filtering tests for Memory system
|
|
1565
|
+
describe('Chroma Metadata Filtering', () => {
|
|
1566
|
+
const chromaVector = new ChromaVector();
|
|
1567
|
+
|
|
1568
|
+
createVectorTestSuite({
|
|
1569
|
+
vector: chromaVector,
|
|
1570
|
+
createIndex: async (indexName: string) => {
|
|
1571
|
+
// Using dimension 4 as required by the metadata filtering test vectors
|
|
1572
|
+
await chromaVector.createIndex({ indexName, dimension: 4 });
|
|
1573
|
+
},
|
|
1574
|
+
deleteIndex: async (indexName: string) => {
|
|
1575
|
+
await chromaVector.deleteIndex({ indexName });
|
|
1576
|
+
},
|
|
1577
|
+
waitForIndexing: async () => {
|
|
1578
|
+
// Chroma may need a short wait for indexing
|
|
1579
|
+
await new Promise(resolve => setTimeout(resolve, 2000));
|
|
1580
|
+
},
|
|
1581
|
+
});
|
|
1582
|
+
});
|
|
1583
|
+
|
|
1584
|
+
// ChromaCloudVector fork functionality tests (requires CHROMA_API_KEY)
|
|
1585
|
+
describe.skipIf(!process.env.CHROMA_API_KEY)('ChromaCloudVector Fork Tests', () => {
|
|
1586
|
+
let cloudVector: ChromaVector;
|
|
1587
|
+
const testIndexName = 'fork-test-index';
|
|
1588
|
+
const forkedIndexName = 'forked-test-index';
|
|
1589
|
+
const dimension = 3;
|
|
1590
|
+
|
|
1591
|
+
beforeEach(async () => {
|
|
1592
|
+
cloudVector = new ChromaVector({
|
|
1593
|
+
apiKey: process.env.CHROMA_API_KEY,
|
|
1594
|
+
});
|
|
1595
|
+
|
|
1596
|
+
// Clean up any existing test indexes
|
|
1597
|
+
try {
|
|
1598
|
+
await cloudVector.deleteIndex({ indexName: testIndexName });
|
|
1599
|
+
} catch {
|
|
1600
|
+
// Ignore errors if index doesn't exist
|
|
1601
|
+
}
|
|
1602
|
+
try {
|
|
1603
|
+
await cloudVector.deleteIndex({ indexName: forkedIndexName });
|
|
1604
|
+
} catch {
|
|
1605
|
+
// Ignore errors if index doesn't exist
|
|
1606
|
+
}
|
|
1607
|
+
});
|
|
1608
|
+
|
|
1609
|
+
afterEach(async () => {
|
|
1610
|
+
// Clean up test indexes
|
|
1611
|
+
try {
|
|
1612
|
+
await cloudVector.deleteIndex({ indexName: testIndexName });
|
|
1613
|
+
} catch {
|
|
1614
|
+
// Ignore cleanup errors
|
|
1615
|
+
}
|
|
1616
|
+
try {
|
|
1617
|
+
await cloudVector.deleteIndex({ indexName: forkedIndexName });
|
|
1618
|
+
} catch {
|
|
1619
|
+
// Ignore cleanup errors
|
|
1620
|
+
}
|
|
1621
|
+
});
|
|
1622
|
+
|
|
1623
|
+
it('should fork an index successfully', async () => {
|
|
1624
|
+
// Create initial index with some data
|
|
1625
|
+
await cloudVector.createIndex({ indexName: testIndexName, dimension });
|
|
1626
|
+
|
|
1627
|
+
const testVectors = [
|
|
1628
|
+
[1.0, 0.0, 0.0],
|
|
1629
|
+
[0.0, 1.0, 0.0],
|
|
1630
|
+
[0.0, 0.0, 1.0],
|
|
1631
|
+
];
|
|
1632
|
+
const testMetadata = [{ label: 'x-axis' }, { label: 'y-axis' }, { label: 'z-axis' }];
|
|
1633
|
+
const testIds = ['vec1', 'vec2', 'vec3'];
|
|
1634
|
+
|
|
1635
|
+
await cloudVector.upsert({
|
|
1636
|
+
indexName: testIndexName,
|
|
1637
|
+
vectors: testVectors,
|
|
1638
|
+
ids: testIds,
|
|
1639
|
+
metadata: testMetadata,
|
|
1640
|
+
});
|
|
1641
|
+
|
|
1642
|
+
// Fork the index
|
|
1643
|
+
await cloudVector.forkIndex({
|
|
1644
|
+
indexName: testIndexName,
|
|
1645
|
+
newIndexName: forkedIndexName,
|
|
1646
|
+
});
|
|
1647
|
+
|
|
1648
|
+
// Verify both indexes exist and have the same data
|
|
1649
|
+
let originalStats = await cloudVector.describeIndex({ indexName: testIndexName });
|
|
1650
|
+
let forkedStats = await cloudVector.describeIndex({ indexName: forkedIndexName });
|
|
1651
|
+
|
|
1652
|
+
expect(originalStats.count).toBe(3);
|
|
1653
|
+
expect(forkedStats.count).toBe(3);
|
|
1654
|
+
|
|
1655
|
+
await cloudVector.deleteVector({ indexName: forkedIndexName, id: 'vec1' });
|
|
1656
|
+
|
|
1657
|
+
originalStats = await cloudVector.describeIndex({ indexName: testIndexName });
|
|
1658
|
+
forkedStats = await cloudVector.describeIndex({ indexName: forkedIndexName });
|
|
1659
|
+
|
|
1660
|
+
expect(originalStats.count).toBe(3);
|
|
1661
|
+
expect(forkedStats.count).toBe(2);
|
|
1662
|
+
});
|
|
1663
|
+
|
|
1664
|
+
it('should throw error when forking non-existent index', async () => {
|
|
1665
|
+
await expect(
|
|
1666
|
+
cloudVector.forkIndex({
|
|
1667
|
+
indexName: 'non-existent-index',
|
|
1668
|
+
newIndexName: forkedIndexName,
|
|
1669
|
+
}),
|
|
1670
|
+
).rejects.toThrow();
|
|
1671
|
+
});
|
|
1672
|
+
});
|
package/src/vector/index.ts
CHANGED
|
@@ -11,9 +11,9 @@ import type {
|
|
|
11
11
|
DeleteVectorParams,
|
|
12
12
|
UpdateVectorParams,
|
|
13
13
|
} from '@mastra/core/vector';
|
|
14
|
-
import { ChromaClient } from 'chromadb';
|
|
15
|
-
import type {
|
|
16
|
-
import type {
|
|
14
|
+
import { ChromaClient, CloudClient } from 'chromadb';
|
|
15
|
+
import type { ChromaClientArgs, RecordSet, Where, WhereDocument, Collection, Metadata } from 'chromadb';
|
|
16
|
+
import type { ChromaVectorFilter } from './filter';
|
|
17
17
|
import { ChromaFilterTranslator } from './filter';
|
|
18
18
|
|
|
19
19
|
interface ChromaUpsertVectorParams extends UpsertVectorParams {
|
|
@@ -21,42 +21,68 @@ interface ChromaUpsertVectorParams extends UpsertVectorParams {
|
|
|
21
21
|
}
|
|
22
22
|
|
|
23
23
|
interface ChromaQueryVectorParams extends QueryVectorParams<ChromaVectorFilter> {
|
|
24
|
-
documentFilter?:
|
|
24
|
+
documentFilter?: WhereDocument | null;
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
+
interface ChromaGetRecordsParams {
|
|
28
|
+
indexName: string;
|
|
29
|
+
ids?: string[];
|
|
30
|
+
filter?: ChromaVectorFilter;
|
|
31
|
+
documentFilter?: WhereDocument | null;
|
|
32
|
+
includeVector?: boolean;
|
|
33
|
+
limit?: number;
|
|
34
|
+
offset?: number;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
type MastraMetadata = {
|
|
38
|
+
dimension?: number;
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
type ChromaVectorArgs = ChromaClientArgs & { apiKey?: string };
|
|
42
|
+
|
|
43
|
+
const spaceMappings = {
|
|
44
|
+
cosine: 'cosine',
|
|
45
|
+
euclidean: 'l2',
|
|
46
|
+
dotproduct: 'ip',
|
|
47
|
+
l2: 'euclidean',
|
|
48
|
+
ip: 'dotproduct',
|
|
49
|
+
};
|
|
50
|
+
|
|
27
51
|
export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
28
52
|
private client: ChromaClient;
|
|
29
|
-
private collections: Map<string,
|
|
30
|
-
|
|
31
|
-
constructor({
|
|
32
|
-
path,
|
|
33
|
-
auth,
|
|
34
|
-
}: {
|
|
35
|
-
path: string;
|
|
36
|
-
auth?: {
|
|
37
|
-
provider: string;
|
|
38
|
-
credentials: string;
|
|
39
|
-
};
|
|
40
|
-
}) {
|
|
53
|
+
private collections: Map<string, Collection>;
|
|
54
|
+
|
|
55
|
+
constructor(chromaClientArgs?: ChromaVectorArgs) {
|
|
41
56
|
super();
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
57
|
+
if (chromaClientArgs?.apiKey) {
|
|
58
|
+
this.client = new CloudClient({
|
|
59
|
+
apiKey: chromaClientArgs.apiKey,
|
|
60
|
+
tenant: chromaClientArgs.tenant,
|
|
61
|
+
database: chromaClientArgs.database,
|
|
62
|
+
});
|
|
63
|
+
} else {
|
|
64
|
+
this.client = new ChromaClient(chromaClientArgs);
|
|
65
|
+
}
|
|
46
66
|
this.collections = new Map();
|
|
47
67
|
}
|
|
48
68
|
|
|
49
|
-
async getCollection(indexName:
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
69
|
+
async getCollection({ indexName, forceUpdate = false }: { indexName: string; forceUpdate?: boolean }) {
|
|
70
|
+
let collection = this.collections.get(indexName);
|
|
71
|
+
if (forceUpdate || !collection) {
|
|
72
|
+
try {
|
|
73
|
+
collection = await this.client.getCollection({ name: indexName });
|
|
74
|
+
this.collections.set(indexName, collection);
|
|
75
|
+
return collection;
|
|
76
|
+
} catch {
|
|
77
|
+
throw new MastraError({
|
|
78
|
+
id: 'CHROMA_COLLECTION_GET_FAILED',
|
|
79
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
80
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
81
|
+
details: { indexName },
|
|
82
|
+
});
|
|
56
83
|
}
|
|
57
|
-
return null;
|
|
58
84
|
}
|
|
59
|
-
return
|
|
85
|
+
return collection;
|
|
60
86
|
}
|
|
61
87
|
|
|
62
88
|
private validateVectorDimensions(vectors: number[][], dimension: number): void {
|
|
@@ -71,17 +97,16 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
71
97
|
|
|
72
98
|
async upsert({ indexName, vectors, metadata, ids, documents }: ChromaUpsertVectorParams): Promise<string[]> {
|
|
73
99
|
try {
|
|
74
|
-
const collection = await this.getCollection(indexName);
|
|
100
|
+
const collection = await this.getCollection({ indexName });
|
|
75
101
|
|
|
76
102
|
const stats = await this.describeIndex({ indexName });
|
|
77
103
|
this.validateVectorDimensions(vectors, stats.dimension);
|
|
78
104
|
const generatedIds = ids || vectors.map(() => crypto.randomUUID());
|
|
79
|
-
const normalizedMetadata = metadata || vectors.map(() => ({}));
|
|
80
105
|
|
|
81
106
|
await collection.upsert({
|
|
82
107
|
ids: generatedIds,
|
|
83
108
|
embeddings: vectors,
|
|
84
|
-
metadatas:
|
|
109
|
+
metadatas: metadata,
|
|
85
110
|
documents: documents,
|
|
86
111
|
});
|
|
87
112
|
|
|
@@ -100,14 +125,6 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
100
125
|
}
|
|
101
126
|
}
|
|
102
127
|
|
|
103
|
-
private HnswSpaceMap = {
|
|
104
|
-
cosine: 'cosine',
|
|
105
|
-
euclidean: 'l2',
|
|
106
|
-
dotproduct: 'ip',
|
|
107
|
-
l2: 'euclidean',
|
|
108
|
-
ip: 'dotproduct',
|
|
109
|
-
};
|
|
110
|
-
|
|
111
128
|
async createIndex({ indexName, dimension, metric = 'cosine' }: CreateIndexParams): Promise<void> {
|
|
112
129
|
if (!Number.isInteger(dimension) || dimension <= 0) {
|
|
113
130
|
throw new MastraError({
|
|
@@ -118,7 +135,9 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
118
135
|
details: { dimension },
|
|
119
136
|
});
|
|
120
137
|
}
|
|
121
|
-
|
|
138
|
+
|
|
139
|
+
const hnswSpace = spaceMappings[metric] as 'cosine' | 'l2' | 'ip' | undefined;
|
|
140
|
+
|
|
122
141
|
if (!hnswSpace || !['cosine', 'l2', 'ip'].includes(hnswSpace)) {
|
|
123
142
|
throw new MastraError({
|
|
124
143
|
id: 'CHROMA_VECTOR_CREATE_INDEX_INVALID_METRIC',
|
|
@@ -128,14 +147,15 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
128
147
|
details: { metric },
|
|
129
148
|
});
|
|
130
149
|
}
|
|
150
|
+
|
|
131
151
|
try {
|
|
132
|
-
await this.client.createCollection({
|
|
152
|
+
const collection = await this.client.createCollection({
|
|
133
153
|
name: indexName,
|
|
134
|
-
metadata: {
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
},
|
|
154
|
+
metadata: { dimension },
|
|
155
|
+
configuration: { hnsw: { space: hnswSpace } },
|
|
156
|
+
embeddingFunction: null,
|
|
138
157
|
});
|
|
158
|
+
this.collections.set(indexName, collection);
|
|
139
159
|
} catch (error: any) {
|
|
140
160
|
// Check for 'already exists' error
|
|
141
161
|
const message = error?.message || error?.toString();
|
|
@@ -158,9 +178,11 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
158
178
|
|
|
159
179
|
transformFilter(filter?: ChromaVectorFilter) {
|
|
160
180
|
const translator = new ChromaFilterTranslator();
|
|
161
|
-
|
|
181
|
+
const translatedFilter = translator.translate(filter);
|
|
182
|
+
return translatedFilter ? (translatedFilter as Where) : undefined;
|
|
162
183
|
}
|
|
163
|
-
|
|
184
|
+
|
|
185
|
+
async query<T extends Metadata = Metadata>({
|
|
164
186
|
indexName,
|
|
165
187
|
queryVector,
|
|
166
188
|
topK = 10,
|
|
@@ -169,16 +191,16 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
169
191
|
documentFilter,
|
|
170
192
|
}: ChromaQueryVectorParams): Promise<QueryResult[]> {
|
|
171
193
|
try {
|
|
172
|
-
const collection = await this.getCollection(indexName
|
|
194
|
+
const collection = await this.getCollection({ indexName });
|
|
173
195
|
|
|
174
|
-
const defaultInclude = ['documents', 'metadatas', 'distances'];
|
|
196
|
+
const defaultInclude: ['documents', 'metadatas', 'distances'] = ['documents', 'metadatas', 'distances'];
|
|
175
197
|
|
|
176
198
|
const translatedFilter = this.transformFilter(filter);
|
|
177
|
-
const results = await collection.query({
|
|
199
|
+
const results = await collection.query<T>({
|
|
178
200
|
queryEmbeddings: [queryVector],
|
|
179
201
|
nResults: topK,
|
|
180
|
-
where: translatedFilter,
|
|
181
|
-
whereDocument: documentFilter,
|
|
202
|
+
where: translatedFilter ?? undefined,
|
|
203
|
+
whereDocument: documentFilter ?? undefined,
|
|
182
204
|
include: includeVector ? [...defaultInclude, 'embeddings'] : defaultInclude,
|
|
183
205
|
});
|
|
184
206
|
|
|
@@ -186,7 +208,7 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
186
208
|
id,
|
|
187
209
|
score: results.distances?.[0]?.[index] || 0,
|
|
188
210
|
metadata: results.metadatas?.[0]?.[index] || {},
|
|
189
|
-
document: results.documents?.[0]?.[index],
|
|
211
|
+
document: results.documents?.[0]?.[index] ?? undefined,
|
|
190
212
|
...(includeVector && { vector: results.embeddings?.[0]?.[index] || [] }),
|
|
191
213
|
}));
|
|
192
214
|
} catch (error: any) {
|
|
@@ -203,10 +225,48 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
203
225
|
}
|
|
204
226
|
}
|
|
205
227
|
|
|
228
|
+
async get<T extends Metadata = Metadata>({
|
|
229
|
+
indexName,
|
|
230
|
+
ids,
|
|
231
|
+
filter,
|
|
232
|
+
includeVector = false,
|
|
233
|
+
documentFilter,
|
|
234
|
+
offset,
|
|
235
|
+
limit,
|
|
236
|
+
}: ChromaGetRecordsParams) {
|
|
237
|
+
try {
|
|
238
|
+
const collection = await this.getCollection({ indexName });
|
|
239
|
+
|
|
240
|
+
const defaultInclude: ['documents', 'metadatas'] = ['documents', 'metadatas'];
|
|
241
|
+
const translatedFilter = this.transformFilter(filter);
|
|
242
|
+
|
|
243
|
+
const result = await collection.get<T>({
|
|
244
|
+
ids,
|
|
245
|
+
where: translatedFilter ?? undefined,
|
|
246
|
+
whereDocument: documentFilter ?? undefined,
|
|
247
|
+
offset,
|
|
248
|
+
limit,
|
|
249
|
+
include: includeVector ? [...defaultInclude, 'embeddings'] : defaultInclude,
|
|
250
|
+
});
|
|
251
|
+
return result.rows();
|
|
252
|
+
} catch (error: any) {
|
|
253
|
+
if (error instanceof MastraError) throw error;
|
|
254
|
+
throw new MastraError(
|
|
255
|
+
{
|
|
256
|
+
id: 'CHROMA_VECTOR_GET_FAILED',
|
|
257
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
258
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
259
|
+
details: { indexName },
|
|
260
|
+
},
|
|
261
|
+
error,
|
|
262
|
+
);
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
|
|
206
266
|
async listIndexes(): Promise<string[]> {
|
|
207
267
|
try {
|
|
208
268
|
const collections = await this.client.listCollections();
|
|
209
|
-
return collections.map(collection => collection);
|
|
269
|
+
return collections.map(collection => collection.name);
|
|
210
270
|
} catch (error: any) {
|
|
211
271
|
throw new MastraError(
|
|
212
272
|
{
|
|
@@ -227,16 +287,15 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
227
287
|
*/
|
|
228
288
|
async describeIndex({ indexName }: DescribeIndexParams): Promise<IndexStats> {
|
|
229
289
|
try {
|
|
230
|
-
const collection = await this.getCollection(indexName);
|
|
290
|
+
const collection = await this.getCollection({ indexName });
|
|
231
291
|
const count = await collection.count();
|
|
232
|
-
const metadata = collection.metadata;
|
|
233
|
-
|
|
234
|
-
const hnswSpace = metadata?.['hnsw:space'] as 'cosine' | 'l2' | 'ip';
|
|
292
|
+
const metadata = collection.metadata as MastraMetadata | undefined;
|
|
293
|
+
const space = collection.configuration.hnsw?.space || collection.configuration.spann?.space || undefined;
|
|
235
294
|
|
|
236
295
|
return {
|
|
237
296
|
dimension: metadata?.dimension || 0,
|
|
238
297
|
count,
|
|
239
|
-
metric:
|
|
298
|
+
metric: space ? (spaceMappings[space] as 'cosine' | 'euclidean' | 'dotproduct') : undefined,
|
|
240
299
|
};
|
|
241
300
|
} catch (error: any) {
|
|
242
301
|
if (error instanceof MastraError) throw error;
|
|
@@ -269,6 +328,25 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
269
328
|
}
|
|
270
329
|
}
|
|
271
330
|
|
|
331
|
+
async forkIndex({ indexName, newIndexName }: { indexName: string; newIndexName: string }): Promise<void> {
|
|
332
|
+
try {
|
|
333
|
+
const collection = await this.getCollection({ indexName, forceUpdate: true });
|
|
334
|
+
const forkedCollection = await collection.fork({ name: newIndexName });
|
|
335
|
+
this.collections.set(newIndexName, forkedCollection);
|
|
336
|
+
} catch (error: any) {
|
|
337
|
+
if (error instanceof MastraError) throw error;
|
|
338
|
+
throw new MastraError(
|
|
339
|
+
{
|
|
340
|
+
id: 'CHROMA_INDEX_FORK_FAILED',
|
|
341
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
342
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
343
|
+
details: { indexName },
|
|
344
|
+
},
|
|
345
|
+
error,
|
|
346
|
+
);
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
272
350
|
/**
|
|
273
351
|
* Updates a vector by its ID with the provided vector and/or metadata.
|
|
274
352
|
* @param indexName - The name of the index containing the vector.
|
|
@@ -291,21 +369,21 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
291
369
|
}
|
|
292
370
|
|
|
293
371
|
try {
|
|
294
|
-
const collection: Collection = await this.getCollection(indexName
|
|
372
|
+
const collection: Collection = await this.getCollection({ indexName });
|
|
295
373
|
|
|
296
|
-
const
|
|
374
|
+
const updateRecordSet: RecordSet = { ids: [id] };
|
|
297
375
|
|
|
298
376
|
if (update?.vector) {
|
|
299
377
|
const stats = await this.describeIndex({ indexName });
|
|
300
378
|
this.validateVectorDimensions([update.vector], stats.dimension);
|
|
301
|
-
|
|
379
|
+
updateRecordSet.embeddings = [update.vector];
|
|
302
380
|
}
|
|
303
381
|
|
|
304
382
|
if (update?.metadata) {
|
|
305
|
-
|
|
383
|
+
updateRecordSet.metadatas = [update.metadata];
|
|
306
384
|
}
|
|
307
385
|
|
|
308
|
-
return await collection.update(
|
|
386
|
+
return await collection.update(updateRecordSet);
|
|
309
387
|
} catch (error: any) {
|
|
310
388
|
if (error instanceof MastraError) throw error;
|
|
311
389
|
throw new MastraError(
|
|
@@ -322,7 +400,7 @@ export class ChromaVector extends MastraVector<ChromaVectorFilter> {
|
|
|
322
400
|
|
|
323
401
|
async deleteVector({ indexName, id }: DeleteVectorParams): Promise<void> {
|
|
324
402
|
try {
|
|
325
|
-
const collection: Collection = await this.getCollection(indexName
|
|
403
|
+
const collection: Collection = await this.getCollection({ indexName });
|
|
326
404
|
await collection.delete({ ids: [id] });
|
|
327
405
|
} catch (error: any) {
|
|
328
406
|
if (error instanceof MastraError) throw error;
|