control-zero 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- control_zero/__init__.py +31 -0
- control_zero/client.py +584 -0
- control_zero/integrations/crewai/__init__.py +53 -0
- control_zero/integrations/crewai/agent.py +267 -0
- control_zero/integrations/crewai/crew.py +381 -0
- control_zero/integrations/crewai/task.py +291 -0
- control_zero/integrations/crewai/tool.py +299 -0
- control_zero/integrations/langchain/__init__.py +58 -0
- control_zero/integrations/langchain/agent.py +311 -0
- control_zero/integrations/langchain/callbacks.py +441 -0
- control_zero/integrations/langchain/chain.py +319 -0
- control_zero/integrations/langchain/graph.py +441 -0
- control_zero/integrations/langchain/tool.py +271 -0
- control_zero/llm/__init__.py +77 -0
- control_zero/llm/anthropic/__init__.py +35 -0
- control_zero/llm/anthropic/client.py +136 -0
- control_zero/llm/anthropic/messages.py +375 -0
- control_zero/llm/base.py +551 -0
- control_zero/llm/cohere/__init__.py +32 -0
- control_zero/llm/cohere/client.py +402 -0
- control_zero/llm/gemini/__init__.py +34 -0
- control_zero/llm/gemini/client.py +486 -0
- control_zero/llm/groq/__init__.py +32 -0
- control_zero/llm/groq/client.py +330 -0
- control_zero/llm/mistral/__init__.py +32 -0
- control_zero/llm/mistral/client.py +319 -0
- control_zero/llm/ollama/__init__.py +31 -0
- control_zero/llm/ollama/client.py +439 -0
- control_zero/llm/openai/__init__.py +34 -0
- control_zero/llm/openai/chat.py +331 -0
- control_zero/llm/openai/client.py +182 -0
- control_zero/logging/__init__.py +5 -0
- control_zero/logging/async_logger.py +65 -0
- control_zero/mcp/__init__.py +5 -0
- control_zero/mcp/middleware.py +148 -0
- control_zero/policy/__init__.py +5 -0
- control_zero/policy/enforcer.py +99 -0
- control_zero/secrets/__init__.py +5 -0
- control_zero/secrets/manager.py +77 -0
- control_zero/types.py +51 -0
- control_zero-0.2.0.dist-info/METADATA +216 -0
- control_zero-0.2.0.dist-info/RECORD +44 -0
- control_zero-0.2.0.dist-info/WHEEL +4 -0
- control_zero-0.2.0.dist-info/licenses/LICENSE +17 -0
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Governed Cohere client wrapper.
|
|
3
|
+
|
|
4
|
+
Provides governance features for the Cohere Python SDK including:
|
|
5
|
+
- Model access control
|
|
6
|
+
- Cost tracking and limits
|
|
7
|
+
- Tool use governance
|
|
8
|
+
- PII detection and masking
|
|
9
|
+
- Audit logging
|
|
10
|
+
|
|
11
|
+
Supports chat, RAG (retrieval-augmented generation), and reranking.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import time
|
|
15
|
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
from control_zero.llm.base import (
|
|
18
|
+
GovernanceAction,
|
|
19
|
+
GovernedLLM,
|
|
20
|
+
GovernedChatMixin,
|
|
21
|
+
LLMGovernanceConfig,
|
|
22
|
+
LLMUsageMetrics,
|
|
23
|
+
estimate_cost,
|
|
24
|
+
)
|
|
25
|
+
from control_zero.policy import PolicyDeniedError
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GovernedCohere(GovernedLLM, GovernedChatMixin):
|
|
29
|
+
"""
|
|
30
|
+
Governed wrapper for the Cohere Python SDK.
|
|
31
|
+
|
|
32
|
+
Supports:
|
|
33
|
+
- Chat API (command models)
|
|
34
|
+
- RAG with connectors
|
|
35
|
+
- Reranking
|
|
36
|
+
- Embeddings
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
from control_zero import ControlZeroClient
|
|
40
|
+
from control_zero.llm.cohere import GovernedCohere
|
|
41
|
+
import cohere
|
|
42
|
+
|
|
43
|
+
cz = ControlZeroClient(api_key="...")
|
|
44
|
+
cz.initialize()
|
|
45
|
+
|
|
46
|
+
client = cohere.ClientV2()
|
|
47
|
+
governed = GovernedCohere(client=client, control_zero=cz)
|
|
48
|
+
|
|
49
|
+
response = governed.chat(
|
|
50
|
+
model="command-r-plus",
|
|
51
|
+
messages=[{"role": "user", "content": "Hello!"}]
|
|
52
|
+
)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
client: Any,
|
|
58
|
+
control_zero: Any,
|
|
59
|
+
config: Optional[LLMGovernanceConfig] = None,
|
|
60
|
+
user_context: Optional[Dict[str, Any]] = None,
|
|
61
|
+
):
|
|
62
|
+
super().__init__(client, control_zero, config, user_context)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def provider_name(self) -> str:
|
|
66
|
+
return "cohere"
|
|
67
|
+
|
|
68
|
+
def chat(
|
|
69
|
+
self,
|
|
70
|
+
*,
|
|
71
|
+
model: str,
|
|
72
|
+
messages: List[Dict[str, Any]],
|
|
73
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
74
|
+
documents: Optional[List[Dict[str, Any]]] = None,
|
|
75
|
+
connectors: Optional[List[Dict[str, Any]]] = None,
|
|
76
|
+
stream: bool = False,
|
|
77
|
+
max_tokens: Optional[int] = None,
|
|
78
|
+
temperature: Optional[float] = None,
|
|
79
|
+
p: Optional[float] = None,
|
|
80
|
+
k: Optional[int] = None,
|
|
81
|
+
stop_sequences: Optional[List[str]] = None,
|
|
82
|
+
seed: Optional[int] = None,
|
|
83
|
+
**kwargs,
|
|
84
|
+
) -> Any:
|
|
85
|
+
"""
|
|
86
|
+
Create a governed chat completion.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
model: Model to use (e.g., "command-r-plus")
|
|
90
|
+
messages: List of message dicts
|
|
91
|
+
tools: Tool definitions
|
|
92
|
+
documents: Documents for RAG
|
|
93
|
+
connectors: Connector configurations
|
|
94
|
+
stream: Whether to stream
|
|
95
|
+
max_tokens: Maximum output tokens
|
|
96
|
+
temperature: Sampling temperature
|
|
97
|
+
p: Nucleus sampling
|
|
98
|
+
k: Top-k sampling
|
|
99
|
+
stop_sequences: Stop sequences
|
|
100
|
+
seed: Random seed
|
|
101
|
+
**kwargs: Additional parameters
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
ChatResponse or stream iterator
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
PolicyDeniedError: If request violates policy
|
|
108
|
+
"""
|
|
109
|
+
start_time = time.time()
|
|
110
|
+
|
|
111
|
+
# Estimate tokens
|
|
112
|
+
estimated_input_tokens = self._estimate_message_tokens(messages)
|
|
113
|
+
if documents:
|
|
114
|
+
for doc in documents:
|
|
115
|
+
estimated_input_tokens += len(str(doc.get("text", ""))) // 4
|
|
116
|
+
|
|
117
|
+
# Prepare tools for policy check
|
|
118
|
+
tools_to_check = []
|
|
119
|
+
if tools:
|
|
120
|
+
tools_to_check = [{"name": t.get("name", ""), "type": "function"} for t in tools]
|
|
121
|
+
|
|
122
|
+
# Run governance checks
|
|
123
|
+
self._pre_request_checks(
|
|
124
|
+
model=model,
|
|
125
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
126
|
+
messages=messages,
|
|
127
|
+
functions=tools_to_check,
|
|
128
|
+
estimated_tokens=estimated_input_tokens,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Process messages for governance
|
|
132
|
+
processed_messages = self._process_messages_for_governance(messages)
|
|
133
|
+
|
|
134
|
+
# Filter tools
|
|
135
|
+
filtered_tools = self._filter_tools_for_governance(tools)
|
|
136
|
+
|
|
137
|
+
# Apply max_tokens limit
|
|
138
|
+
governed_max_tokens = max_tokens
|
|
139
|
+
if self._config.content_policy.max_output_tokens:
|
|
140
|
+
if max_tokens:
|
|
141
|
+
governed_max_tokens = min(max_tokens, self._config.content_policy.max_output_tokens)
|
|
142
|
+
else:
|
|
143
|
+
governed_max_tokens = self._config.content_policy.max_output_tokens
|
|
144
|
+
|
|
145
|
+
# Build request
|
|
146
|
+
request_kwargs = {
|
|
147
|
+
"model": model,
|
|
148
|
+
"messages": processed_messages,
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
if filtered_tools:
|
|
152
|
+
request_kwargs["tools"] = filtered_tools
|
|
153
|
+
if documents:
|
|
154
|
+
request_kwargs["documents"] = documents
|
|
155
|
+
if connectors:
|
|
156
|
+
request_kwargs["connectors"] = connectors
|
|
157
|
+
if governed_max_tokens:
|
|
158
|
+
request_kwargs["max_tokens"] = governed_max_tokens
|
|
159
|
+
if temperature is not None:
|
|
160
|
+
request_kwargs["temperature"] = temperature
|
|
161
|
+
if p is not None:
|
|
162
|
+
request_kwargs["p"] = p
|
|
163
|
+
if k is not None:
|
|
164
|
+
request_kwargs["k"] = k
|
|
165
|
+
if stop_sequences:
|
|
166
|
+
request_kwargs["stop_sequences"] = stop_sequences
|
|
167
|
+
if seed is not None:
|
|
168
|
+
request_kwargs["seed"] = seed
|
|
169
|
+
|
|
170
|
+
request_kwargs.update(kwargs)
|
|
171
|
+
|
|
172
|
+
# Handle streaming
|
|
173
|
+
if stream:
|
|
174
|
+
return self._create_stream(request_kwargs, start_time, model)
|
|
175
|
+
|
|
176
|
+
# Make API call
|
|
177
|
+
try:
|
|
178
|
+
response = self._client.chat(**request_kwargs)
|
|
179
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
180
|
+
|
|
181
|
+
# Extract metrics
|
|
182
|
+
usage = getattr(response, "usage", None)
|
|
183
|
+
input_tokens = getattr(usage, "billed_units", {}).get("input_tokens", estimated_input_tokens) if usage else estimated_input_tokens
|
|
184
|
+
output_tokens = getattr(usage, "billed_units", {}).get("output_tokens", 0) if usage else 0
|
|
185
|
+
total_tokens = input_tokens + output_tokens
|
|
186
|
+
|
|
187
|
+
# Count tool calls
|
|
188
|
+
tool_call_count = 0
|
|
189
|
+
if hasattr(response, "tool_calls") and response.tool_calls:
|
|
190
|
+
tool_call_count = len(response.tool_calls)
|
|
191
|
+
|
|
192
|
+
metrics = LLMUsageMetrics(
|
|
193
|
+
provider="cohere",
|
|
194
|
+
model=model,
|
|
195
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
196
|
+
input_tokens=input_tokens,
|
|
197
|
+
output_tokens=output_tokens,
|
|
198
|
+
total_tokens=total_tokens,
|
|
199
|
+
latency_ms=latency_ms,
|
|
200
|
+
estimated_cost=estimate_cost(model, input_tokens, output_tokens),
|
|
201
|
+
function_calls=tool_call_count,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
self._post_request_update(metrics)
|
|
205
|
+
self._log_request(model, GovernanceAction.CHAT_COMPLETION, metrics)
|
|
206
|
+
|
|
207
|
+
return response
|
|
208
|
+
|
|
209
|
+
except PolicyDeniedError:
|
|
210
|
+
raise
|
|
211
|
+
except Exception as e:
|
|
212
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
213
|
+
metrics = LLMUsageMetrics(
|
|
214
|
+
provider="cohere",
|
|
215
|
+
model=model,
|
|
216
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
217
|
+
latency_ms=latency_ms,
|
|
218
|
+
)
|
|
219
|
+
self._log_request(
|
|
220
|
+
model, GovernanceAction.CHAT_COMPLETION, metrics,
|
|
221
|
+
status="error", error=str(e)
|
|
222
|
+
)
|
|
223
|
+
raise
|
|
224
|
+
|
|
225
|
+
def _create_stream(
|
|
226
|
+
self,
|
|
227
|
+
request_kwargs: Dict[str, Any],
|
|
228
|
+
start_time: float,
|
|
229
|
+
model: str,
|
|
230
|
+
) -> Iterator[Any]:
|
|
231
|
+
"""Create a governed streaming response."""
|
|
232
|
+
total_tokens = 0
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
stream = self._client.chat_stream(**request_kwargs)
|
|
236
|
+
|
|
237
|
+
for event in stream:
|
|
238
|
+
total_tokens += 1
|
|
239
|
+
yield event
|
|
240
|
+
|
|
241
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
242
|
+
|
|
243
|
+
metrics = LLMUsageMetrics(
|
|
244
|
+
provider="cohere",
|
|
245
|
+
model=model,
|
|
246
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
247
|
+
total_tokens=total_tokens,
|
|
248
|
+
latency_ms=latency_ms,
|
|
249
|
+
estimated_cost=estimate_cost(model, total_tokens // 2, total_tokens // 2),
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
self._post_request_update(metrics)
|
|
253
|
+
self._log_request(model, GovernanceAction.CHAT_COMPLETION, metrics)
|
|
254
|
+
|
|
255
|
+
except Exception as e:
|
|
256
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
257
|
+
metrics = LLMUsageMetrics(
|
|
258
|
+
provider="cohere",
|
|
259
|
+
model=model,
|
|
260
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
261
|
+
latency_ms=latency_ms,
|
|
262
|
+
)
|
|
263
|
+
self._log_request(
|
|
264
|
+
model, GovernanceAction.CHAT_COMPLETION, metrics,
|
|
265
|
+
status="error", error=str(e)
|
|
266
|
+
)
|
|
267
|
+
raise
|
|
268
|
+
|
|
269
|
+
def rerank(
|
|
270
|
+
self,
|
|
271
|
+
*,
|
|
272
|
+
model: str,
|
|
273
|
+
query: str,
|
|
274
|
+
documents: List[Union[str, Dict[str, Any]]],
|
|
275
|
+
top_n: Optional[int] = None,
|
|
276
|
+
return_documents: bool = True,
|
|
277
|
+
**kwargs,
|
|
278
|
+
) -> Any:
|
|
279
|
+
"""
|
|
280
|
+
Rerank documents with governance.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
model: Rerank model (e.g., "rerank-v3.5")
|
|
284
|
+
query: Query to rerank against
|
|
285
|
+
documents: Documents to rerank
|
|
286
|
+
top_n: Number of results to return
|
|
287
|
+
return_documents: Whether to return document content
|
|
288
|
+
**kwargs: Additional parameters
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
RerankResponse
|
|
292
|
+
"""
|
|
293
|
+
start_time = time.time()
|
|
294
|
+
|
|
295
|
+
# Governance checks
|
|
296
|
+
estimated_tokens = len(query) // 4 + sum(len(str(d)) // 4 for d in documents)
|
|
297
|
+
|
|
298
|
+
self._pre_request_checks(
|
|
299
|
+
model=model,
|
|
300
|
+
action=GovernanceAction.MODERATION,
|
|
301
|
+
estimated_tokens=estimated_tokens,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
try:
|
|
305
|
+
response = self._client.rerank(
|
|
306
|
+
model=model,
|
|
307
|
+
query=query,
|
|
308
|
+
documents=documents,
|
|
309
|
+
top_n=top_n,
|
|
310
|
+
return_documents=return_documents,
|
|
311
|
+
**kwargs,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
315
|
+
|
|
316
|
+
metrics = LLMUsageMetrics(
|
|
317
|
+
provider="cohere",
|
|
318
|
+
model=model,
|
|
319
|
+
action=GovernanceAction.MODERATION,
|
|
320
|
+
input_tokens=estimated_tokens,
|
|
321
|
+
latency_ms=latency_ms,
|
|
322
|
+
estimated_cost=0.0, # Rerank pricing is different
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
self._post_request_update(metrics)
|
|
326
|
+
self._log_request(model, GovernanceAction.MODERATION, metrics)
|
|
327
|
+
|
|
328
|
+
return response
|
|
329
|
+
|
|
330
|
+
except Exception as e:
|
|
331
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
332
|
+
metrics = LLMUsageMetrics(
|
|
333
|
+
provider="cohere",
|
|
334
|
+
model=model,
|
|
335
|
+
action=GovernanceAction.MODERATION,
|
|
336
|
+
latency_ms=latency_ms,
|
|
337
|
+
)
|
|
338
|
+
self._log_request(
|
|
339
|
+
model, GovernanceAction.MODERATION, metrics,
|
|
340
|
+
status="error", error=str(e)
|
|
341
|
+
)
|
|
342
|
+
raise
|
|
343
|
+
|
|
344
|
+
@property
|
|
345
|
+
def embed(self):
|
|
346
|
+
"""Access the underlying embed endpoint (pass-through)."""
|
|
347
|
+
return self._client.embed
|
|
348
|
+
|
|
349
|
+
def _filter_tools_for_governance(
|
|
350
|
+
self,
|
|
351
|
+
tools: Optional[List[Dict[str, Any]]],
|
|
352
|
+
) -> Optional[List[Dict[str, Any]]]:
|
|
353
|
+
"""Filter tools according to governance policies."""
|
|
354
|
+
if not tools:
|
|
355
|
+
return tools
|
|
356
|
+
|
|
357
|
+
policy = self._config.function_policy
|
|
358
|
+
|
|
359
|
+
if not policy.allowed_functions and not policy.denied_functions:
|
|
360
|
+
return tools
|
|
361
|
+
|
|
362
|
+
filtered = []
|
|
363
|
+
for tool in tools:
|
|
364
|
+
tool_name = tool.get("name", "")
|
|
365
|
+
|
|
366
|
+
# Skip denied
|
|
367
|
+
if policy.denied_functions:
|
|
368
|
+
denied = any(d.lower() in tool_name.lower() for d in policy.denied_functions)
|
|
369
|
+
if denied:
|
|
370
|
+
continue
|
|
371
|
+
|
|
372
|
+
# Check allowed
|
|
373
|
+
if policy.allowed_functions:
|
|
374
|
+
allowed = any(a.lower() in tool_name.lower() for a in policy.allowed_functions)
|
|
375
|
+
if not allowed:
|
|
376
|
+
continue
|
|
377
|
+
|
|
378
|
+
filtered.append(tool)
|
|
379
|
+
|
|
380
|
+
return filtered if filtered else None
|
|
381
|
+
|
|
382
|
+
def _estimate_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
|
383
|
+
"""Estimate token count for messages."""
|
|
384
|
+
total_chars = sum(len(str(m.get("content", ""))) for m in messages)
|
|
385
|
+
return max(1, total_chars // 4)
|
|
386
|
+
|
|
387
|
+
def with_user_context(self, user_context: Dict[str, Any]) -> "GovernedCohere":
|
|
388
|
+
merged_context = {**self._user_context, **user_context}
|
|
389
|
+
return GovernedCohere(
|
|
390
|
+
client=self._client,
|
|
391
|
+
control_zero=self._cz,
|
|
392
|
+
config=self._config,
|
|
393
|
+
user_context=merged_context,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def with_config(self, config: LLMGovernanceConfig) -> "GovernedCohere":
|
|
397
|
+
return GovernedCohere(
|
|
398
|
+
client=self._client,
|
|
399
|
+
control_zero=self._cz,
|
|
400
|
+
config=config,
|
|
401
|
+
user_context=self._user_context,
|
|
402
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Control Zero Google Gemini Governance Wrapper.
|
|
3
|
+
|
|
4
|
+
This module provides governance wrappers for the Google Generative AI SDK,
|
|
5
|
+
enabling policy enforcement, cost tracking, and audit logging for
|
|
6
|
+
all Gemini API calls.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from control_zero import ControlZeroClient
|
|
10
|
+
from control_zero.llm.gemini import GovernedGemini
|
|
11
|
+
import google.generativeai as genai
|
|
12
|
+
|
|
13
|
+
# Initialize Control Zero
|
|
14
|
+
cz_client = ControlZeroClient(api_key="cz_live_xxx")
|
|
15
|
+
cz_client.initialize()
|
|
16
|
+
|
|
17
|
+
# Configure Gemini
|
|
18
|
+
genai.configure(api_key="your-gemini-api-key")
|
|
19
|
+
|
|
20
|
+
# Create governed model
|
|
21
|
+
governed = GovernedGemini(
|
|
22
|
+
model_name="gemini-1.5-pro",
|
|
23
|
+
control_zero=cz_client
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# All calls are now governed
|
|
27
|
+
response = governed.generate_content("Hello, how are you?")
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from control_zero.llm.gemini.client import GovernedGemini
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
"GovernedGemini",
|
|
34
|
+
]
|