abstractcore 2.5.2__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. abstractcore/__init__.py +19 -1
  2. abstractcore/architectures/detection.py +252 -6
  3. abstractcore/assets/architecture_formats.json +14 -1
  4. abstractcore/assets/model_capabilities.json +533 -10
  5. abstractcore/compression/__init__.py +29 -0
  6. abstractcore/compression/analytics.py +420 -0
  7. abstractcore/compression/cache.py +250 -0
  8. abstractcore/compression/config.py +279 -0
  9. abstractcore/compression/exceptions.py +30 -0
  10. abstractcore/compression/glyph_processor.py +381 -0
  11. abstractcore/compression/optimizer.py +388 -0
  12. abstractcore/compression/orchestrator.py +380 -0
  13. abstractcore/compression/pil_text_renderer.py +818 -0
  14. abstractcore/compression/quality.py +226 -0
  15. abstractcore/compression/text_formatter.py +666 -0
  16. abstractcore/compression/vision_compressor.py +371 -0
  17. abstractcore/config/main.py +64 -0
  18. abstractcore/config/manager.py +100 -5
  19. abstractcore/core/retry.py +2 -2
  20. abstractcore/core/session.py +193 -7
  21. abstractcore/download.py +253 -0
  22. abstractcore/embeddings/manager.py +2 -2
  23. abstractcore/events/__init__.py +113 -2
  24. abstractcore/exceptions/__init__.py +49 -2
  25. abstractcore/media/auto_handler.py +312 -18
  26. abstractcore/media/handlers/local_handler.py +14 -2
  27. abstractcore/media/handlers/openai_handler.py +62 -3
  28. abstractcore/media/processors/__init__.py +11 -1
  29. abstractcore/media/processors/direct_pdf_processor.py +210 -0
  30. abstractcore/media/processors/glyph_pdf_processor.py +227 -0
  31. abstractcore/media/processors/image_processor.py +7 -1
  32. abstractcore/media/processors/office_processor.py +2 -2
  33. abstractcore/media/processors/text_processor.py +18 -3
  34. abstractcore/media/types.py +164 -7
  35. abstractcore/media/utils/image_scaler.py +2 -2
  36. abstractcore/media/vision_fallback.py +2 -2
  37. abstractcore/providers/__init__.py +18 -0
  38. abstractcore/providers/anthropic_provider.py +228 -8
  39. abstractcore/providers/base.py +378 -11
  40. abstractcore/providers/huggingface_provider.py +563 -23
  41. abstractcore/providers/lmstudio_provider.py +284 -4
  42. abstractcore/providers/mlx_provider.py +27 -2
  43. abstractcore/providers/model_capabilities.py +352 -0
  44. abstractcore/providers/ollama_provider.py +282 -6
  45. abstractcore/providers/openai_provider.py +286 -8
  46. abstractcore/providers/registry.py +85 -13
  47. abstractcore/providers/streaming.py +2 -2
  48. abstractcore/server/app.py +91 -81
  49. abstractcore/tools/common_tools.py +2 -2
  50. abstractcore/tools/handler.py +2 -2
  51. abstractcore/tools/parser.py +2 -2
  52. abstractcore/tools/registry.py +2 -2
  53. abstractcore/tools/syntax_rewriter.py +2 -2
  54. abstractcore/tools/tag_rewriter.py +3 -3
  55. abstractcore/utils/__init__.py +4 -1
  56. abstractcore/utils/self_fixes.py +2 -2
  57. abstractcore/utils/trace_export.py +287 -0
  58. abstractcore/utils/version.py +1 -1
  59. abstractcore/utils/vlm_token_calculator.py +655 -0
  60. {abstractcore-2.5.2.dist-info → abstractcore-2.6.0.dist-info}/METADATA +207 -8
  61. abstractcore-2.6.0.dist-info/RECORD +108 -0
  62. abstractcore-2.5.2.dist-info/RECORD +0 -90
  63. {abstractcore-2.5.2.dist-info → abstractcore-2.6.0.dist-info}/WHEEL +0 -0
  64. {abstractcore-2.5.2.dist-info → abstractcore-2.6.0.dist-info}/entry_points.txt +0 -0
  65. {abstractcore-2.5.2.dist-info → abstractcore-2.6.0.dist-info}/licenses/LICENSE +0 -0
  66. {abstractcore-2.5.2.dist-info → abstractcore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,12 @@ BasicSession for conversation tracking.
3
3
  Target: <500 lines maximum.
4
4
  """
5
5
 
6
- from typing import List, Optional, Dict, Any, Union, Iterator, Callable
6
+ from typing import List, Optional, Dict, Any, Union, Iterator, AsyncIterator, Callable
7
7
  from datetime import datetime
8
8
  from pathlib import Path
9
9
  import json
10
10
  import uuid
11
+ import asyncio
11
12
  from collections.abc import Generator
12
13
 
13
14
  from .interface import AbstractCoreInterface
@@ -34,9 +35,10 @@ class BasicSession:
34
35
  auto_compact: bool = False,
35
36
  auto_compact_threshold: int = 6000,
36
37
  temperature: Optional[float] = None,
37
- seed: Optional[int] = None):
38
+ seed: Optional[int] = None,
39
+ enable_tracing: bool = False):
38
40
  """Initialize basic session
39
-
41
+
40
42
  Args:
41
43
  provider: LLM provider instance
42
44
  system_prompt: System prompt for the session
@@ -48,6 +50,7 @@ class BasicSession:
48
50
  auto_compact_threshold: Token threshold for auto-compaction
49
51
  temperature: Default temperature for generation (0.0-1.0)
50
52
  seed: Default seed for deterministic generation
53
+ enable_tracing: Enable interaction tracing for observability
51
54
  """
52
55
 
53
56
  self.provider = provider
@@ -59,11 +62,15 @@ class BasicSession:
59
62
  self.auto_compact = auto_compact
60
63
  self.auto_compact_threshold = auto_compact_threshold
61
64
  self._original_session = None # Track if this is a compacted session
62
-
65
+
63
66
  # Store session-level generation parameters
64
67
  self.temperature = temperature
65
68
  self.seed = seed
66
-
69
+
70
+ # Setup interaction tracing
71
+ self.enable_tracing = enable_tracing
72
+ self.interaction_traces: List[Dict[str, Any]] = [] # Session-specific traces
73
+
67
74
  # Optional analytics fields
68
75
  self.summary = None
69
76
  self.assessment = None
@@ -214,6 +221,16 @@ class BasicSession:
214
221
  if 'seed' not in kwargs and self.seed is not None:
215
222
  kwargs['seed'] = self.seed
216
223
 
224
+ # Add trace metadata if tracing is enabled
225
+ if self.enable_tracing:
226
+ if 'trace_metadata' not in kwargs:
227
+ kwargs['trace_metadata'] = {}
228
+ kwargs['trace_metadata'].update({
229
+ 'session_id': self.id,
230
+ 'step_type': kwargs.get('step_type', 'chat'),
231
+ 'attempt_number': kwargs.get('attempt_number', 1)
232
+ })
233
+
217
234
  # Call provider
218
235
  response = self.provider.generate(
219
236
  prompt=prompt,
@@ -231,6 +248,14 @@ class BasicSession:
231
248
  # Non-streaming response
232
249
  if hasattr(response, 'content') and response.content:
233
250
  self.add_message('assistant', response.content)
251
+
252
+ # Capture trace if enabled and available
253
+ if self.enable_tracing and hasattr(self.provider, 'get_traces'):
254
+ if hasattr(response, 'metadata') and response.metadata and 'trace_id' in response.metadata:
255
+ trace = self.provider.get_traces(response.metadata['trace_id'])
256
+ if trace:
257
+ self.interaction_traces.append(trace)
258
+
234
259
  return response
235
260
 
236
261
  def _handle_streaming_response(self, response_iterator: Iterator[GenerateResponse]) -> Iterator[GenerateResponse]:
@@ -249,6 +274,136 @@ class BasicSession:
249
274
  if collected_content:
250
275
  self.add_message('assistant', collected_content)
251
276
 
277
+ async def agenerate(self,
278
+ prompt: str,
279
+ name: Optional[str] = None,
280
+ location: Optional[str] = None,
281
+ **kwargs) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:
282
+ """
283
+ Async generation with conversation history.
284
+
285
+ Args:
286
+ prompt: User message
287
+ name: Optional speaker name
288
+ location: Optional location context
289
+ **kwargs: Generation parameters (stream, temperature, etc.)
290
+
291
+ Returns:
292
+ GenerateResponse or AsyncIterator for streaming
293
+
294
+ Example:
295
+ # Async chat interaction
296
+ response = await session.agenerate('What is Python?')
297
+
298
+ # Async streaming
299
+ async for chunk in await session.agenerate('Tell me a story', stream=True):
300
+ print(chunk.content, end='')
301
+ """
302
+ if not self.provider:
303
+ raise ValueError("No provider configured")
304
+
305
+ # Check for auto-compaction before generating
306
+ if self.auto_compact and self.should_compact(self.auto_compact_threshold):
307
+ print(f"🗜️ Auto-compacting session (tokens: {self.get_token_estimate()} > {self.auto_compact_threshold})")
308
+ compacted = self.compact(reason="auto_threshold")
309
+ # Replace current session with compacted version
310
+ self._replace_with_compacted(compacted)
311
+
312
+ # Pre-processing (fast, sync is fine)
313
+ self.add_message('user', prompt, name=name, location=location)
314
+
315
+ # Format messages for provider (exclude the current user message since provider will add it)
316
+ messages = self._format_messages_for_provider_excluding_current()
317
+
318
+ # Use session tools if not provided in kwargs
319
+ if 'tools' not in kwargs and self.tools:
320
+ kwargs['tools'] = self.tools
321
+
322
+ # Pass session tool_call_tags if available and not overridden in kwargs
323
+ if hasattr(self, 'tool_call_tags') and self.tool_call_tags is not None and 'tool_call_tags' not in kwargs:
324
+ kwargs['tool_call_tags'] = self.tool_call_tags
325
+
326
+ # Extract media parameter explicitly
327
+ media = kwargs.pop('media', None)
328
+
329
+ # Add session-level parameters if not overridden in kwargs
330
+ if 'temperature' not in kwargs and self.temperature is not None:
331
+ kwargs['temperature'] = self.temperature
332
+ if 'seed' not in kwargs and self.seed is not None:
333
+ kwargs['seed'] = self.seed
334
+
335
+ # Add trace metadata if tracing is enabled
336
+ if self.enable_tracing:
337
+ if 'trace_metadata' not in kwargs:
338
+ kwargs['trace_metadata'] = {}
339
+ kwargs['trace_metadata'].update({
340
+ 'session_id': self.id,
341
+ 'step_type': kwargs.get('step_type', 'chat'),
342
+ 'attempt_number': kwargs.get('attempt_number', 1)
343
+ })
344
+
345
+ # Check if streaming
346
+ stream = kwargs.get('stream', False)
347
+
348
+ if stream:
349
+ # Return async streaming wrapper that adds assistant message after
350
+ return self._async_session_stream(prompt, messages, media, **kwargs)
351
+ else:
352
+ # Async generation
353
+ response = await self.provider.agenerate(
354
+ prompt=prompt,
355
+ messages=messages,
356
+ system_prompt=self.system_prompt,
357
+ media=media,
358
+ **kwargs
359
+ )
360
+
361
+ # Post-processing (fast, sync is fine)
362
+ if hasattr(response, 'content') and response.content:
363
+ self.add_message('assistant', response.content)
364
+
365
+ # Capture trace if enabled and available
366
+ if self.enable_tracing and hasattr(self.provider, 'get_traces'):
367
+ if hasattr(response, 'metadata') and response.metadata and 'trace_id' in response.metadata:
368
+ trace = self.provider.get_traces(response.metadata['trace_id'])
369
+ if trace:
370
+ self.interaction_traces.append(trace)
371
+
372
+ return response
373
+
374
+ async def _async_session_stream(self,
375
+ prompt: str,
376
+ messages: List[Dict[str, str]],
377
+ media: Optional[List],
378
+ **kwargs) -> AsyncIterator[GenerateResponse]:
379
+ """Async streaming with session history management."""
380
+ collected_content = ""
381
+
382
+ # Remove 'stream' from kwargs since we're explicitly setting it
383
+ kwargs_copy = {k: v for k, v in kwargs.items() if k != 'stream'}
384
+
385
+ # CRITICAL: Await first to get async generator, then iterate
386
+ stream_gen = await self.provider.agenerate(
387
+ prompt=prompt,
388
+ messages=messages,
389
+ system_prompt=self.system_prompt,
390
+ media=media,
391
+ stream=True,
392
+ **kwargs_copy
393
+ )
394
+
395
+ async for chunk in stream_gen:
396
+ # Yield the chunk for the caller
397
+ yield chunk
398
+
399
+ # Collect content for history
400
+ if hasattr(chunk, 'content') and chunk.content:
401
+ collected_content += chunk.content
402
+
403
+ # After streaming completes, add assistant message
404
+ if collected_content:
405
+ self.add_message('assistant', collected_content)
406
+
252
407
  def _format_messages_for_provider(self) -> List[Dict[str, str]]:
253
408
  """Format messages for provider API"""
254
409
  return [
@@ -937,5 +1092,36 @@ class BasicSession:
937
1092
  focus_participant=focus_participant,
938
1093
  depth=depth_enum
939
1094
  )
940
-
941
- return results
1095
+
1096
+ return results
1097
+
1098
+ def get_interaction_history(self) -> List[Dict[str, Any]]:
1099
+ """
1100
+ Get all interaction traces for this session.
1101
+
1102
+ Returns a list of all LLM interaction traces captured during the session.
1103
+ Each trace contains complete information about the prompt, parameters,
1104
+ and response for observability and debugging.
1105
+
1106
+ Returns:
1107
+ List of trace dictionaries containing:
1108
+ - trace_id: Unique identifier for the interaction
1109
+ - timestamp: ISO format timestamp
1110
+ - provider: Provider name
1111
+ - model: Model name
1112
+ - prompt: User prompt
1113
+ - system_prompt: System prompt (if any)
1114
+ - messages: Conversation history
1115
+ - parameters: Generation parameters (temperature, tokens, etc.)
1116
+ - response: Full response with content, usage, timing
1117
+ - metadata: Custom metadata (session_id, step_type, etc.)
1118
+
1119
+ Example:
1120
+ >>> session = BasicSession(provider=llm, enable_tracing=True)
1121
+ >>> response = session.generate("What is Python?")
1122
+ >>> traces = session.get_interaction_history()
1123
+ >>> print(f"Captured {len(traces)} interactions")
1124
+ >>> print(f"First trace: {traces[0]['trace_id']}")
1125
+ >>> print(f"Tokens used: {traces[0]['response']['usage']}")
1126
+ """
1127
+ return self.interaction_traces.copy()
@@ -0,0 +1,253 @@
1
+ """
2
+ Model download API with async progress reporting.
3
+
4
+ Provides a provider-agnostic interface for downloading models from Ollama,
5
+ HuggingFace Hub, and MLX with streaming progress updates.
6
+ """
7
+
8
+ import json
9
+ import asyncio
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ from typing import AsyncIterator, Optional
13
+
14
+ import httpx
15
+
16
+
17
+ class DownloadStatus(Enum):
18
+ """Download progress status."""
19
+
20
+ STARTING = "starting"
21
+ DOWNLOADING = "downloading"
22
+ VERIFYING = "verifying"
23
+ COMPLETE = "complete"
24
+ ERROR = "error"
25
+
26
+
27
+ @dataclass
28
+ class DownloadProgress:
29
+ """Progress information for model download."""
30
+
31
+ status: DownloadStatus
32
+ message: str
33
+ percent: Optional[float] = None # 0-100
34
+ downloaded_bytes: Optional[int] = None
35
+ total_bytes: Optional[int] = None
36
+
37
+
38
+ async def download_model(
39
+ provider: str,
40
+ model: str,
41
+ token: Optional[str] = None,
42
+ base_url: Optional[str] = None,
43
+ ) -> AsyncIterator[DownloadProgress]:
44
+ """
45
+ Download a model with async progress reporting.
46
+
47
+ This function provides a unified interface for downloading models across
48
+ different providers. Progress updates are yielded as DownloadProgress
49
+ dataclasses that include status, message, and optional progress percentage.
50
+
51
+ Args:
52
+ provider: Provider name ("ollama", "huggingface", "mlx")
53
+ model: Model identifier:
54
+ - Ollama: "llama3:8b", "gemma3:1b", etc.
55
+ - HuggingFace/MLX: "meta-llama/Llama-2-7b", "mlx-community/Qwen3-4B-4bit", etc.
56
+ token: Optional auth token (for HuggingFace gated models)
57
+ base_url: Optional custom base URL (for Ollama, default: http://localhost:11434)
58
+
59
+ Yields:
60
+ DownloadProgress: Progress updates with status, message, and optional metrics
61
+
62
+ Raises:
63
+ ValueError: If provider doesn't support downloads (OpenAI, Anthropic, LMStudio)
64
+ httpx.HTTPStatusError: If Ollama server returns error
65
+ Exception: Various exceptions from HuggingFace Hub (RepositoryNotFoundError, etc.)
66
+
67
+ Examples:
68
+ Download Ollama model:
69
+ >>> async for progress in download_model("ollama", "gemma3:1b"):
70
+ ... print(f"{progress.status.value}: {progress.message}")
71
+ ... if progress.percent:
72
+ ... print(f" Progress: {progress.percent:.1f}%")
73
+
74
+ Download HuggingFace model with token:
75
+ >>> async for progress in download_model(
76
+ ... "huggingface",
77
+ ... "meta-llama/Llama-2-7b",
78
+ ... token="hf_..."
79
+ ... ):
80
+ ... print(f"{progress.message}")
81
+ """
82
+ provider_lower = provider.lower()
83
+
84
+ if provider_lower == "ollama":
85
+ async for progress in _download_ollama(model, base_url):
86
+ yield progress
87
+ elif provider_lower in ("huggingface", "mlx"):
88
+ async for progress in _download_huggingface(model, token):
89
+ yield progress
90
+ else:
91
+ raise ValueError(
92
+ f"Provider '{provider}' does not support model downloads. "
93
+ f"Supported providers: ollama, huggingface, mlx. "
94
+ f"Note: OpenAI and Anthropic are cloud-only; LMStudio has no download API."
95
+ )
96
+
97
+
98
+ async def _download_ollama(
99
+ model: str,
100
+ base_url: Optional[str] = None,
101
+ ) -> AsyncIterator[DownloadProgress]:
102
+ """
103
+ Download model from Ollama using /api/pull endpoint.
104
+
105
+ Args:
106
+ model: Ollama model name (e.g., "llama3:8b", "gemma3:1b")
107
+ base_url: Ollama server URL (default: http://localhost:11434)
108
+
109
+ Yields:
110
+ DownloadProgress with status updates from Ollama streaming response
111
+ """
112
+ url = (base_url or "http://localhost:11434").rstrip("/")
113
+
114
+ yield DownloadProgress(
115
+ status=DownloadStatus.STARTING, message=f"Pulling {model} from Ollama..."
116
+ )
117
+
118
+ try:
119
+ async with httpx.AsyncClient(timeout=None) as client:
120
+ async with client.stream(
121
+ "POST",
122
+ f"{url}/api/pull",
123
+ json={"name": model, "stream": True},
124
+ ) as response:
125
+ response.raise_for_status()
126
+
127
+ async for line in response.aiter_lines():
128
+ if not line:
129
+ continue
130
+
131
+ try:
132
+ data = json.loads(line)
133
+ except json.JSONDecodeError:
134
+ continue
135
+
136
+ status_msg = data.get("status", "")
137
+
138
+ # Parse progress from Ollama response
139
+ # Format: {"status": "downloading...", "total": 123, "completed": 45}
140
+ if "total" in data and "completed" in data:
141
+ total = data["total"]
142
+ completed = data["completed"]
143
+ percent = (completed / total * 100) if total > 0 else 0
144
+
145
+ yield DownloadProgress(
146
+ status=DownloadStatus.DOWNLOADING,
147
+ message=status_msg,
148
+ percent=percent,
149
+ downloaded_bytes=completed,
150
+ total_bytes=total,
151
+ )
152
+ elif status_msg == "success":
153
+ yield DownloadProgress(
154
+ status=DownloadStatus.COMPLETE,
155
+ message=f"Successfully pulled {model}",
156
+ percent=100.0,
157
+ )
158
+ elif "verifying" in status_msg.lower():
159
+ yield DownloadProgress(
160
+ status=DownloadStatus.VERIFYING,
161
+ message=status_msg,
162
+ )
163
+ else:
164
+ # Other status messages (pulling manifest, etc.)
165
+ yield DownloadProgress(
166
+ status=DownloadStatus.DOWNLOADING,
167
+ message=status_msg,
168
+ )
169
+
170
+ except httpx.HTTPStatusError as e:
171
+ yield DownloadProgress(
172
+ status=DownloadStatus.ERROR,
173
+ message=f"Ollama server error: {e.response.status_code} - {e.response.text}",
174
+ )
175
+ except httpx.ConnectError:
176
+ yield DownloadProgress(
177
+ status=DownloadStatus.ERROR,
178
+ message=f"Cannot connect to Ollama server at {url}. Is Ollama running?",
179
+ )
180
+ except Exception as e:
181
+ yield DownloadProgress(
182
+ status=DownloadStatus.ERROR,
183
+ message=f"Download failed: {str(e)}",
184
+ )
185
+
186
+
187
+ async def _download_huggingface(
188
+ model: str,
189
+ token: Optional[str] = None,
190
+ ) -> AsyncIterator[DownloadProgress]:
191
+ """
192
+ Download model from HuggingFace Hub.
193
+
194
+ Args:
195
+ model: HuggingFace model identifier (e.g., "meta-llama/Llama-2-7b")
196
+ token: Optional HuggingFace token (required for gated models)
197
+
198
+ Yields:
199
+ DownloadProgress with status updates
200
+ """
201
+ yield DownloadProgress(
202
+ status=DownloadStatus.STARTING,
203
+ message=f"Downloading {model} from HuggingFace Hub...",
204
+ )
205
+
206
+ try:
207
+ # Import here to make huggingface_hub optional
208
+ from huggingface_hub import snapshot_download
209
+ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
210
+ except ImportError:
211
+ yield DownloadProgress(
212
+ status=DownloadStatus.ERROR,
213
+ message=(
214
+ "huggingface_hub is not installed. "
215
+ "Install with: pip install abstractcore[huggingface]"
216
+ ),
217
+ )
218
+ return
219
+
220
+ try:
221
+ # Run blocking download in thread
222
+ # Note: snapshot_download doesn't have built-in async progress callbacks
223
+ # We provide start and completion messages
224
+ await asyncio.to_thread(
225
+ snapshot_download,
226
+ repo_id=model,
227
+ token=token,
228
+ )
229
+
230
+ yield DownloadProgress(
231
+ status=DownloadStatus.COMPLETE,
232
+ message=f"Successfully downloaded {model}",
233
+ percent=100.0,
234
+ )
235
+
236
+ except RepositoryNotFoundError:
237
+ yield DownloadProgress(
238
+ status=DownloadStatus.ERROR,
239
+ message=f"Model '{model}' not found on HuggingFace Hub",
240
+ )
241
+ except GatedRepoError:
242
+ yield DownloadProgress(
243
+ status=DownloadStatus.ERROR,
244
+ message=(
245
+ f"Model '{model}' requires authentication. "
246
+ f"Provide a HuggingFace token via the 'token' parameter."
247
+ ),
248
+ )
249
+ except Exception as e:
250
+ yield DownloadProgress(
251
+ status=DownloadStatus.ERROR,
252
+ message=f"Download failed: {str(e)}",
253
+ )
@@ -7,7 +7,6 @@ Production-ready embedding generation with SOTA models and efficient serving.
7
7
 
8
8
  import hashlib
9
9
  import pickle
10
- import logging
11
10
  import atexit
12
11
  import sys
13
12
  import builtins
@@ -33,8 +32,9 @@ except ImportError:
33
32
  emit_global = None
34
33
 
35
34
  from .models import EmbeddingBackend, get_model_config, list_available_models, get_default_model
35
+ from ..utils.structured_logging import get_logger
36
36
 
37
- logger = logging.getLogger(__name__)
37
+ logger = get_logger(__name__)
38
38
 
39
39
 
40
40
  @contextmanager
@@ -20,15 +20,17 @@ from enum import Enum
20
20
  from dataclasses import dataclass, field
21
21
  from datetime import datetime
22
22
  import uuid
23
+ import asyncio
23
24
 
24
25
 
25
26
  class EventType(Enum):
26
27
  """Minimal event system - clean, simple, efficient"""
27
28
 
28
- # Core events (4) - matches LangChain pattern
29
+ # Core events (5) - matches LangChain pattern + async progress
29
30
  GENERATION_STARTED = "generation_started" # Unified for streaming and non-streaming
30
31
  GENERATION_COMPLETED = "generation_completed" # Includes all metrics
31
32
  TOOL_STARTED = "tool_started" # Before tool execution
33
+ TOOL_PROGRESS = "tool_progress" # Real-time progress during tool execution
32
34
  TOOL_COMPLETED = "tool_completed" # After tool execution
33
35
 
34
36
  # Error handling (1)
@@ -60,6 +62,7 @@ class EventEmitter:
60
62
 
61
63
  def __init__(self):
62
64
  self._listeners: Dict[EventType, List[Callable]] = {}
65
+ self._async_listeners: Dict[EventType, List[Callable]] = {}
63
66
 
64
67
  def on(self, event_type: EventType, handler: Callable):
65
68
  """
@@ -141,6 +144,67 @@ class EventEmitter:
141
144
  }
142
145
  )
143
146
 
147
+ def on_async(self, event_type: EventType, handler: Callable):
148
+ """
149
+ Register an async event handler.
150
+
151
+ Args:
152
+ event_type: Type of event to listen for
153
+ handler: Async function to call when event occurs
154
+ """
155
+ if event_type not in self._async_listeners:
156
+ self._async_listeners[event_type] = []
157
+ self._async_listeners[event_type].append(handler)
158
+
159
+ async def emit_async(self, event_type: EventType, data: Dict[str, Any], source: Optional[str] = None, **kwargs) -> Event:
160
+ """
161
+ Emit an event asynchronously to all registered handlers.
162
+
163
+ Runs async handlers concurrently with asyncio.gather().
164
+ Also triggers sync handlers for backward compatibility.
165
+
166
+ Args:
167
+ event_type: Type of event
168
+ data: Event data
169
+ source: Source of the event
170
+ **kwargs: Additional event attributes (model_name, tokens, etc.)
171
+
172
+ Returns:
173
+ The event object
174
+ """
175
+ # Filter kwargs to only include valid Event fields
176
+ try:
177
+ valid_fields = set(Event.__dataclass_fields__.keys())
178
+ except AttributeError:
179
+ # Fallback for older Python versions
180
+ valid_fields = {'trace_id', 'span_id', 'request_id', 'duration_ms', 'model_name',
181
+ 'provider_name', 'tokens_input', 'tokens_output', 'cost_usd', 'metadata'}
182
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields}
183
+
184
+ event = Event(
185
+ type=event_type,
186
+ timestamp=datetime.now(),
187
+ data=data,
188
+ source=source or self.__class__.__name__,
189
+ **filtered_kwargs
190
+ )
191
+
192
+ # Run async handlers concurrently
193
+ if event_type in self._async_listeners:
194
+ tasks = [handler(event) for handler in self._async_listeners[event_type]]
195
+ await asyncio.gather(*tasks, return_exceptions=True)
196
+
197
+ # Also run sync handlers (backward compatible)
198
+ if event_type in self._listeners:
199
+ for handler in self._listeners[event_type]:
200
+ try:
201
+ handler(event)
202
+ except Exception as e:
203
+ # Log error but don't stop event propagation
204
+ print(f"Error in event handler: {e}")
205
+
206
+ return event
207
+
144
208
 
