@mastra/pg 1.0.0-beta.2 → 1.0.0-beta.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -5,7 +5,7 @@ import { Mutex } from 'async-mutex';
5
5
  import * as pg from 'pg';
6
6
  import xxhash from 'xxhash-wasm';
7
7
  import { BaseFilterTranslator } from '@mastra/core/vector/filter';
8
- import { MastraStorage, StoreOperations, TABLE_SCHEMAS, TABLE_WORKFLOW_SNAPSHOT, TABLE_SPANS, TABLE_THREADS, TABLE_MESSAGES, TABLE_TRACES, TABLE_SCORERS, ScoresStorage, normalizePerPage, calculatePagination, WorkflowsStorage, MemoryStorage, TABLE_RESOURCES, ObservabilityStorage, safelyParseJSON } from '@mastra/core/storage';
8
+ import { MastraStorage, StoreOperations, TABLE_SCHEMAS, TABLE_WORKFLOW_SNAPSHOT, TABLE_SPANS, TABLE_THREADS, TABLE_MESSAGES, TABLE_TRACES, TABLE_SCORERS, ScoresStorage, normalizePerPage, calculatePagination, WorkflowsStorage, MemoryStorage, TABLE_RESOURCES, ObservabilityStorage, transformScoreRow as transformScoreRow$1 } from '@mastra/core/storage';
9
9
  import pgPromise from 'pg-promise';
10
10
  import { MessageList } from '@mastra/core/agent';
11
11
  import { saveScorePayloadSchema } from '@mastra/core/evals';
@@ -596,9 +596,6 @@ var PgVector = class extends MastraVector {
596
596
  if (this.vectorExtensionSchema === "pg_catalog") {
597
597
  return "vector";
598
598
  }
599
- if (this.vectorExtensionSchema === (this.schema || "public")) {
600
- return "vector";
601
- }
602
599
  const validatedSchema = parseSqlIdentifier(this.vectorExtensionSchema, "vector extension schema");
603
600
  return `${validatedSchema}.vector`;
604
601
  }
@@ -1888,6 +1885,20 @@ var MemoryPG = class extends MemoryStorage {
1888
1885
  const threadTableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
1889
1886
  await this.client.tx(async (t) => {
1890
1887
  await t.none(`DELETE FROM ${tableName} WHERE thread_id = $1`, [threadId]);
1888
+ const schemaName = this.schema || "public";
1889
+ const vectorTables = await t.manyOrNone(
1890
+ `
1891
+ SELECT tablename
1892
+ FROM pg_tables
1893
+ WHERE schemaname = $1
1894
+ AND (tablename = 'memory_messages' OR tablename LIKE 'memory_messages_%')
1895
+ `,
1896
+ [schemaName]
1897
+ );
1898
+ for (const { tablename } of vectorTables) {
1899
+ const vectorTableName = getTableName({ indexName: tablename, schemaName: getSchemaName(this.schema) });
1900
+ await t.none(`DELETE FROM ${vectorTableName} WHERE metadata->>'thread_id' = $1`, [threadId]);
1901
+ }
1891
1902
  await t.none(`DELETE FROM ${threadTableName} WHERE id = $1`, [threadId]);
1892
1903
  });
1893
1904
  } catch (error) {
@@ -1904,28 +1915,26 @@ var MemoryPG = class extends MemoryStorage {
1904
1915
  );
1905
1916
  }
1906
1917
  }
