xinference 0.7.5__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (120) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/__init__.py +13 -0
  3. xinference/api/oauth2/common.py +14 -0
  4. xinference/api/oauth2/core.py +93 -0
  5. xinference/api/oauth2/types.py +36 -0
  6. xinference/api/oauth2/utils.py +44 -0
  7. xinference/api/restful_api.py +216 -27
  8. xinference/client/oscar/actor_client.py +18 -18
  9. xinference/client/restful/restful_client.py +96 -33
  10. xinference/conftest.py +63 -1
  11. xinference/constants.py +1 -0
  12. xinference/core/chat_interface.py +143 -3
  13. xinference/core/metrics.py +83 -0
  14. xinference/core/model.py +244 -181
  15. xinference/core/status_guard.py +86 -0
  16. xinference/core/supervisor.py +57 -7
  17. xinference/core/worker.py +134 -13
  18. xinference/deploy/cmdline.py +142 -16
  19. xinference/deploy/local.py +39 -7
  20. xinference/deploy/supervisor.py +2 -0
  21. xinference/deploy/worker.py +33 -5
  22. xinference/fields.py +4 -1
  23. xinference/model/core.py +8 -1
  24. xinference/model/embedding/core.py +3 -2
  25. xinference/model/embedding/model_spec_modelscope.json +60 -18
  26. xinference/model/image/stable_diffusion/core.py +4 -3
  27. xinference/model/llm/__init__.py +7 -0
  28. xinference/model/llm/ggml/llamacpp.py +3 -2
  29. xinference/model/llm/llm_family.json +87 -3
  30. xinference/model/llm/llm_family.py +15 -5
  31. xinference/model/llm/llm_family_modelscope.json +92 -3
  32. xinference/model/llm/pytorch/chatglm.py +70 -28
  33. xinference/model/llm/pytorch/core.py +11 -30
  34. xinference/model/llm/pytorch/internlm2.py +155 -0
  35. xinference/model/llm/pytorch/utils.py +0 -153
  36. xinference/model/llm/utils.py +37 -8
  37. xinference/model/llm/vllm/core.py +15 -3
  38. xinference/model/multimodal/__init__.py +15 -8
  39. xinference/model/multimodal/core.py +8 -1
  40. xinference/model/multimodal/model_spec.json +9 -0
  41. xinference/model/multimodal/model_spec_modelscope.json +45 -0
  42. xinference/model/multimodal/qwen_vl.py +5 -9
  43. xinference/model/utils.py +7 -2
  44. xinference/types.py +2 -0
  45. xinference/web/ui/build/asset-manifest.json +3 -3
  46. xinference/web/ui/build/index.html +1 -1
  47. xinference/web/ui/build/static/js/main.b83095c2.js +3 -0
  48. xinference/web/ui/build/static/js/{main.236e72e7.js.LICENSE.txt → main.b83095c2.js.LICENSE.txt} +7 -0
  49. xinference/web/ui/build/static/js/main.b83095c2.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/0a853b2fa1902551e262a2f1a4b7894341f27b3dd9587f2ef7aaea195af89518.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/193e7ba39e70d4bb2895a5cb317f6f293a5fd02e7e324c02a1eba2f83216419c.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/22858de5265f2d279fca9f2f54dfb147e4b2704200dfb5d2ad3ec9769417328f.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/27696db5fcd4fcf0e7974cadf1e4a2ab89690474045c3188eafd586323ad13bb.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/27bdbe25deab8cf08f7fab8f05f8f26cf84a98809527a37986a4ab73a57ba96a.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/2bee7b8bd3d52976a45d6068e1333df88b943e0e679403c809e45382e3818037.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/30670751f55508ef3b861e13dd71b9e5a10d2561373357a12fc3831a0b77fd93.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/3605cd3a96ff2a3b443c70a101575482279ad26847924cab0684d165ba0d2492.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/3789ef437d3ecbf945bb9cea39093d1f16ebbfa32dbe6daf35abcfb6d48de6f1.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/4d933e35e0fe79867d3aa6c46db28804804efddf5490347cb6c2c2879762a157.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/4d96f071168af43965e0fab2ded658fa0a15b8d9ca03789a5ef9c5c16a4e3cee.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/4fd24800544873512b540544ae54601240a5bfefd9105ff647855c64f8ad828f.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/5c408307c982f07f9c09c85c98212d1b1c22548a9194c69548750a3016b91b88.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/663adbcb60b942e9cf094c8d9fabe57517f5e5e6e722d28b4948a40b7445a3b8.json +1 -0
  68. xinference/web/ui/node_modules/.cache/babel-loader/666bb2e1b250dc731311a7e4880886177885dfa768508d2ed63e02630cc78725.json +1 -0
  69. xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/8b246d79cd3f6fc78f11777e6a6acca6a2c5d4ecce7f2dd4dcf9a48126440d3c.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/95c8cc049fadd23085d8623e1d43d70b614a4e52217676f186a417dca894aa09.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/a8070ce4b780b4a044218536e158a9e7192a6c80ff593fdc126fee43f46296b5.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/b4e4fccaf8f2489a29081f0bf3b191656bd452fb3c8b5e3c6d92d94f680964d5.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/b53eb7c7967f6577bd3e678293c44204fb03ffa7fdc1dd59d3099015c68f6f7f.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/bd04667474fd9cac2983b03725c218908a6cc0ee9128a5953cd00d26d4877f60.json +1 -0
  79. xinference/web/ui/node_modules/.cache/babel-loader/c230a727b8f68f0e62616a75e14a3d33026dc4164f2e325a9a8072d733850edb.json +1 -0
  80. xinference/web/ui/node_modules/.cache/babel-loader/d06af85a84e5c5a29d3acf2dbb5b30c0cf75c8aec4ab5f975e6096f944ee4324.json +1 -0
  81. xinference/web/ui/node_modules/.cache/babel-loader/d44a6eb6106e09082b691a315c9f6ce17fcfe25beb7547810e0d271ce3301cd2.json +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/d5e150bff31715977d8f537c970f06d4fe3de9909d7e8342244a83a9f6447121.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/de36e5c08fd524e341d664883dda6cb1745acc852a4f1b011a35a0b4615f72fa.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/f23ab356a8603d4a2aaa74388c2f381675c207d37c4d1c832df922e9655c9a6b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/f7c23b0922f4087b9e2e3e46f15c946b772daa46c28c3a12426212ecaf481deb.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/f95a8bd358eeb55fa2f49f1224cc2f4f36006359856744ff09ae4bb295f59ec1.json +1 -0
  88. xinference/web/ui/node_modules/.cache/babel-loader/fe5db70859503a54cbe71f9637e5a314cda88b1f0eecb733b6e6f837697db1ef.json +1 -0
  89. xinference/web/ui/node_modules/.package-lock.json +36 -0
  90. xinference/web/ui/node_modules/@types/cookie/package.json +30 -0
  91. xinference/web/ui/node_modules/@types/hoist-non-react-statics/package.json +33 -0
  92. xinference/web/ui/node_modules/react-cookie/package.json +55 -0
  93. xinference/web/ui/node_modules/universal-cookie/package.json +48 -0
  94. xinference/web/ui/package-lock.json +37 -0
  95. xinference/web/ui/package.json +3 -2
  96. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/METADATA +17 -6
  97. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/RECORD +101 -66
  98. xinference/web/ui/build/static/js/main.236e72e7.js +0 -3
  99. xinference/web/ui/build/static/js/main.236e72e7.js.map +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/0cccfbe5d963b8e31eb679f9d9677392839cedd04aa2956ac6b33cf19599d597.json +0 -1
  101. xinference/web/ui/node_modules/.cache/babel-loader/0f3b6cc71b7c83bdc85aa4835927aeb86af2ce0d2ac241917ecfbf90f75c6d27.json +0 -1
  102. xinference/web/ui/node_modules/.cache/babel-loader/2f651cf60b1bde50c0601c7110f77dd44819fb6e2501ff748a631724d91445d4.json +0 -1
  103. xinference/web/ui/node_modules/.cache/babel-loader/42bb623f337ad08ed076484185726e072ca52bb88e373d72c7b052db4c273342.json +0 -1
  104. xinference/web/ui/node_modules/.cache/babel-loader/57af83639c604bd3362d0f03f7505e81c6f67ff77bee7c6bb31f6e5523eba185.json +0 -1
  105. xinference/web/ui/node_modules/.cache/babel-loader/667753ce39ce1d4bcbf9a5f1a103d653be1d19d42f4e1fbaceb9b507679a52c7.json +0 -1
  106. xinference/web/ui/node_modules/.cache/babel-loader/66ed1bd4c06748c1b176a625c25c856997edc787856c73162f82f2b465c5d956.json +0 -1
  107. xinference/web/ui/node_modules/.cache/babel-loader/78f2521da2e2a98b075a2666cb782c7e2c019cd3c72199eecd5901c82d8655df.json +0 -1
  108. xinference/web/ui/node_modules/.cache/babel-loader/8d2b0b3c6988d1894694dcbbe708ef91cfe62d62dac317031f09915ced637953.json +0 -1
  109. xinference/web/ui/node_modules/.cache/babel-loader/9427ae7f1e94ae8dcd2333fb361e381f4054fde07394fe5448658e3417368476.json +0 -1
  110. xinference/web/ui/node_modules/.cache/babel-loader/bcee2b4e76b07620f9087989eb86d43c645ba3c7a74132cf926260af1164af0e.json +0 -1
  111. xinference/web/ui/node_modules/.cache/babel-loader/cc2ddd02ccc1dad1a2737ac247c79e6f6ed2c7836c6b68e511e3048f666b64af.json +0 -1
  112. xinference/web/ui/node_modules/.cache/babel-loader/d2e8e6665a7efc832b43907dadf4e3c896a59eaf8129f9a520882466c8f2e489.json +0 -1
  113. xinference/web/ui/node_modules/.cache/babel-loader/d8a42e9df7157de9f28eecefdf178fd113bf2280d28471b6e32a8a45276042df.json +0 -1
  114. xinference/web/ui/node_modules/.cache/babel-loader/e26750d9556e9741912333349e4da454c53dbfddbfc6002ab49518dcf02af745.json +0 -1
  115. xinference/web/ui/node_modules/.cache/babel-loader/ef42ec014d7bc373b874b2a1ff0dcd785490f125e913698bc049b0bd778e4d66.json +0 -1
  116. xinference/web/ui/node_modules/.cache/babel-loader/fe3eb4d76c79ca98833f686d642224eeeb94cc83ad14300d281623796d087f0a.json +0 -1
  117. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/LICENSE +0 -0
  118. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/WHEEL +0 -0
  119. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/entry_points.txt +0 -0
  120. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/top_level.txt +0 -0
