xinference 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (35) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +35 -1
  3. xinference/client/oscar/actor_client.py +2 -2
  4. xinference/client/restful/restful_client.py +2 -2
  5. xinference/conftest.py +5 -1
  6. xinference/core/metrics.py +83 -0
  7. xinference/core/model.py +148 -8
  8. xinference/core/status_guard.py +86 -0
  9. xinference/core/supervisor.py +57 -7
  10. xinference/core/worker.py +132 -13
  11. xinference/deploy/cmdline.py +57 -4
  12. xinference/deploy/local.py +32 -6
  13. xinference/deploy/worker.py +33 -5
  14. xinference/fields.py +4 -1
  15. xinference/model/llm/__init__.py +7 -0
  16. xinference/model/llm/ggml/llamacpp.py +3 -2
  17. xinference/model/llm/llm_family.json +70 -3
  18. xinference/model/llm/llm_family.py +11 -1
  19. xinference/model/llm/llm_family_modelscope.json +72 -3
  20. xinference/model/llm/pytorch/chatglm.py +70 -28
  21. xinference/model/llm/pytorch/core.py +11 -30
  22. xinference/model/llm/pytorch/internlm2.py +155 -0
  23. xinference/model/llm/pytorch/utils.py +0 -153
  24. xinference/model/llm/utils.py +37 -8
  25. xinference/model/llm/vllm/core.py +15 -3
  26. xinference/model/multimodal/__init__.py +15 -8
  27. xinference/model/multimodal/model_spec_modelscope.json +45 -0
  28. xinference/model/utils.py +7 -2
  29. xinference/types.py +2 -0
  30. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/METADATA +2 -1
  31. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/RECORD +35 -31
  32. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/LICENSE +0 -0
  33. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/WHEEL +0 -0
  34. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/entry_points.txt +0 -0
  35. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/top_level.txt +0 -0
xinference/core/worker.py CHANGED
@@ -15,7 +15,9 @@
15
15
  import asyncio
16
16
  import os
17
17
  import platform
18
+ import queue
18
19
  import signal
20
+ import threading
19
21
  from collections import defaultdict
20
22
  from logging import getLogger
21
23
  from typing import Any, Dict, List, Optional, Set, Tuple, Union
@@ -25,8 +27,10 @@ from xoscar import MainActorPoolType
25
27
 
26
28
  from ..constants import XINFERENCE_CACHE_DIR
27
29
  from ..core import ModelActor
30
+ from ..core.status_guard import LaunchStatus
28
31
  from ..model.core import ModelDescription, create_model_instance
29
32
  from ..utils import cuda_count
33
+ from .metrics import launch_metrics_export_server, record_metrics
30
34
  from .resource import gather_node_info
31
35
  from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
32
36
 
@@ -34,6 +38,12 @@ logger = getLogger(__name__)
34
38
 
35
39
 
36
40
  DEFAULT_NODE_HEARTBEAT_INTERVAL = 5
41
+ MODEL_ACTOR_AUTO_RECOVER_LIMIT: Optional[int]
42
+ _MODEL_ACTOR_AUTO_RECOVER_LIMIT = os.getenv("XINFERENCE_MODEL_ACTOR_AUTO_RECOVER_LIMIT")
43
+ if _MODEL_ACTOR_AUTO_RECOVER_LIMIT is not None:
44
+ MODEL_ACTOR_AUTO_RECOVER_LIMIT = int(_MODEL_ACTOR_AUTO_RECOVER_LIMIT)
45
+ else:
46
+ MODEL_ACTOR_AUTO_RECOVER_LIMIT = None
37
47
 
38
48
 
39
49
  class WorkerActor(xo.StatelessActor):
@@ -42,6 +52,8 @@ class WorkerActor(xo.StatelessActor):
42
52
  supervisor_address: str,
43
53
  main_pool: MainActorPoolType,
44
54
  cuda_devices: List[int],
