pi-llama-cpp 0.3.4 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -7
- package/package.json +4 -5
- package/src/commands/models.ts +41 -4
- package/src/constants.ts +5 -5
- package/src/events.ts +5 -2
- package/src/index.ts +26 -20
- package/src/interfaces/endpoints/models.ts +6 -0
- package/src/manager.ts +93 -0
- package/src/models/baseModel.ts +11 -30
- package/src/models/routerModel.ts +11 -8
- package/src/models/singleModel.ts +10 -0
- package/src/tools/retriever.ts +1 -1
- package/tests/commandManager.test.ts +152 -0
- package/tests/handlers.test.ts +7 -2
- package/tests/modelsCommand.test.ts +270 -0
- package/tests/routerModel.test.ts +25 -84
- package/tests/singleModel.test.ts +16 -29
- package/src/tools/provider.ts +0 -28
package/README.md
CHANGED
|
@@ -99,17 +99,21 @@ llama-server --model path/to/model.gguf ...
|
|
|
99
99
|
```
|
|
100
100
|
|
|
101
101
|
The extension determines the context size as follows:
|
|
102
|
-
- **Router mode**
|
|
102
|
+
- **Router mode**
|
|
103
|
+
- When loaded, reads `meta.n_ctx` from the `/models` endpoint
|
|
104
|
+
- When not loaded, reads `--ctx-size` and/or `--fit-ctx` from the server arguments, or `ctx-size` and/or `fit-ctx` keys from the **presets.ini** file.
|
|
103
105
|
- **Single mode** — reads `meta.n_ctx` from the `/models` endpoint
|
|
104
106
|
- Falls back to `128000` if not available
|
|
105
107
|
|
|
106
108
|
### Commands
|
|
107
109
|
|
|
108
|
-
| Command
|
|
109
|
-
|
|
|
110
|
-
| `/models`
|
|
110
|
+
| Command | Description |
|
|
111
|
+
| ---------------- | ------------------------------------------------------------------------------------------ |
|
|
112
|
+
| `/models` | Browse your models with live status. Select a model to load, switch, or unload it. |
|
|
113
|
+
| `/models info` | Show detailed information for all available models at once. |
|
|
114
|
+
| `/models unload` | Unload all loaded models at once (Note: this only makes sense in router mode). |
|
|
111
115
|
|
|
112
|
-
> **Note:** When the llama.cpp server is unreachable, `/models`
|
|
116
|
+
> **Note:** When the llama.cpp server is unreachable, `/models` displays an error notification with the configured server URL.
|
|
113
117
|
|
|
114
118
|
### Model Actions
|
|
115
119
|
|
|
@@ -139,9 +143,9 @@ When you trigger a load, switch, or retry action, the extension polls the server
|
|
|
139
143
|
|
|
140
144
|
Each model exposed to Pi includes the following defaults:
|
|
141
145
|
|
|
142
|
-
- **`maxTokens`** —
|
|
146
|
+
- **`maxTokens`** — dynamically set to the model's context window (detected from llama-server)
|
|
143
147
|
- **`reasoning`** — `true` (assumed, as llama.cpp's `/models` endpoint does not expose it)
|
|
144
|
-
- **`cost`** — all zero (local
|
|
148
|
+
- **`cost`** — all zero (local models)
|
|
145
149
|
|
|
146
150
|
## Dependencies
|
|
147
151
|
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "pi-llama-cpp",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.5.0",
|
|
4
4
|
"description": "Pi extension for llama.cpp integration. Supports both router and single modes.",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"pi",
|
|
@@ -24,8 +24,7 @@
|
|
|
24
24
|
]
|
|
25
25
|
},
|
|
26
26
|
"scripts": {
|
|
27
|
-
"test": "vitest"
|
|
28
|
-
"test:run": "vitest run"
|
|
27
|
+
"test": "vitest run"
|
|
29
28
|
},
|
|
30
29
|
"prettier": {
|
|
31
30
|
"plugins": [
|
|
@@ -36,8 +35,8 @@
|
|
|
36
35
|
"@earendil-works/pi-coding-agent": "*"
|
|
37
36
|
},
|
|
38
37
|
"devDependencies": {
|
|
39
|
-
"@types/node": "^25.
|
|
38
|
+
"@types/node": "^25.9.1",
|
|
40
39
|
"prettier-plugin-organize-imports": "^4.3.0",
|
|
41
|
-
"vitest": "^4.1.
|
|
40
|
+
"vitest": "^4.1.7"
|
|
42
41
|
}
|
|
43
42
|
}
|
package/src/commands/models.ts
CHANGED
|
@@ -1,14 +1,43 @@
|
|
|
1
1
|
import type {
|
|
2
2
|
ExtensionAPI,
|
|
3
3
|
ExtensionCommandContext,
|
|
4
|
+
ExtensionContext,
|
|
5
|
+
SessionBeforeSwitchEvent,
|
|
4
6
|
} from "@earendil-works/pi-coding-agent";
|
|
5
|
-
import { PROVIDER_ID, PROVIDER_NAME } from "../constants";
|
|
7
|
+
import { PROVIDER_ID, PROVIDER_NAME, READABLE_TIMEOUT } from "../constants";
|
|
6
8
|
import { Action } from "../enums/action";
|
|
7
9
|
import { Mode } from "../enums/mode";
|
|
8
10
|
import { Status } from "../enums/status";
|
|
9
11
|
import { BaseModel } from "../models/baseModel";
|
|
10
12
|
import { resolveUrl } from "../tools/resolver";
|
|
11
13
|
|
|
14
|
+
// In-flight model reference — handler gates on this.
|
|
15
|
+
let inflightModel: BaseModel | null = null;
|
|
16
|
+
|
|
17
|
+
export const resetInflightModel = () => (inflightModel = null);
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* Session-switch handler. Registered once at extension init.
|
|
21
|
+
* Only notifies if a model load is actually in-flight.
|
|
22
|
+
*/
|
|
23
|
+
export const onSessionBeforeSwitch = async (
|
|
24
|
+
_event: SessionBeforeSwitchEvent,
|
|
25
|
+
ctx: ExtensionContext,
|
|
26
|
+
) => {
|
|
27
|
+
if (!inflightModel) return;
|
|
28
|
+
|
|
29
|
+
const messages = [
|
|
30
|
+
`Session change detected while model '${inflightModel.name}' was still loading.`,
|
|
31
|
+
"Model load will continue in the background, but UI might not update.",
|
|
32
|
+
"",
|
|
33
|
+
"Verify that your new model is loaded, or use /models to re-select it afterwards.",
|
|
34
|
+
];
|
|
35
|
+
ctx.ui.notify(messages.join("\n"), "warning");
|
|
36
|
+
|
|
37
|
+
// Show the notification for a reasonable amount of time
|
|
38
|
+
await new Promise((r) => setTimeout(r, READABLE_TIMEOUT));
|
|
39
|
+
};
|
|
40
|
+
|
|
12
41
|
/**
|
|
13
42
|
* Select a model from the list. Returns null if user cancels.
|
|
14
43
|
*
|
|
@@ -130,8 +159,10 @@ export const notFoundCommand = async (
|
|
|
130
159
|
/**
|
|
131
160
|
* Handles the /models command
|
|
132
161
|
*
|
|
162
|
+
* @param args Arguments passed to the command
|
|
133
163
|
* @param ctx The context used by Pi
|
|
134
164
|
* @param pi The Pi extension
|
|
165
|
+
* @param models List of available models
|
|
135
166
|
*/
|
|
136
167
|
export const modelsCommand = async (
|
|
137
168
|
ctx: ExtensionCommandContext,
|
|
@@ -165,6 +196,7 @@ export const modelsCommand = async (
|
|
|
165
196
|
const loadActions = [Action.LOAD, Action.SWITCH, Action.RETRY];
|
|
166
197
|
if (loadActions.includes(action)) {
|
|
167
198
|
ctx.ui.notify(`Loading ${model.name}...`, "info");
|
|
199
|
+
inflightModel = model;
|
|
168
200
|
|
|
169
201
|
const onSuccess = async () => {
|
|
170
202
|
const piModel = ctx.modelRegistry.find(PROVIDER_ID, model.id);
|
|
@@ -173,7 +205,7 @@ export const modelsCommand = async (
|
|
|
173
205
|
}
|
|
174
206
|
|
|
175
207
|
if ((await model.getStatus()) === Status.FAILED) {
|
|
176
|
-
throw new Error(
|
|
208
|
+
throw new Error(`Failed to load model ${model.name}`);
|
|
177
209
|
}
|
|
178
210
|
|
|
179
211
|
await pi.setModel(piModel);
|
|
@@ -182,10 +214,15 @@ export const modelsCommand = async (
|
|
|
182
214
|
|
|
183
215
|
const onFailure = (err: any) => {
|
|
184
216
|
const message = err instanceof Error ? err.message : String(err);
|
|
185
|
-
|
|
217
|
+
|
|
218
|
+
try {
|
|
219
|
+
ctx.ui.notify(message, "error");
|
|
220
|
+
} catch {
|
|
221
|
+
// ctx went stale between error and notification
|
|
222
|
+
}
|
|
186
223
|
};
|
|
187
224
|
|
|
188
225
|
// Load the model without blocking the UI
|
|
189
|
-
model.load().then(onSuccess).catch(onFailure);
|
|
226
|
+
model.load().then(onSuccess).catch(onFailure).finally(resetInflightModel);
|
|
190
227
|
}
|
|
191
228
|
};
|
package/src/constants.ts
CHANGED
|
@@ -23,11 +23,6 @@ export const API_KEY_PLACEHOLDER = "sk-placeholder";
|
|
|
23
23
|
*/
|
|
24
24
|
export const DEFAULT_CTX = 128000;
|
|
25
25
|
|
|
26
|
-
/**
|
|
27
|
-
* Maximum number of tokens a model can generate in a single response
|
|
28
|
-
*/
|
|
29
|
-
export const MAX_TOKENS = 32000;
|
|
30
|
-
|
|
31
26
|
/**
|
|
32
27
|
* Polling interval (ms) for checking model load status
|
|
33
28
|
*/
|
|
@@ -37,3 +32,8 @@ export const POLLING_INTERVAL = 500;
|
|
|
37
32
|
* Maximum time (ms) to wait for model loading before giving up
|
|
38
33
|
*/
|
|
39
34
|
export const POLLING_TIMEOUT = 60000;
|
|
35
|
+
|
|
36
|
+
/**
|
|
37
|
+
* Reasonable time to read notifications if context goes stale
|
|
38
|
+
*/
|
|
39
|
+
export const READABLE_TIMEOUT = 15000;
|
package/src/events.ts
CHANGED
|
@@ -18,6 +18,9 @@ export const onModelSelect = async (
|
|
|
18
18
|
const model = models.find((m) => m.id === event.model.id);
|
|
19
19
|
if (!model) return;
|
|
20
20
|
|
|
21
|
-
ctx.ui.notify(
|
|
22
|
-
await model
|
|
21
|
+
ctx.ui.notify(`Loading ${model.name}...`, "info");
|
|
22
|
+
await model
|
|
23
|
+
.load()
|
|
24
|
+
.then(() => ctx.ui.notify(`Model ${model.name} ready`, "info"))
|
|
25
|
+
.catch(() => ctx.ui.notify(`Failed to load model ${model.name}`, "error"));
|
|
23
26
|
};
|
package/src/index.ts
CHANGED
|
@@ -2,35 +2,41 @@ import type {
|
|
|
2
2
|
ExtensionAPI,
|
|
3
3
|
ExtensionCommandContext,
|
|
4
4
|
} from "@earendil-works/pi-coding-agent";
|
|
5
|
-
import {
|
|
5
|
+
import type { AutocompleteItem } from "@earendil-works/pi-tui";
|
|
6
|
+
import { onSessionBeforeSwitch } from "./commands/models";
|
|
6
7
|
import { PROVIDER_NAME } from "./constants";
|
|
7
8
|
import { onModelSelect } from "./events";
|
|
8
|
-
import {
|
|
9
|
-
import { isServerReady } from "./tools/retriever";
|
|
9
|
+
import { CommandManager } from "./manager";
|
|
10
10
|
|
|
11
11
|
export default async function (pi: ExtensionAPI) {
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
pi.registerCommand("models", {
|
|
15
|
-
description: `${PROVIDER_NAME} models (offline)`,
|
|
16
|
-
handler: async (_: string, ctx: ExtensionCommandContext) => {
|
|
17
|
-
await notFoundCommand(ctx);
|
|
18
|
-
},
|
|
19
|
-
});
|
|
20
|
-
|
|
21
|
-
return;
|
|
22
|
-
}
|
|
23
|
-
|
|
24
|
-
// Provider registration
|
|
25
|
-
const serverModels = await registerLlamaCppProvider(pi);
|
|
12
|
+
const manager = new CommandManager(pi);
|
|
13
|
+
await manager.initialize();
|
|
26
14
|
|
|
27
15
|
// Command: /models
|
|
28
16
|
pi.registerCommand("models", {
|
|
29
|
-
description: `Browse ${PROVIDER_NAME} models
|
|
30
|
-
|
|
31
|
-
|
|
17
|
+
description: `Browse ${PROVIDER_NAME} models`,
|
|
18
|
+
getArgumentCompletions: (prefix: string): AutocompleteItem[] | null => {
|
|
19
|
+
const available = [
|
|
20
|
+
{
|
|
21
|
+
value: "info",
|
|
22
|
+
label: "info",
|
|
23
|
+
description: "Show information of all models",
|
|
24
|
+
},
|
|
25
|
+
{
|
|
26
|
+
value: "unload",
|
|
27
|
+
label: "unload",
|
|
28
|
+
description: "Unload all models",
|
|
29
|
+
},
|
|
30
|
+
];
|
|
31
|
+
|
|
32
|
+
const filtered = available.filter((a) => a.value.startsWith(prefix));
|
|
33
|
+
return filtered.length > 0 ? filtered : null;
|
|
34
|
+
},
|
|
35
|
+
handler: async (args: string, ctx: ExtensionCommandContext) =>
|
|
36
|
+
await manager.run(args, ctx),
|
|
32
37
|
});
|
|
33
38
|
|
|
34
39
|
// Events registration
|
|
35
40
|
pi.on("model_select", onModelSelect);
|
|
41
|
+
pi.on("session_before_switch", onSessionBeforeSwitch);
|
|
36
42
|
}
|
|
@@ -39,6 +39,7 @@ export interface DataProperty {
|
|
|
39
39
|
owned_by: string;
|
|
40
40
|
created: number;
|
|
41
41
|
status?: StatusProperty;
|
|
42
|
+
architecture?: ArchitectureProperty;
|
|
42
43
|
meta?: MetaProperty;
|
|
43
44
|
}
|
|
44
45
|
|
|
@@ -50,6 +51,11 @@ interface StatusProperty {
|
|
|
50
51
|
failed?: boolean;
|
|
51
52
|
}
|
|
52
53
|
|
|
54
|
+
interface ArchitectureProperty {
|
|
55
|
+
input_modalities: ("text" | "image" | "audio")[];
|
|
56
|
+
output_modalities: ["text"];
|
|
57
|
+
}
|
|
58
|
+
|
|
53
59
|
interface MetaProperty {
|
|
54
60
|
vocab_type: number;
|
|
55
61
|
n_vocab: number;
|
package/src/manager.ts
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import type {
|
|
2
|
+
ExtensionAPI,
|
|
3
|
+
ExtensionCommandContext,
|
|
4
|
+
ProviderModelConfig,
|
|
5
|
+
} from "@earendil-works/pi-coding-agent";
|
|
6
|
+
import { modelsCommand, notFoundCommand } from "./commands/models";
|
|
7
|
+
import {
|
|
8
|
+
DEFAULT_LLAMA_SERVER_URL,
|
|
9
|
+
PROVIDER_ID,
|
|
10
|
+
PROVIDER_NAME,
|
|
11
|
+
} from "./constants";
|
|
12
|
+
import { BaseModel } from "./models/baseModel";
|
|
13
|
+
import { resolveApiKey, resolveUrl } from "./tools/resolver";
|
|
14
|
+
import { isServerReady, listModels } from "./tools/retriever";
|
|
15
|
+
|
|
16
|
+
export class CommandManager {
|
|
17
|
+
private baseUrl: string = DEFAULT_LLAMA_SERVER_URL;
|
|
18
|
+
private serverModels: BaseModel[] = [];
|
|
19
|
+
|
|
20
|
+
constructor(private readonly pi: ExtensionAPI) {}
|
|
21
|
+
|
|
22
|
+
/**
|
|
23
|
+
* Sets up the initial state of the provider
|
|
24
|
+
*/
|
|
25
|
+
async initialize() {
|
|
26
|
+
if (await isServerReady()) {
|
|
27
|
+
await this.update();
|
|
28
|
+
} else {
|
|
29
|
+
await this.register([]);
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/**
|
|
34
|
+
* Ensures the models are up-to-date with the server
|
|
35
|
+
*/
|
|
36
|
+
async update() {
|
|
37
|
+
this.baseUrl = `${await resolveUrl(process.cwd())}`;
|
|
38
|
+
|
|
39
|
+
this.serverModels = await listModels();
|
|
40
|
+
const modelConfigs = await Promise.all(
|
|
41
|
+
this.serverModels.map((m) => m.toProviderConfig()),
|
|
42
|
+
);
|
|
43
|
+
|
|
44
|
+
await this.register(modelConfigs);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
/**
|
|
48
|
+
* Registers the provider in Pi with the given configurations
|
|
49
|
+
* Note: Registrations overload previous provider
|
|
50
|
+
*
|
|
51
|
+
* @param models Provider configurations for the models
|
|
52
|
+
*/
|
|
53
|
+
async register(models: ProviderModelConfig[]) {
|
|
54
|
+
this.pi.registerProvider(PROVIDER_ID, {
|
|
55
|
+
name: PROVIDER_NAME,
|
|
56
|
+
baseUrl: this.baseUrl,
|
|
57
|
+
api: "openai-completions",
|
|
58
|
+
apiKey: await resolveApiKey(),
|
|
59
|
+
models,
|
|
60
|
+
});
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* Dispatches the /models command
|
|
65
|
+
*
|
|
66
|
+
* @param args Arguments passed to the command
|
|
67
|
+
* @param ctx The context used by Pi
|
|
68
|
+
* @param pi The Pi extension
|
|
69
|
+
*/
|
|
70
|
+
async run(args: string, ctx: ExtensionCommandContext) {
|
|
71
|
+
if (!(await isServerReady())) {
|
|
72
|
+
return await notFoundCommand(ctx);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
// Command: `/models info`
|
|
76
|
+
if (args === "info") {
|
|
77
|
+
const info = await Promise.all(this.serverModels.map((m) => m.getInfo()));
|
|
78
|
+
const message = ctx.ui.theme.fg("accent", info.join("\n"));
|
|
79
|
+
ctx.ui.notify(message, "info");
|
|
80
|
+
return;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
// Command: `/models unload`
|
|
84
|
+
if (args === "unload") {
|
|
85
|
+
await Promise.all(this.serverModels.map((m) => m.unload()));
|
|
86
|
+
ctx.ui.notify(`Unloaded all ${PROVIDER_NAME} models`, "info");
|
|
87
|
+
return;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Command: `/models` (interactive menu)
|
|
91
|
+
return await modelsCommand(ctx, this.pi, this.serverModels);
|
|
92
|
+
}
|
|
93
|
+
}
|
package/src/models/baseModel.ts
CHANGED
|
@@ -1,13 +1,8 @@
|
|
|
1
1
|
import type { ProviderModelConfig } from "@earendil-works/pi-coding-agent";
|
|
2
|
-
import {
|
|
3
|
-
DEFAULT_CTX,
|
|
4
|
-
MAX_TOKENS,
|
|
5
|
-
POLLING_INTERVAL,
|
|
6
|
-
POLLING_TIMEOUT,
|
|
7
|
-
} from "../constants";
|
|
2
|
+
import { POLLING_INTERVAL, POLLING_TIMEOUT } from "../constants";
|
|
8
3
|
import { Mode } from "../enums/mode";
|
|
9
4
|
import { Status } from "../enums/status";
|
|
10
|
-
import { DataProperty } from "../interfaces/endpoints/models";
|
|
5
|
+
import { DataProperty, ModelsEndpoint } from "../interfaces/endpoints/models";
|
|
11
6
|
import { PropsEndpoint } from "../interfaces/endpoints/props";
|
|
12
7
|
import { rpc } from "../tools/retriever";
|
|
13
8
|
|
|
@@ -55,17 +50,7 @@ export abstract class BaseModel {
|
|
|
55
50
|
*
|
|
56
51
|
* @returns An array of capabilities, as expected by Pi
|
|
57
52
|
*/
|
|
58
|
-
|
|
59
|
-
try {
|
|
60
|
-
const { modalities } = await rpc<PropsEndpoint>(
|
|
61
|
-
`/props?model=${this.id}`,
|
|
62
|
-
);
|
|
63
|
-
|
|
64
|
-
return modalities.vision ? ["image"] : ["text"];
|
|
65
|
-
} catch {
|
|
66
|
-
return ["text"];
|
|
67
|
-
}
|
|
68
|
-
}
|
|
53
|
+
abstract getCapabilities(): Promise<("text" | "image")[]>;
|
|
69
54
|
|
|
70
55
|
/**
|
|
71
56
|
* Gets the load status of the model
|
|
@@ -75,7 +60,7 @@ export abstract class BaseModel {
|
|
|
75
60
|
public async getStatus(): Promise<Status> {
|
|
76
61
|
try {
|
|
77
62
|
const { is_sleeping, error } = await rpc<PropsEndpoint>(
|
|
78
|
-
`/props?model=${this.id}`,
|
|
63
|
+
`/props?model=${this.id}&autoload=false`,
|
|
79
64
|
);
|
|
80
65
|
|
|
81
66
|
if (is_sleeping) return Status.SLEEPING;
|
|
@@ -96,15 +81,10 @@ export abstract class BaseModel {
|
|
|
96
81
|
* @returns The detected context size
|
|
97
82
|
*/
|
|
98
83
|
async getContextSize(): Promise<number> {
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
const { n_ctx } = default_generation_settings;
|
|
104
|
-
return n_ctx;
|
|
105
|
-
} catch {
|
|
106
|
-
return DEFAULT_CTX;
|
|
107
|
-
}
|
|
84
|
+
const { data } = await rpc<ModelsEndpoint>("/models");
|
|
85
|
+
const { n_ctx } = data.find((m) => m.id === this.id)?.meta!;
|
|
86
|
+
|
|
87
|
+
return n_ctx;
|
|
108
88
|
}
|
|
109
89
|
|
|
110
90
|
/**
|
|
@@ -147,7 +127,7 @@ export abstract class BaseModel {
|
|
|
147
127
|
input: await this.getCapabilities(),
|
|
148
128
|
contextWindow: await this.getContextSize(),
|
|
149
129
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
|
150
|
-
maxTokens:
|
|
130
|
+
maxTokens: await this.getContextSize(),
|
|
151
131
|
};
|
|
152
132
|
|
|
153
133
|
return response;
|
|
@@ -157,7 +137,8 @@ export abstract class BaseModel {
|
|
|
157
137
|
* Loads the model in llama-server
|
|
158
138
|
*/
|
|
159
139
|
async load(): Promise<void> {
|
|
160
|
-
|
|
140
|
+
const status = await this.getStatus();
|
|
141
|
+
if (status === Status.LOADED || status === Status.SLEEPING) return;
|
|
161
142
|
|
|
162
143
|
await rpc("/models/load", { model: this.id });
|
|
163
144
|
await this.pollStatus();
|
|
@@ -50,7 +50,7 @@ export class RouterModel extends BaseModel {
|
|
|
50
50
|
// Grab the glitch
|
|
51
51
|
while (Date.now() - startTime <= limit) {
|
|
52
52
|
try {
|
|
53
|
-
await rpc<PropsEndpoint>(`/props?model=${this.id}`);
|
|
53
|
+
await rpc<PropsEndpoint>(`/props?model=${this.id}&autoload=false`);
|
|
54
54
|
break;
|
|
55
55
|
} catch {
|
|
56
56
|
elapsed += POLLING_INTERVAL;
|
|
@@ -62,14 +62,17 @@ export class RouterModel extends BaseModel {
|
|
|
62
62
|
return await super.pollStatus(startTime, timeout);
|
|
63
63
|
}
|
|
64
64
|
|
|
65
|
-
async getCapabilities(): Promise<
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
}
|
|
65
|
+
async getCapabilities(): Promise<("text" | "image")[]> {
|
|
66
|
+
const { data } = await rpc<ModelsEndpoint>(`/models`);
|
|
67
|
+
const model = data.find((d) => d.id === this.id);
|
|
68
|
+
if (!model) return ["text"];
|
|
70
69
|
|
|
71
|
-
const
|
|
72
|
-
|
|
70
|
+
const { input_modalities } = model.architecture!;
|
|
71
|
+
const response = input_modalities.filter(
|
|
72
|
+
(mod) => mod === "text" || mod === "image",
|
|
73
|
+
);
|
|
74
|
+
|
|
75
|
+
return response;
|
|
73
76
|
}
|
|
74
77
|
|
|
75
78
|
async getContextSize(): Promise<number> {
|
|
@@ -1,8 +1,18 @@
|
|
|
1
1
|
import { Mode } from "../enums/mode";
|
|
2
|
+
import { ModelsEndpoint } from "../interfaces/endpoints/models";
|
|
3
|
+
import { rpc } from "../tools/retriever";
|
|
2
4
|
import { BaseModel } from "./baseModel";
|
|
3
5
|
|
|
4
6
|
export class SingleModel extends BaseModel {
|
|
5
7
|
get mode(): Mode {
|
|
6
8
|
return Mode.SINGLE;
|
|
7
9
|
}
|
|
10
|
+
|
|
11
|
+
async getCapabilities(): Promise<("text" | "image")[]> {
|
|
12
|
+
const { models } = await rpc<ModelsEndpoint>(`/models`);
|
|
13
|
+
const [model] = models!;
|
|
14
|
+
|
|
15
|
+
const hasImage = model.capabilities.includes("multimodal");
|
|
16
|
+
return hasImage ? ["text", "image"] : ["text"];
|
|
17
|
+
}
|
|
8
18
|
}
|
package/src/tools/retriever.ts
CHANGED
|
@@ -28,7 +28,7 @@ export const isServerReady = async (): Promise<boolean> => {
|
|
|
28
28
|
export const rpc = async <T>(
|
|
29
29
|
endpoint: string,
|
|
30
30
|
body?: Record<string, unknown>,
|
|
31
|
-
) => {
|
|
31
|
+
): Promise<T> => {
|
|
32
32
|
const base = await resolveUrl(process.cwd());
|
|
33
33
|
const url = `${base}${endpoint}`;
|
|
34
34
|
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import { beforeEach, describe, expect, it, vi } from "vitest";
|
|
2
|
+
import { PROVIDER_ID, PROVIDER_NAME } from "../src/constants";
|
|
3
|
+
import { CommandManager } from "../src/manager";
|
|
4
|
+
|
|
5
|
+
// Mock modules at top level (vi.mock is hoisted)
|
|
6
|
+
vi.mock("../src/tools/retriever", () => ({
|
|
7
|
+
isServerReady: vi.fn(),
|
|
8
|
+
listModels: vi.fn(),
|
|
9
|
+
}));
|
|
10
|
+
|
|
11
|
+
vi.mock("../src/tools/resolver", () => ({
|
|
12
|
+
resolveUrl: vi.fn(),
|
|
13
|
+
resolveApiKey: vi.fn(),
|
|
14
|
+
}));
|
|
15
|
+
|
|
16
|
+
// Import mocked functions after vi.mock
|
|
17
|
+
import { resolveApiKey, resolveUrl } from "../src/tools/resolver";
|
|
18
|
+
import { isServerReady, listModels } from "../src/tools/retriever";
|
|
19
|
+
|
|
20
|
+
const mockPi = {
|
|
21
|
+
registerProvider: vi.fn(),
|
|
22
|
+
};
|
|
23
|
+
|
|
24
|
+
beforeEach(() => {
|
|
25
|
+
vi.clearAllMocks();
|
|
26
|
+
(resolveUrl as any).mockResolvedValue("http://127.0.0.1:8080");
|
|
27
|
+
(resolveApiKey as any).mockResolvedValue("test-key");
|
|
28
|
+
});
|
|
29
|
+
|
|
30
|
+
describe("CommandManager", () => {
|
|
31
|
+
it("should register empty models when server is not ready", async () => {
|
|
32
|
+
(isServerReady as any).mockResolvedValue(false);
|
|
33
|
+
|
|
34
|
+
const manager = new CommandManager(mockPi as any);
|
|
35
|
+
await manager.initialize();
|
|
36
|
+
|
|
37
|
+
expect(mockPi.registerProvider).toHaveBeenCalledWith(PROVIDER_ID, {
|
|
38
|
+
name: PROVIDER_NAME,
|
|
39
|
+
baseUrl: "http://127.0.0.1:8080",
|
|
40
|
+
api: "openai-completions",
|
|
41
|
+
apiKey: "test-key",
|
|
42
|
+
models: [],
|
|
43
|
+
});
|
|
44
|
+
});
|
|
45
|
+
|
|
46
|
+
it("should update and register models when server is ready", async () => {
|
|
47
|
+
const mockModel = {
|
|
48
|
+
name: "test-model",
|
|
49
|
+
id: "test-model",
|
|
50
|
+
toProviderConfig: vi
|
|
51
|
+
.fn()
|
|
52
|
+
.mockResolvedValue({ id: "test-model", maxTokens: 32000 }),
|
|
53
|
+
};
|
|
54
|
+
(isServerReady as any).mockResolvedValue(true);
|
|
55
|
+
(listModels as any).mockResolvedValue([mockModel]);
|
|
56
|
+
|
|
57
|
+
const manager = new CommandManager(mockPi as any);
|
|
58
|
+
await manager.initialize();
|
|
59
|
+
|
|
60
|
+
expect(resolveUrl).toHaveBeenCalledWith(expect.any(String));
|
|
61
|
+
expect(listModels).toHaveBeenCalled();
|
|
62
|
+
expect(mockPi.registerProvider).toHaveBeenCalledWith(PROVIDER_ID, {
|
|
63
|
+
name: PROVIDER_NAME,
|
|
64
|
+
baseUrl: "http://127.0.0.1:8080",
|
|
65
|
+
api: "openai-completions",
|
|
66
|
+
apiKey: "test-key",
|
|
67
|
+
models: [{ id: "test-model", maxTokens: 32000 }],
|
|
68
|
+
});
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
it("should call notFoundCommand when server is not ready in run()", async () => {
|
|
72
|
+
(isServerReady as any).mockResolvedValue(false);
|
|
73
|
+
|
|
74
|
+
const manager = new CommandManager(mockPi as any);
|
|
75
|
+
await manager.run("", { ui: { notify: vi.fn() } } as any);
|
|
76
|
+
|
|
77
|
+
expect(mockPi.registerProvider).not.toHaveBeenCalled();
|
|
78
|
+
});
|
|
79
|
+
|
|
80
|
+
it("should show info for all models when args is 'info'", async () => {
|
|
81
|
+
const mockModel = {
|
|
82
|
+
name: "test-model",
|
|
83
|
+
id: "test-model",
|
|
84
|
+
getInfo: vi.fn().mockResolvedValue("Model info for test-model"),
|
|
85
|
+
toProviderConfig: vi.fn().mockResolvedValue({ id: "test-model" }),
|
|
86
|
+
};
|
|
87
|
+
(isServerReady as any).mockResolvedValue(true);
|
|
88
|
+
(listModels as any).mockResolvedValue([mockModel]);
|
|
89
|
+
|
|
90
|
+
const notifyFn = vi.fn();
|
|
91
|
+
const manager = new CommandManager(mockPi as any);
|
|
92
|
+
await manager.initialize();
|
|
93
|
+
await manager.run("info", {
|
|
94
|
+
ui: { notify: notifyFn, theme: { fg: (_c: string, t: string) => t } },
|
|
95
|
+
} as any);
|
|
96
|
+
|
|
97
|
+
expect(notifyFn).toHaveBeenCalledWith("Model info for test-model", "info");
|
|
98
|
+
expect(listModels).toHaveBeenCalledOnce();
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
it("should unload all models when args is 'unload'", async () => {
|
|
102
|
+
const mockModel1 = {
|
|
103
|
+
name: "model-1",
|
|
104
|
+
id: "model-1",
|
|
105
|
+
unload: vi.fn().mockResolvedValue(undefined),
|
|
106
|
+
toProviderConfig: vi.fn().mockResolvedValue({ id: "model-1" }),
|
|
107
|
+
};
|
|
108
|
+
const mockModel2 = {
|
|
109
|
+
name: "model-2",
|
|
110
|
+
id: "model-2",
|
|
111
|
+
unload: vi.fn().mockResolvedValue(undefined),
|
|
112
|
+
toProviderConfig: vi.fn().mockResolvedValue({ id: "model-2" }),
|
|
113
|
+
};
|
|
114
|
+
(isServerReady as any).mockResolvedValue(true);
|
|
115
|
+
(listModels as any).mockResolvedValue([mockModel1, mockModel2]);
|
|
116
|
+
|
|
117
|
+
const notifyFn = vi.fn();
|
|
118
|
+
const manager = new CommandManager(mockPi as any);
|
|
119
|
+
await manager.initialize();
|
|
120
|
+
await manager.run("unload", {
|
|
121
|
+
ui: { notify: notifyFn },
|
|
122
|
+
} as any);
|
|
123
|
+
|
|
124
|
+
expect(mockModel1.unload).toHaveBeenCalled();
|
|
125
|
+
expect(mockModel2.unload).toHaveBeenCalled();
|
|
126
|
+
expect(notifyFn).toHaveBeenCalledWith(
|
|
127
|
+
"Unloaded all Llama.cpp models",
|
|
128
|
+
"info",
|
|
129
|
+
);
|
|
130
|
+
});
|
|
131
|
+
|
|
132
|
+
it("should dispatch modelsCommand when args is empty", async () => {
|
|
133
|
+
const mockModel = {
|
|
134
|
+
name: "test-model",
|
|
135
|
+
id: "test-model",
|
|
136
|
+
getLabel: vi.fn().mockResolvedValue("test-model"),
|
|
137
|
+
toProviderConfig: vi.fn().mockResolvedValue({ id: "test-model" }),
|
|
138
|
+
};
|
|
139
|
+
(isServerReady as any).mockResolvedValue(true);
|
|
140
|
+
(listModels as any).mockResolvedValue([mockModel]);
|
|
141
|
+
|
|
142
|
+
const selectFn = vi.fn().mockReturnValue(null); // cancel immediately
|
|
143
|
+
const manager = new CommandManager(mockPi as any);
|
|
144
|
+
await manager.initialize();
|
|
145
|
+
await manager.run("", {
|
|
146
|
+
ui: { notify: vi.fn(), select: selectFn },
|
|
147
|
+
} as any);
|
|
148
|
+
|
|
149
|
+
// modelsCommand was called (select is invoked for model picking)
|
|
150
|
+
expect(selectFn).toHaveBeenCalled();
|
|
151
|
+
});
|
|
152
|
+
});
|
package/tests/handlers.test.ts
CHANGED
|
@@ -62,7 +62,12 @@ const getActionsForModel = async (model: TestModel): Promise<Array<Action>> => {
|
|
|
62
62
|
[Status.LOADED]: [Action.SWITCH, Action.UNLOAD, Action.INFO, Action.CANCEL],
|
|
63
63
|
[Status.LOADING]: [Action.INFO, Action.CANCEL],
|
|
64
64
|
[Status.FAILED]: [Action.RETRY, Action.CANCEL],
|
|
65
|
-
[Status.SLEEPING]: [
|
|
65
|
+
[Status.SLEEPING]: [
|
|
66
|
+
Action.SWITCH,
|
|
67
|
+
Action.UNLOAD,
|
|
68
|
+
Action.INFO,
|
|
69
|
+
Action.CANCEL,
|
|
70
|
+
],
|
|
66
71
|
[Status.UNLOADED]: [Action.LOAD, Action.CANCEL],
|
|
67
72
|
};
|
|
68
73
|
|
|
@@ -106,7 +111,7 @@ describe("Action availability", () => {
|
|
|
106
111
|
{
|
|
107
112
|
mode: Mode.ROUTER,
|
|
108
113
|
status: Status.SLEEPING,
|
|
109
|
-
expected: [Action.UNLOAD, Action.INFO, Action.CANCEL],
|
|
114
|
+
expected: [Action.SWITCH, Action.UNLOAD, Action.INFO, Action.CANCEL],
|
|
110
115
|
},
|
|
111
116
|
{
|
|
112
117
|
mode: Mode.ROUTER,
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
|
2
|
+
|
|
3
|
+
// Set up fake timers before any imports so setTimeout is mocked globally
|
|
4
|
+
vi.useFakeTimers();
|
|
5
|
+
|
|
6
|
+
import type { ExtensionContext } from "@earendil-works/pi-coding-agent";
|
|
7
|
+
import {
|
|
8
|
+
modelsCommand,
|
|
9
|
+
onSessionBeforeSwitch,
|
|
10
|
+
resetInflightModel,
|
|
11
|
+
} from "../src/commands/models";
|
|
12
|
+
import { Action } from "../src/enums/action";
|
|
13
|
+
import { Mode } from "../src/enums/mode";
|
|
14
|
+
import { Status } from "../src/enums/status";
|
|
15
|
+
import { BaseModel } from "../src/models/baseModel";
|
|
16
|
+
|
|
17
|
+
// Mock the retriever module
|
|
18
|
+
vi.mock("../src/tools/retriever", () => ({
|
|
19
|
+
rpc: vi.fn(),
|
|
20
|
+
isServerReady: vi.fn(),
|
|
21
|
+
listModels: vi.fn(),
|
|
22
|
+
}));
|
|
23
|
+
|
|
24
|
+
// Helper to create a mock BaseModel
|
|
25
|
+
const createMockModel = (
|
|
26
|
+
name: string,
|
|
27
|
+
overrides: Partial<BaseModel> = {},
|
|
28
|
+
): BaseModel =>
|
|
29
|
+
({
|
|
30
|
+
name,
|
|
31
|
+
id: name,
|
|
32
|
+
mode: Mode.ROUTER,
|
|
33
|
+
capabilities: ["text"] as ["text"],
|
|
34
|
+
getStatus: vi.fn().mockResolvedValue(Status.LOADED),
|
|
35
|
+
getContextSize: vi.fn().mockResolvedValue(4096),
|
|
36
|
+
getInfo: vi.fn().mockResolvedValue(`Model: ${name}\nID: ${name}`),
|
|
37
|
+
load: vi.fn().mockResolvedValue(undefined),
|
|
38
|
+
unload: vi.fn().mockResolvedValue(undefined),
|
|
39
|
+
toProviderConfig: vi.fn().mockResolvedValue({}),
|
|
40
|
+
getLabel: vi.fn().mockResolvedValue(name),
|
|
41
|
+
...overrides,
|
|
42
|
+
}) as unknown as BaseModel;
|
|
43
|
+
|
|
44
|
+
const createMockCtx = (
|
|
45
|
+
selectFn: (prompt: string, options: string[]) => string | null,
|
|
46
|
+
) => ({
|
|
47
|
+
cwd: "/tmp/test",
|
|
48
|
+
ui: {
|
|
49
|
+
select: vi.fn(selectFn),
|
|
50
|
+
notify: vi.fn(),
|
|
51
|
+
theme: {
|
|
52
|
+
fg: (color: string, text: string) => text,
|
|
53
|
+
},
|
|
54
|
+
},
|
|
55
|
+
modelRegistry: {
|
|
56
|
+
find: vi.fn().mockReturnValue({ id: "test-model-id" }),
|
|
57
|
+
},
|
|
58
|
+
});
|
|
59
|
+
|
|
60
|
+
const createMockPiContext = (notifyFn: ReturnType<typeof vi.fn>) =>
|
|
61
|
+
({
|
|
62
|
+
ui: {
|
|
63
|
+
notify: notifyFn,
|
|
64
|
+
},
|
|
65
|
+
}) as any as ExtensionContext;
|
|
66
|
+
|
|
67
|
+
const createMockPi = () => ({
|
|
68
|
+
setModel: vi.fn(),
|
|
69
|
+
registerProvider: vi.fn(),
|
|
70
|
+
});
|
|
71
|
+
|
|
72
|
+
beforeEach(() => {
|
|
73
|
+
vi.clearAllTimers();
|
|
74
|
+
resetInflightModel();
|
|
75
|
+
});
|
|
76
|
+
|
|
77
|
+
afterEach(() => {
|
|
78
|
+
vi.clearAllTimers();
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
describe("modelsCommand", () => {
|
|
82
|
+
it("should return early on cancel (null model selection)", async () => {
|
|
83
|
+
const models = [createMockModel("model-a")];
|
|
84
|
+
const ctx = createMockCtx(() => null);
|
|
85
|
+
const pi = createMockPi();
|
|
86
|
+
|
|
87
|
+
await modelsCommand(ctx as any, pi as any, models);
|
|
88
|
+
|
|
89
|
+
expect(ctx.ui.notify).not.toHaveBeenCalled();
|
|
90
|
+
});
|
|
91
|
+
|
|
92
|
+
it("should show info when INFO action is selected", async () => {
|
|
93
|
+
const model = createMockModel("model-a");
|
|
94
|
+
const models = [model];
|
|
95
|
+
const ctx = createMockCtx((prompt) => {
|
|
96
|
+
if (prompt.includes("models")) return "model-a";
|
|
97
|
+
return Action.INFO;
|
|
98
|
+
});
|
|
99
|
+
const pi = createMockPi();
|
|
100
|
+
|
|
101
|
+
await modelsCommand(ctx as any, pi as any, models);
|
|
102
|
+
|
|
103
|
+
expect(ctx.ui.notify).toHaveBeenCalledWith(
|
|
104
|
+
"Model: model-a\nID: model-a",
|
|
105
|
+
"info",
|
|
106
|
+
);
|
|
107
|
+
});
|
|
108
|
+
|
|
109
|
+
it("should unload model when UNLOAD action is selected", async () => {
|
|
110
|
+
const model = createMockModel("model-a");
|
|
111
|
+
const models = [model];
|
|
112
|
+
const ctx = createMockCtx((prompt) => {
|
|
113
|
+
if (prompt.includes("models")) return "model-a";
|
|
114
|
+
return Action.UNLOAD;
|
|
115
|
+
});
|
|
116
|
+
const pi = createMockPi();
|
|
117
|
+
|
|
118
|
+
await modelsCommand(ctx as any, pi as any, models);
|
|
119
|
+
|
|
120
|
+
expect(model.unload).toHaveBeenCalled();
|
|
121
|
+
expect(ctx.ui.notify).toHaveBeenCalledWith("Unloaded model-a", "info");
|
|
122
|
+
});
|
|
123
|
+
|
|
124
|
+
it("should load model when LOAD action is selected", async () => {
|
|
125
|
+
const loadFn = vi.fn().mockResolvedValue(undefined);
|
|
126
|
+
const model = createMockModel("model-a");
|
|
127
|
+
(model.load as any) = loadFn;
|
|
128
|
+
(model.getStatus as any).mockResolvedValue(Status.UNLOADED);
|
|
129
|
+
const models = [model];
|
|
130
|
+
const ctx = createMockCtx((prompt) => {
|
|
131
|
+
if (prompt.includes("models")) return "model-a";
|
|
132
|
+
return Action.LOAD;
|
|
133
|
+
});
|
|
134
|
+
const pi = createMockPi();
|
|
135
|
+
|
|
136
|
+
await modelsCommand(ctx as any, pi as any, models);
|
|
137
|
+
await vi.waitFor(() => expect(loadFn).toHaveBeenCalled());
|
|
138
|
+
await vi.waitFor(() => expect(pi.setModel).toHaveBeenCalled());
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
it("should show warning when session changes during model load", async () => {
|
|
142
|
+
// Create a deferred promise so we can control when the load completes
|
|
143
|
+
let resolveLoad: () => void;
|
|
144
|
+
const loadPromise = new Promise<void>((resolve) => {
|
|
145
|
+
resolveLoad = resolve;
|
|
146
|
+
});
|
|
147
|
+
const model = createMockModel("model-a", {
|
|
148
|
+
load: () => loadPromise,
|
|
149
|
+
getStatus: vi.fn().mockResolvedValue(Status.UNLOADED),
|
|
150
|
+
});
|
|
151
|
+
const models = [model];
|
|
152
|
+
|
|
153
|
+
let selectCallCount = 0;
|
|
154
|
+
const ctx = createMockCtx(() => {
|
|
155
|
+
selectCallCount++;
|
|
156
|
+
// 1st: select model, 2nd: select LOAD
|
|
157
|
+
if (selectCallCount === 1) return "model-a";
|
|
158
|
+
if (selectCallCount === 2) return Action.LOAD;
|
|
159
|
+
return null;
|
|
160
|
+
});
|
|
161
|
+
const pi = createMockPi();
|
|
162
|
+
|
|
163
|
+
// Start the load (non-blocking)
|
|
164
|
+
const modelsPromise = modelsCommand(ctx as any, pi as any, models);
|
|
165
|
+
|
|
166
|
+
// Advance past the microtask that sets inflightModel
|
|
167
|
+
await vi.advanceTimersByTimeAsync(0);
|
|
168
|
+
|
|
169
|
+
// Simulate session switch while model is still loading
|
|
170
|
+
// onSessionBeforeSwitch awaits READABLE_TIMEOUT (15s) for the notification
|
|
171
|
+
const switchPromise = onSessionBeforeSwitch(
|
|
172
|
+
{} as any,
|
|
173
|
+
createMockPiContext(ctx.ui.notify as any),
|
|
174
|
+
);
|
|
175
|
+
await vi.advanceTimersByTimeAsync(15000);
|
|
176
|
+
await switchPromise;
|
|
177
|
+
|
|
178
|
+
// Should have shown a warning notification
|
|
179
|
+
expect(ctx.ui.notify).toHaveBeenCalledWith(
|
|
180
|
+
expect.stringContaining("Session change detected"),
|
|
181
|
+
"warning",
|
|
182
|
+
);
|
|
183
|
+
expect(ctx.ui.notify).toHaveBeenCalledWith(
|
|
184
|
+
expect.stringContaining("model-a"),
|
|
185
|
+
"warning",
|
|
186
|
+
);
|
|
187
|
+
|
|
188
|
+
// Complete the load so inflightModel is cleared
|
|
189
|
+
resolveLoad!();
|
|
190
|
+
await modelsPromise;
|
|
191
|
+
});
|
|
192
|
+
|
|
193
|
+
it("should not warn when no model is loading", async () => {
|
|
194
|
+
const notifyFn = vi.fn();
|
|
195
|
+
const ctx = createMockPiContext(notifyFn);
|
|
196
|
+
|
|
197
|
+
await onSessionBeforeSwitch({} as any, ctx);
|
|
198
|
+
|
|
199
|
+
expect(notifyFn).not.toHaveBeenCalled();
|
|
200
|
+
// No timers should be scheduled
|
|
201
|
+
expect(vi.getTimerCount()).toBe(0);
|
|
202
|
+
});
|
|
203
|
+
|
|
204
|
+
it("should clear inflightModel after load completes successfully", async () => {
|
|
205
|
+
const loadFn = vi.fn().mockResolvedValue(undefined);
|
|
206
|
+
const model = createMockModel("model-a", {
|
|
207
|
+
load: loadFn,
|
|
208
|
+
getStatus: vi.fn().mockResolvedValue(Status.UNLOADED),
|
|
209
|
+
});
|
|
210
|
+
const models = [model];
|
|
211
|
+
const ctx = createMockCtx((prompt) => {
|
|
212
|
+
if (prompt.includes("models")) return "model-a";
|
|
213
|
+
return Action.LOAD;
|
|
214
|
+
});
|
|
215
|
+
const pi = createMockPi();
|
|
216
|
+
|
|
217
|
+
await modelsCommand(ctx as any, pi as any, models);
|
|
218
|
+
await vi.waitFor(() => expect(loadFn).toHaveBeenCalled());
|
|
219
|
+
await vi.waitFor(() => expect(pi.setModel).toHaveBeenCalled());
|
|
220
|
+
|
|
221
|
+
// inflightModel should be cleared after completion
|
|
222
|
+
// (verified indirectly: calling onSessionBeforeSwitch should not warn)
|
|
223
|
+
await vi.advanceTimersByTimeAsync(0);
|
|
224
|
+
const notifyFn = vi.fn();
|
|
225
|
+
await onSessionBeforeSwitch({} as any, createMockPiContext(notifyFn));
|
|
226
|
+
expect(notifyFn).not.toHaveBeenCalled();
|
|
227
|
+
});
|
|
228
|
+
|
|
229
|
+
it("should clear inflightModel after load fails", async () => {
|
|
230
|
+
const loadFn = vi.fn().mockRejectedValue(new Error("Load failed"));
|
|
231
|
+
const model = createMockModel("model-a", {
|
|
232
|
+
load: loadFn,
|
|
233
|
+
getStatus: vi.fn().mockResolvedValue(Status.FAILED),
|
|
234
|
+
});
|
|
235
|
+
const models = [model];
|
|
236
|
+
const ctx = createMockCtx((prompt) => {
|
|
237
|
+
if (prompt.includes("models")) return "model-a";
|
|
238
|
+
return Action.RETRY;
|
|
239
|
+
});
|
|
240
|
+
const pi = createMockPi();
|
|
241
|
+
|
|
242
|
+
await modelsCommand(ctx as any, pi as any, models);
|
|
243
|
+
await vi.waitFor(() => expect(loadFn).toHaveBeenCalled());
|
|
244
|
+
|
|
245
|
+
// inflightModel should be cleared after failure
|
|
246
|
+
await vi.advanceTimersByTimeAsync(0);
|
|
247
|
+
const notifyFn = vi.fn();
|
|
248
|
+
await onSessionBeforeSwitch({} as any, createMockPiContext(notifyFn));
|
|
249
|
+
expect(notifyFn).not.toHaveBeenCalled();
|
|
250
|
+
});
|
|
251
|
+
|
|
252
|
+
it("should loop back to model selection when action is cancelled", async () => {
|
|
253
|
+
const model = createMockModel("model-a");
|
|
254
|
+
const models = [model];
|
|
255
|
+
|
|
256
|
+
let selectCallCount = 0;
|
|
257
|
+
const ctx = createMockCtx(() => {
|
|
258
|
+
selectCallCount++;
|
|
259
|
+
// 1st: select model-a, 2nd: cancel action, 3rd: cancel model => exit
|
|
260
|
+
if (selectCallCount === 1) return "model-a";
|
|
261
|
+
return null;
|
|
262
|
+
});
|
|
263
|
+
const pi = createMockPi();
|
|
264
|
+
|
|
265
|
+
await modelsCommand(ctx as any, pi as any, models);
|
|
266
|
+
|
|
267
|
+
expect(ctx.ui.select).toHaveBeenCalledTimes(3);
|
|
268
|
+
expect(ctx.ui.notify).not.toHaveBeenCalled();
|
|
269
|
+
});
|
|
270
|
+
});
|
|
@@ -130,9 +130,14 @@ describe("RouterModel context size extraction", () => {
|
|
|
130
130
|
},
|
|
131
131
|
],
|
|
132
132
|
});
|
|
133
|
-
// Second call: super.getContextSize() -> /
|
|
133
|
+
// Second call: super.getContextSize() -> /models with meta.n_ctx
|
|
134
134
|
mockRpc.mockResolvedValueOnce({
|
|
135
|
-
|
|
135
|
+
data: [
|
|
136
|
+
{
|
|
137
|
+
id: "test-model",
|
|
138
|
+
meta: { n_ctx: 4096 },
|
|
139
|
+
},
|
|
140
|
+
],
|
|
136
141
|
});
|
|
137
142
|
|
|
138
143
|
const model = new RouterModel(
|
|
@@ -149,7 +154,7 @@ describe("RouterModel context size extraction", () => {
|
|
|
149
154
|
expect(ctxSize).toBe(4096);
|
|
150
155
|
});
|
|
151
156
|
|
|
152
|
-
it("should return
|
|
157
|
+
it("should return n_ctx from meta when loaded without context size args", async () => {
|
|
153
158
|
// First call: getStatus() -> /models
|
|
154
159
|
mockRpc.mockResolvedValueOnce({
|
|
155
160
|
data: [
|
|
@@ -163,17 +168,16 @@ describe("RouterModel context size extraction", () => {
|
|
|
163
168
|
},
|
|
164
169
|
],
|
|
165
170
|
});
|
|
166
|
-
// Second call: super.getContextSize() -> /models
|
|
171
|
+
// Second call: super.getContextSize() -> /models with meta.n_ctx
|
|
167
172
|
mockRpc.mockResolvedValueOnce({
|
|
168
173
|
data: [
|
|
169
174
|
{
|
|
170
175
|
id: "test-model",
|
|
176
|
+
meta: { n_ctx: 4096 },
|
|
171
177
|
},
|
|
172
178
|
],
|
|
173
179
|
});
|
|
174
180
|
|
|
175
|
-
const { DEFAULT_CTX } = await import("../src/constants");
|
|
176
|
-
|
|
177
181
|
const model = new RouterModel(
|
|
178
182
|
createModel({
|
|
179
183
|
status: {
|
|
@@ -185,13 +189,12 @@ describe("RouterModel context size extraction", () => {
|
|
|
185
189
|
);
|
|
186
190
|
|
|
187
191
|
const ctxSize = await model.getContextSize();
|
|
188
|
-
expect(ctxSize).toBe(
|
|
192
|
+
expect(ctxSize).toBe(4096);
|
|
189
193
|
});
|
|
190
194
|
});
|
|
191
195
|
|
|
192
196
|
describe("RouterModel capabilities detection", () => {
|
|
193
|
-
it("should detect image capability
|
|
194
|
-
// getStatus() calls /models first
|
|
197
|
+
it("should detect image capability from architecture.input_modalities", async () => {
|
|
195
198
|
mockRpc.mockResolvedValueOnce({
|
|
196
199
|
data: [
|
|
197
200
|
{
|
|
@@ -202,21 +205,22 @@ describe("RouterModel capabilities detection", () => {
|
|
|
202
205
|
preset: "default",
|
|
203
206
|
failed: false,
|
|
204
207
|
},
|
|
208
|
+
architecture: {
|
|
209
|
+
input_modalities: ["text", "image"],
|
|
210
|
+
output_modalities: ["text"],
|
|
211
|
+
},
|
|
205
212
|
},
|
|
206
213
|
],
|
|
207
214
|
});
|
|
208
|
-
// super.getCapabilities() calls /props?model=<id>
|
|
209
|
-
mockRpc.mockResolvedValueOnce({ modalities: { vision: true } });
|
|
210
215
|
|
|
211
216
|
const model = new RouterModel(createModel());
|
|
212
217
|
const capabilities = await model.getCapabilities();
|
|
213
218
|
|
|
214
|
-
expect(capabilities).toEqual(["image"]);
|
|
215
|
-
expect(mockRpc).toHaveBeenCalledWith("/
|
|
219
|
+
expect(capabilities).toEqual(["text", "image"]);
|
|
220
|
+
expect(mockRpc).toHaveBeenCalledWith("/models");
|
|
216
221
|
});
|
|
217
222
|
|
|
218
|
-
it("should detect text-only capability when
|
|
219
|
-
// getStatus() calls /models first
|
|
223
|
+
it("should detect text-only capability when only text in input_modalities", async () => {
|
|
220
224
|
mockRpc.mockResolvedValueOnce({
|
|
221
225
|
data: [
|
|
222
226
|
{
|
|
@@ -227,11 +231,13 @@ describe("RouterModel capabilities detection", () => {
|
|
|
227
231
|
preset: "default",
|
|
228
232
|
failed: false,
|
|
229
233
|
},
|
|
234
|
+
architecture: {
|
|
235
|
+
input_modalities: ["text"],
|
|
236
|
+
output_modalities: ["text"],
|
|
237
|
+
},
|
|
230
238
|
},
|
|
231
239
|
],
|
|
232
240
|
});
|
|
233
|
-
// super.getCapabilities() calls /props?model=<id>
|
|
234
|
-
mockRpc.mockResolvedValueOnce({ modalities: { vision: false } });
|
|
235
241
|
|
|
236
242
|
const model = new RouterModel(createModel());
|
|
237
243
|
const capabilities = await model.getCapabilities();
|
|
@@ -239,12 +245,11 @@ describe("RouterModel capabilities detection", () => {
|
|
|
239
245
|
expect(capabilities).toEqual(["text"]);
|
|
240
246
|
});
|
|
241
247
|
|
|
242
|
-
it("should
|
|
243
|
-
// getStatus() calls /models first
|
|
248
|
+
it("should return text when model not found in /models response", async () => {
|
|
244
249
|
mockRpc.mockResolvedValueOnce({
|
|
245
250
|
data: [
|
|
246
251
|
{
|
|
247
|
-
id: "
|
|
252
|
+
id: "other-model",
|
|
248
253
|
status: {
|
|
249
254
|
value: "loaded",
|
|
250
255
|
args: [],
|
|
@@ -254,76 +259,12 @@ describe("RouterModel capabilities detection", () => {
|
|
|
254
259
|
},
|
|
255
260
|
],
|
|
256
261
|
});
|
|
257
|
-
// super.getCapabilities() calls /props?model=<id> which fails
|
|
258
|
-
mockRpc.mockRejectedValueOnce(new Error("Connection refused"));
|
|
259
262
|
|
|
260
263
|
const model = new RouterModel(createModel());
|
|
261
264
|
const capabilities = await model.getCapabilities();
|
|
262
265
|
|
|
263
266
|
expect(capabilities).toEqual(["text"]);
|
|
264
267
|
});
|
|
265
|
-
|
|
266
|
-
it("should use status.args to detect image capability when not loaded", async () => {
|
|
267
|
-
// getStatus() calls /models first, returns unloaded
|
|
268
|
-
mockRpc.mockResolvedValueOnce({
|
|
269
|
-
data: [
|
|
270
|
-
{
|
|
271
|
-
id: "test-model",
|
|
272
|
-
status: {
|
|
273
|
-
value: "unloaded",
|
|
274
|
-
args: ["--model", "gguf", "--mmproj", "mmproj.gguf"],
|
|
275
|
-
preset: "default",
|
|
276
|
-
failed: false,
|
|
277
|
-
},
|
|
278
|
-
},
|
|
279
|
-
],
|
|
280
|
-
});
|
|
281
|
-
|
|
282
|
-
const model = new RouterModel(
|
|
283
|
-
createModel({
|
|
284
|
-
status: {
|
|
285
|
-
value: "unloaded",
|
|
286
|
-
args: ["--model", "gguf", "--mmproj", "mmproj.gguf"],
|
|
287
|
-
preset: "default",
|
|
288
|
-
failed: false,
|
|
289
|
-
},
|
|
290
|
-
}),
|
|
291
|
-
);
|
|
292
|
-
const capabilities = await model.getCapabilities();
|
|
293
|
-
|
|
294
|
-
expect(capabilities).toEqual(["image"]);
|
|
295
|
-
});
|
|
296
|
-
|
|
297
|
-
it("should return text when not loaded and no --mmproj in args", async () => {
|
|
298
|
-
// getStatus() calls /models first, returns unloaded
|
|
299
|
-
mockRpc.mockResolvedValueOnce({
|
|
300
|
-
data: [
|
|
301
|
-
{
|
|
302
|
-
id: "test-model",
|
|
303
|
-
status: {
|
|
304
|
-
value: "unloaded",
|
|
305
|
-
args: ["--model", "gguf"],
|
|
306
|
-
preset: "default",
|
|
307
|
-
failed: false,
|
|
308
|
-
},
|
|
309
|
-
},
|
|
310
|
-
],
|
|
311
|
-
});
|
|
312
|
-
|
|
313
|
-
const model = new RouterModel(
|
|
314
|
-
createModel({
|
|
315
|
-
status: {
|
|
316
|
-
value: "unloaded",
|
|
317
|
-
args: ["--model", "gguf"],
|
|
318
|
-
preset: "default",
|
|
319
|
-
failed: false,
|
|
320
|
-
},
|
|
321
|
-
}),
|
|
322
|
-
);
|
|
323
|
-
const capabilities = await model.getCapabilities();
|
|
324
|
-
|
|
325
|
-
expect(capabilities).toEqual(["text"]);
|
|
326
|
-
});
|
|
327
268
|
});
|
|
328
269
|
|
|
329
270
|
describe("RouterModel mode", () => {
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import { beforeEach, describe, expect, it, vi } from "vitest";
|
|
2
|
-
import { DEFAULT_CTX } from "../src/constants";
|
|
3
2
|
import { Mode } from "../src/enums/mode";
|
|
4
3
|
import { Status } from "../src/enums/status";
|
|
5
4
|
import { ModelProperty } from "../src/interfaces/endpoints/models";
|
|
@@ -34,27 +33,22 @@ describe("SingleModel mode", () => {
|
|
|
34
33
|
});
|
|
35
34
|
|
|
36
35
|
describe("SingleModel capabilities", () => {
|
|
37
|
-
it("should detect image capability when
|
|
38
|
-
mockRpc.mockResolvedValueOnce({
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
const capabilities = await model.getCapabilities();
|
|
42
|
-
|
|
43
|
-
expect(capabilities).toEqual(["image"]);
|
|
44
|
-
expect(mockRpc).toHaveBeenCalledWith("/props?model=test");
|
|
45
|
-
});
|
|
46
|
-
|
|
47
|
-
it("should detect text-only capability when modalities.vision is false", async () => {
|
|
48
|
-
mockRpc.mockResolvedValueOnce({ modalities: { vision: false } });
|
|
36
|
+
it("should detect image capability when multimodal is in capabilities", async () => {
|
|
37
|
+
mockRpc.mockResolvedValueOnce({
|
|
38
|
+
models: [{ id: "test", capabilities: ["multimodal"] }],
|
|
39
|
+
});
|
|
49
40
|
|
|
50
41
|
const model = createModel();
|
|
51
42
|
const capabilities = await model.getCapabilities();
|
|
52
43
|
|
|
53
|
-
expect(capabilities).toEqual(["text"]);
|
|
44
|
+
expect(capabilities).toEqual(["text", "image"]);
|
|
45
|
+
expect(mockRpc).toHaveBeenCalledWith("/models");
|
|
54
46
|
});
|
|
55
47
|
|
|
56
|
-
it("should
|
|
57
|
-
mockRpc.
|
|
48
|
+
it("should detect text-only capability when multimodal is not in capabilities", async () => {
|
|
49
|
+
mockRpc.mockResolvedValueOnce({
|
|
50
|
+
models: [{ id: "test", capabilities: [] }],
|
|
51
|
+
});
|
|
58
52
|
|
|
59
53
|
const model = createModel();
|
|
60
54
|
const capabilities = await model.getCapabilities();
|
|
@@ -71,7 +65,9 @@ describe("SingleModel getStatus", () => {
|
|
|
71
65
|
const status = await model.getStatus();
|
|
72
66
|
|
|
73
67
|
expect(status).toBe(Status.LOADED);
|
|
74
|
-
expect(mockRpc).toHaveBeenCalledWith(
|
|
68
|
+
expect(mockRpc).toHaveBeenCalledWith(
|
|
69
|
+
`/props?model=${model.id}&autoload=false`,
|
|
70
|
+
);
|
|
75
71
|
});
|
|
76
72
|
|
|
77
73
|
it("should return SLEEPING when is_sleeping is true", async () => {
|
|
@@ -85,24 +81,15 @@ describe("SingleModel getStatus", () => {
|
|
|
85
81
|
});
|
|
86
82
|
|
|
87
83
|
describe("SingleModel getContextSize", () => {
|
|
88
|
-
it("should return n_ctx from /
|
|
84
|
+
it("should return n_ctx from /models endpoint meta", async () => {
|
|
89
85
|
mockRpc.mockResolvedValueOnce({
|
|
90
|
-
|
|
86
|
+
data: [{ id: "test", meta: { n_ctx: 8192 } }],
|
|
91
87
|
});
|
|
92
88
|
|
|
93
89
|
const model = createModel();
|
|
94
90
|
const ctxSize = await model.getContextSize();
|
|
95
91
|
|
|
96
92
|
expect(ctxSize).toBe(8192);
|
|
97
|
-
expect(mockRpc).toHaveBeenCalledWith("/
|
|
98
|
-
});
|
|
99
|
-
|
|
100
|
-
it("should return DEFAULT_CTX when /props fails", async () => {
|
|
101
|
-
mockRpc.mockRejectedValueOnce(new Error("Connection refused"));
|
|
102
|
-
|
|
103
|
-
const model = createModel();
|
|
104
|
-
const ctxSize = await model.getContextSize();
|
|
105
|
-
|
|
106
|
-
expect(ctxSize).toBe(DEFAULT_CTX);
|
|
93
|
+
expect(mockRpc).toHaveBeenCalledWith("/models");
|
|
107
94
|
});
|
|
108
95
|
});
|
package/src/tools/provider.ts
DELETED
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
import type { ExtensionAPI } from "@earendil-works/pi-coding-agent";
|
|
2
|
-
import { PROVIDER_ID, PROVIDER_NAME } from "../constants";
|
|
3
|
-
import type { BaseModel } from "../models/baseModel";
|
|
4
|
-
import { resolveApiKey, resolveUrl } from "./resolver";
|
|
5
|
-
import { listModels } from "./retriever";
|
|
6
|
-
|
|
7
|
-
/**
|
|
8
|
-
* Registers the Llama.cpp provider and returns the fetched models.
|
|
9
|
-
*
|
|
10
|
-
* @param pi The Pi extension API
|
|
11
|
-
* @returns The list of models fetched from the server
|
|
12
|
-
*/
|
|
13
|
-
export const registerLlamaCppProvider = async (
|
|
14
|
-
pi: ExtensionAPI,
|
|
15
|
-
): Promise<BaseModel[]> => {
|
|
16
|
-
const baseUrl = `${await resolveUrl(process.cwd())}/v1`;
|
|
17
|
-
const models = await listModels();
|
|
18
|
-
|
|
19
|
-
pi.registerProvider(PROVIDER_ID, {
|
|
20
|
-
name: PROVIDER_NAME,
|
|
21
|
-
baseUrl,
|
|
22
|
-
api: "openai-completions",
|
|
23
|
-
apiKey: await resolveApiKey(),
|
|
24
|
-
models: await Promise.all(models.map((m) => m.toProviderConfig())),
|
|
25
|
-
});
|
|
26
|
-
|
|
27
|
-
return models;
|
|
28
|
-
};
|