jaf-py 2.5.10__py3-none-any.whl → 2.5.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +310 -210
  54. jaf/core/types.py +403 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +475 -283
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.12.dist-info/RECORD +97 -0
  88. jaf_py-2.5.10.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/top_level.txt +0 -0
jaf/core/guardrails.py CHANGED
@@ -31,7 +31,7 @@ from .types import (
31
31
  GuardrailEvent,
32
32
  GuardrailEventData,
33
33
  GuardrailViolationEvent,
34
- GuardrailViolationEventData
34
+ GuardrailViolationEventData,
35
35
  )
36
36
 
37
37
  # Constants for content length limits
@@ -48,7 +48,7 @@ OUTPUT_GUARDRAIL_TIMEOUT_MS = 15000
48
48
 
49
49
  class GuardrailCircuitBreaker:
50
50
  """Circuit breaker for guardrail execution to handle repeated failures."""
51
-
51
+
52
52
  def __init__(self, max_failures: int = 5, reset_time_ms: int = 60000):
53
53
  self.failures = 0
54
54
  self.last_failure_time = 0
@@ -59,12 +59,12 @@ class GuardrailCircuitBreaker:
59
59
  """Check if circuit breaker is open (blocking requests)."""
60
60
  if self.failures < self.max_failures:
61
61
  return False
62
-
62
+
63
63
  time_since_last_failure = (time.time() * 1000) - self.last_failure_time
64
64
  if time_since_last_failure > self.reset_time_ms:
65
65
  self.failures = 0
66
66
  return False
67
-
67
+
68
68
  return True
69
69
 
70
70
  def record_failure(self) -> None:
@@ -79,14 +79,17 @@ class GuardrailCircuitBreaker:
79
79
  def should_be_cleaned_up(self, max_age: int) -> bool:
80
80
  """Check if this circuit breaker should be cleaned up."""
81
81
  now = time.time() * 1000
82
- return (self.last_failure_time > 0 and
83
- (now - self.last_failure_time) > max_age and
84
- not self.is_open())
82
+ return (
83
+ self.last_failure_time > 0
84
+ and (now - self.last_failure_time) > max_age
85
+ and not self.is_open()
86
+ )
85
87
 
86
88
 
87
89
  @dataclass
88
90
  class CacheEntry:
89
91
  """Cache entry for guardrail results."""
92
+
90
93
  result: ValidationResult
91
94
  timestamp: float
92
95
  hit_count: int = 1
@@ -94,7 +97,7 @@ class CacheEntry:
94
97
 
95
98
  class GuardrailCache:
96
99
  """LRU cache for guardrail results."""
97
-
100
+
98
101
  def __init__(self, max_size: int = 1000, ttl_ms: int = 300000):
99
102
  self.cache: Dict[str, CacheEntry] = {}
100
103
  self.max_size = max_size
@@ -105,7 +108,7 @@ class GuardrailCache:
105
108
  content_hash = self._hash_string(content[:1000])
106
109
  rule_hash = self._hash_string(rule_prompt)
107
110
  return f"guardrail_{stage}_{model_name}_{rule_hash}_{content_hash}_{len(content)}"
108
-
111
+
109
112
  def _hash_string(self, s: str) -> str:
110
113
  """Simple hash function for strings."""
111
114
  hash_val = 0
@@ -122,47 +125,47 @@ class GuardrailCache:
122
125
  """Evict least recently used entry."""
123
126
  if len(self.cache) < self.max_size:
124
127
  return
125
-
128
+
126
129
  lru_key: Optional[str] = None
127
- lru_score = float('inf')
130
+ lru_score = float("inf")
128
131
  now = time.time() * 1000
129
-
132
+
130
133
  for key, entry in self.cache.items():
131
134
  age_hours = (now - entry.timestamp) / (1000 * 60 * 60)
132
135
  score = entry.hit_count / (1 + age_hours)
133
136
  if score < lru_score:
134
137
  lru_score = score
135
138
  lru_key = key
136
-
139
+
137
140
  if lru_key:
138
141
  del self.cache[lru_key]
139
142
 
140
- def get(self, stage: str, rule_prompt: str, content: str, model_name: str) -> Optional[ValidationResult]:
143
+ def get(
144
+ self, stage: str, rule_prompt: str, content: str, model_name: str
145
+ ) -> Optional[ValidationResult]:
141
146
  """Get cached result."""
142
147
  key = self._create_key(stage, rule_prompt, content, model_name)
143
148
  entry = self.cache.get(key)
144
-
149
+
145
150
  if not entry or self._is_expired(entry):
146
151
  if entry:
147
152
  del self.cache[key]