55
+ metrics_exporter_host: Optional[str] = None,
56
+ metrics_exporter_port: Optional[int] = None,
45
57
  ):
46
58
  super().__init__()
47
59
  # static attrs.
@@ -57,20 +69,71 @@ class WorkerActor(xo.StatelessActor):
57
69
  self._gpu_to_model_uid: Dict[int, str] = {}
58
70
  self._gpu_to_embedding_model_uids: Dict[int, Set[str]] = defaultdict(set)
59
71
  self._model_uid_to_addr: Dict[str, str] = {}
72
+ self._model_uid_to_recover_count: Dict[str, int] = {}
60
73
  self._model_uid_to_launch_args: Dict[str, Dict] = {}
61
74
 
75
+ # metrics export server.
76
+ if metrics_exporter_host is not None or metrics_exporter_port is not None:
77
+ logger.info(
78
+ f"Starting metrics export server at {metrics_exporter_host}:{metrics_exporter_port}"
79
+ )
80
+ q: queue.Queue = queue.Queue()
81
+ self._metrics_thread = threading.Thread(
82
+ name="Metrics Export Server",
83
+ target=launch_metrics_export_server,
84
+ args=(q, metrics_exporter_host, metrics_exporter_port),
85
+ daemon=True,
86
+ )
87
+ self._metrics_thread.start()
88
+ logger.info("Checking metrics export server...")
89
+ while self._metrics_thread.is_alive():
90
+ try:
91
+ host, port = q.get(block=False)[:2]
92
+ logger.info(f"Metrics server is started at: http://{host}:{port}")
93
+ break
94
+ except queue.Empty:
95
+ pass
96
+ else:
97
+ raise Exception("Metrics server thread exit.")
98
+
62
99
  self._lock = asyncio.Lock()
63
100
 
64
101
  async def recover_sub_pool(self, address):
65
- logger.warning("Process %s is down, create model.", address)
102
+ logger.warning("Process %s is down.", address)
103
+ # Xoscar does not remove the address from sub_processes.
104
+ try:
105
+ await self._main_pool.remove_sub_pool(address)
106
+ except Exception:
107
+ pass
66
108
  for model_uid, addr in self._model_uid_to_addr.items():
67
109
  if addr == address:
68
110
  launch_args = self._model_uid_to_launch_args.get(model_uid)
69
- try:
70
- await self.terminate_model(model_uid)
71
- except Exception:
72
- pass
73
- await self.launch_builtin_model(**launch_args)
111
+ if launch_args is None:
112
+ logger.warning(
113
+ "Not recreate model because the it is down during launch."
114
+ )
115
+ else:
116
+ recover_count = self._model_uid_to_recover_count.get(model_uid)
117
+ try:
118
+ await self.terminate_model(model_uid)
119
+ except Exception:
120
+ pass
121
+ if recover_count is not None:
122
+ if recover_count > 0:
123
+ logger.warning(
124
+ "Recreating model actor %s, remain %s times ...",
125
+ model_uid,
126
+ recover_count - 1,
127
+ )
128
+ self._model_uid_to_recover_count[model_uid] = (
129
+ recover_count - 1
130
+ )
131
+ await self.launch_builtin_model(**launch_args)
132
+ else:
133
+ logger.warning("Stop recreating model actor.")
134
+ else:
135
+ logger.warning("Recreating model actor %s ...", model_uid)
136
+ await self.launch_builtin_model(**launch_args)
74
137
  break
75
138
 
76
139
  @classmethod
@@ -78,8 +141,14 @@ class WorkerActor(xo.StatelessActor):
78
141
  return "worker"
79
142
 
80
143
  async def __post_create__(self):
144
+ from .status_guard import StatusGuardActor
81
145
  from .supervisor import SupervisorActor
82
146
 
