xinference 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (58) hide show
  1. {xinference-0.1.2/xinference.egg-info → xinference-0.1.3}/PKG-INFO +105 -1
  2. {xinference-0.1.2 → xinference-0.1.3}/README.md +104 -0
  3. {xinference-0.1.2 → xinference-0.1.3}/setup.cfg +0 -2
  4. {xinference-0.1.2 → xinference-0.1.3}/xinference/_version.py +3 -3
  5. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/worker.py +1 -0
  6. {xinference-0.1.2 → xinference-0.1.3}/xinference/deploy/cmdline.py +35 -8
  7. {xinference-0.1.2 → xinference-0.1.3}/xinference/deploy/worker.py +2 -2
  8. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/ggml/llamacpp.py +1 -0
  9. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/llm_family.json +30 -15
  10. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/llm_family.py +8 -6
  11. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/pytorch/core.py +63 -40
  12. {xinference-0.1.2 → xinference-0.1.3/xinference.egg-info}/PKG-INFO +105 -1
  13. {xinference-0.1.2 → xinference-0.1.3}/xinference.egg-info/requires.txt +0 -2
  14. {xinference-0.1.2 → xinference-0.1.3}/LICENSE +0 -0
  15. {xinference-0.1.2 → xinference-0.1.3}/MANIFEST.in +0 -0
  16. {xinference-0.1.2 → xinference-0.1.3}/pyproject.toml +0 -0
  17. {xinference-0.1.2 → xinference-0.1.3}/setup.py +0 -0
  18. {xinference-0.1.2 → xinference-0.1.3}/versioneer.py +0 -0
  19. {xinference-0.1.2 → xinference-0.1.3}/xinference/__init__.py +0 -0
  20. {xinference-0.1.2 → xinference-0.1.3}/xinference/client.py +0 -0
  21. {xinference-0.1.2 → xinference-0.1.3}/xinference/constants.py +0 -0
  22. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/__init__.py +0 -0
  23. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/api.py +0 -0
  24. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/gradio.py +0 -0
  25. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/model.py +0 -0
  26. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/resource.py +0 -0
  27. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/restful_api.py +0 -0
  28. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/supervisor.py +0 -0
  29. {xinference-0.1.2 → xinference-0.1.3}/xinference/core/utils.py +0 -0
  30. {xinference-0.1.2 → xinference-0.1.3}/xinference/deploy/__init__.py +0 -0
  31. {xinference-0.1.2 → xinference-0.1.3}/xinference/deploy/local.py +0 -0
  32. {xinference-0.1.2 → xinference-0.1.3}/xinference/deploy/supervisor.py +0 -0
  33. {xinference-0.1.2 → xinference-0.1.3}/xinference/deploy/test/__init__.py +0 -0
  34. {xinference-0.1.2 → xinference-0.1.3}/xinference/deploy/utils.py +0 -0
  35. {xinference-0.1.2 → xinference-0.1.3}/xinference/isolation.py +0 -0
  36. {xinference-0.1.2 → xinference-0.1.3}/xinference/locale/__init__.py +0 -0
  37. {xinference-0.1.2 → xinference-0.1.3}/xinference/locale/utils.py +0 -0
  38. {xinference-0.1.2 → xinference-0.1.3}/xinference/locale/zh_CN.json +0 -0
  39. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/__init__.py +0 -0
  40. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/core.py +0 -0
  41. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/__init__.py +0 -0
  42. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/core.py +0 -0
  43. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/ggml/__init__.py +0 -0
  44. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/ggml/chatglm.py +0 -0
  45. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/pytorch/__init__.py +0 -0
  46. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/pytorch/baichuan.py +0 -0
  47. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/pytorch/chatglm.py +0 -0
  48. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/pytorch/compression.py +0 -0
  49. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/pytorch/falcon.py +0 -0
  50. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/pytorch/utils.py +0 -0
  51. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/pytorch/vicuna.py +0 -0
  52. {xinference-0.1.2 → xinference-0.1.3}/xinference/model/llm/utils.py +0 -0
  53. {xinference-0.1.2 → xinference-0.1.3}/xinference/types.py +0 -0
  54. {xinference-0.1.2 → xinference-0.1.3}/xinference.egg-info/SOURCES.txt +0 -0
  55. {xinference-0.1.2 → xinference-0.1.3}/xinference.egg-info/dependency_links.txt +0 -0
  56. {xinference-0.1.2 → xinference-0.1.3}/xinference.egg-info/entry_points.txt +0 -0
  57. {xinference-0.1.2 → xinference-0.1.3}/xinference.egg-info/not-zip-safe +0 -0
  58. {xinference-0.1.2 → xinference-0.1.3}/xinference.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: xinference
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Model Serving Made Easy
5
5
  Home-page: https://github.com/xorbitsai/inference
6
6
  Author: Qin Xuye
@@ -238,6 +238,110 @@ $ xinference list --all
238
238
  - If you want to use Apple Metal GPU for acceleration, please choose the q4_0 and q4_1 quantization methods.