1907
- async _getIncludedMessages({
1908
- threadId,
1909
- include
1910
- }) {
1911
- if (!threadId.trim()) throw new Error("threadId must be a non-empty string");
1912
- if (!include) return null;
1918
+ async _getIncludedMessages({ include }) {
1919
+ if (!include || include.length === 0) return null;
1913
1920
  const unionQueries = [];
1914
1921
  const params = [];
1915
1922
  let paramIdx = 1;
1916
1923
  const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
1917
1924
  for (const inc of include) {
1918
1925
  const { id, withPreviousMessages = 0, withNextMessages = 0 } = inc;
1919
- const searchId = inc.threadId || threadId;
1920
1926
  unionQueries.push(
1921
1927
  `
1922
1928
  SELECT * FROM (
1923
- WITH ordered_messages AS (
1929
+ WITH target_thread AS (
1930
+ SELECT thread_id FROM ${tableName} WHERE id = $${paramIdx}
1931
+ ),
1932
+ ordered_messages AS (
1924
1933
  SELECT
1925
1934
  *,
1926
1935
  ROW_NUMBER() OVER (ORDER BY "createdAt" ASC) as row_num
1927
1936
  FROM ${tableName}
1928
- WHERE thread_id = $${paramIdx}
1937
+ WHERE thread_id = (SELECT thread_id FROM target_thread)
1929
1938
  )
1930
1939
  SELECT
1931
1940
  m.id,
@@ -1937,24 +1946,24 @@ var MemoryPG = class extends MemoryStorage {
1937
1946
  m.thread_id AS "threadId",
1938
1947
  m."resourceId"
1939
1948
  FROM ordered_messages m
1940
- WHERE m.id = $${paramIdx + 1}
1949
+ WHERE m.id = $${paramIdx}
1941
1950
  OR EXISTS (
1942
1951
  SELECT 1 FROM ordered_messages target
1943
- WHERE target.id = $${paramIdx + 1}
1952
+ WHERE target.id = $${paramIdx}
1944
1953
  AND (
1945
1954
  -- Get previous messages (messages that come BEFORE the target)
1946
- (m.row_num < target.row_num AND m.row_num >= target.row_num - $${paramIdx + 2})
1955
+ (m.row_num < target.row_num AND m.row_num >= target.row_num - $${paramIdx + 1})
1947
1956
  OR
1948
1957
  -- Get next messages (messages that come AFTER the target)
1949
- (m.row_num > target.row_num AND m.row_num <= target.row_num + $${paramIdx + 3})
1958
+ (m.row_num > target.row_num AND m.row_num <= target.row_num + $${paramIdx + 2})
1950
1959
  )
1951
1960
  )
1952
1961
  ) AS query_${paramIdx}
1953
1962
  `
1954
1963
  // Keep ASC for final sorting after fetching context
1955
1964
  );
1956
- params.push(searchId, id, withPreviousMessages, withNextMessages);
1957
- paramIdx += 4;
1965
+ params.push(id, withPreviousMessages, withNextMessages);
1966
+ paramIdx += 3;
1958
1967
  }
1959
1968
  const finalQuery = unionQueries.join(" UNION ALL ") + ' ORDER BY "createdAt" ASC';
1960
1969
  const includedRows = await this.client.manyOrNone(finalQuery, params);
@@ -2018,15 +2027,18 @@ var MemoryPG = class extends MemoryStorage {
2018
2027
  }
2019
2028
  async listMessages(args) {
2020
2029
  const { threadId, resourceId, include, filter, perPage: perPageInput, page = 0, orderBy } = args;
2021
- if (!threadId.trim()) {
2030
+ const threadIds = (Array.isArray(threadId) ? threadId : [threadId]).filter(
2031
+ (id) => typeof id === "string"
2032
+ );
2033
+ if (threadIds.length === 0 || threadIds.some((id) => !id.trim())) {
2022
2034
  throw new MastraError(
2023
2035
  {
2024
2036
  id: "STORAGE_PG_LIST_MESSAGES_INVALID_THREAD_ID",
2025
2037
  domain: ErrorDomain.STORAGE,
2026
2038
  category: ErrorCategory.THIRD_PARTY,
2027
- details: { threadId }
2039
+ details: { threadId: Array.isArray(threadId) ? String(threadId) : String(threadId) }
2028
2040
  },
2029
- new Error("threadId must be a non-empty string")
2041
+ new Error("threadId must be a non-empty string or array of non-empty strings")
2030
2042
  );
2031
2043
  }
2032
2044
  if (page < 0) {
@@ -2036,7 +2048,7 @@ var MemoryPG = class extends MemoryStorage {
2036
2048
  category: ErrorCategory.USER,
2037
2049
  text: "Page number must be non-negative",
2038
2050
  details: {
2039
- threadId,
2051
+ threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
2040
2052
  page
2041
2053
  }
2042
2054
  });
@@ -2048,9 +2060,10 @@ var MemoryPG = class extends MemoryStorage {
2048
2060
  const orderByStatement = `ORDER BY "${field}" ${direction}`;
2049
2061
  const selectStatement = `SELECT id, content, role, type, "createdAt", "createdAtZ", thread_id AS "threadId", "resourceId"`;
2050
2062
  const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
2051
- const conditions = [`thread_id = $1`];
2052
- const queryParams = [threadId];
2053
- let paramIndex = 2;
2063
+ const threadPlaceholders = threadIds.map((_, i) => `$${i + 1}`).join(", ");
2064
+ const conditions = [`thread_id IN (${threadPlaceholders})`];
2065
+ const queryParams = [...threadIds];
2066
+ let paramIndex = threadIds.length + 1;
2054
2067
  if (resourceId) {
2055
2068
  conditions.push(`"resourceId" = $${paramIndex++}`);
2056
2069
  queryParams.push(resourceId);
@@ -2082,7 +2095,7 @@ var MemoryPG = class extends MemoryStorage {
2082
2095
  }
2083
2096
  const messageIds = new Set(messages.map((m) => m.id));
2084
2097
  if (include && include.length > 0) {
2085
- const includeMessages = await this._getIncludedMessages({ threadId, include });
2098
+ const includeMessages = await this._getIncludedMessages({ include });
2086
2099
  if (includeMessages) {
2087
2100
  for (const includeMsg of includeMessages) {
2088
2101
  if (!messageIds.has(includeMsg.id)) {
@@ -2109,7 +2122,10 @@ var MemoryPG = class extends MemoryStorage {
2109
2122
  }
2110
2123
  return direction === "ASC" ? String(aValue).localeCompare(String(bValue)) : String(bValue).localeCompare(String(aValue));
2111
2124
  });
2112
- const returnedThreadMessageIds = new Set(finalMessages.filter((m) => m.threadId === threadId).map((m) => m.id));
2125
+ const threadIdSet = new Set(threadIds);
2126
+ const returnedThreadMessageIds = new Set(
2127
+ finalMessages.filter((m) => m.threadId && threadIdSet.has(m.threadId)).map((m) => m.id)
2128
+ );
2113
2129
  const allThreadMessagesReturned = returnedThreadMessageIds.size >= total;
2114
2130
  const hasMore = perPageInput !== false && !allThreadMessagesReturned && offset + perPage < total;
2115
2131
  return {
@@ -2126,7 +2142,7 @@ var MemoryPG = class extends MemoryStorage {
2126
2142
  domain: ErrorDomain.STORAGE,
2127
2143
  category: ErrorCategory.THIRD_PARTY,
2128
2144
  details: {
2129
- threadId,
2145
+ threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
2130
2146
  resourceId: resourceId ?? ""
2131
2147
  }
2132
2148
  },
@@ -3581,20 +3597,12 @@ var StoreOperationsPG = class extends StoreOperations {
3581
3597
  }
3582
3598
  };
3583
3599
  function transformScoreRow(row) {
3584
- return {
3585
- ...row,
3586
- input: safelyParseJSON(row.input),
3587
- scorer: safelyParseJSON(row.scorer),
3588
- preprocessStepResult: safelyParseJSON(row.preprocessStepResult),
3589
- analyzeStepResult: safelyParseJSON(row.analyzeStepResult),
3590
- metadata: safelyParseJSON(row.metadata),
3591
- output: safelyParseJSON(row.output),
3592
- additionalContext: safelyParseJSON(row.additionalContext),
3593
- requestContext: safelyParseJSON(row.requestContext),
3594
- entity: safelyParseJSON(row.entity),
3595
- createdAt: row.createdAtZ || row.createdAt,
3596
- updatedAt: row.updatedAtZ || row.updatedAt
3597
- };
3600
+ return transformScoreRow$1(row, {
3601
+ preferredTimestampFields: {
3602
+ createdAt: "createdAtZ",
3603
+ updatedAt: "updatedAtZ"
3604
+ }
3605
+ });
3598
3606
  }
3599
3607
  var ScoresPG = class extends ScoresStorage {
3600
3608
  client;
@@ -3741,8 +3749,6 @@ var ScoresPG = class extends ScoresStorage {
3741
3749
  scorer: scorer ? JSON.stringify(scorer) : null,
3742
3750
  preprocessStepResult: preprocessStepResult ? JSON.stringify(preprocessStepResult) : null,
3743
3751
  analyzeStepResult: analyzeStepResult ? JSON.stringify(analyzeStepResult) : null,
3744
- metadata: metadata ? JSON.stringify(metadata) : null,
3745
- additionalContext: additionalContext ? JSON.stringify(additionalContext) : null,
3746
3752
  requestContext: requestContext ? JSON.stringify(requestContext) : null,
3747
3753
  entity: entity ? JSON.stringify(entity) : null,
3748
3754
  createdAt: (/* @__PURE__ */ new Date()).toISOString(),