pi-rlm 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/engine.ts ADDED
@@ -0,0 +1,462 @@
1
+ import { promises as fs } from "node:fs";
2
+ import { tmpdir } from "node:os";
3
+ import { join } from "node:path";
4
+ import type { ExtensionContext } from "@mariozechner/pi-coding-agent";
5
+ import { completeWithBackend } from "./backends";
6
+ import { plannerPrompt, solverPrompt, synthesisPrompt } from "./prompts";
7
+ import { PlannerDecision, RlmNode, RlmRunResult, RunArtifacts, StartRunInput } from "./types";
8
+ import { extractFirstJsonObject, normalizeTask, shortTask, toErrorMessage } from "./utils";
9
+
10
+ interface EngineInput extends StartRunInput {
11
+ runId: string;
12
+ }
13
+
14
+ interface EngineState {
15
+ nodeCounter: number;
16
+ nodesVisited: number;
17
+ maxDepthSeen: number;
18
+ }
19
+
20
+ type ProgressFn = (line: string) => void;
21
+
22
+ export async function runRlmEngine(
23
+ input: EngineInput,
24
+ ctx: ExtensionContext,
25
+ signal?: AbortSignal,
26
+ progress?: ProgressFn
27
+ ): Promise<RlmRunResult> {
28
+ const startedAt = Date.now();
29
+ const artifacts = await createArtifacts(input.runId);
30
+ const log = createEventLogger(artifacts.eventsPath);
31
+ const state: EngineState = {
32
+ nodeCounter: 0,
33
+ nodesVisited: 0,
34
+ maxDepthSeen: 0
35
+ };
36
+ const activeSignal = signal ?? new AbortController().signal;
37
+
38
+ progress?.(`RLM run ${input.runId} started (${input.backend}, mode=${input.mode})`);
39
+ log("run_start", {
40
+ runId: input.runId,
41
+ backend: input.backend,
42
+ mode: input.mode,
43
+ maxDepth: input.maxDepth,
44
+ maxNodes: input.maxNodes,
45
+ maxBranching: input.maxBranching,
46
+ concurrency: input.concurrency
47
+ });
48
+
49
+ try {
50
+ const root = await runNode({ task: input.task, depth: 0, lineage: [] });
51
+
52
+ const finalOutput = root.result ?? "(no final output)";
53
+
54
+ await fs.writeFile(artifacts.treePath, JSON.stringify(root, null, 2), "utf8");
55
+ await fs.writeFile(artifacts.outputPath, finalOutput, "utf8");
56
+
57
+ const durationMs = Date.now() - startedAt;
58
+ const result: RlmRunResult = {
59
+ runId: input.runId,
60
+ backend: input.backend,
61
+ final: finalOutput,
62
+ root,
63
+ artifacts,
64
+ stats: {
65
+ nodesVisited: state.nodesVisited,
66
+ maxDepthSeen: state.maxDepthSeen,
67
+ durationMs
68
+ }
69
+ };
70
+
71
+ log("run_end", {
72
+ runId: input.runId,
73
+ durationMs,
74
+ nodesVisited: state.nodesVisited,
75
+ maxDepthSeen: state.maxDepthSeen,
76
+ finalChars: finalOutput.length
77
+ });
78
+
79
+ if (root.status === "failed") {
80
+ throw new Error(root.error ?? "RLM root node failed");
81
+ }
82
+
83
+ progress?.(`RLM run ${input.runId} completed in ${durationMs}ms`);
84
+ return result;
85
+ } finally {
86
+ await log.flush();
87
+ }
88
+
89
+ async function runNode(params: {
90
+ task: string;
91
+ depth: number;
92
+ lineage: string[];
93
+ }): Promise<RlmNode> {
94
+ const nodeId = `n${++state.nodeCounter}`;
95
+ state.nodesVisited += 1;
96
+ state.maxDepthSeen = Math.max(state.maxDepthSeen, params.depth);
97
+
98
+ const node: RlmNode = {
99
+ id: nodeId,
100
+ depth: params.depth,
101
+ task: params.task,
102
+ status: "running",
103
+ startedAt: Date.now(),
104
+ children: []
105
+ };
106
+
107
+ progress?.(`[${node.id}] depth=${params.depth} ${shortTask(params.task, 72)}`);
108
+ log("node_start", {
109
+ nodeId: node.id,
110
+ depth: params.depth,
111
+ task: params.task,
112
+ nodesVisited: state.nodesVisited
113
+ });
114
+
115
+ if (activeSignal.aborted) {
116
+ node.status = "cancelled";
117
+ node.error = "Run cancelled";
118
+ node.finishedAt = Date.now();
119
+ log("node_cancelled", { nodeId: node.id, depth: node.depth });
120
+ throw new Error("RLM run cancelled");
121
+ }
122
+
123
+ const normalized = normalizeTask(params.task);
124
+ const remainingNodeBudget = Math.max(0, input.maxNodes - state.nodesVisited);
125
+
126
+ try {
127
+ const forcedReason = getForcedSolveReason({
128
+ depth: params.depth,
129
+ normalizedTask: normalized,
130
+ lineage: params.lineage
131
+ });
132
+
133
+ if (forcedReason || input.mode === "solve") {
134
+ const reason = forcedReason ?? "mode=solve";
135
+ node.decision = { action: "solve", reason };
136
+ node.result = await solveNode(node, reason);
137
+ node.status = "completed";
138
+ node.finishedAt = Date.now();
139
+ log("node_end", {
140
+ nodeId: node.id,
141
+ action: "solve",
142
+ reason,
143
+ chars: node.result.length,
144
+ durationMs: node.finishedAt - node.startedAt
145
+ });
146
+ return node;
147
+ }
148
+
149
+ const decision = await planNode({
150
+ task: params.task,
151
+ depth: params.depth,
152
+ maxDepth: input.maxDepth,
153
+ maxBranching: input.maxBranching,
154
+ remainingNodeBudget
155
+ });
156
+
157
+ node.decision = {
158
+ action: decision.action,
159
+ reason: decision.reason
160
+ };
161
+
162
+ if (decision.action === "solve") {
163
+ node.result = await solveNode(node, decision.reason);
164
+ node.status = "completed";
165
+ node.finishedAt = Date.now();
166
+ log("node_end", {
167
+ nodeId: node.id,
168
+ action: "solve",
169
+ reason: decision.reason,
170
+ chars: node.result.length,
171
+ durationMs: node.finishedAt - node.startedAt
172
+ });
173
+ return node;
174
+ }
175
+
176
+ const subtasks = sanitizeSubtasks(decision.subtasks ?? [], params.task).slice(
177
+ 0,
178
+ input.maxBranching
179
+ );
180
+
181
+ if (subtasks.length < 2) {
182
+ node.decision = {
183
+ action: "solve",
184
+ reason: "planner returned insufficient valid subtasks"
185
+ };
186
+ node.result = await solveNode(node, node.decision.reason);
187
+ node.status = "completed";
188
+ node.finishedAt = Date.now();
189
+ log("node_end", {
190
+ nodeId: node.id,
191
+ action: "solve",
192
+ reason: node.decision.reason,
193
+ chars: node.result.length,
194
+ durationMs: node.finishedAt - node.startedAt
195
+ });
196
+ return node;
197
+ }
198
+
199
+ progress?.(`[${node.id}] decomposing into ${subtasks.length} subtasks`);
200
+ log("node_decompose", {
201
+ nodeId: node.id,
202
+ subtasks,
203
+ reason: decision.reason
204
+ });
205
+
206
+ node.children = await mapConcurrent(subtasks, input.concurrency, async (subtask) => {
207
+ return runNode({
208
+ task: subtask,
209
+ depth: params.depth + 1,
210
+ lineage: [...params.lineage, normalized]
211
+ });
212
+ });
213
+
214
+ node.result = await synthesizeNode(node);
215
+ node.status = "completed";
216
+ node.finishedAt = Date.now();
217
+ log("node_end", {
218
+ nodeId: node.id,
219
+ action: "decompose",
220
+ chars: node.result.length,
221
+ children: node.children.length,
222
+ durationMs: node.finishedAt - node.startedAt
223
+ });
224
+ return node;
225
+ } catch (error) {
226
+ const message = toErrorMessage(error);
227
+ if (activeSignal.aborted || message.toLowerCase().includes("cancel")) {
228
+ node.status = "cancelled";
229
+ node.error = message;
230
+ node.finishedAt = Date.now();
231
+ log("node_cancelled", {
232
+ nodeId: node.id,
233
+ error: message,
234
+ durationMs: node.finishedAt - node.startedAt
235
+ });
236
+ throw error;
237
+ }
238
+
239
+ node.status = "failed";
240
+ node.error = message;
241
+ node.finishedAt = Date.now();
242
+ log("node_error", {
243
+ nodeId: node.id,
244
+ error: message,
245
+ durationMs: node.finishedAt - node.startedAt
246
+ });
247
+
248
+ node.result = `Node failed: ${message}`;
249
+ return node;
250
+ }
251
+ }
252
+
253
+ async function planNode(args: {
254
+ task: string;
255
+ depth: number;
256
+ maxDepth: number;
257
+ maxBranching: number;
258
+ remainingNodeBudget: number;
259
+ }): Promise<PlannerDecision> {
260
+ if (input.mode === "decompose") {
261
+ const forced = await callModel("planner", plannerPrompt(args));
262
+ const parsedForced = parsePlannerDecision(forced);
263
+ if (parsedForced.action === "decompose") {
264
+ return parsedForced;
265
+ }
266
+ return {
267
+ action: "decompose",
268
+ reason: "mode=decompose requested, but planner output was invalid",
269
+ subtasks: []
270
+ };
271
+ }
272
+
273
+ const raw = await callModel("planner", plannerPrompt(args));
274
+ return parsePlannerDecision(raw);
275
+ }
276
+
277
+ async function solveNode(node: RlmNode, forceReason: string): Promise<string> {
278
+ const prompt = solverPrompt({
279
+ task: node.task,
280
+ depth: node.depth,
281
+ maxDepth: input.maxDepth,
282
+ forceReason
283
+ });
284
+ return callModel("solver", prompt, node.id);
285
+ }
286
+
287
+ async function synthesizeNode(node: RlmNode): Promise<string> {
288
+ const prompt = synthesisPrompt({
289
+ task: node.task,
290
+ depth: node.depth,
291
+ children: node.children
292
+ });
293
+ return callModel("synthesizer", prompt, node.id);
294
+ }
295
+
296
+ async function callModel(stage: string, promptText: string, nodeId?: string): Promise<string> {
297
+ if (activeSignal.aborted) {
298
+ throw new Error("RLM run cancelled");
299
+ }
300
+
301
+ log("backend_call", {
302
+ nodeId,
303
+ stage,
304
+ backend: input.backend,
305
+ model: input.model,
306
+ promptChars: promptText.length
307
+ });
308
+
309
+ const output = await completeWithBackend(
310
+ {
311
+ backend: input.backend,
312
+ prompt: promptText,
313
+ cwd: input.cwd,
314
+ model: input.model,
315
+ toolsProfile: input.toolsProfile,
316
+ timeoutMs: input.timeoutMs,
317
+ signal: activeSignal
318
+ },
319
+ ctx
320
+ );
321
+
322
+ log("backend_result", {
323
+ nodeId,
324
+ stage,
325
+ outputChars: output.length
326
+ });
327
+
328
+ return output;
329
+ }
330
+
331
+ function getForcedSolveReason(args: {
332
+ depth: number;
333
+ normalizedTask: string;
334
+ lineage: string[];
335
+ }): string | undefined {
336
+ if (args.depth >= input.maxDepth) {
337
+ return "maxDepth reached";
338
+ }
339
+
340
+ if (state.nodesVisited >= input.maxNodes) {
341
+ return "maxNodes reached";
342
+ }
343
+
344
+ if (args.lineage.includes(args.normalizedTask)) {
345
+ return "cycle detected in task lineage";
346
+ }
347
+
348
+ return undefined;
349
+ }
350
+
351
+ }
352
+
353
+ function parsePlannerDecision(raw: string): PlannerDecision {
354
+ const jsonCandidate = extractFirstJsonObject(raw);
355
+ if (!jsonCandidate) {
356
+ return {
357
+ action: "solve",
358
+ reason: "planner JSON parse failed"
359
+ };
360
+ }
361
+
362
+ try {
363
+ const parsed = JSON.parse(jsonCandidate) as {
364
+ action?: unknown;
365
+ reason?: unknown;
366
+ subtasks?: unknown;
367
+ };
368
+
369
+ const action = parsed.action === "decompose" ? "decompose" : "solve";
370
+ const reason = typeof parsed.reason === "string" ? parsed.reason : "planner did not provide reason";
371
+ const subtasks = Array.isArray(parsed.subtasks)
372
+ ? parsed.subtasks.filter((item): item is string => typeof item === "string")
373
+ : undefined;
374
+
375
+ return { action, reason, subtasks };
376
+ } catch {
377
+ return {
378
+ action: "solve",
379
+ reason: "planner JSON was invalid"
380
+ };
381
+ }
382
+ }
383
+
384
+ function sanitizeSubtasks(subtasks: string[], parentTask: string): string[] {
385
+ const parentNormalized = normalizeTask(parentTask);
386
+ const deduped = new Set<string>();
387
+ const cleaned: string[] = [];
388
+
389
+ for (const subtask of subtasks) {
390
+ const value = subtask.trim();
391
+ if (!value) continue;
392
+
393
+ const normalized = normalizeTask(value);
394
+ if (!normalized || normalized === parentNormalized) continue;
395
+ if (deduped.has(normalized)) continue;
396
+
397
+ deduped.add(normalized);
398
+ cleaned.push(value);
399
+ }
400
+
401
+ return cleaned;
402
+ }
403
+
404
+ async function mapConcurrent<T, R>(
405
+ items: T[],
406
+ concurrency: number,
407
+ worker: (item: T, index: number) => Promise<R>
408
+ ): Promise<R[]> {
409
+ if (items.length === 0) return [];
410
+
411
+ const limit = Math.max(1, concurrency);
412
+ const results = new Array<R>(items.length);
413
+ let nextIndex = 0;
414
+
415
+ await Promise.all(
416
+ Array.from({ length: Math.min(limit, items.length) }, async () => {
417
+ while (nextIndex < items.length) {
418
+ const current = nextIndex;
419
+ nextIndex += 1;
420
+ results[current] = await worker(items[current], current);
421
+ }
422
+ })
423
+ );
424
+
425
+ return results;
426
+ }
427
+
428
+ async function createArtifacts(runId: string): Promise<RunArtifacts> {
429
+ const dir = join(tmpdir(), "pi-rlm-runs", runId);
430
+ await fs.mkdir(dir, { recursive: true });
431
+ return {
432
+ dir,
433
+ eventsPath: join(dir, "events.jsonl"),
434
+ treePath: join(dir, "tree.json"),
435
+ outputPath: join(dir, "output.md")
436
+ };
437
+ }
438
+
439
+ function createEventLogger(path: string): {
440
+ (type: string, payload: Record<string, unknown>): void;
441
+ flush: () => Promise<void>;
442
+ } {
443
+ let tail = Promise.resolve();
444
+
445
+ const write = (type: string, payload: Record<string, unknown>): void => {
446
+ const line = `${JSON.stringify({
447
+ ts: new Date().toISOString(),
448
+ type,
449
+ ...payload
450
+ })}\n`;
451
+
452
+ tail = tail
453
+ .then(() => fs.appendFile(path, line, "utf8"))
454
+ .catch(() => undefined);
455
+ };
456
+
457
+ write.flush = async (): Promise<void> => {
458
+ await tail;
459
+ };
460
+
461
+ return write;
462
+ }
package/src/prompts.ts ADDED
@@ -0,0 +1,83 @@
1
+ import { RlmNode } from "./types";
2
+
3
+ export function plannerPrompt(input: {
4
+ task: string;
5
+ depth: number;
6
+ maxDepth: number;
7
+ maxBranching: number;
8
+ remainingNodeBudget: number;
9
+ }): string {
10
+ return [
11
+ "You are a recursion controller for a recursive language model run.",
12
+ "Decide whether the task should be solved directly, or decomposed into subtasks.",
13
+ "",
14
+ "Return ONLY a JSON object with this schema:",
15
+ '{"action":"solve"|"decompose","reason":"...","subtasks":["..."]}',
16
+ "",
17
+ "Rules:",
18
+ "- Use action=solve if the task is atomic enough for one model pass.",
19
+ "- Use action=decompose only when decomposition is clearly beneficial.",
20
+ `- If action=decompose, return 2 to ${input.maxBranching} subtasks (never more).`,
21
+ "- Subtasks must be clear, non-empty strings.",
22
+ "- Do not include markdown or prose outside JSON.",
23
+ "",
24
+ `Current depth: ${input.depth} / ${input.maxDepth}`,
25
+ `Remaining node budget: ${input.remainingNodeBudget}`,
26
+ "",
27
+ "Task:",
28
+ input.task
29
+ ].join("\n");
30
+ }
31
+
32
+ export function solverPrompt(input: {
33
+ task: string;
34
+ depth: number;
35
+ maxDepth: number;
36
+ forceReason?: string;
37
+ }): string {
38
+ return [
39
+ "You are a worker node in a recursive language model run.",
40
+ "Solve the task directly and return a concrete answer.",
41
+ "",
42
+ `Depth: ${input.depth} / ${input.maxDepth}`,
43
+ input.forceReason ? `Note: forced direct solve because ${input.forceReason}` : "",
44
+ "",
45
+ "Task:",
46
+ input.task
47
+ ]
48
+ .filter(Boolean)
49
+ .join("\n");
50
+ }
51
+
52
+ export function synthesisPrompt(input: {
53
+ task: string;
54
+ depth: number;
55
+ children: RlmNode[];
56
+ }): string {
57
+ const childBlocks = input.children
58
+ .map((child, idx) => {
59
+ const status = child.status.toUpperCase();
60
+ const result = child.result ?? child.error ?? "(no output)";
61
+ return [
62
+ `### Child ${idx + 1} (${status})`,
63
+ `Subtask: ${child.task}`,
64
+ "Output:",
65
+ result
66
+ ].join("\n");
67
+ })
68
+ .join("\n\n");
69
+
70
+ return [
71
+ "You are the synthesizer node in a recursive language model run.",
72
+ "Combine child results into one final response to the parent task.",
73
+ "Be explicit about uncertainties if child outputs conflict.",
74
+ "",
75
+ `Depth: ${input.depth}`,
76
+ "",
77
+ "Parent task:",
78
+ input.task,
79
+ "",
80
+ "Child outputs:",
81
+ childBlocks
82
+ ].join("\n");
83
+ }
package/src/runs.ts ADDED
@@ -0,0 +1,128 @@
1
+ import { RlmRunResult, RunRecord, RunStatus, StartRunInput } from "./types";
2
+ import { createRunId, toErrorMessage } from "./utils";
3
+
4
+ const maxRecords = 200;
5
+
6
+ export class RunStore {
7
+ private readonly records = new Map<string, RunRecord>();
8
+
9
+ start(
10
+ input: StartRunInput,
11
+ executor: (runId: string, signal: AbortSignal) => Promise<RlmRunResult>,
12
+ externalSignal?: AbortSignal
13
+ ): RunRecord {
14
+ const id = createRunId();
15
+ const controller = new AbortController();
16
+
17
+ if (externalSignal) {
18
+ if (externalSignal.aborted) {
19
+ controller.abort();
20
+ } else {
21
+ externalSignal.addEventListener(
22
+ "abort",
23
+ () => {
24
+ controller.abort();
25
+ },
26
+ { once: true }
27
+ );
28
+ }
29
+ }
30
+
31
+ const record: RunRecord = {
32
+ id,
33
+ input,
34
+ status: "running",
35
+ createdAt: Date.now(),
36
+ startedAt: Date.now(),
37
+ controller,
38
+ promise: Promise.resolve(null as unknown as RlmRunResult)
39
+ };
40
+
41
+ record.promise = (async () => {
42
+ try {
43
+ const result = await executor(id, controller.signal);
44
+ record.status = "completed";
45
+ record.finishedAt = Date.now();
46
+ record.result = result;
47
+ return result;
48
+ } catch (error) {
49
+ const message = toErrorMessage(error);
50
+ record.finishedAt = Date.now();
51
+ if (controller.signal.aborted || message.toLowerCase().includes("cancel")) {
52
+ record.status = "cancelled";
53
+ } else {
54
+ record.status = "failed";
55
+ }
56
+ record.error = message;
57
+ throw error;
58
+ } finally {
59
+ this.prune();
60
+ }
61
+ })();
62
+
63
+ this.records.set(id, record);
64
+ this.prune();
65
+ return record;
66
+ }
67
+
68
+ get(id: string): RunRecord | undefined {
69
+ return this.records.get(id);
70
+ }
71
+
72
+ list(): RunRecord[] {
73
+ return Array.from(this.records.values()).sort((a, b) => b.createdAt - a.createdAt);
74
+ }
75
+
76
+ async wait(id: string, timeoutMs: number): Promise<{
77
+ status: RunStatus;
78
+ record: RunRecord;
79
+ done: boolean;
80
+ }> {
81
+ const record = this.records.get(id);
82
+ if (!record) {
83
+ throw new Error(`Unknown run id: ${id}`);
84
+ }
85
+
86
+ if (record.status !== "running") {
87
+ return { status: record.status, record, done: true };
88
+ }
89
+
90
+ let timeoutHandle: NodeJS.Timeout | undefined;
91
+ try {
92
+ await Promise.race([
93
+ record.promise.then(() => undefined).catch(() => undefined),
94
+ new Promise((resolve) => {
95
+ timeoutHandle = setTimeout(resolve, timeoutMs);
96
+ })
97
+ ]);
98
+ } finally {
99
+ if (timeoutHandle) clearTimeout(timeoutHandle);
100
+ }
101
+
102
+ const done = record.status !== "running";
103
+ return { status: record.status, record, done };
104
+ }
105
+
106
+ cancel(id: string): RunRecord {
107
+ const record = this.records.get(id);
108
+ if (!record) {
109
+ throw new Error(`Unknown run id: ${id}`);
110
+ }
111
+
112
+ if (record.status === "running") {
113
+ record.controller.abort();
114
+ }
115
+
116
+ return record;
117
+ }
118
+
119
+ private prune(): void {
120
+ const records = this.list();
121
+ if (records.length <= maxRecords) return;
122
+
123
+ for (const record of records.slice(maxRecords)) {
124
+ if (record.status === "running") continue;
125
+ this.records.delete(record.id);
126
+ }
127
+ }
128
+ }
package/src/schema.ts ADDED
@@ -0,0 +1,29 @@
1
+ import { StringEnum } from "@mariozechner/pi-ai";
2
+ import { Static, Type } from "@sinclair/typebox";
3
+
4
+ const opSchema = StringEnum(["start", "status", "wait", "cancel"] as const);
5
+ const backendSchema = StringEnum(["sdk", "cli", "tmux"] as const);
6
+ const modeSchema = StringEnum(["auto", "solve", "decompose"] as const);
7
+ const toolsProfileSchema = StringEnum(["coding", "read-only"] as const);
8
+
9
+ export const rlmToolParamsSchema = Type.Object({
10
+ op: Type.Optional(opSchema),
11
+ id: Type.Optional(Type.String({ description: "Run ID for status/wait/cancel" })),
12
+ task: Type.Optional(Type.String({ description: "Task to solve recursively" })),
13
+
14
+ backend: Type.Optional(backendSchema),
15
+ mode: Type.Optional(modeSchema),
16
+ async: Type.Optional(Type.Boolean({ description: "Return immediately and run in background" })),
17
+ model: Type.Optional(Type.String({ description: "Optional provider/model[:thinking] override" })),
18
+ cwd: Type.Optional(Type.String({ description: "Working directory for subcalls" })),
19
+ toolsProfile: Type.Optional(toolsProfileSchema),
20
+
21
+ maxDepth: Type.Optional(Type.Integer({ minimum: 0, maximum: 8 })),
22
+ maxNodes: Type.Optional(Type.Integer({ minimum: 1, maximum: 300 })),
23
+ maxBranching: Type.Optional(Type.Integer({ minimum: 1, maximum: 8 })),
24
+ concurrency: Type.Optional(Type.Integer({ minimum: 1, maximum: 8 })),
25
+ timeoutMs: Type.Optional(Type.Integer({ minimum: 1000, maximum: 3600000 })),
26
+ waitTimeoutMs: Type.Optional(Type.Integer({ minimum: 100, maximum: 3600000 }))
27
+ });
28
+
29
+ export type RlmToolParams = Static<typeof rlmToolParamsSchema>;