xinference 1.8.1__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (64) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +5 -0
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +3 -4
  6. xinference/deploy/local.py +5 -0
  7. xinference/deploy/worker.py +6 -0
  8. xinference/model/core.py +3 -0
  9. xinference/model/embedding/sentence_transformers/core.py +3 -4
  10. xinference/model/embedding/vllm/core.py +4 -3
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +22 -0
  13. xinference/model/llm/cache_manager.py +17 -3
  14. xinference/model/llm/harmony.py +245 -0
  15. xinference/model/llm/llm_family.json +293 -8
  16. xinference/model/llm/llm_family.py +1 -1
  17. xinference/model/llm/sglang/core.py +108 -5
  18. xinference/model/llm/transformers/core.py +15 -7
  19. xinference/model/llm/transformers/gemma3.py +1 -1
  20. xinference/model/llm/transformers/gpt_oss.py +91 -0
  21. xinference/model/llm/transformers/multimodal/core.py +1 -1
  22. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  23. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  24. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  25. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  26. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  27. xinference/model/llm/transformers/utils.py +1 -33
  28. xinference/model/llm/utils.py +61 -7
  29. xinference/model/llm/vllm/core.py +38 -8
  30. xinference/model/rerank/__init__.py +66 -23
  31. xinference/model/rerank/cache_manager.py +35 -0
  32. xinference/model/rerank/core.py +84 -339
  33. xinference/model/rerank/custom.py +33 -8
  34. xinference/model/rerank/model_spec.json +251 -212
  35. xinference/model/rerank/rerank_family.py +137 -0
  36. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  37. xinference/model/rerank/sentence_transformers/core.py +337 -0
  38. xinference/model/rerank/vllm/__init__.py +13 -0
  39. xinference/model/rerank/vllm/core.py +106 -0
  40. xinference/model/utils.py +109 -0
  41. xinference/types.py +2 -0
  42. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  43. xinference/ui/web/ui/build/index.html +1 -1
  44. xinference/ui/web/ui/build/static/js/{main.b969199a.js → main.4918643a.js} +3 -3
  45. xinference/ui/web/ui/build/static/js/{main.b969199a.js.map → main.4918643a.js.map} +1 -1
  46. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  47. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  48. xinference/ui/web/ui/node_modules/.cache/babel-loader/89179f8f51887b9167721860a12412549ff04f78162e921a7b6aa6532646deb2.json +1 -0
  49. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  50. xinference/ui/web/ui/node_modules/.cache/babel-loader/9dc5cfc67dd0617b0272aeef8651f1589b2155a4ff1fd72ad3166b217089b619.json +1 -0
  51. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  52. {xinference-1.8.1.dist-info → xinference-1.9.0.dist-info}/METADATA +6 -1
  53. {xinference-1.8.1.dist-info → xinference-1.9.0.dist-info}/RECORD +58 -50
  54. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  55. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  56. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  57. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  58. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  59. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  60. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.4918643a.js.LICENSE.txt} +0 -0
  61. {xinference-1.8.1.dist-info → xinference-1.9.0.dist-info}/WHEEL +0 -0
  62. {xinference-1.8.1.dist-info → xinference-1.9.0.dist-info}/entry_points.txt +0 -0
  63. {xinference-1.8.1.dist-info → xinference-1.9.0.dist-info}/licenses/LICENSE +0 -0
  64. {xinference-1.8.1.dist-info → xinference-1.9.0.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-08-03T23:41:33+0800",
11
+ "date": "2025-08-16T21:34:08+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "0e5b67be8403547b9c36e2c4f7ebba19d929e2e3",
15
- "version": "1.8.1"
14
+ "full-revisionid": "38e0401f83799f57d42ef948c57782466b8e4777",
15
+ "version": "1.9.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -2249,8 +2249,9 @@ class RESTfulAPI(CancelMixin):
2249
2249
  )
2250
2250
  if body.tools and body.stream:
2251
2251
  is_vllm = await model.is_vllm_backend()
