superqode 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- superqode/__init__.py +33 -0
- superqode/acp/__init__.py +23 -0
- superqode/acp/client.py +913 -0
- superqode/acp/permission_screen.py +457 -0
- superqode/acp/types.py +480 -0
- superqode/acp_discovery.py +856 -0
- superqode/agent/__init__.py +22 -0
- superqode/agent/edit_strategies.py +334 -0
- superqode/agent/loop.py +892 -0
- superqode/agent/qe_report_templates.py +39 -0
- superqode/agent/system_prompts.py +353 -0
- superqode/agent_output.py +721 -0
- superqode/agent_stream.py +953 -0
- superqode/agents/__init__.py +59 -0
- superqode/agents/acp_registry.py +305 -0
- superqode/agents/client.py +249 -0
- superqode/agents/data/augmentcode.com.toml +51 -0
- superqode/agents/data/cagent.dev.toml +51 -0
- superqode/agents/data/claude.com.toml +60 -0
- superqode/agents/data/codeassistant.dev.toml +51 -0
- superqode/agents/data/codex.openai.com.toml +57 -0
- superqode/agents/data/fastagent.ai.toml +66 -0
- superqode/agents/data/geminicli.com.toml +77 -0
- superqode/agents/data/goose.block.xyz.toml +54 -0
- superqode/agents/data/junie.jetbrains.com.toml +56 -0
- superqode/agents/data/kimi.moonshot.cn.toml +57 -0
- superqode/agents/data/llmlingagent.dev.toml +51 -0
- superqode/agents/data/molt.bot.toml +49 -0
- superqode/agents/data/opencode.ai.toml +60 -0
- superqode/agents/data/stakpak.dev.toml +51 -0
- superqode/agents/data/vtcode.dev.toml +51 -0
- superqode/agents/discovery.py +266 -0
- superqode/agents/messaging.py +160 -0
- superqode/agents/persona.py +166 -0
- superqode/agents/registry.py +421 -0
- superqode/agents/schema.py +72 -0
- superqode/agents/unified.py +367 -0
- superqode/app/__init__.py +111 -0
- superqode/app/constants.py +314 -0
- superqode/app/css.py +366 -0
- superqode/app/models.py +118 -0
- superqode/app/suggester.py +125 -0
- superqode/app/widgets.py +1591 -0
- superqode/app_enhanced.py +399 -0
- superqode/app_main.py +17187 -0
- superqode/approval.py +312 -0
- superqode/atomic.py +296 -0
- superqode/commands/__init__.py +1 -0
- superqode/commands/acp.py +965 -0
- superqode/commands/agents.py +180 -0
- superqode/commands/auth.py +278 -0
- superqode/commands/config.py +374 -0
- superqode/commands/init.py +826 -0
- superqode/commands/providers.py +819 -0
- superqode/commands/qe.py +1145 -0
- superqode/commands/roles.py +380 -0
- superqode/commands/serve.py +172 -0
- superqode/commands/suggestions.py +127 -0
- superqode/commands/superqe.py +460 -0
- superqode/config/__init__.py +51 -0
- superqode/config/loader.py +812 -0
- superqode/config/schema.py +498 -0
- superqode/core/__init__.py +111 -0
- superqode/core/roles.py +281 -0
- superqode/danger.py +386 -0
- superqode/data/superqode-template.yaml +1522 -0
- superqode/design_system.py +1080 -0
- superqode/dialogs/__init__.py +6 -0
- superqode/dialogs/base.py +39 -0
- superqode/dialogs/model.py +130 -0
- superqode/dialogs/provider.py +870 -0
- superqode/diff_view.py +919 -0
- superqode/enterprise.py +21 -0
- superqode/evaluation/__init__.py +25 -0
- superqode/evaluation/adapters.py +93 -0
- superqode/evaluation/behaviors.py +89 -0
- superqode/evaluation/engine.py +209 -0
- superqode/evaluation/scenarios.py +96 -0
- superqode/execution/__init__.py +36 -0
- superqode/execution/linter.py +538 -0
- superqode/execution/modes.py +347 -0
- superqode/execution/resolver.py +283 -0
- superqode/execution/runner.py +642 -0
- superqode/file_explorer.py +811 -0
- superqode/file_viewer.py +471 -0
- superqode/flash.py +183 -0
- superqode/guidance/__init__.py +58 -0
- superqode/guidance/config.py +203 -0
- superqode/guidance/prompts.py +71 -0
- superqode/harness/__init__.py +54 -0
- superqode/harness/accelerator.py +291 -0
- superqode/harness/config.py +319 -0
- superqode/harness/validator.py +147 -0
- superqode/history.py +279 -0
- superqode/integrations/superopt_runner.py +124 -0
- superqode/logging/__init__.py +49 -0
- superqode/logging/adapters.py +219 -0
- superqode/logging/formatter.py +923 -0
- superqode/logging/integration.py +341 -0
- superqode/logging/sinks.py +170 -0
- superqode/logging/unified_log.py +417 -0
- superqode/lsp/__init__.py +26 -0
- superqode/lsp/client.py +544 -0
- superqode/main.py +1069 -0
- superqode/mcp/__init__.py +89 -0
- superqode/mcp/auth_storage.py +380 -0
- superqode/mcp/client.py +1236 -0
- superqode/mcp/config.py +319 -0
- superqode/mcp/integration.py +337 -0
- superqode/mcp/oauth.py +436 -0
- superqode/mcp/oauth_callback.py +385 -0
- superqode/mcp/types.py +290 -0
- superqode/memory/__init__.py +31 -0
- superqode/memory/feedback.py +342 -0
- superqode/memory/store.py +522 -0
- superqode/notifications.py +369 -0
- superqode/optimization/__init__.py +5 -0
- superqode/optimization/config.py +33 -0
- superqode/permissions/__init__.py +25 -0
- superqode/permissions/rules.py +488 -0
- superqode/plan.py +323 -0
- superqode/providers/__init__.py +33 -0
- superqode/providers/gateway/__init__.py +165 -0
- superqode/providers/gateway/base.py +228 -0
- superqode/providers/gateway/litellm_gateway.py +1170 -0
- superqode/providers/gateway/openresponses_gateway.py +436 -0
- superqode/providers/health.py +297 -0
- superqode/providers/huggingface/__init__.py +74 -0
- superqode/providers/huggingface/downloader.py +472 -0
- superqode/providers/huggingface/endpoints.py +442 -0
- superqode/providers/huggingface/hub.py +531 -0
- superqode/providers/huggingface/inference.py +394 -0
- superqode/providers/huggingface/transformers_runner.py +516 -0
- superqode/providers/local/__init__.py +100 -0
- superqode/providers/local/base.py +438 -0
- superqode/providers/local/discovery.py +418 -0
- superqode/providers/local/lmstudio.py +256 -0
- superqode/providers/local/mlx.py +457 -0
- superqode/providers/local/ollama.py +486 -0
- superqode/providers/local/sglang.py +268 -0
- superqode/providers/local/tgi.py +260 -0
- superqode/providers/local/tool_support.py +477 -0
- superqode/providers/local/vllm.py +258 -0
- superqode/providers/manager.py +1338 -0
- superqode/providers/models.py +1016 -0
- superqode/providers/models_dev.py +578 -0
- superqode/providers/openresponses/__init__.py +87 -0
- superqode/providers/openresponses/converters/__init__.py +17 -0
- superqode/providers/openresponses/converters/messages.py +343 -0
- superqode/providers/openresponses/converters/tools.py +268 -0
- superqode/providers/openresponses/schema/__init__.py +56 -0
- superqode/providers/openresponses/schema/models.py +585 -0
- superqode/providers/openresponses/streaming/__init__.py +5 -0
- superqode/providers/openresponses/streaming/parser.py +338 -0
- superqode/providers/openresponses/tools/__init__.py +21 -0
- superqode/providers/openresponses/tools/apply_patch.py +352 -0
- superqode/providers/openresponses/tools/code_interpreter.py +290 -0
- superqode/providers/openresponses/tools/file_search.py +333 -0
- superqode/providers/openresponses/tools/mcp_adapter.py +252 -0
- superqode/providers/registry.py +716 -0
- superqode/providers/usage.py +332 -0
- superqode/pure_mode.py +384 -0
- superqode/qr/__init__.py +23 -0
- superqode/qr/dashboard.py +781 -0
- superqode/qr/generator.py +1018 -0
- superqode/qr/templates.py +135 -0
- superqode/safety/__init__.py +41 -0
- superqode/safety/sandbox.py +413 -0
- superqode/safety/warnings.py +256 -0
- superqode/server/__init__.py +33 -0
- superqode/server/lsp_server.py +775 -0
- superqode/server/web.py +250 -0
- superqode/session/__init__.py +25 -0
- superqode/session/persistence.py +580 -0
- superqode/session/sharing.py +477 -0
- superqode/session.py +475 -0
- superqode/sidebar.py +2991 -0
- superqode/stream_view.py +648 -0
- superqode/styles/__init__.py +3 -0
- superqode/superqe/__init__.py +184 -0
- superqode/superqe/acp_runner.py +1064 -0
- superqode/superqe/constitution/__init__.py +62 -0
- superqode/superqe/constitution/evaluator.py +308 -0
- superqode/superqe/constitution/loader.py +432 -0
- superqode/superqe/constitution/schema.py +250 -0
- superqode/superqe/events.py +591 -0
- superqode/superqe/frameworks/__init__.py +65 -0
- superqode/superqe/frameworks/base.py +234 -0
- superqode/superqe/frameworks/e2e.py +263 -0
- superqode/superqe/frameworks/executor.py +237 -0
- superqode/superqe/frameworks/javascript.py +409 -0
- superqode/superqe/frameworks/python.py +373 -0
- superqode/superqe/frameworks/registry.py +92 -0
- superqode/superqe/mcp_tools/__init__.py +47 -0
- superqode/superqe/mcp_tools/core_tools.py +418 -0
- superqode/superqe/mcp_tools/registry.py +230 -0
- superqode/superqe/mcp_tools/testing_tools.py +167 -0
- superqode/superqe/noise.py +89 -0
- superqode/superqe/orchestrator.py +778 -0
- superqode/superqe/roles.py +609 -0
- superqode/superqe/session.py +713 -0
- superqode/superqe/skills/__init__.py +57 -0
- superqode/superqe/skills/base.py +106 -0
- superqode/superqe/skills/core_skills.py +899 -0
- superqode/superqe/skills/registry.py +90 -0
- superqode/superqe/verifier.py +101 -0
- superqode/superqe_cli.py +76 -0
- superqode/tool_call.py +358 -0
- superqode/tools/__init__.py +93 -0
- superqode/tools/agent_tools.py +496 -0
- superqode/tools/base.py +324 -0
- superqode/tools/batch_tool.py +133 -0
- superqode/tools/diagnostics.py +311 -0
- superqode/tools/edit_tools.py +653 -0
- superqode/tools/enhanced_base.py +515 -0
- superqode/tools/file_tools.py +269 -0
- superqode/tools/file_tracking.py +45 -0
- superqode/tools/lsp_tools.py +610 -0
- superqode/tools/network_tools.py +350 -0
- superqode/tools/permissions.py +400 -0
- superqode/tools/question_tool.py +324 -0
- superqode/tools/search_tools.py +598 -0
- superqode/tools/shell_tools.py +259 -0
- superqode/tools/todo_tools.py +121 -0
- superqode/tools/validation.py +80 -0
- superqode/tools/web_tools.py +639 -0
- superqode/tui.py +1152 -0
- superqode/tui_integration.py +875 -0
- superqode/tui_widgets/__init__.py +27 -0
- superqode/tui_widgets/widgets/__init__.py +18 -0
- superqode/tui_widgets/widgets/progress.py +185 -0
- superqode/tui_widgets/widgets/tool_display.py +188 -0
- superqode/undo_manager.py +574 -0
- superqode/utils/__init__.py +5 -0
- superqode/utils/error_handling.py +323 -0
- superqode/utils/fuzzy.py +257 -0
- superqode/widgets/__init__.py +477 -0
- superqode/widgets/agent_collab.py +390 -0
- superqode/widgets/agent_store.py +936 -0
- superqode/widgets/agent_switcher.py +395 -0
- superqode/widgets/animation_manager.py +284 -0
- superqode/widgets/code_context.py +356 -0
- superqode/widgets/command_palette.py +412 -0
- superqode/widgets/connection_status.py +537 -0
- superqode/widgets/conversation_history.py +470 -0
- superqode/widgets/diff_indicator.py +155 -0
- superqode/widgets/enhanced_status_bar.py +385 -0
- superqode/widgets/enhanced_toast.py +476 -0
- superqode/widgets/file_browser.py +809 -0
- superqode/widgets/file_reference.py +585 -0
- superqode/widgets/issue_timeline.py +340 -0
- superqode/widgets/leader_key.py +264 -0
- superqode/widgets/mode_switcher.py +445 -0
- superqode/widgets/model_picker.py +234 -0
- superqode/widgets/permission_preview.py +1205 -0
- superqode/widgets/prompt.py +358 -0
- superqode/widgets/provider_connect.py +725 -0
- superqode/widgets/pty_shell.py +587 -0
- superqode/widgets/qe_dashboard.py +321 -0
- superqode/widgets/resizable_sidebar.py +377 -0
- superqode/widgets/response_changes.py +218 -0
- superqode/widgets/response_display.py +528 -0
- superqode/widgets/rich_tool_display.py +613 -0
- superqode/widgets/sidebar_panels.py +1180 -0
- superqode/widgets/slash_complete.py +356 -0
- superqode/widgets/split_view.py +612 -0
- superqode/widgets/status_bar.py +273 -0
- superqode/widgets/superqode_display.py +786 -0
- superqode/widgets/thinking_display.py +815 -0
- superqode/widgets/throbber.py +87 -0
- superqode/widgets/toast.py +206 -0
- superqode/widgets/unified_output.py +1073 -0
- superqode/workspace/__init__.py +75 -0
- superqode/workspace/artifacts.py +472 -0
- superqode/workspace/coordinator.py +353 -0
- superqode/workspace/diff_tracker.py +429 -0
- superqode/workspace/git_guard.py +373 -0
- superqode/workspace/git_snapshot.py +526 -0
- superqode/workspace/manager.py +750 -0
- superqode/workspace/snapshot.py +357 -0
- superqode/workspace/watcher.py +535 -0
- superqode/workspace/worktree.py +440 -0
- superqode-0.1.5.dist-info/METADATA +204 -0
- superqode-0.1.5.dist-info/RECORD +288 -0
- superqode-0.1.5.dist-info/WHEEL +5 -0
- superqode-0.1.5.dist-info/entry_points.txt +3 -0
- superqode-0.1.5.dist-info/licenses/LICENSE +648 -0
- superqode-0.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""HuggingFace Inference API client with streaming support.
|
|
2
|
+
|
|
3
|
+
This module provides access to the HuggingFace Inference API (serverless)
|
|
4
|
+
for text generation with any compatible model.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
|
12
|
+
from urllib.error import HTTPError, URLError
|
|
13
|
+
from urllib.request import Request, urlopen
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# HuggingFace Inference API endpoints
|
|
17
|
+
HF_INFERENCE_API = "https://api-inference.huggingface.co/models"
|
|
18
|
+
HF_ROUTER_API = "https://router.huggingface.co/hf" # New router for free inference
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class InferenceResponse:
|
|
23
|
+
"""Response from HF Inference API.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
content: Generated text content
|
|
27
|
+
model: Model that generated the response
|
|
28
|
+
finish_reason: Why generation stopped
|
|
29
|
+
usage: Token usage information
|
|
30
|
+
tool_calls: Tool calls if any
|
|
31
|
+
error: Error message if failed
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
content: str = ""
|
|
35
|
+
model: str = ""
|
|
36
|
+
finish_reason: str = ""
|
|
37
|
+
usage: Dict[str, int] = field(default_factory=dict)
|
|
38
|
+
tool_calls: List[Dict] = field(default_factory=list)
|
|
39
|
+
error: str = ""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Recommended models for different use cases
|
|
43
|
+
RECOMMENDED_MODELS = {
|
|
44
|
+
"general": [
|
|
45
|
+
"meta-llama/Llama-3.3-70B-Instruct",
|
|
46
|
+
"Qwen/Qwen2.5-72B-Instruct",
|
|
47
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
48
|
+
"microsoft/Phi-3.5-mini-instruct",
|
|
49
|
+
],
|
|
50
|
+
"coding": [
|
|
51
|
+
"Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
52
|
+
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
|
53
|
+
"codellama/CodeLlama-34b-Instruct-hf",
|
|
54
|
+
"bigcode/starcoder2-15b-instruct-v0.1",
|
|
55
|
+
],
|
|
56
|
+
"small": [
|
|
57
|
+
"microsoft/Phi-3.5-mini-instruct",
|
|
58
|
+
"google/gemma-2-2b-it",
|
|
59
|
+
"Qwen/Qwen2.5-3B-Instruct",
|
|
60
|
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
|
61
|
+
],
|
|
62
|
+
"chat": [
|
|
63
|
+
"meta-llama/Llama-3.2-3B-Instruct",
|
|
64
|
+
"HuggingFaceH4/zephyr-7b-beta",
|
|
65
|
+
"openchat/openchat-3.5-0106",
|
|
66
|
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
|
67
|
+
],
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class HFInferenceClient:
|
|
72
|
+
"""HuggingFace Inference API client.
|
|
73
|
+
|
|
74
|
+
Provides access to HF's serverless inference API for text generation.
|
|
75
|
+
Supports both the free tier and Pro tier.
|
|
76
|
+
|
|
77
|
+
Environment:
|
|
78
|
+
HF_TOKEN: HuggingFace token for authentication (optional but recommended)
|
|
79
|
+
HF_INFERENCE_ENDPOINT: Custom inference endpoint (optional)
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self, token: Optional[str] = None, endpoint: Optional[str] = None, use_router: bool = True
|
|
84
|
+
):
|
|
85
|
+
"""Initialize the Inference API client.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
token: HF token. Falls back to HF_TOKEN env var.
|
|
89
|
+
endpoint: Custom inference endpoint. Falls back to HF_INFERENCE_ENDPOINT.
|
|
90
|
+
use_router: Use the new router API for better availability.
|
|
91
|
+
"""
|
|
92
|
+
self._token = (
|
|
93
|
+
token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
|
94
|
+
)
|
|
95
|
+
self._custom_endpoint = endpoint or os.environ.get("HF_INFERENCE_ENDPOINT")
|
|
96
|
+
self._use_router = use_router
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def is_authenticated(self) -> bool:
|
|
100
|
+
"""Check if we have authentication."""
|
|
101
|
+
return self._token is not None and len(self._token) > 0
|
|
102
|
+
|
|
103
|
+
def get_endpoint(self, model_id: str) -> str:
|
|
104
|
+
"""Get the API endpoint for a model.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
model_id: Model ID.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Full API endpoint URL.
|
|
111
|
+
"""
|
|
112
|
+
if self._custom_endpoint:
|
|
113
|
+
return f"{self._custom_endpoint}/{model_id}"
|
|
114
|
+
|
|
115
|
+
if self._use_router:
|
|
116
|
+
return f"{HF_ROUTER_API}/{model_id}/v1/chat/completions"
|
|
117
|
+
|
|
118
|
+
return f"{HF_INFERENCE_API}/{model_id}"
|
|
119
|
+
|
|
120
|
+
def _request(
|
|
121
|
+
self, endpoint: str, data: Dict[str, Any], timeout: float = 120.0
|
|
122
|
+
) -> Dict[str, Any]:
|
|
123
|
+
"""Make a request to the Inference API.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
endpoint: Full API endpoint URL.
|
|
127
|
+
data: Request body.
|
|
128
|
+
timeout: Request timeout.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
JSON response.
|
|
132
|
+
"""
|
|
133
|
+
headers = {
|
|
134
|
+
"Content-Type": "application/json",
|
|
135
|
+
"Accept": "application/json",
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
if self._token:
|
|
139
|
+
headers["Authorization"] = f"Bearer {self._token}"
|
|
140
|
+
|
|
141
|
+
body = json.dumps(data).encode("utf-8")
|
|
142
|
+
request = Request(endpoint, data=body, headers=headers, method="POST")
|
|
143
|
+
|
|
144
|
+
with urlopen(request, timeout=timeout) as response:
|
|
145
|
+
return json.loads(response.read().decode("utf-8"))
|
|
146
|
+
|
|
147
|
+
async def _async_request(
|
|
148
|
+
self, endpoint: str, data: Dict[str, Any], timeout: float = 120.0
|
|
149
|
+
) -> Dict[str, Any]:
|
|
150
|
+
"""Async wrapper for _request."""
|
|
151
|
+
loop = asyncio.get_event_loop()
|
|
152
|
+
return await loop.run_in_executor(None, lambda: self._request(endpoint, data, timeout))
|
|
153
|
+
|
|
154
|
+
async def chat(
|
|
155
|
+
self,
|
|
156
|
+
messages: List[Dict[str, str]],
|
|
157
|
+
model: str = "meta-llama/Llama-3.3-70B-Instruct",
|
|
158
|
+
max_tokens: int = 2048,
|
|
159
|
+
temperature: float = 0.7,
|
|
160
|
+
top_p: float = 0.9,
|
|
161
|
+
tools: Optional[List[Dict]] = None,
|
|
162
|
+
tool_choice: Optional[str] = None,
|
|
163
|
+
stream: bool = False,
|
|
164
|
+
) -> InferenceResponse:
|
|
165
|
+
"""Send a chat completion request.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
messages: Chat messages in OpenAI format.
|
|
169
|
+
model: Model ID to use.
|
|
170
|
+
max_tokens: Maximum tokens to generate.
|
|
171
|
+
temperature: Sampling temperature.
|
|
172
|
+
top_p: Nucleus sampling threshold.
|
|
173
|
+
tools: Tool definitions for function calling.
|
|
174
|
+
tool_choice: Tool choice mode ("auto", "none", "required").
|
|
175
|
+
stream: Whether to stream the response (not yet implemented).
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
InferenceResponse with generated content.
|
|
179
|
+
"""
|
|
180
|
+
endpoint = self.get_endpoint(model)
|
|
181
|
+
|
|
182
|
+
# Build request payload
|
|
183
|
+
payload: Dict[str, Any] = {
|
|
184
|
+
"model": model,
|
|
185
|
+
"messages": messages,
|
|
186
|
+
"max_tokens": max_tokens,
|
|
187
|
+
"temperature": temperature,
|
|
188
|
+
"top_p": top_p,
|
|
189
|
+
"stream": False, # Streaming handled separately
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
if tools:
|
|
193
|
+
payload["tools"] = tools
|
|
194
|
+
|
|
195
|
+
if tool_choice:
|
|
196
|
+
payload["tool_choice"] = tool_choice
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
response = await self._async_request(endpoint, payload)
|
|
200
|
+
return self._parse_chat_response(response, model)
|
|
201
|
+
|
|
202
|
+
except HTTPError as e:
|
|
203
|
+
error_body = ""
|
|
204
|
+
try:
|
|
205
|
+
error_body = e.read().decode("utf-8")
|
|
206
|
+
except Exception:
|
|
207
|
+
pass
|
|
208
|
+
|
|
209
|
+
return InferenceResponse(model=model, error=f"HTTP {e.code}: {error_body or e.reason}")
|
|
210
|
+
|
|
211
|
+
except Exception as e:
|
|
212
|
+
return InferenceResponse(model=model, error=str(e))
|
|
213
|
+
|
|
214
|
+
async def generate(
|
|
215
|
+
self,
|
|
216
|
+
prompt: str,
|
|
217
|
+
model: str = "meta-llama/Llama-3.3-70B-Instruct",
|
|
218
|
+
max_tokens: int = 2048,
|
|
219
|
+
temperature: float = 0.7,
|
|
220
|
+
stop: Optional[List[str]] = None,
|
|
221
|
+
) -> InferenceResponse:
|
|
222
|
+
"""Send a text generation request (non-chat format).
|
|
223
|
+
|
|
224
|
+
This uses the older text generation API format for models
|
|
225
|
+
that don't support chat templates.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
prompt: Text prompt.
|
|
229
|
+
model: Model ID.
|
|
230
|
+
max_tokens: Maximum tokens.
|
|
231
|
+
temperature: Sampling temperature.
|
|
232
|
+
stop: Stop sequences.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
InferenceResponse with generated text.
|
|
236
|
+
"""
|
|
237
|
+
# Use direct inference API for non-chat models
|
|
238
|
+
endpoint = f"{HF_INFERENCE_API}/{model}"
|
|
239
|
+
|
|
240
|
+
payload = {
|
|
241
|
+
"inputs": prompt,
|
|
242
|
+
"parameters": {
|
|
243
|
+
"max_new_tokens": max_tokens,
|
|
244
|
+
"temperature": temperature,
|
|
245
|
+
"return_full_text": False,
|
|
246
|
+
},
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
if stop:
|
|
250
|
+
payload["parameters"]["stop_sequences"] = stop
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
response = await self._async_request(endpoint, payload)
|
|
254
|
+
|
|
255
|
+
# Parse text generation response
|
|
256
|
+
if isinstance(response, list) and len(response) > 0:
|
|
257
|
+
text = response[0].get("generated_text", "")
|
|
258
|
+
return InferenceResponse(content=text, model=model)
|
|
259
|
+
|
|
260
|
+
return InferenceResponse(model=model, error="Unexpected response format")
|
|
261
|
+
|
|
262
|
+
except Exception as e:
|
|
263
|
+
return InferenceResponse(model=model, error=str(e))
|
|
264
|
+
|
|
265
|
+
async def check_model_status(self, model: str) -> Dict[str, Any]:
|
|
266
|
+
"""Check the status of a model on the Inference API.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
model: Model ID.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Dict with status information.
|
|
273
|
+
"""
|
|
274
|
+
# Try a minimal request to check status
|
|
275
|
+
try:
|
|
276
|
+
response = await self.chat(
|
|
277
|
+
messages=[{"role": "user", "content": "Hi"}],
|
|
278
|
+
model=model,
|
|
279
|
+
max_tokens=1,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if response.error:
|
|
283
|
+
# Check for common error patterns
|
|
284
|
+
if "loading" in response.error.lower():
|
|
285
|
+
return {
|
|
286
|
+
"available": False,
|
|
287
|
+
"loading": True,
|
|
288
|
+
"error": "Model is loading",
|
|
289
|
+
}
|
|
290
|
+
if "rate limit" in response.error.lower():
|
|
291
|
+
return {
|
|
292
|
+
"available": True,
|
|
293
|
+
"rate_limited": True,
|
|
294
|
+
"error": response.error,
|
|
295
|
+
}
|
|
296
|
+
return {
|
|
297
|
+
"available": False,
|
|
298
|
+
"error": response.error,
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
return {
|
|
302
|
+
"available": True,
|
|
303
|
+
"loading": False,
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
except Exception as e:
|
|
307
|
+
return {
|
|
308
|
+
"available": False,
|
|
309
|
+
"error": str(e),
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
async def list_available_models(self) -> List[str]:
|
|
313
|
+
"""Get list of recommended available models.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
List of model IDs known to work well with the Inference API.
|
|
317
|
+
"""
|
|
318
|
+
# Return all recommended models
|
|
319
|
+
all_models = []
|
|
320
|
+
for category_models in RECOMMENDED_MODELS.values():
|
|
321
|
+
all_models.extend(category_models)
|
|
322
|
+
|
|
323
|
+
# Remove duplicates while preserving order
|
|
324
|
+
seen = set()
|
|
325
|
+
unique = []
|
|
326
|
+
for m in all_models:
|
|
327
|
+
if m not in seen:
|
|
328
|
+
seen.add(m)
|
|
329
|
+
unique.append(m)
|
|
330
|
+
|
|
331
|
+
return unique
|
|
332
|
+
|
|
333
|
+
def get_recommended_models(self, category: str = "general") -> List[str]:
|
|
334
|
+
"""Get recommended models for a category.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
category: Model category (general, coding, small, chat).
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
List of recommended model IDs.
|
|
341
|
+
"""
|
|
342
|
+
return RECOMMENDED_MODELS.get(category, RECOMMENDED_MODELS["general"])
|
|
343
|
+
|
|
344
|
+
def _parse_chat_response(self, response: Dict[str, Any], model: str) -> InferenceResponse:
|
|
345
|
+
"""Parse a chat completion response."""
|
|
346
|
+
# Handle OpenAI-compatible format
|
|
347
|
+
choices = response.get("choices", [])
|
|
348
|
+
|
|
349
|
+
if not choices:
|
|
350
|
+
# Check for error
|
|
351
|
+
if "error" in response:
|
|
352
|
+
return InferenceResponse(
|
|
353
|
+
model=model,
|
|
354
|
+
error=response.get("error", {}).get("message", str(response["error"])),
|
|
355
|
+
)
|
|
356
|
+
return InferenceResponse(model=model, error="No response choices")
|
|
357
|
+
|
|
358
|
+
choice = choices[0]
|
|
359
|
+
message = choice.get("message", {})
|
|
360
|
+
|
|
361
|
+
content = message.get("content", "")
|
|
362
|
+
tool_calls = message.get("tool_calls", [])
|
|
363
|
+
finish_reason = choice.get("finish_reason", "")
|
|
364
|
+
|
|
365
|
+
# Parse usage
|
|
366
|
+
usage = response.get("usage", {})
|
|
367
|
+
|
|
368
|
+
return InferenceResponse(
|
|
369
|
+
content=content,
|
|
370
|
+
model=model,
|
|
371
|
+
finish_reason=finish_reason,
|
|
372
|
+
usage={
|
|
373
|
+
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
374
|
+
"completion_tokens": usage.get("completion_tokens", 0),
|
|
375
|
+
"total_tokens": usage.get("total_tokens", 0),
|
|
376
|
+
},
|
|
377
|
+
tool_calls=tool_calls,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
# Singleton instance
|
|
382
|
+
_inference_client: Optional[HFInferenceClient] = None
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def get_hf_inference_client() -> HFInferenceClient:
|
|
386
|
+
"""Get the global HF Inference API client instance.
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
HFInferenceClient instance.
|
|
390
|
+
"""
|
|
391
|
+
global _inference_client
|
|
392
|
+
if _inference_client is None:
|
|
393
|
+
_inference_client = HFInferenceClient()
|
|
394
|
+
return _inference_client
|