rwkv-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rwkv_api/__init__.py ADDED
@@ -0,0 +1,90 @@
1
+ """rwkv_api —— RWKV-Server Task API 的 Python SDK。
2
+
3
+ 提供同步和异步两种接口,支持流式 SSE 响应和完整的任务生命周期管理。
4
+
5
+ 快速开始(异步)::
6
+
7
+ import asyncio
8
+ from rwkv_api import AsyncClient
9
+
10
+ async def main():
11
+ async with AsyncClient("http://localhost:8000") as client:
12
+ task = await client.create("Hello, world!", max_tokens=50)
13
+ await task.wait()
14
+ print(task.result)
15
+
16
+ asyncio.run(main())
17
+
18
+ 快速开始(同步)::
19
+
20
+ from rwkv_api import Client
21
+
22
+ with Client("http://localhost:8000") as client:
23
+ task = client.create("Hello, world!", max_tokens=50)
24
+ task.wait()
25
+ print(task.result)
26
+
27
+ 流式生成(异步)::
28
+
29
+ async with AsyncClient("http://localhost:8000") as client:
30
+ task = await client.create("Hello", stream=True)
31
+ async for chunk in task:
32
+ print(chunk, end="")
33
+
34
+ 流式生成(同步)::
35
+
36
+ with Client("http://localhost:8000") as client:
37
+ task = client.create("Hello", stream=True)
38
+ for chunk in task:
39
+ print(chunk, end="")
40
+ """
41
+
42
+ from . import exceptions
43
+ from ._client import AsyncClient
44
+ from ._sync import Client
45
+ from ._task import AsyncTask, Task
46
+ from .exceptions import (
47
+ ConnectionError as RWKVConnectionError,
48
+ RWKVError,
49
+ RWKVServerError,
50
+ RWKVValidationError,
51
+ TaskCancelledError,
52
+ TaskNotFoundError,
53
+ TimeoutError as RWKVTimeoutError,
54
+ )
55
+ from .models import (
56
+ FIMRequest,
57
+ Status,
58
+ TaskCreate,
59
+ TaskInfo,
60
+ TaskResponseModel,
61
+ TaskUpdate,
62
+ )
63
+
64
+ __version__ = "0.1.0"
65
+ __all__ = [
66
+ # 客户端
67
+ "AsyncClient",
68
+ "Client",
69
+ # Task 对象
70
+ "Task",
71
+ "AsyncTask",
72
+ # 数据模型
73
+ "TaskCreate",
74
+ "TaskUpdate",
75
+ "TaskResponseModel",
76
+ "TaskInfo",
77
+ "FIMRequest",
78
+ "Status",
79
+ # 异常
80
+ "RWKVError",
81
+ "RWKVValidationError",
82
+ "RWKVServerError",
83
+ "TaskNotFoundError",
84
+ "TaskCancelledError",
85
+ "RWKVTimeoutError",
86
+ "RWKVConnectionError",
87
+ "exceptions",
88
+ ]
89
+
90
+
rwkv_api/_client.py ADDED
@@ -0,0 +1,399 @@
1
+ """AsyncClient —— RWKV-Server Task API 的异步客户端。
2
+
3
+ 基于 httpx.AsyncHTTPClient 实现所有 API 调用,
4
+ 支持流式 SSE 响应和完整的 Task 对象生命周期管理。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from typing import Any, AsyncIterator
11
+
12
+ import httpx
13
+
14
+ from . import _task
15
+ from .exceptions import (
16
+ ConnectionError as RWKVConnectionError,
17
+ RWKVError,
18
+ raise_for_status,
19
+ )
20
+ from .models import (
21
+ TaskInfo,
22
+ TaskResponseModel,
23
+ )
24
+
25
+ # API 基础路径
26
+ _TASKS_PREFIX = "/v1/tasks"
27
+
28
+ # 生成参数字段名集合,用于构建请求体
29
+ _GEN_PARAMS = frozenset({
30
+ "max_tokens", "temperature", "top_k", "top_p",
31
+ "presence_penalty", "repetition_penalty", "penalty_decay", "seed",
32
+ })
33
+
34
+
35
+ class AsyncClient:
36
+ """RWKV-Server Task API 异步客户端。
37
+
38
+ Args:
39
+ base_url: 服务地址,如 ``http://localhost:8000``。
40
+ timeout: HTTP 请求超时(秒)。
41
+ headers: 额外的 HTTP 请求头。
42
+ httpx_client: 自定义的 httpx.AsyncClient 实例(高级用法)。
43
+
44
+ Usage::
45
+
46
+ async with AsyncClient("http://localhost:8000") as client:
47
+ task = await client.create("Hello, world!", max_tokens=50)
48
+ result = await task.wait()
49
+ print(result.result)
50
+
51
+ # 实时流式
52
+ async with AsyncClient("http://localhost:8000") as client:
53
+ async for chunk in await client.create("Hello", stream=True):
54
+ print(chunk, end="")
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ base_url: str = "http://localhost:8000",
60
+ *,
61
+ timeout: float = 120.0,
62
+ headers: dict[str, str] | None = None,
63
+ httpx_client: httpx.AsyncClient | None = None,
64
+ ) -> None:
65
+ if httpx_client is not None:
66
+ self._client = httpx_client
67
+ else:
68
+ self._client = httpx.AsyncClient(
69
+ base_url=base_url.rstrip("/"),
70
+ timeout=httpx.Timeout(timeout),
71
+ headers=headers,
72
+ )
73
+
74
+ async def __aenter__(self) -> AsyncClient:
75
+ return self
76
+
77
+ async def __aexit__(self, *args: Any) -> None:
78
+ await self.close()
79
+
80
+ async def close(self) -> None:
81
+ """关闭底层 HTTP 连接池。"""
82
+ await self._client.aclose()
83
+
84
+ # ===================================================================
85
+ # 流式 API
86
+ # ===================================================================
87
+
88
+ async def _stream_sse(
89
+ self,
90
+ method: str,
91
+ endpoint: str,
92
+ body: dict[str, Any],
93
+ ) -> AsyncIterator[str]:
94
+ """通用 SSE 流式请求,逐 chunk yield。"""
95
+ try:
96
+ async with self._client.stream(method, endpoint, json=body) as resp:
97
+ raise_for_status(resp.status_code, "")
98
+ async for line in resp.aiter_lines():
99
+ line = line.strip()
100
+ if not line or not line.startswith("data:"):
101
+ continue
102
+ data_str = line[len("data:"):].strip()
103
+ if data_str == "[DONE]":
104
+ break
105
+ try:
106
+ parsed = json.loads(data_str)
107
+ except json.JSONDecodeError:
108
+ continue
109
+ # 跳过首事件(仅含 prefill_time,无 data)
110
+ if "prefill_time" in parsed and "data" not in parsed:
111
+ continue
112
+ if "data" in parsed:
113
+ yield parsed["data"]
114
+ except httpx.ConnectError as e:
115
+ raise RWKVConnectionError(f"无法连接到服务器: {e}") from e
116
+ except httpx.TimeoutException as e:
117
+ from .exceptions import TimeoutError as _Timeout
118
+ raise _Timeout(f"请求超时: {e}") from e
119
+
120
+ async def create_stream(
121
+ self,
122
+ prompt: str | list[int],
123
+ *,
124
+ max_tokens: int | None = None,
125
+ temperature: float | None = None,
126
+ top_k: int | None = None,
127
+ top_p: float | None = None,
128
+ presence_penalty: float | None = None,
129
+ repetition_penalty: float | None = None,
130
+ penalty_decay: float | None = None,
131
+ seed: int | None = None,
132
+ persistent: bool = False,
133
+ ) -> AsyncIterator[str]:
134
+ """创建任务并实时流式返回生成内容。"""
135
+ body: dict[str, Any] = {"prompt": prompt, "stream": True}
136
+ self._inject_gen_params(body, max_tokens=max_tokens, temperature=temperature,
137
+ top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
138
+ repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
139
+ seed=seed)
140
+ endpoint = f"{_TASKS_PREFIX}/create" if persistent else f"{_TASKS_PREFIX}/tmp"
141
+ async for chunk in self._stream_sse("POST", endpoint, body):
142
+ yield chunk
143
+
144
+ async def fim_stream(
145
+ self,
146
+ prefix: str,
147
+ suffix: str = "",
148
+ *,
149
+ max_tokens: int | None = None,
150
+ temperature: float | None = None,
151
+ top_k: int | None = None,
152
+ top_p: float | None = None,
153
+ presence_penalty: float | None = None,
154
+ repetition_penalty: float | None = None,
155
+ penalty_decay: float | None = None,
156
+ seed: int | None = None,
157
+ ) -> AsyncIterator[str]:
158
+ """FIM 实时流式返回生成内容。"""
159
+ body: dict[str, Any] = {"prefix": prefix, "suffix": suffix, "stream": True}
160
+ self._inject_gen_params(body, max_tokens=max_tokens, temperature=temperature,
161
+ top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
162
+ repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
163
+ seed=seed)
164
+ async for chunk in self._stream_sse("POST", f"{_TASKS_PREFIX}/fim", body):
165
+ yield chunk
166
+
167
+ # ===================================================================
168
+ # Task API
169
+ # ===================================================================
170
+
171
+ async def create(
172
+ self,
173
+ prompt: str | list[int],
174
+ *,
175
+ max_tokens: int | None = None,
176
+ temperature: float | None = None,
177
+ top_k: int | None = None,
178
+ top_p: float | None = None,
179
+ presence_penalty: float | None = None,
180
+ repetition_penalty: float | None = None,
181
+ penalty_decay: float | None = None,
182
+ stream: bool = False,
183
+ seed: int | None = None,
184
+ persistent: bool = False,
185
+ ) -> _task.AsyncTask | AsyncIterator[str]:
186
+ """创建任务。
187
+
188
+ Args:
189
+ prompt: 提示文本或 token id 列表。
190
+ max_tokens: 最大生成 token 数。
191
+ temperature: 采样温度。
192
+ top_k: Top-K 采样参数。
193
+ top_p: Top-P 采样参数。
194
+ presence_penalty: 存在惩罚。
195
+ repetition_penalty: 重复惩罚。
196
+ penalty_decay: 惩罚衰减。
197
+ stream: 是否使用流式响应。True 时返回异步生成器。
198
+ seed: 随机种子。
199
+ persistent: 是否创建持久化任务(True 使用 /create,False 使用 /tmp)。
200
+
201
+ Returns:
202
+ stream=True 时返回 AsyncIterator[str];
203
+ stream=False 时返回 AsyncTask 对象。
204
+ """
205
+ if stream is True:
206
+ return await self.create_stream(
207
+ prompt, max_tokens=max_tokens, temperature=temperature,
208
+ top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
209
+ repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
210
+ seed=seed, persistent=persistent,
211
+ )
212
+
213
+ body: dict[str, Any] = {"prompt": prompt, "stream": False}
214
+ self._inject_gen_params(body, max_tokens=max_tokens, temperature=temperature,
215
+ top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
216
+ repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
217
+ seed=seed)
218
+ endpoint = f"{_TASKS_PREFIX}/create" if persistent else f"{_TASKS_PREFIX}/tmp"
219
+ resp_model = await self._post_json(endpoint, body, response_model=TaskResponseModel)
220
+ return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
221
+
222
+ async def create_tmp(
223
+ self, prompt: str | list[int] | None = None, /, **kwargs: Any
224
+ ) -> _task.AsyncTask | AsyncIterator[str]:
225
+ """创建临时任务(等同于 create(..., persistent=False))。"""
226
+ if prompt is not None:
227
+ kwargs["prompt"] = prompt
228
+ return await self.create(**kwargs, persistent=False)
229
+
230
+ async def create_persistent(
231
+ self, prompt: str | list[int] | None = None, /, **kwargs: Any
232
+ ) -> _task.AsyncTask | AsyncIterator[str]:
233
+ """创建持久化任务(等同于 create(..., persistent=True))。"""
234
+ if prompt is not None:
235
+ kwargs["prompt"] = prompt
236
+ return await self.create(**kwargs, persistent=True)
237
+
238
+ async def fim(
239
+ self,
240
+ prefix: str,
241
+ suffix: str = "",
242
+ *,
243
+ max_tokens: int | None = None,
244
+ temperature: float | None = None,
245
+ top_k: int | None = None,
246
+ top_p: float | None = None,
247
+ presence_penalty: float | None = None,
248
+ repetition_penalty: float | None = None,
249
+ penalty_decay: float | None = None,
250
+ stream: bool = False,
251
+ seed: int | None = None,
252
+ ) -> _task.AsyncTask | AsyncIterator[str]:
253
+ """Fill In Middle —— 在 prefix 和 suffix 之间生成文本。
254
+
255
+ Args:
256
+ prefix: 前缀文本。
257
+ suffix: 后缀文本。
258
+ stream: 是否使用流式响应。True 时返回异步生成器。
259
+ 其余参数同 create()。
260
+
261
+ Returns:
262
+ stream=True 时返回 AsyncIterator[str];
263
+ stream=False 时返回 AsyncTask 对象。
264
+ """
265
+ if stream is True:
266
+ return await self.fim_stream(
267
+ prefix, suffix, max_tokens=max_tokens, temperature=temperature,
268
+ top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
269
+ repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
270
+ seed=seed,
271
+ )
272
+
273
+ body: dict[str, Any] = {"prefix": prefix, "suffix": suffix, "stream": False}
274
+ self._inject_gen_params(body, max_tokens=max_tokens, temperature=temperature,
275
+ top_k=top_k, top_p=top_p, presence_penalty=presence_penalty,
276
+ repetition_penalty=repetition_penalty, penalty_decay=penalty_decay,
277
+ seed=seed)
278
+ resp_model = await self._post_json(f"{_TASKS_PREFIX}/fim", body, response_model=TaskResponseModel)
279
+ return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
280
+
281
+ async def get_task_result(self, task_id: str) -> TaskResponseModel:
282
+ """获取任务结果。"""
283
+ return await self._get_json(f"{_TASKS_PREFIX}/{task_id}/get_result", response_model=TaskResponseModel)
284
+
285
+ async def get_task_status(self, task_id: str) -> TaskInfo:
286
+ """获取任务状态。"""
287
+ return await self._get_json(f"{_TASKS_PREFIX}/{task_id}/status", response_model=TaskInfo)
288
+
289
+ async def fork_task(self, task_id: str, **overrides: Any) -> _task.AsyncTask:
290
+ """Fork 任务。"""
291
+ body = self._build_update_body(overrides)
292
+ resp_model = await self._post_json(
293
+ f"{_TASKS_PREFIX}/{task_id}/fork", body, response_model=TaskResponseModel,
294
+ )
295
+ return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
296
+
297
+ async def continue_task(self, task_id: str, **overrides: Any) -> _task.AsyncTask:
298
+ """继续生成。"""
299
+ body = self._build_update_body(overrides)
300
+ resp_model = await self._post_json(
301
+ f"{_TASKS_PREFIX}/{task_id}/continue", body, response_model=TaskResponseModel,
302
+ )
303
+ return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
304
+
305
+ async def stop_task(self, task_id: str) -> None:
306
+ """停止任务。"""
307
+ await self._post_no_content(f"{_TASKS_PREFIX}/{task_id}/stop", {})
308
+
309
+ async def as_template(self, task_id: str) -> _task.AsyncTask:
310
+ """将任务转为模板。"""
311
+ resp_model = await self._post_json(
312
+ f"{_TASKS_PREFIX}/{task_id}/as_template", {}, response_model=TaskResponseModel,
313
+ )
314
+ return _task.AsyncTask(self, resp_model.task_id, response=resp_model)
315
+
316
+ async def stream_task(self, task_id: str) -> AsyncIterator[str]:
317
+ """订阅已创建任务的实时流式输出。"""
318
+ async for chunk in self._stream_sse("GET", f"{_TASKS_PREFIX}/{task_id}/stream", {}):
319
+ yield chunk
320
+
321
+ async def delete_task(self, task_id: str, *, force: bool = False) -> None:
322
+ """删除任务。"""
323
+ await self._post_no_content(
324
+ f"{_TASKS_PREFIX}/{task_id}/delete", {},
325
+ params={"force": str(force).lower()},
326
+ )
327
+
328
+ async def list_tasks(self) -> dict[str, Any]:
329
+ """列出所有任务。"""
330
+ return await self._get_json(f"{_TASKS_PREFIX}/list")
331
+
332
+ # ===================================================================
333
+ # 内部 HTTP 方法
334
+ # ===================================================================
335
+
336
+ async def _post_json(
337
+ self, path: str, body: dict[str, Any], *, response_model: type[Any] | None = None,
338
+ ) -> Any:
339
+ """发送 POST 请求,解析 JSON 响应。"""
340
+ try:
341
+ resp = await self._client.post(path, json=body)
342
+ except httpx.ConnectError as e:
343
+ raise RWKVConnectionError(f"无法连接到服务器: {e}") from e
344
+ except httpx.TimeoutException as e:
345
+ from .exceptions import TimeoutError as _Timeout
346
+ raise _Timeout(f"请求超时: {e}") from e
347
+ raise_for_status(resp.status_code, resp.text)
348
+ if response_model is not None:
349
+ return response_model.model_validate(resp.json())
350
+ return resp.json()
351
+
352
+ async def _get_json(
353
+ self, path: str, *, response_model: type[Any] | None = None,
354
+ ) -> Any:
355
+ """发送 GET 请求,解析 JSON 响应。"""
356
+ try:
357
+ resp = await self._client.get(path)
358
+ except httpx.ConnectError as e:
359
+ raise RWKVConnectionError(f"无法连接到服务器: {e}") from e
360
+ except httpx.TimeoutException as e:
361
+ from .exceptions import TimeoutError as _Timeout
362
+ raise _Timeout(f"请求超时: {e}") from e
363
+ raise_for_status(resp.status_code, resp.text)
364
+ if response_model is not None:
365
+ return response_model.model_validate(resp.json())
366
+ return resp.json()
367
+
368
+ async def _post_no_content(
369
+ self, path: str, body: dict[str, Any], *, params: dict[str, str] | None = None,
370
+ ) -> None:
371
+ """发送 POST 请求,不解析响应体。"""
372
+ try:
373
+ resp = await self._client.post(path, json=body, params=params)
374
+ except httpx.ConnectError as e:
375
+ raise RWKVConnectionError(f"无法连接到服务器: {e}") from e
376
+ except httpx.TimeoutException as e:
377
+ from .exceptions import TimeoutError as _Timeout
378
+ raise _Timeout(f"请求超时: {e}") from e
379
+ raise_for_status(resp.status_code, resp.text)
380
+
381
+ # ===================================================================
382
+ # 内部辅助
383
+ # ===================================================================
384
+
385
+ @staticmethod
386
+ def _inject_gen_params(body: dict[str, Any], **params: Any) -> None:
387
+ """将非 None 的生成参数注入请求体。"""
388
+ for key, value in params.items():
389
+ if value is not None:
390
+ body[key] = value
391
+
392
+ @staticmethod
393
+ def _build_update_body(overrides: dict[str, Any]) -> dict[str, Any]:
394
+ """从关键字参数构建更新请求体。"""
395
+ body: dict[str, Any] = {"stream": False}
396
+ for key in ("prompt", *_GEN_PARAMS, "stream"):
397
+ if key in overrides:
398
+ body[key] = overrides[key]
399
+ return body