@ixo/sqlite-saver 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. package/.eslintrc.js +9 -0
  2. package/.prettierignore +3 -0
  3. package/.prettierrc.js +4 -0
  4. package/.turbo/turbo-build.log +4 -0
  5. package/CHANGELOG.md +25 -0
  6. package/README.md +0 -0
  7. package/dist/index.d.ts +38 -0
  8. package/dist/index.d.ts.map +1 -0
  9. package/dist/index.js +567 -0
  10. package/dist/index.js.map +1 -0
  11. package/dist/migrations/001_add_created_at_to_messages.d.ts +8 -0
  12. package/dist/migrations/001_add_created_at_to_messages.d.ts.map +1 -0
  13. package/dist/migrations/001_add_created_at_to_messages.js +32 -0
  14. package/dist/migrations/001_add_created_at_to_messages.js.map +1 -0
  15. package/dist/tests/agent-with-checkpoiner.test.d.ts +2 -0
  16. package/dist/tests/agent-with-checkpoiner.test.d.ts.map +1 -0
  17. package/dist/tests/agent-with-checkpoiner.test.js +206 -0
  18. package/dist/tests/agent-with-checkpoiner.test.js.map +1 -0
  19. package/dist/tests/checkpointer.test.d.ts +2 -0
  20. package/dist/tests/checkpointer.test.d.ts.map +1 -0
  21. package/dist/tests/checkpointer.test.js +426 -0
  22. package/dist/tests/checkpointer.test.js.map +1 -0
  23. package/dist/utils.d.ts +15 -0
  24. package/dist/utils.d.ts.map +1 -0
  25. package/dist/utils.js +284 -0
  26. package/dist/utils.js.map +1 -0
  27. package/jest.config.js +6 -0
  28. package/package.json +41 -0
  29. package/src/index.ts +929 -0
  30. package/src/migrations/001_add_created_at_to_messages.ts +48 -0
  31. package/src/tests/agent-with-checkpoiner.test.ts +264 -0
  32. package/src/tests/checkpointer.test.ts +628 -0
  33. package/src/utils.ts +358 -0
  34. package/tsconfig.json +11 -0
  35. package/tsconfig.tsbuildinfo +1 -0
