xinference 0.7.4__py3-none-any.whl → 0.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (34) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +22 -8
  3. xinference/client/oscar/actor_client.py +78 -8
  4. xinference/client/restful/restful_client.py +86 -0
  5. xinference/core/model.py +14 -7
  6. xinference/core/supervisor.py +12 -0
  7. xinference/deploy/cmdline.py +16 -0
  8. xinference/deploy/test/test_cmdline.py +1 -0
  9. xinference/model/embedding/model_spec.json +40 -0
  10. xinference/model/llm/__init__.py +14 -1
  11. xinference/model/llm/llm_family.json +10 -1
  12. xinference/model/llm/llm_family.py +38 -2
  13. xinference/model/llm/llm_family_modelscope.json +10 -1
  14. xinference/model/llm/pytorch/chatglm.py +1 -0
  15. xinference/model/llm/pytorch/core.py +1 -1
  16. xinference/model/llm/pytorch/utils.py +50 -18
  17. xinference/model/llm/utils.py +2 -2
  18. xinference/model/llm/vllm/core.py +13 -4
  19. xinference/model/multimodal/core.py +1 -1
  20. xinference/model/multimodal/qwen_vl.py +34 -2
  21. xinference/web/ui/build/asset-manifest.json +3 -3
  22. xinference/web/ui/build/index.html +1 -1
  23. xinference/web/ui/build/static/js/{main.31d347d8.js → main.236e72e7.js} +3 -3
  24. xinference/web/ui/build/static/js/main.236e72e7.js.map +1 -0
  25. xinference/web/ui/node_modules/.cache/babel-loader/78f2521da2e2a98b075a2666cb782c7e2c019cd3c72199eecd5901c82d8655df.json +1 -0
  26. {xinference-0.7.4.dist-info → xinference-0.7.5.dist-info}/METADATA +9 -2
  27. {xinference-0.7.4.dist-info → xinference-0.7.5.dist-info}/RECORD +32 -32
  28. xinference/web/ui/build/static/js/main.31d347d8.js.map +0 -1
  29. xinference/web/ui/node_modules/.cache/babel-loader/ca8515ecefb4a06c5305417bfd9c04e13cf6b9103f52a47c925921b26c0a9f9d.json +0 -1
  30. /xinference/web/ui/build/static/js/{main.31d347d8.js.LICENSE.txt → main.236e72e7.js.LICENSE.txt} +0 -0
  31. {xinference-0.7.4.dist-info → xinference-0.7.5.dist-info}/LICENSE +0 -0
  32. {xinference-0.7.4.dist-info → xinference-0.7.5.dist-info}/WHEEL +0 -0
  33. {xinference-0.7.4.dist-info → xinference-0.7.5.dist-info}/entry_points.txt +0 -0
  34. {xinference-0.7.4.dist-info → xinference-0.7.5.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2023-12-29T11:58:30+0800",
11
+ "date": "2024-01-05T15:29:43+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "6ffee492d79c1bf9a6bdef8291dc0a56117abe06",
15
- "version": "0.7.4"
14
+ "full-revisionid": "56b28b3e4149b0a9ab6f5322401b1c3f1fc95c1a",
15
+ "version": "0.7.5"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -160,6 +160,9 @@ class RESTfulAPI:
160
160
  self._router.add_api_route(
161
161
  "/v1/models/prompts", self._get_builtin_prompts, methods=["GET"]
162
162
  )
163
+ self._router.add_api_route(
164
+ "/v1/models/families", self._get_builtin_families, methods=["GET"]
165
+ )
163
166
  self._router.add_api_route(
164
167
  "/v1/cluster/devices", self._get_devices_count, methods=["GET"]
165
168
  )
@@ -312,6 +315,17 @@ class RESTfulAPI:
312
315
  logger.error(e, exc_info=True)
313
316
  raise HTTPException(status_code=500, detail=str(e))
314
317
 
318
+ async def _get_builtin_families(self) -> JSONResponse:
319
+ """
320
+ For internal usage
321
+ """
322
+ try:
323
+ data = await (await self._get_supervisor_ref()).get_builtin_families()
324
+ return JSONResponse(content=data)
325
+ except Exception as e:
326
+ logger.error(e, exc_info=True)
327
+ raise HTTPException(status_code=500, detail=str(e))
328
+
315
329
  async def _get_devices_count(self) -> JSONResponse:
316
330
  """
317
331
  For internal usage
@@ -565,7 +579,7 @@ class RESTfulAPI:
565
579
  except RuntimeError as re:
566
580
  self.handle_request_limit_error(re)
567
581
  async for item in iterator:
568
- yield dict(data=json.dumps(item))
582
+ yield item
569
583
  except Exception as ex:
570
584
  if iterator is not None:
571
585
  await iterator.destroy()
@@ -577,7 +591,7 @@ class RESTfulAPI:
577
591
  else:
578
592
  try:
579
593
  data = await model.generate(body.prompt, kwargs)
580
- return JSONResponse(content=data)
594
+ return Response(data, media_type="application/json")
581
595
  except Exception as e:
582
596
  logger.error(e, exc_info=True)
583
597
  self.handle_request_limit_error(e)
@@ -634,7 +648,7 @@ class RESTfulAPI:
634
648
  logger.error(e, exc_info=True)
635
649
  raise HTTPException(status_code=500, detail=str(e))
636
650
 
637
- async def create_images(self, request: TextToImageRequest) -> JSONResponse:
651
+ async def create_images(self, request: TextToImageRequest) -> Response:
638
652
  model_uid = request.model
639
653
  try:
640
654
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -655,7 +669,7 @@ class RESTfulAPI:
655
669
  response_format=request.response_format,
656
670
  **kwargs,
657
671
  )
658
- return JSONResponse(content=image_list)
672
+ return Response(content=image_list, media_type="application/json")
659
673
  except RuntimeError as re:
660
674
  logger.error(re, exc_info=True)
661
675
  self.handle_request_limit_error(re)
@@ -674,7 +688,7 @@ class RESTfulAPI:
674
688
  response_format: Optional[str] = Form("url"),
675
689
  size: Optional[str] = Form("1024*1024"),
676
690
  kwargs: Optional[str] = Form(None),
677
- ) -> JSONResponse:
691
+ ) -> Response:
678
692
  model_uid = model
679
693
  try:
680
694
  model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -697,7 +711,7 @@ class RESTfulAPI:
697
711
  response_format=response_format,
698
712
  **kwargs,
699
713
  )
700
- return JSONResponse(content=image_list)
714
+ return Response(content=image_list, media_type="application/json")
701
715
  except RuntimeError as re:
702
716
  logger.error(re, exc_info=True)
703
717
  raise HTTPException(status_code=400, detail=str(re))
@@ -828,7 +842,7 @@ class RESTfulAPI:
828
842
  except RuntimeError as re:
829
843
  self.handle_request_limit_error(re)
830
844
  async for item in iterator:
831
- yield dict(data=json.dumps(item))
845
+ yield item
832
846
  except Exception as ex:
833
847
  if iterator is not None:
834
848
  await iterator.destroy()
@@ -843,7 +857,7 @@ class RESTfulAPI:
843
857
  data = await model.chat(prompt, chat_history, kwargs)
844
858
  else:
845
859
  data = await model.chat(prompt, system_prompt, chat_history, kwargs)
846
- return JSONResponse(content=data)
860
+ return Response(content=data, media_type="application/json")
847
861
  except Exception as e:
848
862
  logger.error(e, exc_info=True)
849
863
  self.handle_request_limit_error(e)
@@ -13,11 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
+ import re
16
17
  from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
17
18
 
19
+ import orjson
18
20
  import xoscar as xo
19
21
 
20
- from ...core.model import ModelActor
22
+ from ...core.model import IteratorWrapper, ModelActor
21
23
  from ...core.supervisor import SupervisorActor
22
24
  from ...isolation import Isolation
23
25
  from ..restful.restful_client import Client
@@ -38,6 +40,52 @@ if TYPE_CHECKING:
38
40
  )
39
41
 
40
42
 
43
+ class SSEEvent(object):
44
+ # https://github.com/btubbs/sseclient/blob/master/sseclient.py
45
+ sse_line_pattern = re.compile("(?P<name>[^:]*):?( ?(?P<value>.*))?")
46
+
47
+ def __init__(self, data="", event="message", id=None, retry=None):
48
+ self.data = data
49
+ self.event = event
50
+ self.id = id
51
+ self.retry = retry
52
+
53
+ @classmethod
54
+ def parse(cls, raw):
55
+ """
56
+ Given a possibly-multiline string representing an SSE message, parse it
57
+ and return a Event object.
58
+ """
59
+ msg = cls()
60
+ for line in raw.splitlines():
61
+ m = cls.sse_line_pattern.match(line)
62
+ if m is None:
63
+ # Malformed line. Discard but warn.
64
+ continue
65
+
66
+ name = m.group("name")
67
+ if name == "":
68
+ # line began with a ":", so is a comment. Ignore
69
+ continue
70
+ value = m.group("value")
71
+
72
+ if name == "data":
73
+ # If we already have some data, then join to it with a newline.
74
+ # Else this is it.
75
+ if msg.data:
76
+ msg.data = "%s\n%s" % (msg.data, value)
77
+ else:
78
+ msg.data = value
79
+ elif name == "event":
80
+ msg.event = value
81
+ elif name == "id":
82
+ msg.id = value
83
+ elif name == "retry":
84
+ msg.retry = int(value)
85
+
86
+ return msg
87
+
88
+
41
89
  class ModelHandle:
42
90
  """
43
91
  A sync model interface (for rpc client) which provides type hints that makes it much easier to use xinference
@@ -49,6 +97,19 @@ class ModelHandle:
49
97
  self._isolation = isolation
50
98
 
51
99
 
100
+ class ClientIteratorWrapper(IteratorWrapper):
101
+ async def __anext__(self):
102
+ r = await super().__anext__()
103
+ text = r.decode("utf-8")
104
+ return orjson.loads(SSEEvent.parse(text).data)
105
+
106
+ @classmethod
107
+ def wrap(cls, iterator_wrapper):
108
+ c = cls.__new__(cls)
109
+ c.__dict__.update(iterator_wrapper.__dict__)
110
+ return c
111
+
112
+
52
113
  class EmbeddingModelHandle(ModelHandle):
53
114
  def create_embedding(self, input: Union[str, List[str]]) -> bytes:
54
115
  """
@@ -68,7 +129,7 @@ class EmbeddingModelHandle(ModelHandle):
68
129
  """
69
130
 
70
131
  coro = self._model_ref.create_embedding(input)
71
- return self._isolation.call(coro)
132
+ return orjson.loads(self._isolation.call(coro))
72
133
 
73
134
 
74
135
  class RerankModelHandle(ModelHandle):
@@ -104,7 +165,7 @@ class RerankModelHandle(ModelHandle):
104
165
  coro = self._model_ref.rerank(
105
166
  documents, query, top_n, max_chunks_per_doc, return_documents
106
167
  )
107
- results = self._isolation.call(coro)
168
+ results = orjson.loads(self._isolation.call(coro))
108
169
  for r in results["results"]:
109
170
  r["document"] = documents[r["index"]]
110
171
  return results
@@ -140,7 +201,10 @@ class GenerateModelHandle(EmbeddingModelHandle):
140
201
  """
141
202
 
142
203
  coro = self._model_ref.generate(prompt, generate_config)
143
- return self._isolation.call(coro)
204
+ r = self._isolation.call(coro)
205
+ if isinstance(r, bytes):
206
+ return orjson.loads(r)
207
+ return ClientIteratorWrapper.wrap(r)
144
208
 
145
209
 
146
210
  class ChatModelHandle(GenerateModelHandle):
@@ -185,7 +249,10 @@ class ChatModelHandle(GenerateModelHandle):
185
249
  coro = self._model_ref.chat(
186
250
  prompt, system_prompt, chat_history, generate_config
187
251
  )
188
- return self._isolation.call(coro)
252
+ r = self._isolation.call(coro)
253
+ if isinstance(r, bytes):
254
+ return orjson.loads(r)
255
+ return ClientIteratorWrapper.wrap(r)
189
256
 
190
257
 
191
258
  class ChatglmCppChatModelHandle(EmbeddingModelHandle):
@@ -217,7 +284,10 @@ class ChatglmCppChatModelHandle(EmbeddingModelHandle):
217
284
  """
218
285
 
219
286
  coro = self._model_ref.chat(prompt, chat_history, generate_config)
220
- return self._isolation.call(coro)
287
+ r = self._isolation.call(coro)
288
+ if isinstance(r, bytes):
289
+ return orjson.loads(r)
290
+ return ClientIteratorWrapper.wrap(r)
221
291
 
222
292
 
223
293
  class ImageModelHandle(ModelHandle):
@@ -249,7 +319,7 @@ class ImageModelHandle(ModelHandle):
249
319
  """
250
320
 
251
321
  coro = self._model_ref.text_to_image(prompt, n, size, response_format, **kwargs)
252
- return self._isolation.call(coro)
322
+ return orjson.loads(self._isolation.call(coro))
253
323
 
254
324
  def image_to_image(
255
325
  self,
@@ -294,7 +364,7 @@ class ImageModelHandle(ModelHandle):
294
364
  coro = self._model_ref.image_to_image(
295
365
  image, prompt, negative_prompt, n, size, response_format, **kwargs
296
366
  )
297
- return self._isolation.call(coro)
367
+ return orjson.loads(self._isolation.call(coro))
298
368
 
299
369
 
300
370
  class ActorClient:
@@ -398,6 +398,90 @@ class RESTfulChatModelHandle(RESTfulGenerateModelHandle):
398
398
  return response_data
399
399
 
400
400
 
401
+ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
402
+ def chat(
403
+ self,
404
+ prompt: Any,
405
+ system_prompt: Optional[str] = None,
406
+ chat_history: Optional[List["ChatCompletionMessage"]] = None,
407
+ tools: Optional[List[Dict]] = None,
408
+ generate_config: Optional[
409
+ Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
410
+ ] = None,
411
+ ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
412
+ """
413
+ Given a list of messages comprising a conversation, the model will return a response via RESTful APIs.
414
+
415
+ Parameters
416
+ ----------
417
+ prompt: str
418
+ The user's input.
419
+ system_prompt: Optional[str]
420
+ The system context provide to Model prior to any chats.
421
+ chat_history: Optional[List["ChatCompletionMessage"]]
422
+ A list of messages comprising the conversation so far.
423
+ tools: Optional[List[Dict]]
424
+ A tool list.
425
+ generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]]
426
+ Additional configuration for the chat generation.
427
+ "LlamaCppGenerateConfig" -> configuration for ggml model
428
+ "PytorchGenerateConfig" -> configuration for pytorch model
429
+
430
+ Returns
431
+ -------
432
+ Union["ChatCompletion", Iterator["ChatCompletionChunk"]]
433
+ Stream is a parameter in generate_config.
434
+ When stream is set to True, the function will return Iterator["ChatCompletionChunk"].
435
+ When stream is set to False, the function will return "ChatCompletion".
436
+
437
+ Raises
438
+ ------
439
+ RuntimeError
440
+ Report the failure to generate the chat from the server. Detailed information provided in error message.
441
+
442
+ """
443
+
444
+ url = f"{self._base_url}/v1/chat/completions"
445
+
446
+ if chat_history is None:
447
+ chat_history = []
448
+
449
+ if chat_history and chat_history[0]["role"] == "system":
450
+ if system_prompt is not None:
451
+ chat_history[0]["content"] = system_prompt
452
+
453
+ else:
454
+ if system_prompt is not None:
455
+ chat_history.insert(0, {"role": "system", "content": system_prompt})
456
+
457
+ chat_history.append({"role": "user", "content": prompt})
458
+
459
+ request_body: Dict[str, Any] = {
460
+ "model": self._model_uid,
461
+ "messages": chat_history,
462
+ }
463
+ if tools is not None:
464
+ raise RuntimeError("Multimodal does not support function call.")
465
+
466
+ if generate_config is not None:
467
+ for key, value in generate_config.items():
468
+ request_body[key] = value
469
+
470
+ stream = bool(generate_config and generate_config.get("stream"))
471
+ response = requests.post(url, json=request_body, stream=stream)
472
+
473
+ if response.status_code != 200:
474
+ raise RuntimeError(
475
+ f"Failed to generate chat completion, detail: {_get_error_string(response)}"
476
+ )
477
+
478
+ if stream:
479
+ return streaming_response_iterator(response.iter_lines())
480
+
481
+ response_data = response.json()
482
+ return response_data
483
+
484
+
401
485
  class RESTfulChatglmCppChatModelHandle(RESTfulEmbeddingModelHandle):
402
486
  def chat(
403
487
  self,
@@ -744,6 +828,8 @@ class Client:
744
828
  return RESTfulImageModelHandle(model_uid, self.base_url)
745
829
  elif desc["model_type"] == "rerank":
746
830
  return RESTfulRerankModelHandle(model_uid, self.base_url)
831
+ elif desc["model_type"] == "multimodal":
832
+ return RESTfulMultimodalModelHandle(model_uid, self.base_url)
747
833
  else:
748
834
  raise ValueError(f"Unknown model type:{desc['model_type']}")
749
835
 
xinference/core/model.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  import asyncio
16
16
  import inspect
17
+ import json
17
18
  import os
18
19
  import uuid
19
20
  from typing import (
@@ -30,6 +31,7 @@ from typing import (
30
31
  Union,
31
32
  )
32
33
 
34
+ import sse_starlette.sse
33
35
  import xoscar as xo
34
36
 
35
37
  if TYPE_CHECKING:
@@ -186,7 +188,7 @@ class ModelActor(xo.StatelessActor):
186
188
  )
187
189
  )
188
190
 
189
- async def _wrap_generator(self, ret: Any):
191
+ def _wrap_generator(self, ret: Any):
190
192
  if inspect.isgenerator(ret) or inspect.isasyncgen(ret):
191
193
  if self._lock is not None and self._generators:
192
194
  raise Exception("Parallel generation is not supported by ggml.")
@@ -199,7 +201,7 @@ class ModelActor(xo.StatelessActor):
199
201
  model_actor_uid=self.uid,
200
202
  )
201
203
  else:
202
- return ret
204
+ return json_dumps(ret)
203
205
 
204
206
  async def _call_wrapper(self, _wrapper: Callable):
205
207
  try:
@@ -335,9 +337,10 @@ class ModelActor(xo.StatelessActor):
335
337
  )
336
338
 
337
339
  def _wrapper():
338
- return getattr(self._model, "text_to_image")(
340
+ r = getattr(self._model, "text_to_image")(
339
341
  prompt, n, size, response_format, *args, **kwargs
340
342
  )
343
+ return json_dumps(r)
341
344
 
342
345
  return await self._call_wrapper(_wrapper)
343
346
 
@@ -358,7 +361,7 @@ class ModelActor(xo.StatelessActor):
358
361
  )
359
362
 
360
363
  def _wrapper():
361
- return getattr(self._model, "image_to_image")(
364
+ r = getattr(self._model, "image_to_image")(
362
365
  image,
363
366
  prompt,
364
367
  negative_prompt,
@@ -368,10 +371,10 @@ class ModelActor(xo.StatelessActor):
368
371
  *args,
369
372
  **kwargs,
370
373
  )
374
+ return json_dumps(r)
371
375
 
372
376
  return await self._call_wrapper(_wrapper)
373
377
 
374
- @log_async(logger=logger)
375
378
  async def next(
376
379
  self, generator_uid: str
377
380
  ) -> Union["ChatCompletionChunk", "CompletionChunk"]:
@@ -381,14 +384,18 @@ class ModelActor(xo.StatelessActor):
381
384
 
382
385
  def _wrapper():
383
386
  try:
384
- return next(gen)
387
+ v = dict(data=json.dumps(next(gen)))
388
+ return sse_starlette.sse.ensure_bytes(v, None)
385
389
  except StopIteration:
386
390
  return stop
387
391
 
388
392
  async def _async_wrapper():
389
393
  try:
390
394
  # anext is only available for Python >= 3.10
391
- return await gen.__anext__() # noqa: F821
395
+ v = await gen.__anext__()
396
+ v = await asyncio.to_thread(json.dumps, v)
397
+ v = dict(data=v) # noqa: F821
398
+ return await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
392
399
  except StopAsyncIteration:
393
400
  return stop
394
401
 
@@ -114,6 +114,18 @@ class SupervisorActor(xo.StatelessActor):
114
114
  data[k] = v.dict()
115
115
  return data
116
116
 
117
+ @staticmethod
118
+ async def get_builtin_families() -> Dict[str, List[str]]:
119
+ from ..model.llm.llm_family import (
120
+ BUILTIN_LLM_MODEL_CHAT_FAMILIES,
121
+ BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
122
+ )
123
+
124
+ return {
125
+ "chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
126
+ "generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
127
+ }
128
+
117
129
  async def get_devices_count(self) -> int:
118
130
  from ..utils import cuda_count
119
131
 
@@ -402,6 +402,22 @@ def list_model_registrations(
402
402
  tabulate(table, headers=["Type", "Name", "Family", "Is-built-in"]),
403
403
  file=sys.stderr,
404
404
  )
405
+ elif model_type == "multimodal":
406
+ for registration in registrations:
407
+ model_name = registration["model_name"]
408
+ model_family = client.get_model_registration(model_type, model_name)
409
+ table.append(
410
+ [
411
+ model_type,
412
+ model_family["model_name"],
413
+ model_family["model_lang"],
414
+ registration["is_builtin"],
415
+ ]
416
+ )
417
+ print(
418
+ tabulate(table, headers=["Type", "Name", "Language", "Is-built-in"]),
419
+ file=sys.stderr,
420
+ )
405
421
  else:
406
422
  raise NotImplementedError(f"List {model_type} is not implemented.")
407
423
 
@@ -159,6 +159,7 @@ def test_cmdline_of_custom_model(setup):
159
159
  "embed",
160
160
  "chat"
161
161
  ],
162
+ "model_family": "other",
162
163
  "model_specs": [
163
164
  {
164
165
  "model_format": "pytorch",
@@ -142,5 +142,45 @@
142
142
  "language": ["en"],
143
143
  "model_id": "jinaai/jina-embeddings-v2-base-en",
144
144
  "model_revision": "7302ac470bed880590f9344bfeee32ff8722d0e5"
145
+ },
146
+ {
147
+ "model_name": "text2vec-large-chinese",
148
+ "dimensions": 1024,
149
+ "max_tokens": 256,
150
+ "language": ["zh"],
151
+ "model_id": "shibing624/text2vec-bge-large-chinese",
152
+ "model_revision": "f5027ca48ea8316d63ee26d2b9bd27a061de33a3"
153
+ },
154
+ {
155
+ "model_name": "text2vec-base-chinese",
156
+ "dimensions": 768,
157
+ "max_tokens": 128,
158
+ "language": ["zh"],
159
+ "model_id": "shibing624/text2vec-base-chinese",
160
+ "model_revision": "8acc1289891d75f6b665ad623359798b55f86adb"
161
+ },
162
+ {
163
+ "model_name": "text2vec-base-chinese-paraphrase",
164
+ "dimensions": 768,
165
+ "max_tokens": 256,
166
+ "language": ["zh"],
167
+ "model_id": "shibing624/text2vec-base-chinese-paraphrase",
168
+ "model_revision": "beaf10481a5d9ca3b0daa9f0df6831ec956bf739"
169
+ },
170
+ {
171
+ "model_name": "text2vec-base-chinese-sentence",
172
+ "dimensions": 768,
173
+ "max_tokens": 256,
174
+ "language": ["zh"],
175
+ "model_id": "shibing624/text2vec-base-chinese-sentence",
176
+ "model_revision": "e73a94e821f22c6163166bfab9408d03933a5525"
177
+ },
178
+ {
179
+ "model_name": "text2vec-base-multilingual",
180
+ "dimensions": 384,
181
+ "max_tokens": 256,
182
+ "language": ["zh"],
183
+ "model_id": "shibing624/text2vec-base-multilingual",
184
+ "model_revision": "f241877385fa56ebcc75f04d1850e1579cfa661d"
145
185
  }
146
186
  ]
@@ -19,9 +19,12 @@ import os
19
19
  from .core import LLM
20
20
  from .llm_family import (
21
21
  BUILTIN_LLM_FAMILIES,
22
+ BUILTIN_LLM_MODEL_CHAT_FAMILIES,
23
+ BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
22
24
  BUILTIN_LLM_PROMPT_STYLE,
23
25
  BUILTIN_MODELSCOPE_LLM_FAMILIES,
24
26
  LLM_CLASSES,
27
+ CustomLLMFamilyV1,
25
28
  GgmlLLMSpecV1,
26
29
  LLMFamilyV1,
27
30
  LLMSpecV1,
@@ -94,6 +97,11 @@ def _install():
94
97
  # note that the key is the model name,
95
98
  # since there are multiple representations of the same prompt style name in json.
96
99
  BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
100
+ # register model family
101
+ if "chat" in model_spec.model_ability:
102
+ BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
103
+ else:
104
+ BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
97
105
 
98
106
  modelscope_json_path = os.path.join(
99
107
  os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json"
@@ -110,6 +118,11 @@ def _install():
110
118
  and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
111
119
  ):
112
120
  BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
121
+ # register model family
122
+ if "chat" in model_spec.model_ability:
123
+ BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
124
+ else:
125
+ BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
113
126
 
114
127
  from ...constants import XINFERENCE_MODEL_DIR
115
128
 
@@ -119,5 +132,5 @@ def _install():
119
132
  with codecs.open(
120
133
  os.path.join(user_defined_llm_dir, f), encoding="utf-8"
121
134
  ) as fd:
122
- user_defined_llm_family = LLMFamilyV1.parse_obj(json.load(fd))
135
+ user_defined_llm_family = CustomLLMFamilyV1.parse_obj(json.load(fd))
123
136
  register_llm(user_defined_llm_family, persist=False)
@@ -557,7 +557,7 @@
557
557
  "none"
558
558
  ],
559
559
  "model_id": "THUDM/chatglm3-6b",
560
- "model_revision": "e46a14881eae613281abbd266ee918e93a56018f"
560
+ "model_revision": "b098244a71fbe69ce149682d9072a7629f7e908c"
561
561
  }
562
562
  ],
563
563
  "prompt_style": {
@@ -566,6 +566,15 @@
566
566
  "roles": [
567
567
  "user",
568
568
  "assistant"
569
+ ],
570
+ "stop_token_ids": [
571
+ 64795,
572
+ 64797,
573
+ 2
574
+ ],
575
+ "stop":[
576
+ "<|user|>",
577
+ "<|observation|>"
569
578
  ]
570
579
  }
571
580
  },