xinference 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

xinference/_compat.py CHANGED
@@ -60,6 +60,10 @@ from openai.types.chat.chat_completion_stream_options_param import (
60
60
  ChatCompletionStreamOptionsParam,
61
61
  )
62
62
  from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
63
+ from openai.types.shared_params.response_format_json_object import (
64
+ ResponseFormatJSONObject,
65
+ )
66
+ from openai.types.shared_params.response_format_text import ResponseFormatText
63
67
 
64
68
  OpenAIChatCompletionStreamOptionsParam = create_model_from_typeddict(
65
69
  ChatCompletionStreamOptionsParam
@@ -70,6 +74,23 @@ OpenAIChatCompletionNamedToolChoiceParam = create_model_from_typeddict(
70
74
  )
71
75
 
72
76
 
77
+ class JSONSchema(BaseModel):
78
+ name: str
79
+ description: Optional[str] = None
80
+ schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
81
+ strict: Optional[bool] = None
82
+
83
+
84
+ class ResponseFormatJSONSchema(BaseModel):
85
+ json_schema: JSONSchema
86
+ type: Literal["json_schema"]
87
+
88
+
89
+ ResponseFormat = Union[
90
+ ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema
91
+ ]
92
+
93
+
73
94
  class CreateChatCompletionOpenAI(BaseModel):
74
95
  """
75
96
  Comes from source code: https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py
@@ -84,8 +105,7 @@ class CreateChatCompletionOpenAI(BaseModel):
84
105
  n: Optional[int]
85
106
  parallel_tool_calls: Optional[bool]
86
107
  presence_penalty: Optional[float]
87
- # we do not support this
88
- # response_format: ResponseFormat
108
+ response_format: Optional[ResponseFormat]
89
109
  seed: Optional[int]
90
110
  service_tier: Optional[Literal["auto", "default"]]
91
111
  stop: Union[Optional[str], List[str]]
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-15T17:33:11+0800",
11
+ "date": "2024-11-29T16:57:04+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "4c96475b8f90e354aa1b47856fda4db098b62b65",
15
- "version": "1.0.0"
14
+ "full-revisionid": "eb8ddd431f5c5fcb2216e25e0d43745f8455d9b9",
15
+ "version": "1.0.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -489,6 +489,16 @@ class RESTfulAPI(CancelMixin):
489
489
  else None
490
490
  ),
491
491
  )
492
+ self._router.add_api_route(
493
+ "/v1/convert_ids_to_tokens",
494
+ self.convert_ids_to_tokens,
495
+ methods=["POST"],
496
+ dependencies=(
497
+ [Security(self._auth_service, scopes=["models:read"])]
498
+ if self.is_authenticated()
499
+ else None
500
+ ),
501
+ )
492
502
  self._router.add_api_route(
493
503
  "/v1/rerank",
494
504
  self.rerank,
@@ -1219,6 +1229,9 @@ class RESTfulAPI(CancelMixin):
1219
1229
  raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
1220
1230
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1221
1231
 
1232
+ # guided_decoding params
1233
+ kwargs.update(self.extract_guided_params(raw_body=raw_body))
1234
+
1222
1235
  # TODO: Decide if this default value override is necessary #1061
1223
1236
  if body.max_tokens is None:
1224
1237
  kwargs["max_tokens"] = max_tokens_field.default
@@ -1264,6 +1277,8 @@ class RESTfulAPI(CancelMixin):
1264
1277
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
1265
1278
  yield dict(data=json.dumps({"error": str(ex)}))
1266
1279
  return
1280
+ finally:
1281
+ await model.decrease_serve_count()
1267
1282
 
1268
1283
  return EventSourceResponse(stream_results())
1269
1284
  else:
@@ -1312,6 +1327,41 @@ class RESTfulAPI(CancelMixin):
1312
1327
  await self._report_error_event(model_uid, str(e))
1313
1328
  raise HTTPException(status_code=500, detail=str(e))
1314
1329
 
1330
+ async def convert_ids_to_tokens(self, request: Request) -> Response:
1331
+ payload = await request.json()
1332
+ body = CreateEmbeddingRequest.parse_obj(payload)
1333
+ model_uid = body.model
1334
+ exclude = {
1335
+ "model",
1336
+ "input",
1337
+ "user",
1338
+ }
1339
+ kwargs = {key: value for key, value in payload.items() if key not in exclude}
1340
+
1341
+ try:
1342
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1343
+ except ValueError as ve:
1344
+ logger.error(str(ve), exc_info=True)
1345
+ await self._report_error_event(model_uid, str(ve))
1346
+ raise HTTPException(status_code=400, detail=str(ve))
1347
+ except Exception as e:
1348
+ logger.error(e, exc_info=True)
1349
+ await self._report_error_event(model_uid, str(e))
1350
+ raise HTTPException(status_code=500, detail=str(e))
1351
+
1352
+ try:
1353
+ decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs)
1354
+ return Response(decoded_texts, media_type="application/json")
1355
+ except RuntimeError as re:
1356
+ logger.error(re, exc_info=True)
1357
+ await self._report_error_event(model_uid, str(re))
1358
+ self.handle_request_limit_error(re)
1359
+ raise HTTPException(status_code=400, detail=str(re))
1360
+ except Exception as e:
1361
+ logger.error(e, exc_info=True)
1362
+ await self._report_error_event(model_uid, str(e))
1363
+ raise HTTPException(status_code=500, detail=str(e))
1364
+
1315
1365
  async def rerank(self, request: Request) -> Response:
1316
1366
  payload = await request.json()
1317
1367
  body = RerankRequest.parse_obj(payload)
@@ -1495,8 +1545,16 @@ class RESTfulAPI(CancelMixin):
1495
1545
  **parsed_kwargs,
1496
1546
  )
1497
1547
  if body.stream:
1548
+
1549
+ async def stream_results():
1550
+ try:
1551
+ async for item in out:
1552
+ yield item
1553
+ finally:
1554
+ await model.decrease_serve_count()
1555
+
1498
1556
  return EventSourceResponse(
1499
- media_type="application/octet-stream", content=out
1557
+ media_type="application/octet-stream", content=stream_results()
1500
1558
  )
1501
1559
  else:
1502
1560
  return Response(media_type="application/octet-stream", content=out)
@@ -1916,9 +1974,13 @@ class RESTfulAPI(CancelMixin):
1916
1974
  "logit_bias_type",
1917
1975
  "user",
1918
1976
  }
1977
+
1919
1978
  raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
1920
1979
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1921
1980
 
1981
+ # guided_decoding params
1982
+ kwargs.update(self.extract_guided_params(raw_body=raw_body))
1983
+
1922
1984
  # TODO: Decide if this default value override is necessary #1061
1923
1985
  if body.max_tokens is None:
1924
1986
  kwargs["max_tokens"] = max_tokens_field.default
@@ -2027,6 +2089,8 @@ class RESTfulAPI(CancelMixin):
2027
2089
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
2028
2090
  yield dict(data=json.dumps({"error": str(ex)}))
2029
2091
  return
2092
+ finally:
2093
+ await model.decrease_serve_count()
2030
2094
 
2031
2095
  return EventSourceResponse(stream_results())
2032
2096
  else:
@@ -2279,6 +2343,27 @@ class RESTfulAPI(CancelMixin):
2279
2343
  logger.error(e, exc_info=True)
2280
2344
  raise HTTPException(status_code=500, detail=str(e))
2281
2345
 
2346
+ @staticmethod
2347
+ def extract_guided_params(raw_body: dict) -> dict:
2348
+ kwargs = {}
2349
+ if raw_body.get("guided_json") is not None:
2350
+ kwargs["guided_json"] = raw_body.get("guided_json")
2351
+ if raw_body.get("guided_regex") is not None:
2352
+ kwargs["guided_regex"] = raw_body.get("guided_regex")
2353
+ if raw_body.get("guided_choice") is not None:
2354
+ kwargs["guided_choice"] = raw_body.get("guided_choice")
2355
+ if raw_body.get("guided_grammar") is not None:
2356
+ kwargs["guided_grammar"] = raw_body.get("guided_grammar")
2357
+ if raw_body.get("guided_json_object") is not None:
2358
+ kwargs["guided_json_object"] = raw_body.get("guided_json_object")
2359
+ if raw_body.get("guided_decoding_backend") is not None:
2360
+ kwargs["guided_decoding_backend"] = raw_body.get("guided_decoding_backend")
2361
+ if raw_body.get("guided_whitespace_pattern") is not None:
2362
+ kwargs["guided_whitespace_pattern"] = raw_body.get(
2363
+ "guided_whitespace_pattern"
2364
+ )
2365
+ return kwargs
2366
+
2282
2367
 
2283
2368
  def run(
2284
2369
  supervisor_address: str,
@@ -126,6 +126,43 @@ class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
126
126
  response_data = response.json()
127
127
  return response_data
128
128
 
129
+ def convert_ids_to_tokens(
130
+ self, input: Union[List, List[List]], **kwargs
131
+ ) -> List[str]:
132
+ """
133
+ Convert token IDs to human readable tokens via RESTful APIs.
134
+
135
+ Parameters
136
+ ----------
137
+ input: Union[List, List[List]]
138
+ Input token IDs to convert, can be a single list of token IDs or a list of token ID lists.
139
+ To convert multiple sequences in a single request, pass a list of token ID lists.
140
+
141
+ Returns
142
+ -------
143
+ list
144
+ A list of decoded tokens in human readable format.
145
+
146
+ Raises
147
+ ------
148
+ RuntimeError
149
+ Report the failure of token conversion and provide the error message.
150
+
151
+ """
152
+ url = f"{self._base_url}/v1/convert_ids_to_tokens"
153
+ request_body = {
154
+ "model": self._model_uid,
155
+ "input": input,
156
+ }
157
+ request_body.update(kwargs)
158
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
159
+ if response.status_code != 200:
160
+ raise RuntimeError(
161
+ f"Failed to decode token ids, detail: {_get_error_string(response)}"
162
+ )
163
+ response_data = response.json()
164
+ return response_data
165
+
129
166
 
130
167
  class RESTfulRerankModelHandle(RESTfulModelHandle):
131
168
  def rerank(
@@ -704,6 +741,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
704
741
  The speed of the generated audio.
705
742
  stream: bool
706
743
  Use stream or not.
744
+ prompt_speech: bytes
745
+ The audio bytes to be provided to the model.
707
746
 
708
747
  Returns
709
748
  -------
xinference/core/model.py CHANGED
@@ -91,21 +91,26 @@ def request_limit(fn):
91
91
  logger.debug(
92
92
  f"Request {fn.__name__}, current serve request count: {self._serve_count}, request limit: {self._request_limits} for the model {self.model_uid()}"
93
93
  )
94
- if self._request_limits is not None:
95
- if 1 + self._serve_count <= self._request_limits:
96
- self._serve_count += 1
97
- else:
98
- raise RuntimeError(
99
- f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
100
- )
94
+ if 1 + self._serve_count <= self._request_limits:
95
+ self._serve_count += 1
96
+ else:
97
+ raise RuntimeError(
98
+ f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
99
+ )
100
+ ret = None
101
101
  try:
102
102
  ret = await fn(self, *args, **kwargs)
103
103
  finally:
104
- if self._request_limits is not None:
104
+ if ret is not None and (
105
+ inspect.isasyncgen(ret) or inspect.isgenerator(ret)
106
+ ):
107
+ # stream case, let client call model_ref to decrease self._serve_count
108
+ pass
109
+ else:
105
110
  self._serve_count -= 1
106
- logger.debug(
107
- f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
108
- )
111
+ logger.debug(
112
+ f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
113
+ )
109
114
  return ret
110
115
 
111
116
  return wrapped_func
@@ -215,7 +220,9 @@ class ModelActor(xo.StatelessActor, CancelMixin):
215
220
  self._model_description = (
216
221
  model_description.to_dict() if model_description else {}
217
222
  )
218
- self._request_limits = request_limits
223
+ self._request_limits = (
224
+ float("inf") if request_limits is None else request_limits
225
+ )
219
226
  self._pending_requests: asyncio.Queue = asyncio.Queue()
220
227
  self._handle_pending_requests_task = None
221
228
  self._lock = (
@@ -268,6 +275,9 @@ class ModelActor(xo.StatelessActor, CancelMixin):
268
275
  def __repr__(self) -> str:
269
276
  return f"ModelActor({self._replica_model_uid})"
270
277
 
278
+ def decrease_serve_count(self):
279
+ self._serve_count -= 1
280
+
271
281
  async def _record_completion_metrics(
272
282
  self, duration, completion_tokens, prompt_tokens
273
283
  ):
@@ -794,6 +804,19 @@ class ModelActor(xo.StatelessActor, CancelMixin):
794
804
  f"Model {self._model.model_spec} is not for creating embedding."
795
805
  )
796
806
 
807
+ @request_limit
808
+ @log_async(logger=logger)
809
+ async def convert_ids_to_tokens(
810
+ self, input: Union[List, List[List]], *args, **kwargs
811
+ ):
812
+ kwargs.pop("request_id", None)
813
+ if hasattr(self._model, "convert_ids_to_tokens"):
814
+ return await self._call_wrapper_json(
815
+ self._model.convert_ids_to_tokens, input, *args, **kwargs
816
+ )
817
+
818
+ raise AttributeError(f"Model {self._model.model_spec} can convert token id.")
819
+
797
820
  @request_limit
798
821
  @log_async(logger=logger)
799
822
  async def rerank(
@@ -15,6 +15,8 @@
15
15
  import codecs
16
16
  import json
17
17
  import os
18
+ import platform
19
+ import sys
18
20
  import warnings
19
21
  from typing import Any, Dict
20
22
 
@@ -55,6 +57,14 @@ def register_custom_model():
55
57
  warnings.warn(f"{user_defined_audio_dir}/{f} has error, {e}")
56
58
 
57
59
 
60
+ def _need_filter(spec: dict):
61
+ if (sys.platform != "darwin" or platform.processor() != "arm") and spec.get(
62
+ "engine", ""
63
+ ).upper() == "MLX":
64
+ return True
65
+ return False
66
+
67
+
58
68
  def _install():
59
69
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
60
70
  _model_spec_modelscope_json = os.path.join(
@@ -64,6 +74,7 @@ def _install():
64
74
  dict(
65
75
  (spec["model_name"], AudioModelFamilyV1(**spec))
66
76
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
77
+ if not _need_filter(spec)
67
78
  )
68
79
  )
69
80
  for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
@@ -75,6 +86,7 @@ def _install():
75
86
  for spec in json.load(
76
87
  codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
77
88
  )
89
+ if not _need_filter(spec)
78
90
  )
79
91
  )
80
92
  for model_name, model_spec in MODELSCOPE_AUDIO_MODELS.items():
@@ -24,6 +24,7 @@ from .cosyvoice import CosyVoiceModel
24
24
  from .fish_speech import FishSpeechModel
25
25
  from .funasr import FunASRModel
26
26
  from .whisper import WhisperModel
27
+ from .whisper_mlx import WhisperMLXModel
27
28
 
28
29
  logger = logging.getLogger(__name__)
29
30
 
@@ -43,11 +44,12 @@ class AudioModelFamilyV1(CacheableModelSpec):
43
44
  model_family: str
44
45
  model_name: str
45
46
  model_id: str
46
- model_revision: str
47
+ model_revision: Optional[str]
47
48
  multilingual: bool
48
49
  model_ability: Optional[str]
49
50
  default_model_config: Optional[Dict[str, Any]]
50
51
  default_transcription_config: Optional[Dict[str, Any]]
52
+ engine: Optional[str]
51
53
 
52
54
 
53
55
  class AudioModelDescription(ModelDescription):
@@ -160,17 +162,32 @@ def create_audio_model_instance(
160
162
  model_path: Optional[str] = None,
161
163
  **kwargs,
162
164
  ) -> Tuple[
163
- Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel],
165
+ Union[
166
+ WhisperModel,
167
+ WhisperMLXModel,
168
+ FunASRModel,
169
+ ChatTTSModel,
170
+ CosyVoiceModel,
171
+ FishSpeechModel,
172
+ ],
164
173
  AudioModelDescription,
165
174
  ]:
166
175
  model_spec = match_audio(model_name, download_hub)
167
176
  if model_path is None:
168
177
  model_path = cache(model_spec)
169
178
  model: Union[
170
- WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel
179
+ WhisperModel,
180
+ WhisperMLXModel,
181
+ FunASRModel,
182
+ ChatTTSModel,
183
+ CosyVoiceModel,
184
+ FishSpeechModel,
171
185
  ]
172
186
  if model_spec.model_family == "whisper":
173
- model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
187
+ if not model_spec.engine:
188
+ model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
189
+ else:
190
+ model = WhisperMLXModel(model_uid, model_path, model_spec, **kwargs)
174
191
  elif model_spec.model_family == "funasr":
175
192
  model = FunASRModel(model_uid, model_path, model_spec, **kwargs)
176
193
  elif model_spec.model_family == "ChatTTS":
@@ -81,12 +81,14 @@ class FishSpeechModel:
81
81
  if not is_device_available(self._device):
82
82
  raise ValueError(f"Device {self._device} is not available!")
83
83
 
84
- logger.info("Loading Llama model...")
84
+ enable_compile = self._kwargs.get("compile", False)
85
+ precision = self._kwargs.get("precision", torch.bfloat16)
86
+ logger.info("Loading Llama model, compile=%s...", enable_compile)
85
87
  self._llama_queue = launch_thread_safe_queue(
86
88
  checkpoint_path=self._model_path,
87
89
  device=self._device,
88
- precision=torch.bfloat16,
89
- compile=False,
90
+ precision=precision,
91
+ compile=enable_compile,
90
92
  )
91
93
  logger.info("Llama model loaded, loading VQ-GAN model...")
92
94
 
@@ -112,9 +114,10 @@ class FishSpeechModel:
112
114
  top_p,
113
115
  repetition_penalty,
114
116
  temperature,
117
+ seed="0",
115
118
  streaming=False,
116
119
  ):
117
- from fish_speech.utils import autocast_exclude_mps
120
+ from fish_speech.utils import autocast_exclude_mps, set_seed
118
121
  from tools.api import decode_vq_tokens, encode_reference
119
122
  from tools.llama.generate import (
120
123
  GenerateRequest,
@@ -122,6 +125,11 @@ class FishSpeechModel:
122
125
  WrappedGenerateResponse,
123
126
  )
124
127
 
128
+ seed = int(seed)
129
+ if seed != 0:
130
+ set_seed(seed)
131
+ logger.warning(f"set seed: {seed}")
132
+
125
133
  # Parse reference audio aka prompt
126
134
  prompt_tokens = encode_reference(
127
135
  decoder_model=self._model,
@@ -137,7 +145,7 @@ class FishSpeechModel:
137
145
  top_p=top_p,
138
146
  repetition_penalty=repetition_penalty,
139
147
  temperature=temperature,
140
- compile=False,
148
+ compile=self._kwargs.get("compile", False),
141
149
  iterative_prompt=chunk_length > 0,
142
150
  chunk_length=chunk_length,
143
151
  max_length=2048,
@@ -153,22 +161,20 @@ class FishSpeechModel:
153
161
  )
154
162
  )
155
163
 
156
- if streaming:
157
- yield wav_chunk_header(), None, None
158
-
159
164
  segments = []
160
165
 
161
166
  while True:
162
- result: WrappedGenerateResponse = response_queue.get() # type: ignore
167
+ result: WrappedGenerateResponse = response_queue.get()
163
168
  if result.status == "error":
164
- raise Exception(str(result.response))
169
+ raise result.response
165
170
 
166
- result: GenerateResponse = result.response # type: ignore
171
+ result: GenerateResponse = result.response
167
172
  if result.action == "next":
168
173
  break
169
174
 
170
175
  with autocast_exclude_mps(
171
- device_type=self._model.device.type, dtype=torch.bfloat16
176
+ device_type=self._model.device.type,
177
+ dtype=self._kwargs.get("precision", torch.bfloat16),
172
178
  ):
173
179
  fake_audios = decode_vq_tokens(
174
180
  decoder_model=self._model,
@@ -179,7 +185,7 @@ class FishSpeechModel:
179
185
  segments.append(fake_audios)
180
186
 
181
187
  if streaming:
182
- yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
188
+ yield fake_audios, None, None
183
189
 
184
190
  if len(segments) == 0:
185
191
  raise Exception("No audio generated, please check the input text.")
@@ -204,29 +210,58 @@ class FishSpeechModel:
204
210
  logger.warning("Fish speech does not support setting voice: %s.", voice)
205
211
  if speed != 1.0:
206
212
  logger.warning("Fish speech does not support setting speed: %s.", speed)
207
- if stream is True:
208
- logger.warning("stream mode is not implemented.")
209
213
  import torchaudio
210
214
 
211
- result = list(
212
- self._inference(
213
- text=input,
214
- enable_reference_audio=False,
215
- reference_audio=None,
216
- reference_text=kwargs.get("reference_text", ""),
217
- max_new_tokens=kwargs.get("max_new_tokens", 1024),
218
- chunk_length=kwargs.get("chunk_length", 200),
219
- top_p=kwargs.get("top_p", 0.7),
220
- repetition_penalty=kwargs.get("repetition_penalty", 1.2),
221
- temperature=kwargs.get("temperature", 0.7),
222
- )
215
+ prompt_speech = kwargs.get("prompt_speech")
216
+ result = self._inference(
217
+ text=input,
218
+ enable_reference_audio=kwargs.get(
219
+ "enable_reference_audio", prompt_speech is not None
220
+ ),
221
+ reference_audio=prompt_speech,
222
+ reference_text=kwargs.get("reference_text", ""),
223
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
224
+ chunk_length=kwargs.get("chunk_length", 200),
225
+ top_p=kwargs.get("top_p", 0.7),
226
+ repetition_penalty=kwargs.get("repetition_penalty", 1.2),
227
+ temperature=kwargs.get("temperature", 0.7),
228
+ streaming=stream,
223
229
  )
224
- sample_rate, audio = result[0][1]
225
- audio = np.array([audio])
226
230
 
227
- # Save the generated audio
228
- with BytesIO() as out:
229
- torchaudio.save(
230
- out, torch.from_numpy(audio), sample_rate, format=response_format
231
- )
232
- return out.getvalue()
231
+ if stream:
232
+
233
+ def _stream_generator():
234
+ with BytesIO() as out:
235
+ writer = torchaudio.io.StreamWriter(out, format=response_format)
236
+ writer.add_audio_stream(
237
+ sample_rate=self._model.spec_transform.sample_rate,
238
+ num_channels=1,
239
+ )
240
+ i = 0
241
+ last_pos = 0
242
+ with writer.open():
243
+ for chunk in result:
244
+ chunk = chunk[0]
245
+ if chunk is not None:
246
+ chunk = chunk.reshape((chunk.shape[0], 1))
247
+ trans_chunk = torch.from_numpy(chunk)
248
+ writer.write_audio_chunk(i, trans_chunk)
249
+ new_last_pos = out.tell()
250
+ if new_last_pos != last_pos:
251
+ out.seek(last_pos)
252
+ encoded_bytes = out.read()
253
+ yield encoded_bytes
254
+ last_pos = new_last_pos
255
+
256
+ return _stream_generator()
257
+ else:
258
+ result = list(result)
259
+ sample_rate, audio = result[0][1]
260
+ audio = np.array([audio])
261
+
262
+ # Save the generated audio
263
+ with BytesIO() as out:
264
+ torchaudio.save(
265
+ out, torch.from_numpy(audio), sample_rate, format=response_format
266
+ )
267
+ return out.getvalue()