148
153
  return None
149
-
154
+
150
155
  entry.hit_count += 1
151
156
  entry.timestamp = time.time() * 1000
152
-
157
+
153
158
  return entry.result
154
159
 
155
- def set(self, stage: str, rule_prompt: str, content: str, model_name: str, result: ValidationResult) -> None:
160
+ def set(
161
+ self, stage: str, rule_prompt: str, content: str, model_name: str, result: ValidationResult
162
+ ) -> None:
156
163
  """Cache a result."""
157
164
  key = self._create_key(stage, rule_prompt, content, model_name)
158
-
165
+
159
166
  self._evict_lru()
160
-
161
- self.cache[key] = CacheEntry(
162
- result=result,
163
- timestamp=time.time() * 1000,
164
- hit_count=1
165
- )
167
+
168
+ self.cache[key] = CacheEntry(result=result, timestamp=time.time() * 1000, hit_count=1)
166
169
 
167
170
  def clear(self) -> None:
168
171
  """Clear all cached entries."""
@@ -170,10 +173,7 @@ class GuardrailCache:
170
173
 
171
174
  def get_stats(self) -> Dict[str, Any]:
172
175
  """Get cache statistics."""
173
- return {
174
- 'size': len(self.cache),
175
- 'max_size': self.max_size
176
- }
176
+ return {"size": len(self.cache), "max_size": self.max_size}
177
177
 
178
178
 
179
179
  # Global instances
@@ -202,19 +202,26 @@ async def _create_llm_guardrail(
202
202
  stage: str,
203
203
  rule_prompt: str,
204
204
  fast_model: Optional[str] = None,
205
- fail_safe: str = 'allow',
206
- timeout_ms: int = 30000
205
+ fail_safe: str = "allow",
206
+ timeout_ms: int = 30000,
207
207
  ) -> Guardrail:
208
208
  """Create an LLM-based guardrail function."""
209
-
209
+
210
210
  async def guardrail_func(content: Any) -> ValidationResult:
211
211
  content_str = str(content) if not isinstance(content, str) else content
212
-
212
+
213
213
  model_to_use = fast_model or config.default_fast_model
214
214
  if not model_to_use:
215
- print(f"[JAF:GUARDRAILS] No fast model available for LLM guardrail evaluation, using failSafe: {fail_safe}")
216
- return (ValidValidationResult() if fail_safe == 'allow'
217
- else InvalidValidationResult(error_message='No model available for guardrail evaluation'))
215
+ print(
216
+ f"[JAF:GUARDRAILS] No fast model available for LLM guardrail evaluation, using failSafe: {fail_safe}"
217
+ )
218
+ return (
219
+ ValidValidationResult()
220
+ if fail_safe == "allow"
221
+ else InvalidValidationResult(
222
+ error_message="No model available for guardrail evaluation"
223
+ )
224
+ )
218
225
 
219
226
  # Check cache first
220
227
  cached_result = _guardrail_cache.get(stage, rule_prompt, content_str, model_to_use)