@@ -11,13 +11,19 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import time
15
+ import uuid
14
16
  from typing import Any, Dict, Iterator, List, Optional, Union
15
17
 
16
18
  from ....types import (
17
19
  SPECIAL_TOOL_PROMPT,
18
20
  ChatCompletion,
21
+ ChatCompletionChoice,
19
22
  ChatCompletionChunk,
20
23
  ChatCompletionMessage,
24
+ CompletionChoice,
25
+ CompletionChunk,
26
+ CompletionUsage,
21
27
  PytorchGenerateConfig,
22
28
  )
23
29
  from ..llm_family import LLMFamilyV1, LLMSpecV1
@@ -106,38 +112,74 @@ class ChatglmPytorchChatModel(PytorchChatModel):
106
112
  generate_config: Optional[PytorchGenerateConfig] = None,
107
113
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
108
114
  tools = self._handle_tools(generate_config)
115
+ kwargs: Dict[str, Any] = {}
116
+ generate_config = generate_config or {}
117
+ temperature = generate_config.get("temperature")
118
+ if temperature is not None:
119
+ kwargs["temperature"] = float(temperature)
120
+ top_p = generate_config.get("top_p")
121
+ if top_p is not None:
122
+ kwargs["top_p"] = float(top_p)
123
+ max_length = generate_config.get("max_tokens")
124
+ if max_length is not None:
125
+ kwargs["max_length"] = int(max_length)
126
+ # Tool calls only works for non stream, so we call chat directly.
127
+ if prompt == SPECIAL_TOOL_PROMPT and chat_history:
128
+ tool_message = chat_history.pop()
129
+ content = tool_message.get("content")
130
+ assert content is not None
131
+ prompt = content
132
+ kwargs["role"] = "observation"
133
+ chat_history = [h for h in chat_history if not h.get("tool_calls")]
134
+ if not chat_history:
135
+ chat_history = []
109
136
  if tools:
