@mastra/pg 1.9.0 → 1.9.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -709,6 +709,25 @@ var PgVector = class extends MastraVector {
709
709
  return null;
710
710
  }
711
711
  }
712
+ /**
713
+ * Sets search_path on the client connection so that vector operators (e.g. <=>, vector_cosine_ops)
714
+ * are resolvable when the pgvector extension is installed in a non-default schema.
715
+ *
716
+ * PostgreSQL's default search_path is ("$user", public). If the extension lives in a custom schema
717
+ * (e.g. "myapp"), operator classes and distance operators won't resolve without this.
718
+ */
719
+ async ensureSearchPath(client) {
720
+ if (!this.vectorExtensionSchema) {
721
+ await this.detectVectorExtensionSchema(client);
722
+ }
723
+ if (this.vectorExtensionSchema && this.vectorExtensionSchema !== "public" && this.vectorExtensionSchema !== "pg_catalog") {
724
+ const schemas = /* @__PURE__ */ new Set();
725
+ schemas.add(this.vectorExtensionSchema);
726
+ if (this.schema) schemas.add(this.schema);
727
+ schemas.add("public");
728
+ await client.query(`SET search_path TO ${[...schemas].map((s) => `"${s}"`).join(", ")}`);
729
+ }
730
+ }
712
731
  /**
713
732
  * Checks if the installed pgvector version supports halfvec type.
714
733
  * halfvec was introduced in pgvector 0.7.0.
@@ -922,6 +941,7 @@ var PgVector = class extends MastraVector {
922
941
  }
923
942
  const client = await this.pool.connect();
924
943
  try {
944
+ await this.ensureSearchPath(client);
925
945
  await client.query("BEGIN");
926
946
  const translatedFilter = this.transformFilter(filter);
927
947
  const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore, topK);
@@ -941,7 +961,23 @@ var PgVector = class extends MastraVector {
941
961
  const qualifiedVectorType = this.getVectorTypeName(indexInfo.vectorType, indexInfo.dimension);
942
962
  const distanceExpr = `embedding ${ops.distanceOperator} '${vectorStr}'::${qualifiedVectorType}`;
943
963
  const scoreExpr = ops.scoreExpr(distanceExpr);
944
- const query = `
964
+ const hasFilter = filterQuery.trim().length > 0;
965
+ const useIndexedOrder = indexInfo.type === "hnsw" && !hasFilter && minScore <= 0;
966
+ const query = useIndexedOrder ? `
967
+ WITH vector_scores AS (
968
+ SELECT
969
+ vector_id as id,
970
+ ${scoreExpr} as score,
971
+ metadata
972
+ ${includeVector ? ", embedding" : ""}
973
+ FROM ${tableName}
974
+ ORDER BY ${distanceExpr}
975
+ LIMIT $2
976
+ )
977
+ SELECT *
978
+ FROM vector_scores
979
+ WHERE score > $1
980
+ ORDER BY score DESC` : `
945
981
  WITH vector_scores AS (
946
982
  SELECT
947
983
  vector_id as id,
@@ -994,6 +1030,7 @@ var PgVector = class extends MastraVector {
994
1030
  const { tableName } = this.getTableName(indexName);
995
1031
  const client = await this.pool.connect();
996
1032
  try {
1033
+ await this.ensureSearchPath(client);
997
1034
  await client.query("BEGIN");
998
1035
  if (deleteFilter) {
999
1036
  this.logger?.debug(`Deleting vectors matching filter before upsert`, { indexName, deleteFilter });
@@ -1244,9 +1281,6 @@ var PgVector = class extends MastraVector {
1244
1281
  }
1245
1282
  });
1246
1283
  }
1247
- if (this.schema && this.vectorExtensionSchema && this.schema !== this.vectorExtensionSchema && this.vectorExtensionSchema !== "pg_catalog") {
1248
- await client.query(`SET search_path TO ${this.getSchemaName()}, "${this.vectorExtensionSchema}"`);
1249
- }
1250
1284
  const qualifiedVectorType = this.getVectorTypeName(vectorType);
1251
1285
  await client.query(`
1252
1286
  CREATE TABLE IF NOT EXISTS ${tableName} (
@@ -1389,6 +1423,7 @@ var PgVector = class extends MastraVector {
1389
1423
  this.describeIndexCache.delete(indexName);
1390
1424
  return;
1391
1425
  }
1426
+ await this.ensureSearchPath(client);
1392
1427
  const effectiveVectorType = existingIndexInfo?.vectorType ?? vectorType;
1393
1428
  const metricOp = this.getVectorOps(effectiveVectorType, metric).operatorClass;
1394
1429
  let indexSQL;
@@ -1739,6 +1774,7 @@ var PgVector = class extends MastraVector {
1739
1774
  });
1740
1775
  }
1741
1776
  client = await this.pool.connect();
1777
+ await this.ensureSearchPath(client);
1742
1778
  const { tableName } = this.getTableName(indexName);
1743
1779
  const indexInfo = await this.getIndexInfo({ indexName });
1744
1780
  const qualifiedVectorType = this.getVectorTypeName(indexInfo.vectorType, indexInfo.dimension);
@@ -1977,6 +2013,7 @@ var PoolAdapter = class {
1977
2013
  constructor($pool) {
1978
2014
  this.$pool = $pool;
1979
2015
  }
2016
+ $pool;
1980
2017
  connect() {
1981
2018
  return this.$pool.connect();
1982
2019
  }
@@ -2045,6 +2082,7 @@ var TransactionClient = class {
2045
2082
  constructor(client) {
2046
2083
  this.client = client;
2047
2084
  }
2085
+ client;
2048
2086
  async none(query, values) {
2049
2087
  await this.client.query(query, values);
2050
2088
  return null;
@@ -2474,36 +2512,39 @@ var PgDB = class extends MastraBase {
2474
2512
  return getDefaultValue(type);
2475
2513
  }
2476
2514
  }
2477
- async insert({ tableName, record }) {
2478
- try {
2479
- this.addTimestampZColumns(record);
2480
- const filteredRecord = await this.filterRecordToKnownColumns(tableName, record);
2481
- const schemaName = getSchemaName(this.schemaName);
2482
- const columns = Object.keys(filteredRecord).map((col) => parseSqlIdentifier(col, "column name"));
2483
- if (columns.length === 0) return;
2484
- const values = this.prepareValuesForInsert(filteredRecord, tableName);
2485
- const placeholders = values.map((_, i) => `$${i + 1}`).join(", ");
2486
- const fullTableName = getTableName({ indexName: tableName, schemaName });
2487
- const columnList = columns.map((c) => `"${c}"`).join(", ");
2488
- if (tableName === TABLE_SPANS) {
2489
- const updateColumns = columns.filter((c) => c !== "traceId" && c !== "spanId");
2490
- if (updateColumns.length > 0) {
2491
- const updateClause = updateColumns.map((c) => `"${c}" = EXCLUDED."${c}"`).join(", ");
2492
- await this.client.none(
2493
- `INSERT INTO ${fullTableName} (${columnList}) VALUES (${placeholders})
2515
+ async executeInsert(client, { tableName, record }) {
2516
+ this.addTimestampZColumns(record);
2517
+ const filteredRecord = await this.filterRecordToKnownColumns(tableName, record);
2518
+ const schemaName = getSchemaName(this.schemaName);
2519
+ const columns = Object.keys(filteredRecord).map((col) => parseSqlIdentifier(col, "column name"));
2520
+ if (columns.length === 0) return;
2521
+ const values = this.prepareValuesForInsert(filteredRecord, tableName);
2522
+ const placeholders = values.map((_, i) => `$${i + 1}`).join(", ");
2523
+ const fullTableName = getTableName({ indexName: tableName, schemaName });
2524
+ const columnList = columns.map((c) => `"${c}"`).join(", ");
2525
+ if (tableName === TABLE_SPANS) {
2526
+ const updateColumns = columns.filter((c) => c !== "traceId" && c !== "spanId");
2527
+ if (updateColumns.length > 0) {
2528
+ const updateClause = updateColumns.map((c) => `"${c}" = EXCLUDED."${c}"`).join(", ");
2529
+ await client.none(
2530
+ `INSERT INTO ${fullTableName} (${columnList}) VALUES (${placeholders})
2494
2531
  ON CONFLICT ("traceId", "spanId") DO UPDATE SET ${updateClause}`,
2495
- values
2496
- );
2497
- } else {
2498
- await this.client.none(
2499
- `INSERT INTO ${fullTableName} (${columnList}) VALUES (${placeholders})
2500
- ON CONFLICT ("traceId", "spanId") DO NOTHING`,
2501
- values
2502
- );
2503
- }
2532
+ values
2533
+ );
2504
2534
  } else {
2505
- await this.client.none(`INSERT INTO ${fullTableName} (${columnList}) VALUES (${placeholders})`, values);
2535
+ await client.none(
2536
+ `INSERT INTO ${fullTableName} (${columnList}) VALUES (${placeholders})
2537
+ ON CONFLICT ("traceId", "spanId") DO NOTHING`,
2538
+ values
2539
+ );
2506
2540
  }
2541
+ } else {
2542
+ await client.none(`INSERT INTO ${fullTableName} (${columnList}) VALUES (${placeholders})`, values);
2543
+ }
2544
+ }
2545
+ async insert({ tableName, record }) {
2546
+ try {
2547
+ await this.executeInsert(this.client, { tableName, record });
2507
2548
  } catch (error) {
2508
2549
  throw new MastraError(
2509
2550
  {
@@ -2978,13 +3019,12 @@ Note: This migration may take some time for large tables.
2978
3019
  }
2979
3020
  async batchInsert({ tableName, records }) {
2980
3021
  try {
2981
- await this.client.query("BEGIN");
2982
- for (const record of records) {
2983
- await this.insert({ tableName, record });
2984
- }
2985
- await this.client.query("COMMIT");
3022
+ await this.client.tx(async (tx) => {
3023
+ for (const record of records) {
3024
+ await this.executeInsert(tx, { tableName, record });
3025
+ }
3026
+ });
2986
3027
  } catch (error) {
2987
- await this.client.query("ROLLBACK");
2988
3028
  throw new MastraError(
2989
3029
  {
2990
3030
  id: createStorageErrorId("PG", "BATCH_INSERT", "FAILED"),
@@ -3262,30 +3302,7 @@ Note: This migration may take some time for large tables.
3262
3302
  data
3263
3303
  }) {
3264
3304
  try {
3265
- const filteredData = await this.filterRecordToKnownColumns(tableName, data);
3266
- if (Object.keys(filteredData).length === 0) return;
3267
- const setColumns = [];
3268
- const setValues = [];
3269
- let paramIndex = 1;
3270
- Object.entries(filteredData).forEach(([key, value]) => {
3271
- const parsedKey = parseSqlIdentifier(key, "column name");
3272
- setColumns.push(`"${parsedKey}" = $${paramIndex++}`);
3273
- setValues.push(this.prepareValue(value, key, tableName));
3274
- });
3275
- const whereConditions = [];
3276
- const whereValues = [];
3277
- Object.entries(keys).forEach(([key, value]) => {
3278
- const parsedKey = parseSqlIdentifier(key, "column name");
3279
- whereConditions.push(`"${parsedKey}" = $${paramIndex++}`);
3280
- whereValues.push(this.prepareValue(value, key, tableName));
3281
- });
3282
- const tableName_ = getTableName({
3283
- indexName: tableName,
3284
- schemaName: getSchemaName(this.schemaName)
3285
- });
3286
- const sql = `UPDATE ${tableName_} SET ${setColumns.join(", ")} WHERE ${whereConditions.join(" AND ")}`;
3287
- const values = [...setValues, ...whereValues];
3288
- await this.client.none(sql, values);
3305
+ await this.executeUpdate(this.client, { tableName, keys, data });
3289
3306
  } catch (error) {
3290
3307
  throw new MastraError(
3291
3308
  {
@@ -3305,13 +3322,12 @@ Note: This migration may take some time for large tables.
3305
3322
  updates
3306
3323
  }) {
3307
3324
  try {
3308
- await this.client.query("BEGIN");
3309
- for (const { keys, data } of updates) {
3310
- await this.update({ tableName, keys, data });
3311
- }
3312
- await this.client.query("COMMIT");
3325
+ await this.client.tx(async (tx) => {
3326
+ for (const { keys, data } of updates) {
3327
+ await this.executeUpdate(tx, { tableName, keys, data });
3328
+ }
3329
+ });
3313
3330
  } catch (error) {
3314
- await this.client.query("ROLLBACK");
3315
3331
  throw new MastraError(
3316
3332
  {
3317
3333
  id: createStorageErrorId("PG", "BATCH_UPDATE", "FAILED"),
@@ -3370,6 +3386,36 @@ Note: This migration may take some time for large tables.
3370
3386
  async deleteData({ tableName }) {
3371
3387
  return this.clearTable({ tableName });
3372
3388
  }
3389
+ async executeUpdate(client, {
3390
+ tableName,
3391
+ keys,
3392
+ data
3393
+ }) {
3394
+ const filteredData = await this.filterRecordToKnownColumns(tableName, data);
3395
+ if (Object.keys(filteredData).length === 0) return;
3396
+ const setColumns = [];
3397
+ const setValues = [];
3398
+ let paramIndex = 1;
3399
+ Object.entries(filteredData).forEach(([key, value]) => {
3400
+ const parsedKey = parseSqlIdentifier(key, "column name");
3401
+ setColumns.push(`"${parsedKey}" = $${paramIndex++}`);
3402
+ setValues.push(this.prepareValue(value, key, tableName));
3403
+ });
3404
+ const whereConditions = [];
3405
+ const whereValues = [];
3406
+ Object.entries(keys).forEach(([key, value]) => {
3407
+ const parsedKey = parseSqlIdentifier(key, "column name");
3408
+ whereConditions.push(`"${parsedKey}" = $${paramIndex++}`);
3409
+ whereValues.push(this.prepareValue(value, key, tableName));
3410
+ });
3411
+ const tableName_ = getTableName({
3412
+ indexName: tableName,
3413
+ schemaName: getSchemaName(this.schemaName)
3414
+ });
3415
+ const sql = `UPDATE ${tableName_} SET ${setColumns.join(", ")} WHERE ${whereConditions.join(" AND ")}`;
3416
+ const values = [...setValues, ...whereValues];
3417
+ await client.none(sql, values);
3418
+ }
3373
3419
  };
3374
3420
  function getSchemaName2(schema) {
3375
3421
  return schema ? `"${parseSqlIdentifier(schema, "schema name")}"` : void 0;
@@ -5283,6 +5329,16 @@ var ExperimentsPG = class _ExperimentsPG extends ExperimentsStorage {
5283
5329
  async init() {
5284
5330
  await this.#db.createTable({ tableName: TABLE_EXPERIMENTS, schema: EXPERIMENTS_SCHEMA });
5285
5331
  await this.#db.createTable({ tableName: TABLE_EXPERIMENT_RESULTS, schema: EXPERIMENT_RESULTS_SCHEMA });
5332
+ await this.#db.alterTable({
5333
+ tableName: TABLE_EXPERIMENTS,
5334
+ schema: EXPERIMENTS_SCHEMA,
5335
+ ifNotExists: ["agentVersion"]
5336
+ });
5337
+ await this.#db.alterTable({
5338
+ tableName: TABLE_EXPERIMENT_RESULTS,
5339
+ schema: EXPERIMENT_RESULTS_SCHEMA,
5340
+ ifNotExists: ["status", "tags"]
5341
+ });
5286
5342
  await this.createDefaultIndexes();
5287
5343
  await this.createCustomIndexes();
5288
5344
  }
@@ -5522,6 +5578,22 @@ var ExperimentsPG = class _ExperimentsPG extends ExperimentsStorage {
5522
5578
  conditions.push(`"datasetId" = $${paramIndex++}`);
5523
5579
  queryParams.push(args.datasetId);
5524
5580
  }
5581
+ if (args.targetType) {
5582
+ conditions.push(`"targetType" = $${paramIndex++}`);
5583
+ queryParams.push(args.targetType);
5584
+ }
5585
+ if (args.targetId) {
5586
+ conditions.push(`"targetId" = $${paramIndex++}`);
5587
+ queryParams.push(args.targetId);
5588
+ }
5589
+ if (args.agentVersion) {
5590
+ conditions.push(`"agentVersion" = $${paramIndex++}`);
5591
+ queryParams.push(args.agentVersion);
5592
+ }
5593
+ if (args.status) {
5594
+ conditions.push(`"status" = $${paramIndex++}`);
5595
+ queryParams.push(args.status);
5596
+ }
5525
5597
  const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(" AND ")}` : "";
5526
5598
  const countResult = await this.#db.client.one(
5527
5599
  `SELECT COUNT(*) as count FROM ${tableName} ${whereClause}`,
@@ -5710,9 +5782,21 @@ var ExperimentsPG = class _ExperimentsPG extends ExperimentsStorage {
5710
5782
  try {
5711
5783
  const { page, perPage: perPageInput } = args.pagination;
5712
5784
  const tableName = getTableName2({ indexName: TABLE_EXPERIMENT_RESULTS, schemaName: getSchemaName2(this.#schema) });
5785
+ const conditions = ['"experimentId" = $1'];
5786
+ const queryParams = [args.experimentId];
5787
+ let paramIndex = 2;
5788
+ if (args.traceId) {
5789
+ conditions.push(`"traceId" = $${paramIndex++}`);
5790
+ queryParams.push(args.traceId);
5791
+ }
5792
+ if (args.status) {
5793
+ conditions.push(`"status" = $${paramIndex++}`);
5794
+ queryParams.push(args.status);
5795
+ }
5796
+ const whereClause = `WHERE ${conditions.join(" AND ")}`;
5713
5797
  const countResult = await this.#db.client.one(
5714
- `SELECT COUNT(*) as count FROM ${tableName} WHERE "experimentId" = $1`,
5715
- [args.experimentId]
5798
+ `SELECT COUNT(*) as count FROM ${tableName} ${whereClause}`,
5799
+ queryParams
5716
5800
  );
5717
5801
  const total = parseInt(countResult.count, 10);
5718
5802
  if (total === 0) {
@@ -5722,8 +5806,8 @@ var ExperimentsPG = class _ExperimentsPG extends ExperimentsStorage {
5722
5806
  const { offset, perPage: perPageForResponse } = calculatePagination(page, perPageInput, perPage);
5723
5807
  const limitValue = perPageInput === false ? total : perPage;
5724
5808
  const rows = await this.#db.client.manyOrNone(
5725
- `SELECT * FROM ${tableName} WHERE "experimentId" = $1 ORDER BY "startedAt" ASC LIMIT $2 OFFSET $3`,
5726
- [args.experimentId, limitValue, offset]
5809
+ `SELECT * FROM ${tableName} ${whereClause} ORDER BY "startedAt" ASC LIMIT $${paramIndex} OFFSET $${paramIndex + 1}`,
5810
+ [...queryParams, limitValue, offset]
5727
5811
  );
5728
5812
  return {
5729
5813
  results: (rows || []).map((row) => this.transformExperimentResultRow(row)),