codeatrium 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,290 @@
1
+ """蒸留モジュール: claude -p で exchange を palace object に変換する
2
+
3
+ SPEC Section 6 DISTILLER フロー準拠:
4
+ ① files_touched を regex で抽出(LLM非使用)
5
+ ② claude -p で palace object 生成(--output-format json --json-schema)
6
+ ③ distill_text を embedding して vec_palace に登録
7
+ ④ files_touched を tree-sitter で解析してシンボルを symbols テーブルに登録
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import datetime
13
+ import hashlib
14
+ import os
15
+ import re
16
+ import struct
17
+ from collections.abc import Callable
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
22
+
23
+ from codeatrium.embedder import Embedder
24
+ from codeatrium.llm import DISTILL_PROMPT_TEMPLATE, call_claude
25
+ from codeatrium.models import PalaceObject
26
+
27
+ # ---- ファイルパス抽出 ----
28
+
29
+ _FILES_PATTERN = re.compile(
30
+ r"(/(?:[a-zA-Z0-9._\-]+/)*[a-zA-Z0-9._\-]+\.[a-zA-Z0-9]+)" # 絶対パス
31
+ r"|([a-zA-Z0-9._\-]+(?:/[a-zA-Z0-9._\-]+)+\.[a-zA-Z0-9]+)" # 相対パス(1段以上のディレクトリ)
32
+ )
33
+
34
+
35
+ # ---- 内部ヘルパー ----
36
+
37
+
38
+ def _sha256(text: str) -> str:
39
+ return hashlib.sha256(text.encode()).hexdigest()
40
+
41
+
42
+ # ---- 公開 API ----
43
+
44
+
45
+ _EXTERNAL_PATH_MARKERS = (
46
+ "site-packages/",
47
+ "dist-packages/",
48
+ "/lib/python",
49
+ "/opt/",
50
+ "/usr/lib/",
51
+ "/usr/local/lib/",
52
+ ".venv/",
53
+ "/venv/",
54
+ "node_modules/",
55
+ )
56
+
57
+
58
+ def _is_external_path(path: str, project_root: str | None = None) -> bool:
59
+ """プロジェクト外のパスか判定する。
60
+
61
+ 絶対パス: project_root が指定されていればその配下かチェック。
62
+ 相対パス: ハードコードマーカーでフィルタ。
63
+ """
64
+ if path.startswith("/"):
65
+ # 絶対パス: project_root 配下でなければ外部
66
+ if project_root:
67
+ return not path.startswith(project_root)
68
+ # project_root 不明時はマーカーでフォールバック
69
+ return any(marker in path for marker in _EXTERNAL_PATH_MARKERS)
70
+
71
+
72
+ def extract_files_touched(
73
+ user_content: str, agent_content: str, project_root: str | None = None
74
+ ) -> list[str]:
75
+ """user_content + agent_content から regex でファイルパスを抽出する(重複排除・順序維持)
76
+
77
+ project_root が指定された場合、絶対パスはその配下のもののみ残す。
78
+ 相対パスはハードコードマーカー(node_modules 等)でフィルタする。
79
+ """
80
+ text = user_content + "\n" + agent_content
81
+ # project_root を末尾スラッシュ付きに正規化
82
+ root_prefix = (project_root.rstrip("/") + "/") if project_root else None
83
+ seen: set[str] = set()
84
+ result: list[str] = []
85
+ for m in _FILES_PATTERN.findall(text):
86
+ path = m[0] or m[1]
87
+ if path and path not in seen and not _is_external_path(path, root_prefix):
88
+ seen.add(path)
89
+ result.append(path)
90
+ return result
91
+
92
+
93
+ def distill_exchange(
94
+ exchange_id: str,
95
+ user_content: str,
96
+ agent_content: str,
97
+ ply_start: int,
98
+ ply_end: int,
99
+ model: str | None = None,
100
+ project_root: str | None = None,
101
+ ) -> PalaceObject:
102
+ """1つの exchange を蒸留して PalaceObject を返す"""
103
+ messages_text = (user_content + "\n" + agent_content)[:4000]
104
+ prompt = DISTILL_PROMPT_TEMPLATE.format(
105
+ ply_start=ply_start,
106
+ ply_end=ply_end,
107
+ messages_text=messages_text,
108
+ )
109
+ raw = call_claude(prompt, model=model)
110
+ files_touched = extract_files_touched(
111
+ user_content, agent_content, project_root=project_root
112
+ )
113
+ return PalaceObject(
114
+ exchange_core=raw["exchange_core"],
115
+ specific_context=raw["specific_context"],
116
+ room_assignments=raw["room_assignments"],
117
+ files_touched=files_touched,
118
+ )
119
+
120
+
121
+ def save_palace_object(
122
+ db_path: Path,
123
+ exchange_id: str,
124
+ palace: PalaceObject,
125
+ embedding: Any, # np.ndarray
126
+ ) -> None:
127
+ """PalaceObject を DB に保存し exchange の distilled_at を更新する"""
128
+ import numpy as np
129
+
130
+ from codeatrium.db import get_connection
131
+
132
+ palace_id = _sha256(f"palace:{exchange_id}")
133
+ distill_text = palace.exchange_core + "\n" + palace.specific_context
134
+
135
+ con = get_connection(db_path)
136
+
137
+ con.execute(
138
+ """
139
+ INSERT OR IGNORE INTO palace_objects
140
+ (id, exchange_id, exchange_core, specific_context, distill_text)
141
+ VALUES (?, ?, ?, ?, ?)
142
+ """,
143
+ (
144
+ palace_id,
145
+ exchange_id,
146
+ palace.exchange_core,
147
+ palace.specific_context,
148
+ distill_text,
149
+ ),
150
+ )
151
+
152
+ for room in palace.room_assignments:
153
+ dedup = _sha256(f"{room['room_type']}:{room['room_key']}")
154
+ room_id = _sha256(f"{palace_id}:{dedup}")
155
+ con.execute(
156
+ """
157
+ INSERT OR IGNORE INTO rooms
158
+ (id, palace_object_id, room_type, room_key, room_label, relevance, dedup_hash)
159
+ VALUES (?, ?, ?, ?, ?, ?, ?)
160
+ """,
161
+ (
162
+ room_id,
163
+ palace_id,
164
+ room["room_type"],
165
+ room["room_key"],
166
+ room["room_label"],
167
+ room["relevance"],
168
+ dedup,
169
+ ),
170
+ )
171
+
172
+ arr = embedding.astype(np.float32)
173
+ blob = struct.pack(f"{len(arr)}f", *arr.tolist())
174
+ exists = con.execute(
175
+ "SELECT 1 FROM vec_palace WHERE palace_id = ?", (palace_id,)
176
+ ).fetchone()
177
+ if not exists:
178
+ con.execute(
179
+ "INSERT INTO vec_palace (palace_id, embedding) VALUES (?, ?)",
180
+ (palace_id, blob),
181
+ )
182
+
183
+ # ⑤ tree-sitter シンボル解決
184
+ from codeatrium.resolver import SymbolResolver
185
+
186
+ resolver = SymbolResolver()
187
+ for file_str in palace.files_touched:
188
+ for sym in resolver.extract(Path(file_str)):
189
+ sym_id = _sha256(f"{sym.symbol_name}:{sym.file_path}")
190
+ con.execute(
191
+ """
192
+ INSERT OR IGNORE INTO symbols
193
+ (id, palace_object_id, symbol_name, symbol_kind,
194
+ file_path, signature, line, dedup_hash)
195
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
196
+ """,
197
+ (
198
+ sym_id,
199
+ palace_id,
200
+ sym.symbol_name,
201
+ sym.symbol_kind,
202
+ sym.file_path,
203
+ sym.signature,
204
+ sym.line,
205
+ sym_id,
206
+ ),
207
+ )
208
+
209
+ con.execute(
210
+ "UPDATE exchanges SET distilled_at = ? WHERE id = ?",
211
+ (datetime.datetime.utcnow().isoformat(), exchange_id),
212
+ )
213
+
214
+ con.commit()
215
+ con.close()
216
+
217
+
218
+ def distill_all(
219
+ db_path: Path,
220
+ limit: int | None = None,
221
+ model: str | None = None,
222
+ on_progress: Callable[..., None] | None = None,
223
+ project_root: str | None = None,
224
+ distill_min_chars: int = 100,
225
+ ) -> int:
226
+ """未蒸留の exchange を処理する。
227
+
228
+ distill_min_chars: この文字数未満の exchange は蒸留スキップ(デフォルト100)
229
+ on_progress: (current, total, error=None) を受け取るコールバック
230
+ Returns: 処理した exchange 数
231
+ """
232
+ from codeatrium.db import get_connection
233
+
234
+ con = get_connection(db_path)
235
+
236
+ # 蒸留対象外の exchange を skipped にマーク:
237
+ # - 1-exchange セッション
238
+ # - distill_min_chars 未満(ワンフレーズ指示・システムメッセージ等)
239
+ con.execute("""
240
+ UPDATE exchanges SET distilled_at = 'skipped'
241
+ WHERE distilled_at IS NULL
242
+ AND ((SELECT COUNT(*) FROM exchanges e2
243
+ WHERE e2.conversation_id = exchanges.conversation_id) < 2
244
+ OR LENGTH(user_content) + LENGTH(agent_content) < ?)
245
+ """, (distill_min_chars,))
246
+ con.commit()
247
+
248
+ query = """
249
+ SELECT e.id, e.user_content, e.agent_content, e.ply_start, e.ply_end
250
+ FROM exchanges e
251
+ WHERE e.distilled_at IS NULL
252
+ """
253
+ params: list[int] = []
254
+ if limit is not None:
255
+ query += " LIMIT ?"
256
+ params.append(int(limit))
257
+ rows = con.execute(query, params).fetchall()
258
+ con.close()
259
+
260
+ if not rows:
261
+ return 0
262
+
263
+ total = len(rows)
264
+ embedder = Embedder()
265
+ count = 0
266
+ errors = 0
267
+ for row in rows:
268
+ try:
269
+ palace = distill_exchange(
270
+ row["id"],
271
+ row["user_content"],
272
+ row["agent_content"],
273
+ row["ply_start"],
274
+ row["ply_end"],
275
+ model=model,
276
+ project_root=project_root,
277
+ )
278
+ distill_text = palace.exchange_core + "\n" + palace.specific_context
279
+ vec = embedder.embed_passage(distill_text)
280
+ save_palace_object(db_path, row["id"], palace, vec)
281
+ count += 1
282
+ except Exception as e:
283
+ errors += 1
284
+ if on_progress is not None:
285
+ on_progress(count, total, error=str(e))
286
+ continue
287
+ if on_progress is not None:
288
+ on_progress(count, total)
289
+
290
+ return count
codeatrium/embedder.py ADDED
@@ -0,0 +1,168 @@
1
+ """
2
+ sentence-transformers / multilingual-e5-small の embedding ラッパー
3
+
4
+ モデル: intfloat/multilingual-e5-small(384次元・日本語+英語混在対応・CPU動作)
5
+
6
+ ソケットサーバー方式:
7
+ - .codeatrium/embedder.sock が存在すれば Unix ソケット経由で embed(< 1秒)
8
+ - ソケットなければ直接モデルをロード(~7秒)+バックグラウンドでサーバー起動
9
+ - 2回目以降は常にソケット経由になるため高速
10
+
11
+ cold start が問題になった場合は SentenceTransformer(..., backend="onnx") でも対処可能
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import socket
18
+ import subprocess
19
+ import sys
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+
25
+
26
+ def _loci_python() -> str:
27
+ """venv の Python パスを返す(sys.executable はシステム Python の場合があるため)"""
28
+ # loci CLI と同じ venv の python を使う
29
+ python = Path(sys.executable).parent / "python3"
30
+ if python.exists():
31
+ return str(python)
32
+ return sys.executable
33
+
34
+
35
+ MODEL_NAME = "intfloat/multilingual-e5-small"
36
+ CONNECT_TIMEOUT = 2.0 # ソケット接続タイムアウト
37
+ RECV_TIMEOUT = 30.0 # 埋め込み受信タイムアウト
38
+
39
+
40
+ def _sock_path_from_env() -> Path | None:
41
+ """環境変数 CODEATRIUM_SOCK_PATH があれば優先使用(テスト用)。
42
+ CODEATRIUM_NO_SOCK=1 の場合はソケット無効(サーバー内自己接続デッドロック防止)。
43
+ """
44
+ import os
45
+
46
+ if os.environ.get("CODEATRIUM_NO_SOCK"):
47
+ return None
48
+ p = os.environ.get("CODEATRIUM_SOCK_PATH")
49
+ return Path(p) if p else None
50
+
51
+
52
+ def _find_sock_path() -> Path | None:
53
+ """DB の親ディレクトリの embedder.sock を探す"""
54
+ import os
55
+
56
+ if os.environ.get("CODEATRIUM_NO_SOCK"):
57
+ return None
58
+ # .codeatrium/ の場所を git root から解決
59
+ try:
60
+ import subprocess as sp
61
+
62
+ result = sp.run(
63
+ ["git", "rev-parse", "--show-toplevel"],
64
+ capture_output=True,
65
+ text=True,
66
+ )
67
+ if result.returncode == 0:
68
+ return Path(result.stdout.strip()) / ".codeatrium" / "embedder.sock"
69
+ except Exception:
70
+ pass
71
+ return None
72
+
73
+
74
+ def _try_socket_embed(sock_path: Path, req_type: str, text: str) -> np.ndarray | None:
75
+ """ソケットサーバーに接続して embedding を取得する。失敗時は None を返す。"""
76
+ if not sock_path.exists():
77
+ return None
78
+ try:
79
+ with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
80
+ s.settimeout(CONNECT_TIMEOUT)
81
+ s.connect(str(sock_path))
82
+ s.settimeout(RECV_TIMEOUT)
83
+ req = json.dumps({"type": req_type, "text": text}) + "\n"
84
+ s.sendall(req.encode())
85
+ buf = b""
86
+ while b"\n" not in buf:
87
+ chunk = s.recv(65536)
88
+ if not chunk:
89
+ break
90
+ buf += chunk
91
+ resp = json.loads(buf.split(b"\n")[0])
92
+ if "embedding" in resp:
93
+ return np.array(resp["embedding"], dtype=np.float32)
94
+ except Exception:
95
+ pass
96
+ return None
97
+
98
+
99
+ def _start_server_background(sock_path: Path) -> None:
100
+ """embedder_server をバックグラウンドで起動する"""
101
+ try:
102
+ subprocess.Popen(
103
+ [_loci_python(), "-m", "codeatrium.embedder_server", str(sock_path)],
104
+ stdout=subprocess.DEVNULL,
105
+ stderr=subprocess.DEVNULL,
106
+ start_new_session=True,
107
+ )
108
+ # サーバーが起動するまで少し待つ
109
+ for _ in range(20):
110
+ if sock_path.exists():
111
+ break
112
+ time.sleep(0.2)
113
+ except Exception:
114
+ pass
115
+
116
+
117
+ class Embedder:
118
+ """multilingual-e5-small の薄いラッパー。
119
+
120
+ ソケットサーバーが起動済みなら高速パス(Unix ソケット)を使い、
121
+ なければ直接モデルをロードしてバックグラウンドでサーバーを起動する。
122
+ """
123
+
124
+ def __init__(self, sock_path: Path | None = None) -> None:
125
+ self._sock_path: Path | None = (
126
+ _sock_path_from_env() or sock_path or _find_sock_path()
127
+ )
128
+ self._model = None # 遅延ロード
129
+
130
+ def _ensure_model(self) -> None:
131
+ """直接モデルをロードする(ソケット不使用時のフォールバック)"""
132
+ if self._model is None:
133
+ from sentence_transformers import SentenceTransformer
134
+
135
+ self._model = SentenceTransformer(MODEL_NAME)
136
+
137
+ def _embed_via_socket_or_direct(
138
+ self, text: str, req_type: str, prefix: str
139
+ ) -> np.ndarray:
140
+ """ソケット優先・なければ直接ロード+サーバー起動"""
141
+ # ① ソケット経由を試みる
142
+ if self._sock_path is not None:
143
+ vec = _try_socket_embed(self._sock_path, req_type, text)
144
+ if vec is not None:
145
+ return vec
146
+
147
+ # ② 直接ロード
148
+ self._ensure_model()
149
+ assert self._model is not None
150
+ result = self._model.encode(
151
+ [f"{prefix}{text}"],
152
+ normalize_embeddings=True,
153
+ )
154
+ vec = result[0].astype(np.float32)
155
+
156
+ # ③ バックグラウンドでサーバーを起動(次回から高速化)
157
+ if self._sock_path is not None and not self._sock_path.exists():
158
+ _start_server_background(self._sock_path)
159
+
160
+ return vec
161
+
162
+ def embed(self, text: str) -> np.ndarray:
163
+ """クエリ用 embedding(query: プレフィックス)"""
164
+ return self._embed_via_socket_or_direct(text, "query", "query: ")
165
+
166
+ def embed_passage(self, text: str) -> np.ndarray:
167
+ """インデックス登録用 embedding(passage: プレフィックス)"""
168
+ return self._embed_via_socket_or_direct(text, "passage", "passage: ")
@@ -0,0 +1,172 @@
1
+ """
2
+ embedder_server.py — multilingual-e5-small を常駐させる Unix ソケットサーバー
3
+
4
+ プロトコル:
5
+ request (改行区切り JSON): {"type": "query"|"passage", "text": "..."}
6
+ response (改行区切り JSON): {"embedding": [0.1, 0.2, ...]}
7
+ 特殊コマンド: {"type": "ping"} → {"status": "ok"}
8
+ {"type": "stop"} → サーバー終了
9
+
10
+ ライフサイクル:
11
+ - loci server start でバックグラウンド起動
12
+ - IDLE_TIMEOUT 秒間リクエストなし → 自動終了
13
+ - loci server stop / SIGTERM で即終了
14
+ - ソケットファイルは .codeatrium/embedder.sock
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import os
21
+ import signal
22
+ import socket
23
+ import sys
24
+ import threading
25
+ import time
26
+ from pathlib import Path
27
+ from typing import TYPE_CHECKING
28
+
29
+ if TYPE_CHECKING:
30
+ from codeatrium.embedder import Embedder
31
+
32
+ IDLE_TIMEOUT = 600 # 10分間無リクエストで自動終了
33
+ BACKLOG = 8
34
+ RECV_BUFSIZE = 65536
35
+
36
+
37
+ def _load_embedder() -> Embedder:
38
+ # サーバー内では直接モデルを使う(ソケット経由にすると自己接続デッドロック)
39
+
40
+ os.environ["CODEATRIUM_NO_SOCK"] = "1"
41
+ from codeatrium.embedder import Embedder
42
+
43
+ embedder = Embedder()
44
+ del os.environ["CODEATRIUM_NO_SOCK"]
45
+ return embedder
46
+
47
+
48
+ def _handle_client(
49
+ conn: socket.socket,
50
+ embedder: Embedder,
51
+ last_activity: list[float],
52
+ stop_event: threading.Event,
53
+ ) -> None:
54
+ """1クライアント接続を処理する"""
55
+ try:
56
+ buf = b""
57
+ while True:
58
+ chunk = conn.recv(RECV_BUFSIZE)
59
+ if not chunk:
60
+ break
61
+ buf += chunk
62
+ while b"\n" in buf:
63
+ line, buf = buf.split(b"\n", 1)
64
+ line = line.strip()
65
+ if not line:
66
+ continue
67
+ try:
68
+ req = json.loads(line)
69
+ except json.JSONDecodeError:
70
+ conn.sendall(json.dumps({"error": "invalid json"}).encode() + b"\n")
71
+ continue
72
+
73
+ req_type = req.get("type", "")
74
+
75
+ if req_type == "ping":
76
+ conn.sendall(json.dumps({"status": "ok"}).encode() + b"\n")
77
+ last_activity[0] = time.monotonic()
78
+ continue
79
+
80
+ if req_type == "stop":
81
+ conn.sendall(json.dumps({"status": "stopping"}).encode() + b"\n")
82
+ stop_event.set()
83
+ return
84
+
85
+ text = req.get("text", "")
86
+ if not text:
87
+ conn.sendall(json.dumps({"error": "missing text"}).encode() + b"\n")
88
+ continue
89
+
90
+ if req_type == "query":
91
+ vec = embedder.embed(text)
92
+ elif req_type == "passage":
93
+ vec = embedder.embed_passage(text)
94
+ else:
95
+ conn.sendall(
96
+ json.dumps({"error": f"unknown type: {req_type}"}).encode()
97
+ + b"\n"
98
+ )
99
+ continue
100
+
101
+ resp = json.dumps({"embedding": vec.tolist()})
102
+ conn.sendall(resp.encode() + b"\n")
103
+ last_activity[0] = time.monotonic()
104
+ except (OSError, BrokenPipeError):
105
+ pass
106
+ finally:
107
+ try:
108
+ conn.close()
109
+ except OSError:
110
+ pass
111
+
112
+
113
+ def run_server(sock_path: Path) -> None:
114
+ """ソケットサーバーを起動してリクエストを処理する(ブロッキング)"""
115
+ # 既存ソケットファイルを削除
116
+ sock_path.unlink(missing_ok=True)
117
+
118
+ embedder = _load_embedder()
119
+
120
+ server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
121
+ old_umask = os.umask(0o177) # enforce 0o600 on the socket, no TOCTOU
122
+ try:
123
+ server.bind(str(sock_path))
124
+ finally:
125
+ os.umask(old_umask)
126
+ server.listen(BACKLOG)
127
+ server.settimeout(1.0) # accept の timeout(idle チェック用)
128
+
129
+ last_activity: list[float] = [time.monotonic()]
130
+ stop_event = threading.Event()
131
+
132
+ def _sigterm_handler(signum: int, frame: object) -> None:
133
+ stop_event.set()
134
+
135
+ signal.signal(signal.SIGTERM, _sigterm_handler)
136
+ signal.signal(signal.SIGINT, _sigterm_handler)
137
+
138
+ try:
139
+ while not stop_event.is_set():
140
+ # idle timeout チェック
141
+ if time.monotonic() - last_activity[0] > IDLE_TIMEOUT:
142
+ break
143
+
144
+ try:
145
+ conn, _ = server.accept()
146
+ except TimeoutError:
147
+ continue
148
+ except OSError:
149
+ break
150
+
151
+ # クライアントごとにスレッド
152
+ t = threading.Thread(
153
+ target=_handle_client,
154
+ args=(conn, embedder, last_activity, stop_event),
155
+ daemon=True,
156
+ )
157
+ t.start()
158
+ finally:
159
+ server.close()
160
+ sock_path.unlink(missing_ok=True)
161
+
162
+
163
+ def main() -> None:
164
+ """CLI エントリポイント: python -m codeatrium.embedder_server <sock_path>"""
165
+ if len(sys.argv) < 2:
166
+ print("Usage: python -m codeatrium.embedder_server <sock_path>", file=sys.stderr)
167
+ sys.exit(1)
168
+ run_server(Path(sys.argv[1]))
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()