researchloop 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- researchloop/__init__.py +1 -0
- researchloop/__main__.py +3 -0
- researchloop/cli.py +1138 -0
- researchloop/clusters/__init__.py +4 -0
- researchloop/clusters/monitor.py +199 -0
- researchloop/clusters/ssh.py +183 -0
- researchloop/comms/__init__.py +0 -0
- researchloop/comms/base.py +34 -0
- researchloop/comms/conversation.py +465 -0
- researchloop/comms/ntfy.py +95 -0
- researchloop/comms/router.py +71 -0
- researchloop/comms/slack.py +188 -0
- researchloop/core/__init__.py +0 -0
- researchloop/core/auth.py +78 -0
- researchloop/core/config.py +328 -0
- researchloop/core/credentials.py +38 -0
- researchloop/core/models.py +119 -0
- researchloop/core/orchestrator.py +910 -0
- researchloop/dashboard/__init__.py +0 -0
- researchloop/dashboard/app.py +15 -0
- researchloop/dashboard/auth.py +60 -0
- researchloop/dashboard/routes.py +912 -0
- researchloop/dashboard/templates/base.html +84 -0
- researchloop/dashboard/templates/login.html +12 -0
- researchloop/dashboard/templates/loop_detail.html +58 -0
- researchloop/dashboard/templates/loops.html +61 -0
- researchloop/dashboard/templates/setup.html +14 -0
- researchloop/dashboard/templates/sprint_detail.html +109 -0
- researchloop/dashboard/templates/sprints.html +48 -0
- researchloop/dashboard/templates/studies.html +18 -0
- researchloop/dashboard/templates/study_detail.html +64 -0
- researchloop/db/__init__.py +5 -0
- researchloop/db/database.py +86 -0
- researchloop/db/migrations.py +172 -0
- researchloop/db/queries.py +351 -0
- researchloop/runner/__init__.py +1 -0
- researchloop/runner/claude.py +169 -0
- researchloop/runner/job_templates/sge.sh.j2 +319 -0
- researchloop/runner/job_templates/slurm.sh.j2 +336 -0
- researchloop/runner/main.py +156 -0
- researchloop/runner/pipeline.py +272 -0
- researchloop/runner/templates/fix_issues.md.j2 +11 -0
- researchloop/runner/templates/idea_generator.md.j2 +16 -0
- researchloop/runner/templates/red_team.md.j2 +15 -0
- researchloop/runner/templates/report.md.j2 +31 -0
- researchloop/runner/templates/research_sprint.md.j2 +51 -0
- researchloop/runner/templates/summarizer.md.j2 +7 -0
- researchloop/runner/upload.py +153 -0
- researchloop/schedulers/__init__.py +11 -0
- researchloop/schedulers/base.py +43 -0
- researchloop/schedulers/local.py +188 -0
- researchloop/schedulers/sge.py +163 -0
- researchloop/schedulers/slurm.py +179 -0
- researchloop/sprints/__init__.py +0 -0
- researchloop/sprints/auto_loop.py +458 -0
- researchloop/sprints/manager.py +750 -0
- researchloop/studies/__init__.py +0 -0
- researchloop/studies/manager.py +102 -0
- researchloop-0.1.0.dist-info/METADATA +596 -0
- researchloop-0.1.0.dist-info/RECORD +63 -0
- researchloop-0.1.0.dist-info/WHEEL +4 -0
- researchloop-0.1.0.dist-info/entry_points.txt +3 -0
- researchloop-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,750 @@
|
|
|
1
|
+
"""Sprint lifecycle management -- create, submit, cancel, and complete sprints."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import jinja2
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from researchloop.clusters.ssh import SSHManager
|
|
16
|
+
from researchloop.comms.router import NotificationRouter
|
|
17
|
+
from researchloop.core.config import Config
|
|
18
|
+
from researchloop.db.database import Database
|
|
19
|
+
from researchloop.schedulers.base import BaseScheduler
|
|
20
|
+
|
|
21
|
+
from researchloop.core.models import (
|
|
22
|
+
Sprint,
|
|
23
|
+
SprintStatus,
|
|
24
|
+
format_sprint_dirname,
|
|
25
|
+
generate_sprint_id,
|
|
26
|
+
)
|
|
27
|
+
from researchloop.db import queries
|
|
28
|
+
from researchloop.studies.manager import StudyManager
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _b64encode(text: str) -> str:
|
|
34
|
+
"""Base64-encode a string for safe SSH transfer."""
|
|
35
|
+
return base64.b64encode(text.encode("utf-8")).decode("ascii")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Jinja2 environment pointing at the runner/job_templates directory.
|
|
39
|
+
_TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "runner" / "job_templates"
|
|
40
|
+
_jinja_env = jinja2.Environment(
|
|
41
|
+
loader=jinja2.FileSystemLoader(str(_TEMPLATES_DIR)),
|
|
42
|
+
keep_trailing_newline=True,
|
|
43
|
+
undefined=jinja2.StrictUndefined,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Prompt templates for the research pipeline steps.
|
|
47
|
+
_PROMPT_TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "runner" / "templates"
|
|
48
|
+
_prompt_env = jinja2.Environment(
|
|
49
|
+
loader=jinja2.FileSystemLoader(str(_PROMPT_TEMPLATES_DIR)),
|
|
50
|
+
keep_trailing_newline=True,
|
|
51
|
+
undefined=jinja2.StrictUndefined,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class SprintManager:
|
|
56
|
+
"""Manages the full lifecycle of research sprints.
|
|
57
|
+
|
|
58
|
+
Coordinates between the database, SSH connections, job schedulers,
|
|
59
|
+
and the notification router.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
db: Database,
|
|
65
|
+
config: Config,
|
|
66
|
+
ssh_manager: SSHManager,
|
|
67
|
+
schedulers: dict[str, BaseScheduler],
|
|
68
|
+
study_manager: StudyManager | None = None,
|
|
69
|
+
notification_router: NotificationRouter | None = None,
|
|
70
|
+
) -> None:
|
|
71
|
+
self.db = db
|
|
72
|
+
self.config = config
|
|
73
|
+
self.ssh_manager = ssh_manager
|
|
74
|
+
self.schedulers = schedulers
|
|
75
|
+
self.study_manager = study_manager
|
|
76
|
+
self.notification_router = notification_router
|
|
77
|
+
|
|
78
|
+
# ------------------------------------------------------------------
|
|
79
|
+
# Create
|
|
80
|
+
# ------------------------------------------------------------------
|
|
81
|
+
|
|
82
|
+
async def create_sprint(self, study_name: str, idea: str | None = None) -> Sprint:
|
|
83
|
+
"""Create a new sprint record in the database.
|
|
84
|
+
|
|
85
|
+
The sprint is created with status ``PENDING`` -- it has not yet
|
|
86
|
+
been submitted to a cluster scheduler.
|
|
87
|
+
"""
|
|
88
|
+
sprint_id = generate_sprint_id()
|
|
89
|
+
directory = format_sprint_dirname(sprint_id, idea)
|
|
90
|
+
|
|
91
|
+
await queries.create_sprint(
|
|
92
|
+
self.db,
|
|
93
|
+
id=sprint_id,
|
|
94
|
+
study_name=study_name,
|
|
95
|
+
idea=idea,
|
|
96
|
+
directory=directory,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
sprint = Sprint(
|
|
100
|
+
id=sprint_id,
|
|
101
|
+
study_name=study_name,
|
|
102
|
+
idea=idea,
|
|
103
|
+
status=SprintStatus.PENDING,
|
|
104
|
+
directory=directory,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
logger.info(
|
|
108
|
+
"Created sprint %s for study %r: %s",
|
|
109
|
+
sprint_id,
|
|
110
|
+
study_name,
|
|
111
|
+
idea,
|
|
112
|
+
)
|
|
113
|
+
return sprint
|
|
114
|
+
|
|
115
|
+
# ------------------------------------------------------------------
|
|
116
|
+
# Submit
|
|
117
|
+
# ------------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
async def submit_sprint(
|
|
120
|
+
self,
|
|
121
|
+
sprint_id: str,
|
|
122
|
+
extra_job_options: dict[str, str] | None = None,
|
|
123
|
+
) -> str:
|
|
124
|
+
"""Submit a pending sprint to its cluster scheduler.
|
|
125
|
+
|
|
126
|
+
Returns the scheduler-assigned job ID.
|
|
127
|
+
"""
|
|
128
|
+
sprint = await queries.get_sprint(self.db, sprint_id)
|
|
129
|
+
if sprint is None:
|
|
130
|
+
raise ValueError(f"Sprint not found: {sprint_id}")
|
|
131
|
+
|
|
132
|
+
study_name: str = sprint["study_name"]
|
|
133
|
+
|
|
134
|
+
# Resolve cluster config through study manager or config lookup.
|
|
135
|
+
if self.study_manager is not None:
|
|
136
|
+
cluster_cfg = await self.study_manager.get_cluster_config(study_name)
|
|
137
|
+
else:
|
|
138
|
+
# Fallback: look up directly from config.
|
|
139
|
+
study_row = await queries.get_study(self.db, study_name)
|
|
140
|
+
if study_row is None:
|
|
141
|
+
raise ValueError(f"Study not found: {study_name}")
|
|
142
|
+
cluster_name = study_row["cluster"]
|
|
143
|
+
cluster_cfg = None
|
|
144
|
+
for c in self.config.clusters:
|
|
145
|
+
if c.name == cluster_name:
|
|
146
|
+
cluster_cfg = c
|
|
147
|
+
break
|
|
148
|
+
if cluster_cfg is None:
|
|
149
|
+
raise ValueError(f"Cluster not found: {cluster_name}")
|
|
150
|
+
|
|
151
|
+
# Look up the study config for template variables.
|
|
152
|
+
study_cfg = None
|
|
153
|
+
for s in self.config.studies:
|
|
154
|
+
if s.name == study_name:
|
|
155
|
+
study_cfg = s
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
# Resolve scheduler.
|
|
159
|
+
scheduler = self.schedulers.get(cluster_cfg.name)
|
|
160
|
+
if scheduler is None:
|
|
161
|
+
scheduler = self.schedulers.get(cluster_cfg.scheduler_type)
|
|
162
|
+
if scheduler is None:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"No scheduler registered for cluster {cluster_cfg.name!r} "
|
|
165
|
+
f"or type {cluster_cfg.scheduler_type!r}"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
sprint_dirname = sprint.get("directory", sprint_id)
|
|
169
|
+
|
|
170
|
+
# Collect context: global → cluster → study.
|
|
171
|
+
# Each level supports inline text + file paths.
|
|
172
|
+
context_parts: list[str] = []
|
|
173
|
+
|
|
174
|
+
# 1. Global inline context.
|
|
175
|
+
if self.config.context:
|
|
176
|
+
context_parts.append(self.config.context)
|
|
177
|
+
|
|
178
|
+
# 2. Global context files.
|
|
179
|
+
for ctx_path in self.config.context_paths:
|
|
180
|
+
p = Path(ctx_path)
|
|
181
|
+
if p.exists():
|
|
182
|
+
context_parts.append(p.read_text(encoding="utf-8"))
|
|
183
|
+
logger.info("Loaded global context file: %s", p)
|
|
184
|
+
|
|
185
|
+
# 3. Cluster inline context.
|
|
186
|
+
if cluster_cfg.context:
|
|
187
|
+
context_parts.append(cluster_cfg.context)
|
|
188
|
+
|
|
189
|
+
# 4. Cluster context files.
|
|
190
|
+
for ctx_path in cluster_cfg.context_paths:
|
|
191
|
+
p = Path(ctx_path)
|
|
192
|
+
if p.exists():
|
|
193
|
+
context_parts.append(p.read_text(encoding="utf-8"))
|
|
194
|
+
logger.info("Loaded cluster context file: %s", p)
|
|
195
|
+
|
|
196
|
+
# 5. Study inline context.
|
|
197
|
+
if study_cfg and study_cfg.context:
|
|
198
|
+
context_parts.append(study_cfg.context)
|
|
199
|
+
|
|
200
|
+
# 6. Study context file.
|
|
201
|
+
if study_cfg and study_cfg.claude_md_path:
|
|
202
|
+
p = Path(study_cfg.claude_md_path)
|
|
203
|
+
if p.exists():
|
|
204
|
+
context_parts.append(p.read_text(encoding="utf-8"))
|
|
205
|
+
logger.info("Loaded study context file: %s", p)
|
|
206
|
+
|
|
207
|
+
has_context = bool(context_parts)
|
|
208
|
+
study_context = "\n\n".join(context_parts) if has_context else ""
|
|
209
|
+
idea = sprint["idea"]
|
|
210
|
+
red_team_rounds = study_cfg.red_team_max_rounds if study_cfg else 3
|
|
211
|
+
|
|
212
|
+
# Resolve the base directory for sprints.
|
|
213
|
+
# Priority: study.sprints_dir > working_dir/<study_name>
|
|
214
|
+
if study_cfg and study_cfg.sprints_dir:
|
|
215
|
+
sprints_base = study_cfg.sprints_dir
|
|
216
|
+
else:
|
|
217
|
+
sprints_base = f"{cluster_cfg.working_dir}/{study_name}"
|
|
218
|
+
sprint_remote_dir = f"{sprints_base}/{sprint_dirname}"
|
|
219
|
+
|
|
220
|
+
# Pre-render all pipeline prompt templates.
|
|
221
|
+
def _render_prompt(name: str, **kw: object) -> str:
|
|
222
|
+
return _prompt_env.get_template(name).render(**kw)
|
|
223
|
+
|
|
224
|
+
prompts: list[dict[str, str]] = []
|
|
225
|
+
|
|
226
|
+
# For loop sprints, idea is None — the job script
|
|
227
|
+
# will generate it and overwrite the prompt.
|
|
228
|
+
idea_text = idea or "(will be auto-generated)"
|
|
229
|
+
|
|
230
|
+
# Research prompt
|
|
231
|
+
prompts.append(
|
|
232
|
+
{
|
|
233
|
+
"filename": "prompt_research.md",
|
|
234
|
+
"content_b64": _b64encode(
|
|
235
|
+
_render_prompt(
|
|
236
|
+
"research_sprint.md.j2",
|
|
237
|
+
study_context=study_context,
|
|
238
|
+
idea=idea_text,
|
|
239
|
+
sprint_dir=sprint_remote_dir,
|
|
240
|
+
)
|
|
241
|
+
),
|
|
242
|
+
}
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Red-team + fix prompts (one pair per round)
|
|
246
|
+
for r in range(1, red_team_rounds + 1):
|
|
247
|
+
prompts.append(
|
|
248
|
+
{
|
|
249
|
+
"filename": f"prompt_red_team_{r}.md",
|
|
250
|
+
"content_b64": _b64encode(
|
|
251
|
+
_render_prompt(
|
|
252
|
+
"red_team.md.j2",
|
|
253
|
+
idea=idea_text,
|
|
254
|
+
round_number=r,
|
|
255
|
+
max_rounds=red_team_rounds,
|
|
256
|
+
)
|
|
257
|
+
),
|
|
258
|
+
}
|
|
259
|
+
)
|
|
260
|
+
prompts.append(
|
|
261
|
+
{
|
|
262
|
+
"filename": f"prompt_fix_{r}.md",
|
|
263
|
+
"content_b64": _b64encode(
|
|
264
|
+
_render_prompt(
|
|
265
|
+
"fix_issues.md.j2",
|
|
266
|
+
round_number=r,
|
|
267
|
+
)
|
|
268
|
+
),
|
|
269
|
+
}
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Report + summarize prompts
|
|
273
|
+
prompts.append(
|
|
274
|
+
{
|
|
275
|
+
"filename": "prompt_report.md",
|
|
276
|
+
"content_b64": _b64encode(
|
|
277
|
+
_render_prompt("report.md.j2", idea=idea_text)
|
|
278
|
+
),
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
prompts.append(
|
|
282
|
+
{
|
|
283
|
+
"filename": "prompt_summarize.md",
|
|
284
|
+
"content_b64": _b64encode(_render_prompt("summarizer.md.j2")),
|
|
285
|
+
}
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# If this sprint belongs to an auto-loop, add the idea
|
|
289
|
+
# generator prompt so the job generates its own idea.
|
|
290
|
+
is_loop_sprint = bool(sprint.get("loop_id"))
|
|
291
|
+
if is_loop_sprint:
|
|
292
|
+
# Find the loop for extra context.
|
|
293
|
+
loop_context = ""
|
|
294
|
+
all_loops = await queries.list_auto_loops(self.db)
|
|
295
|
+
for lp in all_loops:
|
|
296
|
+
if lp.get("current_sprint_id") == sprint_id:
|
|
297
|
+
meta = lp.get("metadata_json")
|
|
298
|
+
if meta:
|
|
299
|
+
try:
|
|
300
|
+
import json as _json
|
|
301
|
+
|
|
302
|
+
loop_context = _json.loads(meta).get("context", "")
|
|
303
|
+
except Exception:
|
|
304
|
+
pass
|
|
305
|
+
break
|
|
306
|
+
|
|
307
|
+
# Collect previous summaries.
|
|
308
|
+
prev_sprints = await queries.list_sprints(
|
|
309
|
+
self.db, study_name=study_name, limit=50
|
|
310
|
+
)
|
|
311
|
+
prev_summaries = [
|
|
312
|
+
{
|
|
313
|
+
"id": s["id"],
|
|
314
|
+
"summary": s.get("summary", ""),
|
|
315
|
+
}
|
|
316
|
+
for s in prev_sprints
|
|
317
|
+
if s.get("summary")
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
# Build the idea generator prompt with loop context.
|
|
321
|
+
idea_prompt = _render_prompt(
|
|
322
|
+
"idea_generator.md.j2",
|
|
323
|
+
study_context=study_context,
|
|
324
|
+
previous_sprints=prev_summaries,
|
|
325
|
+
)
|
|
326
|
+
if loop_context:
|
|
327
|
+
idea_prompt += f"\n\n## Additional Guidance\n{loop_context}\n"
|
|
328
|
+
|
|
329
|
+
prompts.append(
|
|
330
|
+
{
|
|
331
|
+
"filename": "prompt_generate_idea.md",
|
|
332
|
+
"content_b64": _b64encode(idea_prompt),
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Render the job script.
|
|
337
|
+
template_name = f"{cluster_cfg.scheduler_type}.sh.j2"
|
|
338
|
+
template = _jinja_env.get_template(template_name)
|
|
339
|
+
job_script = template.render(
|
|
340
|
+
sprint_id=sprint_id,
|
|
341
|
+
study_name=study_name,
|
|
342
|
+
idea=idea_text,
|
|
343
|
+
sprint_dirname=sprint_dirname,
|
|
344
|
+
job_name=f"rl-{sprint_id}",
|
|
345
|
+
working_dir=sprints_base,
|
|
346
|
+
time_limit=(
|
|
347
|
+
f"{study_cfg.max_sprint_duration_hours}:00:00"
|
|
348
|
+
if study_cfg
|
|
349
|
+
else "8:00:00"
|
|
350
|
+
),
|
|
351
|
+
environment=cluster_cfg.environment,
|
|
352
|
+
job_options={
|
|
353
|
+
**cluster_cfg.job_options,
|
|
354
|
+
**(study_cfg.job_options if study_cfg else {}),
|
|
355
|
+
**(extra_job_options or {}),
|
|
356
|
+
},
|
|
357
|
+
claude_command=(
|
|
358
|
+
(study_cfg.claude_command if study_cfg else "")
|
|
359
|
+
or cluster_cfg.claude_command
|
|
360
|
+
or self.config.claude_command
|
|
361
|
+
or "claude --dangerously-skip-permissions"
|
|
362
|
+
),
|
|
363
|
+
orchestrator_url=self.config.orchestrator_url or "",
|
|
364
|
+
webhook_token=sprint.get("webhook_token", ""),
|
|
365
|
+
red_team_max_rounds=red_team_rounds,
|
|
366
|
+
prompts=prompts,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# SSH to cluster: create sprint directory and write job script.
|
|
370
|
+
cluster_dict = {
|
|
371
|
+
"host": cluster_cfg.host,
|
|
372
|
+
"port": cluster_cfg.port,
|
|
373
|
+
"user": cluster_cfg.user,
|
|
374
|
+
"key_path": cluster_cfg.key_path,
|
|
375
|
+
}
|
|
376
|
+
ssh = await self.ssh_manager.get_connection(cluster_dict)
|
|
377
|
+
|
|
378
|
+
sprint_remote_dir = f"{sprints_base}/{sprint_dirname}"
|
|
379
|
+
await ssh.run(
|
|
380
|
+
f"mkdir -p {sprint_remote_dir}/.researchloop {sprint_remote_dir}/results"
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# Upload CLAUDE.md so Claude CLI picks it up automatically.
|
|
384
|
+
if has_context:
|
|
385
|
+
encoded_ctx = _b64encode(study_context)
|
|
386
|
+
await ssh.run(
|
|
387
|
+
f"echo '{encoded_ctx}' | base64 -d > {sprint_remote_dir}/CLAUDE.md"
|
|
388
|
+
)
|
|
389
|
+
logger.info(
|
|
390
|
+
"Uploaded CLAUDE.md (%d parts) to %s",
|
|
391
|
+
len(context_parts),
|
|
392
|
+
sprint_remote_dir,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# Write idea.txt so it's always available on cluster.
|
|
396
|
+
if idea:
|
|
397
|
+
encoded_idea = _b64encode(idea)
|
|
398
|
+
await ssh.run(
|
|
399
|
+
f"echo '{encoded_idea}' | base64 -d > {sprint_remote_dir}/idea.txt"
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Write the job script via base64.
|
|
403
|
+
# Prompts are embedded in the script as base64.
|
|
404
|
+
script_path = f"{sprint_remote_dir}/run_sprint.sh"
|
|
405
|
+
encoded_script = _b64encode(job_script)
|
|
406
|
+
await ssh.run(f"echo '{encoded_script}' | base64 -d > {script_path}")
|
|
407
|
+
await ssh.run(f"chmod +x {script_path}")
|
|
408
|
+
|
|
409
|
+
# Submit via the scheduler.
|
|
410
|
+
job_id = await scheduler.submit(
|
|
411
|
+
ssh=ssh,
|
|
412
|
+
script=script_path,
|
|
413
|
+
job_name=f"rl-{sprint_id}",
|
|
414
|
+
working_dir=sprint_remote_dir,
|
|
415
|
+
env=cluster_cfg.environment or None,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Update the sprint record.
|
|
419
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
420
|
+
await queries.update_sprint(
|
|
421
|
+
self.db,
|
|
422
|
+
sprint_id,
|
|
423
|
+
job_id=job_id,
|
|
424
|
+
status=SprintStatus.SUBMITTED.value,
|
|
425
|
+
started_at=now,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
logger.info(
|
|
429
|
+
"Sprint %s submitted as job %s on cluster %s",
|
|
430
|
+
sprint_id,
|
|
431
|
+
job_id,
|
|
432
|
+
cluster_cfg.name,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# Notify.
|
|
436
|
+
if self.notification_router is not None:
|
|
437
|
+
await self.notification_router.notify_sprint_started(
|
|
438
|
+
sprint_id=sprint_id,
|
|
439
|
+
study_name=study_name,
|
|
440
|
+
idea=sprint["idea"],
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
return job_id
|
|
444
|
+
|
|
445
|
+
# ------------------------------------------------------------------
|
|
446
|
+
# Combined create + submit
|
|
447
|
+
# ------------------------------------------------------------------
|
|
448
|
+
|
|
449
|
+
async def run_sprint(
|
|
450
|
+
self,
|
|
451
|
+
study_name: str,
|
|
452
|
+
idea: str | None = None,
|
|
453
|
+
job_options: dict[str, str] | None = None,
|
|
454
|
+
) -> Sprint:
|
|
455
|
+
"""Create a sprint and immediately submit it.
|
|
456
|
+
|
|
457
|
+
Returns the :class:`Sprint` with updated status and job ID.
|
|
458
|
+
"""
|
|
459
|
+
sprint = await self.create_sprint(study_name, idea)
|
|
460
|
+
job_id = await self.submit_sprint(sprint.id, extra_job_options=job_options)
|
|
461
|
+
sprint.status = SprintStatus.SUBMITTED
|
|
462
|
+
sprint.job_id = job_id
|
|
463
|
+
return sprint
|
|
464
|
+
|
|
465
|
+
# ------------------------------------------------------------------
|
|
466
|
+
# Cancel
|
|
467
|
+
# ------------------------------------------------------------------
|
|
468
|
+
|
|
469
|
+
async def cancel_sprint(self, sprint_id: str) -> bool:
|
|
470
|
+
"""Cancel a running or submitted sprint.
|
|
471
|
+
|
|
472
|
+
Returns ``True`` if the cancellation succeeded.
|
|
473
|
+
If the sprint belongs to an auto-loop, the loop is also stopped.
|
|
474
|
+
"""
|
|
475
|
+
sprint = await queries.get_sprint(self.db, sprint_id)
|
|
476
|
+
if sprint is None:
|
|
477
|
+
raise ValueError(f"Sprint not found: {sprint_id}")
|
|
478
|
+
|
|
479
|
+
study_name: str = sprint["study_name"]
|
|
480
|
+
|
|
481
|
+
# Resolve cluster.
|
|
482
|
+
if self.study_manager is not None:
|
|
483
|
+
cluster_cfg = await self.study_manager.get_cluster_config(study_name)
|
|
484
|
+
else:
|
|
485
|
+
study_row = await queries.get_study(self.db, study_name)
|
|
486
|
+
if study_row is None:
|
|
487
|
+
raise ValueError(f"Study not found: {study_name}")
|
|
488
|
+
cluster_name = study_row["cluster"]
|
|
489
|
+
cluster_cfg = None
|
|
490
|
+
for c in self.config.clusters:
|
|
491
|
+
if c.name == cluster_name:
|
|
492
|
+
cluster_cfg = c
|
|
493
|
+
break
|
|
494
|
+
if cluster_cfg is None:
|
|
495
|
+
raise ValueError(f"Cluster not found: {cluster_name}")
|
|
496
|
+
|
|
497
|
+
scheduler = self.schedulers.get(cluster_cfg.name)
|
|
498
|
+
if scheduler is None:
|
|
499
|
+
scheduler = self.schedulers.get(cluster_cfg.scheduler_type)
|
|
500
|
+
if scheduler is None:
|
|
501
|
+
raise ValueError(f"No scheduler for cluster {cluster_cfg.name!r}")
|
|
502
|
+
|
|
503
|
+
job_id = sprint.get("job_id")
|
|
504
|
+
if not job_id:
|
|
505
|
+
logger.warning("Sprint %s has no job_id, marking as cancelled", sprint_id)
|
|
506
|
+
await queries.update_sprint(
|
|
507
|
+
self.db,
|
|
508
|
+
sprint_id,
|
|
509
|
+
status=SprintStatus.CANCELLED.value,
|
|
510
|
+
completed_at=datetime.now(timezone.utc).isoformat(),
|
|
511
|
+
)
|
|
512
|
+
else:
|
|
513
|
+
cluster_dict = {
|
|
514
|
+
"host": cluster_cfg.host,
|
|
515
|
+
"port": cluster_cfg.port,
|
|
516
|
+
"user": cluster_cfg.user,
|
|
517
|
+
"key_path": cluster_cfg.key_path,
|
|
518
|
+
}
|
|
519
|
+
ssh = await self.ssh_manager.get_connection(cluster_dict)
|
|
520
|
+
await scheduler.cancel(ssh, job_id)
|
|
521
|
+
|
|
522
|
+
await queries.update_sprint(
|
|
523
|
+
self.db,
|
|
524
|
+
sprint_id,
|
|
525
|
+
status=SprintStatus.CANCELLED.value,
|
|
526
|
+
completed_at=datetime.now(timezone.utc).isoformat(),
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
logger.info("Sprint %s cancelled", sprint_id)
|
|
530
|
+
|
|
531
|
+
# Stop the parent auto-loop if this sprint belongs to one.
|
|
532
|
+
loop_id = sprint.get("loop_id")
|
|
533
|
+
if loop_id:
|
|
534
|
+
try:
|
|
535
|
+
loop = await queries.get_auto_loop(self.db, loop_id)
|
|
536
|
+
if loop and loop["status"] == "running":
|
|
537
|
+
await queries.update_auto_loop(
|
|
538
|
+
self.db,
|
|
539
|
+
loop_id,
|
|
540
|
+
status="stopped",
|
|
541
|
+
stopped_at=datetime.now(timezone.utc).isoformat(),
|
|
542
|
+
)
|
|
543
|
+
logger.info(
|
|
544
|
+
"Auto-loop %s stopped (sprint %s cancelled)",
|
|
545
|
+
loop_id,
|
|
546
|
+
sprint_id,
|
|
547
|
+
)
|
|
548
|
+
except Exception:
|
|
549
|
+
logger.warning(
|
|
550
|
+
"Failed to stop loop %s after cancelling sprint %s",
|
|
551
|
+
loop_id,
|
|
552
|
+
sprint_id,
|
|
553
|
+
exc_info=True,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# Notify about cancellation.
|
|
557
|
+
if self.notification_router is not None:
|
|
558
|
+
await self.notification_router.notify_sprint_failed(
|
|
559
|
+
sprint_id=sprint_id,
|
|
560
|
+
study_name=study_name,
|
|
561
|
+
error="Sprint cancelled",
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
return True
|
|
565
|
+
|
|
566
|
+
# ------------------------------------------------------------------
|
|
567
|
+
# Query helpers
|
|
568
|
+
# ------------------------------------------------------------------
|
|
569
|
+
|
|
570
|
+
async def get_sprint(self, sprint_id: str) -> dict | None:
|
|
571
|
+
"""Return a single sprint by ID, or ``None``."""
|
|
572
|
+
return await queries.get_sprint(self.db, sprint_id)
|
|
573
|
+
|
|
574
|
+
async def list_sprints(
|
|
575
|
+
self, study_name: str | None = None, limit: int = 50
|
|
576
|
+
) -> list[dict]:
|
|
577
|
+
"""Return sprints, optionally filtered by study name."""
|
|
578
|
+
return await queries.list_sprints(self.db, study_name=study_name, limit=limit)
|
|
579
|
+
|
|
580
|
+
# ------------------------------------------------------------------
|
|
581
|
+
# Completion handling
|
|
582
|
+
# ------------------------------------------------------------------
|
|
583
|
+
|
|
584
|
+
async def handle_completion(
|
|
585
|
+
self,
|
|
586
|
+
sprint_id: str,
|
|
587
|
+
status: str,
|
|
588
|
+
summary: str | None = None,
|
|
589
|
+
error: str | None = None,
|
|
590
|
+
idea: str | None = None,
|
|
591
|
+
) -> None:
|
|
592
|
+
"""Handle a sprint completion event.
|
|
593
|
+
|
|
594
|
+
Updates the database, sends notifications, and creates an event
|
|
595
|
+
record.
|
|
596
|
+
"""
|
|
597
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
598
|
+
|
|
599
|
+
update_kw: dict[str, str | None] = {
|
|
600
|
+
"status": status,
|
|
601
|
+
"completed_at": now,
|
|
602
|
+
"summary": summary,
|
|
603
|
+
"error": error,
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
# Update the idea if it was auto-generated (sprint had idea=None).
|
|
607
|
+
sprint_before = await queries.get_sprint(self.db, sprint_id)
|
|
608
|
+
if sprint_before and not sprint_before.get("idea"):
|
|
609
|
+
if idea:
|
|
610
|
+
update_kw["idea"] = idea[:500]
|
|
611
|
+
else:
|
|
612
|
+
# Fallback: try to read idea.txt from the cluster.
|
|
613
|
+
fetched = await self._fetch_idea(sprint_before)
|
|
614
|
+
if fetched:
|
|
615
|
+
update_kw["idea"] = fetched[:500]
|
|
616
|
+
|
|
617
|
+
await queries.update_sprint(self.db, sprint_id, **update_kw)
|
|
618
|
+
|
|
619
|
+
sprint = await queries.get_sprint(self.db, sprint_id)
|
|
620
|
+
study_name = sprint["study_name"] if sprint else "unknown"
|
|
621
|
+
|
|
622
|
+
# Create an event record.
|
|
623
|
+
event_data = json.dumps(
|
|
624
|
+
{
|
|
625
|
+
"status": status,
|
|
626
|
+
"summary": summary,
|
|
627
|
+
"error": error,
|
|
628
|
+
}
|
|
629
|
+
)
|
|
630
|
+
await queries.create_event(
|
|
631
|
+
self.db,
|
|
632
|
+
sprint_id=sprint_id,
|
|
633
|
+
event_type="sprint_completed",
|
|
634
|
+
data_json=event_data,
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# Try to fetch the PDF for the notification.
|
|
638
|
+
pdf_local: str | None = None
|
|
639
|
+
if status == SprintStatus.COMPLETED.value and sprint:
|
|
640
|
+
pdf_local = await self._fetch_pdf(sprint)
|
|
641
|
+
|
|
642
|
+
# Notify via configured channels.
|
|
643
|
+
if self.notification_router is not None:
|
|
644
|
+
if status == SprintStatus.COMPLETED.value:
|
|
645
|
+
await self.notification_router.notify_sprint_completed(
|
|
646
|
+
sprint_id=sprint_id,
|
|
647
|
+
study_name=study_name,
|
|
648
|
+
summary=summary or "No summary provided",
|
|
649
|
+
pdf_path=pdf_local,
|
|
650
|
+
)
|
|
651
|
+
elif status == SprintStatus.FAILED.value:
|
|
652
|
+
await self.notification_router.notify_sprint_failed(
|
|
653
|
+
sprint_id=sprint_id,
|
|
654
|
+
study_name=study_name,
|
|
655
|
+
error=error or "Unknown error",
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
logger.info(
|
|
659
|
+
"Sprint %s completion handled: status=%s",
|
|
660
|
+
sprint_id,
|
|
661
|
+
status,
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
async def _fetch_idea(self, sprint: dict) -> str | None:
|
|
665
|
+
"""Try to read idea.txt from the cluster for auto-loop sprints."""
|
|
666
|
+
try:
|
|
667
|
+
study_name = sprint["study_name"]
|
|
668
|
+
if self.study_manager is None:
|
|
669
|
+
return None
|
|
670
|
+
cluster_cfg = await self.study_manager.get_cluster_config(study_name)
|
|
671
|
+
study_cfg = None
|
|
672
|
+
for s in self.config.studies:
|
|
673
|
+
if s.name == study_name:
|
|
674
|
+
study_cfg = s
|
|
675
|
+
break
|
|
676
|
+
if study_cfg and study_cfg.sprints_dir:
|
|
677
|
+
sbase = study_cfg.sprints_dir
|
|
678
|
+
else:
|
|
679
|
+
sbase = f"{cluster_cfg.working_dir}/{study_name}"
|
|
680
|
+
sp_dir = sprint.get("directory", "")
|
|
681
|
+
remote_idea = f"{sbase}/{sp_dir}/idea.txt"
|
|
682
|
+
|
|
683
|
+
conn = {
|
|
684
|
+
"host": cluster_cfg.host,
|
|
685
|
+
"port": cluster_cfg.port,
|
|
686
|
+
"user": cluster_cfg.user,
|
|
687
|
+
"key_path": cluster_cfg.key_path,
|
|
688
|
+
}
|
|
689
|
+
ssh = await self.ssh_manager.get_connection(conn)
|
|
690
|
+
stdout, _, rc = await ssh.run(f"cat {remote_idea} 2>/dev/null")
|
|
691
|
+
if rc == 0 and stdout.strip():
|
|
692
|
+
return stdout.strip()
|
|
693
|
+
return None
|
|
694
|
+
except Exception:
|
|
695
|
+
logger.debug("Idea fetch failed for %s", sprint.get("id"), exc_info=True)
|
|
696
|
+
return None
|
|
697
|
+
|
|
698
|
+
async def _fetch_pdf(self, sprint: dict) -> str | None:
|
|
699
|
+
"""Try to download report.pdf from the cluster."""
|
|
700
|
+
try:
|
|
701
|
+
study_name = sprint["study_name"]
|
|
702
|
+
if self.study_manager is None:
|
|
703
|
+
logger.warning("PDF fetch: no study_manager")
|
|
704
|
+
return None
|
|
705
|
+
cluster_cfg = await self.study_manager.get_cluster_config(study_name)
|
|
706
|
+
# Resolve sprint path.
|
|
707
|
+
study_cfg = None
|
|
708
|
+
for s in self.config.studies:
|
|
709
|
+
if s.name == study_name:
|
|
710
|
+
study_cfg = s
|
|
711
|
+
break
|
|
712
|
+
if study_cfg and study_cfg.sprints_dir:
|
|
713
|
+
sbase = study_cfg.sprints_dir
|
|
714
|
+
else:
|
|
715
|
+
sbase = f"{cluster_cfg.working_dir}/{study_name}"
|
|
716
|
+
sp_dir = sprint.get("directory", "")
|
|
717
|
+
remote_pdf = f"{sbase}/{sp_dir}/report.pdf"
|
|
718
|
+
|
|
719
|
+
conn = {
|
|
720
|
+
"host": cluster_cfg.host,
|
|
721
|
+
"port": cluster_cfg.port,
|
|
722
|
+
"user": cluster_cfg.user,
|
|
723
|
+
"key_path": cluster_cfg.key_path,
|
|
724
|
+
}
|
|
725
|
+
ssh = await self.ssh_manager.get_connection(conn)
|
|
726
|
+
|
|
727
|
+
# Check if PDF exists.
|
|
728
|
+
_, _, rc = await ssh.run(f"test -f {remote_pdf}")
|
|
729
|
+
if rc != 0:
|
|
730
|
+
logger.info(
|
|
731
|
+
"No report.pdf for %s at %s",
|
|
732
|
+
sprint["id"],
|
|
733
|
+
remote_pdf,
|
|
734
|
+
)
|
|
735
|
+
return None
|
|
736
|
+
|
|
737
|
+
# Download to local artifact dir.
|
|
738
|
+
art_dir = Path(self.config.artifact_dir) / sprint["id"]
|
|
739
|
+
art_dir.mkdir(parents=True, exist_ok=True)
|
|
740
|
+
local_pdf = str(art_dir / "report.pdf")
|
|
741
|
+
await ssh.download_file(remote_pdf, local_pdf)
|
|
742
|
+
logger.info("Downloaded PDF for %s", sprint["id"])
|
|
743
|
+
return local_pdf
|
|
744
|
+
except Exception:
|
|
745
|
+
logger.warning(
|
|
746
|
+
"PDF fetch failed for %s",
|
|
747
|
+
sprint.get("id"),
|
|
748
|
+
exc_info=True,
|
|
749
|
+
)
|
|
750
|
+
return None
|