@mastra/astra 0.10.2 → 0.10.3-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +24 -0
- package/dist/_tsup-dts-rollup.d.cts +20 -5
- package/dist/_tsup-dts-rollup.d.ts +20 -5
- package/dist/index.cjs +146 -44
- package/dist/index.js +144 -42
- package/package.json +7 -7
- package/src/vector/filter.test.ts +38 -49
- package/src/vector/filter.ts +26 -4
- package/src/vector/index.test.ts +9 -17
- package/src/vector/index.ts +154 -51
package/src/vector/index.ts
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import type { Db } from '@datastax/astra-db-ts';
|
|
2
2
|
import { DataAPIClient, UUID } from '@datastax/astra-db-ts';
|
|
3
|
+
import { MastraError, ErrorDomain, ErrorCategory } from '@mastra/core/error';
|
|
3
4
|
import { MastraVector } from '@mastra/core/vector';
|
|
4
5
|
import type {
|
|
5
6
|
QueryResult,
|
|
@@ -12,8 +13,7 @@ import type {
|
|
|
12
13
|
DeleteVectorParams,
|
|
13
14
|
UpdateVectorParams,
|
|
14
15
|
} from '@mastra/core/vector';
|
|
15
|
-
import type {
|
|
16
|
-
|
|
16
|
+
import type { AstraVectorFilter } from './filter';
|
|
17
17
|
import { AstraFilterTranslator } from './filter';
|
|
18
18
|
|
|
19
19
|
// Mastra and Astra DB agree on cosine and euclidean, but Astra DB uses dot_product instead of dotproduct.
|
|
@@ -29,7 +29,9 @@ export interface AstraDbOptions {
|
|
|
29
29
|
keyspace?: string;
|
|
30
30
|
}
|
|
31
31
|
|
|
32
|
-
|
|
32
|
+
type AstraQueryVectorParams = QueryVectorParams<AstraVectorFilter>;
|
|
33
|
+
|
|
34
|
+
export class AstraVector extends MastraVector<AstraVectorFilter> {
|
|
33
35
|
readonly #db: Db;
|
|
34
36
|
|
|
35
37
|
constructor({ token, endpoint, keyspace }: AstraDbOptions) {
|
|
@@ -48,15 +50,32 @@ export class AstraVector extends MastraVector {
|
|
|
48
50
|
*/
|
|
49
51
|
async createIndex({ indexName, dimension, metric = 'cosine' }: CreateIndexParams): Promise<void> {
|
|
50
52
|
if (!Number.isInteger(dimension) || dimension <= 0) {
|
|
51
|
-
throw new
|
|
53
|
+
throw new MastraError({
|
|
54
|
+
id: 'ASTRA_VECTOR_CREATE_INDEX_INVALID_DIMENSION',
|
|
55
|
+
text: 'Dimension must be a positive integer',
|
|
56
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
57
|
+
category: ErrorCategory.USER,
|
|
58
|
+
});
|
|
59
|
+
}
|
|
60
|
+
try {
|
|
61
|
+
await this.#db.createCollection(indexName, {
|
|
62
|
+
vector: {
|
|
63
|
+
dimension,
|
|
64
|
+
metric: metricMap[metric],
|
|
65
|
+
},
|
|
66
|
+
checkExists: false,
|
|
67
|
+
});
|
|
68
|
+
} catch (error: any) {
|
|
69
|
+
new MastraError(
|
|
70
|
+
{
|
|
71
|
+
id: 'ASTRA_VECTOR_CREATE_INDEX_DB_ERROR',
|
|
72
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
73
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
74
|
+
details: { indexName },
|
|
75
|
+
},
|
|
76
|
+
error,
|
|
77
|
+
);
|
|
52
78
|
}
|
|
53
|
-
await this.#db.createCollection(indexName, {
|
|
54
|
-
vector: {
|
|
55
|
-
dimension,
|
|
56
|
-
metric: metricMap[metric],
|
|
57
|
-
},
|
|
58
|
-
checkExists: false,
|
|
59
|
-
});
|
|
60
79
|
}
|
|
61
80
|
|
|
62
81
|
/**
|
|
@@ -80,11 +99,23 @@ export class AstraVector extends MastraVector {
|
|
|
80
99
|
metadata: metadata?.[i] || {},
|
|
81
100
|
}));
|
|
82
101
|
|
|
83
|
-
|
|
102
|
+
try {
|
|
103
|
+
await collection.insertMany(records);
|
|
104
|
+
} catch (error: any) {
|
|
105
|
+
throw new MastraError(
|
|
106
|
+
{
|
|
107
|
+
id: 'ASTRA_VECTOR_UPSERT_DB_ERROR',
|
|
108
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
109
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
110
|
+
details: { indexName },
|
|
111
|
+
},
|
|
112
|
+
error,
|
|
113
|
+
);
|
|
114
|
+
}
|
|
84
115
|
return vectorIds;
|
|
85
116
|
}
|
|
86
117
|
|
|
87
|
-
transformFilter(filter?:
|
|
118
|
+
transformFilter(filter?: AstraVectorFilter) {
|
|
88
119
|
const translator = new AstraFilterTranslator();
|
|
89
120
|
return translator.translate(filter);
|
|
90
121
|
}
|
|
@@ -105,28 +136,40 @@ export class AstraVector extends MastraVector {
|
|
|
105
136
|
topK = 10,
|
|
106
137
|
filter,
|
|
107
138
|
includeVector = false,
|
|
108
|
-
}:
|
|
139
|
+
}: AstraQueryVectorParams): Promise<QueryResult[]> {
|
|
109
140
|
const collection = this.#db.collection(indexName);
|
|
110
141
|
|
|
111
142
|
const translatedFilter = this.transformFilter(filter);
|
|
112
143
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
144
|
+
try {
|
|
145
|
+
const cursor = collection.find(translatedFilter ?? {}, {
|
|
146
|
+
sort: { $vector: queryVector },
|
|
147
|
+
limit: topK,
|
|
148
|
+
includeSimilarity: true,
|
|
149
|
+
projection: {
|
|
150
|
+
$vector: includeVector ? true : false,
|
|
151
|
+
},
|
|
152
|
+
});
|
|
153
|
+
|
|
154
|
+
const results = await cursor.toArray();
|
|
155
|
+
|
|
156
|
+
return results.map(result => ({
|
|
157
|
+
id: result.id,
|
|
158
|
+
score: result.$similarity,
|
|
159
|
+
metadata: result.metadata,
|
|
160
|
+
...(includeVector && { vector: result.$vector }),
|
|
161
|
+
}));
|
|
162
|
+
} catch (error: any) {
|
|
163
|
+
throw new MastraError(
|
|
164
|
+
{
|
|
165
|
+
id: 'ASTRA_VECTOR_QUERY_DB_ERROR',
|
|
166
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
167
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
168
|
+
details: { indexName },
|
|
169
|
+
},
|
|
170
|
+
error,
|
|
171
|
+
);
|
|
172
|
+
}
|
|
130
173
|
}
|
|
131
174
|
|
|
132
175
|
/**
|
|
@@ -134,8 +177,19 @@ export class AstraVector extends MastraVector {
|
|
|
134
177
|
*
|
|
135
178
|
* @returns {Promise<string[]>} A promise that resolves to an array of collection names.
|
|
136
179
|
*/
|
|
137
|
-
listIndexes(): Promise<string[]> {
|
|
138
|
-
|
|
180
|
+
async listIndexes(): Promise<string[]> {
|
|
181
|
+
try {
|
|
182
|
+
return await this.#db.listCollections({ nameOnly: true });
|
|
183
|
+
} catch (error: any) {
|
|
184
|
+
throw new MastraError(
|
|
185
|
+
{
|
|
186
|
+
id: 'ASTRA_VECTOR_LIST_INDEXES_DB_ERROR',
|
|
187
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
188
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
189
|
+
},
|
|
190
|
+
error,
|
|
191
|
+
);
|
|
192
|
+
}
|
|
139
193
|
}
|
|
140
194
|
|
|
141
195
|
/**
|
|
@@ -146,17 +200,30 @@ export class AstraVector extends MastraVector {
|
|
|
146
200
|
*/
|
|
147
201
|
async describeIndex({ indexName }: DescribeIndexParams): Promise<IndexStats> {
|
|
148
202
|
const collection = this.#db.collection(indexName);
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
203
|
+
try {
|
|
204
|
+
const optionsPromise = collection.options();
|
|
205
|
+
const countPromise = collection.countDocuments({}, 100);
|
|
206
|
+
const [options, count] = await Promise.all([optionsPromise, countPromise]);
|
|
207
|
+
|
|
208
|
+
const keys = Object.keys(metricMap) as (keyof typeof metricMap)[];
|
|
209
|
+
const metric = keys.find(key => metricMap[key] === options.vector?.metric);
|
|
210
|
+
return {
|
|
211
|
+
dimension: options.vector?.dimension!,
|
|
212
|
+
metric,
|
|
213
|
+
count: count,
|
|
214
|
+
};
|
|
215
|
+
} catch (error: any) {
|
|
216
|
+
if (error instanceof MastraError) throw error;
|
|
217
|
+
throw new MastraError(
|
|
218
|
+
{
|
|
219
|
+
id: 'ASTRA_VECTOR_DESCRIBE_INDEX_DB_ERROR',
|
|
220
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
221
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
222
|
+
details: { indexName },
|
|
223
|
+
},
|
|
224
|
+
error,
|
|
225
|
+
);
|
|
226
|
+
}
|
|
160
227
|
}
|
|
161
228
|
|
|
162
229
|
/**
|
|
@@ -167,7 +234,19 @@ export class AstraVector extends MastraVector {
|
|
|
167
234
|
*/
|
|
168
235
|
async deleteIndex({ indexName }: DeleteIndexParams): Promise<void> {
|
|
169
236
|
const collection = this.#db.collection(indexName);
|
|
170
|
-
|
|
237
|
+
try {
|
|
238
|
+
await collection.drop();
|
|
239
|
+
} catch (error: any) {
|
|
240
|
+
throw new MastraError(
|
|
241
|
+
{
|
|
242
|
+
id: 'ASTRA_VECTOR_DELETE_INDEX_DB_ERROR',
|
|
243
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
244
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
245
|
+
details: { indexName },
|
|
246
|
+
},
|
|
247
|
+
error,
|
|
248
|
+
);
|
|
249
|
+
}
|
|
171
250
|
}
|
|
172
251
|
|
|
173
252
|
/**
|
|
@@ -181,11 +260,17 @@ export class AstraVector extends MastraVector {
|
|
|
181
260
|
* @throws Will throw an error if no updates are provided or if the update operation fails.
|
|
182
261
|
*/
|
|
183
262
|
async updateVector({ indexName, id, update }: UpdateVectorParams): Promise<void> {
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
263
|
+
if (!update.vector && !update.metadata) {
|
|
264
|
+
throw new MastraError({
|
|
265
|
+
id: 'ASTRA_VECTOR_UPDATE_NO_PAYLOAD',
|
|
266
|
+
text: 'No updates provided for vector',
|
|
267
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
268
|
+
category: ErrorCategory.USER,
|
|
269
|
+
details: { indexName, id },
|
|
270
|
+
});
|
|
271
|
+
}
|
|
188
272
|
|
|
273
|
+
try {
|
|
189
274
|
const collection = this.#db.collection(indexName);
|
|
190
275
|
const updateDoc: Record<string, any> = {};
|
|
191
276
|
|
|
@@ -199,7 +284,16 @@ export class AstraVector extends MastraVector {
|
|
|
199
284
|
|
|
200
285
|
await collection.findOneAndUpdate({ id }, { $set: updateDoc });
|
|
201
286
|
} catch (error: any) {
|
|
202
|
-
|
|
287
|
+
if (error instanceof MastraError) throw error;
|
|
288
|
+
throw new MastraError(
|
|
289
|
+
{
|
|
290
|
+
id: 'ASTRA_VECTOR_UPDATE_FAILED_UNHANDLED',
|
|
291
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
292
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
293
|
+
details: { indexName, id },
|
|
294
|
+
},
|
|
295
|
+
error,
|
|
296
|
+
);
|
|
203
297
|
}
|
|
204
298
|
}
|
|
205
299
|
|
|
@@ -215,7 +309,16 @@ export class AstraVector extends MastraVector {
|
|
|
215
309
|
const collection = this.#db.collection(indexName);
|
|
216
310
|
await collection.deleteOne({ id });
|
|
217
311
|
} catch (error: any) {
|
|
218
|
-
|
|
312
|
+
if (error instanceof MastraError) throw error;
|
|
313
|
+
throw new MastraError(
|
|
314
|
+
{
|
|
315
|
+
id: 'ASTRA_VECTOR_DELETE_FAILED',
|
|
316
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
317
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
318
|
+
details: { indexName, id },
|
|
319
|
+
},
|
|
320
|
+
error,
|
|
321
|
+
);
|
|
219
322
|
}
|
|
220
323
|
}
|
|
221
324
|
}
|