@@ -225,35 +232,52 @@ async def _create_llm_guardrail(
225
232
  # Check circuit breaker
226
233
  circuit_breaker = _get_circuit_breaker(stage, model_to_use)
227
234
  if circuit_breaker.is_open():
228
- print(f"[JAF:GUARDRAILS] Circuit breaker open for {stage} guardrail on model {model_to_use}, using failSafe: {fail_safe}")
229
- return (ValidValidationResult() if fail_safe == 'allow'
230
- else InvalidValidationResult(error_message='Circuit breaker open - too many recent failures'))
235
+ print(
236
+ f"[JAF:GUARDRAILS] Circuit breaker open for {stage} guardrail on model {model_to_use}, using failSafe: {fail_safe}"
237
+ )
238
+ return (
239
+ ValidValidationResult()
240
+ if fail_safe == "allow"
241
+ else InvalidValidationResult(
242
+ error_message="Circuit breaker open - too many recent failures"
243
+ )
244
+ )
231
245
 
232
246
  # Validate content
233
247
  if not content_str:
234
248
  print(f"[JAF:GUARDRAILS] Invalid content provided to {stage} guardrail")
235
- return (ValidValidationResult() if fail_safe == 'allow'
236
- else InvalidValidationResult(error_message='Invalid content provided to guardrail'))
249
+ return (
250
+ ValidValidationResult()
251
+ if fail_safe == "allow"
252
+ else InvalidValidationResult(error_message="Invalid content provided to guardrail")
253
+ )
237
254
 
238
255
  # Check content length
239
- max_content_length = SHORT_TIMEOUT_MAX_CONTENT if timeout_ms < 10000 else LONG_TIMEOUT_MAX_CONTENT
256
+ max_content_length = (
257
+ SHORT_TIMEOUT_MAX_CONTENT if timeout_ms < 10000 else LONG_TIMEOUT_MAX_CONTENT
258
+ )
240
259
  if len(content_str) > max_content_length:
241
- print(f"[JAF:GUARDRAILS] Content too large for {stage} guardrail ({len(content_str)} chars, max: {max_content_length})")
242
- return (ValidValidationResult() if fail_safe == 'allow'
243
- else InvalidValidationResult(error_message=f'Content too large for guardrail evaluation ({len(content_str)} > {max_content_length} chars)'))
260
+ print(
261
+ f"[JAF:GUARDRAILS] Content too large for {stage} guardrail ({len(content_str)} chars, max: {max_content_length})"
262
+ )
263
+ return (
264
+ ValidValidationResult()
265
+ if fail_safe == "allow"
266
+ else InvalidValidationResult(
267
+ error_message=f"Content too large for guardrail evaluation ({len(content_str)} > {max_content_length} chars)"
268
+ )
269
+ )
244
270
 
245
271
  # Prepare content for evaluation
246
- sanitized_content = (content_str
247
- .replace('"""', '[TRIPLE_QUOTE]')
248
- .replace('\n', ' ')[:2000])
249
-
272
+ sanitized_content = content_str.replace('"""', "[TRIPLE_QUOTE]").replace("\n", " ")[:2000]
273
+
250
274
  eval_prompt = f"""You are a guardrail validator for {stage}.
251
275
  Rules:
252
276
  {rule_prompt}
253
277
 
254
- Decide if the {stage} {'user message' if stage == 'input' else 'assistant output'} complies with the rules.
278
+ Decide if the {stage} {"user message" if stage == "input" else "assistant output"} complies with the rules.
255
279
  Return a JSON object with keys: {{"allowed": boolean, "reason": string}}. Do not include extra text.
256
- {stage.capitalize()} {'user message' if stage == 'input' else 'assistant output'}:
280
+ {stage.capitalize()} {"user message" if stage == "input" else "assistant output"}:
257
281
  \"\"\"
258
282
  {sanitized_content}
259
283
  \"\"\""""
@@ -261,22 +285,22 @@ Return a JSON object with keys: {{"allowed": boolean, "reason": string}}. Do not
261
285
  try:
262
286
  # Create temporary state for guardrail evaluation
263
287
  temp_state = RunState(
264
- run_id=create_run_id('guardrail-eval'),
265
- trace_id=create_trace_id('guardrail-eval'),
288
+ run_id=create_run_id("guardrail-eval"),
289
+ trace_id=create_trace_id("guardrail-eval"),
266
290
  messages=[Message(role=ContentRole.USER, content=eval_prompt)],
267
- current_agent_name='guardrail-evaluator',
291
+ current_agent_name="guardrail-evaluator",
268
292
  context={},
269
- turn_count=0
293
+ turn_count=0,
270
294
  )
271
295
 
272
296
  # Create evaluation agent
273
297
  def eval_instructions(state: RunState) -> str:
274
- return 'You are a guardrail validator. Return only valid JSON.'
298
+ return "You are a guardrail validator. Return only valid JSON."
275
299
 
276
300
  eval_agent = Agent(
277
- name='guardrail-evaluator',
301
+ name="guardrail-evaluator",
278
302
  instructions=eval_instructions,
279
- model_config={'name': model_to_use} if hasattr(config, 'ModelConfig') else None
303
+ model_config={"name": model_to_use} if hasattr(config, "ModelConfig") else None,
280
304
  )
281
305
 
282
306
  # Create guardrail config (no guardrails to avoid recursion)
@@ -289,28 +313,30 @@ Return a JSON object with keys: {{"allowed": boolean, "reason": string}}. Do not
289
313
  initial_input_guardrails=None,
290
314
  final_output_guardrails=None,
291
315
  on_event=None,
292
- prefer_streaming=config.prefer_streaming
316
+ prefer_streaming=config.prefer_streaming,
293
317
  )
294
318
 
295
319
  # Execute with timeout
296
- completion_promise = config.model_provider.get_completion(temp_state, eval_agent, guardrail_config)
320
+ completion_promise = config.model_provider.get_completion(
321
+ temp_state, eval_agent, guardrail_config
322
+ )
297
323
  response = await _with_timeout(
298
324
  completion_promise,
299
325
  timeout_ms,
300
- f"{stage} guardrail evaluation timed out after {timeout_ms}ms"
326
+ f"{stage} guardrail evaluation timed out after {timeout_ms}ms",
301
327
  )
