control-zero 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- control_zero/__init__.py +31 -0
- control_zero/client.py +584 -0
- control_zero/integrations/crewai/__init__.py +53 -0
- control_zero/integrations/crewai/agent.py +267 -0
- control_zero/integrations/crewai/crew.py +381 -0
- control_zero/integrations/crewai/task.py +291 -0
- control_zero/integrations/crewai/tool.py +299 -0
- control_zero/integrations/langchain/__init__.py +58 -0
- control_zero/integrations/langchain/agent.py +311 -0
- control_zero/integrations/langchain/callbacks.py +441 -0
- control_zero/integrations/langchain/chain.py +319 -0
- control_zero/integrations/langchain/graph.py +441 -0
- control_zero/integrations/langchain/tool.py +271 -0
- control_zero/llm/__init__.py +77 -0
- control_zero/llm/anthropic/__init__.py +35 -0
- control_zero/llm/anthropic/client.py +136 -0
- control_zero/llm/anthropic/messages.py +375 -0
- control_zero/llm/base.py +551 -0
- control_zero/llm/cohere/__init__.py +32 -0
- control_zero/llm/cohere/client.py +402 -0
- control_zero/llm/gemini/__init__.py +34 -0
- control_zero/llm/gemini/client.py +486 -0
- control_zero/llm/groq/__init__.py +32 -0
- control_zero/llm/groq/client.py +330 -0
- control_zero/llm/mistral/__init__.py +32 -0
- control_zero/llm/mistral/client.py +319 -0
- control_zero/llm/ollama/__init__.py +31 -0
- control_zero/llm/ollama/client.py +439 -0
- control_zero/llm/openai/__init__.py +34 -0
- control_zero/llm/openai/chat.py +331 -0
- control_zero/llm/openai/client.py +182 -0
- control_zero/logging/__init__.py +5 -0
- control_zero/logging/async_logger.py +65 -0
- control_zero/mcp/__init__.py +5 -0
- control_zero/mcp/middleware.py +148 -0
- control_zero/policy/__init__.py +5 -0
- control_zero/policy/enforcer.py +99 -0
- control_zero/secrets/__init__.py +5 -0
- control_zero/secrets/manager.py +77 -0
- control_zero/types.py +51 -0
- control_zero-0.2.0.dist-info/METADATA +216 -0
- control_zero-0.2.0.dist-info/RECORD +44 -0
- control_zero-0.2.0.dist-info/WHEEL +4 -0
- control_zero-0.2.0.dist-info/licenses/LICENSE +17 -0
control_zero/llm/base.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base classes for LLM governance wrappers.
|
|
3
|
+
|
|
4
|
+
This module provides the foundational classes that all LLM provider
|
|
5
|
+
wrappers inherit from, ensuring consistent governance behavior across
|
|
6
|
+
OpenAI, Anthropic, Gemini, and other providers.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import time
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic
|
|
13
|
+
from enum import Enum
|
|
14
|
+
|
|
15
|
+
from control_zero.policy import PolicyDecision, PolicyDeniedError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GovernanceAction(str, Enum):
|
|
19
|
+
"""Actions that can be governed for LLM calls."""
|
|
20
|
+
CHAT_COMPLETION = "chat_completion"
|
|
21
|
+
TEXT_COMPLETION = "text_completion"
|
|
22
|
+
EMBEDDING = "embedding"
|
|
23
|
+
FUNCTION_CALL = "function_call"
|
|
24
|
+
TOOL_USE = "tool_use"
|
|
25
|
+
IMAGE_GENERATION = "image_generation"
|
|
26
|
+
AUDIO_TRANSCRIPTION = "audio_transcription"
|
|
27
|
+
MODERATION = "moderation"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ModelPolicy:
|
|
32
|
+
"""Policy for model access control."""
|
|
33
|
+
allowed_models: List[str] = field(default_factory=list)
|
|
34
|
+
denied_models: List[str] = field(default_factory=list)
|
|
35
|
+
default_model: Optional[str] = None
|
|
36
|
+
max_tokens_per_request: Optional[int] = None
|
|
37
|
+
max_context_length: Optional[int] = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class CostPolicy:
|
|
42
|
+
"""Policy for cost control."""
|
|
43
|
+
max_cost_per_request: Optional[float] = None
|
|
44
|
+
max_cost_per_day: Optional[float] = None
|
|
45
|
+
max_cost_per_month: Optional[float] = None
|
|
46
|
+
max_tokens_per_day: Optional[int] = None
|
|
47
|
+
max_requests_per_minute: Optional[int] = None
|
|
48
|
+
max_requests_per_day: Optional[int] = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class FunctionPolicy:
|
|
53
|
+
"""Policy for function/tool calling."""
|
|
54
|
+
allowed_functions: List[str] = field(default_factory=list)
|
|
55
|
+
denied_functions: List[str] = field(default_factory=list)
|
|
56
|
+
require_approval: List[str] = field(default_factory=list)
|
|
57
|
+
max_function_calls_per_request: Optional[int] = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class ContentPolicy:
|
|
62
|
+
"""Policy for content filtering."""
|
|
63
|
+
enable_pii_detection: bool = False
|
|
64
|
+
pii_action: str = "mask" # mask, block, warn
|
|
65
|
+
blocked_topics: List[str] = field(default_factory=list)
|
|
66
|
+
required_system_prompt: Optional[str] = None
|
|
67
|
+
max_output_tokens: Optional[int] = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class LLMGovernanceConfig:
|
|
72
|
+
"""Complete governance configuration for LLM calls."""
|
|
73
|
+
model_policy: ModelPolicy = field(default_factory=ModelPolicy)
|
|
74
|
+
cost_policy: CostPolicy = field(default_factory=CostPolicy)
|
|
75
|
+
function_policy: FunctionPolicy = field(default_factory=FunctionPolicy)
|
|
76
|
+
content_policy: ContentPolicy = field(default_factory=ContentPolicy)
|
|
77
|
+
enable_audit_logging: bool = True
|
|
78
|
+
enable_streaming_governance: bool = True
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class LLMUsageMetrics:
|
|
83
|
+
"""Metrics for a single LLM call."""
|
|
84
|
+
provider: str
|
|
85
|
+
model: str
|
|
86
|
+
action: GovernanceAction
|
|
87
|
+
input_tokens: int = 0
|
|
88
|
+
output_tokens: int = 0
|
|
89
|
+
total_tokens: int = 0
|
|
90
|
+
latency_ms: int = 0
|
|
91
|
+
estimated_cost: float = 0.0
|
|
92
|
+
function_calls: int = 0
|
|
93
|
+
cached: bool = False
|
|
94
|
+
|
|
95
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
96
|
+
return {
|
|
97
|
+
"provider": self.provider,
|
|
98
|
+
"model": self.model,
|
|
99
|
+
"action": self.action.value,
|
|
100
|
+
"input_tokens": self.input_tokens,
|
|
101
|
+
"output_tokens": self.output_tokens,
|
|
102
|
+
"total_tokens": self.total_tokens,
|
|
103
|
+
"latency_ms": self.latency_ms,
|
|
104
|
+
"estimated_cost": self.estimated_cost,
|
|
105
|
+
"function_calls": self.function_calls,
|
|
106
|
+
"cached": self.cached,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# Model pricing (approximate, per 1M tokens)
|
|
111
|
+
MODEL_PRICING = {
|
|
112
|
+
# OpenAI
|
|
113
|
+
"gpt-4-turbo": {"input": 10.0, "output": 30.0},
|
|
114
|
+
"gpt-4": {"input": 30.0, "output": 60.0},
|
|
115
|
+
"gpt-4o": {"input": 2.50, "output": 10.0},
|
|
116
|
+
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
|
117
|
+
"gpt-3.5-turbo": {"input": 0.50, "output": 1.50},
|
|
118
|
+
"o1-preview": {"input": 15.0, "output": 60.0},
|
|
119
|
+
"o1-mini": {"input": 3.0, "output": 12.0},
|
|
120
|
+
# Anthropic
|
|
121
|
+
"claude-3-opus": {"input": 15.0, "output": 75.0},
|
|
122
|
+
"claude-3-sonnet": {"input": 3.0, "output": 15.0},
|
|
123
|
+
"claude-3-haiku": {"input": 0.25, "output": 1.25},
|
|
124
|
+
"claude-3.5-sonnet": {"input": 3.0, "output": 15.0},
|
|
125
|
+
"claude-opus-4": {"input": 15.0, "output": 75.0},
|
|
126
|
+
# Google
|
|
127
|
+
"gemini-1.5-pro": {"input": 3.50, "output": 10.50},
|
|
128
|
+
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
|
|
129
|
+
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
|
130
|
+
# Groq
|
|
131
|
+
"llama-3.1-70b": {"input": 0.59, "output": 0.79},
|
|
132
|
+
"llama-3.1-8b": {"input": 0.05, "output": 0.08},
|
|
133
|
+
"mixtral-8x7b": {"input": 0.24, "output": 0.24},
|
|
134
|
+
# Mistral
|
|
135
|
+
"mistral-large": {"input": 2.0, "output": 6.0},
|
|
136
|
+
"mistral-small": {"input": 0.2, "output": 0.6},
|
|
137
|
+
"codestral": {"input": 0.2, "output": 0.6},
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
|
142
|
+
"""Estimate cost for a model call."""
|
|
143
|
+
# Normalize model name
|
|
144
|
+
model_lower = model.lower()
|
|
145
|
+
|
|
146
|
+
# Find matching pricing
|
|
147
|
+
for model_key, pricing in MODEL_PRICING.items():
|
|
148
|
+
if model_key in model_lower:
|
|
149
|
+
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
|
150
|
+
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
|
151
|
+
return input_cost + output_cost
|
|
152
|
+
|
|
153
|
+
# Default estimate if model not found
|
|
154
|
+
return (input_tokens + output_tokens) / 1_000_000 * 1.0
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
T = TypeVar("T") # Original client type
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class GovernedLLM(ABC, Generic[T]):
|
|
161
|
+
"""
|
|
162
|
+
Base class for governed LLM clients.
|
|
163
|
+
|
|
164
|
+
This class wraps any LLM provider's client and adds:
|
|
165
|
+
- Policy enforcement (model access, function calling, content)
|
|
166
|
+
- Cost tracking and limits
|
|
167
|
+
- Audit logging
|
|
168
|
+
- PII detection and masking
|
|
169
|
+
|
|
170
|
+
Subclasses implement provider-specific wrapping logic.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
client: T,
|
|
176
|
+
control_zero: Any, # ControlZeroClient
|
|
177
|
+
config: Optional[LLMGovernanceConfig] = None,
|
|
178
|
+
user_context: Optional[Dict[str, Any]] = None,
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Initialize a governed LLM client.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
client: The original LLM provider client (OpenAI, Anthropic, etc.)
|
|
185
|
+
control_zero: Control Zero client for policy and logging
|
|
186
|
+
config: Governance configuration (optional, uses defaults if not provided)
|
|
187
|
+
user_context: Context about the current user (user_id, role, etc.)
|
|
188
|
+
"""
|
|
189
|
+
self._client = client
|
|
190
|
+
self._cz = control_zero
|
|
191
|
+
self._config = config or LLMGovernanceConfig()
|
|
192
|
+
self._user_context = user_context or {}
|
|
193
|
+
|
|
194
|
+
# Usage tracking
|
|
195
|
+
self._session_metrics: List[LLMUsageMetrics] = []
|
|
196
|
+
self._daily_cost: float = 0.0
|
|
197
|
+
self._daily_tokens: int = 0
|
|
198
|
+
self._daily_requests: int = 0
|
|
199
|
+
self._last_reset: float = time.time()
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
@abstractmethod
|
|
203
|
+
def provider_name(self) -> str:
|
|
204
|
+
"""Return the provider name (e.g., 'openai', 'anthropic')."""
|
|
205
|
+
pass
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def client(self) -> T:
|
|
209
|
+
"""Access the underlying client."""
|
|
210
|
+
return self._client
|
|
211
|
+
|
|
212
|
+
def _check_daily_reset(self) -> None:
|
|
213
|
+
"""Reset daily counters if a new day has started."""
|
|
214
|
+
current_time = time.time()
|
|
215
|
+
# Reset if more than 24 hours since last reset
|
|
216
|
+
if current_time - self._last_reset > 86400:
|
|
217
|
+
self._daily_cost = 0.0
|
|
218
|
+
self._daily_tokens = 0
|
|
219
|
+
self._daily_requests = 0
|
|
220
|
+
self._last_reset = current_time
|
|
221
|
+
|
|
222
|
+
def _check_model_policy(self, model: str) -> PolicyDecision:
|
|
223
|
+
"""Check if the model is allowed."""
|
|
224
|
+
policy = self._config.model_policy
|
|
225
|
+
|
|
226
|
+
# Check denied list first
|
|
227
|
+
if policy.denied_models:
|
|
228
|
+
for denied in policy.denied_models:
|
|
229
|
+
if denied.lower() in model.lower():
|
|
230
|
+
return PolicyDecision(
|
|
231
|
+
effect="deny",
|
|
232
|
+
reason=f"Model '{model}' is not allowed by policy",
|
|
233
|
+
policy_id="model_deny_list"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Check allowed list if specified
|
|
237
|
+
if policy.allowed_models:
|
|
238
|
+
allowed = False
|
|
239
|
+
for allowed_model in policy.allowed_models:
|
|
240
|
+
if allowed_model.lower() in model.lower():
|
|
241
|
+
allowed = True
|
|
242
|
+
break
|
|
243
|
+
if not allowed:
|
|
244
|
+
return PolicyDecision(
|
|
245
|
+
effect="deny",
|
|
246
|
+
reason=f"Model '{model}' is not in the allowed models list",
|
|
247
|
+
policy_id="model_allow_list"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return PolicyDecision(effect="allow")
|
|
251
|
+
|
|
252
|
+
def _check_cost_policy(self, estimated_cost: float) -> PolicyDecision:
|
|
253
|
+
"""Check if the request is within cost limits."""
|
|
254
|
+
self._check_daily_reset()
|
|
255
|
+
policy = self._config.cost_policy
|
|
256
|
+
|
|
257
|
+
# Check per-request limit
|
|
258
|
+
if policy.max_cost_per_request:
|
|
259
|
+
if estimated_cost > policy.max_cost_per_request:
|
|
260
|
+
return PolicyDecision(
|
|
261
|
+
effect="deny",
|
|
262
|
+
reason=f"Estimated cost ${estimated_cost:.4f} exceeds per-request limit ${policy.max_cost_per_request:.4f}",
|
|
263
|
+
policy_id="cost_per_request"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Check daily limit
|
|
267
|
+
if policy.max_cost_per_day:
|
|
268
|
+
if self._daily_cost + estimated_cost > policy.max_cost_per_day:
|
|
269
|
+
return PolicyDecision(
|
|
270
|
+
effect="deny",
|
|
271
|
+
reason=f"Daily cost limit ${policy.max_cost_per_day:.2f} would be exceeded",
|
|
272
|
+
policy_id="cost_per_day"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Check daily request limit
|
|
276
|
+
if policy.max_requests_per_day:
|
|
277
|
+
if self._daily_requests >= policy.max_requests_per_day:
|
|
278
|
+
return PolicyDecision(
|
|
279
|
+
effect="deny",
|
|
280
|
+
reason=f"Daily request limit of {policy.max_requests_per_day} reached",
|
|
281
|
+
policy_id="requests_per_day"
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return PolicyDecision(effect="allow")
|
|
285
|
+
|
|
286
|
+
def _check_function_policy(self, functions: List[Dict[str, Any]]) -> PolicyDecision:
|
|
287
|
+
"""Check if function calls are allowed."""
|
|
288
|
+
policy = self._config.function_policy
|
|
289
|
+
|
|
290
|
+
for func in functions:
|
|
291
|
+
func_name = func.get("name", func.get("function", {}).get("name", ""))
|
|
292
|
+
|
|
293
|
+
# Check denied list
|
|
294
|
+
if policy.denied_functions:
|
|
295
|
+
for denied in policy.denied_functions:
|
|
296
|
+
if denied.lower() in func_name.lower():
|
|
297
|
+
return PolicyDecision(
|
|
298
|
+
effect="deny",
|
|
299
|
+
reason=f"Function '{func_name}' is not allowed",
|
|
300
|
+
policy_id="function_deny_list"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Check allowed list if specified
|
|
304
|
+
if policy.allowed_functions:
|
|
305
|
+
allowed = False
|
|
306
|
+
for allowed_func in policy.allowed_functions:
|
|
307
|
+
if allowed_func.lower() in func_name.lower():
|
|
308
|
+
allowed = True
|
|
309
|
+
break
|
|
310
|
+
if not allowed:
|
|
311
|
+
return PolicyDecision(
|
|
312
|
+
effect="deny",
|
|
313
|
+
reason=f"Function '{func_name}' is not in the allowed list",
|
|
314
|
+
policy_id="function_allow_list"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
return PolicyDecision(effect="allow")
|
|
318
|
+
|
|
319
|
+
def _detect_pii(self, text: str) -> List[Dict[str, Any]]:
|
|
320
|
+
"""
|
|
321
|
+
Simple PII detection.
|
|
322
|
+
In production, use a more sophisticated solution.
|
|
323
|
+
"""
|
|
324
|
+
import re
|
|
325
|
+
|
|
326
|
+
patterns = {
|
|
327
|
+
"email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
|
328
|
+
"phone": r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b',
|
|
329
|
+
"ssn": r'\b\d{3}[-]?\d{2}[-]?\d{4}\b',
|
|
330
|
+
"credit_card": r'\b(?:\d{4}[-\s]?){3}\d{4}\b',
|
|
331
|
+
"ip_address": r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b',
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
findings = []
|
|
335
|
+
for pii_type, pattern in patterns.items():
|
|
336
|
+
matches = re.findall(pattern, text)
|
|
337
|
+
for match in matches:
|
|
338
|
+
findings.append({"type": pii_type, "value": match})
|
|
339
|
+
|
|
340
|
+
return findings
|
|
341
|
+
|
|
342
|
+
def _mask_pii(self, text: str) -> str:
|
|
343
|
+
"""Mask detected PII in text."""
|
|
344
|
+
findings = self._detect_pii(text)
|
|
345
|
+
masked = text
|
|
346
|
+
for finding in findings:
|
|
347
|
+
pii_type = finding["type"]
|
|
348
|
+
value = finding["value"]
|
|
349
|
+
mask = f"[{pii_type.upper()}_MASKED]"
|
|
350
|
+
masked = masked.replace(value, mask)
|
|
351
|
+
return masked
|
|
352
|
+
|
|
353
|
+
def _pre_request_checks(
|
|
354
|
+
self,
|
|
355
|
+
model: str,
|
|
356
|
+
action: GovernanceAction,
|
|
357
|
+
messages: Optional[List[Dict[str, Any]]] = None,
|
|
358
|
+
functions: Optional[List[Dict[str, Any]]] = None,
|
|
359
|
+
estimated_tokens: int = 0,
|
|
360
|
+
) -> None:
|
|
361
|
+
"""
|
|
362
|
+
Run all pre-request governance checks.
|
|
363
|
+
Raises PolicyDeniedError if any check fails.
|
|
364
|
+
"""
|
|
365
|
+
# Check model policy
|
|
366
|
+
decision = self._check_model_policy(model)
|
|
367
|
+
if decision.effect == "deny":
|
|
368
|
+
self._log_denied(model, action, decision)
|
|
369
|
+
raise PolicyDeniedError(decision)
|
|
370
|
+
|
|
371
|
+
# Estimate cost and check cost policy
|
|
372
|
+
estimated_cost = estimate_cost(model, estimated_tokens, estimated_tokens // 2)
|
|
373
|
+
decision = self._check_cost_policy(estimated_cost)
|
|
374
|
+
if decision.effect == "deny":
|
|
375
|
+
self._log_denied(model, action, decision)
|
|
376
|
+
raise PolicyDeniedError(decision)
|
|
377
|
+
|
|
378
|
+
# Check function policy if functions provided
|
|
379
|
+
if functions:
|
|
380
|
+
decision = self._check_function_policy(functions)
|
|
381
|
+
if decision.effect == "deny":
|
|
382
|
+
self._log_denied(model, action, decision)
|
|
383
|
+
raise PolicyDeniedError(decision)
|
|
384
|
+
|
|
385
|
+
# Check content policy - PII detection
|
|
386
|
+
if self._config.content_policy.enable_pii_detection and messages:
|
|
387
|
+
for msg in messages:
|
|
388
|
+
content = msg.get("content", "")
|
|
389
|
+
if isinstance(content, str):
|
|
390
|
+
findings = self._detect_pii(content)
|
|
391
|
+
if findings:
|
|
392
|
+
pii_action = self._config.content_policy.pii_action
|
|
393
|
+
if pii_action == "block":
|
|
394
|
+
decision = PolicyDecision(
|
|
395
|
+
effect="deny",
|
|
396
|
+
reason=f"PII detected in request: {[f['type'] for f in findings]}",
|
|
397
|
+
policy_id="pii_detection"
|
|
398
|
+
)
|
|
399
|
+
self._log_denied(model, action, decision)
|
|
400
|
+
raise PolicyDeniedError(decision)
|
|
401
|
+
|
|
402
|
+
def _post_request_update(self, metrics: LLMUsageMetrics) -> None:
|
|
403
|
+
"""Update tracking after a successful request."""
|
|
404
|
+
self._session_metrics.append(metrics)
|
|
405
|
+
self._daily_cost += metrics.estimated_cost
|
|
406
|
+
self._daily_tokens += metrics.total_tokens
|
|
407
|
+
self._daily_requests += 1
|
|
408
|
+
|
|
409
|
+
def _log_request(
|
|
410
|
+
self,
|
|
411
|
+
model: str,
|
|
412
|
+
action: GovernanceAction,
|
|
413
|
+
metrics: LLMUsageMetrics,
|
|
414
|
+
status: str = "success",
|
|
415
|
+
error: Optional[str] = None,
|
|
416
|
+
) -> None:
|
|
417
|
+
"""Log request to Control Zero."""
|
|
418
|
+
if not self._config.enable_audit_logging:
|
|
419
|
+
return
|
|
420
|
+
|
|
421
|
+
try:
|
|
422
|
+
# Use Control Zero's logging
|
|
423
|
+
self._cz._log(
|
|
424
|
+
tool=f"llm:{self.provider_name}",
|
|
425
|
+
method=action.value,
|
|
426
|
+
status=status,
|
|
427
|
+
latency_ms=metrics.latency_ms,
|
|
428
|
+
error_type="LLMError" if error else None,
|
|
429
|
+
error_message=error,
|
|
430
|
+
)
|
|
431
|
+
except Exception:
|
|
432
|
+
pass # Don't fail on logging errors
|
|
433
|
+
|
|
434
|
+
def _log_denied(
|
|
435
|
+
self,
|
|
436
|
+
model: str,
|
|
437
|
+
action: GovernanceAction,
|
|
438
|
+
decision: PolicyDecision,
|
|
439
|
+
) -> None:
|
|
440
|
+
"""Log a denied request."""
|
|
441
|
+
if not self._config.enable_audit_logging:
|
|
442
|
+
return
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
self._cz._log(
|
|
446
|
+
tool=f"llm:{self.provider_name}",
|
|
447
|
+
method=action.value,
|
|
448
|
+
status="denied",
|
|
449
|
+
latency_ms=0,
|
|
450
|
+
policy_decision=decision,
|
|
451
|
+
)
|
|
452
|
+
except Exception:
|
|
453
|
+
pass
|
|
454
|
+
|
|
455
|
+
def get_session_metrics(self) -> Dict[str, Any]:
|
|
456
|
+
"""Get aggregated metrics for the current session."""
|
|
457
|
+
total_cost = sum(m.estimated_cost for m in self._session_metrics)
|
|
458
|
+
total_tokens = sum(m.total_tokens for m in self._session_metrics)
|
|
459
|
+
total_requests = len(self._session_metrics)
|
|
460
|
+
|
|
461
|
+
return {
|
|
462
|
+
"provider": self.provider_name,
|
|
463
|
+
"total_cost": total_cost,
|
|
464
|
+
"total_tokens": total_tokens,
|
|
465
|
+
"total_requests": total_requests,
|
|
466
|
+
"daily_cost": self._daily_cost,
|
|
467
|
+
"daily_tokens": self._daily_tokens,
|
|
468
|
+
"daily_requests": self._daily_requests,
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
def reset_session_metrics(self) -> None:
|
|
472
|
+
"""Reset session metrics."""
|
|
473
|
+
self._session_metrics = []
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
class GovernedChatMixin:
|
|
477
|
+
"""Mixin providing governed chat completion functionality."""
|
|
478
|
+
|
|
479
|
+
def _process_messages_for_governance(
|
|
480
|
+
self,
|
|
481
|
+
messages: List[Dict[str, Any]],
|
|
482
|
+
) -> List[Dict[str, Any]]:
|
|
483
|
+
"""Process messages according to governance policies."""
|
|
484
|
+
processed = []
|
|
485
|
+
|
|
486
|
+
# Add required system prompt if configured
|
|
487
|
+
content_policy = getattr(self, '_config', LLMGovernanceConfig()).content_policy
|
|
488
|
+
if content_policy.required_system_prompt:
|
|
489
|
+
has_system = any(m.get("role") == "system" for m in messages)
|
|
490
|
+
if not has_system:
|
|
491
|
+
processed.append({
|
|
492
|
+
"role": "system",
|
|
493
|
+
"content": content_policy.required_system_prompt
|
|
494
|
+
})
|
|
495
|
+
|
|
496
|
+
# Process each message
|
|
497
|
+
for msg in messages:
|
|
498
|
+
new_msg = msg.copy()
|
|
499
|
+
content = msg.get("content", "")
|
|
500
|
+
|
|
501
|
+
# Mask PII if enabled
|
|
502
|
+
if content_policy.enable_pii_detection and content_policy.pii_action == "mask":
|
|
503
|
+
if isinstance(content, str):
|
|
504
|
+
governed_self = getattr(self, '_mask_pii', lambda x: x)
|
|
505
|
+
new_msg["content"] = governed_self(content)
|
|
506
|
+
|
|
507
|
+
processed.append(new_msg)
|
|
508
|
+
|
|
509
|
+
return processed
|
|
510
|
+
|
|
511
|
+
def _filter_functions_for_governance(
|
|
512
|
+
self,
|
|
513
|
+
functions: Optional[List[Dict[str, Any]]],
|
|
514
|
+
) -> Optional[List[Dict[str, Any]]]:
|
|
515
|
+
"""Filter functions/tools according to governance policies."""
|
|
516
|
+
if not functions:
|
|
517
|
+
return functions
|
|
518
|
+
|
|
519
|
+
config = getattr(self, '_config', LLMGovernanceConfig())
|
|
520
|
+
policy = config.function_policy
|
|
521
|
+
|
|
522
|
+
if not policy.allowed_functions and not policy.denied_functions:
|
|
523
|
+
return functions
|
|
524
|
+
|
|
525
|
+
filtered = []
|
|
526
|
+
for func in functions:
|
|
527
|
+
func_name = func.get("name", func.get("function", {}).get("name", ""))
|
|
528
|
+
|
|
529
|
+
# Skip denied
|
|
530
|
+
if policy.denied_functions:
|
|
531
|
+
denied = False
|
|
532
|
+
for denied_name in policy.denied_functions:
|
|
533
|
+
if denied_name.lower() in func_name.lower():
|
|
534
|
+
denied = True
|
|
535
|
+
break
|
|
536
|
+
if denied:
|
|
537
|
+
continue
|
|
538
|
+
|
|
539
|
+
# Check allowed
|
|
540
|
+
if policy.allowed_functions:
|
|
541
|
+
allowed = False
|
|
542
|
+
for allowed_name in policy.allowed_functions:
|
|
543
|
+
if allowed_name.lower() in func_name.lower():
|
|
544
|
+
allowed = True
|
|
545
|
+
break
|
|
546
|
+
if not allowed:
|
|
547
|
+
continue
|
|
548
|
+
|
|
549
|
+
filtered.append(func)
|
|
550
|
+
|
|
551
|
+
return filtered if filtered else None
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Control Zero Cohere Governance Wrapper.
|
|
3
|
+
|
|
4
|
+
This module provides governance wrappers for the Cohere Python SDK,
|
|
5
|
+
enabling policy enforcement, cost tracking, and audit logging for
|
|
6
|
+
Cohere API calls including chat, RAG, and reranking.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from control_zero import ControlZeroClient
|
|
10
|
+
from control_zero.llm.cohere import GovernedCohere
|
|
11
|
+
import cohere
|
|
12
|
+
|
|
13
|
+
# Initialize Control Zero
|
|
14
|
+
cz_client = ControlZeroClient(api_key="cz_live_xxx")
|
|
15
|
+
cz_client.initialize()
|
|
16
|
+
|
|
17
|
+
# Wrap Cohere client with governance
|
|
18
|
+
co_client = cohere.ClientV2()
|
|
19
|
+
governed = GovernedCohere(client=co_client, control_zero=cz_client)
|
|
20
|
+
|
|
21
|
+
# All calls are now governed
|
|
22
|
+
response = governed.chat(
|
|
23
|
+
model="command-r-plus",
|
|
24
|
+
messages=[{"role": "user", "content": "Hello"}]
|
|
25
|
+
)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from control_zero.llm.cohere.client import GovernedCohere
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"GovernedCohere",
|
|
32
|
+
]
|