239
239
  - `llama-2-chat` 70B ggmlv3 model only supports q4_0 quantization currently.
240
240
 
241
+ ## Custom models \[Experimental\]
242
+ Custom models are currently an experimental feature and are expected to be officially released in version v0.2.0.
243
+
244
+ Define a custom model based on the following template:
245
+ ```python
246
+ custom_model = {
247
+ "version": 1,
248
+ # model name. must start with a letter or a
249
+ # digit, and can only contain letters, digits,
250
+ # underscores, or dashes.
251
+ "model_name": "nsql-2B",
252
+ # supported languages
253
+ "model_lang": [
254
+ "en"
255
+ ],
256
+ # model abilities. could be "embed", "generate"
257
+ # and "chat".
258
+ "model_ability": [
259
+ "generate"
260
+ ],
261
+ # model specifications.
262
+ "model_specs": [
263
+ {
264
+ # model format.
265
+ "model_format": "pytorch",
266
+ "model_size_in_billions": 2,
267
+ # quantizations.
268
+ "quantizations": [
269
+ "4-bit",
270
+ "8-bit",
271
+ "none"
272
+ ],
273
+ # hugging face model ID.
274
+ "model_id": "NumbersStation/nsql-2B"
275
+ }
276
+ ],
277
+ # prompt style, required by chat models.
278
+ # for more details, see: xinference/model/llm/tests/test_utils.py
279
+ "prompt_style": None
280
+ }
281
+ ```
282
+
283
+ Register the custom model:
284
+ ```python
285
+ import json
286
+
287
+ from xinference.client import Client
288
+
289
+ # replace with real xinference endpoint
290
+ endpoint = "http://localhost:9997"
291
+ client = Client(endpoint)
292
+ client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False)
293
+ ```
294
+
295
+ Load the custom model:
296
+ ```python
297
+ uid = client.launch_model(model_name='nsql-2B')
298
+ ```
299
+
300
+ Run the custom model:
301
+ ```python
302
+ text = """CREATE TABLE work_orders (
303
+ ID NUMBER,
304
+ CREATED_AT TEXT,
305
+ COST FLOAT,
306
+ INVOICE_AMOUNT FLOAT,
307
+ IS_DUE BOOLEAN,
308
+ IS_OPEN BOOLEAN,
309
+ IS_OVERDUE BOOLEAN,
310
+ COUNTRY_NAME TEXT,
311
+ )
312
+
313
+ -- Using valid SQLite, answer the following questions for the tables provided above.
314
+
315
+ -- how many work orders are open?
316
+
317
+ SELECT"""
318
+
319
+ model = client.get_model(model_uid=uid)
320
+ model.generate(prompt=text)
321
+ ```
322
+
323
+ Result:
324
+ ```json
325
+ {
326
+ "id":"aeb5c87a-352e-11ee-89ad-9af9f16816c5",
327
+ "object":"text_completion",
328
+ "created":1691418511,
329
+ "model":"3b912fc4-352e-11ee-8e66-9af9f16816c5",
330
+ "choices":[
331
+ {
332
+ "text":" COUNT(*) FROM work_orders WHERE IS_OPEN = '1';",
333
+ "index":0,
334
+ "logprobs":"None",
335
+ "finish_reason":"stop"
336
+ }
337
+ ],
338
+ "usage":{
339
+ "prompt_tokens":117,
340
+ "completion_tokens":17,
341
+ "total_tokens":134
342
+ }
343
+ }
344
+ ```
241
345
 
242
346
  ## Pytorch Model Best Practices
243
347
 
@@ -210,6 +210,110 @@ $ xinference list --all
210
210
  - If you want to use Apple Metal GPU for acceleration, please choose the q4_0 and q4_1 quantization methods.
211
211
  - `llama-2-chat` 70B ggmlv3 model only supports q4_0 quantization currently.
212
212
 
