ojin-client 0.1.7.dev8__tar.gz → 0.1.7.dev10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ojin-client
3
- Version: 0.1.7.dev8
3
+ Version: 0.1.7.dev10
4
4
  Summary: Ojin platform services
5
5
  Author: Journee
6
6
  License: Apache-2.0
@@ -1,4 +1,4 @@
1
- """WebSocket client for OJIN Persona service."""
1
+ """WebSocket client for OJIN Persona service with optimized cancellation handling."""
2
2
 
3
3
  import asyncio
4
4
  import contextlib
@@ -87,9 +87,12 @@ class OjinPersonaClient(IOjinPersonaClient):
87
87
  self._receive_task: Optional[asyncio.Task] = None
88
88
  self._inference_server_ready: bool = False
89
89
  self._cancelled: bool = False
90
- self.active_interaction_id: str | None = None
90
+ self._active_interaction_id: str | None = None
91
91
  self._split_audio_task: Optional[asyncio.Task] = None
92
92
  self._audio_queue: asyncio.Queue[OjinPersonaInteractionInputMessage] = asyncio.Queue()
93
+
94
+ # Add cancellation event for immediate stopping
95
+ self._cancel_event = asyncio.Event()
93
96
 
94
97
  async def connect(self) -> None:
95
98
  """Establish WebSocket connection and authenticate with the service."""
@@ -135,7 +138,8 @@ class OjinPersonaClient(IOjinPersonaClient):
135
138
  pass
136
139
 
137
140
  self._running = False
138
- self.active_interaction_id = None
141
+ self._active_interaction_id = None
142
+ self._cancel_event.set() # Signal cancellation to all tasks
139
143
 
140
144
  if self._ws:
141
145
  try:
@@ -156,7 +160,6 @@ class OjinPersonaClient(IOjinPersonaClient):
156
160
  await self._receive_task
157
161
  self._receive_task = None
158
162
 
159
-
160
163
  logger.info("Disconnected from OJIN Persona service")
161
164
 
162
165
  async def _receive_messages(self) -> None:
@@ -200,8 +203,9 @@ class OjinPersonaClient(IOjinPersonaClient):
200
203
  )
201
204
  logger.debug("Received InteractionResponse for id %s", interaction_response.interaction_id)
202
205
 
203
- if interaction_response.interaction_id == self._cancelled_interaction_id:
204
- logger.warning("Message From old interaction")
206
+ # TODO: Possibly want to delete
207
+ if interaction_response.interaction_id != self._active_interaction_id:
208
+ logger.warning("Message From other interaction")
205
209
  return
206
210
  await self._message_queue.put(interaction_response)
207
211
  return
@@ -282,73 +286,55 @@ class OjinPersonaClient(IOjinPersonaClient):
282
286
  raise ConnectionError("Infernece Server is not ready to receive messsages")
283
287
 
284
288
  if isinstance(message, OjinPersonaCancelInteractionMessage):
285
- logger.info("Interrupt")
289
+ logger.info("Interrupt - Processing cancellation immediately")
286
290
 
291
+ # Set cancellation flag and event immediately
287
292
  self._cancelled = True
293
+ self._cancel_event.set()
294
+
295
+ # Send cancellation message with high priority
288
296
  cancel_input = CancelInteractionMessage(
289
297
  payload=message.to_proxy_message()
290
298
  )
291
299
 
292
- await self._ws.send(cancel_input.model_dump_json())
293
-
294
- logger.info(f"Message sent {message.interaction_id}")
300
+ # Send immediately without waiting
301
+ try:
302
+ await self._ws.send(cancel_input.model_dump_json())
303
+ logger.info(f"Cancellation message sent immediately for {message.interaction_id}")
304
+ except Exception as e:
305
+ logger.error(f"Failed to send cancellation message: {e}")
295
306
 
296
- while not self._message_queue.empty():
297
- try:
298
- self._message_queue.get_nowait()
299
- except asyncio.QueueEmpty:
300
- break
307
+ # Clear queues quickly without blocking
308
+ self._clear_queues_non_blocking()
301
309
 
310
+ # Reset cancellation state
302
311
  self._cancelled = False
312
+ self._cancel_event.clear()
303
313
 
304
314
  return
305
315
 
306
316
  if isinstance(message, StartInteractionMessage):
307
317
  interaction_id = str(uuid.uuid4())
308
- self.active_interaction_id = interaction_id
318
+ self._active_interaction_id = interaction_id
309
319
  logger.info("Generate UUID %s", interaction_id)
310
320
  interaction_response = StartInteractionResponseMessage(
311
321
  interaction_id=interaction_id
312
322
  )
313
- while not self._message_queue.empty():
314
- await self._message_queue.get()
323
+ # Clear queues non-blocking
324
+ self._clear_queues_non_blocking()
315
325
  self._message_queue.put_nowait(interaction_response)
316
326
  return
317
327
 
318
328
  if isinstance(message, OjinPersonaInteractionInputMessage):
319
329
  logger.info("InteractionMessage")
320
330
  logger.info(f"Message sent {message.interaction_id}")
321
- if message.interaction_id != self.active_interaction_id:
331
+ if message.interaction_id != self._active_interaction_id:
322
332
  return
323
333
 
324
334
  if not message.audio_int16_bytes:
325
335
  raise ValueError("Audio cannot be empty")
326
336
 
327
337
  await self._audio_queue.put(message)