147
+ self._status_guard_ref: xo.ActorRefType[
148
+ "StatusGuardActor"
149
+ ] = await xo.actor_ref(
150
+ address=self._supervisor_address, uid=StatusGuardActor.uid()
151
+ )
83
152
  self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref(
84
153
  address=self._supervisor_address, uid=SupervisorActor.uid()
85
154
  )
@@ -309,7 +378,12 @@ class WorkerActor(xo.StatelessActor):
309
378
 
310
379
  try:
311
380
  model_ref = await xo.create_actor(
312
- ModelActor, address=subpool_address, uid=model_uid, model=model
381
+ ModelActor,
382
+ address=subpool_address,
383
+ uid=model_uid,
384
+ worker_address=self.address,
385
+ model=model,
386
+ model_description=model_description,
313
387
  )
314
388
  await model_ref.load()
315
389
  except:
@@ -324,6 +398,22 @@ class WorkerActor(xo.StatelessActor):
324
398
  self._gpu_to_model_uid[int(dev)] = model_uid
325
399
  self._model_uid_to_addr[model_uid] = subpool_address
326
400
 
401
+ async def _get_model_ability(self, model: Any, model_type: str) -> List[str]:
402
+ from ..model.llm.core import LLM
403
+
404
+ if model_type == "embedding":
405
+ return ["embed"]
406
+ elif model_type == "rerank":
407
+ return ["rerank"]
408
+ elif model_type == "image":
409
+ return ["text_to_image"]
410
+ elif model_type == "multimodal":
411
+ return ["multimodal"]
412
+ else:
413
+ assert model_type == "LLM"
414
+ assert isinstance(model, LLM)
415
+ return model.model_family.model_ability # type: ignore
416
+
327
417
  @log_async(logger=logger)
328
418
  async def launch_builtin_model(
329
419
  self,
@@ -360,6 +450,7 @@ class WorkerActor(xo.StatelessActor):
360
450
  )
361
451
 
362
452
  try:
453
+ origin_uid, _, _ = parse_replica_model_uid(model_uid)
363
454
  model, model_description = await asyncio.to_thread(
364
455
  create_model_instance,
365
456
  subpool_address,
@@ -377,7 +468,9 @@ class WorkerActor(xo.StatelessActor):
377
468
  ModelActor,
378
469
  address=subpool_address,
379
470
  uid=model_uid,
471
+ worker_address=self.address,
380
472
  model=model,
473
+ model_description=model_description,
381
474
  request_limits=request_limits,
382
475
  )
383
476
  await model_ref.load()
@@ -390,13 +483,27 @@ class WorkerActor(xo.StatelessActor):
390
483
  self._model_uid_to_model[model_uid] = model_ref
391
484
  self._model_uid_to_model_spec[model_uid] = model_description
392
485
  self._model_uid_to_addr[model_uid] = subpool_address
486
+ self._model_uid_to_recover_count.setdefault(
487
+ model_uid, MODEL_ACTOR_AUTO_RECOVER_LIMIT
488
+ )
393
489
  self._model_uid_to_launch_args[model_uid] = launch_args
394
490
 
491
+ # update status to READY
492
+ abilities = await self._get_model_ability(model, model_type)
493
+ await self._status_guard_ref.update_instance_info(
494
+ origin_uid,
495
+ {"model_ability": abilities, "status": LaunchStatus.READY.name},
496
+ )
497
+
395
498
  @log_async(logger=logger)
396
499
  async def terminate_model(self, model_uid: str):
500
+ origin_uid, _, _ = parse_replica_model_uid(model_uid)
501
+ await self._status_guard_ref.update_instance_info(
502
+ origin_uid, {"status": LaunchStatus.TERMINATING.name}
503
+ )
397
504
  model_ref = self._model_uid_to_model.get(model_uid, None)
398
505
  if model_ref is None:
399
- raise ValueError(f"Model not found in the model list, uid: {model_uid}")
506
+ logger.debug("Model not found, uid: %s", model_uid)
400
507
 
401
508
  try:
402
509
  await xo.destroy_actor(model_ref)
@@ -407,12 +514,20 @@ class WorkerActor(xo.StatelessActor):
407
514
  try:
408
515
  subpool_address = self._model_uid_to_addr[model_uid]
409
516
  await self._main_pool.remove_sub_pool(subpool_address)
517
+ except Exception as e:
518
+ logger.debug(
519
+ "Remove sub pool failed, model uid: %s, error: %s", model_uid, e
520
+ )
410
521
  finally:
411
- del self._model_uid_to_model[model_uid]
412
- del self._model_uid_to_model_spec[model_uid]
522
+ self._model_uid_to_model.pop(model_uid, None)
523
+ self._model_uid_to_model_spec.pop(model_uid, None)
413
524
  self.release_devices(model_uid)
414
- del self._model_uid_to_addr[model_uid]
415
- del self._model_uid_to_launch_args[model_uid]
525
+ self._model_uid_to_addr.pop(model_uid, None)
526
+ self._model_uid_to_recover_count.pop(model_uid, None)
527
+ self._model_uid_to_launch_args.pop(model_uid, None)
528
+ await self._status_guard_ref.update_instance_info(
529
+ origin_uid, {"status": LaunchStatus.TERMINATED.name}
530
+ )
416
531
 
417
532
  @log_async(logger=logger)
418
533
  async def list_models(self) -> Dict[str, Dict[str, Any]]:
@@ -427,7 +542,7 @@ class WorkerActor(xo.StatelessActor):
427
542
  def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
428
543
  model_ref = self._model_uid_to_model.get(model_uid, None)
429
544
  if model_ref is None:
430
- raise ValueError(f"Model not found in the model list, uid: {model_uid}")
545
+ raise ValueError(f"Model not found, uid: {model_uid}")
431
546
  return model_ref
432
547
 
433
548
  @log_sync(logger=logger)
@@ -460,3 +575,7 @@ class WorkerActor(xo.StatelessActor):
460
575
  await asyncio.sleep(DEFAULT_NODE_HEARTBEAT_INTERVAL)
461
576
  except asyncio.CancelledError: # pragma: no cover
462
577
  break