145
209
  class GlobalEventBus:
146
210
  """
@@ -149,6 +213,7 @@ class GlobalEventBus:
149
213
  """
150
214
  _instance = None
151
215
  _listeners: Dict[EventType, List[Callable]] = {}
216
+ _async_listeners: Dict[EventType, List[Callable]] = {}
152
217
 
153
218
  def __new__(cls):
154
219
  if cls._instance is None:
@@ -199,6 +264,52 @@ class GlobalEventBus:
199
264
  def clear(cls):
200
265
  """Clear all global event handlers"""
201
266
  cls._listeners.clear()
267
+ cls._async_listeners.clear()
268
+
269
+ @classmethod
270
+ def on_async(cls, event_type: EventType, handler: Callable):
271
+ """Register a global async event handler"""
272
+ if event_type not in cls._async_listeners:
273
+ cls._async_listeners[event_type] = []
274
+ cls._async_listeners[event_type].append(handler)
275
+
276
+ @classmethod
277
+ async def emit_async(cls, event_type: EventType, data: Dict[str, Any], source: Optional[str] = None, **kwargs):
278
+ """
279
+ Emit a global event asynchronously.
280
+
281
+ Runs async handlers concurrently with asyncio.gather().
282
+ Also triggers sync handlers for backward compatibility.
283
+ """
284
+ # Filter kwargs to only include valid Event fields
285
+ try:
286
+ valid_fields = set(Event.__dataclass_fields__.keys())
287
+ except AttributeError:
288
+ # Fallback for older Python versions
289
+ valid_fields = {'trace_id', 'span_id', 'request_id', 'duration_ms', 'model_name',
290
+ 'provider_name', 'tokens_input', 'tokens_output', 'cost_usd', 'metadata'}
291
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields}
292
+
293
+ event = Event(
294
+ type=event_type,
295
+ timestamp=datetime.now(),
296
+ data=data,
297
+ source=source or "GlobalEventBus",
298
+ **filtered_kwargs
299
+ )
300
+
301
+ # Run async handlers concurrently
302
+ if event_type in cls._async_listeners:
303
+ tasks = [handler(event) for handler in cls._async_listeners[event_type]]
304
+ await asyncio.gather(*tasks, return_exceptions=True)
305
+
306
+ # Also run sync handlers (backward compatible)
307
+ if event_type in cls._listeners:
308
+ for handler in cls._listeners[event_type]:
309
+ try:
310
+ handler(event)
311
+ except Exception as e:
312
+ print(f"Error in global event handler: {e}")
202
313
 
203
314
 
204
315
  # Convenience functions
@@ -216,7 +327,7 @@ def emit_global(event_type: EventType, data: Dict[str, Any], source: Optional[st
216
327
  def create_generation_event(model_name: str, provider_name: str,
217
328
  tokens_input: int = None, tokens_output: int = None,
218
329
  duration_ms: float = None, cost_usd: float = None,
219
- **data) -> Dict[str, Any]:
330
+ **data) -> tuple[Dict[str, Any], Dict[str, Any]]:
220
331
  """Create standardized generation event data"""
221
332
  event_data = {
222
333
  "model_name": model_name,