tau-sim 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tau_sim/__init__.py +0 -0
- tau_sim/api/__init__.py +0 -0
- tau_sim/api/routes.py +148 -0
- tau_sim/cli.py +62 -0
- tau_sim/config.py +23 -0
- tau_sim/llm/__init__.py +0 -0
- tau_sim/llm/agent.py +195 -0
- tau_sim/llm/manual.py +91 -0
- tau_sim/llm/providers.py +115 -0
- tau_sim/main.py +66 -0
- tau_sim/middleware/__init__.py +0 -0
- tau_sim/middleware/session.py +62 -0
- tau_sim/schemas.py +62 -0
- tau_sim/sim/__init__.py +0 -0
- tau_sim/sim/menagerie.py +169 -0
- tau_sim/sim/runner.py +27 -0
- tau_sim/sim/runner_core.py +155 -0
- tau_sim/sim/sandbox.py +111 -0
- tau_sim/storage/__init__.py +0 -0
- tau_sim/storage/projects.py +188 -0
- tau_sim/ws/__init__.py +0 -0
- tau_sim/ws/stream.py +46 -0
- tau_sim-0.1.0.dist-info/METADATA +30 -0
- tau_sim-0.1.0.dist-info/RECORD +27 -0
- tau_sim-0.1.0.dist-info/WHEEL +5 -0
- tau_sim-0.1.0.dist-info/entry_points.txt +2 -0
- tau_sim-0.1.0.dist-info/top_level.txt +1 -0
tau_sim/__init__.py
ADDED
|
File without changes
|
tau_sim/api/__init__.py
ADDED
|
File without changes
|
tau_sim/api/routes.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from fastapi import APIRouter, HTTPException, Request, Response
|
|
2
|
+
|
|
3
|
+
from tau_sim.llm import manual
|
|
4
|
+
from tau_sim.llm.agent import parse_manual_response, run_chat
|
|
5
|
+
from tau_sim.middleware.session import session_id_from_request
|
|
6
|
+
from tau_sim.schemas import (
|
|
7
|
+
ChatRequest,
|
|
8
|
+
ChatResponse,
|
|
9
|
+
FileWrite,
|
|
10
|
+
ProjectCreate,
|
|
11
|
+
ProjectInfo,
|
|
12
|
+
)
|
|
13
|
+
from tau_sim.sim import menagerie
|
|
14
|
+
from tau_sim.storage import projects
|
|
15
|
+
|
|
16
|
+
router = APIRouter(prefix="/api")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _sid(request: Request) -> str:
|
|
20
|
+
return session_id_from_request(request)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# ---- projects ----------------------------------------------------------
|
|
24
|
+
|
|
25
|
+
@router.get("/projects", response_model=list[ProjectInfo])
|
|
26
|
+
def list_projects(request: Request) -> list[ProjectInfo]:
|
|
27
|
+
return projects.list_projects(_sid(request))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@router.post("/projects", response_model=ProjectInfo)
|
|
31
|
+
def create_project(body: ProjectCreate, request: Request) -> ProjectInfo:
|
|
32
|
+
return projects.create_project(_sid(request), body.name, body.template)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@router.get("/projects/{project_id}", response_model=ProjectInfo)
|
|
36
|
+
def get_project(project_id: str, request: Request) -> ProjectInfo:
|
|
37
|
+
try:
|
|
38
|
+
return projects.get_project(_sid(request), project_id)
|
|
39
|
+
except FileNotFoundError as e:
|
|
40
|
+
raise HTTPException(status_code=404, detail=str(e))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@router.delete("/projects/{project_id}")
|
|
44
|
+
def delete_project(project_id: str, request: Request) -> dict:
|
|
45
|
+
projects.delete_project(_sid(request), project_id)
|
|
46
|
+
return {"ok": True}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ---- menagerie import --------------------------------------------------
|
|
50
|
+
|
|
51
|
+
@router.get("/menagerie/models")
|
|
52
|
+
def menagerie_models() -> list[str]:
|
|
53
|
+
try:
|
|
54
|
+
return menagerie.list_models()
|
|
55
|
+
except Exception as e: # noqa: BLE001
|
|
56
|
+
raise HTTPException(status_code=502, detail=f"menagerie list failed: {e}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@router.post("/projects/{project_id}/import/menagerie")
|
|
60
|
+
def import_menagerie(project_id: str, body: dict, request: Request) -> dict:
|
|
61
|
+
model = (body or {}).get("model", "").strip()
|
|
62
|
+
if not model:
|
|
63
|
+
raise HTTPException(status_code=400, detail="`model` is required")
|
|
64
|
+
try:
|
|
65
|
+
root = projects.project_root(_sid(request), project_id)
|
|
66
|
+
except (FileNotFoundError, ValueError) as e:
|
|
67
|
+
raise HTTPException(status_code=404, detail=str(e))
|
|
68
|
+
try:
|
|
69
|
+
return menagerie.import_model(model, root)
|
|
70
|
+
except (FileNotFoundError, ValueError) as e:
|
|
71
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
72
|
+
except PermissionError as e:
|
|
73
|
+
raise HTTPException(status_code=429, detail=str(e))
|
|
74
|
+
except Exception as e: # noqa: BLE001
|
|
75
|
+
raise HTTPException(status_code=502, detail=f"import failed: {e}")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# ---- files -------------------------------------------------------------
|
|
79
|
+
|
|
80
|
+
@router.get("/projects/{project_id}/file")
|
|
81
|
+
def read_file(project_id: str, path: str, request: Request) -> dict:
|
|
82
|
+
try:
|
|
83
|
+
return {
|
|
84
|
+
"path": path,
|
|
85
|
+
"content": projects.read_file(_sid(request), project_id, path),
|
|
86
|
+
}
|
|
87
|
+
except FileNotFoundError:
|
|
88
|
+
raise HTTPException(status_code=404, detail="file not found")
|
|
89
|
+
except ValueError as e:
|
|
90
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@router.put("/projects/{project_id}/file")
|
|
94
|
+
def write_file(project_id: str, body: FileWrite, request: Request) -> dict:
|
|
95
|
+
try:
|
|
96
|
+
projects.write_file(_sid(request), project_id, body.path, body.content)
|
|
97
|
+
return {"ok": True}
|
|
98
|
+
except ValueError as e:
|
|
99
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# ---- chat --------------------------------------------------------------
|
|
103
|
+
|
|
104
|
+
@router.post("/chat", response_model=ChatResponse)
|
|
105
|
+
def chat(body: ChatRequest, request: Request) -> ChatResponse:
|
|
106
|
+
try:
|
|
107
|
+
return run_chat(_sid(request), body)
|
|
108
|
+
except RuntimeError as e:
|
|
109
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
110
|
+
except Exception as e: # noqa: BLE001
|
|
111
|
+
raise HTTPException(status_code=500, detail=f"LLM error: {e!r}")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# ---- manual / human-in-the-loop ----------------------------------------
|
|
115
|
+
|
|
116
|
+
@router.get("/chat/result/{pending_id}")
|
|
117
|
+
def chat_result(pending_id: str, response: Response) -> dict:
|
|
118
|
+
try:
|
|
119
|
+
payload = manual.read_response(pending_id)
|
|
120
|
+
except ValueError as e:
|
|
121
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
122
|
+
if payload is None:
|
|
123
|
+
response.status_code = 202
|
|
124
|
+
return {"pending": True}
|
|
125
|
+
chat_resp = parse_manual_response(payload)
|
|
126
|
+
return chat_resp.model_dump()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@router.get("/manual/pending")
|
|
130
|
+
def manual_pending() -> list[dict]:
|
|
131
|
+
return manual.list_pending()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@router.get("/manual/pending/{pending_id}")
|
|
135
|
+
def manual_pending_one(pending_id: str) -> dict:
|
|
136
|
+
req = manual.get_request(pending_id)
|
|
137
|
+
if req is None:
|
|
138
|
+
raise HTTPException(status_code=404, detail="not found")
|
|
139
|
+
return req
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@router.post("/manual/pending/{pending_id}/respond")
|
|
143
|
+
def manual_respond(pending_id: str, payload: dict) -> dict:
|
|
144
|
+
try:
|
|
145
|
+
manual.write_response(pending_id, payload)
|
|
146
|
+
except FileNotFoundError:
|
|
147
|
+
raise HTTPException(status_code=404, detail="not found")
|
|
148
|
+
return {"ok": True}
|
tau_sim/cli.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""CLI entrypoint: ``tau`` launches the backend + opens a browser.
|
|
2
|
+
|
|
3
|
+
Bundles the SPA if `backend/static/` exists (production build); otherwise
|
|
4
|
+
expects Vite dev server on port 5173.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
import threading
|
|
12
|
+
import time
|
|
13
|
+
import webbrowser
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main() -> None:
|
|
17
|
+
parser = argparse.ArgumentParser(prog="tau")
|
|
18
|
+
parser.add_argument("--host", default="127.0.0.1")
|
|
19
|
+
parser.add_argument("--port", default=int(os.environ.get("PORT", 7860)), type=int)
|
|
20
|
+
parser.add_argument("--no-browser", action="store_true",
|
|
21
|
+
help="don't auto-open the browser")
|
|
22
|
+
parser.add_argument("--mujoco-gl",
|
|
23
|
+
default=os.environ.get("MUJOCO_GL", "egl"),
|
|
24
|
+
choices=["egl", "osmesa", "glfw"],
|
|
25
|
+
help="MuJoCo rendering backend")
|
|
26
|
+
parser.add_argument("--projects-dir",
|
|
27
|
+
default=os.environ.get("PROJECTS_DIR"),
|
|
28
|
+
help="where to store project data "
|
|
29
|
+
"(default: ~/.tau/projects)")
|
|
30
|
+
args = parser.parse_args()
|
|
31
|
+
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
|
|
34
|
+
projects_dir = (
|
|
35
|
+
Path(args.projects_dir).expanduser()
|
|
36
|
+
if args.projects_dir
|
|
37
|
+
else Path.home() / ".tau" / "projects"
|
|
38
|
+
)
|
|
39
|
+
projects_dir.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
os.environ["PROJECTS_DIR"] = str(projects_dir)
|
|
41
|
+
os.environ["MUJOCO_GL"] = args.mujoco_gl
|
|
42
|
+
# Local CLI = single user, no sandbox by default.
|
|
43
|
+
os.environ.setdefault("TAU_SINGLE_USER", "1")
|
|
44
|
+
|
|
45
|
+
print(f"\n tau-sim - http://{args.host}:{args.port}", file=sys.stderr)
|
|
46
|
+
print(f" data: {projects_dir}", file=sys.stderr)
|
|
47
|
+
print(f" render: MUJOCO_GL={args.mujoco_gl}\n", file=sys.stderr)
|
|
48
|
+
|
|
49
|
+
if not args.no_browser:
|
|
50
|
+
url = f"http://{args.host}:{args.port}/"
|
|
51
|
+
threading.Thread(
|
|
52
|
+
target=lambda: (time.sleep(0.8), webbrowser.open(url)),
|
|
53
|
+
daemon=True,
|
|
54
|
+
).start()
|
|
55
|
+
|
|
56
|
+
# Import uvicorn lazily so --help is fast.
|
|
57
|
+
import uvicorn
|
|
58
|
+
uvicorn.run("tau_sim.main:app", host=args.host, port=args.port, log_level="warning")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
if __name__ == "__main__":
|
|
62
|
+
main()
|
tau_sim/config.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Settings(BaseSettings):
|
|
6
|
+
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
|
7
|
+
|
|
8
|
+
openai_api_key: str = ""
|
|
9
|
+
gemini_api_key: str = ""
|
|
10
|
+
openrouter_api_key: str = ""
|
|
11
|
+
|
|
12
|
+
llm_default_provider: str = "openai"
|
|
13
|
+
llm_default_model_openai: str = "gpt-4o-mini"
|
|
14
|
+
llm_default_model_gemini: str = "gemini-1.5-flash"
|
|
15
|
+
llm_default_model_openrouter: str = "deepseek/deepseek-v4-flash"
|
|
16
|
+
openrouter_base_url: str = "https://openrouter.ai/api/v1"
|
|
17
|
+
|
|
18
|
+
projects_dir: Path = Path("./projects_data")
|
|
19
|
+
mujoco_gl: str = "egl"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
settings = Settings()
|
|
23
|
+
settings.projects_dir.mkdir(parents=True, exist_ok=True)
|
tau_sim/llm/__init__.py
ADDED
|
File without changes
|
tau_sim/llm/agent.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Chat agent that asks the LLM to return a JSON object with a reply
|
|
2
|
+
plus an optional list of full-file replacement proposals."""
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from tau_sim.llm import manual
|
|
9
|
+
from tau_sim.llm.providers import get_provider
|
|
10
|
+
from tau_sim.schemas import ChatMessage, ChatRequest, ChatResponse, DiffProposal
|
|
11
|
+
from tau_sim.storage import projects
|
|
12
|
+
|
|
13
|
+
_SYSTEM = """You are an expert coding assistant embedded in an IDE for
|
|
14
|
+
building MuJoCo robotic-simulation environments. You collaborate with
|
|
15
|
+
the user on ONE project at a time. The project contains MJCF XML scenes
|
|
16
|
+
and a Python entry module that exposes `make_env()`.
|
|
17
|
+
|
|
18
|
+
# Response format
|
|
19
|
+
|
|
20
|
+
Respond with a single JSON object of EXACTLY this shape and nothing else
|
|
21
|
+
(no prose, no markdown fences):
|
|
22
|
+
|
|
23
|
+
{
|
|
24
|
+
"reply": "<short natural-language explanation for the user>",
|
|
25
|
+
"proposals": [
|
|
26
|
+
{
|
|
27
|
+
"path": "<project-relative file path>",
|
|
28
|
+
"new_content": "<the COMPLETE new contents of that file>",
|
|
29
|
+
"rationale": "<one-sentence why>"
|
|
30
|
+
}
|
|
31
|
+
]
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
Rules for `proposals`:
|
|
35
|
+
- `new_content` is the FULL file (never a diff/patch).
|
|
36
|
+
- Include only files you actually want to change. May be empty.
|
|
37
|
+
- Paths are relative to the project root. Use forward slashes.
|
|
38
|
+
- Creating a new file inside the current project is allowed when the
|
|
39
|
+
user's request implies it (e.g. a helper module, a new asset).
|
|
40
|
+
- NEVER include text outside the JSON object.
|
|
41
|
+
|
|
42
|
+
# Project scope (very important)
|
|
43
|
+
|
|
44
|
+
Each project represents ONE simulated system (one robot, one task).
|
|
45
|
+
When the user asks for something that does not fit the current project:
|
|
46
|
+
|
|
47
|
+
- If the request is a *modification* of the existing system (change
|
|
48
|
+
geometry, tune a reward, add a sensor, fix a bug) -> emit proposals.
|
|
49
|
+
- If the request is a *different system* (e.g. project is a cartpole and
|
|
50
|
+
the user asks for a quadruped, a manipulator, a drone, etc.) -> do
|
|
51
|
+
NOT replace the existing files. Instead return an empty `proposals`
|
|
52
|
+
array and use `reply` to tell the user to click the “+ New” button in
|
|
53
|
+
the top bar to create a new project for that system, and offer to
|
|
54
|
+
generate the env once they switch to it. Mention the suggested
|
|
55
|
+
project name.
|
|
56
|
+
- Never silently replace the entry env, the scene XML, or any file with
|
|
57
|
+
code for an unrelated robot or task.
|
|
58
|
+
|
|
59
|
+
When in doubt, prefer to ASK in `reply` rather than guess.
|
|
60
|
+
|
|
61
|
+
# Importing real robot models
|
|
62
|
+
|
|
63
|
+
You CANNOT fetch URLs or external files yourself. If the user asks to
|
|
64
|
+
import a model from `mujoco_menagerie` (or any other repo of pre-built
|
|
65
|
+
MuJoCo assets), do NOT try to inline the XML/meshes. Instead, in
|
|
66
|
+
`reply`, tell them to:
|
|
67
|
+
|
|
68
|
+
1. Click **+ New** in the top bar.
|
|
69
|
+
2. In the modal, choose template **"Import from MuJoCo Menagerie"**.
|
|
70
|
+
3. Pick the model from the dropdown (e.g. `unitree_h1`, `franka_emika_panda`).
|
|
71
|
+
|
|
72
|
+
The importer downloads the XML + meshes and writes a stub `env.py`. You
|
|
73
|
+
can then help them edit the env's reward / policy in that new project.
|
|
74
|
+
|
|
75
|
+
# Env API contract (the runner calls these exact methods)
|
|
76
|
+
|
|
77
|
+
The entry module must define `make_env()` returning an object with:
|
|
78
|
+
|
|
79
|
+
class MyEnv:
|
|
80
|
+
xml_path = "scene.xml" # any project-relative .xml file
|
|
81
|
+
|
|
82
|
+
def reset(self, data): # may also accept (self, model, data)
|
|
83
|
+
data.qpos[:] = ...
|
|
84
|
+
data.qvel[:] = 0.0
|
|
85
|
+
return self._obs(data) # numpy array of observations
|
|
86
|
+
|
|
87
|
+
def step(self, model, data, action):
|
|
88
|
+
data.ctrl[:] = action # set actuator commands
|
|
89
|
+
obs = self._obs(data)
|
|
90
|
+
reward = float(...)
|
|
91
|
+
done = bool(...)
|
|
92
|
+
return obs, reward, done
|
|
93
|
+
|
|
94
|
+
def policy(self, obs): # called every control step
|
|
95
|
+
return np.array([...]) # shape must match model.nu
|
|
96
|
+
|
|
97
|
+
The runner advances physics (`mujoco.mj_step`) between `step()` calls;
|
|
98
|
+
do NOT call `mj_step` yourself. Render setup is also handled for you.
|
|
99
|
+
|
|
100
|
+
# Style
|
|
101
|
+
|
|
102
|
+
- Keep MJCF valid (closing tags, valid attribute values).
|
|
103
|
+
- Keep Python importable. `import numpy as np` if you use it.
|
|
104
|
+
- Prefer small, focused diffs. Don't reformat unrelated code.
|
|
105
|
+
- Don't add docstrings or comments unless the user asked for them.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _build_context(session_id: str, project_id: str) -> str:
|
|
110
|
+
info = projects.get_project(session_id, project_id)
|
|
111
|
+
files = projects.read_all_text(session_id, project_id)
|
|
112
|
+
entry = _entry_file(files)
|
|
113
|
+
header_lines = [
|
|
114
|
+
f"### Project: {info.name!r} (id: {info.id})",
|
|
115
|
+
f"Entry env module: {entry or '(none yet)'}",
|
|
116
|
+
"The runner loads this entry module, calls `make_env()`, then",
|
|
117
|
+
"drives `env.reset` / `env.step` / `env.policy` per the contract.",
|
|
118
|
+
"All file edits must stay within the scope of this single project.",
|
|
119
|
+
"",
|
|
120
|
+
f"### Current project files ({len(files)} files):",
|
|
121
|
+
"",
|
|
122
|
+
]
|
|
123
|
+
parts = ["\n".join(header_lines)]
|
|
124
|
+
for path, content in files.items():
|
|
125
|
+
parts.append(f"--- {path} ---\n{content}\n")
|
|
126
|
+
return "\n".join(parts)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _entry_file(files: dict[str, str]) -> str | None:
|
|
130
|
+
"""Mirror the runner's entry-file rule for the LLM's benefit."""
|
|
131
|
+
if "env.py" in files:
|
|
132
|
+
return "env.py"
|
|
133
|
+
candidates = sorted(p for p in files if p.endswith("_env.py"))
|
|
134
|
+
return candidates[0] if candidates else None
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _parse(raw: str) -> dict[str, Any]:
|
|
138
|
+
raw = raw.strip()
|
|
139
|
+
# Strip stray markdown fences if a model ignores response_format.
|
|
140
|
+
if raw.startswith("```"):
|
|
141
|
+
raw = raw.strip("`")
|
|
142
|
+
if raw.lower().startswith("json"):
|
|
143
|
+
raw = raw[4:]
|
|
144
|
+
raw = raw.strip()
|
|
145
|
+
try:
|
|
146
|
+
return json.loads(raw)
|
|
147
|
+
except json.JSONDecodeError:
|
|
148
|
+
return {"reply": raw, "proposals": []}
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def run_chat(session_id: str, req: ChatRequest) -> ChatResponse:
|
|
152
|
+
context = _build_context(session_id, req.project_id)
|
|
153
|
+
# Inject project context as a leading user message.
|
|
154
|
+
messages = [ChatMessage(role="user", content=context), *req.messages]
|
|
155
|
+
|
|
156
|
+
if (req.provider or "").lower() == "manual":
|
|
157
|
+
pid = manual.enqueue(
|
|
158
|
+
req.project_id,
|
|
159
|
+
_SYSTEM,
|
|
160
|
+
[m.model_dump() for m in messages],
|
|
161
|
+
)
|
|
162
|
+
return ChatResponse(
|
|
163
|
+
reply="(waiting for human-in-the-loop response…)",
|
|
164
|
+
pending_id=pid,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
provider = get_provider(req.provider, req.api_key)
|
|
168
|
+
raw = provider.complete(_SYSTEM, messages, req.model)
|
|
169
|
+
parsed = _parse(raw)
|
|
170
|
+
|
|
171
|
+
proposals: list[DiffProposal] = []
|
|
172
|
+
for p in parsed.get("proposals", []) or []:
|
|
173
|
+
try:
|
|
174
|
+
proposals.append(DiffProposal(**p))
|
|
175
|
+
except Exception:
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
return ChatResponse(
|
|
179
|
+
reply=str(parsed.get("reply", "")).strip() or "(no reply)",
|
|
180
|
+
proposals=proposals,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def parse_manual_response(payload: dict[str, Any]) -> ChatResponse:
|
|
185
|
+
"""Validate a manually-written response and return a ChatResponse."""
|
|
186
|
+
proposals: list[DiffProposal] = []
|
|
187
|
+
for p in payload.get("proposals", []) or []:
|
|
188
|
+
try:
|
|
189
|
+
proposals.append(DiffProposal(**p))
|
|
190
|
+
except Exception:
|
|
191
|
+
continue
|
|
192
|
+
return ChatResponse(
|
|
193
|
+
reply=str(payload.get("reply", "")).strip() or "(no reply)",
|
|
194
|
+
proposals=proposals,
|
|
195
|
+
)
|
tau_sim/llm/manual.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Human-in-the-loop \"LLM\" queue.
|
|
2
|
+
|
|
3
|
+
When the user picks the `manual` provider, the chat request is written
|
|
4
|
+
to ``manual_queue/<id>/request.json`` and the API returns a `pending_id`.
|
|
5
|
+
A human (or Copilot) drops a JSON file at ``manual_queue/<id>/response.json``
|
|
6
|
+
and the frontend, which is polling, picks it up.
|
|
7
|
+
|
|
8
|
+
Response file shape (same as the regular LLM JSON contract):
|
|
9
|
+
|
|
10
|
+
{"reply": "...", "proposals": [{"path": "...", "new_content": "...",
|
|
11
|
+
"rationale": "..."}]}
|
|
12
|
+
"""
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from tau_sim.config import settings
|
|
22
|
+
|
|
23
|
+
QUEUE_DIR: Path = settings.projects_dir.parent.resolve() / "manual_queue"
|
|
24
|
+
QUEUE_DIR.mkdir(parents=True, exist_ok=True)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _slot(pending_id: str) -> Path:
|
|
28
|
+
p = (QUEUE_DIR / pending_id).resolve()
|
|
29
|
+
if not p.is_relative_to(QUEUE_DIR):
|
|
30
|
+
raise ValueError("invalid id")
|
|
31
|
+
return p
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def enqueue(project_id: str, system: str, messages: list[dict[str, str]]) -> str:
|
|
35
|
+
"""Write a pending request and return its id."""
|
|
36
|
+
pid = uuid.uuid4().hex[:10]
|
|
37
|
+
slot = _slot(pid)
|
|
38
|
+
slot.mkdir(parents=True, exist_ok=False)
|
|
39
|
+
(slot / "request.json").write_text(
|
|
40
|
+
json.dumps(
|
|
41
|
+
{
|
|
42
|
+
"id": pid,
|
|
43
|
+
"project_id": project_id,
|
|
44
|
+
"created_at": time.time(),
|
|
45
|
+
"system": system,
|
|
46
|
+
"messages": messages,
|
|
47
|
+
},
|
|
48
|
+
indent=2,
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
return pid
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def list_pending() -> list[dict[str, Any]]:
|
|
55
|
+
out: list[dict[str, Any]] = []
|
|
56
|
+
for d in sorted(QUEUE_DIR.iterdir()):
|
|
57
|
+
if not d.is_dir():
|
|
58
|
+
continue
|
|
59
|
+
req = d / "request.json"
|
|
60
|
+
resp = d / "response.json"
|
|
61
|
+
if req.exists() and not resp.exists():
|
|
62
|
+
try:
|
|
63
|
+
out.append(json.loads(req.read_text()))
|
|
64
|
+
except json.JSONDecodeError:
|
|
65
|
+
continue
|
|
66
|
+
return out
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_request(pending_id: str) -> dict[str, Any] | None:
|
|
70
|
+
req = _slot(pending_id) / "request.json"
|
|
71
|
+
return json.loads(req.read_text()) if req.exists() else None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def write_response(pending_id: str, payload: dict[str, Any]) -> None:
|
|
75
|
+
slot = _slot(pending_id)
|
|
76
|
+
if not slot.is_dir():
|
|
77
|
+
raise FileNotFoundError(pending_id)
|
|
78
|
+
(slot / "response.json").write_text(json.dumps(payload, indent=2))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def read_response(pending_id: str) -> dict[str, Any] | None:
|
|
82
|
+
resp = _slot(pending_id) / "response.json"
|
|
83
|
+
return json.loads(resp.read_text()) if resp.exists() else None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def delete(pending_id: str) -> None:
|
|
87
|
+
import shutil
|
|
88
|
+
|
|
89
|
+
slot = _slot(pending_id)
|
|
90
|
+
if slot.is_dir():
|
|
91
|
+
shutil.rmtree(slot)
|
tau_sim/llm/providers.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Pluggable LLM providers.
|
|
2
|
+
|
|
3
|
+
Each provider exposes ``complete(system, messages, model) -> str``.
|
|
4
|
+
The agent layer wraps these and parses structured diff proposals.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
|
|
10
|
+
from tau_sim.config import settings
|
|
11
|
+
from tau_sim.schemas import ChatMessage
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LLMProvider(ABC):
|
|
15
|
+
name: str
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def complete(self, system: str, messages: list[ChatMessage], model: str | None) -> str: ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OpenAIProvider(LLMProvider):
|
|
22
|
+
name = "openai"
|
|
23
|
+
|
|
24
|
+
def __init__(self, api_key: str | None = None) -> None:
|
|
25
|
+
from openai import OpenAI
|
|
26
|
+
|
|
27
|
+
key = api_key or settings.openai_api_key
|
|
28
|
+
if not key:
|
|
29
|
+
raise RuntimeError("OPENAI_API_KEY is not set")
|
|
30
|
+
self._client = OpenAI(api_key=key)
|
|
31
|
+
|
|
32
|
+
def complete(self, system: str, messages: list[ChatMessage], model: str | None) -> str:
|
|
33
|
+
chat = [{"role": "system", "content": system}] + [
|
|
34
|
+
{"role": m.role, "content": m.content} for m in messages
|
|
35
|
+
]
|
|
36
|
+
resp = self._client.chat.completions.create(
|
|
37
|
+
model=model or settings.llm_default_model_openai,
|
|
38
|
+
messages=chat,
|
|
39
|
+
response_format={"type": "json_object"},
|
|
40
|
+
temperature=0.2,
|
|
41
|
+
)
|
|
42
|
+
return resp.choices[0].message.content or ""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GeminiProvider(LLMProvider):
|
|
46
|
+
name = "gemini"
|
|
47
|
+
|
|
48
|
+
def __init__(self, api_key: str | None = None) -> None:
|
|
49
|
+
import google.generativeai as genai
|
|
50
|
+
|
|
51
|
+
key = api_key or settings.gemini_api_key
|
|
52
|
+
if not key:
|
|
53
|
+
raise RuntimeError("GEMINI_API_KEY is not set")
|
|
54
|
+
genai.configure(api_key=key)
|
|
55
|
+
self._genai = genai
|
|
56
|
+
|
|
57
|
+
def complete(self, system: str, messages: list[ChatMessage], model: str | None) -> str:
|
|
58
|
+
gm = self._genai.GenerativeModel(
|
|
59
|
+
model_name=model or settings.llm_default_model_gemini,
|
|
60
|
+
system_instruction=system,
|
|
61
|
+
generation_config={"response_mime_type": "application/json", "temperature": 0.2},
|
|
62
|
+
)
|
|
63
|
+
# Gemini expects alternating user/model turns.
|
|
64
|
+
history = []
|
|
65
|
+
for m in messages[:-1]:
|
|
66
|
+
history.append({"role": "user" if m.role == "user" else "model", "parts": [m.content]})
|
|
67
|
+
chat = gm.start_chat(history=history)
|
|
68
|
+
last = messages[-1].content if messages else ""
|
|
69
|
+
resp = chat.send_message(last)
|
|
70
|
+
return resp.text or ""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class OpenRouterProvider(LLMProvider):
|
|
74
|
+
"""OpenAI-compatible client pointed at OpenRouter."""
|
|
75
|
+
|
|
76
|
+
name = "openrouter"
|
|
77
|
+
|
|
78
|
+
def __init__(self, api_key: str | None = None) -> None:
|
|
79
|
+
from openai import OpenAI
|
|
80
|
+
|
|
81
|
+
key = api_key or settings.openrouter_api_key
|
|
82
|
+
if not key:
|
|
83
|
+
raise RuntimeError("OPENROUTER_API_KEY is not set")
|
|
84
|
+
self._client = OpenAI(
|
|
85
|
+
api_key=key,
|
|
86
|
+
base_url=settings.openrouter_base_url,
|
|
87
|
+
default_headers={
|
|
88
|
+
# Optional but recommended by OpenRouter for attribution.
|
|
89
|
+
"HTTP-Referer": "https://github.com/sim_llm",
|
|
90
|
+
"X-Title": "sim_llm",
|
|
91
|
+
},
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def complete(self, system: str, messages: list[ChatMessage], model: str | None) -> str:
|
|
95
|
+
chat = [{"role": "system", "content": system}] + [
|
|
96
|
+
{"role": m.role, "content": m.content} for m in messages
|
|
97
|
+
]
|
|
98
|
+
resp = self._client.chat.completions.create(
|
|
99
|
+
model=model or settings.llm_default_model_openrouter,
|
|
100
|
+
messages=chat,
|
|
101
|
+
response_format={"type": "json_object"},
|
|
102
|
+
temperature=0.2,
|
|
103
|
+
)
|
|
104
|
+
return resp.choices[0].message.content or ""
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_provider(name: str | None, api_key: str | None = None) -> LLMProvider:
|
|
108
|
+
name = (name or settings.llm_default_provider).lower()
|
|
109
|
+
if name == "openai":
|
|
110
|
+
return OpenAIProvider(api_key)
|
|
111
|
+
if name == "gemini":
|
|
112
|
+
return GeminiProvider(api_key)
|
|
113
|
+
if name == "openrouter":
|
|
114
|
+
return OpenRouterProvider(api_key)
|
|
115
|
+
raise ValueError(f"unknown provider: {name}")
|