@mastra/pg 0.3.4 → 0.3.5-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -1,5 +1,6 @@
1
1
  'use strict';
2
2
 
3
+ var utils = require('@mastra/core/utils');
3
4
  var vector = require('@mastra/core/vector');
4
5
  var asyncMutex = require('async-mutex');
5
6
  var pg = require('pg');
@@ -81,22 +82,26 @@ var PGFilterTranslator = class extends filter.BaseFilterTranslator {
81
82
  return { $regex: flags ? `(?${flags})${pattern}` : pattern };
82
83
  }
83
84
  };
84
-
85
- // src/vector/sql-builder.ts
86
85
  var createBasicOperator = (symbol) => {
87
- return (key, paramIndex) => ({
88
- sql: `CASE
89
- WHEN $${paramIndex}::text IS NULL THEN metadata#>>'{${handleKey(key)}}' IS ${symbol === "=" ? "" : "NOT"} NULL
90
- ELSE metadata#>>'{${handleKey(key)}}' ${symbol} $${paramIndex}::text
91
- END`,
92
- needsValue: true
93
- });
86
+ return (key, paramIndex) => {
87
+ const jsonPathKey = parseJsonPathKey(key);
88
+ return {
89
+ sql: `CASE
90
+ WHEN $${paramIndex}::text IS NULL THEN metadata#>>'{${jsonPathKey}}' IS ${symbol === "=" ? "" : "NOT"} NULL
91
+ ELSE metadata#>>'{${jsonPathKey}}' ${symbol} $${paramIndex}::text
92
+ END`,
93
+ needsValue: true
94
+ };
95
+ };
94
96
  };
95
97
  var createNumericOperator = (symbol) => {
96
- return (key, paramIndex) => ({
97
- sql: `(metadata#>>'{${handleKey(key)}}')::numeric ${symbol} $${paramIndex}`,
98
- needsValue: true
99
- });
98
+ return (key, paramIndex) => {
99
+ const jsonPathKey = parseJsonPathKey(key);
100
+ return {
101
+ sql: `(metadata#>>'{${jsonPathKey}}')::numeric ${symbol} $${paramIndex}`,
102
+ needsValue: true
103
+ };
104
+ };
100
105
  };
