pi-cursor-agent 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +19 -0
- package/README.md +45 -0
- package/package.json +50 -0
- package/src/__generated__/agent/v1/agent_pb.ts +4642 -0
- package/src/__generated__/agent/v1/agent_service_connect.ts +71 -0
- package/src/__generated__/agent/v1/apply_agent_diff_tool_pb.ts +317 -0
- package/src/__generated__/agent/v1/ask_question_tool_pb.ts +588 -0
- package/src/__generated__/agent/v1/background_shell_exec_pb.ts +245 -0
- package/src/__generated__/agent/v1/computer_use_tool_pb.ts +959 -0
- package/src/__generated__/agent/v1/control_service_connect.ts +144 -0
- package/src/__generated__/agent/v1/control_service_pb.ts +1308 -0
- package/src/__generated__/agent/v1/create_plan_tool_pb.ts +366 -0
- package/src/__generated__/agent/v1/cursor_packages_pb.ts +278 -0
- package/src/__generated__/agent/v1/cursor_rules_pb.ts +301 -0
- package/src/__generated__/agent/v1/delete_exec_pb.ts +443 -0
- package/src/__generated__/agent/v1/delete_tool_pb.ts +52 -0
- package/src/__generated__/agent/v1/diagnostics_exec_pb.ts +399 -0
- package/src/__generated__/agent/v1/edit_tool_pb.ts +497 -0
- package/src/__generated__/agent/v1/exa_fetch_tool_pb.ts +472 -0
- package/src/__generated__/agent/v1/exa_search_tool_pb.ts +484 -0
- package/src/__generated__/agent/v1/exec_pb.ts +1271 -0
- package/src/__generated__/agent/v1/exec_service_connect.ts +14 -0
- package/src/__generated__/agent/v1/fetch_tool_pb.ts +242 -0
- package/src/__generated__/agent/v1/generate_image_tool_pb.ts +230 -0
- package/src/__generated__/agent/v1/glob_tool_pb.ts +248 -0
- package/src/__generated__/agent/v1/grep_exec_pb.ts +690 -0
- package/src/__generated__/agent/v1/grep_tool_pb.ts +52 -0
- package/src/__generated__/agent/v1/kv_pb.ts +281 -0
- package/src/__generated__/agent/v1/ls_exec_pb.ts +295 -0
- package/src/__generated__/agent/v1/ls_tool_pb.ts +52 -0
- package/src/__generated__/agent/v1/mcp_pb.ts +302 -0
- package/src/__generated__/agent/v1/mcp_resource_tool_pb.ts +688 -0
- package/src/__generated__/agent/v1/mcp_tool_pb.ts +630 -0
- package/src/__generated__/agent/v1/private_worker_bridge_external_connect.ts +26 -0
- package/src/__generated__/agent/v1/read_exec_pb.ts +412 -0
- package/src/__generated__/agent/v1/read_lints_tool_pb.ts +384 -0
- package/src/__generated__/agent/v1/read_tool_pb.ts +342 -0
- package/src/__generated__/agent/v1/record_screen_tool_pb.ts +376 -0
- package/src/__generated__/agent/v1/reflect_tool_pb.ts +236 -0
- package/src/__generated__/agent/v1/repo_pb.ts +154 -0
- package/src/__generated__/agent/v1/report_bugfix_results_tool_pb.ts +305 -0
- package/src/__generated__/agent/v1/request_context_exec_pb.ts +528 -0
- package/src/__generated__/agent/v1/sandbox_pb.ts +125 -0
- package/src/__generated__/agent/v1/selected_context_pb.ts +2272 -0
- package/src/__generated__/agent/v1/semsearch_tool_pb.ts +230 -0
- package/src/__generated__/agent/v1/setup_vm_environment_tool_pb.ts +168 -0
- package/src/__generated__/agent/v1/shell_exec_pb.ts +1195 -0
- package/src/__generated__/agent/v1/shell_tool_pb.ts +176 -0
- package/src/__generated__/agent/v1/start_grind_execution_tool_pb.ts +212 -0
- package/src/__generated__/agent/v1/start_grind_planning_tool_pb.ts +212 -0
- package/src/__generated__/agent/v1/subagents_pb.ts +1106 -0
- package/src/__generated__/agent/v1/switch_mode_tool_pb.ts +429 -0
- package/src/__generated__/agent/v1/todo_tool_pb.ts +551 -0
- package/src/__generated__/agent/v1/utils_pb.ts +348 -0
- package/src/__generated__/agent/v1/web_fetch_tool_pb.ts +429 -0
- package/src/__generated__/agent/v1/web_search_tool_pb.ts +466 -0
- package/src/__generated__/agent/v1/write_exec_pb.ts +379 -0
- package/src/__generated__/agent/v1/write_shell_stdin_tool_pb.ts +224 -0
- package/src/__generated__/aiserver/v1/aiserver_service_connect.ts +40 -0
- package/src/api/agent-service.ts +55 -0
- package/src/api/ai-service.ts +42 -0
- package/src/api/auth.ts +74 -0
- package/src/index.ts +101 -0
- package/src/lib/agent-store/disk.ts +139 -0
- package/src/lib/agent-store/index.ts +72 -0
- package/src/lib/agent-store/json-blob-store.ts +47 -0
- package/src/lib/auth.ts +135 -0
- package/src/lib/backoff.ts +32 -0
- package/src/lib/env.ts +3 -0
- package/src/lib/heartbeat.ts +21 -0
- package/src/pi/agent-store.ts +102 -0
- package/src/pi/env.ts +11 -0
- package/src/pi/executors/delete.ts +129 -0
- package/src/pi/executors/grep.ts +238 -0
- package/src/pi/executors/hook.ts +64 -0
- package/src/pi/executors/ls.ts +107 -0
- package/src/pi/executors/read.ts +73 -0
- package/src/pi/executors/request-context.ts +120 -0
- package/src/pi/executors/shell-stream.ts +136 -0
- package/src/pi/executors/shell.ts +157 -0
- package/src/pi/executors/stubs.ts +173 -0
- package/src/pi/executors/write.ts +189 -0
- package/src/pi/local-resource-provider/index.ts +10 -0
- package/src/pi/local-resource-provider/provider.ts +98 -0
- package/src/pi/local-resource-provider/types.ts +110 -0
- package/src/pi/model-mapping.ts +115 -0
- package/src/pi/model-override.ts +110 -0
- package/src/pi/model.ts +61 -0
- package/src/pi/request-builder.ts +279 -0
- package/src/pi/utils/tool-result.ts +35 -0
- package/src/stream.ts +386 -0
- package/src/tool-host.ts +44 -0
- package/src/vendor/agent-client/checkpoint-controller.ts +34 -0
- package/src/vendor/agent-client/connect.ts +348 -0
- package/src/vendor/agent-client/exec-controller.ts +102 -0
- package/src/vendor/agent-client/index.ts +25 -0
- package/src/vendor/agent-client/interaction-controller.ts +96 -0
- package/src/vendor/agent-client/split-stream.ts +143 -0
- package/src/vendor/agent-core/index.ts +9 -0
- package/src/vendor/agent-core/interaction-conversion.ts +558 -0
- package/src/vendor/agent-exec/controlled.ts +104 -0
- package/src/vendor/agent-exec/index.ts +45 -0
- package/src/vendor/agent-exec/registry-resource-accessor.ts +39 -0
- package/src/vendor/agent-exec/resources.ts +121 -0
- package/src/vendor/agent-exec/serialization.ts +22 -0
- package/src/vendor/agent-exec/simple-controlled-exec-manager.ts +161 -0
- package/src/vendor/agent-kv/agent-store.ts +115 -0
- package/src/vendor/agent-kv/blob-store.ts +36 -0
- package/src/vendor/agent-kv/controlled.ts +117 -0
- package/src/vendor/agent-kv/index.ts +15 -0
- package/src/vendor/agent-kv/serde.ts +44 -0
- package/src/vendor/local-exec/common.ts +19 -0
- package/src/vendor/local-exec/git-executor.ts +37 -0
- package/src/vendor/local-exec/git-helpers.ts +79 -0
- package/src/vendor/local-exec/index.ts +8 -0
- package/src/vendor/utils/index.ts +5 -0
- package/src/vendor/utils/map-writable.ts +34 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import {
|
|
2
|
+
createClient,
|
|
3
|
+
type Client,
|
|
4
|
+
type Interceptor,
|
|
5
|
+
} from "@connectrpc/connect";
|
|
6
|
+
import { createConnectTransport } from "@connectrpc/connect-node";
|
|
7
|
+
import { AiService as AiServiceDef } from "../__generated__/aiserver/v1/aiserver_service_connect";
|
|
8
|
+
|
|
9
|
+
export interface AgentServiceOptions {
|
|
10
|
+
accessToken: string;
|
|
11
|
+
clientType: string;
|
|
12
|
+
clientVersion: string;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
class AiService {
|
|
16
|
+
private readonly client: Client<typeof AiServiceDef>;
|
|
17
|
+
|
|
18
|
+
constructor(baseUrl: string, options: AgentServiceOptions) {
|
|
19
|
+
const authInterceptor: Interceptor = (next) => async (req) => {
|
|
20
|
+
req.header.set("authorization", `Bearer ${options.accessToken}`);
|
|
21
|
+
req.header.set("x-cursor-client-type", options.clientType);
|
|
22
|
+
req.header.set("x-cursor-client-version", options.clientVersion);
|
|
23
|
+
req.header.set("x-ghost-mode", "true");
|
|
24
|
+
req.header.set("x-request-id", crypto.randomUUID());
|
|
25
|
+
return next(req);
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
const transport = createConnectTransport({
|
|
29
|
+
baseUrl,
|
|
30
|
+
httpVersion: "1.1",
|
|
31
|
+
interceptors: [authInterceptor],
|
|
32
|
+
});
|
|
33
|
+
|
|
34
|
+
this.client = createClient(AiServiceDef, transport);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
public async getUsableModels() {
|
|
38
|
+
return this.client.getUsableModels({});
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
export default AiService;
|
package/src/api/auth.ts
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
type AuthResult = {
|
|
2
|
+
accessToken: string;
|
|
3
|
+
refreshToken: string;
|
|
4
|
+
};
|
|
5
|
+
|
|
6
|
+
class Auth {
|
|
7
|
+
constructor(private readonly baseUrl: string) {}
|
|
8
|
+
|
|
9
|
+
public async poll({
|
|
10
|
+
uuid,
|
|
11
|
+
verifier,
|
|
12
|
+
signal,
|
|
13
|
+
}: {
|
|
14
|
+
uuid: string;
|
|
15
|
+
verifier: string;
|
|
16
|
+
signal?: AbortSignal | undefined;
|
|
17
|
+
}) {
|
|
18
|
+
const params = new URLSearchParams({ uuid, verifier });
|
|
19
|
+
return this.fetchJson<AuthResult>(`/auth/poll?${params.toString()}`, {
|
|
20
|
+
headers: { "content-type": "application/json" },
|
|
21
|
+
signal: signal ?? null,
|
|
22
|
+
validator: this.isAuthResult,
|
|
23
|
+
});
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
public async exchangeUserApiKey({
|
|
27
|
+
token,
|
|
28
|
+
signal,
|
|
29
|
+
}: {
|
|
30
|
+
token: string;
|
|
31
|
+
signal?: AbortSignal | undefined;
|
|
32
|
+
}) {
|
|
33
|
+
return this.fetchJson<AuthResult>("/auth/exchange_user_api_key", {
|
|
34
|
+
method: "POST",
|
|
35
|
+
headers: {
|
|
36
|
+
authorization: `Bearer ${token}`,
|
|
37
|
+
"content-type": "application/json",
|
|
38
|
+
},
|
|
39
|
+
body: JSON.stringify({}),
|
|
40
|
+
signal: signal ?? null,
|
|
41
|
+
validator: this.isAuthResult,
|
|
42
|
+
});
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
private async fetchJson<T>(
|
|
46
|
+
url: string,
|
|
47
|
+
{ validator, ...init }: RequestInit & { validator: (data: T) => data is T },
|
|
48
|
+
): Promise<T> {
|
|
49
|
+
const response = await fetch(`${this.baseUrl}${url}`, init);
|
|
50
|
+
if (!response.ok) {
|
|
51
|
+
const error = await response.text();
|
|
52
|
+
throw new Error(`Fetch failed ${url} for ${response.status}: ${error}`);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
const data = await response.json();
|
|
56
|
+
if (!validator(data)) {
|
|
57
|
+
const error = JSON.stringify(data);
|
|
58
|
+
throw new Error(`Fetch failed ${url} for invalid response: ${error}`);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
return data;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
private isAuthResult(data: unknown): data is AuthResult {
|
|
65
|
+
return (
|
|
66
|
+
typeof data === "object" &&
|
|
67
|
+
data !== null &&
|
|
68
|
+
"accessToken" in data &&
|
|
69
|
+
"refreshToken" in data
|
|
70
|
+
);
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
export default Auth;
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import type {
|
|
2
|
+
ExtensionAPI,
|
|
3
|
+
ExtensionContext,
|
|
4
|
+
} from "@mariozechner/pi-coding-agent";
|
|
5
|
+
import type {
|
|
6
|
+
Api,
|
|
7
|
+
OAuthCredentials,
|
|
8
|
+
OAuthLoginCallbacks,
|
|
9
|
+
} from "@mariozechner/pi-ai";
|
|
10
|
+
import AiService from "./api/ai-service";
|
|
11
|
+
import Auth from "./api/auth";
|
|
12
|
+
import AuthManager from "./lib/auth";
|
|
13
|
+
import { restoreAgentStoreFromBranch } from "./pi/agent-store";
|
|
14
|
+
import { streamCursorAgent } from "./stream";
|
|
15
|
+
import {
|
|
16
|
+
CURSOR_API_URL,
|
|
17
|
+
CURSOR_CLIENT_VERSION,
|
|
18
|
+
CURSOR_WEBSITE_URL,
|
|
19
|
+
} from "./lib/env";
|
|
20
|
+
import { getCachedPiModels, updateCachedPiModels } from "./pi/model";
|
|
21
|
+
|
|
22
|
+
const auth = new AuthManager(new Auth(CURSOR_API_URL), CURSOR_WEBSITE_URL);
|
|
23
|
+
|
|
24
|
+
const login = async (
|
|
25
|
+
callbacks: OAuthLoginCallbacks,
|
|
26
|
+
): Promise<OAuthCredentials> => {
|
|
27
|
+
const credentials = await auth.login(callbacks);
|
|
28
|
+
const ai = new AiService(CURSOR_API_URL, {
|
|
29
|
+
accessToken: credentials.access,
|
|
30
|
+
clientVersion: CURSOR_CLIENT_VERSION,
|
|
31
|
+
clientType: "cli",
|
|
32
|
+
});
|
|
33
|
+
await updateCachedPiModels(ai);
|
|
34
|
+
return credentials;
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
const refreshToken = async (
|
|
38
|
+
credentials: OAuthCredentials,
|
|
39
|
+
): Promise<OAuthCredentials> => {
|
|
40
|
+
const refreshed = await auth.refresh(credentials);
|
|
41
|
+
const ai = new AiService(CURSOR_API_URL, {
|
|
42
|
+
accessToken: credentials.access,
|
|
43
|
+
clientVersion: CURSOR_CLIENT_VERSION,
|
|
44
|
+
clientType: "cli",
|
|
45
|
+
});
|
|
46
|
+
await updateCachedPiModels(ai);
|
|
47
|
+
return refreshed;
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
export default (pi: ExtensionAPI) => {
|
|
51
|
+
let lastCtx: ExtensionContext | null = null;
|
|
52
|
+
const getCtx = () => lastCtx;
|
|
53
|
+
|
|
54
|
+
const refreshBranchState = async (ctx: ExtensionContext) => {
|
|
55
|
+
lastCtx = ctx;
|
|
56
|
+
try {
|
|
57
|
+
const sessionId = ctx.sessionManager.getSessionId();
|
|
58
|
+
await restoreAgentStoreFromBranch(
|
|
59
|
+
sessionId,
|
|
60
|
+
ctx.sessionManager.getBranch(),
|
|
61
|
+
);
|
|
62
|
+
} catch {
|
|
63
|
+
// ignore
|
|
64
|
+
}
|
|
65
|
+
};
|
|
66
|
+
|
|
67
|
+
pi.on("session_start", async (_, ctx) => {
|
|
68
|
+
await refreshBranchState(ctx);
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
pi.on("session_switch", async (_, ctx) => {
|
|
72
|
+
await refreshBranchState(ctx);
|
|
73
|
+
});
|
|
74
|
+
|
|
75
|
+
pi.on("session_tree", async (_, ctx) => {
|
|
76
|
+
await refreshBranchState(ctx);
|
|
77
|
+
});
|
|
78
|
+
|
|
79
|
+
pi.on("before_agent_start", async (_, ctx) => {
|
|
80
|
+
lastCtx = ctx;
|
|
81
|
+
});
|
|
82
|
+
|
|
83
|
+
pi.on("agent_start", async (_, ctx) => {
|
|
84
|
+
lastCtx = ctx;
|
|
85
|
+
});
|
|
86
|
+
|
|
87
|
+
pi.registerProvider("cursor", {
|
|
88
|
+
baseUrl: CURSOR_API_URL,
|
|
89
|
+
apiKey: "CURSOR_ACCESS_TOKEN",
|
|
90
|
+
api: "cursor-agent" as unknown as Api,
|
|
91
|
+
streamSimple: (model, context, options) =>
|
|
92
|
+
streamCursorAgent(pi, getCtx, model, context, options),
|
|
93
|
+
models: getCachedPiModels(),
|
|
94
|
+
oauth: {
|
|
95
|
+
name: "Cursor Agent",
|
|
96
|
+
login,
|
|
97
|
+
refreshToken,
|
|
98
|
+
getApiKey: (cred) => cred.access,
|
|
99
|
+
},
|
|
100
|
+
});
|
|
101
|
+
};
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import fs from "node:fs/promises";
|
|
2
|
+
import path from "node:path";
|
|
3
|
+
import { toHex, fromHex, type AgentMetadata } from "../../vendor/agent-kv";
|
|
4
|
+
|
|
5
|
+
interface BlobEntry {
|
|
6
|
+
id: string;
|
|
7
|
+
data: string;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
interface BlobsFile {
|
|
11
|
+
version: 1;
|
|
12
|
+
blobs: BlobEntry[];
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
interface MetaFile {
|
|
16
|
+
version: 1;
|
|
17
|
+
agentId: string;
|
|
18
|
+
latestRootBlobId: string;
|
|
19
|
+
name: string;
|
|
20
|
+
createdAt: number;
|
|
21
|
+
mode: string;
|
|
22
|
+
lastUsedModel?: string;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
const getSessionDir = (baseDir: string, sessionId: string): string =>
|
|
26
|
+
path.join(baseDir, "chats", sessionId);
|
|
27
|
+
|
|
28
|
+
const getBlobsFilePath = (baseDir: string, sessionId: string): string =>
|
|
29
|
+
path.join(getSessionDir(baseDir, sessionId), "blobs.json");
|
|
30
|
+
|
|
31
|
+
const getMetaFilePath = (baseDir: string, sessionId: string): string =>
|
|
32
|
+
path.join(getSessionDir(baseDir, sessionId), "meta.json");
|
|
33
|
+
|
|
34
|
+
export const loadBlobsFromDisk = async (
|
|
35
|
+
baseDir: string,
|
|
36
|
+
sessionId: string,
|
|
37
|
+
): Promise<Map<string, Uint8Array>> => {
|
|
38
|
+
try {
|
|
39
|
+
const text = await fs.readFile(
|
|
40
|
+
getBlobsFilePath(baseDir, sessionId),
|
|
41
|
+
"utf-8",
|
|
42
|
+
);
|
|
43
|
+
const parsed = JSON.parse(text) as BlobsFile;
|
|
44
|
+
if (!parsed || parsed.version !== 1 || !Array.isArray(parsed.blobs)) {
|
|
45
|
+
return new Map();
|
|
46
|
+
}
|
|
47
|
+
const map = new Map<string, Uint8Array>();
|
|
48
|
+
for (const entry of parsed.blobs) {
|
|
49
|
+
if (
|
|
50
|
+
!entry ||
|
|
51
|
+
typeof entry.id !== "string" ||
|
|
52
|
+
typeof entry.data !== "string"
|
|
53
|
+
)
|
|
54
|
+
continue;
|
|
55
|
+
try {
|
|
56
|
+
map.set(entry.id, new Uint8Array(Buffer.from(entry.data, "base64")));
|
|
57
|
+
} catch {
|
|
58
|
+
// skip corrupt entries
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
return map;
|
|
62
|
+
} catch {
|
|
63
|
+
return new Map();
|
|
64
|
+
}
|
|
65
|
+
};
|
|
66
|
+
|
|
67
|
+
export const saveBlobsToDisk = async (
|
|
68
|
+
baseDir: string,
|
|
69
|
+
sessionId: string,
|
|
70
|
+
blobs: Map<string, Uint8Array>,
|
|
71
|
+
): Promise<void> => {
|
|
72
|
+
const dir = getSessionDir(baseDir, sessionId);
|
|
73
|
+
await fs.mkdir(dir, { recursive: true });
|
|
74
|
+
const file: BlobsFile = {
|
|
75
|
+
version: 1,
|
|
76
|
+
blobs: Array.from(blobs.entries()).map(([id, data]) => ({
|
|
77
|
+
id,
|
|
78
|
+
data: Buffer.from(data).toString("base64"),
|
|
79
|
+
})),
|
|
80
|
+
};
|
|
81
|
+
const filePath = getBlobsFilePath(baseDir, sessionId);
|
|
82
|
+
const tmpPath = `${filePath}.tmp`;
|
|
83
|
+
await fs.writeFile(tmpPath, JSON.stringify(file), "utf-8");
|
|
84
|
+
await fs.rename(tmpPath, filePath);
|
|
85
|
+
};
|
|
86
|
+
|
|
87
|
+
export const loadMetaFromDisk = async (
|
|
88
|
+
baseDir: string,
|
|
89
|
+
sessionId: string,
|
|
90
|
+
): Promise<AgentMetadata | null> => {
|
|
91
|
+
try {
|
|
92
|
+
const text = await fs.readFile(
|
|
93
|
+
getMetaFilePath(baseDir, sessionId),
|
|
94
|
+
"utf-8",
|
|
95
|
+
);
|
|
96
|
+
const parsed = JSON.parse(text) as MetaFile;
|
|
97
|
+
if (!parsed || parsed.version !== 1 || typeof parsed.agentId !== "string") {
|
|
98
|
+
return null;
|
|
99
|
+
}
|
|
100
|
+
return {
|
|
101
|
+
agentId: parsed.agentId,
|
|
102
|
+
latestRootBlobId: parsed.latestRootBlobId
|
|
103
|
+
? fromHex(parsed.latestRootBlobId)
|
|
104
|
+
: new Uint8Array(),
|
|
105
|
+
name: parsed.name ?? "New Agent",
|
|
106
|
+
createdAt: parsed.createdAt ?? Date.now(),
|
|
107
|
+
mode: (parsed.mode as AgentMetadata["mode"]) ?? "default",
|
|
108
|
+
...(parsed.lastUsedModel != null && {
|
|
109
|
+
lastUsedModel: parsed.lastUsedModel,
|
|
110
|
+
}),
|
|
111
|
+
};
|
|
112
|
+
} catch {
|
|
113
|
+
return null;
|
|
114
|
+
}
|
|
115
|
+
};
|
|
116
|
+
|
|
117
|
+
export const saveMetaToDisk = async (
|
|
118
|
+
baseDir: string,
|
|
119
|
+
sessionId: string,
|
|
120
|
+
metadata: AgentMetadata,
|
|
121
|
+
): Promise<void> => {
|
|
122
|
+
const dir = getSessionDir(baseDir, sessionId);
|
|
123
|
+
await fs.mkdir(dir, { recursive: true });
|
|
124
|
+
const file: MetaFile = {
|
|
125
|
+
version: 1,
|
|
126
|
+
agentId: metadata.agentId,
|
|
127
|
+
latestRootBlobId: toHex(metadata.latestRootBlobId),
|
|
128
|
+
name: metadata.name,
|
|
129
|
+
createdAt: metadata.createdAt,
|
|
130
|
+
mode: metadata.mode,
|
|
131
|
+
...(metadata.lastUsedModel != null && {
|
|
132
|
+
lastUsedModel: metadata.lastUsedModel,
|
|
133
|
+
}),
|
|
134
|
+
};
|
|
135
|
+
const filePath = getMetaFilePath(baseDir, sessionId);
|
|
136
|
+
const tmpPath = `${filePath}.tmp`;
|
|
137
|
+
await fs.writeFile(tmpPath, JSON.stringify(file), "utf-8");
|
|
138
|
+
await fs.rename(tmpPath, filePath);
|
|
139
|
+
};
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import { AgentStore, getDefaultAgentMetadata } from "../../vendor/agent-kv";
|
|
2
|
+
import {
|
|
3
|
+
loadBlobsFromDisk,
|
|
4
|
+
loadMetaFromDisk,
|
|
5
|
+
saveBlobsToDisk,
|
|
6
|
+
saveMetaToDisk,
|
|
7
|
+
} from "./disk";
|
|
8
|
+
import { JsonBlobStoreWithMetadata } from "./json-blob-store";
|
|
9
|
+
|
|
10
|
+
export interface StoreEntry {
|
|
11
|
+
store: AgentStore;
|
|
12
|
+
jsonStore: JsonBlobStoreWithMetadata;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
const sessionStores = new Map<string, StoreEntry>();
|
|
16
|
+
|
|
17
|
+
export const ensureAgentStore = async (
|
|
18
|
+
baseDir: string,
|
|
19
|
+
sessionId: string,
|
|
20
|
+
): Promise<StoreEntry> => {
|
|
21
|
+
const existing = sessionStores.get(sessionId);
|
|
22
|
+
if (existing) {
|
|
23
|
+
return existing;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
const [blobs, meta] = await Promise.all([
|
|
27
|
+
loadBlobsFromDisk(baseDir, sessionId),
|
|
28
|
+
loadMetaFromDisk(baseDir, sessionId),
|
|
29
|
+
]);
|
|
30
|
+
|
|
31
|
+
const metadata = meta ?? getDefaultAgentMetadata();
|
|
32
|
+
const jsonStore = new JsonBlobStoreWithMetadata(blobs, metadata);
|
|
33
|
+
const store = new AgentStore(jsonStore, jsonStore);
|
|
34
|
+
|
|
35
|
+
if (metadata.latestRootBlobId.length > 0) {
|
|
36
|
+
await store.resetFromDb(null);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
const entry: StoreEntry = { store, jsonStore };
|
|
40
|
+
sessionStores.set(sessionId, entry);
|
|
41
|
+
return entry;
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
export const persistAgentStore = async (
|
|
45
|
+
baseDir: string,
|
|
46
|
+
sessionId: string,
|
|
47
|
+
): Promise<StoreEntry | null> => {
|
|
48
|
+
const entry = sessionStores.get(sessionId);
|
|
49
|
+
if (!entry) {
|
|
50
|
+
return null;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
await Promise.all([
|
|
54
|
+
saveBlobsToDisk(baseDir, sessionId, entry.jsonStore.blobs),
|
|
55
|
+
saveMetaToDisk(baseDir, sessionId, entry.jsonStore.metadata),
|
|
56
|
+
]);
|
|
57
|
+
|
|
58
|
+
return entry;
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
export const applySnapshotToStore = async (
|
|
62
|
+
entry: StoreEntry,
|
|
63
|
+
agentId: string,
|
|
64
|
+
latestRootBlobId: Uint8Array,
|
|
65
|
+
): Promise<void> => {
|
|
66
|
+
entry.jsonStore.metadata.agentId = agentId;
|
|
67
|
+
entry.jsonStore.metadata.latestRootBlobId = latestRootBlobId;
|
|
68
|
+
|
|
69
|
+
if (latestRootBlobId.length > 0) {
|
|
70
|
+
await entry.store.resetFromDb(null);
|
|
71
|
+
}
|
|
72
|
+
};
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import {
|
|
2
|
+
toHex,
|
|
3
|
+
type AgentMetadata,
|
|
4
|
+
type BlobStore,
|
|
5
|
+
type MetadataStore,
|
|
6
|
+
} from "../../vendor/agent-kv";
|
|
7
|
+
|
|
8
|
+
export class JsonBlobStoreWithMetadata implements BlobStore, MetadataStore {
|
|
9
|
+
readonly blobs: Map<string, Uint8Array>;
|
|
10
|
+
readonly metadata: AgentMetadata;
|
|
11
|
+
|
|
12
|
+
constructor(blobs: Map<string, Uint8Array>, metadata: AgentMetadata) {
|
|
13
|
+
this.blobs = blobs;
|
|
14
|
+
this.metadata = metadata;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
public get<K extends keyof AgentMetadata>(key: K): AgentMetadata[K] {
|
|
18
|
+
return this.metadata[key];
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
public set<K extends keyof AgentMetadata>(
|
|
22
|
+
key: K,
|
|
23
|
+
value: AgentMetadata[K],
|
|
24
|
+
): void {
|
|
25
|
+
this.metadata[key] = value;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
public subscribe(_: keyof AgentMetadata, __: () => void): () => void {
|
|
29
|
+
return () => {};
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
public async getBlob(
|
|
33
|
+
_: unknown,
|
|
34
|
+
blobId: Uint8Array,
|
|
35
|
+
): Promise<Uint8Array | undefined> {
|
|
36
|
+
return Promise.resolve(this.blobs.get(toHex(blobId)));
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
public setBlob(
|
|
40
|
+
_ctx: unknown,
|
|
41
|
+
blobId: Uint8Array,
|
|
42
|
+
blobData: Uint8Array,
|
|
43
|
+
): Promise<void> {
|
|
44
|
+
this.blobs.set(toHex(blobId), blobData);
|
|
45
|
+
return Promise.resolve();
|
|
46
|
+
}
|
|
47
|
+
}
|
package/src/lib/auth.ts
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import { createHash, randomBytes } from "node:crypto";
|
|
2
|
+
import { OAuthLoginCallbacks } from "@mariozechner/pi-ai";
|
|
3
|
+
import Auth from "../api/auth";
|
|
4
|
+
import { backoff } from "./backoff";
|
|
5
|
+
|
|
6
|
+
type OnAuth = (info: { url: string; instructions: string }) => void;
|
|
7
|
+
type OnProgress = (message: string) => void;
|
|
8
|
+
|
|
9
|
+
class AuthManager {
|
|
10
|
+
constructor(
|
|
11
|
+
private readonly auth: Auth,
|
|
12
|
+
private readonly websiteUrl: string,
|
|
13
|
+
) {}
|
|
14
|
+
|
|
15
|
+
public async login({
|
|
16
|
+
onAuth,
|
|
17
|
+
onProgress,
|
|
18
|
+
signal,
|
|
19
|
+
}: {
|
|
20
|
+
onAuth: OnAuth;
|
|
21
|
+
onProgress?: OnProgress;
|
|
22
|
+
signal?: AbortSignal;
|
|
23
|
+
}) {
|
|
24
|
+
const { uuid, verifier, loginUrl } = this.generateAuthParams();
|
|
25
|
+
|
|
26
|
+
const instructions = "Complete the sign-in in your browser.";
|
|
27
|
+
onAuth({ url: loginUrl, instructions });
|
|
28
|
+
|
|
29
|
+
return await this.pollAuthenticationStatus({
|
|
30
|
+
uuid,
|
|
31
|
+
verifier,
|
|
32
|
+
onProgress,
|
|
33
|
+
signal,
|
|
34
|
+
});
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
public async refresh(credentials: {
|
|
38
|
+
access: string;
|
|
39
|
+
refresh: string;
|
|
40
|
+
}): Promise<{ access: string; refresh: string; expires: number }> {
|
|
41
|
+
if (!credentials.access && !credentials.refresh) {
|
|
42
|
+
throw new Error("No credentials provided");
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
try {
|
|
46
|
+
const { accessToken, refreshToken } = await this.auth.exchangeUserApiKey({
|
|
47
|
+
token: credentials.refresh || credentials.access,
|
|
48
|
+
});
|
|
49
|
+
const expires = getTokenExpiry(accessToken);
|
|
50
|
+
return { access: accessToken, refresh: refreshToken, expires };
|
|
51
|
+
} catch {
|
|
52
|
+
// If the refresh token is invalid, try to refresh it with access token
|
|
53
|
+
if (credentials.access && credentials.refresh) {
|
|
54
|
+
return this.refresh({ access: credentials.access, refresh: "" });
|
|
55
|
+
}
|
|
56
|
+
throw new Error("Failed to refresh credentials");
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
private generateAuthParams() {
|
|
61
|
+
const verifier = base64URLEncode(randomBytes(32));
|
|
62
|
+
const challenge = base64URLEncode(
|
|
63
|
+
createHash("sha256").update(verifier).digest(),
|
|
64
|
+
);
|
|
65
|
+
const uuid = crypto.randomUUID();
|
|
66
|
+
const loginUrl = `${this.websiteUrl}/loginDeepControl?challenge=${challenge}&uuid=${uuid}&mode=login&redirectTarget=cli`;
|
|
67
|
+
return { challenge, uuid, verifier, loginUrl };
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
private async pollAuthenticationStatus({
|
|
71
|
+
uuid,
|
|
72
|
+
verifier,
|
|
73
|
+
onProgress,
|
|
74
|
+
signal,
|
|
75
|
+
}: {
|
|
76
|
+
uuid: string;
|
|
77
|
+
verifier: string;
|
|
78
|
+
onProgress?: OAuthLoginCallbacks["onProgress"];
|
|
79
|
+
signal?: AbortSignal | undefined;
|
|
80
|
+
}) {
|
|
81
|
+
return backoff(
|
|
82
|
+
async () => {
|
|
83
|
+
onProgress?.("Polling authentication status...");
|
|
84
|
+
const tokens = await this.auth.poll({ uuid, verifier, signal });
|
|
85
|
+
const { accessToken, refreshToken } = tokens;
|
|
86
|
+
const expires = getTokenExpiry(accessToken);
|
|
87
|
+
return { access: accessToken, refresh: refreshToken, expires };
|
|
88
|
+
},
|
|
89
|
+
{
|
|
90
|
+
retries: 150,
|
|
91
|
+
delay: 1000,
|
|
92
|
+
shouldRetry: (error) =>
|
|
93
|
+
error instanceof Error &&
|
|
94
|
+
error.message.includes("/auth/poll") &&
|
|
95
|
+
error.message.includes("404"),
|
|
96
|
+
},
|
|
97
|
+
);
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
type JwtPayload = {
|
|
102
|
+
exp: number;
|
|
103
|
+
[key: string]: unknown;
|
|
104
|
+
};
|
|
105
|
+
|
|
106
|
+
export const base64URLEncode = (buffer: Buffer) => {
|
|
107
|
+
return buffer
|
|
108
|
+
.toString("base64")
|
|
109
|
+
.replace(/\+/g, "-")
|
|
110
|
+
.replace(/\//g, "_")
|
|
111
|
+
.replace(/=/g, "");
|
|
112
|
+
};
|
|
113
|
+
|
|
114
|
+
export const decodeJwt = (token: string): JwtPayload | null => {
|
|
115
|
+
try {
|
|
116
|
+
const parts = token.split(".");
|
|
117
|
+
if (parts.length !== 3) return null;
|
|
118
|
+
const payload = parts[1] ?? "";
|
|
119
|
+
const decoded = atob(payload);
|
|
120
|
+
return JSON.parse(decoded) as JwtPayload;
|
|
121
|
+
} catch {
|
|
122
|
+
return null;
|
|
123
|
+
}
|
|
124
|
+
};
|
|
125
|
+
|
|
126
|
+
const getTokenExpiry = (token: string): number => {
|
|
127
|
+
try {
|
|
128
|
+
const decoded = decodeJwt(token);
|
|
129
|
+
return decoded!.exp * 1000 - 5 * 60 * 1000;
|
|
130
|
+
} catch {
|
|
131
|
+
return Date.now() + 3600 * 1000;
|
|
132
|
+
}
|
|
133
|
+
};
|
|
134
|
+
|
|
135
|
+
export default AuthManager;
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
type BackoffConfig = {
|
|
2
|
+
retries: number;
|
|
3
|
+
delay: number;
|
|
4
|
+
shouldRetry?: (error: unknown, attempt: number) => boolean;
|
|
5
|
+
};
|
|
6
|
+
|
|
7
|
+
const sleep = (ms: number): Promise<void> => {
|
|
8
|
+
return new Promise((resolve) => setTimeout(resolve, ms));
|
|
9
|
+
};
|
|
10
|
+
|
|
11
|
+
export const backoff = async <T>(
|
|
12
|
+
fn: () => Promise<T>,
|
|
13
|
+
{ retries, delay, shouldRetry }: BackoffConfig,
|
|
14
|
+
): Promise<T> => {
|
|
15
|
+
let lastError: unknown;
|
|
16
|
+
|
|
17
|
+
for (let attempt = 0; attempt < retries; attempt++) {
|
|
18
|
+
try {
|
|
19
|
+
return await fn();
|
|
20
|
+
} catch (error) {
|
|
21
|
+
lastError = error;
|
|
22
|
+
|
|
23
|
+
if (shouldRetry && !shouldRetry(error, attempt)) {
|
|
24
|
+
throw error;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
await sleep(delay);
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
throw lastError;
|
|
32
|
+
};
|
package/src/lib/env.ts
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
export const heartbeat = (
|
|
2
|
+
fn: () => Promise<unknown>,
|
|
3
|
+
interval: number,
|
|
4
|
+
): (() => void) => {
|
|
5
|
+
let timeout: ReturnType<typeof setTimeout> | undefined;
|
|
6
|
+
|
|
7
|
+
const schedule = () => {
|
|
8
|
+
timeout = setTimeout(() => {
|
|
9
|
+
fn().then(schedule, () => {});
|
|
10
|
+
}, interval);
|
|
11
|
+
};
|
|
12
|
+
|
|
13
|
+
schedule();
|
|
14
|
+
|
|
15
|
+
return () => {
|
|
16
|
+
if (timeout !== undefined) {
|
|
17
|
+
clearTimeout(timeout);
|
|
18
|
+
timeout = undefined;
|
|
19
|
+
}
|
|
20
|
+
};
|
|
21
|
+
};
|