chuk-ai-session-manager 0.7.1__py3-none-any.whl → 0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chuk_ai_session_manager/__init__.py +84 -40
- chuk_ai_session_manager/api/__init__.py +1 -1
- chuk_ai_session_manager/api/simple_api.py +53 -59
- chuk_ai_session_manager/exceptions.py +31 -17
- chuk_ai_session_manager/guards/__init__.py +118 -0
- chuk_ai_session_manager/guards/bindings.py +217 -0
- chuk_ai_session_manager/guards/cache.py +163 -0
- chuk_ai_session_manager/guards/manager.py +819 -0
- chuk_ai_session_manager/guards/models.py +498 -0
- chuk_ai_session_manager/guards/ungrounded.py +159 -0
- chuk_ai_session_manager/infinite_conversation.py +86 -79
- chuk_ai_session_manager/memory/__init__.py +247 -0
- chuk_ai_session_manager/memory/artifacts_bridge.py +469 -0
- chuk_ai_session_manager/memory/context_packer.py +347 -0
- chuk_ai_session_manager/memory/fault_handler.py +507 -0
- chuk_ai_session_manager/memory/manifest.py +307 -0
- chuk_ai_session_manager/memory/models.py +1084 -0
- chuk_ai_session_manager/memory/mutation_log.py +186 -0
- chuk_ai_session_manager/memory/pack_cache.py +206 -0
- chuk_ai_session_manager/memory/page_table.py +275 -0
- chuk_ai_session_manager/memory/prefetcher.py +192 -0
- chuk_ai_session_manager/memory/tlb.py +247 -0
- chuk_ai_session_manager/memory/vm_prompts.py +238 -0
- chuk_ai_session_manager/memory/working_set.py +574 -0
- chuk_ai_session_manager/models/__init__.py +21 -9
- chuk_ai_session_manager/models/event_source.py +3 -1
- chuk_ai_session_manager/models/event_type.py +10 -1
- chuk_ai_session_manager/models/session.py +103 -68
- chuk_ai_session_manager/models/session_event.py +69 -68
- chuk_ai_session_manager/models/session_metadata.py +9 -10
- chuk_ai_session_manager/models/session_run.py +21 -22
- chuk_ai_session_manager/models/token_usage.py +76 -76
- chuk_ai_session_manager/procedural_memory/__init__.py +70 -0
- chuk_ai_session_manager/procedural_memory/formatter.py +407 -0
- chuk_ai_session_manager/procedural_memory/manager.py +523 -0
- chuk_ai_session_manager/procedural_memory/models.py +371 -0
- chuk_ai_session_manager/sample_tools.py +79 -46
- chuk_ai_session_manager/session_aware_tool_processor.py +27 -16
- chuk_ai_session_manager/session_manager.py +238 -197
- chuk_ai_session_manager/session_prompt_builder.py +163 -111
- chuk_ai_session_manager/session_storage.py +45 -52
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.dist-info}/METADATA +79 -3
- chuk_ai_session_manager-0.8.dist-info/RECORD +45 -0
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.dist-info}/WHEEL +1 -1
- chuk_ai_session_manager-0.7.1.dist-info/RECORD +0 -22
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
# chuk_ai_session_manager/guards/models.py
|
|
2
|
+
"""Pydantic models for tool state management.
|
|
3
|
+
|
|
4
|
+
All state-related models in one place, fully type-safe.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from pydantic import BaseModel, Field
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def classify_value_type(value: Any) -> "ValueType":
|
|
19
|
+
"""Classify the type of a value for binding."""
|
|
20
|
+
if isinstance(value, (int, float)):
|
|
21
|
+
return ValueType.NUMBER
|
|
22
|
+
if isinstance(value, str):
|
|
23
|
+
try:
|
|
24
|
+
float(value)
|
|
25
|
+
return ValueType.NUMBER
|
|
26
|
+
except (ValueError, TypeError):
|
|
27
|
+
return ValueType.STRING
|
|
28
|
+
if isinstance(value, list):
|
|
29
|
+
return ValueType.LIST
|
|
30
|
+
if isinstance(value, dict):
|
|
31
|
+
return ValueType.OBJECT
|
|
32
|
+
return ValueType.UNKNOWN
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def compute_args_hash(arguments: dict[str, Any]) -> str:
|
|
36
|
+
"""Compute a stable hash of tool arguments."""
|
|
37
|
+
args_str = json.dumps(arguments, sort_keys=True, default=str)
|
|
38
|
+
return hashlib.sha256(args_str.encode()).hexdigest()[:16]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ValueType(str, Enum):
|
|
42
|
+
"""Types for bound values."""
|
|
43
|
+
|
|
44
|
+
NUMBER = "number"
|
|
45
|
+
STRING = "string"
|
|
46
|
+
JSON = "json"
|
|
47
|
+
LIST = "list"
|
|
48
|
+
OBJECT = "object"
|
|
49
|
+
UNKNOWN = "unknown"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ToolClassification:
|
|
53
|
+
"""Central definitions for tool classification.
|
|
54
|
+
|
|
55
|
+
Guards and managers should use these definitions rather than
|
|
56
|
+
maintaining their own hardcoded sets.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# Discovery tools - search/list/get schemas (count against discovery budget)
|
|
60
|
+
DISCOVERY_TOOLS: frozenset[str] = frozenset(
|
|
61
|
+
{
|
|
62
|
+
"list_tools",
|
|
63
|
+
"search_tools",
|
|
64
|
+
"get_tool_schema",
|
|
65
|
+
"get_tool_schemas",
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Idempotent math tools - safe to call multiple times, exempt from per-tool limits
|
|
70
|
+
IDEMPOTENT_MATH_TOOLS: frozenset[str] = frozenset(
|
|
71
|
+
{
|
|
72
|
+
"add",
|
|
73
|
+
"subtract",
|
|
74
|
+
"multiply",
|
|
75
|
+
"divide",
|
|
76
|
+
"sqrt",
|
|
77
|
+
"pow",
|
|
78
|
+
"power",
|
|
79
|
+
"log",
|
|
80
|
+
"exp",
|
|
81
|
+
"sin",
|
|
82
|
+
"cos",
|
|
83
|
+
"tan",
|
|
84
|
+
"abs",
|
|
85
|
+
"floor",
|
|
86
|
+
"ceil",
|
|
87
|
+
"round",
|
|
88
|
+
}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Parameterized tools - require computed input values (precondition guard)
|
|
92
|
+
PARAMETERIZED_TOOLS: frozenset[str] = frozenset(
|
|
93
|
+
{
|
|
94
|
+
"normal_cdf",
|
|
95
|
+
"normal_pdf",
|
|
96
|
+
"normal_sf",
|
|
97
|
+
"t_cdf",
|
|
98
|
+
"t_sf",
|
|
99
|
+
"t_test",
|
|
100
|
+
"chi_cdf",
|
|
101
|
+
"chi_sf",
|
|
102
|
+
"chi_square",
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def is_discovery_tool(cls, tool_name: str) -> bool:
|
|
108
|
+
"""Check if tool is a discovery tool."""
|
|
109
|
+
base = (
|
|
110
|
+
tool_name.split(".")[-1].lower() if "." in tool_name else tool_name.lower()
|
|
111
|
+
)
|
|
112
|
+
return base in cls.DISCOVERY_TOOLS
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def is_idempotent_math_tool(cls, tool_name: str) -> bool:
|
|
116
|
+
"""Check if tool is an idempotent math tool."""
|
|
117
|
+
base = (
|
|
118
|
+
tool_name.split(".")[-1].lower() if "." in tool_name else tool_name.lower()
|
|
119
|
+
)
|
|
120
|
+
return base in cls.IDEMPOTENT_MATH_TOOLS
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def is_parameterized_tool(cls, tool_name: str) -> bool:
|
|
124
|
+
"""Check if tool requires computed values."""
|
|
125
|
+
base = (
|
|
126
|
+
tool_name.split(".")[-1].lower() if "." in tool_name else tool_name.lower()
|
|
127
|
+
)
|
|
128
|
+
return base in cls.PARAMETERIZED_TOOLS
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class RuntimeMode(str, Enum):
|
|
132
|
+
"""Runtime enforcement mode presets."""
|
|
133
|
+
|
|
134
|
+
SMOOTH = "smooth" # Feels like ChatGPT/Claude UI - warn but allow
|
|
135
|
+
STRICT = "strict" # Best for solver/math/physics - hard enforcement
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class EnforcementLevel(str, Enum):
|
|
139
|
+
"""Enforcement level for guards and constraints."""
|
|
140
|
+
|
|
141
|
+
OFF = "off" # No enforcement
|
|
142
|
+
WARN = "warn" # Proceed but log warning
|
|
143
|
+
BLOCK = "block" # Do not execute, return error
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class CacheScope(str, Enum):
|
|
147
|
+
"""Scope for result caching."""
|
|
148
|
+
|
|
149
|
+
TURN = "turn" # Cache per conversation turn
|
|
150
|
+
SESSION = "session" # Cache for entire session
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class UnusedResultAction(str, Enum):
|
|
154
|
+
"""Action to take for unused tool results."""
|
|
155
|
+
|
|
156
|
+
OFF = "off" # No enforcement
|
|
157
|
+
WARN = "warn" # Warn but continue
|
|
158
|
+
BLOCK_NEXT_TOOL = "block-next-tool" # Block next tool call
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class ValueBinding(BaseModel):
|
|
162
|
+
"""A bound value from a tool result with a stable ID.
|
|
163
|
+
|
|
164
|
+
Every tool result gets assigned a value ID (v1, v2, v3...) that can
|
|
165
|
+
be referenced in subsequent tool calls using $vN syntax.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
id: str = Field(..., description="Value ID, e.g., 'v1', 'v2'")
|
|
169
|
+
tool_name: str = Field(..., description="Name of the tool that produced this value")
|
|
170
|
+
args_hash: str = Field(..., description="Hash of arguments for dedup")
|
|
171
|
+
raw_value: Any = Field(..., description="The raw value from the tool")
|
|
172
|
+
value_type: ValueType = Field(..., description="Classified type of the value")
|
|
173
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
174
|
+
aliases: list[str] = Field(default_factory=list, description="Model-provided names")
|
|
175
|
+
used: bool = Field(default=False, description="Has this value been referenced?")
|
|
176
|
+
used_in: list[str] = Field(
|
|
177
|
+
default_factory=list, description="Tool calls that used this"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def typed_value(self) -> Any:
|
|
184
|
+
"""Get the value with appropriate type coercion."""
|
|
185
|
+
if self.value_type == ValueType.NUMBER:
|
|
186
|
+
if isinstance(self.raw_value, (int, float)):
|
|
187
|
+
return float(self.raw_value)
|
|
188
|
+
try:
|
|
189
|
+
return float(self.raw_value)
|
|
190
|
+
except (ValueError, TypeError):
|
|
191
|
+
return self.raw_value
|
|
192
|
+
return self.raw_value
|
|
193
|
+
|
|
194
|
+
def format_for_model(self) -> str:
|
|
195
|
+
"""Format this binding for display to the model."""
|
|
196
|
+
if self.value_type == ValueType.NUMBER:
|
|
197
|
+
val = self.typed_value
|
|
198
|
+
if isinstance(val, float):
|
|
199
|
+
if abs(val) < 0.0001 or abs(val) > 10000:
|
|
200
|
+
formatted = f"{val:.6e}"
|
|
201
|
+
else:
|
|
202
|
+
formatted = f"{val:.6f}"
|
|
203
|
+
else:
|
|
204
|
+
formatted = str(val)
|
|
205
|
+
elif self.value_type == ValueType.STRING:
|
|
206
|
+
raw_str = str(self.raw_value)
|
|
207
|
+
formatted = f'"{raw_str}"' if len(raw_str) < 50 else f'"{raw_str[:47]}..."'
|
|
208
|
+
elif self.value_type == ValueType.LIST:
|
|
209
|
+
lst = self.raw_value
|
|
210
|
+
if isinstance(lst, list):
|
|
211
|
+
if len(lst) == 0:
|
|
212
|
+
formatted = "[]"
|
|
213
|
+
elif len(lst) <= 3:
|
|
214
|
+
formatted = str(lst)[:60]
|
|
215
|
+
else:
|
|
216
|
+
formatted = f"[{len(lst)} items]"
|
|
217
|
+
else:
|
|
218
|
+
formatted = str(lst)[:50]
|
|
219
|
+
elif self.value_type == ValueType.OBJECT:
|
|
220
|
+
obj = self.raw_value
|
|
221
|
+
if isinstance(obj, dict):
|
|
222
|
+
keys = list(obj.keys())
|
|
223
|
+
if len(keys) == 0:
|
|
224
|
+
formatted = "{}"
|
|
225
|
+
elif len(keys) <= 3:
|
|
226
|
+
formatted = f"{{keys: {keys}}}"
|
|
227
|
+
else:
|
|
228
|
+
formatted = f"{{object with {len(keys)} keys}}"
|
|
229
|
+
else:
|
|
230
|
+
formatted = str(obj)[:50]
|
|
231
|
+
else:
|
|
232
|
+
formatted = str(self.raw_value)[:50]
|
|
233
|
+
|
|
234
|
+
alias_str = f" (aka {', '.join(self.aliases)})" if self.aliases else ""
|
|
235
|
+
return f"${self.id} = {formatted}{alias_str} # from {self.tool_name}"
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class ReferenceCheckResult(BaseModel):
|
|
239
|
+
"""Result of checking references in tool arguments."""
|
|
240
|
+
|
|
241
|
+
valid: bool = Field(..., description="Whether all references are valid")
|
|
242
|
+
missing_refs: list[str] = Field(
|
|
243
|
+
default_factory=list, description="References that don't exist"
|
|
244
|
+
)
|
|
245
|
+
resolved_refs: dict[str, Any] = Field(
|
|
246
|
+
default_factory=dict, description="ref -> resolved value"
|
|
247
|
+
)
|
|
248
|
+
message: str = Field(default="", description="Human-readable message")
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class PerToolCallStatus(BaseModel):
|
|
252
|
+
"""Status of per-tool call tracking for anti-thrash."""
|
|
253
|
+
|
|
254
|
+
tool_name: str
|
|
255
|
+
call_count: int = Field(default=0)
|
|
256
|
+
max_calls: int = Field(
|
|
257
|
+
default=3, description="Default max before requiring justification"
|
|
258
|
+
)
|
|
259
|
+
requires_justification: bool = Field(default=False)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class UngroundedCallResult(BaseModel):
|
|
263
|
+
"""Result of checking if a tool call is ungrounded.
|
|
264
|
+
|
|
265
|
+
An ungrounded call is one where:
|
|
266
|
+
- Arguments contain numeric literals
|
|
267
|
+
- No $vN references are present
|
|
268
|
+
- Values exist that could have been referenced
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
is_ungrounded: bool = Field(default=False)
|
|
272
|
+
numeric_args: list[str] = Field(default_factory=list)
|
|
273
|
+
has_bindings: bool = Field(default=False)
|
|
274
|
+
message: str = Field(default="")
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class RunawayStatus(BaseModel):
|
|
278
|
+
"""Status of runaway detection checks."""
|
|
279
|
+
|
|
280
|
+
should_stop: bool = Field(default=False)
|
|
281
|
+
reason: str | None = Field(default=None)
|
|
282
|
+
budget_exhausted: bool = Field(default=False)
|
|
283
|
+
degenerate_detected: bool = Field(default=False)
|
|
284
|
+
saturation_detected: bool = Field(default=False)
|
|
285
|
+
calls_remaining: int = Field(default=0)
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def message(self) -> str:
|
|
289
|
+
"""Get user-friendly message about why we should stop."""
|
|
290
|
+
if self.budget_exhausted:
|
|
291
|
+
return f"Tool call budget exhausted ({self.calls_remaining} remaining). Use computed values to answer."
|
|
292
|
+
if self.degenerate_detected:
|
|
293
|
+
return "Degenerate output detected (0.0, 1.0, or repeating). Results have saturated."
|
|
294
|
+
if self.saturation_detected:
|
|
295
|
+
return (
|
|
296
|
+
"Numeric saturation detected. Values are at machine precision limits."
|
|
297
|
+
)
|
|
298
|
+
return self.reason or "Unknown stop reason"
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class CachedToolResult(BaseModel):
|
|
302
|
+
"""A cached tool call result with metadata."""
|
|
303
|
+
|
|
304
|
+
tool_name: str
|
|
305
|
+
arguments: dict[str, Any]
|
|
306
|
+
result: Any
|
|
307
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
308
|
+
call_count: int = Field(default=1)
|
|
309
|
+
|
|
310
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
311
|
+
|
|
312
|
+
@property
|
|
313
|
+
def signature(self) -> str:
|
|
314
|
+
"""Generate unique signature for this tool call."""
|
|
315
|
+
args_str = json.dumps(self.arguments, sort_keys=True, default=str)
|
|
316
|
+
return f"{self.tool_name}:{args_str}"
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def is_numeric(self) -> bool:
|
|
320
|
+
"""Check if result is numeric."""
|
|
321
|
+
if isinstance(self.result, (int, float)):
|
|
322
|
+
return True
|
|
323
|
+
if isinstance(self.result, str):
|
|
324
|
+
try:
|
|
325
|
+
float(self.result)
|
|
326
|
+
return True
|
|
327
|
+
except (ValueError, TypeError):
|
|
328
|
+
return False
|
|
329
|
+
return False
|
|
330
|
+
|
|
331
|
+
@property
|
|
332
|
+
def numeric_value(self) -> float | None:
|
|
333
|
+
"""Extract numeric value if available."""
|
|
334
|
+
if isinstance(self.result, (int, float)):
|
|
335
|
+
return float(self.result)
|
|
336
|
+
if isinstance(self.result, str):
|
|
337
|
+
try:
|
|
338
|
+
return float(self.result)
|
|
339
|
+
except (ValueError, TypeError):
|
|
340
|
+
return None
|
|
341
|
+
return None
|
|
342
|
+
|
|
343
|
+
def format_compact(self) -> str:
|
|
344
|
+
"""Format for compact state display."""
|
|
345
|
+
if self.is_numeric:
|
|
346
|
+
val = self.numeric_value
|
|
347
|
+
if val is not None:
|
|
348
|
+
if abs(val) < 0.0001 or abs(val) > 10000:
|
|
349
|
+
return f"{self.tool_name}({self._format_args()}) = {val:.6e}"
|
|
350
|
+
else:
|
|
351
|
+
return f"{self.tool_name}({self._format_args()}) = {val:.6f}"
|
|
352
|
+
result_str = str(self.result)
|
|
353
|
+
if len(result_str) > 50:
|
|
354
|
+
result_str = result_str[:47] + "..."
|
|
355
|
+
return f"{self.tool_name}({self._format_args()}) = {result_str}"
|
|
356
|
+
|
|
357
|
+
def _format_args(self) -> str:
|
|
358
|
+
"""Format arguments compactly."""
|
|
359
|
+
if not self.arguments:
|
|
360
|
+
return ""
|
|
361
|
+
if len(self.arguments) == 1:
|
|
362
|
+
val = list(self.arguments.values())[0]
|
|
363
|
+
if isinstance(val, (int, float)):
|
|
364
|
+
return str(val)
|
|
365
|
+
parts = []
|
|
366
|
+
for k, v in self.arguments.items():
|
|
367
|
+
if isinstance(v, (int, float)):
|
|
368
|
+
parts.append(f"{k}={v}")
|
|
369
|
+
elif isinstance(v, str) and len(v) < 20:
|
|
370
|
+
parts.append(f'{k}="{v}"')
|
|
371
|
+
else:
|
|
372
|
+
parts.append(f"{k}=...")
|
|
373
|
+
return ", ".join(parts)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class NamedVariable(BaseModel):
|
|
377
|
+
"""A named variable binding from tool results."""
|
|
378
|
+
|
|
379
|
+
name: str
|
|
380
|
+
value: float
|
|
381
|
+
units: str | None = Field(default=None)
|
|
382
|
+
source_tool: str | None = Field(default=None)
|
|
383
|
+
source_args: dict[str, Any] | None = Field(default=None)
|
|
384
|
+
|
|
385
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
386
|
+
|
|
387
|
+
def format_compact(self) -> str:
|
|
388
|
+
"""Format for state display."""
|
|
389
|
+
if self.units:
|
|
390
|
+
return f"{self.name} = {self.value:.6f} {self.units}"
|
|
391
|
+
return f"{self.name} = {self.value:.6f}"
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class SoftBlockReason(str, Enum):
|
|
395
|
+
"""Reasons why a tool call might be soft-blocked."""
|
|
396
|
+
|
|
397
|
+
UNGROUNDED_ARGS = "ungrounded_args"
|
|
398
|
+
MISSING_REFS = "missing_refs"
|
|
399
|
+
BUDGET_EXHAUSTED = "budget_exhausted"
|
|
400
|
+
PER_TOOL_LIMIT = "per_tool_limit"
|
|
401
|
+
MISSING_DEPENDENCY = "missing_dependency"
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class RepairAction(str, Enum):
|
|
405
|
+
"""Actions the runtime can take to repair a soft-blocked call."""
|
|
406
|
+
|
|
407
|
+
REBIND_FROM_EXISTING = "rebind_from_existing"
|
|
408
|
+
COMPUTE_MISSING = "compute_missing"
|
|
409
|
+
REWRITE_CALL = "rewrite_call"
|
|
410
|
+
SYMBOLIC_FALLBACK = "symbolic_fallback"
|
|
411
|
+
ASK_USER = "ask_user"
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class SoftBlock(BaseModel):
|
|
415
|
+
"""A soft block that can potentially be repaired."""
|
|
416
|
+
|
|
417
|
+
reason: SoftBlockReason
|
|
418
|
+
tool_name: str = Field(default="")
|
|
419
|
+
arguments: dict[str, Any] = Field(default_factory=dict)
|
|
420
|
+
message: str = Field(default="")
|
|
421
|
+
repair_attempts: int = Field(default=0)
|
|
422
|
+
max_repairs: int = Field(default=3)
|
|
423
|
+
|
|
424
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
425
|
+
|
|
426
|
+
@property
|
|
427
|
+
def can_repair(self) -> bool:
|
|
428
|
+
"""Check if we can attempt another repair."""
|
|
429
|
+
return self.repair_attempts < self.max_repairs
|
|
430
|
+
|
|
431
|
+
@property
|
|
432
|
+
def next_repair_action(self) -> RepairAction:
|
|
433
|
+
"""Get the next repair action to try."""
|
|
434
|
+
if self.reason == SoftBlockReason.UNGROUNDED_ARGS:
|
|
435
|
+
return RepairAction.REBIND_FROM_EXISTING
|
|
436
|
+
elif self.reason == SoftBlockReason.MISSING_REFS:
|
|
437
|
+
return RepairAction.COMPUTE_MISSING
|
|
438
|
+
elif self.reason == SoftBlockReason.MISSING_DEPENDENCY:
|
|
439
|
+
return RepairAction.COMPUTE_MISSING
|
|
440
|
+
return RepairAction.ASK_USER
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class RuntimeLimits(BaseModel):
|
|
444
|
+
"""Configuration for tool runtime enforcement.
|
|
445
|
+
|
|
446
|
+
Two preset modes:
|
|
447
|
+
- SMOOTH: Feels like ChatGPT/Claude UI - warn but allow, auto-retry
|
|
448
|
+
- STRICT: Best for solver/math - hard enforcement, dataflow discipline
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
# Discovery controls (search_tools, list_tools, get_tool_schema)
|
|
452
|
+
discovery_budget: int = Field(default=5)
|
|
453
|
+
|
|
454
|
+
# Execution controls (call_tool)
|
|
455
|
+
execution_budget: int = Field(default=12)
|
|
456
|
+
tool_budget_total: int = Field(default=15)
|
|
457
|
+
|
|
458
|
+
# Per-tool caps (anti-thrash)
|
|
459
|
+
# Set to 0 or negative to disable per-tool limits
|
|
460
|
+
per_tool_cap: int = Field(default=0)
|
|
461
|
+
|
|
462
|
+
# Cache behavior
|
|
463
|
+
cache_scope: CacheScope = Field(default=CacheScope.TURN)
|
|
464
|
+
|
|
465
|
+
# Binding enforcement
|
|
466
|
+
require_bindings: EnforcementLevel = Field(default=EnforcementLevel.WARN)
|
|
467
|
+
ungrounded_grace_calls: int = Field(default=1)
|
|
468
|
+
|
|
469
|
+
# Unused result enforcement
|
|
470
|
+
unused_results: UnusedResultAction = Field(default=UnusedResultAction.WARN)
|
|
471
|
+
|
|
472
|
+
@classmethod
|
|
473
|
+
def smooth(cls) -> "RuntimeLimits":
|
|
474
|
+
"""Preset for smooth UI-like experience."""
|
|
475
|
+
return cls(
|
|
476
|
+
discovery_budget=6,
|
|
477
|
+
execution_budget=15,
|
|
478
|
+
tool_budget_total=20,
|
|
479
|
+
per_tool_cap=0, # Unlimited per-tool calls
|
|
480
|
+
cache_scope=CacheScope.TURN,
|
|
481
|
+
require_bindings=EnforcementLevel.WARN,
|
|
482
|
+
ungrounded_grace_calls=2,
|
|
483
|
+
unused_results=UnusedResultAction.WARN,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
@classmethod
|
|
487
|
+
def strict(cls) -> "RuntimeLimits":
|
|
488
|
+
"""Preset for strict dataflow enforcement."""
|
|
489
|
+
return cls(
|
|
490
|
+
discovery_budget=4,
|
|
491
|
+
execution_budget=10,
|
|
492
|
+
tool_budget_total=12,
|
|
493
|
+
per_tool_cap=0, # Unlimited per-tool calls
|
|
494
|
+
cache_scope=CacheScope.TURN,
|
|
495
|
+
require_bindings=EnforcementLevel.BLOCK,
|
|
496
|
+
ungrounded_grace_calls=0,
|
|
497
|
+
unused_results=UnusedResultAction.WARN,
|
|
498
|
+
)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# chuk_ai_session_manager/guards/ungrounded.py
|
|
2
|
+
"""Ungrounded call guard - detects missing $vN references.
|
|
3
|
+
|
|
4
|
+
Catches when the model passes numeric literals that should have been
|
|
5
|
+
the result of prior computation. Enforces dataflow discipline.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import re
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
# Import base classes from chuk-tool-processor
|
|
17
|
+
from chuk_tool_processor.guards import BaseGuard, EnforcementLevel, GuardResult
|
|
18
|
+
|
|
19
|
+
# Reference pattern: $v1, $v2, ${v1}, ${myalias}
|
|
20
|
+
REFERENCE_PATTERN = re.compile(r"\$\{?([a-zA-Z_][a-zA-Z0-9_]*|v\d+)\}?")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class UngroundedGuardConfig(BaseModel):
|
|
24
|
+
"""Configuration for ungrounded call detection."""
|
|
25
|
+
|
|
26
|
+
# How many ungrounded calls before blocking
|
|
27
|
+
grace_calls: int = Field(default=1)
|
|
28
|
+
|
|
29
|
+
# Enforcement level
|
|
30
|
+
mode: EnforcementLevel = Field(default=EnforcementLevel.WARN)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UngroundedGuard(BaseGuard):
|
|
34
|
+
"""Guard that detects ungrounded numeric arguments.
|
|
35
|
+
|
|
36
|
+
An ungrounded call is when:
|
|
37
|
+
- Arguments contain numeric literals (int or float)
|
|
38
|
+
- No $vN references exist in the arguments
|
|
39
|
+
- The model should have used a computed value instead
|
|
40
|
+
|
|
41
|
+
This catches the anti-pattern where a model passes a literal number
|
|
42
|
+
that should have been the result of a prior computation.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
config: UngroundedGuardConfig | None = None,
|
|
48
|
+
get_user_literals: Any = None, # Callable[[], set[float]]
|
|
49
|
+
get_bindings: Any = None, # Callable[[], dict]
|
|
50
|
+
):
|
|
51
|
+
self.config = config or UngroundedGuardConfig()
|
|
52
|
+
self._get_user_literals = get_user_literals
|
|
53
|
+
self._get_bindings = get_bindings
|
|
54
|
+
self._ungrounded_count = 0
|
|
55
|
+
|
|
56
|
+
def check(
|
|
57
|
+
self,
|
|
58
|
+
tool_name: str,
|
|
59
|
+
arguments: dict[str, Any],
|
|
60
|
+
) -> GuardResult:
|
|
61
|
+
"""Check if tool call has ungrounded numeric arguments.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
tool_name: Name of the tool being called
|
|
65
|
+
arguments: Arguments passed to the tool
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
GuardResult - WARN or BLOCK if ungrounded
|
|
69
|
+
"""
|
|
70
|
+
if self.config.mode == EnforcementLevel.OFF:
|
|
71
|
+
return self.allow()
|
|
72
|
+
|
|
73
|
+
# Get user-provided literals (these are allowed)
|
|
74
|
+
user_literals = self._get_user_literals() if self._get_user_literals else set()
|
|
75
|
+
|
|
76
|
+
# Find numeric arguments not from user
|
|
77
|
+
numeric_args = self._find_numeric_args(arguments, user_literals)
|
|
78
|
+
if not numeric_args:
|
|
79
|
+
return self.allow()
|
|
80
|
+
|
|
81
|
+
# Check if any $vN references exist
|
|
82
|
+
args_str = json.dumps(arguments, default=str)
|
|
83
|
+
has_refs = bool(REFERENCE_PATTERN.search(args_str))
|
|
84
|
+
|
|
85
|
+
if has_refs:
|
|
86
|
+
# Has references - not ungrounded
|
|
87
|
+
return self.allow()
|
|
88
|
+
|
|
89
|
+
# Ungrounded call detected
|
|
90
|
+
self._ungrounded_count += 1
|
|
91
|
+
|
|
92
|
+
# Get available bindings for helpful message
|
|
93
|
+
bindings = self._get_bindings() if self._get_bindings else {}
|
|
94
|
+
has_bindings = bool(bindings)
|
|
95
|
+
|
|
96
|
+
# Build message
|
|
97
|
+
arg_names = ", ".join(f"`{name}`" for name in numeric_args.keys())
|
|
98
|
+
if has_bindings:
|
|
99
|
+
available = [f"${bid}" for bid in bindings.keys()]
|
|
100
|
+
message = (
|
|
101
|
+
f"Ungrounded call: `{tool_name}` has numeric arguments ({arg_names}) "
|
|
102
|
+
f"but no $vN references. Available values: {', '.join(available)}. "
|
|
103
|
+
"Did you mean to use a computed value?"
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
message = (
|
|
107
|
+
f"Ungrounded call: `{tool_name}` has numeric arguments ({arg_names}) "
|
|
108
|
+
"but no prior computations exist. Compute input values first."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Check enforcement level
|
|
112
|
+
if (
|
|
113
|
+
self.config.mode == EnforcementLevel.BLOCK
|
|
114
|
+
and self._ungrounded_count > self.config.grace_calls
|
|
115
|
+
):
|
|
116
|
+
return self.block(
|
|
117
|
+
reason=message,
|
|
118
|
+
ungrounded_count=self._ungrounded_count,
|
|
119
|
+
numeric_args=numeric_args,
|
|
120
|
+
has_bindings=has_bindings,
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
return self.warn(
|
|
124
|
+
reason=message,
|
|
125
|
+
ungrounded_count=self._ungrounded_count,
|
|
126
|
+
grace_remaining=max(
|
|
127
|
+
0, self.config.grace_calls - self._ungrounded_count + 1
|
|
128
|
+
),
|
|
129
|
+
numeric_args=numeric_args,
|
|
130
|
+
has_bindings=has_bindings,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def reset(self) -> None:
|
|
134
|
+
"""Reset for new prompt."""
|
|
135
|
+
self._ungrounded_count = 0
|
|
136
|
+
|
|
137
|
+
def _find_numeric_args(
|
|
138
|
+
self,
|
|
139
|
+
arguments: dict[str, Any],
|
|
140
|
+
user_literals: set[float],
|
|
141
|
+
) -> dict[str, float]:
|
|
142
|
+
"""Find numeric arguments not from user input."""
|
|
143
|
+
numeric = {}
|
|
144
|
+
for key, value in arguments.items():
|
|
145
|
+
if key == "tool_name":
|
|
146
|
+
continue
|
|
147
|
+
if isinstance(value, bool):
|
|
148
|
+
continue
|
|
149
|
+
if isinstance(value, (int, float)):
|
|
150
|
+
if float(value) not in user_literals:
|
|
151
|
+
numeric[key] = value
|
|
152
|
+
elif isinstance(value, str):
|
|
153
|
+
try:
|
|
154
|
+
num_val = float(value)
|
|
155
|
+
if num_val not in user_literals:
|
|
156
|
+
numeric[key] = num_val
|
|
157
|
+
except (ValueError, TypeError):
|
|
158
|
+
pass
|
|
159
|
+
return numeric
|