110
- # Tool calls only works for non stream, so we call chat directly.
111
- kwargs: Dict[str, Any] = {}
112
- generate_config = generate_config or {}
113
- temperature = generate_config.get("temperature")
114
- if temperature is not None:
115
- kwargs["temperature"] = float(temperature)
116
- top_p = generate_config.get("top_p")
117
- if top_p is not None:
118
- kwargs["top_p"] = float(top_p)
119
- max_length = generate_config.get("max_tokens")
120
- if max_length is not None:
121
- kwargs["max_length"] = int(max_length)
122
- if prompt == SPECIAL_TOOL_PROMPT and chat_history:
123
- tool_message = chat_history.pop()
124
- content = tool_message.get("content")
125
- assert content is not None
126
- prompt = content
127
- kwargs["role"] = "observation"
128
- chat_history = [h for h in chat_history if not h.get("tool_calls")]
129
- if not chat_history:
130
- chat_history = []
131
137
  msg = self._model.chat(
132
138
  self._tokenizer, prompt, [tools] + chat_history, **kwargs
133
139
  )
134
140
  return self._tool_calls_completion(
135
- self.model_family.model_name, self.model_uid, msg, tools
141
+ self.model_family, self.model_uid, msg, tools
136
142
  )
137
143
  else:
138
- return super().chat(
139
- prompt=prompt,
140
- system_prompt=system_prompt,
141
- chat_history=chat_history,
142
- generate_config=generate_config,
143
- )
144
+ stream = generate_config.get("stream", False)
145
+ if stream:
146
+
147
+ def _stream_generator():
148
+ last_chunk_text_length = 0
149
+ for chunk_text, _ in self._model.stream_chat(
150
+ self._tokenizer, prompt, chat_history, **kwargs
151
+ ):
152
+ chunk_text = chunk_text[last_chunk_text_length:]
153
+ last_chunk_text_length += len(chunk_text)
154
+ completion_choice = CompletionChoice(
155
+ text=chunk_text, index=0, logprobs=None, finish_reason=None
156
+ )
157
+ yield CompletionChunk(
158
+ id=str(uuid.uuid1()),
159
+ object="text_completion",
160
+ created=int(time.time()),
161
+ model=self.model_uid,
162
+ choices=[completion_choice],
163
+ )
164
+
165
+ return self._to_chat_completion_chunks(_stream_generator())
166
+ else:
167
+ response, _ = self._model.chat(
168
+ self._tokenizer, prompt, chat_history, **kwargs
169
+ )
170
+ return ChatCompletion(
171
+ id="chat" + str(uuid.uuid1()),
172
+ object="chat.completion",
173
+ created=int(time.time()),
174
+ model=self.model_uid,
175
+ choices=[
176
+ ChatCompletionChoice(
177
+ index=0,
178
+ message={"role": "assistant", "content": response},
179
+ finish_reason="stop",
180
+ )
181
+ ],
182
+ usage=CompletionUsage(
183
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
184
+ ),
185
+ )
@@ -192,7 +192,8 @@ class PytorchModel(LLM):
192
192
  ) -> bool:
