@mastra/chroma 0.1.6-alpha.0 → 0.1.6-alpha.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,25 @@
1
- import { type Filter } from '@mastra/core/filter';
2
- import { MastraVector, type QueryResult, type IndexStats } from '@mastra/core/vector';
1
+ import { MastraVector } from '@mastra/core/vector';
2
+ import type {
3
+ QueryResult,
4
+ IndexStats,
5
+ CreateIndexParams,
6
+ UpsertVectorParams,
7
+ QueryVectorParams,
8
+ ParamsToArgs,
9
+ } from '@mastra/core/vector';
10
+ import type { VectorFilter } from '@mastra/core/vector/filter';
3
11
  import { ChromaClient } from 'chromadb';
4
12
 
5
13
  import { ChromaFilterTranslator } from './filter';
6
14
 
15
+ interface ChromaUpsertVectorParams extends UpsertVectorParams {
16
+ documents?: string[];
17
+ }
18
+
19
+ interface ChromaQueryVectorParams extends QueryVectorParams {
20
+ documentFilter?: VectorFilter;
21
+ }
22
+
7
23
  export class ChromaVector extends MastraVector {
8
24
  private client: ChromaClient;
9
25
  private collections: Map<string, any>;
@@ -26,11 +42,11 @@ export class ChromaVector extends MastraVector {
26
42
  this.collections = new Map();
27
43
  }
28
44
 
29
- private async getCollection(indexName: string, throwIfNotExists: boolean = true) {
45
+ async getCollection(indexName: string, throwIfNotExists: boolean = true) {
30
46
  try {
31
47
  const collection = await this.client.getCollection({ name: indexName, embeddingFunction: undefined as any });
32
48
  this.collections.set(indexName, collection);
33
- } catch (error) {
49
+ } catch {
34
50
  if (throwIfNotExists) {
35
51
  throw new Error(`Index ${indexName} does not exist`);
36
52
  }
@@ -49,12 +65,11 @@ export class ChromaVector extends MastraVector {
49
65
  }
50
66
  }
51
67
 
52
- async upsert(
53
- indexName: string,
54
- vectors: number[][],
55
- metadata?: Record<string, any>[],
56
- ids?: string[],
57
- ): Promise<string[]> {
68
+ async upsert(...args: ParamsToArgs<ChromaUpsertVectorParams>): Promise<string[]> {
69
+ const params = this.normalizeArgs<ChromaUpsertVectorParams>('upsert', args, ['documents']);
70
+
71
+ const { indexName, vectors, metadata, ids, documents } = params;
72
+
58
73
  const collection = await this.getCollection(indexName);
59
74
 
60
75
  // Get index stats to check dimension
@@ -73,50 +88,60 @@ export class ChromaVector extends MastraVector {
73
88
  ids: generatedIds,
74
89
  embeddings: vectors,
75
90
  metadatas: normalizedMetadata,
91
+ documents: documents,
76
92
  });
77
93
 
78
94
  return generatedIds;
79
95
  }
80
96
 
81
- async createIndex(
82
- indexName: string,
83
- dimension: number,
84
- metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
85
- ): Promise<void> {
97
+ private HnswSpaceMap = {
98
+ cosine: 'cosine',
99
+ euclidean: 'l2',
100
+ dotproduct: 'ip',
101
+ l2: 'euclidean',
102
+ ip: 'dotproduct',
103
+ };
104
+
105
+ async createIndex(...args: ParamsToArgs<CreateIndexParams>): Promise<void> {
106
+ const params = this.normalizeArgs<CreateIndexParams>('createIndex', args);
107
+
108
+ const { indexName, dimension, metric = 'cosine' } = params;
109
+
86
110
  if (!Number.isInteger(dimension) || dimension <= 0) {
87
111
  throw new Error('Dimension must be a positive integer');
88
112
  }
113
+ const hnswSpace = this.HnswSpaceMap[metric];
114
+ if (!['cosine', 'l2', 'ip'].includes(hnswSpace)) {
115
+ throw new Error(`Invalid metric: "${metric}". Must be one of: cosine, euclidean, dotproduct`);
116
+ }
89
117
  await this.client.createCollection({
90
118
  name: indexName,
91
119
  metadata: {
92
120
  dimension,
93
- metric,
121
+ 'hnsw:space': this.HnswSpaceMap[metric],
94
122
  },
95
123
  });
96
124
  }
97
125
 
98
- transformFilter(filter?: Filter) {
99
- const chromaFilter = new ChromaFilterTranslator();
100
- const translatedFilter = chromaFilter.translate(filter);
101
- return translatedFilter;
126
+ transformFilter(filter?: VectorFilter) {
127
+ const translator = new ChromaFilterTranslator();
128
+ return translator.translate(filter);
102
129
  }
103
- async query(
104
- indexName: string,
105
- queryVector: number[],
106
- topK: number = 10,
107
- filter?: Filter,
108
- includeVector: boolean = false,
109
- ): Promise<QueryResult[]> {
130
+ async query(...args: ParamsToArgs<ChromaQueryVectorParams>): Promise<QueryResult[]> {
131
+ const params = this.normalizeArgs<ChromaQueryVectorParams>('query', args, ['documentFilter']);
132
+
133
+ const { indexName, queryVector, topK = 10, filter, includeVector = false, documentFilter } = params;
134
+
110
135
  const collection = await this.getCollection(indexName, true);
111
136
 
112
137
  const defaultInclude = ['documents', 'metadatas', 'distances'];
113
138
 
114
139
  const translatedFilter = this.transformFilter(filter);
115
-
116
140
  const results = await collection.query({
117
141
  queryEmbeddings: [queryVector],
118
142
  nResults: topK,
119
143
  where: translatedFilter,
144
+ whereDocument: documentFilter,
120
145
  include: includeVector ? [...defaultInclude, 'embeddings'] : defaultInclude,
121
146
  });
122
147
 
@@ -125,6 +150,7 @@ export class ChromaVector extends MastraVector {
125
150
  id,
126
151
  score: results.distances?.[0]?.[index] || 0,
127
152
  metadata: results.metadatas?.[0]?.[index] || {},
153
+ document: results.documents?.[0]?.[index],
128
154
  ...(includeVector && { vector: results.embeddings?.[0]?.[index] || [] }),
129
155
  }));
130
156
  }
@@ -139,10 +165,12 @@ export class ChromaVector extends MastraVector {
139
165
  const count = await collection.count();
140
166
  const metadata = collection.metadata;
141
167
 
168
+ const hnswSpace = metadata?.['hnsw:space'] as 'cosine' | 'l2' | 'ip';
169
+
142
170
  return {
143
171
  dimension: metadata?.dimension || 0,
144
172
  count,
145
- metric: metadata?.metric as 'cosine' | 'euclidean' | 'dotproduct',
173
+ metric: this.HnswSpaceMap[hnswSpace] as 'cosine' | 'euclidean' | 'dotproduct',
146
174
  };
147
175
  }
148
176