hajimi-claw 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.lock +2602 -0
- package/Cargo.toml +57 -0
- package/README.md +73 -0
- package/bin/hajimi-claw.js +28 -0
- package/config.example.toml +32 -0
- package/crates/hajimi-claw-agent/Cargo.toml +25 -0
- package/crates/hajimi-claw-agent/src/lib.rs +351 -0
- package/crates/hajimi-claw-bot/Cargo.toml +18 -0
- package/crates/hajimi-claw-bot/src/lib.rs +305 -0
- package/crates/hajimi-claw-daemon/Cargo.toml +24 -0
- package/crates/hajimi-claw-daemon/src/lib.rs +173 -0
- package/crates/hajimi-claw-exec/Cargo.toml +21 -0
- package/crates/hajimi-claw-exec/src/lib.rs +419 -0
- package/crates/hajimi-claw-gateway/Cargo.toml +27 -0
- package/crates/hajimi-claw-gateway/src/lib.rs +747 -0
- package/crates/hajimi-claw-llm/Cargo.toml +19 -0
- package/crates/hajimi-claw-llm/src/lib.rs +367 -0
- package/crates/hajimi-claw-policy/Cargo.toml +14 -0
- package/crates/hajimi-claw-policy/src/lib.rs +381 -0
- package/crates/hajimi-claw-store/Cargo.toml +17 -0
- package/crates/hajimi-claw-store/src/lib.rs +730 -0
- package/crates/hajimi-claw-tools/Cargo.toml +21 -0
- package/crates/hajimi-claw-tools/src/lib.rs +758 -0
- package/crates/hajimi-claw-types/Cargo.toml +16 -0
- package/crates/hajimi-claw-types/src/lib.rs +300 -0
- package/package.json +26 -0
- package/scripts/npm-install.js +45 -0
- package/src/main.rs +4 -0
|
@@ -0,0 +1,730 @@
|
|
|
1
|
+
use std::path::Path;
|
|
2
|
+
use std::sync::{Arc, Mutex};
|
|
3
|
+
|
|
4
|
+
use aes_gcm::aead::{Aead, KeyInit, OsRng, rand_core::RngCore};
|
|
5
|
+
use aes_gcm::{Aes256Gcm, Nonce};
|
|
6
|
+
use anyhow::{Context, Result};
|
|
7
|
+
use base64::Engine;
|
|
8
|
+
use base64::engine::general_purpose::STANDARD as BASE64;
|
|
9
|
+
use chrono::{DateTime, Utc};
|
|
10
|
+
use hajimi_claw_types::{
|
|
11
|
+
ApprovalRequest, ConversationId, ConversationMessage, OnboardingSession, ProviderConfig,
|
|
12
|
+
ProviderDraft, ProviderKind, ProviderRecord, SessionHandle, SessionSummary, TaskId, TaskKind,
|
|
13
|
+
TaskStatus,
|
|
14
|
+
};
|
|
15
|
+
use rusqlite::{Connection, OptionalExtension, params};
|
|
16
|
+
use sha2::{Digest, Sha256};
|
|
17
|
+
|
|
18
|
+
pub struct Store {
|
|
19
|
+
connection: Mutex<Connection>,
|
|
20
|
+
cipher: Option<Arc<SecretCipher>>,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
#[derive(Clone)]
|
|
24
|
+
pub struct SecretCipher {
|
|
25
|
+
key: [u8; 32],
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
impl Store {
|
|
29
|
+
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
|
|
30
|
+
Self::open_with_cipher(path, None)
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
pub fn open_with_cipher(
|
|
34
|
+
path: impl AsRef<Path>,
|
|
35
|
+
cipher: Option<Arc<SecretCipher>>,
|
|
36
|
+
) -> Result<Self> {
|
|
37
|
+
let connection = Connection::open(path).context("open sqlite database")?;
|
|
38
|
+
connection.pragma_update(None, "journal_mode", "WAL")?;
|
|
39
|
+
Self::from_connection(connection, cipher)
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
pub fn open_in_memory() -> Result<Self> {
|
|
43
|
+
Self::open_in_memory_with_cipher(None)
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
pub fn open_in_memory_with_cipher(cipher: Option<Arc<SecretCipher>>) -> Result<Self> {
|
|
47
|
+
let connection = Connection::open_in_memory().context("open sqlite memory database")?;
|
|
48
|
+
connection.pragma_update(None, "journal_mode", "WAL")?;
|
|
49
|
+
Self::from_connection(connection, cipher)
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
fn from_connection(connection: Connection, cipher: Option<Arc<SecretCipher>>) -> Result<Self> {
|
|
53
|
+
let store = Self {
|
|
54
|
+
connection: Mutex::new(connection),
|
|
55
|
+
cipher,
|
|
56
|
+
};
|
|
57
|
+
store.migrate()?;
|
|
58
|
+
Ok(store)
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
fn migrate(&self) -> Result<()> {
|
|
62
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
63
|
+
connection.execute_batch(
|
|
64
|
+
r#"
|
|
65
|
+
CREATE TABLE IF NOT EXISTS messages (
|
|
66
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
67
|
+
conversation_id TEXT NOT NULL,
|
|
68
|
+
role TEXT NOT NULL,
|
|
69
|
+
content TEXT NOT NULL,
|
|
70
|
+
created_at TEXT NOT NULL
|
|
71
|
+
);
|
|
72
|
+
|
|
73
|
+
CREATE TABLE IF NOT EXISTS tasks (
|
|
74
|
+
id TEXT PRIMARY KEY,
|
|
75
|
+
kind TEXT NOT NULL,
|
|
76
|
+
description TEXT NOT NULL,
|
|
77
|
+
queued_at TEXT NOT NULL,
|
|
78
|
+
started_at TEXT,
|
|
79
|
+
finished_at TEXT,
|
|
80
|
+
running INTEGER NOT NULL
|
|
81
|
+
);
|
|
82
|
+
|
|
83
|
+
CREATE TABLE IF NOT EXISTS shell_sessions (
|
|
84
|
+
id TEXT PRIMARY KEY,
|
|
85
|
+
name TEXT NOT NULL,
|
|
86
|
+
cwd TEXT NOT NULL,
|
|
87
|
+
created_at TEXT NOT NULL,
|
|
88
|
+
last_used_at TEXT NOT NULL,
|
|
89
|
+
active INTEGER NOT NULL
|
|
90
|
+
);
|
|
91
|
+
|
|
92
|
+
CREATE TABLE IF NOT EXISTS approvals (
|
|
93
|
+
id TEXT PRIMARY KEY,
|
|
94
|
+
reason TEXT NOT NULL,
|
|
95
|
+
risk_level TEXT NOT NULL,
|
|
96
|
+
command_preview TEXT NOT NULL,
|
|
97
|
+
cwd TEXT,
|
|
98
|
+
expires_at TEXT NOT NULL,
|
|
99
|
+
approved INTEGER
|
|
100
|
+
);
|
|
101
|
+
|
|
102
|
+
CREATE TABLE IF NOT EXISTS command_audit (
|
|
103
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
104
|
+
task_id TEXT,
|
|
105
|
+
session_id TEXT,
|
|
106
|
+
command_preview TEXT NOT NULL,
|
|
107
|
+
exit_code INTEGER,
|
|
108
|
+
duration_ms INTEGER NOT NULL,
|
|
109
|
+
created_at TEXT NOT NULL
|
|
110
|
+
);
|
|
111
|
+
|
|
112
|
+
CREATE TABLE IF NOT EXISTS conversation_summaries (
|
|
113
|
+
conversation_id TEXT PRIMARY KEY,
|
|
114
|
+
summary_json TEXT NOT NULL,
|
|
115
|
+
updated_at TEXT NOT NULL
|
|
116
|
+
);
|
|
117
|
+
|
|
118
|
+
CREATE TABLE IF NOT EXISTS config_kv (
|
|
119
|
+
key TEXT PRIMARY KEY,
|
|
120
|
+
value TEXT NOT NULL
|
|
121
|
+
);
|
|
122
|
+
|
|
123
|
+
CREATE TABLE IF NOT EXISTS providers (
|
|
124
|
+
id TEXT PRIMARY KEY,
|
|
125
|
+
label TEXT NOT NULL,
|
|
126
|
+
kind TEXT NOT NULL,
|
|
127
|
+
base_url TEXT NOT NULL,
|
|
128
|
+
api_key TEXT NOT NULL,
|
|
129
|
+
model TEXT NOT NULL,
|
|
130
|
+
enabled INTEGER NOT NULL,
|
|
131
|
+
extra_headers_json TEXT NOT NULL,
|
|
132
|
+
is_default INTEGER NOT NULL,
|
|
133
|
+
created_at TEXT NOT NULL
|
|
134
|
+
);
|
|
135
|
+
|
|
136
|
+
CREATE TABLE IF NOT EXISTS onboarding_sessions (
|
|
137
|
+
chat_id INTEGER NOT NULL,
|
|
138
|
+
user_id INTEGER NOT NULL,
|
|
139
|
+
step TEXT NOT NULL,
|
|
140
|
+
draft_json TEXT NOT NULL,
|
|
141
|
+
updated_at TEXT NOT NULL,
|
|
142
|
+
PRIMARY KEY(chat_id, user_id)
|
|
143
|
+
);
|
|
144
|
+
|
|
145
|
+
CREATE TABLE IF NOT EXISTS chat_provider_bindings (
|
|
146
|
+
chat_id INTEGER PRIMARY KEY,
|
|
147
|
+
provider_id TEXT NOT NULL
|
|
148
|
+
);
|
|
149
|
+
"#,
|
|
150
|
+
)?;
|
|
151
|
+
Ok(())
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
pub fn save_message(
|
|
155
|
+
&self,
|
|
156
|
+
conversation_id: ConversationId,
|
|
157
|
+
message: &ConversationMessage,
|
|
158
|
+
) -> Result<()> {
|
|
159
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
160
|
+
connection.execute(
|
|
161
|
+
"INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
|
162
|
+
params![
|
|
163
|
+
conversation_id.to_string(),
|
|
164
|
+
format!("{:?}", message.role),
|
|
165
|
+
message.content,
|
|
166
|
+
message.created_at.to_rfc3339()
|
|
167
|
+
],
|
|
168
|
+
)?;
|
|
169
|
+
Ok(())
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
pub fn upsert_task(&self, task: &TaskStatus) -> Result<()> {
|
|
173
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
174
|
+
connection.execute(
|
|
175
|
+
r#"
|
|
176
|
+
INSERT INTO tasks (id, kind, description, queued_at, started_at, finished_at, running)
|
|
177
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
178
|
+
ON CONFLICT(id) DO UPDATE SET
|
|
179
|
+
kind=excluded.kind,
|
|
180
|
+
description=excluded.description,
|
|
181
|
+
queued_at=excluded.queued_at,
|
|
182
|
+
started_at=excluded.started_at,
|
|
183
|
+
finished_at=excluded.finished_at,
|
|
184
|
+
running=excluded.running
|
|
185
|
+
"#,
|
|
186
|
+
params![
|
|
187
|
+
task.id.to_string(),
|
|
188
|
+
format!("{:?}", task.kind),
|
|
189
|
+
task.description,
|
|
190
|
+
task.queued_at.to_rfc3339(),
|
|
191
|
+
task.started_at.map(|ts| ts.to_rfc3339()),
|
|
192
|
+
task.finished_at.map(|ts| ts.to_rfc3339()),
|
|
193
|
+
i64::from(task.running),
|
|
194
|
+
],
|
|
195
|
+
)?;
|
|
196
|
+
Ok(())
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
pub fn list_tasks(&self) -> Result<Vec<TaskStatus>> {
|
|
200
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
201
|
+
let mut stmt = connection.prepare(
|
|
202
|
+
"SELECT id, kind, description, queued_at, started_at, finished_at, running FROM tasks ORDER BY queued_at DESC",
|
|
203
|
+
)?;
|
|
204
|
+
let rows = stmt.query_map([], |row| {
|
|
205
|
+
Ok(TaskStatus {
|
|
206
|
+
id: TaskId(uuid::Uuid::parse_str(row.get::<_, String>(0)?.as_str()).unwrap()),
|
|
207
|
+
kind: match row.get::<_, String>(1)?.as_str() {
|
|
208
|
+
"PersistentShellTask" => TaskKind::PersistentShellTask,
|
|
209
|
+
_ => TaskKind::EphemeralAgentTask,
|
|
210
|
+
},
|
|
211
|
+
description: row.get(2)?,
|
|
212
|
+
queued_at: parse_ts(row.get::<_, String>(3)?),
|
|
213
|
+
started_at: row.get::<_, Option<String>>(4)?.map(parse_ts),
|
|
214
|
+
finished_at: row.get::<_, Option<String>>(5)?.map(parse_ts),
|
|
215
|
+
running: row.get::<_, i64>(6)? != 0,
|
|
216
|
+
})
|
|
217
|
+
})?;
|
|
218
|
+
|
|
219
|
+
rows.collect::<rusqlite::Result<Vec<_>>>()
|
|
220
|
+
.map_err(Into::into)
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
pub fn upsert_session(&self, session: &SessionHandle, active: bool) -> Result<()> {
|
|
224
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
225
|
+
connection.execute(
|
|
226
|
+
r#"
|
|
227
|
+
INSERT INTO shell_sessions (id, name, cwd, created_at, last_used_at, active)
|
|
228
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
229
|
+
ON CONFLICT(id) DO UPDATE SET
|
|
230
|
+
name=excluded.name,
|
|
231
|
+
cwd=excluded.cwd,
|
|
232
|
+
created_at=excluded.created_at,
|
|
233
|
+
last_used_at=excluded.last_used_at,
|
|
234
|
+
active=excluded.active
|
|
235
|
+
"#,
|
|
236
|
+
params![
|
|
237
|
+
session.id.to_string(),
|
|
238
|
+
session.name,
|
|
239
|
+
session.cwd.display().to_string(),
|
|
240
|
+
session.created_at.to_rfc3339(),
|
|
241
|
+
session.last_used_at.to_rfc3339(),
|
|
242
|
+
i64::from(active),
|
|
243
|
+
],
|
|
244
|
+
)?;
|
|
245
|
+
Ok(())
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
pub fn save_approval(&self, approval: &ApprovalRequest, approved: Option<bool>) -> Result<()> {
|
|
249
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
250
|
+
connection.execute(
|
|
251
|
+
r#"
|
|
252
|
+
INSERT INTO approvals (id, reason, risk_level, command_preview, cwd, expires_at, approved)
|
|
253
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
254
|
+
ON CONFLICT(id) DO UPDATE SET
|
|
255
|
+
reason=excluded.reason,
|
|
256
|
+
risk_level=excluded.risk_level,
|
|
257
|
+
command_preview=excluded.command_preview,
|
|
258
|
+
cwd=excluded.cwd,
|
|
259
|
+
expires_at=excluded.expires_at,
|
|
260
|
+
approved=excluded.approved
|
|
261
|
+
"#,
|
|
262
|
+
params![
|
|
263
|
+
approval.request_id.to_string(),
|
|
264
|
+
approval.reason,
|
|
265
|
+
format!("{:?}", approval.risk_level),
|
|
266
|
+
approval.command_preview,
|
|
267
|
+
approval.cwd.as_ref().map(|cwd| cwd.display().to_string()),
|
|
268
|
+
approval.expires_at.to_rfc3339(),
|
|
269
|
+
approved.map(i64::from),
|
|
270
|
+
],
|
|
271
|
+
)?;
|
|
272
|
+
Ok(())
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
pub fn get_approval_state(&self, request_id: &str) -> Result<Option<Option<bool>>> {
|
|
276
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
277
|
+
connection
|
|
278
|
+
.query_row(
|
|
279
|
+
"SELECT approved FROM approvals WHERE id = ?",
|
|
280
|
+
params![request_id],
|
|
281
|
+
|row| {
|
|
282
|
+
row.get::<_, Option<i64>>(0)
|
|
283
|
+
.map(|value| value.map(|v| v != 0))
|
|
284
|
+
},
|
|
285
|
+
)
|
|
286
|
+
.optional()
|
|
287
|
+
.map_err(Into::into)
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
pub fn append_command_audit(
|
|
291
|
+
&self,
|
|
292
|
+
task_id: Option<TaskId>,
|
|
293
|
+
session_id: Option<String>,
|
|
294
|
+
command_preview: &str,
|
|
295
|
+
exit_code: Option<i32>,
|
|
296
|
+
duration_ms: u128,
|
|
297
|
+
) -> Result<()> {
|
|
298
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
299
|
+
connection.execute(
|
|
300
|
+
"INSERT INTO command_audit (task_id, session_id, command_preview, exit_code, duration_ms, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
|
301
|
+
params![
|
|
302
|
+
task_id.map(|id| id.to_string()),
|
|
303
|
+
session_id,
|
|
304
|
+
command_preview,
|
|
305
|
+
exit_code,
|
|
306
|
+
duration_ms as i64,
|
|
307
|
+
Utc::now().to_rfc3339(),
|
|
308
|
+
],
|
|
309
|
+
)?;
|
|
310
|
+
Ok(())
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
pub fn save_summary(&self, summary: &SessionSummary) -> Result<()> {
|
|
314
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
315
|
+
connection.execute(
|
|
316
|
+
r#"
|
|
317
|
+
INSERT INTO conversation_summaries (conversation_id, summary_json, updated_at)
|
|
318
|
+
VALUES (?, ?, ?)
|
|
319
|
+
ON CONFLICT(conversation_id) DO UPDATE SET
|
|
320
|
+
summary_json=excluded.summary_json,
|
|
321
|
+
updated_at=excluded.updated_at
|
|
322
|
+
"#,
|
|
323
|
+
params![
|
|
324
|
+
summary.session_id.to_string(),
|
|
325
|
+
serde_json::to_string(summary)?,
|
|
326
|
+
Utc::now().to_rfc3339(),
|
|
327
|
+
],
|
|
328
|
+
)?;
|
|
329
|
+
Ok(())
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
pub fn load_summary(&self, conversation_id: ConversationId) -> Result<Option<SessionSummary>> {
|
|
333
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
334
|
+
let payload = connection
|
|
335
|
+
.query_row(
|
|
336
|
+
"SELECT summary_json FROM conversation_summaries WHERE conversation_id = ?",
|
|
337
|
+
params![conversation_id.to_string()],
|
|
338
|
+
|row| row.get::<_, String>(0),
|
|
339
|
+
)
|
|
340
|
+
.optional()?;
|
|
341
|
+
|
|
342
|
+
payload
|
|
343
|
+
.map(|json| serde_json::from_str(&json).context("decode conversation summary"))
|
|
344
|
+
.transpose()
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
pub fn set_config(&self, key: &str, value: &str) -> Result<()> {
|
|
348
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
349
|
+
connection.execute(
|
|
350
|
+
"INSERT INTO config_kv (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value",
|
|
351
|
+
params![key, value],
|
|
352
|
+
)?;
|
|
353
|
+
Ok(())
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
pub fn get_config(&self, key: &str) -> Result<Option<String>> {
|
|
357
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
358
|
+
connection
|
|
359
|
+
.query_row(
|
|
360
|
+
"SELECT value FROM config_kv WHERE key = ?",
|
|
361
|
+
params![key],
|
|
362
|
+
|row| row.get(0),
|
|
363
|
+
)
|
|
364
|
+
.optional()
|
|
365
|
+
.map_err(Into::into)
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
pub fn upsert_provider(&self, record: &ProviderRecord) -> Result<()> {
|
|
369
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
370
|
+
if record.is_default {
|
|
371
|
+
connection.execute("UPDATE providers SET is_default = 0", [])?;
|
|
372
|
+
}
|
|
373
|
+
let encrypted_api_key = self.encrypt_secret(&record.config.api_key)?;
|
|
374
|
+
connection.execute(
|
|
375
|
+
r#"
|
|
376
|
+
INSERT INTO providers (id, label, kind, base_url, api_key, model, enabled, extra_headers_json, is_default, created_at)
|
|
377
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
378
|
+
ON CONFLICT(id) DO UPDATE SET
|
|
379
|
+
label=excluded.label,
|
|
380
|
+
kind=excluded.kind,
|
|
381
|
+
base_url=excluded.base_url,
|
|
382
|
+
api_key=excluded.api_key,
|
|
383
|
+
model=excluded.model,
|
|
384
|
+
enabled=excluded.enabled,
|
|
385
|
+
extra_headers_json=excluded.extra_headers_json,
|
|
386
|
+
is_default=excluded.is_default,
|
|
387
|
+
created_at=excluded.created_at
|
|
388
|
+
"#,
|
|
389
|
+
params![
|
|
390
|
+
record.config.id,
|
|
391
|
+
record.config.label,
|
|
392
|
+
record.config.kind.as_str(),
|
|
393
|
+
record.config.base_url,
|
|
394
|
+
encrypted_api_key,
|
|
395
|
+
record.config.model,
|
|
396
|
+
i64::from(record.config.enabled),
|
|
397
|
+
serde_json::to_string(&record.config.extra_headers)?,
|
|
398
|
+
i64::from(record.is_default),
|
|
399
|
+
record.config.created_at.to_rfc3339(),
|
|
400
|
+
],
|
|
401
|
+
)?;
|
|
402
|
+
Ok(())
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
pub fn list_providers(&self) -> Result<Vec<ProviderRecord>> {
|
|
406
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
407
|
+
let mut stmt = connection.prepare(
|
|
408
|
+
r#"
|
|
409
|
+
SELECT id, label, kind, base_url, api_key, model, enabled, extra_headers_json, is_default, created_at
|
|
410
|
+
FROM providers
|
|
411
|
+
ORDER BY is_default DESC, created_at ASC
|
|
412
|
+
"#,
|
|
413
|
+
)?;
|
|
414
|
+
let rows = stmt.query_map([], |row| self.row_to_provider(row))?;
|
|
415
|
+
rows.collect::<rusqlite::Result<Vec<_>>>()
|
|
416
|
+
.map_err(Into::into)
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
pub fn get_provider(&self, provider_id: &str) -> Result<Option<ProviderRecord>> {
|
|
420
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
421
|
+
connection
|
|
422
|
+
.query_row(
|
|
423
|
+
r#"
|
|
424
|
+
SELECT id, label, kind, base_url, api_key, model, enabled, extra_headers_json, is_default, created_at
|
|
425
|
+
FROM providers WHERE id = ?
|
|
426
|
+
"#,
|
|
427
|
+
params![provider_id],
|
|
428
|
+
|row| self.row_to_provider(row),
|
|
429
|
+
)
|
|
430
|
+
.optional()
|
|
431
|
+
.map_err(Into::into)
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
pub fn get_default_provider(&self) -> Result<Option<ProviderRecord>> {
|
|
435
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
436
|
+
connection
|
|
437
|
+
.query_row(
|
|
438
|
+
r#"
|
|
439
|
+
SELECT id, label, kind, base_url, api_key, model, enabled, extra_headers_json, is_default, created_at
|
|
440
|
+
FROM providers WHERE is_default = 1 LIMIT 1
|
|
441
|
+
"#,
|
|
442
|
+
[],
|
|
443
|
+
|row| self.row_to_provider(row),
|
|
444
|
+
)
|
|
445
|
+
.optional()
|
|
446
|
+
.map_err(Into::into)
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
pub fn set_default_provider(&self, provider_id: &str) -> Result<()> {
|
|
450
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
451
|
+
connection.execute("UPDATE providers SET is_default = 0", [])?;
|
|
452
|
+
connection.execute(
|
|
453
|
+
"UPDATE providers SET is_default = 1 WHERE id = ?",
|
|
454
|
+
params![provider_id],
|
|
455
|
+
)?;
|
|
456
|
+
Ok(())
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
pub fn bind_provider_to_chat(&self, chat_id: i64, provider_id: &str) -> Result<()> {
|
|
460
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
461
|
+
connection.execute(
|
|
462
|
+
r#"
|
|
463
|
+
INSERT INTO chat_provider_bindings (chat_id, provider_id)
|
|
464
|
+
VALUES (?, ?)
|
|
465
|
+
ON CONFLICT(chat_id) DO UPDATE SET provider_id=excluded.provider_id
|
|
466
|
+
"#,
|
|
467
|
+
params![chat_id, provider_id],
|
|
468
|
+
)?;
|
|
469
|
+
Ok(())
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
pub fn get_bound_provider_id(&self, chat_id: i64) -> Result<Option<String>> {
|
|
473
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
474
|
+
connection
|
|
475
|
+
.query_row(
|
|
476
|
+
"SELECT provider_id FROM chat_provider_bindings WHERE chat_id = ?",
|
|
477
|
+
params![chat_id],
|
|
478
|
+
|row| row.get(0),
|
|
479
|
+
)
|
|
480
|
+
.optional()
|
|
481
|
+
.map_err(Into::into)
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
pub fn resolve_provider_for_chat(&self, chat_id: i64) -> Result<Option<ProviderRecord>> {
|
|
485
|
+
if let Some(provider_id) = self.get_bound_provider_id(chat_id)? {
|
|
486
|
+
if let Some(record) = self.get_provider(&provider_id)? {
|
|
487
|
+
return Ok(Some(record));
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
self.get_default_provider()
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
pub fn save_onboarding_session(&self, session: &OnboardingSession) -> Result<()> {
|
|
494
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
495
|
+
connection.execute(
|
|
496
|
+
r#"
|
|
497
|
+
INSERT INTO onboarding_sessions (chat_id, user_id, step, draft_json, updated_at)
|
|
498
|
+
VALUES (?, ?, ?, ?, ?)
|
|
499
|
+
ON CONFLICT(chat_id, user_id) DO UPDATE SET
|
|
500
|
+
step=excluded.step,
|
|
501
|
+
draft_json=excluded.draft_json,
|
|
502
|
+
updated_at=excluded.updated_at
|
|
503
|
+
"#,
|
|
504
|
+
params![
|
|
505
|
+
session.chat_id,
|
|
506
|
+
session.user_id,
|
|
507
|
+
format!("{:?}", session.step),
|
|
508
|
+
serde_json::to_string(&session.draft)?,
|
|
509
|
+
session.updated_at.to_rfc3339(),
|
|
510
|
+
],
|
|
511
|
+
)?;
|
|
512
|
+
Ok(())
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
pub fn load_onboarding_session(
|
|
516
|
+
&self,
|
|
517
|
+
chat_id: i64,
|
|
518
|
+
user_id: i64,
|
|
519
|
+
) -> Result<Option<OnboardingSession>> {
|
|
520
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
521
|
+
connection
|
|
522
|
+
.query_row(
|
|
523
|
+
"SELECT step, draft_json, updated_at FROM onboarding_sessions WHERE chat_id = ? AND user_id = ?",
|
|
524
|
+
params![chat_id, user_id],
|
|
525
|
+
|row| {
|
|
526
|
+
let step = parse_onboarding_step(row.get::<_, String>(0)?);
|
|
527
|
+
let draft_json: String = row.get(1)?;
|
|
528
|
+
let draft: ProviderDraft = serde_json::from_str(&draft_json).map_err(to_sql_err)?;
|
|
529
|
+
let updated_at = parse_ts(row.get::<_, String>(2)?);
|
|
530
|
+
Ok(OnboardingSession {
|
|
531
|
+
user_id,
|
|
532
|
+
chat_id,
|
|
533
|
+
step,
|
|
534
|
+
draft,
|
|
535
|
+
updated_at,
|
|
536
|
+
})
|
|
537
|
+
},
|
|
538
|
+
)
|
|
539
|
+
.optional()
|
|
540
|
+
.map_err(Into::into)
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
pub fn clear_onboarding_session(&self, chat_id: i64, user_id: i64) -> Result<()> {
|
|
544
|
+
let connection = self.connection.lock().expect("store lock poisoned");
|
|
545
|
+
connection.execute(
|
|
546
|
+
"DELETE FROM onboarding_sessions WHERE chat_id = ? AND user_id = ?",
|
|
547
|
+
params![chat_id, user_id],
|
|
548
|
+
)?;
|
|
549
|
+
Ok(())
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
fn row_to_provider(&self, row: &rusqlite::Row<'_>) -> rusqlite::Result<ProviderRecord> {
|
|
553
|
+
let kind = match row.get::<_, String>(2)?.as_str() {
|
|
554
|
+
"custom-chat-completions" => ProviderKind::CustomChatCompletions,
|
|
555
|
+
_ => ProviderKind::OpenAiCompatible,
|
|
556
|
+
};
|
|
557
|
+
let headers_json: String = row.get(7)?;
|
|
558
|
+
let extra_headers = serde_json::from_str(&headers_json).map_err(to_sql_err)?;
|
|
559
|
+
let api_key_raw: String = row.get(4)?;
|
|
560
|
+
let api_key = self
|
|
561
|
+
.decrypt_secret(&api_key_raw)
|
|
562
|
+
.map_err(to_sql_anyhow_err)?;
|
|
563
|
+
Ok(ProviderRecord {
|
|
564
|
+
config: ProviderConfig {
|
|
565
|
+
id: row.get(0)?,
|
|
566
|
+
label: row.get(1)?,
|
|
567
|
+
kind,
|
|
568
|
+
base_url: row.get(3)?,
|
|
569
|
+
api_key,
|
|
570
|
+
model: row.get(5)?,
|
|
571
|
+
enabled: row.get::<_, i64>(6)? != 0,
|
|
572
|
+
extra_headers,
|
|
573
|
+
created_at: parse_ts(row.get(9)?),
|
|
574
|
+
},
|
|
575
|
+
is_default: row.get::<_, i64>(8)? != 0,
|
|
576
|
+
})
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
fn encrypt_secret(&self, value: &str) -> Result<String> {
|
|
580
|
+
match &self.cipher {
|
|
581
|
+
Some(cipher) => cipher.encrypt(value),
|
|
582
|
+
None => Ok(value.to_string()),
|
|
583
|
+
}
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
fn decrypt_secret(&self, value: &str) -> Result<String> {
|
|
587
|
+
match &self.cipher {
|
|
588
|
+
Some(cipher) => cipher.decrypt_or_passthrough(value),
|
|
589
|
+
None => Ok(value.to_string()),
|
|
590
|
+
}
|
|
591
|
+
}
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
impl SecretCipher {
|
|
595
|
+
pub fn from_passphrase(passphrase: &str) -> Result<Self> {
|
|
596
|
+
if passphrase.trim().is_empty() {
|
|
597
|
+
anyhow::bail!("master key must not be empty");
|
|
598
|
+
}
|
|
599
|
+
let digest = Sha256::digest(passphrase.as_bytes());
|
|
600
|
+
let mut key = [0_u8; 32];
|
|
601
|
+
key.copy_from_slice(&digest);
|
|
602
|
+
Ok(Self { key })
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
pub fn encrypt(&self, plaintext: &str) -> Result<String> {
|
|
606
|
+
let cipher = Aes256Gcm::new_from_slice(&self.key).context("initialize aes-256-gcm")?;
|
|
607
|
+
let mut nonce_bytes = [0_u8; 12];
|
|
608
|
+
OsRng.fill_bytes(&mut nonce_bytes);
|
|
609
|
+
let nonce = Nonce::from_slice(&nonce_bytes);
|
|
610
|
+
let ciphertext = cipher
|
|
611
|
+
.encrypt(nonce, plaintext.as_bytes())
|
|
612
|
+
.map_err(|_| anyhow::anyhow!("failed to encrypt secret"))?;
|
|
613
|
+
let mut payload = nonce_bytes.to_vec();
|
|
614
|
+
payload.extend(ciphertext);
|
|
615
|
+
Ok(format!("enc:v1:{}", BASE64.encode(payload)))
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
pub fn decrypt_or_passthrough(&self, value: &str) -> Result<String> {
|
|
619
|
+
if !value.starts_with("enc:v1:") {
|
|
620
|
+
return Ok(value.to_string());
|
|
621
|
+
}
|
|
622
|
+
let encoded = value.trim_start_matches("enc:v1:");
|
|
623
|
+
let payload = BASE64
|
|
624
|
+
.decode(encoded)
|
|
625
|
+
.context("decode encrypted provider secret")?;
|
|
626
|
+
if payload.len() < 13 {
|
|
627
|
+
anyhow::bail!("encrypted payload is too short");
|
|
628
|
+
}
|
|
629
|
+
let (nonce_bytes, ciphertext) = payload.split_at(12);
|
|
630
|
+
let cipher = Aes256Gcm::new_from_slice(&self.key).context("initialize aes-256-gcm")?;
|
|
631
|
+
let plaintext = cipher
|
|
632
|
+
.decrypt(Nonce::from_slice(nonce_bytes), ciphertext)
|
|
633
|
+
.map_err(|_| anyhow::anyhow!("failed to decrypt provider secret"))?;
|
|
634
|
+
String::from_utf8(plaintext).context("provider secret is not valid utf-8")
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
fn parse_ts(ts: String) -> DateTime<Utc> {
|
|
639
|
+
DateTime::parse_from_rfc3339(&ts)
|
|
640
|
+
.expect("timestamp stored in rfc3339")
|
|
641
|
+
.with_timezone(&Utc)
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
fn parse_onboarding_step(step: String) -> hajimi_claw_types::OnboardingStep {
|
|
645
|
+
match step.as_str() {
|
|
646
|
+
"ProviderKind" => hajimi_claw_types::OnboardingStep::ProviderKind,
|
|
647
|
+
"ProviderBaseUrl" => hajimi_claw_types::OnboardingStep::ProviderBaseUrl,
|
|
648
|
+
"ProviderApiKey" => hajimi_claw_types::OnboardingStep::ProviderApiKey,
|
|
649
|
+
"ProviderModel" => hajimi_claw_types::OnboardingStep::ProviderModel,
|
|
650
|
+
"Completed" => hajimi_claw_types::OnboardingStep::Completed,
|
|
651
|
+
_ => hajimi_claw_types::OnboardingStep::ProviderLabel,
|
|
652
|
+
}
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
fn to_sql_err(err: serde_json::Error) -> rusqlite::Error {
|
|
656
|
+
rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(err))
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
fn to_sql_anyhow_err(err: anyhow::Error) -> rusqlite::Error {
|
|
660
|
+
rusqlite::Error::FromSqlConversionFailure(
|
|
661
|
+
0,
|
|
662
|
+
rusqlite::types::Type::Text,
|
|
663
|
+
Box::new(std::io::Error::other(err.to_string())),
|
|
664
|
+
)
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
#[cfg(test)]
|
|
668
|
+
mod tests {
|
|
669
|
+
use std::sync::Arc;
|
|
670
|
+
|
|
671
|
+
use chrono::Utc;
|
|
672
|
+
use hajimi_claw_types::{
|
|
673
|
+
ConversationMessage, MessageRole, ProviderConfig, ProviderKind, ProviderRecord,
|
|
674
|
+
};
|
|
675
|
+
|
|
676
|
+
use super::{SecretCipher, Store};
|
|
677
|
+
|
|
678
|
+
#[test]
|
|
679
|
+
fn persists_message_and_task() {
|
|
680
|
+
let store = Store::open_in_memory().unwrap();
|
|
681
|
+
let conversation_id = hajimi_claw_types::ConversationId::new();
|
|
682
|
+
store
|
|
683
|
+
.save_message(
|
|
684
|
+
conversation_id,
|
|
685
|
+
&ConversationMessage {
|
|
686
|
+
role: MessageRole::User,
|
|
687
|
+
content: "hello".into(),
|
|
688
|
+
created_at: Utc::now(),
|
|
689
|
+
},
|
|
690
|
+
)
|
|
691
|
+
.unwrap();
|
|
692
|
+
|
|
693
|
+
let task = hajimi_claw_types::TaskStatus {
|
|
694
|
+
id: hajimi_claw_types::TaskId::new(),
|
|
695
|
+
kind: hajimi_claw_types::TaskKind::EphemeralAgentTask,
|
|
696
|
+
description: "test".into(),
|
|
697
|
+
queued_at: Utc::now(),
|
|
698
|
+
started_at: None,
|
|
699
|
+
finished_at: None,
|
|
700
|
+
running: false,
|
|
701
|
+
};
|
|
702
|
+
store.upsert_task(&task).unwrap();
|
|
703
|
+
let tasks = store.list_tasks().unwrap();
|
|
704
|
+
assert_eq!(tasks.len(), 1);
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
#[test]
|
|
708
|
+
fn encrypts_provider_api_key_when_cipher_enabled() {
|
|
709
|
+
let cipher = Arc::new(SecretCipher::from_passphrase("secret").unwrap());
|
|
710
|
+
let store = Store::open_in_memory_with_cipher(Some(cipher)).unwrap();
|
|
711
|
+
store
|
|
712
|
+
.upsert_provider(&ProviderRecord {
|
|
713
|
+
config: ProviderConfig {
|
|
714
|
+
id: "demo".into(),
|
|
715
|
+
label: "Demo".into(),
|
|
716
|
+
kind: ProviderKind::OpenAiCompatible,
|
|
717
|
+
base_url: "https://example.com/v1".into(),
|
|
718
|
+
api_key: "top-secret".into(),
|
|
719
|
+
model: "gpt-demo".into(),
|
|
720
|
+
enabled: true,
|
|
721
|
+
extra_headers: vec![],
|
|
722
|
+
created_at: Utc::now(),
|
|
723
|
+
},
|
|
724
|
+
is_default: true,
|
|
725
|
+
})
|
|
726
|
+
.unwrap();
|
|
727
|
+
let provider = store.get_default_provider().unwrap().unwrap();
|
|
728
|
+
assert_eq!(provider.config.api_key, "top-secret");
|
|
729
|
+
}
|
|
730
|
+
}
|