codeatrium 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeatrium/__init__.py +3 -0
- codeatrium/__main__.py +5 -0
- codeatrium/cli/__init__.py +295 -0
- codeatrium/cli/distill_cmd.py +76 -0
- codeatrium/cli/hook_cmd.py +24 -0
- codeatrium/cli/index_cmd.py +62 -0
- codeatrium/cli/prime_cmd.py +90 -0
- codeatrium/cli/search_cmd.py +128 -0
- codeatrium/cli/server_cmd.py +122 -0
- codeatrium/cli/show_cmd.py +151 -0
- codeatrium/cli/status_cmd.py +59 -0
- codeatrium/config.py +96 -0
- codeatrium/db.py +135 -0
- codeatrium/distiller.py +290 -0
- codeatrium/embedder.py +168 -0
- codeatrium/embedder_server.py +172 -0
- codeatrium/hooks.py +156 -0
- codeatrium/indexer.py +237 -0
- codeatrium/llm.py +148 -0
- codeatrium/models.py +53 -0
- codeatrium/paths.py +74 -0
- codeatrium/py.typed +0 -0
- codeatrium/resolver.py +301 -0
- codeatrium/search.py +273 -0
- codeatrium-0.1.0.dist-info/METADATA +180 -0
- codeatrium-0.1.0.dist-info/RECORD +29 -0
- codeatrium-0.1.0.dist-info/WHEEL +4 -0
- codeatrium-0.1.0.dist-info/entry_points.txt +2 -0
- codeatrium-0.1.0.dist-info/licenses/LICENSE +21 -0
codeatrium/distiller.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""蒸留モジュール: claude -p で exchange を palace object に変換する
|
|
2
|
+
|
|
3
|
+
SPEC Section 6 DISTILLER フロー準拠:
|
|
4
|
+
① files_touched を regex で抽出(LLM非使用)
|
|
5
|
+
② claude -p で palace object 生成(--output-format json --json-schema)
|
|
6
|
+
③ distill_text を embedding して vec_palace に登録
|
|
7
|
+
④ files_touched を tree-sitter で解析してシンボルを symbols テーブルに登録
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import datetime
|
|
13
|
+
import hashlib
|
|
14
|
+
import os
|
|
15
|
+
import re
|
|
16
|
+
import struct
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
|
22
|
+
|
|
23
|
+
from codeatrium.embedder import Embedder
|
|
24
|
+
from codeatrium.llm import DISTILL_PROMPT_TEMPLATE, call_claude
|
|
25
|
+
from codeatrium.models import PalaceObject
|
|
26
|
+
|
|
27
|
+
# ---- ファイルパス抽出 ----
|
|
28
|
+
|
|
29
|
+
_FILES_PATTERN = re.compile(
|
|
30
|
+
r"(/(?:[a-zA-Z0-9._\-]+/)*[a-zA-Z0-9._\-]+\.[a-zA-Z0-9]+)" # 絶対パス
|
|
31
|
+
r"|([a-zA-Z0-9._\-]+(?:/[a-zA-Z0-9._\-]+)+\.[a-zA-Z0-9]+)" # 相対パス(1段以上のディレクトリ)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ---- 内部ヘルパー ----
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _sha256(text: str) -> str:
|
|
39
|
+
return hashlib.sha256(text.encode()).hexdigest()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ---- 公開 API ----
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
_EXTERNAL_PATH_MARKERS = (
|
|
46
|
+
"site-packages/",
|
|
47
|
+
"dist-packages/",
|
|
48
|
+
"/lib/python",
|
|
49
|
+
"/opt/",
|
|
50
|
+
"/usr/lib/",
|
|
51
|
+
"/usr/local/lib/",
|
|
52
|
+
".venv/",
|
|
53
|
+
"/venv/",
|
|
54
|
+
"node_modules/",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _is_external_path(path: str, project_root: str | None = None) -> bool:
|
|
59
|
+
"""プロジェクト外のパスか判定する。
|
|
60
|
+
|
|
61
|
+
絶対パス: project_root が指定されていればその配下かチェック。
|
|
62
|
+
相対パス: ハードコードマーカーでフィルタ。
|
|
63
|
+
"""
|
|
64
|
+
if path.startswith("/"):
|
|
65
|
+
# 絶対パス: project_root 配下でなければ外部
|
|
66
|
+
if project_root:
|
|
67
|
+
return not path.startswith(project_root)
|
|
68
|
+
# project_root 不明時はマーカーでフォールバック
|
|
69
|
+
return any(marker in path for marker in _EXTERNAL_PATH_MARKERS)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def extract_files_touched(
|
|
73
|
+
user_content: str, agent_content: str, project_root: str | None = None
|
|
74
|
+
) -> list[str]:
|
|
75
|
+
"""user_content + agent_content から regex でファイルパスを抽出する(重複排除・順序維持)
|
|
76
|
+
|
|
77
|
+
project_root が指定された場合、絶対パスはその配下のもののみ残す。
|
|
78
|
+
相対パスはハードコードマーカー(node_modules 等)でフィルタする。
|
|
79
|
+
"""
|
|
80
|
+
text = user_content + "\n" + agent_content
|
|
81
|
+
# project_root を末尾スラッシュ付きに正規化
|
|
82
|
+
root_prefix = (project_root.rstrip("/") + "/") if project_root else None
|
|
83
|
+
seen: set[str] = set()
|
|
84
|
+
result: list[str] = []
|
|
85
|
+
for m in _FILES_PATTERN.findall(text):
|
|
86
|
+
path = m[0] or m[1]
|
|
87
|
+
if path and path not in seen and not _is_external_path(path, root_prefix):
|
|
88
|
+
seen.add(path)
|
|
89
|
+
result.append(path)
|
|
90
|
+
return result
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def distill_exchange(
|
|
94
|
+
exchange_id: str,
|
|
95
|
+
user_content: str,
|
|
96
|
+
agent_content: str,
|
|
97
|
+
ply_start: int,
|
|
98
|
+
ply_end: int,
|
|
99
|
+
model: str | None = None,
|
|
100
|
+
project_root: str | None = None,
|
|
101
|
+
) -> PalaceObject:
|
|
102
|
+
"""1つの exchange を蒸留して PalaceObject を返す"""
|
|
103
|
+
messages_text = (user_content + "\n" + agent_content)[:4000]
|
|
104
|
+
prompt = DISTILL_PROMPT_TEMPLATE.format(
|
|
105
|
+
ply_start=ply_start,
|
|
106
|
+
ply_end=ply_end,
|
|
107
|
+
messages_text=messages_text,
|
|
108
|
+
)
|
|
109
|
+
raw = call_claude(prompt, model=model)
|
|
110
|
+
files_touched = extract_files_touched(
|
|
111
|
+
user_content, agent_content, project_root=project_root
|
|
112
|
+
)
|
|
113
|
+
return PalaceObject(
|
|
114
|
+
exchange_core=raw["exchange_core"],
|
|
115
|
+
specific_context=raw["specific_context"],
|
|
116
|
+
room_assignments=raw["room_assignments"],
|
|
117
|
+
files_touched=files_touched,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def save_palace_object(
|
|
122
|
+
db_path: Path,
|
|
123
|
+
exchange_id: str,
|
|
124
|
+
palace: PalaceObject,
|
|
125
|
+
embedding: Any, # np.ndarray
|
|
126
|
+
) -> None:
|
|
127
|
+
"""PalaceObject を DB に保存し exchange の distilled_at を更新する"""
|
|
128
|
+
import numpy as np
|
|
129
|
+
|
|
130
|
+
from codeatrium.db import get_connection
|
|
131
|
+
|
|
132
|
+
palace_id = _sha256(f"palace:{exchange_id}")
|
|
133
|
+
distill_text = palace.exchange_core + "\n" + palace.specific_context
|
|
134
|
+
|
|
135
|
+
con = get_connection(db_path)
|
|
136
|
+
|
|
137
|
+
con.execute(
|
|
138
|
+
"""
|
|
139
|
+
INSERT OR IGNORE INTO palace_objects
|
|
140
|
+
(id, exchange_id, exchange_core, specific_context, distill_text)
|
|
141
|
+
VALUES (?, ?, ?, ?, ?)
|
|
142
|
+
""",
|
|
143
|
+
(
|
|
144
|
+
palace_id,
|
|
145
|
+
exchange_id,
|
|
146
|
+
palace.exchange_core,
|
|
147
|
+
palace.specific_context,
|
|
148
|
+
distill_text,
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
for room in palace.room_assignments:
|
|
153
|
+
dedup = _sha256(f"{room['room_type']}:{room['room_key']}")
|
|
154
|
+
room_id = _sha256(f"{palace_id}:{dedup}")
|
|
155
|
+
con.execute(
|
|
156
|
+
"""
|
|
157
|
+
INSERT OR IGNORE INTO rooms
|
|
158
|
+
(id, palace_object_id, room_type, room_key, room_label, relevance, dedup_hash)
|
|
159
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
160
|
+
""",
|
|
161
|
+
(
|
|
162
|
+
room_id,
|
|
163
|
+
palace_id,
|
|
164
|
+
room["room_type"],
|
|
165
|
+
room["room_key"],
|
|
166
|
+
room["room_label"],
|
|
167
|
+
room["relevance"],
|
|
168
|
+
dedup,
|
|
169
|
+
),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
arr = embedding.astype(np.float32)
|
|
173
|
+
blob = struct.pack(f"{len(arr)}f", *arr.tolist())
|
|
174
|
+
exists = con.execute(
|
|
175
|
+
"SELECT 1 FROM vec_palace WHERE palace_id = ?", (palace_id,)
|
|
176
|
+
).fetchone()
|
|
177
|
+
if not exists:
|
|
178
|
+
con.execute(
|
|
179
|
+
"INSERT INTO vec_palace (palace_id, embedding) VALUES (?, ?)",
|
|
180
|
+
(palace_id, blob),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# ⑤ tree-sitter シンボル解決
|
|
184
|
+
from codeatrium.resolver import SymbolResolver
|
|
185
|
+
|
|
186
|
+
resolver = SymbolResolver()
|
|
187
|
+
for file_str in palace.files_touched:
|
|
188
|
+
for sym in resolver.extract(Path(file_str)):
|
|
189
|
+
sym_id = _sha256(f"{sym.symbol_name}:{sym.file_path}")
|
|
190
|
+
con.execute(
|
|
191
|
+
"""
|
|
192
|
+
INSERT OR IGNORE INTO symbols
|
|
193
|
+
(id, palace_object_id, symbol_name, symbol_kind,
|
|
194
|
+
file_path, signature, line, dedup_hash)
|
|
195
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
196
|
+
""",
|
|
197
|
+
(
|
|
198
|
+
sym_id,
|
|
199
|
+
palace_id,
|
|
200
|
+
sym.symbol_name,
|
|
201
|
+
sym.symbol_kind,
|
|
202
|
+
sym.file_path,
|
|
203
|
+
sym.signature,
|
|
204
|
+
sym.line,
|
|
205
|
+
sym_id,
|
|
206
|
+
),
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
con.execute(
|
|
210
|
+
"UPDATE exchanges SET distilled_at = ? WHERE id = ?",
|
|
211
|
+
(datetime.datetime.utcnow().isoformat(), exchange_id),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
con.commit()
|
|
215
|
+
con.close()
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def distill_all(
|
|
219
|
+
db_path: Path,
|
|
220
|
+
limit: int | None = None,
|
|
221
|
+
model: str | None = None,
|
|
222
|
+
on_progress: Callable[..., None] | None = None,
|
|
223
|
+
project_root: str | None = None,
|
|
224
|
+
distill_min_chars: int = 100,
|
|
225
|
+
) -> int:
|
|
226
|
+
"""未蒸留の exchange を処理する。
|
|
227
|
+
|
|
228
|
+
distill_min_chars: この文字数未満の exchange は蒸留スキップ(デフォルト100)
|
|
229
|
+
on_progress: (current, total, error=None) を受け取るコールバック
|
|
230
|
+
Returns: 処理した exchange 数
|
|
231
|
+
"""
|
|
232
|
+
from codeatrium.db import get_connection
|
|
233
|
+
|
|
234
|
+
con = get_connection(db_path)
|
|
235
|
+
|
|
236
|
+
# 蒸留対象外の exchange を skipped にマーク:
|
|
237
|
+
# - 1-exchange セッション
|
|
238
|
+
# - distill_min_chars 未満(ワンフレーズ指示・システムメッセージ等)
|
|
239
|
+
con.execute("""
|
|
240
|
+
UPDATE exchanges SET distilled_at = 'skipped'
|
|
241
|
+
WHERE distilled_at IS NULL
|
|
242
|
+
AND ((SELECT COUNT(*) FROM exchanges e2
|
|
243
|
+
WHERE e2.conversation_id = exchanges.conversation_id) < 2
|
|
244
|
+
OR LENGTH(user_content) + LENGTH(agent_content) < ?)
|
|
245
|
+
""", (distill_min_chars,))
|
|
246
|
+
con.commit()
|
|
247
|
+
|
|
248
|
+
query = """
|
|
249
|
+
SELECT e.id, e.user_content, e.agent_content, e.ply_start, e.ply_end
|
|
250
|
+
FROM exchanges e
|
|
251
|
+
WHERE e.distilled_at IS NULL
|
|
252
|
+
"""
|
|
253
|
+
params: list[int] = []
|
|
254
|
+
if limit is not None:
|
|
255
|
+
query += " LIMIT ?"
|
|
256
|
+
params.append(int(limit))
|
|
257
|
+
rows = con.execute(query, params).fetchall()
|
|
258
|
+
con.close()
|
|
259
|
+
|
|
260
|
+
if not rows:
|
|
261
|
+
return 0
|
|
262
|
+
|
|
263
|
+
total = len(rows)
|
|
264
|
+
embedder = Embedder()
|
|
265
|
+
count = 0
|
|
266
|
+
errors = 0
|
|
267
|
+
for row in rows:
|
|
268
|
+
try:
|
|
269
|
+
palace = distill_exchange(
|
|
270
|
+
row["id"],
|
|
271
|
+
row["user_content"],
|
|
272
|
+
row["agent_content"],
|
|
273
|
+
row["ply_start"],
|
|
274
|
+
row["ply_end"],
|
|
275
|
+
model=model,
|
|
276
|
+
project_root=project_root,
|
|
277
|
+
)
|
|
278
|
+
distill_text = palace.exchange_core + "\n" + palace.specific_context
|
|
279
|
+
vec = embedder.embed_passage(distill_text)
|
|
280
|
+
save_palace_object(db_path, row["id"], palace, vec)
|
|
281
|
+
count += 1
|
|
282
|
+
except Exception as e:
|
|
283
|
+
errors += 1
|
|
284
|
+
if on_progress is not None:
|
|
285
|
+
on_progress(count, total, error=str(e))
|
|
286
|
+
continue
|
|
287
|
+
if on_progress is not None:
|
|
288
|
+
on_progress(count, total)
|
|
289
|
+
|
|
290
|
+
return count
|
codeatrium/embedder.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
sentence-transformers / multilingual-e5-small の embedding ラッパー
|
|
3
|
+
|
|
4
|
+
モデル: intfloat/multilingual-e5-small(384次元・日本語+英語混在対応・CPU動作)
|
|
5
|
+
|
|
6
|
+
ソケットサーバー方式:
|
|
7
|
+
- .codeatrium/embedder.sock が存在すれば Unix ソケット経由で embed(< 1秒)
|
|
8
|
+
- ソケットなければ直接モデルをロード(~7秒)+バックグラウンドでサーバー起動
|
|
9
|
+
- 2回目以降は常にソケット経由になるため高速
|
|
10
|
+
|
|
11
|
+
cold start が問題になった場合は SentenceTransformer(..., backend="onnx") でも対処可能
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import socket
|
|
18
|
+
import subprocess
|
|
19
|
+
import sys
|
|
20
|
+
import time
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _loci_python() -> str:
|
|
27
|
+
"""venv の Python パスを返す(sys.executable はシステム Python の場合があるため)"""
|
|
28
|
+
# loci CLI と同じ venv の python を使う
|
|
29
|
+
python = Path(sys.executable).parent / "python3"
|
|
30
|
+
if python.exists():
|
|
31
|
+
return str(python)
|
|
32
|
+
return sys.executable
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
MODEL_NAME = "intfloat/multilingual-e5-small"
|
|
36
|
+
CONNECT_TIMEOUT = 2.0 # ソケット接続タイムアウト
|
|
37
|
+
RECV_TIMEOUT = 30.0 # 埋め込み受信タイムアウト
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _sock_path_from_env() -> Path | None:
|
|
41
|
+
"""環境変数 CODEATRIUM_SOCK_PATH があれば優先使用(テスト用)。
|
|
42
|
+
CODEATRIUM_NO_SOCK=1 の場合はソケット無効(サーバー内自己接続デッドロック防止)。
|
|
43
|
+
"""
|
|
44
|
+
import os
|
|
45
|
+
|
|
46
|
+
if os.environ.get("CODEATRIUM_NO_SOCK"):
|
|
47
|
+
return None
|
|
48
|
+
p = os.environ.get("CODEATRIUM_SOCK_PATH")
|
|
49
|
+
return Path(p) if p else None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _find_sock_path() -> Path | None:
|
|
53
|
+
"""DB の親ディレクトリの embedder.sock を探す"""
|
|
54
|
+
import os
|
|
55
|
+
|
|
56
|
+
if os.environ.get("CODEATRIUM_NO_SOCK"):
|
|
57
|
+
return None
|
|
58
|
+
# .codeatrium/ の場所を git root から解決
|
|
59
|
+
try:
|
|
60
|
+
import subprocess as sp
|
|
61
|
+
|
|
62
|
+
result = sp.run(
|
|
63
|
+
["git", "rev-parse", "--show-toplevel"],
|
|
64
|
+
capture_output=True,
|
|
65
|
+
text=True,
|
|
66
|
+
)
|
|
67
|
+
if result.returncode == 0:
|
|
68
|
+
return Path(result.stdout.strip()) / ".codeatrium" / "embedder.sock"
|
|
69
|
+
except Exception:
|
|
70
|
+
pass
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _try_socket_embed(sock_path: Path, req_type: str, text: str) -> np.ndarray | None:
|
|
75
|
+
"""ソケットサーバーに接続して embedding を取得する。失敗時は None を返す。"""
|
|
76
|
+
if not sock_path.exists():
|
|
77
|
+
return None
|
|
78
|
+
try:
|
|
79
|
+
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
|
|
80
|
+
s.settimeout(CONNECT_TIMEOUT)
|
|
81
|
+
s.connect(str(sock_path))
|
|
82
|
+
s.settimeout(RECV_TIMEOUT)
|
|
83
|
+
req = json.dumps({"type": req_type, "text": text}) + "\n"
|
|
84
|
+
s.sendall(req.encode())
|
|
85
|
+
buf = b""
|
|
86
|
+
while b"\n" not in buf:
|
|
87
|
+
chunk = s.recv(65536)
|
|
88
|
+
if not chunk:
|
|
89
|
+
break
|
|
90
|
+
buf += chunk
|
|
91
|
+
resp = json.loads(buf.split(b"\n")[0])
|
|
92
|
+
if "embedding" in resp:
|
|
93
|
+
return np.array(resp["embedding"], dtype=np.float32)
|
|
94
|
+
except Exception:
|
|
95
|
+
pass
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _start_server_background(sock_path: Path) -> None:
|
|
100
|
+
"""embedder_server をバックグラウンドで起動する"""
|
|
101
|
+
try:
|
|
102
|
+
subprocess.Popen(
|
|
103
|
+
[_loci_python(), "-m", "codeatrium.embedder_server", str(sock_path)],
|
|
104
|
+
stdout=subprocess.DEVNULL,
|
|
105
|
+
stderr=subprocess.DEVNULL,
|
|
106
|
+
start_new_session=True,
|
|
107
|
+
)
|
|
108
|
+
# サーバーが起動するまで少し待つ
|
|
109
|
+
for _ in range(20):
|
|
110
|
+
if sock_path.exists():
|
|
111
|
+
break
|
|
112
|
+
time.sleep(0.2)
|
|
113
|
+
except Exception:
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class Embedder:
|
|
118
|
+
"""multilingual-e5-small の薄いラッパー。
|
|
119
|
+
|
|
120
|
+
ソケットサーバーが起動済みなら高速パス(Unix ソケット)を使い、
|
|
121
|
+
なければ直接モデルをロードしてバックグラウンドでサーバーを起動する。
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, sock_path: Path | None = None) -> None:
|
|
125
|
+
self._sock_path: Path | None = (
|
|
126
|
+
_sock_path_from_env() or sock_path or _find_sock_path()
|
|
127
|
+
)
|
|
128
|
+
self._model = None # 遅延ロード
|
|
129
|
+
|
|
130
|
+
def _ensure_model(self) -> None:
|
|
131
|
+
"""直接モデルをロードする(ソケット不使用時のフォールバック)"""
|
|
132
|
+
if self._model is None:
|
|
133
|
+
from sentence_transformers import SentenceTransformer
|
|
134
|
+
|
|
135
|
+
self._model = SentenceTransformer(MODEL_NAME)
|
|
136
|
+
|
|
137
|
+
def _embed_via_socket_or_direct(
|
|
138
|
+
self, text: str, req_type: str, prefix: str
|
|
139
|
+
) -> np.ndarray:
|
|
140
|
+
"""ソケット優先・なければ直接ロード+サーバー起動"""
|
|
141
|
+
# ① ソケット経由を試みる
|
|
142
|
+
if self._sock_path is not None:
|
|
143
|
+
vec = _try_socket_embed(self._sock_path, req_type, text)
|
|
144
|
+
if vec is not None:
|
|
145
|
+
return vec
|
|
146
|
+
|
|
147
|
+
# ② 直接ロード
|
|
148
|
+
self._ensure_model()
|
|
149
|
+
assert self._model is not None
|
|
150
|
+
result = self._model.encode(
|
|
151
|
+
[f"{prefix}{text}"],
|
|
152
|
+
normalize_embeddings=True,
|
|
153
|
+
)
|
|
154
|
+
vec = result[0].astype(np.float32)
|
|
155
|
+
|
|
156
|
+
# ③ バックグラウンドでサーバーを起動(次回から高速化)
|
|
157
|
+
if self._sock_path is not None and not self._sock_path.exists():
|
|
158
|
+
_start_server_background(self._sock_path)
|
|
159
|
+
|
|
160
|
+
return vec
|
|
161
|
+
|
|
162
|
+
def embed(self, text: str) -> np.ndarray:
|
|
163
|
+
"""クエリ用 embedding(query: プレフィックス)"""
|
|
164
|
+
return self._embed_via_socket_or_direct(text, "query", "query: ")
|
|
165
|
+
|
|
166
|
+
def embed_passage(self, text: str) -> np.ndarray:
|
|
167
|
+
"""インデックス登録用 embedding(passage: プレフィックス)"""
|
|
168
|
+
return self._embed_via_socket_or_direct(text, "passage", "passage: ")
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
embedder_server.py — multilingual-e5-small を常駐させる Unix ソケットサーバー
|
|
3
|
+
|
|
4
|
+
プロトコル:
|
|
5
|
+
request (改行区切り JSON): {"type": "query"|"passage", "text": "..."}
|
|
6
|
+
response (改行区切り JSON): {"embedding": [0.1, 0.2, ...]}
|
|
7
|
+
特殊コマンド: {"type": "ping"} → {"status": "ok"}
|
|
8
|
+
{"type": "stop"} → サーバー終了
|
|
9
|
+
|
|
10
|
+
ライフサイクル:
|
|
11
|
+
- loci server start でバックグラウンド起動
|
|
12
|
+
- IDLE_TIMEOUT 秒間リクエストなし → 自動終了
|
|
13
|
+
- loci server stop / SIGTERM で即終了
|
|
14
|
+
- ソケットファイルは .codeatrium/embedder.sock
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import os
|
|
21
|
+
import signal
|
|
22
|
+
import socket
|
|
23
|
+
import sys
|
|
24
|
+
import threading
|
|
25
|
+
import time
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import TYPE_CHECKING
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from codeatrium.embedder import Embedder
|
|
31
|
+
|
|
32
|
+
IDLE_TIMEOUT = 600 # 10分間無リクエストで自動終了
|
|
33
|
+
BACKLOG = 8
|
|
34
|
+
RECV_BUFSIZE = 65536
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _load_embedder() -> Embedder:
|
|
38
|
+
# サーバー内では直接モデルを使う(ソケット経由にすると自己接続デッドロック)
|
|
39
|
+
|
|
40
|
+
os.environ["CODEATRIUM_NO_SOCK"] = "1"
|
|
41
|
+
from codeatrium.embedder import Embedder
|
|
42
|
+
|
|
43
|
+
embedder = Embedder()
|
|
44
|
+
del os.environ["CODEATRIUM_NO_SOCK"]
|
|
45
|
+
return embedder
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _handle_client(
|
|
49
|
+
conn: socket.socket,
|
|
50
|
+
embedder: Embedder,
|
|
51
|
+
last_activity: list[float],
|
|
52
|
+
stop_event: threading.Event,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""1クライアント接続を処理する"""
|
|
55
|
+
try:
|
|
56
|
+
buf = b""
|
|
57
|
+
while True:
|
|
58
|
+
chunk = conn.recv(RECV_BUFSIZE)
|
|
59
|
+
if not chunk:
|
|
60
|
+
break
|
|
61
|
+
buf += chunk
|
|
62
|
+
while b"\n" in buf:
|
|
63
|
+
line, buf = buf.split(b"\n", 1)
|
|
64
|
+
line = line.strip()
|
|
65
|
+
if not line:
|
|
66
|
+
continue
|
|
67
|
+
try:
|
|
68
|
+
req = json.loads(line)
|
|
69
|
+
except json.JSONDecodeError:
|
|
70
|
+
conn.sendall(json.dumps({"error": "invalid json"}).encode() + b"\n")
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
req_type = req.get("type", "")
|
|
74
|
+
|
|
75
|
+
if req_type == "ping":
|
|
76
|
+
conn.sendall(json.dumps({"status": "ok"}).encode() + b"\n")
|
|
77
|
+
last_activity[0] = time.monotonic()
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
if req_type == "stop":
|
|
81
|
+
conn.sendall(json.dumps({"status": "stopping"}).encode() + b"\n")
|
|
82
|
+
stop_event.set()
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
text = req.get("text", "")
|
|
86
|
+
if not text:
|
|
87
|
+
conn.sendall(json.dumps({"error": "missing text"}).encode() + b"\n")
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
if req_type == "query":
|
|
91
|
+
vec = embedder.embed(text)
|
|
92
|
+
elif req_type == "passage":
|
|
93
|
+
vec = embedder.embed_passage(text)
|
|
94
|
+
else:
|
|
95
|
+
conn.sendall(
|
|
96
|
+
json.dumps({"error": f"unknown type: {req_type}"}).encode()
|
|
97
|
+
+ b"\n"
|
|
98
|
+
)
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
resp = json.dumps({"embedding": vec.tolist()})
|
|
102
|
+
conn.sendall(resp.encode() + b"\n")
|
|
103
|
+
last_activity[0] = time.monotonic()
|
|
104
|
+
except (OSError, BrokenPipeError):
|
|
105
|
+
pass
|
|
106
|
+
finally:
|
|
107
|
+
try:
|
|
108
|
+
conn.close()
|
|
109
|
+
except OSError:
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def run_server(sock_path: Path) -> None:
|
|
114
|
+
"""ソケットサーバーを起動してリクエストを処理する(ブロッキング)"""
|
|
115
|
+
# 既存ソケットファイルを削除
|
|
116
|
+
sock_path.unlink(missing_ok=True)
|
|
117
|
+
|
|
118
|
+
embedder = _load_embedder()
|
|
119
|
+
|
|
120
|
+
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
121
|
+
old_umask = os.umask(0o177) # enforce 0o600 on the socket, no TOCTOU
|
|
122
|
+
try:
|
|
123
|
+
server.bind(str(sock_path))
|
|
124
|
+
finally:
|
|
125
|
+
os.umask(old_umask)
|
|
126
|
+
server.listen(BACKLOG)
|
|
127
|
+
server.settimeout(1.0) # accept の timeout(idle チェック用)
|
|
128
|
+
|
|
129
|
+
last_activity: list[float] = [time.monotonic()]
|
|
130
|
+
stop_event = threading.Event()
|
|
131
|
+
|
|
132
|
+
def _sigterm_handler(signum: int, frame: object) -> None:
|
|
133
|
+
stop_event.set()
|
|
134
|
+
|
|
135
|
+
signal.signal(signal.SIGTERM, _sigterm_handler)
|
|
136
|
+
signal.signal(signal.SIGINT, _sigterm_handler)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
while not stop_event.is_set():
|
|
140
|
+
# idle timeout チェック
|
|
141
|
+
if time.monotonic() - last_activity[0] > IDLE_TIMEOUT:
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
conn, _ = server.accept()
|
|
146
|
+
except TimeoutError:
|
|
147
|
+
continue
|
|
148
|
+
except OSError:
|
|
149
|
+
break
|
|
150
|
+
|
|
151
|
+
# クライアントごとにスレッド
|
|
152
|
+
t = threading.Thread(
|
|
153
|
+
target=_handle_client,
|
|
154
|
+
args=(conn, embedder, last_activity, stop_event),
|
|
155
|
+
daemon=True,
|
|
156
|
+
)
|
|
157
|
+
t.start()
|
|
158
|
+
finally:
|
|
159
|
+
server.close()
|
|
160
|
+
sock_path.unlink(missing_ok=True)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def main() -> None:
|
|
164
|
+
"""CLI エントリポイント: python -m codeatrium.embedder_server <sock_path>"""
|
|
165
|
+
if len(sys.argv) < 2:
|
|
166
|
+
print("Usage: python -m codeatrium.embedder_server <sock_path>", file=sys.stderr)
|
|
167
|
+
sys.exit(1)
|
|
168
|
+
run_server(Path(sys.argv[1]))
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
if __name__ == "__main__":
|
|
172
|
+
main()
|