@mastra/mssql 1.0.0-beta.2 → 1.0.0-beta.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { MastraError, ErrorCategory, ErrorDomain } from '@mastra/core/error';
2
- import { MastraStorage, StoreOperations, TABLE_WORKFLOW_SNAPSHOT, TABLE_SCHEMAS, TABLE_THREADS, TABLE_MESSAGES, TABLE_TRACES, TABLE_SCORERS, TABLE_SPANS, ScoresStorage, normalizePerPage, calculatePagination, WorkflowsStorage, MemoryStorage, TABLE_RESOURCES, ObservabilityStorage, safelyParseJSON } from '@mastra/core/storage';
2
+ import { MastraStorage, StoreOperations, TABLE_WORKFLOW_SNAPSHOT, TABLE_SCHEMAS, TABLE_THREADS, TABLE_MESSAGES, TABLE_TRACES, TABLE_SCORERS, TABLE_SPANS, ScoresStorage, normalizePerPage, calculatePagination, WorkflowsStorage, MemoryStorage, TABLE_RESOURCES, ObservabilityStorage, transformScoreRow as transformScoreRow$1 } from '@mastra/core/storage';
3
3
  import sql2 from 'mssql';
4
4
  import { MessageList } from '@mastra/core/agent';
5
5
  import { parseSqlIdentifier } from '@mastra/core/utils';
@@ -26,24 +26,62 @@ function buildDateRangeFilter(dateRange, fieldName) {
26
26
  }
27
27
  return filters;
28
28
  }