302
328
 
303
329
  # Handle different response formats
304
330
  response_content = None
305
- if hasattr(response, 'message') and response.message:
306
- if hasattr(response.message, 'content'):
331
+ if hasattr(response, "message") and response.message:
332
+ if hasattr(response.message, "content"):
307
333
  response_content = response.message.content
308
334
  elif isinstance(response, dict):
309
- if 'message' in response and response['message']:
310
- if isinstance(response['message'], dict) and 'content' in response['message']:
311
- response_content = response['message']['content']
312
- elif hasattr(response['message'], 'content'):
313
- response_content = response['message'].content
335
+ if "message" in response and response["message"]:
336
+ if isinstance(response["message"], dict) and "content" in response["message"]:
337
+ response_content = response["message"]["content"]
338
+ elif hasattr(response["message"], "content"):
339
+ response_content = response["message"].content
314
340
 
315
341
  if not response_content:
316
342
  circuit_breaker.record_success()
@@ -320,62 +346,76 @@ Return a JSON object with keys: {{"allowed": boolean, "reason": string}}. Do not
320
346
 
321
347
  # Parse response
322
348
  parsed = json_parse_llm_output(response_content)
323
- allowed = bool(parsed.get('allowed', True) if parsed else True)
324
- reason = str(parsed.get('reason', 'Guardrail violation') if parsed else 'Guardrail violation')
325
-
349
+ allowed = bool(parsed.get("allowed", True) if parsed else True)
350
+ reason = str(
351
+ parsed.get("reason", "Guardrail violation") if parsed else "Guardrail violation"
352
+ )
353
+
326
354
  circuit_breaker.record_success()
327
-
328
- result = (ValidValidationResult() if allowed
329
- else InvalidValidationResult(error_message=reason))
330
-
355
+
356
+ result = (
357
+ ValidValidationResult()
358
+ if allowed
359
+ else InvalidValidationResult(error_message=reason)
360
+ )
361
+
331
362
  _guardrail_cache.set(stage, rule_prompt, content_str, model_to_use, result)
332
363
  return result
333
364
 
334
365
  except Exception as e:
335
366
  circuit_breaker.record_failure()
336
-
367
+
337
368
  error_message = str(e)
338
- is_timeout = 'Timeout' in error_message
339
-
369
+ is_timeout = "Timeout" in error_message
370
+
340
371
  log_message = f"[JAF:GUARDRAILS] {stage} guardrail evaluation failed"
341
372
  if is_timeout:
342
373
  print(f"{log_message} due to timeout ({timeout_ms}ms), using failSafe: {fail_safe}")
343
374
  else:
344
375
  print(f"{log_message}, using failSafe: {fail_safe} - {error_message}")
345
-
346
- return (ValidValidationResult() if fail_safe == 'allow'
347
- else InvalidValidationResult(error_message=f'Guardrail evaluation failed: {error_message}'))
376
+
377
+ return (
378
+ ValidValidationResult()
379
+ if fail_safe == "allow"
380
+ else InvalidValidationResult(
381
+ error_message=f"Guardrail evaluation failed: {error_message}"
382
+ )
383
+ )
348
384
 
349
385
  return guardrail_func
350
386
 
351
387
 
352
388
  async def build_effective_guardrails(
353
- current_agent: Agent,
354
- config: RunConfig
389
+ current_agent: Agent, config: RunConfig
355
390
  ) -> Tuple[List[Guardrail], List[Guardrail]]:
356
391
  """Build effective input and output guardrails for an agent."""
357
392
  effective_input_guardrails: List[Guardrail] = []
358
393
  effective_output_guardrails: List[Guardrail] = []
359
-
394
+
360
395
  try:
361
- raw_guardrails_cfg = (current_agent.advanced_config.guardrails
362
- if current_agent.advanced_config
363
- else None)
396
+ raw_guardrails_cfg = (
397
+ current_agent.advanced_config.guardrails if current_agent.advanced_config else None
398
+ )
364
399
  guardrails_cfg = validate_guardrails_config(raw_guardrails_cfg)
365
400
 
366
401
  fast_model = guardrails_cfg.fast_model or config.default_fast_model
367
402
  if not fast_model and (guardrails_cfg.input_prompt or guardrails_cfg.output_prompt):
