kekkai-cli 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kekkai/__init__.py +7 -0
- kekkai/cli.py +1038 -0
- kekkai/config.py +403 -0
- kekkai/dojo.py +419 -0
- kekkai/dojo_import.py +213 -0
- kekkai/github/__init__.py +16 -0
- kekkai/github/commenter.py +198 -0
- kekkai/github/models.py +56 -0
- kekkai/github/sanitizer.py +112 -0
- kekkai/installer/__init__.py +39 -0
- kekkai/installer/errors.py +23 -0
- kekkai/installer/extract.py +161 -0
- kekkai/installer/manager.py +252 -0
- kekkai/installer/manifest.py +189 -0
- kekkai/installer/verify.py +86 -0
- kekkai/manifest.py +77 -0
- kekkai/output.py +218 -0
- kekkai/paths.py +46 -0
- kekkai/policy.py +326 -0
- kekkai/runner.py +70 -0
- kekkai/scanners/__init__.py +67 -0
- kekkai/scanners/backends/__init__.py +14 -0
- kekkai/scanners/backends/base.py +73 -0
- kekkai/scanners/backends/docker.py +178 -0
- kekkai/scanners/backends/native.py +240 -0
- kekkai/scanners/base.py +110 -0
- kekkai/scanners/container.py +144 -0
- kekkai/scanners/falco.py +237 -0
- kekkai/scanners/gitleaks.py +237 -0
- kekkai/scanners/semgrep.py +227 -0
- kekkai/scanners/trivy.py +246 -0
- kekkai/scanners/url_policy.py +163 -0
- kekkai/scanners/zap.py +340 -0
- kekkai/threatflow/__init__.py +94 -0
- kekkai/threatflow/artifacts.py +476 -0
- kekkai/threatflow/chunking.py +361 -0
- kekkai/threatflow/core.py +438 -0
- kekkai/threatflow/mermaid.py +374 -0
- kekkai/threatflow/model_adapter.py +491 -0
- kekkai/threatflow/prompts.py +277 -0
- kekkai/threatflow/redaction.py +228 -0
- kekkai/threatflow/sanitizer.py +643 -0
- kekkai/triage/__init__.py +33 -0
- kekkai/triage/app.py +168 -0
- kekkai/triage/audit.py +203 -0
- kekkai/triage/ignore.py +269 -0
- kekkai/triage/models.py +185 -0
- kekkai/triage/screens.py +341 -0
- kekkai/triage/widgets.py +169 -0
- kekkai_cli-1.0.0.dist-info/METADATA +135 -0
- kekkai_cli-1.0.0.dist-info/RECORD +90 -0
- kekkai_cli-1.0.0.dist-info/WHEEL +5 -0
- kekkai_cli-1.0.0.dist-info/entry_points.txt +3 -0
- kekkai_cli-1.0.0.dist-info/top_level.txt +3 -0
- kekkai_core/__init__.py +3 -0
- kekkai_core/ci/__init__.py +11 -0
- kekkai_core/ci/benchmarks.py +354 -0
- kekkai_core/ci/metadata.py +104 -0
- kekkai_core/ci/validators.py +92 -0
- kekkai_core/docker/__init__.py +17 -0
- kekkai_core/docker/metadata.py +153 -0
- kekkai_core/docker/sbom.py +173 -0
- kekkai_core/docker/security.py +158 -0
- kekkai_core/docker/signing.py +135 -0
- kekkai_core/redaction.py +84 -0
- kekkai_core/slsa/__init__.py +13 -0
- kekkai_core/slsa/verify.py +121 -0
- kekkai_core/windows/__init__.py +29 -0
- kekkai_core/windows/chocolatey.py +335 -0
- kekkai_core/windows/installer.py +256 -0
- kekkai_core/windows/scoop.py +165 -0
- kekkai_core/windows/validators.py +220 -0
- portal/__init__.py +19 -0
- portal/api.py +155 -0
- portal/auth.py +103 -0
- portal/enterprise/__init__.py +32 -0
- portal/enterprise/audit.py +435 -0
- portal/enterprise/licensing.py +342 -0
- portal/enterprise/rbac.py +276 -0
- portal/enterprise/saml.py +595 -0
- portal/ops/__init__.py +53 -0
- portal/ops/backup.py +553 -0
- portal/ops/log_shipper.py +469 -0
- portal/ops/monitoring.py +517 -0
- portal/ops/restore.py +469 -0
- portal/ops/secrets.py +408 -0
- portal/ops/upgrade.py +591 -0
- portal/tenants.py +340 -0
- portal/uploads.py +259 -0
- portal/web.py +384 -0
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
"""Model adapter protocol for ThreatFlow LLM backends.
|
|
2
|
+
|
|
3
|
+
Supports:
|
|
4
|
+
- Local models (default, privacy-preserving)
|
|
5
|
+
- Remote APIs (OpenAI, Anthropic - opt-in with warning)
|
|
6
|
+
- Mock adapter for testing
|
|
7
|
+
|
|
8
|
+
ASVS V13.1.3: Timeouts and resource limits on all model calls.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import TYPE_CHECKING, Any
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
DEFAULT_TIMEOUT_SECONDS = 120
|
|
28
|
+
DEFAULT_MAX_TOKENS = 4096
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class ModelResponse:
|
|
33
|
+
"""Response from an LLM model."""
|
|
34
|
+
|
|
35
|
+
content: str
|
|
36
|
+
model_name: str
|
|
37
|
+
prompt_tokens: int = 0
|
|
38
|
+
completion_tokens: int = 0
|
|
39
|
+
total_tokens: int = 0
|
|
40
|
+
latency_ms: int = 0
|
|
41
|
+
raw_response: dict[str, Any] = field(default_factory=dict)
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def success(self) -> bool:
|
|
45
|
+
return bool(self.content)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ModelConfig:
|
|
50
|
+
"""Configuration for model adapters."""
|
|
51
|
+
|
|
52
|
+
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS
|
|
53
|
+
max_tokens: int = DEFAULT_MAX_TOKENS
|
|
54
|
+
temperature: float = 0.1 # Low temperature for deterministic output
|
|
55
|
+
model_path: str | None = None # For local models
|
|
56
|
+
api_key: str | None = None # For remote APIs
|
|
57
|
+
api_base: str | None = None # Custom API endpoint
|
|
58
|
+
model_name: str | None = None # Specific model to use
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ModelAdapter(ABC):
|
|
62
|
+
"""Abstract base class for LLM model adapters."""
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def name(self) -> str:
|
|
67
|
+
"""Adapter name for logging/display."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def is_local(self) -> bool:
|
|
73
|
+
"""Whether this adapter runs locally (no external calls)."""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def generate(
|
|
78
|
+
self,
|
|
79
|
+
system_prompt: str,
|
|
80
|
+
user_prompt: str,
|
|
81
|
+
config: ModelConfig | None = None,
|
|
82
|
+
) -> ModelResponse:
|
|
83
|
+
"""Generate a response from the model.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
system_prompt: System-level instructions
|
|
87
|
+
user_prompt: User query/content to analyze
|
|
88
|
+
config: Optional configuration overrides
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
ModelResponse with the generated content
|
|
92
|
+
"""
|
|
93
|
+
...
|
|
94
|
+
|
|
95
|
+
def health_check(self) -> bool:
|
|
96
|
+
"""Check if the model is available and working."""
|
|
97
|
+
try:
|
|
98
|
+
response = self.generate(
|
|
99
|
+
system_prompt="Respond with OK.",
|
|
100
|
+
user_prompt="Health check",
|
|
101
|
+
config=ModelConfig(timeout_seconds=10, max_tokens=10),
|
|
102
|
+
)
|
|
103
|
+
return response.success
|
|
104
|
+
except Exception:
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LocalModelAdapter(ModelAdapter):
|
|
109
|
+
"""Adapter for local LLM inference.
|
|
110
|
+
|
|
111
|
+
Supports common local model formats via llama-cpp-python or similar.
|
|
112
|
+
Falls back to a stub if no local model is available.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, model_path: str | None = None) -> None:
|
|
116
|
+
self._model_path = model_path or os.environ.get("KEKKAI_THREATFLOW_MODEL_PATH")
|
|
117
|
+
self._model: Any = None
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def name(self) -> str:
|
|
121
|
+
return "local"
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def is_local(self) -> bool:
|
|
125
|
+
return True
|
|
126
|
+
|
|
127
|
+
def _load_model(self) -> Any:
|
|
128
|
+
"""Lazy-load the local model."""
|
|
129
|
+
if self._model is not None:
|
|
130
|
+
return self._model
|
|
131
|
+
|
|
132
|
+
if not self._model_path:
|
|
133
|
+
logger.warning("No local model path configured")
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
model_path = Path(self._model_path)
|
|
137
|
+
if not model_path.exists():
|
|
138
|
+
logger.warning("Local model not found: %s", model_path)
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
# Try to import llama-cpp-python
|
|
143
|
+
from llama_cpp import Llama # type: ignore[import-not-found]
|
|
144
|
+
|
|
145
|
+
self._model = Llama(
|
|
146
|
+
model_path=str(model_path),
|
|
147
|
+
n_ctx=4096,
|
|
148
|
+
n_threads=4,
|
|
149
|
+
verbose=False,
|
|
150
|
+
)
|
|
151
|
+
logger.info("Loaded local model: %s", model_path.name)
|
|
152
|
+
return self._model
|
|
153
|
+
except ImportError:
|
|
154
|
+
logger.warning("llama-cpp-python not installed, local model unavailable")
|
|
155
|
+
return None
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.error("Failed to load local model: %s", e)
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
def generate(
|
|
161
|
+
self,
|
|
162
|
+
system_prompt: str,
|
|
163
|
+
user_prompt: str,
|
|
164
|
+
config: ModelConfig | None = None,
|
|
165
|
+
) -> ModelResponse:
|
|
166
|
+
"""Generate using local model."""
|
|
167
|
+
config = config or ModelConfig()
|
|
168
|
+
start_time = time.time()
|
|
169
|
+
|
|
170
|
+
model = self._load_model()
|
|
171
|
+
if model is None:
|
|
172
|
+
# Return a stub response indicating local model unavailable
|
|
173
|
+
return ModelResponse(
|
|
174
|
+
content="[LOCAL MODEL UNAVAILABLE - Install llama-cpp-python]",
|
|
175
|
+
model_name="local-stub",
|
|
176
|
+
latency_ms=0,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
# Format prompt for chat-style models
|
|
181
|
+
full_prompt = f"<|system|>\n{system_prompt}\n<|user|>\n{user_prompt}\n<|assistant|>\n"
|
|
182
|
+
|
|
183
|
+
response = model(
|
|
184
|
+
full_prompt,
|
|
185
|
+
max_tokens=config.max_tokens,
|
|
186
|
+
temperature=config.temperature,
|
|
187
|
+
stop=["<|user|>", "<|system|>"],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
content = response["choices"][0]["text"].strip()
|
|
191
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
192
|
+
|
|
193
|
+
model_name = "local:unknown"
|
|
194
|
+
if self._model_path:
|
|
195
|
+
model_name = f"local:{Path(self._model_path).name}"
|
|
196
|
+
return ModelResponse(
|
|
197
|
+
content=content,
|
|
198
|
+
model_name=model_name,
|
|
199
|
+
prompt_tokens=response.get("usage", {}).get("prompt_tokens", 0),
|
|
200
|
+
completion_tokens=response.get("usage", {}).get("completion_tokens", 0),
|
|
201
|
+
total_tokens=response.get("usage", {}).get("total_tokens", 0),
|
|
202
|
+
latency_ms=latency_ms,
|
|
203
|
+
raw_response=dict(response),
|
|
204
|
+
)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error("Local model generation failed: %s", e)
|
|
207
|
+
return ModelResponse(
|
|
208
|
+
content="",
|
|
209
|
+
model_name="local-error",
|
|
210
|
+
latency_ms=int((time.time() - start_time) * 1000),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class RemoteModelAdapter(ModelAdapter):
|
|
215
|
+
"""Adapter for remote LLM APIs (OpenAI, Anthropic).
|
|
216
|
+
|
|
217
|
+
WARNING: Using this adapter sends code to external services.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
def __init__(
|
|
221
|
+
self,
|
|
222
|
+
api_key: str | None = None,
|
|
223
|
+
api_base: str | None = None,
|
|
224
|
+
model_name: str = "gpt-4o-mini",
|
|
225
|
+
provider: str = "openai",
|
|
226
|
+
) -> None:
|
|
227
|
+
self._api_key = api_key or os.environ.get("KEKKAI_THREATFLOW_API_KEY")
|
|
228
|
+
self._api_base = api_base or os.environ.get("KEKKAI_THREATFLOW_API_BASE")
|
|
229
|
+
self._model_name = model_name
|
|
230
|
+
self._provider = provider
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def name(self) -> str:
|
|
234
|
+
return f"remote:{self._provider}"
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def is_local(self) -> bool:
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
def generate(
|
|
241
|
+
self,
|
|
242
|
+
system_prompt: str,
|
|
243
|
+
user_prompt: str,
|
|
244
|
+
config: ModelConfig | None = None,
|
|
245
|
+
) -> ModelResponse:
|
|
246
|
+
"""Generate using remote API."""
|
|
247
|
+
config = config or ModelConfig()
|
|
248
|
+
start_time = time.time()
|
|
249
|
+
|
|
250
|
+
if not self._api_key:
|
|
251
|
+
return ModelResponse(
|
|
252
|
+
content="[REMOTE API KEY NOT CONFIGURED]",
|
|
253
|
+
model_name=self._model_name,
|
|
254
|
+
latency_ms=0,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
if self._provider == "openai":
|
|
259
|
+
return self._generate_openai(system_prompt, user_prompt, config, start_time)
|
|
260
|
+
elif self._provider == "anthropic":
|
|
261
|
+
return self._generate_anthropic(system_prompt, user_prompt, config, start_time)
|
|
262
|
+
else:
|
|
263
|
+
return ModelResponse(
|
|
264
|
+
content=f"[UNSUPPORTED PROVIDER: {self._provider}]",
|
|
265
|
+
model_name=self._model_name,
|
|
266
|
+
latency_ms=0,
|
|
267
|
+
)
|
|
268
|
+
except Exception as e:
|
|
269
|
+
logger.error("Remote API call failed: %s", e)
|
|
270
|
+
return ModelResponse(
|
|
271
|
+
content="",
|
|
272
|
+
model_name=self._model_name,
|
|
273
|
+
latency_ms=int((time.time() - start_time) * 1000),
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def _generate_openai(
|
|
277
|
+
self,
|
|
278
|
+
system_prompt: str,
|
|
279
|
+
user_prompt: str,
|
|
280
|
+
config: ModelConfig,
|
|
281
|
+
start_time: float,
|
|
282
|
+
) -> ModelResponse:
|
|
283
|
+
"""Generate using OpenAI API."""
|
|
284
|
+
import urllib.error
|
|
285
|
+
import urllib.request
|
|
286
|
+
|
|
287
|
+
url = self._api_base or "https://api.openai.com/v1/chat/completions"
|
|
288
|
+
|
|
289
|
+
headers = {
|
|
290
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
291
|
+
"Content-Type": "application/json",
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
data = {
|
|
295
|
+
"model": config.model_name or self._model_name,
|
|
296
|
+
"messages": [
|
|
297
|
+
{"role": "system", "content": system_prompt},
|
|
298
|
+
{"role": "user", "content": user_prompt},
|
|
299
|
+
],
|
|
300
|
+
"max_tokens": config.max_tokens,
|
|
301
|
+
"temperature": config.temperature,
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
req = urllib.request.Request( # noqa: S310 - URL is validated API endpoint
|
|
305
|
+
url,
|
|
306
|
+
data=json.dumps(data).encode("utf-8"),
|
|
307
|
+
headers=headers,
|
|
308
|
+
method="POST",
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
with urllib.request.urlopen( # noqa: S310 # nosec B310
|
|
313
|
+
req, timeout=config.timeout_seconds
|
|
314
|
+
) as resp:
|
|
315
|
+
response_data = json.loads(resp.read().decode("utf-8"))
|
|
316
|
+
|
|
317
|
+
content = response_data["choices"][0]["message"]["content"]
|
|
318
|
+
usage = response_data.get("usage", {})
|
|
319
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
320
|
+
|
|
321
|
+
return ModelResponse(
|
|
322
|
+
content=content,
|
|
323
|
+
model_name=response_data.get("model", self._model_name),
|
|
324
|
+
prompt_tokens=usage.get("prompt_tokens", 0),
|
|
325
|
+
completion_tokens=usage.get("completion_tokens", 0),
|
|
326
|
+
total_tokens=usage.get("total_tokens", 0),
|
|
327
|
+
latency_ms=latency_ms,
|
|
328
|
+
raw_response=response_data,
|
|
329
|
+
)
|
|
330
|
+
except urllib.error.URLError as e:
|
|
331
|
+
logger.error("OpenAI API error: %s", e)
|
|
332
|
+
return ModelResponse(
|
|
333
|
+
content="",
|
|
334
|
+
model_name=self._model_name,
|
|
335
|
+
latency_ms=int((time.time() - start_time) * 1000),
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def _generate_anthropic(
|
|
339
|
+
self,
|
|
340
|
+
system_prompt: str,
|
|
341
|
+
user_prompt: str,
|
|
342
|
+
config: ModelConfig,
|
|
343
|
+
start_time: float,
|
|
344
|
+
) -> ModelResponse:
|
|
345
|
+
"""Generate using Anthropic API."""
|
|
346
|
+
import urllib.error
|
|
347
|
+
import urllib.request
|
|
348
|
+
|
|
349
|
+
url = self._api_base or "https://api.anthropic.com/v1/messages"
|
|
350
|
+
|
|
351
|
+
headers = {
|
|
352
|
+
"x-api-key": self._api_key or "",
|
|
353
|
+
"Content-Type": "application/json",
|
|
354
|
+
"anthropic-version": "2023-06-01",
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
data = {
|
|
358
|
+
"model": config.model_name or "claude-3-haiku-20240307",
|
|
359
|
+
"max_tokens": config.max_tokens,
|
|
360
|
+
"system": system_prompt,
|
|
361
|
+
"messages": [{"role": "user", "content": user_prompt}],
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
req = urllib.request.Request( # noqa: S310 - URL is validated API endpoint
|
|
365
|
+
url,
|
|
366
|
+
data=json.dumps(data).encode("utf-8"),
|
|
367
|
+
headers=headers,
|
|
368
|
+
method="POST",
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
with urllib.request.urlopen( # noqa: S310 # nosec B310
|
|
373
|
+
req, timeout=config.timeout_seconds
|
|
374
|
+
) as resp:
|
|
375
|
+
response_data = json.loads(resp.read().decode("utf-8"))
|
|
376
|
+
|
|
377
|
+
content = response_data["content"][0]["text"]
|
|
378
|
+
usage = response_data.get("usage", {})
|
|
379
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
380
|
+
|
|
381
|
+
return ModelResponse(
|
|
382
|
+
content=content,
|
|
383
|
+
model_name=response_data.get("model", self._model_name),
|
|
384
|
+
prompt_tokens=usage.get("input_tokens", 0),
|
|
385
|
+
completion_tokens=usage.get("output_tokens", 0),
|
|
386
|
+
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
|
387
|
+
latency_ms=latency_ms,
|
|
388
|
+
raw_response=response_data,
|
|
389
|
+
)
|
|
390
|
+
except urllib.error.URLError as e:
|
|
391
|
+
logger.error("Anthropic API error: %s", e)
|
|
392
|
+
return ModelResponse(
|
|
393
|
+
content="",
|
|
394
|
+
model_name=self._model_name,
|
|
395
|
+
latency_ms=int((time.time() - start_time) * 1000),
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class MockModelAdapter(ModelAdapter):
|
|
400
|
+
"""Mock adapter for testing."""
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
responses: dict[str, str] | None = None,
|
|
405
|
+
default_response: str = "Mock response",
|
|
406
|
+
) -> None:
|
|
407
|
+
self._responses = responses or {}
|
|
408
|
+
self._default_response = default_response
|
|
409
|
+
self._call_history: list[tuple[str, str]] = []
|
|
410
|
+
|
|
411
|
+
@property
|
|
412
|
+
def name(self) -> str:
|
|
413
|
+
return "mock"
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def is_local(self) -> bool:
|
|
417
|
+
return True
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def call_history(self) -> list[tuple[str, str]]:
|
|
421
|
+
"""Get history of calls for testing."""
|
|
422
|
+
return list(self._call_history)
|
|
423
|
+
|
|
424
|
+
def generate(
|
|
425
|
+
self,
|
|
426
|
+
system_prompt: str,
|
|
427
|
+
user_prompt: str,
|
|
428
|
+
config: ModelConfig | None = None,
|
|
429
|
+
) -> ModelResponse:
|
|
430
|
+
"""Generate a mock response."""
|
|
431
|
+
self._call_history.append((system_prompt, user_prompt))
|
|
432
|
+
|
|
433
|
+
# Check for keyword matches in responses
|
|
434
|
+
for keyword, response in self._responses.items():
|
|
435
|
+
if keyword.lower() in user_prompt.lower():
|
|
436
|
+
return ModelResponse(
|
|
437
|
+
content=response,
|
|
438
|
+
model_name="mock",
|
|
439
|
+
latency_ms=1,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
return ModelResponse(
|
|
443
|
+
content=self._default_response,
|
|
444
|
+
model_name="mock",
|
|
445
|
+
latency_ms=1,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
def set_response(self, keyword: str, response: str) -> None:
|
|
449
|
+
"""Set a response for a specific keyword."""
|
|
450
|
+
self._responses[keyword] = response
|
|
451
|
+
|
|
452
|
+
def clear_history(self) -> None:
|
|
453
|
+
"""Clear call history."""
|
|
454
|
+
self._call_history.clear()
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def create_adapter(
|
|
458
|
+
mode: str = "local",
|
|
459
|
+
config: ModelConfig | None = None,
|
|
460
|
+
) -> ModelAdapter:
|
|
461
|
+
"""Create a model adapter based on mode.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
mode: "local", "openai", "anthropic", or "mock"
|
|
465
|
+
config: Configuration for the adapter
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
Configured ModelAdapter instance
|
|
469
|
+
"""
|
|
470
|
+
config = config or ModelConfig()
|
|
471
|
+
|
|
472
|
+
if mode == "mock":
|
|
473
|
+
return MockModelAdapter()
|
|
474
|
+
elif mode == "local":
|
|
475
|
+
return LocalModelAdapter(model_path=config.model_path)
|
|
476
|
+
elif mode == "openai":
|
|
477
|
+
return RemoteModelAdapter(
|
|
478
|
+
api_key=config.api_key,
|
|
479
|
+
api_base=config.api_base,
|
|
480
|
+
model_name=config.model_name or "gpt-4o-mini",
|
|
481
|
+
provider="openai",
|
|
482
|
+
)
|
|
483
|
+
elif mode == "anthropic":
|
|
484
|
+
return RemoteModelAdapter(
|
|
485
|
+
api_key=config.api_key,
|
|
486
|
+
api_base=config.api_base,
|
|
487
|
+
model_name=config.model_name or "claude-3-haiku-20240307",
|
|
488
|
+
provider="anthropic",
|
|
489
|
+
)
|
|
490
|
+
else:
|
|
491
|
+
raise ValueError(f"Unknown adapter mode: {mode}")
|