ai-push-hooks 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.ai-push-hooks.toml +73 -0
- package/LICENSE +21 -0
- package/README.md +234 -0
- package/bin/ai-push-hooks.js +35 -0
- package/package.json +24 -0
- package/pyproject.toml +38 -0
- package/run.sh +29 -0
- package/src/ai_push_hooks/__init__.py +6 -0
- package/src/ai_push_hooks/__main__.py +3 -0
- package/src/ai_push_hooks/artifacts.py +86 -0
- package/src/ai_push_hooks/cli.py +49 -0
- package/src/ai_push_hooks/config.py +356 -0
- package/src/ai_push_hooks/engine.py +172 -0
- package/src/ai_push_hooks/executors/__init__.py +1 -0
- package/src/ai_push_hooks/executors/apply.py +55 -0
- package/src/ai_push_hooks/executors/assertions.py +44 -0
- package/src/ai_push_hooks/executors/exec.py +413 -0
- package/src/ai_push_hooks/executors/llm.py +308 -0
- package/src/ai_push_hooks/hook.py +130 -0
- package/src/ai_push_hooks/modules/__init__.py +11 -0
- package/src/ai_push_hooks/modules/beads.py +46 -0
- package/src/ai_push_hooks/modules/docs.py +159 -0
- package/src/ai_push_hooks/modules/pr.py +73 -0
- package/src/ai_push_hooks/prompts_builtin.py +135 -0
- package/src/ai_push_hooks/types.py +236 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import pathlib
|
|
7
|
+
import re
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from .prompts_builtin import BUILTIN_PROMPTS
|
|
11
|
+
from .types import GeneralConfig, HookConfig, HookError, LlmConfig, LoggingConfig, ModuleConfig, StepConfig, SUPPORTED_STEP_TYPES, WorkflowConfig
|
|
12
|
+
from .executors.exec import env_bool
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import tomllib
|
|
16
|
+
except ModuleNotFoundError: # pragma: no cover
|
|
17
|
+
tomllib = None # type: ignore[assignment]
|
|
18
|
+
|
|
19
|
+
DEFAULT_CONFIG_RAW: dict[str, Any] = {
|
|
20
|
+
"general": {
|
|
21
|
+
"enabled": True,
|
|
22
|
+
"allow_push_on_error": False,
|
|
23
|
+
"require_clean_worktree": False,
|
|
24
|
+
"skip_on_sync_branch": True,
|
|
25
|
+
},
|
|
26
|
+
"llm": {
|
|
27
|
+
"runner": "opencode",
|
|
28
|
+
"model": "openai/gpt-5.3-codex-spark",
|
|
29
|
+
"variant": "",
|
|
30
|
+
"timeout_seconds": 800,
|
|
31
|
+
"max_parallel": 2,
|
|
32
|
+
"json_max_retries": 2,
|
|
33
|
+
"invalid_json_feedback_max_chars": 6000,
|
|
34
|
+
"json_retry_new_session": True,
|
|
35
|
+
"delete_session_after_run": True,
|
|
36
|
+
"max_diff_bytes": 180000,
|
|
37
|
+
"session_title_prefix": "ai-push-hooks",
|
|
38
|
+
},
|
|
39
|
+
"logging": {
|
|
40
|
+
"level": "status",
|
|
41
|
+
"jsonl": True,
|
|
42
|
+
"dir": ".git/ai-push-hooks/logs",
|
|
43
|
+
"capture_llm_transcript": True,
|
|
44
|
+
"transcript_dir": ".git/ai-push-hooks/transcripts",
|
|
45
|
+
"summary_dir": ".git/ai-push-hooks/summaries",
|
|
46
|
+
"print_llm_output": False,
|
|
47
|
+
},
|
|
48
|
+
"workflow": {"modules": ["docs"]},
|
|
49
|
+
"modules": {
|
|
50
|
+
"docs": {
|
|
51
|
+
"enabled": True,
|
|
52
|
+
"steps": [
|
|
53
|
+
{"id": "collect", "type": "collect", "collector": "docs_context"},
|
|
54
|
+
{
|
|
55
|
+
"id": "query",
|
|
56
|
+
"type": "llm",
|
|
57
|
+
"inputs": ["collect/push.diff", "collect/changed-files.txt"],
|
|
58
|
+
"output": "queries.json",
|
|
59
|
+
"schema": "string_array",
|
|
60
|
+
"fallback_prompt_id": "docs-query-basic",
|
|
61
|
+
},
|
|
62
|
+
{
|
|
63
|
+
"id": "analyze",
|
|
64
|
+
"type": "llm",
|
|
65
|
+
"inputs": [
|
|
66
|
+
"collect/push.diff",
|
|
67
|
+
"collect/docs-context.txt",
|
|
68
|
+
"query/queries.json",
|
|
69
|
+
"collect/recent-commits.txt",
|
|
70
|
+
],
|
|
71
|
+
"output": "issues.json",
|
|
72
|
+
"schema": "docs_issue_array",
|
|
73
|
+
"fallback_prompt_id": "docs-analysis-basic",
|
|
74
|
+
},
|
|
75
|
+
{
|
|
76
|
+
"id": "apply",
|
|
77
|
+
"type": "apply",
|
|
78
|
+
"inputs": ["collect/push.diff", "collect/docs-context.txt", "analyze/issues.json"],
|
|
79
|
+
"allow_paths": ["README.md", "docs/**/*.md"],
|
|
80
|
+
"fallback_prompt_id": "docs-apply-basic",
|
|
81
|
+
},
|
|
82
|
+
{
|
|
83
|
+
"id": "assert",
|
|
84
|
+
"type": "assert",
|
|
85
|
+
"inputs": ["apply/result.json"],
|
|
86
|
+
"assertion": "docs_apply_requires_manual_commit",
|
|
87
|
+
},
|
|
88
|
+
]
|
|
89
|
+
}
|
|
90
|
+
},
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
ALLOWED_TOP_LEVEL_KEYS = {"general", "llm", "logging", "workflow", "modules"}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _parse_multiline_string(lines: list[str], index: int, initial: str) -> tuple[str, int]:
|
|
97
|
+
chunks: list[str] = []
|
|
98
|
+
value = initial[3:]
|
|
99
|
+
while True:
|
|
100
|
+
end_index = value.find('"""')
|
|
101
|
+
if end_index >= 0:
|
|
102
|
+
chunks.append(value[:end_index])
|
|
103
|
+
return "\n".join(chunks), index
|
|
104
|
+
chunks.append(value)
|
|
105
|
+
index += 1
|
|
106
|
+
if index >= len(lines):
|
|
107
|
+
raise HookError("Unterminated multiline string in TOML fallback parser")
|
|
108
|
+
value = lines[index]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _assign_path(root: dict[str, Any], path: list[str], value: Any, array_mode: bool = False) -> dict[str, Any]:
|
|
112
|
+
current: Any = root
|
|
113
|
+
for part in path[:-1]:
|
|
114
|
+
if isinstance(current, list):
|
|
115
|
+
if not current:
|
|
116
|
+
current.append({})
|
|
117
|
+
current = current[-1]
|
|
118
|
+
current = current.setdefault(part, {})
|
|
119
|
+
key = path[-1]
|
|
120
|
+
if array_mode:
|
|
121
|
+
items = current.setdefault(key, [])
|
|
122
|
+
if not isinstance(items, list):
|
|
123
|
+
raise HookError(f"Invalid array-of-table path: {'.'.join(path)}")
|
|
124
|
+
item: dict[str, Any] = {}
|
|
125
|
+
items.append(item)
|
|
126
|
+
return item
|
|
127
|
+
current[key] = value
|
|
128
|
+
return current
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _parse_scalar(raw: str) -> Any:
|
|
132
|
+
raw = raw.strip()
|
|
133
|
+
if raw.startswith('"') and raw.endswith('"'):
|
|
134
|
+
return raw[1:-1]
|
|
135
|
+
if raw in {"true", "false"}:
|
|
136
|
+
return raw == "true"
|
|
137
|
+
if re.fullmatch(r"-?\d+", raw):
|
|
138
|
+
return int(raw)
|
|
139
|
+
if raw.startswith("[") and raw.endswith("]"):
|
|
140
|
+
return json.loads(raw)
|
|
141
|
+
return raw
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def parse_toml_fallback(raw: str) -> dict[str, Any]:
|
|
145
|
+
parsed: dict[str, Any] = {}
|
|
146
|
+
lines = raw.splitlines()
|
|
147
|
+
current: Any = parsed
|
|
148
|
+
index = 0
|
|
149
|
+
while index < len(lines):
|
|
150
|
+
line = lines[index].strip()
|
|
151
|
+
index += 1
|
|
152
|
+
if not line or line.startswith("#"):
|
|
153
|
+
continue
|
|
154
|
+
if line.startswith("[[") and line.endswith("]]"):
|
|
155
|
+
path = [part.strip() for part in line[2:-2].split(".") if part.strip()]
|
|
156
|
+
current = _assign_path(parsed, path, None, array_mode=True)
|
|
157
|
+
continue
|
|
158
|
+
if line.startswith("[") and line.endswith("]"):
|
|
159
|
+
path = [part.strip() for part in line[1:-1].split(".") if part.strip()]
|
|
160
|
+
current = parsed
|
|
161
|
+
for part in path:
|
|
162
|
+
current = current.setdefault(part, {})
|
|
163
|
+
continue
|
|
164
|
+
if "=" not in line:
|
|
165
|
+
continue
|
|
166
|
+
key, value = line.split("=", 1)
|
|
167
|
+
key = key.strip()
|
|
168
|
+
value = value.strip()
|
|
169
|
+
if value.startswith('"""'):
|
|
170
|
+
parsed_value, index = _parse_multiline_string(lines, index - 1, value)
|
|
171
|
+
else:
|
|
172
|
+
parsed_value = _parse_scalar(value)
|
|
173
|
+
current[key] = parsed_value
|
|
174
|
+
return parsed
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
|
178
|
+
merged = copy.deepcopy(base)
|
|
179
|
+
for key, value in override.items():
|
|
180
|
+
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
|
|
181
|
+
merged[key] = deep_merge(merged[key], value)
|
|
182
|
+
else:
|
|
183
|
+
merged[key] = copy.deepcopy(value)
|
|
184
|
+
return merged
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _normalize_step(raw: dict[str, Any]) -> StepConfig:
|
|
188
|
+
step_type = str(raw.get("type", "")).strip()
|
|
189
|
+
if step_type not in SUPPORTED_STEP_TYPES:
|
|
190
|
+
raise HookError(f"Unknown step type: {step_type}")
|
|
191
|
+
step = StepConfig(
|
|
192
|
+
id=str(raw.get("id", "")).strip(),
|
|
193
|
+
type=step_type,
|
|
194
|
+
inputs=tuple(str(item) for item in raw.get("inputs", []) or []),
|
|
195
|
+
output=str(raw.get("output")).strip() if raw.get("output") is not None else None,
|
|
196
|
+
schema=str(raw.get("schema")).strip() if raw.get("schema") is not None else None,
|
|
197
|
+
prompt=str(raw.get("prompt")).strip() if raw.get("prompt") is not None else None,
|
|
198
|
+
prompt_file=str(raw.get("prompt_file")).strip() if raw.get("prompt_file") is not None else None,
|
|
199
|
+
fallback_prompt_id=(
|
|
200
|
+
str(raw.get("fallback_prompt_id")).strip()
|
|
201
|
+
if raw.get("fallback_prompt_id") is not None
|
|
202
|
+
else None
|
|
203
|
+
),
|
|
204
|
+
collector=str(raw.get("collector")).strip() if raw.get("collector") is not None else None,
|
|
205
|
+
allow_paths=tuple(str(item) for item in raw.get("allow_paths", []) or []),
|
|
206
|
+
executor=str(raw.get("executor")).strip() if raw.get("executor") is not None else None,
|
|
207
|
+
assertion=str(raw.get("assertion")).strip() if raw.get("assertion") is not None else None,
|
|
208
|
+
when_env=str(raw.get("when_env")).strip() if raw.get("when_env") is not None else None,
|
|
209
|
+
)
|
|
210
|
+
if not step.id:
|
|
211
|
+
raise HookError("Every workflow step requires a non-empty id")
|
|
212
|
+
if step.is_promptable and not any([step.prompt, step.prompt_file, step.fallback_prompt_id]):
|
|
213
|
+
raise HookError(f"Promptable step `{step.id}` requires prompt, prompt_file, or fallback_prompt_id")
|
|
214
|
+
if step.type == "collect" and not step.collector:
|
|
215
|
+
raise HookError(f"Collect step `{step.id}` requires collector")
|
|
216
|
+
if step.type == "llm" and not step.output:
|
|
217
|
+
raise HookError(f"LLM step `{step.id}` requires output")
|
|
218
|
+
if step.type == "apply" and not step.allow_paths:
|
|
219
|
+
raise HookError(f"Apply step `{step.id}` requires allow_paths")
|
|
220
|
+
if step.type == "exec" and not step.executor:
|
|
221
|
+
raise HookError(f"Exec step `{step.id}` requires executor")
|
|
222
|
+
if step.type == "assert" and not step.assertion:
|
|
223
|
+
raise HookError(f"Assert step `{step.id}` requires assertion")
|
|
224
|
+
return step
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _build_config(raw: dict[str, Any]) -> HookConfig:
|
|
228
|
+
unknown = set(raw) - ALLOWED_TOP_LEVEL_KEYS
|
|
229
|
+
if unknown:
|
|
230
|
+
raise HookError(
|
|
231
|
+
"Legacy or unsupported config keys are not allowed: " + ", ".join(sorted(unknown))
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
workflow_modules = tuple(str(item) for item in raw.get("workflow", {}).get("modules", []) or [])
|
|
235
|
+
if not workflow_modules:
|
|
236
|
+
raise HookError("workflow.modules must define at least one module id")
|
|
237
|
+
|
|
238
|
+
module_payload = raw.get("modules", {})
|
|
239
|
+
if not isinstance(module_payload, dict):
|
|
240
|
+
raise HookError("modules must be a table")
|
|
241
|
+
|
|
242
|
+
modules: dict[str, ModuleConfig] = {}
|
|
243
|
+
for module_id in workflow_modules:
|
|
244
|
+
if module_id not in module_payload:
|
|
245
|
+
raise HookError(f"workflow.modules references unknown module `{module_id}`")
|
|
246
|
+
module_raw = module_payload[module_id]
|
|
247
|
+
steps_raw = module_raw.get("steps", [])
|
|
248
|
+
if not isinstance(steps_raw, list) or not steps_raw:
|
|
249
|
+
raise HookError(f"Module `{module_id}` must define a non-empty steps array")
|
|
250
|
+
modules[module_id] = ModuleConfig(
|
|
251
|
+
id=module_id,
|
|
252
|
+
enabled=bool(module_raw.get("enabled", True)),
|
|
253
|
+
steps=tuple(_normalize_step(step) for step in steps_raw),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
general = GeneralConfig(**raw.get("general", {}))
|
|
257
|
+
llm = LlmConfig(**raw.get("llm", {}))
|
|
258
|
+
logging = LoggingConfig(**raw.get("logging", {}))
|
|
259
|
+
return HookConfig(
|
|
260
|
+
general=general,
|
|
261
|
+
llm=llm,
|
|
262
|
+
logging=logging,
|
|
263
|
+
workflow=WorkflowConfig(modules=workflow_modules),
|
|
264
|
+
modules=modules,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _apply_env_overrides(config: HookConfig) -> HookConfig:
|
|
269
|
+
raw = {
|
|
270
|
+
"general": {
|
|
271
|
+
"enabled": config.general.enabled,
|
|
272
|
+
"allow_push_on_error": config.general.allow_push_on_error,
|
|
273
|
+
"require_clean_worktree": config.general.require_clean_worktree,
|
|
274
|
+
"skip_on_sync_branch": config.general.skip_on_sync_branch,
|
|
275
|
+
},
|
|
276
|
+
"llm": config.llm.__dict__.copy(),
|
|
277
|
+
"logging": config.logging.__dict__.copy(),
|
|
278
|
+
"workflow": {"modules": list(config.workflow.modules)},
|
|
279
|
+
"modules": {},
|
|
280
|
+
}
|
|
281
|
+
for module_id, module in config.modules.items():
|
|
282
|
+
raw["modules"][module_id] = {
|
|
283
|
+
"enabled": module.enabled,
|
|
284
|
+
"steps": [step.__dict__.copy() for step in module.steps],
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
skip = env_bool("AI_PUSH_HOOKS_SKIP")
|
|
288
|
+
if skip is True:
|
|
289
|
+
raw["general"]["enabled"] = False
|
|
290
|
+
allow_on_error = env_bool("AI_PUSH_HOOKS_ALLOW_PUSH_ON_ERROR")
|
|
291
|
+
if allow_on_error is not None:
|
|
292
|
+
raw["general"]["allow_push_on_error"] = allow_on_error
|
|
293
|
+
require_clean = env_bool("AI_PUSH_HOOKS_REQUIRE_CLEAN")
|
|
294
|
+
if require_clean is not None:
|
|
295
|
+
raw["general"]["require_clean_worktree"] = require_clean
|
|
296
|
+
allow_dirty = env_bool("AI_PUSH_HOOKS_ALLOW_DIRTY")
|
|
297
|
+
if allow_dirty is True:
|
|
298
|
+
raw["general"]["require_clean_worktree"] = False
|
|
299
|
+
|
|
300
|
+
logging_level = os.getenv("AI_PUSH_HOOKS_LOG_LEVEL")
|
|
301
|
+
if logging_level:
|
|
302
|
+
raw["logging"]["level"] = logging_level.strip().lower()
|
|
303
|
+
print_output = env_bool("AI_PUSH_HOOKS_PRINT_LLM_OUTPUT")
|
|
304
|
+
if print_output is not None:
|
|
305
|
+
raw["logging"]["print_llm_output"] = print_output
|
|
306
|
+
model = os.getenv("AI_PUSH_HOOKS_MODEL")
|
|
307
|
+
if model:
|
|
308
|
+
raw["llm"]["model"] = model
|
|
309
|
+
variant = os.getenv("AI_PUSH_HOOKS_VARIANT")
|
|
310
|
+
if variant is not None:
|
|
311
|
+
raw["llm"]["variant"] = variant.strip()
|
|
312
|
+
timeout = os.getenv("AI_PUSH_HOOKS_TIMEOUT_SECONDS")
|
|
313
|
+
if timeout:
|
|
314
|
+
raw["llm"]["timeout_seconds"] = int(timeout)
|
|
315
|
+
return _build_config(raw)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def load_config(repo_root: pathlib.Path) -> tuple[HookConfig, pathlib.Path | None]:
|
|
319
|
+
config_path: pathlib.Path | None = None
|
|
320
|
+
raw = copy.deepcopy(DEFAULT_CONFIG_RAW)
|
|
321
|
+
for candidate in [repo_root / ".ai-push-hooks.toml", repo_root / "ai-push-hooks.toml"]:
|
|
322
|
+
if candidate.exists():
|
|
323
|
+
config_path = candidate
|
|
324
|
+
text = candidate.read_text(encoding="utf-8")
|
|
325
|
+
loaded = tomllib.loads(text) if tomllib is not None else parse_toml_fallback(text)
|
|
326
|
+
if not isinstance(loaded, dict):
|
|
327
|
+
raise HookError(f"Invalid config format in {candidate}")
|
|
328
|
+
raw = deep_merge(raw, loaded)
|
|
329
|
+
break
|
|
330
|
+
return _apply_env_overrides(_build_config(raw)), config_path
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def resolve_prompt_text(repo_root: pathlib.Path, step: StepConfig) -> str:
|
|
334
|
+
if step.prompt and step.prompt.strip():
|
|
335
|
+
return step.prompt.strip()
|
|
336
|
+
if step.prompt_file:
|
|
337
|
+
prompt_path = pathlib.Path(step.prompt_file)
|
|
338
|
+
if not prompt_path.is_absolute():
|
|
339
|
+
prompt_path = (repo_root / prompt_path).resolve()
|
|
340
|
+
if prompt_path.exists():
|
|
341
|
+
text = prompt_path.read_text(encoding="utf-8").strip()
|
|
342
|
+
if text:
|
|
343
|
+
return text
|
|
344
|
+
if step.fallback_prompt_id:
|
|
345
|
+
return resolve_builtin_prompt(step.fallback_prompt_id)
|
|
346
|
+
raise HookError(f"Prompt file not found or empty for step `{step.id}`: {prompt_path}")
|
|
347
|
+
if step.fallback_prompt_id:
|
|
348
|
+
return resolve_builtin_prompt(step.fallback_prompt_id)
|
|
349
|
+
raise HookError(f"No prompt source available for step `{step.id}`")
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def resolve_builtin_prompt(prompt_id: str) -> str:
|
|
353
|
+
prompt = BUILTIN_PROMPTS.get(prompt_id)
|
|
354
|
+
if not prompt:
|
|
355
|
+
raise HookError(f"Unknown built-in prompt id: {prompt_id}")
|
|
356
|
+
return prompt
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import pathlib
|
|
4
|
+
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
from .artifacts import ArtifactStore
|
|
8
|
+
from .config import resolve_prompt_text
|
|
9
|
+
from .executors.apply import run_apply_step
|
|
10
|
+
from .executors.assertions import ASSERTION_HANDLERS
|
|
11
|
+
from .executors.exec import EXEC_HANDLERS, env_bool
|
|
12
|
+
from .executors.llm import run_llm_step
|
|
13
|
+
from .modules import COLLECTORS
|
|
14
|
+
from .types import CollectorResult, HookError, ModuleRuntimeState, RuntimeContext, StepConfig, StepResult, WorkflowRunResult
|
|
15
|
+
|
|
16
|
+
CollectorHandler = Callable[[RuntimeContext, ModuleRuntimeState], CollectorResult]
|
|
17
|
+
ExecHandler = Callable[[RuntimeContext, ModuleRuntimeState, StepConfig, list[pathlib.Path]], dict[str, Any]]
|
|
18
|
+
AssertionHandler = Callable[[RuntimeContext, StepConfig, list[pathlib.Path]], dict[str, Any]]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class WorkflowEngine:
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
context: RuntimeContext,
|
|
25
|
+
artifacts: ArtifactStore,
|
|
26
|
+
collectors: dict[str, CollectorHandler] | None = None,
|
|
27
|
+
exec_handlers: dict[str, ExecHandler] | None = None,
|
|
28
|
+
assertion_handlers: dict[str, AssertionHandler] | None = None,
|
|
29
|
+
llm_executor: Callable[[RuntimeContext, StepConfig, str, list[pathlib.Path], str], Any] = run_llm_step,
|
|
30
|
+
apply_executor: Callable[[RuntimeContext, ModuleRuntimeState, StepConfig, str, list[pathlib.Path], str], dict[str, object]] = run_apply_step,
|
|
31
|
+
) -> None:
|
|
32
|
+
self.context = context
|
|
33
|
+
self.artifacts = artifacts
|
|
34
|
+
self.collectors = collectors or COLLECTORS
|
|
35
|
+
self.exec_handlers = exec_handlers or EXEC_HANDLERS
|
|
36
|
+
self.assertion_handlers = assertion_handlers or ASSERTION_HANDLERS
|
|
37
|
+
self.llm_executor = llm_executor
|
|
38
|
+
self.apply_executor = apply_executor
|
|
39
|
+
|
|
40
|
+
def run(self) -> WorkflowRunResult:
|
|
41
|
+
self.artifacts.prepare()
|
|
42
|
+
states = [
|
|
43
|
+
ModuleRuntimeState(module=self.context.config.modules[module_id])
|
|
44
|
+
for module_id in self.context.config.workflow.modules
|
|
45
|
+
if self.context.config.modules[module_id].enabled
|
|
46
|
+
]
|
|
47
|
+
statuses: dict[str, str] = {state.module.id: "pending" for state in states}
|
|
48
|
+
futures: dict[Future[StepResult], tuple[ModuleRuntimeState, StepConfig]] = {}
|
|
49
|
+
|
|
50
|
+
with ThreadPoolExecutor(max_workers=max(1, self.context.config.llm.max_parallel)) as pool:
|
|
51
|
+
while True:
|
|
52
|
+
for state in states:
|
|
53
|
+
if state.status in {"completed", "failed"}:
|
|
54
|
+
statuses[state.module.id] = state.status
|
|
55
|
+
continue
|
|
56
|
+
if state.active_step_id is not None:
|
|
57
|
+
continue
|
|
58
|
+
step = state.next_step
|
|
59
|
+
if step is None:
|
|
60
|
+
state.status = "completed"
|
|
61
|
+
statuses[state.module.id] = "completed"
|
|
62
|
+
continue
|
|
63
|
+
if futures and not step.is_read_only:
|
|
64
|
+
continue
|
|
65
|
+
if any(not running_step.is_read_only for _future, (_state, running_step) in futures.items()):
|
|
66
|
+
continue
|
|
67
|
+
if not step.is_read_only and futures:
|
|
68
|
+
continue
|
|
69
|
+
if step.is_read_only and len(futures) >= max(1, self.context.config.llm.max_parallel):
|
|
70
|
+
continue
|
|
71
|
+
future = pool.submit(self._execute_step, state, step)
|
|
72
|
+
futures[future] = (state, step)
|
|
73
|
+
state.active_step_id = step.id
|
|
74
|
+
state.status = "running"
|
|
75
|
+
if not step.is_read_only:
|
|
76
|
+
break
|
|
77
|
+
|
|
78
|
+
if not futures:
|
|
79
|
+
if all(state.status == "completed" for state in states):
|
|
80
|
+
break
|
|
81
|
+
pending = [state.module.id for state in states if state.status not in {"completed", "failed"}]
|
|
82
|
+
raise HookError("Scheduler deadlock while running modules: " + ", ".join(pending))
|
|
83
|
+
|
|
84
|
+
done, _ = wait(set(futures), return_when=FIRST_COMPLETED)
|
|
85
|
+
for future in done:
|
|
86
|
+
state, step = futures.pop(future)
|
|
87
|
+
state.active_step_id = None
|
|
88
|
+
try:
|
|
89
|
+
result = future.result()
|
|
90
|
+
except Exception as exc: # noqa: BLE001
|
|
91
|
+
state.status = "failed"
|
|
92
|
+
state.error = str(exc)
|
|
93
|
+
raise
|
|
94
|
+
|
|
95
|
+
state.metadata.update(result.metadata)
|
|
96
|
+
state.step_index += 1
|
|
97
|
+
if result.metadata.get("skip_module"):
|
|
98
|
+
state.step_index = len(state.module.steps)
|
|
99
|
+
state.status = "completed"
|
|
100
|
+
elif state.next_step is None:
|
|
101
|
+
state.status = "completed"
|
|
102
|
+
else:
|
|
103
|
+
state.status = "pending"
|
|
104
|
+
statuses[state.module.id] = state.status
|
|
105
|
+
|
|
106
|
+
return WorkflowRunResult(run_dir=self.artifacts.run_dir, modules=statuses)
|
|
107
|
+
|
|
108
|
+
def _execute_step(self, state: ModuleRuntimeState, step: StepConfig) -> StepResult:
|
|
109
|
+
if step.when_env and env_bool(step.when_env) is not True:
|
|
110
|
+
payload = {"skipped": True, "reason": f"{step.when_env} not enabled"}
|
|
111
|
+
path = self.artifacts.write_json(state, state.step_index, step.id, "result.json", payload)
|
|
112
|
+
return StepResult(status="skipped", artifacts={"result.json": path}, metadata={})
|
|
113
|
+
|
|
114
|
+
if step.type == "collect":
|
|
115
|
+
return self._run_collect(state, step)
|
|
116
|
+
|
|
117
|
+
input_paths = [self.artifacts.resolve_input(state, reference) for reference in step.inputs]
|
|
118
|
+
stage_name = f"{state.module.id}.{step.id}"
|
|
119
|
+
|
|
120
|
+
if step.type == "llm":
|
|
121
|
+
prompt = resolve_prompt_text(self.context.repo_root, step)
|
|
122
|
+
payload = self.llm_executor(self.context, step, prompt, input_paths, stage_name)
|
|
123
|
+
artifact_name = step.output or "result.json"
|
|
124
|
+
if isinstance(payload, (dict, list)) or artifact_name.endswith(".json"):
|
|
125
|
+
path = self.artifacts.write_json(state, state.step_index, step.id, artifact_name, payload)
|
|
126
|
+
else:
|
|
127
|
+
path = self.artifacts.write_text(state, state.step_index, step.id, artifact_name, str(payload))
|
|
128
|
+
return StepResult(artifacts={artifact_name: path})
|
|
129
|
+
|
|
130
|
+
if step.type == "apply":
|
|
131
|
+
prompt = resolve_prompt_text(self.context.repo_root, step)
|
|
132
|
+
payload = self.apply_executor(self.context, state, step, prompt, input_paths, stage_name)
|
|
133
|
+
path = self.artifacts.write_json(state, state.step_index, step.id, "result.json", payload)
|
|
134
|
+
return StepResult(artifacts={"result.json": path})
|
|
135
|
+
|
|
136
|
+
if step.type == "exec":
|
|
137
|
+
handler = self.exec_handlers.get(step.executor or "")
|
|
138
|
+
if handler is None:
|
|
139
|
+
raise HookError(f"Unknown exec handler: {step.executor}")
|
|
140
|
+
payload = handler(self.context, state, step, input_paths)
|
|
141
|
+
path = self.artifacts.write_json(state, state.step_index, step.id, "result.json", payload)
|
|
142
|
+
return StepResult(artifacts={"result.json": path})
|
|
143
|
+
|
|
144
|
+
if step.type == "assert":
|
|
145
|
+
handler = self.assertion_handlers.get(step.assertion or "")
|
|
146
|
+
if handler is None:
|
|
147
|
+
raise HookError(f"Unknown assertion handler: {step.assertion}")
|
|
148
|
+
payload = handler(self.context, step, input_paths)
|
|
149
|
+
path = self.artifacts.write_json(state, state.step_index, step.id, "result.json", payload)
|
|
150
|
+
if not bool(payload.get("ok", False)):
|
|
151
|
+
raise HookError(str(payload.get("message", "assertion failed")))
|
|
152
|
+
return StepResult(artifacts={"result.json": path})
|
|
153
|
+
|
|
154
|
+
raise HookError(f"Unsupported step type: {step.type}")
|
|
155
|
+
|
|
156
|
+
def _run_collect(self, state: ModuleRuntimeState, step: StepConfig) -> StepResult:
|
|
157
|
+
handler = self.collectors.get(step.collector or "")
|
|
158
|
+
if handler is None:
|
|
159
|
+
raise HookError(f"Unknown collector: {step.collector}")
|
|
160
|
+
result = handler(self.context, state)
|
|
161
|
+
artifacts: dict[str, pathlib.Path] = {}
|
|
162
|
+
for artifact_name, payload in result.artifacts.items():
|
|
163
|
+
if isinstance(payload, (dict, list)) or artifact_name.endswith(".json"):
|
|
164
|
+
path = self.artifacts.write_json(state, state.step_index, step.id, artifact_name, payload)
|
|
165
|
+
else:
|
|
166
|
+
path = self.artifacts.write_text(state, state.step_index, step.id, artifact_name, str(payload))
|
|
167
|
+
artifacts[artifact_name] = path
|
|
168
|
+
metadata = dict(result.metadata)
|
|
169
|
+
if result.skip_module:
|
|
170
|
+
metadata["skip_module"] = True
|
|
171
|
+
metadata["skip_reason"] = result.skip_reason
|
|
172
|
+
return StepResult(artifacts=artifacts, metadata=metadata)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Workflow executors."""
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import pathlib
|
|
5
|
+
|
|
6
|
+
from ..types import HookError, ModuleRuntimeState, RuntimeContext, StepConfig
|
|
7
|
+
from .exec import list_repo_changes, path_matches
|
|
8
|
+
from .llm import call_opencode, finalize_opencode_session
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def run_apply_step(
|
|
12
|
+
context: RuntimeContext,
|
|
13
|
+
state: ModuleRuntimeState,
|
|
14
|
+
step: StepConfig,
|
|
15
|
+
prompt: str,
|
|
16
|
+
input_paths: list[pathlib.Path],
|
|
17
|
+
stage_name: str,
|
|
18
|
+
) -> dict[str, object]:
|
|
19
|
+
for input_path in input_paths:
|
|
20
|
+
if input_path.name.endswith("issues.json"):
|
|
21
|
+
issues = json.loads(input_path.read_text(encoding="utf-8"))
|
|
22
|
+
if isinstance(issues, list) and not issues:
|
|
23
|
+
return {"changed": False, "changed_files": [], "skipped": True}
|
|
24
|
+
|
|
25
|
+
baseline = list_repo_changes(context.repo_root)
|
|
26
|
+
files = list(input_paths)
|
|
27
|
+
agents = context.repo_root / "AGENTS.md"
|
|
28
|
+
if agents.exists():
|
|
29
|
+
files.append(agents)
|
|
30
|
+
|
|
31
|
+
result = call_opencode(
|
|
32
|
+
context,
|
|
33
|
+
stage_name=stage_name,
|
|
34
|
+
purpose=f"{step.type}:{step.id}",
|
|
35
|
+
prompt=prompt,
|
|
36
|
+
files=files,
|
|
37
|
+
)
|
|
38
|
+
finalize_opencode_session(context, stage_name, result.session_id)
|
|
39
|
+
if result.return_code != 0:
|
|
40
|
+
details = result.stderr.strip() or result.stdout.strip() or f"exit code {result.return_code}"
|
|
41
|
+
raise HookError(f"Apply step failed: {details}")
|
|
42
|
+
|
|
43
|
+
after = list_repo_changes(context.repo_root)
|
|
44
|
+
changed_files = sorted(after - baseline)
|
|
45
|
+
unexpected = [
|
|
46
|
+
path for path in changed_files if not any(path_matches(path, pattern) for pattern in step.allow_paths)
|
|
47
|
+
]
|
|
48
|
+
if unexpected:
|
|
49
|
+
raise HookError("Apply step modified files outside allowlist: " + ", ".join(unexpected))
|
|
50
|
+
return {
|
|
51
|
+
"changed": bool(changed_files),
|
|
52
|
+
"changed_files": changed_files,
|
|
53
|
+
"allowed_paths": list(step.allow_paths),
|
|
54
|
+
"skipped": False,
|
|
55
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import pathlib
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from ..types import RuntimeContext, StepConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def docs_apply_requires_manual_commit(
|
|
11
|
+
_context: RuntimeContext,
|
|
12
|
+
_step: StepConfig,
|
|
13
|
+
inputs: list[pathlib.Path],
|
|
14
|
+
) -> dict[str, Any]:
|
|
15
|
+
payload = json.loads(inputs[0].read_text(encoding="utf-8"))
|
|
16
|
+
changed_files = payload.get("changed_files", [])
|
|
17
|
+
if changed_files:
|
|
18
|
+
return {
|
|
19
|
+
"ok": False,
|
|
20
|
+
"message": "Documentation updates were applied; review and commit them before pushing again.",
|
|
21
|
+
"changed_files": changed_files,
|
|
22
|
+
}
|
|
23
|
+
return {"ok": True, "message": "", "changed_files": changed_files}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def beads_alignment_clean(
|
|
27
|
+
_context: RuntimeContext,
|
|
28
|
+
_step: StepConfig,
|
|
29
|
+
inputs: list[pathlib.Path],
|
|
30
|
+
) -> dict[str, Any]:
|
|
31
|
+
payload = json.loads(inputs[0].read_text(encoding="utf-8"))
|
|
32
|
+
unresolved = bool(payload.get("unresolved", False))
|
|
33
|
+
if unresolved:
|
|
34
|
+
return {
|
|
35
|
+
"ok": False,
|
|
36
|
+
"message": "Beads alignment requires manual action before push.",
|
|
37
|
+
}
|
|
38
|
+
return {"ok": True, "message": ""}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
ASSERTION_HANDLERS = {
|
|
42
|
+
"docs_apply_requires_manual_commit": docs_apply_requires_manual_commit,
|
|
43
|
+
"beads_alignment_clean": beads_alignment_clean,
|
|
44
|
+
}
|