xinference 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

@@ -14,12 +14,15 @@
14
14
 
15
15
  import logging
16
16
  import os
17
- from typing import List, Optional, Union
17
+ import platform
18
+ from threading import Lock
19
+ from typing import List, Optional, Tuple, Type, Union
18
20
 
19
21
  from pydantic import BaseModel, Field
20
22
  from typing_extensions import Annotated, Literal
21
23
 
22
- from xinference.constants import XINFERENCE_CACHE_DIR
24
+ from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
25
+ from . import LLM
23
26
 
24
27
  logger = logging.getLogger(__name__)
25
28
 
@@ -30,7 +33,8 @@ class GgmlLLMSpecV1(BaseModel):
30
33
  quantizations: List[str]
31
34
  model_id: str
32
35
  model_file_name_template: str
33
- model_local_path: Optional[str]
36
+ model_uri: Optional[str]
37
+ model_revision: Optional[str]
34
38
 
35
39
 
36
40
  class PytorchLLMSpecV1(BaseModel):
@@ -38,7 +42,8 @@ class PytorchLLMSpecV1(BaseModel):
38
42
  model_size_in_billions: int
39
43
  quantizations: List[str]
40
44
  model_id: str
41
- model_local_path: Optional[str]
45
+ model_uri: Optional[str]
46
+ model_revision: Optional[str]
42
47
 
43
48
 
44
49
  class PromptStyleV1(BaseModel):
