@ixo/sqlite-saver 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintrc.js +9 -0
- package/.prettierignore +3 -0
- package/.prettierrc.js +4 -0
- package/.turbo/turbo-build.log +4 -0
- package/CHANGELOG.md +25 -0
- package/README.md +0 -0
- package/dist/index.d.ts +38 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +567 -0
- package/dist/index.js.map +1 -0
- package/dist/migrations/001_add_created_at_to_messages.d.ts +8 -0
- package/dist/migrations/001_add_created_at_to_messages.d.ts.map +1 -0
- package/dist/migrations/001_add_created_at_to_messages.js +32 -0
- package/dist/migrations/001_add_created_at_to_messages.js.map +1 -0
- package/dist/tests/agent-with-checkpoiner.test.d.ts +2 -0
- package/dist/tests/agent-with-checkpoiner.test.d.ts.map +1 -0
- package/dist/tests/agent-with-checkpoiner.test.js +206 -0
- package/dist/tests/agent-with-checkpoiner.test.js.map +1 -0
- package/dist/tests/checkpointer.test.d.ts +2 -0
- package/dist/tests/checkpointer.test.d.ts.map +1 -0
- package/dist/tests/checkpointer.test.js +426 -0
- package/dist/tests/checkpointer.test.js.map +1 -0
- package/dist/utils.d.ts +15 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +284 -0
- package/dist/utils.js.map +1 -0
- package/jest.config.js +6 -0
- package/package.json +41 -0
- package/src/index.ts +929 -0
- package/src/migrations/001_add_created_at_to_messages.ts +48 -0
- package/src/tests/agent-with-checkpoiner.test.ts +264 -0
- package/src/tests/checkpointer.test.ts +628 -0
- package/src/utils.ts +358 -0
- package/tsconfig.json +11 -0
- package/tsconfig.tsbuildinfo +1 -0
package/src/index.ts
ADDED
|
@@ -0,0 +1,929 @@
|
|
|
1
|
+
import type { RunnableConfig } from '@langchain/core/runnables';
|
|
2
|
+
import {
|
|
3
|
+
BaseCheckpointSaver,
|
|
4
|
+
type Checkpoint,
|
|
5
|
+
type CheckpointListOptions,
|
|
6
|
+
type CheckpointMetadata,
|
|
7
|
+
type CheckpointTuple,
|
|
8
|
+
copyCheckpoint,
|
|
9
|
+
maxChannelVersion,
|
|
10
|
+
type PendingWrite,
|
|
11
|
+
type SerializerProtocol,
|
|
12
|
+
TASKS,
|
|
13
|
+
} from '@langchain/langgraph-checkpoint';
|
|
14
|
+
import Database, { Database as DatabaseType, Statement } from 'better-sqlite3';
|
|
15
|
+
import { BaseMessage } from 'langchain';
|
|
16
|
+
import migration001 from './migrations/001_add_created_at_to_messages';
|
|
17
|
+
import {
|
|
18
|
+
_default,
|
|
19
|
+
CleanAdditionalKwargs,
|
|
20
|
+
cleanAdditionalKwargs,
|
|
21
|
+
stringify,
|
|
22
|
+
} from './utils';
|
|
23
|
+
|
|
24
|
+
type ChannelValues<C extends string = string> = {
|
|
25
|
+
[key in C]: unknown;
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
interface CheckpointWithMessages<
|
|
29
|
+
N extends string = string,
|
|
30
|
+
C extends string = string,
|
|
31
|
+
> extends Checkpoint<N, C> {
|
|
32
|
+
channel_values: ChannelValues<C> & {
|
|
33
|
+
messages?: BaseMessage[];
|
|
34
|
+
};
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
interface MessageRow {
|
|
38
|
+
thread_id: string;
|
|
39
|
+
checkpoint_ns: string;
|
|
40
|
+
checkpoint_id: string;
|
|
41
|
+
message_id: string;
|
|
42
|
+
message_type: string;
|
|
43
|
+
message_content: string;
|
|
44
|
+
message: string;
|
|
45
|
+
created_at: string;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
interface CheckpointRow {
|
|
49
|
+
checkpoint: string;
|
|
50
|
+
metadata: string;
|
|
51
|
+
parent_checkpoint_id?: string;
|
|
52
|
+
thread_id: string;
|
|
53
|
+
checkpoint_id: string;
|
|
54
|
+
checkpoint_ns?: string;
|
|
55
|
+
type?: string;
|
|
56
|
+
pending_writes: string;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
interface PendingWriteColumn {
|
|
60
|
+
task_id: string;
|
|
61
|
+
channel: string;
|
|
62
|
+
type: string;
|
|
63
|
+
value: string;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
interface PendingSendColumn {
|
|
67
|
+
type: string;
|
|
68
|
+
value: string;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
interface Migration {
|
|
72
|
+
version: number;
|
|
73
|
+
name: string;
|
|
74
|
+
up: (db: DatabaseType) => void;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys
|
|
78
|
+
// that are part of the `CheckpointMetadata` type. The lines below ensure that we get compile-time errors if the list
|
|
79
|
+
// of keys that we use is out of sync with the `CheckpointMetadata` type.
|
|
80
|
+
const checkpointMetadataKeys = ['source', 'step', 'parents'] as const;
|
|
81
|
+
|
|
82
|
+
type CheckKeys<T, K extends readonly (keyof T)[]> = [K[number]] extends [
|
|
83
|
+
keyof T,
|
|
84
|
+
]
|
|
85
|
+
? [keyof T] extends [K[number]]
|
|
86
|
+
? K
|
|
87
|
+
: never
|
|
88
|
+
: never;
|
|
89
|
+
|
|
90
|
+
function validateKeys<T, K extends readonly (keyof T)[]>(
|
|
91
|
+
keys: CheckKeys<T, K>,
|
|
92
|
+
): K {
|
|
93
|
+
return keys;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// If this line fails to compile, the list of keys that we use in the `SqliteSaver.list` method is out of sync with the
|
|
97
|
+
// `CheckpointMetadata` type. In that case, just update `checkpointMetadataKeys` to contain all the keys in
|
|
98
|
+
// `CheckpointMetadata`
|
|
99
|
+
const validCheckpointMetadataKeys = validateKeys<
|
|
100
|
+
CheckpointMetadata,
|
|
101
|
+
typeof checkpointMetadataKeys
|
|
102
|
+
>(checkpointMetadataKeys);
|
|
103
|
+
|
|
104
|
+
function prepareSql(db: DatabaseType, checkpointId: boolean) {
|
|
105
|
+
const sql = `
|
|
106
|
+
SELECT
|
|
107
|
+
thread_id,
|
|
108
|
+
checkpoint_ns,
|
|
109
|
+
checkpoint_id,
|
|
110
|
+
parent_checkpoint_id,
|
|
111
|
+
type,
|
|
112
|
+
checkpoint,
|
|
113
|
+
metadata,
|
|
114
|
+
(
|
|
115
|
+
SELECT
|
|
116
|
+
json_group_array(
|
|
117
|
+
json_object(
|
|
118
|
+
'task_id', pw.task_id,
|
|
119
|
+
'channel', pw.channel,
|
|
120
|
+
'type', pw.type,
|
|
121
|
+
'value', CAST(pw.value AS TEXT)
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
FROM writes as pw
|
|
125
|
+
WHERE pw.thread_id = checkpoints.thread_id
|
|
126
|
+
AND pw.checkpoint_ns = checkpoints.checkpoint_ns
|
|
127
|
+
AND pw.checkpoint_id = checkpoints.checkpoint_id
|
|
128
|
+
) as pending_writes,
|
|
129
|
+
(
|
|
130
|
+
SELECT
|
|
131
|
+
json_group_array(
|
|
132
|
+
json_object(
|
|
133
|
+
'type', ps.type,
|
|
134
|
+
'value', CAST(ps.value AS TEXT)
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
FROM writes as ps
|
|
138
|
+
WHERE ps.thread_id = checkpoints.thread_id
|
|
139
|
+
AND ps.checkpoint_ns = checkpoints.checkpoint_ns
|
|
140
|
+
AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
|
|
141
|
+
AND ps.channel = '${TASKS}'
|
|
142
|
+
ORDER BY ps.idx
|
|
143
|
+
) as pending_sends
|
|
144
|
+
FROM checkpoints
|
|
145
|
+
WHERE thread_id = ? AND checkpoint_ns = ? ${
|
|
146
|
+
checkpointId
|
|
147
|
+
? 'AND checkpoint_id = ?'
|
|
148
|
+
: 'ORDER BY checkpoint_id DESC LIMIT 1'
|
|
149
|
+
}`;
|
|
150
|
+
|
|
151
|
+
return db.prepare(sql);
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
export class SqliteSaver extends BaseCheckpointSaver {
|
|
155
|
+
db: DatabaseType;
|
|
156
|
+
|
|
157
|
+
protected isSetup: boolean;
|
|
158
|
+
|
|
159
|
+
protected withoutCheckpoint: Statement;
|
|
160
|
+
|
|
161
|
+
protected withCheckpoint: Statement;
|
|
162
|
+
|
|
163
|
+
protected putCheckpointStmt: Statement;
|
|
164
|
+
|
|
165
|
+
protected putWritesStmt: Statement;
|
|
166
|
+
|
|
167
|
+
protected deleteCheckpointsStmt: Statement;
|
|
168
|
+
protected putMessageStmt: Statement;
|
|
169
|
+
|
|
170
|
+
protected getMessageStmt: Statement;
|
|
171
|
+
|
|
172
|
+
protected deleteWritesStmt: Statement;
|
|
173
|
+
|
|
174
|
+
constructor(db: DatabaseType, serde?: SerializerProtocol) {
|
|
175
|
+
super(serde);
|
|
176
|
+
this.db = db;
|
|
177
|
+
this.isSetup = false;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
static fromConnString(connStringOrLocalPath: string): SqliteSaver {
|
|
181
|
+
return new SqliteSaver(new Database(connStringOrLocalPath));
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
static fromDatabase(db: DatabaseType, serde?: SerializerProtocol): SqliteSaver {
|
|
185
|
+
return new SqliteSaver(db, serde);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
close(): void {
|
|
189
|
+
if (this.db.open) {
|
|
190
|
+
this.db.close();
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
/**
|
|
195
|
+
* Create schema_migrations table to track applied migrations
|
|
196
|
+
*/
|
|
197
|
+
protected createSchemaMigrationsTable(): void {
|
|
198
|
+
this.db.exec(`
|
|
199
|
+
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
200
|
+
version INTEGER NOT NULL PRIMARY KEY,
|
|
201
|
+
name TEXT NOT NULL,
|
|
202
|
+
applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
203
|
+
);
|
|
204
|
+
`);
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
/**
|
|
208
|
+
* Get list of applied migration versions
|
|
209
|
+
*/
|
|
210
|
+
protected getAppliedMigrations(): number[] {
|
|
211
|
+
try {
|
|
212
|
+
const rows = this.db
|
|
213
|
+
.prepare('SELECT version FROM schema_migrations ORDER BY version')
|
|
214
|
+
.all() as Array<{ version: number }>;
|
|
215
|
+
return rows.map((row) => row.version);
|
|
216
|
+
} catch {
|
|
217
|
+
// Table doesn't exist yet, return empty array
|
|
218
|
+
return [];
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
/**
|
|
223
|
+
* Record that a migration has been applied
|
|
224
|
+
*/
|
|
225
|
+
protected recordMigration(migration: Migration): void {
|
|
226
|
+
this.db
|
|
227
|
+
.prepare(
|
|
228
|
+
'INSERT INTO schema_migrations (version, name, applied_at) VALUES (?, ?, CURRENT_TIMESTAMP)',
|
|
229
|
+
)
|
|
230
|
+
.run(migration.version, migration.name);
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
/**
|
|
234
|
+
* Load all migration files from the migrations folder
|
|
235
|
+
*
|
|
236
|
+
* Note: When adding new migrations:
|
|
237
|
+
* 1. Import them at the top of this file
|
|
238
|
+
* 2. Add them to the migrations array below
|
|
239
|
+
*/
|
|
240
|
+
protected loadMigrations(): Migration[] {
|
|
241
|
+
const migrations: Migration[] = [
|
|
242
|
+
migration001,
|
|
243
|
+
// Add future migrations here:
|
|
244
|
+
// migration002,
|
|
245
|
+
// migration003,
|
|
246
|
+
];
|
|
247
|
+
|
|
248
|
+
// Sort by version to ensure correct order
|
|
249
|
+
migrations.sort((a, b) => a.version - b.version);
|
|
250
|
+
|
|
251
|
+
return migrations;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
/**
|
|
255
|
+
* Run pending migrations
|
|
256
|
+
*/
|
|
257
|
+
protected runMigrations(): void {
|
|
258
|
+
this.createSchemaMigrationsTable();
|
|
259
|
+
|
|
260
|
+
const appliedVersions = this.getAppliedMigrations();
|
|
261
|
+
const allMigrations = this.loadMigrations();
|
|
262
|
+
|
|
263
|
+
const pendingMigrations = allMigrations.filter(
|
|
264
|
+
(migration) => !appliedVersions.includes(migration.version),
|
|
265
|
+
);
|
|
266
|
+
|
|
267
|
+
if (pendingMigrations.length === 0) {
|
|
268
|
+
return;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
console.log(`Running ${pendingMigrations.length} pending migration(s)...`);
|
|
272
|
+
|
|
273
|
+
for (const migration of pendingMigrations) {
|
|
274
|
+
try {
|
|
275
|
+
console.log(
|
|
276
|
+
`Applying migration ${migration.version}: ${migration.name}`,
|
|
277
|
+
);
|
|
278
|
+
migration.up(this.db);
|
|
279
|
+
this.recordMigration(migration);
|
|
280
|
+
console.log(
|
|
281
|
+
`✓ Migration ${migration.version}: ${migration.name} applied successfully`,
|
|
282
|
+
);
|
|
283
|
+
} catch (error) {
|
|
284
|
+
console.error(
|
|
285
|
+
`✗ Failed to apply migration ${migration.version}: ${migration.name}`,
|
|
286
|
+
error,
|
|
287
|
+
);
|
|
288
|
+
throw error;
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
protected setup(): void {
|
|
294
|
+
if (this.isSetup) {
|
|
295
|
+
return;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
// Enable WAL mode for concurrent read/write support and set busy timeout
|
|
299
|
+
this.db.pragma('journal_mode = WAL');
|
|
300
|
+
this.db.pragma('busy_timeout = 5000');
|
|
301
|
+
|
|
302
|
+
// Create base tables
|
|
303
|
+
this.db.exec(`
|
|
304
|
+
CREATE TABLE IF NOT EXISTS checkpoints (
|
|
305
|
+
thread_id TEXT NOT NULL,
|
|
306
|
+
checkpoint_ns TEXT NOT NULL DEFAULT '',
|
|
307
|
+
checkpoint_id TEXT NOT NULL,
|
|
308
|
+
parent_checkpoint_id TEXT,
|
|
309
|
+
type TEXT,
|
|
310
|
+
checkpoint BLOB,
|
|
311
|
+
metadata BLOB,
|
|
312
|
+
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
|
|
313
|
+
);`);
|
|
314
|
+
this.db.exec(`
|
|
315
|
+
CREATE TABLE IF NOT EXISTS writes (
|
|
316
|
+
thread_id TEXT NOT NULL,
|
|
317
|
+
checkpoint_ns TEXT NOT NULL DEFAULT '',
|
|
318
|
+
checkpoint_id TEXT NOT NULL,
|
|
319
|
+
task_id TEXT NOT NULL,
|
|
320
|
+
idx INTEGER NOT NULL,
|
|
321
|
+
channel TEXT NOT NULL,
|
|
322
|
+
type TEXT,
|
|
323
|
+
value BLOB,
|
|
324
|
+
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
|
|
325
|
+
);`);
|
|
326
|
+
|
|
327
|
+
// Create messages table
|
|
328
|
+
// For new databases, created_at is included
|
|
329
|
+
// For existing databases, the migration will add it if missing
|
|
330
|
+
this.db.exec(`
|
|
331
|
+
CREATE TABLE IF NOT EXISTS messages (
|
|
332
|
+
thread_id TEXT NOT NULL,
|
|
333
|
+
checkpoint_ns TEXT NOT NULL DEFAULT '',
|
|
334
|
+
checkpoint_id TEXT NOT NULL,
|
|
335
|
+
message_id TEXT NOT NULL,
|
|
336
|
+
message_type TEXT NOT NULL,
|
|
337
|
+
message_content TEXT NOT NULL,
|
|
338
|
+
message BLOB,
|
|
339
|
+
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
340
|
+
PRIMARY KEY (message_id)
|
|
341
|
+
);
|
|
342
|
+
`);
|
|
343
|
+
|
|
344
|
+
// Create schema_migrations table and run migrations
|
|
345
|
+
this.runMigrations();
|
|
346
|
+
|
|
347
|
+
// Create indexes (after migrations have run)
|
|
348
|
+
this.db.exec(`
|
|
349
|
+
CREATE INDEX IF NOT EXISTS idx_messages_thread_id
|
|
350
|
+
ON messages(thread_id);
|
|
351
|
+
`);
|
|
352
|
+
this.db.exec(`
|
|
353
|
+
CREATE INDEX IF NOT EXISTS idx_messages_checkpoint_id
|
|
354
|
+
ON messages(checkpoint_id);
|
|
355
|
+
`);
|
|
356
|
+
|
|
357
|
+
this.db.exec(`
|
|
358
|
+
CREATE INDEX IF NOT EXISTS idx_messages_lookup
|
|
359
|
+
ON messages(thread_id, checkpoint_ns, checkpoint_id);
|
|
360
|
+
`);
|
|
361
|
+
|
|
362
|
+
// Create index on created_at (migration ensures column exists)
|
|
363
|
+
this.db.exec(`
|
|
364
|
+
CREATE INDEX IF NOT EXISTS idx_messages_thread_created
|
|
365
|
+
ON messages(thread_id, created_at);
|
|
366
|
+
`);
|
|
367
|
+
|
|
368
|
+
this.db.exec(`
|
|
369
|
+
CREATE INDEX IF NOT EXISTS idx_writes_channel
|
|
370
|
+
ON writes(thread_id, checkpoint_id, channel);
|
|
371
|
+
`);
|
|
372
|
+
|
|
373
|
+
this.withoutCheckpoint = prepareSql(this.db, false);
|
|
374
|
+
this.withCheckpoint = prepareSql(this.db, true);
|
|
375
|
+
|
|
376
|
+
// Cache prepared statements for write operations
|
|
377
|
+
this.putCheckpointStmt = this.db.prepare(
|
|
378
|
+
`INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
379
|
+
);
|
|
380
|
+
this.putWritesStmt = this.db.prepare(`
|
|
381
|
+
INSERT OR REPLACE INTO writes
|
|
382
|
+
(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
|
|
383
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
384
|
+
`);
|
|
385
|
+
this.deleteCheckpointsStmt = this.db.prepare(
|
|
386
|
+
`DELETE FROM checkpoints WHERE thread_id = ?`,
|
|
387
|
+
);
|
|
388
|
+
this.deleteWritesStmt = this.db.prepare(
|
|
389
|
+
`DELETE FROM writes WHERE thread_id = ?`,
|
|
390
|
+
);
|
|
391
|
+
|
|
392
|
+
this.putMessageStmt = this.db.prepare(
|
|
393
|
+
`INSERT OR REPLACE INTO messages (thread_id, checkpoint_ns, checkpoint_id, message_id, message_type, message_content, message, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
394
|
+
);
|
|
395
|
+
|
|
396
|
+
this.getMessageStmt = this.db.prepare(
|
|
397
|
+
`SELECT * FROM messages WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`,
|
|
398
|
+
);
|
|
399
|
+
|
|
400
|
+
this.isSetup = true;
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
|
|
404
|
+
this.setup();
|
|
405
|
+
const {
|
|
406
|
+
thread_id,
|
|
407
|
+
checkpoint_ns = '',
|
|
408
|
+
checkpoint_id,
|
|
409
|
+
} = config.configurable ?? {};
|
|
410
|
+
|
|
411
|
+
const args = [thread_id, checkpoint_ns];
|
|
412
|
+
if (checkpoint_id) args.push(checkpoint_id);
|
|
413
|
+
|
|
414
|
+
const stm = checkpoint_id ? this.withCheckpoint : this.withoutCheckpoint;
|
|
415
|
+
const row = stm.get(...args) as CheckpointRow;
|
|
416
|
+
if (row === undefined) return undefined;
|
|
417
|
+
|
|
418
|
+
let finalConfig = config;
|
|
419
|
+
|
|
420
|
+
if (!checkpoint_id) {
|
|
421
|
+
finalConfig = {
|
|
422
|
+
configurable: {
|
|
423
|
+
thread_id: row.thread_id,
|
|
424
|
+
checkpoint_ns,
|
|
425
|
+
checkpoint_id: row.checkpoint_id,
|
|
426
|
+
},
|
|
427
|
+
};
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
if (
|
|
431
|
+
finalConfig.configurable?.thread_id === undefined ||
|
|
432
|
+
finalConfig.configurable?.checkpoint_id === undefined
|
|
433
|
+
) {
|
|
434
|
+
throw new Error('Missing thread_id or checkpoint_id');
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
const messages = this.getMessageStmt.all(
|
|
438
|
+
finalConfig.configurable?.thread_id,
|
|
439
|
+
finalConfig.configurable?.checkpoint_ns,
|
|
440
|
+
finalConfig.configurable?.checkpoint_id,
|
|
441
|
+
) as MessageRow[];
|
|
442
|
+
|
|
443
|
+
const pendingWrites = await Promise.all(
|
|
444
|
+
(JSON.parse(row.pending_writes) as PendingWriteColumn[]).map(
|
|
445
|
+
async (write) => {
|
|
446
|
+
return [
|
|
447
|
+
write.task_id,
|
|
448
|
+
write.channel,
|
|
449
|
+
await this.serde.loadsTyped(
|
|
450
|
+
write.type ?? 'json',
|
|
451
|
+
write.value ?? '',
|
|
452
|
+
),
|
|
453
|
+
] as [string, string, unknown];
|
|
454
|
+
},
|
|
455
|
+
),
|
|
456
|
+
);
|
|
457
|
+
|
|
458
|
+
const parsedMessages: BaseMessage[] = await Promise.all(
|
|
459
|
+
messages.map(async (message) => {
|
|
460
|
+
return this.serde.loadsTyped('json', message.message);
|
|
461
|
+
}),
|
|
462
|
+
);
|
|
463
|
+
|
|
464
|
+
const checkpoint = (await this.serde.loadsTyped(
|
|
465
|
+
row.type ?? 'json',
|
|
466
|
+
row.checkpoint,
|
|
467
|
+
)) as Checkpoint;
|
|
468
|
+
|
|
469
|
+
if (parsedMessages.length > 0) {
|
|
470
|
+
checkpoint.channel_values.messages = parsedMessages;
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
if (checkpoint.v < 4 && row.parent_checkpoint_id != null) {
|
|
474
|
+
await this.migratePendingSends(
|
|
475
|
+
checkpoint,
|
|
476
|
+
row.thread_id,
|
|
477
|
+
row.parent_checkpoint_id,
|
|
478
|
+
);
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
return {
|
|
482
|
+
checkpoint,
|
|
483
|
+
config: finalConfig,
|
|
484
|
+
metadata: (await this.serde.loadsTyped(
|
|
485
|
+
row.type ?? 'json',
|
|
486
|
+
row.metadata,
|
|
487
|
+
)) as CheckpointMetadata,
|
|
488
|
+
parentConfig: row.parent_checkpoint_id
|
|
489
|
+
? {
|
|
490
|
+
configurable: {
|
|
491
|
+
thread_id: row.thread_id,
|
|
492
|
+
checkpoint_ns,
|
|
493
|
+
checkpoint_id: row.parent_checkpoint_id,
|
|
494
|
+
},
|
|
495
|
+
}
|
|
496
|
+
: undefined,
|
|
497
|
+
pendingWrites,
|
|
498
|
+
};
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
async *list(
|
|
502
|
+
config: RunnableConfig,
|
|
503
|
+
options?: CheckpointListOptions,
|
|
504
|
+
): AsyncGenerator<CheckpointTuple> {
|
|
505
|
+
const { limit, before, filter } = options ?? {};
|
|
506
|
+
this.setup();
|
|
507
|
+
const thread_id = config.configurable?.thread_id;
|
|
508
|
+
const checkpoint_ns = config.configurable?.checkpoint_ns;
|
|
509
|
+
let sql = `
|
|
510
|
+
SELECT
|
|
511
|
+
thread_id,
|
|
512
|
+
checkpoint_ns,
|
|
513
|
+
checkpoint_id,
|
|
514
|
+
parent_checkpoint_id,
|
|
515
|
+
type,
|
|
516
|
+
checkpoint,
|
|
517
|
+
metadata,
|
|
518
|
+
(
|
|
519
|
+
SELECT
|
|
520
|
+
json_group_array(
|
|
521
|
+
json_object(
|
|
522
|
+
'task_id', pw.task_id,
|
|
523
|
+
'channel', pw.channel,
|
|
524
|
+
'type', pw.type,
|
|
525
|
+
'value', CAST(pw.value AS TEXT)
|
|
526
|
+
)
|
|
527
|
+
)
|
|
528
|
+
FROM writes as pw
|
|
529
|
+
WHERE pw.thread_id = checkpoints.thread_id
|
|
530
|
+
AND pw.checkpoint_ns = checkpoints.checkpoint_ns
|
|
531
|
+
AND pw.checkpoint_id = checkpoints.checkpoint_id
|
|
532
|
+
) as pending_writes,
|
|
533
|
+
(
|
|
534
|
+
SELECT
|
|
535
|
+
json_group_array(
|
|
536
|
+
json_object(
|
|
537
|
+
'type', ps.type,
|
|
538
|
+
'value', CAST(ps.value AS TEXT)
|
|
539
|
+
)
|
|
540
|
+
)
|
|
541
|
+
FROM writes as ps
|
|
542
|
+
WHERE ps.thread_id = checkpoints.thread_id
|
|
543
|
+
AND ps.checkpoint_ns = checkpoints.checkpoint_ns
|
|
544
|
+
AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
|
|
545
|
+
AND ps.channel = '${TASKS}'
|
|
546
|
+
ORDER BY ps.idx
|
|
547
|
+
) as pending_sends
|
|
548
|
+
FROM checkpoints\n`;
|
|
549
|
+
|
|
550
|
+
const whereClause: string[] = [];
|
|
551
|
+
|
|
552
|
+
if (thread_id) {
|
|
553
|
+
whereClause.push('thread_id = ?');
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
if (checkpoint_ns !== undefined && checkpoint_ns !== null) {
|
|
557
|
+
whereClause.push('checkpoint_ns = ?');
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
if (before?.configurable?.checkpoint_id !== undefined) {
|
|
561
|
+
whereClause.push('checkpoint_id < ?');
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
const sanitizedFilter = Object.fromEntries(
|
|
565
|
+
Object.entries(filter ?? {}).filter(
|
|
566
|
+
([key, value]) =>
|
|
567
|
+
value !== undefined &&
|
|
568
|
+
validCheckpointMetadataKeys.includes(key as keyof CheckpointMetadata),
|
|
569
|
+
),
|
|
570
|
+
);
|
|
571
|
+
|
|
572
|
+
whereClause.push(
|
|
573
|
+
...Object.entries(sanitizedFilter).map(
|
|
574
|
+
([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?`,
|
|
575
|
+
),
|
|
576
|
+
);
|
|
577
|
+
|
|
578
|
+
if (whereClause.length > 0) {
|
|
579
|
+
sql += `WHERE\n ${whereClause.join(' AND\n ')}\n`;
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
sql += '\nORDER BY checkpoint_id DESC';
|
|
583
|
+
|
|
584
|
+
if (limit) {
|
|
585
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
586
|
+
sql += ` LIMIT ${parseInt(limit as any, 10)}`; // parseInt here (with cast to make TS happy) to sanitize input, as limit may be user-provided
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
const args = [
|
|
590
|
+
thread_id,
|
|
591
|
+
checkpoint_ns,
|
|
592
|
+
before?.configurable?.checkpoint_id,
|
|
593
|
+
...Object.values(sanitizedFilter).map((value) => JSON.stringify(value)),
|
|
594
|
+
].filter((value) => value !== undefined && value !== null);
|
|
595
|
+
|
|
596
|
+
const rows: CheckpointRow[] = this.db
|
|
597
|
+
.prepare(sql)
|
|
598
|
+
.all(...args) as CheckpointRow[];
|
|
599
|
+
|
|
600
|
+
if (rows) {
|
|
601
|
+
for (const row of rows) {
|
|
602
|
+
const pendingWrites = await Promise.all(
|
|
603
|
+
(JSON.parse(row.pending_writes) as PendingWriteColumn[]).map(
|
|
604
|
+
async (write) => {
|
|
605
|
+
return [
|
|
606
|
+
write.task_id,
|
|
607
|
+
write.channel,
|
|
608
|
+
await this.serde.loadsTyped(
|
|
609
|
+
write.type ?? 'json',
|
|
610
|
+
write.value ?? '',
|
|
611
|
+
),
|
|
612
|
+
] as [string, string, unknown];
|
|
613
|
+
},
|
|
614
|
+
),
|
|
615
|
+
);
|
|
616
|
+
|
|
617
|
+
const messages = this.getMessageStmt.all(
|
|
618
|
+
row.thread_id,
|
|
619
|
+
row.checkpoint_ns,
|
|
620
|
+
row.checkpoint_id,
|
|
621
|
+
) as MessageRow[];
|
|
622
|
+
const parsedMessages: BaseMessage[] = await Promise.all(
|
|
623
|
+
messages.map(async (message) => {
|
|
624
|
+
return this.serde.loadsTyped('json', message.message);
|
|
625
|
+
}),
|
|
626
|
+
);
|
|
627
|
+
|
|
628
|
+
const checkpoint = (await this.serde.loadsTyped(
|
|
629
|
+
row.type ?? 'json',
|
|
630
|
+
row.checkpoint,
|
|
631
|
+
)) as Checkpoint;
|
|
632
|
+
|
|
633
|
+
if (parsedMessages.length > 0) {
|
|
634
|
+
checkpoint.channel_values.messages = parsedMessages;
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
if (checkpoint.v < 4 && row.parent_checkpoint_id != null) {
|
|
638
|
+
await this.migratePendingSends(
|
|
639
|
+
checkpoint,
|
|
640
|
+
row.thread_id,
|
|
641
|
+
row.parent_checkpoint_id,
|
|
642
|
+
);
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
yield {
|
|
646
|
+
config: {
|
|
647
|
+
configurable: {
|
|
648
|
+
thread_id: row.thread_id,
|
|
649
|
+
checkpoint_ns: row.checkpoint_ns,
|
|
650
|
+
checkpoint_id: row.checkpoint_id,
|
|
651
|
+
},
|
|
652
|
+
},
|
|
653
|
+
checkpoint,
|
|
654
|
+
metadata: (await this.serde.loadsTyped(
|
|
655
|
+
row.type ?? 'json',
|
|
656
|
+
row.metadata,
|
|
657
|
+
)) as CheckpointMetadata,
|
|
658
|
+
parentConfig: row.parent_checkpoint_id
|
|
659
|
+
? {
|
|
660
|
+
configurable: {
|
|
661
|
+
thread_id: row.thread_id,
|
|
662
|
+
checkpoint_ns: row.checkpoint_ns,
|
|
663
|
+
checkpoint_id: row.parent_checkpoint_id,
|
|
664
|
+
},
|
|
665
|
+
}
|
|
666
|
+
: undefined,
|
|
667
|
+
pendingWrites,
|
|
668
|
+
};
|
|
669
|
+
}
|
|
670
|
+
}
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
async put(
|
|
674
|
+
config: RunnableConfig,
|
|
675
|
+
_checkpoint: Checkpoint,
|
|
676
|
+
metadata: CheckpointMetadata,
|
|
677
|
+
): Promise<RunnableConfig> {
|
|
678
|
+
this.setup();
|
|
679
|
+
|
|
680
|
+
if (!config.configurable) {
|
|
681
|
+
throw new Error('Empty configuration supplied.');
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
const thread_id = config.configurable?.thread_id;
|
|
685
|
+
const checkpoint_ns = config.configurable?.checkpoint_ns ?? '';
|
|
686
|
+
const parent_checkpoint_id = config.configurable?.checkpoint_id;
|
|
687
|
+
|
|
688
|
+
if (!thread_id) {
|
|
689
|
+
throw new Error(
|
|
690
|
+
`Missing "thread_id" field in passed "config.configurable".`,
|
|
691
|
+
);
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
const { checkpoint, messages } = removeMessagesFromCheckpoint(_checkpoint);
|
|
695
|
+
|
|
696
|
+
const [[type1, serializedCheckpoint], [type2, serializedMetadata]] =
|
|
697
|
+
await Promise.all([
|
|
698
|
+
this.serde.dumpsTyped(checkpoint),
|
|
699
|
+
this.serde.dumpsTyped(metadata),
|
|
700
|
+
]);
|
|
701
|
+
|
|
702
|
+
if (type1 !== type2) {
|
|
703
|
+
throw new Error(
|
|
704
|
+
'Failed to serialized checkpoint and metadata to the same type.',
|
|
705
|
+
);
|
|
706
|
+
}
|
|
707
|
+
const row = [
|
|
708
|
+
thread_id,
|
|
709
|
+
checkpoint_ns,
|
|
710
|
+
checkpoint.id,
|
|
711
|
+
parent_checkpoint_id,
|
|
712
|
+
type1,
|
|
713
|
+
serializedCheckpoint,
|
|
714
|
+
serializedMetadata,
|
|
715
|
+
];
|
|
716
|
+
|
|
717
|
+
const transaction = this.db.transaction(() => {
|
|
718
|
+
this.putCheckpointStmt.run(...row);
|
|
719
|
+
if (messages) {
|
|
720
|
+
for (const message of messages) {
|
|
721
|
+
const encoder = new TextEncoder();
|
|
722
|
+
const msgFromMatrixRoom = message.additional_kwargs
|
|
723
|
+
.msgFromMatrixRoom as boolean;
|
|
724
|
+
const _additionalKwargs = message.additional_kwargs;
|
|
725
|
+
const cleanedAdditionalKwargs = cleanAdditionalKwargs(
|
|
726
|
+
message.additional_kwargs,
|
|
727
|
+
msgFromMatrixRoom ?? false,
|
|
728
|
+
);
|
|
729
|
+
|
|
730
|
+
message.additional_kwargs = {
|
|
731
|
+
...cleanedAdditionalKwargs,
|
|
732
|
+
reasoning:
|
|
733
|
+
cleanedAdditionalKwargs.reasoning ?? _additionalKwargs.reasoning,
|
|
734
|
+
reasoningDetails:
|
|
735
|
+
cleanedAdditionalKwargs.reasoningDetails ??
|
|
736
|
+
_additionalKwargs.reasoningDetails,
|
|
737
|
+
};
|
|
738
|
+
|
|
739
|
+
if (message.type !== 'ai') {
|
|
740
|
+
delete message.additional_kwargs.reasoning;
|
|
741
|
+
delete message.additional_kwargs.reasoningDetails;
|
|
742
|
+
}
|
|
743
|
+
|
|
744
|
+
const serializedMessage = encoder.encode(
|
|
745
|
+
stringify(message, (_: string, value: any) => {
|
|
746
|
+
return _default(value);
|
|
747
|
+
}),
|
|
748
|
+
);
|
|
749
|
+
|
|
750
|
+
const messageRow: MessageRow = {
|
|
751
|
+
thread_id,
|
|
752
|
+
checkpoint_ns,
|
|
753
|
+
checkpoint_id: checkpoint.id,
|
|
754
|
+
message_id: message.id ?? message.lc_kwargs?.id,
|
|
755
|
+
message_type: message.type,
|
|
756
|
+
message_content: message.content.toString(),
|
|
757
|
+
message: serializedMessage as unknown as string,
|
|
758
|
+
created_at:
|
|
759
|
+
(message.additional_kwargs as CleanAdditionalKwargs)?.timestamp ??
|
|
760
|
+
new Date().toISOString(),
|
|
761
|
+
};
|
|
762
|
+
|
|
763
|
+
this.putMessageStmt.run(
|
|
764
|
+
messageRow.thread_id,
|
|
765
|
+
messageRow.checkpoint_ns,
|
|
766
|
+
messageRow.checkpoint_id,
|
|
767
|
+
messageRow.message_id,
|
|
768
|
+
messageRow.message_type,
|
|
769
|
+
messageRow.message_content,
|
|
770
|
+
messageRow.message,
|
|
771
|
+
messageRow.created_at,
|
|
772
|
+
);
|
|
773
|
+
}
|
|
774
|
+
}
|
|
775
|
+
});
|
|
776
|
+
transaction();
|
|
777
|
+
|
|
778
|
+
return {
|
|
779
|
+
configurable: {
|
|
780
|
+
thread_id,
|
|
781
|
+
checkpoint_ns,
|
|
782
|
+
checkpoint_id: checkpoint.id,
|
|
783
|
+
},
|
|
784
|
+
};
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
async putWrites(
|
|
788
|
+
config: RunnableConfig,
|
|
789
|
+
writes: PendingWrite[],
|
|
790
|
+
taskId: string,
|
|
791
|
+
): Promise<void> {
|
|
792
|
+
this.setup();
|
|
793
|
+
|
|
794
|
+
if (!config.configurable) {
|
|
795
|
+
throw new Error('Empty configuration supplied.');
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
if (!config.configurable?.thread_id) {
|
|
799
|
+
console.error('Missing thread_id field in config.configurable.', {
|
|
800
|
+
configurable: config.configurable,
|
|
801
|
+
});
|
|
802
|
+
|
|
803
|
+
// get thread id using the checkpoint_id
|
|
804
|
+
const threadId = this.db
|
|
805
|
+
.prepare('SELECT thread_id FROM checkpoints WHERE checkpoint_id = ?')
|
|
806
|
+
.get(config.configurable?.checkpoint_id) as
|
|
807
|
+
| { thread_id: string }
|
|
808
|
+
| undefined;
|
|
809
|
+
if (!threadId) {
|
|
810
|
+
throw new Error(
|
|
811
|
+
'Missing thread_id field in config.configurable. config: ' +
|
|
812
|
+
JSON.stringify(config.configurable),
|
|
813
|
+
);
|
|
814
|
+
}
|
|
815
|
+
config.configurable.thread_id = threadId.thread_id;
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
if (!config.configurable?.checkpoint_id) {
|
|
819
|
+
console.error('Missing checkpoint_id field in config.configurable.', {
|
|
820
|
+
configurable: config.configurable,
|
|
821
|
+
});
|
|
822
|
+
throw new Error(
|
|
823
|
+
'Missing checkpoint_id field in config.configurable. config: ' +
|
|
824
|
+
JSON.stringify(config.configurable),
|
|
825
|
+
);
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
const transaction = this.db.transaction((rows) => {
|
|
829
|
+
for (const row of rows) {
|
|
830
|
+
this.putWritesStmt.run(...row);
|
|
831
|
+
}
|
|
832
|
+
});
|
|
833
|
+
|
|
834
|
+
const rows = await Promise.all(
|
|
835
|
+
writes.map(async (write, idx) => {
|
|
836
|
+
const [type, serializedWrite] = await this.serde.dumpsTyped(write[1]);
|
|
837
|
+
return [
|
|
838
|
+
config.configurable?.thread_id,
|
|
839
|
+
config.configurable?.checkpoint_ns,
|
|
840
|
+
config.configurable?.checkpoint_id,
|
|
841
|
+
taskId,
|
|
842
|
+
idx,
|
|
843
|
+
write[0],
|
|
844
|
+
type,
|
|
845
|
+
serializedWrite,
|
|
846
|
+
];
|
|
847
|
+
}),
|
|
848
|
+
);
|
|
849
|
+
|
|
850
|
+
transaction(rows);
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
async deleteThread(threadId: string) {
|
|
854
|
+
const transaction = this.db.transaction(() => {
|
|
855
|
+
this.deleteCheckpointsStmt.run(threadId);
|
|
856
|
+
this.deleteWritesStmt.run(threadId);
|
|
857
|
+
});
|
|
858
|
+
|
|
859
|
+
transaction();
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
protected async migratePendingSends(
|
|
863
|
+
checkpoint: Checkpoint,
|
|
864
|
+
threadId: string,
|
|
865
|
+
parentCheckpointId: string,
|
|
866
|
+
) {
|
|
867
|
+
const { pending_sends } = this.db
|
|
868
|
+
.prepare(
|
|
869
|
+
`
|
|
870
|
+
SELECT
|
|
871
|
+
checkpoint_id,
|
|
872
|
+
json_group_array(
|
|
873
|
+
json_object(
|
|
874
|
+
'type', ps.type,
|
|
875
|
+
'value', CAST(ps.value AS TEXT)
|
|
876
|
+
)
|
|
877
|
+
) as pending_sends
|
|
878
|
+
FROM writes as ps
|
|
879
|
+
WHERE ps.thread_id = ?
|
|
880
|
+
AND ps.checkpoint_id = ?
|
|
881
|
+
AND ps.channel = '${TASKS}'
|
|
882
|
+
ORDER BY ps.idx
|
|
883
|
+
`,
|
|
884
|
+
)
|
|
885
|
+
.get(threadId, parentCheckpointId) as { pending_sends: string };
|
|
886
|
+
|
|
887
|
+
const mutableCheckpoint = checkpoint;
|
|
888
|
+
|
|
889
|
+
// add pending sends to checkpoint
|
|
890
|
+
mutableCheckpoint.channel_values ??= {};
|
|
891
|
+
mutableCheckpoint.channel_values[TASKS] = await Promise.all(
|
|
892
|
+
JSON.parse(pending_sends).map(({ type, value }: PendingSendColumn) =>
|
|
893
|
+
this.serde.loadsTyped(type, value),
|
|
894
|
+
),
|
|
895
|
+
);
|
|
896
|
+
|
|
897
|
+
// add to versions
|
|
898
|
+
mutableCheckpoint.channel_versions[TASKS] =
|
|
899
|
+
Object.keys(checkpoint.channel_versions).length > 0
|
|
900
|
+
? maxChannelVersion(...Object.values(checkpoint.channel_versions))
|
|
901
|
+
: this.getNextVersion(undefined);
|
|
902
|
+
}
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
const isCheckpointWithMessages = (
|
|
906
|
+
checkpoint: Checkpoint,
|
|
907
|
+
): checkpoint is CheckpointWithMessages => {
|
|
908
|
+
return 'messages' in checkpoint.channel_values;
|
|
909
|
+
};
|
|
910
|
+
|
|
911
|
+
const removeMessagesFromCheckpoint = (
|
|
912
|
+
checkpoint: Checkpoint,
|
|
913
|
+
): {
|
|
914
|
+
checkpoint: Checkpoint;
|
|
915
|
+
messages?: BaseMessage[];
|
|
916
|
+
} => {
|
|
917
|
+
if (isCheckpointWithMessages(checkpoint)) {
|
|
918
|
+
const newCheckpoint = copyCheckpoint(checkpoint);
|
|
919
|
+
delete newCheckpoint.channel_values.messages;
|
|
920
|
+
return {
|
|
921
|
+
checkpoint: newCheckpoint,
|
|
922
|
+
messages: checkpoint.channel_values.messages,
|
|
923
|
+
};
|
|
924
|
+
}
|
|
925
|
+
return {
|
|
926
|
+
checkpoint,
|
|
927
|
+
messages: undefined,
|
|
928
|
+
};
|
|
929
|
+
};
|