213
+ ## Custom models \[Experimental\]
214
+ Custom models are currently an experimental feature and are expected to be officially released in version v0.2.0.
215
+
216
+ Define a custom model based on the following template:
217
+ ```python
218
+ custom_model = {
219
+ "version": 1,
220
+ # model name. must start with a letter or a
221
+ # digit, and can only contain letters, digits,
222
+ # underscores, or dashes.
223
+ "model_name": "nsql-2B",
224
+ # supported languages
225
+ "model_lang": [
226
+ "en"
227
+ ],
228
+ # model abilities. could be "embed", "generate"
229
+ # and "chat".
230
+ "model_ability": [
231
+ "generate"
232
+ ],
233
+ # model specifications.
234
+ "model_specs": [
235
+ {
236
+ # model format.
237
+ "model_format": "pytorch",
238
+ "model_size_in_billions": 2,
239
+ # quantizations.
240
+ "quantizations": [
241
+ "4-bit",
242
+ "8-bit",
243
+ "none"
244
+ ],
245
+ # hugging face model ID.
246
+ "model_id": "NumbersStation/nsql-2B"
247
+ }
248
+ ],
249
+ # prompt style, required by chat models.
250
+ # for more details, see: xinference/model/llm/tests/test_utils.py
251
+ "prompt_style": None
252
+ }
253
+ ```
254
+
255
+ Register the custom model:
256
+ ```python
257
+ import json
258
+
259
+ from xinference.client import Client
260
+
261
+ # replace with real xinference endpoint
262
+ endpoint = "http://localhost:9997"
263
+ client = Client(endpoint)
264
+ client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False)
265
+ ```
266
+
267
+ Load the custom model:
268
+ ```python
269
+ uid = client.launch_model(model_name='nsql-2B')
270
+ ```
271
+
272
+ Run the custom model:
273
+ ```python
274
+ text = """CREATE TABLE work_orders (
275
+ ID NUMBER,
276
+ CREATED_AT TEXT,
277
+ COST FLOAT,
278
+ INVOICE_AMOUNT FLOAT,
279
+ IS_DUE BOOLEAN,
280
+ IS_OPEN BOOLEAN,
281
+ IS_OVERDUE BOOLEAN,
282
+ COUNTRY_NAME TEXT,
283
+ )
284
+
285
+ -- Using valid SQLite, answer the following questions for the tables provided above.
286
+
287
+ -- how many work orders are open?
288
+
289
+ SELECT"""
290
+
291
+ model = client.get_model(model_uid=uid)
292
+ model.generate(prompt=text)
293
+ ```
294
+
295
+ Result:
296
+ ```json
297
+ {
298
+ "id":"aeb5c87a-352e-11ee-89ad-9af9f16816c5",
299
+ "object":"text_completion",
300
+ "created":1691418511,
301
+ "model":"3b912fc4-352e-11ee-8e66-9af9f16816c5",
302
+ "choices":[
303
+ {
304
+ "text":" COUNT(*) FROM work_orders WHERE IS_OPEN = '1';",
305
+ "index":0,
306
+ "logprobs":"None",
307
+ "finish_reason":"stop"
308
+ }
309
+ ],
310
+ "usage":{
311
+ "prompt_tokens":117,
312
+ "completion_tokens":17,
313
+ "total_tokens":134
314
+ }
315
+ }
316
+ ```
213
317
 
214
318
  ## Pytorch Model Best Practices
215
319
 
@@ -60,7 +60,6 @@ dev =
60
60
  flake8>=3.8.0
61
61
  black
62
62
  all =
63
- chatglm-cpp
64
63
  llama-cpp-python>=0.1.77
65
64
  transformers>=4.31.0
66
65
  torch
@@ -72,7 +71,6 @@ all =
72
71
  einops
73
72
  tiktoken
74
73
  ggml =
75
- chatglm-cpp
76
74
  llama-cpp-python>=0.1.77
77
75
  pytorch =
78
76
  transformers>=4.31.0
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2023-08-04T18:35:56+0800",
11
+ "date": "2023-08-09T18:43:41+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "98765f249b05b51514078cc97b88e92ce40e6948",
15
- "version": "0.1.2"
14
+ "full-revisionid": "4d2f61cb6591ac94624f035b37259a89002abefd",
15
+ "version": "0.1.3"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -108,6 +108,7 @@ class WorkerActor(xo.Actor):
108
108
  "model_format": llm_spec.model_format,
109
109
  "model_size_in_billions": llm_spec.model_size_in_billions,
110
110
  "quantization": quantization,
111
+ "revision": llm_spec.model_revision,
111
112
  }
112
113
 
113
114
  @log_sync(logger=logger)
@@ -11,8 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
-
14
+ import configparser
16
15
  import logging
17
16
  import os
18
17
  import sys
@@ -31,6 +30,32 @@ from ..constants import (
31
30
  )
32
31
 
33
32
 
33
+ def get_config_string(log_level: str) -> str:
34
+ return f"""
35
+ [loggers]
36
+ keys=root
37
+
38
+ [handlers]
39
+ keys=stream_handler
40
+
41
+ [formatters]
42
+ keys=formatter
43
+
44
+ [logger_root]
45
+ level={log_level.upper()}
46
+ handlers=stream_handler
47
+
48
+ [handler_stream_handler]
49
+ class=StreamHandler
50
+ formatter=formatter
51
+ level={log_level.upper()}
52
+ args=(sys.stderr,)
53
+
54
+ [formatter_formatter]
55
+ format=%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s
56
+ """
57
+
58
+
34
59
  def get_endpoint(endpoint: Optional[str]) -> str:
35
60
  # user didn't specify the endpoint.
36
61
  if endpoint is None:
