rwkv-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_api/__init__.py +90 -0
- rwkv_api/_client.py +399 -0
- rwkv_api/_sync.py +359 -0
- rwkv_api/_task.py +197 -0
- rwkv_api/exceptions.py +85 -0
- rwkv_api/models.py +101 -0
- rwkv_api-0.1.0.dist-info/METADATA +285 -0
- rwkv_api-0.1.0.dist-info/RECORD +10 -0
- rwkv_api-0.1.0.dist-info/WHEEL +4 -0
- rwkv_api-0.1.0.dist-info/licenses/LICENSE +21 -0
rwkv_api/__init__.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""rwkv_api —— RWKV-Server Task API 的 Python SDK。
|
|
2
|
+
|
|
3
|
+
提供同步和异步两种接口,支持流式 SSE 响应和完整的任务生命周期管理。
|
|
4
|
+
|
|
5
|
+
快速开始(异步)::
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
from rwkv_api import AsyncClient
|
|
9
|
+
|
|
10
|
+
async def main():
|
|
11
|
+
async with AsyncClient("http://localhost:8000") as client:
|
|
12
|
+
task = await client.create("Hello, world!", max_tokens=50)
|
|
13
|
+
await task.wait()
|
|
14
|
+
print(task.result)
|
|
15
|
+
|
|
16
|
+
asyncio.run(main())
|
|
17
|
+
|
|
18
|
+
快速开始(同步)::
|
|
19
|
+
|
|
20
|
+
from rwkv_api import Client
|
|
21
|
+
|
|
22
|
+
with Client("http://localhost:8000") as client:
|
|
23
|
+
task = client.create("Hello, world!", max_tokens=50)
|
|
24
|
+
task.wait()
|
|
25
|
+
print(task.result)
|
|
26
|
+
|
|
27
|
+
流式生成(异步)::
|
|
28
|
+
|
|
29
|
+
async with AsyncClient("http://localhost:8000") as client:
|
|
30
|
+
task = await client.create("Hello", stream=True)
|
|
31
|
+
async for chunk in task:
|
|
32
|
+
print(chunk, end="")
|
|
33
|
+
|
|
34
|
+
流式生成(同步)::
|
|
35
|
+
|
|
36
|
+
with Client("http://localhost:8000") as client:
|
|
37
|
+
task = client.create("Hello", stream=True)
|
|
38
|
+
for chunk in task:
|
|
39
|
+
print(chunk, end="")
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
from . import exceptions
|
|
43
|
+
from ._client import AsyncClient
|
|
44
|
+
from ._sync import Client
|
|
45
|
+
from ._task import AsyncTask, Task
|
|
46
|
+
from .exceptions import (
|
|
47
|
+
ConnectionError as RWKVConnectionError,
|
|
48
|
+
RWKVError,
|
|
49
|
+
RWKVServerError,
|
|
50
|
+
RWKVValidationError,
|
|
51
|
+
TaskCancelledError,
|
|
52
|
+
TaskNotFoundError,
|
|
53
|
+
TimeoutError as RWKVTimeoutError,
|
|
54
|
+
)
|
|
55
|
+
from .models import (
|
|
56
|
+
FIMRequest,
|
|
57
|
+
Status,
|
|
58
|
+
TaskCreate,
|
|
59
|
+
TaskInfo,
|
|
60
|
+
TaskResponseModel,
|
|
61
|
+
TaskUpdate,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
__version__ = "0.1.0"
|
|
65
|
+
__all__ = [
|
|
66
|
+
# 客户端
|
|
67
|
+
"AsyncClient",
|
|
68
|
+
"Client",
|
|
69
|
+
# Task 对象
|
|
70
|
+
"Task",
|
|
71
|
+
"AsyncTask",
|
|
72
|
+
# 数据模型
|
|
73
|
+
"TaskCreate",
|
|
74
|
+
"TaskUpdate",
|
|
75
|
+
"TaskResponseModel",
|
|
76
|
+
"TaskInfo",
|
|
77
|
+
"FIMRequest",
|
|
78
|
+
"Status",
|
|
79
|
+
# 异常
|
|
80
|
+
"RWKVError",
|
|
81
|
+
"RWKVValidationError",
|
|
82
|
+
"RWKVServerError",
|
|
83
|
+
"TaskNotFoundError",
|
|
84
|
+
"TaskCancelledError",
|
|
85
|
+
"RWKVTimeoutError",
|
|
86
|
+
"RWKVConnectionError",
|
|
87
|
+
"exceptions",
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
|
rwkv_api/_client.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""AsyncClient —— RWKV-Server Task API 的异步客户端。
|
|
2
|
+
|
|
3
|
+
基于 httpx.AsyncHTTPClient 实现所有 API 调用,
|
|
4
|
+
支持流式 SSE 响应和完整的 Task 对象生命周期管理。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from typing import Any, AsyncIterator
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from . import _task
|
|
15
|
+
from .exceptions import (
|
|
16
|
+
ConnectionError as RWKVConnectionError,
|
|
17
|
+
RWKVError,
|
|
18
|
+
raise_for_status,
|
|
19
|
+
)
|
|
20
|
+
from .models import (
|
|
21
|
+
TaskInfo,
|
|
22
|
+
TaskResponseModel,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# API 基础路径
|
|
26
|
+
_TASKS_PREFIX = "/v1/tasks"
|
|
27
|
+
|
|
28
|
+
# 生成参数字段名集合,用于构建请求体
|
|
29
|
+
_GEN_PARAMS = frozenset({
|
|
30
|
+
"max_tokens", "temperature", "top_k", "top_p",
|
|
31
|
+
"presence_penalty", "repetition_penalty", "penalty_decay", "seed",
|
|
32
|
+
})
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class AsyncClient:
|
|
36
|
+
"""RWKV-Server Task API 异步客户端。
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
base_url: 服务地址,如 ``http://localhost:8000``。
|
|
40
|
+
timeout: HTTP 请求超时(秒)。
|
|
41
|
+
headers: 额外的 HTTP 请求头。
|
|
42
|
+
httpx_client: 自定义的 httpx.AsyncClient 实例(高级用法)。
|
|
43
|
+
|
|
44
|
+
Usage::
|
|
45
|
+
|
|
46
|
+
async with AsyncClient("http://localhost:8000") as client:
|
|
47
|
+
task = await client.create("Hello, world!", max_tokens=50)
|
|
48
|
+
result = await task.wait()
|
|
49
|
+
print(result.result)
|
|
50
|
+
|
|
51
|
+
# 实时流式
|
|
52
|
+
async with AsyncClient("http://localhost:8000") as client:
|
|
53
|
+
async for chunk in await client.create("Hello", stream=True):
|
|
54
|
+
print(chunk, end="")
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
base_url: str = "http://localhost:8000",
|
|
60
|
+
*,
|
|
61
|
+
timeout: float = 120.0,
|
|
62
|
+
headers: dict[str, str] | None = None,
|
|
63
|
+
httpx_client: httpx.AsyncClient | None = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
if httpx_client is not None:
|
|
66
|
+
self._client = httpx_client
|
|
67
|
+
else:
|
|
68
|
+
self._client = httpx.AsyncClient(
|
|
69
|
+
base_url=base_url.rstrip("/"),
|
|
70
|
+
timeout=httpx.Timeout(timeout),
|
|
71
|
+
headers=headers,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
async def __aenter__(self) -> AsyncClient:
|
|
75
|
+
return self
|
|
76
|
+
|
|
77
|
+
async def __aexit__(self, *args: Any) -> None:
|
|
78
|
+
await self.close()
|
|
79
|
+
|
|
80
|
+
async def close(self) -> None:
|
|
81
|
+
"""关闭底层 HTTP 连接池。"""
|
|
82
|
+
await self._client.aclose()
|
|
83
|
+
|
|
84
|
+
# ===================================================================
|
|
85
|
+
# 流式 API
|
|
86
|
+
# ===================================================================
|
|
87
|
+
|
|
88
|
+
async def _stream_sse(
|
|
89
|
+
self,
|
|
90
|
+
method: str,
|
|
91
|
+
endpoint: str,
|
|
92
|
+
body: dict[str, Any],
|
|
93
|
+
) -> AsyncIterator[str]:
|
|
94
|
+
"""通用 SSE 流式请求,逐 chunk yield。"""
|
|
95
|
+
try:
|
|
96
|
+
async with self._client.stream(method, endpoint, json=body) as resp:
|
|
97
|
+
raise_for_status(resp.status_code, "")
|
|
98
|
+
async for line in resp.aiter_lines():
|
|
99
|
+
line = line.strip()
|
|
100
|
+
if not line or not line.startswith("data:"):
|
|
101
|
+
continue
|
|
102
|
+
data_str = line[len("data:"):].strip()
|
|
103
|
+
if data_str == "[DONE]":
|
|
104
|
+
break
|
|
105
|
+
try:
|
|
106
|
+
parsed = json.loads(data_str)
|
|
107
|
+
except json.JSONDecodeError:
|
|
108
|
+
continue
|
|
109
|
+
# 跳过首事件(仅含 prefill_time,无 data)
|
|
110
|
+
if "prefill_time" in parsed and "data" not in parsed:
|
|
111
|
+
continue
|
|
112
|
+
if "data" in parsed:
|
|
113
|
+
yield parsed["data"]
|
|
114
|
+
except httpx.ConnectError as e:
|
|
115
|
+
raise RWKVConnectionError(f"无法连接到服务器: {e}") from e
|
|
116
|
+
except httpx.TimeoutException as e:
|
|
117
|
+
from .exceptions import TimeoutError as _Timeout
|
|
118
|
+
raise _Timeout(f"请求超时: {e}") from e
|
|
119
|
+
|
|
120
|
+
async def create_stream(
|
|
121
|
+
self,
|
|
122
|
+
prompt: str | list[int],
|
|
123
|
+
*,
|
|
124
|
+
max_tokens: int | None = None,
|
|
125
|
+
temperature: float | None = None,
|
|
126
|
+
top_k: int | None = None,
|
|
127
|
+
top_p: float | None = None,
|
|
128
|
+
presence_penalty: float | None = None,
|
|
129
|
+
repetition_penalty: float | None = None,
|
|
130
|
+
penalty_decay: float | None = None,
|
|
131
|
+
seed: int | None = None,
|
|
132
|
+
persistent: bool = False,
|
|
133
|
+
) -> AsyncIterator[str]:
|
|
134
|
+
"""创建任务并实时流式返回生成内容。"""
|
|
135
|
+
body: dict[str, Any] = {"prompt": prompt, "stream": True}
|
|
136
|
+
self._inject_gen_params(body, max_tokens=max_tokens, temperature=temperature,
|
|
137
|
+
top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
|
|
138
|
+
repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
|
|
139
|
+
seed=seed)
|
|
140
|
+
endpoint = f"{_TASKS_PREFIX}/create" if persistent else f"{_TASKS_PREFIX}/tmp"
|
|
141
|
+
async for chunk in self._stream_sse("POST", endpoint, body):
|
|
142
|
+
yield chunk
|
|
143
|
+
|
|
144
|
+
async def fim_stream(
|
|
145
|
+
self,
|
|
146
|
+
prefix: str,
|
|
147
|
+
suffix: str = "",
|
|
148
|
+
*,
|
|
149
|
+
max_tokens: int | None = None,
|
|
150
|
+
temperature: float | None = None,
|
|
151
|
+
top_k: int | None = None,
|
|
152
|
+
top_p: float | None = None,
|
|
153
|
+
presence_penalty: float | None = None,
|
|
154
|
+
repetition_penalty: float | None = None,
|
|
155
|
+
penalty_decay: float | None = None,
|
|
156
|
+
seed: int | None = None,
|
|
157
|
+
) -> AsyncIterator[str]:
|
|
158
|
+
"""FIM 实时流式返回生成内容。"""
|
|
159
|
+
body: dict[str, Any] = {"prefix": prefix, "suffix": suffix, "stream": True}
|
|
160
|
+
self._inject_gen_params(body, max_tokens=max_tokens, temperature=temperature,
|
|
161
|
+
top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
|
|
162
|
+
repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
|
|
163
|
+
seed=seed)
|
|
164
|
+
async for chunk in self._stream_sse("POST", f"{_TASKS_PREFIX}/fim", body):
|
|
165
|
+
yield chunk
|
|
166
|
+
|
|
167
|
+
# ===================================================================
|
|
168
|
+
# Task API
|
|
169
|
+
# ===================================================================
|
|
170
|
+
|
|
171
|
+
async def create(
|
|
172
|
+
self,
|
|
173
|
+
prompt: str | list[int],
|
|
174
|
+
*,
|
|
175
|
+
max_tokens: int | None = None,
|
|
176
|
+
temperature: float | None = None,
|
|
177
|
+
top_k: int | None = None,
|
|
178
|
+
top_p: float | None = None,
|
|
179
|
+
presence_penalty: float | None = None,
|
|
180
|
+
repetition_penalty: float | None = None,
|
|
181
|
+
penalty_decay: float | None = None,
|
|
182
|
+
stream: bool = False,
|
|
183
|
+
seed: int | None = None,
|
|
184
|
+
persistent: bool = False,
|
|
185
|
+
) -> _task.AsyncTask | AsyncIterator[str]:
|
|
186
|
+
"""创建任务。
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
prompt: 提示文本或 token id 列表。
|
|
190
|
+
max_tokens: 最大生成 token 数。
|
|
191
|
+
temperature: 采样温度。
|
|
192
|
+
top_k: Top-K 采样参数。
|
|
193
|
+
top_p: Top-P 采样参数。
|
|
194
|
+
presence_penalty: 存在惩罚。
|
|
195
|
+
repetition_penalty: 重复惩罚。
|
|
196
|
+
penalty_decay: 惩罚衰减。
|
|
197
|
+
stream: 是否使用流式响应。True 时返回异步生成器。
|
|
198
|
+
seed: 随机种子。
|
|
199
|
+
persistent: 是否创建持久化任务(True 使用 /create,False 使用 /tmp)。
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
stream=True 时返回 AsyncIterator[str];
|
|
203
|
+
stream=False 时返回 AsyncTask 对象。
|
|
204
|
+
"""
|
|
205
|
+
if stream is True:
|
|
206
|
+
return await self.create_stream(
|
|
207
|
+
prompt, max_tokens=max_tokens, temperature=temperature,
|
|
208
|
+
top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
|
|
209
|
+
repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
|
|
210
|
+
seed=seed, persistent=persistent,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
body: dict[str, Any] = {"prompt": prompt, "stream": False}
|
|
214
|
+
self._inject_gen_params(body, max_tokens=max_tokens, temperature=temperature,
|
|
215
|
+
top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
|
|
216
|
+
repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
|
|
217
|
+
seed=seed)
|
|
218
|
+
endpoint = f"{_TASKS_PREFIX}/create" if persistent else f"{_TASKS_PREFIX}/tmp"
|
|
219
|
+
resp_model = await self._post_json(endpoint, body, response_model=TaskResponseModel)
|
|
220
|
+
return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
|
|
221
|
+
|
|
222
|
+
async def create_tmp(
|
|
223
|
+
self, prompt: str | list[int] | None = None, /, **kwargs: Any
|
|
224
|
+
) -> _task.AsyncTask | AsyncIterator[str]:
|
|
225
|
+
"""创建临时任务(等同于 create(..., persistent=False))。"""
|
|
226
|
+
if prompt is not None:
|
|
227
|
+
kwargs["prompt"] = prompt
|
|
228
|
+
return await self.create(**kwargs, persistent=False)
|
|
229
|
+
|
|
230
|
+
async def create_persistent(
|
|
231
|
+
self, prompt: str | list[int] | None = None, /, **kwargs: Any
|
|
232
|
+
) -> _task.AsyncTask | AsyncIterator[str]:
|
|
233
|
+
"""创建持久化任务(等同于 create(..., persistent=True))。"""
|
|
234
|
+
if prompt is not None:
|
|
235
|
+
kwargs["prompt"] = prompt
|
|
236
|
+
return await self.create(**kwargs, persistent=True)
|
|
237
|
+
|
|
238
|
+
async def fim(
|
|
239
|
+
self,
|
|
240
|
+
prefix: str,
|
|
241
|
+
suffix: str = "",
|
|
242
|
+
*,
|
|
243
|
+
max_tokens: int | None = None,
|
|
244
|
+
temperature: float | None = None,
|
|
245
|
+
top_k: int | None = None,
|
|
246
|
+
top_p: float | None = None,
|
|
247
|
+
presence_penalty: float | None = None,
|
|
248
|
+
repetition_penalty: float | None = None,
|
|
249
|
+
penalty_decay: float | None = None,
|
|
250
|
+
stream: bool = False,
|
|
251
|
+
seed: int | None = None,
|
|
252
|
+
) -> _task.AsyncTask | AsyncIterator[str]:
|
|
253
|
+
"""Fill In Middle —— 在 prefix 和 suffix 之间生成文本。
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
prefix: 前缀文本。
|
|
257
|
+
suffix: 后缀文本。
|
|
258
|
+
stream: 是否使用流式响应。True 时返回异步生成器。
|
|
259
|
+
其余参数同 create()。
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
stream=True 时返回 AsyncIterator[str];
|
|
263
|
+
stream=False 时返回 AsyncTask 对象。
|
|
264
|
+
"""
|
|
265
|
+
if stream is True:
|
|
266
|
+
return await self.fim_stream(
|
|
267
|
+
prefix, suffix, max_tokens=max_tokens, temperature=temperature,
|
|
268
|
+
top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
|
|
269
|
+
repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
|
|
270
|
+
seed=seed,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
body: dict[str, Any] = {"prefix": prefix, "suffix": suffix, "stream": False}
|
|
274
|
+
self._inject_gen_params(body, max_tokens=max_tokens, temperature=temperature,
|
|
275
|
+
top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
|
|
276
|
+
repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
|
|
277
|
+
seed=seed)
|
|
278
|
+
resp_model = await self._post_json(f"{_TASKS_PREFIX}/fim", body, response_model=TaskResponseModel)
|
|
279
|
+
return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
|
|
280
|
+
|
|
281
|
+
async def get_task_result(self, task_id: str) -> TaskResponseModel:
|
|
282
|
+
"""获取任务结果。"""
|
|
283
|
+
return await self._get_json(f"{_TASKS_PREFIX}/{task_id}/get_result", response_model=TaskResponseModel)
|
|
284
|
+
|
|
285
|
+
async def get_task_status(self, task_id: str) -> TaskInfo:
|
|
286
|
+
"""获取任务状态。"""
|
|
287
|
+
return await self._get_json(f"{_TASKS_PREFIX}/{task_id}/status", response_model=TaskInfo)
|
|
288
|
+
|
|
289
|
+
async def fork_task(self, task_id: str, **overrides: Any) -> _task.AsyncTask:
|
|
290
|
+
"""Fork 任务。"""
|
|
291
|
+
body = self._build_update_body(overrides)
|
|
292
|
+
resp_model = await self._post_json(
|
|
293
|
+
f"{_TASKS_PREFIX}/{task_id}/fork", body, response_model=TaskResponseModel,
|
|
294
|
+
)
|
|
295
|
+
return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
|
|
296
|
+
|
|
297
|
+
async def continue_task(self, task_id: str, **overrides: Any) -> _task.AsyncTask:
|
|
298
|
+
"""继续生成。"""
|
|
299
|
+
body = self._build_update_body(overrides)
|
|
300
|
+
resp_model = await self._post_json(
|
|
301
|
+
f"{_TASKS_PREFIX}/{task_id}/continue", body, response_model=TaskResponseModel,
|
|
302
|
+
)
|
|
303
|
+
return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
|
|
304
|
+
|
|
305
|
+
async def stop_task(self, task_id: str) -> None:
|
|
306
|
+
"""停止任务。"""
|
|
307
|
+
await self._post_no_content(f"{_TASKS_PREFIX}/{task_id}/stop", {})
|
|
308
|
+
|
|
309
|
+
async def as_template(self, task_id: str) -> _task.AsyncTask:
|
|
310
|
+
"""将任务转为模板。"""
|
|
311
|
+
resp_model = await self._post_json(
|
|
312
|
+
f"{_TASKS_PREFIX}/{task_id}/as_template", {}, response_model=TaskResponseModel,
|
|
313
|
+
)
|
|
314
|
+
return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
|
|
315
|
+
|
|
316
|
+
async def stream_task(self, task_id: str) -> AsyncIterator[str]:
|
|
317
|
+
"""订阅已创建任务的实时流式输出。"""
|
|
318
|
+
async for chunk in self._stream_sse("GET", f"{_TASKS_PREFIX}/{task_id}/stream", {}):
|
|
319
|
+
yield chunk
|
|
320
|
+
|
|
321
|
+
async def delete_task(self, task_id: str, *, force: bool = False) -> None:
|
|
322
|
+
"""删除任务。"""
|
|
323
|
+
await self._post_no_content(
|
|
324
|
+
f"{_TASKS_PREFIX}/{task_id}/delete", {},
|
|
325
|
+
params={"force": str(force).lower()},
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
async def list_tasks(self) -> dict[str, Any]:
|
|
329
|
+
"""列出所有任务。"""
|
|
330
|
+
return await self._get_json(f"{_TASKS_PREFIX}/list")
|
|
331
|
+
|
|
332
|
+
# ===================================================================
|
|
333
|
+
# 内部 HTTP 方法
|
|
334
|
+
# ===================================================================
|
|
335
|
+
|
|
336
|
+
async def _post_json(
|
|
337
|
+
self, path: str, body: dict[str, Any], *, response_model: type[Any] | None = None,
|
|
338
|
+
) -> Any:
|
|
339
|
+
"""发送 POST 请求,解析 JSON 响应。"""
|
|
340
|
+
try:
|
|
341
|
+
resp = await self._client.post(path, json=body)
|
|
342
|
+
except httpx.ConnectError as e:
|
|
343
|
+
raise RWKVConnectionError(f"无法连接到服务器: {e}") from e
|
|
344
|
+
except httpx.TimeoutException as e:
|
|
345
|
+
from .exceptions import TimeoutError as _Timeout
|
|
346
|
+
raise _Timeout(f"请求超时: {e}") from e
|
|
347
|
+
raise_for_status(resp.status_code, resp.text)
|
|
348
|
+
if response_model is not None:
|
|
349
|
+
return response_model.model_validate(resp.json())
|
|
350
|
+
return resp.json()
|
|
351
|
+
|
|
352
|
+
async def _get_json(
|
|
353
|
+
self, path: str, *, response_model: type[Any] | None = None,
|
|
354
|
+
) -> Any:
|
|
355
|
+
"""发送 GET 请求,解析 JSON 响应。"""
|
|
356
|
+
try:
|
|
357
|
+
resp = await self._client.get(path)
|
|
358
|
+
except httpx.ConnectError as e:
|
|
359
|
+
raise RWKVConnectionError(f"无法连接到服务器: {e}") from e
|
|
360
|
+
except httpx.TimeoutException as e:
|
|
361
|
+
from .exceptions import TimeoutError as _Timeout
|
|
362
|
+
raise _Timeout(f"请求超时: {e}") from e
|
|
363
|
+
raise_for_status(resp.status_code, resp.text)
|
|
364
|
+
if response_model is not None:
|
|
365
|
+
return response_model.model_validate(resp.json())
|
|
366
|
+
return resp.json()
|
|
367
|
+
|
|
368
|
+
async def _post_no_content(
|
|
369
|
+
self, path: str, body: dict[str, Any], *, params: dict[str, str] | None = None,
|
|
370
|
+
) -> None:
|
|
371
|
+
"""发送 POST 请求,不解析响应体。"""
|
|
372
|
+
try:
|
|
373
|
+
resp = await self._client.post(path, json=body, params=params)
|
|
374
|
+
except httpx.ConnectError as e:
|
|
375
|
+
raise RWKVConnectionError(f"无法连接到服务器: {e}") from e
|
|
376
|
+
except httpx.TimeoutException as e:
|
|
377
|
+
from .exceptions import TimeoutError as _Timeout
|
|
378
|
+
raise _Timeout(f"请求超时: {e}") from e
|
|
379
|
+
raise_for_status(resp.status_code, resp.text)
|
|
380
|
+
|
|
381
|
+
# ===================================================================
|
|
382
|
+
# 内部辅助
|
|
383
|
+
# ===================================================================
|
|
384
|
+
|
|
385
|
+
@staticmethod
|
|
386
|
+
def _inject_gen_params(body: dict[str, Any], **params: Any) -> None:
|
|
387
|
+
"""将非 None 的生成参数注入请求体。"""
|
|
388
|
+
for key, value in params.items():
|
|
389
|
+
if value is not None:
|
|
390
|
+
body[key] = value
|
|
391
|
+
|
|
392
|
+
@staticmethod
|
|
393
|
+
def _build_update_body(overrides: dict[str, Any]) -> dict[str, Any]:
|
|
394
|
+
"""从关键字参数构建更新请求体。"""
|
|
395
|
+
body: dict[str, Any] = {"stream": False}
|
|
396
|
+
for key in ("prompt", *_GEN_PARAMS, "stream"):
|
|
397
|
+
if key in overrides:
|
|
398
|
+
body[key] = overrides[key]
|
|
399
|
+
return body
|