chuk-ai-session-manager 0.7.1__py3-none-any.whl → 0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chuk_ai_session_manager/__init__.py +84 -40
- chuk_ai_session_manager/api/__init__.py +1 -1
- chuk_ai_session_manager/api/simple_api.py +53 -59
- chuk_ai_session_manager/exceptions.py +31 -17
- chuk_ai_session_manager/guards/__init__.py +118 -0
- chuk_ai_session_manager/guards/bindings.py +217 -0
- chuk_ai_session_manager/guards/cache.py +163 -0
- chuk_ai_session_manager/guards/manager.py +819 -0
- chuk_ai_session_manager/guards/models.py +498 -0
- chuk_ai_session_manager/guards/ungrounded.py +159 -0
- chuk_ai_session_manager/infinite_conversation.py +86 -79
- chuk_ai_session_manager/memory/__init__.py +247 -0
- chuk_ai_session_manager/memory/artifacts_bridge.py +469 -0
- chuk_ai_session_manager/memory/context_packer.py +347 -0
- chuk_ai_session_manager/memory/fault_handler.py +507 -0
- chuk_ai_session_manager/memory/manifest.py +307 -0
- chuk_ai_session_manager/memory/models.py +1084 -0
- chuk_ai_session_manager/memory/mutation_log.py +186 -0
- chuk_ai_session_manager/memory/pack_cache.py +206 -0
- chuk_ai_session_manager/memory/page_table.py +275 -0
- chuk_ai_session_manager/memory/prefetcher.py +192 -0
- chuk_ai_session_manager/memory/tlb.py +247 -0
- chuk_ai_session_manager/memory/vm_prompts.py +238 -0
- chuk_ai_session_manager/memory/working_set.py +574 -0
- chuk_ai_session_manager/models/__init__.py +21 -9
- chuk_ai_session_manager/models/event_source.py +3 -1
- chuk_ai_session_manager/models/event_type.py +10 -1
- chuk_ai_session_manager/models/session.py +103 -68
- chuk_ai_session_manager/models/session_event.py +69 -68
- chuk_ai_session_manager/models/session_metadata.py +9 -10
- chuk_ai_session_manager/models/session_run.py +21 -22
- chuk_ai_session_manager/models/token_usage.py +76 -76
- chuk_ai_session_manager/procedural_memory/__init__.py +70 -0
- chuk_ai_session_manager/procedural_memory/formatter.py +407 -0
- chuk_ai_session_manager/procedural_memory/manager.py +523 -0
- chuk_ai_session_manager/procedural_memory/models.py +371 -0
- chuk_ai_session_manager/sample_tools.py +79 -46
- chuk_ai_session_manager/session_aware_tool_processor.py +27 -16
- chuk_ai_session_manager/session_manager.py +238 -197
- chuk_ai_session_manager/session_prompt_builder.py +163 -111
- chuk_ai_session_manager/session_storage.py +45 -52
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.dist-info}/METADATA +79 -3
- chuk_ai_session_manager-0.8.dist-info/RECORD +45 -0
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.dist-info}/WHEEL +1 -1
- chuk_ai_session_manager-0.7.1.dist-info/RECORD +0 -22
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,819 @@
|
|
|
1
|
+
# chuk_ai_session_manager/guards/manager.py
|
|
2
|
+
"""Slim ToolStateManager - coordinates guards and state.
|
|
3
|
+
|
|
4
|
+
This is a thin facade that wires together:
|
|
5
|
+
- BindingManager for $vN references
|
|
6
|
+
- ResultCache for deduplication
|
|
7
|
+
- All guards (precondition, budget, ungrounded, runaway, per-tool)
|
|
8
|
+
|
|
9
|
+
All heavy logic is in the guards and sub-managers.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
import re
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel, Field
|
|
19
|
+
|
|
20
|
+
from chuk_ai_session_manager.guards.bindings import BindingManager
|
|
21
|
+
from chuk_ai_session_manager.guards.cache import ResultCache
|
|
22
|
+
from chuk_ai_session_manager.guards.models import (
|
|
23
|
+
CachedToolResult,
|
|
24
|
+
NamedVariable,
|
|
25
|
+
PerToolCallStatus,
|
|
26
|
+
ReferenceCheckResult,
|
|
27
|
+
RunawayStatus,
|
|
28
|
+
RuntimeLimits,
|
|
29
|
+
RuntimeMode,
|
|
30
|
+
SoftBlockReason,
|
|
31
|
+
ToolClassification,
|
|
32
|
+
UngroundedCallResult,
|
|
33
|
+
ValueBinding,
|
|
34
|
+
)
|
|
35
|
+
from chuk_ai_session_manager.guards.ungrounded import (
|
|
36
|
+
UngroundedGuard,
|
|
37
|
+
UngroundedGuardConfig,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Import guards from chuk-tool-processor
|
|
41
|
+
from chuk_tool_processor.guards import (
|
|
42
|
+
BudgetGuard,
|
|
43
|
+
BudgetGuardConfig,
|
|
44
|
+
GuardResult,
|
|
45
|
+
GuardVerdict,
|
|
46
|
+
PerToolGuard,
|
|
47
|
+
PerToolGuardConfig,
|
|
48
|
+
PreconditionGuard,
|
|
49
|
+
PreconditionGuardConfig,
|
|
50
|
+
RunawayGuard,
|
|
51
|
+
RunawayGuardConfig,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
log = logging.getLogger(__name__)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ToolStateManager(BaseModel):
|
|
58
|
+
"""Coordinates tool state and guards.
|
|
59
|
+
|
|
60
|
+
Pydantic-native, slim coordinator. All logic delegated to:
|
|
61
|
+
- bindings: BindingManager for $vN references
|
|
62
|
+
- cache: ResultCache for deduplication
|
|
63
|
+
- guards: Individual guard instances
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
# Sub-managers
|
|
67
|
+
bindings: BindingManager = Field(default_factory=BindingManager)
|
|
68
|
+
cache: ResultCache = Field(default_factory=ResultCache)
|
|
69
|
+
|
|
70
|
+
# Guards (initialized lazily)
|
|
71
|
+
precondition_guard: PreconditionGuard | None = Field(default=None, exclude=True)
|
|
72
|
+
budget_guard: BudgetGuard | None = Field(default=None, exclude=True)
|
|
73
|
+
ungrounded_guard: UngroundedGuard | None = Field(default=None, exclude=True)
|
|
74
|
+
runaway_guard: RunawayGuard | None = Field(default=None, exclude=True)
|
|
75
|
+
per_tool_guard: PerToolGuard | None = Field(default=None, exclude=True)
|
|
76
|
+
|
|
77
|
+
# User-provided literals (whitelisted for ungrounded check)
|
|
78
|
+
user_literals: set[float] = Field(default_factory=set)
|
|
79
|
+
|
|
80
|
+
# Stated values from assistant text
|
|
81
|
+
stated_values: dict[float, str] = Field(default_factory=dict)
|
|
82
|
+
|
|
83
|
+
# Runtime limits
|
|
84
|
+
limits: RuntimeLimits = Field(default_factory=RuntimeLimits)
|
|
85
|
+
|
|
86
|
+
# Per-tool call tracking (0 = unlimited)
|
|
87
|
+
per_tool_limit: int = Field(default=0)
|
|
88
|
+
tool_call_counts: dict[str, int] = Field(default_factory=dict)
|
|
89
|
+
|
|
90
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
91
|
+
|
|
92
|
+
def model_post_init(self, __context: Any) -> None:
|
|
93
|
+
"""Initialize guards after model creation."""
|
|
94
|
+
self._init_guards()
|
|
95
|
+
|
|
96
|
+
def _init_guards(self) -> None:
|
|
97
|
+
"""Initialize all guards with proper callbacks."""
|
|
98
|
+
self.precondition_guard = PreconditionGuard(
|
|
99
|
+
config=PreconditionGuardConfig(
|
|
100
|
+
parameterized_tools=set(ToolClassification.PARAMETERIZED_TOOLS),
|
|
101
|
+
safe_values={0.0, 1.0},
|
|
102
|
+
),
|
|
103
|
+
get_binding_count=lambda: len(self.bindings.bindings),
|
|
104
|
+
get_binding_values=lambda: self.bindings.get_numeric_values(),
|
|
105
|
+
get_user_literals=lambda: self.user_literals,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.budget_guard = BudgetGuard(
|
|
109
|
+
config=BudgetGuardConfig(
|
|
110
|
+
discovery_budget=self.limits.discovery_budget,
|
|
111
|
+
execution_budget=self.limits.execution_budget,
|
|
112
|
+
total_budget=self.limits.tool_budget_total,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.ungrounded_guard = UngroundedGuard(
|
|
117
|
+
config=UngroundedGuardConfig(
|
|
118
|
+
grace_calls=self.limits.ungrounded_grace_calls,
|
|
119
|
+
mode=self.limits.require_bindings,
|
|
120
|
+
),
|
|
121
|
+
get_user_literals=lambda: self.user_literals,
|
|
122
|
+
get_bindings=lambda: self.bindings.bindings,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self.runaway_guard = RunawayGuard(config=RunawayGuardConfig())
|
|
126
|
+
self.per_tool_guard = PerToolGuard(
|
|
127
|
+
config=PerToolGuardConfig(default_limit=self.limits.per_tool_cap)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# =========================================================================
|
|
131
|
+
# Configuration
|
|
132
|
+
# =========================================================================
|
|
133
|
+
|
|
134
|
+
def configure(self, limits: RuntimeLimits) -> None:
|
|
135
|
+
"""Configure runtime limits."""
|
|
136
|
+
self.limits = limits
|
|
137
|
+
self._init_guards()
|
|
138
|
+
log.info(f"Configured runtime limits: {limits}")
|
|
139
|
+
|
|
140
|
+
def set_mode(self, mode: RuntimeMode | str) -> None:
|
|
141
|
+
"""Set runtime mode preset."""
|
|
142
|
+
if isinstance(mode, str):
|
|
143
|
+
mode = RuntimeMode(mode.lower())
|
|
144
|
+
|
|
145
|
+
if mode == RuntimeMode.SMOOTH:
|
|
146
|
+
self.configure(RuntimeLimits.smooth())
|
|
147
|
+
elif mode == RuntimeMode.STRICT:
|
|
148
|
+
self.configure(RuntimeLimits.strict())
|
|
149
|
+
log.info(f"Runtime mode set to: {mode}")
|
|
150
|
+
|
|
151
|
+
# =========================================================================
|
|
152
|
+
# Guard Checks
|
|
153
|
+
# =========================================================================
|
|
154
|
+
|
|
155
|
+
def check_all_guards(
|
|
156
|
+
self,
|
|
157
|
+
tool_name: str,
|
|
158
|
+
arguments: dict[str, Any],
|
|
159
|
+
) -> GuardResult:
|
|
160
|
+
"""Run all guards and return first blocking result."""
|
|
161
|
+
guards = [
|
|
162
|
+
self.precondition_guard,
|
|
163
|
+
self.budget_guard,
|
|
164
|
+
self.ungrounded_guard,
|
|
165
|
+
self.per_tool_guard,
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
for guard in guards:
|
|
169
|
+
if guard is None:
|
|
170
|
+
continue
|
|
171
|
+
result = guard.check(tool_name, arguments)
|
|
172
|
+
if result.blocked:
|
|
173
|
+
return result
|
|
174
|
+
if result.verdict == GuardVerdict.WARN:
|
|
175
|
+
log.warning(f"{guard.__class__.__name__}: {result.reason}")
|
|
176
|
+
|
|
177
|
+
return GuardResult(verdict=GuardVerdict.ALLOW)
|
|
178
|
+
|
|
179
|
+
def check_preconditions(
|
|
180
|
+
self,
|
|
181
|
+
tool_name: str,
|
|
182
|
+
arguments: dict[str, Any],
|
|
183
|
+
) -> tuple[bool, str | None]:
|
|
184
|
+
"""Check if tool preconditions are met."""
|
|
185
|
+
if self.precondition_guard is None:
|
|
186
|
+
return True, None
|
|
187
|
+
|
|
188
|
+
result = self.precondition_guard.check(tool_name, arguments)
|
|
189
|
+
if result.blocked:
|
|
190
|
+
return False, result.reason
|
|
191
|
+
return True, None
|
|
192
|
+
|
|
193
|
+
# =========================================================================
|
|
194
|
+
# Value Binding (delegated to BindingManager)
|
|
195
|
+
# =========================================================================
|
|
196
|
+
|
|
197
|
+
def bind_value(
|
|
198
|
+
self,
|
|
199
|
+
tool_name: str,
|
|
200
|
+
arguments: dict[str, Any],
|
|
201
|
+
value: Any,
|
|
202
|
+
aliases: list[str] | None = None,
|
|
203
|
+
) -> ValueBinding:
|
|
204
|
+
"""Bind a tool result to a $vN reference."""
|
|
205
|
+
binding = self.bindings.bind(tool_name, arguments, value, aliases)
|
|
206
|
+
log.debug(f"Bound ${binding.id} = {value} from {tool_name}")
|
|
207
|
+
return binding
|
|
208
|
+
|
|
209
|
+
def get_binding(self, ref: str) -> ValueBinding | None:
|
|
210
|
+
"""Get a binding by ID or alias."""
|
|
211
|
+
return self.bindings.get(ref)
|
|
212
|
+
|
|
213
|
+
def resolve_references(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
|
214
|
+
"""Resolve $vN references in arguments."""
|
|
215
|
+
return self.bindings.resolve_references(arguments)
|
|
216
|
+
|
|
217
|
+
def check_references(self, arguments: dict[str, Any]) -> ReferenceCheckResult:
|
|
218
|
+
"""Check if all $vN references in arguments are valid."""
|
|
219
|
+
missing_refs: list[str] = []
|
|
220
|
+
resolved_refs: dict[str, Any] = {}
|
|
221
|
+
|
|
222
|
+
def check_value(val: Any) -> None:
|
|
223
|
+
if isinstance(val, str) and val.startswith("$v"):
|
|
224
|
+
ref = val[1:]
|
|
225
|
+
binding = self.bindings.get(ref)
|
|
226
|
+
if binding:
|
|
227
|
+
resolved_refs[val] = binding.raw_value
|
|
228
|
+
else:
|
|
229
|
+
missing_refs.append(val)
|
|
230
|
+
elif isinstance(val, dict):
|
|
231
|
+
for v in val.values():
|
|
232
|
+
check_value(v)
|
|
233
|
+
elif isinstance(val, list):
|
|
234
|
+
for v in val:
|
|
235
|
+
check_value(v)
|
|
236
|
+
|
|
237
|
+
for v in arguments.values():
|
|
238
|
+
check_value(v)
|
|
239
|
+
|
|
240
|
+
if missing_refs:
|
|
241
|
+
return ReferenceCheckResult(
|
|
242
|
+
valid=False,
|
|
243
|
+
missing_refs=missing_refs,
|
|
244
|
+
resolved_refs=resolved_refs,
|
|
245
|
+
message=f"Missing references: {', '.join(missing_refs)}",
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return ReferenceCheckResult(
|
|
249
|
+
valid=True,
|
|
250
|
+
missing_refs=[],
|
|
251
|
+
resolved_refs=resolved_refs,
|
|
252
|
+
message="All references valid",
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# =========================================================================
|
|
256
|
+
# Cache (delegated to ResultCache)
|
|
257
|
+
# =========================================================================
|
|
258
|
+
|
|
259
|
+
def get_cached_result(
|
|
260
|
+
self,
|
|
261
|
+
tool_name: str,
|
|
262
|
+
arguments: dict[str, Any],
|
|
263
|
+
) -> CachedToolResult | None:
|
|
264
|
+
"""Check if we have a cached result."""
|
|
265
|
+
return self.cache.get(tool_name, arguments)
|
|
266
|
+
|
|
267
|
+
def cache_result(
|
|
268
|
+
self,
|
|
269
|
+
tool_name: str,
|
|
270
|
+
arguments: dict[str, Any],
|
|
271
|
+
result: Any,
|
|
272
|
+
) -> CachedToolResult:
|
|
273
|
+
"""Cache a tool result."""
|
|
274
|
+
return self.cache.put(tool_name, arguments, result)
|
|
275
|
+
|
|
276
|
+
def store_variable(
|
|
277
|
+
self,
|
|
278
|
+
name: str,
|
|
279
|
+
value: float,
|
|
280
|
+
units: str | None = None,
|
|
281
|
+
source_tool: str | None = None,
|
|
282
|
+
) -> NamedVariable:
|
|
283
|
+
"""Store a named variable."""
|
|
284
|
+
return self.cache.store_variable(name, value, units, source_tool)
|
|
285
|
+
|
|
286
|
+
def get_variable(self, name: str) -> NamedVariable | None:
|
|
287
|
+
"""Get a stored variable by name."""
|
|
288
|
+
return self.cache.get_variable(name)
|
|
289
|
+
|
|
290
|
+
def get_cache_stats(self) -> dict[str, Any]:
|
|
291
|
+
"""Get cache statistics."""
|
|
292
|
+
return self.cache.get_stats()
|
|
293
|
+
|
|
294
|
+
def format_duplicate_message(
|
|
295
|
+
self, tool_name: str, arguments: dict[str, Any]
|
|
296
|
+
) -> str:
|
|
297
|
+
"""Format message for duplicate tool call."""
|
|
298
|
+
return self.cache.format_duplicate_message(tool_name, arguments)
|
|
299
|
+
|
|
300
|
+
def format_duplicate_recovery_message(
|
|
301
|
+
self, tool_name: str, arguments: dict[str, Any]
|
|
302
|
+
) -> str:
|
|
303
|
+
"""Format recovery message for duplicate tool call (alias)."""
|
|
304
|
+
return self.cache.format_duplicate_message(tool_name, arguments)
|
|
305
|
+
|
|
306
|
+
# =========================================================================
|
|
307
|
+
# Budget Tracking
|
|
308
|
+
# =========================================================================
|
|
309
|
+
|
|
310
|
+
def record_tool_call(self, tool_name: str) -> None:
|
|
311
|
+
"""Record a tool call for budget tracking."""
|
|
312
|
+
if self.budget_guard:
|
|
313
|
+
self.budget_guard.record_call(tool_name)
|
|
314
|
+
if self.per_tool_guard:
|
|
315
|
+
self.per_tool_guard.record_call(tool_name)
|
|
316
|
+
|
|
317
|
+
def get_budget_status(self) -> dict[str, Any]:
|
|
318
|
+
"""Get current budget status."""
|
|
319
|
+
if self.budget_guard:
|
|
320
|
+
status: dict[str, Any] = self.budget_guard.get_status()
|
|
321
|
+
return status
|
|
322
|
+
return {
|
|
323
|
+
"discovery": {"used": 0, "limit": 0},
|
|
324
|
+
"execution": {"used": 0, "limit": 0},
|
|
325
|
+
"total": {"used": 0, "limit": 0},
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
def set_budget(self, budget: int) -> None:
|
|
329
|
+
"""Set the total tool budget."""
|
|
330
|
+
self.limits = RuntimeLimits(
|
|
331
|
+
tool_budget_total=budget, execution_budget=budget, discovery_budget=budget
|
|
332
|
+
)
|
|
333
|
+
self._init_guards()
|
|
334
|
+
|
|
335
|
+
def get_discovery_status(self) -> dict[str, int]:
|
|
336
|
+
"""Get discovery budget status."""
|
|
337
|
+
if self.budget_guard:
|
|
338
|
+
status = self.budget_guard.get_status()
|
|
339
|
+
return {
|
|
340
|
+
"used": status["discovery"]["used"],
|
|
341
|
+
"limit": status["discovery"]["limit"],
|
|
342
|
+
}
|
|
343
|
+
return {"used": 0, "limit": 0}
|
|
344
|
+
|
|
345
|
+
def get_execution_status(self) -> dict[str, int]:
|
|
346
|
+
"""Get execution budget status."""
|
|
347
|
+
if self.budget_guard:
|
|
348
|
+
status = self.budget_guard.get_status()
|
|
349
|
+
return {
|
|
350
|
+
"used": status["execution"]["used"],
|
|
351
|
+
"limit": status["execution"]["limit"],
|
|
352
|
+
}
|
|
353
|
+
return {"used": 0, "limit": 0}
|
|
354
|
+
|
|
355
|
+
def is_discovery_exhausted(self) -> bool:
|
|
356
|
+
"""Check if discovery budget is exhausted."""
|
|
357
|
+
status = self.get_discovery_status()
|
|
358
|
+
return status["used"] >= status["limit"]
|
|
359
|
+
|
|
360
|
+
def is_execution_exhausted(self) -> bool:
|
|
361
|
+
"""Check if execution budget is exhausted."""
|
|
362
|
+
status = self.get_execution_status()
|
|
363
|
+
return status["used"] >= status["limit"]
|
|
364
|
+
|
|
365
|
+
def increment_discovery_call(self) -> None:
|
|
366
|
+
"""Increment discovery call count."""
|
|
367
|
+
if self.budget_guard:
|
|
368
|
+
self.budget_guard.record_call("search_tools")
|
|
369
|
+
|
|
370
|
+
def increment_execution_call(self) -> None:
|
|
371
|
+
"""Increment execution call count."""
|
|
372
|
+
if self.budget_guard:
|
|
373
|
+
self.budget_guard.record_call("execute_tool")
|
|
374
|
+
|
|
375
|
+
def get_discovered_tools(self) -> set[str]:
|
|
376
|
+
"""Get set of discovered tool names."""
|
|
377
|
+
if self.budget_guard:
|
|
378
|
+
tools: set[str] = self.budget_guard._discovered_tools
|
|
379
|
+
return tools
|
|
380
|
+
return set()
|
|
381
|
+
|
|
382
|
+
def is_tool_discovered(self, tool_name: str) -> bool:
|
|
383
|
+
"""Check if a tool has been discovered."""
|
|
384
|
+
return tool_name in self.get_discovered_tools()
|
|
385
|
+
|
|
386
|
+
def record_numeric_result(self, value: float) -> None:
|
|
387
|
+
"""Record a numeric result for runaway detection."""
|
|
388
|
+
if self.runaway_guard:
|
|
389
|
+
self.runaway_guard.record_result(value)
|
|
390
|
+
|
|
391
|
+
@property
|
|
392
|
+
def _recent_numeric_results(self) -> list[float]:
|
|
393
|
+
"""Get recent numeric results from runaway guard."""
|
|
394
|
+
if self.runaway_guard:
|
|
395
|
+
values: list[float] = self.runaway_guard._recent_values
|
|
396
|
+
return values
|
|
397
|
+
return []
|
|
398
|
+
|
|
399
|
+
def register_discovered_tool(self, tool_name: str) -> None:
|
|
400
|
+
"""Register a tool as discovered."""
|
|
401
|
+
if self.budget_guard:
|
|
402
|
+
self.budget_guard.register_discovered_tool(tool_name)
|
|
403
|
+
|
|
404
|
+
# =========================================================================
|
|
405
|
+
# User Literals
|
|
406
|
+
# =========================================================================
|
|
407
|
+
|
|
408
|
+
def register_user_literals(self, text: str) -> int:
|
|
409
|
+
"""Extract and register numeric literals from user prompt."""
|
|
410
|
+
pattern = re.compile(r"-?\d+\.?\d*(?:[eE][+-]?\d+)?")
|
|
411
|
+
matches = pattern.findall(text)
|
|
412
|
+
|
|
413
|
+
count = 0
|
|
414
|
+
for match in matches:
|
|
415
|
+
try:
|
|
416
|
+
value = float(match)
|
|
417
|
+
self.user_literals.add(value)
|
|
418
|
+
count += 1
|
|
419
|
+
except ValueError:
|
|
420
|
+
pass
|
|
421
|
+
|
|
422
|
+
if count > 0:
|
|
423
|
+
log.debug(f"Registered {count} user literals: {self.user_literals}")
|
|
424
|
+
return count
|
|
425
|
+
|
|
426
|
+
# =========================================================================
|
|
427
|
+
# Tool Classification (delegated to ToolClassification)
|
|
428
|
+
# =========================================================================
|
|
429
|
+
|
|
430
|
+
def is_discovery_tool(self, tool_name: str) -> bool:
|
|
431
|
+
"""Check if tool is a discovery tool (search/list/schema)."""
|
|
432
|
+
return ToolClassification.is_discovery_tool(tool_name)
|
|
433
|
+
|
|
434
|
+
def is_execution_tool(self, tool_name: str) -> bool:
|
|
435
|
+
"""Check if tool is an execution tool (not discovery)."""
|
|
436
|
+
return not self.is_discovery_tool(tool_name)
|
|
437
|
+
|
|
438
|
+
def is_idempotent_math_tool(self, tool_name: str) -> bool:
|
|
439
|
+
"""Check if tool is an idempotent math tool."""
|
|
440
|
+
return ToolClassification.is_idempotent_math_tool(tool_name)
|
|
441
|
+
|
|
442
|
+
def is_parameterized_tool(self, tool_name: str) -> bool:
|
|
443
|
+
"""Check if tool requires computed values."""
|
|
444
|
+
return ToolClassification.is_parameterized_tool(tool_name)
|
|
445
|
+
|
|
446
|
+
def classify_by_result(self, tool_name: str, result: Any) -> None:
|
|
447
|
+
"""Classify a tool based on its result shape."""
|
|
448
|
+
if isinstance(result, dict):
|
|
449
|
+
if "results" in result and isinstance(result["results"], list):
|
|
450
|
+
for item in result["results"]:
|
|
451
|
+
if isinstance(item, dict) and "name" in item:
|
|
452
|
+
self.register_discovered_tool(item["name"])
|
|
453
|
+
elif "function" in result:
|
|
454
|
+
func = result.get("function", {})
|
|
455
|
+
if "name" in func:
|
|
456
|
+
self.register_discovered_tool(func["name"])
|
|
457
|
+
|
|
458
|
+
# =========================================================================
|
|
459
|
+
# Ungrounded Call Detection
|
|
460
|
+
# =========================================================================
|
|
461
|
+
|
|
462
|
+
def check_ungrounded_call(
|
|
463
|
+
self,
|
|
464
|
+
tool_name: str,
|
|
465
|
+
arguments: dict[str, Any],
|
|
466
|
+
) -> UngroundedCallResult:
|
|
467
|
+
"""Check if a tool call has ungrounded numeric arguments."""
|
|
468
|
+
if self.ungrounded_guard is None:
|
|
469
|
+
return UngroundedCallResult(is_ungrounded=False)
|
|
470
|
+
|
|
471
|
+
result = self.ungrounded_guard.check(tool_name, arguments)
|
|
472
|
+
|
|
473
|
+
if result.blocked or result.verdict == GuardVerdict.WARN:
|
|
474
|
+
numeric_args = []
|
|
475
|
+
for k, v in arguments.items():
|
|
476
|
+
if isinstance(v, (int, float)):
|
|
477
|
+
if v not in self.user_literals:
|
|
478
|
+
numeric_args.append(f"{k}={v}")
|
|
479
|
+
|
|
480
|
+
return UngroundedCallResult(
|
|
481
|
+
is_ungrounded=bool(numeric_args),
|
|
482
|
+
numeric_args=numeric_args,
|
|
483
|
+
has_bindings=bool(self.bindings.bindings),
|
|
484
|
+
message=result.reason,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
return UngroundedCallResult(is_ungrounded=False)
|
|
488
|
+
|
|
489
|
+
def should_auto_rebound(self, tool_name: str) -> bool:
|
|
490
|
+
"""Check if a tool should auto-rebound."""
|
|
491
|
+
return self.is_idempotent_math_tool(tool_name) and bool(self.bindings.bindings)
|
|
492
|
+
|
|
493
|
+
def check_tool_preconditions(
|
|
494
|
+
self,
|
|
495
|
+
tool_name: str,
|
|
496
|
+
arguments: dict[str, Any],
|
|
497
|
+
) -> tuple[bool, str | None]:
|
|
498
|
+
"""Check if a tool's preconditions are met (alias)."""
|
|
499
|
+
return self.check_preconditions(tool_name, arguments)
|
|
500
|
+
|
|
501
|
+
def try_soft_block_repair(
|
|
502
|
+
self,
|
|
503
|
+
tool_name: str,
|
|
504
|
+
arguments: dict[str, Any],
|
|
505
|
+
reason: SoftBlockReason | UngroundedCallResult,
|
|
506
|
+
) -> tuple[bool, dict[str, Any] | None, str | None]:
|
|
507
|
+
"""Try to repair a soft-blocked call by rebinding."""
|
|
508
|
+
if isinstance(reason, UngroundedCallResult):
|
|
509
|
+
if not reason.has_bindings:
|
|
510
|
+
return False, None, None
|
|
511
|
+
numeric_args = reason.numeric_args
|
|
512
|
+
elif reason == SoftBlockReason.UNGROUNDED_ARGS:
|
|
513
|
+
numeric_args = []
|
|
514
|
+
for k, v in arguments.items():
|
|
515
|
+
if isinstance(v, (int, float)) and v not in self.user_literals:
|
|
516
|
+
numeric_args.append(f"{k}={v}")
|
|
517
|
+
if not numeric_args or not self.bindings.bindings:
|
|
518
|
+
fallback = (
|
|
519
|
+
f"Cannot call {tool_name} with literal values. "
|
|
520
|
+
f"Please compute the required values first."
|
|
521
|
+
)
|
|
522
|
+
return False, None, fallback
|
|
523
|
+
else:
|
|
524
|
+
return False, None, None
|
|
525
|
+
|
|
526
|
+
repaired = dict(arguments)
|
|
527
|
+
any_repaired = False
|
|
528
|
+
for arg_str in numeric_args:
|
|
529
|
+
if "=" in arg_str:
|
|
530
|
+
key, val_str = arg_str.split("=", 1)
|
|
531
|
+
try:
|
|
532
|
+
val = float(val_str)
|
|
533
|
+
for binding in self.bindings.bindings.values():
|
|
534
|
+
if isinstance(binding.raw_value, (int, float)):
|
|
535
|
+
if abs(binding.raw_value - val) < 1e-9:
|
|
536
|
+
repaired[key] = f"${binding.id}"
|
|
537
|
+
log.info(f"Repaired {key}={val} -> ${binding.id}")
|
|
538
|
+
any_repaired = True
|
|
539
|
+
break
|
|
540
|
+
except ValueError:
|
|
541
|
+
pass
|
|
542
|
+
|
|
543
|
+
if any_repaired:
|
|
544
|
+
return True, repaired, None
|
|
545
|
+
else:
|
|
546
|
+
fallback = (
|
|
547
|
+
f"Cannot auto-repair call to {tool_name}. "
|
|
548
|
+
f"No matching bindings found for {numeric_args}."
|
|
549
|
+
)
|
|
550
|
+
return False, None, fallback
|
|
551
|
+
|
|
552
|
+
# =========================================================================
|
|
553
|
+
# Per-Tool Tracking
|
|
554
|
+
# =========================================================================
|
|
555
|
+
|
|
556
|
+
def get_tool_call_count(self, tool_name: str) -> int:
|
|
557
|
+
"""Get number of times a tool has been called."""
|
|
558
|
+
base_name = tool_name.split(".")[-1] if "." in tool_name else tool_name
|
|
559
|
+
return self.tool_call_counts.get(base_name.lower(), 0)
|
|
560
|
+
|
|
561
|
+
def increment_tool_call(self, tool_name: str) -> None:
|
|
562
|
+
"""Increment the call count for a tool."""
|
|
563
|
+
base_name = tool_name.split(".")[-1] if "." in tool_name else tool_name
|
|
564
|
+
key = base_name.lower()
|
|
565
|
+
self.tool_call_counts[key] = self.tool_call_counts.get(key, 0) + 1
|
|
566
|
+
if self.per_tool_guard is not None:
|
|
567
|
+
self.per_tool_guard.record_call(tool_name)
|
|
568
|
+
|
|
569
|
+
def track_tool_call(self, tool_name: str) -> PerToolCallStatus:
|
|
570
|
+
"""Track a tool call and return its status."""
|
|
571
|
+
base_name = tool_name.split(".")[-1] if "." in tool_name else tool_name
|
|
572
|
+
count = self.get_tool_call_count(tool_name)
|
|
573
|
+
|
|
574
|
+
# If per_tool_limit is 0 or negative, limits are disabled
|
|
575
|
+
requires_justification = (
|
|
576
|
+
self.per_tool_limit > 0 and count >= self.per_tool_limit
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
return PerToolCallStatus(
|
|
580
|
+
tool_name=base_name,
|
|
581
|
+
call_count=count,
|
|
582
|
+
max_calls=self.per_tool_limit,
|
|
583
|
+
requires_justification=requires_justification,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
def format_tool_limit_warning(self, tool_name: str) -> str:
|
|
587
|
+
"""Format a warning when tool has been called too many times."""
|
|
588
|
+
count = self.get_tool_call_count(tool_name)
|
|
589
|
+
return (
|
|
590
|
+
f"Tool '{tool_name}' has been called {count} times (limit: {self.per_tool_limit}).\n"
|
|
591
|
+
"Consider using cached results or computed values instead."
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
def check_per_tool_limit(self, tool_name: str) -> GuardResult:
|
|
595
|
+
"""Check if tool has exceeded its per-turn limit."""
|
|
596
|
+
if self.per_tool_guard is None:
|
|
597
|
+
return GuardResult(verdict=GuardVerdict.ALLOW)
|
|
598
|
+
|
|
599
|
+
return self.per_tool_guard.check(tool_name, {})
|
|
600
|
+
|
|
601
|
+
# =========================================================================
|
|
602
|
+
# Runaway Detection
|
|
603
|
+
# =========================================================================
|
|
604
|
+
|
|
605
|
+
def check_runaway(self, tool_name: str | None = None) -> RunawayStatus:
|
|
606
|
+
"""Check if we should stop tool execution."""
|
|
607
|
+
if self.runaway_guard:
|
|
608
|
+
result = self.runaway_guard.check(tool_name or "", {})
|
|
609
|
+
if result.blocked:
|
|
610
|
+
return RunawayStatus(
|
|
611
|
+
should_stop=True,
|
|
612
|
+
reason=result.reason,
|
|
613
|
+
saturation_detected="saturation" in result.reason.lower(),
|
|
614
|
+
degenerate_detected="degenerate" in result.reason.lower(),
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
if self.budget_guard and tool_name:
|
|
618
|
+
if self.is_discovery_tool(tool_name):
|
|
619
|
+
status = self.budget_guard.get_status()
|
|
620
|
+
if status["discovery"]["used"] >= status["discovery"]["limit"]:
|
|
621
|
+
return RunawayStatus(
|
|
622
|
+
should_stop=True,
|
|
623
|
+
reason="Discovery budget exhausted",
|
|
624
|
+
budget_exhausted=True,
|
|
625
|
+
calls_remaining=0,
|
|
626
|
+
)
|
|
627
|
+
else:
|
|
628
|
+
status = self.budget_guard.get_status()
|
|
629
|
+
if status["execution"]["used"] >= status["execution"]["limit"]:
|
|
630
|
+
return RunawayStatus(
|
|
631
|
+
should_stop=True,
|
|
632
|
+
reason="Execution budget exhausted",
|
|
633
|
+
budget_exhausted=True,
|
|
634
|
+
calls_remaining=0,
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if self.budget_guard:
|
|
638
|
+
status = self.budget_guard.get_status()
|
|
639
|
+
total_used = status["discovery"]["used"] + status["execution"]["used"]
|
|
640
|
+
if total_used >= status["total"]["limit"]:
|
|
641
|
+
return RunawayStatus(
|
|
642
|
+
should_stop=True,
|
|
643
|
+
reason="Total tool budget exhausted",
|
|
644
|
+
budget_exhausted=True,
|
|
645
|
+
calls_remaining=0,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
return RunawayStatus(should_stop=False)
|
|
649
|
+
|
|
650
|
+
# =========================================================================
|
|
651
|
+
# Formatting
|
|
652
|
+
# =========================================================================
|
|
653
|
+
|
|
654
|
+
def format_state_for_model(self, max_items: int = 10) -> str:
|
|
655
|
+
"""Generate compact state summary."""
|
|
656
|
+
parts = []
|
|
657
|
+
|
|
658
|
+
bindings_str = self.bindings.format_for_model()
|
|
659
|
+
if bindings_str:
|
|
660
|
+
parts.append(bindings_str)
|
|
661
|
+
|
|
662
|
+
cache_str = self.cache.format_state(max_items)
|
|
663
|
+
if cache_str:
|
|
664
|
+
parts.append(cache_str)
|
|
665
|
+
|
|
666
|
+
return "\n\n".join(parts)
|
|
667
|
+
|
|
668
|
+
def format_budget_status(self) -> str:
|
|
669
|
+
"""Format current budget status."""
|
|
670
|
+
if not self.budget_guard:
|
|
671
|
+
return ""
|
|
672
|
+
|
|
673
|
+
status = self.budget_guard.get_status()
|
|
674
|
+
return (
|
|
675
|
+
f"Discovery: {status['discovery']['used']}/{status['discovery']['limit']} | "
|
|
676
|
+
f"Execution: {status['execution']['used']}/{status['execution']['limit']}"
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
def format_bindings_for_model(self) -> str:
|
|
680
|
+
"""Format bindings summary for model context."""
|
|
681
|
+
return self.bindings.format_for_model()
|
|
682
|
+
|
|
683
|
+
def get_duplicate_count(self) -> int:
|
|
684
|
+
"""Get number of duplicate tool calls detected."""
|
|
685
|
+
return self.cache.duplicate_count
|
|
686
|
+
|
|
687
|
+
def format_discovery_exhausted_message(self) -> str:
|
|
688
|
+
"""Format message when discovery budget is exhausted."""
|
|
689
|
+
state_summary = self.format_state_for_model()
|
|
690
|
+
return (
|
|
691
|
+
"**Discovery budget exhausted.** You have searched/listed tools enough times.\n\n"
|
|
692
|
+
f"{state_summary}\n\n"
|
|
693
|
+
"Please proceed with calling tools using the schemas you already have, "
|
|
694
|
+
"or provide your answer using the computed values above."
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
def format_execution_exhausted_message(self) -> str:
|
|
698
|
+
"""Format message when execution budget is exhausted."""
|
|
699
|
+
state_summary = self.format_state_for_model()
|
|
700
|
+
return (
|
|
701
|
+
"**Execution budget exhausted.** No more tool calls allowed.\n\n"
|
|
702
|
+
f"{state_summary}\n\n"
|
|
703
|
+
"Please provide your final answer using the computed values above."
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
def format_budget_exhausted_message(self) -> str:
|
|
707
|
+
"""Format message when total budget is exhausted."""
|
|
708
|
+
state_summary = self.format_state_for_model()
|
|
709
|
+
return (
|
|
710
|
+
"**Tool budget exhausted.** You have made the maximum allowed tool calls.\n\n"
|
|
711
|
+
f"{state_summary}\n\n"
|
|
712
|
+
"Please provide your final answer using the computed values above."
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
def format_saturation_message(self, last_value: float) -> str:
|
|
716
|
+
"""Format message when numeric saturation is detected."""
|
|
717
|
+
state_summary = self.format_state_for_model()
|
|
718
|
+
return (
|
|
719
|
+
f"**Numeric saturation detected.** Last value: {last_value}\n\n"
|
|
720
|
+
"Values have converged or reached machine precision limits.\n\n"
|
|
721
|
+
f"{state_summary}\n\n"
|
|
722
|
+
"Please provide your final answer using the computed values above."
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
def format_unused_warning(self) -> str:
|
|
726
|
+
"""Format warning about unused tool results."""
|
|
727
|
+
unused = [b for b in self.bindings.bindings.values() if not b.used]
|
|
728
|
+
if not unused:
|
|
729
|
+
return ""
|
|
730
|
+
names = ", ".join(f"${b.id}" for b in unused[:5])
|
|
731
|
+
if len(unused) > 5:
|
|
732
|
+
names += f" (+{len(unused) - 5} more)"
|
|
733
|
+
return f"Unused tool results: {names}. Consider using these values."
|
|
734
|
+
|
|
735
|
+
def extract_bindings_from_text(self, text: str) -> list[ValueBinding]:
|
|
736
|
+
"""Extract value bindings from assistant text."""
|
|
737
|
+
new_bindings: list[ValueBinding] = []
|
|
738
|
+
|
|
739
|
+
pattern = re.compile(
|
|
740
|
+
r"([a-zA-Zα-ωΑ-Ω_][a-zA-Zα-ωΑ-Ω0-9_]*(?:_[a-zA-Z0-9]+)?)\s*=\s*"
|
|
741
|
+
r"(-?\d+\.?\d*(?:[eE][+-]?\d+)?)"
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
for match in pattern.finditer(text):
|
|
745
|
+
var_name = match.group(1)
|
|
746
|
+
try:
|
|
747
|
+
value = float(match.group(2))
|
|
748
|
+
|
|
749
|
+
context_start = max(0, match.start() - 10)
|
|
750
|
+
context = text[context_start : match.start()]
|
|
751
|
+
if any(c in context for c in ["==", "!=", "if ", "while ", "for "]):
|
|
752
|
+
continue
|
|
753
|
+
|
|
754
|
+
binding = self.bindings.bind(
|
|
755
|
+
tool_name="assistant_text",
|
|
756
|
+
arguments={"source": "extracted"},
|
|
757
|
+
value=value,
|
|
758
|
+
aliases=[var_name],
|
|
759
|
+
)
|
|
760
|
+
new_bindings.append(binding)
|
|
761
|
+
log.debug(
|
|
762
|
+
f"Extracted binding: ${binding.id} = {value} (alias: {var_name})"
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
except ValueError:
|
|
766
|
+
continue
|
|
767
|
+
|
|
768
|
+
return new_bindings
|
|
769
|
+
|
|
770
|
+
# =========================================================================
|
|
771
|
+
# Lifecycle
|
|
772
|
+
# =========================================================================
|
|
773
|
+
|
|
774
|
+
def reset_for_new_prompt(self) -> None:
|
|
775
|
+
"""Reset per-prompt state."""
|
|
776
|
+
self.bindings.reset()
|
|
777
|
+
self.user_literals.clear()
|
|
778
|
+
self.stated_values.clear()
|
|
779
|
+
self.tool_call_counts.clear()
|
|
780
|
+
|
|
781
|
+
if self.budget_guard:
|
|
782
|
+
self.budget_guard.reset()
|
|
783
|
+
if self.ungrounded_guard:
|
|
784
|
+
self.ungrounded_guard.reset()
|
|
785
|
+
if self.runaway_guard:
|
|
786
|
+
self.runaway_guard.reset()
|
|
787
|
+
if self.per_tool_guard:
|
|
788
|
+
self.per_tool_guard.reset()
|
|
789
|
+
|
|
790
|
+
log.debug("Reset tool state for new prompt")
|
|
791
|
+
|
|
792
|
+
def clear(self) -> None:
|
|
793
|
+
"""Clear all state (new conversation)."""
|
|
794
|
+
self.bindings.reset()
|
|
795
|
+
self.cache.reset()
|
|
796
|
+
self.user_literals.clear()
|
|
797
|
+
self.stated_values.clear()
|
|
798
|
+
self.reset_for_new_prompt()
|
|
799
|
+
log.debug("Tool state cleared")
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
# Global instance
|
|
803
|
+
_tool_state: ToolStateManager | None = None
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def get_tool_state() -> ToolStateManager:
|
|
807
|
+
"""Get or create the global tool state manager."""
|
|
808
|
+
global _tool_state
|
|
809
|
+
if _tool_state is None:
|
|
810
|
+
_tool_state = ToolStateManager()
|
|
811
|
+
return _tool_state
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
def reset_tool_state() -> None:
|
|
815
|
+
"""Reset tool state (new conversation)."""
|
|
816
|
+
global _tool_state
|
|
817
|
+
if _tool_state:
|
|
818
|
+
_tool_state.clear()
|
|
819
|
+
_tool_state = ToolStateManager()
|