29
+ function isInOperator(value) {
30
+ return typeof value === "object" && value !== null && "$in" in value && Array.isArray(value.$in);
31
+ }
29
32
  function prepareWhereClause(filters, _schema) {
30
33
  const conditions = [];
31
34
  const params = {};
32
35
  let paramIndex = 1;
33
36
  Object.entries(filters).forEach(([key, value]) => {
34
37
  if (value === void 0) return;
35
- const paramName = `p${paramIndex++}`;
36
38
  if (key.endsWith("_gte")) {
39
+ const paramName = `p${paramIndex++}`;
37
40
  const fieldName = key.slice(0, -4);
38
41
  conditions.push(`[${parseSqlIdentifier(fieldName, "field name")}] >= @${paramName}`);
39
42
  params[paramName] = value instanceof Date ? value.toISOString() : value;
40
43
  } else if (key.endsWith("_lte")) {
44
+ const paramName = `p${paramIndex++}`;
41
45
  const fieldName = key.slice(0, -4);
42
46
  conditions.push(`[${parseSqlIdentifier(fieldName, "field name")}] <= @${paramName}`);
43
47
  params[paramName] = value instanceof Date ? value.toISOString() : value;
44
48
  } else if (value === null) {
45
49
  conditions.push(`[${parseSqlIdentifier(key, "field name")}] IS NULL`);
50
+ } else if (isInOperator(value)) {
51
+ const inValues = value.$in;
52
+ if (inValues.length === 0) {
53
+ conditions.push("1 = 0");
54
+ } else if (inValues.length === 1) {
55
+ const paramName = `p${paramIndex++}`;
56
+ conditions.push(`[${parseSqlIdentifier(key, "field name")}] = @${paramName}`);
57
+ params[paramName] = inValues[0] instanceof Date ? inValues[0].toISOString() : inValues[0];
58
+ } else {
59
+ const inParamNames = [];
60
+ for (const item of inValues) {
61
+ const paramName = `p${paramIndex++}`;
62
+ inParamNames.push(`@${paramName}`);
63
+ params[paramName] = item instanceof Date ? item.toISOString() : item;
64
+ }
65
+ conditions.push(`[${parseSqlIdentifier(key, "field name")}] IN (${inParamNames.join(", ")})`);
66
+ }
67
+ } else if (Array.isArray(value)) {
68
+ if (value.length === 0) {
69
+ conditions.push("1 = 0");
70
+ } else if (value.length === 1) {
71
+ const paramName = `p${paramIndex++}`;
72
+ conditions.push(`[${parseSqlIdentifier(key, "field name")}] = @${paramName}`);
73
+ params[paramName] = value[0] instanceof Date ? value[0].toISOString() : value[0];
74
+ } else {
75
+ const inParamNames = [];
76
+ for (const item of value) {
77
+ const paramName = `p${paramIndex++}`;
78
+ inParamNames.push(`@${paramName}`);
79
+ params[paramName] = item instanceof Date ? item.toISOString() : item;
80
+ }
81
+ conditions.push(`[${parseSqlIdentifier(key, "field name")}] IN (${inParamNames.join(", ")})`);
82
+ }
46
83
  } else {
84
+ const paramName = `p${paramIndex++}`;
47
85
  conditions.push(`[${parseSqlIdentifier(key, "field name")}] = @${paramName}`);
48
86
  params[paramName] = value instanceof Date ? value.toISOString() : value;
49
87
  }
@@ -381,23 +419,18 @@ var MemoryMSSQL = class extends MemoryStorage {
381
419
  );
382
420
  }
383
421
  }
384
- async _getIncludedMessages({
385
- threadId,
386
- include
387
- }) {
388
- if (!threadId.trim()) throw new Error("threadId must be a non-empty string");
389
- if (!include) return null;
422
+ async _getIncludedMessages({ include }) {
423
+ if (!include || include.length === 0) return null;
390
424
  const unionQueries = [];
391
425
  const paramValues = [];
392
426
  let paramIdx = 1;
393
427
  const paramNames = [];
428
+ const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
394
429
  for (const inc of include) {
395
430
  const { id, withPreviousMessages = 0, withNextMessages = 0 } = inc;
396
- const searchId = inc.threadId || threadId;
397
- const pThreadId = `@p${paramIdx}`;
398
- const pId = `@p${paramIdx + 1}`;
399
- const pPrev = `@p${paramIdx + 2}`;
400
- const pNext = `@p${paramIdx + 3}`;
431
+ const pId = `@p${paramIdx}`;
432
+ const pPrev = `@p${paramIdx + 1}`;
433
+ const pNext = `@p${paramIdx + 2}`;
401
434
  unionQueries.push(
402
435
  `
403
436
  SELECT
@@ -411,16 +444,16 @@ var MemoryMSSQL = class extends MemoryStorage {
411
444
  m.seq_id
412
445
  FROM (
413
446
  SELECT *, ROW_NUMBER() OVER (ORDER BY [createdAt] ASC) as row_num
414
- FROM ${getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) })}
415
- WHERE [thread_id] = ${pThreadId}
447
+ FROM ${tableName}
448
+ WHERE [thread_id] = (SELECT thread_id FROM ${tableName} WHERE id = ${pId})
416
449
  ) AS m
417
450
  WHERE m.id = ${pId}
418
451
  OR EXISTS (
419
452
  SELECT 1
420
453
  FROM (
421
454
  SELECT *, ROW_NUMBER() OVER (ORDER BY [createdAt] ASC) as row_num
422
- FROM ${getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) })}
423
- WHERE [thread_id] = ${pThreadId}
455
+ FROM ${tableName}
456
+ WHERE [thread_id] = (SELECT thread_id FROM ${tableName} WHERE id = ${pId})
424
457
  ) AS target
425
458
  WHERE target.id = ${pId}
426
459
  AND (
@@ -433,9 +466,9 @@ var MemoryMSSQL = class extends MemoryStorage {
433
466
  )
434
467
  `
435
468
  );
436
- paramValues.push(searchId, id, withPreviousMessages, withNextMessages);
437
- paramNames.push(`p${paramIdx}`, `p${paramIdx + 1}`, `p${paramIdx + 2}`, `p${paramIdx + 3}`);
438
- paramIdx += 4;
469
+ paramValues.push(id, withPreviousMessages, withNextMessages);
470
+ paramNames.push(`p${paramIdx}`, `p${paramIdx + 1}`, `p${paramIdx + 2}`);
471
+ paramIdx += 3;
439
472
  }
440
473
  const finalQuery = `
441
474
  SELECT * FROM (
@@ -506,15 +539,16 @@ var MemoryMSSQL = class extends MemoryStorage {
506
539
  }
507
540
  async listMessages(args) {
508
541
  const { threadId, resourceId, include, filter, perPage: perPageInput, page = 0, orderBy } = args;
509
- if (!threadId.trim()) {
542
+ const threadIds = Array.isArray(threadId) ? threadId : [threadId];
543
+ if (threadIds.length === 0 || threadIds.some((id) => !id.trim())) {
510
544
  throw new MastraError(
511
545
  {
512
546
  id: "STORAGE_MSSQL_LIST_MESSAGES_INVALID_THREAD_ID",
513
547
  domain: ErrorDomain.STORAGE,
514
548
  category: ErrorCategory.THIRD_PARTY,
515
- details: { threadId }
549
+ details: { threadId: Array.isArray(threadId) ? threadId.join(",") : threadId }
516
550
  },
517
- new Error("threadId must be a non-empty string")
551
+ new Error("threadId must be a non-empty string or array of non-empty strings")
518
552
  );
519
553
  }
520
554
  if (page < 0) {
@@ -524,7 +558,7 @@ var MemoryMSSQL = class extends MemoryStorage {
524
558
  category: ErrorCategory.USER,
525
559
  text: "Page number must be non-negative",
526
560
  details: {
527
- threadId,
561
+ threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
528
562
  page
529
563
  }
530
564
  });
@@ -537,7 +571,7 @@ var MemoryMSSQL = class extends MemoryStorage {
537
571
  const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
538
572
  const baseQuery = `SELECT seq_id, id, content, role, type, [createdAt], thread_id AS threadId, resourceId FROM ${tableName}`;
539
573
  const filters = {
540
- thread_id: threadId,
574
+ thread_id: threadIds.length === 1 ? threadIds[0] : { $in: threadIds },
541
575
  ...resourceId ? { resourceId } : {},
542
576
  ...buildDateRangeFilter(filter?.dateRange, "createdAt")
543
577
  };
@@ -581,7 +615,7 @@ var MemoryMSSQL = class extends MemoryStorage {
581
615
  }
582
616
  if (include?.length) {
583
617
  const messageIds = new Set(messages.map((m) => m.id));
584
- const includeMessages = await this._getIncludedMessages({ threadId, include });
618
+ const includeMessages = await this._getIncludedMessages({ include });
585
619
  includeMessages?.forEach((msg) => {
586
620
  if (!messageIds.has(msg.id)) {
587
621
  messages.push(msg);
@@ -604,7 +638,8 @@ var MemoryMSSQL = class extends MemoryStorage {
604
638
  const seqB = seqById.get(b.id);
605
639
  return seqA != null && seqB != null ? (seqA - seqB) * mult : a.id.localeCompare(b.id);
606
640
  });
607
- const returnedThreadMessageCount = finalMessages.filter((m) => m.threadId === threadId).length;
641
+ const threadIdSet = new Set(threadIds);
642
+ const returnedThreadMessageCount = finalMessages.filter((m) => m.threadId && threadIdSet.has(m.threadId)).length;
608
643
  const hasMore = perPageInput !== false && returnedThreadMessageCount < total && offset + perPage < total;
609
644
  return {
610
645
  messages: finalMessages,
@@ -620,7 +655,7 @@ var MemoryMSSQL = class extends MemoryStorage {
620
655
  domain: ErrorDomain.STORAGE,
621
656
  category: ErrorCategory.THIRD_PARTY,
622
657
  details: {
623
- threadId,
658
+ threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
624
659
  resourceId: resourceId ?? ""
625
660
  }
626
661
  },
@@ -2255,20 +2290,9 @@ ${columns}
2255
2290
  }
2256
2291
  };
2257
2292
  function transformScoreRow(row) {
2258
- return {
2259
- ...row,
2260
- input: safelyParseJSON(row.input),
2261
- scorer: safelyParseJSON(row.scorer),
2262
- preprocessStepResult: safelyParseJSON(row.preprocessStepResult),
2263
- analyzeStepResult: safelyParseJSON(row.analyzeStepResult),
2264
- metadata: safelyParseJSON(row.metadata),
2265
- output: safelyParseJSON(row.output),
2266
- additionalContext: safelyParseJSON(row.additionalContext),
2267
- requestContext: safelyParseJSON(row.requestContext),
2268
- entity: safelyParseJSON(row.entity),
2269
- createdAt: row.createdAt,
2270
- updatedAt: row.updatedAt
2271
- };
2293
+ return transformScoreRow$1(row, {
2294
+ convertTimestamps: true
2295
+ });
2272
2296
  }
2273
2297
  var ScoresMSSQL = class extends ScoresStorage {
2274
2298
  pool;