package/src/index.ts ADDED
@@ -0,0 +1,929 @@
1
+ import type { RunnableConfig } from '@langchain/core/runnables';
2
+ import {
3
+ BaseCheckpointSaver,
4
+ type Checkpoint,
5
+ type CheckpointListOptions,
6
+ type CheckpointMetadata,
7
+ type CheckpointTuple,
8
+ copyCheckpoint,
9
+ maxChannelVersion,
10
+ type PendingWrite,
11
+ type SerializerProtocol,
12
+ TASKS,
13
+ } from '@langchain/langgraph-checkpoint';
14
+ import Database, { Database as DatabaseType, Statement } from 'better-sqlite3';
15
+ import { BaseMessage } from 'langchain';
16
+ import migration001 from './migrations/001_add_created_at_to_messages';
17
+ import {
18
+ _default,
19
+ CleanAdditionalKwargs,
20
+ cleanAdditionalKwargs,
21
+ stringify,
22
+ } from './utils';
23
+
24
+ type ChannelValues<C extends string = string> = {
25
+ [key in C]: unknown;
26
+ };
27
+
28
+ interface CheckpointWithMessages<
29
+ N extends string = string,
30
+ C extends string = string,
31
+ > extends Checkpoint<N, C> {
32
+ channel_values: ChannelValues<C> & {
33
+ messages?: BaseMessage[];
34
+ };
35
+ }
36
+
37
+ interface MessageRow {
38
+ thread_id: string;
39
+ checkpoint_ns: string;
40
+ checkpoint_id: string;
41
+ message_id: string;
42
+ message_type: string;
43
+ message_content: string;
44
+ message: string;
45
+ created_at: string;
46
+ }
47
+
48
+ interface CheckpointRow {
49
+ checkpoint: string;
50
+ metadata: string;
51
+ parent_checkpoint_id?: string;
52
+ thread_id: string;
53
+ checkpoint_id: string;
54
+ checkpoint_ns?: string;
55
+ type?: string;
56
+ pending_writes: string;
57
+ }
58
+
59
+ interface PendingWriteColumn {
60
+ task_id: string;
61
+ channel: string;
62
+ type: string;
63
+ value: string;
64
+ }
65
+
66
+ interface PendingSendColumn {
67
+ type: string;
68
+ value: string;
69
+ }
70
+
71
+ interface Migration {
72
+ version: number;
73
+ name: string;
74
+ up: (db: DatabaseType) => void;
75
+ }
76
+
77
+ // In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys
78
+ // that are part of the `CheckpointMetadata` type. The lines below ensure that we get compile-time errors if the list
79
+ // of keys that we use is out of sync with the `CheckpointMetadata` type.
80
+ const checkpointMetadataKeys = ['source', 'step', 'parents'] as const;
81
+
82
+ type CheckKeys<T, K extends readonly (keyof T)[]> = [K[number]] extends [
83
+ keyof T,
84
+ ]
85
+ ? [keyof T] extends [K[number]]
86
+ ? K
87
+ : never
88
+ : never;
89
+
90
+ function validateKeys<T, K extends readonly (keyof T)[]>(
91
+ keys: CheckKeys<T, K>,
92
+ ): K {
93
+ return keys;
94
+ }
95
+
96
+ // If this line fails to compile, the list of keys that we use in the `SqliteSaver.list` method is out of sync with the
97
+ // `CheckpointMetadata` type. In that case, just update `checkpointMetadataKeys` to contain all the keys in
98
+ // `CheckpointMetadata`
99
+ const validCheckpointMetadataKeys = validateKeys<
100
+ CheckpointMetadata,
101
+ typeof checkpointMetadataKeys
102
+ >(checkpointMetadataKeys);
103
+
104
+ function prepareSql(db: DatabaseType, checkpointId: boolean) {
105
+ const sql = `
106
+ SELECT
107
+ thread_id,
108
+ checkpoint_ns,
109
+ checkpoint_id,
110
+ parent_checkpoint_id,
111
+ type,
112
+ checkpoint,
113
+ metadata,
114
+ (
115
+ SELECT
116
+ json_group_array(
117
+ json_object(
118
+ 'task_id', pw.task_id,
119
+ 'channel', pw.channel,
120
+ 'type', pw.type,
121
+ 'value', CAST(pw.value AS TEXT)
122
+ )
123
+ )
124
+ FROM writes as pw
125
+ WHERE pw.thread_id = checkpoints.thread_id
126
+ AND pw.checkpoint_ns = checkpoints.checkpoint_ns
127
+ AND pw.checkpoint_id = checkpoints.checkpoint_id
128
+ ) as pending_writes,
129
+ (
130
+ SELECT
131
+ json_group_array(
132
+ json_object(
133
+ 'type', ps.type,
134
+ 'value', CAST(ps.value AS TEXT)
135
+ )
136
+ )
137
+ FROM writes as ps
138
+ WHERE ps.thread_id = checkpoints.thread_id
139
+ AND ps.checkpoint_ns = checkpoints.checkpoint_ns
140
+ AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
141
+ AND ps.channel = '${TASKS}'
142
+ ORDER BY ps.idx
143
+ ) as pending_sends
144
+ FROM checkpoints
145
+ WHERE thread_id = ? AND checkpoint_ns = ? ${
146
+ checkpointId
147
+ ? 'AND checkpoint_id = ?'
148
+ : 'ORDER BY checkpoint_id DESC LIMIT 1'
149
+ }`;
150
+
151
+ return db.prepare(sql);
152
+ }
153
+
154
+ export class SqliteSaver extends BaseCheckpointSaver {
155
+ db: DatabaseType;
156
+
157
+ protected isSetup: boolean;
158
+
159
+ protected withoutCheckpoint: Statement;
160
+
161
+ protected withCheckpoint: Statement;
162
+
163
+ protected putCheckpointStmt: Statement;
164
+
165
+ protected putWritesStmt: Statement;
166
+
167
+ protected deleteCheckpointsStmt: Statement;
168
+ protected putMessageStmt: Statement;
169
+
170
+ protected getMessageStmt: Statement;
171
+
172
+ protected deleteWritesStmt: Statement;
173
+
174
+ constructor(db: DatabaseType, serde?: SerializerProtocol) {
175
+ super(serde);
176
+ this.db = db;
177
+ this.isSetup = false;
178
+ }
179
+
180
+ static fromConnString(connStringOrLocalPath: string): SqliteSaver {
181
+ return new SqliteSaver(new Database(connStringOrLocalPath));
182
+ }
183
+
184
+ static fromDatabase(db: DatabaseType, serde?: SerializerProtocol): SqliteSaver {
185
+ return new SqliteSaver(db, serde);
186
+ }
187
+
188
+ close(): void {
189
+ if (this.db.open) {
190
+ this.db.close();
191
+ }
192
+ }
193
+
194
+ /**
195
+ * Create schema_migrations table to track applied migrations
196
+ */
197
+ protected createSchemaMigrationsTable(): void {
198
+ this.db.exec(`
199
+ CREATE TABLE IF NOT EXISTS schema_migrations (
200
+ version INTEGER NOT NULL PRIMARY KEY,
201
+ name TEXT NOT NULL,
202
+ applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
203
+ );
204
+ `);
205
+ }
206
+
207
+ /**
208
+ * Get list of applied migration versions
209
+ */
210
+ protected getAppliedMigrations(): number[] {
211
+ try {
212
+ const rows = this.db
213
+ .prepare('SELECT version FROM schema_migrations ORDER BY version')
214
+ .all() as Array<{ version: number }>;
215
+ return rows.map((row) => row.version);
216
+ } catch {
217
+ // Table doesn't exist yet, return empty array
218
+ return [];
219
+ }
220
+ }
221
+
222
+ /**
223
+ * Record that a migration has been applied
224
+ */
225
+ protected recordMigration(migration: Migration): void {
226
+ this.db
227
+ .prepare(
228
+ 'INSERT INTO schema_migrations (version, name, applied_at) VALUES (?, ?, CURRENT_TIMESTAMP)',
229
+ )
230
+ .run(migration.version, migration.name);
231
+ }
232
+
233
+ /**
234
+ * Load all migration files from the migrations folder
235
+ *
236
+ * Note: When adding new migrations:
237
+ * 1. Import them at the top of this file
238
+ * 2. Add them to the migrations array below
239
+ */
240
+ protected loadMigrations(): Migration[] {
241
+ const migrations: Migration[] = [
242
+ migration001,
243
+ // Add future migrations here:
244
+ // migration002,
245
+ // migration003,
246
+ ];
247
+
248
+ // Sort by version to ensure correct order
249
+ migrations.sort((a, b) => a.version - b.version);
250
+
251
+ return migrations;
252
+ }
253
+
254
+ /**
255
+ * Run pending migrations
256
+ */
257
+ protected runMigrations(): void {
258
+ this.createSchemaMigrationsTable();
259
+
260
+ const appliedVersions = this.getAppliedMigrations();
261
+ const allMigrations = this.loadMigrations();
262
+
263
+ const pendingMigrations = allMigrations.filter(
264
+ (migration) => !appliedVersions.includes(migration.version),
265
+ );
266
+
267
+ if (pendingMigrations.length === 0) {
268
+ return;
269
+ }
270
+
271
+ console.log(`Running ${pendingMigrations.length} pending migration(s)...`);
272
+
273
+ for (const migration of pendingMigrations) {
274
+ try {
275
+ console.log(
276
+ `Applying migration ${migration.version}: ${migration.name}`,
277
+ );
278
+ migration.up(this.db);
279
+ this.recordMigration(migration);
280
+ console.log(
281
+ `✓ Migration ${migration.version}: ${migration.name} applied successfully`,
282
+ );
283
+ } catch (error) {
284
+ console.error(
285
+ `✗ Failed to apply migration ${migration.version}: ${migration.name}`,
286
+ error,
287
+ );
288
+ throw error;
289
+ }
290
+ }
291
+ }
292
+
293
+ protected setup(): void {
294
+ if (this.isSetup) {
295
+ return;
296
+ }
297
+
298
+ // Enable WAL mode for concurrent read/write support and set busy timeout
299
+ this.db.pragma('journal_mode = WAL');
300
+ this.db.pragma('busy_timeout = 5000');
301
+
302
+ // Create base tables
303
+ this.db.exec(`
304
+ CREATE TABLE IF NOT EXISTS checkpoints (
305
+ thread_id TEXT NOT NULL,
306
+ checkpoint_ns TEXT NOT NULL DEFAULT '',
307
+ checkpoint_id TEXT NOT NULL,
308
+ parent_checkpoint_id TEXT,
309
+ type TEXT,
310
+ checkpoint BLOB,
311
+ metadata BLOB,
312
+ PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
313
+ );`);
314
+ this.db.exec(`
315
+ CREATE TABLE IF NOT EXISTS writes (
316
+ thread_id TEXT NOT NULL,
317
+ checkpoint_ns TEXT NOT NULL DEFAULT '',
318
+ checkpoint_id TEXT NOT NULL,
319
+ task_id TEXT NOT NULL,
320
+ idx INTEGER NOT NULL,
321
+ channel TEXT NOT NULL,
322
+ type TEXT,
323
+ value BLOB,
324
+ PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
325
+ );`);
326
+
327
+ // Create messages table
328
+ // For new databases, created_at is included
329
+ // For existing databases, the migration will add it if missing
330
+ this.db.exec(`
331
+ CREATE TABLE IF NOT EXISTS messages (
332
+ thread_id TEXT NOT NULL,
333
+ checkpoint_ns TEXT NOT NULL DEFAULT '',
334
+ checkpoint_id TEXT NOT NULL,
335
+ message_id TEXT NOT NULL,
336
+ message_type TEXT NOT NULL,
337
+ message_content TEXT NOT NULL,
338
+ message BLOB,
339
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
340
+ PRIMARY KEY (message_id)
341
+ );
342
+ `);
343
+
344
+ // Create schema_migrations table and run migrations
345
+ this.runMigrations();
346
+
347
+ // Create indexes (after migrations have run)
348
+ this.db.exec(`
349
+ CREATE INDEX IF NOT EXISTS idx_messages_thread_id
350
+ ON messages(thread_id);
351
+ `);
352
+ this.db.exec(`
353
+ CREATE INDEX IF NOT EXISTS idx_messages_checkpoint_id
354
+ ON messages(checkpoint_id);
355
+ `);
356
+
357
+ this.db.exec(`
358
+ CREATE INDEX IF NOT EXISTS idx_messages_lookup
359
+ ON messages(thread_id, checkpoint_ns, checkpoint_id);
360
+ `);
361
+
362
+ // Create index on created_at (migration ensures column exists)
363
+ this.db.exec(`
364
+ CREATE INDEX IF NOT EXISTS idx_messages_thread_created
365
+ ON messages(thread_id, created_at);
366
+ `);
367
+
368
+ this.db.exec(`
369
+ CREATE INDEX IF NOT EXISTS idx_writes_channel
370
+ ON writes(thread_id, checkpoint_id, channel);
371
+ `);
372
+
373
+ this.withoutCheckpoint = prepareSql(this.db, false);
374
+ this.withCheckpoint = prepareSql(this.db, true);
375
+
376
+ // Cache prepared statements for write operations
377
+ this.putCheckpointStmt = this.db.prepare(
378
+ `INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)`,
379
+ );
380
+ this.putWritesStmt = this.db.prepare(`
381
+ INSERT OR REPLACE INTO writes
382
+ (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
383
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
384
+ `);
385
+ this.deleteCheckpointsStmt = this.db.prepare(
386
+ `DELETE FROM checkpoints WHERE thread_id = ?`,
387
+ );
388
+ this.deleteWritesStmt = this.db.prepare(
389
+ `DELETE FROM writes WHERE thread_id = ?`,
390
+ );
391
+
392
+ this.putMessageStmt = this.db.prepare(
393
+ `INSERT OR REPLACE INTO messages (thread_id, checkpoint_ns, checkpoint_id, message_id, message_type, message_content, message, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
394
+ );
395
+
396
+ this.getMessageStmt = this.db.prepare(
397
+ `SELECT * FROM messages WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`,
398
+ );
399
+
400
+ this.isSetup = true;
401
+ }
402
+
403
+ async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
404
+ this.setup();
405
+ const {
406
+ thread_id,
407
+ checkpoint_ns = '',
408
+ checkpoint_id,
409
+ } = config.configurable ?? {};
410
+
411
+ const args = [thread_id, checkpoint_ns];
412
+ if (checkpoint_id) args.push(checkpoint_id);
413
+
414
+ const stm = checkpoint_id ? this.withCheckpoint : this.withoutCheckpoint;
415
+ const row = stm.get(...args) as CheckpointRow;
416
+ if (row === undefined) return undefined;
417
+
418
+ let finalConfig = config;
419
+
420
+ if (!checkpoint_id) {
421
+ finalConfig = {
422
+ configurable: {
423
+ thread_id: row.thread_id,
424
+ checkpoint_ns,
425
+ checkpoint_id: row.checkpoint_id,
426
+ },
427
+ };
428
+ }
429
+
430
+ if (
431
+ finalConfig.configurable?.thread_id === undefined ||
432
+ finalConfig.configurable?.checkpoint_id === undefined
433
+ ) {
434
+ throw new Error('Missing thread_id or checkpoint_id');
435
+ }
436
+
437
+ const messages = this.getMessageStmt.all(
438
+ finalConfig.configurable?.thread_id,
439
+ finalConfig.configurable?.checkpoint_ns,
440
+ finalConfig.configurable?.checkpoint_id,
441
+ ) as MessageRow[];
442
+
443
+ const pendingWrites = await Promise.all(
444
+ (JSON.parse(row.pending_writes) as PendingWriteColumn[]).map(
445
+ async (write) => {
446
+ return [
447
+ write.task_id,
448
+ write.channel,
449
+ await this.serde.loadsTyped(
450
+ write.type ?? 'json',
451
+ write.value ?? '',
452
+ ),
453
+ ] as [string, string, unknown];
454
+ },
455
+ ),
456
+ );
457
+
458
+ const parsedMessages: BaseMessage[] = await Promise.all(
459
+ messages.map(async (message) => {
460
+ return this.serde.loadsTyped('json', message.message);
461
+ }),
462
+ );
463
+
464
+ const checkpoint = (await this.serde.loadsTyped(
465
+ row.type ?? 'json',
466
+ row.checkpoint,
467
+ )) as Checkpoint;
468
+
469
+ if (parsedMessages.length > 0) {
470
+ checkpoint.channel_values.messages = parsedMessages;
471
+ }
472
+
473
+ if (checkpoint.v < 4 && row.parent_checkpoint_id != null) {
474
+ await this.migratePendingSends(
475
+ checkpoint,
476
+ row.thread_id,
477
+ row.parent_checkpoint_id,
478
+ );
479
+ }
480
+
481
+ return {
482
+ checkpoint,
483
+ config: finalConfig,
484
+ metadata: (await this.serde.loadsTyped(
485
+ row.type ?? 'json',
486
+ row.metadata,
487
+ )) as CheckpointMetadata,
488
+ parentConfig: row.parent_checkpoint_id
489
+ ? {
490
+ configurable: {
491
+ thread_id: row.thread_id,
492
+ checkpoint_ns,
493
+ checkpoint_id: row.parent_checkpoint_id,
494
+ },
495
+ }
496
+ : undefined,
497
+ pendingWrites,
498
+ };
499
+ }
500
+
501
+ async *list(
502
+ config: RunnableConfig,
503
+ options?: CheckpointListOptions,
504
+ ): AsyncGenerator<CheckpointTuple> {
505
+ const { limit, before, filter } = options ?? {};
506
+ this.setup();
507
+ const thread_id = config.configurable?.thread_id;
508
+ const checkpoint_ns = config.configurable?.checkpoint_ns;
509
+ let sql = `
510
+ SELECT
511
+ thread_id,
512
+ checkpoint_ns,
513
+ checkpoint_id,
514
+ parent_checkpoint_id,
515
+ type,
516
+ checkpoint,
517
+ metadata,
518
+ (
519
+ SELECT
520
+ json_group_array(
521
+ json_object(
522
+ 'task_id', pw.task_id,
523
+ 'channel', pw.channel,
524
+ 'type', pw.type,
525
+ 'value', CAST(pw.value AS TEXT)
526
+ )
527
+ )
528
+ FROM writes as pw
529
+ WHERE pw.thread_id = checkpoints.thread_id
530
+ AND pw.checkpoint_ns = checkpoints.checkpoint_ns
531
+ AND pw.checkpoint_id = checkpoints.checkpoint_id
532
+ ) as pending_writes,
533
+ (
534
+ SELECT
535
+ json_group_array(
536
+ json_object(
537
+ 'type', ps.type,
538
+ 'value', CAST(ps.value AS TEXT)
539
+ )
540
+ )
541
+ FROM writes as ps
542
+ WHERE ps.thread_id = checkpoints.thread_id
543
+ AND ps.checkpoint_ns = checkpoints.checkpoint_ns
544
+ AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
545
+ AND ps.channel = '${TASKS}'
546
+ ORDER BY ps.idx
547
+ ) as pending_sends
548
+ FROM checkpoints\n`;
549
+
550
+ const whereClause: string[] = [];
551
+
552
+ if (thread_id) {
553
+ whereClause.push('thread_id = ?');
554
+ }
555
+
556
+ if (checkpoint_ns !== undefined && checkpoint_ns !== null) {
557
+ whereClause.push('checkpoint_ns = ?');
558
+ }
559
+
560
+ if (before?.configurable?.checkpoint_id !== undefined) {
561
+ whereClause.push('checkpoint_id < ?');
562
+ }
563
+
564
+ const sanitizedFilter = Object.fromEntries(
565
+ Object.entries(filter ?? {}).filter(
566
+ ([key, value]) =>
567
+ value !== undefined &&
568
+ validCheckpointMetadataKeys.includes(key as keyof CheckpointMetadata),
569
+ ),
570
+ );
571
+
572
+ whereClause.push(
573
+ ...Object.entries(sanitizedFilter).map(
574
+ ([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?`,
575
+ ),
576
+ );
577
+
578
+ if (whereClause.length > 0) {
579
+ sql += `WHERE\n ${whereClause.join(' AND\n ')}\n`;
580
+ }
581
+
582
+ sql += '\nORDER BY checkpoint_id DESC';
583
+
584
+ if (limit) {
585
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
586
+ sql += ` LIMIT ${parseInt(limit as any, 10)}`; // parseInt here (with cast to make TS happy) to sanitize input, as limit may be user-provided
587
+ }
588
+
589
+ const args = [
590
+ thread_id,
591
+ checkpoint_ns,
592
+ before?.configurable?.checkpoint_id,
593
+ ...Object.values(sanitizedFilter).map((value) => JSON.stringify(value)),
594
+ ].filter((value) => value !== undefined && value !== null);
595
+
596
+ const rows: CheckpointRow[] = this.db
597
+ .prepare(sql)
598
+ .all(...args) as CheckpointRow[];
599
+
600
+ if (rows) {
601
+ for (const row of rows) {
602
+ const pendingWrites = await Promise.all(
603
+ (JSON.parse(row.pending_writes) as PendingWriteColumn[]).map(
604
+ async (write) => {
605
+ return [
606
+ write.task_id,
607
+ write.channel,
608
+ await this.serde.loadsTyped(
609
+ write.type ?? 'json',
610
+ write.value ?? '',
611
+ ),
612
+ ] as [string, string, unknown];
613
+ },
614
+ ),
615
+ );
616
+
617
+ const messages = this.getMessageStmt.all(
618
+ row.thread_id,
619
+ row.checkpoint_ns,
620
+ row.checkpoint_id,
621
+ ) as MessageRow[];
622
+ const parsedMessages: BaseMessage[] = await Promise.all(
623
+ messages.map(async (message) => {
624
+ return this.serde.loadsTyped('json', message.message);
625
+ }),
626
+ );
627
+
628
+ const checkpoint = (await this.serde.loadsTyped(
629
+ row.type ?? 'json',
630
+ row.checkpoint,
631
+ )) as Checkpoint;
632
+
633
+ if (parsedMessages.length > 0) {
634
+ checkpoint.channel_values.messages = parsedMessages;
635
+ }
636
+
637
+ if (checkpoint.v < 4 && row.parent_checkpoint_id != null) {
638
+ await this.migratePendingSends(
639
+ checkpoint,
640
+ row.thread_id,
641
+ row.parent_checkpoint_id,
642
+ );
643
+ }
644
+
645
+ yield {
646
+ config: {
647
+ configurable: {
648
+ thread_id: row.thread_id,
649
+ checkpoint_ns: row.checkpoint_ns,
650
+ checkpoint_id: row.checkpoint_id,
651
+ },
652
+ },
653
+ checkpoint,
654
+ metadata: (await this.serde.loadsTyped(
655
+ row.type ?? 'json',
656
+ row.metadata,
657
+ )) as CheckpointMetadata,
658
+ parentConfig: row.parent_checkpoint_id
659
+ ? {
660
+ configurable: {
661
+ thread_id: row.thread_id,
662
+ checkpoint_ns: row.checkpoint_ns,
663
+ checkpoint_id: row.parent_checkpoint_id,
664
+ },
665
+ }
666
+ : undefined,
667
+ pendingWrites,
668
+ };
669
+ }
670
+ }
671
+ }
672
+
673
+ async put(
674
+ config: RunnableConfig,
675
+ _checkpoint: Checkpoint,
676
+ metadata: CheckpointMetadata,
677
+ ): Promise<RunnableConfig> {
678
+ this.setup();
679
+
680
+ if (!config.configurable) {
681
+ throw new Error('Empty configuration supplied.');
682
+ }
683
+
684
+ const thread_id = config.configurable?.thread_id;
685
+ const checkpoint_ns = config.configurable?.checkpoint_ns ?? '';
686
+ const parent_checkpoint_id = config.configurable?.checkpoint_id;
687
+
688
+ if (!thread_id) {
689
+ throw new Error(
690
+ `Missing "thread_id" field in passed "config.configurable".`,
691
+ );
692
+ }
693
+
694
+ const { checkpoint, messages } = removeMessagesFromCheckpoint(_checkpoint);
695
+
696
+ const [[type1, serializedCheckpoint], [type2, serializedMetadata]] =
697
+ await Promise.all([
698
+ this.serde.dumpsTyped(checkpoint),
699
+ this.serde.dumpsTyped(metadata),
700
+ ]);
701
+
702
+ if (type1 !== type2) {
703
+ throw new Error(
704
+ 'Failed to serialized checkpoint and metadata to the same type.',
705
+ );
706
+ }
707
+ const row = [
708
+ thread_id,
709
+ checkpoint_ns,
710
+ checkpoint.id,
711
+ parent_checkpoint_id,
712
+ type1,
713
+ serializedCheckpoint,
714
+ serializedMetadata,
715
+ ];
716
+
717
+ const transaction = this.db.transaction(() => {
718
+ this.putCheckpointStmt.run(...row);
719
+ if (messages) {
720
+ for (const message of messages) {
721
+ const encoder = new TextEncoder();
722
+ const msgFromMatrixRoom = message.additional_kwargs
723
+ .msgFromMatrixRoom as boolean;
724
+ const _additionalKwargs = message.additional_kwargs;
725
+ const cleanedAdditionalKwargs = cleanAdditionalKwargs(
726
+ message.additional_kwargs,
727
+ msgFromMatrixRoom ?? false,
728
+ );
729
+
730
+ message.additional_kwargs = {
731
+ ...cleanedAdditionalKwargs,
732
+ reasoning:
733
+ cleanedAdditionalKwargs.reasoning ?? _additionalKwargs.reasoning,
734
+ reasoningDetails:
735
+ cleanedAdditionalKwargs.reasoningDetails ??
736
+ _additionalKwargs.reasoningDetails,
737
+ };
738
+
739
+ if (message.type !== 'ai') {
740
+ delete message.additional_kwargs.reasoning;
741
+ delete message.additional_kwargs.reasoningDetails;
742
+ }
743
+
744
+ const serializedMessage = encoder.encode(
745
+ stringify(message, (_: string, value: any) => {
746
+ return _default(value);
747
+ }),
748
+ );
749
+
750
+ const messageRow: MessageRow = {
751
+ thread_id,
752
+ checkpoint_ns,
753
+ checkpoint_id: checkpoint.id,
754
+ message_id: message.id ?? message.lc_kwargs?.id,
755
+ message_type: message.type,
756
+ message_content: message.content.toString(),
757
+ message: serializedMessage as unknown as string,
758
+ created_at:
759
+ (message.additional_kwargs as CleanAdditionalKwargs)?.timestamp ??
760
+ new Date().toISOString(),
761
+ };
762
+
763
+ this.putMessageStmt.run(
764
+ messageRow.thread_id,
765
+ messageRow.checkpoint_ns,
766
+ messageRow.checkpoint_id,
767
+ messageRow.message_id,
768
+ messageRow.message_type,
769
+ messageRow.message_content,
770
+ messageRow.message,
771
+ messageRow.created_at,
772
+ );
773
+ }
774
+ }
775
+ });
776
+ transaction();
777
+
778
+ return {
779
+ configurable: {
780
+ thread_id,
781
+ checkpoint_ns,
782
+ checkpoint_id: checkpoint.id,
783
+ },
784
+ };
785
+ }
786
+
787
+ async putWrites(
788
+ config: RunnableConfig,
789
+ writes: PendingWrite[],
790
+ taskId: string,
791
+ ): Promise<void> {
792
+ this.setup();
793
+
794
+ if (!config.configurable) {
795
+ throw new Error('Empty configuration supplied.');
796
+ }
797
+
798
+ if (!config.configurable?.thread_id) {
799
+ console.error('Missing thread_id field in config.configurable.', {
800
+ configurable: config.configurable,
801
+ });
802
+
803
+ // get thread id using the checkpoint_id
804
+ const threadId = this.db
805
+ .prepare('SELECT thread_id FROM checkpoints WHERE checkpoint_id = ?')
806
+ .get(config.configurable?.checkpoint_id) as
807
+ | { thread_id: string }
808
+ | undefined;
809
+ if (!threadId) {
810
+ throw new Error(
811
+ 'Missing thread_id field in config.configurable. config: ' +
812
+ JSON.stringify(config.configurable),
813
+ );
814
+ }
815
+ config.configurable.thread_id = threadId.thread_id;
816
+ }
817
+
818
+ if (!config.configurable?.checkpoint_id) {
819
+ console.error('Missing checkpoint_id field in config.configurable.', {
820
+ configurable: config.configurable,
821
+ });
822
+ throw new Error(
823
+ 'Missing checkpoint_id field in config.configurable. config: ' +
824
+ JSON.stringify(config.configurable),
825
+ );
826
+ }
827
+
828
+ const transaction = this.db.transaction((rows) => {
829
+ for (const row of rows) {
830
+ this.putWritesStmt.run(...row);
831
+ }
832
+ });
833
+
834
+ const rows = await Promise.all(
835
+ writes.map(async (write, idx) => {
836
+ const [type, serializedWrite] = await this.serde.dumpsTyped(write[1]);
837
+ return [
838
+ config.configurable?.thread_id,
839
+ config.configurable?.checkpoint_ns,
840
+ config.configurable?.checkpoint_id,
841
+ taskId,
842
+ idx,
843
+ write[0],
844
+ type,
845
+ serializedWrite,
846
+ ];
847
+ }),
848
+ );
849
+
850
+ transaction(rows);
851
+ }
852
+
853
+ async deleteThread(threadId: string) {
854
+ const transaction = this.db.transaction(() => {
855
+ this.deleteCheckpointsStmt.run(threadId);
856
+ this.deleteWritesStmt.run(threadId);
857
+ });
858
+
859
+ transaction();
860
+ }
861
+
862
+ protected async migratePendingSends(
863
+ checkpoint: Checkpoint,
864
+ threadId: string,
865
+ parentCheckpointId: string,
866
+ ) {
867
+ const { pending_sends } = this.db
868
+ .prepare(
869
+ `
870
+ SELECT
871
+ checkpoint_id,
872
+ json_group_array(
873
+ json_object(
874
+ 'type', ps.type,
875
+ 'value', CAST(ps.value AS TEXT)
876
+ )
877
+ ) as pending_sends
878
+ FROM writes as ps
879
+ WHERE ps.thread_id = ?
880
+ AND ps.checkpoint_id = ?
881
+ AND ps.channel = '${TASKS}'
882
+ ORDER BY ps.idx
883
+ `,
884
+ )
885
+ .get(threadId, parentCheckpointId) as { pending_sends: string };
886
+
887
+ const mutableCheckpoint = checkpoint;
888
+
889
+ // add pending sends to checkpoint
890
+ mutableCheckpoint.channel_values ??= {};
891
+ mutableCheckpoint.channel_values[TASKS] = await Promise.all(
892
+ JSON.parse(pending_sends).map(({ type, value }: PendingSendColumn) =>
893
+ this.serde.loadsTyped(type, value),
894
+ ),
895
+ );
896
+
897
+ // add to versions
898
+ mutableCheckpoint.channel_versions[TASKS] =
899
+ Object.keys(checkpoint.channel_versions).length > 0
900
+ ? maxChannelVersion(...Object.values(checkpoint.channel_versions))
901
+ : this.getNextVersion(undefined);
902
+ }
903
+ }
904
+
905
+ const isCheckpointWithMessages = (
906
+ checkpoint: Checkpoint,
907
+ ): checkpoint is CheckpointWithMessages => {
908
+ return 'messages' in checkpoint.channel_values;
909
+ };
910
+
911
+ const removeMessagesFromCheckpoint = (
912
+ checkpoint: Checkpoint,
913
+ ): {
914
+ checkpoint: Checkpoint;
915
+ messages?: BaseMessage[];
916
+ } => {
917
+ if (isCheckpointWithMessages(checkpoint)) {
918
+ const newCheckpoint = copyCheckpoint(checkpoint);
919
+ delete newCheckpoint.channel_values.messages;
920
+ return {
921
+ checkpoint: newCheckpoint,
922
+ messages: checkpoint.channel_values.messages,
923
+ };
924
+ }
925
+ return {
926
+ checkpoint,
927
+ messages: undefined,
928
+ };
929
+ };