@@ -58,9 +83,10 @@ def cli(
58
83
  if ctx.invoked_subcommand is None:
59
84
  from .local import main
60
85
 
61
- if log_level:
62
- logging.basicConfig(level=logging.getLevelName(log_level.upper()))
63
- logging_conf = dict(level=log_level.upper())
86
+ logging_conf = configparser.RawConfigParser()
87
+ logger_config_string = get_config_string(log_level)
88
+ logging_conf.read_string(logger_config_string)
89
+ logging.config.fileConfig(logging_conf) # type: ignore
64
90
 
65
91
  address = f"{host}:{get_next_port()}"
66
92
 
@@ -103,9 +129,10 @@ def supervisor(
103
129
  def worker(log_level: str, endpoint: Optional[str], host: str):
104
130
  from ..deploy.worker import main
105
131
 
106
- if log_level:
107
- logging.basicConfig(level=logging.getLevelName(log_level.upper()))
108
- logging_conf = dict(level=log_level.upper())
132
+ logging_conf = configparser.RawConfigParser()
133
+ logger_config_string = get_config_string(log_level)
134
+ logging_conf.read_string(logger_config_string)
135
+ logging.config.fileConfig(level=logging.getLevelName(log_level.upper())) # type: ignore
109
136
 
110
137
  endpoint = get_endpoint(endpoint)
111
138
 
@@ -14,7 +14,7 @@
14
14
 
15
15
  import asyncio
16
16
  import logging
17
- from typing import Dict, Optional
17
+ from typing import Any, Dict, Optional
18
18
 
19
19
  import xoscar as xo
20
20
 
@@ -53,7 +53,7 @@ async def _start_worker(
53
53
  await pool.join()
54
54
 
55
55
 
56
- def main(address: str, supervisor_address: str, logging_conf: Optional[Dict] = None):
56
+ def main(address: str, supervisor_address: str, logging_conf: Any = None):
57
57
  loop = asyncio.get_event_loop()
58
58
  task = loop.create_task(_start_worker(address, supervisor_address, logging_conf))
59
59
 
@@ -139,6 +139,7 @@ class LlamaCppModel(LLM):
139
139
  llamacpp_model_config["n_gqa"] = 8
140
140
 
141
141
  if self._is_darwin_and_apple_silicon() and self._can_apply_metal():
142
+ # TODO: platform.processor() is not safe, need to be replaced to other method.
142
143
  llamacpp_model_config.setdefault("n_gpu_layers", 1)
143
144
  elif self._is_linux() and self._can_apply_cublas():
144
145
  llamacpp_model_config.setdefault("n_gpu_layers", self._gpu_layers)
@@ -41,7 +41,8 @@
41
41
  "8-bit",
42
42
  "none"
43
43
  ],
44
- "model_id": "baichuan-inc/Baichuan-7B"
44
+ "model_id": "baichuan-inc/Baichuan-7B",
45
+ "model_revision": "c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756"
45
46
  },
46
47
  {
47
48
  "model_format": "pytorch",
@@ -51,7 +52,8 @@
51
52
  "8-bit",
52
53
  "none"
53
54
  ],
54
- "model_id": "baichuan-inc/Baichuan-13B-Base"
55
+ "model_id": "baichuan-inc/Baichuan-13B-Base",
56
+ "model_revision": "0ef0739c7bdd34df954003ef76d80f3dabca2ff9"
55
57
  }
56
58
  ],
57
59
  "prompt_style": null
@@ -98,7 +100,8 @@
98
100
  "8-bit",
99
101
  "none"
100
102
  ],
101
- "model_id": "baichuan-inc/Baichuan-13B-Chat"
103
+ "model_id": "baichuan-inc/Baichuan-13B-Chat",
104
+ "model_revision": "19ef51ba5bad8935b03acd20ff04a269210983bc"
102
105
  }
103
106
  ],
104
107
  "prompt_style": {
@@ -267,7 +270,8 @@
267
270
  "8-bit",
268
271
  "none"
269
272
  ],
270
- "model_id": "lmsys/vicuna-33b-v1.3"
273
+ "model_id": "lmsys/vicuna-33b-v1.3",
274
+ "model_revision": "ef8d6becf883fb3ce52e3706885f761819477ab4"
271
275
  },
272
276
  {
273
277
  "model_format": "pytorch",
@@ -277,7 +281,8 @@
277
281
  "8-bit",
278
282
  "none"
279
283
  ],
280
- "model_id": "lmsys/vicuna-13b-v1.3"
284
+ "model_id": "lmsys/vicuna-13b-v1.3",
285
+ "model_revision": "6566e9cb1787585d1147dcf4f9bc48f29e1328d2"
281
286
  },
282
287
  {
283
288
  "model_format": "pytorch",
@@ -287,7 +292,8 @@
287
292
  "8-bit",
288
293
  "none"
289
294
  ],
290
- "model_id": "lmsys/vicuna-7b-v1.3"
295
+ "model_id": "lmsys/vicuna-7b-v1.3",
296
+ "model_revision": "236eeeab96f0dc2e463f2bebb7bb49809279c6d6"
291
297
  }
292
298
  ],
293
299
  "prompt_style": {
@@ -395,7 +401,8 @@
395
401
  "8-bit",
396
402
  "none"
397
403
  ],
398
- "model_id": "THUDM/chatglm-6b"
404
+ "model_id": "THUDM/chatglm-6b",
405
+ "model_revision": "b1502f4f75c71499a3d566b14463edd62620ce9f"
399
406
  }