368
- print('[JAF:GUARDRAILS] No fast model available for LLM guardrails - skipping LLM-based validation')
369
-
370
- print('[JAF:GUARDRAILS] Configuration:', {
371
- 'hasInputPrompt': bool(guardrails_cfg.input_prompt),
372
- 'hasOutputPrompt': bool(guardrails_cfg.output_prompt),
373
- 'requireCitations': guardrails_cfg.require_citations,
374
- 'executionMode': guardrails_cfg.execution_mode,
375
- 'failSafe': guardrails_cfg.fail_safe,
376
- 'timeoutMs': guardrails_cfg.timeout_ms,
377
- 'fastModel': fast_model or 'none'
378
- })
403
+ print(
404
+ "[JAF:GUARDRAILS] No fast model available for LLM guardrails - skipping LLM-based validation"
405
+ )
406
+
407
+ print(
408
+ "[JAF:GUARDRAILS] Configuration:",
409
+ {
410
+ "hasInputPrompt": bool(guardrails_cfg.input_prompt),
411
+ "hasOutputPrompt": bool(guardrails_cfg.output_prompt),
412
+ "requireCitations": guardrails_cfg.require_citations,
413
+ "executionMode": guardrails_cfg.execution_mode,
414
+ "failSafe": guardrails_cfg.fail_safe,
415
+ "timeoutMs": guardrails_cfg.timeout_ms,
416
+ "fastModel": fast_model or "none",
417
+ },
418
+ )
379
419
 
380
420
  # Start with global guardrails
381
421
  effective_input_guardrails = list(config.initial_input_guardrails or [])