101
106
  function buildElemMatchConditions(value, paramIndex) {
102
107
  if (typeof value !== "object" || Array.isArray(value)) {
@@ -147,46 +152,56 @@ var FILTER_OPERATORS = {
147
152
  $lt: createNumericOperator("<"),
148
153
  $lte: createNumericOperator("<="),
149
154
  // Array Operators
150
- $in: (key, paramIndex) => ({
151
- sql: `(
152
- CASE
153
- WHEN jsonb_typeof(metadata->'${handleKey(key)}') = 'array' THEN
154
- EXISTS (
155
- SELECT 1 FROM jsonb_array_elements_text(metadata->'${handleKey(key)}') as elem
156
- WHERE elem = ANY($${paramIndex}::text[])
157
- )
158
- ELSE metadata#>>'{${handleKey(key)}}' = ANY($${paramIndex}::text[])
159
- END
160
- )`,
161
- needsValue: true
162
- }),
163
- $nin: (key, paramIndex) => ({
164
- sql: `(
165
- CASE
166
- WHEN jsonb_typeof(metadata->'${handleKey(key)}') = 'array' THEN
167
- NOT EXISTS (
168
- SELECT 1 FROM jsonb_array_elements_text(metadata->'${handleKey(key)}') as elem
169
- WHERE elem = ANY($${paramIndex}::text[])
170
- )
171
- ELSE metadata#>>'{${handleKey(key)}}' != ALL($${paramIndex}::text[])
172
- END
173
- )`,
174
- needsValue: true
175
- }),
176
- $all: (key, paramIndex) => ({
177
- sql: `CASE WHEN array_length($${paramIndex}::text[], 1) IS NULL THEN false
178
- ELSE (metadata#>'{${handleKey(key)}}')::jsonb ?& $${paramIndex}::text[] END`,
179
- needsValue: true
180
- }),
155
+ $in: (key, paramIndex) => {
156
+ const jsonPathKey = parseJsonPathKey(key);
157
+ return {
158
+ sql: `(
159
+ CASE
160
+ WHEN jsonb_typeof(metadata->'${jsonPathKey}') = 'array' THEN
161
+ EXISTS (
162
+ SELECT 1 FROM jsonb_array_elements_text(metadata->'${jsonPathKey}') as elem
163
+ WHERE elem = ANY($${paramIndex}::text[])
164
+ )
165
+ ELSE metadata#>>'{${jsonPathKey}}' = ANY($${paramIndex}::text[])
166
+ END
167
+ )`,
168
+ needsValue: true
169
+ };
170
+ },
171
+ $nin: (key, paramIndex) => {
172
+ const jsonPathKey = parseJsonPathKey(key);
173
+ return {
174
+ sql: `(
175
+ CASE
176
+ WHEN jsonb_typeof(metadata->'${jsonPathKey}') = 'array' THEN
177
+ NOT EXISTS (
178
+ SELECT 1 FROM jsonb_array_elements_text(metadata->'${jsonPathKey}') as elem
179
+ WHERE elem = ANY($${paramIndex}::text[])
180
+ )
181
+ ELSE metadata#>>'{${jsonPathKey}}' != ALL($${paramIndex}::text[])
182
+ END
183
+ )`,
184
+ needsValue: true
185
+ };
186
+ },
187
+ $all: (key, paramIndex) => {
188
+ const jsonPathKey = parseJsonPathKey(key);
189
+ return {
190
+ sql: `CASE WHEN array_length($${paramIndex}::text[], 1) IS NULL THEN false
191
+ ELSE (metadata#>'{${jsonPathKey}}')::jsonb ?& $${paramIndex}::text[] END`,
192
+ needsValue: true
193
+ };
194
+ },
181
195
  $elemMatch: (key, paramIndex, value) => {
182
196
  const { sql, values } = buildElemMatchConditions(value, paramIndex);
197
+ const jsonPathKey = parseJsonPathKey(key);
183
198
  return {
184
199
  sql: `(
185
200
  CASE
186
- WHEN jsonb_typeof(metadata->'${handleKey(key)}') = 'array' THEN
201
+ WHEN jsonb_typeof(metadata->'${jsonPathKey}') = 'array' THEN
187
202
  EXISTS (
188
203
  SELECT 1
189
- FROM jsonb_array_elements(metadata->'${handleKey(key)}') as elem
204
+ FROM jsonb_array_elements(metadata->'${jsonPathKey}') as elem
190
205
  WHERE ${sql}
191
206
  )
192
207
  ELSE FALSE
@@ -197,33 +212,40 @@ var FILTER_OPERATORS = {
197
212
  };
198
213
  },
199
214
  // Element Operators
200
- $exists: (key) => ({
201
- sql: `metadata ? '${key}'`,
202
- needsValue: false
203
- }),
215
+ $exists: (key) => {
216
+ const jsonPathKey = parseJsonPathKey(key);
217
+ return {
218
+ sql: `metadata ? '${jsonPathKey}'`,
219
+ needsValue: false
220
+ };
221
+ },
204
222
  // Logical Operators
205
223
  $and: (key) => ({ sql: `(${key})`, needsValue: false }),
206
224
  $or: (key) => ({ sql: `(${key})`, needsValue: false }),
207
225
  $not: (key) => ({ sql: `NOT (${key})`, needsValue: false }),
208
226
  $nor: (key) => ({ sql: `NOT (${key})`, needsValue: false }),
209
227
  // Regex Operators
210
- $regex: (key, paramIndex) => ({
211
- sql: `metadata#>>'{${handleKey(key)}}' ~ $${paramIndex}`,
212
- needsValue: true
213
- }),
228
+ $regex: (key, paramIndex) => {
229
+ const jsonPathKey = parseJsonPathKey(key);
230
+ return {
231
+ sql: `metadata#>>'{${jsonPathKey}}' ~ $${paramIndex}`,
232
+ needsValue: true
233
+ };
234
+ },
214
235
  $contains: (key, paramIndex, value) => {
236
+ const jsonPathKey = parseJsonPathKey(key);
215
237
  let sql;
216
238
  if (Array.isArray(value)) {
217
- sql = `(metadata->'${handleKey(key)}') ?& $${paramIndex}`;
239
+ sql = `(metadata->'${jsonPathKey}') ?& $${paramIndex}`;
218
240
  } else if (typeof value === "string") {
219
- sql = `metadata->>'${handleKey(key)}' ILIKE '%' || $${paramIndex} || '%'`;
241
+ sql = `metadata->>'${jsonPathKey}' ILIKE '%' || $${paramIndex} || '%' ESCAPE '\\'`;
220
242
  } else {
221
- sql = `metadata->>'${handleKey(key)}' = $${paramIndex}`;
243
+ sql = `metadata->>'${jsonPathKey}' = $${paramIndex}`;
222
244
  }
223
245
  return {
224
246
  sql,
225
247
  needsValue: true,
226
- transformValue: () => Array.isArray(value) ? value.map(String) : value
248
+ transformValue: () => Array.isArray(value) ? value.map(String) : typeof value === "string" ? escapeLikePattern(value) : value
227
249
  };
228
250
  },
229
251
  /**
@@ -238,29 +260,36 @@ var FILTER_OPERATORS = {
238
260
  // return JSON.stringify(parts.reduceRight((value, key) => ({ [key]: value }), value));
239
261
  // },
240
262
  // }),
241
- $size: (key, paramIndex) => ({
242
- sql: `(
263
+ $size: (key, paramIndex) => {
264
+ const jsonPathKey = parseJsonPathKey(key);
265
+ return {
266
+ sql: `(
243
267
  CASE
244
- WHEN jsonb_typeof(metadata#>'{${handleKey(key)}}') = 'array' THEN
245
- jsonb_array_length(metadata#>'{${handleKey(key)}}') = $${paramIndex}
268
+ WHEN jsonb_typeof(metadata#>'{${jsonPathKey}}') = 'array' THEN
269
+ jsonb_array_length(metadata#>'{${jsonPathKey}}') = $${paramIndex}
246
270
  ELSE FALSE
247
271
  END
248
272
  )`,
249
- needsValue: true
250
- })
273
+ needsValue: true
274
+ };
275
+ }
251
276
  };
252
- var handleKey = (key) => {
253
- return key.replace(/\./g, ",");
277
+ var parseJsonPathKey = (key) => {
278
+ const parsedKey = key !== "" ? utils.parseFieldKey(key) : "";
279
+ return parsedKey.replace(/\./g, ",");
254
280
  };
255
- function buildFilterQuery(filter, minScore) {
256
- const values = [minScore];
281
+ function escapeLikePattern(str) {
282
+ return str.replace(/([%_\\])/g, "\\$1");
283
+ }
284
+ function buildFilterQuery(filter, minScore, topK) {
285
+ const values = [minScore, topK];
257
286
  function buildCondition(key, value, parentPath) {
258
287
  if (["$and", "$or", "$not", "$nor"].includes(key)) {
259
288
  return handleLogicalOperator(key, value);
260
289
  }
261
290
  if (!value || typeof value !== "object") {
262
291
  values.push(value);
263
- return `metadata#>>'{${handleKey(key)}}' = $${values.length}`;
292
+ return `metadata#>>'{${parseJsonPathKey(key)}}' = $${values.length}`;
264
293
  }
265
294
  const [[operator, operatorValue] = []] = Object.entries(value);
266
295
  if (operator === "$not") {
@@ -270,7 +299,7 @@ function buildFilterQuery(filter, minScore) {
270
299
  throw new Error(`Invalid operator in $not condition: ${nestedOp}`);
271
300
  }
272
301
  const operatorFn2 = FILTER_OPERATORS[nestedOp];
273
- const operatorResult2 = operatorFn2(key, values.length + 1);
302
+ const operatorResult2 = operatorFn2(key, values.length + 1, nestedValue);
274
303
  if (operatorResult2.needsValue) {
275
304
  values.push(nestedValue);
276
305
  }
@@ -389,7 +418,7 @@ var PgVector = class extends vector.MastraVector {
389
418
  void (async () => {
390
419
  const existingIndexes = await this.listIndexes();
391
420
  void existingIndexes.map(async (indexName) => {
392
- const info = await this.getIndexInfo(indexName);
421
+ const info = await this.getIndexInfo({ indexName });
393
422
  const key = await this.getIndexCacheKey({
394
423
  indexName,
395
424
  metric: info.metric,
@@ -405,15 +434,19 @@ var PgVector = class extends vector.MastraVector {
405
434
  return this.mutexesByName.get(indexName);
406
435
  }
407
436
  getTableName(indexName) {
408
- return this.schema ? `${this.schema}.${indexName}` : indexName;
437
+ const parsedIndexName = utils.parseSqlIdentifier(indexName, "index name");
438
+ const parsedSchemaName = this.schema ? utils.parseSqlIdentifier(this.schema, "schema name") : void 0;
439
+ return parsedSchemaName ? `${parsedSchemaName}.${parsedIndexName}` : parsedIndexName;
409
440
  }
410
441
  transformFilter(filter) {
411
442
  const translator = new PGFilterTranslator();
412
443
  return translator.translate(filter);
413
444
  }
414
- async getIndexInfo(indexName) {
445
+ async getIndexInfo(...args) {
446
+ const params = this.normalizeArgs("getIndexInfo", args);
447
+ const { indexName } = params;
415
448
  if (!this.describeIndexCache.has(indexName)) {
416
- this.describeIndexCache.set(indexName, await this.describeIndex(indexName));
449
+ this.describeIndexCache.set(indexName, await this.describeIndex({ indexName }));
417
450
  }
418
451
  return this.describeIndexCache.get(indexName);
419
452
  }
@@ -424,12 +457,18 @@ var PgVector = class extends vector.MastraVector {
424
457
  "probes"
425
458
  ]);
426
459
  const { indexName, queryVector, topK = 10, filter, includeVector = false, minScore = 0, ef, probes } = params;
460
+ if (!Number.isInteger(topK) || topK <= 0) {
461
+ throw new Error("topK must be a positive integer");
462
+ }
463
+ if (!Array.isArray(queryVector) || !queryVector.every((x) => typeof x === "number" && Number.isFinite(x))) {
464
+ throw new Error("queryVector must be an array of finite numbers");
465
+ }
427
466
  const client = await this.pool.connect();
428
467
  try {
429
468
  const vectorStr = `[${queryVector.join(",")}]`;
430
469
  const translatedFilter = this.transformFilter(filter);
431
- const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore);
432
- const indexInfo = await this.getIndexInfo(indexName);
470
+ const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore, topK);
471
+ const indexInfo = await this.getIndexInfo({ indexName });
433
472
  if (indexInfo.type === "hnsw") {
434
473
  const calculatedEf = ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
435
474
  const searchEf = Math.min(1e3, Math.max(1, calculatedEf));
@@ -453,7 +492,7 @@ var PgVector = class extends vector.MastraVector {
453
492
  FROM vector_scores
454
493
  WHERE score > $1
455
494
  ORDER BY score DESC
456
- LIMIT ${topK}`;
495
+ LIMIT $2`;
457
496
  const result = await client.query(query, filterValues);
458
497
  return result.rows.map(({ id, score, metadata, embedding }) => ({
459
498
  id,
@@ -716,7 +755,16 @@ var PgVector = class extends vector.MastraVector {
716
755
  client.release();
717
756
  }
718
757
  }
719
- async describeIndex(indexName) {
758
+ /**
759
+ * Retrieves statistics about a vector index.
760
+ *
761
+ * @param params - The parameters for describing an index
762
+ * @param params.indexName - The name of the index to describe
763
+ * @returns A promise that resolves to the index statistics including dimension, count and metric
764
+ */
765
+ async describeIndex(...args) {
766
+ const params = this.normalizeArgs("describeIndex", args);
767
+ const { indexName } = params;
720
768
  const client = await this.pool.connect();
721
769
  try {
722
770
  const tableName = this.getTableName(indexName);
@@ -790,7 +838,9 @@ var PgVector = class extends vector.MastraVector {
790
838
  client.release();
791
839
  }
792
840
  }
793
- async deleteIndex(indexName) {
841
+ async deleteIndex(...args) {
842
+ const params = this.normalizeArgs("deleteIndex", args);
843
+ const { indexName } = params;
794
844
  const client = await this.pool.connect();
795
845
  try {
796
846
  const tableName = this.getTableName(indexName);
@@ -803,7 +853,9 @@ var PgVector = class extends vector.MastraVector {
803
853
  client.release();
804
854
  }
805
855
  }
806
- async truncateIndex(indexName) {
856
+ async truncateIndex(...args) {
857
+ const params = this.normalizeArgs("truncateIndex", args);
858
+ const { indexName } = params;
807
859
  const client = await this.pool.connect();
808
860
  try {
809
861
  const tableName = this.getTableName(indexName);
@@ -836,7 +888,7 @@ var PgVector = class extends vector.MastraVector {
836
888
  Please use updateVector() instead.
837
889
  updateIndexById() will be removed on May 20th, 2025.`
838
890
  );
839
- await this.updateVector(indexName, id, update);
891
+ await this.updateVector({ indexName, id, update });
840
892
  }
841
893
  /**
842
894
  * Updates a vector by its ID with the provided vector and/or metadata.
@@ -848,7 +900,9 @@ var PgVector = class extends vector.MastraVector {
848
900
  * @returns A promise that resolves when the update is complete.
849
901
  * @throws Will throw an error if no updates are provided or if the update operation fails.
850
902
  */
851
- async updateVector(indexName, id, update) {
903
+ async updateVector(...args) {
904
+ const params = this.normalizeArgs("updateVector", args);
905
+ const { indexName, id, update } = params;
852
906
  if (!update.vector && !update.metadata) {
853
907
  throw new Error("No updates provided");
854
908
  }
@@ -897,7 +951,7 @@ var PgVector = class extends vector.MastraVector {
897
951
  Please use deleteVector() instead.
898
952
  deleteIndexById() will be removed on May 20th, 2025.`
899
953
  );
900
- await this.deleteVector(indexName, id);
954
+ await this.deleteVector({ indexName, id });
901
955
  }
902
956
  /**
903
957
  * Deletes a vector by its ID.
@@ -906,7 +960,9 @@ var PgVector = class extends vector.MastraVector {
906
960
  * @returns A promise that resolves when the deletion is complete.
907
961
  * @throws Will throw an error if the deletion operation fails.
908
962
  */
909
- async deleteVector(indexName, id) {
963
+ async deleteVector(...args) {
964
+ const params = this.normalizeArgs("deleteVector", args);
965
+ const { indexName, id } = params;
910
966
  const client = await this.pool.connect();
911
967
  try {
912
968
  const tableName = this.getTableName(indexName);
@@ -965,7 +1021,9 @@ var PostgresStore = class extends storage.MastraStorage {
965
1021
  );
966
1022
  }
967
1023
  getTableName(indexName) {
968
- return this.schema ? `${this.schema}."${indexName}"` : `"${indexName}"`;
1024
+ const parsedIndexName = utils.parseSqlIdentifier(indexName, "table name");
1025
+ const parsedSchemaName = this.schema ? utils.parseSqlIdentifier(this.schema, "schema name") : void 0;
1026
+ return parsedSchemaName ? `${parsedSchemaName}."${parsedIndexName}"` : `"${parsedIndexName}"`;
969
1027
  }
970
1028
  async getEvalsByAgentName(agentName, type) {
971
1029
  try {
@@ -1040,12 +1098,14 @@ var PostgresStore = class extends storage.MastraStorage {
1040
1098
  }
1041
1099
  if (attributes) {
1042
1100
  Object.keys(attributes).forEach((key) => {
1043
- conditions.push(`attributes->>'${key}' = $${idx++}`);
1101
+ const parsedKey = utils.parseSqlIdentifier(key, "attribute key");
1102
+ conditions.push(`attributes->>'${parsedKey}' = $${idx++}`);
1044
1103
  });
1045
1104
  }
1046
1105
  if (filters) {
1047
1106
  Object.entries(filters).forEach(([key]) => {
1048
- conditions.push(`${key} = $${idx++}`);
1107
+ const parsedKey = utils.parseSqlIdentifier(key, "filter key");
1108
+ conditions.push(`${parsedKey} = $${idx++}`);
1049
1109
  });
1050
1110
  }
1051
1111
  if (fromDate) {
@@ -1147,10 +1207,11 @@ var PostgresStore = class extends storage.MastraStorage {
1147
1207
  }) {
1148
1208
  try {
1149
1209
  const columns = Object.entries(schema).map(([name, def]) => {
1210
+ const parsedName = utils.parseSqlIdentifier(name, "column name");
1150
1211
  const constraints = [];
1151
1212
  if (def.primaryKey) constraints.push("PRIMARY KEY");
1152
1213
  if (!def.nullable) constraints.push("NOT NULL");
1153
- return `"${name}" ${def.type.toUpperCase()} ${constraints.join(" ")}`;
1214
+ return `"${parsedName}" ${def.type.toUpperCase()} ${constraints.join(" ")}`;
1154
1215
  }).join(",\n");
1155
1216
  if (this.schema) {
1156
1217
  await this.setupSchema();
@@ -1187,7 +1248,7 @@ var PostgresStore = class extends storage.MastraStorage {
1187
1248
  }
1188
1249
  async insert({ tableName, record }) {
1189
1250
  try {
1190
- const columns = Object.keys(record);
1251
+ const columns = Object.keys(record).map((col) => utils.parseSqlIdentifier(col, "column name"));
1191
1252
  const values = Object.values(record);
1192
1253
  const placeholders = values.map((_, i) => `$${i + 1}`).join(", ");
1193
1254
  await this.db.none(
@@ -1201,7 +1262,7 @@ var PostgresStore = class extends storage.MastraStorage {
1201
1262
  }
1202
1263
  async load({ tableName, keys }) {
1203
1264
  try {
1204
- const keyEntries = Object.entries(keys);
1265
+ const keyEntries = Object.entries(keys).map(([key, value]) => [utils.parseSqlIdentifier(key, "column name"), value]);
1205
1266
  const conditions = keyEntries.map(([key], index) => `"${key}" = $${index + 1}`).join(" AND ");
1206
1267
  const values = keyEntries.map(([_, value]) => value);
1207
1268
  const result = await this.db.oneOrNone(