400
407
  ],
401
408
  "prompt_style": {
@@ -441,7 +448,8 @@
441
448
  "8-bit",
442
449
  "none"
443
450
  ],
444
- "model_id": "THUDM/chatglm2-6b"
451
+ "model_id": "THUDM/chatglm2-6b",
452
+ "model_revision": "b1502f4f75c71499a3d566b14463edd62620ce9f"
445
453
  }
446
454
  ],
447
455
  "prompt_style": {
@@ -474,7 +482,8 @@
474
482
  "8-bit",
475
483
  "none"
476
484
  ],
477
- "model_id": "THUDM/chatglm2-6b-32k"
485
+ "model_id": "THUDM/chatglm2-6b-32k",
486
+ "model_revision": "455746d4706479a1cbbd07179db39eb2741dc692"
478
487
  }
479
488
  ],
480
489
  "prompt_style": {
@@ -643,7 +652,8 @@
643
652
  "8-bit",
644
653
  "none"
645
654
  ],
646
- "model_id": "facebook/opt-125m"
655
+ "model_id": "facebook/opt-125m",
656
+ "model_revision": "3d2b5f275bdf882b8775f902e1bfdb790e2cfc32"
647
657
  }
648
658
  ],
649
659
  "prompt_style": null
@@ -667,7 +677,8 @@
667
677
  "8-bit",
668
678
  "none"
669
679
  ],
670
- "model_id": "tiiuae/falcon-40b"
680
+ "model_id": "tiiuae/falcon-40b",
681
+ "model_revision": "561820f7eef0cc56a31ea38af15ca1acb07fab5d"
671
682
  },
672
683
  {
673
684
  "model_format": "pytorch",
@@ -677,7 +688,8 @@
677
688
  "8-bit",
678
689
  "none"
679
690
  ],
680
- "model_id": "tiiuae/falcon-7b"
691
+ "model_id": "tiiuae/falcon-7b",
692
+ "model_revision": "378337427557d1df3e742264a2901a49f25d4eb1"
681
693
  }
682
694
  ],
683
695
  "prompt_style": null
@@ -701,7 +713,8 @@
701
713
  "8-bit",
702
714
  "none"
703
715
  ],
704
- "model_id": "tiiuae/falcon-7b-instruct"
716
+ "model_id": "tiiuae/falcon-7b-instruct",
717
+ "model_revision": "eb410fb6ffa9028e97adb801f0d6ec46d02f8b07"
705
718
  },
706
719
  {
707
720
  "model_format": "pytorch",
@@ -711,7 +724,8 @@
711
724
  "8-bit",
712
725
  "none"
713
726
  ],
714
- "model_id": "tiiuae/falcon-40b-instruct"
727
+ "model_id": "tiiuae/falcon-40b-instruct",
728
+ "model_revision": "ca78eac0ed45bf64445ff0687fabba1598daebf3"
715
729
  }
716
730
  ],
717
731
  "prompt_style": {
@@ -759,7 +773,8 @@
759
773
  "8-bit",
760
774
  "none"
761
775
  ],
762
- "model_id": "Qwen/Qwen-7B-Chat"
776
+ "model_id": "Qwen/Qwen-7B-Chat",
777
+ "model_revision": "5c611a5cde5769440581f91e8b4bba050f62b1af"
763
778
  }
764
779
  ],