2252
+ is_sglang = await model.is_sglang_backend()
2252
2253
  if not (
2253
- (is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
2254
+ ((is_vllm or is_sglang) and model_family in QWEN_TOOL_CALL_FAMILY)
2254
2255
  or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
2255
2256
  ):
2256
2257
  raise HTTPException(
xinference/core/model.py CHANGED
@@ -365,6 +365,11 @@ class ModelActor(xo.StatelessActor, CancelMixin):
365
365
 
366
366
  return isinstance(self._model, VLLMModel)
367
367
 
368
+ def is_sglang_backend(self) -> bool:
369
+ from ..model.llm.sglang.core import SGLANGModel
370
+
371
+ return isinstance(self._model, SGLANGModel)
372
+
368
373
  async def load(self):
369
374
  try:
370
375
  # Change process title for model
@@ -476,7 +476,7 @@ class SupervisorActor(xo.StatelessActor):
476
476
  async def _to_rerank_model_reg(
477
477
  self, model_spec: "RerankModelFamilyV2", is_builtin: bool
478
478
  ) -> Dict[str, Any]:
479
- from ..model.cache_manager import CacheManager
479
+ from ..model.rerank.cache_manager import RerankCacheManager as CacheManager
480
480
 
481
481
  instance_cnt = await self.get_instance_count(model_spec.model_name)
482
482
  version_cnt = await self.get_model_version_count(model_spec.model_name)
@@ -712,9 +712,8 @@ class SupervisorActor(xo.StatelessActor):
712
712
  from ..model.rerank import BUILTIN_RERANK_MODELS
713
713
  from ..model.rerank.custom import get_user_defined_reranks
714
714
 
715
- for model_name, families in BUILTIN_RERANK_MODELS.items():
715
+ for model_name, family in BUILTIN_RERANK_MODELS.items():
716
716
  if detailed:
717
- family = [x for x in families if x.model_hub == "huggingface"][0]
718
717
  ret.append(await self._to_rerank_model_reg(family, is_builtin=True))
719
718
  else:
720
719
  ret.append({"model_name": model_name, "is_builtin": True})
xinference/core/worker.py CHANGED
@@ -817,10 +817,7 @@ class WorkerActor(xo.StatelessActor):
817
817
  # we specify python_path explicitly
818
818
  # sometimes uv would find other versions.
819
819
  python_path = pathlib.Path(sys.executable)
820
- kw = {}
821
- if XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED:
822
- kw["skip_installed"] = XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED
823
- virtual_env_manager.create_env(python_path=python_path, **kw)
820
+ virtual_env_manager.create_env(python_path=python_path)
824
821
  return virtual_env_manager
825
822
 
826
823
  @classmethod
@@ -847,6 +844,8 @@ class WorkerActor(xo.StatelessActor):
847
844
  packages.extend(virtual_env_packages)
848
845
  conf.pop("packages", None)
849
846
  conf.pop("inherit_pip_config", None)
847
+ if XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED:
848
+ conf["skip_installed"] = XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED
850
849
 
851
850
  logger.info(
852
851
  "Installing packages %s in virtual env %s, with settings(%s)",
@@ -152,6 +152,11 @@ def main(
152
152
  logging_conf: Optional[Dict] = None,
153
153
  auth_config_file: Optional[str] = None,
154
154
  ):
155
+ # force to set spawn,
156
+ # cuda may be inited in xoscar virtualenv
157
+ # which will raise error after sub pool is created
158
+ multiprocessing.set_start_method("spawn")
159
+
155
160
  supervisor_address = f"{host}:{get_next_port()}"
156
161
  local_cluster = run_in_subprocess(
157
162
  supervisor_address, metrics_exporter_host, metrics_exporter_port, logging_conf
@@ -14,6 +14,7 @@
14
14
 
15
15
  import asyncio
16
16
  import logging
17
+ import multiprocessing
17
18
  import os
18
19
  from typing import Any, Optional
19
20
 
@@ -81,6 +82,11 @@ def main(
81
82
  metrics_exporter_port: Optional[int] = None,
82
83
  logging_conf: Optional[dict] = None,
83
84
  ):
85
+ # force to set spawn,
86
+ # cuda may be inited in xoscar virtualenv
87
+ # which will raise error after sub pool is created
88
+ multiprocessing.set_start_method("spawn")
89
+
84
90
  loop = asyncio.get_event_loop()
85
91
  task = loop.create_task(
86
92
  _start_worker(
xinference/model/core.py CHANGED
@@ -81,6 +81,9 @@ def create_model_instance(
81
81
  return create_rerank_model_instance(
82
82
  model_uid,
83
83
  model_name,
84
+ model_engine,
85
+ model_format,
86
+ quantization,
84
87
  download_hub,
85
88
  model_path,
86
89
  **kwargs,
@@ -19,8 +19,8 @@ from typing import List, Optional, Union, no_type_check
19
19
  import numpy as np
20
20
  import torch
21
21
 
22
- from ....device_utils import is_device_available
23
22
  from ....types import Embedding, EmbeddingData, EmbeddingUsage
23
+ from ...utils import is_flash_attn_available
24
24
  from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
25
25
 
26
26
  logger = logging.getLogger(__name__)
@@ -85,13 +85,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
85
85
  )
86
86
  elif "qwen3" in self.model_family.model_name.lower():
87
87
  # qwen3 embedding
88
- flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
89
88
  flash_attn_enabled = self._kwargs.get(
90
- "enable_flash_attn", is_device_available("cuda")
89
+ "enable_flash_attn", is_flash_attn_available()
91
90
  )
92
91
  model_kwargs = {"device_map": "auto"}
93
92
  tokenizer_kwargs = {}
94
- if flash_attn_installed and flash_attn_enabled:
93
+ if flash_attn_enabled:
95
94
  model_kwargs["attn_implementation"] = "flash_attention_2"
96
95
  model_kwargs["torch_dtype"] = "bfloat16"
97
96
  tokenizer_kwargs["padding_side"] = "left"
@@ -17,6 +17,7 @@ import logging
17
17
  from typing import List, Union
18
18
 
19
19
  from ....types import Embedding, EmbeddingData, EmbeddingUsage
20
+ from ...utils import cache_clean
20
21
  from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
21
22
 
22
23
  logger = logging.getLogger(__name__)
@@ -42,13 +43,14 @@ class VLLMEmbeddingModel(EmbeddingModel):
42
43
 
43
44
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
44
45
 
45
- self._model = LLM(model=self._model_path, task="embed")
46
+ self._model = LLM(model=self._model_path, task="embed", **self._kwargs)
46
47
  self._tokenizer = self._model.get_tokenizer()
47
48
 
48
49
  @staticmethod
49
50
  def _get_detailed_instruct(task_description: str, query: str) -> str:
50
51
  return f"Instruct: {task_description}\nQuery:{query}"
51
52
 
53
+ @cache_clean
52
54
  def create_embedding(
53
55
  self,
54
56
  sentences: Union[str, List[str]],
@@ -60,8 +62,7 @@ class VLLMEmbeddingModel(EmbeddingModel):
60
62
  normalize_embedding = kwargs.get("normalize_embedding", True)
61
63
  if not normalize_embedding:
62
64
  raise ValueError(
63
- "vllm embedding engine does not support "
64
- "setting `normalize_embedding=False`"
65
+ "vllm embedding engine does not support setting `normalize_embedding=False`"
65
66
  )
66
67
 
67
68
  assert self._model is not None
@@ -175,6 +175,75 @@
175
175
  "no_build_isolation": true
176
176
  }
177
177
  },
178
+ {
179
+ "version": 2,
180
+ "model_name": "Qwen-Image",
181
+ "model_family": "stable_diffusion",
182
+ "model_ability": [
183
+ "text2image"
184
+ ],
185
+ "model_src": {
186
+ "huggingface": {
187
+ "model_id": "Qwen/Qwen-Image",
188
+ "model_revision": "4516c4d3058302ff35cd86c62ffa645d039fefad",
189
+ "gguf_model_id": "city96/Qwen-Image-gguf",
190
+ "gguf_quantizations": [
191
+ "F16",
192
+ "Q3_K_M",
193
+ "Q3_K_S",
194
+ "Q4_0",
195
+ "Q4_1",
196
+ "Q4_K_M",
197
+ "Q4_K_S",
198
+ "Q5_0",
199
+ "Q5_1",
200
+ "Q5_K_M",
201
+ "Q5_K_S",
202
+ "Q6_K",
203
+ "Q8_0"
204
+ ],
205
+ "gguf_model_file_name_template": "qwen-image-{quantization}.gguf"
206
+ },
207
+ "modelscope": {
208
+ "model_id": "Qwen/Qwen-Image",
209
+ "model_revision": "master",
210
+ "gguf_model_id": "city96/Qwen-Image-gguf",
211
+ "gguf_quantizations": [
212
+ "F16",
213
+ "Q3_K_M",
214
+ "Q3_K_S",
215
+ "Q4_0",
216
+ "Q4_1",
217
+ "Q4_K_M",
218
+ "Q4_K_S",
219
+ "Q5_0",
220
+ "Q5_1",
221
+ "Q5_K_M",
222
+ "Q5_K_S",
223
+ "Q6_K",
224
+ "Q8_0"
225
+ ],
226
+ "gguf_model_file_name_template": "qwen-image-{quantization}.gguf"
227
+ }
228
+ },
229
+ "default_model_config": {
230
+ "quantize": true,
231
+ "quantize_text_encoder": "text_encoder",
232
+ "torch_dtype": "bfloat16"
233
+ },
234
+ "default_generate_config": {
235
+ "guidance_scale": 1.0
236
+ },
237
+ "virtualenv": {
238
+ "packages": [
239
+ "git+https://github.com/huggingface/diffusers.git",
240
+ "peft>=0.17.0",
241
+ "#system_torch#",
242
+ "#system_numpy#"
243
+ ],
244
+ "no_build_isolation": true
245
+ }
246
+ },
178
247
  {
179
248
  "version": 2,
180
249
  "model_name": "sd3-medium",
@@ -254,6 +254,14 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
254
254
  self._model = FluxKontextPipeline.from_pretrained(
255
255
  self._model_path, **self._kwargs
256
256
  )
257
+ elif "qwen" in self._model_spec.model_name.lower():
258
+ # TODO: remove this branch when auto pipeline supports
259
+ # Qwen-Image
260
+ from diffusers import DiffusionPipeline
261
+
262
+ self._model = DiffusionPipeline.from_pretrained(
263
+ self._model_path, **self._kwargs
264
+ )
257
265
  else:
258
266
  raise
259
267
  self._load_to_device(self._model)
@@ -348,11 +356,19 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
348
356
  return
349
357
 
350
358
  if not quantize_text_encoder:
359
+ logger.debug("No text encoder quantization")
351
360
  return
352
361
 
353
362
  quantization_method = self._kwargs.pop("text_encoder_quantize_method", "bnb")
354
363
  quantization = self._kwargs.pop("text_encoder_quantization", "8-bit")
355
364
 
365
+ logger.debug(
366
+ "Quantize text encoder %s with method %s, quantization %s",
367
+ quantize_text_encoder,
368
+ quantization_method,
369
+ quantization,
370
+ )
371
+
356
372
  torch_dtype = self._torch_dtype
357
373
  for text_encoder_name in quantize_text_encoder.split(","):
358
374
  quantization_kwargs: Dict[str, Any] = {}
@@ -389,8 +405,13 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
389
405
 
390
406
  if not quantization:
391
407
  # skip if no quantization specified
408
+ logger.debug("No transformer quantization")
392
409
  return
393
410
 
411
+ logger.debug(
412
+ "Quantize transformer with %s, quantization %s", method, quantization
413
+ )
414
+
394
415
  torch_dtype = self._torch_dtype
395
416
  transformer_cls = self._get_layer_cls("transformer")
396
417
  quantization_config = self._get_quantize_config(
@@ -409,6 +430,7 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
409
430
 
410
431
  # GGUF transformer
411
432
  torch_dtype = self._torch_dtype
433
+ logger.debug("Quantize transformer with gguf file %s", self._gguf_model_path)
412
434
  self._kwargs["transformer"] = self._get_layer_cls(
413
435
  "transformer"
414
436
  ).from_single_file(
@@ -1,3 +1,17 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import logging
2
16
  import os
3
17
  from typing import TYPE_CHECKING, Optional
@@ -81,7 +95,7 @@ class LLMCacheManager(CacheManager):
81
95
  if not IS_NEW_HUGGINGFACE_HUB:
82
96
  use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
83
97
 
84
- if self._model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
98
+ if self._model_format in ["pytorch", "gptq", "awq", "fp8", "bnb", "mlx"]:
85
99
  download_dir = retry_download(
86
100
  huggingface_hub.snapshot_download,
87
101
  self._model_name,
@@ -144,7 +158,7 @@ class LLMCacheManager(CacheManager):
144
158
  if self.get_cache_status():
145
159
  return cache_dir
146
160
 
147
- if self._model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
161
+ if self._model_format in ["pytorch", "gptq", "awq", "bnb", "fp8", "bnb", "mlx"]:
148
162
  download_dir = retry_download(
149
163
  snapshot_download,
150
164
  self._model_name,
@@ -234,7 +248,7 @@ class LLMCacheManager(CacheManager):
234
248
  if self.get_cache_status():
235
249
  return cache_dir
236
250
 
237
- if self._model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
251
+ if self._model_format in ["pytorch", "gptq", "awq", "fp8", "bnb", "mlx"]:
238
252
  download_dir = retry_download(
239
253
  snapshot_download,
240
254
  self._model_name,
@@ -0,0 +1,245 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from copy import deepcopy
16
+ from typing import TYPE_CHECKING, AsyncGenerator, Dict, Union
17
+
18
+ if TYPE_CHECKING:
19
+ from ...types import ChatCompletion, ChatCompletionChunk
20
+
21
+
22
+ class HarmonyStreamParser:
23
+ def __init__(self):
24
+ # Current channel: either 'analysis', 'final', or None if not started yet
25
+ self.current_channel = None
26
+ # Buffer for accumulating text when looking for 'assistantfinal' marker
27
+ self.buffer = ""
28
+
29
+ def feed(self, text):
30
+ """
31
+ Feed a chunk of text into the parser and return parsed segments.
32
+
33
+ Each segment is a dict:
34
+ {
35
+ "channel": "analysis" | "final",
36
+ "content": <string>
37
+ }
38
+
39
+ The parser detects 'assistantfinal' markers inside reasoning text,
40
+ splits the reasoning and final content correctly, and switches the channel.
41
+ """
42
+ segments = []
43
+
44
+ # If we are currently in 'analysis' mode
45
+ if self.current_channel == "analysis":
46
+ # Add text to buffer and check for 'assistantfinal' marker
47
+ self.buffer += text
48
+ if "assistantfinal" in self.buffer:
49
+ # Split reasoning and final content
50
+ before, after = self.buffer.split("assistantfinal", 1)
51
+ if before:
52
+ segments.append({"channel": "analysis", "content": before})
53
+ # Switch to final channel
54
+ self.current_channel = "final"
55
+ self.buffer = ""
56
+ if after:
57
+ segments.append({"channel": "final", "content": after})
58
+ return segments
59
+ else:
60
+ # Check if buffer ends with partial 'assistantfinal'
61
+ if any(
62
+ self.buffer.endswith("assistantfinal"[:i])
63
+ for i in range(1, len("assistantfinal") + 1)
64
+ ):
65
+ # Don't emit anything yet, wait for more text
66
+ return segments
67
+ else:
68
+ # Emit what we have so far and keep buffer for next time
69
+ if self.buffer:
70
+ segments.append({"channel": "analysis", "content": self.buffer})
71
+ self.buffer = ""
72
+ return segments
73
+
74
+ # If we are currently in 'final' mode
75
+ if self.current_channel == "final":
76
+ # Check if this is actually a new message starting with 'analysis'
77
+ if text.startswith("analysis"):
78
+ # Reset parser state for new message
79
+ self.current_channel = None
80
+ self.buffer = ""
81
+ # Re-process this text with the new state
82
+ return self.feed(text)
83
+ else:
84
+ segments.append({"channel": "final", "content": text})
85
+ return segments
86
+
87
+ # If no channel has been started yet
88
+ if text.startswith("analysis"):
89
+ self.current_channel = "analysis"
90
+ rest = text[len("analysis") :]
91
+ if "assistantfinal" in rest:
92
+ # Split immediately if marker is found in the first chunk
93
+ before, after = rest.split("assistantfinal", 1)
94
+ if before:
95
+ segments.append({"channel": "analysis", "content": before})
96
+ self.current_channel = "final"
97
+ if after:
98
+ segments.append({"channel": "final", "content": after})
99
+ else:
100
+ # Start buffering for potential 'assistantfinal' marker
101
+ self.buffer = rest
102
+ # Check if buffer ends with partial 'assistantfinal'
103
+ if any(
104
+ self.buffer.endswith("assistantfinal"[:i])
105
+ for i in range(1, len("assistantfinal") + 1)
106
+ ):
107
+ # Don't emit anything yet, wait for more text
108
+ pass
109
+ else:
110
+ # Emit what we have so far
111
+ if self.buffer:
112
+ segments.append({"channel": "analysis", "content": self.buffer})
113
+ self.buffer = ""
114
+ elif text.startswith("assistantfinal"):
115
+ self.current_channel = "final"
116
+ rest = text[len("assistantfinal") :]
117
+ if rest:
118
+ segments.append({"channel": "final", "content": rest})
119
+
120
+ return segments
121
+
122
+
123
+ async def async_stream_harmony_chat_completion(
124
+ chunks: Union[
125
+ "ChatCompletion",
126
+ AsyncGenerator["ChatCompletionChunk", None],
127
+ ],
128
+ ) -> AsyncGenerator["ChatCompletion", None]:
129
+ """
130
+ Parse Harmony-formatted content from either a full ChatCompletion (non-streaming)
131
+ or an async stream of ChatCompletionChunk (streaming), using the HarmonyStreamParser defined in this file.
132
+
133
+ Yields parsed objects incrementally.
134
+ """
135
+
136
+ # --- Non-streaming: ChatCompletion ---
137
+ if isinstance(chunks, dict) and chunks.get("object") == "chat.completion":
138
+ out_data = deepcopy(chunks)
139
+
140
+ for choice in out_data["choices"]:
141
+ parser = HarmonyStreamParser()
142
+ msg = choice["message"]
143
+
144
+ # Backup original content & reasoning
145
+ original_content = msg.get("content") or ""
146
+ original_reasoning = msg.get("reasoning_content") or ""
147
+
148
+ # Reset fields before parsing
149
+ msg["content"] = ""
150
+ msg["reasoning_content"] = ""
151
+ msg.setdefault("tool_calls", [])
152
+
153
+ # Feed original content
154
+ for seg in parser.feed(original_content):
155
+ ch, c = seg["channel"], seg["content"]
156
+ if ch == "final":
157
+ msg["content"] += c
158
+ elif ch == "analysis":
159
+ msg["reasoning_content"] += c
160
+ elif ch == "tool":
161
+ msg["tool_calls"].append(c)
162
+
163
+ # Feed original reasoning_content
164
+ for seg in parser.feed(original_reasoning):
165
+ if seg["channel"] == "analysis":
166
+ msg["reasoning_content"] += seg["content"]
167
+ elif seg["channel"] == "tool":
168
+ msg["tool_calls"].append(seg["content"])
169
+
170
+ # Clean up reasoning_content: set to None if no reasoning content was parsed
171
+ if not msg["reasoning_content"] and not original_reasoning:
172
+ msg["reasoning_content"] = None # type: ignore
173
+
174
+ yield out_data
175
+
176
+ else:
177
+ # Streaming: handle async generator
178
+ parsers_per_choice = {}
179
+
180
+ async for chunk in chunks: # type: ignore
181
+ out_chunk = { # type: ignore
182
+ "id": chunk["id"],
183
+ "model": chunk["model"],
184
+ "object": chunk["object"],
185
+ "created": chunk["created"],
186
+ "choices": [],
187
+ }
188
+
189
+ for i, choice in enumerate(chunk["choices"]):
190
+ delta = choice.get("delta", {})
191
+ text = delta.get("content") or "" # type: ignore
192
+
193
+ if i not in parsers_per_choice:
194
+ parsers_per_choice[i] = HarmonyStreamParser()
195
+
196
+ # Feed text to parser and collect current delta only
197
+ curr_delta: Dict[str, object] = {
198
+ "content": "",
199
+ "reasoning_content": "",
200
+ "tool_calls": [],
201
+ }
202
+
203
+ for seg in parsers_per_choice[i].feed(text):
204
+ ch = seg["channel"]
205
+ c = seg["content"]
206
+ if ch == "final":
207
+ curr_delta["content"] += c # type: ignore
208
+ elif ch == "analysis":
209
+ curr_delta["reasoning_content"] += c # type: ignore
210
+ elif ch == "tool":
211
+ curr_delta["tool_calls"].append(c) # type: ignore
212
+
213
+ if curr_delta["reasoning_content"]:
214
+ if not curr_delta["content"]:
215
+ curr_delta["content"] = None
216
+
217
+ elif curr_delta["content"]:
218
+ if not curr_delta["reasoning_content"]:
219
+ curr_delta["reasoning_content"] = None
220
+
221
+ elif (
222
+ choice.get("finish_reason") is not None
223
+ and not curr_delta["reasoning_content"]
224
+ ):
225
+ # For the final chunk, if there's no new reasoning content,
226
+ # don't include empty reasoning_content to avoid clearing existing state
227
+ curr_delta["reasoning_content"] = None
228
+
229
+ out_chunk["choices"].append( # type: ignore
230
+ {
231
+ "index": i,
232
+ "delta": curr_delta,
233
+ "finish_reason": choice.get("finish_reason"),
234
+ }
235
+ )
236
+
237
+ # Only yield if we have either content or reasoning_content
238
+ has_content = any(
239
+ choice["delta"].get("content") # type: ignore
240
+ or choice["delta"].get("reasoning_content") # type: ignore
241
+ or choice.get("finish_reason") is not None # type: ignore
242
+ for choice in out_chunk["choices"] # type: ignore
243
+ )
244
+ if has_content:
245
+ yield out_chunk # type: ignore