193
193
  if llm_spec.model_format not in ["pytorch", "gptq"]:
194
194
  return False
195
- if llm_family.model_name in [
195
+ model_family = llm_family.model_family or llm_family.model_name
196
+ if model_family in [
196
197
  "baichuan-chat",
197
198
  "vicuna-v1.3",
198
199
  "falcon",
@@ -211,11 +212,7 @@ class PytorchModel(LLM):
211
212
  def generate(
212
213
  self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
213
214
  ) -> Union[Completion, Iterator[CompletionChunk]]:
214
- from .utils import (
215
- generate_stream,
216
- generate_stream_chatglm,
217
- generate_stream_falcon,
218
- )
215
+ from .utils import generate_stream, generate_stream_falcon
219
216
 
220
217
  model_family_name = self.model_family.model_name.lower()
221
218
 
@@ -223,17 +220,7 @@ class PytorchModel(LLM):
223
220
  prompt: str, generate_config: PytorchGenerateConfig
224
221
  ) -> Iterator[CompletionChunk]:
225
222
  if "falcon" in model_family_name:
226
- for completion_chunk, _ in generate_stream_falcon(
227
- self.model_uid,
228
- self._model,
229
- self._tokenizer,
230
- prompt,
231
- self._device,
232
- generate_config,
233
- ):
234
- yield completion_chunk
235
- elif "chatglm" in model_family_name:
236
- for completion_chunk, _ in generate_stream_chatglm(
223
+ for completion_chunk, completion_usage in generate_stream_falcon(
237
224
  self.model_uid,
238
225
  self._model,
239
226
  self._tokenizer,
@@ -241,9 +228,10 @@ class PytorchModel(LLM):
241
228
  self._device,
242
229
  generate_config,
243
230
  ):
231
+ completion_chunk["usage"] = completion_usage
244
232
  yield completion_chunk
245
233
  else:
246
- for completion_chunk, _ in generate_stream(
234
+ for completion_chunk, completion_usage in generate_stream(
247
235
  self.model_uid,
248
236
  self._model,
249
237
  self._tokenizer,
@@ -251,6 +239,7 @@ class PytorchModel(LLM):
251
239
  self._device,
252
240
  generate_config,
253
241
  ):
242
+ completion_chunk["usage"] = completion_usage
254
243
  yield completion_chunk
255
244
 
256
245
  logger.debug(
@@ -274,16 +263,6 @@ class PytorchModel(LLM):
274
263
  generate_config,
275
264
  ):
276
265
  pass
277
- elif "chatglm" in model_family_name:
278
- for completion_chunk, completion_usage in generate_stream_chatglm(
279
- self.model_uid,
280
- self._model,
281
- self._tokenizer,
282
- prompt,
283
- self._device,
284
- generate_config,
285
- ):
286
- pass
287
266
  else:
288
267
  for completion_chunk, completion_usage in generate_stream(
289
268
  self.model_uid,
@@ -442,6 +421,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
442
421
  "chatglm2-32k",
443
422
  "llama-2",
444
423
  "llama-2-chat",
424
+ "internlm2-chat",
445
425
  ]:
446
426
  return False
447
427
  if "chat" not in llm_family.model_ability:
@@ -465,7 +445,8 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
465
445
 
466
446
  generate_config = self._sanitize_generate_config(generate_config)
467
447
  # TODO(codingl2k1): qwen hacky to set stop for function call.
468
- if tools and self.model_family.model_name == "qwen-chat":
448
+ model_family = self.model_family.model_family or self.model_family.model_name
449
+ if tools and "qwen-chat" == model_family:
469
450
  stop = generate_config.get("stop")
470
451
  if isinstance(stop, str):
471
452
  generate_config["stop"] = [stop, "Observation:"]
@@ -485,6 +466,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
485
466
  assert not isinstance(c, Iterator)
486
467
  if tools:
487
468
  return self._tool_calls_completion(
488
- self.model_family.model_name, self.model_uid, c, tools
469
+ self.model_family, self.model_uid, c, tools
489
470
  )
490
471
  return self._to_chat_completion(c)
@@ -0,0 +1,155 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import time
15
+ import uuid
16
+ from typing import Any, Dict, Iterator, List, Optional, Union
17
+
18
+ from ....types import (
19
+ ChatCompletion,
20
+ ChatCompletionChoice,
21
+ ChatCompletionChunk,
22
+ ChatCompletionMessage,
23
+ CompletionChoice,
24
+ CompletionChunk,
25
+ CompletionUsage,
26
+ PytorchGenerateConfig,
27
+ )
28
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
29
+ from .core import PytorchChatModel, PytorchModelConfig
30
+
31
+
32
+ class Internlm2PytorchChatModel(PytorchChatModel):
33
+ def __init__(
34
+ self,
35
+ model_uid: str,
36
+ model_family: "LLMFamilyV1",
37
+ model_spec: "LLMSpecV1",
38
+ quantization: str,
39
+ model_path: str,
40
+ pytorch_model_config: Optional[PytorchModelConfig] = None,
41
+ ):
42
+ super().__init__(
43
+ model_uid,
44
+ model_family,
45
+ model_spec,
46
+ quantization,
47
+ model_path,
48
+ pytorch_model_config=pytorch_model_config,
49
+ )
50
+
51
+ def _load_model(self, **kwargs):
52
+ try:
53
+ from transformers import AutoModel, AutoTokenizer
54
+ except ImportError:
55
+ error_message = "Failed to import module 'transformers'"
56
+ installation_guide = [
57
+ "Please make sure 'transformers' is installed. ",
58
+ "You can install it by `pip install transformers`\n",
59
+ ]
60
+
61
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
62
+
63
+ tokenizer = AutoTokenizer.from_pretrained(
64
+ self.model_path,
65
+ trust_remote_code=kwargs["trust_remote_code"],
66
+ encode_special_tokens=True,
67
+ revision=kwargs["revision"],
68
+ )
69
+ model = AutoModel.from_pretrained(
70
+ self.model_path,
71
+ **kwargs,
72
+ )
73
+ return model, tokenizer
74
+
75
+ @classmethod
76
+ def match(
77
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
78
+ ) -> bool:
79
+ if llm_spec.model_format != "pytorch":
80
+ return False
81
+ model_family = llm_family.model_family or llm_family.model_name
82
+ if model_family != "internlm2-chat":
83
+ return False
84
+ if "chat" not in llm_family.model_ability:
85
+ return False
86
+ return True
87
+
88
+ def chat(
89
+ self,
90
+ prompt: str,
91
+ system_prompt: Optional[str] = None,
92
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
93
+ generate_config: Optional[PytorchGenerateConfig] = None,
94
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
95
+ kwargs: Dict[str, Any] = {}
96
+ generate_config = generate_config or {}
97
+ temperature = generate_config.get("temperature")
98
+ if temperature is not None:
99
+ kwargs["temperature"] = float(temperature)
100
+ top_p = generate_config.get("top_p")
101
+ if top_p is not None:
102
+ kwargs["top_p"] = float(top_p)
103
+ max_new_tokens = generate_config.get("max_tokens")
104
+ if max_new_tokens is not None:
105
+ kwargs["max_length"] = int(max_new_tokens)
106
+
107
+ stream = generate_config.get("stream", False)
108
+ if chat_history:
109
+ input_history = [
110
+ (chat_history[i]["content"], (chat_history[i + 1]["content"]))
111
+ for i in range(0, len(chat_history), 2)
112
+ ]
113
+ else:
114
+ input_history = []
115
+ if stream:
116
+
117
+ def _stream_generator():
118
+ last_chunk_text_length = 0
119
+ for chunk_text, _ in self._model.stream_chat(
120
+ self._tokenizer, prompt, input_history, **kwargs
121
+ ):
122
+ chunk_text = chunk_text[last_chunk_text_length:]
123
+ last_chunk_text_length += len(chunk_text)
124
+ completion_choice = CompletionChoice(
125
+ text=chunk_text, index=0, logprobs=None, finish_reason=None
126
+ )
127
+ yield CompletionChunk(
128
+ id=str(uuid.uuid1()),
129
+ object="text_completion",
130
+ created=int(time.time()),
131
+ model=self.model_uid,
132
+ choices=[completion_choice],
133
+ )
134
+
135
+ return self._to_chat_completion_chunks(_stream_generator())
136
+ else:
137
+ response, _ = self._model.chat(
138
+ self._tokenizer, prompt, input_history, **kwargs
139
+ )
140
+ return ChatCompletion(
141
+ id="chat" + str(uuid.uuid1()),
142
+ object="chat.completion",
143
+ created=int(time.time()),
144
+ model=self.model_uid,
145
+ choices=[
146
+ ChatCompletionChoice(
147
+ index=0,
148
+ message={"role": "assistant", "content": response},
149
+ finish_reason="stop",
150
+ )
151
+ ],
152
+ usage=CompletionUsage(
153
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
154
+ ),
155
+ )
@@ -14,7 +14,6 @@
14
14
 
15
15
  import gc
16
16
  import logging
17
- import re
18
17
  import time
19
18
  import uuid
20
19
  from threading import Thread
@@ -23,7 +22,6 @@ from typing import Iterable, Iterator, Tuple
23
22
  import torch
24
23
  from transformers import GenerationConfig, TextIteratorStreamer
25
24
  from transformers.generation.logits_process import (
26
- LogitsProcessor,
27
25
  LogitsProcessorList,
28
26
  RepetitionPenaltyLogitsProcessor,
29
27
  TemperatureLogitsWarper,
@@ -480,154 +478,3 @@ def generate_stream_falcon(
480
478
  # clean
481
479
  gc.collect()
482
480
  torch.cuda.empty_cache()
483
-
484
-
485
- class InvalidScoreLogitsProcessor(LogitsProcessor):
486
- def __call__(
487
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
488
- ) -> torch.FloatTensor:
489
- if torch.isnan(scores).any() or torch.isinf(scores).any():
490
- scores.zero_()
491
- scores[..., 5] = 5e4
492
- return scores
493
-
494
-
495
- invalid_score_processor = InvalidScoreLogitsProcessor()
496
-
497
-
498
- def process_response(response):
499
- response = response.strip()
500
- response = response.replace("[[训练时间]]", "2023年")
501
- punkts = [
502
- [",", ","],
503
- ["!", "!"],
504
- [":", ":"],
505
- [";", ";"],
506
- ["\\?", "?"],
507
- ]
508
- for item in punkts:
509
- response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
510
- response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
511
- return response
512
-
513
-
514
- @torch.inference_mode()
515
- def generate_stream_chatglm(
516
- model_uid,
517
- model,
518
- tokenizer,
519
- prompt,
520
- device,
521
- generate_config,
522
- judge_sent_end=False,
523
- ):
524
- stream = generate_config.get("stream", False)
525
- temperature = float(generate_config.get("temperature", 1.0))
526
- repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
527
- top_p = float(generate_config.get("top_p", 1.0))
528
- max_new_tokens = int(generate_config.get("max_tokens", 256))
529
- echo = generate_config.get("echo", False)
530
- stop_str = generate_config.get("stop", None)
531
- eos_token_id = generate_config.get("stop_token_ids", [])
532
- eos_token_id.append(tokenizer.eos_token_id)
533
-
534
- inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
535
- input_echo_len = len(inputs["input_ids"][0])
536
- gen_kwargs = {
537
- "max_length": max_new_tokens + input_echo_len,
538
- "do_sample": True if temperature > 1e-5 else False,
539
- "top_p": top_p,
540
- "repetition_penalty": repetition_penalty,
541
- "logits_processor": [invalid_score_processor],
542
- }
543
- if temperature > 1e-5:
544
- gen_kwargs["temperature"] = temperature
545
-
546
- total_len = 0
547
- last_response_length = 0
548
- for total_ids in model.stream_generate(
549
- **inputs, eos_token_id=eos_token_id, **gen_kwargs
550
- ):
551
- total_ids = total_ids.tolist()[0]
552
- total_len = len(total_ids)
553
- if echo:
554
- output_ids = total_ids
555
- else:
556
- output_ids = total_ids[input_echo_len:]
557
- response = tokenizer.decode(output_ids)
558
- response = process_response(response)
559
-
560
- partially_stopped = False
561
- stopped = False
562
- if stop_str:
563
- if isinstance(stop_str, str):
564
- pos = response.rfind(stop_str, 0)
565
- if pos != -1:
566
- response = response[:pos]
567
- stopped = True
568
- else:
569
- partially_stopped = is_partial_stop(response, stop_str)
570
- elif isinstance(stop_str, Iterable):
571
- for each_stop in stop_str:
572
- pos = response.rfind(each_stop, 0)
573
- if pos != -1:
574
- response = response[:pos]
575
- stopped = True
576
- break
577
- else:
578
- partially_stopped = is_partial_stop(response, each_stop)
579
- if partially_stopped:
580
- break
581
- else:
582
- raise ValueError("Invalid stop field type.")
583
-
584
- if stream:
585
- response = response.strip("�")
586
- tmp_response_length = len(response)
587
- response = response[last_response_length:]
588
- last_response_length = tmp_response_length
589
-
590
- if not partially_stopped:
591
- completion_choice = CompletionChoice(
592
- text=response, index=0, logprobs=None, finish_reason=None
593
- )
594
- completion_chunk = CompletionChunk(
595
- id=str(uuid.uuid1()),
596
- object="text_completion",
597
- created=int(time.time()),
598
- model=model_uid,
599
- choices=[completion_choice],
600
- )
601
- completion_usage = CompletionUsage(
602
- prompt_tokens=input_echo_len,
603
- completion_tokens=(total_len - input_echo_len),
604
- total_tokens=total_len,
605
- )
606
-
607
- yield completion_chunk, completion_usage
608
-
609
- if stopped:
610
- break
611
-
612
- if total_len - input_echo_len == max_new_tokens - 1:
613
- finish_reason = "length"
614
- else:
615
- finish_reason = "stop"
616
-
617
- completion_choice = CompletionChoice(
618
- text=response, index=0, logprobs=None, finish_reason=finish_reason
619
- )
620
- completion_chunk = CompletionChunk(
621
- id=str(uuid.uuid1()),
622
- object="text_completion",
623
- created=int(time.time()),
624
- model=model_uid,
625
- choices=[completion_choice],
626
- )
627
- completion_usage = CompletionUsage(
628
- prompt_tokens=input_echo_len,
629
- completion_tokens=(total_len - input_echo_len),
630
- total_tokens=total_len,
631
- )
632
-
633
- yield completion_chunk, completion_usage
@@ -16,7 +16,7 @@ import json
16
16
  import logging
17
17
  import time
18
18
  import uuid
19
- from typing import AsyncGenerator, Dict, Iterator, List, Optional
19
+ from typing import AsyncGenerator, Dict, Iterator, List, Optional, cast
20
20
 
21
21
  from xinference.model.llm.llm_family import PromptStyleV1
22
22
 
@@ -299,6 +299,24 @@ Begin!"""
299
299
  )
300
300
  ret += chat_history[-1]["role"] + ":"
301
301
  return ret
302
+ elif prompt_style.style_name == "INTERNLM2":
303
+ ret = (
304
+ "<s>"
305
+ if prompt_style.system_prompt == ""
306
+ else "<s>[UNUSED_TOKEN_146]system\n"
307
+ + prompt_style.system_prompt
308
+ + prompt_style.intra_message_sep
309
+ + "\n"
310
+ )
311
+ for message in chat_history:
312
+ role = message["role"]
313
+ content = message["content"]
314
+
315
+ if content:
316
+ ret += role + "\n" + content + prompt_style.intra_message_sep + "\n"
317
+ else:
318
+ ret += role + "\n"
319
+ return ret
302
320
  elif prompt_style.style_name == "ADD_COLON_SINGLE_COT":
303
321
  ret = prompt_style.system_prompt + prompt_style.intra_message_sep
304
322
  for message in chat_history:
@@ -360,7 +378,7 @@ Begin!"""
360
378
 
361
379
  @classmethod
362
380
  def _to_chat_completion_chunk(cls, chunk: CompletionChunk) -> ChatCompletionChunk:
363
- return {
381
+ chat_chunk = {
364
382
  "id": "chat" + chunk["id"],
365
383
  "model": chunk["model"],
366
384
  "created": chunk["created"],
@@ -376,12 +394,16 @@ Begin!"""
376
394
  for i, choice in enumerate(chunk["choices"])
377
395
  ],
378
396
  }
397
+ usage = chunk.get("usage")
398
+ if usage is not None:
399
+ chat_chunk["usage"] = usage
400
+ return cast(ChatCompletionChunk, chat_chunk)
379
401
 
380
402
  @classmethod
381
403
  def _get_first_chat_completion_chunk(
382
404
  cls, chunk: CompletionChunk
383
405
  ) -> ChatCompletionChunk:
384
- return {
406
+ chat_chunk = {
385
407
  "id": "chat" + chunk["id"],
386
408
  "model": chunk["model"],
387
409
  "created": chunk["created"],
@@ -397,6 +419,10 @@ Begin!"""
397
419
  for i, choice in enumerate(chunk["choices"])
398
420
  ],
399
421
  }
422
+ usage = chunk.get("usage")
423
+ if usage is not None:
424
+ chat_chunk["usage"] = usage
425
+ return cast(ChatCompletionChunk, chat_chunk)
400
426
 
401
427
  @classmethod
402
428
  def _to_chat_completion_chunks(
@@ -494,16 +520,19 @@ Begin!"""
494
520
  return text, None, None
495
521
 
496
522
  @classmethod
497
- def _tool_calls_completion(cls, model_name, model_uid, c, tools):
523
+ def _tool_calls_completion(cls, model_family, model_uid, c, tools):
498
524
  _id = str(uuid.uuid4())
499
- if model_name == "gorilla-openfunctions-v1":
525
+ family = model_family.model_family or model_family.model_name
526
+ if "gorilla-openfunctions-v1" == family:
500
527
  content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
501
- elif model_name == "chatglm3":
528
+ elif "chatglm3" == family:
502
529
  content, func, args = cls._eval_chatglm3_arguments(c, tools)
503
- elif model_name == "qwen-chat":
530
+ elif "qwen-chat" == family:
504
531
  content, func, args = cls._eval_qwen_chat_arguments(c, tools)
505
532
  else:
506
- raise Exception(f"Model {model_name} is not support tool calls.")
533
+ raise Exception(
534
+ f"Model {model_family.model_name} is not support tool calls."
535
+ )
507
536
  logger.debug("Tool call content: %s, func: %s, args: %s", content, func, args)
508
537
 
509
538
  if content: