@mastra/chroma 0.1.6-alpha.0 → 0.1.6-alpha.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +11 -6
- package/CHANGELOG.md +37 -0
- package/README.md +26 -16
- package/dist/_tsup-dts-rollup.d.cts +56 -0
- package/dist/_tsup-dts-rollup.d.ts +24 -11
- package/dist/index.cjs +214 -0
- package/dist/index.d.cts +1 -0
- package/dist/index.js +31 -11
- package/docker-compose.yaml +7 -0
- package/eslint.config.js +6 -0
- package/package.json +12 -5
- package/src/vector/filter.ts +5 -10
- package/src/vector/index.test.ts +687 -192
- package/src/vector/index.ts +57 -29
package/src/vector/index.ts
CHANGED
|
@@ -1,9 +1,25 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import
|
|
1
|
+
import { MastraVector } from '@mastra/core/vector';
|
|
2
|
+
import type {
|
|
3
|
+
QueryResult,
|
|
4
|
+
IndexStats,
|
|
5
|
+
CreateIndexParams,
|
|
6
|
+
UpsertVectorParams,
|
|
7
|
+
QueryVectorParams,
|
|
8
|
+
ParamsToArgs,
|
|
9
|
+
} from '@mastra/core/vector';
|
|
10
|
+
import type { VectorFilter } from '@mastra/core/vector/filter';
|
|
3
11
|
import { ChromaClient } from 'chromadb';
|
|
4
12
|
|
|
5
13
|
import { ChromaFilterTranslator } from './filter';
|
|
6
14
|
|
|
15
|
+
interface ChromaUpsertVectorParams extends UpsertVectorParams {
|
|
16
|
+
documents?: string[];
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
interface ChromaQueryVectorParams extends QueryVectorParams {
|
|
20
|
+
documentFilter?: VectorFilter;
|
|
21
|
+
}
|
|
22
|
+
|
|
7
23
|
export class ChromaVector extends MastraVector {
|
|
8
24
|
private client: ChromaClient;
|
|
9
25
|
private collections: Map<string, any>;
|
|
@@ -26,11 +42,11 @@ export class ChromaVector extends MastraVector {
|
|
|
26
42
|
this.collections = new Map();
|
|
27
43
|
}
|
|
28
44
|
|
|
29
|
-
|
|
45
|
+
async getCollection(indexName: string, throwIfNotExists: boolean = true) {
|
|
30
46
|
try {
|
|
31
47
|
const collection = await this.client.getCollection({ name: indexName, embeddingFunction: undefined as any });
|
|
32
48
|
this.collections.set(indexName, collection);
|
|
33
|
-
} catch
|
|
49
|
+
} catch {
|
|
34
50
|
if (throwIfNotExists) {
|
|
35
51
|
throw new Error(`Index ${indexName} does not exist`);
|
|
36
52
|
}
|
|
@@ -49,12 +65,11 @@ export class ChromaVector extends MastraVector {
|
|
|
49
65
|
}
|
|
50
66
|
}
|
|
51
67
|
|
|
52
|
-
async upsert(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
metadata
|
|
56
|
-
|
|
57
|
-
): Promise<string[]> {
|
|
68
|
+
async upsert(...args: ParamsToArgs<ChromaUpsertVectorParams>): Promise<string[]> {
|
|
69
|
+
const params = this.normalizeArgs<ChromaUpsertVectorParams>('upsert', args, ['documents']);
|
|
70
|
+
|
|
71
|
+
const { indexName, vectors, metadata, ids, documents } = params;
|
|
72
|
+
|
|
58
73
|
const collection = await this.getCollection(indexName);
|
|
59
74
|
|
|
60
75
|
// Get index stats to check dimension
|
|
@@ -73,50 +88,60 @@ export class ChromaVector extends MastraVector {
|
|
|
73
88
|
ids: generatedIds,
|
|
74
89
|
embeddings: vectors,
|
|
75
90
|
metadatas: normalizedMetadata,
|
|
91
|
+
documents: documents,
|
|
76
92
|
});
|
|
77
93
|
|
|
78
94
|
return generatedIds;
|
|
79
95
|
}
|
|
80
96
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
97
|
+
private HnswSpaceMap = {
|
|
98
|
+
cosine: 'cosine',
|
|
99
|
+
euclidean: 'l2',
|
|
100
|
+
dotproduct: 'ip',
|
|
101
|
+
l2: 'euclidean',
|
|
102
|
+
ip: 'dotproduct',
|
|
103
|
+
};
|
|
104
|
+
|
|
105
|
+
async createIndex(...args: ParamsToArgs<CreateIndexParams>): Promise<void> {
|
|
106
|
+
const params = this.normalizeArgs<CreateIndexParams>('createIndex', args);
|
|
107
|
+
|
|
108
|
+
const { indexName, dimension, metric = 'cosine' } = params;
|
|
109
|
+
|
|
86
110
|
if (!Number.isInteger(dimension) || dimension <= 0) {
|
|
87
111
|
throw new Error('Dimension must be a positive integer');
|
|
88
112
|
}
|
|
113
|
+
const hnswSpace = this.HnswSpaceMap[metric];
|
|
114
|
+
if (!['cosine', 'l2', 'ip'].includes(hnswSpace)) {
|
|
115
|
+
throw new Error(`Invalid metric: "${metric}". Must be one of: cosine, euclidean, dotproduct`);
|
|
116
|
+
}
|
|
89
117
|
await this.client.createCollection({
|
|
90
118
|
name: indexName,
|
|
91
119
|
metadata: {
|
|
92
120
|
dimension,
|
|
93
|
-
metric,
|
|
121
|
+
'hnsw:space': this.HnswSpaceMap[metric],
|
|
94
122
|
},
|
|
95
123
|
});
|
|
96
124
|
}
|
|
97
125
|
|
|
98
|
-
transformFilter(filter?:
|
|
99
|
-
const
|
|
100
|
-
|
|
101
|
-
return translatedFilter;
|
|
126
|
+
transformFilter(filter?: VectorFilter) {
|
|
127
|
+
const translator = new ChromaFilterTranslator();
|
|
128
|
+
return translator.translate(filter);
|
|
102
129
|
}
|
|
103
|
-
async query(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
topK
|
|
107
|
-
|
|
108
|
-
includeVector: boolean = false,
|
|
109
|
-
): Promise<QueryResult[]> {
|
|
130
|
+
async query(...args: ParamsToArgs<ChromaQueryVectorParams>): Promise<QueryResult[]> {
|
|
131
|
+
const params = this.normalizeArgs<ChromaQueryVectorParams>('query', args, ['documentFilter']);
|
|
132
|
+
|
|
133
|
+
const { indexName, queryVector, topK = 10, filter, includeVector = false, documentFilter } = params;
|
|
134
|
+
|
|
110
135
|
const collection = await this.getCollection(indexName, true);
|
|
111
136
|
|
|
112
137
|
const defaultInclude = ['documents', 'metadatas', 'distances'];
|
|
113
138
|
|
|
114
139
|
const translatedFilter = this.transformFilter(filter);
|
|
115
|
-
|
|
116
140
|
const results = await collection.query({
|
|
117
141
|
queryEmbeddings: [queryVector],
|
|
118
142
|
nResults: topK,
|
|
119
143
|
where: translatedFilter,
|
|
144
|
+
whereDocument: documentFilter,
|
|
120
145
|
include: includeVector ? [...defaultInclude, 'embeddings'] : defaultInclude,
|
|
121
146
|
});
|
|
122
147
|
|
|
@@ -125,6 +150,7 @@ export class ChromaVector extends MastraVector {
|
|
|
125
150
|
id,
|
|
126
151
|
score: results.distances?.[0]?.[index] || 0,
|
|
127
152
|
metadata: results.metadatas?.[0]?.[index] || {},
|
|
153
|
+
document: results.documents?.[0]?.[index],
|
|
128
154
|
...(includeVector && { vector: results.embeddings?.[0]?.[index] || [] }),
|
|
129
155
|
}));
|
|
130
156
|
}
|
|
@@ -139,10 +165,12 @@ export class ChromaVector extends MastraVector {
|
|
|
139
165
|
const count = await collection.count();
|
|
140
166
|
const metadata = collection.metadata;
|
|
141
167
|
|
|
168
|
+
const hnswSpace = metadata?.['hnsw:space'] as 'cosine' | 'l2' | 'ip';
|
|
169
|
+
|
|
142
170
|
return {
|
|
143
171
|
dimension: metadata?.dimension || 0,
|
|
144
172
|
count,
|
|
145
|
-
metric:
|
|
173
|
+
metric: this.HnswSpaceMap[hnswSpace] as 'cosine' | 'euclidean' | 'dotproduct',
|
|
146
174
|
};
|
|
147
175
|
}
|
|
148
176
|
|