pomera-ai-commander 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +680 -0
- package/bin/pomera-ai-commander.js +62 -0
- package/core/__init__.py +66 -0
- package/core/__pycache__/__init__.cpython-313.pyc +0 -0
- package/core/__pycache__/app_context.cpython-313.pyc +0 -0
- package/core/__pycache__/async_text_processor.cpython-313.pyc +0 -0
- package/core/__pycache__/backup_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/backup_recovery_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/content_hash_cache.cpython-313.pyc +0 -0
- package/core/__pycache__/context_menu.cpython-313.pyc +0 -0
- package/core/__pycache__/data_validator.cpython-313.pyc +0 -0
- package/core/__pycache__/database_connection_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_curl_settings_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_promera_ai_settings_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_schema.cpython-313.pyc +0 -0
- package/core/__pycache__/database_schema_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_settings_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_settings_manager_interface.cpython-313.pyc +0 -0
- package/core/__pycache__/dialog_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/efficient_line_numbers.cpython-313.pyc +0 -0
- package/core/__pycache__/error_handler.cpython-313.pyc +0 -0
- package/core/__pycache__/error_service.cpython-313.pyc +0 -0
- package/core/__pycache__/event_consolidator.cpython-313.pyc +0 -0
- package/core/__pycache__/memory_efficient_text_widget.cpython-313.pyc +0 -0
- package/core/__pycache__/migration_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/migration_test_suite.cpython-313.pyc +0 -0
- package/core/__pycache__/migration_validator.cpython-313.pyc +0 -0
- package/core/__pycache__/optimized_find_replace.cpython-313.pyc +0 -0
- package/core/__pycache__/optimized_pattern_engine.cpython-313.pyc +0 -0
- package/core/__pycache__/optimized_search_highlighter.cpython-313.pyc +0 -0
- package/core/__pycache__/performance_monitor.cpython-313.pyc +0 -0
- package/core/__pycache__/persistence_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/progressive_stats_calculator.cpython-313.pyc +0 -0
- package/core/__pycache__/regex_pattern_cache.cpython-313.pyc +0 -0
- package/core/__pycache__/regex_pattern_library.cpython-313.pyc +0 -0
- package/core/__pycache__/search_operation_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/settings_defaults_registry.cpython-313.pyc +0 -0
- package/core/__pycache__/settings_integrity_validator.cpython-313.pyc +0 -0
- package/core/__pycache__/settings_serializer.cpython-313.pyc +0 -0
- package/core/__pycache__/settings_validator.cpython-313.pyc +0 -0
- package/core/__pycache__/smart_stats_calculator.cpython-313.pyc +0 -0
- package/core/__pycache__/statistics_update_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/stats_config_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/streaming_text_handler.cpython-313.pyc +0 -0
- package/core/__pycache__/task_scheduler.cpython-313.pyc +0 -0
- package/core/__pycache__/visibility_monitor.cpython-313.pyc +0 -0
- package/core/__pycache__/widget_cache.cpython-313.pyc +0 -0
- package/core/app_context.py +482 -0
- package/core/async_text_processor.py +422 -0
- package/core/backup_manager.py +656 -0
- package/core/backup_recovery_manager.py +1034 -0
- package/core/content_hash_cache.py +509 -0
- package/core/context_menu.py +313 -0
- package/core/data_validator.py +1067 -0
- package/core/database_connection_manager.py +745 -0
- package/core/database_curl_settings_manager.py +609 -0
- package/core/database_promera_ai_settings_manager.py +447 -0
- package/core/database_schema.py +412 -0
- package/core/database_schema_manager.py +396 -0
- package/core/database_settings_manager.py +1508 -0
- package/core/database_settings_manager_interface.py +457 -0
- package/core/dialog_manager.py +735 -0
- package/core/efficient_line_numbers.py +511 -0
- package/core/error_handler.py +747 -0
- package/core/error_service.py +431 -0
- package/core/event_consolidator.py +512 -0
- package/core/mcp/__init__.py +43 -0
- package/core/mcp/__pycache__/__init__.cpython-313.pyc +0 -0
- package/core/mcp/__pycache__/protocol.cpython-313.pyc +0 -0
- package/core/mcp/__pycache__/schema.cpython-313.pyc +0 -0
- package/core/mcp/__pycache__/server_stdio.cpython-313.pyc +0 -0
- package/core/mcp/__pycache__/tool_registry.cpython-313.pyc +0 -0
- package/core/mcp/protocol.py +288 -0
- package/core/mcp/schema.py +251 -0
- package/core/mcp/server_stdio.py +299 -0
- package/core/mcp/tool_registry.py +2345 -0
- package/core/memory_efficient_text_widget.py +712 -0
- package/core/migration_manager.py +915 -0
- package/core/migration_test_suite.py +1086 -0
- package/core/migration_validator.py +1144 -0
- package/core/optimized_find_replace.py +715 -0
- package/core/optimized_pattern_engine.py +424 -0
- package/core/optimized_search_highlighter.py +553 -0
- package/core/performance_monitor.py +675 -0
- package/core/persistence_manager.py +713 -0
- package/core/progressive_stats_calculator.py +632 -0
- package/core/regex_pattern_cache.py +530 -0
- package/core/regex_pattern_library.py +351 -0
- package/core/search_operation_manager.py +435 -0
- package/core/settings_defaults_registry.py +1087 -0
- package/core/settings_integrity_validator.py +1112 -0
- package/core/settings_serializer.py +558 -0
- package/core/settings_validator.py +1824 -0
- package/core/smart_stats_calculator.py +710 -0
- package/core/statistics_update_manager.py +619 -0
- package/core/stats_config_manager.py +858 -0
- package/core/streaming_text_handler.py +723 -0
- package/core/task_scheduler.py +596 -0
- package/core/update_pattern_library.py +169 -0
- package/core/visibility_monitor.py +596 -0
- package/core/widget_cache.py +498 -0
- package/mcp.json +61 -0
- package/package.json +57 -0
- package/pomera.py +7483 -0
- package/pomera_mcp_server.py +144 -0
- package/tools/__init__.py +5 -0
- package/tools/__pycache__/__init__.cpython-313.pyc +0 -0
- package/tools/__pycache__/ai_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/ascii_art_generator.cpython-313.pyc +0 -0
- package/tools/__pycache__/base64_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/base_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/case_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/column_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/cron_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/curl_history.cpython-313.pyc +0 -0
- package/tools/__pycache__/curl_processor.cpython-313.pyc +0 -0
- package/tools/__pycache__/curl_settings.cpython-313.pyc +0 -0
- package/tools/__pycache__/curl_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/diff_viewer.cpython-313.pyc +0 -0
- package/tools/__pycache__/email_extraction_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/email_header_analyzer.cpython-313.pyc +0 -0
- package/tools/__pycache__/extraction_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/find_replace.cpython-313.pyc +0 -0
- package/tools/__pycache__/folder_file_reporter.cpython-313.pyc +0 -0
- package/tools/__pycache__/folder_file_reporter_adapter.cpython-313.pyc +0 -0
- package/tools/__pycache__/generator_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/hash_generator.cpython-313.pyc +0 -0
- package/tools/__pycache__/html_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/huggingface_helper.cpython-313.pyc +0 -0
- package/tools/__pycache__/jsonxml_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/line_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/list_comparator.cpython-313.pyc +0 -0
- package/tools/__pycache__/markdown_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/mcp_widget.cpython-313.pyc +0 -0
- package/tools/__pycache__/notes_widget.cpython-313.pyc +0 -0
- package/tools/__pycache__/number_base_converter.cpython-313.pyc +0 -0
- package/tools/__pycache__/regex_extractor.cpython-313.pyc +0 -0
- package/tools/__pycache__/slug_generator.cpython-313.pyc +0 -0
- package/tools/__pycache__/sorter_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/string_escape_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/text_statistics_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/text_wrapper.cpython-313.pyc +0 -0
- package/tools/__pycache__/timestamp_converter.cpython-313.pyc +0 -0
- package/tools/__pycache__/tool_loader.cpython-313.pyc +0 -0
- package/tools/__pycache__/translator_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/url_link_extractor.cpython-313.pyc +0 -0
- package/tools/__pycache__/url_parser.cpython-313.pyc +0 -0
- package/tools/__pycache__/whitespace_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/word_frequency_counter.cpython-313.pyc +0 -0
- package/tools/ai_tools.py +2892 -0
- package/tools/ascii_art_generator.py +353 -0
- package/tools/base64_tools.py +184 -0
- package/tools/base_tool.py +511 -0
- package/tools/case_tool.py +309 -0
- package/tools/column_tools.py +396 -0
- package/tools/cron_tool.py +885 -0
- package/tools/curl_history.py +601 -0
- package/tools/curl_processor.py +1208 -0
- package/tools/curl_settings.py +503 -0
- package/tools/curl_tool.py +5467 -0
- package/tools/diff_viewer.py +1072 -0
- package/tools/email_extraction_tool.py +249 -0
- package/tools/email_header_analyzer.py +426 -0
- package/tools/extraction_tools.py +250 -0
- package/tools/find_replace.py +1751 -0
- package/tools/folder_file_reporter.py +1463 -0
- package/tools/folder_file_reporter_adapter.py +480 -0
- package/tools/generator_tools.py +1217 -0
- package/tools/hash_generator.py +256 -0
- package/tools/html_tool.py +657 -0
- package/tools/huggingface_helper.py +449 -0
- package/tools/jsonxml_tool.py +730 -0
- package/tools/line_tools.py +419 -0
- package/tools/list_comparator.py +720 -0
- package/tools/markdown_tools.py +562 -0
- package/tools/mcp_widget.py +1417 -0
- package/tools/notes_widget.py +973 -0
- package/tools/number_base_converter.py +373 -0
- package/tools/regex_extractor.py +572 -0
- package/tools/slug_generator.py +311 -0
- package/tools/sorter_tools.py +459 -0
- package/tools/string_escape_tool.py +393 -0
- package/tools/text_statistics_tool.py +366 -0
- package/tools/text_wrapper.py +431 -0
- package/tools/timestamp_converter.py +422 -0
- package/tools/tool_loader.py +710 -0
- package/tools/translator_tools.py +523 -0
- package/tools/url_link_extractor.py +262 -0
- package/tools/url_parser.py +205 -0
- package/tools/whitespace_tools.py +356 -0
- package/tools/word_frequency_counter.py +147 -0
|
@@ -0,0 +1,2892 @@
|
|
|
1
|
+
import tkinter as tk
|
|
2
|
+
from tkinter import ttk, messagebox, filedialog
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import requests
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
import random
|
|
9
|
+
import webbrowser
|
|
10
|
+
import hashlib
|
|
11
|
+
import hmac
|
|
12
|
+
import urllib.parse
|
|
13
|
+
import os
|
|
14
|
+
import base64
|
|
15
|
+
from datetime import datetime
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from google.oauth2 import service_account
|
|
19
|
+
from google.auth.transport.requests import Request
|
|
20
|
+
GOOGLE_AUTH_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
GOOGLE_AUTH_AVAILABLE = False
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from huggingface_hub import InferenceClient
|
|
26
|
+
from huggingface_hub.utils import HfHubHTTPError
|
|
27
|
+
HUGGINGFACE_AVAILABLE = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
HUGGINGFACE_AVAILABLE = False
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from tools.huggingface_helper import process_huggingface_request
|
|
33
|
+
HUGGINGFACE_HELPER_AVAILABLE = True
|
|
34
|
+
except ImportError:
|
|
35
|
+
HUGGINGFACE_HELPER_AVAILABLE = False
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
from cryptography.fernet import Fernet
|
|
39
|
+
from cryptography.hazmat.primitives import hashes
|
|
40
|
+
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
41
|
+
ENCRYPTION_AVAILABLE = True
|
|
42
|
+
except ImportError:
|
|
43
|
+
ENCRYPTION_AVAILABLE = False
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
from core.streaming_text_handler import (
|
|
47
|
+
StreamingTextHandler,
|
|
48
|
+
StreamingTextManager,
|
|
49
|
+
StreamConfig,
|
|
50
|
+
StreamMetrics
|
|
51
|
+
)
|
|
52
|
+
STREAMING_AVAILABLE = True
|
|
53
|
+
except ImportError:
|
|
54
|
+
STREAMING_AVAILABLE = False
|
|
55
|
+
|
|
56
|
+
def get_system_encryption_key():
|
|
57
|
+
"""Generate encryption key based on system characteristics"""
|
|
58
|
+
if not ENCRYPTION_AVAILABLE:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
# Use machine-specific data as salt
|
|
63
|
+
machine_id = os.environ.get('COMPUTERNAME', '') + os.environ.get('USERNAME', '')
|
|
64
|
+
if not machine_id:
|
|
65
|
+
machine_id = os.environ.get('HOSTNAME', '') + os.environ.get('USER', '')
|
|
66
|
+
|
|
67
|
+
salt = machine_id.encode()[:16].ljust(16, b'0')
|
|
68
|
+
|
|
69
|
+
kdf = PBKDF2HMAC(
|
|
70
|
+
algorithm=hashes.SHA256(),
|
|
71
|
+
length=32,
|
|
72
|
+
salt=salt,
|
|
73
|
+
iterations=100000,
|
|
74
|
+
)
|
|
75
|
+
key = base64.urlsafe_b64encode(kdf.derive(b"pomera_ai_tool_encryption"))
|
|
76
|
+
return Fernet(key)
|
|
77
|
+
except Exception:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
def encrypt_api_key(api_key):
|
|
81
|
+
"""Encrypt API key for storage"""
|
|
82
|
+
if not api_key or api_key == "putinyourkey" or not ENCRYPTION_AVAILABLE:
|
|
83
|
+
return api_key
|
|
84
|
+
|
|
85
|
+
# Check if already encrypted (starts with our prefix)
|
|
86
|
+
if api_key.startswith("ENC:"):
|
|
87
|
+
return api_key
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
fernet = get_system_encryption_key()
|
|
91
|
+
if not fernet:
|
|
92
|
+
return api_key
|
|
93
|
+
|
|
94
|
+
encrypted = fernet.encrypt(api_key.encode())
|
|
95
|
+
return "ENC:" + base64.urlsafe_b64encode(encrypted).decode()
|
|
96
|
+
except Exception:
|
|
97
|
+
return api_key # Fallback to unencrypted if encryption fails
|
|
98
|
+
|
|
99
|
+
def decrypt_api_key(encrypted_key):
|
|
100
|
+
"""Decrypt API key for use"""
|
|
101
|
+
if not encrypted_key or encrypted_key == "putinyourkey" or not ENCRYPTION_AVAILABLE:
|
|
102
|
+
return encrypted_key
|
|
103
|
+
|
|
104
|
+
# Check if encrypted (starts with our prefix)
|
|
105
|
+
if not encrypted_key.startswith("ENC:"):
|
|
106
|
+
return encrypted_key # Not encrypted, return as-is
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
fernet = get_system_encryption_key()
|
|
110
|
+
if not fernet:
|
|
111
|
+
return encrypted_key
|
|
112
|
+
|
|
113
|
+
# Remove prefix and decrypt
|
|
114
|
+
encrypted_data = encrypted_key[4:] # Remove "ENC:" prefix
|
|
115
|
+
encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode())
|
|
116
|
+
decrypted = fernet.decrypt(encrypted_bytes)
|
|
117
|
+
return decrypted.decode()
|
|
118
|
+
except Exception:
|
|
119
|
+
return encrypted_key # Fallback to encrypted value if decryption fails
|
|
120
|
+
|
|
121
|
+
class AIToolsWidget(ttk.Frame):
|
|
122
|
+
"""A tabbed interface for all AI tools."""
|
|
123
|
+
|
|
124
|
+
def __init__(self, parent, app_instance, dialog_manager=None):
|
|
125
|
+
super().__init__(parent)
|
|
126
|
+
self.app = app_instance
|
|
127
|
+
self.logger = app_instance.logger
|
|
128
|
+
self.dialog_manager = dialog_manager
|
|
129
|
+
|
|
130
|
+
# AI provider configurations
|
|
131
|
+
self.ai_providers = {
|
|
132
|
+
"Google AI": {
|
|
133
|
+
"url_template": "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}",
|
|
134
|
+
"headers_template": {'Content-Type': 'application/json'},
|
|
135
|
+
"api_url": "https://aistudio.google.com/apikey"
|
|
136
|
+
},
|
|
137
|
+
"Vertex AI": {
|
|
138
|
+
"url_template": "https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers/google/models/{model}:generateContent",
|
|
139
|
+
"headers_template": {'Content-Type': 'application/json', 'Authorization': 'Bearer {access_token}'},
|
|
140
|
+
"api_url": "https://cloud.google.com/vertex-ai/docs/authentication"
|
|
141
|
+
},
|
|
142
|
+
"Azure AI": {
|
|
143
|
+
"url_template": "{endpoint}/models/chat/completions?api-version={api_version}", # Used for Foundry; Azure OpenAI uses /openai/deployments/{model}/...
|
|
144
|
+
"headers_template": {'Content-Type': 'application/json', 'api-key': '{api_key}'},
|
|
145
|
+
"api_url": "https://learn.microsoft.com/en-us/azure/ai-foundry/foundry-models/how-to/quickstart-ai-project"
|
|
146
|
+
},
|
|
147
|
+
"Anthropic AI": {
|
|
148
|
+
"url": "https://api.anthropic.com/v1/messages",
|
|
149
|
+
"headers_template": {"x-api-key": "{api_key}", "anthropic-version": "2023-06-01", "Content-Type": "application/json"},
|
|
150
|
+
"api_url": "https://console.anthropic.com/settings/keys"
|
|
151
|
+
},
|
|
152
|
+
"OpenAI": {
|
|
153
|
+
"url": "https://api.openai.com/v1/chat/completions",
|
|
154
|
+
"headers_template": {"Authorization": "Bearer {api_key}", "Content-Type": "application/json"},
|
|
155
|
+
"api_url": "https://platform.openai.com/settings/organization/api-keys"
|
|
156
|
+
},
|
|
157
|
+
"Cohere AI": {
|
|
158
|
+
"url": "https://api.cohere.com/v1/chat",
|
|
159
|
+
"headers_template": {"Authorization": "Bearer {api_key}", "Content-Type": "application/json"},
|
|
160
|
+
"api_url": "https://dashboard.cohere.com/api-keys"
|
|
161
|
+
},
|
|
162
|
+
"HuggingFace AI": {
|
|
163
|
+
"api_url": "https://huggingface.co/settings/tokens"
|
|
164
|
+
},
|
|
165
|
+
"Groq AI": {
|
|
166
|
+
"url": "https://api.groq.com/openai/v1/chat/completions",
|
|
167
|
+
"headers_template": {"Authorization": "Bearer {api_key}", "Content-Type": "application/json"},
|
|
168
|
+
"api_url": "https://console.groq.com/keys"
|
|
169
|
+
},
|
|
170
|
+
"OpenRouterAI": {
|
|
171
|
+
"url": "https://openrouter.ai/api/v1/chat/completions",
|
|
172
|
+
"headers_template": {
|
|
173
|
+
"Authorization": "Bearer {api_key}",
|
|
174
|
+
"Content-Type": "application/json",
|
|
175
|
+
"HTTP-Referer": "https://github.com/matbanik/Pomera-AI-Commander",
|
|
176
|
+
"X-Title": "Pomera AI Commander"
|
|
177
|
+
},
|
|
178
|
+
"api_url": "https://openrouter.ai/settings/keys"
|
|
179
|
+
},
|
|
180
|
+
"LM Studio": {
|
|
181
|
+
"url_template": "{base_url}/v1/chat/completions",
|
|
182
|
+
"headers_template": {"Content-Type": "application/json"},
|
|
183
|
+
"api_url": "http://lmstudio.ai/",
|
|
184
|
+
"local_service": True
|
|
185
|
+
},
|
|
186
|
+
"AWS Bedrock": {
|
|
187
|
+
# Using Converse API (recommended) - provides unified interface across all models
|
|
188
|
+
"url": "https://bedrock-runtime.{region}.amazonaws.com/model/{model}/converse",
|
|
189
|
+
"url_invoke": "https://bedrock-runtime.{region}.amazonaws.com/model/{model}/invoke", # Fallback for legacy
|
|
190
|
+
"headers_template": {"Content-Type": "application/json", "Accept": "application/json"},
|
|
191
|
+
"api_url": "https://docs.aws.amazon.com/bedrock/latest/userguide/getting-started.html",
|
|
192
|
+
"aws_service": True
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
self.current_provider = "Google AI"
|
|
197
|
+
self.ai_widgets = {}
|
|
198
|
+
self._ai_thread = None
|
|
199
|
+
|
|
200
|
+
# Streaming support - enabled by default when available
|
|
201
|
+
self._streaming_enabled = STREAMING_AVAILABLE
|
|
202
|
+
self._streaming_handler = None
|
|
203
|
+
self._streaming_manager = None
|
|
204
|
+
|
|
205
|
+
self.create_widgets()
|
|
206
|
+
|
|
207
|
+
# Show encryption status in logs
|
|
208
|
+
if ENCRYPTION_AVAILABLE:
|
|
209
|
+
self.logger.info("API Key encryption is ENABLED - keys will be encrypted at rest")
|
|
210
|
+
else:
|
|
211
|
+
self.logger.warning("API Key encryption is DISABLED - cryptography library not found. Install with: pip install cryptography")
|
|
212
|
+
|
|
213
|
+
# Show streaming status in logs
|
|
214
|
+
if STREAMING_AVAILABLE:
|
|
215
|
+
self.logger.info("Streaming text handler is ENABLED - AI responses will be streamed progressively")
|
|
216
|
+
else:
|
|
217
|
+
self.logger.warning("Streaming text handler is NOT AVAILABLE - AI responses will be displayed at once")
|
|
218
|
+
|
|
219
|
+
def apply_font_to_widgets(self, font_tuple):
|
|
220
|
+
"""Apply font to all text widgets in AI Tools."""
|
|
221
|
+
try:
|
|
222
|
+
for provider_name, widgets in self.ai_widgets.items():
|
|
223
|
+
for widget_name, widget in widgets.items():
|
|
224
|
+
# Apply to Text widgets (like system prompts)
|
|
225
|
+
if isinstance(widget, tk.Text):
|
|
226
|
+
widget.configure(font=font_tuple)
|
|
227
|
+
|
|
228
|
+
self.logger.debug(f"Applied font {font_tuple} to AI Tools widgets")
|
|
229
|
+
except Exception as e:
|
|
230
|
+
self.logger.debug(f"Error applying font to AI Tools widgets: {e}")
|
|
231
|
+
|
|
232
|
+
def get_api_key_for_provider(self, provider_name, settings):
|
|
233
|
+
"""Get decrypted API key for a provider"""
|
|
234
|
+
if provider_name == "LM Studio":
|
|
235
|
+
return "" # LM Studio doesn't use API keys
|
|
236
|
+
|
|
237
|
+
encrypted_key = settings.get("API_KEY", "")
|
|
238
|
+
return decrypt_api_key(encrypted_key)
|
|
239
|
+
|
|
240
|
+
def get_aws_credential(self, settings, credential_name):
|
|
241
|
+
"""Get decrypted AWS credential"""
|
|
242
|
+
encrypted_credential = settings.get(credential_name, "")
|
|
243
|
+
return decrypt_api_key(encrypted_credential)
|
|
244
|
+
|
|
245
|
+
def save_encrypted_api_key(self, provider_name, api_key):
|
|
246
|
+
"""Save encrypted API key for a provider"""
|
|
247
|
+
if provider_name == "LM Studio":
|
|
248
|
+
return # LM Studio doesn't use API keys
|
|
249
|
+
|
|
250
|
+
if not api_key or api_key == "putinyourkey":
|
|
251
|
+
# Don't encrypt empty or placeholder keys
|
|
252
|
+
self.app.settings["tool_settings"][provider_name]["API_KEY"] = api_key
|
|
253
|
+
else:
|
|
254
|
+
encrypted_key = encrypt_api_key(api_key)
|
|
255
|
+
self.app.settings["tool_settings"][provider_name]["API_KEY"] = encrypted_key
|
|
256
|
+
|
|
257
|
+
self.app.save_settings()
|
|
258
|
+
|
|
259
|
+
def upload_vertex_ai_json(self, provider_name):
|
|
260
|
+
"""Upload and parse Vertex AI service account JSON file."""
|
|
261
|
+
try:
|
|
262
|
+
file_path = filedialog.askopenfilename(
|
|
263
|
+
title="Select Vertex AI Service Account JSON File",
|
|
264
|
+
filetypes=[
|
|
265
|
+
("JSON files", "*.json"),
|
|
266
|
+
("All files", "*.*")
|
|
267
|
+
]
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
if not file_path:
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
# Read and parse JSON file
|
|
274
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
275
|
+
json_data = json.load(f)
|
|
276
|
+
|
|
277
|
+
# Validate required fields
|
|
278
|
+
required_fields = ['type', 'project_id', 'private_key_id', 'private_key',
|
|
279
|
+
'client_email', 'client_id', 'auth_uri', 'token_uri']
|
|
280
|
+
missing_fields = [field for field in required_fields if field not in json_data]
|
|
281
|
+
|
|
282
|
+
if missing_fields:
|
|
283
|
+
self._show_error("Invalid JSON File",
|
|
284
|
+
f"Missing required fields: {', '.join(missing_fields)}")
|
|
285
|
+
return
|
|
286
|
+
|
|
287
|
+
# Encrypt private_key
|
|
288
|
+
encrypted_private_key = encrypt_api_key(json_data['private_key'])
|
|
289
|
+
|
|
290
|
+
# Store in database
|
|
291
|
+
if hasattr(self.app, 'db_settings_manager') and self.app.db_settings_manager:
|
|
292
|
+
conn_manager = self.app.db_settings_manager.connection_manager
|
|
293
|
+
with conn_manager.transaction() as conn:
|
|
294
|
+
# Delete existing record (singleton pattern)
|
|
295
|
+
conn.execute("DELETE FROM vertex_ai_json")
|
|
296
|
+
|
|
297
|
+
# Insert new record
|
|
298
|
+
conn.execute("""
|
|
299
|
+
INSERT INTO vertex_ai_json (
|
|
300
|
+
type, project_id, private_key_id, private_key,
|
|
301
|
+
client_email, client_id, auth_uri, token_uri,
|
|
302
|
+
auth_provider_x509_cert_url, client_x509_cert_url, universe_domain
|
|
303
|
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
304
|
+
""", (
|
|
305
|
+
json_data.get('type', ''),
|
|
306
|
+
json_data.get('project_id', ''),
|
|
307
|
+
json_data.get('private_key_id', ''),
|
|
308
|
+
encrypted_private_key,
|
|
309
|
+
json_data.get('client_email', ''),
|
|
310
|
+
json_data.get('client_id', ''),
|
|
311
|
+
json_data.get('auth_uri', ''),
|
|
312
|
+
json_data.get('token_uri', ''),
|
|
313
|
+
json_data.get('auth_provider_x509_cert_url'),
|
|
314
|
+
json_data.get('client_x509_cert_url'),
|
|
315
|
+
json_data.get('universe_domain')
|
|
316
|
+
))
|
|
317
|
+
|
|
318
|
+
# Update location setting if not already set (default to us-central1)
|
|
319
|
+
settings = self.get_current_settings()
|
|
320
|
+
if not settings.get("LOCATION"):
|
|
321
|
+
self.app.db_settings_manager.set_tool_setting(provider_name, "LOCATION", "us-central1")
|
|
322
|
+
|
|
323
|
+
# Update project_id in tool_settings if not already set
|
|
324
|
+
if not settings.get("PROJECT_ID"):
|
|
325
|
+
self.app.db_settings_manager.set_tool_setting(provider_name, "PROJECT_ID", json_data.get('project_id', ''))
|
|
326
|
+
|
|
327
|
+
self._show_info("Success", "Vertex AI service account JSON uploaded and stored successfully.")
|
|
328
|
+
self.logger.info(f"Vertex AI JSON uploaded: project_id={json_data.get('project_id')}")
|
|
329
|
+
|
|
330
|
+
# Update status label if widget exists
|
|
331
|
+
if provider_name in self.ai_widgets and "JSON_STATUS" in self.ai_widgets[provider_name]:
|
|
332
|
+
status_label = self.ai_widgets[provider_name]["JSON_STATUS"]
|
|
333
|
+
status_label.config(text=f"✓ Loaded: {json_data.get('project_id', 'Unknown')}", foreground="green")
|
|
334
|
+
else:
|
|
335
|
+
self._show_error("Error", "Database settings manager not available")
|
|
336
|
+
|
|
337
|
+
except json.JSONDecodeError as e:
|
|
338
|
+
self._show_error("Invalid JSON", f"The file is not valid JSON: {str(e)}")
|
|
339
|
+
except Exception as e:
|
|
340
|
+
self.logger.error(f"Error uploading Vertex AI JSON: {e}", exc_info=True)
|
|
341
|
+
self._show_error("Error", f"Failed to upload JSON file: {str(e)}")
|
|
342
|
+
|
|
343
|
+
def get_vertex_ai_credentials(self):
|
|
344
|
+
"""Get Vertex AI service account credentials from database."""
|
|
345
|
+
try:
|
|
346
|
+
if not hasattr(self.app, 'db_settings_manager') or not self.app.db_settings_manager:
|
|
347
|
+
return None
|
|
348
|
+
|
|
349
|
+
conn_manager = self.app.db_settings_manager.connection_manager
|
|
350
|
+
with conn_manager.transaction() as conn:
|
|
351
|
+
cursor = conn.execute("""
|
|
352
|
+
SELECT type, project_id, private_key_id, private_key,
|
|
353
|
+
client_email, client_id, auth_uri, token_uri,
|
|
354
|
+
auth_provider_x509_cert_url, client_x509_cert_url, universe_domain
|
|
355
|
+
FROM vertex_ai_json
|
|
356
|
+
ORDER BY updated_at DESC
|
|
357
|
+
LIMIT 1
|
|
358
|
+
""")
|
|
359
|
+
row = cursor.fetchone()
|
|
360
|
+
|
|
361
|
+
if not row:
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
# Decrypt private_key
|
|
365
|
+
decrypted_private_key = decrypt_api_key(row[3])
|
|
366
|
+
|
|
367
|
+
# Reconstruct JSON structure
|
|
368
|
+
credentials_dict = {
|
|
369
|
+
'type': row[0],
|
|
370
|
+
'project_id': row[1],
|
|
371
|
+
'private_key_id': row[2],
|
|
372
|
+
'private_key': decrypted_private_key,
|
|
373
|
+
'client_email': row[4],
|
|
374
|
+
'client_id': row[5],
|
|
375
|
+
'auth_uri': row[6],
|
|
376
|
+
'token_uri': row[7],
|
|
377
|
+
'auth_provider_x509_cert_url': row[8],
|
|
378
|
+
'client_x509_cert_url': row[9],
|
|
379
|
+
'universe_domain': row[10]
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
return credentials_dict
|
|
383
|
+
|
|
384
|
+
except Exception as e:
|
|
385
|
+
self.logger.error(f"Error getting Vertex AI credentials: {e}", exc_info=True)
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
def get_vertex_ai_access_token(self):
|
|
389
|
+
"""Get OAuth2 access token for Vertex AI using service account credentials."""
|
|
390
|
+
if not GOOGLE_AUTH_AVAILABLE:
|
|
391
|
+
self.logger.error("google-auth library not available")
|
|
392
|
+
return None
|
|
393
|
+
|
|
394
|
+
try:
|
|
395
|
+
credentials_dict = self.get_vertex_ai_credentials()
|
|
396
|
+
if not credentials_dict:
|
|
397
|
+
self.logger.warning("No Vertex AI credentials found in database")
|
|
398
|
+
return None
|
|
399
|
+
|
|
400
|
+
# Create credentials from service account info
|
|
401
|
+
credentials = service_account.Credentials.from_service_account_info(
|
|
402
|
+
credentials_dict,
|
|
403
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Refresh token if needed
|
|
407
|
+
if not credentials.valid:
|
|
408
|
+
request = Request()
|
|
409
|
+
credentials.refresh(request)
|
|
410
|
+
|
|
411
|
+
# Get access token
|
|
412
|
+
access_token = credentials.token
|
|
413
|
+
self.logger.debug("Vertex AI access token obtained successfully")
|
|
414
|
+
|
|
415
|
+
return access_token
|
|
416
|
+
|
|
417
|
+
except Exception as e:
|
|
418
|
+
self.logger.error(f"Error getting Vertex AI access token: {e}", exc_info=True)
|
|
419
|
+
return None
|
|
420
|
+
|
|
421
|
+
def _show_info(self, title, message, category="success"):
|
|
422
|
+
"""Show info dialog using DialogManager if available, otherwise use messagebox."""
|
|
423
|
+
if self.dialog_manager:
|
|
424
|
+
return self.dialog_manager.show_info(title, message, category)
|
|
425
|
+
else:
|
|
426
|
+
from tkinter import messagebox
|
|
427
|
+
messagebox.showinfo(title, message)
|
|
428
|
+
return True
|
|
429
|
+
|
|
430
|
+
def _show_warning(self, title, message, category="warning"):
|
|
431
|
+
"""Show warning dialog using DialogManager if available, otherwise use messagebox."""
|
|
432
|
+
if self.dialog_manager:
|
|
433
|
+
return self.dialog_manager.show_warning(title, message, category)
|
|
434
|
+
else:
|
|
435
|
+
from tkinter import messagebox
|
|
436
|
+
messagebox.showwarning(title, message)
|
|
437
|
+
return True
|
|
438
|
+
|
|
439
|
+
def _show_error(self, title, message):
|
|
440
|
+
"""Show error dialog using DialogManager if available, otherwise use messagebox."""
|
|
441
|
+
if self.dialog_manager:
|
|
442
|
+
return self.dialog_manager.show_error(title, message)
|
|
443
|
+
else:
|
|
444
|
+
from tkinter import messagebox
|
|
445
|
+
messagebox.showerror(title, message)
|
|
446
|
+
return True
|
|
447
|
+
|
|
448
|
+
def create_widgets(self):
|
|
449
|
+
"""Create the tabbed interface for AI tools."""
|
|
450
|
+
# Create notebook for tabs
|
|
451
|
+
self.notebook = ttk.Notebook(self)
|
|
452
|
+
self.notebook.pack(fill=tk.BOTH, expand=True)
|
|
453
|
+
|
|
454
|
+
# Create tabs for each AI provider
|
|
455
|
+
self.tabs = {}
|
|
456
|
+
for provider in self.ai_providers.keys():
|
|
457
|
+
tab_frame = ttk.Frame(self.notebook)
|
|
458
|
+
self.notebook.add(tab_frame, text=provider)
|
|
459
|
+
self.tabs[provider] = tab_frame
|
|
460
|
+
self.create_provider_widgets(tab_frame, provider)
|
|
461
|
+
|
|
462
|
+
# Bind tab selection event
|
|
463
|
+
self.notebook.bind("<<NotebookTabChanged>>", self.on_tab_changed)
|
|
464
|
+
|
|
465
|
+
# Set initial tab
|
|
466
|
+
self.notebook.select(0)
|
|
467
|
+
self.current_provider = list(self.ai_providers.keys())[0]
|
|
468
|
+
|
|
469
|
+
def on_tab_changed(self, event=None):
|
|
470
|
+
"""Handle tab change event."""
|
|
471
|
+
try:
|
|
472
|
+
selected_tab = self.notebook.select()
|
|
473
|
+
tab_index = self.notebook.index(selected_tab)
|
|
474
|
+
self.current_provider = list(self.ai_providers.keys())[tab_index]
|
|
475
|
+
|
|
476
|
+
# Ensure AWS Bedrock fields are properly visible when switching to that tab
|
|
477
|
+
if self.current_provider == "AWS Bedrock":
|
|
478
|
+
self.after_idle(lambda: self.update_aws_credentials_fields(self.current_provider))
|
|
479
|
+
|
|
480
|
+
self.app.on_tool_setting_change() # Notify parent app of change
|
|
481
|
+
except tk.TclError:
|
|
482
|
+
pass
|
|
483
|
+
|
|
484
|
+
def create_provider_widgets(self, parent, provider_name):
|
|
485
|
+
"""Create widgets for a specific AI provider."""
|
|
486
|
+
# Get settings for this provider
|
|
487
|
+
settings = self.app.settings["tool_settings"].get(provider_name, {})
|
|
488
|
+
|
|
489
|
+
# Create main container with reduced padding - don't expand vertically
|
|
490
|
+
main_frame = ttk.Frame(parent)
|
|
491
|
+
main_frame.pack(fill=tk.X, padx=5, pady=5, anchor="n")
|
|
492
|
+
|
|
493
|
+
# Top frame with API key, model, and process button all on same line
|
|
494
|
+
top_frame = ttk.Frame(main_frame)
|
|
495
|
+
top_frame.pack(fill=tk.X, pady=(0, 5))
|
|
496
|
+
|
|
497
|
+
# Store reference for later access
|
|
498
|
+
if provider_name not in self.ai_widgets:
|
|
499
|
+
self.ai_widgets[provider_name] = {}
|
|
500
|
+
|
|
501
|
+
# API Configuration section (different for LM Studio and AWS Bedrock)
|
|
502
|
+
if provider_name == "LM Studio":
|
|
503
|
+
# LM Studio Configuration section
|
|
504
|
+
lm_frame = ttk.LabelFrame(top_frame, text="LM Studio Configuration")
|
|
505
|
+
lm_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
506
|
+
|
|
507
|
+
ttk.Label(lm_frame, text="Base URL:").pack(side=tk.LEFT, padx=(5, 5))
|
|
508
|
+
|
|
509
|
+
base_url_var = tk.StringVar(value=settings.get("BASE_URL", "http://127.0.0.1:1234"))
|
|
510
|
+
base_url_entry = ttk.Entry(lm_frame, textvariable=base_url_var, width=20)
|
|
511
|
+
base_url_entry.pack(side=tk.LEFT, padx=(0, 5))
|
|
512
|
+
base_url_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
513
|
+
|
|
514
|
+
self.ai_widgets[provider_name]["BASE_URL"] = base_url_var
|
|
515
|
+
|
|
516
|
+
# Refresh models button
|
|
517
|
+
ttk.Button(lm_frame, text="Refresh Models",
|
|
518
|
+
command=lambda: self.refresh_lm_studio_models(provider_name)).pack(side=tk.LEFT, padx=(5, 5))
|
|
519
|
+
elif provider_name == "AWS Bedrock":
|
|
520
|
+
# AWS Bedrock Configuration section
|
|
521
|
+
aws_frame = ttk.LabelFrame(top_frame, text="AWS Bedrock Configuration")
|
|
522
|
+
aws_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
523
|
+
|
|
524
|
+
# Authentication Method
|
|
525
|
+
ttk.Label(aws_frame, text="Auth Method:").pack(side=tk.LEFT, padx=(5, 5))
|
|
526
|
+
|
|
527
|
+
auth_method_var = tk.StringVar(value=settings.get("AUTH_METHOD", "api_key"))
|
|
528
|
+
auth_combo = ttk.Combobox(aws_frame, textvariable=auth_method_var,
|
|
529
|
+
values=[
|
|
530
|
+
"API Key (Bearer Token)",
|
|
531
|
+
"IAM (Explicit Credentials)",
|
|
532
|
+
"Session Token (Temporary Credentials)",
|
|
533
|
+
"IAM (Implied Credentials)"
|
|
534
|
+
],
|
|
535
|
+
state="readonly", width=30)
|
|
536
|
+
|
|
537
|
+
# Set the display value based on stored value
|
|
538
|
+
stored_auth = settings.get("AUTH_METHOD", "api_key") # Default to api_key for consistency
|
|
539
|
+
|
|
540
|
+
# Ensure the AUTH_METHOD is saved in settings if not present
|
|
541
|
+
if "AUTH_METHOD" not in settings:
|
|
542
|
+
if provider_name not in self.app.settings["tool_settings"]:
|
|
543
|
+
self.app.settings["tool_settings"][provider_name] = {}
|
|
544
|
+
self.app.settings["tool_settings"][provider_name]["AUTH_METHOD"] = stored_auth
|
|
545
|
+
self.app.save_settings()
|
|
546
|
+
self.logger.debug(f"Initialized AWS Bedrock AUTH_METHOD to: {stored_auth}")
|
|
547
|
+
|
|
548
|
+
if stored_auth == "api_key":
|
|
549
|
+
auth_combo.set("API Key (Bearer Token)")
|
|
550
|
+
elif stored_auth == "iam":
|
|
551
|
+
auth_combo.set("IAM (Explicit Credentials)")
|
|
552
|
+
elif stored_auth == "sessionToken":
|
|
553
|
+
auth_combo.set("Session Token (Temporary Credentials)")
|
|
554
|
+
elif stored_auth == "iam_role":
|
|
555
|
+
auth_combo.set("IAM (Implied Credentials)")
|
|
556
|
+
else:
|
|
557
|
+
# Fallback for any unknown values
|
|
558
|
+
auth_combo.set("API Key (Bearer Token)")
|
|
559
|
+
stored_auth = "api_key"
|
|
560
|
+
# Update settings with corrected value
|
|
561
|
+
self.app.settings["tool_settings"][provider_name]["AUTH_METHOD"] = stored_auth
|
|
562
|
+
self.app.save_settings()
|
|
563
|
+
|
|
564
|
+
auth_combo.pack(side=tk.LEFT, padx=(0, 5))
|
|
565
|
+
auth_method_var.trace_add("write", lambda *args: [self.on_aws_auth_change(provider_name), self.update_aws_credentials_fields(provider_name)])
|
|
566
|
+
|
|
567
|
+
self.ai_widgets[provider_name]["AUTH_METHOD"] = auth_method_var
|
|
568
|
+
|
|
569
|
+
# AWS Region
|
|
570
|
+
ttk.Label(aws_frame, text="Region:").pack(side=tk.LEFT, padx=(10, 5))
|
|
571
|
+
|
|
572
|
+
region_var = tk.StringVar(value=settings.get("AWS_REGION", "us-west-2"))
|
|
573
|
+
aws_regions = [
|
|
574
|
+
"us-east-1", "us-east-2", "us-west-1", "us-west-2",
|
|
575
|
+
"ca-central-1", "eu-north-1", "eu-west-1", "eu-west-2",
|
|
576
|
+
"eu-west-3", "eu-central-1", "eu-south-1", "af-south-1",
|
|
577
|
+
"ap-northeast-1", "ap-northeast-2", "ap-northeast-3",
|
|
578
|
+
"ap-southeast-1", "ap-southeast-2", "ap-southeast-3",
|
|
579
|
+
"ap-east-1", "ap-south-1", "sa-east-1", "me-south-1"
|
|
580
|
+
]
|
|
581
|
+
region_combo = ttk.Combobox(aws_frame, textvariable=region_var,
|
|
582
|
+
values=aws_regions, state="readonly", width=15)
|
|
583
|
+
region_combo.pack(side=tk.LEFT, padx=(0, 5))
|
|
584
|
+
region_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
585
|
+
|
|
586
|
+
self.ai_widgets[provider_name]["AWS_REGION"] = region_var
|
|
587
|
+
elif provider_name == "Vertex AI":
|
|
588
|
+
# Vertex AI Configuration section with JSON upload
|
|
589
|
+
encryption_status = "🔒" if ENCRYPTION_AVAILABLE else "⚠️"
|
|
590
|
+
api_frame = ttk.LabelFrame(top_frame, text=f"API Configuration {encryption_status}")
|
|
591
|
+
api_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
592
|
+
|
|
593
|
+
ttk.Label(api_frame, text="Service Account:").pack(side=tk.LEFT, padx=(5, 5))
|
|
594
|
+
|
|
595
|
+
# Upload JSON button for Vertex AI
|
|
596
|
+
ttk.Button(api_frame, text="Upload JSON",
|
|
597
|
+
command=lambda: self.upload_vertex_ai_json(provider_name)).pack(side=tk.LEFT, padx=(5, 5))
|
|
598
|
+
|
|
599
|
+
# Status label to show if JSON is loaded
|
|
600
|
+
status_label = ttk.Label(api_frame, text="", foreground="gray")
|
|
601
|
+
status_label.pack(side=tk.LEFT, padx=(5, 5))
|
|
602
|
+
self.ai_widgets[provider_name]["JSON_STATUS"] = status_label
|
|
603
|
+
|
|
604
|
+
# Check if credentials exist and update status
|
|
605
|
+
credentials = self.get_vertex_ai_credentials()
|
|
606
|
+
if credentials:
|
|
607
|
+
status_label.config(text=f"✓ Loaded: {credentials.get('project_id', 'Unknown')}", foreground="green")
|
|
608
|
+
else:
|
|
609
|
+
status_label.config(text="No JSON loaded", foreground="red")
|
|
610
|
+
|
|
611
|
+
# API key link button (docs)
|
|
612
|
+
ttk.Button(api_frame, text="Get API Key",
|
|
613
|
+
command=lambda: webbrowser.open(self.ai_providers[provider_name]["api_url"])).pack(side=tk.LEFT, padx=(5, 5))
|
|
614
|
+
elif provider_name == "Azure AI":
|
|
615
|
+
# Azure AI Configuration section
|
|
616
|
+
encryption_status = "🔒" if ENCRYPTION_AVAILABLE else "⚠️"
|
|
617
|
+
api_frame = ttk.LabelFrame(top_frame, text=f"API Configuration {encryption_status}")
|
|
618
|
+
api_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
619
|
+
|
|
620
|
+
ttk.Label(api_frame, text="API Key:").pack(side=tk.LEFT, padx=(5, 5))
|
|
621
|
+
|
|
622
|
+
# Get decrypted API key for display
|
|
623
|
+
decrypted_key = self.get_api_key_for_provider(provider_name, settings)
|
|
624
|
+
api_key_var = tk.StringVar(value=decrypted_key if decrypted_key else "putinyourkey")
|
|
625
|
+
api_key_entry = ttk.Entry(api_frame, textvariable=api_key_var, show="*", width=20)
|
|
626
|
+
api_key_entry.pack(side=tk.LEFT, padx=(0, 5))
|
|
627
|
+
api_key_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
628
|
+
|
|
629
|
+
self.ai_widgets[provider_name]["API_KEY"] = api_key_var
|
|
630
|
+
|
|
631
|
+
# API key link button
|
|
632
|
+
ttk.Button(api_frame, text="Get API Key",
|
|
633
|
+
command=lambda: webbrowser.open(self.ai_providers[provider_name]["api_url"])).pack(side=tk.LEFT, padx=(5, 5))
|
|
634
|
+
|
|
635
|
+
# Resource Endpoint field
|
|
636
|
+
endpoint_frame = ttk.LabelFrame(top_frame, text="Endpoint")
|
|
637
|
+
endpoint_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
638
|
+
|
|
639
|
+
ttk.Label(endpoint_frame, text="Resource Endpoint:").pack(side=tk.LEFT, padx=(5, 5))
|
|
640
|
+
|
|
641
|
+
endpoint_var = tk.StringVar(value=settings.get("ENDPOINT", ""))
|
|
642
|
+
endpoint_entry = ttk.Entry(endpoint_frame, textvariable=endpoint_var, width=30)
|
|
643
|
+
endpoint_entry.pack(side=tk.LEFT, padx=(0, 5))
|
|
644
|
+
endpoint_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
645
|
+
|
|
646
|
+
self.ai_widgets[provider_name]["ENDPOINT"] = endpoint_var
|
|
647
|
+
|
|
648
|
+
# API Version field
|
|
649
|
+
api_version_frame = ttk.LabelFrame(top_frame, text="API Version")
|
|
650
|
+
api_version_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
651
|
+
|
|
652
|
+
ttk.Label(api_version_frame, text="API Version:").pack(side=tk.LEFT, padx=(5, 5))
|
|
653
|
+
|
|
654
|
+
api_version_var = tk.StringVar(value=settings.get("API_VERSION", "2024-10-21"))
|
|
655
|
+
api_version_entry = ttk.Entry(api_version_frame, textvariable=api_version_var, width=15)
|
|
656
|
+
api_version_entry.pack(side=tk.LEFT, padx=(0, 5))
|
|
657
|
+
api_version_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
658
|
+
|
|
659
|
+
self.ai_widgets[provider_name]["API_VERSION"] = api_version_var
|
|
660
|
+
else:
|
|
661
|
+
# Standard API Configuration section
|
|
662
|
+
encryption_status = "🔒" if ENCRYPTION_AVAILABLE else "⚠️"
|
|
663
|
+
api_frame = ttk.LabelFrame(top_frame, text=f"API Configuration {encryption_status}")
|
|
664
|
+
api_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
665
|
+
|
|
666
|
+
ttk.Label(api_frame, text="API Key:").pack(side=tk.LEFT, padx=(5, 5))
|
|
667
|
+
|
|
668
|
+
# Get decrypted API key for display
|
|
669
|
+
decrypted_key = self.get_api_key_for_provider(provider_name, settings)
|
|
670
|
+
api_key_var = tk.StringVar(value=decrypted_key if decrypted_key else "putinyourkey")
|
|
671
|
+
api_key_entry = ttk.Entry(api_frame, textvariable=api_key_var, show="*", width=20)
|
|
672
|
+
api_key_entry.pack(side=tk.LEFT, padx=(0, 5))
|
|
673
|
+
api_key_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
674
|
+
|
|
675
|
+
self.ai_widgets[provider_name]["API_KEY"] = api_key_var
|
|
676
|
+
|
|
677
|
+
# API key link button
|
|
678
|
+
ttk.Button(api_frame, text="Get API Key",
|
|
679
|
+
command=lambda: webbrowser.open(self.ai_providers[provider_name]["api_url"])).pack(side=tk.LEFT, padx=(5, 5))
|
|
680
|
+
|
|
681
|
+
# Vertex AI Location field (similar to AWS Region)
|
|
682
|
+
if provider_name == "Vertex AI":
|
|
683
|
+
location_frame = ttk.LabelFrame(top_frame, text="Location")
|
|
684
|
+
location_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
685
|
+
|
|
686
|
+
ttk.Label(location_frame, text="Location:").pack(side=tk.LEFT, padx=(5, 5))
|
|
687
|
+
|
|
688
|
+
location_var = tk.StringVar(value=settings.get("LOCATION", "us-central1"))
|
|
689
|
+
vertex_locations = [
|
|
690
|
+
"us-central1", "us-east1", "us-east4", "us-west1", "us-west4",
|
|
691
|
+
"europe-west1", "europe-west4", "europe-west6", "asia-east1",
|
|
692
|
+
"asia-northeast1", "asia-southeast1", "asia-south1"
|
|
693
|
+
]
|
|
694
|
+
location_combo = ttk.Combobox(location_frame, textvariable=location_var,
|
|
695
|
+
values=vertex_locations, state="readonly", width=15)
|
|
696
|
+
location_combo.pack(side=tk.LEFT, padx=(0, 5))
|
|
697
|
+
location_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
698
|
+
|
|
699
|
+
self.ai_widgets[provider_name]["LOCATION"] = location_var
|
|
700
|
+
|
|
701
|
+
# Model Configuration section
|
|
702
|
+
model_frame = ttk.LabelFrame(top_frame, text="Model Configuration")
|
|
703
|
+
model_frame.pack(side=tk.LEFT, padx=(0, 10), fill=tk.Y)
|
|
704
|
+
|
|
705
|
+
if provider_name == "Azure AI":
|
|
706
|
+
ttk.Label(model_frame, text="Model (Deployment Name):").pack(side=tk.LEFT, padx=(5, 5))
|
|
707
|
+
else:
|
|
708
|
+
ttk.Label(model_frame, text="Model:").pack(side=tk.LEFT, padx=(5, 5))
|
|
709
|
+
|
|
710
|
+
# Set default models for Vertex AI if not present
|
|
711
|
+
if provider_name == "Vertex AI":
|
|
712
|
+
models_list = settings.get("MODELS_LIST", [])
|
|
713
|
+
if not models_list:
|
|
714
|
+
models_list = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
|
715
|
+
if hasattr(self.app, 'db_settings_manager') and self.app.db_settings_manager:
|
|
716
|
+
self.app.db_settings_manager.set_tool_setting(provider_name, "MODELS_LIST", models_list)
|
|
717
|
+
else:
|
|
718
|
+
if provider_name not in self.app.settings["tool_settings"]:
|
|
719
|
+
self.app.settings["tool_settings"][provider_name] = {}
|
|
720
|
+
self.app.settings["tool_settings"][provider_name]["MODELS_LIST"] = models_list
|
|
721
|
+
self.app.save_settings()
|
|
722
|
+
|
|
723
|
+
# Set default model if not present
|
|
724
|
+
if not settings.get("MODEL"):
|
|
725
|
+
default_model = "gemini-2.5-flash"
|
|
726
|
+
if hasattr(self.app, 'db_settings_manager') and self.app.db_settings_manager:
|
|
727
|
+
self.app.db_settings_manager.set_tool_setting(provider_name, "MODEL", default_model)
|
|
728
|
+
else:
|
|
729
|
+
if provider_name not in self.app.settings["tool_settings"]:
|
|
730
|
+
self.app.settings["tool_settings"][provider_name] = {}
|
|
731
|
+
self.app.settings["tool_settings"][provider_name]["MODEL"] = default_model
|
|
732
|
+
self.app.save_settings()
|
|
733
|
+
settings["MODEL"] = default_model
|
|
734
|
+
settings["MODELS_LIST"] = models_list
|
|
735
|
+
|
|
736
|
+
model_var = tk.StringVar(value=settings.get("MODEL", ""))
|
|
737
|
+
models_list = settings.get("MODELS_LIST", [])
|
|
738
|
+
|
|
739
|
+
model_combo = ttk.Combobox(model_frame, textvariable=model_var, values=models_list, width=30)
|
|
740
|
+
model_combo.pack(side=tk.LEFT, padx=(0, 5))
|
|
741
|
+
model_combo.bind("<<ComboboxSelected>>", lambda e: self.on_setting_change(provider_name))
|
|
742
|
+
model_combo.bind("<KeyRelease>", lambda e: self.on_setting_change(provider_name))
|
|
743
|
+
|
|
744
|
+
# Model buttons
|
|
745
|
+
if provider_name == "AWS Bedrock":
|
|
746
|
+
# Refresh Models button for AWS Bedrock
|
|
747
|
+
ttk.Button(model_frame, text="Refresh Models",
|
|
748
|
+
command=lambda: self.refresh_bedrock_models(provider_name)).pack(side=tk.LEFT, padx=(0, 5))
|
|
749
|
+
elif provider_name == "LM Studio":
|
|
750
|
+
# Store model combobox reference for LM Studio
|
|
751
|
+
pass # LM Studio refresh button is in the configuration section
|
|
752
|
+
elif provider_name == "Google AI":
|
|
753
|
+
# Refresh Models button for Google AI (fetches from API)
|
|
754
|
+
ttk.Button(model_frame, text="Refresh",
|
|
755
|
+
command=lambda: self.refresh_google_ai_models(provider_name)).pack(side=tk.LEFT, padx=(0, 5))
|
|
756
|
+
# Model edit button
|
|
757
|
+
ttk.Button(model_frame, text="\u270E",
|
|
758
|
+
command=lambda: self.open_model_editor(provider_name), width=3).pack(side=tk.LEFT, padx=(0, 5))
|
|
759
|
+
elif provider_name == "OpenRouterAI":
|
|
760
|
+
# Refresh Models button for OpenRouter (fetches from API)
|
|
761
|
+
ttk.Button(model_frame, text="Refresh",
|
|
762
|
+
command=lambda: self.refresh_openrouter_models(provider_name)).pack(side=tk.LEFT, padx=(0, 5))
|
|
763
|
+
# Model edit button
|
|
764
|
+
ttk.Button(model_frame, text="\u270E",
|
|
765
|
+
command=lambda: self.open_model_editor(provider_name), width=3).pack(side=tk.LEFT, padx=(0, 5))
|
|
766
|
+
else:
|
|
767
|
+
# Model edit button for other providers
|
|
768
|
+
ttk.Button(model_frame, text="\u270E",
|
|
769
|
+
command=lambda: self.open_model_editor(provider_name), width=3).pack(side=tk.LEFT, padx=(0, 5))
|
|
770
|
+
|
|
771
|
+
self.ai_widgets[provider_name]["MODEL"] = model_var
|
|
772
|
+
|
|
773
|
+
# Store model combobox reference for LM Studio and AWS Bedrock
|
|
774
|
+
if provider_name in ["LM Studio", "AWS Bedrock"]:
|
|
775
|
+
self.ai_widgets[provider_name]["MODEL_COMBO"] = model_combo
|
|
776
|
+
|
|
777
|
+
# Max Tokens for LM Studio
|
|
778
|
+
if provider_name == "LM Studio":
|
|
779
|
+
ttk.Label(model_frame, text="Max Tokens:").pack(side=tk.LEFT, padx=(10, 5))
|
|
780
|
+
|
|
781
|
+
max_tokens_var = tk.StringVar(value=settings.get("MAX_TOKENS", "2048"))
|
|
782
|
+
max_tokens_entry = ttk.Entry(model_frame, textvariable=max_tokens_var, width=10)
|
|
783
|
+
max_tokens_entry.pack(side=tk.LEFT, padx=(0, 5))
|
|
784
|
+
max_tokens_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
785
|
+
|
|
786
|
+
self.ai_widgets[provider_name]["MAX_TOKENS"] = max_tokens_var
|
|
787
|
+
|
|
788
|
+
# AWS Bedrock specific fields
|
|
789
|
+
if provider_name == "AWS Bedrock":
|
|
790
|
+
# AWS Credentials section
|
|
791
|
+
self.aws_creds_frame = ttk.LabelFrame(main_frame, text="AWS Credentials")
|
|
792
|
+
self.aws_creds_frame.pack(fill=tk.X, pady=(5, 0))
|
|
793
|
+
|
|
794
|
+
# Add note about AWS Bedrock authentication
|
|
795
|
+
note_frame = ttk.Frame(self.aws_creds_frame)
|
|
796
|
+
note_frame.pack(fill=tk.X, padx=5, pady=2)
|
|
797
|
+
|
|
798
|
+
auth_note = "AWS Bedrock supports both API Key (Bearer Token) and IAM authentication.\nAPI Key is simpler, IAM provides more granular control."
|
|
799
|
+
if ENCRYPTION_AVAILABLE:
|
|
800
|
+
auth_note += "\n🔒 API keys are encrypted at rest for security."
|
|
801
|
+
else:
|
|
802
|
+
auth_note += "\n⚠️ API keys are stored in plain text. Install 'cryptography' for encryption."
|
|
803
|
+
|
|
804
|
+
note_label = ttk.Label(note_frame, text=auth_note, foreground="blue", font=('TkDefaultFont', 8))
|
|
805
|
+
note_label.pack(side=tk.LEFT)
|
|
806
|
+
|
|
807
|
+
# API Key row
|
|
808
|
+
self.api_key_row = ttk.Frame(self.aws_creds_frame)
|
|
809
|
+
self.api_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
810
|
+
|
|
811
|
+
ttk.Label(self.api_key_row, text="AWS Bedrock API Key:").pack(side=tk.LEFT)
|
|
812
|
+
# Get decrypted API key for display
|
|
813
|
+
decrypted_key = self.get_api_key_for_provider(provider_name, settings)
|
|
814
|
+
api_key_var = tk.StringVar(value=decrypted_key if decrypted_key else "")
|
|
815
|
+
api_key_entry = ttk.Entry(self.api_key_row, textvariable=api_key_var, show="*", width=40)
|
|
816
|
+
api_key_entry.pack(side=tk.LEFT, padx=(5, 0))
|
|
817
|
+
api_key_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
818
|
+
self.ai_widgets[provider_name]["API_KEY"] = api_key_var
|
|
819
|
+
|
|
820
|
+
# Get API Key link
|
|
821
|
+
get_key_link = ttk.Label(self.api_key_row, text="Get API Key", foreground="blue", cursor="hand2")
|
|
822
|
+
get_key_link.pack(side=tk.LEFT, padx=(10, 0))
|
|
823
|
+
get_key_link.bind("<Button-1>", lambda e: webbrowser.open("https://console.aws.amazon.com/bedrock/home"))
|
|
824
|
+
|
|
825
|
+
# Access Key ID row
|
|
826
|
+
self.access_key_row = ttk.Frame(self.aws_creds_frame)
|
|
827
|
+
self.access_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
828
|
+
|
|
829
|
+
ttk.Label(self.access_key_row, text="AWS Bedrock IAM Access ID:").pack(side=tk.LEFT)
|
|
830
|
+
# Get decrypted AWS Access Key for display
|
|
831
|
+
decrypted_access_key = self.get_aws_credential(settings, "AWS_ACCESS_KEY_ID")
|
|
832
|
+
access_key_var = tk.StringVar(value=decrypted_access_key)
|
|
833
|
+
access_key_entry = ttk.Entry(self.access_key_row, textvariable=access_key_var, show="*", width=30)
|
|
834
|
+
access_key_entry.pack(side=tk.LEFT, padx=(5, 0))
|
|
835
|
+
access_key_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
836
|
+
self.ai_widgets[provider_name]["AWS_ACCESS_KEY_ID"] = access_key_var
|
|
837
|
+
|
|
838
|
+
# Secret Access Key row
|
|
839
|
+
self.secret_key_row = ttk.Frame(self.aws_creds_frame)
|
|
840
|
+
self.secret_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
841
|
+
|
|
842
|
+
ttk.Label(self.secret_key_row, text="AWS Bedrock IAM Access Key:").pack(side=tk.LEFT)
|
|
843
|
+
# Get decrypted AWS Secret Key for display
|
|
844
|
+
decrypted_secret_key = self.get_aws_credential(settings, "AWS_SECRET_ACCESS_KEY")
|
|
845
|
+
secret_key_var = tk.StringVar(value=decrypted_secret_key)
|
|
846
|
+
secret_key_entry = ttk.Entry(self.secret_key_row, textvariable=secret_key_var, show="*", width=30)
|
|
847
|
+
secret_key_entry.pack(side=tk.LEFT, padx=(5, 0))
|
|
848
|
+
secret_key_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
849
|
+
self.ai_widgets[provider_name]["AWS_SECRET_ACCESS_KEY"] = secret_key_var
|
|
850
|
+
|
|
851
|
+
# Session Token row
|
|
852
|
+
self.session_token_row = ttk.Frame(self.aws_creds_frame)
|
|
853
|
+
self.session_token_row.pack(fill=tk.X, padx=5, pady=2)
|
|
854
|
+
|
|
855
|
+
ttk.Label(self.session_token_row, text="AWS Bedrock Session Token:").pack(side=tk.LEFT)
|
|
856
|
+
# Get decrypted AWS Session Token for display
|
|
857
|
+
decrypted_session_token = self.get_aws_credential(settings, "AWS_SESSION_TOKEN")
|
|
858
|
+
session_token_var = tk.StringVar(value=decrypted_session_token)
|
|
859
|
+
session_token_entry = ttk.Entry(self.session_token_row, textvariable=session_token_var, show="*", width=30)
|
|
860
|
+
session_token_entry.pack(side=tk.LEFT, padx=(5, 0))
|
|
861
|
+
session_token_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
862
|
+
self.ai_widgets[provider_name]["AWS_SESSION_TOKEN"] = session_token_var
|
|
863
|
+
|
|
864
|
+
# Content section (renamed from Model Configuration)
|
|
865
|
+
content_frame = ttk.LabelFrame(main_frame, text="Content")
|
|
866
|
+
content_frame.pack(fill=tk.X, pady=(5, 0))
|
|
867
|
+
|
|
868
|
+
content_row = ttk.Frame(content_frame)
|
|
869
|
+
content_row.pack(fill=tk.X, padx=5, pady=5)
|
|
870
|
+
|
|
871
|
+
# Context Window
|
|
872
|
+
ttk.Label(content_row, text="Model context window:").pack(side=tk.LEFT)
|
|
873
|
+
context_window_var = tk.StringVar(value=settings.get("CONTEXT_WINDOW", "8192"))
|
|
874
|
+
context_window_entry = ttk.Entry(content_row, textvariable=context_window_var, width=10)
|
|
875
|
+
context_window_entry.pack(side=tk.LEFT, padx=(5, 20))
|
|
876
|
+
context_window_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
877
|
+
self.ai_widgets[provider_name]["CONTEXT_WINDOW"] = context_window_var
|
|
878
|
+
|
|
879
|
+
# Max Output Tokens
|
|
880
|
+
ttk.Label(content_row, text="Model max output tokens:").pack(side=tk.LEFT)
|
|
881
|
+
max_output_tokens_var = tk.StringVar(value=settings.get("MAX_OUTPUT_TOKENS", "4096"))
|
|
882
|
+
max_output_tokens_entry = ttk.Entry(content_row, textvariable=max_output_tokens_var, width=10)
|
|
883
|
+
max_output_tokens_entry.pack(side=tk.LEFT, padx=(5, 0))
|
|
884
|
+
max_output_tokens_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
885
|
+
self.ai_widgets[provider_name]["MAX_OUTPUT_TOKENS"] = max_output_tokens_var
|
|
886
|
+
|
|
887
|
+
# Add IAM role info frame
|
|
888
|
+
self.iam_role_info_frame = ttk.Frame(self.aws_creds_frame)
|
|
889
|
+
self.iam_role_info_frame.pack(fill=tk.X, padx=5, pady=5)
|
|
890
|
+
|
|
891
|
+
info_label = ttk.Label(self.iam_role_info_frame,
|
|
892
|
+
text="IAM Role authentication uses the AWS credentials configured on this system.\nEnsure your AWS CLI is configured or EC2 instance has proper IAM role.",
|
|
893
|
+
foreground="gray")
|
|
894
|
+
info_label.pack(side=tk.LEFT)
|
|
895
|
+
|
|
896
|
+
# Initialize field visibility based on current auth method
|
|
897
|
+
# Use after_idle to ensure all widgets are created before updating visibility
|
|
898
|
+
self.after_idle(lambda: self.update_aws_credentials_fields(provider_name))
|
|
899
|
+
|
|
900
|
+
# Process button section
|
|
901
|
+
process_frame = ttk.Frame(top_frame)
|
|
902
|
+
process_frame.pack(side=tk.LEFT, fill=tk.Y)
|
|
903
|
+
|
|
904
|
+
ttk.Button(process_frame, text="Process",
|
|
905
|
+
command=self.run_ai_in_thread).pack(padx=5, pady=10)
|
|
906
|
+
|
|
907
|
+
# System prompt
|
|
908
|
+
system_frame = ttk.LabelFrame(main_frame, text="System Prompt")
|
|
909
|
+
system_frame.pack(fill=tk.X, pady=(0, 5))
|
|
910
|
+
|
|
911
|
+
system_prompt_key = "system_prompt"
|
|
912
|
+
if provider_name == "Anthropic AI":
|
|
913
|
+
system_prompt_key = "system"
|
|
914
|
+
elif provider_name == "Cohere AI":
|
|
915
|
+
system_prompt_key = "preamble"
|
|
916
|
+
|
|
917
|
+
system_text = tk.Text(system_frame, height=2, wrap=tk.WORD)
|
|
918
|
+
|
|
919
|
+
# Apply current font settings from main app
|
|
920
|
+
try:
|
|
921
|
+
if hasattr(self.app, 'get_best_font'):
|
|
922
|
+
text_font_family, text_font_size = self.app.get_best_font("text")
|
|
923
|
+
system_text.configure(font=(text_font_family, text_font_size))
|
|
924
|
+
except:
|
|
925
|
+
pass # Use default font if font settings not available
|
|
926
|
+
|
|
927
|
+
system_text.pack(fill=tk.X, padx=5, pady=3)
|
|
928
|
+
system_text.insert("1.0", settings.get(system_prompt_key, "You are a helpful assistant."))
|
|
929
|
+
|
|
930
|
+
self.ai_widgets[provider_name][system_prompt_key] = system_text
|
|
931
|
+
|
|
932
|
+
# Parameters notebook with minimal height to reduce empty space (skip for AWS Bedrock, LM Studio, and Azure AI)
|
|
933
|
+
# Note: Azure AI will use parameters, but we include it here since it uses standard OpenAI-style params
|
|
934
|
+
if provider_name not in ["AWS Bedrock", "LM Studio"]:
|
|
935
|
+
params_notebook = ttk.Notebook(main_frame)
|
|
936
|
+
# Much smaller height to eliminate wasted space - users can scroll if needed
|
|
937
|
+
params_notebook.pack(fill=tk.X, pady=(5, 0))
|
|
938
|
+
params_notebook.configure(height=120) # Significantly reduced height
|
|
939
|
+
|
|
940
|
+
# Create parameter tabs
|
|
941
|
+
self.create_parameter_tabs(params_notebook, provider_name, settings)
|
|
942
|
+
|
|
943
|
+
# Bind change events
|
|
944
|
+
model_var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
945
|
+
system_text.bind("<KeyRelease>", lambda *args: self.on_setting_change(provider_name))
|
|
946
|
+
|
|
947
|
+
def create_parameter_tabs(self, notebook, provider_name, settings):
|
|
948
|
+
"""Create parameter configuration tabs."""
|
|
949
|
+
# Get parameter configuration for this provider
|
|
950
|
+
params_config = self._get_ai_params_config(provider_name)
|
|
951
|
+
|
|
952
|
+
# Group parameters by tab
|
|
953
|
+
tabs_data = {}
|
|
954
|
+
for param, config in params_config.items():
|
|
955
|
+
tab_name = config.get("tab", "general")
|
|
956
|
+
if tab_name not in tabs_data:
|
|
957
|
+
tabs_data[tab_name] = {}
|
|
958
|
+
tabs_data[tab_name][param] = config
|
|
959
|
+
|
|
960
|
+
# Create tabs
|
|
961
|
+
for tab_name, params in tabs_data.items():
|
|
962
|
+
tab_frame = ttk.Frame(notebook)
|
|
963
|
+
notebook.add(tab_frame, text=tab_name.title())
|
|
964
|
+
|
|
965
|
+
# Create scrollable frame with improved scrolling
|
|
966
|
+
canvas = tk.Canvas(tab_frame, highlightthickness=0)
|
|
967
|
+
scrollbar = ttk.Scrollbar(tab_frame, orient="vertical", command=canvas.yview)
|
|
968
|
+
scrollable_frame = ttk.Frame(canvas)
|
|
969
|
+
|
|
970
|
+
def configure_scroll_region(event=None):
|
|
971
|
+
canvas.configure(scrollregion=canvas.bbox("all"))
|
|
972
|
+
|
|
973
|
+
def on_mousewheel(event):
|
|
974
|
+
# Handle cross-platform mouse wheel events
|
|
975
|
+
if event.delta:
|
|
976
|
+
# Windows
|
|
977
|
+
canvas.yview_scroll(int(-1*(event.delta/120)), "units")
|
|
978
|
+
else:
|
|
979
|
+
# Linux
|
|
980
|
+
if event.num == 4:
|
|
981
|
+
canvas.yview_scroll(-1, "units")
|
|
982
|
+
elif event.num == 5:
|
|
983
|
+
canvas.yview_scroll(1, "units")
|
|
984
|
+
|
|
985
|
+
scrollable_frame.bind("<Configure>", configure_scroll_region)
|
|
986
|
+
|
|
987
|
+
# Bind mouse wheel to canvas and scrollable frame (cross-platform)
|
|
988
|
+
canvas.bind("<MouseWheel>", on_mousewheel) # Windows
|
|
989
|
+
canvas.bind("<Button-4>", on_mousewheel) # Linux scroll up
|
|
990
|
+
canvas.bind("<Button-5>", on_mousewheel) # Linux scroll down
|
|
991
|
+
scrollable_frame.bind("<MouseWheel>", on_mousewheel)
|
|
992
|
+
scrollable_frame.bind("<Button-4>", on_mousewheel)
|
|
993
|
+
scrollable_frame.bind("<Button-5>", on_mousewheel)
|
|
994
|
+
|
|
995
|
+
# Make sure mouse wheel works when hovering over child widgets
|
|
996
|
+
def bind_mousewheel_to_children(widget):
|
|
997
|
+
widget.bind("<MouseWheel>", on_mousewheel)
|
|
998
|
+
widget.bind("<Button-4>", on_mousewheel)
|
|
999
|
+
widget.bind("<Button-5>", on_mousewheel)
|
|
1000
|
+
for child in widget.winfo_children():
|
|
1001
|
+
bind_mousewheel_to_children(child)
|
|
1002
|
+
|
|
1003
|
+
canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
|
|
1004
|
+
canvas.configure(yscrollcommand=scrollbar.set)
|
|
1005
|
+
|
|
1006
|
+
canvas.pack(side="left", fill="both", expand=True)
|
|
1007
|
+
scrollbar.pack(side="right", fill="y")
|
|
1008
|
+
|
|
1009
|
+
# Store references for later mouse wheel binding
|
|
1010
|
+
canvas._scrollable_frame = scrollable_frame
|
|
1011
|
+
canvas._bind_mousewheel_to_children = bind_mousewheel_to_children
|
|
1012
|
+
|
|
1013
|
+
# Add parameters to scrollable frame
|
|
1014
|
+
row = 0
|
|
1015
|
+
for param, config in params.items():
|
|
1016
|
+
self.create_parameter_widget(scrollable_frame, provider_name, param, config, settings, row)
|
|
1017
|
+
row += 1
|
|
1018
|
+
|
|
1019
|
+
# Bind mouse wheel to all child widgets after they're created
|
|
1020
|
+
canvas._bind_mousewheel_to_children(scrollable_frame)
|
|
1021
|
+
|
|
1022
|
+
def create_parameter_widget(self, parent, provider_name, param, config, settings, row):
|
|
1023
|
+
"""Create a widget for a specific parameter."""
|
|
1024
|
+
# Label
|
|
1025
|
+
ttk.Label(parent, text=param.replace("_", " ").title() + ":").grid(row=row, column=0, sticky="w", padx=(5, 10), pady=2)
|
|
1026
|
+
|
|
1027
|
+
# Get current value
|
|
1028
|
+
current_value = settings.get(param, "")
|
|
1029
|
+
|
|
1030
|
+
# Create appropriate widget based on type
|
|
1031
|
+
if config["type"] == "scale":
|
|
1032
|
+
var = tk.DoubleVar(value=float(current_value) if current_value else config["range"][0])
|
|
1033
|
+
scale = ttk.Scale(parent, from_=config["range"][0], to=config["range"][1],
|
|
1034
|
+
variable=var, orient=tk.HORIZONTAL, length=200)
|
|
1035
|
+
scale.grid(row=row, column=1, sticky="ew", padx=(0, 10), pady=2)
|
|
1036
|
+
|
|
1037
|
+
# Value label
|
|
1038
|
+
value_label = ttk.Label(parent, text=f"{var.get():.2f}")
|
|
1039
|
+
value_label.grid(row=row, column=2, padx=(0, 5), pady=2)
|
|
1040
|
+
|
|
1041
|
+
# Update label when scale changes
|
|
1042
|
+
def update_label(*args):
|
|
1043
|
+
value_label.config(text=f"{var.get():.2f}")
|
|
1044
|
+
self.on_setting_change(provider_name)
|
|
1045
|
+
|
|
1046
|
+
var.trace_add("write", update_label)
|
|
1047
|
+
|
|
1048
|
+
elif config["type"] == "combo":
|
|
1049
|
+
var = tk.StringVar(value=current_value)
|
|
1050
|
+
combo = ttk.Combobox(parent, textvariable=var, values=config["values"], width=20)
|
|
1051
|
+
combo.grid(row=row, column=1, sticky="ew", padx=(0, 10), pady=2)
|
|
1052
|
+
var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
1053
|
+
|
|
1054
|
+
elif config["type"] == "checkbox":
|
|
1055
|
+
# Convert string values to boolean for checkbox
|
|
1056
|
+
if isinstance(current_value, str):
|
|
1057
|
+
checkbox_value = current_value.lower() in ('true', '1', 'yes', 'on')
|
|
1058
|
+
else:
|
|
1059
|
+
checkbox_value = bool(current_value)
|
|
1060
|
+
|
|
1061
|
+
var = tk.BooleanVar(value=checkbox_value)
|
|
1062
|
+
checkbox = ttk.Checkbutton(parent, variable=var)
|
|
1063
|
+
checkbox.grid(row=row, column=1, sticky="w", padx=(0, 10), pady=2)
|
|
1064
|
+
var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
1065
|
+
|
|
1066
|
+
else: # entry
|
|
1067
|
+
var = tk.StringVar(value=current_value)
|
|
1068
|
+
entry = ttk.Entry(parent, textvariable=var, width=30)
|
|
1069
|
+
entry.grid(row=row, column=1, sticky="ew", padx=(0, 10), pady=2)
|
|
1070
|
+
var.trace_add("write", lambda *args: self.on_setting_change(provider_name))
|
|
1071
|
+
|
|
1072
|
+
# Store widget reference
|
|
1073
|
+
self.ai_widgets[provider_name][param] = var
|
|
1074
|
+
|
|
1075
|
+
# Tooltip
|
|
1076
|
+
if "tip" in config:
|
|
1077
|
+
self.create_tooltip(parent.grid_slaves(row=row, column=1)[0], config["tip"])
|
|
1078
|
+
|
|
1079
|
+
# Configure column weights
|
|
1080
|
+
parent.columnconfigure(1, weight=1)
|
|
1081
|
+
|
|
1082
|
+
def create_tooltip(self, widget, text):
|
|
1083
|
+
"""Create a tooltip for a widget with proper delay."""
|
|
1084
|
+
tooltip_window = None
|
|
1085
|
+
tooltip_timer = None
|
|
1086
|
+
|
|
1087
|
+
def show_tooltip_delayed():
|
|
1088
|
+
nonlocal tooltip_window
|
|
1089
|
+
if tooltip_window is None:
|
|
1090
|
+
x, y = widget.winfo_rootx() + 25, widget.winfo_rooty() + 25
|
|
1091
|
+
tooltip_window = tk.Toplevel()
|
|
1092
|
+
tooltip_window.wm_overrideredirect(True)
|
|
1093
|
+
tooltip_window.wm_geometry(f"+{x}+{y}")
|
|
1094
|
+
|
|
1095
|
+
label = ttk.Label(tooltip_window, text=text, background="#ffffe0",
|
|
1096
|
+
relief="solid", borderwidth=1, wraplength=250)
|
|
1097
|
+
label.pack()
|
|
1098
|
+
|
|
1099
|
+
def on_enter(event):
|
|
1100
|
+
nonlocal tooltip_timer
|
|
1101
|
+
# Cancel any existing timer
|
|
1102
|
+
if tooltip_timer:
|
|
1103
|
+
widget.after_cancel(tooltip_timer)
|
|
1104
|
+
# Start new timer with 750ms delay (standard for most applications)
|
|
1105
|
+
tooltip_timer = widget.after(750, show_tooltip_delayed)
|
|
1106
|
+
|
|
1107
|
+
def on_leave(event):
|
|
1108
|
+
nonlocal tooltip_window, tooltip_timer
|
|
1109
|
+
# Cancel the timer if we leave before tooltip shows
|
|
1110
|
+
if tooltip_timer:
|
|
1111
|
+
widget.after_cancel(tooltip_timer)
|
|
1112
|
+
tooltip_timer = None
|
|
1113
|
+
# Hide tooltip if it's showing
|
|
1114
|
+
if tooltip_window:
|
|
1115
|
+
tooltip_window.destroy()
|
|
1116
|
+
tooltip_window = None
|
|
1117
|
+
|
|
1118
|
+
widget.bind("<Enter>", on_enter)
|
|
1119
|
+
widget.bind("<Leave>", on_leave)
|
|
1120
|
+
|
|
1121
|
+
def on_setting_change(self, provider_name):
|
|
1122
|
+
"""Handle setting changes for a provider."""
|
|
1123
|
+
try:
|
|
1124
|
+
# Update settings in parent app
|
|
1125
|
+
if provider_name not in self.app.settings["tool_settings"]:
|
|
1126
|
+
self.app.settings["tool_settings"][provider_name] = {}
|
|
1127
|
+
|
|
1128
|
+
# Collect all widget values first
|
|
1129
|
+
updated_settings = {}
|
|
1130
|
+
for param, widget in self.ai_widgets[provider_name].items():
|
|
1131
|
+
if isinstance(widget, tk.Text):
|
|
1132
|
+
value = widget.get("1.0", tk.END).strip()
|
|
1133
|
+
else:
|
|
1134
|
+
value = widget.get()
|
|
1135
|
+
|
|
1136
|
+
# Encrypt sensitive credentials before saving (except for LM Studio)
|
|
1137
|
+
if provider_name != "LM Studio" and param in ["API_KEY", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]:
|
|
1138
|
+
if value and value != "putinyourkey":
|
|
1139
|
+
value = encrypt_api_key(value)
|
|
1140
|
+
|
|
1141
|
+
updated_settings[param] = value
|
|
1142
|
+
|
|
1143
|
+
# Update settings using database manager directly for better reliability
|
|
1144
|
+
if hasattr(self.app, 'db_settings_manager') and self.app.db_settings_manager:
|
|
1145
|
+
# Use database manager's tool setting method for atomic updates
|
|
1146
|
+
for param, value in updated_settings.items():
|
|
1147
|
+
self.app.db_settings_manager.set_tool_setting(provider_name, param, value)
|
|
1148
|
+
|
|
1149
|
+
self.logger.info(f"Saved {len(updated_settings)} settings for {provider_name} via database manager")
|
|
1150
|
+
else:
|
|
1151
|
+
# Fallback to proxy method
|
|
1152
|
+
for param, value in updated_settings.items():
|
|
1153
|
+
self.app.settings["tool_settings"][provider_name][param] = value
|
|
1154
|
+
|
|
1155
|
+
# Force save
|
|
1156
|
+
self.app.save_settings()
|
|
1157
|
+
self.logger.info(f"Saved {len(updated_settings)} settings for {provider_name} via proxy")
|
|
1158
|
+
|
|
1159
|
+
except Exception as e:
|
|
1160
|
+
self.logger.error(f"Failed to save settings for {provider_name}: {e}", exc_info=True)
|
|
1161
|
+
# Show user-friendly error
|
|
1162
|
+
self._show_info("Error", f"Failed to save {provider_name} settings: {str(e)}", "error")
|
|
1163
|
+
|
|
1164
|
+
def refresh_lm_studio_models(self, provider_name):
|
|
1165
|
+
"""Refresh the model list from LM Studio server."""
|
|
1166
|
+
if provider_name != "LM Studio":
|
|
1167
|
+
return
|
|
1168
|
+
|
|
1169
|
+
base_url = self.ai_widgets[provider_name]["BASE_URL"].get().strip()
|
|
1170
|
+
if not base_url:
|
|
1171
|
+
self._show_error("Error", "Please enter a valid Base URL")
|
|
1172
|
+
return
|
|
1173
|
+
|
|
1174
|
+
try:
|
|
1175
|
+
# Remove trailing slash if present
|
|
1176
|
+
base_url = base_url.rstrip('/')
|
|
1177
|
+
models_url = f"{base_url}/v1/models"
|
|
1178
|
+
|
|
1179
|
+
response = requests.get(models_url, timeout=10)
|
|
1180
|
+
response.raise_for_status()
|
|
1181
|
+
|
|
1182
|
+
data = response.json()
|
|
1183
|
+
models = [model["id"] for model in data.get("data", [])]
|
|
1184
|
+
|
|
1185
|
+
if models:
|
|
1186
|
+
# Update the model combobox using stored reference
|
|
1187
|
+
model_combo = self.ai_widgets[provider_name].get("MODEL_COMBO")
|
|
1188
|
+
if model_combo:
|
|
1189
|
+
model_combo.configure(values=models)
|
|
1190
|
+
# Set first model as default if no model is currently selected
|
|
1191
|
+
if models and not self.ai_widgets[provider_name]["MODEL"].get():
|
|
1192
|
+
self.ai_widgets[provider_name]["MODEL"].set(models[0])
|
|
1193
|
+
|
|
1194
|
+
# Update settings
|
|
1195
|
+
self.app.settings["tool_settings"][provider_name]["MODELS_LIST"] = models
|
|
1196
|
+
self.app.save_settings()
|
|
1197
|
+
|
|
1198
|
+
self._show_info("Success", f"Found {len(models)} models from LM Studio")
|
|
1199
|
+
else:
|
|
1200
|
+
self._show_warning("Warning", "No models found. Make sure LM Studio is running and has models loaded.")
|
|
1201
|
+
|
|
1202
|
+
except requests.exceptions.RequestException as e:
|
|
1203
|
+
self._show_error("Connection Error", f"Could not connect to LM Studio at {base_url}\n\nError: {e}\n\nMake sure LM Studio is running and the Base URL is correct.")
|
|
1204
|
+
except Exception as e:
|
|
1205
|
+
self._show_error("Error", f"Error refreshing models: {e}")
|
|
1206
|
+
|
|
1207
|
+
def refresh_google_ai_models(self, provider_name):
|
|
1208
|
+
"""Refresh the model list from Google AI (Gemini) API."""
|
|
1209
|
+
if provider_name != "Google AI":
|
|
1210
|
+
return
|
|
1211
|
+
|
|
1212
|
+
settings = self.get_current_settings()
|
|
1213
|
+
api_key = self.get_api_key_for_provider(provider_name, settings)
|
|
1214
|
+
|
|
1215
|
+
if not api_key or api_key == "putinyourkey":
|
|
1216
|
+
self._show_error("Error", "Please enter a valid Google AI API key first")
|
|
1217
|
+
return
|
|
1218
|
+
|
|
1219
|
+
try:
|
|
1220
|
+
# Google AI models endpoint
|
|
1221
|
+
models_url = f"https://generativelanguage.googleapis.com/v1beta/models?key={api_key}"
|
|
1222
|
+
|
|
1223
|
+
response = requests.get(models_url, timeout=30)
|
|
1224
|
+
response.raise_for_status()
|
|
1225
|
+
|
|
1226
|
+
data = response.json()
|
|
1227
|
+
models = []
|
|
1228
|
+
|
|
1229
|
+
# Filter for generative models (not embedding models)
|
|
1230
|
+
for model in data.get("models", []):
|
|
1231
|
+
model_name = model.get("name", "")
|
|
1232
|
+
# Remove "models/" prefix
|
|
1233
|
+
if model_name.startswith("models/"):
|
|
1234
|
+
model_name = model_name[7:]
|
|
1235
|
+
|
|
1236
|
+
# Filter for text generation models (gemini models)
|
|
1237
|
+
supported_methods = model.get("supportedGenerationMethods", [])
|
|
1238
|
+
if "generateContent" in supported_methods:
|
|
1239
|
+
models.append(model_name)
|
|
1240
|
+
|
|
1241
|
+
if models:
|
|
1242
|
+
# Sort models (prefer newer versions)
|
|
1243
|
+
models.sort(reverse=True)
|
|
1244
|
+
|
|
1245
|
+
# Update the model combobox
|
|
1246
|
+
if provider_name in self.ai_widgets and "MODEL" in self.ai_widgets[provider_name]:
|
|
1247
|
+
# Find and update the combobox in the tab
|
|
1248
|
+
for provider, tab_frame in self.tabs.items():
|
|
1249
|
+
if provider == provider_name:
|
|
1250
|
+
# Update the model variable and refresh the UI
|
|
1251
|
+
self.ai_widgets[provider_name]["MODEL"].set(models[0] if models else "")
|
|
1252
|
+
for widget in tab_frame.winfo_children():
|
|
1253
|
+
widget.destroy()
|
|
1254
|
+
self.create_provider_widgets(tab_frame, provider_name)
|
|
1255
|
+
break
|
|
1256
|
+
|
|
1257
|
+
# Update settings
|
|
1258
|
+
self.app.settings["tool_settings"][provider_name]["MODELS_LIST"] = models
|
|
1259
|
+
self.app.settings["tool_settings"][provider_name]["MODEL"] = models[0] if models else ""
|
|
1260
|
+
self.app.save_settings()
|
|
1261
|
+
|
|
1262
|
+
self._show_info("Success", f"Found {len(models)} generative models from Google AI")
|
|
1263
|
+
else:
|
|
1264
|
+
self._show_warning("Warning", "No generative models found. Check your API key permissions.")
|
|
1265
|
+
|
|
1266
|
+
except requests.exceptions.RequestException as e:
|
|
1267
|
+
error_msg = f"Could not connect to Google AI API\n\nError: {e}"
|
|
1268
|
+
if hasattr(e, 'response') and e.response is not None:
|
|
1269
|
+
try:
|
|
1270
|
+
error_detail = e.response.json()
|
|
1271
|
+
error_msg += f"\n\nDetails: {json.dumps(error_detail, indent=2)}"
|
|
1272
|
+
except:
|
|
1273
|
+
error_msg += f"\n\nResponse: {e.response.text}"
|
|
1274
|
+
self._show_error("Connection Error", error_msg)
|
|
1275
|
+
except Exception as e:
|
|
1276
|
+
self._show_error("Error", f"Error refreshing Google AI models: {e}")
|
|
1277
|
+
|
|
1278
|
+
def refresh_openrouter_models(self, provider_name):
|
|
1279
|
+
"""Refresh the model list from OpenRouter API."""
|
|
1280
|
+
if provider_name != "OpenRouterAI":
|
|
1281
|
+
return
|
|
1282
|
+
|
|
1283
|
+
settings = self.get_current_settings()
|
|
1284
|
+
api_key = self.get_api_key_for_provider(provider_name, settings)
|
|
1285
|
+
|
|
1286
|
+
# OpenRouter models endpoint is public, but API key is recommended
|
|
1287
|
+
try:
|
|
1288
|
+
headers = {"Content-Type": "application/json"}
|
|
1289
|
+
if api_key and api_key != "putinyourkey":
|
|
1290
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
1291
|
+
|
|
1292
|
+
# OpenRouter models endpoint
|
|
1293
|
+
models_url = "https://openrouter.ai/api/v1/models"
|
|
1294
|
+
|
|
1295
|
+
response = requests.get(models_url, headers=headers, timeout=30)
|
|
1296
|
+
response.raise_for_status()
|
|
1297
|
+
|
|
1298
|
+
data = response.json()
|
|
1299
|
+
models = []
|
|
1300
|
+
|
|
1301
|
+
# Extract model IDs from the response
|
|
1302
|
+
for model in data.get("data", []):
|
|
1303
|
+
model_id = model.get("id", "")
|
|
1304
|
+
if model_id:
|
|
1305
|
+
models.append(model_id)
|
|
1306
|
+
|
|
1307
|
+
if models:
|
|
1308
|
+
# Sort models alphabetically
|
|
1309
|
+
models.sort()
|
|
1310
|
+
|
|
1311
|
+
# Update the model combobox
|
|
1312
|
+
if provider_name in self.ai_widgets and "MODEL" in self.ai_widgets[provider_name]:
|
|
1313
|
+
# Find and update the combobox in the tab
|
|
1314
|
+
for provider, tab_frame in self.tabs.items():
|
|
1315
|
+
if provider == provider_name:
|
|
1316
|
+
# Update the model variable and refresh the UI
|
|
1317
|
+
current_model = self.ai_widgets[provider_name]["MODEL"].get()
|
|
1318
|
+
if not current_model or current_model not in models:
|
|
1319
|
+
self.ai_widgets[provider_name]["MODEL"].set(models[0] if models else "")
|
|
1320
|
+
for widget in tab_frame.winfo_children():
|
|
1321
|
+
widget.destroy()
|
|
1322
|
+
self.create_provider_widgets(tab_frame, provider_name)
|
|
1323
|
+
break
|
|
1324
|
+
|
|
1325
|
+
# Update settings
|
|
1326
|
+
self.app.settings["tool_settings"][provider_name]["MODELS_LIST"] = models
|
|
1327
|
+
if not self.app.settings["tool_settings"].get(provider_name, {}).get("MODEL"):
|
|
1328
|
+
self.app.settings["tool_settings"][provider_name]["MODEL"] = models[0] if models else ""
|
|
1329
|
+
self.app.save_settings()
|
|
1330
|
+
|
|
1331
|
+
self._show_info("Success", f"Found {len(models)} models from OpenRouter")
|
|
1332
|
+
else:
|
|
1333
|
+
self._show_warning("Warning", "No models found from OpenRouter.")
|
|
1334
|
+
|
|
1335
|
+
except requests.exceptions.RequestException as e:
|
|
1336
|
+
error_msg = f"Could not connect to OpenRouter API\n\nError: {e}"
|
|
1337
|
+
if hasattr(e, 'response') and e.response is not None:
|
|
1338
|
+
try:
|
|
1339
|
+
error_detail = e.response.json()
|
|
1340
|
+
error_msg += f"\n\nDetails: {json.dumps(error_detail, indent=2)}"
|
|
1341
|
+
except:
|
|
1342
|
+
error_msg += f"\n\nResponse: {e.response.text}"
|
|
1343
|
+
self._show_error("Connection Error", error_msg)
|
|
1344
|
+
except Exception as e:
|
|
1345
|
+
self._show_error("Error", f"Error refreshing OpenRouter models: {e}")
|
|
1346
|
+
|
|
1347
|
+
def refresh_bedrock_models(self, provider_name):
|
|
1348
|
+
"""Refresh the model list from AWS Bedrock ListFoundationModels API."""
|
|
1349
|
+
if provider_name != "AWS Bedrock":
|
|
1350
|
+
return
|
|
1351
|
+
|
|
1352
|
+
settings = self.app.settings["tool_settings"].get(provider_name, {})
|
|
1353
|
+
auth_method = settings.get("AUTH_METHOD", "api_key")
|
|
1354
|
+
region = settings.get("AWS_REGION", "us-west-2")
|
|
1355
|
+
|
|
1356
|
+
# AWS Bedrock ListFoundationModels API requires AWS IAM credentials
|
|
1357
|
+
access_key = self.get_aws_credential(settings, "AWS_ACCESS_KEY_ID")
|
|
1358
|
+
secret_key = self.get_aws_credential(settings, "AWS_SECRET_ACCESS_KEY")
|
|
1359
|
+
|
|
1360
|
+
if not access_key or not secret_key:
|
|
1361
|
+
self._show_error("Error", "Please enter your AWS IAM credentials (Access Key ID and Secret Access Key) first")
|
|
1362
|
+
return
|
|
1363
|
+
|
|
1364
|
+
try:
|
|
1365
|
+
# Build ListFoundationModels API URL
|
|
1366
|
+
list_models_url = f"https://bedrock.{region}.amazonaws.com/foundation-models"
|
|
1367
|
+
|
|
1368
|
+
# Always use AWS SigV4 signing for ListFoundationModels API
|
|
1369
|
+
session_token = self.get_aws_credential(settings, "AWS_SESSION_TOKEN") if auth_method == "sessionToken" else None
|
|
1370
|
+
|
|
1371
|
+
# Sign the request (GET method, empty payload)
|
|
1372
|
+
signed_headers = self.sign_aws_request(
|
|
1373
|
+
"GET", list_models_url, "", access_key, secret_key,
|
|
1374
|
+
session_token, region, "bedrock"
|
|
1375
|
+
)
|
|
1376
|
+
|
|
1377
|
+
# Make the API request with signed headers
|
|
1378
|
+
response = requests.get(list_models_url, headers=signed_headers, timeout=30)
|
|
1379
|
+
response.raise_for_status()
|
|
1380
|
+
|
|
1381
|
+
data = response.json()
|
|
1382
|
+
models = []
|
|
1383
|
+
|
|
1384
|
+
# Extract model IDs from the response, filtering out embedding and image models
|
|
1385
|
+
models = []
|
|
1386
|
+
if "modelSummaries" in data:
|
|
1387
|
+
for model in data["modelSummaries"]:
|
|
1388
|
+
model_id = model.get("modelId", "")
|
|
1389
|
+
model_name = model.get("modelName", "")
|
|
1390
|
+
|
|
1391
|
+
# Filter out embedding models and image generation models
|
|
1392
|
+
# Embedding models: contain "embed" in ID or name
|
|
1393
|
+
# Image models: contain "image", "stable-diffusion", "titan-image", "nova-canvas", "nova-reel"
|
|
1394
|
+
if model_id and not any(keyword in model_id.lower() for keyword in [
|
|
1395
|
+
"embed", "embedding", "image", "stable-diffusion",
|
|
1396
|
+
"titan-image", "nova-canvas", "nova-reel", "nova-sonic"
|
|
1397
|
+
]):
|
|
1398
|
+
# Also check model name for additional filtering
|
|
1399
|
+
if not any(keyword in model_name.lower() for keyword in [
|
|
1400
|
+
"embed", "embedding", "image", "vision"
|
|
1401
|
+
]):
|
|
1402
|
+
models.append(model_id)
|
|
1403
|
+
|
|
1404
|
+
if models:
|
|
1405
|
+
# Add inference profile versions for models that require them
|
|
1406
|
+
enhanced_models = []
|
|
1407
|
+
inference_profile_mapping = {
|
|
1408
|
+
# Claude 4.5 models (newest)
|
|
1409
|
+
"anthropic.claude-haiku-4-5-20251001-v1:0": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
|
1410
|
+
"anthropic.claude-sonnet-4-5-20250929-v1:0": "global.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
1411
|
+
|
|
1412
|
+
# Claude 4.1 models
|
|
1413
|
+
"anthropic.claude-opus-4-1-20250805-v1:0": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
|
1414
|
+
|
|
1415
|
+
# Claude 3.7 models
|
|
1416
|
+
"anthropic.claude-3-7-sonnet-20250219-v1:0": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
1417
|
+
|
|
1418
|
+
# Claude 3.5 models (v2)
|
|
1419
|
+
"anthropic.claude-3-5-haiku-20241022-v1:0": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
|
1420
|
+
"anthropic.claude-3-5-sonnet-20241022-v2:0": "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
1421
|
+
|
|
1422
|
+
# Claude 3.5 models (v1)
|
|
1423
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0": "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
1424
|
+
|
|
1425
|
+
# Claude 3 models (original)
|
|
1426
|
+
"anthropic.claude-3-opus-20240229-v1:0": "us.anthropic.claude-3-opus-20240229-v1:0",
|
|
1427
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": "us.anthropic.claude-3-sonnet-20240229-v1:0",
|
|
1428
|
+
"anthropic.claude-3-haiku-20240307-v1:0": "us.anthropic.claude-3-haiku-20240307-v1:0"
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
for model_id in models:
|
|
1432
|
+
enhanced_models.append(model_id)
|
|
1433
|
+
# If this model has an inference profile, add it as an option too
|
|
1434
|
+
if model_id in inference_profile_mapping:
|
|
1435
|
+
profile_id = inference_profile_mapping[model_id]
|
|
1436
|
+
enhanced_models.append(f"{profile_id} (Inference Profile)")
|
|
1437
|
+
|
|
1438
|
+
# Update the model combobox
|
|
1439
|
+
model_combo = self.ai_widgets[provider_name].get("MODEL_COMBO")
|
|
1440
|
+
if model_combo:
|
|
1441
|
+
model_combo.configure(values=enhanced_models)
|
|
1442
|
+
# Set first model as default if no model is currently selected
|
|
1443
|
+
if enhanced_models and not self.ai_widgets[provider_name]["MODEL"].get():
|
|
1444
|
+
self.ai_widgets[provider_name]["MODEL"].set(enhanced_models[0])
|
|
1445
|
+
|
|
1446
|
+
# Update settings (store the enhanced list)
|
|
1447
|
+
self.app.settings["tool_settings"][provider_name]["MODELS_LIST"] = enhanced_models
|
|
1448
|
+
self.app.save_settings()
|
|
1449
|
+
|
|
1450
|
+
profile_count = len([m for m in enhanced_models if "Inference Profile" in m])
|
|
1451
|
+
self._show_info("Success", f"Found {len(models)} models from AWS Bedrock ({profile_count} with inference profiles)")
|
|
1452
|
+
else:
|
|
1453
|
+
self._show_warning("Warning", "No models found. Please check your credentials and region.")
|
|
1454
|
+
|
|
1455
|
+
except requests.exceptions.RequestException as e:
|
|
1456
|
+
error_msg = f"Could not connect to AWS Bedrock API\n\nError: {e}"
|
|
1457
|
+
if hasattr(e, 'response') and e.response is not None:
|
|
1458
|
+
try:
|
|
1459
|
+
error_data = e.response.json()
|
|
1460
|
+
if "message" in error_data:
|
|
1461
|
+
error_msg += f"\n\nAWS Error: {error_data['message']}"
|
|
1462
|
+
except:
|
|
1463
|
+
error_msg += f"\n\nHTTP {e.response.status_code}: {e.response.text}"
|
|
1464
|
+
self._show_error("Connection Error", error_msg)
|
|
1465
|
+
except Exception as e:
|
|
1466
|
+
self._show_error("Error", f"Error refreshing models: {e}")
|
|
1467
|
+
|
|
1468
|
+
def update_aws_credentials_fields(self, provider_name):
|
|
1469
|
+
"""Update AWS credentials field visibility based on authentication method."""
|
|
1470
|
+
if provider_name != "AWS Bedrock" or not hasattr(self, 'aws_creds_frame'):
|
|
1471
|
+
self.logger.debug(f"Skipping AWS credentials field update: provider={provider_name}, has_frame={hasattr(self, 'aws_creds_frame')}")
|
|
1472
|
+
return
|
|
1473
|
+
|
|
1474
|
+
# Get the stored value from settings
|
|
1475
|
+
stored_auth = self.app.settings["tool_settings"].get(provider_name, {}).get("AUTH_METHOD", "api_key")
|
|
1476
|
+
self.logger.debug(f"AWS Bedrock auth method: {stored_auth}")
|
|
1477
|
+
|
|
1478
|
+
# Hide all credential fields first
|
|
1479
|
+
fields_to_hide = ['api_key_row', 'access_key_row', 'secret_key_row', 'session_token_row', 'iam_role_info_frame']
|
|
1480
|
+
for field_name in fields_to_hide:
|
|
1481
|
+
if hasattr(self, field_name):
|
|
1482
|
+
try:
|
|
1483
|
+
getattr(self, field_name).pack_forget()
|
|
1484
|
+
except Exception as e:
|
|
1485
|
+
self.logger.debug(f"Error hiding {field_name}: {e}")
|
|
1486
|
+
|
|
1487
|
+
# Show fields based on authentication method
|
|
1488
|
+
try:
|
|
1489
|
+
if stored_auth == "api_key": # API Key (Bearer Token)
|
|
1490
|
+
if hasattr(self, 'api_key_row'):
|
|
1491
|
+
self.api_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
1492
|
+
self.logger.debug("Showing API key field")
|
|
1493
|
+
else:
|
|
1494
|
+
self.logger.warning("API key row not found!")
|
|
1495
|
+
elif stored_auth == "iam": # IAM (Explicit Credentials)
|
|
1496
|
+
if hasattr(self, 'access_key_row') and hasattr(self, 'secret_key_row'):
|
|
1497
|
+
self.access_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
1498
|
+
self.secret_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
1499
|
+
self.logger.debug("Showing IAM credential fields")
|
|
1500
|
+
elif stored_auth == "sessionToken": # Session Token (Temporary Credentials)
|
|
1501
|
+
if hasattr(self, 'access_key_row') and hasattr(self, 'secret_key_row') and hasattr(self, 'session_token_row'):
|
|
1502
|
+
self.access_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
1503
|
+
self.secret_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
1504
|
+
self.session_token_row.pack(fill=tk.X, padx=5, pady=2)
|
|
1505
|
+
self.logger.debug("Showing session token credential fields")
|
|
1506
|
+
elif stored_auth == "iam_role": # IAM (Implied Credentials)
|
|
1507
|
+
if hasattr(self, 'iam_role_info_frame'):
|
|
1508
|
+
self.iam_role_info_frame.pack(fill=tk.X, padx=5, pady=5)
|
|
1509
|
+
self.logger.debug("Showing IAM role info")
|
|
1510
|
+
else:
|
|
1511
|
+
self.logger.warning(f"Unknown auth method: {stored_auth}, defaulting to API key")
|
|
1512
|
+
if hasattr(self, 'api_key_row'):
|
|
1513
|
+
self.api_key_row.pack(fill=tk.X, padx=5, pady=2)
|
|
1514
|
+
except Exception as e:
|
|
1515
|
+
self.logger.error(f"Error updating AWS credentials fields: {e}", exc_info=True)
|
|
1516
|
+
|
|
1517
|
+
def on_aws_auth_change(self, provider_name):
|
|
1518
|
+
"""Handle AWS authentication method change and convert display name to stored value."""
|
|
1519
|
+
if provider_name != "AWS Bedrock":
|
|
1520
|
+
return
|
|
1521
|
+
|
|
1522
|
+
display_value = self.ai_widgets[provider_name]["AUTH_METHOD"].get()
|
|
1523
|
+
|
|
1524
|
+
# Convert display name to stored value
|
|
1525
|
+
if display_value == "API Key (Bearer Token)":
|
|
1526
|
+
stored_value = "api_key"
|
|
1527
|
+
elif display_value == "IAM (Explicit Credentials)":
|
|
1528
|
+
stored_value = "iam"
|
|
1529
|
+
elif display_value == "Session Token (Temporary Credentials)":
|
|
1530
|
+
stored_value = "sessionToken"
|
|
1531
|
+
elif display_value == "IAM (Implied Credentials)":
|
|
1532
|
+
stored_value = "iam_role"
|
|
1533
|
+
else:
|
|
1534
|
+
stored_value = "api_key" # default
|
|
1535
|
+
|
|
1536
|
+
# Update settings with the stored value
|
|
1537
|
+
if provider_name not in self.app.settings["tool_settings"]:
|
|
1538
|
+
self.app.settings["tool_settings"][provider_name] = {}
|
|
1539
|
+
|
|
1540
|
+
self.app.settings["tool_settings"][provider_name]["AUTH_METHOD"] = stored_value
|
|
1541
|
+
self.app.save_settings()
|
|
1542
|
+
|
|
1543
|
+
def sign_aws_request(self, method, url, payload, access_key, secret_key, session_token=None, region="us-west-2", service="bedrock"):
|
|
1544
|
+
"""Sign AWS request using Signature Version 4."""
|
|
1545
|
+
try:
|
|
1546
|
+
# Parse URL
|
|
1547
|
+
parsed_url = urllib.parse.urlparse(url)
|
|
1548
|
+
host = parsed_url.netloc
|
|
1549
|
+
path = parsed_url.path
|
|
1550
|
+
|
|
1551
|
+
# Create timestamp
|
|
1552
|
+
t = datetime.utcnow()
|
|
1553
|
+
amz_date = t.strftime('%Y%m%dT%H%M%SZ')
|
|
1554
|
+
date_stamp = t.strftime('%Y%m%d')
|
|
1555
|
+
|
|
1556
|
+
# Create canonical request
|
|
1557
|
+
canonical_uri = path
|
|
1558
|
+
canonical_querystring = ''
|
|
1559
|
+
canonical_headers = f'host:{host}\nx-amz-date:{amz_date}\n'
|
|
1560
|
+
signed_headers = 'host;x-amz-date'
|
|
1561
|
+
|
|
1562
|
+
if session_token:
|
|
1563
|
+
canonical_headers += f'x-amz-security-token:{session_token}\n'
|
|
1564
|
+
signed_headers += ';x-amz-security-token'
|
|
1565
|
+
|
|
1566
|
+
payload_hash = hashlib.sha256(payload.encode('utf-8')).hexdigest()
|
|
1567
|
+
canonical_request = f'{method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{payload_hash}'
|
|
1568
|
+
|
|
1569
|
+
# Create string to sign
|
|
1570
|
+
algorithm = 'AWS4-HMAC-SHA256'
|
|
1571
|
+
credential_scope = f'{date_stamp}/{region}/{service}/aws4_request'
|
|
1572
|
+
string_to_sign = f'{algorithm}\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}'
|
|
1573
|
+
|
|
1574
|
+
# Calculate signature
|
|
1575
|
+
def sign(key, msg):
|
|
1576
|
+
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
|
|
1577
|
+
|
|
1578
|
+
def get_signature_key(key, date_stamp, region_name, service_name):
|
|
1579
|
+
k_date = sign(('AWS4' + key).encode('utf-8'), date_stamp)
|
|
1580
|
+
k_region = sign(k_date, region_name)
|
|
1581
|
+
k_service = sign(k_region, service_name)
|
|
1582
|
+
k_signing = sign(k_service, 'aws4_request')
|
|
1583
|
+
return k_signing
|
|
1584
|
+
|
|
1585
|
+
signing_key = get_signature_key(secret_key, date_stamp, region, service)
|
|
1586
|
+
signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest()
|
|
1587
|
+
|
|
1588
|
+
# Create authorization header
|
|
1589
|
+
authorization_header = f'{algorithm} Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}'
|
|
1590
|
+
|
|
1591
|
+
# Build headers
|
|
1592
|
+
headers = {
|
|
1593
|
+
'Content-Type': 'application/json',
|
|
1594
|
+
'X-Amz-Date': amz_date,
|
|
1595
|
+
'Authorization': authorization_header,
|
|
1596
|
+
'X-Amz-Content-Sha256': payload_hash
|
|
1597
|
+
}
|
|
1598
|
+
|
|
1599
|
+
if session_token:
|
|
1600
|
+
headers['X-Amz-Security-Token'] = session_token
|
|
1601
|
+
|
|
1602
|
+
return headers
|
|
1603
|
+
|
|
1604
|
+
except Exception as e:
|
|
1605
|
+
self.logger.error(f"Error signing AWS request: {e}")
|
|
1606
|
+
return {}
|
|
1607
|
+
|
|
1608
|
+
def get_current_provider(self):
|
|
1609
|
+
"""Get the currently selected provider."""
|
|
1610
|
+
return self.current_provider
|
|
1611
|
+
|
|
1612
|
+
def get_current_settings(self):
|
|
1613
|
+
"""Get settings for the current provider."""
|
|
1614
|
+
return self.app.settings["tool_settings"].get(self.current_provider, {})
|
|
1615
|
+
|
|
1616
|
+
def run_ai_in_thread(self):
|
|
1617
|
+
"""Start AI processing in a separate thread."""
|
|
1618
|
+
if hasattr(self, '_ai_thread') and self._ai_thread and self._ai_thread.is_alive():
|
|
1619
|
+
return
|
|
1620
|
+
|
|
1621
|
+
self.app.update_output_text("Generating response from AI...")
|
|
1622
|
+
self._ai_thread = threading.Thread(target=self.process_ai_request, daemon=True)
|
|
1623
|
+
self._ai_thread.start()
|
|
1624
|
+
|
|
1625
|
+
def process_ai_request(self):
|
|
1626
|
+
"""Process the AI request."""
|
|
1627
|
+
provider_name = self.current_provider
|
|
1628
|
+
settings = self.get_current_settings()
|
|
1629
|
+
api_key = self.get_api_key_for_provider(provider_name, settings)
|
|
1630
|
+
|
|
1631
|
+
# Get input text from parent app
|
|
1632
|
+
active_input_tab = self.app.input_tabs[self.app.input_notebook.index(self.app.input_notebook.select())]
|
|
1633
|
+
prompt = active_input_tab.text.get("1.0", tk.END).strip()
|
|
1634
|
+
|
|
1635
|
+
# Validate Vertex AI credentials
|
|
1636
|
+
if provider_name == "Vertex AI":
|
|
1637
|
+
credentials = self.get_vertex_ai_credentials()
|
|
1638
|
+
if not credentials:
|
|
1639
|
+
self.app.after(0, self.app.update_output_text, "Error: Vertex AI requires service account JSON file. Please upload it using the 'Upload JSON' button.")
|
|
1640
|
+
return
|
|
1641
|
+
project_id = settings.get("PROJECT_ID")
|
|
1642
|
+
if not project_id:
|
|
1643
|
+
self.app.after(0, self.app.update_output_text, "Error: Project ID not found. Please upload the service account JSON file.")
|
|
1644
|
+
return
|
|
1645
|
+
|
|
1646
|
+
# LM Studio doesn't require API key, AWS Bedrock has multiple auth methods
|
|
1647
|
+
if provider_name == "AWS Bedrock":
|
|
1648
|
+
# Validate AWS Bedrock credentials
|
|
1649
|
+
auth_method = settings.get("AUTH_METHOD", "api_key")
|
|
1650
|
+
|
|
1651
|
+
# Handle both display names and internal values for backward compatibility
|
|
1652
|
+
is_api_key_auth = auth_method in ["api_key", "API Key (Bearer Token)"]
|
|
1653
|
+
is_iam_auth = auth_method in ["iam", "IAM (Explicit Credentials)"]
|
|
1654
|
+
is_session_token_auth = auth_method in ["sessionToken", "Session Token (Temporary Credentials)"]
|
|
1655
|
+
|
|
1656
|
+
if is_api_key_auth:
|
|
1657
|
+
api_key = self.get_api_key_for_provider(provider_name, settings)
|
|
1658
|
+
if not api_key or api_key == "putinyourkey":
|
|
1659
|
+
self.app.after(0, self.app.update_output_text, "Error: AWS Bedrock requires an API Key. Please enter your AWS Bedrock API Key.")
|
|
1660
|
+
return
|
|
1661
|
+
elif is_iam_auth or is_session_token_auth:
|
|
1662
|
+
access_key = self.get_aws_credential(settings, "AWS_ACCESS_KEY_ID")
|
|
1663
|
+
secret_key = self.get_aws_credential(settings, "AWS_SECRET_ACCESS_KEY")
|
|
1664
|
+
if not access_key or not secret_key:
|
|
1665
|
+
self.app.after(0, self.app.update_output_text, "Error: AWS Bedrock requires Access Key ID and Secret Access Key.")
|
|
1666
|
+
return
|
|
1667
|
+
if is_session_token_auth:
|
|
1668
|
+
session_token = self.get_aws_credential(settings, "AWS_SESSION_TOKEN")
|
|
1669
|
+
if not session_token:
|
|
1670
|
+
self.app.after(0, self.app.update_output_text, "Error: AWS Bedrock requires Session Token for temporary credentials.")
|
|
1671
|
+
return
|
|
1672
|
+
elif provider_name == "Azure AI":
|
|
1673
|
+
# Validate Azure AI credentials
|
|
1674
|
+
endpoint = settings.get("ENDPOINT", "").strip()
|
|
1675
|
+
if not endpoint:
|
|
1676
|
+
self.app.after(0, self.app.update_output_text, "Error: Azure AI requires a Resource Endpoint. Please enter your endpoint URL.")
|
|
1677
|
+
return
|
|
1678
|
+
if not api_key or api_key == "putinyourkey":
|
|
1679
|
+
self.app.after(0, self.app.update_output_text, "Error: Azure AI requires an API Key. Please enter your API key.")
|
|
1680
|
+
return
|
|
1681
|
+
elif provider_name not in ["LM Studio", "Vertex AI"] and (not api_key or api_key == "putinyourkey"):
|
|
1682
|
+
self.app.after(0, self.app.update_output_text, f"Error: Please enter a valid {provider_name} API Key in the settings.")
|
|
1683
|
+
return
|
|
1684
|
+
if not prompt:
|
|
1685
|
+
self.app.after(0, self.app.update_output_text, "Error: Input text cannot be empty.")
|
|
1686
|
+
return
|
|
1687
|
+
|
|
1688
|
+
self.logger.info(f"Submitting prompt to {provider_name} with model {settings.get('MODEL')}")
|
|
1689
|
+
|
|
1690
|
+
# Handle HuggingFace separately (uses different client)
|
|
1691
|
+
if provider_name == "HuggingFace AI":
|
|
1692
|
+
if not api_key or api_key == "putinyourkey":
|
|
1693
|
+
error_msg = "Please configure your HuggingFace API key in the settings."
|
|
1694
|
+
self.logger.warning(error_msg)
|
|
1695
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1696
|
+
return
|
|
1697
|
+
|
|
1698
|
+
if HUGGINGFACE_HELPER_AVAILABLE:
|
|
1699
|
+
try:
|
|
1700
|
+
# Use the huggingface_helper module for proper task detection
|
|
1701
|
+
def update_callback(response):
|
|
1702
|
+
# Use unified display method (handles streaming automatically)
|
|
1703
|
+
self.display_ai_response(response)
|
|
1704
|
+
|
|
1705
|
+
self.logger.debug(f"Calling HuggingFace helper with model: {settings.get('MODEL', 'unknown')}")
|
|
1706
|
+
process_huggingface_request(api_key, prompt, settings, update_callback, self.logger)
|
|
1707
|
+
|
|
1708
|
+
except Exception as e:
|
|
1709
|
+
error_msg = f"HuggingFace processing failed: {str(e)}"
|
|
1710
|
+
self.logger.error(error_msg, exc_info=True)
|
|
1711
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1712
|
+
else:
|
|
1713
|
+
error_msg = "HuggingFace helper module not available. Please check your installation."
|
|
1714
|
+
self.logger.error(error_msg)
|
|
1715
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1716
|
+
return
|
|
1717
|
+
|
|
1718
|
+
# All other providers via REST helper
|
|
1719
|
+
try:
|
|
1720
|
+
if provider_name == "AWS Bedrock":
|
|
1721
|
+
model_id = settings.get("MODEL", "")
|
|
1722
|
+
# Check if it's an embedding or image model
|
|
1723
|
+
if any(keyword in model_id.lower() for keyword in [
|
|
1724
|
+
"embed", "embedding", "image", "stable-diffusion",
|
|
1725
|
+
"titan-image", "nova-canvas", "nova-reel", "nova-sonic"
|
|
1726
|
+
]):
|
|
1727
|
+
error_msg = (
|
|
1728
|
+
f"Error: '{model_id}' is not a text generation model.\n\n"
|
|
1729
|
+
"You've selected an embedding or image model which cannot generate text.\n\n"
|
|
1730
|
+
"Please select a text generation model such as:\n"
|
|
1731
|
+
"• amazon.nova-pro-v1:0\n"
|
|
1732
|
+
"• anthropic.claude-3-5-sonnet-20241022-v2:0\n"
|
|
1733
|
+
"• meta.llama3-1-70b-instruct-v1:0\n"
|
|
1734
|
+
"• mistral.mistral-large-2402-v1:0\n\n"
|
|
1735
|
+
"Use the 'Refresh Models' button to get an updated list of text generation models."
|
|
1736
|
+
)
|
|
1737
|
+
self.logger.error(error_msg)
|
|
1738
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1739
|
+
return
|
|
1740
|
+
|
|
1741
|
+
url, payload, headers = self._build_api_request(provider_name, api_key, prompt, settings)
|
|
1742
|
+
|
|
1743
|
+
self.logger.debug(f"{provider_name} payload: {json.dumps(payload, indent=2)}")
|
|
1744
|
+
|
|
1745
|
+
# Log request details for Vertex AI (without sensitive token)
|
|
1746
|
+
if provider_name == "Vertex AI":
|
|
1747
|
+
self.logger.debug(f"Vertex AI Request URL: {url}")
|
|
1748
|
+
safe_headers = {k: ('***REDACTED***' if k == 'Authorization' else v) for k, v in headers.items()}
|
|
1749
|
+
self.logger.debug(f"Vertex AI Headers: {json.dumps(safe_headers, indent=2)}")
|
|
1750
|
+
|
|
1751
|
+
# Check if provider supports streaming and streaming is enabled
|
|
1752
|
+
streaming_providers = ["OpenAI", "Groq AI", "OpenRouterAI", "Azure AI", "Anthropic AI"]
|
|
1753
|
+
use_streaming = (
|
|
1754
|
+
self.is_streaming_enabled() and
|
|
1755
|
+
provider_name in streaming_providers
|
|
1756
|
+
)
|
|
1757
|
+
|
|
1758
|
+
# Retry logic with exponential backoff
|
|
1759
|
+
max_retries = 5
|
|
1760
|
+
base_delay = 1
|
|
1761
|
+
|
|
1762
|
+
for i in range(max_retries):
|
|
1763
|
+
try:
|
|
1764
|
+
if use_streaming:
|
|
1765
|
+
# Use streaming API call
|
|
1766
|
+
self.logger.info(f"Using streaming mode for {provider_name}")
|
|
1767
|
+
streaming_payload = payload.copy()
|
|
1768
|
+
streaming_payload["stream"] = True
|
|
1769
|
+
|
|
1770
|
+
self._call_streaming_api(url, streaming_payload, headers, provider_name)
|
|
1771
|
+
return
|
|
1772
|
+
else:
|
|
1773
|
+
# Non-streaming API call
|
|
1774
|
+
response = requests.post(url, json=payload, headers=headers, timeout=60)
|
|
1775
|
+
response.raise_for_status()
|
|
1776
|
+
|
|
1777
|
+
data = response.json()
|
|
1778
|
+
self.logger.debug(f"{provider_name} Response: {data}")
|
|
1779
|
+
|
|
1780
|
+
result_text = self._extract_response_text(provider_name, data)
|
|
1781
|
+
self.logger.debug(f"FINAL: About to display result_text: {str(result_text)[:100]}...")
|
|
1782
|
+
|
|
1783
|
+
# Use unified display method (handles streaming automatically)
|
|
1784
|
+
self.display_ai_response(result_text)
|
|
1785
|
+
return
|
|
1786
|
+
|
|
1787
|
+
except requests.exceptions.HTTPError as e:
|
|
1788
|
+
if e.response.status_code == 429 and i < max_retries - 1:
|
|
1789
|
+
delay = base_delay * (2 ** i) + random.uniform(0, 1)
|
|
1790
|
+
self.logger.warning(f"Rate limit exceeded. Retrying in {delay:.2f} seconds...")
|
|
1791
|
+
time.sleep(delay)
|
|
1792
|
+
else:
|
|
1793
|
+
# Get full error response
|
|
1794
|
+
try:
|
|
1795
|
+
error_response = e.response.text if hasattr(e, 'response') and e.response else str(e)
|
|
1796
|
+
error_json = e.response.json() if hasattr(e, 'response') and e.response and e.response.headers.get('content-type', '').startswith('application/json') else None
|
|
1797
|
+
except:
|
|
1798
|
+
error_response = str(e)
|
|
1799
|
+
error_json = None
|
|
1800
|
+
|
|
1801
|
+
# Log detailed error for Vertex AI
|
|
1802
|
+
if provider_name == "Vertex AI":
|
|
1803
|
+
self.logger.error(f"Vertex AI API Error - Status: {e.response.status_code if hasattr(e, 'response') and e.response else 'N/A'}")
|
|
1804
|
+
self.logger.error(f"Vertex AI Error Response: {error_response}")
|
|
1805
|
+
if error_json:
|
|
1806
|
+
self.logger.error(f"Vertex AI Error JSON: {json.dumps(error_json, indent=2)}")
|
|
1807
|
+
self.logger.error(f"Vertex AI Request URL: {url}")
|
|
1808
|
+
self.logger.debug(f"Vertex AI Headers (token redacted): {[(k, '***REDACTED***' if k == 'Authorization' else v) for k, v in headers.items()]}")
|
|
1809
|
+
|
|
1810
|
+
# Provide helpful error message
|
|
1811
|
+
if e.response.status_code == 403:
|
|
1812
|
+
error_msg = f"Vertex AI 403 Forbidden Error\n\n"
|
|
1813
|
+
error_msg += f"URL: {url}\n\n"
|
|
1814
|
+
if error_json:
|
|
1815
|
+
error_msg += f"Error Details: {json.dumps(error_json, indent=2)}\n\n"
|
|
1816
|
+
else:
|
|
1817
|
+
error_msg += f"Error Response: {error_response}\n\n"
|
|
1818
|
+
error_msg += "Common causes:\n"
|
|
1819
|
+
error_msg += "1. Service account doesn't have 'Vertex AI User' role\n"
|
|
1820
|
+
error_msg += "2. Vertex AI API not enabled for the project\n"
|
|
1821
|
+
error_msg += "3. Project ID format incorrect (check for encoding issues)\n"
|
|
1822
|
+
error_msg += "4. Model name not available in the selected region\n"
|
|
1823
|
+
error_msg += "5. Billing not enabled for the project\n\n"
|
|
1824
|
+
error_msg += "Solutions:\n"
|
|
1825
|
+
error_msg += "1. Enable Vertex AI API in Google Cloud Console\n"
|
|
1826
|
+
error_msg += "2. Grant 'Vertex AI User' role to service account\n"
|
|
1827
|
+
error_msg += "3. Ensure billing is enabled\n"
|
|
1828
|
+
error_msg += "4. Verify model name is correct (try gemini-1.5-flash or gemini-1.5-pro)\n"
|
|
1829
|
+
|
|
1830
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1831
|
+
return
|
|
1832
|
+
|
|
1833
|
+
# Check for AWS Bedrock specific errors
|
|
1834
|
+
if provider_name == "AWS Bedrock":
|
|
1835
|
+
model_id = settings.get("MODEL", "unknown")
|
|
1836
|
+
auth_method = settings.get("AUTH_METHOD", "api_key")
|
|
1837
|
+
|
|
1838
|
+
if e.response.status_code == 403:
|
|
1839
|
+
error_msg = f"AWS Bedrock 403 Forbidden Error\n\n"
|
|
1840
|
+
error_msg += f"Model: {model_id}\n"
|
|
1841
|
+
error_msg += f"Auth Method: {auth_method}\n\n"
|
|
1842
|
+
error_msg += "This error typically means:\n"
|
|
1843
|
+
error_msg += "1. Your credentials don't have permission to access this model\n"
|
|
1844
|
+
error_msg += "2. The model is not enabled in your AWS account\n"
|
|
1845
|
+
error_msg += "3. The model is not available in your selected region\n"
|
|
1846
|
+
error_msg += "4. Your API key may be invalid or expired\n\n"
|
|
1847
|
+
error_msg += "Solutions:\n"
|
|
1848
|
+
error_msg += "1. Go to AWS Bedrock Console and enable model access\n"
|
|
1849
|
+
error_msg += "2. Verify your IAM permissions include 'bedrock:InvokeModel'\n"
|
|
1850
|
+
error_msg += "3. Try a different model (e.g., amazon.nova-lite-v1:0)\n"
|
|
1851
|
+
error_msg += "4. Try a different region (us-east-1, us-west-2)\n"
|
|
1852
|
+
error_msg += "5. If using API Key auth, try IAM credentials instead\n\n"
|
|
1853
|
+
error_msg += f"Original error: {error_response}"
|
|
1854
|
+
|
|
1855
|
+
self.logger.error(error_msg)
|
|
1856
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1857
|
+
return
|
|
1858
|
+
|
|
1859
|
+
if "on-demand throughput isn't supported" in error_response:
|
|
1860
|
+
error_msg = f"AWS Bedrock Model Error: {model_id}\n\n"
|
|
1861
|
+
error_msg += "This model requires an inference profile instead of direct model ID.\n\n"
|
|
1862
|
+
error_msg += "Solutions:\n"
|
|
1863
|
+
error_msg += "1. Use 'Refresh Models' button to get updated model list with inference profiles\n"
|
|
1864
|
+
error_msg += "2. Manually update model ID with regional prefix:\n"
|
|
1865
|
+
error_msg += f" • US: us.{model_id}\n"
|
|
1866
|
+
error_msg += f" • EU: eu.{model_id}\n"
|
|
1867
|
+
error_msg += f" • APAC: apac.{model_id}\n"
|
|
1868
|
+
error_msg += "3. For Claude Sonnet 4.5, use global profile: global.anthropic.claude-sonnet-4-5-20250929-v1:0\n\n"
|
|
1869
|
+
error_msg += f"Original error: {error_response}"
|
|
1870
|
+
|
|
1871
|
+
self.logger.error(error_msg)
|
|
1872
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1873
|
+
elif e.response.status_code == 400 and any(provider in model_id for provider in ["openai.", "qwen.", "twelvelabs."]):
|
|
1874
|
+
error_msg = f"AWS Bedrock Model Error: {model_id}\n\n"
|
|
1875
|
+
error_msg += "This third-party model may not be properly configured or available in your region.\n\n"
|
|
1876
|
+
error_msg += "Common issues:\n"
|
|
1877
|
+
error_msg += "1. Model may not be available in your selected region\n"
|
|
1878
|
+
error_msg += "2. Model may require special access or subscription\n"
|
|
1879
|
+
error_msg += "3. Model may have been deprecated or renamed\n"
|
|
1880
|
+
error_msg += "4. Payload format may not be compatible\n\n"
|
|
1881
|
+
error_msg += "Solutions:\n"
|
|
1882
|
+
error_msg += "1. Try a different region (us-east-1, us-west-2, eu-west-1)\n"
|
|
1883
|
+
error_msg += "2. Use 'Refresh Models' to get current available models\n"
|
|
1884
|
+
error_msg += "3. Try a similar model from Amazon, Anthropic, or Meta instead\n\n"
|
|
1885
|
+
error_msg += f"Original error: {error_response}"
|
|
1886
|
+
|
|
1887
|
+
self.logger.error(error_msg)
|
|
1888
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1889
|
+
elif e.response.status_code == 404:
|
|
1890
|
+
error_msg = f"AWS Bedrock Model Not Found: {model_id}\n\n"
|
|
1891
|
+
error_msg += "This model is not available or the model ID is incorrect.\n\n"
|
|
1892
|
+
error_msg += "Solutions:\n"
|
|
1893
|
+
error_msg += "1. Use 'Refresh Models' button to get current available models\n"
|
|
1894
|
+
error_msg += "2. Check if model ID has suffixes that need to be removed\n"
|
|
1895
|
+
error_msg += "3. Verify the model is available in your selected region\n"
|
|
1896
|
+
error_msg += "4. Try a similar model that's confirmed to be available\n\n"
|
|
1897
|
+
error_msg += f"Original error: {error_response}"
|
|
1898
|
+
|
|
1899
|
+
self.logger.error(error_msg)
|
|
1900
|
+
self.app.after(0, self.app.update_output_text, error_msg)
|
|
1901
|
+
else:
|
|
1902
|
+
self.logger.error(f"AWS Bedrock API Request Error: {e}\nResponse: {error_response}")
|
|
1903
|
+
self.app.after(0, self.app.update_output_text, f"AWS Bedrock API Request Error: {e}\nResponse: {error_response}")
|
|
1904
|
+
else:
|
|
1905
|
+
self.logger.error(f"API Request Error: {e}\nResponse: {error_response}")
|
|
1906
|
+
self.app.after(0, self.app.update_output_text, f"API Request Error: {e}\nResponse: {error_response}")
|
|
1907
|
+
return
|
|
1908
|
+
except requests.exceptions.RequestException as e:
|
|
1909
|
+
self.logger.error(f"Network Error: {e}")
|
|
1910
|
+
self.app.after(0, self.app.update_output_text, f"Network Error: {e}")
|
|
1911
|
+
return
|
|
1912
|
+
except (KeyError, IndexError, json.JSONDecodeError) as e:
|
|
1913
|
+
self.logger.error(f"Error parsing AI response: {e}\n\nResponse:\n{response.text if 'response' in locals() else 'N/A'}")
|
|
1914
|
+
self.app.after(0, self.app.update_output_text, f"Error parsing AI response: {e}\n\nResponse:\n{response.text if 'response' in locals() else 'N/A'}")
|
|
1915
|
+
return
|
|
1916
|
+
|
|
1917
|
+
self.app.after(0, self.app.update_output_text, "Error: Max retries exceeded. The API is still busy.")
|
|
1918
|
+
|
|
1919
|
+
except Exception as e:
|
|
1920
|
+
self.logger.error(f"Error configuring API for {provider_name}: {e}")
|
|
1921
|
+
self.app.after(0, self.app.update_output_text, f"Error configuring API request: {e}")
|
|
1922
|
+
|
|
1923
|
+
def _build_api_request(self, provider_name, api_key, prompt, settings):
|
|
1924
|
+
"""Build API request URL, payload, and headers."""
|
|
1925
|
+
provider_config = self.ai_providers[provider_name]
|
|
1926
|
+
|
|
1927
|
+
# Build URL
|
|
1928
|
+
if provider_name == "Vertex AI":
|
|
1929
|
+
# Get project_id and location from settings
|
|
1930
|
+
project_id = settings.get("PROJECT_ID", "")
|
|
1931
|
+
location = settings.get("LOCATION", "us-central1")
|
|
1932
|
+
model = settings.get("MODEL", "")
|
|
1933
|
+
|
|
1934
|
+
# Note: Project IDs in Google Cloud REST API URLs should be used as-is
|
|
1935
|
+
# If project_id contains colons (like project numbers), they're part of the format
|
|
1936
|
+
url = provider_config["url_template"].format(
|
|
1937
|
+
location=location,
|
|
1938
|
+
project_id=project_id,
|
|
1939
|
+
model=model
|
|
1940
|
+
)
|
|
1941
|
+
|
|
1942
|
+
self.logger.debug(f"Vertex AI URL components - project_id: {project_id}, location: {location}, model: {model}")
|
|
1943
|
+
elif provider_name == "Azure AI":
|
|
1944
|
+
endpoint = settings.get("ENDPOINT", "").strip().rstrip('/')
|
|
1945
|
+
model = settings.get("MODEL", "gpt-4.1")
|
|
1946
|
+
api_version = settings.get("API_VERSION", "2024-10-21")
|
|
1947
|
+
|
|
1948
|
+
# Auto-detect endpoint type and build URL accordingly
|
|
1949
|
+
# Azure AI Foundry: https://[resource].services.ai.azure.com
|
|
1950
|
+
# Azure OpenAI: https://[resource].openai.azure.com or https://[resource].cognitiveservices.azure.com
|
|
1951
|
+
|
|
1952
|
+
if ".services.ai.azure.com" in endpoint:
|
|
1953
|
+
# Azure AI Foundry - use /models/chat/completions format (model goes in request body, not URL)
|
|
1954
|
+
# Check if endpoint already includes /api/projects/[project-name]
|
|
1955
|
+
import re
|
|
1956
|
+
if "/api/projects/" in endpoint:
|
|
1957
|
+
# Project endpoint format - extract base resource endpoint
|
|
1958
|
+
match = re.search(r'https://([^.]+)\.services\.ai\.azure\.com', endpoint)
|
|
1959
|
+
if match:
|
|
1960
|
+
resource_name = match.group(1)
|
|
1961
|
+
endpoint = f"https://{resource_name}.services.ai.azure.com"
|
|
1962
|
+
# Use Foundry models endpoint format: /models/chat/completions
|
|
1963
|
+
url = f"{endpoint}/models/chat/completions?api-version={api_version}"
|
|
1964
|
+
elif ".openai.azure.com" in endpoint or ".cognitiveservices.azure.com" in endpoint:
|
|
1965
|
+
# Azure OpenAI - use /openai/deployments/[model]/chat/completions format
|
|
1966
|
+
# Both *.openai.azure.com and *.cognitiveservices.azure.com are Azure OpenAI endpoints
|
|
1967
|
+
url = f"{endpoint}/openai/deployments/{model}/chat/completions?api-version={api_version}"
|
|
1968
|
+
else:
|
|
1969
|
+
# Unknown format - assume Azure AI Foundry format by default
|
|
1970
|
+
# Most likely it's a Foundry endpoint if it's not explicitly OpenAI
|
|
1971
|
+
url = f"{endpoint}/models/chat/completions?api-version={api_version}"
|
|
1972
|
+
elif provider_name == "LM Studio":
|
|
1973
|
+
base_url = settings.get("BASE_URL", "http://127.0.0.1:1234").rstrip('/')
|
|
1974
|
+
url = provider_config["url_template"].format(base_url=base_url)
|
|
1975
|
+
elif provider_name == "AWS Bedrock":
|
|
1976
|
+
region = settings.get("AWS_REGION", "us-west-2")
|
|
1977
|
+
model_id = settings.get("MODEL", "meta.llama3-1-70b-instruct-v1:0")
|
|
1978
|
+
|
|
1979
|
+
# Handle inference profile selection from dropdown
|
|
1980
|
+
if " (Inference Profile)" in model_id:
|
|
1981
|
+
model_id = model_id.replace(" (Inference Profile)", "")
|
|
1982
|
+
self.logger.debug(f"Using inference profile directly: {model_id}")
|
|
1983
|
+
|
|
1984
|
+
# Clean up model ID suffixes that are metadata but not part of the actual model ID for API calls
|
|
1985
|
+
# These suffixes are used in the model list for information but need to be removed for API calls
|
|
1986
|
+
original_model_id = model_id
|
|
1987
|
+
if ":mm" in model_id: # Multimodal capability indicator
|
|
1988
|
+
model_id = model_id.replace(":mm", "")
|
|
1989
|
+
self.logger.debug(f"Removed multimodal suffix: {original_model_id} -> {model_id}")
|
|
1990
|
+
elif ":8k" in model_id: # Context length indicator
|
|
1991
|
+
model_id = model_id.replace(":8k", "")
|
|
1992
|
+
self.logger.debug(f"Removed context length suffix: {original_model_id} -> {model_id}")
|
|
1993
|
+
elif model_id.count(":") > 2: # Other suffixes (model should have max 2 colons: provider.model-name-version:number)
|
|
1994
|
+
# Keep only the first two parts (provider.model:version)
|
|
1995
|
+
parts = model_id.split(":")
|
|
1996
|
+
if len(parts) > 2:
|
|
1997
|
+
model_id = ":".join(parts[:2])
|
|
1998
|
+
self.logger.debug(f"Cleaned model ID: {original_model_id} -> {model_id}")
|
|
1999
|
+
|
|
2000
|
+
# Check if model_id is already an inference profile (has regional prefix)
|
|
2001
|
+
# If so, don't apply the mapping - use it as-is
|
|
2002
|
+
already_has_prefix = any(model_id.startswith(prefix) for prefix in ['us.', 'eu.', 'apac.', 'global.'])
|
|
2003
|
+
|
|
2004
|
+
if already_has_prefix:
|
|
2005
|
+
# Model already has inference profile prefix, use as-is
|
|
2006
|
+
final_model_id = model_id
|
|
2007
|
+
self.logger.debug(f"AWS Bedrock: Model '{model_id}' already has inference profile prefix")
|
|
2008
|
+
else:
|
|
2009
|
+
# AWS Bedrock requires inference profiles for newer Claude models
|
|
2010
|
+
# Based on AWS documentation and current model availability
|
|
2011
|
+
# Note: Only map base model IDs to inference profiles
|
|
2012
|
+
inference_profile_mapping = {
|
|
2013
|
+
# Claude 3.5 models (v2) - these require inference profiles
|
|
2014
|
+
"anthropic.claude-3-5-haiku-20241022-v1:0": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
|
2015
|
+
"anthropic.claude-3-5-sonnet-20241022-v2:0": "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
2016
|
+
|
|
2017
|
+
# Claude 3.5 models (v1)
|
|
2018
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0": "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
2019
|
+
|
|
2020
|
+
# Claude 3 models (original) - some may work without profiles
|
|
2021
|
+
"anthropic.claude-3-opus-20240229-v1:0": "us.anthropic.claude-3-opus-20240229-v1:0",
|
|
2022
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": "us.anthropic.claude-3-sonnet-20240229-v1:0",
|
|
2023
|
+
"anthropic.claude-3-haiku-20240307-v1:0": "us.anthropic.claude-3-haiku-20240307-v1:0"
|
|
2024
|
+
}
|
|
2025
|
+
|
|
2026
|
+
# Use inference profile if available, otherwise use direct model ID
|
|
2027
|
+
final_model_id = inference_profile_mapping.get(model_id, model_id)
|
|
2028
|
+
|
|
2029
|
+
# If we're using an inference profile, log the conversion for debugging
|
|
2030
|
+
if final_model_id != model_id:
|
|
2031
|
+
self.logger.info(f"AWS Bedrock: Converting model ID '{model_id}' to inference profile '{final_model_id}'")
|
|
2032
|
+
|
|
2033
|
+
# Handle regional preferences for inference profiles
|
|
2034
|
+
# If user is in EU region and model supports EU profiles, use EU prefix
|
|
2035
|
+
if region.startswith('eu-') and final_model_id.startswith('us.anthropic.'):
|
|
2036
|
+
eu_model_id = final_model_id.replace('us.anthropic.', 'eu.anthropic.')
|
|
2037
|
+
self.logger.info(f"AWS Bedrock: Using EU inference profile '{eu_model_id}' for region '{region}'")
|
|
2038
|
+
final_model_id = eu_model_id
|
|
2039
|
+
elif region.startswith('ap-') and final_model_id.startswith('us.anthropic.'):
|
|
2040
|
+
apac_model_id = final_model_id.replace('us.anthropic.', 'apac.anthropic.')
|
|
2041
|
+
self.logger.info(f"AWS Bedrock: Using APAC inference profile '{apac_model_id}' for region '{region}'")
|
|
2042
|
+
final_model_id = apac_model_id
|
|
2043
|
+
|
|
2044
|
+
# Always use InvokeModel API - it's more reliable and works with both
|
|
2045
|
+
# inference profiles and base model IDs
|
|
2046
|
+
# The Converse API has compatibility issues with some authentication methods
|
|
2047
|
+
url = provider_config["url_invoke"].format(region=region, model=final_model_id)
|
|
2048
|
+
self.logger.info(f"AWS Bedrock: Using InvokeModel API for model '{final_model_id}'")
|
|
2049
|
+
elif "url_template" in provider_config:
|
|
2050
|
+
url = provider_config["url_template"].format(model=settings.get("MODEL"), api_key=api_key)
|
|
2051
|
+
else:
|
|
2052
|
+
url = provider_config["url"]
|
|
2053
|
+
|
|
2054
|
+
# Build payload first (needed for AWS signing)
|
|
2055
|
+
payload = self._build_payload(provider_name, prompt, settings)
|
|
2056
|
+
|
|
2057
|
+
# Build headers
|
|
2058
|
+
headers = {}
|
|
2059
|
+
for key, value in provider_config["headers_template"].items():
|
|
2060
|
+
if provider_name == "Vertex AI":
|
|
2061
|
+
# Vertex AI uses OAuth2 access token
|
|
2062
|
+
if "{access_token}" in value:
|
|
2063
|
+
access_token = self.get_vertex_ai_access_token()
|
|
2064
|
+
if not access_token:
|
|
2065
|
+
raise ValueError("Failed to obtain Vertex AI access token. Please check your service account JSON.")
|
|
2066
|
+
headers[key] = value.format(access_token=access_token)
|
|
2067
|
+
else:
|
|
2068
|
+
headers[key] = value
|
|
2069
|
+
elif provider_name == "Azure AI":
|
|
2070
|
+
# Azure AI uses api-key header (not Authorization Bearer)
|
|
2071
|
+
headers[key] = value.format(api_key=api_key)
|
|
2072
|
+
elif provider_name in ["LM Studio", "AWS Bedrock"]:
|
|
2073
|
+
# LM Studio and AWS Bedrock don't need API key in headers
|
|
2074
|
+
headers[key] = value
|
|
2075
|
+
else:
|
|
2076
|
+
headers[key] = value.format(api_key=api_key)
|
|
2077
|
+
|
|
2078
|
+
# AWS Bedrock authentication - following Roo Code's approach
|
|
2079
|
+
if provider_name == "AWS Bedrock":
|
|
2080
|
+
auth_method = settings.get("AUTH_METHOD", "api_key")
|
|
2081
|
+
region = settings.get("AWS_REGION", "us-west-2")
|
|
2082
|
+
|
|
2083
|
+
# Handle both display names and internal values for backward compatibility
|
|
2084
|
+
is_api_key_auth = auth_method in ["api_key", "API Key (Bearer Token)"]
|
|
2085
|
+
is_iam_auth = auth_method in ["iam", "IAM (Explicit Credentials)"]
|
|
2086
|
+
is_session_token_auth = auth_method in ["sessionToken", "Session Token (Temporary Credentials)"]
|
|
2087
|
+
is_iam_role_auth = auth_method in ["iam_role", "IAM (Implied Credentials)"]
|
|
2088
|
+
|
|
2089
|
+
# Based on Roo Code's implementation, they support API key authentication
|
|
2090
|
+
# Let's add that back and use Bearer token format like they do
|
|
2091
|
+
if is_api_key_auth:
|
|
2092
|
+
# Use API key/token-based authentication (Roo Code style)
|
|
2093
|
+
api_key_value = self.get_api_key_for_provider(provider_name, settings)
|
|
2094
|
+
self.logger.debug(f"AWS Bedrock API Key auth: key length = {len(api_key_value) if api_key_value else 0}")
|
|
2095
|
+
headers.update({
|
|
2096
|
+
"Authorization": f"Bearer {api_key_value}",
|
|
2097
|
+
"Content-Type": "application/json",
|
|
2098
|
+
"Accept": "application/json"
|
|
2099
|
+
})
|
|
2100
|
+
elif is_iam_auth or is_session_token_auth:
|
|
2101
|
+
# Use AWS SigV4 authentication
|
|
2102
|
+
access_key = self.get_aws_credential(settings, "AWS_ACCESS_KEY_ID")
|
|
2103
|
+
secret_key = self.get_aws_credential(settings, "AWS_SECRET_ACCESS_KEY")
|
|
2104
|
+
session_token = self.get_aws_credential(settings, "AWS_SESSION_TOKEN") if is_session_token_auth else None
|
|
2105
|
+
|
|
2106
|
+
if access_key and secret_key:
|
|
2107
|
+
payload_str = json.dumps(payload)
|
|
2108
|
+
signed_headers = self.sign_aws_request(
|
|
2109
|
+
"POST", url, payload_str, access_key, secret_key,
|
|
2110
|
+
session_token, region, "bedrock-runtime"
|
|
2111
|
+
)
|
|
2112
|
+
headers.update(signed_headers)
|
|
2113
|
+
elif is_iam_role_auth:
|
|
2114
|
+
# For IAM role, we would need to use boto3 or assume role
|
|
2115
|
+
# For now, add basic headers (this won't work without proper IAM role setup)
|
|
2116
|
+
headers.update({
|
|
2117
|
+
"Content-Type": "application/json",
|
|
2118
|
+
"Accept": "application/json"
|
|
2119
|
+
})
|
|
2120
|
+
|
|
2121
|
+
return url, payload, headers
|
|
2122
|
+
|
|
2123
|
+
def _build_payload(self, provider_name, prompt, settings):
|
|
2124
|
+
"""Build API payload for the specific provider."""
|
|
2125
|
+
payload = {}
|
|
2126
|
+
|
|
2127
|
+
if provider_name in ["Google AI", "Vertex AI"]:
|
|
2128
|
+
system_prompt = settings.get("system_prompt", "").strip()
|
|
2129
|
+
|
|
2130
|
+
# Use proper systemInstruction field instead of prepending to prompt
|
|
2131
|
+
# This is the recommended way to set system prompts for Gemini models
|
|
2132
|
+
payload = {"contents": [{"parts": [{"text": prompt}], "role": "user"}]}
|
|
2133
|
+
|
|
2134
|
+
# Add systemInstruction as a separate field (proper Gemini API format)
|
|
2135
|
+
if system_prompt:
|
|
2136
|
+
payload["systemInstruction"] = {
|
|
2137
|
+
"parts": [{"text": system_prompt}]
|
|
2138
|
+
}
|
|
2139
|
+
|
|
2140
|
+
gen_config = {}
|
|
2141
|
+
self._add_param_if_valid(gen_config, settings, 'temperature', float)
|
|
2142
|
+
self._add_param_if_valid(gen_config, settings, 'topP', float)
|
|
2143
|
+
self._add_param_if_valid(gen_config, settings, 'topK', int)
|
|
2144
|
+
self._add_param_if_valid(gen_config, settings, 'maxOutputTokens', int)
|
|
2145
|
+
self._add_param_if_valid(gen_config, settings, 'candidateCount', int)
|
|
2146
|
+
|
|
2147
|
+
stop_seq_str = str(settings.get('stopSequences', '')).strip()
|
|
2148
|
+
if stop_seq_str:
|
|
2149
|
+
gen_config['stopSequences'] = [s.strip() for s in stop_seq_str.split(',')]
|
|
2150
|
+
|
|
2151
|
+
if gen_config:
|
|
2152
|
+
payload['generationConfig'] = gen_config
|
|
2153
|
+
|
|
2154
|
+
elif provider_name == "Anthropic AI":
|
|
2155
|
+
payload = {"model": settings.get("MODEL"), "messages": [{"role": "user", "content": prompt}]}
|
|
2156
|
+
if settings.get("system"):
|
|
2157
|
+
payload["system"] = settings.get("system")
|
|
2158
|
+
|
|
2159
|
+
self._add_param_if_valid(payload, settings, 'max_tokens', int)
|
|
2160
|
+
self._add_param_if_valid(payload, settings, 'temperature', float)
|
|
2161
|
+
self._add_param_if_valid(payload, settings, 'top_p', float)
|
|
2162
|
+
self._add_param_if_valid(payload, settings, 'top_k', int)
|
|
2163
|
+
|
|
2164
|
+
stop_seq_str = str(settings.get('stop_sequences', '')).strip()
|
|
2165
|
+
if stop_seq_str:
|
|
2166
|
+
payload['stop_sequences'] = [s.strip() for s in stop_seq_str.split(',')]
|
|
2167
|
+
|
|
2168
|
+
elif provider_name == "Cohere AI":
|
|
2169
|
+
payload = {"model": settings.get("MODEL"), "message": prompt}
|
|
2170
|
+
if settings.get("preamble"):
|
|
2171
|
+
payload["preamble"] = settings.get("preamble")
|
|
2172
|
+
|
|
2173
|
+
self._add_param_if_valid(payload, settings, 'temperature', float)
|
|
2174
|
+
self._add_param_if_valid(payload, settings, 'p', float)
|
|
2175
|
+
self._add_param_if_valid(payload, settings, 'k', int)
|
|
2176
|
+
self._add_param_if_valid(payload, settings, 'max_tokens', int)
|
|
2177
|
+
self._add_param_if_valid(payload, settings, 'frequency_penalty', float)
|
|
2178
|
+
self._add_param_if_valid(payload, settings, 'presence_penalty', float)
|
|
2179
|
+
|
|
2180
|
+
if settings.get('citation_quality'):
|
|
2181
|
+
payload['citation_quality'] = settings['citation_quality']
|
|
2182
|
+
|
|
2183
|
+
stop_seq_str = str(settings.get('stop_sequences', '')).strip()
|
|
2184
|
+
if stop_seq_str:
|
|
2185
|
+
payload['stop_sequences'] = [s.strip() for s in stop_seq_str.split(',')]
|
|
2186
|
+
|
|
2187
|
+
elif provider_name == "Azure AI":
|
|
2188
|
+
# Azure AI uses OpenAI-compatible format
|
|
2189
|
+
# For Azure OpenAI: model is in URL, so don't include in payload (recommended)
|
|
2190
|
+
# For Azure AI Foundry: model must be in payload
|
|
2191
|
+
endpoint = settings.get("ENDPOINT", "").strip().rstrip('/')
|
|
2192
|
+
payload = {"messages": []}
|
|
2193
|
+
|
|
2194
|
+
# Only include model in payload for Azure AI Foundry
|
|
2195
|
+
# Azure OpenAI has model in URL path, so omit from payload for better compatibility
|
|
2196
|
+
if ".services.ai.azure.com" in endpoint:
|
|
2197
|
+
# Azure AI Foundry - model MUST be in payload
|
|
2198
|
+
payload["model"] = settings.get("MODEL")
|
|
2199
|
+
# For Azure OpenAI (openai.azure.com or cognitiveservices.azure.com), model is in URL
|
|
2200
|
+
# Some API versions accept model in payload too, but it's better to omit it
|
|
2201
|
+
|
|
2202
|
+
system_prompt = settings.get("system_prompt", "").strip()
|
|
2203
|
+
if system_prompt:
|
|
2204
|
+
payload["messages"].append({"role": "system", "content": system_prompt})
|
|
2205
|
+
payload["messages"].append({"role": "user", "content": prompt})
|
|
2206
|
+
|
|
2207
|
+
# Universal parameters supported by Azure AI Foundry
|
|
2208
|
+
self._add_param_if_valid(payload, settings, 'temperature', float)
|
|
2209
|
+
self._add_param_if_valid(payload, settings, 'top_p', float)
|
|
2210
|
+
self._add_param_if_valid(payload, settings, 'max_tokens', int)
|
|
2211
|
+
self._add_param_if_valid(payload, settings, 'frequency_penalty', float)
|
|
2212
|
+
self._add_param_if_valid(payload, settings, 'presence_penalty', float)
|
|
2213
|
+
self._add_param_if_valid(payload, settings, 'seed', int)
|
|
2214
|
+
|
|
2215
|
+
stop_str = str(settings.get('stop', '')).strip()
|
|
2216
|
+
if stop_str:
|
|
2217
|
+
payload['stop'] = [s.strip() for s in stop_str.split(',')]
|
|
2218
|
+
elif provider_name in ["OpenAI", "Groq AI", "OpenRouterAI", "LM Studio"]:
|
|
2219
|
+
payload = {"model": settings.get("MODEL"), "messages": []}
|
|
2220
|
+
system_prompt = settings.get("system_prompt", "").strip()
|
|
2221
|
+
if system_prompt:
|
|
2222
|
+
payload["messages"].append({"role": "system", "content": system_prompt})
|
|
2223
|
+
payload["messages"].append({"role": "user", "content": prompt})
|
|
2224
|
+
|
|
2225
|
+
# LM Studio specific parameters
|
|
2226
|
+
if provider_name == "LM Studio":
|
|
2227
|
+
max_tokens = settings.get("MAX_TOKENS", "2048")
|
|
2228
|
+
if max_tokens:
|
|
2229
|
+
try:
|
|
2230
|
+
payload["max_tokens"] = int(max_tokens)
|
|
2231
|
+
except ValueError:
|
|
2232
|
+
pass
|
|
2233
|
+
else:
|
|
2234
|
+
# Standard OpenAI-compatible parameters
|
|
2235
|
+
self._add_param_if_valid(payload, settings, 'temperature', float)
|
|
2236
|
+
self._add_param_if_valid(payload, settings, 'top_p', float)
|
|
2237
|
+
self._add_param_if_valid(payload, settings, 'max_tokens', int)
|
|
2238
|
+
self._add_param_if_valid(payload, settings, 'frequency_penalty', float)
|
|
2239
|
+
self._add_param_if_valid(payload, settings, 'presence_penalty', float)
|
|
2240
|
+
self._add_param_if_valid(payload, settings, 'seed', int)
|
|
2241
|
+
|
|
2242
|
+
stop_str = str(settings.get('stop', '')).strip()
|
|
2243
|
+
if stop_str:
|
|
2244
|
+
payload['stop'] = [s.strip() for s in stop_str.split(',')]
|
|
2245
|
+
|
|
2246
|
+
if settings.get("response_format") == "json_object":
|
|
2247
|
+
payload["response_format"] = {"type": "json_object"}
|
|
2248
|
+
|
|
2249
|
+
# OpenRouter specific parameters
|
|
2250
|
+
if provider_name == "OpenRouterAI":
|
|
2251
|
+
self._add_param_if_valid(payload, settings, 'top_k', int)
|
|
2252
|
+
self._add_param_if_valid(payload, settings, 'repetition_penalty', float)
|
|
2253
|
+
|
|
2254
|
+
elif provider_name == "AWS Bedrock":
|
|
2255
|
+
# AWS Bedrock InvokeModel API - model-specific payload formats
|
|
2256
|
+
# Using InvokeModel API for better compatibility with API Key authentication
|
|
2257
|
+
model_id = settings.get("MODEL", "")
|
|
2258
|
+
system_prompt = settings.get("system_prompt", "").strip()
|
|
2259
|
+
|
|
2260
|
+
max_tokens = settings.get("MAX_OUTPUT_TOKENS", "4096")
|
|
2261
|
+
try:
|
|
2262
|
+
max_tokens_int = int(max_tokens)
|
|
2263
|
+
except ValueError:
|
|
2264
|
+
max_tokens_int = 4096
|
|
2265
|
+
|
|
2266
|
+
self.logger.debug(f"Building InvokeModel payload for model: {model_id}")
|
|
2267
|
+
|
|
2268
|
+
if "anthropic.claude" in model_id:
|
|
2269
|
+
# Anthropic Claude models
|
|
2270
|
+
payload = {
|
|
2271
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
2272
|
+
"max_tokens": max_tokens_int,
|
|
2273
|
+
"messages": [{"role": "user", "content": prompt}]
|
|
2274
|
+
}
|
|
2275
|
+
if system_prompt:
|
|
2276
|
+
payload["system"] = system_prompt
|
|
2277
|
+
elif "amazon.nova" in model_id:
|
|
2278
|
+
# Amazon Nova models
|
|
2279
|
+
payload = {
|
|
2280
|
+
"messages": [{"role": "user", "content": [{"text": prompt}]}],
|
|
2281
|
+
"inferenceConfig": {"maxTokens": max_tokens_int}
|
|
2282
|
+
}
|
|
2283
|
+
if system_prompt:
|
|
2284
|
+
payload["system"] = [{"text": system_prompt}]
|
|
2285
|
+
elif "amazon.titan" in model_id:
|
|
2286
|
+
# Amazon Titan models
|
|
2287
|
+
payload = {
|
|
2288
|
+
"inputText": f"{system_prompt}\n\n{prompt}" if system_prompt else prompt,
|
|
2289
|
+
"textGenerationConfig": {
|
|
2290
|
+
"maxTokenCount": max_tokens_int,
|
|
2291
|
+
"temperature": 0.7,
|
|
2292
|
+
"topP": 0.9
|
|
2293
|
+
}
|
|
2294
|
+
}
|
|
2295
|
+
elif "meta.llama" in model_id:
|
|
2296
|
+
# Meta Llama models
|
|
2297
|
+
full_prompt = f"{system_prompt}\n\nHuman: {prompt}\n\nAssistant:" if system_prompt else f"Human: {prompt}\n\nAssistant:"
|
|
2298
|
+
payload = {
|
|
2299
|
+
"prompt": full_prompt,
|
|
2300
|
+
"max_gen_len": max_tokens_int,
|
|
2301
|
+
"temperature": 0.7,
|
|
2302
|
+
"top_p": 0.9
|
|
2303
|
+
}
|
|
2304
|
+
elif "mistral." in model_id or "mixtral." in model_id:
|
|
2305
|
+
# Mistral models
|
|
2306
|
+
payload = {
|
|
2307
|
+
"prompt": f"<s>[INST] {system_prompt}\n\n{prompt} [/INST]" if system_prompt else f"<s>[INST] {prompt} [/INST]",
|
|
2308
|
+
"max_tokens": max_tokens_int,
|
|
2309
|
+
"temperature": 0.7,
|
|
2310
|
+
"top_p": 0.9
|
|
2311
|
+
}
|
|
2312
|
+
elif "cohere.command" in model_id:
|
|
2313
|
+
# Cohere Command models
|
|
2314
|
+
payload = {
|
|
2315
|
+
"message": prompt,
|
|
2316
|
+
"max_tokens": max_tokens_int,
|
|
2317
|
+
"temperature": 0.7,
|
|
2318
|
+
"p": 0.9
|
|
2319
|
+
}
|
|
2320
|
+
if system_prompt:
|
|
2321
|
+
payload["preamble"] = system_prompt
|
|
2322
|
+
elif "ai21." in model_id:
|
|
2323
|
+
# AI21 models
|
|
2324
|
+
payload = {
|
|
2325
|
+
"prompt": f"{system_prompt}\n\n{prompt}" if system_prompt else prompt,
|
|
2326
|
+
"maxTokens": max_tokens_int,
|
|
2327
|
+
"temperature": 0.7,
|
|
2328
|
+
"topP": 0.9
|
|
2329
|
+
}
|
|
2330
|
+
else:
|
|
2331
|
+
# Default format - try messages format first (works for many models)
|
|
2332
|
+
payload = {
|
|
2333
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
2334
|
+
"max_tokens": max_tokens_int,
|
|
2335
|
+
"temperature": 0.7
|
|
2336
|
+
}
|
|
2337
|
+
if system_prompt:
|
|
2338
|
+
payload["messages"].insert(0, {"role": "system", "content": system_prompt})
|
|
2339
|
+
|
|
2340
|
+
return payload
|
|
2341
|
+
|
|
2342
|
+
def _add_param_if_valid(self, param_dict, settings, key, param_type):
|
|
2343
|
+
"""Add parameter to dict if it's valid."""
|
|
2344
|
+
val_str = str(settings.get(key, '')).strip()
|
|
2345
|
+
if val_str:
|
|
2346
|
+
try:
|
|
2347
|
+
converted_val = param_type(val_str)
|
|
2348
|
+
if converted_val: # Excludes empty strings, 0, and 0.0
|
|
2349
|
+
param_dict[key] = converted_val
|
|
2350
|
+
except (ValueError, TypeError):
|
|
2351
|
+
self.logger.warning(f"Could not convert {key} value '{val_str}' to {param_type}")
|
|
2352
|
+
|
|
2353
|
+
def _extract_response_text(self, provider_name, data):
|
|
2354
|
+
"""Extract response text from API response."""
|
|
2355
|
+
result_text = f"Error: Could not parse response from {provider_name}."
|
|
2356
|
+
|
|
2357
|
+
if provider_name in ["Google AI", "Vertex AI"]:
|
|
2358
|
+
result_text = data.get('candidates', [{}])[0].get('content', {}).get('parts', [{}])[0].get('text', result_text)
|
|
2359
|
+
elif provider_name == "Anthropic AI":
|
|
2360
|
+
result_text = data.get('content', [{}])[0].get('text', result_text)
|
|
2361
|
+
elif provider_name in ["OpenAI", "Groq AI", "OpenRouterAI", "LM Studio", "Azure AI"]:
|
|
2362
|
+
result_text = data.get('choices', [{}])[0].get('message', {}).get('content', result_text)
|
|
2363
|
+
elif provider_name == "Cohere AI":
|
|
2364
|
+
result_text = data.get('text', result_text)
|
|
2365
|
+
elif provider_name == "AWS Bedrock":
|
|
2366
|
+
# Extract response from AWS Bedrock Converse API
|
|
2367
|
+
# Converse API response format: {'output': {'message': {'content': [{'text': '...'}], 'role': 'assistant'}}}
|
|
2368
|
+
self.logger.debug(f"AWS Bedrock response data: {data}")
|
|
2369
|
+
|
|
2370
|
+
try:
|
|
2371
|
+
# Primary: Converse API format (recommended)
|
|
2372
|
+
if 'output' in data and 'message' in data['output']:
|
|
2373
|
+
message_data = data['output']['message']
|
|
2374
|
+
self.logger.debug("Using Converse API response format")
|
|
2375
|
+
|
|
2376
|
+
if 'content' in message_data and isinstance(message_data['content'], list):
|
|
2377
|
+
# Extract text from content array
|
|
2378
|
+
text_parts = []
|
|
2379
|
+
for content_item in message_data['content']:
|
|
2380
|
+
if isinstance(content_item, dict) and 'text' in content_item:
|
|
2381
|
+
text_parts.append(content_item['text'])
|
|
2382
|
+
|
|
2383
|
+
if text_parts:
|
|
2384
|
+
result_text = ''.join(text_parts)
|
|
2385
|
+
self.logger.debug(f"Successfully extracted Converse API text: {result_text[:100]}...")
|
|
2386
|
+
else:
|
|
2387
|
+
self.logger.warning("Converse API response had no text content")
|
|
2388
|
+
result_text = str(message_data.get('content', ''))
|
|
2389
|
+
else:
|
|
2390
|
+
result_text = str(message_data)
|
|
2391
|
+
|
|
2392
|
+
# Fallback: Legacy InvokeModel API formats
|
|
2393
|
+
elif 'content' in data and isinstance(data['content'], list) and len(data['content']) > 0:
|
|
2394
|
+
# Anthropic Claude format (InvokeModel)
|
|
2395
|
+
self.logger.debug("Using legacy Claude content format")
|
|
2396
|
+
result_text = data['content'][0].get('text', result_text)
|
|
2397
|
+
elif 'generation' in data:
|
|
2398
|
+
# Meta Llama format (InvokeModel)
|
|
2399
|
+
self.logger.debug("Using legacy Llama generation format")
|
|
2400
|
+
result_text = data['generation']
|
|
2401
|
+
elif 'results' in data and len(data['results']) > 0:
|
|
2402
|
+
# Amazon Titan format (InvokeModel)
|
|
2403
|
+
self.logger.debug("Using legacy Titan results format")
|
|
2404
|
+
result_text = data['results'][0].get('outputText', result_text)
|
|
2405
|
+
elif 'text' in data:
|
|
2406
|
+
# Direct text format
|
|
2407
|
+
self.logger.debug("Using direct text format")
|
|
2408
|
+
result_text = data['text']
|
|
2409
|
+
elif 'response' in data:
|
|
2410
|
+
# Some models use 'response' field
|
|
2411
|
+
self.logger.debug("Using response field format")
|
|
2412
|
+
result_text = data['response']
|
|
2413
|
+
elif 'choices' in data and len(data['choices']) > 0:
|
|
2414
|
+
# OpenAI-style format
|
|
2415
|
+
self.logger.debug("Using OpenAI-style choices format")
|
|
2416
|
+
choice = data['choices'][0]
|
|
2417
|
+
if 'message' in choice and 'content' in choice['message']:
|
|
2418
|
+
result_text = choice['message']['content']
|
|
2419
|
+
elif 'text' in choice:
|
|
2420
|
+
result_text = choice['text']
|
|
2421
|
+
else:
|
|
2422
|
+
# Fallback - try to find text in common locations
|
|
2423
|
+
self.logger.debug("Using fallback format - no recognized structure")
|
|
2424
|
+
result_text = data.get('text', data.get('output', data.get('response', str(data))))
|
|
2425
|
+
except Exception as e:
|
|
2426
|
+
self.logger.error(f"Error extracting AWS Bedrock response: {e}")
|
|
2427
|
+
result_text = str(data)
|
|
2428
|
+
|
|
2429
|
+
return result_text
|
|
2430
|
+
|
|
2431
|
+
def _call_streaming_api(self, url, payload, headers, provider_name):
|
|
2432
|
+
"""
|
|
2433
|
+
Make a streaming API call and progressively display the response.
|
|
2434
|
+
|
|
2435
|
+
Supports OpenAI-compatible streaming format (SSE with data: prefix).
|
|
2436
|
+
Works with OpenAI, Groq, OpenRouter, Azure AI, and Anthropic.
|
|
2437
|
+
|
|
2438
|
+
Args:
|
|
2439
|
+
url: API endpoint URL
|
|
2440
|
+
payload: Request payload (should include "stream": True)
|
|
2441
|
+
headers: Request headers
|
|
2442
|
+
provider_name: Name of the AI provider
|
|
2443
|
+
"""
|
|
2444
|
+
try:
|
|
2445
|
+
# Start streaming display
|
|
2446
|
+
if not self.start_streaming_response():
|
|
2447
|
+
self.logger.warning("Failed to start streaming display, falling back to non-streaming")
|
|
2448
|
+
# Fall back to non-streaming
|
|
2449
|
+
payload_copy = payload.copy()
|
|
2450
|
+
payload_copy.pop("stream", None)
|
|
2451
|
+
response = requests.post(url, json=payload_copy, headers=headers, timeout=60)
|
|
2452
|
+
response.raise_for_status()
|
|
2453
|
+
data = response.json()
|
|
2454
|
+
result_text = self._extract_response_text(provider_name, data)
|
|
2455
|
+
self.display_ai_response(result_text)
|
|
2456
|
+
return
|
|
2457
|
+
|
|
2458
|
+
# Make streaming request
|
|
2459
|
+
response = requests.post(url, json=payload, headers=headers, timeout=120, stream=True)
|
|
2460
|
+
response.raise_for_status()
|
|
2461
|
+
|
|
2462
|
+
accumulated_text = ""
|
|
2463
|
+
|
|
2464
|
+
for line in response.iter_lines():
|
|
2465
|
+
if not line:
|
|
2466
|
+
continue
|
|
2467
|
+
|
|
2468
|
+
line_text = line.decode('utf-8')
|
|
2469
|
+
|
|
2470
|
+
# Handle SSE format (data: prefix)
|
|
2471
|
+
if line_text.startswith('data: '):
|
|
2472
|
+
data_str = line_text[6:] # Remove 'data: ' prefix
|
|
2473
|
+
|
|
2474
|
+
# Check for stream end marker
|
|
2475
|
+
if data_str.strip() == '[DONE]':
|
|
2476
|
+
break
|
|
2477
|
+
|
|
2478
|
+
try:
|
|
2479
|
+
chunk_data = json.loads(data_str)
|
|
2480
|
+
|
|
2481
|
+
# Extract content based on provider format
|
|
2482
|
+
content = self._extract_streaming_chunk(chunk_data, provider_name)
|
|
2483
|
+
|
|
2484
|
+
if content:
|
|
2485
|
+
accumulated_text += content
|
|
2486
|
+
self.add_streaming_chunk(content)
|
|
2487
|
+
|
|
2488
|
+
except json.JSONDecodeError as e:
|
|
2489
|
+
self.logger.debug(f"Skipping non-JSON line: {data_str[:50]}...")
|
|
2490
|
+
continue
|
|
2491
|
+
|
|
2492
|
+
# Handle Anthropic's event-based format
|
|
2493
|
+
elif line_text.startswith('event: '):
|
|
2494
|
+
# Anthropic uses event: content_block_delta, etc.
|
|
2495
|
+
continue
|
|
2496
|
+
|
|
2497
|
+
# End streaming
|
|
2498
|
+
self.end_streaming_response()
|
|
2499
|
+
|
|
2500
|
+
if not accumulated_text:
|
|
2501
|
+
self.logger.warning("No content received from streaming response")
|
|
2502
|
+
self.app.after(0, self.app.update_output_text, "Error: No content received from streaming response.")
|
|
2503
|
+
|
|
2504
|
+
except requests.exceptions.RequestException as e:
|
|
2505
|
+
self.cancel_streaming()
|
|
2506
|
+
self.logger.error(f"Streaming API request failed: {e}")
|
|
2507
|
+
self.app.after(0, self.app.update_output_text, f"Streaming API Error: {e}")
|
|
2508
|
+
except Exception as e:
|
|
2509
|
+
self.cancel_streaming()
|
|
2510
|
+
self.logger.error(f"Streaming error: {e}", exc_info=True)
|
|
2511
|
+
self.app.after(0, self.app.update_output_text, f"Streaming Error: {e}")
|
|
2512
|
+
|
|
2513
|
+
def _extract_streaming_chunk(self, chunk_data, provider_name):
|
|
2514
|
+
"""
|
|
2515
|
+
Extract text content from a streaming chunk based on provider format.
|
|
2516
|
+
|
|
2517
|
+
Args:
|
|
2518
|
+
chunk_data: Parsed JSON chunk data
|
|
2519
|
+
provider_name: Name of the AI provider
|
|
2520
|
+
|
|
2521
|
+
Returns:
|
|
2522
|
+
Extracted text content or empty string
|
|
2523
|
+
"""
|
|
2524
|
+
try:
|
|
2525
|
+
if provider_name == "Anthropic AI":
|
|
2526
|
+
# Anthropic format: {"type": "content_block_delta", "delta": {"text": "..."}}
|
|
2527
|
+
if chunk_data.get("type") == "content_block_delta":
|
|
2528
|
+
return chunk_data.get("delta", {}).get("text", "")
|
|
2529
|
+
return ""
|
|
2530
|
+
else:
|
|
2531
|
+
# OpenAI-compatible format (OpenAI, Groq, OpenRouter, Azure AI)
|
|
2532
|
+
# Format: {"choices": [{"delta": {"content": "..."}}]}
|
|
2533
|
+
choices = chunk_data.get("choices", [])
|
|
2534
|
+
if choices and len(choices) > 0:
|
|
2535
|
+
delta = choices[0].get("delta", {})
|
|
2536
|
+
return delta.get("content", "")
|
|
2537
|
+
return ""
|
|
2538
|
+
except Exception as e:
|
|
2539
|
+
self.logger.debug(f"Error extracting streaming chunk: {e}")
|
|
2540
|
+
return ""
|
|
2541
|
+
|
|
2542
|
+
def open_model_editor(self, provider_name):
|
|
2543
|
+
"""Opens a Toplevel window to edit the model list for an AI provider."""
|
|
2544
|
+
dialog = tk.Toplevel(self.app)
|
|
2545
|
+
dialog.title(f"Edit {provider_name} Models")
|
|
2546
|
+
|
|
2547
|
+
self.app.update_idletasks()
|
|
2548
|
+
dialog_width = 400
|
|
2549
|
+
dialog_height = 200
|
|
2550
|
+
main_x, main_y, main_width, main_height = self.app.winfo_x(), self.app.winfo_y(), self.app.winfo_width(), self.app.winfo_height()
|
|
2551
|
+
pos_x = main_x + (main_width // 2) - (dialog_width // 2)
|
|
2552
|
+
pos_y = main_y + (main_height // 2) - (dialog_height // 2)
|
|
2553
|
+
dialog.geometry(f"{dialog_width}x{dialog_height}+{pos_x}+{pos_y}")
|
|
2554
|
+
dialog.transient(self.app)
|
|
2555
|
+
dialog.grab_set()
|
|
2556
|
+
|
|
2557
|
+
ttk.Label(dialog, text="One model per line. The first line is the default.").pack(pady=(10, 2))
|
|
2558
|
+
|
|
2559
|
+
text_area = tk.Text(dialog, height=7, width=45, undo=True)
|
|
2560
|
+
text_area.pack(pady=5, padx=10)
|
|
2561
|
+
|
|
2562
|
+
current_models = self.app.settings["tool_settings"].get(provider_name, {}).get("MODELS_LIST", [])
|
|
2563
|
+
text_area.insert("1.0", "\n".join(current_models))
|
|
2564
|
+
|
|
2565
|
+
save_button = ttk.Button(dialog, text="Save Changes",
|
|
2566
|
+
command=lambda: self.save_model_list(provider_name, text_area, dialog))
|
|
2567
|
+
save_button.pack(pady=5)
|
|
2568
|
+
|
|
2569
|
+
def save_model_list(self, provider_name, text_area, dialog):
|
|
2570
|
+
"""Saves the edited model list back to settings."""
|
|
2571
|
+
content = text_area.get("1.0", tk.END)
|
|
2572
|
+
new_list = [line.strip() for line in content.splitlines() if line.strip()]
|
|
2573
|
+
|
|
2574
|
+
if not new_list:
|
|
2575
|
+
self._show_warning("No Models", "Model list cannot be empty.")
|
|
2576
|
+
return
|
|
2577
|
+
|
|
2578
|
+
self.app.settings["tool_settings"][provider_name]["MODELS_LIST"] = new_list
|
|
2579
|
+
self.app.settings["tool_settings"][provider_name]["MODEL"] = new_list[0]
|
|
2580
|
+
|
|
2581
|
+
# Update the combobox values
|
|
2582
|
+
if provider_name in self.ai_widgets and "MODEL" in self.ai_widgets[provider_name]:
|
|
2583
|
+
# Find the combobox widget and update its values
|
|
2584
|
+
for provider, tab_frame in self.tabs.items():
|
|
2585
|
+
if provider == provider_name:
|
|
2586
|
+
# Update the model variable and refresh the UI
|
|
2587
|
+
self.ai_widgets[provider_name]["MODEL"].set(new_list[0])
|
|
2588
|
+
# We need to recreate the provider widgets to update the combobox values
|
|
2589
|
+
for widget in tab_frame.winfo_children():
|
|
2590
|
+
widget.destroy()
|
|
2591
|
+
self.create_provider_widgets(tab_frame, provider_name)
|
|
2592
|
+
break
|
|
2593
|
+
|
|
2594
|
+
self.app.save_settings()
|
|
2595
|
+
dialog.destroy()
|
|
2596
|
+
|
|
2597
|
+
def _get_ai_params_config(self, provider_name):
|
|
2598
|
+
"""Get parameter configuration for AI provider."""
|
|
2599
|
+
configs = {
|
|
2600
|
+
"Google AI": {
|
|
2601
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 2.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2602
|
+
"topP": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Cumulative probability threshold for token selection."},
|
|
2603
|
+
"topK": {"tab": "sampling", "type": "scale", "range": (1, 100), "res": 1, "tip": "Limits token selection to top K candidates."},
|
|
2604
|
+
"maxOutputTokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2605
|
+
"candidateCount": {"tab": "content", "type": "scale", "range": (1, 8), "res": 1, "tip": "Number of response candidates to generate."},
|
|
2606
|
+
"stopSequences": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."}
|
|
2607
|
+
},
|
|
2608
|
+
"Vertex AI": {
|
|
2609
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 2.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2610
|
+
"topP": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Cumulative probability threshold for token selection."},
|
|
2611
|
+
"topK": {"tab": "sampling", "type": "scale", "range": (1, 100), "res": 1, "tip": "Limits token selection to top K candidates."},
|
|
2612
|
+
"maxOutputTokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2613
|
+
"candidateCount": {"tab": "content", "type": "scale", "range": (1, 8), "res": 1, "tip": "Number of response candidates to generate."},
|
|
2614
|
+
"stopSequences": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."}
|
|
2615
|
+
},
|
|
2616
|
+
"Anthropic AI": {
|
|
2617
|
+
"max_tokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2618
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2619
|
+
"top_p": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Cumulative probability threshold for token selection."},
|
|
2620
|
+
"top_k": {"tab": "sampling", "type": "scale", "range": (1, 200), "res": 1, "tip": "Limits token selection to top K candidates."},
|
|
2621
|
+
"stop_sequences": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."}
|
|
2622
|
+
},
|
|
2623
|
+
"OpenAI": {
|
|
2624
|
+
"max_tokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2625
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 2.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2626
|
+
"top_p": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Nucleus sampling threshold."},
|
|
2627
|
+
"frequency_penalty": {"tab": "sampling", "type": "scale", "range": (-2.0, 2.0), "res": 0.1, "tip": "Penalizes frequent tokens."},
|
|
2628
|
+
"presence_penalty": {"tab": "sampling", "type": "scale", "range": (-2.0, 2.0), "res": 0.1, "tip": "Penalizes tokens that have appeared."},
|
|
2629
|
+
"seed": {"tab": "content", "type": "entry", "tip": "Random seed for reproducible outputs."},
|
|
2630
|
+
"stop": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."},
|
|
2631
|
+
"response_format": {"tab": "content", "type": "combo", "values": ["text", "json_object"], "tip": "Force JSON output."}
|
|
2632
|
+
},
|
|
2633
|
+
"Cohere AI": {
|
|
2634
|
+
"max_tokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2635
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2636
|
+
"p": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Top-p/nucleus sampling threshold."},
|
|
2637
|
+
"k": {"tab": "sampling", "type": "scale", "range": (1, 500), "res": 1, "tip": "Top-k sampling threshold."},
|
|
2638
|
+
"frequency_penalty": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.1, "tip": "Penalizes frequent tokens."},
|
|
2639
|
+
"presence_penalty": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.1, "tip": "Penalizes tokens that have appeared."},
|
|
2640
|
+
"stop_sequences": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."},
|
|
2641
|
+
"citation_quality": {"tab": "content", "type": "combo", "values": ["accurate", "fast"], "tip": "Citation quality vs. speed."}
|
|
2642
|
+
},
|
|
2643
|
+
"HuggingFace AI": {
|
|
2644
|
+
"max_tokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2645
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 2.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2646
|
+
"top_p": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Nucleus sampling threshold."},
|
|
2647
|
+
"seed": {"tab": "content", "type": "entry", "tip": "Random seed for reproducible outputs."},
|
|
2648
|
+
"stop_sequences": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."}
|
|
2649
|
+
},
|
|
2650
|
+
"Groq AI": {
|
|
2651
|
+
"max_tokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2652
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 2.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2653
|
+
"top_p": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Nucleus sampling threshold."},
|
|
2654
|
+
"frequency_penalty": {"tab": "sampling", "type": "scale", "range": (-2.0, 2.0), "res": 0.1, "tip": "Penalizes frequent tokens."},
|
|
2655
|
+
"presence_penalty": {"tab": "sampling", "type": "scale", "range": (-2.0, 2.0), "res": 0.1, "tip": "Penalizes tokens that have appeared."},
|
|
2656
|
+
"seed": {"tab": "content", "type": "entry", "tip": "Random seed for reproducible outputs."},
|
|
2657
|
+
"stop": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."},
|
|
2658
|
+
"response_format": {"tab": "content", "type": "combo", "values": ["text", "json_object"], "tip": "Force JSON output."}
|
|
2659
|
+
},
|
|
2660
|
+
"OpenRouterAI": {
|
|
2661
|
+
"max_tokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2662
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 2.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2663
|
+
"top_p": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Nucleus sampling threshold."},
|
|
2664
|
+
"top_k": {"tab": "sampling", "type": "scale", "range": (1, 100), "res": 1, "tip": "Limits token selection to top K candidates."},
|
|
2665
|
+
"frequency_penalty": {"tab": "sampling", "type": "scale", "range": (-2.0, 2.0), "res": 0.1, "tip": "Penalizes frequent tokens."},
|
|
2666
|
+
"presence_penalty": {"tab": "sampling", "type": "scale", "range": (-2.0, 2.0), "res": 0.1, "tip": "Penalizes tokens that have appeared."},
|
|
2667
|
+
"repetition_penalty": {"tab": "sampling", "type": "scale", "range": (0.0, 2.0), "res": 0.1, "tip": "Penalizes repetitive text."},
|
|
2668
|
+
"seed": {"tab": "content", "type": "entry", "tip": "Random seed for reproducible outputs."},
|
|
2669
|
+
"stop": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."}
|
|
2670
|
+
},
|
|
2671
|
+
"Azure AI": {
|
|
2672
|
+
"max_tokens": {"tab": "content", "type": "entry", "tip": "Maximum number of tokens to generate."},
|
|
2673
|
+
"temperature": {"tab": "sampling", "type": "scale", "range": (0.0, 2.0), "res": 0.1, "tip": "Controls randomness. Higher is more creative."},
|
|
2674
|
+
"top_p": {"tab": "sampling", "type": "scale", "range": (0.0, 1.0), "res": 0.05, "tip": "Nucleus sampling threshold."},
|
|
2675
|
+
"frequency_penalty": {"tab": "sampling", "type": "scale", "range": (-2.0, 2.0), "res": 0.1, "tip": "Penalizes frequent tokens."},
|
|
2676
|
+
"presence_penalty": {"tab": "sampling", "type": "scale", "range": (-2.0, 2.0), "res": 0.1, "tip": "Penalizes tokens that have appeared."},
|
|
2677
|
+
"seed": {"tab": "content", "type": "entry", "tip": "Random seed for reproducible outputs."},
|
|
2678
|
+
"stop": {"tab": "content", "type": "entry", "tip": "Comma-separated list of strings to stop generation."}
|
|
2679
|
+
}
|
|
2680
|
+
}
|
|
2681
|
+
|
|
2682
|
+
return configs.get(provider_name, {})
|
|
2683
|
+
|
|
2684
|
+
# ==================== Streaming Support Methods ====================
|
|
2685
|
+
|
|
2686
|
+
def enable_streaming(self, enabled: bool = True) -> bool:
|
|
2687
|
+
"""
|
|
2688
|
+
Enable or disable streaming mode for AI responses.
|
|
2689
|
+
|
|
2690
|
+
Args:
|
|
2691
|
+
enabled: Whether to enable streaming
|
|
2692
|
+
|
|
2693
|
+
Returns:
|
|
2694
|
+
True if streaming was enabled/disabled successfully
|
|
2695
|
+
"""
|
|
2696
|
+
if not STREAMING_AVAILABLE:
|
|
2697
|
+
self.logger.warning("Streaming is not available - module not loaded")
|
|
2698
|
+
return False
|
|
2699
|
+
|
|
2700
|
+
self._streaming_enabled = enabled
|
|
2701
|
+
self.logger.info(f"Streaming mode {'enabled' if enabled else 'disabled'}")
|
|
2702
|
+
return True
|
|
2703
|
+
|
|
2704
|
+
def is_streaming_enabled(self) -> bool:
|
|
2705
|
+
"""Check if streaming mode is enabled."""
|
|
2706
|
+
return self._streaming_enabled and STREAMING_AVAILABLE
|
|
2707
|
+
|
|
2708
|
+
def _get_output_text_widget(self):
|
|
2709
|
+
"""Get the current output text widget from the app."""
|
|
2710
|
+
try:
|
|
2711
|
+
current_tab_index = self.app.output_notebook.index(self.app.output_notebook.select())
|
|
2712
|
+
active_output_tab = self.app.output_tabs[current_tab_index]
|
|
2713
|
+
return active_output_tab.text
|
|
2714
|
+
except Exception as e:
|
|
2715
|
+
self.logger.error(f"Failed to get output text widget: {e}")
|
|
2716
|
+
return None
|
|
2717
|
+
|
|
2718
|
+
def _init_streaming_handler(self, text_widget):
|
|
2719
|
+
"""Initialize the streaming handler for a text widget."""
|
|
2720
|
+
if not STREAMING_AVAILABLE:
|
|
2721
|
+
return None
|
|
2722
|
+
|
|
2723
|
+
try:
|
|
2724
|
+
config = StreamConfig(
|
|
2725
|
+
chunk_delay_ms=10,
|
|
2726
|
+
batch_size=3,
|
|
2727
|
+
auto_scroll=True,
|
|
2728
|
+
highlight_new_text=False,
|
|
2729
|
+
use_threading=True
|
|
2730
|
+
)
|
|
2731
|
+
|
|
2732
|
+
self._streaming_manager = StreamingTextManager(
|
|
2733
|
+
text_widget,
|
|
2734
|
+
stream_config=config
|
|
2735
|
+
)
|
|
2736
|
+
|
|
2737
|
+
return self._streaming_manager
|
|
2738
|
+
except Exception as e:
|
|
2739
|
+
self.logger.error(f"Failed to initialize streaming handler: {e}")
|
|
2740
|
+
return None
|
|
2741
|
+
|
|
2742
|
+
def start_streaming_response(self, clear_existing: bool = True) -> bool:
|
|
2743
|
+
"""
|
|
2744
|
+
Start streaming an AI response to the output widget.
|
|
2745
|
+
|
|
2746
|
+
Args:
|
|
2747
|
+
clear_existing: Whether to clear existing content
|
|
2748
|
+
|
|
2749
|
+
Returns:
|
|
2750
|
+
True if streaming started successfully
|
|
2751
|
+
"""
|
|
2752
|
+
if not self.is_streaming_enabled():
|
|
2753
|
+
return False
|
|
2754
|
+
|
|
2755
|
+
text_widget = self._get_output_text_widget()
|
|
2756
|
+
if not text_widget:
|
|
2757
|
+
return False
|
|
2758
|
+
|
|
2759
|
+
# Enable the text widget for editing
|
|
2760
|
+
text_widget.config(state="normal")
|
|
2761
|
+
|
|
2762
|
+
manager = self._init_streaming_handler(text_widget)
|
|
2763
|
+
if not manager:
|
|
2764
|
+
return False
|
|
2765
|
+
|
|
2766
|
+
def on_progress(chars_received, total):
|
|
2767
|
+
self.logger.debug(f"Streaming progress: {chars_received} chars received")
|
|
2768
|
+
|
|
2769
|
+
def on_complete(metrics):
|
|
2770
|
+
self.logger.info(
|
|
2771
|
+
f"Streaming complete: {metrics.total_characters} chars "
|
|
2772
|
+
f"in {metrics.duration:.2f}s ({metrics.chars_per_second:.0f} chars/s)"
|
|
2773
|
+
)
|
|
2774
|
+
# Disable the text widget after streaming
|
|
2775
|
+
self.app.after(0, lambda: text_widget.config(state="disabled"))
|
|
2776
|
+
# Update stats
|
|
2777
|
+
self.app.after(10, self.app.update_all_stats)
|
|
2778
|
+
|
|
2779
|
+
return manager.start_streaming(
|
|
2780
|
+
clear_existing=clear_existing,
|
|
2781
|
+
on_progress=on_progress,
|
|
2782
|
+
on_complete=on_complete
|
|
2783
|
+
)
|
|
2784
|
+
|
|
2785
|
+
def add_streaming_chunk(self, chunk: str) -> bool:
|
|
2786
|
+
"""
|
|
2787
|
+
Add a chunk of text to the streaming response.
|
|
2788
|
+
|
|
2789
|
+
Args:
|
|
2790
|
+
chunk: Text chunk to add
|
|
2791
|
+
|
|
2792
|
+
Returns:
|
|
2793
|
+
True if chunk was added successfully
|
|
2794
|
+
"""
|
|
2795
|
+
if not self._streaming_manager:
|
|
2796
|
+
return False
|
|
2797
|
+
|
|
2798
|
+
return self._streaming_manager.add_stream_chunk(chunk)
|
|
2799
|
+
|
|
2800
|
+
def end_streaming_response(self):
|
|
2801
|
+
"""End the streaming response and finalize."""
|
|
2802
|
+
if not self._streaming_manager:
|
|
2803
|
+
return None
|
|
2804
|
+
|
|
2805
|
+
metrics = self._streaming_manager.end_streaming()
|
|
2806
|
+
|
|
2807
|
+
# Save settings after streaming completes
|
|
2808
|
+
self.app.save_settings()
|
|
2809
|
+
|
|
2810
|
+
return metrics
|
|
2811
|
+
|
|
2812
|
+
def cancel_streaming(self):
|
|
2813
|
+
"""Cancel the current streaming operation."""
|
|
2814
|
+
if self._streaming_manager:
|
|
2815
|
+
self._streaming_manager.cancel()
|
|
2816
|
+
self._streaming_manager = None
|
|
2817
|
+
|
|
2818
|
+
def process_streaming_response(self, response_iterator):
|
|
2819
|
+
"""
|
|
2820
|
+
Process a streaming response from an API.
|
|
2821
|
+
|
|
2822
|
+
This method handles the full streaming lifecycle:
|
|
2823
|
+
1. Start streaming
|
|
2824
|
+
2. Process each chunk from the iterator
|
|
2825
|
+
3. End streaming
|
|
2826
|
+
|
|
2827
|
+
Args:
|
|
2828
|
+
response_iterator: Iterator yielding response chunks
|
|
2829
|
+
|
|
2830
|
+
Returns:
|
|
2831
|
+
The complete accumulated text, or None if streaming failed
|
|
2832
|
+
"""
|
|
2833
|
+
if not self.start_streaming_response():
|
|
2834
|
+
self.logger.warning("Failed to start streaming, falling back to non-streaming")
|
|
2835
|
+
return None
|
|
2836
|
+
|
|
2837
|
+
try:
|
|
2838
|
+
for chunk in response_iterator:
|
|
2839
|
+
if not self.add_streaming_chunk(chunk):
|
|
2840
|
+
self.logger.warning("Failed to add chunk, stopping stream")
|
|
2841
|
+
break
|
|
2842
|
+
|
|
2843
|
+
self.end_streaming_response()
|
|
2844
|
+
return self._streaming_manager.get_accumulated_text() if self._streaming_manager else None
|
|
2845
|
+
|
|
2846
|
+
except Exception as e:
|
|
2847
|
+
self.logger.error(f"Error during streaming: {e}")
|
|
2848
|
+
self.cancel_streaming()
|
|
2849
|
+
return None
|
|
2850
|
+
|
|
2851
|
+
def display_text_with_streaming(self, text: str, chunk_size: int = 50):
|
|
2852
|
+
"""
|
|
2853
|
+
Display text progressively using streaming, simulating a streaming response.
|
|
2854
|
+
|
|
2855
|
+
Useful for displaying large text content progressively.
|
|
2856
|
+
|
|
2857
|
+
Args:
|
|
2858
|
+
text: The text to display
|
|
2859
|
+
chunk_size: Size of each chunk to display
|
|
2860
|
+
"""
|
|
2861
|
+
if not self.is_streaming_enabled():
|
|
2862
|
+
# Fall back to regular display
|
|
2863
|
+
self.app.after(0, self.app.update_output_text, text)
|
|
2864
|
+
return
|
|
2865
|
+
|
|
2866
|
+
def chunk_generator():
|
|
2867
|
+
for i in range(0, len(text), chunk_size):
|
|
2868
|
+
yield text[i:i + chunk_size]
|
|
2869
|
+
|
|
2870
|
+
# Run in background thread
|
|
2871
|
+
def stream_text():
|
|
2872
|
+
self.process_streaming_response(chunk_generator())
|
|
2873
|
+
|
|
2874
|
+
thread = threading.Thread(target=stream_text, daemon=True)
|
|
2875
|
+
thread.start()
|
|
2876
|
+
|
|
2877
|
+
def display_ai_response(self, text: str, min_streaming_length: int = 500):
|
|
2878
|
+
"""
|
|
2879
|
+
Unified method to display AI response with automatic streaming for large responses.
|
|
2880
|
+
|
|
2881
|
+
This method should be used by all AI providers to display their responses.
|
|
2882
|
+
It automatically decides whether to use streaming based on response length.
|
|
2883
|
+
|
|
2884
|
+
Args:
|
|
2885
|
+
text: The AI response text to display
|
|
2886
|
+
min_streaming_length: Minimum text length to trigger streaming (default 500)
|
|
2887
|
+
"""
|
|
2888
|
+
if self.is_streaming_enabled() and len(text) > min_streaming_length:
|
|
2889
|
+
self.logger.debug(f"Using streaming display for response ({len(text)} chars)")
|
|
2890
|
+
self.app.after(0, lambda t=text: self.display_text_with_streaming(t))
|
|
2891
|
+
else:
|
|
2892
|
+
self.app.after(0, self.app.update_output_text, text)
|