abstractcore 2.5.3__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- abstractcore/__init__.py +7 -1
- abstractcore/architectures/detection.py +2 -2
- abstractcore/core/retry.py +2 -2
- abstractcore/core/session.py +132 -1
- abstractcore/download.py +253 -0
- abstractcore/embeddings/manager.py +2 -2
- abstractcore/events/__init__.py +112 -1
- abstractcore/exceptions/__init__.py +49 -2
- abstractcore/media/processors/office_processor.py +2 -2
- abstractcore/media/utils/image_scaler.py +2 -2
- abstractcore/media/vision_fallback.py +2 -2
- abstractcore/providers/anthropic_provider.py +200 -6
- abstractcore/providers/base.py +100 -5
- abstractcore/providers/lmstudio_provider.py +246 -2
- abstractcore/providers/ollama_provider.py +244 -2
- abstractcore/providers/openai_provider.py +258 -6
- abstractcore/providers/streaming.py +2 -2
- abstractcore/tools/common_tools.py +2 -2
- abstractcore/tools/handler.py +2 -2
- abstractcore/tools/parser.py +2 -2
- abstractcore/tools/registry.py +2 -2
- abstractcore/tools/syntax_rewriter.py +2 -2
- abstractcore/tools/tag_rewriter.py +3 -3
- abstractcore/utils/self_fixes.py +2 -2
- abstractcore/utils/version.py +1 -1
- {abstractcore-2.5.3.dist-info → abstractcore-2.6.0.dist-info}/METADATA +102 -4
- {abstractcore-2.5.3.dist-info → abstractcore-2.6.0.dist-info}/RECORD +31 -30
- {abstractcore-2.5.3.dist-info → abstractcore-2.6.0.dist-info}/WHEEL +0 -0
- {abstractcore-2.5.3.dist-info → abstractcore-2.6.0.dist-info}/entry_points.txt +0 -0
- {abstractcore-2.5.3.dist-info → abstractcore-2.6.0.dist-info}/licenses/LICENSE +0 -0
- {abstractcore-2.5.3.dist-info → abstractcore-2.6.0.dist-info}/top_level.txt +0 -0
abstractcore/__init__.py
CHANGED
|
@@ -49,6 +49,9 @@ _has_processing = True
|
|
|
49
49
|
# Tools module (core functionality)
|
|
50
50
|
from .tools import tool
|
|
51
51
|
|
|
52
|
+
# Download module (core functionality)
|
|
53
|
+
from .download import download_model, DownloadProgress, DownloadStatus
|
|
54
|
+
|
|
52
55
|
# Compression module (optional import)
|
|
53
56
|
try:
|
|
54
57
|
from .compression import GlyphConfig, CompressionOrchestrator
|
|
@@ -67,7 +70,10 @@ __all__ = [
|
|
|
67
70
|
'ModelNotFoundError',
|
|
68
71
|
'ProviderAPIError',
|
|
69
72
|
'AuthenticationError',
|
|
70
|
-
'tool'
|
|
73
|
+
'tool',
|
|
74
|
+
'download_model',
|
|
75
|
+
'DownloadProgress',
|
|
76
|
+
'DownloadStatus',
|
|
71
77
|
]
|
|
72
78
|
|
|
73
79
|
if _has_embeddings:
|
|
@@ -9,9 +9,9 @@ import json
|
|
|
9
9
|
import os
|
|
10
10
|
from typing import Dict, Any, Optional, List
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
import
|
|
12
|
+
from ..utils.structured_logging import get_logger
|
|
13
13
|
|
|
14
|
-
logger =
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
15
|
|
|
16
16
|
# Cache for loaded JSON data
|
|
17
17
|
_architecture_formats: Optional[Dict[str, Any]] = None
|
abstractcore/core/retry.py
CHANGED
|
@@ -8,13 +8,13 @@ and production LLM system requirements.
|
|
|
8
8
|
|
|
9
9
|
import time
|
|
10
10
|
import random
|
|
11
|
-
import logging
|
|
12
11
|
from typing import Type, Optional, Set, Dict, Any
|
|
13
12
|
from dataclasses import dataclass
|
|
14
13
|
from datetime import datetime, timedelta
|
|
15
14
|
from enum import Enum
|
|
15
|
+
from ..utils.structured_logging import get_logger
|
|
16
16
|
|
|
17
|
-
logger =
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class RetryableErrorType(Enum):
|
abstractcore/core/session.py
CHANGED
|
@@ -3,11 +3,12 @@ BasicSession for conversation tracking.
|
|
|
3
3
|
Target: <500 lines maximum.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
from typing import List, Optional, Dict, Any, Union, Iterator, Callable
|
|
6
|
+
from typing import List, Optional, Dict, Any, Union, Iterator, AsyncIterator, Callable
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
import json
|
|
10
10
|
import uuid
|
|
11
|
+
import asyncio
|
|
11
12
|
from collections.abc import Generator
|
|
12
13
|
|
|
13
14
|
from .interface import AbstractCoreInterface
|
|
@@ -273,6 +274,136 @@ class BasicSession:
|
|
|
273
274
|
if collected_content:
|
|
274
275
|
self.add_message('assistant', collected_content)
|
|
275
276
|
|
|
277
|
+
async def agenerate(self,
|
|
278
|
+
prompt: str,
|
|
279
|
+
name: Optional[str] = None,
|
|
280
|
+
location: Optional[str] = None,
|
|
281
|
+
**kwargs) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:
|
|
282
|
+
"""
|
|
283
|
+
Async generation with conversation history.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
prompt: User message
|
|
287
|
+
name: Optional speaker name
|
|
288
|
+
location: Optional location context
|
|
289
|
+
**kwargs: Generation parameters (stream, temperature, etc.)
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
GenerateResponse or AsyncIterator for streaming
|
|
293
|
+
|
|
294
|
+
Example:
|
|
295
|
+
# Async chat interaction
|
|
296
|
+
response = await session.agenerate('What is Python?')
|
|
297
|
+
|
|
298
|
+
# Async streaming
|
|
299
|
+
async for chunk in await session.agenerate('Tell me a story', stream=True):
|
|
300
|
+
print(chunk.content, end='')
|
|
301
|
+
"""
|
|
302
|
+
if not self.provider:
|
|
303
|
+
raise ValueError("No provider configured")
|
|
304
|
+
|
|
305
|
+
# Check for auto-compaction before generating
|
|
306
|
+
if self.auto_compact and self.should_compact(self.auto_compact_threshold):
|
|
307
|
+
print(f"🗜️ Auto-compacting session (tokens: {self.get_token_estimate()} > {self.auto_compact_threshold})")
|
|
308
|
+
compacted = self.compact(reason="auto_threshold")
|
|
309
|
+
# Replace current session with compacted version
|
|
310
|
+
self._replace_with_compacted(compacted)
|
|
311
|
+
|
|
312
|
+
# Pre-processing (fast, sync is fine)
|
|
313
|
+
self.add_message('user', prompt, name=name, location=location)
|
|
314
|
+
|
|
315
|
+
# Format messages for provider (exclude the current user message since provider will add it)
|
|
316
|
+
messages = self._format_messages_for_provider_excluding_current()
|
|
317
|
+
|
|
318
|
+
# Use session tools if not provided in kwargs
|
|
319
|
+
if 'tools' not in kwargs and self.tools:
|
|
320
|
+
kwargs['tools'] = self.tools
|
|
321
|
+
|
|
322
|
+
# Pass session tool_call_tags if available and not overridden in kwargs
|
|
323
|
+
if hasattr(self, 'tool_call_tags') and self.tool_call_tags is not None and 'tool_call_tags' not in kwargs:
|
|
324
|
+
kwargs['tool_call_tags'] = self.tool_call_tags
|
|
325
|
+
|
|
326
|
+
# Extract media parameter explicitly
|
|
327
|
+
media = kwargs.pop('media', None)
|
|
328
|
+
|
|
329
|
+
# Add session-level parameters if not overridden in kwargs
|
|
330
|
+
if 'temperature' not in kwargs and self.temperature is not None:
|
|
331
|
+
kwargs['temperature'] = self.temperature
|
|
332
|
+
if 'seed' not in kwargs and self.seed is not None:
|
|
333
|
+
kwargs['seed'] = self.seed
|
|
334
|
+
|
|
335
|
+
# Add trace metadata if tracing is enabled
|
|
336
|
+
if self.enable_tracing:
|
|
337
|
+
if 'trace_metadata' not in kwargs:
|
|
338
|
+
kwargs['trace_metadata'] = {}
|
|
339
|
+
kwargs['trace_metadata'].update({
|
|
340
|
+
'session_id': self.id,
|
|
341
|
+
'step_type': kwargs.get('step_type', 'chat'),
|
|
342
|
+
'attempt_number': kwargs.get('attempt_number', 1)
|
|
343
|
+
})
|
|
344
|
+
|
|
345
|
+
# Check if streaming
|
|
346
|
+
stream = kwargs.get('stream', False)
|
|
347
|
+
|
|
348
|
+
if stream:
|
|
349
|
+
# Return async streaming wrapper that adds assistant message after
|
|
350
|
+
return self._async_session_stream(prompt, messages, media, **kwargs)
|
|
351
|
+
else:
|
|
352
|
+
# Async generation
|
|
353
|
+
response = await self.provider.agenerate(
|
|
354
|
+
prompt=prompt,
|
|
355
|
+
messages=messages,
|
|
356
|
+
system_prompt=self.system_prompt,
|
|
357
|
+
media=media,
|
|
358
|
+
**kwargs
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Post-processing (fast, sync is fine)
|
|
362
|
+
if hasattr(response, 'content') and response.content:
|
|
363
|
+
self.add_message('assistant', response.content)
|
|
364
|
+
|
|
365
|
+
# Capture trace if enabled and available
|
|
366
|
+
if self.enable_tracing and hasattr(self.provider, 'get_traces'):
|
|
367
|
+
if hasattr(response, 'metadata') and response.metadata and 'trace_id' in response.metadata:
|
|
368
|
+
trace = self.provider.get_traces(response.metadata['trace_id'])
|
|
369
|
+
if trace:
|
|
370
|
+
self.interaction_traces.append(trace)
|
|
371
|
+
|
|
372
|
+
return response
|
|
373
|
+
|
|
374
|
+
async def _async_session_stream(self,
|
|
375
|
+
prompt: str,
|
|
376
|
+
messages: List[Dict[str, str]],
|
|
377
|
+
media: Optional[List],
|
|
378
|
+
**kwargs) -> AsyncIterator[GenerateResponse]:
|
|
379
|
+
"""Async streaming with session history management."""
|
|
380
|
+
collected_content = ""
|
|
381
|
+
|
|
382
|
+
# Remove 'stream' from kwargs since we're explicitly setting it
|
|
383
|
+
kwargs_copy = {k: v for k, v in kwargs.items() if k != 'stream'}
|
|
384
|
+
|
|
385
|
+
# CRITICAL: Await first to get async generator, then iterate
|
|
386
|
+
stream_gen = await self.provider.agenerate(
|
|
387
|
+
prompt=prompt,
|
|
388
|
+
messages=messages,
|
|
389
|
+
system_prompt=self.system_prompt,
|
|
390
|
+
media=media,
|
|
391
|
+
stream=True,
|
|
392
|
+
**kwargs_copy
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
async for chunk in stream_gen:
|
|
396
|
+
# Yield the chunk for the caller
|
|
397
|
+
yield chunk
|
|
398
|
+
|
|
399
|
+
# Collect content for history
|
|
400
|
+
if hasattr(chunk, 'content') and chunk.content:
|
|
401
|
+
collected_content += chunk.content
|
|
402
|
+
|
|
403
|
+
# After streaming completes, add assistant message
|
|
404
|
+
if collected_content:
|
|
405
|
+
self.add_message('assistant', collected_content)
|
|
406
|
+
|
|
276
407
|
def _format_messages_for_provider(self) -> List[Dict[str, str]]:
|
|
277
408
|
"""Format messages for provider API"""
|
|
278
409
|
return [
|
abstractcore/download.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model download API with async progress reporting.
|
|
3
|
+
|
|
4
|
+
Provides a provider-agnostic interface for downloading models from Ollama,
|
|
5
|
+
HuggingFace Hub, and MLX with streaming progress updates.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import asyncio
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import AsyncIterator, Optional
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DownloadStatus(Enum):
|
|
18
|
+
"""Download progress status."""
|
|
19
|
+
|
|
20
|
+
STARTING = "starting"
|
|
21
|
+
DOWNLOADING = "downloading"
|
|
22
|
+
VERIFYING = "verifying"
|
|
23
|
+
COMPLETE = "complete"
|
|
24
|
+
ERROR = "error"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class DownloadProgress:
|
|
29
|
+
"""Progress information for model download."""
|
|
30
|
+
|
|
31
|
+
status: DownloadStatus
|
|
32
|
+
message: str
|
|
33
|
+
percent: Optional[float] = None # 0-100
|
|
34
|
+
downloaded_bytes: Optional[int] = None
|
|
35
|
+
total_bytes: Optional[int] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
async def download_model(
|
|
39
|
+
provider: str,
|
|
40
|
+
model: str,
|
|
41
|
+
token: Optional[str] = None,
|
|
42
|
+
base_url: Optional[str] = None,
|
|
43
|
+
) -> AsyncIterator[DownloadProgress]:
|
|
44
|
+
"""
|
|
45
|
+
Download a model with async progress reporting.
|
|
46
|
+
|
|
47
|
+
This function provides a unified interface for downloading models across
|
|
48
|
+
different providers. Progress updates are yielded as DownloadProgress
|
|
49
|
+
dataclasses that include status, message, and optional progress percentage.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
provider: Provider name ("ollama", "huggingface", "mlx")
|
|
53
|
+
model: Model identifier:
|
|
54
|
+
- Ollama: "llama3:8b", "gemma3:1b", etc.
|
|
55
|
+
- HuggingFace/MLX: "meta-llama/Llama-2-7b", "mlx-community/Qwen3-4B-4bit", etc.
|
|
56
|
+
token: Optional auth token (for HuggingFace gated models)
|
|
57
|
+
base_url: Optional custom base URL (for Ollama, default: http://localhost:11434)
|
|
58
|
+
|
|
59
|
+
Yields:
|
|
60
|
+
DownloadProgress: Progress updates with status, message, and optional metrics
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If provider doesn't support downloads (OpenAI, Anthropic, LMStudio)
|
|
64
|
+
httpx.HTTPStatusError: If Ollama server returns error
|
|
65
|
+
Exception: Various exceptions from HuggingFace Hub (RepositoryNotFoundError, etc.)
|
|
66
|
+
|
|
67
|
+
Examples:
|
|
68
|
+
Download Ollama model:
|
|
69
|
+
>>> async for progress in download_model("ollama", "gemma3:1b"):
|
|
70
|
+
... print(f"{progress.status.value}: {progress.message}")
|
|
71
|
+
... if progress.percent:
|
|
72
|
+
... print(f" Progress: {progress.percent:.1f}%")
|
|
73
|
+
|
|
74
|
+
Download HuggingFace model with token:
|
|
75
|
+
>>> async for progress in download_model(
|
|
76
|
+
... "huggingface",
|
|
77
|
+
... "meta-llama/Llama-2-7b",
|
|
78
|
+
... token="hf_..."
|
|
79
|
+
... ):
|
|
80
|
+
... print(f"{progress.message}")
|
|
81
|
+
"""
|
|
82
|
+
provider_lower = provider.lower()
|
|
83
|
+
|
|
84
|
+
if provider_lower == "ollama":
|
|
85
|
+
async for progress in _download_ollama(model, base_url):
|
|
86
|
+
yield progress
|
|
87
|
+
elif provider_lower in ("huggingface", "mlx"):
|
|
88
|
+
async for progress in _download_huggingface(model, token):
|
|
89
|
+
yield progress
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Provider '{provider}' does not support model downloads. "
|
|
93
|
+
f"Supported providers: ollama, huggingface, mlx. "
|
|
94
|
+
f"Note: OpenAI and Anthropic are cloud-only; LMStudio has no download API."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def _download_ollama(
|
|
99
|
+
model: str,
|
|
100
|
+
base_url: Optional[str] = None,
|
|
101
|
+
) -> AsyncIterator[DownloadProgress]:
|
|
102
|
+
"""
|
|
103
|
+
Download model from Ollama using /api/pull endpoint.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
model: Ollama model name (e.g., "llama3:8b", "gemma3:1b")
|
|
107
|
+
base_url: Ollama server URL (default: http://localhost:11434)
|
|
108
|
+
|
|
109
|
+
Yields:
|
|
110
|
+
DownloadProgress with status updates from Ollama streaming response
|
|
111
|
+
"""
|
|
112
|
+
url = (base_url or "http://localhost:11434").rstrip("/")
|
|
113
|
+
|
|
114
|
+
yield DownloadProgress(
|
|
115
|
+
status=DownloadStatus.STARTING, message=f"Pulling {model} from Ollama..."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
async with httpx.AsyncClient(timeout=None) as client:
|
|
120
|
+
async with client.stream(
|
|
121
|
+
"POST",
|
|
122
|
+
f"{url}/api/pull",
|
|
123
|
+
json={"name": model, "stream": True},
|
|
124
|
+
) as response:
|
|
125
|
+
response.raise_for_status()
|
|
126
|
+
|
|
127
|
+
async for line in response.aiter_lines():
|
|
128
|
+
if not line:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
data = json.loads(line)
|
|
133
|
+
except json.JSONDecodeError:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
status_msg = data.get("status", "")
|
|
137
|
+
|
|
138
|
+
# Parse progress from Ollama response
|
|
139
|
+
# Format: {"status": "downloading...", "total": 123, "completed": 45}
|
|
140
|
+
if "total" in data and "completed" in data:
|
|
141
|
+
total = data["total"]
|
|
142
|
+
completed = data["completed"]
|
|
143
|
+
percent = (completed / total * 100) if total > 0 else 0
|
|
144
|
+
|
|
145
|
+
yield DownloadProgress(
|
|
146
|
+
status=DownloadStatus.DOWNLOADING,
|
|
147
|
+
message=status_msg,
|
|
148
|
+
percent=percent,
|
|
149
|
+
downloaded_bytes=completed,
|
|
150
|
+
total_bytes=total,
|
|
151
|
+
)
|
|
152
|
+
elif status_msg == "success":
|
|
153
|
+
yield DownloadProgress(
|
|
154
|
+
status=DownloadStatus.COMPLETE,
|
|
155
|
+
message=f"Successfully pulled {model}",
|
|
156
|
+
percent=100.0,
|
|
157
|
+
)
|
|
158
|
+
elif "verifying" in status_msg.lower():
|
|
159
|
+
yield DownloadProgress(
|
|
160
|
+
status=DownloadStatus.VERIFYING,
|
|
161
|
+
message=status_msg,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
# Other status messages (pulling manifest, etc.)
|
|
165
|
+
yield DownloadProgress(
|
|
166
|
+
status=DownloadStatus.DOWNLOADING,
|
|
167
|
+
message=status_msg,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
except httpx.HTTPStatusError as e:
|
|
171
|
+
yield DownloadProgress(
|
|
172
|
+
status=DownloadStatus.ERROR,
|
|
173
|
+
message=f"Ollama server error: {e.response.status_code} - {e.response.text}",
|
|
174
|
+
)
|
|
175
|
+
except httpx.ConnectError:
|
|
176
|
+
yield DownloadProgress(
|
|
177
|
+
status=DownloadStatus.ERROR,
|
|
178
|
+
message=f"Cannot connect to Ollama server at {url}. Is Ollama running?",
|
|
179
|
+
)
|
|
180
|
+
except Exception as e:
|
|
181
|
+
yield DownloadProgress(
|
|
182
|
+
status=DownloadStatus.ERROR,
|
|
183
|
+
message=f"Download failed: {str(e)}",
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
async def _download_huggingface(
|
|
188
|
+
model: str,
|
|
189
|
+
token: Optional[str] = None,
|
|
190
|
+
) -> AsyncIterator[DownloadProgress]:
|
|
191
|
+
"""
|
|
192
|
+
Download model from HuggingFace Hub.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
model: HuggingFace model identifier (e.g., "meta-llama/Llama-2-7b")
|
|
196
|
+
token: Optional HuggingFace token (required for gated models)
|
|
197
|
+
|
|
198
|
+
Yields:
|
|
199
|
+
DownloadProgress with status updates
|
|
200
|
+
"""
|
|
201
|
+
yield DownloadProgress(
|
|
202
|
+
status=DownloadStatus.STARTING,
|
|
203
|
+
message=f"Downloading {model} from HuggingFace Hub...",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
# Import here to make huggingface_hub optional
|
|
208
|
+
from huggingface_hub import snapshot_download
|
|
209
|
+
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
|
210
|
+
except ImportError:
|
|
211
|
+
yield DownloadProgress(
|
|
212
|
+
status=DownloadStatus.ERROR,
|
|
213
|
+
message=(
|
|
214
|
+
"huggingface_hub is not installed. "
|
|
215
|
+
"Install with: pip install abstractcore[huggingface]"
|
|
216
|
+
),
|
|
217
|
+
)
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
# Run blocking download in thread
|
|
222
|
+
# Note: snapshot_download doesn't have built-in async progress callbacks
|
|
223
|
+
# We provide start and completion messages
|
|
224
|
+
await asyncio.to_thread(
|
|
225
|
+
snapshot_download,
|
|
226
|
+
repo_id=model,
|
|
227
|
+
token=token,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
yield DownloadProgress(
|
|
231
|
+
status=DownloadStatus.COMPLETE,
|
|
232
|
+
message=f"Successfully downloaded {model}",
|
|
233
|
+
percent=100.0,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
except RepositoryNotFoundError:
|
|
237
|
+
yield DownloadProgress(
|
|
238
|
+
status=DownloadStatus.ERROR,
|
|
239
|
+
message=f"Model '{model}' not found on HuggingFace Hub",
|
|
240
|
+
)
|
|
241
|
+
except GatedRepoError:
|
|
242
|
+
yield DownloadProgress(
|
|
243
|
+
status=DownloadStatus.ERROR,
|
|
244
|
+
message=(
|
|
245
|
+
f"Model '{model}' requires authentication. "
|
|
246
|
+
f"Provide a HuggingFace token via the 'token' parameter."
|
|
247
|
+
),
|
|
248
|
+
)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
yield DownloadProgress(
|
|
251
|
+
status=DownloadStatus.ERROR,
|
|
252
|
+
message=f"Download failed: {str(e)}",
|
|
253
|
+
)
|
|
@@ -7,7 +7,6 @@ Production-ready embedding generation with SOTA models and efficient serving.
|
|
|
7
7
|
|
|
8
8
|
import hashlib
|
|
9
9
|
import pickle
|
|
10
|
-
import logging
|
|
11
10
|
import atexit
|
|
12
11
|
import sys
|
|
13
12
|
import builtins
|
|
@@ -33,8 +32,9 @@ except ImportError:
|
|
|
33
32
|
emit_global = None
|
|
34
33
|
|
|
35
34
|
from .models import EmbeddingBackend, get_model_config, list_available_models, get_default_model
|
|
35
|
+
from ..utils.structured_logging import get_logger
|
|
36
36
|
|
|
37
|
-
logger =
|
|
37
|
+
logger = get_logger(__name__)
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
@contextmanager
|
abstractcore/events/__init__.py
CHANGED
|
@@ -20,15 +20,17 @@ from enum import Enum
|
|
|
20
20
|
from dataclasses import dataclass, field
|
|
21
21
|
from datetime import datetime
|
|
22
22
|
import uuid
|
|
23
|
+
import asyncio
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class EventType(Enum):
|
|
26
27
|
"""Minimal event system - clean, simple, efficient"""
|
|
27
28
|
|
|
28
|
-
# Core events (
|
|
29
|
+
# Core events (5) - matches LangChain pattern + async progress
|
|
29
30
|
GENERATION_STARTED = "generation_started" # Unified for streaming and non-streaming
|
|
30
31
|
GENERATION_COMPLETED = "generation_completed" # Includes all metrics
|
|
31
32
|
TOOL_STARTED = "tool_started" # Before tool execution
|
|
33
|
+
TOOL_PROGRESS = "tool_progress" # Real-time progress during tool execution
|
|
32
34
|
TOOL_COMPLETED = "tool_completed" # After tool execution
|
|
33
35
|
|
|
34
36
|
# Error handling (1)
|
|
@@ -60,6 +62,7 @@ class EventEmitter:
|
|
|
60
62
|
|
|
61
63
|
def __init__(self):
|
|
62
64
|
self._listeners: Dict[EventType, List[Callable]] = {}
|
|
65
|
+
self._async_listeners: Dict[EventType, List[Callable]] = {}
|
|
63
66
|
|
|
64
67
|
def on(self, event_type: EventType, handler: Callable):
|
|
65
68
|
"""
|
|
@@ -141,6 +144,67 @@ class EventEmitter:
|
|
|
141
144
|
}
|
|
142
145
|
)
|
|
143
146
|
|
|
147
|
+
def on_async(self, event_type: EventType, handler: Callable):
|
|
148
|
+
"""
|
|
149
|
+
Register an async event handler.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
event_type: Type of event to listen for
|
|
153
|
+
handler: Async function to call when event occurs
|
|
154
|
+
"""
|
|
155
|
+
if event_type not in self._async_listeners:
|
|
156
|
+
self._async_listeners[event_type] = []
|
|
157
|
+
self._async_listeners[event_type].append(handler)
|
|
158
|
+
|
|
159
|
+
async def emit_async(self, event_type: EventType, data: Dict[str, Any], source: Optional[str] = None, **kwargs) -> Event:
|
|
160
|
+
"""
|
|
161
|
+
Emit an event asynchronously to all registered handlers.
|
|
162
|
+
|
|
163
|
+
Runs async handlers concurrently with asyncio.gather().
|
|
164
|
+
Also triggers sync handlers for backward compatibility.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
event_type: Type of event
|
|
168
|
+
data: Event data
|
|
169
|
+
source: Source of the event
|
|
170
|
+
**kwargs: Additional event attributes (model_name, tokens, etc.)
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
The event object
|
|
174
|
+
"""
|
|
175
|
+
# Filter kwargs to only include valid Event fields
|
|
176
|
+
try:
|
|
177
|
+
valid_fields = set(Event.__dataclass_fields__.keys())
|
|
178
|
+
except AttributeError:
|
|
179
|
+
# Fallback for older Python versions
|
|
180
|
+
valid_fields = {'trace_id', 'span_id', 'request_id', 'duration_ms', 'model_name',
|
|
181
|
+
'provider_name', 'tokens_input', 'tokens_output', 'cost_usd', 'metadata'}
|
|
182
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields}
|
|
183
|
+
|
|
184
|
+
event = Event(
|
|
185
|
+
type=event_type,
|
|
186
|
+
timestamp=datetime.now(),
|
|
187
|
+
data=data,
|
|
188
|
+
source=source or self.__class__.__name__,
|
|
189
|
+
**filtered_kwargs
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Run async handlers concurrently
|
|
193
|
+
if event_type in self._async_listeners:
|
|
194
|
+
tasks = [handler(event) for handler in self._async_listeners[event_type]]
|
|
195
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
196
|
+
|
|
197
|
+
# Also run sync handlers (backward compatible)
|
|
198
|
+
if event_type in self._listeners:
|
|
199
|
+
for handler in self._listeners[event_type]:
|
|
200
|
+
try:
|
|
201
|
+
handler(event)
|
|
202
|
+
except Exception as e:
|
|
203
|
+
# Log error but don't stop event propagation
|
|
204
|
+
print(f"Error in event handler: {e}")
|
|
205
|
+
|
|
206
|
+
return event
|
|
207
|
+
|
|
144
208
|
|
|
145
209
|
class GlobalEventBus:
|
|
146
210
|
"""
|
|
@@ -149,6 +213,7 @@ class GlobalEventBus:
|
|
|
149
213
|
"""
|
|
150
214
|
_instance = None
|
|
151
215
|
_listeners: Dict[EventType, List[Callable]] = {}
|
|
216
|
+
_async_listeners: Dict[EventType, List[Callable]] = {}
|
|
152
217
|
|
|
153
218
|
def __new__(cls):
|
|
154
219
|
if cls._instance is None:
|
|
@@ -199,6 +264,52 @@ class GlobalEventBus:
|
|
|
199
264
|
def clear(cls):
|
|
200
265
|
"""Clear all global event handlers"""
|
|
201
266
|
cls._listeners.clear()
|
|
267
|
+
cls._async_listeners.clear()
|
|
268
|
+
|
|
269
|
+
@classmethod
|
|
270
|
+
def on_async(cls, event_type: EventType, handler: Callable):
|
|
271
|
+
"""Register a global async event handler"""
|
|
272
|
+
if event_type not in cls._async_listeners:
|
|
273
|
+
cls._async_listeners[event_type] = []
|
|
274
|
+
cls._async_listeners[event_type].append(handler)
|
|
275
|
+
|
|
276
|
+
@classmethod
|
|
277
|
+
async def emit_async(cls, event_type: EventType, data: Dict[str, Any], source: Optional[str] = None, **kwargs):
|
|
278
|
+
"""
|
|
279
|
+
Emit a global event asynchronously.
|
|
280
|
+
|
|
281
|
+
Runs async handlers concurrently with asyncio.gather().
|
|
282
|
+
Also triggers sync handlers for backward compatibility.
|
|
283
|
+
"""
|
|
284
|
+
# Filter kwargs to only include valid Event fields
|
|
285
|
+
try:
|
|
286
|
+
valid_fields = set(Event.__dataclass_fields__.keys())
|
|
287
|
+
except AttributeError:
|
|
288
|
+
# Fallback for older Python versions
|
|
289
|
+
valid_fields = {'trace_id', 'span_id', 'request_id', 'duration_ms', 'model_name',
|
|
290
|
+
'provider_name', 'tokens_input', 'tokens_output', 'cost_usd', 'metadata'}
|
|
291
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields}
|
|
292
|
+
|
|
293
|
+
event = Event(
|
|
294
|
+
type=event_type,
|
|
295
|
+
timestamp=datetime.now(),
|
|
296
|
+
data=data,
|
|
297
|
+
source=source or "GlobalEventBus",
|
|
298
|
+
**filtered_kwargs
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Run async handlers concurrently
|
|
302
|
+
if event_type in cls._async_listeners:
|
|
303
|
+
tasks = [handler(event) for handler in cls._async_listeners[event_type]]
|
|
304
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
305
|
+
|
|
306
|
+
# Also run sync handlers (backward compatible)
|
|
307
|
+
if event_type in cls._listeners:
|
|
308
|
+
for handler in cls._listeners[event_type]:
|
|
309
|
+
try:
|
|
310
|
+
handler(event)
|
|
311
|
+
except Exception as e:
|
|
312
|
+
print(f"Error in global event handler: {e}")
|
|
202
313
|
|
|
203
314
|
|
|
204
315
|
# Convenience functions
|
|
@@ -106,10 +106,55 @@ def format_model_error(provider: str, invalid_model: str, available_models: list
|
|
|
106
106
|
return message.rstrip()
|
|
107
107
|
|
|
108
108
|
|
|
109
|
+
def format_auth_error(provider: str, reason: str = None) -> str:
|
|
110
|
+
"""
|
|
111
|
+
Format actionable authentication error with setup instructions.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
provider: Provider name (e.g., "openai", "anthropic")
|
|
115
|
+
reason: Optional reason for auth failure
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Formatted error message with fix instructions
|
|
119
|
+
"""
|
|
120
|
+
urls = {
|
|
121
|
+
"openai": "https://platform.openai.com/api-keys",
|
|
122
|
+
"anthropic": "https://console.anthropic.com/settings/keys",
|
|
123
|
+
}
|
|
124
|
+
msg = f"{provider.upper()} authentication failed"
|
|
125
|
+
if reason:
|
|
126
|
+
msg += f": {reason}"
|
|
127
|
+
msg += f"\nFix: abstractcore --set-api-key {provider} YOUR_KEY"
|
|
128
|
+
if provider.lower() in urls:
|
|
129
|
+
msg += f"\nGet key: {urls[provider.lower()]}"
|
|
130
|
+
return msg
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def format_provider_error(provider: str, reason: str) -> str:
|
|
134
|
+
"""
|
|
135
|
+
Format actionable provider unavailability error with setup instructions.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
provider: Provider name (e.g., "ollama", "lmstudio")
|
|
139
|
+
reason: Reason for unavailability (e.g., "Connection refused")
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Formatted error message with setup instructions
|
|
143
|
+
"""
|
|
144
|
+
instructions = {
|
|
145
|
+
"ollama": "Install: https://ollama.com/download\nStart: ollama serve",
|
|
146
|
+
"lmstudio": "Install: https://lmstudio.ai/\nEnable API in settings",
|
|
147
|
+
}
|
|
148
|
+
msg = f"Provider '{provider}' unavailable: {reason}"
|
|
149
|
+
if provider.lower() in instructions:
|
|
150
|
+
msg += f"\n{instructions[provider.lower()]}"
|
|
151
|
+
return msg
|
|
152
|
+
|
|
153
|
+
|
|
109
154
|
# Export all exceptions for easy importing
|
|
110
155
|
__all__ = [
|
|
111
156
|
'AbstractCoreError',
|
|
112
|
-
'ProviderError',
|
|
157
|
+
'ProviderError',
|
|
113
158
|
'ProviderAPIError',
|
|
114
159
|
'AuthenticationError',
|
|
115
160
|
'Authentication', # Backward compatibility alias
|
|
@@ -121,5 +166,7 @@ __all__ = [
|
|
|
121
166
|
'SessionError',
|
|
122
167
|
'ConfigurationError',
|
|
123
168
|
'ModelNotFoundError',
|
|
124
|
-
'format_model_error'
|
|
169
|
+
'format_model_error',
|
|
170
|
+
'format_auth_error',
|
|
171
|
+
'format_provider_error'
|
|
125
172
|
]
|