@@ -68,7 +73,14 @@ LLMSpecV1 = Annotated[
68
73
 
69
74
  LLMFamilyV1.update_forward_refs()
70
75
 
71
- LLM_FAMILIES: List[LLMFamilyV1] = []
76
+
77
+ LLM_CLASSES: List[Type[LLM]] = []
78
+
79
+ BUILTIN_LLM_FAMILIES: List["LLMFamilyV1"] = []
80
+
81
+ UD_LLM_FAMILIES: List["LLMFamilyV1"] = []
82
+
83
+ UD_LLM_FAMILIES_LOCK = Lock()
72
84
 
73
85
 
74
86
  def get_legacy_cache_path(
@@ -96,7 +108,18 @@ def cache(
96
108
  logger.debug("Legacy cache path exists: %s", legacy_cache_path)
97
109
  return os.path.dirname(legacy_cache_path)
98
110
  else:
99
- return cache_from_huggingface(llm_family, llm_spec, quantization)
111
+ if llm_spec.model_uri is not None:
112
+ return cache_from_uri(llm_family, llm_spec, quantization)
113
+ else:
114
+ return cache_from_huggingface(llm_family, llm_spec, quantization)
115
+
116
+
117
+ def cache_from_uri(
118
+ llm_family: LLMFamilyV1,
119
+ llm_spec: "LLMSpecV1",
120
+ quantization: Optional[str] = None,
121
+ ) -> str:
122
+ raise NotImplementedError
100
123
 
101
124
 
102
125
  def cache_from_huggingface(
@@ -110,7 +133,7 @@ def cache_from_huggingface(
110
133
  import huggingface_hub
111
134
 
112
135
  cache_dir_name = f"{llm_family.model_name}-{llm_spec.model_format}-{llm_spec.model_size_in_billions}b"
113
- cache_dir = os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name)
136
+ cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name))
114
137
  if not os.path.exists(cache_dir):
115
138
  os.makedirs(cache_dir, exist_ok=True)
116
139
 
@@ -118,6 +141,7 @@ def cache_from_huggingface(
118
141
  assert isinstance(llm_spec, PytorchLLMSpecV1)
119
142
  huggingface_hub.snapshot_download(
120
143
  llm_spec.model_id,
144
+ revision=llm_spec.model_revision,
121
145
  local_dir=cache_dir,
122
146
  local_dir_use_symlinks=True,
123
147
  )
@@ -126,9 +150,130 @@ def cache_from_huggingface(
126
150
  file_name = llm_spec.model_file_name_template.format(quantization=quantization)
127
151
  huggingface_hub.hf_hub_download(
128
152
  llm_spec.model_id,
153
+ revision=llm_spec.model_revision,
129
154
  filename=file_name,
130
155
  local_dir=cache_dir,
131
156
  local_dir_use_symlinks=True,
132
157
  )
133
158
 
134
159
  return cache_dir
160
+
161
+
162
+ def _is_linux():
163
+ return platform.system() == "Linux"
164
+
165
+
166
+ def _has_cuda_device():
167
+ # `cuda_count` method already contains the logic for the
168
+ # number of GPUs specified by `CUDA_VISIBLE_DEVICES`.
169
+ from xorbits._mars.resource import cuda_count
170
+
171
+ return cuda_count() > 0
172
+
173
+
174
+ def get_user_defined_llm_families():
175
+ with UD_LLM_FAMILIES_LOCK:
176
+ return UD_LLM_FAMILIES.copy()
177
+
178
+
179
+ def match_llm(
180
+ model_name: str,
181
+ model_format: Optional[str] = None,
182
+ model_size_in_billions: Optional[int] = None,
183
+ quantization: Optional[str] = None,
184
+ is_local_deployment: bool = False,
185
+ ) -> Optional[Tuple[LLMFamilyV1, LLMSpecV1, str]]:
186
+ """
187
+ Find an LLM family, spec, and quantization that satisfy given criteria.
188
+ """
189
+ user_defined_llm_families = get_user_defined_llm_families()
190
+
191
+ for family in BUILTIN_LLM_FAMILIES + user_defined_llm_families:
192
+ if model_name != family.model_name:
193
+ continue
194
+ for spec in family.model_specs:
195
+ if (
196
+ model_format
197
+ and model_format != spec.model_format
198
+ or model_size_in_billions
199
+ and model_size_in_billions != spec.model_size_in_billions
200
+ or quantization
201
+ and quantization not in spec.quantizations
202
+ ):
203
+ continue
204
+ if quantization:
205
+ return family, spec, quantization
206
+ else:
207
+ # by default, choose the most coarse-grained quantization.
208
+ # TODO: too hacky.
209
+ quantizations = spec.quantizations
210
+ quantizations.sort()
211
+ for q in quantizations:
212
+ if (
213
+ is_local_deployment
214
+ and not (_is_linux() and _has_cuda_device())
215
+ and q == "4-bit"
216
+ ):
217
+ logger.warning(
218
+ "Skipping %s for non-linux or non-cuda local deployment .",
219
+ q,
220
+ )
221
+ continue
222
+ return family, spec, q
223
+ return None
224
+
225
+
226
+ def register_llm(llm_family: LLMFamilyV1, persist: bool):
227
+ from .utils import is_valid_model_name
228
+
229
+ if not is_valid_model_name(llm_family.model_name):
230
+ raise ValueError(
231
+ f"Invalid model name {llm_family.model_name}. The model name must start with a letter"
232
+ f" or a digit, and can only contain letters, digits, underscores, or dashes."
233
+ )
234
+
235
+ with UD_LLM_FAMILIES_LOCK:
236
+ for family in BUILTIN_LLM_FAMILIES + UD_LLM_FAMILIES:
237
+ if llm_family.model_name == family.model_name:
238
+ raise ValueError(
239
+ f"Model name conflicts with existing model {family.model_name}"
240
+ )
241
+
242
+ UD_LLM_FAMILIES.append(llm_family)
243
+
244
+ if persist:
245
+ persist_path = os.path.join(
246
+ XINFERENCE_MODEL_DIR, "llm", f"{llm_family.model_name}.json"
247
+ )
248
+ os.makedirs(os.path.dirname(persist_path), exist_ok=True)
249
+ with open(persist_path, mode="w") as fd:
250
+ fd.write(llm_family.json())
251
+
252
+
253
+ def unregister_llm(model_name: str):
254
+ with UD_LLM_FAMILIES_LOCK:
255
+ llm_family = None
256
+ for i, f in enumerate(UD_LLM_FAMILIES):
257
+ if f.model_name == model_name:
258
+ llm_family = f
259
+ break
260
+ if llm_family:
261
+ UD_LLM_FAMILIES.remove(llm_family)
262
+
263
+ persist_path = os.path.join(
264
+ XINFERENCE_MODEL_DIR, "llm", f"{llm_family.model_name}.json"
265
+ )
266
+ if os.path.exists(persist_path):
267
+ os.remove(persist_path)
268
+ else:
269
+ raise ValueError(f"Model {model_name} not found")
270
+
271
+
272
+ def match_llm_cls(family: LLMFamilyV1, llm_spec: "LLMSpecV1") -> Optional[Type[LLM]]:
273
+ """
274
+ Find an LLM implementation for given LLM family and spec.
275
+ """
276
+ for cls in LLM_CLASSES:
277
+ if cls.match(family, llm_spec):
278
+ return cls
279
+ return None
@@ -47,7 +47,7 @@ class PytorchGenerateConfig(TypedDict, total=False):
47
47
 
48
48
 
49
49
  class PytorchModelConfig(TypedDict, total=False):
50
- revision: str
50
+ revision: Optional[str]
51
51
  device: str
52
52
  gpus: Optional[str]
53
53
  num_gpus: int
@@ -79,17 +79,14 @@ class PytorchModel(LLM):
79
79
  ) -> PytorchModelConfig:
80
80
  if pytorch_model_config is None:
81
81
  pytorch_model_config = PytorchModelConfig()
82
- pytorch_model_config.setdefault("revision", "main")
82
+ pytorch_model_config.setdefault("revision", self.model_spec.model_revision)
83
83
  pytorch_model_config.setdefault("gpus", None)
84
84
  pytorch_model_config.setdefault("num_gpus", 1)
85
85
  pytorch_model_config.setdefault("gptq_ckpt", None)
86
86
  pytorch_model_config.setdefault("gptq_wbits", 16)
87
87
  pytorch_model_config.setdefault("gptq_groupsize", -1)
88
88
  pytorch_model_config.setdefault("gptq_act_order", False)
89
- if self._is_darwin_and_apple_silicon():
90
- pytorch_model_config.setdefault("device", "mps")
91
- else:
92
- pytorch_model_config.setdefault("device", "cuda")
89
+ pytorch_model_config.setdefault("device", "auto")
93
90
  return pytorch_model_config
94
91
 
95
92
  def _sanitize_generate_config(
@@ -142,26 +139,35 @@ class PytorchModel(LLM):
142
139
 
143
140
  quantization = self.quantization
144
141
  num_gpus = self._pytorch_model_config.get("num_gpus", 1)
145
- if self._is_darwin_and_apple_silicon():
146
- device = self._pytorch_model_config.get("device", "mps")
147
- else:
148
- device = self._pytorch_model_config.get("device", "cuda")
142
+ device = self._pytorch_model_config.get("device", "auto")
143
+ self._pytorch_model_config["device"] = self._select_device(device)
144
+ self._device = self._pytorch_model_config["device"]
149
145
 
150
- if device == "cpu":
146
+ if self._device == "cpu":
151
147
  kwargs = {"torch_dtype": torch.float32}
152
- elif device == "cuda":
148
+ elif self._device == "cuda":
153
149
  kwargs = {"torch_dtype": torch.float16}
154
- elif device == "mps":
150
+ elif self._device == "mps":
155
151
  kwargs = {"torch_dtype": torch.float16}
156
152
  else:
157
- raise ValueError(f"Device {device} is not supported in temporary")
158
- kwargs["revision"] = self._pytorch_model_config.get("revision", "main")
153
+ raise ValueError(f"Device {self._device} is not supported in temporary")
154
+
155
+ kwargs["revision"] = self._pytorch_model_config.get(
156
+ "revision", self.model_spec.model_revision
157
+ )
159
158
 
160
159
  if quantization != "none":
161
- if device == "cuda" and self._is_linux():
160
+ if self._device == "cuda" and self._is_linux():
162
161
  kwargs["device_map"] = "auto"
163
162
  if quantization == "4-bit":
164
163
  kwargs["load_in_4bit"] = True
164
+ kwargs["bnb_4bit_compute_dtype"] = torch.float16
165
+ kwargs["bnb_4bit_use_double_quant"] = True
166
+ kwargs["llm_int8_skip_modules"] = [
167
+ "lm_head",
168
+ "encoder",
169
+ "EncDecAttention",
170
+ ]
165
171
  elif quantization == "8-bit":
166
172
  kwargs["load_in_8bit"] = True
167
173
  else:
@@ -178,7 +184,7 @@ class PytorchModel(LLM):
178
184
  else:
179
185
  self._model, self._tokenizer = load_compress_model(
180
186
  model_path=self.model_path,
181
- device=device,
187
+ device=self._device,
182
188
  torch_dtype=kwargs["torch_dtype"],
183
189
  use_fast=self._use_fast_tokenizer,
184
190
  revision=kwargs["revision"],
@@ -189,11 +195,37 @@ class PytorchModel(LLM):
189
195
  self._model, self._tokenizer = self._load_model(kwargs)
190
196
 
191
197
  if (
192
- device == "cuda" and num_gpus == 1 and quantization == "none"
193
- ) or device == "mps":
194
- self._model.to(device)
198
+ self._device == "cuda" and num_gpus == 1 and quantization == "none"
199
+ ) or self._device == "mps":
200
+ self._model.to(self._device)
195
201
  logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")
196
202
 
203
+ def _select_device(self, device: str) -> str:
204
+ try:
205
+ import torch
206
+ except ImportError:
207
+ raise ImportError(
208
+ f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
209
+ )
210
+
211
+ if device == "auto":
212
+ if torch.cuda.is_available():
213
+ return "cuda"
214
+ elif torch.backends.mps.is_available():
215
+ return "mps"
216
+ return "cpu"
217
+ elif device == "cuda":
218
+ if not torch.cuda.is_available():
219
+ raise ValueError("cuda is unavailable in your environment")
220
+ elif device == "mps":
221
+ if not torch.backends.mps.is_available():
222
+ raise ValueError("mps is unavailable in your environment")
223
+ elif device == "cpu":
224
+ pass
225
+ else:
226
+ raise ValueError(f"Device {device} is not supported in temporary")
227
+ return device
228
+
197
229
  @classmethod
198
230
  def match(cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1") -> bool:
199
231
  if llm_spec.model_format != "pytorch":
@@ -222,21 +254,21 @@ class PytorchModel(LLM):
222
254
  )
223
255
 
224
256
  def generator_wrapper(
225
- prompt: str, device: str, generate_config: PytorchGenerateConfig
257
+ prompt: str, generate_config: PytorchGenerateConfig
226
258
  ) -> Iterator[CompletionChunk]:
227
259
  if "falcon" in self.model_family.model_name:
228
260
  for completion_chunk, _ in generate_stream_falcon(
229
- self._model, self._tokenizer, prompt, device, generate_config
261
+ self._model, self._tokenizer, prompt, self._device, generate_config
230
262
  ):
231
263
  yield completion_chunk
232
264
  elif "chatglm" in self.model_family.model_name:
233
265
  for completion_chunk, _ in generate_stream_chatglm(
234
- self._model, self._tokenizer, prompt, device, generate_config
266
+ self._model, self._tokenizer, prompt, self._device, generate_config
235
267
  ):
236
268
  yield completion_chunk
237
269
  else:
238
270
  for completion_chunk, _ in generate_stream(
239
- self._model, self._tokenizer, prompt, device, generate_config
271
+ self._model, self._tokenizer, prompt, self._device, generate_config
240
272
  ):
241
273
  yield completion_chunk
242
274
 
@@ -250,24 +282,20 @@ class PytorchModel(LLM):
250
282
  assert self._tokenizer is not None
251
283
 
252
284
  stream = generate_config.get("stream", False)
253
- if self._is_darwin_and_apple_silicon():
254
- device = self._pytorch_model_config.get("device", "mps")
255
- else:
256
- device = self._pytorch_model_config.get("device", "cuda")
257
285
  if not stream:
258
286
  if "falcon" in self.model_family.model_name:
259
287
  for completion_chunk, completion_usage in generate_stream_falcon(
260
- self._model, self._tokenizer, prompt, device, generate_config
288
+ self._model, self._tokenizer, prompt, self._device, generate_config
261
289
  ):
262
290
  pass
263
291
  elif "chatglm" in self.model_family.model_name:
264
292
  for completion_chunk, completion_usage in generate_stream_chatglm(
265
- self._model, self._tokenizer, prompt, device, generate_config
293
+ self._model, self._tokenizer, prompt, self._device, generate_config
266
294
  ):
267
295
  pass
268
296
  else:
269
297
  for completion_chunk, completion_usage in generate_stream(
270
- self._model, self._tokenizer, prompt, device, generate_config
298
+ self._model, self._tokenizer, prompt, self._device, generate_config
271
299
  ):
272
300
  pass
273
301
  completion = Completion(
@@ -280,7 +308,7 @@ class PytorchModel(LLM):
280
308
  )
281
309
  return completion
282
310
  else:
283
- return generator_wrapper(prompt, device, generate_config)
311
+ return generator_wrapper(prompt, generate_config)
284
312
 
285
313
  def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
286
314
  try:
@@ -291,11 +319,6 @@ class PytorchModel(LLM):
291
319
  "Could not import torch. Please install it with `pip install torch`."
292
320
  ) from e
293
321
 
294
- if self._is_darwin_and_apple_silicon():
295
- device = self._pytorch_model_config.get("device", "mps")
296
- else:
297
- device = self._pytorch_model_config.get("device", "cuda")
298
-
299
322
  if isinstance(input, str):
300
323
  inputs = [input]
301
324
  else:
@@ -308,8 +331,8 @@ class PytorchModel(LLM):
308
331
  encoding = tokenizer.batch_encode_plus(
309
332
  inputs, padding=True, return_tensors="pt"
310
333
  )
311
- input_ids = encoding["input_ids"].to(device)
312
- attention_mask = encoding["attention_mask"].to(device)
334
+ input_ids = encoding["input_ids"].to(self._device)
335
+ attention_mask = encoding["attention_mask"].to(self._device)
313
336
  model_output = self._model(
314
337
  input_ids, attention_mask, output_hidden_states=True
315
338
  )
@@ -342,7 +365,7 @@ class PytorchModel(LLM):
342
365
  embedding = []
343
366
  token_num = 0
344
367
  for index, text in enumerate(inputs):
345
- input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
368
+ input_ids = tokenizer.encode(text, return_tensors="pt").to(self._device)
346
369
  model_output = self._model(input_ids, output_hidden_states=True)
347
370
  if is_chatglm:
348
371
  data = (model_output.hidden_states[-1].transpose(0, 1))[0]
@@ -104,7 +104,11 @@ def generate_stream(
104
104
  temperature, repetition_penalty, top_p, top_k
105
105
  )
106
106
 
107
- input_ids = tokenizer(prompt).input_ids
107
+ if "qwen" in str(type(model)).lower():
108
+ # TODO: hacky
109
+ input_ids = tokenizer(prompt, allowed_special="all").input_ids
110
+ else:
111
+ input_ids = tokenizer(prompt).input_ids
108
112
  output_ids = list(input_ids)
109
113
 
110
114
  if model.config.is_encoder_decoder:
@@ -192,3 +192,9 @@ class ChatModelMixin:
192
192
  ],
193
193
  "usage": completion["usage"],
194
194
  }
195
+
196
+
197
+ def is_valid_model_name(model_name: str) -> bool:
198
+ import re
199
+
200
+ return re.match(r"^[A-Za-z0-9][A-Za-z0-9_\-]*$", model_name) is not None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: xinference
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: Model Serving Made Easy
5
5
  Home-page: https://github.com/xorbitsai/inference
6
6
  Author: Qin Xuye
@@ -21,62 +21,62 @@ Description-Content-Type: text/markdown
21
21
  License-File: LICENSE
22
22
  Requires-Dist: xoscar
23
23
  Requires-Dist: xorbits
24
- Requires-Dist: gradio (>=3.35.0)
24
+ Requires-Dist: gradio >=3.35.0
25
25
  Requires-Dist: click
26
- Requires-Dist: tqdm (>=4.27)
26
+ Requires-Dist: tqdm >=4.27
27
27
  Requires-Dist: tabulate
28
28
  Requires-Dist: requests
29
- Requires-Dist: pydantic (<2)
29
+ Requires-Dist: pydantic <2
30
30
  Requires-Dist: fastapi
31
31
  Requires-Dist: uvicorn
32
32
  Requires-Dist: sse-starlette
33
- Requires-Dist: huggingface-hub (<1.0,>=0.14.1)
33
+ Requires-Dist: huggingface-hub <1.0,>=0.14.1
34
34
  Requires-Dist: typing-extensions
35
35
  Provides-Extra: all
36
- Requires-Dist: chatglm-cpp ; extra == 'all'
37
- Requires-Dist: llama-cpp-python (>=0.1.77) ; extra == 'all'
38
- Requires-Dist: transformers (>=4.31.0) ; extra == 'all'
36
+ Requires-Dist: llama-cpp-python >=0.1.77 ; extra == 'all'
37
+ Requires-Dist: transformers >=4.31.0 ; extra == 'all'
39
38
  Requires-Dist: torch ; extra == 'all'
40
- Requires-Dist: accelerate (>=0.20.3) ; extra == 'all'
39
+ Requires-Dist: accelerate >=0.20.3 ; extra == 'all'
41
40
  Requires-Dist: sentencepiece ; extra == 'all'
42
41
  Requires-Dist: transformers-stream-generator ; extra == 'all'
43
42
  Requires-Dist: bitsandbytes ; extra == 'all'
44
43
  Requires-Dist: protobuf ; extra == 'all'
45
44
  Requires-Dist: einops ; extra == 'all'
45
+ Requires-Dist: tiktoken ; extra == 'all'
46
46
  Provides-Extra: benchmark
47
47
  Requires-Dist: psutil ; extra == 'benchmark'
48
48
  Requires-Dist: pynvml ; extra == 'benchmark'
49
49
  Provides-Extra: dev
50
- Requires-Dist: cython (>=0.29) ; extra == 'dev'
51
- Requires-Dist: pytest (>=3.5.0) ; extra == 'dev'
52
- Requires-Dist: pytest-cov (>=2.5.0) ; extra == 'dev'
53
- Requires-Dist: pytest-timeout (>=1.2.0) ; extra == 'dev'
54
- Requires-Dist: pytest-forked (>=1.0) ; extra == 'dev'
55
- Requires-Dist: pytest-asyncio (>=0.14.0) ; extra == 'dev'
56
- Requires-Dist: ipython (>=6.5.0) ; extra == 'dev'
57
- Requires-Dist: sphinx (<5.0.0,>=3.0.0) ; extra == 'dev'
58
- Requires-Dist: pydata-sphinx-theme (>=0.3.0) ; extra == 'dev'
59
- Requires-Dist: sphinx-intl (>=0.9.9) ; extra == 'dev'
60
- Requires-Dist: jieba (>=0.42.0) ; extra == 'dev'
61
- Requires-Dist: flake8 (>=3.8.0) ; extra == 'dev'
50
+ Requires-Dist: cython >=0.29 ; extra == 'dev'
51
+ Requires-Dist: pytest >=3.5.0 ; extra == 'dev'
52
+ Requires-Dist: pytest-cov >=2.5.0 ; extra == 'dev'
53
+ Requires-Dist: pytest-timeout >=1.2.0 ; extra == 'dev'
54
+ Requires-Dist: pytest-forked >=1.0 ; extra == 'dev'
55
+ Requires-Dist: pytest-asyncio >=0.14.0 ; extra == 'dev'
56
+ Requires-Dist: ipython >=6.5.0 ; extra == 'dev'
57
+ Requires-Dist: sphinx <5.0.0,>=3.0.0 ; extra == 'dev'
58
+ Requires-Dist: pydata-sphinx-theme >=0.3.0 ; extra == 'dev'
59
+ Requires-Dist: sphinx-intl >=0.9.9 ; extra == 'dev'
60
+ Requires-Dist: jieba >=0.42.0 ; extra == 'dev'
61
+ Requires-Dist: flake8 >=3.8.0 ; extra == 'dev'
62
62
  Requires-Dist: black ; extra == 'dev'
63
63
  Provides-Extra: doc
64
- Requires-Dist: ipython (>=6.5.0) ; extra == 'doc'
65
- Requires-Dist: sphinx (<5.0.0,>=3.0.0) ; extra == 'doc'
66
- Requires-Dist: pydata-sphinx-theme (>=0.3.0) ; extra == 'doc'
67
- Requires-Dist: sphinx-intl (>=0.9.9) ; extra == 'doc'
64
+ Requires-Dist: ipython >=6.5.0 ; extra == 'doc'
65
+ Requires-Dist: sphinx <5.0.0,>=3.0.0 ; extra == 'doc'
66
+ Requires-Dist: pydata-sphinx-theme >=0.3.0 ; extra == 'doc'
67
+ Requires-Dist: sphinx-intl >=0.9.9 ; extra == 'doc'
68
68
  Provides-Extra: ggml
69
- Requires-Dist: chatglm-cpp ; extra == 'ggml'
70
- Requires-Dist: llama-cpp-python (>=0.1.77) ; extra == 'ggml'
69
+ Requires-Dist: llama-cpp-python >=0.1.77 ; extra == 'ggml'
71
70
  Provides-Extra: pytorch
72
- Requires-Dist: transformers (>=4.31.0) ; extra == 'pytorch'
71
+ Requires-Dist: transformers >=4.31.0 ; extra == 'pytorch'
73
72
  Requires-Dist: torch ; extra == 'pytorch'
74
- Requires-Dist: accelerate (>=0.20.3) ; extra == 'pytorch'
73
+ Requires-Dist: accelerate >=0.20.3 ; extra == 'pytorch'
75
74
  Requires-Dist: sentencepiece ; extra == 'pytorch'
76
75
  Requires-Dist: transformers-stream-generator ; extra == 'pytorch'
77
76
  Requires-Dist: bitsandbytes ; extra == 'pytorch'
78
77
  Requires-Dist: protobuf ; extra == 'pytorch'
79
78
  Requires-Dist: einops ; extra == 'pytorch'
79
+ Requires-Dist: tiktoken ; extra == 'pytorch'
80
80
 
81
81
  <div align="center">
82
82
  <img src="./assets/xorbits-logo.png" width="180px" alt="xorbits" />
@@ -290,6 +290,110 @@ $ xinference list --all
290
290
  - If you want to use Apple Metal GPU for acceleration, please choose the q4_0 and q4_1 quantization methods.
291
291
  - `llama-2-chat` 70B ggmlv3 model only supports q4_0 quantization currently.
292
292
 
293
+ ## Custom models \[Experimental\]
294
+ Custom models are currently an experimental feature and are expected to be officially released in version v0.2.0.
295
+
296
+ Define a custom model based on the following template:
297
+ ```python
298
+ custom_model = {
299
+ "version": 1,
300
+ # model name. must start with a letter or a
301
+ # digit, and can only contain letters, digits,
302
+ # underscores, or dashes.
303
+ "model_name": "nsql-2B",
304
+ # supported languages
305
+ "model_lang": [
306
+ "en"
307
+ ],
308
+ # model abilities. could be "embed", "generate"
309
+ # and "chat".
310
+ "model_ability": [
311
+ "generate"
312
+ ],
313
+ # model specifications.
314
+ "model_specs": [
315
+ {
316
+ # model format.
317
+ "model_format": "pytorch",
318
+ "model_size_in_billions": 2,
319
+ # quantizations.
320
+ "quantizations": [
321
+ "4-bit",
322
+ "8-bit",
323
+ "none"
324
+ ],
325
+ # hugging face model ID.
326
+ "model_id": "NumbersStation/nsql-2B"
327
+ }
328
+ ],
329
+ # prompt style, required by chat models.
330
+ # for more details, see: xinference/model/llm/tests/test_utils.py
331
+ "prompt_style": None
332
+ }
333
+ ```
334
+
335
+ Register the custom model:
336
+ ```python
337
+ import json
338
+
339
+ from xinference.client import Client
340
+
341
+ # replace with real xinference endpoint
342
+ endpoint = "http://localhost:9997"
343
+ client = Client(endpoint)
344
+ client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False)
345
+ ```
346
+
347
+ Load the custom model:
348
+ ```python
349
+ uid = client.launch_model(model_name='nsql-2B')
350
+ ```
351
+
352
+ Run the custom model:
353
+ ```python
354
+ text = """CREATE TABLE work_orders (
355
+ ID NUMBER,
356
+ CREATED_AT TEXT,
357
+ COST FLOAT,
358
+ INVOICE_AMOUNT FLOAT,
359
+ IS_DUE BOOLEAN,
360
+ IS_OPEN BOOLEAN,
361
+ IS_OVERDUE BOOLEAN,
362
+ COUNTRY_NAME TEXT,
363
+ )
364
+
365
+ -- Using valid SQLite, answer the following questions for the tables provided above.
366
+
367
+ -- how many work orders are open?
368
+
369
+ SELECT"""
370
+
371
+ model = client.get_model(model_uid=uid)
372
+ model.generate(prompt=text)
373
+ ```
374
+
375
+ Result:
376
+ ```json
377
+ {
378
+ "id":"aeb5c87a-352e-11ee-89ad-9af9f16816c5",
379
+ "object":"text_completion",
380
+ "created":1691418511,
381
+ "model":"3b912fc4-352e-11ee-8e66-9af9f16816c5",
382
+ "choices":[
383
+ {
384
+ "text":" COUNT(*) FROM work_orders WHERE IS_OPEN = '1';",
385
+ "index":0,
386
+ "logprobs":"None",
387
+ "finish_reason":"stop"
388
+ }
389
+ ],
390
+ "usage":{
391
+ "prompt_tokens":117,
392
+ "completion_tokens":17,
393
+ "total_tokens":134
394
+ }
395
+ }
396
+ ```
293
397
 
294
398
  ## Pytorch Model Best Practices
295
399