pomera-ai-commander 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +680 -0
- package/bin/pomera-ai-commander.js +62 -0
- package/core/__init__.py +66 -0
- package/core/__pycache__/__init__.cpython-313.pyc +0 -0
- package/core/__pycache__/app_context.cpython-313.pyc +0 -0
- package/core/__pycache__/async_text_processor.cpython-313.pyc +0 -0
- package/core/__pycache__/backup_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/backup_recovery_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/content_hash_cache.cpython-313.pyc +0 -0
- package/core/__pycache__/context_menu.cpython-313.pyc +0 -0
- package/core/__pycache__/data_validator.cpython-313.pyc +0 -0
- package/core/__pycache__/database_connection_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_curl_settings_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_promera_ai_settings_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_schema.cpython-313.pyc +0 -0
- package/core/__pycache__/database_schema_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_settings_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/database_settings_manager_interface.cpython-313.pyc +0 -0
- package/core/__pycache__/dialog_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/efficient_line_numbers.cpython-313.pyc +0 -0
- package/core/__pycache__/error_handler.cpython-313.pyc +0 -0
- package/core/__pycache__/error_service.cpython-313.pyc +0 -0
- package/core/__pycache__/event_consolidator.cpython-313.pyc +0 -0
- package/core/__pycache__/memory_efficient_text_widget.cpython-313.pyc +0 -0
- package/core/__pycache__/migration_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/migration_test_suite.cpython-313.pyc +0 -0
- package/core/__pycache__/migration_validator.cpython-313.pyc +0 -0
- package/core/__pycache__/optimized_find_replace.cpython-313.pyc +0 -0
- package/core/__pycache__/optimized_pattern_engine.cpython-313.pyc +0 -0
- package/core/__pycache__/optimized_search_highlighter.cpython-313.pyc +0 -0
- package/core/__pycache__/performance_monitor.cpython-313.pyc +0 -0
- package/core/__pycache__/persistence_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/progressive_stats_calculator.cpython-313.pyc +0 -0
- package/core/__pycache__/regex_pattern_cache.cpython-313.pyc +0 -0
- package/core/__pycache__/regex_pattern_library.cpython-313.pyc +0 -0
- package/core/__pycache__/search_operation_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/settings_defaults_registry.cpython-313.pyc +0 -0
- package/core/__pycache__/settings_integrity_validator.cpython-313.pyc +0 -0
- package/core/__pycache__/settings_serializer.cpython-313.pyc +0 -0
- package/core/__pycache__/settings_validator.cpython-313.pyc +0 -0
- package/core/__pycache__/smart_stats_calculator.cpython-313.pyc +0 -0
- package/core/__pycache__/statistics_update_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/stats_config_manager.cpython-313.pyc +0 -0
- package/core/__pycache__/streaming_text_handler.cpython-313.pyc +0 -0
- package/core/__pycache__/task_scheduler.cpython-313.pyc +0 -0
- package/core/__pycache__/visibility_monitor.cpython-313.pyc +0 -0
- package/core/__pycache__/widget_cache.cpython-313.pyc +0 -0
- package/core/app_context.py +482 -0
- package/core/async_text_processor.py +422 -0
- package/core/backup_manager.py +656 -0
- package/core/backup_recovery_manager.py +1034 -0
- package/core/content_hash_cache.py +509 -0
- package/core/context_menu.py +313 -0
- package/core/data_validator.py +1067 -0
- package/core/database_connection_manager.py +745 -0
- package/core/database_curl_settings_manager.py +609 -0
- package/core/database_promera_ai_settings_manager.py +447 -0
- package/core/database_schema.py +412 -0
- package/core/database_schema_manager.py +396 -0
- package/core/database_settings_manager.py +1508 -0
- package/core/database_settings_manager_interface.py +457 -0
- package/core/dialog_manager.py +735 -0
- package/core/efficient_line_numbers.py +511 -0
- package/core/error_handler.py +747 -0
- package/core/error_service.py +431 -0
- package/core/event_consolidator.py +512 -0
- package/core/mcp/__init__.py +43 -0
- package/core/mcp/__pycache__/__init__.cpython-313.pyc +0 -0
- package/core/mcp/__pycache__/protocol.cpython-313.pyc +0 -0
- package/core/mcp/__pycache__/schema.cpython-313.pyc +0 -0
- package/core/mcp/__pycache__/server_stdio.cpython-313.pyc +0 -0
- package/core/mcp/__pycache__/tool_registry.cpython-313.pyc +0 -0
- package/core/mcp/protocol.py +288 -0
- package/core/mcp/schema.py +251 -0
- package/core/mcp/server_stdio.py +299 -0
- package/core/mcp/tool_registry.py +2345 -0
- package/core/memory_efficient_text_widget.py +712 -0
- package/core/migration_manager.py +915 -0
- package/core/migration_test_suite.py +1086 -0
- package/core/migration_validator.py +1144 -0
- package/core/optimized_find_replace.py +715 -0
- package/core/optimized_pattern_engine.py +424 -0
- package/core/optimized_search_highlighter.py +553 -0
- package/core/performance_monitor.py +675 -0
- package/core/persistence_manager.py +713 -0
- package/core/progressive_stats_calculator.py +632 -0
- package/core/regex_pattern_cache.py +530 -0
- package/core/regex_pattern_library.py +351 -0
- package/core/search_operation_manager.py +435 -0
- package/core/settings_defaults_registry.py +1087 -0
- package/core/settings_integrity_validator.py +1112 -0
- package/core/settings_serializer.py +558 -0
- package/core/settings_validator.py +1824 -0
- package/core/smart_stats_calculator.py +710 -0
- package/core/statistics_update_manager.py +619 -0
- package/core/stats_config_manager.py +858 -0
- package/core/streaming_text_handler.py +723 -0
- package/core/task_scheduler.py +596 -0
- package/core/update_pattern_library.py +169 -0
- package/core/visibility_monitor.py +596 -0
- package/core/widget_cache.py +498 -0
- package/mcp.json +61 -0
- package/package.json +57 -0
- package/pomera.py +7483 -0
- package/pomera_mcp_server.py +144 -0
- package/tools/__init__.py +5 -0
- package/tools/__pycache__/__init__.cpython-313.pyc +0 -0
- package/tools/__pycache__/ai_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/ascii_art_generator.cpython-313.pyc +0 -0
- package/tools/__pycache__/base64_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/base_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/case_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/column_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/cron_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/curl_history.cpython-313.pyc +0 -0
- package/tools/__pycache__/curl_processor.cpython-313.pyc +0 -0
- package/tools/__pycache__/curl_settings.cpython-313.pyc +0 -0
- package/tools/__pycache__/curl_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/diff_viewer.cpython-313.pyc +0 -0
- package/tools/__pycache__/email_extraction_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/email_header_analyzer.cpython-313.pyc +0 -0
- package/tools/__pycache__/extraction_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/find_replace.cpython-313.pyc +0 -0
- package/tools/__pycache__/folder_file_reporter.cpython-313.pyc +0 -0
- package/tools/__pycache__/folder_file_reporter_adapter.cpython-313.pyc +0 -0
- package/tools/__pycache__/generator_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/hash_generator.cpython-313.pyc +0 -0
- package/tools/__pycache__/html_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/huggingface_helper.cpython-313.pyc +0 -0
- package/tools/__pycache__/jsonxml_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/line_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/list_comparator.cpython-313.pyc +0 -0
- package/tools/__pycache__/markdown_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/mcp_widget.cpython-313.pyc +0 -0
- package/tools/__pycache__/notes_widget.cpython-313.pyc +0 -0
- package/tools/__pycache__/number_base_converter.cpython-313.pyc +0 -0
- package/tools/__pycache__/regex_extractor.cpython-313.pyc +0 -0
- package/tools/__pycache__/slug_generator.cpython-313.pyc +0 -0
- package/tools/__pycache__/sorter_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/string_escape_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/text_statistics_tool.cpython-313.pyc +0 -0
- package/tools/__pycache__/text_wrapper.cpython-313.pyc +0 -0
- package/tools/__pycache__/timestamp_converter.cpython-313.pyc +0 -0
- package/tools/__pycache__/tool_loader.cpython-313.pyc +0 -0
- package/tools/__pycache__/translator_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/url_link_extractor.cpython-313.pyc +0 -0
- package/tools/__pycache__/url_parser.cpython-313.pyc +0 -0
- package/tools/__pycache__/whitespace_tools.cpython-313.pyc +0 -0
- package/tools/__pycache__/word_frequency_counter.cpython-313.pyc +0 -0
- package/tools/ai_tools.py +2892 -0
- package/tools/ascii_art_generator.py +353 -0
- package/tools/base64_tools.py +184 -0
- package/tools/base_tool.py +511 -0
- package/tools/case_tool.py +309 -0
- package/tools/column_tools.py +396 -0
- package/tools/cron_tool.py +885 -0
- package/tools/curl_history.py +601 -0
- package/tools/curl_processor.py +1208 -0
- package/tools/curl_settings.py +503 -0
- package/tools/curl_tool.py +5467 -0
- package/tools/diff_viewer.py +1072 -0
- package/tools/email_extraction_tool.py +249 -0
- package/tools/email_header_analyzer.py +426 -0
- package/tools/extraction_tools.py +250 -0
- package/tools/find_replace.py +1751 -0
- package/tools/folder_file_reporter.py +1463 -0
- package/tools/folder_file_reporter_adapter.py +480 -0
- package/tools/generator_tools.py +1217 -0
- package/tools/hash_generator.py +256 -0
- package/tools/html_tool.py +657 -0
- package/tools/huggingface_helper.py +449 -0
- package/tools/jsonxml_tool.py +730 -0
- package/tools/line_tools.py +419 -0
- package/tools/list_comparator.py +720 -0
- package/tools/markdown_tools.py +562 -0
- package/tools/mcp_widget.py +1417 -0
- package/tools/notes_widget.py +973 -0
- package/tools/number_base_converter.py +373 -0
- package/tools/regex_extractor.py +572 -0
- package/tools/slug_generator.py +311 -0
- package/tools/sorter_tools.py +459 -0
- package/tools/string_escape_tool.py +393 -0
- package/tools/text_statistics_tool.py +366 -0
- package/tools/text_wrapper.py +431 -0
- package/tools/timestamp_converter.py +422 -0
- package/tools/tool_loader.py +710 -0
- package/tools/translator_tools.py +523 -0
- package/tools/url_link_extractor.py +262 -0
- package/tools/url_parser.py +205 -0
- package/tools/whitespace_tools.py +356 -0
- package/tools/word_frequency_counter.py +147 -0
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
"""Helper functions for HuggingFace model integration."""
|
|
2
|
+
import json
|
|
3
|
+
from typing import Dict, Any, Optional, Union, List
|
|
4
|
+
from huggingface_hub import InferenceClient, model_info
|
|
5
|
+
from huggingface_hub.utils import HfHubHTTPError
|
|
6
|
+
|
|
7
|
+
def process_huggingface_request(api_key: str, prompt: str, settings: Dict[str, Any],
|
|
8
|
+
update_callback, logger) -> None:
|
|
9
|
+
"""Process HuggingFace AI request with proper task handling.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
api_key: HuggingFace API key
|
|
13
|
+
prompt: User's input prompt
|
|
14
|
+
settings: Dictionary containing model settings
|
|
15
|
+
update_callback: Function to call with the result or error message
|
|
16
|
+
logger: Logger instance for logging
|
|
17
|
+
"""
|
|
18
|
+
try:
|
|
19
|
+
# Add timeout configuration for better reliability (default 60 seconds)
|
|
20
|
+
timeout = int(settings.get("timeout", 60))
|
|
21
|
+
client = InferenceClient(token=api_key, timeout=timeout)
|
|
22
|
+
model_name = settings.get("MODEL", "")
|
|
23
|
+
|
|
24
|
+
if not model_name:
|
|
25
|
+
update_callback("Error: No model specified in settings.")
|
|
26
|
+
return
|
|
27
|
+
|
|
28
|
+
# Detect supported tasks from HuggingFace API
|
|
29
|
+
logger.info(f"Detecting supported tasks for model: {model_name}")
|
|
30
|
+
supported_tasks = get_model_supported_tasks(model_name, api_key, logger)
|
|
31
|
+
|
|
32
|
+
if not supported_tasks:
|
|
33
|
+
# Try to determine if this is an API key issue or model name issue
|
|
34
|
+
logger.warning(f"Could not determine supported tasks for model '{model_name}'")
|
|
35
|
+
# Don't return here - let the routing logic handle the fallback
|
|
36
|
+
else:
|
|
37
|
+
logger.info(f"Model '{model_name}' supports tasks: {supported_tasks}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Route to appropriate handler based on supported tasks
|
|
42
|
+
if "text-classification" in supported_tasks:
|
|
43
|
+
handle_text_classification(client, prompt, model_name, update_callback, logger)
|
|
44
|
+
elif any(task in supported_tasks for task in ["text-generation", "conversational"]):
|
|
45
|
+
handle_chat_completion(client, prompt, model_name, settings, update_callback, logger)
|
|
46
|
+
else:
|
|
47
|
+
# Enhanced model type detection
|
|
48
|
+
is_base_model = detect_base_model(model_name, logger)
|
|
49
|
+
is_chat_model = detect_chat_model(model_name, logger)
|
|
50
|
+
|
|
51
|
+
if is_chat_model and not is_base_model:
|
|
52
|
+
# Definitely a chat model, try chat completion first
|
|
53
|
+
logger.info(f"Detected chat model '{model_name}', trying chat completion")
|
|
54
|
+
try:
|
|
55
|
+
handle_chat_completion(client, prompt, model_name, settings, update_callback, logger)
|
|
56
|
+
return
|
|
57
|
+
except Exception as chat_error:
|
|
58
|
+
logger.warning(f"Chat completion failed: {chat_error}")
|
|
59
|
+
|
|
60
|
+
# Try text generation as fallback
|
|
61
|
+
try:
|
|
62
|
+
logger.info("Attempting text generation as fallback")
|
|
63
|
+
handle_text_generation(client, prompt, model_name, settings, update_callback, logger)
|
|
64
|
+
return
|
|
65
|
+
except Exception as text_gen_fallback_error:
|
|
66
|
+
logger.warning(f"Text generation fallback failed: {text_gen_fallback_error}")
|
|
67
|
+
|
|
68
|
+
elif is_base_model or not supported_tasks:
|
|
69
|
+
# Base model or no tasks detected, try text generation first
|
|
70
|
+
logger.info(f"Detected base model or no tasks found for '{model_name}', trying text generation")
|
|
71
|
+
try:
|
|
72
|
+
handle_text_generation(client, prompt, model_name, settings, update_callback, logger)
|
|
73
|
+
return
|
|
74
|
+
except Exception as text_gen_error:
|
|
75
|
+
logger.warning(f"Text generation failed: {text_gen_error}")
|
|
76
|
+
|
|
77
|
+
# Try chat completion as final fallback
|
|
78
|
+
try:
|
|
79
|
+
logger.info("Attempting chat completion as final fallback")
|
|
80
|
+
handle_chat_completion(client, prompt, model_name, settings, update_callback, logger)
|
|
81
|
+
return
|
|
82
|
+
except Exception as chat_fallback_error:
|
|
83
|
+
logger.warning(f"Chat completion fallback failed: {chat_fallback_error}")
|
|
84
|
+
|
|
85
|
+
else:
|
|
86
|
+
# Unknown model type, try both approaches
|
|
87
|
+
logger.info(f"Unknown model type for '{model_name}', trying both approaches")
|
|
88
|
+
|
|
89
|
+
# Try text generation first (more common for base models)
|
|
90
|
+
try:
|
|
91
|
+
handle_text_generation(client, prompt, model_name, settings, update_callback, logger)
|
|
92
|
+
return
|
|
93
|
+
except Exception as text_gen_error:
|
|
94
|
+
logger.warning(f"Text generation failed: {text_gen_error}")
|
|
95
|
+
|
|
96
|
+
# Try chat completion as fallback
|
|
97
|
+
try:
|
|
98
|
+
logger.info("Attempting chat completion as fallback")
|
|
99
|
+
handle_chat_completion(client, prompt, model_name, settings, update_callback, logger)
|
|
100
|
+
return
|
|
101
|
+
except Exception as chat_fallback_error:
|
|
102
|
+
logger.warning(f"Chat completion fallback failed: {chat_fallback_error}")
|
|
103
|
+
|
|
104
|
+
# All methods failed
|
|
105
|
+
error_msg = f"Model '{model_name}' is not supported by HuggingFace Inference API.\n\n"
|
|
106
|
+
|
|
107
|
+
if "isn't deployed by any Inference Provider" in str(supported_tasks):
|
|
108
|
+
error_msg += "This model is not deployed by any Inference Provider.\n\n"
|
|
109
|
+
|
|
110
|
+
if supported_tasks:
|
|
111
|
+
error_msg += f"Detected tasks: {', '.join(supported_tasks)}\n\n"
|
|
112
|
+
else:
|
|
113
|
+
error_msg += "Could not determine supported tasks.\n\n"
|
|
114
|
+
|
|
115
|
+
error_msg += "Solutions:\n"
|
|
116
|
+
error_msg += "1. Try a chat-optimized version (e.g., add '-chat' to model name)\n"
|
|
117
|
+
error_msg += "2. Use a model that's deployed on HuggingFace Inference API\n"
|
|
118
|
+
error_msg += "3. Verify your HuggingFace API key is valid\n"
|
|
119
|
+
error_msg += "4. Check the model page for inference provider availability"
|
|
120
|
+
|
|
121
|
+
logger.warning(error_msg)
|
|
122
|
+
update_callback(error_msg)
|
|
123
|
+
|
|
124
|
+
except HfHubHTTPError as e:
|
|
125
|
+
error_msg = format_hf_http_error(e, settings.get("MODEL"))
|
|
126
|
+
logger.error(error_msg, exc_info=True)
|
|
127
|
+
update_callback(error_msg)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
error_msg = format_generic_error(e, settings.get("MODEL"))
|
|
130
|
+
logger.error(error_msg, exc_info=True)
|
|
131
|
+
update_callback(error_msg)
|
|
132
|
+
|
|
133
|
+
def detect_base_model(model_name: str, logger) -> bool:
|
|
134
|
+
"""Detect if a model is a base model (not fine-tuned for chat)."""
|
|
135
|
+
model_lower = model_name.lower()
|
|
136
|
+
|
|
137
|
+
# Patterns that indicate base models
|
|
138
|
+
base_patterns = [
|
|
139
|
+
'-hf', # HuggingFace format indicator
|
|
140
|
+
'base', # Explicitly named base models
|
|
141
|
+
'pretrained', # Pre-trained models
|
|
142
|
+
'foundation', # Foundation models
|
|
143
|
+
'raw', # Raw/untuned models
|
|
144
|
+
'original', # Original models
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
# Patterns that indicate NOT base models (chat/instruct models)
|
|
148
|
+
non_base_patterns = [
|
|
149
|
+
'chat',
|
|
150
|
+
'instruct',
|
|
151
|
+
'assistant',
|
|
152
|
+
'conversation',
|
|
153
|
+
'dialogue',
|
|
154
|
+
'it', # Instruction tuned
|
|
155
|
+
'sft', # Supervised fine-tuned
|
|
156
|
+
'dpo', # Direct preference optimization
|
|
157
|
+
'rlhf', # Reinforcement learning from human feedback
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
# Check for non-base patterns first (these override base patterns)
|
|
161
|
+
has_non_base = any(pattern in model_lower for pattern in non_base_patterns)
|
|
162
|
+
if has_non_base:
|
|
163
|
+
logger.debug(f"Model '{model_name}' has non-base patterns, not a base model")
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
# Check for base patterns
|
|
167
|
+
has_base = any(pattern in model_lower for pattern in base_patterns)
|
|
168
|
+
if has_base:
|
|
169
|
+
logger.debug(f"Model '{model_name}' has base patterns, likely a base model")
|
|
170
|
+
return True
|
|
171
|
+
|
|
172
|
+
# Additional heuristics based on model naming conventions
|
|
173
|
+
# Models without specific suffixes are often base models
|
|
174
|
+
if not any(suffix in model_lower for suffix in ['-chat', '-instruct', '-it', '-sft']):
|
|
175
|
+
# Check if it's a well-known base model pattern
|
|
176
|
+
if any(pattern in model_lower for pattern in ['llama', 'mistral', 'qwen', 'phi']):
|
|
177
|
+
# These are often base models unless explicitly marked otherwise
|
|
178
|
+
logger.debug(f"Model '{model_name}' appears to be a base model based on naming convention")
|
|
179
|
+
return True
|
|
180
|
+
|
|
181
|
+
logger.debug(f"Model '{model_name}' does not appear to be a base model")
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
def detect_chat_model(model_name: str, logger) -> bool:
|
|
185
|
+
"""Detect if a model is specifically designed for chat/conversation."""
|
|
186
|
+
model_lower = model_name.lower()
|
|
187
|
+
|
|
188
|
+
# Patterns that strongly indicate chat models
|
|
189
|
+
chat_patterns = [
|
|
190
|
+
'chat',
|
|
191
|
+
'instruct',
|
|
192
|
+
'assistant',
|
|
193
|
+
'conversation',
|
|
194
|
+
'dialogue',
|
|
195
|
+
'dialog', # Alternative spelling
|
|
196
|
+
'it', # Instruction tuned
|
|
197
|
+
'sft', # Supervised fine-tuned
|
|
198
|
+
'dpo', # Direct preference optimization
|
|
199
|
+
'rlhf', # Reinforcement learning from human feedback
|
|
200
|
+
'alpaca', # Alpaca models are instruction-tuned
|
|
201
|
+
'vicuna', # Vicuna models are chat-tuned
|
|
202
|
+
'wizard', # WizardLM models are instruction-tuned
|
|
203
|
+
'dialogpt', # DialoGPT models are for dialogue
|
|
204
|
+
]
|
|
205
|
+
|
|
206
|
+
has_chat_pattern = any(pattern in model_lower for pattern in chat_patterns)
|
|
207
|
+
if has_chat_pattern:
|
|
208
|
+
logger.debug(f"Model '{model_name}' has chat patterns, likely a chat model")
|
|
209
|
+
return True
|
|
210
|
+
|
|
211
|
+
logger.debug(f"Model '{model_name}' does not appear to be a chat model")
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
def get_model_supported_tasks(model_name: str, api_key: str, logger) -> List[str]:
|
|
215
|
+
"""Query HuggingFace API to get the supported tasks for a model.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
model_name: Name of the HuggingFace model
|
|
219
|
+
api_key: HuggingFace API key
|
|
220
|
+
logger: Logger instance
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
List of supported task names (e.g., ['text-classification', 'text-generation'])
|
|
224
|
+
"""
|
|
225
|
+
try:
|
|
226
|
+
info = model_info(model_name, token=api_key)
|
|
227
|
+
|
|
228
|
+
# Check if this is a LoRA adapter model
|
|
229
|
+
is_lora_adapter = False
|
|
230
|
+
if hasattr(info, 'tags') and info.tags:
|
|
231
|
+
is_lora_adapter = any('lora' in tag.lower() or 'peft' in tag.lower() for tag in info.tags)
|
|
232
|
+
|
|
233
|
+
# Check model name patterns for LoRA adapters
|
|
234
|
+
if not is_lora_adapter:
|
|
235
|
+
lora_patterns = ['lora', 'peft', 'adapter', 'fingpt']
|
|
236
|
+
is_lora_adapter = any(pattern in model_name.lower() for pattern in lora_patterns)
|
|
237
|
+
|
|
238
|
+
if is_lora_adapter:
|
|
239
|
+
logger.info(f"Detected LoRA adapter model: {model_name}")
|
|
240
|
+
# LoRA adapters typically don't have inference providers
|
|
241
|
+
# Return empty list to trigger special handling
|
|
242
|
+
return []
|
|
243
|
+
|
|
244
|
+
# Get pipeline_tag (primary task)
|
|
245
|
+
tasks = []
|
|
246
|
+
if hasattr(info, 'pipeline_tag') and info.pipeline_tag:
|
|
247
|
+
tasks.append(info.pipeline_tag)
|
|
248
|
+
|
|
249
|
+
# Also check tags for additional supported tasks
|
|
250
|
+
if hasattr(info, 'tags') and info.tags:
|
|
251
|
+
task_tags = [tag for tag in info.tags if any(
|
|
252
|
+
keyword in tag for keyword in
|
|
253
|
+
['classification', 'generation', 'conversational', 'sentiment']
|
|
254
|
+
)]
|
|
255
|
+
tasks.extend(task_tags)
|
|
256
|
+
|
|
257
|
+
# Remove duplicates while preserving order
|
|
258
|
+
seen = set()
|
|
259
|
+
unique_tasks = []
|
|
260
|
+
for task in tasks:
|
|
261
|
+
if task not in seen:
|
|
262
|
+
seen.add(task)
|
|
263
|
+
unique_tasks.append(task)
|
|
264
|
+
|
|
265
|
+
return unique_tasks
|
|
266
|
+
|
|
267
|
+
except Exception as e:
|
|
268
|
+
logger.warning(f"Could not fetch model info for '{model_name}': {e}")
|
|
269
|
+
return []
|
|
270
|
+
|
|
271
|
+
def handle_text_classification(client: InferenceClient, prompt: str, model_name: str,
|
|
272
|
+
update_callback, logger) -> None:
|
|
273
|
+
"""Handle text classification models (e.g., sentiment analysis, categorization)."""
|
|
274
|
+
try:
|
|
275
|
+
logger.info(f"Running text classification on model: {model_name}")
|
|
276
|
+
result = client.text_classification(prompt, model=model_name)
|
|
277
|
+
|
|
278
|
+
# Format the classification results
|
|
279
|
+
if hasattr(result, 'label') and hasattr(result, 'score'):
|
|
280
|
+
response_text = f"Classification Result:\n\nLabel: {result.label}\nConfidence: {result.score:.4f} ({result.score*100:.2f}%)"
|
|
281
|
+
elif isinstance(result, list) and len(result) > 0:
|
|
282
|
+
response_text = "Classification Results:\n\n"
|
|
283
|
+
for i, item in enumerate(result, 1):
|
|
284
|
+
if hasattr(item, 'label') and hasattr(item, 'score'):
|
|
285
|
+
response_text += f"{i}. {item.label}: {item.score:.4f} ({item.score*100:.2f}%)\n"
|
|
286
|
+
else:
|
|
287
|
+
response_text += f"{i}. {item}\n"
|
|
288
|
+
else:
|
|
289
|
+
response_text = f"Classification Result:\n\n{str(result)}"
|
|
290
|
+
|
|
291
|
+
logger.info("Text classification completed successfully")
|
|
292
|
+
update_callback(response_text)
|
|
293
|
+
|
|
294
|
+
except Exception as e:
|
|
295
|
+
error_msg = f"Text Classification Error: {str(e)}\n\n"
|
|
296
|
+
error_msg += f"Failed to classify text using model '{model_name}'.\n"
|
|
297
|
+
error_msg += "Please verify the model supports text classification and try again."
|
|
298
|
+
logger.error(error_msg, exc_info=True)
|
|
299
|
+
update_callback(error_msg)
|
|
300
|
+
|
|
301
|
+
def handle_text_generation(client: InferenceClient, prompt: str, model_name: str,
|
|
302
|
+
settings: Dict[str, Any], update_callback, logger) -> None:
|
|
303
|
+
"""Handle text generation models (base models without chat formatting)."""
|
|
304
|
+
try:
|
|
305
|
+
logger.info(f"Running text generation on model: {model_name}")
|
|
306
|
+
|
|
307
|
+
# Build parameters for text generation
|
|
308
|
+
params = {"model": model_name}
|
|
309
|
+
|
|
310
|
+
# Add supported parameters
|
|
311
|
+
for param_name, param_type in [
|
|
312
|
+
("max_new_tokens", int),
|
|
313
|
+
("temperature", float),
|
|
314
|
+
("top_p", float),
|
|
315
|
+
("top_k", int),
|
|
316
|
+
("repetition_penalty", float),
|
|
317
|
+
("do_sample", bool)
|
|
318
|
+
]:
|
|
319
|
+
if param_name in settings:
|
|
320
|
+
try:
|
|
321
|
+
if param_type == bool:
|
|
322
|
+
# Handle boolean conversion for do_sample
|
|
323
|
+
if isinstance(settings[param_name], str):
|
|
324
|
+
params[param_name] = settings[param_name].lower() in ('true', '1', 'yes', 'on')
|
|
325
|
+
else:
|
|
326
|
+
params[param_name] = bool(settings[param_name])
|
|
327
|
+
else:
|
|
328
|
+
params[param_name] = param_type(settings[param_name])
|
|
329
|
+
except (ValueError, TypeError):
|
|
330
|
+
logger.warning(f"Could not convert {param_name} value '{settings[param_name]}' to {param_type}")
|
|
331
|
+
|
|
332
|
+
# Handle stop sequences
|
|
333
|
+
stop_seq_str = str(settings.get("stop_sequences", '')).strip()
|
|
334
|
+
if stop_seq_str:
|
|
335
|
+
params["stop_sequences"] = [s.strip() for s in stop_seq_str.split(',')]
|
|
336
|
+
|
|
337
|
+
# Set default parameters if not provided
|
|
338
|
+
if "max_new_tokens" not in params:
|
|
339
|
+
params["max_new_tokens"] = 512
|
|
340
|
+
if "temperature" not in params:
|
|
341
|
+
params["temperature"] = 0.7
|
|
342
|
+
if "do_sample" not in params:
|
|
343
|
+
params["do_sample"] = True
|
|
344
|
+
|
|
345
|
+
logger.debug(f"HuggingFace text generation payload: {json.dumps(params, indent=2)}")
|
|
346
|
+
|
|
347
|
+
# Call text generation
|
|
348
|
+
response = client.text_generation(prompt, **params)
|
|
349
|
+
|
|
350
|
+
# Handle response
|
|
351
|
+
if hasattr(response, 'generated_text'):
|
|
352
|
+
result_text = response.generated_text
|
|
353
|
+
elif isinstance(response, str):
|
|
354
|
+
result_text = response
|
|
355
|
+
else:
|
|
356
|
+
result_text = str(response)
|
|
357
|
+
|
|
358
|
+
# Clean up the response (remove the original prompt if it's included)
|
|
359
|
+
if result_text.startswith(prompt):
|
|
360
|
+
result_text = result_text[len(prompt):].strip()
|
|
361
|
+
|
|
362
|
+
logger.info("Text generation completed successfully")
|
|
363
|
+
update_callback(result_text)
|
|
364
|
+
|
|
365
|
+
except Exception as e:
|
|
366
|
+
error_msg = f"Text Generation Error: {str(e)}\n\n"
|
|
367
|
+
error_msg += f"Failed to generate text using model '{model_name}'.\n"
|
|
368
|
+
|
|
369
|
+
if "doesn't support task 'text-generation'" in str(e):
|
|
370
|
+
error_msg += "\nThis model doesn't support text generation. It may be a specialized model (e.g., classification, embedding).\n"
|
|
371
|
+
error_msg += "Try using a different model or check the model's documentation for supported tasks."
|
|
372
|
+
elif "isn't deployed by any Inference Provider" in str(e):
|
|
373
|
+
error_msg += "\nThis model is not deployed by any Inference Provider on HuggingFace.\n"
|
|
374
|
+
error_msg += "Solutions:\n"
|
|
375
|
+
error_msg += "1. Try a similar model that's available on the Inference API\n"
|
|
376
|
+
error_msg += "2. Use HuggingFace Spaces or deploy the model yourself\n"
|
|
377
|
+
error_msg += "3. Check the model page for inference provider availability"
|
|
378
|
+
else:
|
|
379
|
+
error_msg += "Please verify the model supports text generation and try again."
|
|
380
|
+
|
|
381
|
+
logger.error(error_msg, exc_info=True)
|
|
382
|
+
update_callback(error_msg)
|
|
383
|
+
|
|
384
|
+
def handle_chat_completion(client: InferenceClient, prompt: str, model_name: str,
|
|
385
|
+
settings: Dict[str, Any], update_callback, logger) -> None:
|
|
386
|
+
"""Handle chat completion models."""
|
|
387
|
+
try:
|
|
388
|
+
messages = []
|
|
389
|
+
system_prompt = settings.get("system_prompt", "").strip()
|
|
390
|
+
if system_prompt:
|
|
391
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
392
|
+
messages.append({"role": "user", "content": prompt})
|
|
393
|
+
|
|
394
|
+
params = {"messages": messages, "model": model_name}
|
|
395
|
+
|
|
396
|
+
# Add supported parameters
|
|
397
|
+
for param_name, param_type in [
|
|
398
|
+
("max_tokens", int),
|
|
399
|
+
("seed", int),
|
|
400
|
+
("temperature", float),
|
|
401
|
+
("top_p", float)
|
|
402
|
+
]:
|
|
403
|
+
if param_name in settings:
|
|
404
|
+
try:
|
|
405
|
+
params[param_name] = param_type(settings[param_name])
|
|
406
|
+
except (ValueError, TypeError):
|
|
407
|
+
pass
|
|
408
|
+
|
|
409
|
+
stop_seq_str = str(settings.get("stop_sequences", '')).strip()
|
|
410
|
+
if stop_seq_str:
|
|
411
|
+
params["stop"] = [s.strip() for s in stop_seq_str.split(',')]
|
|
412
|
+
|
|
413
|
+
logger.debug(f"HuggingFace chat completion payload: {json.dumps(params, indent=2)}")
|
|
414
|
+
response_obj = client.chat_completion(**params)
|
|
415
|
+
update_callback(response_obj.choices[0].message.content)
|
|
416
|
+
|
|
417
|
+
except Exception as e:
|
|
418
|
+
error_msg = f"HuggingFace Chat Error: {str(e)}\n\n"
|
|
419
|
+
error_msg += "This model may not support chat completion. Please try a different model or check the model's documentation."
|
|
420
|
+
if "doesn't support task 'conversational'" in str(e):
|
|
421
|
+
error_msg += "\n\nNote: This appears to be a text classification model, not a chat model. It's designed to analyze text and return categories/sentiment, not generate responses."
|
|
422
|
+
logger.error(error_msg, exc_info=True)
|
|
423
|
+
update_callback(error_msg)
|
|
424
|
+
|
|
425
|
+
def format_hf_http_error(error: HfHubHTTPError, model_name: str = "") -> str:
|
|
426
|
+
"""Format HuggingFace HTTP error messages."""
|
|
427
|
+
error_msg = f"HuggingFace API Error: {error.response.status_code} - {error.response.reason}\n\n{error.response.text}"
|
|
428
|
+
|
|
429
|
+
if error.response.status_code == 401:
|
|
430
|
+
error_msg += "\n\nThis means your API token is invalid or expired. Please check your API key."
|
|
431
|
+
elif error.response.status_code == 403:
|
|
432
|
+
error_msg += f"\n\nThis is a 'gated model'. You MUST accept the terms on the model page:\nhttps://huggingface.co/{model_name}"
|
|
433
|
+
elif error.response.status_code == 404:
|
|
434
|
+
error_msg += "\n\nThe model was not found. Please check the model name and try again."
|
|
435
|
+
|
|
436
|
+
return error_msg
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def format_generic_error(error: Exception, model_name: str = "") -> str:
|
|
440
|
+
"""Format generic error messages."""
|
|
441
|
+
error_msg = f"HuggingFace Error: {str(error)}\n\n"
|
|
442
|
+
error_msg += "Please check that the model supports the task you're trying to perform.\n"
|
|
443
|
+
error_msg += f"Model: {model_name or 'Not specified'}\n"
|
|
444
|
+
error_msg += "\nCommon issues:\n"
|
|
445
|
+
error_msg += "1. The model may not support chat completion\n"
|
|
446
|
+
error_msg += "2. The model may require a different task type (e.g., text-classification)\n"
|
|
447
|
+
error_msg += "3. The model may be gated - check if you need to accept terms at https://huggingface.co/models"
|
|
448
|
+
|
|
449
|
+
return error_msg
|