proxilion 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proxilion/__init__.py +136 -0
- proxilion/audit/__init__.py +133 -0
- proxilion/audit/base_exporters.py +527 -0
- proxilion/audit/compliance/__init__.py +130 -0
- proxilion/audit/compliance/base.py +457 -0
- proxilion/audit/compliance/eu_ai_act.py +603 -0
- proxilion/audit/compliance/iso27001.py +544 -0
- proxilion/audit/compliance/soc2.py +491 -0
- proxilion/audit/events.py +493 -0
- proxilion/audit/explainability.py +1173 -0
- proxilion/audit/exporters/__init__.py +58 -0
- proxilion/audit/exporters/aws_s3.py +636 -0
- proxilion/audit/exporters/azure_storage.py +608 -0
- proxilion/audit/exporters/cloud_base.py +468 -0
- proxilion/audit/exporters/gcp_storage.py +570 -0
- proxilion/audit/exporters/multi_exporter.py +498 -0
- proxilion/audit/hash_chain.py +652 -0
- proxilion/audit/logger.py +543 -0
- proxilion/caching/__init__.py +49 -0
- proxilion/caching/tool_cache.py +633 -0
- proxilion/context/__init__.py +73 -0
- proxilion/context/context_window.py +556 -0
- proxilion/context/message_history.py +505 -0
- proxilion/context/session.py +735 -0
- proxilion/contrib/__init__.py +51 -0
- proxilion/contrib/anthropic.py +609 -0
- proxilion/contrib/google.py +1012 -0
- proxilion/contrib/langchain.py +641 -0
- proxilion/contrib/mcp.py +893 -0
- proxilion/contrib/openai.py +646 -0
- proxilion/core.py +3058 -0
- proxilion/decorators.py +966 -0
- proxilion/engines/__init__.py +287 -0
- proxilion/engines/base.py +266 -0
- proxilion/engines/casbin_engine.py +412 -0
- proxilion/engines/opa_engine.py +493 -0
- proxilion/engines/simple.py +437 -0
- proxilion/exceptions.py +887 -0
- proxilion/guards/__init__.py +54 -0
- proxilion/guards/input_guard.py +522 -0
- proxilion/guards/output_guard.py +634 -0
- proxilion/observability/__init__.py +198 -0
- proxilion/observability/cost_tracker.py +866 -0
- proxilion/observability/hooks.py +683 -0
- proxilion/observability/metrics.py +798 -0
- proxilion/observability/session_cost_tracker.py +1063 -0
- proxilion/policies/__init__.py +67 -0
- proxilion/policies/base.py +304 -0
- proxilion/policies/builtin.py +486 -0
- proxilion/policies/registry.py +376 -0
- proxilion/providers/__init__.py +201 -0
- proxilion/providers/adapter.py +468 -0
- proxilion/providers/anthropic_adapter.py +330 -0
- proxilion/providers/gemini_adapter.py +391 -0
- proxilion/providers/openai_adapter.py +294 -0
- proxilion/py.typed +0 -0
- proxilion/resilience/__init__.py +81 -0
- proxilion/resilience/degradation.py +615 -0
- proxilion/resilience/fallback.py +555 -0
- proxilion/resilience/retry.py +554 -0
- proxilion/scheduling/__init__.py +57 -0
- proxilion/scheduling/priority_queue.py +419 -0
- proxilion/scheduling/scheduler.py +459 -0
- proxilion/security/__init__.py +244 -0
- proxilion/security/agent_trust.py +968 -0
- proxilion/security/behavioral_drift.py +794 -0
- proxilion/security/cascade_protection.py +869 -0
- proxilion/security/circuit_breaker.py +428 -0
- proxilion/security/cost_limiter.py +690 -0
- proxilion/security/idor_protection.py +460 -0
- proxilion/security/intent_capsule.py +849 -0
- proxilion/security/intent_validator.py +495 -0
- proxilion/security/memory_integrity.py +767 -0
- proxilion/security/rate_limiter.py +509 -0
- proxilion/security/scope_enforcer.py +680 -0
- proxilion/security/sequence_validator.py +636 -0
- proxilion/security/trust_boundaries.py +784 -0
- proxilion/streaming/__init__.py +70 -0
- proxilion/streaming/detector.py +761 -0
- proxilion/streaming/transformer.py +674 -0
- proxilion/timeouts/__init__.py +55 -0
- proxilion/timeouts/decorators.py +477 -0
- proxilion/timeouts/manager.py +545 -0
- proxilion/tools/__init__.py +69 -0
- proxilion/tools/decorators.py +493 -0
- proxilion/tools/registry.py +732 -0
- proxilion/types.py +339 -0
- proxilion/validation/__init__.py +93 -0
- proxilion/validation/pydantic_schema.py +351 -0
- proxilion/validation/schema.py +651 -0
- proxilion-0.0.1.dist-info/METADATA +872 -0
- proxilion-0.0.1.dist-info/RECORD +94 -0
- proxilion-0.0.1.dist-info/WHEEL +4 -0
- proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,761 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool call detection in streaming LLM responses.
|
|
3
|
+
|
|
4
|
+
Provides utilities for detecting and extracting tool calls from
|
|
5
|
+
streaming responses across multiple LLM providers (OpenAI, Anthropic, Google).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import contextlib
|
|
11
|
+
import json
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime, timezone
|
|
14
|
+
from enum import Enum
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StreamEventType(Enum):
|
|
19
|
+
"""Types of events that can occur during streaming."""
|
|
20
|
+
|
|
21
|
+
TEXT = "text"
|
|
22
|
+
TOOL_CALL_START = "tool_call_start"
|
|
23
|
+
TOOL_CALL_DELTA = "tool_call_delta"
|
|
24
|
+
TOOL_CALL_END = "tool_call_end"
|
|
25
|
+
DONE = "done"
|
|
26
|
+
ERROR = "error"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class PartialToolCall:
|
|
31
|
+
"""
|
|
32
|
+
A tool call that is being buffered during streaming.
|
|
33
|
+
|
|
34
|
+
Accumulates argument chunks until the tool call is complete.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
id: Unique identifier for this tool call.
|
|
38
|
+
name: Name of the tool being called.
|
|
39
|
+
arguments_buffer: Accumulated JSON argument string.
|
|
40
|
+
is_complete: Whether the tool call has finished streaming.
|
|
41
|
+
started_at: When the tool call started.
|
|
42
|
+
index: Index in the response (for multiple tool calls).
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
id: str
|
|
46
|
+
name: str
|
|
47
|
+
arguments_buffer: str = ""
|
|
48
|
+
is_complete: bool = False
|
|
49
|
+
started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
50
|
+
index: int = 0
|
|
51
|
+
|
|
52
|
+
def append_arguments(self, delta: str) -> None:
|
|
53
|
+
"""Append argument delta to the buffer."""
|
|
54
|
+
self.arguments_buffer += delta
|
|
55
|
+
|
|
56
|
+
def get_arguments(self) -> dict[str, Any]:
|
|
57
|
+
"""
|
|
58
|
+
Parse and return the accumulated arguments.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Parsed arguments dictionary.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If arguments cannot be parsed as JSON.
|
|
65
|
+
"""
|
|
66
|
+
if not self.arguments_buffer:
|
|
67
|
+
return {}
|
|
68
|
+
try:
|
|
69
|
+
return json.loads(self.arguments_buffer)
|
|
70
|
+
except json.JSONDecodeError as e:
|
|
71
|
+
raise ValueError(f"Invalid tool call arguments: {e}") from e
|
|
72
|
+
|
|
73
|
+
def complete(self) -> None:
|
|
74
|
+
"""Mark the tool call as complete."""
|
|
75
|
+
self.is_complete = True
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class DetectedToolCall:
|
|
80
|
+
"""
|
|
81
|
+
A fully detected tool call extracted from streaming.
|
|
82
|
+
|
|
83
|
+
Attributes:
|
|
84
|
+
id: Unique identifier for this tool call.
|
|
85
|
+
name: Name of the tool being called.
|
|
86
|
+
arguments: Parsed arguments dictionary.
|
|
87
|
+
raw_arguments: Original JSON string of arguments.
|
|
88
|
+
index: Index in the response (for multiple tool calls).
|
|
89
|
+
detected_at: When the tool call was fully detected.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
id: str
|
|
93
|
+
name: str
|
|
94
|
+
arguments: dict[str, Any]
|
|
95
|
+
raw_arguments: str = ""
|
|
96
|
+
index: int = 0
|
|
97
|
+
detected_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def from_partial(cls, partial: PartialToolCall) -> DetectedToolCall:
|
|
101
|
+
"""Create a DetectedToolCall from a completed PartialToolCall."""
|
|
102
|
+
return cls(
|
|
103
|
+
id=partial.id,
|
|
104
|
+
name=partial.name,
|
|
105
|
+
arguments=partial.get_arguments(),
|
|
106
|
+
raw_arguments=partial.arguments_buffer,
|
|
107
|
+
index=partial.index,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@dataclass
|
|
112
|
+
class StreamEvent:
|
|
113
|
+
"""
|
|
114
|
+
An event emitted during stream processing.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
type: The type of event.
|
|
118
|
+
content: Text content (for TEXT events).
|
|
119
|
+
tool_call: The tool call (for TOOL_CALL_* events).
|
|
120
|
+
partial_call: The partial tool call (for TOOL_CALL_DELTA events).
|
|
121
|
+
error: Error message (for ERROR events).
|
|
122
|
+
raw_chunk: The original chunk that produced this event.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
type: StreamEventType
|
|
126
|
+
content: str | None = None
|
|
127
|
+
tool_call: DetectedToolCall | None = None
|
|
128
|
+
partial_call: PartialToolCall | None = None
|
|
129
|
+
error: str | None = None
|
|
130
|
+
raw_chunk: Any = None
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def text(cls, content: str, raw_chunk: Any = None) -> StreamEvent:
|
|
134
|
+
"""Create a TEXT event."""
|
|
135
|
+
return cls(type=StreamEventType.TEXT, content=content, raw_chunk=raw_chunk)
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def tool_call_start(
|
|
139
|
+
cls, partial: PartialToolCall, raw_chunk: Any = None
|
|
140
|
+
) -> StreamEvent:
|
|
141
|
+
"""Create a TOOL_CALL_START event."""
|
|
142
|
+
return cls(
|
|
143
|
+
type=StreamEventType.TOOL_CALL_START,
|
|
144
|
+
partial_call=partial,
|
|
145
|
+
raw_chunk=raw_chunk,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def tool_call_delta(
|
|
150
|
+
cls, partial: PartialToolCall, raw_chunk: Any = None
|
|
151
|
+
) -> StreamEvent:
|
|
152
|
+
"""Create a TOOL_CALL_DELTA event."""
|
|
153
|
+
return cls(
|
|
154
|
+
type=StreamEventType.TOOL_CALL_DELTA,
|
|
155
|
+
partial_call=partial,
|
|
156
|
+
raw_chunk=raw_chunk,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def tool_call_end(
|
|
161
|
+
cls, tool_call: DetectedToolCall, raw_chunk: Any = None
|
|
162
|
+
) -> StreamEvent:
|
|
163
|
+
"""Create a TOOL_CALL_END event."""
|
|
164
|
+
return cls(
|
|
165
|
+
type=StreamEventType.TOOL_CALL_END,
|
|
166
|
+
tool_call=tool_call,
|
|
167
|
+
raw_chunk=raw_chunk,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def done(cls, raw_chunk: Any = None) -> StreamEvent:
|
|
172
|
+
"""Create a DONE event."""
|
|
173
|
+
return cls(type=StreamEventType.DONE, raw_chunk=raw_chunk)
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def error_event(cls, error: str, raw_chunk: Any = None) -> StreamEvent:
|
|
177
|
+
"""Create an ERROR event."""
|
|
178
|
+
return cls(type=StreamEventType.ERROR, error=error, raw_chunk=raw_chunk)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class StreamingToolCallDetector:
|
|
182
|
+
"""
|
|
183
|
+
Detect and buffer tool calls from streaming LLM responses.
|
|
184
|
+
|
|
185
|
+
Works with OpenAI, Anthropic, and Google streaming formats,
|
|
186
|
+
automatically detecting the provider from chunk structure.
|
|
187
|
+
|
|
188
|
+
Example:
|
|
189
|
+
>>> detector = StreamingToolCallDetector()
|
|
190
|
+
>>> async for chunk in llm_stream:
|
|
191
|
+
... events = detector.process_chunk(chunk)
|
|
192
|
+
... for event in events:
|
|
193
|
+
... if event.type == StreamEventType.TOOL_CALL_END:
|
|
194
|
+
... # Full tool call is now available
|
|
195
|
+
... tool_call = event.tool_call
|
|
196
|
+
... result = auth.authorize(user, "execute", tool_call.name)
|
|
197
|
+
... elif event.type == StreamEventType.TEXT:
|
|
198
|
+
... print(event.content, end="")
|
|
199
|
+
|
|
200
|
+
Attributes:
|
|
201
|
+
provider: The LLM provider ("openai", "anthropic", "google", or "auto").
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
SUPPORTED_PROVIDERS = ("auto", "openai", "anthropic", "google")
|
|
205
|
+
|
|
206
|
+
def __init__(self, provider: str = "auto") -> None:
|
|
207
|
+
"""
|
|
208
|
+
Initialize the detector.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
provider: LLM provider name or "auto" for detection.
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
ValueError: If provider is not supported.
|
|
215
|
+
"""
|
|
216
|
+
if provider not in self.SUPPORTED_PROVIDERS:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Unsupported provider: {provider}. "
|
|
219
|
+
f"Supported: {self.SUPPORTED_PROVIDERS}"
|
|
220
|
+
)
|
|
221
|
+
self.provider = provider
|
|
222
|
+
self._partial_calls: dict[str, PartialToolCall] = {}
|
|
223
|
+
self._text_buffer: str = ""
|
|
224
|
+
self._detected: bool = False
|
|
225
|
+
|
|
226
|
+
def process_chunk(self, chunk: Any) -> list[StreamEvent]:
|
|
227
|
+
"""
|
|
228
|
+
Process a streaming chunk and return any events.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
chunk: A streaming chunk from an LLM provider.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
List of StreamEvent objects detected in this chunk.
|
|
235
|
+
|
|
236
|
+
Raises:
|
|
237
|
+
ValueError: If the provider cannot be determined.
|
|
238
|
+
"""
|
|
239
|
+
# Detect provider from chunk structure if auto
|
|
240
|
+
if self.provider == "auto" and not self._detected:
|
|
241
|
+
detected = self._detect_provider(chunk)
|
|
242
|
+
if detected:
|
|
243
|
+
self.provider = detected
|
|
244
|
+
self._detected = True
|
|
245
|
+
else:
|
|
246
|
+
# Try to process without detection
|
|
247
|
+
return self._process_generic_chunk(chunk)
|
|
248
|
+
|
|
249
|
+
if self.provider == "openai":
|
|
250
|
+
return self._process_openai_chunk(chunk)
|
|
251
|
+
elif self.provider == "anthropic":
|
|
252
|
+
return self._process_anthropic_chunk(chunk)
|
|
253
|
+
elif self.provider == "google":
|
|
254
|
+
return self._process_google_chunk(chunk)
|
|
255
|
+
else:
|
|
256
|
+
return self._process_generic_chunk(chunk)
|
|
257
|
+
|
|
258
|
+
def _detect_provider(self, chunk: Any) -> str | None:
|
|
259
|
+
"""
|
|
260
|
+
Detect the provider from chunk structure.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
chunk: A streaming chunk to analyze.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Provider name or None if detection failed.
|
|
267
|
+
"""
|
|
268
|
+
# Check if it's a dict-like object
|
|
269
|
+
if isinstance(chunk, dict):
|
|
270
|
+
# OpenAI streaming format
|
|
271
|
+
if "choices" in chunk:
|
|
272
|
+
return "openai"
|
|
273
|
+
# Anthropic streaming format
|
|
274
|
+
if "type" in chunk and chunk.get("type") in (
|
|
275
|
+
"content_block_start",
|
|
276
|
+
"content_block_delta",
|
|
277
|
+
"content_block_stop",
|
|
278
|
+
"message_start",
|
|
279
|
+
"message_delta",
|
|
280
|
+
"message_stop",
|
|
281
|
+
):
|
|
282
|
+
return "anthropic"
|
|
283
|
+
# Google/Gemini format
|
|
284
|
+
if "candidates" in chunk:
|
|
285
|
+
return "google"
|
|
286
|
+
|
|
287
|
+
# Check for object attributes (SDK response objects)
|
|
288
|
+
if hasattr(chunk, "choices"):
|
|
289
|
+
return "openai"
|
|
290
|
+
if hasattr(chunk, "type"):
|
|
291
|
+
chunk_type = getattr(chunk, "type", None)
|
|
292
|
+
if chunk_type in (
|
|
293
|
+
"content_block_start",
|
|
294
|
+
"content_block_delta",
|
|
295
|
+
"content_block_stop",
|
|
296
|
+
"message_start",
|
|
297
|
+
"message_delta",
|
|
298
|
+
"message_stop",
|
|
299
|
+
):
|
|
300
|
+
return "anthropic"
|
|
301
|
+
if hasattr(chunk, "candidates"):
|
|
302
|
+
return "google"
|
|
303
|
+
|
|
304
|
+
return None
|
|
305
|
+
|
|
306
|
+
def _process_openai_chunk(self, chunk: Any) -> list[StreamEvent]:
|
|
307
|
+
"""Process an OpenAI streaming chunk."""
|
|
308
|
+
events: list[StreamEvent] = []
|
|
309
|
+
|
|
310
|
+
# Handle dict or object format
|
|
311
|
+
choices = (
|
|
312
|
+
chunk.get("choices", [])
|
|
313
|
+
if isinstance(chunk, dict)
|
|
314
|
+
else getattr(chunk, "choices", [])
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
if not choices:
|
|
318
|
+
return events
|
|
319
|
+
|
|
320
|
+
for choice in choices:
|
|
321
|
+
# Get delta from choice
|
|
322
|
+
if isinstance(choice, dict):
|
|
323
|
+
delta = choice.get("delta", {})
|
|
324
|
+
finish_reason = choice.get("finish_reason")
|
|
325
|
+
index = choice.get("index", 0)
|
|
326
|
+
else:
|
|
327
|
+
delta = getattr(choice, "delta", None)
|
|
328
|
+
finish_reason = getattr(choice, "finish_reason", None)
|
|
329
|
+
index = getattr(choice, "index", 0)
|
|
330
|
+
if delta is None:
|
|
331
|
+
continue
|
|
332
|
+
|
|
333
|
+
# Check for content (text)
|
|
334
|
+
content = (
|
|
335
|
+
delta.get("content")
|
|
336
|
+
if isinstance(delta, dict)
|
|
337
|
+
else getattr(delta, "content", None)
|
|
338
|
+
)
|
|
339
|
+
if content:
|
|
340
|
+
self._text_buffer += content
|
|
341
|
+
events.append(StreamEvent.text(content, chunk))
|
|
342
|
+
|
|
343
|
+
# Check for tool calls
|
|
344
|
+
tool_calls = (
|
|
345
|
+
delta.get("tool_calls")
|
|
346
|
+
if isinstance(delta, dict)
|
|
347
|
+
else getattr(delta, "tool_calls", None)
|
|
348
|
+
)
|
|
349
|
+
if tool_calls:
|
|
350
|
+
for tc in tool_calls:
|
|
351
|
+
tc_events = self._process_openai_tool_call(tc, index, chunk)
|
|
352
|
+
events.extend(tc_events)
|
|
353
|
+
|
|
354
|
+
# Check for finish
|
|
355
|
+
if finish_reason == "stop":
|
|
356
|
+
events.append(StreamEvent.done(chunk))
|
|
357
|
+
elif finish_reason == "tool_calls":
|
|
358
|
+
# Mark all pending tool calls as complete
|
|
359
|
+
for _call_id, partial in list(self._partial_calls.items()):
|
|
360
|
+
if not partial.is_complete:
|
|
361
|
+
partial.complete()
|
|
362
|
+
tool_call = DetectedToolCall.from_partial(partial)
|
|
363
|
+
events.append(StreamEvent.tool_call_end(tool_call, chunk))
|
|
364
|
+
|
|
365
|
+
return events
|
|
366
|
+
|
|
367
|
+
def _process_openai_tool_call(
|
|
368
|
+
self, tc: Any, choice_index: int, chunk: Any
|
|
369
|
+
) -> list[StreamEvent]:
|
|
370
|
+
"""Process a tool call delta from OpenAI format."""
|
|
371
|
+
events: list[StreamEvent] = []
|
|
372
|
+
|
|
373
|
+
# Extract tool call fields
|
|
374
|
+
if isinstance(tc, dict):
|
|
375
|
+
tc_index = tc.get("index", 0)
|
|
376
|
+
tc_id = tc.get("id")
|
|
377
|
+
tc_function = tc.get("function", {})
|
|
378
|
+
tc_name = tc_function.get("name") if isinstance(tc_function, dict) else None
|
|
379
|
+
tc_args = (
|
|
380
|
+
tc_function.get("arguments")
|
|
381
|
+
if isinstance(tc_function, dict)
|
|
382
|
+
else None
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
tc_index = getattr(tc, "index", 0)
|
|
386
|
+
tc_id = getattr(tc, "id", None)
|
|
387
|
+
tc_function = getattr(tc, "function", None)
|
|
388
|
+
tc_name = getattr(tc_function, "name", None) if tc_function else None
|
|
389
|
+
tc_args = getattr(tc_function, "arguments", None) if tc_function else None
|
|
390
|
+
|
|
391
|
+
# Create unique key for this tool call
|
|
392
|
+
call_key = f"{choice_index}_{tc_index}"
|
|
393
|
+
|
|
394
|
+
# Check if this is a new tool call
|
|
395
|
+
if tc_id and tc_name and call_key not in self._partial_calls:
|
|
396
|
+
partial = PartialToolCall(
|
|
397
|
+
id=tc_id,
|
|
398
|
+
name=tc_name,
|
|
399
|
+
index=tc_index,
|
|
400
|
+
)
|
|
401
|
+
self._partial_calls[call_key] = partial
|
|
402
|
+
events.append(StreamEvent.tool_call_start(partial, chunk))
|
|
403
|
+
elif call_key not in self._partial_calls and tc_id:
|
|
404
|
+
# New tool call without name yet
|
|
405
|
+
partial = PartialToolCall(
|
|
406
|
+
id=tc_id,
|
|
407
|
+
name="",
|
|
408
|
+
index=tc_index,
|
|
409
|
+
)
|
|
410
|
+
self._partial_calls[call_key] = partial
|
|
411
|
+
events.append(StreamEvent.tool_call_start(partial, chunk))
|
|
412
|
+
|
|
413
|
+
# Update existing tool call
|
|
414
|
+
if call_key in self._partial_calls:
|
|
415
|
+
partial = self._partial_calls[call_key]
|
|
416
|
+
if tc_name and not partial.name:
|
|
417
|
+
partial.name = tc_name
|
|
418
|
+
if tc_args:
|
|
419
|
+
partial.append_arguments(tc_args)
|
|
420
|
+
events.append(StreamEvent.tool_call_delta(partial, chunk))
|
|
421
|
+
|
|
422
|
+
return events
|
|
423
|
+
|
|
424
|
+
def _process_anthropic_chunk(self, chunk: Any) -> list[StreamEvent]:
|
|
425
|
+
"""Process an Anthropic streaming chunk."""
|
|
426
|
+
events: list[StreamEvent] = []
|
|
427
|
+
|
|
428
|
+
# Get chunk type
|
|
429
|
+
chunk_type = (
|
|
430
|
+
chunk.get("type")
|
|
431
|
+
if isinstance(chunk, dict)
|
|
432
|
+
else getattr(chunk, "type", None)
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
if chunk_type == "content_block_start":
|
|
436
|
+
events.extend(self._handle_anthropic_block_start(chunk))
|
|
437
|
+
elif chunk_type == "content_block_delta":
|
|
438
|
+
events.extend(self._handle_anthropic_block_delta(chunk))
|
|
439
|
+
elif chunk_type == "content_block_stop":
|
|
440
|
+
events.extend(self._handle_anthropic_block_stop(chunk))
|
|
441
|
+
elif chunk_type == "message_stop":
|
|
442
|
+
events.append(StreamEvent.done(chunk))
|
|
443
|
+
|
|
444
|
+
return events
|
|
445
|
+
|
|
446
|
+
def _handle_anthropic_block_start(self, chunk: Any) -> list[StreamEvent]:
|
|
447
|
+
"""Handle Anthropic content_block_start event."""
|
|
448
|
+
events: list[StreamEvent] = []
|
|
449
|
+
|
|
450
|
+
# Get the content block
|
|
451
|
+
content_block = (
|
|
452
|
+
chunk.get("content_block")
|
|
453
|
+
if isinstance(chunk, dict)
|
|
454
|
+
else getattr(chunk, "content_block", None)
|
|
455
|
+
)
|
|
456
|
+
index = (
|
|
457
|
+
chunk.get("index", 0)
|
|
458
|
+
if isinstance(chunk, dict)
|
|
459
|
+
else getattr(chunk, "index", 0)
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if not content_block:
|
|
463
|
+
return events
|
|
464
|
+
|
|
465
|
+
block_type = (
|
|
466
|
+
content_block.get("type")
|
|
467
|
+
if isinstance(content_block, dict)
|
|
468
|
+
else getattr(content_block, "type", None)
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if block_type == "tool_use":
|
|
472
|
+
# Start of a tool call
|
|
473
|
+
tc_id = (
|
|
474
|
+
content_block.get("id")
|
|
475
|
+
if isinstance(content_block, dict)
|
|
476
|
+
else getattr(content_block, "id", None)
|
|
477
|
+
)
|
|
478
|
+
tc_name = (
|
|
479
|
+
content_block.get("name")
|
|
480
|
+
if isinstance(content_block, dict)
|
|
481
|
+
else getattr(content_block, "name", None)
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
if tc_id:
|
|
485
|
+
partial = PartialToolCall(
|
|
486
|
+
id=tc_id,
|
|
487
|
+
name=tc_name or "",
|
|
488
|
+
index=index,
|
|
489
|
+
)
|
|
490
|
+
self._partial_calls[tc_id] = partial
|
|
491
|
+
events.append(StreamEvent.tool_call_start(partial, chunk))
|
|
492
|
+
|
|
493
|
+
return events
|
|
494
|
+
|
|
495
|
+
def _handle_anthropic_block_delta(self, chunk: Any) -> list[StreamEvent]:
|
|
496
|
+
"""Handle Anthropic content_block_delta event."""
|
|
497
|
+
events: list[StreamEvent] = []
|
|
498
|
+
|
|
499
|
+
# Get the delta
|
|
500
|
+
delta = (
|
|
501
|
+
chunk.get("delta")
|
|
502
|
+
if isinstance(chunk, dict)
|
|
503
|
+
else getattr(chunk, "delta", None)
|
|
504
|
+
)
|
|
505
|
+
index = (
|
|
506
|
+
chunk.get("index", 0)
|
|
507
|
+
if isinstance(chunk, dict)
|
|
508
|
+
else getattr(chunk, "index", 0)
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
if not delta:
|
|
512
|
+
return events
|
|
513
|
+
|
|
514
|
+
delta_type = (
|
|
515
|
+
delta.get("type")
|
|
516
|
+
if isinstance(delta, dict)
|
|
517
|
+
else getattr(delta, "type", None)
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
if delta_type == "text_delta":
|
|
521
|
+
# Text content
|
|
522
|
+
text = (
|
|
523
|
+
delta.get("text")
|
|
524
|
+
if isinstance(delta, dict)
|
|
525
|
+
else getattr(delta, "text", None)
|
|
526
|
+
)
|
|
527
|
+
if text:
|
|
528
|
+
self._text_buffer += text
|
|
529
|
+
events.append(StreamEvent.text(text, chunk))
|
|
530
|
+
|
|
531
|
+
elif delta_type == "input_json_delta":
|
|
532
|
+
# Tool call arguments
|
|
533
|
+
partial_json = (
|
|
534
|
+
delta.get("partial_json")
|
|
535
|
+
if isinstance(delta, dict)
|
|
536
|
+
else getattr(delta, "partial_json", None)
|
|
537
|
+
)
|
|
538
|
+
if partial_json:
|
|
539
|
+
# Find the partial call for this index
|
|
540
|
+
for partial in self._partial_calls.values():
|
|
541
|
+
if partial.index == index and not partial.is_complete:
|
|
542
|
+
partial.append_arguments(partial_json)
|
|
543
|
+
events.append(StreamEvent.tool_call_delta(partial, chunk))
|
|
544
|
+
break
|
|
545
|
+
|
|
546
|
+
return events
|
|
547
|
+
|
|
548
|
+
def _handle_anthropic_block_stop(self, chunk: Any) -> list[StreamEvent]:
|
|
549
|
+
"""Handle Anthropic content_block_stop event."""
|
|
550
|
+
events: list[StreamEvent] = []
|
|
551
|
+
|
|
552
|
+
index = (
|
|
553
|
+
chunk.get("index", 0)
|
|
554
|
+
if isinstance(chunk, dict)
|
|
555
|
+
else getattr(chunk, "index", 0)
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Find and complete the partial call at this index
|
|
559
|
+
for _call_id, partial in list(self._partial_calls.items()):
|
|
560
|
+
if partial.index == index and not partial.is_complete:
|
|
561
|
+
partial.complete()
|
|
562
|
+
try:
|
|
563
|
+
tool_call = DetectedToolCall.from_partial(partial)
|
|
564
|
+
events.append(StreamEvent.tool_call_end(tool_call, chunk))
|
|
565
|
+
except ValueError as e:
|
|
566
|
+
events.append(StreamEvent.error_event(str(e), chunk))
|
|
567
|
+
break
|
|
568
|
+
|
|
569
|
+
return events
|
|
570
|
+
|
|
571
|
+
def _process_google_chunk(self, chunk: Any) -> list[StreamEvent]:
|
|
572
|
+
"""Process a Google/Gemini streaming chunk."""
|
|
573
|
+
events: list[StreamEvent] = []
|
|
574
|
+
|
|
575
|
+
# Get candidates
|
|
576
|
+
candidates = (
|
|
577
|
+
chunk.get("candidates", [])
|
|
578
|
+
if isinstance(chunk, dict)
|
|
579
|
+
else getattr(chunk, "candidates", [])
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
if not candidates:
|
|
583
|
+
return events
|
|
584
|
+
|
|
585
|
+
for cand_idx, candidate in enumerate(candidates):
|
|
586
|
+
# Get content from candidate
|
|
587
|
+
content = (
|
|
588
|
+
candidate.get("content")
|
|
589
|
+
if isinstance(candidate, dict)
|
|
590
|
+
else getattr(candidate, "content", None)
|
|
591
|
+
)
|
|
592
|
+
finish_reason = (
|
|
593
|
+
candidate.get("finishReason")
|
|
594
|
+
if isinstance(candidate, dict)
|
|
595
|
+
else getattr(candidate, "finish_reason", None)
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
if content:
|
|
599
|
+
parts = (
|
|
600
|
+
content.get("parts", [])
|
|
601
|
+
if isinstance(content, dict)
|
|
602
|
+
else getattr(content, "parts", [])
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
for part_idx, part in enumerate(parts):
|
|
606
|
+
part_events = self._process_google_part(
|
|
607
|
+
part, cand_idx, part_idx, chunk
|
|
608
|
+
)
|
|
609
|
+
events.extend(part_events)
|
|
610
|
+
|
|
611
|
+
# Check for finish
|
|
612
|
+
if finish_reason == "STOP":
|
|
613
|
+
# Complete any pending tool calls
|
|
614
|
+
for partial in list(self._partial_calls.values()):
|
|
615
|
+
if not partial.is_complete:
|
|
616
|
+
partial.complete()
|
|
617
|
+
try:
|
|
618
|
+
tool_call = DetectedToolCall.from_partial(partial)
|
|
619
|
+
events.append(StreamEvent.tool_call_end(tool_call, chunk))
|
|
620
|
+
except ValueError:
|
|
621
|
+
pass
|
|
622
|
+
events.append(StreamEvent.done(chunk))
|
|
623
|
+
|
|
624
|
+
return events
|
|
625
|
+
|
|
626
|
+
def _process_google_part(
|
|
627
|
+
self, part: Any, cand_idx: int, part_idx: int, chunk: Any
|
|
628
|
+
) -> list[StreamEvent]:
|
|
629
|
+
"""Process a part from a Google response."""
|
|
630
|
+
events: list[StreamEvent] = []
|
|
631
|
+
|
|
632
|
+
# Check for text
|
|
633
|
+
text = (
|
|
634
|
+
part.get("text") if isinstance(part, dict) else getattr(part, "text", None)
|
|
635
|
+
)
|
|
636
|
+
if text:
|
|
637
|
+
self._text_buffer += text
|
|
638
|
+
events.append(StreamEvent.text(text, chunk))
|
|
639
|
+
|
|
640
|
+
# Check for function call
|
|
641
|
+
function_call = (
|
|
642
|
+
part.get("functionCall")
|
|
643
|
+
if isinstance(part, dict)
|
|
644
|
+
else getattr(part, "function_call", None)
|
|
645
|
+
)
|
|
646
|
+
if function_call:
|
|
647
|
+
# Google sends function calls complete in one chunk
|
|
648
|
+
fc_name = (
|
|
649
|
+
function_call.get("name")
|
|
650
|
+
if isinstance(function_call, dict)
|
|
651
|
+
else getattr(function_call, "name", None)
|
|
652
|
+
)
|
|
653
|
+
fc_args = (
|
|
654
|
+
function_call.get("args")
|
|
655
|
+
if isinstance(function_call, dict)
|
|
656
|
+
else getattr(function_call, "args", None)
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
call_id = f"google_{cand_idx}_{part_idx}"
|
|
660
|
+
partial = PartialToolCall(
|
|
661
|
+
id=call_id,
|
|
662
|
+
name=fc_name or "",
|
|
663
|
+
index=part_idx,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
# Google sends args as dict, not JSON string
|
|
667
|
+
if fc_args:
|
|
668
|
+
if isinstance(fc_args, dict):
|
|
669
|
+
partial.arguments_buffer = json.dumps(fc_args)
|
|
670
|
+
else:
|
|
671
|
+
partial.arguments_buffer = str(fc_args)
|
|
672
|
+
|
|
673
|
+
partial.complete()
|
|
674
|
+
self._partial_calls[call_id] = partial
|
|
675
|
+
|
|
676
|
+
events.append(StreamEvent.tool_call_start(partial, chunk))
|
|
677
|
+
tool_call = DetectedToolCall.from_partial(partial)
|
|
678
|
+
events.append(StreamEvent.tool_call_end(tool_call, chunk))
|
|
679
|
+
|
|
680
|
+
return events
|
|
681
|
+
|
|
682
|
+
def _process_generic_chunk(self, chunk: Any) -> list[StreamEvent]:
|
|
683
|
+
"""Process a chunk with unknown format."""
|
|
684
|
+
events: list[StreamEvent] = []
|
|
685
|
+
|
|
686
|
+
# Try to extract text content
|
|
687
|
+
if isinstance(chunk, str):
|
|
688
|
+
self._text_buffer += chunk
|
|
689
|
+
events.append(StreamEvent.text(chunk, chunk))
|
|
690
|
+
elif isinstance(chunk, dict):
|
|
691
|
+
# Try common keys
|
|
692
|
+
for key in ("content", "text", "message", "data"):
|
|
693
|
+
if key in chunk and isinstance(chunk[key], str):
|
|
694
|
+
self._text_buffer += chunk[key]
|
|
695
|
+
events.append(StreamEvent.text(chunk[key], chunk))
|
|
696
|
+
break
|
|
697
|
+
|
|
698
|
+
return events
|
|
699
|
+
|
|
700
|
+
def get_pending_calls(self) -> list[PartialToolCall]:
|
|
701
|
+
"""
|
|
702
|
+
Get tool calls that are still being streamed.
|
|
703
|
+
|
|
704
|
+
Returns:
|
|
705
|
+
List of incomplete PartialToolCall objects.
|
|
706
|
+
"""
|
|
707
|
+
return [c for c in self._partial_calls.values() if not c.is_complete]
|
|
708
|
+
|
|
709
|
+
def get_completed_calls(self) -> list[PartialToolCall]:
|
|
710
|
+
"""
|
|
711
|
+
Get tool calls that have completed streaming.
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
List of complete PartialToolCall objects.
|
|
715
|
+
"""
|
|
716
|
+
return [c for c in self._partial_calls.values() if c.is_complete]
|
|
717
|
+
|
|
718
|
+
def get_all_detected_calls(self) -> list[DetectedToolCall]:
|
|
719
|
+
"""
|
|
720
|
+
Get all completed tool calls as DetectedToolCall objects.
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
List of DetectedToolCall objects.
|
|
724
|
+
"""
|
|
725
|
+
result = []
|
|
726
|
+
for partial in self._partial_calls.values():
|
|
727
|
+
if partial.is_complete:
|
|
728
|
+
with contextlib.suppress(ValueError):
|
|
729
|
+
result.append(DetectedToolCall.from_partial(partial))
|
|
730
|
+
return result
|
|
731
|
+
|
|
732
|
+
def get_text_buffer(self) -> str:
|
|
733
|
+
"""
|
|
734
|
+
Get all accumulated text content.
|
|
735
|
+
|
|
736
|
+
Returns:
|
|
737
|
+
Concatenated text from all TEXT events.
|
|
738
|
+
"""
|
|
739
|
+
return self._text_buffer
|
|
740
|
+
|
|
741
|
+
def reset(self) -> None:
|
|
742
|
+
"""Reset detector state for a new stream."""
|
|
743
|
+
self._partial_calls.clear()
|
|
744
|
+
self._text_buffer = ""
|
|
745
|
+
if self.provider == "auto":
|
|
746
|
+
self._detected = False
|
|
747
|
+
|
|
748
|
+
def get_stats(self) -> dict[str, Any]:
|
|
749
|
+
"""
|
|
750
|
+
Get statistics about the current stream processing.
|
|
751
|
+
|
|
752
|
+
Returns:
|
|
753
|
+
Dictionary with processing statistics.
|
|
754
|
+
"""
|
|
755
|
+
return {
|
|
756
|
+
"provider": self.provider,
|
|
757
|
+
"pending_calls": len(self.get_pending_calls()),
|
|
758
|
+
"completed_calls": len(self.get_completed_calls()),
|
|
759
|
+
"text_length": len(self._text_buffer),
|
|
760
|
+
"total_calls": len(self._partial_calls),
|
|
761
|
+
}
|