huggingface-hub 0.31.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +42 -4
- huggingface_hub/_local_folder.py +8 -0
- huggingface_hub/_oauth.py +464 -0
- huggingface_hub/_snapshot_download.py +11 -3
- huggingface_hub/_upload_large_folder.py +16 -36
- huggingface_hub/commands/huggingface_cli.py +2 -0
- huggingface_hub/commands/repo.py +147 -0
- huggingface_hub/commands/user.py +2 -108
- huggingface_hub/constants.py +9 -1
- huggingface_hub/dataclasses.py +2 -2
- huggingface_hub/file_download.py +13 -11
- huggingface_hub/hf_api.py +48 -19
- huggingface_hub/hub_mixin.py +2 -2
- huggingface_hub/inference/_client.py +8 -7
- huggingface_hub/inference/_generated/_async_client.py +8 -7
- huggingface_hub/inference/_generated/types/__init__.py +4 -1
- huggingface_hub/inference/_generated/types/chat_completion.py +43 -9
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/agent.py +99 -0
- huggingface_hub/inference/_mcp/cli.py +154 -0
- huggingface_hub/inference/_mcp/constants.py +80 -0
- huggingface_hub/inference/_mcp/mcp_client.py +322 -0
- huggingface_hub/inference/_mcp/utils.py +123 -0
- huggingface_hub/inference/_providers/__init__.py +13 -1
- huggingface_hub/inference/_providers/_common.py +1 -0
- huggingface_hub/inference/_providers/cerebras.py +1 -1
- huggingface_hub/inference/_providers/cohere.py +20 -3
- huggingface_hub/inference/_providers/fireworks_ai.py +18 -0
- huggingface_hub/inference/_providers/hf_inference.py +8 -1
- huggingface_hub/inference/_providers/nebius.py +28 -0
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/sambanova.py +14 -0
- huggingface_hub/inference/_providers/together.py +15 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/insecure_hashlib.py +8 -4
- {huggingface_hub-0.31.3.dist-info → huggingface_hub-0.32.0.dist-info}/METADATA +30 -8
- {huggingface_hub-0.31.3.dist-info → huggingface_hub-0.32.0.dist-info}/RECORD +41 -32
- {huggingface_hub-0.31.3.dist-info → huggingface_hub-0.32.0.dist-info}/entry_points.txt +1 -0
- {huggingface_hub-0.31.3.dist-info → huggingface_hub-0.32.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.31.3.dist-info → huggingface_hub-0.32.0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.31.3.dist-info → huggingface_hub-0.32.0.dist-info}/top_level.txt +0 -0
|
@@ -66,6 +66,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
66
66
|
AudioToAudioOutputElement,
|
|
67
67
|
AutomaticSpeechRecognitionOutput,
|
|
68
68
|
ChatCompletionInputGrammarType,
|
|
69
|
+
ChatCompletionInputMessage,
|
|
69
70
|
ChatCompletionInputStreamOptions,
|
|
70
71
|
ChatCompletionInputTool,
|
|
71
72
|
ChatCompletionInputToolChoiceClass,
|
|
@@ -100,7 +101,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
100
101
|
ZeroShotClassificationOutputElement,
|
|
101
102
|
ZeroShotImageClassificationOutputElement,
|
|
102
103
|
)
|
|
103
|
-
from huggingface_hub.inference._providers import
|
|
104
|
+
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
|
|
104
105
|
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
105
106
|
from huggingface_hub.utils._auth import get_token
|
|
106
107
|
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
@@ -133,7 +134,7 @@ class InferenceClient:
|
|
|
133
134
|
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
134
135
|
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
135
136
|
provider (`str`, *optional*):
|
|
136
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
137
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
137
138
|
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
138
139
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
139
140
|
token (`str`, *optional*):
|
|
@@ -164,7 +165,7 @@ class InferenceClient:
|
|
|
164
165
|
self,
|
|
165
166
|
model: Optional[str] = None,
|
|
166
167
|
*,
|
|
167
|
-
provider:
|
|
168
|
+
provider: Optional[PROVIDER_OR_POLICY_T] = None,
|
|
168
169
|
token: Optional[str] = None,
|
|
169
170
|
timeout: Optional[float] = None,
|
|
170
171
|
headers: Optional[Dict[str, str]] = None,
|
|
@@ -446,7 +447,7 @@ class InferenceClient:
|
|
|
446
447
|
@overload
|
|
447
448
|
def chat_completion( # type: ignore
|
|
448
449
|
self,
|
|
449
|
-
messages: List[Dict],
|
|
450
|
+
messages: List[Union[Dict, ChatCompletionInputMessage]],
|
|
450
451
|
*,
|
|
451
452
|
model: Optional[str] = None,
|
|
452
453
|
stream: Literal[False] = False,
|
|
@@ -472,7 +473,7 @@ class InferenceClient:
|
|
|
472
473
|
@overload
|
|
473
474
|
def chat_completion( # type: ignore
|
|
474
475
|
self,
|
|
475
|
-
messages: List[Dict],
|
|
476
|
+
messages: List[Union[Dict, ChatCompletionInputMessage]],
|
|
476
477
|
*,
|
|
477
478
|
model: Optional[str] = None,
|
|
478
479
|
stream: Literal[True] = True,
|
|
@@ -498,7 +499,7 @@ class InferenceClient:
|
|
|
498
499
|
@overload
|
|
499
500
|
def chat_completion(
|
|
500
501
|
self,
|
|
501
|
-
messages: List[Dict],
|
|
502
|
+
messages: List[Union[Dict, ChatCompletionInputMessage]],
|
|
502
503
|
*,
|
|
503
504
|
model: Optional[str] = None,
|
|
504
505
|
stream: bool = False,
|
|
@@ -523,7 +524,7 @@ class InferenceClient:
|
|
|
523
524
|
|
|
524
525
|
def chat_completion(
|
|
525
526
|
self,
|
|
526
|
-
messages: List[Dict],
|
|
527
|
+
messages: List[Union[Dict, ChatCompletionInputMessage]],
|
|
527
528
|
*,
|
|
528
529
|
model: Optional[str] = None,
|
|
529
530
|
stream: bool = False,
|
|
@@ -51,6 +51,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
51
51
|
AudioToAudioOutputElement,
|
|
52
52
|
AutomaticSpeechRecognitionOutput,
|
|
53
53
|
ChatCompletionInputGrammarType,
|
|
54
|
+
ChatCompletionInputMessage,
|
|
54
55
|
ChatCompletionInputStreamOptions,
|
|
55
56
|
ChatCompletionInputTool,
|
|
56
57
|
ChatCompletionInputToolChoiceClass,
|
|
@@ -85,7 +86,7 @@ from huggingface_hub.inference._generated.types import (
|
|
|
85
86
|
ZeroShotClassificationOutputElement,
|
|
86
87
|
ZeroShotImageClassificationOutputElement,
|
|
87
88
|
)
|
|
88
|
-
from huggingface_hub.inference._providers import
|
|
89
|
+
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
|
|
89
90
|
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
|
90
91
|
from huggingface_hub.utils._auth import get_token
|
|
91
92
|
from huggingface_hub.utils._deprecation import _deprecate_method
|
|
@@ -121,7 +122,7 @@ class AsyncInferenceClient:
|
|
|
121
122
|
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
|
|
122
123
|
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
|
|
123
124
|
provider (`str`, *optional*):
|
|
124
|
-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
125
|
+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
|
|
125
126
|
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
126
127
|
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
127
128
|
token (`str`, *optional*):
|
|
@@ -154,7 +155,7 @@ class AsyncInferenceClient:
|
|
|
154
155
|
self,
|
|
155
156
|
model: Optional[str] = None,
|
|
156
157
|
*,
|
|
157
|
-
provider:
|
|
158
|
+
provider: Optional[PROVIDER_OR_POLICY_T] = None,
|
|
158
159
|
token: Optional[str] = None,
|
|
159
160
|
timeout: Optional[float] = None,
|
|
160
161
|
headers: Optional[Dict[str, str]] = None,
|
|
@@ -480,7 +481,7 @@ class AsyncInferenceClient:
|
|
|
480
481
|
@overload
|
|
481
482
|
async def chat_completion( # type: ignore
|
|
482
483
|
self,
|
|
483
|
-
messages: List[Dict],
|
|
484
|
+
messages: List[Union[Dict, ChatCompletionInputMessage]],
|
|
484
485
|
*,
|
|
485
486
|
model: Optional[str] = None,
|
|
486
487
|
stream: Literal[False] = False,
|
|
@@ -506,7 +507,7 @@ class AsyncInferenceClient:
|
|
|
506
507
|
@overload
|
|
507
508
|
async def chat_completion( # type: ignore
|
|
508
509
|
self,
|
|
509
|
-
messages: List[Dict],
|
|
510
|
+
messages: List[Union[Dict, ChatCompletionInputMessage]],
|
|
510
511
|
*,
|
|
511
512
|
model: Optional[str] = None,
|
|
512
513
|
stream: Literal[True] = True,
|
|
@@ -532,7 +533,7 @@ class AsyncInferenceClient:
|
|
|
532
533
|
@overload
|
|
533
534
|
async def chat_completion(
|
|
534
535
|
self,
|
|
535
|
-
messages: List[Dict],
|
|
536
|
+
messages: List[Union[Dict, ChatCompletionInputMessage]],
|
|
536
537
|
*,
|
|
537
538
|
model: Optional[str] = None,
|
|
538
539
|
stream: bool = False,
|
|
@@ -557,7 +558,7 @@ class AsyncInferenceClient:
|
|
|
557
558
|
|
|
558
559
|
async def chat_completion(
|
|
559
560
|
self,
|
|
560
|
-
messages: List[Dict],
|
|
561
|
+
messages: List[Union[Dict, ChatCompletionInputMessage]],
|
|
561
562
|
*,
|
|
562
563
|
model: Optional[str] = None,
|
|
563
564
|
stream: bool = False,
|
|
@@ -24,10 +24,13 @@ from .chat_completion import (
|
|
|
24
24
|
ChatCompletionInputFunctionDefinition,
|
|
25
25
|
ChatCompletionInputFunctionName,
|
|
26
26
|
ChatCompletionInputGrammarType,
|
|
27
|
-
|
|
27
|
+
ChatCompletionInputJSONSchema,
|
|
28
28
|
ChatCompletionInputMessage,
|
|
29
29
|
ChatCompletionInputMessageChunk,
|
|
30
30
|
ChatCompletionInputMessageChunkType,
|
|
31
|
+
ChatCompletionInputResponseFormatJSONObject,
|
|
32
|
+
ChatCompletionInputResponseFormatJSONSchema,
|
|
33
|
+
ChatCompletionInputResponseFormatText,
|
|
31
34
|
ChatCompletionInputStreamOptions,
|
|
32
35
|
ChatCompletionInputTool,
|
|
33
36
|
ChatCompletionInputToolCall,
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# See:
|
|
4
4
|
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
|
5
5
|
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
|
6
|
-
from typing import Any, List, Literal, Optional, Union
|
|
6
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
7
7
|
|
|
8
8
|
from .base import BaseInferenceType, dataclass_with_extra
|
|
9
9
|
|
|
@@ -45,17 +45,51 @@ class ChatCompletionInputMessage(BaseInferenceType):
|
|
|
45
45
|
tool_calls: Optional[List[ChatCompletionInputToolCall]] = None
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
@dataclass_with_extra
|
|
49
|
+
class ChatCompletionInputJSONSchema(BaseInferenceType):
|
|
50
|
+
name: str
|
|
51
|
+
"""
|
|
52
|
+
The name of the response format.
|
|
53
|
+
"""
|
|
54
|
+
description: Optional[str] = None
|
|
55
|
+
"""
|
|
56
|
+
A description of what the response format is for, used by the model to determine
|
|
57
|
+
how to respond in the format.
|
|
58
|
+
"""
|
|
59
|
+
schema: Optional[Dict[str, object]] = None
|
|
60
|
+
"""
|
|
61
|
+
The schema for the response format, described as a JSON Schema object. Learn how
|
|
62
|
+
to build JSON schemas [here](https://json-schema.org/).
|
|
63
|
+
"""
|
|
64
|
+
strict: Optional[bool] = None
|
|
65
|
+
"""
|
|
66
|
+
Whether to enable strict schema adherence when generating the output. If set to
|
|
67
|
+
true, the model will always follow the exact schema defined in the `schema`
|
|
68
|
+
field.
|
|
69
|
+
"""
|
|
49
70
|
|
|
50
71
|
|
|
51
72
|
@dataclass_with_extra
|
|
52
|
-
class
|
|
53
|
-
type: "
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
""
|
|
73
|
+
class ChatCompletionInputResponseFormatText(BaseInferenceType):
|
|
74
|
+
type: Literal["text"]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass_with_extra
|
|
78
|
+
class ChatCompletionInputResponseFormatJSONSchema(BaseInferenceType):
|
|
79
|
+
type: Literal["json_schema"]
|
|
80
|
+
json_schema: ChatCompletionInputJSONSchema
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass_with_extra
|
|
84
|
+
class ChatCompletionInputResponseFormatJSONObject(BaseInferenceType):
|
|
85
|
+
type: Literal["json_object"]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
ChatCompletionInputGrammarType = Union[
|
|
89
|
+
ChatCompletionInputResponseFormatText,
|
|
90
|
+
ChatCompletionInputResponseFormatJSONSchema,
|
|
91
|
+
ChatCompletionInputResponseFormatJSONObject,
|
|
92
|
+
]
|
|
59
93
|
|
|
60
94
|
|
|
61
95
|
@dataclass_with_extra
|
|
File without changes
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
|
|
5
|
+
|
|
6
|
+
from huggingface_hub import ChatCompletionInputMessage, ChatCompletionStreamOutput, MCPClient
|
|
7
|
+
|
|
8
|
+
from .._providers import PROVIDER_OR_POLICY_T
|
|
9
|
+
from .constants import DEFAULT_SYSTEM_PROMPT, EXIT_LOOP_TOOLS, MAX_NUM_TURNS
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Agent(MCPClient):
|
|
13
|
+
"""
|
|
14
|
+
Implementation of a Simple Agent, which is a simple while loop built right on top of an [`MCPClient`].
|
|
15
|
+
|
|
16
|
+
<Tip warning={true}>
|
|
17
|
+
|
|
18
|
+
This class is experimental and might be subject to breaking changes in the future without prior notice.
|
|
19
|
+
|
|
20
|
+
</Tip>
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model (`str`):
|
|
24
|
+
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
|
|
25
|
+
or a URL to a deployed Inference Endpoint or other local or remote endpoint.
|
|
26
|
+
servers (`Iterable[Dict]`):
|
|
27
|
+
MCP servers to connect to. Each server is a dictionary containing a `type` key and a `config` key. The `type` key can be `"stdio"` or `"sse"`, and the `config` key is a dictionary of arguments for the server.
|
|
28
|
+
provider (`str`, *optional*):
|
|
29
|
+
Name of the provider to use for inference. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
30
|
+
If model is a URL or `base_url` is passed, then `provider` is not used.
|
|
31
|
+
api_key (`str`, *optional*):
|
|
32
|
+
Token to use for authentication. Will default to the locally Hugging Face saved token if not provided. You can also use your own provider API key to interact directly with the provider's service.
|
|
33
|
+
prompt (`str`, *optional*):
|
|
34
|
+
The system prompt to use for the agent. Defaults to the default system prompt in `constants.py`.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*,
|
|
40
|
+
model: str,
|
|
41
|
+
servers: Iterable[Dict],
|
|
42
|
+
provider: Optional[PROVIDER_OR_POLICY_T] = None,
|
|
43
|
+
api_key: Optional[str] = None,
|
|
44
|
+
prompt: Optional[str] = None,
|
|
45
|
+
):
|
|
46
|
+
super().__init__(model=model, provider=provider, api_key=api_key)
|
|
47
|
+
self._servers_cfg = list(servers)
|
|
48
|
+
self.messages: List[Union[Dict, ChatCompletionInputMessage]] = [
|
|
49
|
+
{"role": "system", "content": prompt or DEFAULT_SYSTEM_PROMPT}
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
async def load_tools(self) -> None:
|
|
53
|
+
for cfg in self._servers_cfg:
|
|
54
|
+
await self.add_mcp_server(cfg["type"], **cfg["config"])
|
|
55
|
+
|
|
56
|
+
async def run(
|
|
57
|
+
self,
|
|
58
|
+
user_input: str,
|
|
59
|
+
*,
|
|
60
|
+
abort_event: Optional[asyncio.Event] = None,
|
|
61
|
+
) -> AsyncGenerator[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage], None]:
|
|
62
|
+
"""
|
|
63
|
+
Run the agent with the given user input.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
user_input (`str`):
|
|
67
|
+
The user input to run the agent with.
|
|
68
|
+
abort_event (`asyncio.Event`, *optional*):
|
|
69
|
+
An event that can be used to abort the agent. If the event is set, the agent will stop running.
|
|
70
|
+
"""
|
|
71
|
+
self.messages.append({"role": "user", "content": user_input})
|
|
72
|
+
|
|
73
|
+
num_turns: int = 0
|
|
74
|
+
next_turn_should_call_tools = True
|
|
75
|
+
|
|
76
|
+
while True:
|
|
77
|
+
if abort_event and abort_event.is_set():
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
async for item in self.process_single_turn_with_tools(
|
|
81
|
+
self.messages,
|
|
82
|
+
exit_loop_tools=EXIT_LOOP_TOOLS,
|
|
83
|
+
exit_if_first_chunk_no_tool=(num_turns > 0 and next_turn_should_call_tools),
|
|
84
|
+
):
|
|
85
|
+
yield item
|
|
86
|
+
|
|
87
|
+
num_turns += 1
|
|
88
|
+
last = self.messages[-1]
|
|
89
|
+
|
|
90
|
+
if last.get("role") == "tool" and last.get("name") in {t.function.name for t in EXIT_LOOP_TOOLS}:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
if last.get("role") != "tool" and num_turns > MAX_NUM_TURNS:
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
if last.get("role") != "tool" and next_turn_should_call_tools:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
next_turn_should_call_tools = last.get("role") != "tool"
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
import signal
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import typer
|
|
8
|
+
from rich import print
|
|
9
|
+
|
|
10
|
+
from .agent import Agent
|
|
11
|
+
from .utils import _load_agent_config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
app = typer.Typer(
|
|
15
|
+
rich_markup_mode="rich",
|
|
16
|
+
help="A squad of lightweight composable AI applications built on Hugging Face's Inference Client and MCP stack.",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
run_cli = typer.Typer(
|
|
20
|
+
name="run",
|
|
21
|
+
help="Run the Agent in the CLI",
|
|
22
|
+
invoke_without_command=True,
|
|
23
|
+
)
|
|
24
|
+
app.add_typer(run_cli, name="run")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def _ainput(prompt: str = "» ") -> str:
|
|
28
|
+
loop = asyncio.get_running_loop()
|
|
29
|
+
return await loop.run_in_executor(None, partial(typer.prompt, prompt, prompt_suffix=" "))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def run_agent(
|
|
33
|
+
agent_path: Optional[str],
|
|
34
|
+
) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Tiny Agent loop.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
agent_path (`str`, *optional*):
|
|
40
|
+
Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` file or a built-in agent stored in a Hugging Face dataset.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
config, prompt = _load_agent_config(agent_path)
|
|
44
|
+
|
|
45
|
+
servers: List[Dict[str, Any]] = config.get("servers", [])
|
|
46
|
+
|
|
47
|
+
abort_event = asyncio.Event()
|
|
48
|
+
first_sigint = True
|
|
49
|
+
|
|
50
|
+
loop = asyncio.get_running_loop()
|
|
51
|
+
original_sigint_handler = signal.getsignal(signal.SIGINT)
|
|
52
|
+
|
|
53
|
+
def _sigint_handler() -> None:
|
|
54
|
+
nonlocal first_sigint
|
|
55
|
+
if first_sigint:
|
|
56
|
+
first_sigint = False
|
|
57
|
+
abort_event.set()
|
|
58
|
+
print("\n[red]Interrupted. Press Ctrl+C again to quit.[/red]", flush=True)
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
print("\n[red]Exiting...[/red]", flush=True)
|
|
62
|
+
|
|
63
|
+
os._exit(130)
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
loop.add_signal_handler(signal.SIGINT, _sigint_handler)
|
|
67
|
+
|
|
68
|
+
async with Agent(
|
|
69
|
+
provider=config["provider"],
|
|
70
|
+
model=config["model"],
|
|
71
|
+
servers=servers,
|
|
72
|
+
prompt=prompt,
|
|
73
|
+
) as agent:
|
|
74
|
+
await agent.load_tools()
|
|
75
|
+
print(f"[bold blue]Agent loaded with {len(agent.available_tools)} tools:[/bold blue]")
|
|
76
|
+
for t in agent.available_tools:
|
|
77
|
+
print(f"[blue] • {t.function.name}[/blue]")
|
|
78
|
+
|
|
79
|
+
while True:
|
|
80
|
+
abort_event.clear()
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
user_input = await _ainput()
|
|
84
|
+
first_sigint = True
|
|
85
|
+
except EOFError:
|
|
86
|
+
print("\n[red]EOF received, exiting.[/red]", flush=True)
|
|
87
|
+
break
|
|
88
|
+
except KeyboardInterrupt:
|
|
89
|
+
if not first_sigint and abort_event.is_set():
|
|
90
|
+
continue
|
|
91
|
+
else:
|
|
92
|
+
print("\n[red]Keyboard interrupt during input processing.[/red]", flush=True)
|
|
93
|
+
break
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
async for chunk in agent.run(user_input, abort_event=abort_event):
|
|
97
|
+
if abort_event.is_set() and not first_sigint:
|
|
98
|
+
break
|
|
99
|
+
|
|
100
|
+
if hasattr(chunk, "choices"):
|
|
101
|
+
delta = chunk.choices[0].delta
|
|
102
|
+
if delta.content:
|
|
103
|
+
print(delta.content, end="", flush=True)
|
|
104
|
+
if delta.tool_calls:
|
|
105
|
+
for call in delta.tool_calls:
|
|
106
|
+
if call.id:
|
|
107
|
+
print(f"<Tool {call.id}>", end="")
|
|
108
|
+
if call.function.name:
|
|
109
|
+
print(f"{call.function.name}", end=" ")
|
|
110
|
+
if call.function.arguments:
|
|
111
|
+
print(f"{call.function.arguments}", end="")
|
|
112
|
+
else:
|
|
113
|
+
print(
|
|
114
|
+
f"\n\n[green]Tool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}[/green]\n",
|
|
115
|
+
flush=True,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
print()
|
|
119
|
+
|
|
120
|
+
except Exception as e:
|
|
121
|
+
print(f"\n[bold red]Error during agent run: {e}[/bold red]", flush=True)
|
|
122
|
+
first_sigint = True # Allow graceful interrupt for the next command
|
|
123
|
+
|
|
124
|
+
finally:
|
|
125
|
+
if loop and not loop.is_closed():
|
|
126
|
+
loop.remove_signal_handler(signal.SIGINT)
|
|
127
|
+
elif original_sigint_handler:
|
|
128
|
+
signal.signal(signal.SIGINT, original_sigint_handler)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@run_cli.callback()
|
|
132
|
+
def run(
|
|
133
|
+
path: Optional[str] = typer.Argument(
|
|
134
|
+
None,
|
|
135
|
+
help=(
|
|
136
|
+
"Path to a local folder containing an agent.json file or a built-in agent "
|
|
137
|
+
"stored in the 'tiny-agents/tiny-agents' Hugging Face dataset "
|
|
138
|
+
"(https://huggingface.co/datasets/tiny-agents/tiny-agents)"
|
|
139
|
+
),
|
|
140
|
+
show_default=False,
|
|
141
|
+
),
|
|
142
|
+
):
|
|
143
|
+
try:
|
|
144
|
+
asyncio.run(run_agent(path))
|
|
145
|
+
except KeyboardInterrupt:
|
|
146
|
+
print("\n[red]Application terminated by KeyboardInterrupt.[/red]", flush=True)
|
|
147
|
+
raise typer.Exit(code=130)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
print(f"\n[bold red]An unexpected error occurred: {e}[/bold red]", flush=True)
|
|
150
|
+
raise e
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
if __name__ == "__main__":
|
|
154
|
+
app()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
from huggingface_hub import ChatCompletionInputTool
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
FILENAME_CONFIG = "agent.json"
|
|
11
|
+
FILENAME_PROMPT = "PROMPT.md"
|
|
12
|
+
|
|
13
|
+
DEFAULT_AGENT = {
|
|
14
|
+
"model": "Qwen/Qwen2.5-72B-Instruct",
|
|
15
|
+
"provider": "nebius",
|
|
16
|
+
"servers": [
|
|
17
|
+
{
|
|
18
|
+
"type": "stdio",
|
|
19
|
+
"config": {
|
|
20
|
+
"command": "npx",
|
|
21
|
+
"args": [
|
|
22
|
+
"-y",
|
|
23
|
+
"@modelcontextprotocol/server-filesystem",
|
|
24
|
+
str(Path.home() / ("Desktop" if sys.platform == "darwin" else "")),
|
|
25
|
+
],
|
|
26
|
+
},
|
|
27
|
+
},
|
|
28
|
+
{
|
|
29
|
+
"type": "stdio",
|
|
30
|
+
"config": {
|
|
31
|
+
"command": "npx",
|
|
32
|
+
"args": ["@playwright/mcp@latest"],
|
|
33
|
+
},
|
|
34
|
+
},
|
|
35
|
+
],
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
DEFAULT_SYSTEM_PROMPT = """
|
|
40
|
+
You are an agent - please keep going until the user’s query is completely
|
|
41
|
+
resolved, before ending your turn and yielding back to the user. Only terminate
|
|
42
|
+
your turn when you are sure that the problem is solved, or if you need more
|
|
43
|
+
info from the user to solve the problem.
|
|
44
|
+
If you are not sure about anything pertaining to the user’s request, use your
|
|
45
|
+
tools to read files and gather the relevant information: do NOT guess or make
|
|
46
|
+
up an answer.
|
|
47
|
+
You MUST plan extensively before each function call, and reflect extensively
|
|
48
|
+
on the outcomes of the previous function calls. DO NOT do this entire process
|
|
49
|
+
by making function calls only, as this can impair your ability to solve the
|
|
50
|
+
problem and think insightfully.
|
|
51
|
+
""".strip()
|
|
52
|
+
|
|
53
|
+
MAX_NUM_TURNS = 10
|
|
54
|
+
|
|
55
|
+
TASK_COMPLETE_TOOL: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj( # type: ignore[assignment]
|
|
56
|
+
{
|
|
57
|
+
"type": "function",
|
|
58
|
+
"function": {
|
|
59
|
+
"name": "task_complete",
|
|
60
|
+
"description": "Call this tool when the task given by the user is complete",
|
|
61
|
+
"parameters": {"type": "object", "properties": {}},
|
|
62
|
+
},
|
|
63
|
+
}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
ASK_QUESTION_TOOL: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj( # type: ignore[assignment]
|
|
67
|
+
{
|
|
68
|
+
"type": "function",
|
|
69
|
+
"function": {
|
|
70
|
+
"name": "ask_question",
|
|
71
|
+
"description": "Ask the user for more info required to solve or clarify their problem.",
|
|
72
|
+
"parameters": {"type": "object", "properties": {}},
|
|
73
|
+
},
|
|
74
|
+
}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
EXIT_LOOP_TOOLS: List[ChatCompletionInputTool] = [TASK_COMPLETE_TOOL, ASK_QUESTION_TOOL]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
DEFAULT_REPO_ID = "tiny-agents/tiny-agents"
|