dtSpark 1.1.0a2__tar.gz → 1.1.0a6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/PKG-INFO +9 -2
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/README.md +7 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/pyproject.toml +1 -1
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/setup.py +1 -1
- dtspark-1.1.0a6/src/dtSpark/_version.txt +1 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/aws/authentication.py +1 -1
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/aws/bedrock.py +238 -239
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/aws/costs.py +9 -5
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/aws/pricing.py +25 -21
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/cli_interface.py +69 -62
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/conversation_manager.py +54 -47
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/core/application.py +151 -111
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/core/context_compaction.py +241 -226
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/daemon/__init__.py +36 -22
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/daemon/action_monitor.py +46 -17
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/daemon/daemon_app.py +126 -104
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/daemon/daemon_manager.py +59 -23
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/daemon/pid_file.py +3 -2
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/autonomous_actions.py +3 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/credential_prompt.py +52 -54
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/files/manager.py +6 -12
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/limits/__init__.py +1 -1
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/limits/tokens.py +2 -2
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/llm/anthropic_direct.py +246 -141
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/llm/ollama.py +3 -1
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/mcp_integration/manager.py +4 -4
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/mcp_integration/tool_selector.py +83 -77
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/resources/config.yaml.template +10 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/safety/patterns.py +45 -46
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/safety/prompt_inspector.py +8 -1
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/scheduler/creation_tools.py +273 -181
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/scheduler/executor.py +503 -221
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/tools/builtin.py +70 -53
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/endpoints/autonomous_actions.py +12 -9
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/endpoints/chat.py +18 -6
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/endpoints/conversations.py +57 -17
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/endpoints/main_menu.py +132 -105
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/endpoints/streaming.py +2 -2
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/server.py +65 -5
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/ssl_utils.py +3 -3
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/static/css/dark-theme.css +8 -29
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/static/js/actions.js +2 -1
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/static/js/chat.js +6 -8
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/static/js/main.js +8 -8
- dtspark-1.1.0a6/src/dtSpark/web/static/js/sse-client.js +250 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/templates/actions.html +5 -5
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/templates/base.html +13 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/templates/chat.html +52 -50
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/templates/conversations.html +50 -22
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/templates/goodbye.html +2 -2
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/templates/main_menu.html +17 -17
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/templates/new_conversation.html +51 -20
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/web_interface.py +2 -2
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark.egg-info/PKG-INFO +9 -2
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark.egg-info/requires.txt +1 -1
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/debug_bulk_api.py +3 -3
- dtspark-1.1.0a6/tests/diagnose_aws_costs.py +242 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_builtin_tools_integration.py +3 -4
- dtspark-1.1.0a6/tests/test_bulk_pricing.py +97 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_ollama_conversation.py +3 -3
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_ollama_integration.py +2 -2
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_pricing_integration.py +1 -1
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_prompt_inspection.py +104 -81
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_web_session.py +0 -1
- dtspark-1.1.0a2/src/dtSpark/_version.txt +0 -1
- dtspark-1.1.0a2/src/dtSpark/web/static/js/sse-client.js +0 -242
- dtspark-1.1.0a2/tests/diagnose_aws_costs.py +0 -226
- dtspark-1.1.0a2/tests/test_bulk_pricing.py +0 -88
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/LICENSE +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/MANIFEST.in +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/setup.cfg +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/_description.txt +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/_full_name.txt +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/_licence.txt +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/_metadata.yaml +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/_name.txt +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/aws/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/core/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/daemon/__main__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/daemon/execution_coordinator.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/backends.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/connection.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/conversations.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/files.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/mcp_ops.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/messages.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/schema.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/tool_permissions.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/database/usage.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/files/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/launch.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/limits/costs.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/llm/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/llm/base.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/llm/context_limits.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/llm/manager.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/mcp_integration/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/safety/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/safety/llm_service.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/safety/violation_logger.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/scheduler/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/scheduler/execution_queue.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/scheduler/manager.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/tools/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/auth.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/dependencies.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/endpoints/__init__.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/session.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark/web/templates/login.html +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark.egg-info/SOURCES.txt +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark.egg-info/dependency_links.txt +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark.egg-info/entry_points.txt +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark.egg-info/not-zip-safe +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/src/dtSpark.egg-info/top_level.txt +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/README.md +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_builtin_tools.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_document_archive_tools.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_filesystem_tools.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_mcp_server.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_ollama_context.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_status_indicator.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_tool_selector.py +0 -0
- {dtspark-1.1.0a2 → dtspark-1.1.0a6}/tests/test_web_auth.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtSpark
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.0a6
|
|
4
4
|
Summary: Secure Personal AI Research Kit - Multi-provider LLM CLI/Web interface with MCP tool integration
|
|
5
5
|
Home-page: https://github.com/digital-thought/dtSpark
|
|
6
6
|
Author: Matthew Westwood-Hill
|
|
@@ -42,7 +42,7 @@ Requires-Dist: httpx>=0.24.0
|
|
|
42
42
|
Requires-Dist: aiohttp>=3.8.0
|
|
43
43
|
Requires-Dist: mcp>=0.9.0
|
|
44
44
|
Requires-Dist: pyyaml>=6.0
|
|
45
|
-
Requires-Dist: dtPyAppFramework>=4.
|
|
45
|
+
Requires-Dist: dtPyAppFramework>=4.1.2
|
|
46
46
|
Requires-Dist: tiktoken>=0.5.0
|
|
47
47
|
Requires-Dist: ollama>=0.2.0
|
|
48
48
|
Requires-Dist: cryptography>=41.0.0
|
|
@@ -83,6 +83,13 @@ Dynamic: requires-python
|
|
|
83
83
|
[](https://opensource.org/licenses/MIT)
|
|
84
84
|
[](https://www.python.org/downloads/)
|
|
85
85
|
|
|
86
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
87
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
88
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
89
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
90
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
91
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
92
|
+
|
|
86
93
|
**Spark** is a powerful, multi-provider LLM interface for conversational AI with integrated tool support. It supports AWS Bedrock, Anthropic Direct API, and Ollama local models through both CLI and Web interfaces.
|
|
87
94
|
|
|
88
95
|
## Key Features
|
|
@@ -3,6 +3,13 @@
|
|
|
3
3
|
[](https://opensource.org/licenses/MIT)
|
|
4
4
|
[](https://www.python.org/downloads/)
|
|
5
5
|
|
|
6
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
7
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
8
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
9
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
10
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
11
|
+
[](https://sonarcloud.io/summary/new_code?id=Digital-Thought_dtSpark)
|
|
12
|
+
|
|
6
13
|
**Spark** is a powerful, multi-provider LLM interface for conversational AI with integrated tool support. It supports AWS Bedrock, Anthropic Direct API, and Ollama local models through both CLI and Web interfaces.
|
|
7
14
|
|
|
8
15
|
## Key Features
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
1.1.0a6
|
|
@@ -161,7 +161,7 @@ class AWSAuthenticator:
|
|
|
161
161
|
# Check if this is an SSO token expiration error
|
|
162
162
|
if 'Token has expired' in error_str or 'refresh failed' in error_str:
|
|
163
163
|
logging.warning("AWS SSO token has expired")
|
|
164
|
-
logging.info(
|
|
164
|
+
logging.info("Attempting automatic re-authentication...")
|
|
165
165
|
|
|
166
166
|
# Try to trigger SSO login automatically
|
|
167
167
|
if self.trigger_sso_login():
|
|
@@ -57,77 +57,28 @@ class BedrockService(LLMService):
|
|
|
57
57
|
"""
|
|
58
58
|
models = []
|
|
59
59
|
|
|
60
|
-
# Get inference profiles (recommended approach)
|
|
61
60
|
try:
|
|
62
61
|
response = self.bedrock_client.list_inference_profiles()
|
|
63
62
|
|
|
64
63
|
for profile in response.get('inferenceProfileSummaries', []):
|
|
65
|
-
# Only include ACTIVE profiles
|
|
66
64
|
if profile.get('status') != 'ACTIVE':
|
|
67
65
|
continue
|
|
68
66
|
|
|
69
|
-
# Extract model info from the profile
|
|
70
67
|
profile_models = profile.get('models', [])
|
|
71
68
|
model_id = profile_models[0].get('modelArn', '').split('/')[-1] if profile_models else 'unknown'
|
|
72
|
-
profile_name_lower = profile['inferenceProfileName'].lower()
|
|
73
|
-
model_id_lower = model_id.lower()
|
|
74
69
|
|
|
75
|
-
|
|
76
|
-
if 'embed' in profile_name_lower or 'embed' in model_id_lower:
|
|
77
|
-
logging.debug(f"Skipping embedding model: {profile['inferenceProfileName']}")
|
|
70
|
+
if self._should_skip_profile(profile, model_id):
|
|
78
71
|
continue
|
|
79
72
|
|
|
80
|
-
|
|
81
|
-
if 'stable-diffusion' in profile_name_lower or 'stable-diffusion' in model_id_lower:
|
|
82
|
-
logging.debug(f"Skipping image generation model: {profile['inferenceProfileName']}")
|
|
73
|
+
if not self._verify_model_access(profile, profile_models):
|
|
83
74
|
continue
|
|
84
75
|
|
|
85
|
-
|
|
86
|
-
# Check if the model has been granted access
|
|
87
|
-
try:
|
|
88
|
-
# Try to get the foundation model details to verify access
|
|
89
|
-
if profile_models and len(profile_models) > 0:
|
|
90
|
-
foundation_model_arn = profile_models[0].get('modelArn', '')
|
|
91
|
-
if foundation_model_arn:
|
|
92
|
-
# Extract the model ID from the ARN
|
|
93
|
-
foundation_model_id = foundation_model_arn.split('/')[-1]
|
|
94
|
-
try:
|
|
95
|
-
# Attempt to get foundation model details
|
|
96
|
-
self.bedrock_client.get_foundation_model(modelIdentifier=foundation_model_id)
|
|
97
|
-
except ClientError as model_error:
|
|
98
|
-
# If we get access denied or validation error, skip this model
|
|
99
|
-
error_code = model_error.response.get('Error', {}).get('Code', '')
|
|
100
|
-
if error_code in ['AccessDeniedException', 'ValidationException', 'ResourceNotFoundException']:
|
|
101
|
-
logging.debug(f"Skipping model without access: {profile['inferenceProfileName']} ({error_code})")
|
|
102
|
-
continue
|
|
103
|
-
# For other errors, log but continue (might be accessible)
|
|
104
|
-
logging.debug(f"Could not verify access for {profile['inferenceProfileName']}: {error_code}")
|
|
105
|
-
except Exception as verify_error:
|
|
106
|
-
logging.debug(f"Error verifying model access for {profile['inferenceProfileName']}: {verify_error}")
|
|
107
|
-
# If we can't verify, skip it to be safe
|
|
108
|
-
continue
|
|
109
|
-
|
|
110
|
-
# Determine model maker from model ID or profile name
|
|
111
|
-
model_maker = 'Unknown'
|
|
112
|
-
|
|
113
|
-
if 'anthropic' in model_id_lower or 'anthropic' in profile_name_lower or 'claude' in profile_name_lower:
|
|
114
|
-
model_maker = 'Anthropic'
|
|
115
|
-
elif 'amazon' in model_id_lower or 'amazon' in profile_name_lower or 'titan' in profile_name_lower:
|
|
116
|
-
model_maker = 'Amazon'
|
|
117
|
-
elif 'meta' in model_id_lower or 'meta' in profile_name_lower or 'llama' in profile_name_lower:
|
|
118
|
-
model_maker = 'Meta'
|
|
119
|
-
elif 'ai21' in model_id_lower or 'ai21' in profile_name_lower or 'jamba' in profile_name_lower:
|
|
120
|
-
model_maker = 'AI21'
|
|
121
|
-
elif 'cohere' in model_id_lower or 'cohere' in profile_name_lower:
|
|
122
|
-
model_maker = 'Cohere'
|
|
123
|
-
elif 'mistral' in model_id_lower or 'mistral' in profile_name_lower:
|
|
124
|
-
model_maker = 'Mistral'
|
|
76
|
+
model_maker = self._detect_model_maker(model_id, profile['inferenceProfileName'])
|
|
125
77
|
|
|
126
78
|
models.append({
|
|
127
79
|
'id': profile['inferenceProfileArn'],
|
|
128
80
|
'name': profile['inferenceProfileName'],
|
|
129
|
-
'model_maker': model_maker,
|
|
130
|
-
# 'provider' will be added by LLM manager to indicate service (AWS Bedrock)
|
|
81
|
+
'model_maker': model_maker,
|
|
131
82
|
'access_info': self.get_access_info(),
|
|
132
83
|
'input_modalities': ['TEXT'],
|
|
133
84
|
'output_modalities': ['TEXT'],
|
|
@@ -142,12 +93,67 @@ class BedrockService(LLMService):
|
|
|
142
93
|
except Exception as e:
|
|
143
94
|
logging.error(f"Unexpected error listing inference profiles: {e}")
|
|
144
95
|
|
|
145
|
-
# Sort models by model maker and name for better display
|
|
146
96
|
models.sort(key=lambda x: (x.get('model_maker', 'Unknown'), x['name']))
|
|
147
|
-
|
|
148
97
|
logging.info(f"Total available models: {len(models)}")
|
|
149
98
|
return models
|
|
150
99
|
|
|
100
|
+
@staticmethod
|
|
101
|
+
def _should_skip_profile(profile: Dict[str, Any], model_id: str) -> bool:
|
|
102
|
+
"""Check whether an inference profile should be excluded from the model list."""
|
|
103
|
+
profile_name_lower = profile['inferenceProfileName'].lower()
|
|
104
|
+
model_id_lower = model_id.lower()
|
|
105
|
+
|
|
106
|
+
if 'embed' in profile_name_lower or 'embed' in model_id_lower:
|
|
107
|
+
logging.debug(f"Skipping embedding model: {profile['inferenceProfileName']}")
|
|
108
|
+
return True
|
|
109
|
+
|
|
110
|
+
if 'stable-diffusion' in profile_name_lower or 'stable-diffusion' in model_id_lower:
|
|
111
|
+
logging.debug(f"Skipping image generation model: {profile['inferenceProfileName']}")
|
|
112
|
+
return True
|
|
113
|
+
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
def _verify_model_access(self, profile: Dict[str, Any], profile_models: List[Dict[str, Any]]) -> bool:
|
|
117
|
+
"""Verify that the underlying foundation model is accessible. Returns True if accessible."""
|
|
118
|
+
_NO_ACCESS_CODES = {'AccessDeniedException', 'ValidationException', 'ResourceNotFoundException'}
|
|
119
|
+
try:
|
|
120
|
+
if not profile_models:
|
|
121
|
+
return True
|
|
122
|
+
foundation_model_arn = profile_models[0].get('modelArn', '')
|
|
123
|
+
if not foundation_model_arn:
|
|
124
|
+
return True
|
|
125
|
+
foundation_model_id = foundation_model_arn.split('/')[-1]
|
|
126
|
+
try:
|
|
127
|
+
self.bedrock_client.get_foundation_model(modelIdentifier=foundation_model_id)
|
|
128
|
+
except ClientError as model_error:
|
|
129
|
+
error_code = model_error.response.get('Error', {}).get('Code', '')
|
|
130
|
+
if error_code in _NO_ACCESS_CODES:
|
|
131
|
+
logging.debug(f"Skipping model without access: {profile['inferenceProfileName']} ({error_code})")
|
|
132
|
+
return False
|
|
133
|
+
logging.debug(f"Could not verify access for {profile['inferenceProfileName']}: {error_code}")
|
|
134
|
+
except Exception as verify_error:
|
|
135
|
+
logging.debug(f"Error verifying model access for {profile['inferenceProfileName']}: {verify_error}")
|
|
136
|
+
return False
|
|
137
|
+
return True
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _detect_model_maker(model_id: str, profile_name: str) -> str:
|
|
141
|
+
"""Determine the model maker from a model ID and profile name."""
|
|
142
|
+
id_lower = model_id.lower()
|
|
143
|
+
name_lower = profile_name.lower()
|
|
144
|
+
maker_keywords = [
|
|
145
|
+
('Anthropic', ['anthropic', 'claude']),
|
|
146
|
+
('Amazon', ['amazon', 'titan']),
|
|
147
|
+
('Meta', ['meta', 'llama']),
|
|
148
|
+
('AI21', ['ai21', 'jamba']),
|
|
149
|
+
('Cohere', ['cohere']),
|
|
150
|
+
('Mistral', ['mistral']),
|
|
151
|
+
]
|
|
152
|
+
for maker, keywords in maker_keywords:
|
|
153
|
+
if any(kw in id_lower or kw in name_lower for kw in keywords):
|
|
154
|
+
return maker
|
|
155
|
+
return 'Unknown'
|
|
156
|
+
|
|
151
157
|
def set_model(self, model_id: str):
|
|
152
158
|
"""
|
|
153
159
|
Set the current model for chat operations.
|
|
@@ -173,6 +179,17 @@ class BedrockService(LLMService):
|
|
|
173
179
|
|
|
174
180
|
logging.info(f"{'Inference profile' if self.is_inference_profile else 'Model'} set to: {model_id}")
|
|
175
181
|
|
|
182
|
+
# Transient error codes that should be retried
|
|
183
|
+
_TRANSIENT_ERRORS = {
|
|
184
|
+
'ThrottlingException',
|
|
185
|
+
'TooManyRequestsException',
|
|
186
|
+
'ModelTimeoutException',
|
|
187
|
+
'ServiceUnavailableException',
|
|
188
|
+
'InternalServerError',
|
|
189
|
+
'ModelNotReadyException',
|
|
190
|
+
'ModelStreamErrorException',
|
|
191
|
+
}
|
|
192
|
+
|
|
176
193
|
def invoke_model(self, messages: List[Dict[str, str]], max_tokens: int = 4096,
|
|
177
194
|
temperature: float = 0.7, tools: Optional[List[Dict[str, Any]]] = None,
|
|
178
195
|
system: Optional[str] = None, max_retries: int = 3) -> Optional[Dict[str, Any]]:
|
|
@@ -200,17 +217,6 @@ class BedrockService(LLMService):
|
|
|
200
217
|
'error_type': 'ConfigurationError'
|
|
201
218
|
}
|
|
202
219
|
|
|
203
|
-
# Transient error codes that should be retried
|
|
204
|
-
transient_errors = [
|
|
205
|
-
'ThrottlingException',
|
|
206
|
-
'TooManyRequestsException',
|
|
207
|
-
'ModelTimeoutException',
|
|
208
|
-
'ServiceUnavailableException',
|
|
209
|
-
'InternalServerError',
|
|
210
|
-
'ModelNotReadyException',
|
|
211
|
-
'ModelStreamErrorException'
|
|
212
|
-
]
|
|
213
|
-
|
|
214
220
|
import time
|
|
215
221
|
attempt = 0
|
|
216
222
|
|
|
@@ -218,77 +224,14 @@ class BedrockService(LLMService):
|
|
|
218
224
|
if attempt > 1:
|
|
219
225
|
logging.info(f"Retry attempt {attempt}/{max_retries} for model invocation")
|
|
220
226
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
# Log the request for debugging
|
|
230
|
-
logging.debug(f"Request body keys: {list(request_body.keys())}")
|
|
231
|
-
if 'tools' in request_body:
|
|
232
|
-
logging.debug(f"Tools count: {len(request_body['tools'])}")
|
|
233
|
-
logging.debug(f"max_tokens is set to {max_tokens}")
|
|
234
|
-
try:
|
|
235
|
-
response = self.bedrock_runtime_client.invoke_model(
|
|
236
|
-
modelId=self.current_model_id,
|
|
237
|
-
contentType='application/json',
|
|
238
|
-
accept='application/json',
|
|
239
|
-
body=json.dumps(request_body)
|
|
240
|
-
)
|
|
241
|
-
except Exception as api_error:
|
|
242
|
-
logging.error(f"Bedrock API error: {api_error}")
|
|
243
|
-
logging.error(f"Request body: {json.dumps(request_body, indent=2)}")
|
|
244
|
-
raise
|
|
245
|
-
|
|
246
|
-
# Parse the response
|
|
247
|
-
response_body = json.loads(response['body'].read())
|
|
248
|
-
parsed_response = self._parse_response(response_body)
|
|
249
|
-
|
|
250
|
-
logging.debug(f"{'Inference profile' if self.is_inference_profile else 'Model'} invoked successfully: {self.current_model_id}")
|
|
251
|
-
return parsed_response
|
|
252
|
-
|
|
253
|
-
except ClientError as e:
|
|
254
|
-
error_code = e.response['Error']['Code']
|
|
255
|
-
error_message = e.response['Error']['Message']
|
|
256
|
-
|
|
257
|
-
# Log detailed error information
|
|
258
|
-
logging.error(f"Bedrock API error - Code: {error_code}, Message: {error_message}")
|
|
259
|
-
|
|
260
|
-
# Check if this is a transient error that should be retried
|
|
261
|
-
if error_code in transient_errors and attempt <= max_retries:
|
|
262
|
-
wait_time = min(2 ** (attempt - 1), 30) # Exponential backoff, max 30 seconds
|
|
263
|
-
logging.warning(f"Transient error {error_code}, retrying in {wait_time} seconds... (attempt {attempt}/{max_retries})")
|
|
264
|
-
time.sleep(wait_time)
|
|
265
|
-
continue # Retry
|
|
266
|
-
|
|
267
|
-
# Non-transient error or max retries reached - return error details
|
|
268
|
-
return {
|
|
269
|
-
'error': True,
|
|
270
|
-
'error_code': error_code,
|
|
271
|
-
'error_message': error_message,
|
|
272
|
-
'error_type': 'ClientError',
|
|
273
|
-
'retries_attempted': attempt - 1
|
|
274
|
-
}
|
|
275
|
-
|
|
276
|
-
except Exception as e:
|
|
277
|
-
logging.error(f"Unexpected error invoking {'inference profile' if self.is_inference_profile else 'model'}: {e}")
|
|
278
|
-
logging.error(f"Error type: {type(e).__name__}")
|
|
279
|
-
import traceback
|
|
280
|
-
logging.error(f"Traceback: {traceback.format_exc()}")
|
|
281
|
-
|
|
282
|
-
# Return error details (unexpected errors are not retried)
|
|
283
|
-
return {
|
|
284
|
-
'error': True,
|
|
285
|
-
'error_code': type(e).__name__,
|
|
286
|
-
'error_message': str(e),
|
|
287
|
-
'error_type': 'Exception',
|
|
288
|
-
'retries_attempted': 0
|
|
289
|
-
}
|
|
290
|
-
|
|
291
|
-
# Should not reach here, but just in case
|
|
227
|
+
result = self._attempt_invocation(messages, max_tokens, temperature, tools, system, attempt, max_retries)
|
|
228
|
+
if result.get('_retry'):
|
|
229
|
+
wait_time = min(2 ** (attempt - 1), 30)
|
|
230
|
+
time.sleep(wait_time)
|
|
231
|
+
attempt += 1
|
|
232
|
+
continue
|
|
233
|
+
return result
|
|
234
|
+
|
|
292
235
|
return {
|
|
293
236
|
'error': True,
|
|
294
237
|
'error_code': 'MaxRetriesExceeded',
|
|
@@ -297,6 +240,67 @@ class BedrockService(LLMService):
|
|
|
297
240
|
'retries_attempted': max_retries
|
|
298
241
|
}
|
|
299
242
|
|
|
243
|
+
def _attempt_invocation(self, messages, max_tokens, temperature, tools, system,
|
|
244
|
+
attempt, max_retries) -> Dict[str, Any]:
|
|
245
|
+
"""Execute a single model invocation attempt. Returns a _retry sentinel on transient failure."""
|
|
246
|
+
model_label = 'inference profile' if self.is_inference_profile else 'model'
|
|
247
|
+
try:
|
|
248
|
+
request_body = self._format_request(messages, max_tokens, temperature, tools, system)
|
|
249
|
+
logging.debug(f"Invoking {model_label}: {self.current_model_id}")
|
|
250
|
+
logging.debug(f"Request body keys: {list(request_body.keys())}")
|
|
251
|
+
if 'tools' in request_body:
|
|
252
|
+
logging.debug(f"Tools count: {len(request_body['tools'])}")
|
|
253
|
+
logging.debug(f"max_tokens is set to {max_tokens}")
|
|
254
|
+
|
|
255
|
+
try:
|
|
256
|
+
response = self.bedrock_runtime_client.invoke_model(
|
|
257
|
+
modelId=self.current_model_id,
|
|
258
|
+
contentType='application/json',
|
|
259
|
+
accept='application/json',
|
|
260
|
+
body=json.dumps(request_body)
|
|
261
|
+
)
|
|
262
|
+
except Exception as api_error:
|
|
263
|
+
logging.error(f"Bedrock API error: {api_error}")
|
|
264
|
+
logging.error(f"Request body: {json.dumps(request_body, indent=2)}")
|
|
265
|
+
raise
|
|
266
|
+
|
|
267
|
+
response_body = json.loads(response['body'].read())
|
|
268
|
+
parsed_response = self._parse_response(response_body)
|
|
269
|
+
logging.debug(f"{model_label} invoked successfully: {self.current_model_id}")
|
|
270
|
+
return parsed_response
|
|
271
|
+
|
|
272
|
+
except ClientError as e:
|
|
273
|
+
error_code = e.response['Error']['Code']
|
|
274
|
+
error_message = e.response['Error']['Message']
|
|
275
|
+
logging.error(f"Bedrock API error - Code: {error_code}, Message: {error_message}")
|
|
276
|
+
|
|
277
|
+
if error_code in self._TRANSIENT_ERRORS and attempt <= max_retries:
|
|
278
|
+
wait_time = min(2 ** (attempt - 1), 30)
|
|
279
|
+
logging.warning(f"Transient error {error_code}, retrying in {wait_time} seconds... "
|
|
280
|
+
f"(attempt {attempt}/{max_retries})")
|
|
281
|
+
return {'_retry': True}
|
|
282
|
+
|
|
283
|
+
return {
|
|
284
|
+
'error': True,
|
|
285
|
+
'error_code': error_code,
|
|
286
|
+
'error_message': error_message,
|
|
287
|
+
'error_type': 'ClientError',
|
|
288
|
+
'retries_attempted': attempt - 1
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
except Exception as e:
|
|
292
|
+
logging.error(f"Unexpected error invoking {model_label}: {e}")
|
|
293
|
+
logging.error(f"Error type: {type(e).__name__}")
|
|
294
|
+
import traceback
|
|
295
|
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
|
296
|
+
return {
|
|
297
|
+
'error': True,
|
|
298
|
+
'error_code': type(e).__name__,
|
|
299
|
+
'error_message': str(e),
|
|
300
|
+
'error_type': 'Exception',
|
|
301
|
+
'retries_attempted': 0
|
|
302
|
+
}
|
|
303
|
+
|
|
300
304
|
def _format_request(self, messages: List[Dict[str, str]], max_tokens: int,
|
|
301
305
|
temperature: float, tools: Optional[List[Dict[str, Any]]] = None,
|
|
302
306
|
system: Optional[str] = None) -> Dict[str, Any]:
|
|
@@ -402,96 +406,92 @@ class BedrockService(LLMService):
|
|
|
402
406
|
Returns:
|
|
403
407
|
Standardised response dictionary
|
|
404
408
|
"""
|
|
405
|
-
# Use model_identifier for provider detection (works for both direct models and profiles)
|
|
406
409
|
model_id = self.model_identifier or self.current_model_id
|
|
410
|
+
model_lower = model_id.lower()
|
|
411
|
+
|
|
412
|
+
if 'anthropic.claude' in model_id or 'anthropic' in model_lower:
|
|
413
|
+
return self._parse_anthropic_response(response_body)
|
|
414
|
+
if 'amazon.titan' in model_id or 'titan' in model_lower:
|
|
415
|
+
return self._parse_titan_response(response_body)
|
|
416
|
+
if 'meta.llama' in model_id or 'llama' in model_lower:
|
|
417
|
+
return self._parse_llama_response(response_body)
|
|
418
|
+
if 'ai21' in model_id:
|
|
419
|
+
return self._parse_ai21_response(response_body)
|
|
420
|
+
if 'cohere' in model_id:
|
|
421
|
+
return self._parse_cohere_response(response_body)
|
|
422
|
+
|
|
423
|
+
return {'content': str(response_body), 'stop_reason': None, 'usage': {}}
|
|
424
|
+
|
|
425
|
+
@staticmethod
|
|
426
|
+
def _parse_anthropic_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
|
|
427
|
+
"""Parse an Anthropic Claude response."""
|
|
428
|
+
content_blocks = response_body.get('content', [])
|
|
429
|
+
parsed_content = []
|
|
430
|
+
text_parts = []
|
|
431
|
+
|
|
432
|
+
for block in content_blocks:
|
|
433
|
+
if block.get('type') == 'text':
|
|
434
|
+
text_parts.append(block.get('text', ''))
|
|
435
|
+
parsed_content.append({'type': 'text', 'text': block.get('text', '')})
|
|
436
|
+
elif block.get('type') == 'tool_use':
|
|
437
|
+
parsed_content.append({
|
|
438
|
+
'type': 'tool_use',
|
|
439
|
+
'id': block.get('id'),
|
|
440
|
+
'name': block.get('name'),
|
|
441
|
+
'input': block.get('input', {})
|
|
442
|
+
})
|
|
407
443
|
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
content_blocks
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
text_parts = []
|
|
415
|
-
|
|
416
|
-
for block in content_blocks:
|
|
417
|
-
if block.get('type') == 'text':
|
|
418
|
-
text_parts.append(block.get('text', ''))
|
|
419
|
-
parsed_content.append({
|
|
420
|
-
'type': 'text',
|
|
421
|
-
'text': block.get('text', '')
|
|
422
|
-
})
|
|
423
|
-
elif block.get('type') == 'tool_use':
|
|
424
|
-
parsed_content.append({
|
|
425
|
-
'type': 'tool_use',
|
|
426
|
-
'id': block.get('id'),
|
|
427
|
-
'name': block.get('name'),
|
|
428
|
-
'input': block.get('input', {})
|
|
429
|
-
})
|
|
430
|
-
|
|
431
|
-
# Combine text for backwards compatibility
|
|
432
|
-
text = '\n'.join(text_parts) if text_parts else ''
|
|
433
|
-
|
|
434
|
-
return {
|
|
435
|
-
'content': text,
|
|
436
|
-
'content_blocks': parsed_content,
|
|
437
|
-
'stop_reason': response_body.get('stop_reason'),
|
|
438
|
-
'usage': response_body.get('usage', {})
|
|
439
|
-
}
|
|
440
|
-
|
|
441
|
-
# Amazon Titan models
|
|
442
|
-
elif 'amazon.titan' in model_id or 'titan' in model_id.lower():
|
|
443
|
-
results = response_body.get('results', [])
|
|
444
|
-
text = results[0].get('outputText', '') if results else ''
|
|
445
|
-
|
|
446
|
-
return {
|
|
447
|
-
'content': text,
|
|
448
|
-
'stop_reason': results[0].get('completionReason') if results else None,
|
|
449
|
-
'usage': {
|
|
450
|
-
'input_tokens': response_body.get('inputTextTokenCount', 0),
|
|
451
|
-
'output_tokens': response_body.get('results', [{}])[0].get('tokenCount', 0)
|
|
452
|
-
}
|
|
453
|
-
}
|
|
444
|
+
return {
|
|
445
|
+
'content': '\n'.join(text_parts) if text_parts else '',
|
|
446
|
+
'content_blocks': parsed_content,
|
|
447
|
+
'stop_reason': response_body.get('stop_reason'),
|
|
448
|
+
'usage': response_body.get('usage', {})
|
|
449
|
+
}
|
|
454
450
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
451
|
+
@staticmethod
|
|
452
|
+
def _parse_titan_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
|
|
453
|
+
"""Parse an Amazon Titan response."""
|
|
454
|
+
results = response_body.get('results', [])
|
|
455
|
+
return {
|
|
456
|
+
'content': results[0].get('outputText', '') if results else '',
|
|
457
|
+
'stop_reason': results[0].get('completionReason') if results else None,
|
|
458
|
+
'usage': {
|
|
459
|
+
'input_tokens': response_body.get('inputTextTokenCount', 0),
|
|
460
|
+
'output_tokens': response_body.get('results', [{}])[0].get('tokenCount', 0)
|
|
464
461
|
}
|
|
462
|
+
}
|
|
465
463
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
'
|
|
474
|
-
'
|
|
464
|
+
@staticmethod
|
|
465
|
+
def _parse_llama_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
|
|
466
|
+
"""Parse a Meta Llama response."""
|
|
467
|
+
return {
|
|
468
|
+
'content': response_body.get('generation', ''),
|
|
469
|
+
'stop_reason': response_body.get('stop_reason'),
|
|
470
|
+
'usage': {
|
|
471
|
+
'input_tokens': response_body.get('prompt_token_count', 0),
|
|
472
|
+
'output_tokens': response_body.get('generation_token_count', 0)
|
|
475
473
|
}
|
|
474
|
+
}
|
|
476
475
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
}
|
|
476
|
+
@staticmethod
|
|
477
|
+
def _parse_ai21_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
|
|
478
|
+
"""Parse an AI21 response."""
|
|
479
|
+
completions = response_body.get('completions', [])
|
|
480
|
+
return {
|
|
481
|
+
'content': completions[0].get('data', {}).get('text', '') if completions else '',
|
|
482
|
+
'stop_reason': completions[0].get('finishReason', {}).get('reason') if completions else None,
|
|
483
|
+
'usage': {}
|
|
484
|
+
}
|
|
487
485
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
486
|
+
@staticmethod
|
|
487
|
+
def _parse_cohere_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
|
|
488
|
+
"""Parse a Cohere response."""
|
|
489
|
+
generations = response_body.get('generations', [])
|
|
490
|
+
return {
|
|
491
|
+
'content': generations[0].get('text', '') if generations else '',
|
|
492
|
+
'stop_reason': generations[0].get('finish_reason') if generations else None,
|
|
493
|
+
'usage': {}
|
|
494
|
+
}
|
|
495
495
|
|
|
496
496
|
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
|
497
497
|
"""
|
|
@@ -544,30 +544,29 @@ class BedrockService(LLMService):
|
|
|
544
544
|
Estimated total token count for all messages
|
|
545
545
|
"""
|
|
546
546
|
total_tokens = 0
|
|
547
|
-
|
|
548
547
|
for message in messages:
|
|
549
|
-
# Count tokens for role (small overhead)
|
|
550
548
|
total_tokens += 4 # Approximate overhead for role formatting
|
|
551
|
-
|
|
552
|
-
# Count tokens in content
|
|
553
|
-
content = message.get('content', '')
|
|
554
|
-
|
|
555
|
-
# Handle content that might be a list (for multi-part content)
|
|
556
|
-
if isinstance(content, list):
|
|
557
|
-
for part in content:
|
|
558
|
-
if isinstance(part, dict):
|
|
559
|
-
if 'text' in part:
|
|
560
|
-
total_tokens += self.count_tokens(part['text'], model_id)
|
|
561
|
-
# Add overhead for other content types (images, documents, etc.)
|
|
562
|
-
elif 'image' in part or 'document' in part:
|
|
563
|
-
total_tokens += 1000 # Rough estimate for non-text content
|
|
564
|
-
elif isinstance(part, str):
|
|
565
|
-
total_tokens += self.count_tokens(part, model_id)
|
|
566
|
-
elif isinstance(content, str):
|
|
567
|
-
total_tokens += self.count_tokens(content, model_id)
|
|
568
|
-
|
|
549
|
+
total_tokens += self._count_content_tokens(message.get('content', ''), model_id)
|
|
569
550
|
return total_tokens
|
|
570
551
|
|
|
552
|
+
def _count_content_tokens(self, content, model_id: Optional[str] = None) -> int:
|
|
553
|
+
"""Count tokens for a single message content value (string or list of parts)."""
|
|
554
|
+
if isinstance(content, str):
|
|
555
|
+
return self.count_tokens(content, model_id)
|
|
556
|
+
if not isinstance(content, list):
|
|
557
|
+
return 0
|
|
558
|
+
|
|
559
|
+
total = 0
|
|
560
|
+
for part in content:
|
|
561
|
+
if isinstance(part, str):
|
|
562
|
+
total += self.count_tokens(part, model_id)
|
|
563
|
+
elif isinstance(part, dict):
|
|
564
|
+
if 'text' in part:
|
|
565
|
+
total += self.count_tokens(part['text'], model_id)
|
|
566
|
+
elif 'image' in part or 'document' in part:
|
|
567
|
+
total += 1000 # Rough estimate for non-text content
|
|
568
|
+
return total
|
|
569
|
+
|
|
571
570
|
def get_current_model_id(self) -> Optional[str]:
|
|
572
571
|
"""
|
|
573
572
|
Get the currently selected model ID.
|