xinference 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +415 -1
- xinference/constants.py +2 -0
- xinference/core/supervisor.py +29 -1
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/kokoro.py +1 -1
- xinference/model/audio/kokoro_zh.py +124 -0
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/embedding/sentence_transformers/core.py +4 -4
- xinference/model/embedding/vllm/core.py +7 -1
- xinference/model/image/model_spec.json +2 -3
- xinference/model/llm/core.py +10 -0
- xinference/model/llm/llama_cpp/core.py +1 -0
- xinference/model/llm/llm_family.json +40 -20
- xinference/model/llm/llm_family.py +1 -0
- xinference/model/llm/mlx/core.py +52 -33
- xinference/model/llm/sglang/core.py +2 -44
- xinference/model/llm/tool_parsers/__init__.py +58 -0
- xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +128 -0
- xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
- xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
- xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
- xinference/model/llm/transformers/core.py +1 -1
- xinference/model/llm/utils.py +127 -45
- xinference/model/llm/vllm/core.py +2 -61
- xinference/types.py +105 -2
- {xinference-1.9.1.dist-info → xinference-1.10.0.dist-info}/METADATA +7 -3
- {xinference-1.9.1.dist-info → xinference-1.10.0.dist-info}/RECORD +34 -26
- {xinference-1.9.1.dist-info → xinference-1.10.0.dist-info}/WHEEL +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.9.1.dist-info → xinference-1.10.0.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2025-
|
|
11
|
+
"date": "2025-09-12T21:20:52+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.
|
|
14
|
+
"full-revisionid": "b018733c97029fb59e8ffe55fadc6473232fbf23",
|
|
15
|
+
"version": "1.10.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
import inspect
|
|
17
|
+
import ipaddress
|
|
17
18
|
import json
|
|
18
19
|
import logging
|
|
19
20
|
import multiprocessing
|
|
@@ -21,6 +22,7 @@ import os
|
|
|
21
22
|
import pprint
|
|
22
23
|
import sys
|
|
23
24
|
import time
|
|
25
|
+
import uuid
|
|
24
26
|
import warnings
|
|
25
27
|
from typing import Any, Dict, List, Optional, Union
|
|
26
28
|
|
|
@@ -53,6 +55,7 @@ from xoscar.utils import get_next_port
|
|
|
53
55
|
from .._compat import BaseModel, Field
|
|
54
56
|
from .._version import get_versions
|
|
55
57
|
from ..constants import (
|
|
58
|
+
XINFERENCE_ALLOWED_IPS,
|
|
56
59
|
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
57
60
|
XINFERENCE_DEFAULT_ENDPOINT_PORT,
|
|
58
61
|
XINFERENCE_DISABLE_METRICS,
|
|
@@ -61,11 +64,16 @@ from ..constants import (
|
|
|
61
64
|
from ..core.event import Event, EventCollectorActor, EventType
|
|
62
65
|
from ..core.supervisor import SupervisorActor
|
|
63
66
|
from ..core.utils import CancelMixin, json_dumps
|
|
67
|
+
|
|
68
|
+
# Import Anthropic-related types and availability flag
|
|
64
69
|
from ..types import (
|
|
70
|
+
ANTHROPIC_AVAILABLE,
|
|
71
|
+
AnthropicMessage,
|
|
65
72
|
ChatCompletion,
|
|
66
73
|
Completion,
|
|
67
74
|
CreateChatCompletion,
|
|
68
75
|
CreateCompletion,
|
|
76
|
+
CreateMessage,
|
|
69
77
|
ImageList,
|
|
70
78
|
PeftModelConfig,
|
|
71
79
|
SDAPIResult,
|
|
@@ -213,6 +221,9 @@ class BuildGradioMediaInterfaceRequest(BaseModel):
|
|
|
213
221
|
|
|
214
222
|
|
|
215
223
|
class RESTfulAPI(CancelMixin):
|
|
224
|
+
# Add new class attributes
|
|
225
|
+
_allowed_ip_list: Optional[List[ipaddress.IPv4Network]] = None
|
|
226
|
+
|
|
216
227
|
def __init__(
|
|
217
228
|
self,
|
|
218
229
|
supervisor_address: str,
|
|
@@ -229,6 +240,45 @@ class RESTfulAPI(CancelMixin):
|
|
|
229
240
|
self._auth_service = AuthService(auth_config_file)
|
|
230
241
|
self._router = APIRouter()
|
|
231
242
|
self._app = FastAPI()
|
|
243
|
+
# Initialize allowed IP list once
|
|
244
|
+
self._init_allowed_ip_list()
|
|
245
|
+
|
|
246
|
+
def _init_allowed_ip_list(self):
|
|
247
|
+
"""Initialize the allowed IP list from environment variable."""
|
|
248
|
+
if RESTfulAPI._allowed_ip_list is None:
|
|
249
|
+
# ie: export XINFERENCE_ALLOWED_IPS=192.168.1.0/24
|
|
250
|
+
allowed_ips = XINFERENCE_ALLOWED_IPS
|
|
251
|
+
if allowed_ips:
|
|
252
|
+
RESTfulAPI._allowed_ip_list = []
|
|
253
|
+
for ip in allowed_ips.split(","):
|
|
254
|
+
ip = ip.strip()
|
|
255
|
+
try:
|
|
256
|
+
# Try parsing as network/CIDR
|
|
257
|
+
if "/" in ip:
|
|
258
|
+
RESTfulAPI._allowed_ip_list.append(ipaddress.ip_network(ip))
|
|
259
|
+
else:
|
|
260
|
+
# Parse as single IP
|
|
261
|
+
RESTfulAPI._allowed_ip_list.append(
|
|
262
|
+
ipaddress.ip_network(f"{ip}/32")
|
|
263
|
+
)
|
|
264
|
+
except ValueError:
|
|
265
|
+
logger.error(
|
|
266
|
+
f"Invalid IP address or network: {ip}", exc_info=True
|
|
267
|
+
)
|
|
268
|
+
continue
|
|
269
|
+
|
|
270
|
+
def _is_ip_allowed(self, ip: str) -> bool:
|
|
271
|
+
"""Check if an IP is allowed based on configured rules."""
|
|
272
|
+
if not RESTfulAPI._allowed_ip_list:
|
|
273
|
+
return True
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
client_ip = ipaddress.ip_address(ip)
|
|
277
|
+
return any(
|
|
278
|
+
client_ip in allowed_net for allowed_net in RESTfulAPI._allowed_ip_list
|
|
279
|
+
)
|
|
280
|
+
except ValueError:
|
|
281
|
+
return False
|
|
232
282
|
|
|
233
283
|
def is_authenticated(self):
|
|
234
284
|
return False if self._auth_service.config is None else True
|
|
@@ -287,6 +337,16 @@ class RESTfulAPI(CancelMixin):
|
|
|
287
337
|
allow_headers=["*"],
|
|
288
338
|
)
|
|
289
339
|
|
|
340
|
+
@self._app.middleware("http")
|
|
341
|
+
async def ip_restriction_middleware(request: Request, call_next):
|
|
342
|
+
client_ip = request.client.host
|
|
343
|
+
if not self._is_ip_allowed(client_ip):
|
|
344
|
+
return PlainTextResponse(
|
|
345
|
+
status_code=403, content=f"Access denied for IP: {client_ip}\n"
|
|
346
|
+
)
|
|
347
|
+
response = await call_next(request)
|
|
348
|
+
return response
|
|
349
|
+
|
|
290
350
|
@self._app.exception_handler(500)
|
|
291
351
|
async def internal_exception_handler(request: Request, exc: Exception):
|
|
292
352
|
logger.exception("Handling request %s failed: %s", request.url, exc)
|
|
@@ -532,6 +592,40 @@ class RESTfulAPI(CancelMixin):
|
|
|
532
592
|
else None
|
|
533
593
|
),
|
|
534
594
|
)
|
|
595
|
+
# Register messages endpoint only if Anthropic is available
|
|
596
|
+
if ANTHROPIC_AVAILABLE:
|
|
597
|
+
self._router.add_api_route(
|
|
598
|
+
"/anthropic/v1/messages",
|
|
599
|
+
self.create_message,
|
|
600
|
+
methods=["POST"],
|
|
601
|
+
response_model=AnthropicMessage,
|
|
602
|
+
dependencies=(
|
|
603
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
604
|
+
if self.is_authenticated()
|
|
605
|
+
else None
|
|
606
|
+
),
|
|
607
|
+
)
|
|
608
|
+
# Register Anthropic models endpoints
|
|
609
|
+
self._router.add_api_route(
|
|
610
|
+
"/anthropic/v1/models",
|
|
611
|
+
self.anthropic_list_models,
|
|
612
|
+
methods=["GET"],
|
|
613
|
+
dependencies=(
|
|
614
|
+
[Security(self._auth_service, scopes=["models:list"])]
|
|
615
|
+
if self.is_authenticated()
|
|
616
|
+
else None
|
|
617
|
+
),
|
|
618
|
+
)
|
|
619
|
+
self._router.add_api_route(
|
|
620
|
+
"/anthropic/v1/models/{model_id}",
|
|
621
|
+
self.anthropic_get_model,
|
|
622
|
+
methods=["GET"],
|
|
623
|
+
dependencies=(
|
|
624
|
+
[Security(self._auth_service, scopes=["models:list"])]
|
|
625
|
+
if self.is_authenticated()
|
|
626
|
+
else None
|
|
627
|
+
),
|
|
628
|
+
)
|
|
535
629
|
self._router.add_api_route(
|
|
536
630
|
"/v1/embeddings",
|
|
537
631
|
self.create_embedding,
|
|
@@ -994,6 +1088,58 @@ class RESTfulAPI(CancelMixin):
|
|
|
994
1088
|
logger.error(e, exc_info=True)
|
|
995
1089
|
raise HTTPException(status_code=500, detail=str(e))
|
|
996
1090
|
|
|
1091
|
+
async def anthropic_list_models(self) -> JSONResponse:
|
|
1092
|
+
"""Anthropic-compatible models endpoint"""
|
|
1093
|
+
try:
|
|
1094
|
+
|
|
1095
|
+
# Get running models from xinference
|
|
1096
|
+
running_models = await (await self._get_supervisor_ref()).list_models()
|
|
1097
|
+
|
|
1098
|
+
# For backward compatibility with tests, only return running models by default
|
|
1099
|
+
model_list = []
|
|
1100
|
+
|
|
1101
|
+
# Add running models to the list
|
|
1102
|
+
for model_id, model_info in running_models.items():
|
|
1103
|
+
anthropic_model = {
|
|
1104
|
+
"id": model_id,
|
|
1105
|
+
"object": "model",
|
|
1106
|
+
"created": 0,
|
|
1107
|
+
"display_name": model_info.get("model_name", model_id),
|
|
1108
|
+
"type": model_info.get("model_type", "model"),
|
|
1109
|
+
"max_tokens": model_info.get("context_length", 4096),
|
|
1110
|
+
}
|
|
1111
|
+
model_list.append(anthropic_model)
|
|
1112
|
+
|
|
1113
|
+
return JSONResponse(content=model_list)
|
|
1114
|
+
except Exception as e:
|
|
1115
|
+
logger.error(e, exc_info=True)
|
|
1116
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1117
|
+
|
|
1118
|
+
async def anthropic_get_model(self, model_id: str) -> JSONResponse:
|
|
1119
|
+
"""Anthropic-compatible model retrieval endpoint"""
|
|
1120
|
+
try:
|
|
1121
|
+
models = await (await self._get_supervisor_ref()).list_models()
|
|
1122
|
+
|
|
1123
|
+
model_info = models[model_id]
|
|
1124
|
+
|
|
1125
|
+
# Convert to Anthropic format
|
|
1126
|
+
anthropic_model = {
|
|
1127
|
+
"id": model_id, # Return the original requested ID
|
|
1128
|
+
"object": "model",
|
|
1129
|
+
"created": 0,
|
|
1130
|
+
"display_name": model_info.get("model_name", model_id),
|
|
1131
|
+
"type": model_info.get("model_type", "model"),
|
|
1132
|
+
"max_tokens": model_info.get("context_length", 4096),
|
|
1133
|
+
**model_info,
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
return JSONResponse(content=anthropic_model)
|
|
1137
|
+
except HTTPException:
|
|
1138
|
+
raise
|
|
1139
|
+
except Exception as e:
|
|
1140
|
+
logger.error(e, exc_info=True)
|
|
1141
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1142
|
+
|
|
997
1143
|
async def describe_model(self, model_uid: str) -> JSONResponse:
|
|
998
1144
|
try:
|
|
999
1145
|
data = await (await self._get_supervisor_ref()).describe_model(model_uid)
|
|
@@ -1417,6 +1563,151 @@ class RESTfulAPI(CancelMixin):
|
|
|
1417
1563
|
self.handle_request_limit_error(e)
|
|
1418
1564
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1419
1565
|
|
|
1566
|
+
async def create_message(self, request: Request) -> Response:
|
|
1567
|
+
raw_body = await request.json()
|
|
1568
|
+
body = CreateMessage.parse_obj(raw_body)
|
|
1569
|
+
|
|
1570
|
+
exclude = {
|
|
1571
|
+
"model",
|
|
1572
|
+
"messages",
|
|
1573
|
+
"stream",
|
|
1574
|
+
"stop_sequences",
|
|
1575
|
+
"metadata",
|
|
1576
|
+
"tool_choice",
|
|
1577
|
+
"tools",
|
|
1578
|
+
}
|
|
1579
|
+
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
|
|
1580
|
+
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
1581
|
+
|
|
1582
|
+
# guided_decoding params
|
|
1583
|
+
kwargs.update(self.extract_guided_params(raw_body=raw_body))
|
|
1584
|
+
|
|
1585
|
+
# TODO: Decide if this default value override is necessary #1061
|
|
1586
|
+
if body.max_tokens is None:
|
|
1587
|
+
kwargs["max_tokens"] = max_tokens_field.default
|
|
1588
|
+
|
|
1589
|
+
messages = body.messages and list(body.messages)
|
|
1590
|
+
|
|
1591
|
+
if not messages or messages[-1].get("role") not in ["user", "assistant"]:
|
|
1592
|
+
raise HTTPException(
|
|
1593
|
+
status_code=400, detail="Invalid input. Please specify the prompt."
|
|
1594
|
+
)
|
|
1595
|
+
|
|
1596
|
+
# Handle tools parameter
|
|
1597
|
+
if hasattr(body, "tools") and body.tools:
|
|
1598
|
+
kwargs["tools"] = body.tools
|
|
1599
|
+
|
|
1600
|
+
# Handle tool_choice parameter
|
|
1601
|
+
if hasattr(body, "tool_choice") and body.tool_choice:
|
|
1602
|
+
kwargs["tool_choice"] = body.tool_choice
|
|
1603
|
+
|
|
1604
|
+
# Get model mapping
|
|
1605
|
+
try:
|
|
1606
|
+
running_models = await (await self._get_supervisor_ref()).list_models()
|
|
1607
|
+
except Exception as e:
|
|
1608
|
+
logger.error(f"Failed to get model mapping: {e}", exc_info=True)
|
|
1609
|
+
raise HTTPException(status_code=500, detail="Failed to get model mapping")
|
|
1610
|
+
|
|
1611
|
+
if not running_models:
|
|
1612
|
+
raise HTTPException(
|
|
1613
|
+
status_code=400,
|
|
1614
|
+
detail=f"No running models available. Please start a model in xinference first.",
|
|
1615
|
+
)
|
|
1616
|
+
|
|
1617
|
+
requested_model_id = body.model
|
|
1618
|
+
if "claude" in requested_model_id:
|
|
1619
|
+
requested_model_id = list(running_models.keys())[0]
|
|
1620
|
+
|
|
1621
|
+
if requested_model_id not in running_models:
|
|
1622
|
+
raise HTTPException(
|
|
1623
|
+
status_code=400,
|
|
1624
|
+
detail=f"Model '{requested_model_id}' is not available. Available models: {list(running_models.keys())}",
|
|
1625
|
+
)
|
|
1626
|
+
else:
|
|
1627
|
+
model_uid = requested_model_id
|
|
1628
|
+
|
|
1629
|
+
try:
|
|
1630
|
+
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1631
|
+
except ValueError as ve:
|
|
1632
|
+
logger.error(str(ve), exc_info=True)
|
|
1633
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1634
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1635
|
+
except Exception as e:
|
|
1636
|
+
logger.error(e, exc_info=True)
|
|
1637
|
+
await self._report_error_event(model_uid, str(e))
|
|
1638
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1639
|
+
|
|
1640
|
+
if body.stream:
|
|
1641
|
+
|
|
1642
|
+
async def stream_results():
|
|
1643
|
+
iterator = None
|
|
1644
|
+
try:
|
|
1645
|
+
try:
|
|
1646
|
+
iterator = await model.chat(
|
|
1647
|
+
messages, kwargs, raw_params=raw_kwargs
|
|
1648
|
+
)
|
|
1649
|
+
except RuntimeError as re:
|
|
1650
|
+
self.handle_request_limit_error(re)
|
|
1651
|
+
|
|
1652
|
+
# Check if iterator is actually an async iterator
|
|
1653
|
+
if hasattr(iterator, "__aiter__"):
|
|
1654
|
+
async for item in iterator:
|
|
1655
|
+
yield item
|
|
1656
|
+
elif isinstance(iterator, (str, bytes)):
|
|
1657
|
+
# Handle case where chat returns bytes/string instead of iterator
|
|
1658
|
+
if isinstance(iterator, bytes):
|
|
1659
|
+
try:
|
|
1660
|
+
content = iterator.decode("utf-8")
|
|
1661
|
+
except UnicodeDecodeError:
|
|
1662
|
+
content = str(iterator)
|
|
1663
|
+
else:
|
|
1664
|
+
content = iterator
|
|
1665
|
+
yield dict(data=json.dumps({"content": content}))
|
|
1666
|
+
else:
|
|
1667
|
+
# Fallback: try to iterate normally
|
|
1668
|
+
try:
|
|
1669
|
+
for item in iterator:
|
|
1670
|
+
yield item
|
|
1671
|
+
except TypeError:
|
|
1672
|
+
# If not iterable, yield as single result
|
|
1673
|
+
yield dict(data=json.dumps({"content": str(iterator)}))
|
|
1674
|
+
|
|
1675
|
+
yield "[DONE]"
|
|
1676
|
+
except asyncio.CancelledError:
|
|
1677
|
+
logger.info(
|
|
1678
|
+
f"Disconnected from client (via refresh/close) {request.client} during chat."
|
|
1679
|
+
)
|
|
1680
|
+
return
|
|
1681
|
+
except Exception as ex:
|
|
1682
|
+
ex = await self._get_model_last_error(model.uid, ex)
|
|
1683
|
+
logger.exception("Message stream got an error: %s", ex)
|
|
1684
|
+
await self._report_error_event(model_uid, str(ex))
|
|
1685
|
+
yield dict(data=json.dumps({"error": str(ex)}))
|
|
1686
|
+
return
|
|
1687
|
+
finally:
|
|
1688
|
+
await model.decrease_serve_count()
|
|
1689
|
+
|
|
1690
|
+
return EventSourceResponse(
|
|
1691
|
+
stream_results(), ping=XINFERENCE_SSE_PING_ATTEMPTS_SECONDS
|
|
1692
|
+
)
|
|
1693
|
+
else:
|
|
1694
|
+
try:
|
|
1695
|
+
data = await model.chat(messages, kwargs, raw_params=raw_kwargs)
|
|
1696
|
+
# Convert OpenAI format to Anthropic format
|
|
1697
|
+
openai_response = json.loads(data)
|
|
1698
|
+
anthropic_response = self._convert_openai_to_anthropic(
|
|
1699
|
+
openai_response, body.model
|
|
1700
|
+
)
|
|
1701
|
+
return Response(
|
|
1702
|
+
json.dumps(anthropic_response), media_type="application/json"
|
|
1703
|
+
)
|
|
1704
|
+
except Exception as e:
|
|
1705
|
+
e = await self._get_model_last_error(model.uid, e)
|
|
1706
|
+
logger.error(e, exc_info=True)
|
|
1707
|
+
await self._report_error_event(model_uid, str(e))
|
|
1708
|
+
self.handle_request_limit_error(e)
|
|
1709
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1710
|
+
|
|
1420
1711
|
async def create_embedding(self, request: Request) -> Response:
|
|
1421
1712
|
payload = await request.json()
|
|
1422
1713
|
body = CreateEmbeddingRequest.parse_obj(payload)
|
|
@@ -2371,7 +2662,14 @@ class RESTfulAPI(CancelMixin):
|
|
|
2371
2662
|
data = await (await self._get_supervisor_ref()).list_model_registrations(
|
|
2372
2663
|
model_type, detailed=detailed
|
|
2373
2664
|
)
|
|
2374
|
-
|
|
2665
|
+
# Remove duplicate model names.
|
|
2666
|
+
model_names = set()
|
|
2667
|
+
final_data = []
|
|
2668
|
+
for item in data:
|
|
2669
|
+
if item["model_name"] not in model_names:
|
|
2670
|
+
model_names.add(item["model_name"])
|
|
2671
|
+
final_data.append(item)
|
|
2672
|
+
return JSONResponse(content=final_data)
|
|
2375
2673
|
except ValueError as re:
|
|
2376
2674
|
logger.error(re, exc_info=True)
|
|
2377
2675
|
raise HTTPException(status_code=400, detail=str(re))
|
|
@@ -2560,6 +2858,19 @@ class RESTfulAPI(CancelMixin):
|
|
|
2560
2858
|
def extract_guided_params(raw_body: dict) -> dict:
|
|
2561
2859
|
kwargs = {}
|
|
2562
2860
|
raw_extra_body: dict = raw_body.get("extra_body") # type: ignore
|
|
2861
|
+
# Convert OpenAI response_format to vLLM guided decoding
|
|
2862
|
+
response_format = raw_body.get("response_format")
|
|
2863
|
+
if response_format is not None:
|
|
2864
|
+
if isinstance(response_format, dict):
|
|
2865
|
+
format_type = response_format.get("type")
|
|
2866
|
+
if format_type == "json_schema":
|
|
2867
|
+
json_schema = response_format.get("json_schema")
|
|
2868
|
+
if isinstance(json_schema, dict):
|
|
2869
|
+
schema = json_schema.get("schema")
|
|
2870
|
+
if schema is not None:
|
|
2871
|
+
kwargs["guided_json"] = schema
|
|
2872
|
+
elif format_type == "json_object":
|
|
2873
|
+
kwargs["guided_json_object"] = True
|
|
2563
2874
|
if raw_body.get("guided_json"):
|
|
2564
2875
|
kwargs["guided_json"] = raw_body.get("guided_json")
|
|
2565
2876
|
if raw_body.get("guided_regex") is not None:
|
|
@@ -2578,6 +2889,19 @@ class RESTfulAPI(CancelMixin):
|
|
|
2578
2889
|
)
|
|
2579
2890
|
# Parse OpenAI extra_body
|
|
2580
2891
|
if raw_extra_body is not None:
|
|
2892
|
+
# Convert OpenAI response_format to vLLM guided decoding
|
|
2893
|
+
extra_response_format = raw_extra_body.get("response_format")
|
|
2894
|
+
if extra_response_format is not None:
|
|
2895
|
+
if isinstance(extra_response_format, dict):
|
|
2896
|
+
format_type = extra_response_format.get("type")
|
|
2897
|
+
if format_type == "json_schema":
|
|
2898
|
+
json_schema = extra_response_format.get("json_schema")
|
|
2899
|
+
if isinstance(json_schema, dict):
|
|
2900
|
+
schema = json_schema.get("schema")
|
|
2901
|
+
if schema is not None:
|
|
2902
|
+
kwargs["guided_json"] = schema
|
|
2903
|
+
elif format_type == "json_object":
|
|
2904
|
+
kwargs["guided_json_object"] = True
|
|
2581
2905
|
if raw_extra_body.get("guided_json"):
|
|
2582
2906
|
kwargs["guided_json"] = raw_extra_body.get("guided_json")
|
|
2583
2907
|
if raw_extra_body.get("guided_regex") is not None:
|
|
@@ -2603,6 +2927,96 @@ class RESTfulAPI(CancelMixin):
|
|
|
2603
2927
|
|
|
2604
2928
|
return kwargs
|
|
2605
2929
|
|
|
2930
|
+
def _convert_openai_to_anthropic(self, openai_response: dict, model: str) -> dict:
|
|
2931
|
+
"""
|
|
2932
|
+
Convert OpenAI response format to Anthropic response format.
|
|
2933
|
+
|
|
2934
|
+
Args:
|
|
2935
|
+
openai_response: OpenAI format response
|
|
2936
|
+
model: Model name
|
|
2937
|
+
|
|
2938
|
+
Returns:
|
|
2939
|
+
Anthropic format response
|
|
2940
|
+
"""
|
|
2941
|
+
|
|
2942
|
+
# Extract content and tool calls from OpenAI response
|
|
2943
|
+
content_blocks = []
|
|
2944
|
+
stop_reason = "stop"
|
|
2945
|
+
|
|
2946
|
+
if "choices" in openai_response and len(openai_response["choices"]) > 0:
|
|
2947
|
+
choice = openai_response["choices"][0]
|
|
2948
|
+
message = choice.get("message", {})
|
|
2949
|
+
|
|
2950
|
+
# Handle content text
|
|
2951
|
+
content = message.get("content", "")
|
|
2952
|
+
if content:
|
|
2953
|
+
if isinstance(content, str):
|
|
2954
|
+
# If content is a string, use it directly
|
|
2955
|
+
content_blocks.append({"type": "text", "text": content})
|
|
2956
|
+
elif isinstance(content, list):
|
|
2957
|
+
# If content is a list, extract text from each content block
|
|
2958
|
+
for content_block in content:
|
|
2959
|
+
if isinstance(content_block, dict):
|
|
2960
|
+
if content_block.get("type") == "text":
|
|
2961
|
+
text = content_block.get("text", "")
|
|
2962
|
+
if text:
|
|
2963
|
+
content_blocks.append(
|
|
2964
|
+
{"type": "text", "text": text}
|
|
2965
|
+
)
|
|
2966
|
+
elif "text" in content_block:
|
|
2967
|
+
# Handle different content block format
|
|
2968
|
+
text = content_block.get("text", "")
|
|
2969
|
+
if text:
|
|
2970
|
+
content_blocks.append(
|
|
2971
|
+
{"type": "text", "text": text}
|
|
2972
|
+
)
|
|
2973
|
+
|
|
2974
|
+
# Handle tool calls
|
|
2975
|
+
tool_calls = message.get("tool_calls", [])
|
|
2976
|
+
for tool_call in tool_calls:
|
|
2977
|
+
function = tool_call.get("function", {})
|
|
2978
|
+
arguments = function.get("arguments", "{}")
|
|
2979
|
+
try:
|
|
2980
|
+
input_data = json.loads(arguments)
|
|
2981
|
+
except json.JSONDecodeError:
|
|
2982
|
+
input_data = {}
|
|
2983
|
+
tool_use_block = {
|
|
2984
|
+
"type": "tool_use",
|
|
2985
|
+
"cache_control": {"type": "ephemeral"},
|
|
2986
|
+
"id": tool_call.get("id", str(uuid.uuid4())),
|
|
2987
|
+
"name": function.get("name", ""),
|
|
2988
|
+
"input": input_data,
|
|
2989
|
+
}
|
|
2990
|
+
content_blocks.append(tool_use_block)
|
|
2991
|
+
|
|
2992
|
+
# Set stop reason based on finish reason
|
|
2993
|
+
finish_reason = choice.get("finish_reason", "stop")
|
|
2994
|
+
if finish_reason == "tool_calls":
|
|
2995
|
+
stop_reason = "tool_use"
|
|
2996
|
+
|
|
2997
|
+
# Build Anthropic response
|
|
2998
|
+
anthropic_response = {
|
|
2999
|
+
"id": str(uuid.uuid4()),
|
|
3000
|
+
"type": "message",
|
|
3001
|
+
"role": "assistant",
|
|
3002
|
+
"content": content_blocks,
|
|
3003
|
+
"model": model,
|
|
3004
|
+
"stop_reason": stop_reason,
|
|
3005
|
+
"stop_sequence": None,
|
|
3006
|
+
"usage": {
|
|
3007
|
+
"input_tokens": openai_response.get("usage", {}).get(
|
|
3008
|
+
"prompt_tokens", 0
|
|
3009
|
+
),
|
|
3010
|
+
"output_tokens": openai_response.get("usage", {}).get(
|
|
3011
|
+
"completion_tokens", 0
|
|
3012
|
+
),
|
|
3013
|
+
"cache_creation_input_tokens": 0,
|
|
3014
|
+
"cache_read_input_tokens": 0,
|
|
3015
|
+
},
|
|
3016
|
+
}
|
|
3017
|
+
|
|
3018
|
+
return anthropic_response
|
|
3019
|
+
|
|
2606
3020
|
|
|
2607
3021
|
def run(
|
|
2608
3022
|
supervisor_address: str,
|
xinference/constants.py
CHANGED
|
@@ -33,6 +33,7 @@ XINFERENCE_ENV_VIRTUAL_ENV = "XINFERENCE_ENABLE_VIRTUAL_ENV"
|
|
|
33
33
|
XINFERENCE_ENV_VIRTUAL_ENV_SKIP_INSTALLED = "XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED"
|
|
34
34
|
XINFERENCE_ENV_SSE_PING_ATTEMPTS_SECONDS = "XINFERENCE_SSE_PING_ATTEMPTS_SECONDS"
|
|
35
35
|
XINFERENCE_ENV_MAX_TOKENS = "XINFERENCE_MAX_TOKENS"
|
|
36
|
+
XINFERENCE_ENV_ALLOWED_IPS = "XINFERENCE_ALLOWED_IPS"
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
def get_xinference_home() -> str:
|
|
@@ -110,3 +111,4 @@ XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED = (
|
|
|
110
111
|
if os.getenv(XINFERENCE_ENV_VIRTUAL_ENV_SKIP_INSTALLED)
|
|
111
112
|
else None
|
|
112
113
|
)
|
|
114
|
+
XINFERENCE_ALLOWED_IPS = os.getenv(XINFERENCE_ENV_ALLOWED_IPS)
|
xinference/core/supervisor.py
CHANGED
|
@@ -886,6 +886,10 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
886
886
|
await self._cache_tracker_ref.record_model_version(
|
|
887
887
|
generate_fn(model_spec), self.address
|
|
888
888
|
)
|
|
889
|
+
await self._sync_register_model(
|
|
890
|
+
model_type, model, persist, model_spec.model_name
|
|
891
|
+
)
|
|
892
|
+
|
|
889
893
|
except ValueError as e:
|
|
890
894
|
raise e
|
|
891
895
|
except Exception as e:
|
|
@@ -894,6 +898,30 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
894
898
|
else:
|
|
895
899
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
896
900
|
|
|
901
|
+
async def _sync_register_model(
|
|
902
|
+
self, model_type: str, model: str, persist: bool, model_name: str
|
|
903
|
+
):
|
|
904
|
+
logger.info(f"begin sync model:{model_name} to worker")
|
|
905
|
+
try:
|
|
906
|
+
# Sync model to all workers.
|
|
907
|
+
for name, worker in self._worker_address_to_worker.items():
|
|
908
|
+
logger.info(f"sync model:{model_name} to {name}")
|
|
909
|
+
if name == self.address:
|
|
910
|
+
# Ignore: when worker and supervisor at the same node.
|
|
911
|
+
logger.info(
|
|
912
|
+
f"ignore sync model:{model_name} to {name} for same node"
|
|
913
|
+
)
|
|
914
|
+
else:
|
|
915
|
+
await worker.register_model(model_type, model, persist)
|
|
916
|
+
logger.info(f"success sync model:{model_name} to {name}")
|
|
917
|
+
except Exception as e:
|
|
918
|
+
# If sync fails, unregister the model in all workers.
|
|
919
|
+
for name, worker in self._worker_address_to_worker.items():
|
|
920
|
+
logger.warning(f"ready to unregister model for {name}")
|
|
921
|
+
await worker.unregister_model(model_type, model_name)
|
|
922
|
+
logger.warning(f"finish unregister model:{model} for {name}")
|
|
923
|
+
raise e
|
|
924
|
+
|
|
897
925
|
@log_async(logger=logger)
|
|
898
926
|
async def unregister_model(self, model_type: str, model_name: str):
|
|
899
927
|
if model_type in self._custom_register_type_to_cls:
|
|
@@ -1014,7 +1042,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1014
1042
|
)
|
|
1015
1043
|
|
|
1016
1044
|
# search in worker first
|
|
1017
|
-
if not self.is_local_deployment():
|
|
1045
|
+
if not self.is_local_deployment() and worker_ip is None:
|
|
1018
1046
|
workers = list(self._worker_address_to_worker.values())
|
|
1019
1047
|
for worker in workers:
|
|
1020
1048
|
res = await worker.get_model_registration(model_type, model_name)
|
xinference/model/audio/core.py
CHANGED
|
@@ -25,6 +25,7 @@ from .fish_speech import FishSpeechModel
|
|
|
25
25
|
from .funasr import FunASRModel
|
|
26
26
|
from .kokoro import KokoroModel
|
|
27
27
|
from .kokoro_mlx import KokoroMLXModel
|
|
28
|
+
from .kokoro_zh import KokoroZHModel
|
|
28
29
|
from .megatts import MegaTTSModel
|
|
29
30
|
from .melotts import MeloTTSModel
|
|
30
31
|
from .whisper import WhisperModel
|
|
@@ -140,6 +141,7 @@ def create_audio_model_instance(
|
|
|
140
141
|
MeloTTSModel,
|
|
141
142
|
KokoroModel,
|
|
142
143
|
KokoroMLXModel,
|
|
144
|
+
KokoroZHModel,
|
|
143
145
|
MegaTTSModel,
|
|
144
146
|
]:
|
|
145
147
|
from ..cache_manager import CacheManager
|
|
@@ -160,6 +162,7 @@ def create_audio_model_instance(
|
|
|
160
162
|
MeloTTSModel,
|
|
161
163
|
KokoroModel,
|
|
162
164
|
KokoroMLXModel,
|
|
165
|
+
KokoroZHModel,
|
|
163
166
|
MegaTTSModel,
|
|
164
167
|
]
|
|
165
168
|
if model_spec.model_family == "whisper":
|
|
@@ -183,6 +186,8 @@ def create_audio_model_instance(
|
|
|
183
186
|
model = MeloTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
184
187
|
elif model_spec.model_family == "Kokoro":
|
|
185
188
|
model = KokoroModel(model_uid, model_path, model_spec, **kwargs)
|
|
189
|
+
elif model_spec.model_family == "Kokoro-zh":
|
|
190
|
+
model = KokoroZHModel(model_uid, model_path, model_spec, **kwargs)
|
|
186
191
|
elif model_spec.model_family == "Kokoro-MLX":
|
|
187
192
|
model = KokoroMLXModel(model_uid, model_path, model_spec, **kwargs)
|
|
188
193
|
elif model_spec.model_family == "MegaTTS":
|
xinference/model/audio/kokoro.py
CHANGED
|
@@ -81,7 +81,7 @@ class KokoroModel:
|
|
|
81
81
|
logger.info("Launching Kokoro model with language code: %s", lang_code)
|
|
82
82
|
self._model = KPipeline(
|
|
83
83
|
lang_code=lang_code,
|
|
84
|
-
model=KModel(config=config_path, model=model_path),
|
|
84
|
+
model=KModel(config=config_path, model=model_path).to(self._device),
|
|
85
85
|
device=self._device,
|
|
86
86
|
)
|
|
87
87
|
|