control-zero 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- control_zero/__init__.py +31 -0
- control_zero/client.py +584 -0
- control_zero/integrations/crewai/__init__.py +53 -0
- control_zero/integrations/crewai/agent.py +267 -0
- control_zero/integrations/crewai/crew.py +381 -0
- control_zero/integrations/crewai/task.py +291 -0
- control_zero/integrations/crewai/tool.py +299 -0
- control_zero/integrations/langchain/__init__.py +58 -0
- control_zero/integrations/langchain/agent.py +311 -0
- control_zero/integrations/langchain/callbacks.py +441 -0
- control_zero/integrations/langchain/chain.py +319 -0
- control_zero/integrations/langchain/graph.py +441 -0
- control_zero/integrations/langchain/tool.py +271 -0
- control_zero/llm/__init__.py +77 -0
- control_zero/llm/anthropic/__init__.py +35 -0
- control_zero/llm/anthropic/client.py +136 -0
- control_zero/llm/anthropic/messages.py +375 -0
- control_zero/llm/base.py +551 -0
- control_zero/llm/cohere/__init__.py +32 -0
- control_zero/llm/cohere/client.py +402 -0
- control_zero/llm/gemini/__init__.py +34 -0
- control_zero/llm/gemini/client.py +486 -0
- control_zero/llm/groq/__init__.py +32 -0
- control_zero/llm/groq/client.py +330 -0
- control_zero/llm/mistral/__init__.py +32 -0
- control_zero/llm/mistral/client.py +319 -0
- control_zero/llm/ollama/__init__.py +31 -0
- control_zero/llm/ollama/client.py +439 -0
- control_zero/llm/openai/__init__.py +34 -0
- control_zero/llm/openai/chat.py +331 -0
- control_zero/llm/openai/client.py +182 -0
- control_zero/logging/__init__.py +5 -0
- control_zero/logging/async_logger.py +65 -0
- control_zero/mcp/__init__.py +5 -0
- control_zero/mcp/middleware.py +148 -0
- control_zero/policy/__init__.py +5 -0
- control_zero/policy/enforcer.py +99 -0
- control_zero/secrets/__init__.py +5 -0
- control_zero/secrets/manager.py +77 -0
- control_zero/types.py +51 -0
- control_zero-0.2.0.dist-info/METADATA +216 -0
- control_zero-0.2.0.dist-info/RECORD +44 -0
- control_zero-0.2.0.dist-info/WHEEL +4 -0
- control_zero-0.2.0.dist-info/licenses/LICENSE +17 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Governed Groq client wrapper.
|
|
3
|
+
|
|
4
|
+
Provides governance features for the Groq Python SDK including:
|
|
5
|
+
- Model access control
|
|
6
|
+
- Cost tracking and limits
|
|
7
|
+
- Function calling governance
|
|
8
|
+
- PII detection and masking
|
|
9
|
+
- Audit logging
|
|
10
|
+
|
|
11
|
+
Groq provides fast inference for open-source models like Llama, Mixtral, etc.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import time
|
|
15
|
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
from control_zero.llm.base import (
|
|
18
|
+
GovernanceAction,
|
|
19
|
+
GovernedLLM,
|
|
20
|
+
GovernedChatMixin,
|
|
21
|
+
LLMGovernanceConfig,
|
|
22
|
+
LLMUsageMetrics,
|
|
23
|
+
estimate_cost,
|
|
24
|
+
)
|
|
25
|
+
from control_zero.policy import PolicyDeniedError
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GovernedChatCompletions(GovernedChatMixin):
|
|
29
|
+
"""Governed wrapper for Groq chat completions."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, governed_client: "GovernedGroq"):
|
|
32
|
+
self._governed = governed_client
|
|
33
|
+
self._client = governed_client._client
|
|
34
|
+
|
|
35
|
+
def create(
|
|
36
|
+
self,
|
|
37
|
+
*,
|
|
38
|
+
model: str,
|
|
39
|
+
messages: List[Dict[str, Any]],
|
|
40
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
41
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
42
|
+
stream: bool = False,
|
|
43
|
+
max_tokens: Optional[int] = None,
|
|
44
|
+
temperature: Optional[float] = None,
|
|
45
|
+
top_p: Optional[float] = None,
|
|
46
|
+
n: Optional[int] = None,
|
|
47
|
+
stop: Optional[Union[str, List[str]]] = None,
|
|
48
|
+
presence_penalty: Optional[float] = None,
|
|
49
|
+
frequency_penalty: Optional[float] = None,
|
|
50
|
+
user: Optional[str] = None,
|
|
51
|
+
response_format: Optional[Dict[str, str]] = None,
|
|
52
|
+
seed: Optional[int] = None,
|
|
53
|
+
**kwargs,
|
|
54
|
+
) -> Any:
|
|
55
|
+
"""
|
|
56
|
+
Create a governed chat completion.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
model: Model to use (e.g., "llama-3.1-70b-versatile")
|
|
60
|
+
messages: List of message dicts
|
|
61
|
+
tools: Tool definitions
|
|
62
|
+
tool_choice: Tool selection mode
|
|
63
|
+
stream: Whether to stream
|
|
64
|
+
max_tokens: Maximum output tokens
|
|
65
|
+
temperature: Sampling temperature
|
|
66
|
+
top_p: Nucleus sampling
|
|
67
|
+
n: Number of completions
|
|
68
|
+
stop: Stop sequences
|
|
69
|
+
presence_penalty: Presence penalty
|
|
70
|
+
frequency_penalty: Frequency penalty
|
|
71
|
+
user: End-user ID
|
|
72
|
+
response_format: Response format
|
|
73
|
+
seed: Random seed
|
|
74
|
+
**kwargs: Additional parameters
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
ChatCompletion response or stream iterator
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
PolicyDeniedError: If request violates policy
|
|
81
|
+
"""
|
|
82
|
+
start_time = time.time()
|
|
83
|
+
|
|
84
|
+
# Estimate tokens
|
|
85
|
+
estimated_input_tokens = self._estimate_message_tokens(messages)
|
|
86
|
+
|
|
87
|
+
# Prepare tools for policy check
|
|
88
|
+
tools_to_check = []
|
|
89
|
+
if tools:
|
|
90
|
+
tools_to_check = [{"name": t.get("function", {}).get("name", ""), "type": "function"} for t in tools]
|
|
91
|
+
|
|
92
|
+
# Run governance checks
|
|
93
|
+
self._governed._pre_request_checks(
|
|
94
|
+
model=model,
|
|
95
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
96
|
+
messages=messages,
|
|
97
|
+
functions=tools_to_check,
|
|
98
|
+
estimated_tokens=estimated_input_tokens,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Process messages for governance
|
|
102
|
+
processed_messages = self._process_messages_for_governance(messages)
|
|
103
|
+
|
|
104
|
+
# Filter tools
|
|
105
|
+
filtered_tools = self._filter_functions_for_governance(tools)
|
|
106
|
+
|
|
107
|
+
# Apply max_tokens limit
|
|
108
|
+
governed_max_tokens = max_tokens
|
|
109
|
+
if self._governed._config.content_policy.max_output_tokens:
|
|
110
|
+
if max_tokens:
|
|
111
|
+
governed_max_tokens = min(max_tokens, self._governed._config.content_policy.max_output_tokens)
|
|
112
|
+
else:
|
|
113
|
+
governed_max_tokens = self._governed._config.content_policy.max_output_tokens
|
|
114
|
+
|
|
115
|
+
# Build request
|
|
116
|
+
request_kwargs = {
|
|
117
|
+
"model": model,
|
|
118
|
+
"messages": processed_messages,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
if filtered_tools:
|
|
122
|
+
request_kwargs["tools"] = filtered_tools
|
|
123
|
+
if tool_choice and filtered_tools:
|
|
124
|
+
request_kwargs["tool_choice"] = tool_choice
|
|
125
|
+
if governed_max_tokens:
|
|
126
|
+
request_kwargs["max_tokens"] = governed_max_tokens
|
|
127
|
+
if temperature is not None:
|
|
128
|
+
request_kwargs["temperature"] = temperature
|
|
129
|
+
if top_p is not None:
|
|
130
|
+
request_kwargs["top_p"] = top_p
|
|
131
|
+
if n is not None:
|
|
132
|
+
request_kwargs["n"] = n
|
|
133
|
+
if stop is not None:
|
|
134
|
+
request_kwargs["stop"] = stop
|
|
135
|
+
if presence_penalty is not None:
|
|
136
|
+
request_kwargs["presence_penalty"] = presence_penalty
|
|
137
|
+
if frequency_penalty is not None:
|
|
138
|
+
request_kwargs["frequency_penalty"] = frequency_penalty
|
|
139
|
+
if user is not None:
|
|
140
|
+
request_kwargs["user"] = user
|
|
141
|
+
elif self._governed._user_context.get("user_id"):
|
|
142
|
+
request_kwargs["user"] = str(self._governed._user_context["user_id"])
|
|
143
|
+
if response_format is not None:
|
|
144
|
+
request_kwargs["response_format"] = response_format
|
|
145
|
+
if seed is not None:
|
|
146
|
+
request_kwargs["seed"] = seed
|
|
147
|
+
|
|
148
|
+
request_kwargs.update(kwargs)
|
|
149
|
+
|
|
150
|
+
# Handle streaming
|
|
151
|
+
if stream:
|
|
152
|
+
request_kwargs["stream"] = True
|
|
153
|
+
return self._create_stream(request_kwargs, start_time, model)
|
|
154
|
+
|
|
155
|
+
# Make API call
|
|
156
|
+
try:
|
|
157
|
+
response = self._client.chat.completions.create(**request_kwargs)
|
|
158
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
159
|
+
|
|
160
|
+
# Extract metrics
|
|
161
|
+
usage = getattr(response, "usage", None)
|
|
162
|
+
input_tokens = usage.prompt_tokens if usage else estimated_input_tokens
|
|
163
|
+
output_tokens = usage.completion_tokens if usage else 0
|
|
164
|
+
total_tokens = usage.total_tokens if usage else input_tokens + output_tokens
|
|
165
|
+
|
|
166
|
+
# Count tool calls
|
|
167
|
+
tool_call_count = 0
|
|
168
|
+
for choice in response.choices:
|
|
169
|
+
if hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
|
|
170
|
+
tool_call_count += len(choice.message.tool_calls)
|
|
171
|
+
|
|
172
|
+
metrics = LLMUsageMetrics(
|
|
173
|
+
provider="groq",
|
|
174
|
+
model=model,
|
|
175
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
176
|
+
input_tokens=input_tokens,
|
|
177
|
+
output_tokens=output_tokens,
|
|
178
|
+
total_tokens=total_tokens,
|
|
179
|
+
latency_ms=latency_ms,
|
|
180
|
+
estimated_cost=estimate_cost(model, input_tokens, output_tokens),
|
|
181
|
+
function_calls=tool_call_count,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self._governed._post_request_update(metrics)
|
|
185
|
+
self._governed._log_request(model, GovernanceAction.CHAT_COMPLETION, metrics)
|
|
186
|
+
|
|
187
|
+
return response
|
|
188
|
+
|
|
189
|
+
except PolicyDeniedError:
|
|
190
|
+
raise
|
|
191
|
+
except Exception as e:
|
|
192
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
193
|
+
metrics = LLMUsageMetrics(
|
|
194
|
+
provider="groq",
|
|
195
|
+
model=model,
|
|
196
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
197
|
+
latency_ms=latency_ms,
|
|
198
|
+
)
|
|
199
|
+
self._governed._log_request(
|
|
200
|
+
model, GovernanceAction.CHAT_COMPLETION, metrics,
|
|
201
|
+
status="error", error=str(e)
|
|
202
|
+
)
|
|
203
|
+
raise
|
|
204
|
+
|
|
205
|
+
def _create_stream(
|
|
206
|
+
self,
|
|
207
|
+
request_kwargs: Dict[str, Any],
|
|
208
|
+
start_time: float,
|
|
209
|
+
model: str,
|
|
210
|
+
) -> Iterator[Any]:
|
|
211
|
+
"""Create a governed streaming response."""
|
|
212
|
+
total_tokens = 0
|
|
213
|
+
tool_call_count = 0
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
stream = self._client.chat.completions.create(**request_kwargs)
|
|
217
|
+
|
|
218
|
+
for chunk in stream:
|
|
219
|
+
total_tokens += 1
|
|
220
|
+
|
|
221
|
+
if chunk.choices:
|
|
222
|
+
delta = chunk.choices[0].delta
|
|
223
|
+
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
224
|
+
tool_call_count += len(delta.tool_calls)
|
|
225
|
+
|
|
226
|
+
yield chunk
|
|
227
|
+
|
|
228
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
229
|
+
|
|
230
|
+
metrics = LLMUsageMetrics(
|
|
231
|
+
provider="groq",
|
|
232
|
+
model=model,
|
|
233
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
234
|
+
total_tokens=total_tokens,
|
|
235
|
+
latency_ms=latency_ms,
|
|
236
|
+
estimated_cost=estimate_cost(model, total_tokens // 2, total_tokens // 2),
|
|
237
|
+
function_calls=tool_call_count,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self._governed._post_request_update(metrics)
|
|
241
|
+
self._governed._log_request(model, GovernanceAction.CHAT_COMPLETION, metrics)
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
245
|
+
metrics = LLMUsageMetrics(
|
|
246
|
+
provider="groq",
|
|
247
|
+
model=model,
|
|
248
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
249
|
+
latency_ms=latency_ms,
|
|
250
|
+
)
|
|
251
|
+
self._governed._log_request(
|
|
252
|
+
model, GovernanceAction.CHAT_COMPLETION, metrics,
|
|
253
|
+
status="error", error=str(e)
|
|
254
|
+
)
|
|
255
|
+
raise
|
|
256
|
+
|
|
257
|
+
def _estimate_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
|
258
|
+
"""Estimate token count for messages."""
|
|
259
|
+
total_chars = sum(len(str(m.get("content", ""))) for m in messages)
|
|
260
|
+
return max(1, total_chars // 4)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class GovernedChat:
|
|
264
|
+
"""Governed wrapper for Groq chat namespace."""
|
|
265
|
+
|
|
266
|
+
def __init__(self, governed_client: "GovernedGroq"):
|
|
267
|
+
self._completions = GovernedChatCompletions(governed_client)
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def completions(self) -> GovernedChatCompletions:
|
|
271
|
+
return self._completions
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class GovernedGroq(GovernedLLM):
|
|
275
|
+
"""
|
|
276
|
+
Governed wrapper for the Groq Python SDK.
|
|
277
|
+
|
|
278
|
+
Groq provides fast inference for open-source models like Llama and Mixtral.
|
|
279
|
+
|
|
280
|
+
Example:
|
|
281
|
+
from control_zero import ControlZeroClient
|
|
282
|
+
from control_zero.llm.groq import GovernedGroq
|
|
283
|
+
from groq import Groq
|
|
284
|
+
|
|
285
|
+
cz = ControlZeroClient(api_key="...")
|
|
286
|
+
cz.initialize()
|
|
287
|
+
|
|
288
|
+
client = Groq()
|
|
289
|
+
governed = GovernedGroq(client=client, control_zero=cz)
|
|
290
|
+
|
|
291
|
+
response = governed.chat.completions.create(
|
|
292
|
+
model="llama-3.1-70b-versatile",
|
|
293
|
+
messages=[{"role": "user", "content": "Hello!"}]
|
|
294
|
+
)
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
def __init__(
|
|
298
|
+
self,
|
|
299
|
+
client: Any,
|
|
300
|
+
control_zero: Any,
|
|
301
|
+
config: Optional[LLMGovernanceConfig] = None,
|
|
302
|
+
user_context: Optional[Dict[str, Any]] = None,
|
|
303
|
+
):
|
|
304
|
+
super().__init__(client, control_zero, config, user_context)
|
|
305
|
+
self._chat = GovernedChat(self)
|
|
306
|
+
|
|
307
|
+
@property
|
|
308
|
+
def provider_name(self) -> str:
|
|
309
|
+
return "groq"
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def chat(self) -> GovernedChat:
|
|
313
|
+
return self._chat
|
|
314
|
+
|
|
315
|
+
def with_user_context(self, user_context: Dict[str, Any]) -> "GovernedGroq":
|
|
316
|
+
merged_context = {**self._user_context, **user_context}
|
|
317
|
+
return GovernedGroq(
|
|
318
|
+
client=self._client,
|
|
319
|
+
control_zero=self._cz,
|
|
320
|
+
config=self._config,
|
|
321
|
+
user_context=merged_context,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def with_config(self, config: LLMGovernanceConfig) -> "GovernedGroq":
|
|
325
|
+
return GovernedGroq(
|
|
326
|
+
client=self._client,
|
|
327
|
+
control_zero=self._cz,
|
|
328
|
+
config=config,
|
|
329
|
+
user_context=self._user_context,
|
|
330
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Control Zero Mistral Governance Wrapper.
|
|
3
|
+
|
|
4
|
+
This module provides governance wrappers for the Mistral Python SDK,
|
|
5
|
+
enabling policy enforcement, cost tracking, and audit logging for
|
|
6
|
+
all Mistral API calls.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from control_zero import ControlZeroClient
|
|
10
|
+
from control_zero.llm.mistral import GovernedMistral
|
|
11
|
+
from mistralai import Mistral
|
|
12
|
+
|
|
13
|
+
# Initialize Control Zero
|
|
14
|
+
cz_client = ControlZeroClient(api_key="cz_live_xxx")
|
|
15
|
+
cz_client.initialize()
|
|
16
|
+
|
|
17
|
+
# Wrap Mistral client with governance
|
|
18
|
+
mistral_client = Mistral(api_key="your-mistral-key")
|
|
19
|
+
governed = GovernedMistral(client=mistral_client, control_zero=cz_client)
|
|
20
|
+
|
|
21
|
+
# All calls are now governed
|
|
22
|
+
response = governed.chat.complete(
|
|
23
|
+
model="mistral-large-latest",
|
|
24
|
+
messages=[{"role": "user", "content": "Hello"}]
|
|
25
|
+
)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from control_zero.llm.mistral.client import GovernedMistral
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"GovernedMistral",
|
|
32
|
+
]
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Governed Mistral client wrapper.
|
|
3
|
+
|
|
4
|
+
Provides governance features for the Mistral Python SDK including:
|
|
5
|
+
- Model access control
|
|
6
|
+
- Cost tracking and limits
|
|
7
|
+
- Function calling governance
|
|
8
|
+
- PII detection and masking
|
|
9
|
+
- Audit logging
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import time
|
|
13
|
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
14
|
+
|
|
15
|
+
from control_zero.llm.base import (
|
|
16
|
+
GovernanceAction,
|
|
17
|
+
GovernedLLM,
|
|
18
|
+
GovernedChatMixin,
|
|
19
|
+
LLMGovernanceConfig,
|
|
20
|
+
LLMUsageMetrics,
|
|
21
|
+
estimate_cost,
|
|
22
|
+
)
|
|
23
|
+
from control_zero.policy import PolicyDeniedError
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GovernedChatCompletions(GovernedChatMixin):
|
|
27
|
+
"""Governed wrapper for Mistral chat completions."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, governed_client: "GovernedMistral"):
|
|
30
|
+
self._governed = governed_client
|
|
31
|
+
self._client = governed_client._client
|
|
32
|
+
|
|
33
|
+
def complete(
|
|
34
|
+
self,
|
|
35
|
+
*,
|
|
36
|
+
model: str,
|
|
37
|
+
messages: List[Dict[str, Any]],
|
|
38
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
39
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
40
|
+
stream: bool = False,
|
|
41
|
+
max_tokens: Optional[int] = None,
|
|
42
|
+
temperature: Optional[float] = None,
|
|
43
|
+
top_p: Optional[float] = None,
|
|
44
|
+
random_seed: Optional[int] = None,
|
|
45
|
+
safe_prompt: bool = False,
|
|
46
|
+
response_format: Optional[Dict[str, str]] = None,
|
|
47
|
+
**kwargs,
|
|
48
|
+
) -> Any:
|
|
49
|
+
"""
|
|
50
|
+
Create a governed chat completion.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model: Model to use (e.g., "mistral-large-latest")
|
|
54
|
+
messages: List of message dicts
|
|
55
|
+
tools: Tool definitions
|
|
56
|
+
tool_choice: Tool selection mode
|
|
57
|
+
stream: Whether to stream
|
|
58
|
+
max_tokens: Maximum output tokens
|
|
59
|
+
temperature: Sampling temperature
|
|
60
|
+
top_p: Nucleus sampling
|
|
61
|
+
random_seed: Random seed
|
|
62
|
+
safe_prompt: Enable safe mode
|
|
63
|
+
response_format: Response format
|
|
64
|
+
**kwargs: Additional parameters
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
ChatCompletionResponse or stream iterator
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
PolicyDeniedError: If request violates policy
|
|
71
|
+
"""
|
|
72
|
+
start_time = time.time()
|
|
73
|
+
|
|
74
|
+
# Estimate tokens
|
|
75
|
+
estimated_input_tokens = self._estimate_message_tokens(messages)
|
|
76
|
+
|
|
77
|
+
# Prepare tools for policy check
|
|
78
|
+
tools_to_check = []
|
|
79
|
+
if tools:
|
|
80
|
+
tools_to_check = [{"name": t.get("function", {}).get("name", ""), "type": "function"} for t in tools]
|
|
81
|
+
|
|
82
|
+
# Run governance checks
|
|
83
|
+
self._governed._pre_request_checks(
|
|
84
|
+
model=model,
|
|
85
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
86
|
+
messages=messages,
|
|
87
|
+
functions=tools_to_check,
|
|
88
|
+
estimated_tokens=estimated_input_tokens,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Process messages for governance
|
|
92
|
+
processed_messages = self._process_messages_for_governance(messages)
|
|
93
|
+
|
|
94
|
+
# Filter tools
|
|
95
|
+
filtered_tools = self._filter_functions_for_governance(tools)
|
|
96
|
+
|
|
97
|
+
# Apply max_tokens limit
|
|
98
|
+
governed_max_tokens = max_tokens
|
|
99
|
+
if self._governed._config.content_policy.max_output_tokens:
|
|
100
|
+
if max_tokens:
|
|
101
|
+
governed_max_tokens = min(max_tokens, self._governed._config.content_policy.max_output_tokens)
|
|
102
|
+
else:
|
|
103
|
+
governed_max_tokens = self._governed._config.content_policy.max_output_tokens
|
|
104
|
+
|
|
105
|
+
# Build request
|
|
106
|
+
request_kwargs = {
|
|
107
|
+
"model": model,
|
|
108
|
+
"messages": processed_messages,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
if filtered_tools:
|
|
112
|
+
request_kwargs["tools"] = filtered_tools
|
|
113
|
+
if tool_choice and filtered_tools:
|
|
114
|
+
request_kwargs["tool_choice"] = tool_choice
|
|
115
|
+
if governed_max_tokens:
|
|
116
|
+
request_kwargs["max_tokens"] = governed_max_tokens
|
|
117
|
+
if temperature is not None:
|
|
118
|
+
request_kwargs["temperature"] = temperature
|
|
119
|
+
if top_p is not None:
|
|
120
|
+
request_kwargs["top_p"] = top_p
|
|
121
|
+
if random_seed is not None:
|
|
122
|
+
request_kwargs["random_seed"] = random_seed
|
|
123
|
+
if safe_prompt:
|
|
124
|
+
request_kwargs["safe_prompt"] = safe_prompt
|
|
125
|
+
if response_format is not None:
|
|
126
|
+
request_kwargs["response_format"] = response_format
|
|
127
|
+
|
|
128
|
+
request_kwargs.update(kwargs)
|
|
129
|
+
|
|
130
|
+
# Handle streaming
|
|
131
|
+
if stream:
|
|
132
|
+
return self._create_stream(request_kwargs, start_time, model)
|
|
133
|
+
|
|
134
|
+
# Make API call
|
|
135
|
+
try:
|
|
136
|
+
response = self._client.chat.complete(**request_kwargs)
|
|
137
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
138
|
+
|
|
139
|
+
# Extract metrics
|
|
140
|
+
usage = getattr(response, "usage", None)
|
|
141
|
+
input_tokens = usage.prompt_tokens if usage else estimated_input_tokens
|
|
142
|
+
output_tokens = usage.completion_tokens if usage else 0
|
|
143
|
+
total_tokens = usage.total_tokens if usage else input_tokens + output_tokens
|
|
144
|
+
|
|
145
|
+
# Count tool calls
|
|
146
|
+
tool_call_count = 0
|
|
147
|
+
for choice in response.choices:
|
|
148
|
+
if hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
|
|
149
|
+
tool_call_count += len(choice.message.tool_calls)
|
|
150
|
+
|
|
151
|
+
metrics = LLMUsageMetrics(
|
|
152
|
+
provider="mistral",
|
|
153
|
+
model=model,
|
|
154
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
155
|
+
input_tokens=input_tokens,
|
|
156
|
+
output_tokens=output_tokens,
|
|
157
|
+
total_tokens=total_tokens,
|
|
158
|
+
latency_ms=latency_ms,
|
|
159
|
+
estimated_cost=estimate_cost(model, input_tokens, output_tokens),
|
|
160
|
+
function_calls=tool_call_count,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
self._governed._post_request_update(metrics)
|
|
164
|
+
self._governed._log_request(model, GovernanceAction.CHAT_COMPLETION, metrics)
|
|
165
|
+
|
|
166
|
+
return response
|
|
167
|
+
|
|
168
|
+
except PolicyDeniedError:
|
|
169
|
+
raise
|
|
170
|
+
except Exception as e:
|
|
171
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
172
|
+
metrics = LLMUsageMetrics(
|
|
173
|
+
provider="mistral",
|
|
174
|
+
model=model,
|
|
175
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
176
|
+
latency_ms=latency_ms,
|
|
177
|
+
)
|
|
178
|
+
self._governed._log_request(
|
|
179
|
+
model, GovernanceAction.CHAT_COMPLETION, metrics,
|
|
180
|
+
status="error", error=str(e)
|
|
181
|
+
)
|
|
182
|
+
raise
|
|
183
|
+
|
|
184
|
+
def _create_stream(
|
|
185
|
+
self,
|
|
186
|
+
request_kwargs: Dict[str, Any],
|
|
187
|
+
start_time: float,
|
|
188
|
+
model: str,
|
|
189
|
+
) -> Iterator[Any]:
|
|
190
|
+
"""Create a governed streaming response."""
|
|
191
|
+
total_tokens = 0
|
|
192
|
+
tool_call_count = 0
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
stream = self._client.chat.stream(**request_kwargs)
|
|
196
|
+
|
|
197
|
+
for chunk in stream:
|
|
198
|
+
total_tokens += 1
|
|
199
|
+
|
|
200
|
+
data = getattr(chunk, "data", chunk)
|
|
201
|
+
if hasattr(data, "choices"):
|
|
202
|
+
for choice in data.choices:
|
|
203
|
+
if hasattr(choice, "delta") and hasattr(choice.delta, "tool_calls"):
|
|
204
|
+
if choice.delta.tool_calls:
|
|
205
|
+
tool_call_count += len(choice.delta.tool_calls)
|
|
206
|
+
|
|
207
|
+
yield chunk
|
|
208
|
+
|
|
209
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
210
|
+
|
|
211
|
+
metrics = LLMUsageMetrics(
|
|
212
|
+
provider="mistral",
|
|
213
|
+
model=model,
|
|
214
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
215
|
+
total_tokens=total_tokens,
|
|
216
|
+
latency_ms=latency_ms,
|
|
217
|
+
estimated_cost=estimate_cost(model, total_tokens // 2, total_tokens // 2),
|
|
218
|
+
function_calls=tool_call_count,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
self._governed._post_request_update(metrics)
|
|
222
|
+
self._governed._log_request(model, GovernanceAction.CHAT_COMPLETION, metrics)
|
|
223
|
+
|
|
224
|
+
except Exception as e:
|
|
225
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
226
|
+
metrics = LLMUsageMetrics(
|
|
227
|
+
provider="mistral",
|
|
228
|
+
model=model,
|
|
229
|
+
action=GovernanceAction.CHAT_COMPLETION,
|
|
230
|
+
latency_ms=latency_ms,
|
|
231
|
+
)
|
|
232
|
+
self._governed._log_request(
|
|
233
|
+
model, GovernanceAction.CHAT_COMPLETION, metrics,
|
|
234
|
+
status="error", error=str(e)
|
|
235
|
+
)
|
|
236
|
+
raise
|
|
237
|
+
|
|
238
|
+
def _estimate_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
|
239
|
+
"""Estimate token count for messages."""
|
|
240
|
+
total_chars = sum(len(str(m.get("content", ""))) for m in messages)
|
|
241
|
+
return max(1, total_chars // 4)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class GovernedChat:
|
|
245
|
+
"""Governed wrapper for Mistral chat namespace."""
|
|
246
|
+
|
|
247
|
+
def __init__(self, governed_client: "GovernedMistral"):
|
|
248
|
+
self._completions = GovernedChatCompletions(governed_client)
|
|
249
|
+
|
|
250
|
+
def complete(self, **kwargs) -> Any:
|
|
251
|
+
"""Create a chat completion."""
|
|
252
|
+
return self._completions.complete(**kwargs)
|
|
253
|
+
|
|
254
|
+
def stream(self, **kwargs) -> Iterator[Any]:
|
|
255
|
+
"""Create a streaming chat completion."""
|
|
256
|
+
kwargs["stream"] = True
|
|
257
|
+
return self._completions.complete(**kwargs)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class GovernedMistral(GovernedLLM):
|
|
261
|
+
"""
|
|
262
|
+
Governed wrapper for the Mistral Python SDK.
|
|
263
|
+
|
|
264
|
+
Example:
|
|
265
|
+
from control_zero import ControlZeroClient
|
|
266
|
+
from control_zero.llm.mistral import GovernedMistral
|
|
267
|
+
from mistralai import Mistral
|
|
268
|
+
|
|
269
|
+
cz = ControlZeroClient(api_key="...")
|
|
270
|
+
cz.initialize()
|
|
271
|
+
|
|
272
|
+
client = Mistral(api_key="...")
|
|
273
|
+
governed = GovernedMistral(client=client, control_zero=cz)
|
|
274
|
+
|
|
275
|
+
response = governed.chat.complete(
|
|
276
|
+
model="mistral-large-latest",
|
|
277
|
+
messages=[{"role": "user", "content": "Hello!"}]
|
|
278
|
+
)
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
def __init__(
|
|
282
|
+
self,
|
|
283
|
+
client: Any,
|
|
284
|
+
control_zero: Any,
|
|
285
|
+
config: Optional[LLMGovernanceConfig] = None,
|
|
286
|
+
user_context: Optional[Dict[str, Any]] = None,
|
|
287
|
+
):
|
|
288
|
+
super().__init__(client, control_zero, config, user_context)
|
|
289
|
+
self._chat = GovernedChat(self)
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def provider_name(self) -> str:
|
|
293
|
+
return "mistral"
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def chat(self) -> GovernedChat:
|
|
297
|
+
return self._chat
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def embeddings(self):
|
|
301
|
+
"""Access the underlying embeddings endpoint (pass-through)."""
|
|
302
|
+
return self._client.embeddings
|
|
303
|
+
|
|
304
|
+
def with_user_context(self, user_context: Dict[str, Any]) -> "GovernedMistral":
|
|
305
|
+
merged_context = {**self._user_context, **user_context}
|
|
306
|
+
return GovernedMistral(
|
|
307
|
+
client=self._client,
|
|
308
|
+
control_zero=self._cz,
|
|
309
|
+
config=self._config,
|
|
310
|
+
user_context=merged_context,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def with_config(self, config: LLMGovernanceConfig) -> "GovernedMistral":
|
|
314
|
+
return GovernedMistral(
|
|
315
|
+
client=self._client,
|
|
316
|
+
control_zero=self._cz,
|
|
317
|
+
config=config,
|
|
318
|
+
user_context=self._user_context,
|
|
319
|
+
)
|