capt-hook 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- capt_hook-0.2.0.dist-info/METADATA +113 -0
- capt_hook-0.2.0.dist-info/RECORD +57 -0
- capt_hook-0.2.0.dist-info/WHEEL +4 -0
- capt_hook-0.2.0.dist-info/entry_points.txt +3 -0
- capt_hook-0.2.0.dist-info/licenses/LICENSE +73 -0
- captain_hook/__init__.py +246 -0
- captain_hook/__main__.py +6 -0
- captain_hook/app.py +278 -0
- captain_hook/classifiers/__init__.py +30 -0
- captain_hook/classifiers/conductor.py +35 -0
- captain_hook/classifiers/droid.py +20 -0
- captain_hook/classifiers/native.py +19 -0
- captain_hook/cli.py +341 -0
- captain_hook/command.py +356 -0
- captain_hook/conditions.py +136 -0
- captain_hook/context.py +161 -0
- captain_hook/dispatch.py +107 -0
- captain_hook/events.py +318 -0
- captain_hook/file.py +120 -0
- captain_hook/llm/__init__.py +9 -0
- captain_hook/llm/backends.py +152 -0
- captain_hook/loader.py +62 -0
- captain_hook/log.py +60 -0
- captain_hook/primitives/__init__.py +51 -0
- captain_hook/primitives/audit.py +71 -0
- captain_hook/primitives/commands.py +61 -0
- captain_hook/primitives/lint.py +216 -0
- captain_hook/primitives/llm.py +376 -0
- captain_hook/primitives/nudge.py +95 -0
- captain_hook/prompt.py +103 -0
- captain_hook/py.typed +1 -0
- captain_hook/session.py +158 -0
- captain_hook/settings.py +120 -0
- captain_hook/signals/__init__.py +86 -0
- captain_hook/signals/nlp.py +105 -0
- captain_hook/state.py +221 -0
- captain_hook/styleguide/__init__.py +183 -0
- captain_hook/styleguide/query.py +238 -0
- captain_hook/styleguide/scope.py +46 -0
- captain_hook/styleguide/types.py +70 -0
- captain_hook/tasks.py +112 -0
- captain_hook/templates/example_hook.py.tmpl +85 -0
- captain_hook/testing/__init__.py +10 -0
- captain_hook/testing/helpers.py +392 -0
- captain_hook/testing/session_cache.py +50 -0
- captain_hook/testing/types.py +88 -0
- captain_hook/tests/__init__.py +27 -0
- captain_hook/tests/helpers.py +361 -0
- captain_hook/tools.py +59 -0
- captain_hook/transcript/__init__.py +572 -0
- captain_hook/transcript/inputs.py +226 -0
- captain_hook/transcript/models.py +186 -0
- captain_hook/types.py +381 -0
- captain_hook/util/__init__.py +0 -0
- captain_hook/util/model_cache.py +87 -0
- captain_hook/utils.py +27 -0
- captain_hook/workflow.py +119 -0
captain_hook/session.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import shutil
|
|
6
|
+
import tempfile
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from hashlib import sha256
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import ClassVar, Generic, TypeVar, overload
|
|
11
|
+
|
|
12
|
+
from loguru import logger
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
M = TypeVar("M", bound=BaseModel)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def state_root() -> Path:
|
|
19
|
+
from captain_hook.settings import resolve_state_dir
|
|
20
|
+
|
|
21
|
+
return resolve_state_dir()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def session_hash(transcript_path: str | Path) -> str:
|
|
25
|
+
return sha256(str(transcript_path).encode()).hexdigest()[:12]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def ensure_session(transcript_path: str | Path) -> Path:
|
|
29
|
+
sd = state_root() / "hooks" / "sessions" / session_hash(transcript_path)
|
|
30
|
+
sd.mkdir(parents=True, exist_ok=True)
|
|
31
|
+
marker = sd / ".transcript_path"
|
|
32
|
+
if not marker.exists():
|
|
33
|
+
marker.write_text(str(transcript_path))
|
|
34
|
+
return sd
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def cleanup_stale() -> None:
|
|
38
|
+
sessions = state_root() / "hooks" / "sessions"
|
|
39
|
+
if not sessions.exists():
|
|
40
|
+
return
|
|
41
|
+
for sd in sessions.iterdir():
|
|
42
|
+
if not sd.is_dir():
|
|
43
|
+
continue
|
|
44
|
+
marker = sd / ".transcript_path"
|
|
45
|
+
if marker.exists() and not Path(marker.read_text().strip()).exists():
|
|
46
|
+
shutil.rmtree(sd, ignore_errors=True)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SessionSlot(Generic[M]): # noqa: UP046
|
|
50
|
+
"""A typed slot for reading/writing a single Pydantic model in a session directory."""
|
|
51
|
+
|
|
52
|
+
def __init__(self, session_dir: Path | None, model: type[M]) -> None:
|
|
53
|
+
self._model = model
|
|
54
|
+
self._path = (session_dir / f"{self.model_key(model)}.json") if session_dir else None
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def model_key(model: type[BaseModel]) -> str:
|
|
58
|
+
return re.sub(r"(?<!^)(?=[A-Z])", "_", model.__name__).lower()
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def path(self) -> Path | None:
|
|
62
|
+
return self._path
|
|
63
|
+
|
|
64
|
+
@overload
|
|
65
|
+
def get(self) -> M | None: ...
|
|
66
|
+
@overload
|
|
67
|
+
def get(self, default: M) -> M: ...
|
|
68
|
+
def get(self, default: M | None = None) -> M | None:
|
|
69
|
+
if not self._path or not self._path.exists():
|
|
70
|
+
return default
|
|
71
|
+
try:
|
|
72
|
+
return self._model.model_validate_json(self._path.read_text())
|
|
73
|
+
except Exception:
|
|
74
|
+
logger.bind(model=self._model.__name__, path=str(self._path)).opt(exception=True).warning(
|
|
75
|
+
"failed to read session state",
|
|
76
|
+
)
|
|
77
|
+
return default
|
|
78
|
+
|
|
79
|
+
def set(self, obj: M) -> None:
|
|
80
|
+
if not self._path:
|
|
81
|
+
return
|
|
82
|
+
try:
|
|
83
|
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
84
|
+
tmp_fd, tmp_name = tempfile.mkstemp(
|
|
85
|
+
dir=self._path.parent,
|
|
86
|
+
suffix=".tmp",
|
|
87
|
+
)
|
|
88
|
+
try:
|
|
89
|
+
os.write(tmp_fd, obj.model_dump_json().encode())
|
|
90
|
+
os.close(tmp_fd)
|
|
91
|
+
os.replace(tmp_name, self._path)
|
|
92
|
+
except BaseException:
|
|
93
|
+
os.close(tmp_fd) if not os.get_inheritable(tmp_fd) else None
|
|
94
|
+
Path(tmp_name).unlink(missing_ok=True)
|
|
95
|
+
raise
|
|
96
|
+
except OSError:
|
|
97
|
+
logger.bind(path=str(self._path)).opt(exception=True).warning("failed to persist session state")
|
|
98
|
+
|
|
99
|
+
def delete(self) -> None:
|
|
100
|
+
if self._path:
|
|
101
|
+
self._path.unlink(missing_ok=True)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class SessionStore:
|
|
105
|
+
"""Class-keyed store providing typed ``SessionSlot`` access via ``store[ModelClass]``."""
|
|
106
|
+
|
|
107
|
+
TRACKED: ClassVar[list[type[BaseModel]]] = []
|
|
108
|
+
|
|
109
|
+
def __init__(self, session_dir: Path | None) -> None:
|
|
110
|
+
self._dir = session_dir
|
|
111
|
+
|
|
112
|
+
def __getitem__(self, model: type[M]) -> SessionSlot[M]:
|
|
113
|
+
return SessionSlot(self._dir, model)
|
|
114
|
+
|
|
115
|
+
def load(self, model: type[M]) -> M:
|
|
116
|
+
"""Read ``model`` from its session slot, defaulting to a fresh ``model()``.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
model: The Pydantic model class to read.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
The persisted instance, or a newly constructed ``model()`` when no
|
|
123
|
+
stored state exists for this session.
|
|
124
|
+
"""
|
|
125
|
+
return self[model].get(model())
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def track(cls, model: type[BaseModel]) -> None:
|
|
129
|
+
"""Register ``model`` so it appears in ``tracked_models()`` and ``tracked_paths()``."""
|
|
130
|
+
if model not in cls.TRACKED:
|
|
131
|
+
cls.TRACKED.append(model)
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def untrack(cls, model: type[BaseModel]) -> None:
|
|
135
|
+
"""Reverse ``track`` — primarily for test isolation."""
|
|
136
|
+
if model in cls.TRACKED:
|
|
137
|
+
cls.TRACKED.remove(model)
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def tracked_models(cls) -> Sequence[type[BaseModel]]:
|
|
141
|
+
"""Return the registered tracked-state models as an immutable tuple."""
|
|
142
|
+
return tuple(cls.TRACKED)
|
|
143
|
+
|
|
144
|
+
def tracked_paths(self) -> dict[str, Path]:
|
|
145
|
+
"""Return ``{ModelClass.__name__: Path}`` for every tracked model whose slot has a path."""
|
|
146
|
+
return {m.__name__: p for m in type(self).TRACKED if (p := self[m].path)}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def session_state[T: BaseModel](cls: type[T]) -> type[T]:
|
|
150
|
+
"""Decorator that registers a Pydantic model for collective ``SessionStore`` introspection.
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
>>> @session_state
|
|
154
|
+
... class Snapshot(BaseModel):
|
|
155
|
+
... op_id: str
|
|
156
|
+
"""
|
|
157
|
+
SessionStore.track(cls)
|
|
158
|
+
return cls
|
captain_hook/settings.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import types
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, get_type_hints
|
|
7
|
+
|
|
8
|
+
from pydantic import Field
|
|
9
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
10
|
+
|
|
11
|
+
INFERRABLE_PRIMITIVES = (str, int, float, bool)
|
|
12
|
+
|
|
13
|
+
DEFAULT_PLANNING_AGENTS = [
|
|
14
|
+
"Explore",
|
|
15
|
+
"Plan",
|
|
16
|
+
"general-purpose",
|
|
17
|
+
"explore",
|
|
18
|
+
"plan",
|
|
19
|
+
"web-analyzer",
|
|
20
|
+
"search-specialist",
|
|
21
|
+
"claude-code-guide",
|
|
22
|
+
"context-manager",
|
|
23
|
+
"sentry-error-debugger",
|
|
24
|
+
"logfire-trace-debugger",
|
|
25
|
+
"sentry:issue-summarizer",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
DEFAULT_WAITING_TOOLS = [
|
|
29
|
+
"Monitor",
|
|
30
|
+
"TeamCreate",
|
|
31
|
+
"ScheduleWakeup",
|
|
32
|
+
"SendMessage",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
DEFAULT_STATE_DIR = Path.home() / ".claude" / "state"
|
|
36
|
+
DEFAULT_LOG_DIR = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) / "captain-hook" / "logs"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def resolve_state_dir() -> Path:
|
|
40
|
+
return Path(os.environ.get("CAPTAIN_HOOK_STATE_DIR") or os.environ.get("CLAUDE_HOOKS_STATE_DIR") or DEFAULT_STATE_DIR)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def resolve_log_dir() -> Path:
|
|
44
|
+
return Path(os.environ.get("CAPTAIN_HOOK_LOG_DIR") or DEFAULT_LOG_DIR)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class HooksSettings(BaseSettings):
|
|
48
|
+
"""Base settings class for hook configuration, backed by environment variables with ``HOOKS_`` prefix."""
|
|
49
|
+
|
|
50
|
+
model_config = SettingsConfigDict(env_prefix="HOOKS_")
|
|
51
|
+
|
|
52
|
+
planning_agents: list[str] = Field(default_factory=lambda: list(DEFAULT_PLANNING_AGENTS))
|
|
53
|
+
waiting_tools: list[str] = Field(default_factory=lambda: list(DEFAULT_WAITING_TOOLS))
|
|
54
|
+
state_dir: Path = Field(default_factory=resolve_state_dir)
|
|
55
|
+
log_dir: Path = Field(default_factory=resolve_log_dir)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AutoConf:
|
|
59
|
+
"""Automatic settings builder that infers a ``HooksSettings`` subclass from a conf module's attributes."""
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def should_skip(name: str, val: Any) -> bool:
|
|
63
|
+
return name.startswith("_") or name.isupper() or callable(val) or isinstance(val, types.ModuleType)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def find_settings_class(module: types.ModuleType) -> type[HooksSettings] | None:
|
|
67
|
+
for val in vars(module).values():
|
|
68
|
+
if isinstance(val, type) and issubclass(val, HooksSettings) and val is not HooksSettings:
|
|
69
|
+
return val
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def build_settings(module: types.ModuleType, prefix: str = "HOOKS_") -> BaseSettings:
|
|
74
|
+
if settings_cls := AutoConf.find_settings_class(module):
|
|
75
|
+
return settings_cls()
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
hints = get_type_hints(module)
|
|
79
|
+
except Exception:
|
|
80
|
+
hints = {}
|
|
81
|
+
|
|
82
|
+
candidates = sorted(
|
|
83
|
+
set(hints.keys()) | {k for k in vars(module) if not AutoConf.should_skip(k, getattr(module, k))}
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
fields: dict[str, tuple[type, Any]] = {}
|
|
87
|
+
for name in candidates:
|
|
88
|
+
val = getattr(module, name, None)
|
|
89
|
+
if AutoConf.should_skip(name, val):
|
|
90
|
+
continue
|
|
91
|
+
if val is None or isinstance(val, (dict, set)):
|
|
92
|
+
continue
|
|
93
|
+
match name in hints, isinstance(val, list):
|
|
94
|
+
case True, True:
|
|
95
|
+
fields[name] = (hints[name], Field(default_factory=lambda v=val: list(v)))
|
|
96
|
+
case True, False:
|
|
97
|
+
fields[name] = (hints[name], Field(default=val))
|
|
98
|
+
case False, True:
|
|
99
|
+
fields[name] = (list, Field(default_factory=lambda v=val: list(v)))
|
|
100
|
+
case False, _ if isinstance(val, tuple):
|
|
101
|
+
fields[name] = (tuple, Field(default_factory=lambda v=val: tuple(v)))
|
|
102
|
+
case False, _ if isinstance(val, INFERRABLE_PRIMITIVES):
|
|
103
|
+
fields[name] = (type(val), Field(default=val))
|
|
104
|
+
case _:
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
return type(
|
|
108
|
+
"AutoSettings",
|
|
109
|
+
(HooksSettings,),
|
|
110
|
+
{
|
|
111
|
+
"__annotations__": {k: t for k, (t, _) in fields.items()},
|
|
112
|
+
"model_config": SettingsConfigDict(env_prefix=prefix),
|
|
113
|
+
**{k: fd for k, (_, fd) in fields.items()},
|
|
114
|
+
},
|
|
115
|
+
)()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def build_settings(module: types.ModuleType, prefix: str = "HOOKS_") -> BaseSettings:
|
|
119
|
+
"""Build a settings instance from a conf module, using an explicit ``HooksSettings`` subclass or auto-inferring fields."""
|
|
120
|
+
return AutoConf.build_settings(module, prefix)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from captain_hook.signals.nlp import NlpSignal
|
|
8
|
+
from captain_hook.types import Event, Signal, Signals
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from captain_hook.events import BaseHookEvent
|
|
12
|
+
|
|
13
|
+
TSignalPattern = Signal | NlpSignal
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def score_signals(patterns: Sequence[TSignalPattern], text: str) -> int:
|
|
17
|
+
from captain_hook.signals.nlp import nlp_scan
|
|
18
|
+
|
|
19
|
+
total = 0
|
|
20
|
+
for s in patterns:
|
|
21
|
+
match s:
|
|
22
|
+
case NlpSignal(clauses=clauses) if nlp_scan(clauses, text):
|
|
23
|
+
total += s.weight
|
|
24
|
+
case Signal() if re.search(s.pattern, text, s.flags):
|
|
25
|
+
total += s.weight
|
|
26
|
+
case _:
|
|
27
|
+
pass
|
|
28
|
+
return total
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def extract_signal_context(patterns: Sequence[TSignalPattern], text: str) -> list[str]:
|
|
32
|
+
from captain_hook.signals.nlp import nlp_scan
|
|
33
|
+
|
|
34
|
+
result: list[str] = []
|
|
35
|
+
for s in patterns:
|
|
36
|
+
match s:
|
|
37
|
+
case NlpSignal(clauses=clauses):
|
|
38
|
+
result.extend(nlp_scan(clauses, text))
|
|
39
|
+
case Signal():
|
|
40
|
+
result.extend(line for line in text.splitlines() if re.search(s.pattern, line, s.flags))
|
|
41
|
+
return result
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def transcript_texts(evt: BaseHookEvent, window: int) -> list[str]:
|
|
45
|
+
"""Extract text from recent transcript messages for signal scoring.
|
|
46
|
+
|
|
47
|
+
For ``UserPromptSubmit`` events, returns just the user prompt.
|
|
48
|
+
Otherwise returns ``.text`` from the last ``window`` messages.
|
|
49
|
+
"""
|
|
50
|
+
return (
|
|
51
|
+
[evt.user_prompt]
|
|
52
|
+
if evt.event == Event.UserPromptSubmit and evt.user_prompt
|
|
53
|
+
else [msg.text for msg in evt.ctx.t.recent(window).messages if msg.text]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def cite_message(sig: Signals, triggering: list[str], message: str) -> str:
|
|
58
|
+
"""Append trigger context to a message when signal matches are found."""
|
|
59
|
+
return (
|
|
60
|
+
f"{message}\n\nTriggered by: {'; '.join(context)}"
|
|
61
|
+
if (context := extract_signal_context(sig.patterns, "\n".join(triggering)))
|
|
62
|
+
else message
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def resolve_signals(signals: Sequence[Signal | NlpSignal] | Signals | None) -> Signals | None:
|
|
67
|
+
"""Normalize signals input into a ``Signals`` bundle, or None.
|
|
68
|
+
|
|
69
|
+
A bare ``list[Signal]`` is wrapped with ``threshold=1`` (any single match triggers).
|
|
70
|
+
"""
|
|
71
|
+
if signals is None:
|
|
72
|
+
return None
|
|
73
|
+
if isinstance(signals, Signals):
|
|
74
|
+
return signals
|
|
75
|
+
return Signals(patterns=list(signals), threshold=1)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
__all__ = [
|
|
79
|
+
"Signal",
|
|
80
|
+
"Signals",
|
|
81
|
+
"cite_message",
|
|
82
|
+
"extract_signal_context",
|
|
83
|
+
"resolve_signals",
|
|
84
|
+
"score_signals",
|
|
85
|
+
"transcript_texts",
|
|
86
|
+
]
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from spacy.tokens import Doc, Span, Token
|
|
10
|
+
|
|
11
|
+
__all__ = ["Clause", "NlpSignal", "Phrase", "dep_related", "nlp_scan"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True, slots=True, init=False)
|
|
15
|
+
class Phrase:
|
|
16
|
+
lemmas: tuple[str, ...]
|
|
17
|
+
|
|
18
|
+
def __init__(self, *terms: str) -> None:
|
|
19
|
+
object.__setattr__(self, "lemmas", tuple(t.lower() for t in terms))
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def expand(cls, *terms: str, pos: str = "n") -> Phrase:
|
|
23
|
+
from captain_hook.state import RESOURCES
|
|
24
|
+
|
|
25
|
+
return cls(
|
|
26
|
+
*{
|
|
27
|
+
lemma.replace("_", " ")
|
|
28
|
+
for term in terms
|
|
29
|
+
for ss in RESOURCES.wn.synsets(term, pos=pos)
|
|
30
|
+
for lemma in ss.lemmas()
|
|
31
|
+
}
|
|
32
|
+
| {t.lower() for t in terms}
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True, slots=True)
|
|
37
|
+
class Clause:
|
|
38
|
+
noun: Phrase
|
|
39
|
+
verb: Phrase | None = None
|
|
40
|
+
adj: Phrase | None = None
|
|
41
|
+
negated: bool = False
|
|
42
|
+
|
|
43
|
+
def __post_init__(self) -> None:
|
|
44
|
+
if not self.verb and not self.adj and not self.negated and not any(" " in lemma for lemma in self.noun.lemmas):
|
|
45
|
+
raise ValueError("Clause needs verb, adj, negated, or a compound noun phrase")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True, kw_only=True)
|
|
49
|
+
class NlpSignal:
|
|
50
|
+
clauses: Sequence[Clause]
|
|
51
|
+
weight: int = 1
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@functools.lru_cache
|
|
55
|
+
def parse(text: str) -> Doc:
|
|
56
|
+
from captain_hook.state import RESOURCES
|
|
57
|
+
|
|
58
|
+
return RESOURCES.spacy(text)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def ancestors(tok: Token, max_hops: int) -> set[Token]:
|
|
62
|
+
result = {tok}
|
|
63
|
+
node = tok
|
|
64
|
+
for _ in range(max_hops):
|
|
65
|
+
if node == node.head:
|
|
66
|
+
break
|
|
67
|
+
result.add(node := node.head)
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def dep_related(a: Token, b: Token, max_hops: int = 3) -> bool:
|
|
72
|
+
return bool(ancestors(a, max_hops) & ancestors(b, max_hops))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def find_lemma_matches(phrase: Phrase, sent: Span, pos: set[str]) -> list[Token]:
|
|
76
|
+
return [
|
|
77
|
+
tok
|
|
78
|
+
for lemma in phrase.lemmas
|
|
79
|
+
if (parts := lemma.split())
|
|
80
|
+
for tok in sent
|
|
81
|
+
if tok.pos_ in pos
|
|
82
|
+
and tok.lemma_.lower() == parts[-1]
|
|
83
|
+
and (
|
|
84
|
+
len(parts) == 1
|
|
85
|
+
or all(m in {c.lemma_.lower() for c in tok.children if c.dep_ == "compound"} for m in parts[:-1])
|
|
86
|
+
)
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def match_clause(clause: Clause, sent: Span) -> bool:
|
|
91
|
+
return any(
|
|
92
|
+
(not clause.verb or any(dep_related(nt, v) for v in find_lemma_matches(clause.verb, sent, {"VERB"})))
|
|
93
|
+
and (
|
|
94
|
+
not clause.adj
|
|
95
|
+
or any(dep_related(nt, a) for a in find_lemma_matches(clause.adj, sent, {"ADJ", "ADV", "PART"}))
|
|
96
|
+
)
|
|
97
|
+
and (not clause.negated or any(t.dep_ == "neg" and dep_related(nt, t) for t in sent))
|
|
98
|
+
for nt in find_lemma_matches(clause.noun, sent, {"NOUN", "PROPN"})
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def nlp_scan(clauses: Sequence[Clause], text: str) -> list[str]:
|
|
103
|
+
if not text.strip():
|
|
104
|
+
return []
|
|
105
|
+
return [sent.text.strip() for sent in parse(text).sents if any(match_clause(clause, sent) for clause in clauses)]
|
captain_hook/state.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""Hook fire-count and primitive echo-suppression state, plus shared NLP resources (spaCy, WordNet)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import inspect
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from collections.abc import Callable, Iterable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from functools import cached_property
|
|
10
|
+
from hashlib import sha256
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from types import FrameType, ModuleType
|
|
18
|
+
|
|
19
|
+
import spacy
|
|
20
|
+
|
|
21
|
+
from captain_hook.events import BaseHookEvent
|
|
22
|
+
from captain_hook.types import Signals
|
|
23
|
+
|
|
24
|
+
FRAMEWORK_DIR = str(Path(__file__).resolve().parent)
|
|
25
|
+
CACHE_ROOT = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) / "captain-hook"
|
|
26
|
+
SPACY_MODEL = "en_core_web_sm"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class NlpResources:
|
|
30
|
+
@cached_property
|
|
31
|
+
def spacy(self) -> spacy.language.Language:
|
|
32
|
+
import spacy
|
|
33
|
+
|
|
34
|
+
from captain_hook.util.model_cache import cached_pipeline
|
|
35
|
+
|
|
36
|
+
if spacy.util.is_package(SPACY_MODEL):
|
|
37
|
+
return spacy.load(SPACY_MODEL)
|
|
38
|
+
# We refuse to auto-download from a live hook: it's a ~100MB silent fetch behind
|
|
39
|
+
# the agent's back. If a previous run / explicit install already populated the
|
|
40
|
+
# cache, use that; otherwise, raise with an actionable install hint.
|
|
41
|
+
if cached := cached_pipeline():
|
|
42
|
+
return spacy.load(cached)
|
|
43
|
+
raise RuntimeError(
|
|
44
|
+
f"spaCy model {SPACY_MODEL!r} is not installed. "
|
|
45
|
+
f"Install it explicitly before running hooks that use NLP signals: "
|
|
46
|
+
f"`python -m spacy download {SPACY_MODEL}` "
|
|
47
|
+
f"or `python -c \"from captain_hook.util.model_cache import ensure_spacy_model; ensure_spacy_model()\"`."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@cached_property
|
|
51
|
+
def wn(self) -> ModuleType:
|
|
52
|
+
import wn
|
|
53
|
+
|
|
54
|
+
if not wn.lexicons(lexicon="oewn:2025"):
|
|
55
|
+
wn.download("oewn:2025", progress_handler=None)
|
|
56
|
+
return wn
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
RESOURCES = NlpResources()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class HookState(BaseModel):
|
|
63
|
+
"""Per-hook persistent state tracked across events in a session (currently just ``fire_count`` for ``max_fires``)."""
|
|
64
|
+
|
|
65
|
+
fire_count: int = 0
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# ECHO_WINDOW: number of subsequent transcript messages after a nudge fires during which we
|
|
69
|
+
# suppress restating the same idea. Tuned for "the agent reads our nudge and the next ~5
|
|
70
|
+
# assistant messages reference the same concept".
|
|
71
|
+
# ECHO_THRESHOLD: fraction of content lemmas in a candidate text that must overlap with the
|
|
72
|
+
# nudge's lemmas to count as an echo. 0.4 = "if 40%+ of the meaningful words match, the
|
|
73
|
+
# agent is parroting".
|
|
74
|
+
# ECHO_MIN_OVERLAP: absolute minimum overlap to count, so short messages don't pass the
|
|
75
|
+
# fractional threshold trivially (e.g. a 2-token message would otherwise hit 0.5).
|
|
76
|
+
ECHO_WINDOW = 5
|
|
77
|
+
ECHO_THRESHOLD = 0.4
|
|
78
|
+
ECHO_MIN_OVERLAP = 2
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class PrimitiveState(BaseModel):
|
|
82
|
+
"""Per-primitive state for nudges/gates: last fire index, consumed-signal hashes, and echo-window lemmas."""
|
|
83
|
+
|
|
84
|
+
last_fired_at: int = 0
|
|
85
|
+
consumed: set[str] = Field(default_factory=set)
|
|
86
|
+
echo_lemmas: set[str] = Field(default_factory=set)
|
|
87
|
+
echo_window_end: int = 0
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def content_lemmas(text: str) -> set[str]:
|
|
91
|
+
return {
|
|
92
|
+
tok.lemma_.lower()
|
|
93
|
+
for tok in RESOURCES.spacy(text)
|
|
94
|
+
if tok.pos_ in {"NOUN", "VERB", "ADJ"} and not tok.is_stop and len(tok.lemma_) > 2
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def is_echo(self, text: str) -> bool:
|
|
98
|
+
return bool(
|
|
99
|
+
self.echo_lemmas
|
|
100
|
+
and (text_lemmas := self.content_lemmas(text))
|
|
101
|
+
and len(overlap := text_lemmas & self.echo_lemmas) >= ECHO_MIN_OVERLAP
|
|
102
|
+
and len(overlap) / len(text_lemmas) >= ECHO_THRESHOLD
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def consume_echoes(self, texts: list[str], transcript_len: int) -> None:
|
|
106
|
+
if not self.echo_lemmas or transcript_len >= self.echo_window_end:
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
for text in texts:
|
|
110
|
+
if (h := text_hash(text)) not in self.consumed and self.is_echo(text):
|
|
111
|
+
self.consumed.add(h)
|
|
112
|
+
|
|
113
|
+
def seed_echo_window(self, triggering_texts: list[str], message: str, transcript_len: int) -> None:
|
|
114
|
+
self.echo_lemmas = self.content_lemmas(" ".join(triggering_texts)) | self.content_lemmas(message)
|
|
115
|
+
self.echo_window_end = transcript_len + ECHO_WINDOW
|
|
116
|
+
|
|
117
|
+
def match_signals(self, sig: Signals, texts: list[str]) -> list[str] | None:
|
|
118
|
+
from captain_hook.signals import score_signals
|
|
119
|
+
|
|
120
|
+
contributing_hashes = [
|
|
121
|
+
h
|
|
122
|
+
for text in texts
|
|
123
|
+
if (h := text_hash(text)) not in self.consumed and score_signals(sig.patterns, text) >= sig.threshold
|
|
124
|
+
]
|
|
125
|
+
if not contributing_hashes:
|
|
126
|
+
return None
|
|
127
|
+
self.consumed.update(contributing_hashes)
|
|
128
|
+
return [t for t in texts if text_hash(t) in set(contributing_hashes)]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def text_hash(text: str) -> str:
|
|
132
|
+
return sha256(text.encode()).hexdigest()[:16]
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def package_aware_stem(p: Path) -> str:
|
|
136
|
+
if (
|
|
137
|
+
p.name != "__init__.py"
|
|
138
|
+
and not str(p).startswith(FRAMEWORK_DIR)
|
|
139
|
+
and (init := p.parent / "__init__.py").exists()
|
|
140
|
+
and init.stat().st_size > 0
|
|
141
|
+
):
|
|
142
|
+
return f"{p.parent.name}.{p.stem}"
|
|
143
|
+
return p.stem
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def caller_stem() -> str:
|
|
147
|
+
frame: FrameType | None = inspect.currentframe()
|
|
148
|
+
if frame:
|
|
149
|
+
frame = frame.f_back
|
|
150
|
+
while frame and frame.f_code.co_filename.startswith(FRAMEWORK_DIR):
|
|
151
|
+
frame = frame.f_back
|
|
152
|
+
return package_aware_stem(Path(frame.f_code.co_filename)) if frame else "unknown"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def hook_name(prefix: str, label: str | None, message: str) -> str:
|
|
156
|
+
suffix = re.sub(r"[^a-z0-9]+", "_", label.lower()).strip("_") if label else sha256(message.encode()).hexdigest()[:8]
|
|
157
|
+
return f"{caller_stem()}:{prefix}_{suffix}"
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def record_fire(evt: BaseHookEvent) -> None:
|
|
161
|
+
ps = evt.ctx.s[PrimitiveState].get(PrimitiveState())
|
|
162
|
+
ps.last_fired_at = len(evt.ctx.t)
|
|
163
|
+
evt.ctx.s[PrimitiveState].set(ps)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def fired_this_turn(evt: BaseHookEvent) -> bool:
|
|
167
|
+
return (ps := evt.ctx.s[PrimitiveState].get()) is not None and ps.last_fired_at > evt.ctx.turn.start_idx
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
from captain_hook.session import SessionStore # noqa: E402
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@dataclass
|
|
174
|
+
class EchoTracker:
|
|
175
|
+
window: int = ECHO_WINDOW
|
|
176
|
+
threshold: float = ECHO_THRESHOLD
|
|
177
|
+
min_overlap: int = ECHO_MIN_OVERLAP
|
|
178
|
+
|
|
179
|
+
def saw(self, text: str, *, evt: BaseHookEvent) -> bool:
|
|
180
|
+
ps = evt.ctx.s[PrimitiveState].get()
|
|
181
|
+
return (
|
|
182
|
+
ps is not None
|
|
183
|
+
and bool(ps.echo_lemmas)
|
|
184
|
+
and len(evt.ctx.t) < ps.echo_window_end
|
|
185
|
+
and ps.is_echo(text)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def record(self, text: str, triggering: Iterable[str], *, evt: BaseHookEvent) -> None:
|
|
189
|
+
ps = evt.ctx.s[PrimitiveState].get(PrimitiveState())
|
|
190
|
+
ps.echo_lemmas = PrimitiveState.content_lemmas(" ".join(triggering)) | PrimitiveState.content_lemmas(text)
|
|
191
|
+
ps.echo_window_end = len(evt.ctx.t) + self.window
|
|
192
|
+
evt.ctx.s[PrimitiveState].set(ps)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
T = TypeVar("T", bound=BaseModel)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def workflow_state(name: str) -> Callable[[type[T]], type[T]]:
|
|
199
|
+
def wrap(cls: type[T]) -> type[T]:
|
|
200
|
+
cls.__workflow_name__ = name # type: ignore[attr-defined]
|
|
201
|
+
SessionStore.track(cls)
|
|
202
|
+
|
|
203
|
+
def load(inner_cls: type[T], evt: BaseHookEvent) -> T:
|
|
204
|
+
return evt.ctx.s.load(inner_cls)
|
|
205
|
+
|
|
206
|
+
def save(self: T, evt: BaseHookEvent) -> None:
|
|
207
|
+
evt.ctx.s[type(self)].set(self)
|
|
208
|
+
|
|
209
|
+
def reset(inner_cls: type[T], evt: BaseHookEvent) -> None:
|
|
210
|
+
evt.ctx.s[inner_cls].delete()
|
|
211
|
+
|
|
212
|
+
cls.load = classmethod(load) # type: ignore[attr-defined]
|
|
213
|
+
cls.save = save # type: ignore[attr-defined]
|
|
214
|
+
cls.reset = classmethod(reset) # type: ignore[attr-defined]
|
|
215
|
+
return cls
|
|
216
|
+
|
|
217
|
+
return wrap
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
SessionStore.track(HookState)
|
|
221
|
+
SessionStore.track(PrimitiveState)
|