578
+
579
+ @staticmethod
580
+ def record_metrics(name, op, kwargs):
581
+ record_metrics(name, op, kwargs)
@@ -87,7 +87,12 @@ def get_stored_token(
87
87
 
88
88
 
89
89
  def start_local_cluster(
90
- log_level: str, host: str, port: int, auth_config_file: Optional[str] = None
90
+ log_level: str,
91
+ host: str,
92
+ port: int,
93
+ metrics_exporter_host: Optional[str] = None,
94
+ metrics_exporter_port: Optional[int] = None,
95
+ auth_config_file: Optional[str] = None,
91
96
  ):
92
97
  from .local import main
93
98
 
@@ -102,6 +107,8 @@ def start_local_cluster(
102
107
  main(
103
108
  host=host,
104
109
  port=port,
110
+ metrics_exporter_host=metrics_exporter_host,
111
+ metrics_exporter_port=metrics_exporter_port,
105
112
  logging_conf=dict_config,
106
113
  auth_config_file=auth_config_file,
107
114
  )
@@ -182,14 +189,41 @@ def cli(
182
189
  type=int,
183
190
  help="Specify the port number for the Xinference server.",
184
191
  )
192
+ @click.option(
193
+ "--metrics-exporter-host",
194
+ "-MH",
195
+ default=None,
196
+ type=str,
197
+ help="Specify the host address for the Xinference metrics exporter server, default is the same as --host.",
198
+ )
199
+ @click.option(
200
+ "--metrics-exporter-port",
201
+ "-mp",
202
+ type=int,
203
+ help="Specify the port number for the Xinference metrics exporter server.",
204
+ )
185
205
  @click.option(
186
206
  "--auth-config",
187
207
  type=str,
188
208
  help="Specify the auth config json file.",
189
209
  )
190
- def local(log_level: str, host: str, port: int, auth_config: Optional[str]):
210
+ def local(
211
+ log_level: str,
212
+ host: str,
213
+ port: int,
214
+ metrics_exporter_host: Optional[str],
215
+ metrics_exporter_port: Optional[int],
216
+ auth_config: Optional[str],
217
+ ):
218
+ if metrics_exporter_host is None:
219
+ metrics_exporter_host = host
191
220
  start_local_cluster(
192
- log_level=log_level, host=host, port=port, auth_config_file=auth_config
221
+ log_level=log_level,
222
+ host=host,
223
+ port=port,
224
+ metrics_exporter_host=metrics_exporter_host,
225
+ metrics_exporter_port=metrics_exporter_port,
226
+ auth_config_file=auth_config,
193
227
  )
194
228
 
195
229
 
@@ -276,8 +310,25 @@ def supervisor(
276
310
  type=int,
277
311
  help="Specify the port number for the Xinference worker.",
278
312
  )
313
+ @click.option(
314
+ "--metrics-exporter-host",
315
+ "-MH",
316
+ default=XINFERENCE_DEFAULT_DISTRIBUTED_HOST,
317
+ type=str,
318
+ help="Specify the host address for the metrics exporter server.",
319
+ )
320
+ @click.option(
321
+ "--metrics-exporter-port",
322
+ type=int,
323
+ help="Specify the port number for the Xinference metrics exporter worker.",
324
+ )
279
325
  def worker(
280
- log_level: str, endpoint: Optional[str], host: str, worker_port: Optional[int]
326
+ log_level: str,
327
+ endpoint: Optional[str],
328
+ host: str,
329
+ worker_port: Optional[int],
330
+ metrics_exporter_host: Optional[str],
331
+ metrics_exporter_port: Optional[int],
281
332
  ):
282
333
  from ..deploy.worker import main
283
334
 
@@ -298,6 +349,8 @@ def worker(
298
349
  main(
299
350
  address=address,
300
351
  supervisor_address=supervisor_internal_addr,
352
+ metrics_exporter_host=metrics_exporter_host,
353
+ metrics_exporter_port=metrics_exporter_port,
301
354
  logging_conf=dict_config,
302
355
  )
303
356
 
@@ -35,6 +35,8 @@ logger = logging.getLogger(__name__)
35
35
 
36
36
  async def _start_local_cluster(
37
37
  address: str,
38
+ metrics_exporter_host: Optional[str] = None,
39
+ metrics_exporter_port: Optional[int] = None,
38
40
  logging_conf: Optional[Dict] = None,
39
41
  ):
40
42
  from .utils import create_worker_actor_pool
@@ -50,7 +52,11 @@ async def _start_local_cluster(
50
52
  SupervisorActor, address=address, uid=SupervisorActor.uid()
51
53
  )
52
54
  await start_worker_components(
53
- address=address, supervisor_address=address, main_pool=pool
55
+ address=address,
56
+ supervisor_address=address,
57
+ main_pool=pool,
58
+ metrics_exporter_host=metrics_exporter_host,
59
+ metrics_exporter_port=metrics_exporter_port,
54
60
  )
55
61
  await pool.join()
56
62
  except asyncio.CancelledError:
@@ -58,7 +64,12 @@ async def _start_local_cluster(
58
64
  await pool.stop()
59
65
 
60
66
 
61
- def run(address: str, logging_conf: Optional[Dict] = None):
67
+ def run(
68
+ address: str,
69
+ metrics_exporter_host: Optional[str] = None,
70
+ metrics_exporter_port: Optional[int] = None,
71
+ logging_conf: Optional[Dict] = None,
72
+ ):
62
73
  def sigterm_handler(signum, frame):
63
74
  sys.exit(0)
64
75
 
@@ -66,15 +77,26 @@ def run(address: str, logging_conf: Optional[Dict] = None):
66
77
 
67
78
  loop = asyncio.get_event_loop()
68
79
  task = loop.create_task(
69
- _start_local_cluster(address=address, logging_conf=logging_conf)
80
+ _start_local_cluster(
81
+ address=address,
82
+ metrics_exporter_host=metrics_exporter_host,
83
+ metrics_exporter_port=metrics_exporter_port,
84
+ logging_conf=logging_conf,
85
+ )
70
86
  )
71
87
  loop.run_until_complete(task)
72
88
 
73
89
 
74
90
  def run_in_subprocess(
75
- address: str, logging_conf: Optional[Dict] = None
91
+ address: str,
92
+ metrics_exporter_host: Optional[str] = None,
93
+ metrics_exporter_port: Optional[int] = None,
94
+ logging_conf: Optional[Dict] = None,
76
95
  ) -> multiprocessing.Process:
77
- p = multiprocessing.Process(target=run, args=(address, logging_conf))
96
+ p = multiprocessing.Process(
97
+ target=run,
98
+ args=(address, metrics_exporter_host, metrics_exporter_port, logging_conf),
99
+ )
78
100
  p.start()
79
101
  return p
80
102
 
@@ -82,11 +104,15 @@ def run_in_subprocess(
82
104
  def main(
83
105
  host: str,
84
106
  port: int,
107
+ metrics_exporter_host: Optional[str] = None,
108
+ metrics_exporter_port: Optional[int] = None,
85
109
  logging_conf: Optional[Dict] = None,
86
110
  auth_config_file: Optional[str] = None,
87
111
  ):
88
112
  supervisor_address = f"{host}:{get_next_port()}"
89
- local_cluster = run_in_subprocess(supervisor_address, logging_conf)
113
+ local_cluster = run_in_subprocess(
114
+ supervisor_address, metrics_exporter_host, metrics_exporter_port, logging_conf
115
+ )
90
116
 
91
117
  if not health_check(
92
118
  address=supervisor_address,
@@ -27,7 +27,11 @@ logger = logging.getLogger(__name__)
27
27
 
28
28
 
29
29
  async def start_worker_components(
30
- address: str, supervisor_address: str, main_pool: MainActorPoolType
30
+ address: str,
31
+ supervisor_address: str,
32
+ main_pool: MainActorPoolType,
33
+ metrics_exporter_host: Optional[str],
34
+ metrics_exporter_port: Optional[int],
31
35
  ):
32
36
  cuda_device_indices = []
33
37
  cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
@@ -43,24 +47,48 @@ async def start_worker_components(
43
47
  supervisor_address=supervisor_address,
44
48
  main_pool=main_pool,
45
49
  cuda_devices=cuda_device_indices,
50
+ metrics_exporter_host=metrics_exporter_host,
51
+ metrics_exporter_port=metrics_exporter_port,
46
52
  )
47
53
 
48
54
 
49
55
  async def _start_worker(
50
- address: str, supervisor_address: str, logging_conf: Any = None
56
+ address: str,
57
+ supervisor_address: str,
58
+ metrics_exporter_host: Optional[str] = None,
59
+ metrics_exporter_port: Optional[int] = None,
60
+ logging_conf: Any = None,
51
61
  ):
52
62
  from .utils import create_worker_actor_pool
53
63
 
54
64
  pool = await create_worker_actor_pool(address=address, logging_conf=logging_conf)
55
65
  await start_worker_components(
56
- address=address, supervisor_address=supervisor_address, main_pool=pool
66
+ address=address,
67
+ supervisor_address=supervisor_address,
68
+ main_pool=pool,
69
+ metrics_exporter_host=metrics_exporter_host,
70
+ metrics_exporter_port=metrics_exporter_port,
57
71
  )
58
72
  await pool.join()
59
73
 
60
74
 
61
- def main(address: str, supervisor_address: str, logging_conf: Optional[dict] = None):
75
+ def main(
76
+ address: str,
77
+ supervisor_address: str,
78
+ metrics_exporter_host: Optional[str] = None,
79
+ metrics_exporter_port: Optional[int] = None,
80
+ logging_conf: Optional[dict] = None,
81
+ ):
62
82
  loop = asyncio.get_event_loop()
63
- task = loop.create_task(_start_worker(address, supervisor_address, logging_conf))
83
+ task = loop.create_task(
84
+ _start_worker(
85
+ address,
86
+ supervisor_address,
87
+ metrics_exporter_host,
88
+ metrics_exporter_port,
89
+ logging_conf,
90
+ )
91
+ )
64
92
 
65
93
  try:
66
94
  loop.run_until_complete(task)
xinference/fields.py CHANGED
@@ -30,7 +30,10 @@ logprobs_field = Field(
30
30
  )
31
31
 
32
32
  max_tokens_field = Field(
33
- default=128, ge=1, le=32768, description="The maximum number of tokens to generate."
33
+ default=1024,
34
+ ge=1,
35
+ le=32768,
36
+ description="The maximum number of tokens to generate.",
34
37
  )
35
38
 
36
39
  temperature_field = Field(
@@ -21,6 +21,7 @@ from .llm_family import (
21
21
  BUILTIN_LLM_FAMILIES,
22
22
  BUILTIN_LLM_MODEL_CHAT_FAMILIES,
23
23
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
24
+ BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
24
25
  BUILTIN_LLM_PROMPT_STYLE,
25
26
  BUILTIN_MODELSCOPE_LLM_FAMILIES,
26
27
  LLM_CLASSES,
@@ -47,6 +48,7 @@ def _install():
47
48
  from .pytorch.chatglm import ChatglmPytorchChatModel
48
49
  from .pytorch.core import PytorchChatModel, PytorchModel
49
50
  from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel
51
+ from .pytorch.internlm2 import Internlm2PytorchChatModel
50
52
  from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
51
53
  from .pytorch.vicuna import VicunaPytorchChatModel
52
54
  from .vllm.core import VLLMChatModel, VLLMModel
@@ -79,6 +81,7 @@ def _install():
79
81
  LlamaPytorchChatModel,
80
82
  PytorchChatModel,
81
83
  FalconPytorchModel,
84
+ Internlm2PytorchChatModel,
82
85
  PytorchModel,
83
86
  ]
84
87
  )
@@ -102,6 +105,8 @@ def _install():
102
105
  BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
103
106
  else:
104
107
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
108
+ if "tool_call" in model_spec.model_ability:
109
+ BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
105
110
 
106
111
  modelscope_json_path = os.path.join(
107
112
  os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json"
@@ -123,6 +128,8 @@ def _install():
123
128
  BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
124
129
  else:
125
130
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
131
+ if "tool_call" in model_spec.model_ability:
132
+ BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
126
133
 
127
134
  from ...constants import XINFERENCE_MODEL_DIR
128
135
 
@@ -306,7 +306,8 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
306
306
 
307
307
  generate_config = self._sanitize_generate_config(generate_config)
308
308
  # TODO(codingl2k1): qwen hacky to set stop for function call.
309
- if tools and self.model_family.model_name == "qwen-chat":
309
+ model_family = self.model_family.model_family or self.model_family.model_name
310
+ if tools and "qwen-chat" == model_family:
310
311
  stop = generate_config.get("stop")
311
312
  if isinstance(stop, str):
312
313
  generate_config["stop"] = [stop, "Observation:"]
@@ -326,6 +327,6 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
326
327
  assert not isinstance(c, Iterator)
327
328
  if tools:
328
329
  return self._tool_calls_completion(
329
- self.model_family.model_name, self.model_uid, c, tools
330
+ self.model_family, self.model_uid, c, tools
330
331
  )
331
332
  return self._to_chat_completion(c)