@@ -384,41 +424,55 @@ async def build_effective_guardrails(
384
424
  # Add input prompt guardrail
385
425
  if guardrails_cfg.input_prompt and guardrails_cfg.input_prompt.strip():
386
426
  input_guardrail = await _create_llm_guardrail(
387
- config, 'input', guardrails_cfg.input_prompt,
388
- fast_model, guardrails_cfg.fail_safe, guardrails_cfg.timeout_ms
427
+ config,
428
+ "input",
429
+ guardrails_cfg.input_prompt,
430
+ fast_model,
431
+ guardrails_cfg.fail_safe,
432
+ guardrails_cfg.timeout_ms,
389
433
  )
390
434
  effective_input_guardrails.append(input_guardrail)
391
435
 
392
436
  # Add citation requirement guardrail
393
437
  if guardrails_cfg.require_citations:
438
+
394
439
  def citation_guardrail(output: Any) -> ValidationResult:
395
440
  def find_text(val: Any) -> str:
396
441
  if isinstance(val, str):
397
442
  return val
398
443
  elif isinstance(val, list):
399
- return ' '.join(find_text(item) for item in val)
444
+ return " ".join(find_text(item) for item in val)
400
445
  elif isinstance(val, dict):
401
- return ' '.join(find_text(v) for v in val.values())
446
+ return " ".join(find_text(v) for v in val.values())
402
447
  else:
403
448
  return str(val)
404
-
449
+
405
450
  text = find_text(output)
406
- has_citation = bool(re.search(r'\[(\d+)\]', text))
407
- return (ValidValidationResult() if has_citation
408
- else InvalidValidationResult(error_message="Missing required [n] citation in output"))
409
-
451
+ has_citation = bool(re.search(r"\[(\d+)\]", text))
452
+ return (
453
+ ValidValidationResult()
454
+ if has_citation
455
+ else InvalidValidationResult(
456
+ error_message="Missing required [n] citation in output"
457
+ )
458
+ )
459
+
410
460
  effective_output_guardrails.append(citation_guardrail)
411
461
 
412
462
  # Add output prompt guardrail
413
463
  if guardrails_cfg.output_prompt and guardrails_cfg.output_prompt.strip():
414
464
  output_guardrail = await _create_llm_guardrail(
415
- config, 'output', guardrails_cfg.output_prompt,
416
- fast_model, guardrails_cfg.fail_safe, guardrails_cfg.timeout_ms
465
+ config,
466
+ "output",
467
+ guardrails_cfg.output_prompt,
468
+ fast_model,
469
+ guardrails_cfg.fail_safe,
470
+ guardrails_cfg.timeout_ms,
417
471
  )
418
472
  effective_output_guardrails.append(output_guardrail)
419
473
 
420
474
  except Exception as e:
421
- print(f'[JAF:GUARDRAILS] Failed to configure advanced guardrails: {e}')
475
+ print(f"[JAF:GUARDRAILS] Failed to configure advanced guardrails: {e}")
422
476
  # Fall back to global guardrails only
423
477
  effective_input_guardrails = list(config.initial_input_guardrails or [])
424
478
  effective_output_guardrails = list(config.final_output_guardrails or [])
@@ -427,192 +481,212 @@ async def build_effective_guardrails(
427
481
 
428
482
 
429
483
  async def execute_input_guardrails_sequential(
430
- input_guardrails: List[Guardrail],
431
- first_user_message: Message,
432
- config: RunConfig
484
+ input_guardrails: List[Guardrail], first_user_message: Message, config: RunConfig
433
485
  ) -> ValidationResult:
434
486
  """Execute input guardrails sequentially."""
435
487
  if not input_guardrails:
436
488
  return ValidValidationResult()
437
489
 
438
490
  print(f"[JAF:GUARDRAILS] Starting {len(input_guardrails)} input guardrails (sequential)")
439
-
491
+
440
492
  content = get_text_content(first_user_message.content)
441
-
493
+
442
494
  for i, guardrail in enumerate(input_guardrails):
443
495
  guardrail_name = f"input-guardrail-{i + 1}"
444
-
496
+
445
497
  try:
446
498
  print(f"[JAF:GUARDRAILS] Starting {guardrail_name}")
447
-
499
+
448
500
  timeout_ms = GUARDRAIL_TIMEOUT_MS
449
501
  result = await _with_timeout(
450
- guardrail(content) if asyncio.iscoroutinefunction(guardrail) else guardrail(content),
502
+ guardrail(content)
503
+ if asyncio.iscoroutinefunction(guardrail)
504
+ else guardrail(content),
451
505
  timeout_ms,
452
- f"{guardrail_name} execution timed out after {timeout_ms}ms"
506
+ f"{guardrail_name} execution timed out after {timeout_ms}ms",
453
507
  )
454
-
508
+
455
509
  print(f"[JAF:GUARDRAILS] {guardrail_name} completed: {result}")
456
-
510
+
457
511
  if not result.is_valid:
458
- error_message = getattr(result, 'error_message', 'Guardrail violation')
512
+ error_message = getattr(result, "error_message", "Guardrail violation")
459
513
  print(f"🚨 {guardrail_name} violation: {error_message}")
460
514
  if config.on_event:
461
- config.on_event(GuardrailViolationEvent(
462
- data=GuardrailViolationEventData(stage='input', reason=error_message)
463
- ))
515
+ config.on_event(
516
+ GuardrailViolationEvent(
517
+ data=GuardrailViolationEventData(stage="input", reason=error_message)
518
+ )
519
+ )
464
520
  return result
465
-
521
+
466
522
  except Exception as error:
467
523
  error_message = str(error)
468
524
  print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
469
-
470
- is_system_error = 'Timeout' in error_message or 'Circuit breaker' in error_message
471
-
525
+
526
+ is_system_error = "Timeout" in error_message or "Circuit breaker" in error_message
527
+
472
528
  if is_system_error:
473
- print(f"[JAF:GUARDRAILS] {guardrail_name} system error, continuing: {error_message}")
529
+ print(
530
+ f"[JAF:GUARDRAILS] {guardrail_name} system error, continuing: {error_message}"
531
+ )
474
532
  continue
475
533
  else:
476
534
  if config.on_event:
477
- config.on_event(GuardrailViolationEvent(
478
- data=GuardrailViolationEventData(stage='input', reason=error_message)
479
- ))
535
+ config.on_event(
536
+ GuardrailViolationEvent(
537
+ data=GuardrailViolationEventData(stage="input", reason=error_message)
538
+ )
539
+ )
480
540
  return InvalidValidationResult(error_message=error_message)
481
-
541
+
482
542
  print("✅ All input guardrails passed (sequential).")
483
543
  return ValidValidationResult()
484
544
 
485
545
 
486
546
  async def execute_input_guardrails_parallel(
487
- input_guardrails: List[Guardrail],
488
- first_user_message: Message,
489
- config: RunConfig
547
+ input_guardrails: List[Guardrail], first_user_message: Message, config: RunConfig
490
548
  ) -> ValidationResult:
491
549
  """Execute input guardrails in parallel."""
492
550
  if not input_guardrails:
493
551
  return ValidValidationResult()
494
552
 
495
553
  print(f"[JAF:GUARDRAILS] Starting {len(input_guardrails)} input guardrails")
496
-
554
+
497
555
  content = get_text_content(first_user_message.content)
498
-
556
+
499
557
  async def run_guardrail(guardrail: Guardrail, index: int):
500
558
  guardrail_name = f"input-guardrail-{index + 1}"
501
-
559
+
502
560
  try:
503
561
  print(f"[JAF:GUARDRAILS] Starting {guardrail_name}")
504
-
505
- timeout_ms = DEFAULT_FAST_MODEL_TIMEOUT_MS if config.default_fast_model else DEFAULT_TIMEOUT_MS
506
-
562
+
563
+ timeout_ms = (
564
+ DEFAULT_FAST_MODEL_TIMEOUT_MS if config.default_fast_model else DEFAULT_TIMEOUT_MS
565
+ )
566
+
507
567
  if asyncio.iscoroutinefunction(guardrail):
508
- result = await _with_timeout(guardrail(content), timeout_ms,
509
- f"{guardrail_name} execution timed out after {timeout_ms}ms")
568
+ result = await _with_timeout(
569
+ guardrail(content),
570
+ timeout_ms,
571
+ f"{guardrail_name} execution timed out after {timeout_ms}ms",
572
+ )
510
573
  else:
511
574
  result = guardrail(content)
512
-
575
+
513
576
  print(f"[JAF:GUARDRAILS] {guardrail_name} completed: {result}")
514
- return {'result': result, 'guardrail_index': index}
515
-
577
+ return {"result": result, "guardrail_index": index}
578
+
516
579
  except Exception as error:
517
580
  error_message = str(error)
518
581
  print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
519
-
582
+
520
583
  return {
521
- 'result': ValidValidationResult(),
522
- 'guardrail_index': index,
523
- 'warning': f"Guardrail {index + 1} failed but was skipped: {error_message}"
584
+ "result": ValidValidationResult(),
585
+ "guardrail_index": index,
586
+ "warning": f"Guardrail {index + 1} failed but was skipped: {error_message}",
524
587
  }
525
-
588
+
526
589
  try:
527
590
  # Run all guardrails in parallel
528
591
  tasks = [run_guardrail(guardrail, i) for i, guardrail in enumerate(input_guardrails)]
529
592
  results = await asyncio.gather(*tasks, return_exceptions=True)
530
-
593
+
531
594
  print("[JAF:GUARDRAILS] Input guardrails completed. Checking results...")
532
-
595
+
533
596
  warnings = []
534
-
597
+
535
598
  for i, result in enumerate(results):
536
599
  if isinstance(result, Exception):
537
600
  error_message = str(result)
538
601
  print(f"[JAF:GUARDRAILS] Input guardrail {i + 1} promise rejected: {error_message}")
539
602
  warnings.append(f"Guardrail {i + 1} failed: {error_message}")
540
603
  continue
541
-
542
- if 'warning' in result:
543
- warnings.append(result['warning'])
544
-
545
- validation_result = result['result']
604
+
605
+ if "warning" in result:
606
+ warnings.append(result["warning"])
607
+
608
+ validation_result = result["result"]
546
609
  if not validation_result.is_valid:
547
- error_message = getattr(validation_result, 'error_message', 'Guardrail violation')
548
- print(f"🚨 Input guardrail {result['guardrail_index'] + 1} violation: {error_message}")
610
+ error_message = getattr(validation_result, "error_message", "Guardrail violation")
611
+ print(
612
+ f"🚨 Input guardrail {result['guardrail_index'] + 1} violation: {error_message}"
613
+ )
549
614
  if config.on_event:
550
- config.on_event(GuardrailViolationEvent(
551
- data=GuardrailViolationEventData(stage='input', reason=error_message)
552
- ))
615
+ config.on_event(
616
+ GuardrailViolationEvent(
617
+ data=GuardrailViolationEventData(stage="input", reason=error_message)
618
+ )
619
+ )
553
620
  return validation_result
554
-
621
+
555
622
  if warnings:
556
623
  print(f"[JAF:GUARDRAILS] {len(warnings)} guardrail warnings: {warnings}")
557
-
624
+
558
625
  print("✅ All input guardrails passed.")
559
626
  return ValidValidationResult()
560
-
627
+
561
628
  except Exception as error:
562
629
  print(f"[JAF:GUARDRAILS] Catastrophic failure in input guardrail execution: {error}")
563
630
  return ValidValidationResult() # Fail gracefully
564
631
 
565
632
 
566
633
  async def execute_output_guardrails(
567
- output_guardrails: List[Guardrail],
568
- output: Any,
569
- config: RunConfig
634
+ output_guardrails: List[Guardrail], output: Any, config: RunConfig
570
635
  ) -> ValidationResult:
571
636
  """Execute output guardrails sequentially."""
572
637
  if not output_guardrails:
573
638
  return ValidValidationResult()
574
639
 
575
640
  print(f"[JAF:GUARDRAILS] Checking {len(output_guardrails)} output guardrails")
576
-
641
+
577
642
  for i, guardrail in enumerate(output_guardrails):
578
643
  guardrail_name = f"output-guardrail-{i + 1}"
579
-
644
+
580
645
  try:
581
646
  timeout_ms = OUTPUT_GUARDRAIL_TIMEOUT_MS
582
-
647
+
583
648
  if asyncio.iscoroutinefunction(guardrail):
584
- result = await _with_timeout(guardrail(output), timeout_ms,
585
- f"{guardrail_name} execution timed out after {timeout_ms}ms")
649
+ result = await _with_timeout(
650
+ guardrail(output),
651
+ timeout_ms,
652
+ f"{guardrail_name} execution timed out after {timeout_ms}ms",
653
+ )
586
654
  else:
587
655
  result = guardrail(output)
588
-
656
+
589
657
  if not result.is_valid:
590
- error_message = getattr(result, 'error_message', 'Guardrail violation')
658
+ error_message = getattr(result, "error_message", "Guardrail violation")
591
659
  print(f"🚨 {guardrail_name} violation: {error_message}")
592
660
  if config.on_event:
593
- config.on_event(GuardrailViolationEvent(
594
- data=GuardrailViolationEventData(stage='output', reason=error_message)
595
- ))
661
+ config.on_event(
662
+ GuardrailViolationEvent(
663
+ data=GuardrailViolationEventData(stage="output", reason=error_message)
664
+ )
665
+ )
596
666
  return result
597
-
667
+
598
668
  print(f"✅ {guardrail_name} passed")
599
-
669
+
600
670
  except Exception as error:
601
671
  error_message = str(error)
602
672
  print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
603
-
604
- is_system_error = 'Timeout' in error_message or 'Circuit breaker' in error_message
605
-
673
+
674
+ is_system_error = "Timeout" in error_message or "Circuit breaker" in error_message
675
+
606
676
  if is_system_error:
607
- print(f"[JAF:GUARDRAILS] {guardrail_name} system error, allowing output: {error_message}")
677
+ print(
678
+ f"[JAF:GUARDRAILS] {guardrail_name} system error, allowing output: {error_message}"
679
+ )
608
680
  continue
609
681
  else:
610
682
  if config.on_event:
611
- config.on_event(GuardrailViolationEvent(
612
- data=GuardrailViolationEventData(stage='output', reason=error_message)
613
- ))
683
+ config.on_event(
684
+ GuardrailViolationEvent(
685
+ data=GuardrailViolationEventData(stage="output", reason=error_message)
686
+ )
687
+ )
614
688
  return InvalidValidationResult(error_message=error_message)
615
-
689
+
616
690
  print("✅ All output guardrails passed")
617
691
  return ValidValidationResult()
618
692
 
@@ -623,40 +697,40 @@ def cleanup_circuit_breakers() -> None:
623
697
  for key, breaker in _circuit_breakers.items():
624
698
  if breaker.should_be_cleaned_up(CIRCUIT_BREAKER_CLEANUP_MAX_AGE):
625
699
  to_remove.append(key)
626
-
700
+
627
701
  for key in to_remove:
628
702
  del _circuit_breakers[key]
629
703
 
630
704
 
631
705
  class GuardrailCacheManager:
632
706
  """Manager for guardrail cache operations."""
633
-
707
+
634
708
  @staticmethod
635
709
  def get_stats() -> Dict[str, Any]:
636
710
  """Get cache statistics."""
637
711
  return _guardrail_cache.get_stats()
638
-
712
+
639
713
  @staticmethod
640
714
  def clear() -> None:
641
715
  """Clear cache."""
642
716
  _guardrail_cache.clear()
643
-
717
+
644
718
  @staticmethod
645
719
  def get_metrics() -> Dict[str, Any]:
646
720
  """Get cache metrics."""
647
721
  stats = _guardrail_cache.get_stats()
648
722
  return {
649
723
  **stats,
650
- 'utilization_percent': (stats['size'] / stats['max_size']) * 100,
651
- 'circuit_breakers_count': len(_circuit_breakers)
724
+ "utilization_percent": (stats["size"] / stats["max_size"]) * 100,
725
+ "circuit_breakers_count": len(_circuit_breakers),
652
726
  }
653
-
727
+
654
728
  @staticmethod
655
729
  def log_stats() -> None:
656
730
  """Log cache statistics."""
657
731
  metrics = GuardrailCacheManager.get_metrics()
658
- print('[JAF:GUARDRAILS] Cache stats:', metrics)
659
-
732
+ print("[JAF:GUARDRAILS] Cache stats:", metrics)
733
+
660
734
  @staticmethod
661
735
  def cleanup() -> None:
662
736
  """Cleanup old entries."""
@@ -664,4 +738,4 @@ class GuardrailCacheManager:
664
738
 
665
739
 
666
740
  # Export the cache manager
667
- guardrail_cache_manager = GuardrailCacheManager()
741
+ guardrail_cache_manager = GuardrailCacheManager()