765
780
  "prompt_style": {
@@ -34,6 +34,7 @@ class GgmlLLMSpecV1(BaseModel):
34
34
  model_id: str
35
35
  model_file_name_template: str
36
36
  model_uri: Optional[str]
37
+ model_revision: Optional[str]
37
38
 
38
39
 
39
40
  class PytorchLLMSpecV1(BaseModel):
@@ -42,6 +43,7 @@ class PytorchLLMSpecV1(BaseModel):
42
43
  quantizations: List[str]
43
44
  model_id: str
44
45
  model_uri: Optional[str]
46
+ model_revision: Optional[str]
45
47
 
46
48
 
47
49
  class PromptStyleV1(BaseModel):
@@ -139,6 +141,7 @@ def cache_from_huggingface(
139
141
  assert isinstance(llm_spec, PytorchLLMSpecV1)
140
142
  huggingface_hub.snapshot_download(
141
143
  llm_spec.model_id,
144
+ revision=llm_spec.model_revision,
142
145
  local_dir=cache_dir,
143
146
  local_dir_use_symlinks=True,
144
147
  )
@@ -147,6 +150,7 @@ def cache_from_huggingface(
147
150
  file_name = llm_spec.model_file_name_template.format(quantization=quantization)
148
151
  huggingface_hub.hf_hub_download(
149
152
  llm_spec.model_id,
153
+ revision=llm_spec.model_revision,
150
154
  filename=file_name,
151
155
  local_dir=cache_dir,
152
156
  local_dir_use_symlinks=True,
@@ -160,13 +164,11 @@ def _is_linux():
160
164
 
161
165
 
162
166
  def _has_cuda_device():
163
- cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
164
- if cuda_visible_devices:
165
- return True
166
- else:
167
- from xorbits._mars.resource import cuda_count
167
+ # `cuda_count` method already contains the logic for the
168
+ # number of GPUs specified by `CUDA_VISIBLE_DEVICES`.
169
+ from xorbits._mars.resource import cuda_count
168
170
 
169
- return cuda_count() > 0
171
+ return cuda_count() > 0
170
172
 
171
173
 
172
174
  def get_user_defined_llm_families():
@@ -47,7 +47,7 @@ class PytorchGenerateConfig(TypedDict, total=False):
47
47
 
48
48
 
49
49
  class PytorchModelConfig(TypedDict, total=False):
50
- revision: str
50
+ revision: Optional[str]
51
51
  device: str
52
52
  gpus: Optional[str]
53
53
  num_gpus: int
@@ -79,17 +79,14 @@ class PytorchModel(LLM):
79
79
  ) -> PytorchModelConfig:
80
80
  if pytorch_model_config is None:
81
81
  pytorch_model_config = PytorchModelConfig()
82
- pytorch_model_config.setdefault("revision", "main")
82
+ pytorch_model_config.setdefault("revision", self.model_spec.model_revision)
83
83
  pytorch_model_config.setdefault("gpus", None)
84
84
  pytorch_model_config.setdefault("num_gpus", 1)
85
85
  pytorch_model_config.setdefault("gptq_ckpt", None)
86
86
  pytorch_model_config.setdefault("gptq_wbits", 16)
87
87
  pytorch_model_config.setdefault("gptq_groupsize", -1)
88
88
  pytorch_model_config.setdefault("gptq_act_order", False)
89
- if self._is_darwin_and_apple_silicon():
90
- pytorch_model_config.setdefault("device", "mps")
91
- else:
92
- pytorch_model_config.setdefault("device", "cuda")
89
+ pytorch_model_config.setdefault("device", "auto")
93
90
  return pytorch_model_config
94
91
 
95
92
  def _sanitize_generate_config(
@@ -142,26 +139,35 @@ class PytorchModel(LLM):
142
139
 
143
140
  quantization = self.quantization
144
141
  num_gpus = self._pytorch_model_config.get("num_gpus", 1)
145
- if self._is_darwin_and_apple_silicon():
146
- device = self._pytorch_model_config.get("device", "mps")
147
- else:
148
- device = self._pytorch_model_config.get("device", "cuda")
142
+ device = self._pytorch_model_config.get("device", "auto")
143
+ self._pytorch_model_config["device"] = self._select_device(device)
144
+ self._device = self._pytorch_model_config["device"]
149
145
 
150
- if device == "cpu":
146
+ if self._device == "cpu":
151
147
  kwargs = {"torch_dtype": torch.float32}
152
- elif device == "cuda":
148
+ elif self._device == "cuda":
153
149
  kwargs = {"torch_dtype": torch.float16}
154
- elif device == "mps":
150
+ elif self._device == "mps":
155
151
  kwargs = {"torch_dtype": torch.float16}
156
152
  else:
157
- raise ValueError(f"Device {device} is not supported in temporary")
158
- kwargs["revision"] = self._pytorch_model_config.get("revision", "main")
153
+ raise ValueError(f"Device {self._device} is not supported in temporary")
154
+
155
+ kwargs["revision"] = self._pytorch_model_config.get(
156
+ "revision", self.model_spec.model_revision
157
+ )
159
158
 
160
159
  if quantization != "none":
161
- if device == "cuda" and self._is_linux():
160
+ if self._device == "cuda" and self._is_linux():
162
161
  kwargs["device_map"] = "auto"
163
162
  if quantization == "4-bit":
164
163
  kwargs["load_in_4bit"] = True
164
+ kwargs["bnb_4bit_compute_dtype"] = torch.float16
165
+ kwargs["bnb_4bit_use_double_quant"] = True
166
+ kwargs["llm_int8_skip_modules"] = [
167
+ "lm_head",
168
+ "encoder",
169
+ "EncDecAttention",
170
+ ]
165
171
  elif quantization == "8-bit":
166
172
  kwargs["load_in_8bit"] = True
167
173
  else:
@@ -178,7 +184,7 @@ class PytorchModel(LLM):
178
184
  else:
179
185
  self._model, self._tokenizer = load_compress_model(
180
186
  model_path=self.model_path,
181
- device=device,
187
+ device=self._device,
182
188
  torch_dtype=kwargs["torch_dtype"],
183
189
  use_fast=self._use_fast_tokenizer,
184
190
  revision=kwargs["revision"],
@@ -189,11 +195,37 @@ class PytorchModel(LLM):
189
195
  self._model, self._tokenizer = self._load_model(kwargs)
190
196
 
191
197
  if (
192
- device == "cuda" and num_gpus == 1 and quantization == "none"
193
- ) or device == "mps":
194
- self._model.to(device)
198
+ self._device == "cuda" and num_gpus == 1 and quantization == "none"
199
+ ) or self._device == "mps":
200
+ self._model.to(self._device)
195
201
  logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")
196
202
 
203
+ def _select_device(self, device: str) -> str:
204
+ try:
205
+ import torch
206
+ except ImportError:
207
+ raise ImportError(
208
+ f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
209
+ )
210
+
211
+ if device == "auto":
212
+ if torch.cuda.is_available():
213
+ return "cuda"
214
+ elif torch.backends.mps.is_available():
215
+ return "mps"
216
+ return "cpu"
217
+ elif device == "cuda":
218
+ if not torch.cuda.is_available():
219
+ raise ValueError("cuda is unavailable in your environment")
220
+ elif device == "mps":
221
+ if not torch.backends.mps.is_available():
222
+ raise ValueError("mps is unavailable in your environment")
223
+ elif device == "cpu":
224
+ pass
225
+ else:
226
+ raise ValueError(f"Device {device} is not supported in temporary")
227
+ return device
228
+
197
229
  @classmethod
198
230
  def match(cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1") -> bool:
199
231
  if llm_spec.model_format != "pytorch":
@@ -222,21 +254,21 @@ class PytorchModel(LLM):
222
254
  )
223
255
 
224
256
  def generator_wrapper(
225
- prompt: str, device: str, generate_config: PytorchGenerateConfig
257
+ prompt: str, generate_config: PytorchGenerateConfig
226
258
  ) -> Iterator[CompletionChunk]:
227
259
  if "falcon" in self.model_family.model_name:
228
260
  for completion_chunk, _ in generate_stream_falcon(
229
- self._model, self._tokenizer, prompt, device, generate_config
261
+ self._model, self._tokenizer, prompt, self._device, generate_config
230
262
  ):
231
263
  yield completion_chunk
232
264
  elif "chatglm" in self.model_family.model_name:
233
265
  for completion_chunk, _ in generate_stream_chatglm(
234
- self._model, self._tokenizer, prompt, device, generate_config
266
+ self._model, self._tokenizer, prompt, self._device, generate_config
235
267
  ):
236
268
  yield completion_chunk
237
269
  else:
238
270
  for completion_chunk, _ in generate_stream(
239
- self._model, self._tokenizer, prompt, device, generate_config
271
+ self._model, self._tokenizer, prompt, self._device, generate_config
240
272
  ):
241
273
  yield completion_chunk
242
274
 
@@ -250,24 +282,20 @@ class PytorchModel(LLM):
250
282
  assert self._tokenizer is not None
251
283
 
252
284
  stream = generate_config.get("stream", False)
253
- if self._is_darwin_and_apple_silicon():
254
- device = self._pytorch_model_config.get("device", "mps")
255
- else:
256
- device = self._pytorch_model_config.get("device", "cuda")
257
285
  if not stream:
258
286
  if "falcon" in self.model_family.model_name:
259
287
  for completion_chunk, completion_usage in generate_stream_falcon(
260
- self._model, self._tokenizer, prompt, device, generate_config
288
+ self._model, self._tokenizer, prompt, self._device, generate_config
261
289
  ):
262
290
  pass
263
291
  elif "chatglm" in self.model_family.model_name:
264
292
  for completion_chunk, completion_usage in generate_stream_chatglm(
265
- self._model, self._tokenizer, prompt, device, generate_config
293
+ self._model, self._tokenizer, prompt, self._device, generate_config
266
294
  ):
267
295
  pass
268
296
  else:
269
297
  for completion_chunk, completion_usage in generate_stream(
270
- self._model, self._tokenizer, prompt, device, generate_config
298
+ self._model, self._tokenizer, prompt, self._device, generate_config
271
299
  ):
272
300
  pass
273
301
  completion = Completion(
@@ -280,7 +308,7 @@ class PytorchModel(LLM):
280
308
  )
281
309
  return completion
282
310
  else:
283
- return generator_wrapper(prompt, device, generate_config)
311
+ return generator_wrapper(prompt, generate_config)
284
312
 
285
313
  def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
286
314
  try:
@@ -291,11 +319,6 @@ class PytorchModel(LLM):
291
319
  "Could not import torch. Please install it with `pip install torch`."
292
320
  ) from e
293
321
 
294
- if self._is_darwin_and_apple_silicon():
295
- device = self._pytorch_model_config.get("device", "mps")
296
- else:
297
- device = self._pytorch_model_config.get("device", "cuda")
298
-
299
322
  if isinstance(input, str):
300
323
  inputs = [input]
301
324
  else:
@@ -308,8 +331,8 @@ class PytorchModel(LLM):
308
331
  encoding = tokenizer.batch_encode_plus(
309
332
  inputs, padding=True, return_tensors="pt"
310
333
  )
311
- input_ids = encoding["input_ids"].to(device)
312
- attention_mask = encoding["attention_mask"].to(device)
334
+ input_ids = encoding["input_ids"].to(self._device)
335
+ attention_mask = encoding["attention_mask"].to(self._device)
313
336
  model_output = self._model(
314
337
  input_ids, attention_mask, output_hidden_states=True
315
338
  )
@@ -342,7 +365,7 @@ class PytorchModel(LLM):
342
365
  embedding = []
343
366
  token_num = 0
344
367
  for index, text in enumerate(inputs):
345
- input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
368
+ input_ids = tokenizer.encode(text, return_tensors="pt").to(self._device)
346
369
  model_output = self._model(input_ids, output_hidden_states=True)
347
370
  if is_chatglm:
348
371
  data = (model_output.hidden_states[-1].transpose(0, 1))[0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: xinference
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Model Serving Made Easy
5
5
  Home-page: https://github.com/xorbitsai/inference
6
6
  Author: Qin Xuye
@@ -238,6 +238,110 @@ $ xinference list --all
238
238
  - If you want to use Apple Metal GPU for acceleration, please choose the q4_0 and q4_1 quantization methods.
239
239
  - `llama-2-chat` 70B ggmlv3 model only supports q4_0 quantization currently.
240
240
 
241
+ ## Custom models \[Experimental\]
242
+ Custom models are currently an experimental feature and are expected to be officially released in version v0.2.0.
243
+
244
+ Define a custom model based on the following template:
245
+ ```python
246
+ custom_model = {
247
+ "version": 1,
248
+ # model name. must start with a letter or a
249
+ # digit, and can only contain letters, digits,
250
+ # underscores, or dashes.
251
+ "model_name": "nsql-2B",
252
+ # supported languages
253
+ "model_lang": [
254
+ "en"
255
+ ],
256
+ # model abilities. could be "embed", "generate"
257
+ # and "chat".
258
+ "model_ability": [
259
+ "generate"
260
+ ],
261
+ # model specifications.
262
+ "model_specs": [
263
+ {
264
+ # model format.
265
+ "model_format": "pytorch",
266
+ "model_size_in_billions": 2,
267
+ # quantizations.
268
+ "quantizations": [
269
+ "4-bit",
270
+ "8-bit",
271
+ "none"
272
+ ],
273
+ # hugging face model ID.
274
+ "model_id": "NumbersStation/nsql-2B"
275
+ }
276
+ ],
277
+ # prompt style, required by chat models.
278
+ # for more details, see: xinference/model/llm/tests/test_utils.py
279
+ "prompt_style": None
280
+ }
281
+ ```
282
+
283
+ Register the custom model:
284
+ ```python
285
+ import json
286
+
287
+ from xinference.client import Client
288
+
289
+ # replace with real xinference endpoint
290
+ endpoint = "http://localhost:9997"
291
+ client = Client(endpoint)
292
+ client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False)
293
+ ```
294
+
295
+ Load the custom model:
296
+ ```python
297
+ uid = client.launch_model(model_name='nsql-2B')
298
+ ```
299
+
300
+ Run the custom model:
301
+ ```python
302
+ text = """CREATE TABLE work_orders (
303
+ ID NUMBER,
304
+ CREATED_AT TEXT,
305
+ COST FLOAT,
306
+ INVOICE_AMOUNT FLOAT,
307
+ IS_DUE BOOLEAN,
308
+ IS_OPEN BOOLEAN,
309
+ IS_OVERDUE BOOLEAN,
310
+ COUNTRY_NAME TEXT,
311
+ )
312
+
313
+ -- Using valid SQLite, answer the following questions for the tables provided above.
314
+
315
+ -- how many work orders are open?
316
+
317
+ SELECT"""
318
+
319
+ model = client.get_model(model_uid=uid)
320
+ model.generate(prompt=text)
321
+ ```
322
+
323
+ Result:
324
+ ```json
325
+ {
326
+ "id":"aeb5c87a-352e-11ee-89ad-9af9f16816c5",
327
+ "object":"text_completion",
328
+ "created":1691418511,
329
+ "model":"3b912fc4-352e-11ee-8e66-9af9f16816c5",
330
+ "choices":[
331
+ {
332
+ "text":" COUNT(*) FROM work_orders WHERE IS_OPEN = '1';",
333
+ "index":0,
334
+ "logprobs":"None",
335
+ "finish_reason":"stop"
336
+ }
337
+ ],
338
+ "usage":{
339
+ "prompt_tokens":117,
340
+ "completion_tokens":17,
341
+ "total_tokens":134
342
+ }
343
+ }
344
+ ```
241
345
 
242
346
  ## Pytorch Model Best Practices
243
347
 
@@ -13,7 +13,6 @@ huggingface-hub<1.0,>=0.14.1
13
13
  typing_extensions
14
14
 
15
15
  [all]
16
- chatglm-cpp
17
16
  llama-cpp-python>=0.1.77
18
17
  transformers>=4.31.0
19
18
  torch
@@ -51,7 +50,6 @@ pydata-sphinx-theme>=0.3.0
51
50
  sphinx-intl>=0.9.9
52
51
 
53
52
  [ggml]
54
- chatglm-cpp
55
53
  llama-cpp-python>=0.1.77
56
54
 
57
55
  [pytorch]
File without changes
File without changes
File without changes
File without changes
File without changes