328
- # Split audio bytes into chunks of max 3200 samples
329
- # max_chunk_size = 3200 * 2
330
- # audio_chunks = [
331
- # message.audio_int16_bytes[i : i + max_chunk_size]
332
- # for i in range(0, len(message.audio_int16_bytes), max_chunk_size)
333
- # ]
334
- # logger.info(
335
- # "Split audio into %d chunks of max %d bytes",
336
- # len(audio_chunks), max_chunk_size
337
- # )
338
-
339
- # for i, chunk in enumerate(audio_chunks):
340
- # is_last = i == len(audio_chunks) - 1 and message.is_last_input
341
- #
342
- # interaction_input = InteractionInput(
343
- # interaction_id=message.interaction_id,
344
- # is_final_input=is_last,
345
- # payload_type="audio",
346
- # payload=chunk,
347
- # timestamp=int(time.monotonic() * 1000),
348
- # params=message.params if i == 0 else None,
349
- # )
350
- # proxy_message = InteractionInputMessage(payload=interaction_input)
351
- # await self._ws.send(proxy_message.to_bytes())
352
338
  return
353
339
 
354
340
  logger.error("The message %s is Unknown", message)
@@ -364,18 +350,73 @@ class OjinPersonaClient(IOjinPersonaClient):
364
350
  )
365
351
  raise Exception(error)
366
352
 
367
- async def _split_audio(self) -> None:
353
+ def _clear_queues_non_blocking(self) -> None:
354
+ """Clear all queues without blocking."""
355
+ # Clear message queue
368
356
  while True:
369
- message_audio: OjinPersonaInteractionInputMessage| None = None
370
- if self._cancelled:
371
- continue
357
+ try:
358
+ self._message_queue.get_nowait()
359
+ except asyncio.QueueEmpty:
360
+ break
361
+
362
+ # Clear audio queue
363
+ while True:
364
+ try:
365
+ self._audio_queue.get_nowait()
366
+ except asyncio.QueueEmpty:
367
+ break
372
368
 
369
+ async def _split_audio(self) -> None:
370
+ """Split audio into chunks and send them, with cancellation support."""
371
+ while self._running:
372
+ message_audio: OjinPersonaInteractionInputMessage | None = None
373
+
373
374
  try:
374
- message_audio = self._audio_queue.get_nowait()
375
+ # Use wait_for with cancellation event to make this interruptible
376
+ wait_tasks = [
377
+ asyncio.create_task(self._audio_queue.get()),
378
+ asyncio.create_task(self._cancel_event.wait())
379
+ ]
380
+
381
+ done, pending = await asyncio.wait(
382
+ wait_tasks,
383
+ return_when=asyncio.FIRST_COMPLETED,
384
+ timeout=0.1 # Short timeout to check cancellation frequently
385
+ )
386
+
387
+ # Cancel pending tasks
388
+ for task in pending:
389
+ task.cancel()
390
+ with contextlib.suppress(asyncio.CancelledError):
391
+ await task
392
+
393
+ # Check if cancellation was triggered
394
+ if self._cancelled or self._cancel_event.is_set():
395
+ logger.info("Audio splitting cancelled")
396
+ continue
397
+
398
+ # Check if we got a message
399
+ if done:
400
+ completed_task = done.pop()
401
+ if completed_task == wait_tasks[0]: # Audio queue task completed
402
+ message_audio = completed_task.result()
403
+ else: # Cancellation event was set
404
+ continue
405
+ else:
406
+ # Timeout occurred, continue loop
407
+ continue
408
+
375
409
  except asyncio.QueueEmpty:
376
410
  await asyncio.sleep(0.01)
377
411
  continue
412
+ except Exception as e:
413
+ logger.error(f"Error getting audio message: {e}")
414
+ continue
378
415
 
416
+ if not message_audio:
417
+ continue
418
+
419
+ # Process audio chunks with cancellation checks
379
420
  max_chunk_size = 3200 * 2
380
421
  audio_chunks = [
381
422
  message_audio.audio_int16_bytes[i : i + max_chunk_size]
@@ -387,6 +428,11 @@ class OjinPersonaClient(IOjinPersonaClient):
387
428
  )
388
429
 
389
430
  for i, chunk in enumerate(audio_chunks):
431
+ # Check for cancellation before each chunk
432
+ if self._cancelled or self._cancel_event.is_set():
433
+ logger.info("Audio chunk sending cancelled")
434
+ break
435
+
390
436
  is_last = i == len(audio_chunks) - 1 and message_audio.is_last_input
391
437
 
392
438
  interaction_input = InteractionInput(
@@ -399,8 +445,11 @@ class OjinPersonaClient(IOjinPersonaClient):
399
445
  )
400
446
  proxy_message = InteractionInputMessage(payload=interaction_input)
401
447
 
402
- await self._ws.send(proxy_message.to_bytes())
403
-
448
+ try:
449
+ await self._ws.send(proxy_message.to_bytes())
450
+ except Exception as e:
451
+ logger.error(f"Failed to send audio chunk: {e}")
452
+ break
404
453
 
405
454
  async def receive_message(self) -> BaseModel | None:
406
455
  """Receive the next message from the OJIN Persona service.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ojin-client
3
- Version: 0.1.7.dev8
3
+ Version: 0.1.7.dev10
4
4
  Summary: Ojin platform services
5
5
  Author: Journee
6
6
  License: Apache-2.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ojin-client"
3
- version = "0.1.7dev8"
3
+ version = "0.1.7dev10"
4
4
  description = "Ojin platform services"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"