streamlit-octostar-utils 0.4.2.dev25__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/PKG-INFO +1 -1
  2. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/pyproject.toml +1 -1
  3. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/celery.py +287 -72
  4. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/fastapi.py +1 -97
  5. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/nifi.py +442 -81
  6. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/nlp/ner.py +214 -83
  7. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/ontology/inheritance.py +5 -5
  8. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/LICENSE +0 -0
  9. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/README.md +0 -0
  10. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/__init__.py +0 -0
  11. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/__init__.py +0 -0
  12. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/contents.py +0 -0
  13. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parallelism.py +0 -0
  14. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/__init__.py +0 -0
  15. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/combine_fields.py +0 -0
  16. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/entities_parser.py +0 -0
  17. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/generics.py +0 -0
  18. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/info.py +0 -0
  19. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/linkchart_functions.py +0 -0
  20. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/matches.py +0 -0
  21. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/parameters.py +0 -0
  22. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/rules.py +0 -0
  23. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/signals.py +0 -0
  24. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/__init__.py +0 -0
  25. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/dict.py +0 -0
  26. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/filetypes.py +0 -0
  27. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/threading/__init__.py +0 -0
  28. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/threading/key_queue.py +0 -0
  29. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/timestamp.py +0 -0
  30. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/nlp/__init__.py +0 -0
  31. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/nlp/custom_recognizers.py +0 -0
  32. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/nlp/language.py +0 -0
  33. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/octostar/__init__.py +0 -0
  34. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/octostar/client.py +0 -0
  35. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/octostar/context.py +0 -0
  36. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/octostar/permissions.py +0 -0
  37. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/ontology/__init__.py +0 -0
  38. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/ontology/relationships.py +0 -0
  39. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/ontology/validation.py +0 -0
  40. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/style/__init__.py +0 -0
  41. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/style/common.py +0 -0
  42. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/threading/__init__.py +0 -0
  43. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/threading/async_task_manager.py +0 -0
  44. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/threading/session_callback_manager.py +0 -0
  45. {streamlit_octostar_utils-0.4.2.dev25 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/threading/session_state_hot_swapper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: streamlit-octostar-utils
3
- Version: 0.4.2.dev25
3
+ Version: 0.5.0
4
4
  Summary:
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -5,7 +5,7 @@ include = '\.pyi?$'
5
5
 
6
6
  [tool.poetry]
7
7
  name = "streamlit-octostar-utils"
8
- version = "0.4.2-dev.25"
8
+ version = "0.5.0"
9
9
  description = ""
10
10
  license = "MIT"
11
11
  authors = ["Octostar"]
@@ -9,12 +9,10 @@ import subprocess
9
9
  from fastapi import Query
10
10
  import time
11
11
  import os
12
- import pickle
13
12
  import atexit
14
13
  import redis
15
14
  import uuid
16
15
  import json
17
- import hashlib
18
16
  import shutil
19
17
  import threading
20
18
  from pottery import Redlock
@@ -57,18 +55,40 @@ class CeleryQueueConfig:
57
55
  max_tasks_in_queue=None,
58
56
  max_tasks_per_child=None,
59
57
  max_memory_per_child=None,
60
- stall_timeout=1200,
61
58
  **options,
62
59
  ):
63
60
  self.n_workers = n_workers
64
61
  self.max_tasks_in_queue = max_tasks_in_queue
65
62
  self.max_tasks_per_child = max_tasks_per_child
66
63
  self.max_memory_per_child = max_memory_per_child # KiB
67
- self.stall_timeout = stall_timeout # seconds; None or 0 to disable
68
64
  self.options = options
69
65
 
70
66
 
67
+ class TaskResult:
68
+ """Wrapper for task results that include binary parts alongside JSON data.
69
+ Tasks returning binary data (e.g. images) should return a TaskResult
70
+ so that serialized_io writes them as multipart parts instead of attempting
71
+ JSON serialization on bytes."""
72
+
73
+ def __init__(self, data, part=None):
74
+ self.data = data
75
+ self.part = part
76
+
77
+
71
78
  class CelerySerialized:
79
+ """Serializes task data to a boundary-delimited multipart file.
80
+
81
+ Format: metadata JSON part followed by optional binary/streamed parts,
82
+ separated by boundary markers (the task_id). Replaces pickle entirely.
83
+ """
84
+
85
+ BOUNDARY_PREFIX = b"--"
86
+ BOUNDARY_SUFFIX = b"\r\n"
87
+ BOUNDARY_END = b"--\r\n"
88
+ CONTENT_TYPE_JSON = b"Content-Type: application/json\r\n"
89
+ CONTENT_TYPE_BYTES = b"Content-Type: application/octet-stream\r\n"
90
+ HEADER_END = b"\r\n"
91
+
72
92
  def __init__(self, folder, redis_client, data=None):
73
93
  self.folder = folder
74
94
  self.data = data
@@ -77,18 +97,70 @@ class CelerySerialized:
77
97
  def set_task_id(self, task_id):
78
98
  self.task_id = task_id
79
99
 
80
- def dump(self):
100
+ def _boundary(self):
101
+ return self.task_id.encode()
102
+
103
+ def _write_boundary(self, f):
104
+ f.write(self.BOUNDARY_PREFIX + self._boundary() + self.BOUNDARY_SUFFIX)
105
+
106
+ def _write_end_boundary(self, f):
107
+ f.write(self.BOUNDARY_PREFIX + self._boundary() + self.BOUNDARY_END)
108
+
109
+ def dump(self, parts=None, part_is_list=False):
110
+ """Write metadata + optional parts in multipart format.
111
+ parts: optional list of bytes objects to write as additional parts.
112
+ part_is_list: whether the original part was a list (preserves type on read).
113
+ """
81
114
  assert self.task_id
115
+ if isinstance(self.data, dict):
116
+ metadata = self.data
117
+ else:
118
+ metadata = {"data": self.data}
119
+ metadata["part_count"] = len(parts) if parts else 0
120
+ metadata["part_is_list"] = part_is_list
82
121
  with RedisFileLock(self.redis_client, os.path.join(self.folder, self.task_id)):
83
- with open(os.path.join(self.folder, self.task_id), "wb") as target_file:
84
- pickle.dump(self.data, file=target_file, protocol=pickle.HIGHEST_PROTOCOL)
122
+ with open(os.path.join(self.folder, self.task_id), "wb") as f:
123
+ self._write_boundary(f)
124
+ f.write(self.CONTENT_TYPE_JSON)
125
+ f.write(self.HEADER_END)
126
+ f.write(json.dumps(metadata).encode())
127
+ f.write(b"\r\n")
128
+ if parts:
129
+ for part in parts:
130
+ self._write_boundary(f)
131
+ f.write(self.CONTENT_TYPE_BYTES)
132
+ f.write(self.HEADER_END)
133
+ f.write(part)
134
+ f.write(b"\r\n")
135
+ self._write_end_boundary(f)
85
136
 
86
137
  def load(self):
138
+ """Read multipart file. Returns (metadata_dict, list_of_bytes_parts)."""
87
139
  assert self.task_id
140
+ boundary = self.BOUNDARY_PREFIX + self._boundary()
141
+ end_boundary = self.BOUNDARY_PREFIX + self._boundary() + b"--"
88
142
  with RedisFileLock(self.redis_client, os.path.join(self.folder, self.task_id)):
89
- with open(os.path.join(self.folder, self.task_id), "rb") as source_file:
90
- data = pickle.load(source_file)
91
- return data
143
+ with open(os.path.join(self.folder, self.task_id), "rb") as f:
144
+ raw = f.read()
145
+ sections = raw.split(boundary)
146
+ metadata = None
147
+ parts = []
148
+ for section in sections:
149
+ section = section.strip(b"\r\n")
150
+ if not section or section == b"--":
151
+ continue
152
+ header_end = section.find(b"\r\n\r\n")
153
+ if header_end == -1:
154
+ continue
155
+ header = section[:header_end]
156
+ body = section[header_end + 4:]
157
+ if body.endswith(b"\r\n"):
158
+ body = body[:-2]
159
+ if b"application/json" in header:
160
+ metadata = json.loads(body)
161
+ else:
162
+ parts.append(body)
163
+ return metadata or {}, parts
92
164
 
93
165
 
94
166
  class CeleryExecutor(object):
@@ -150,7 +222,7 @@ class CeleryExecutor(object):
150
222
  self.app = Celery(self.filename)
151
223
  self.app.conf.broker_url = f"redis://{self.redis_host}:{self.redis_port}/0"
152
224
  self.app.conf.result_backend = f"redis://{self.redis_host}:{self.redis_port}/0"
153
- self.app.conf.track_started = True
225
+ self.app.conf.track_started = False
154
226
  self.app.conf.task_serializer = "json"
155
227
  self.app.conf.result_serializer = "json"
156
228
  self.app.conf.accept_content = ["application/json"]
@@ -182,11 +254,11 @@ class CeleryExecutor(object):
182
254
  self.worker_info = {}
183
255
 
184
256
  # Queue stall detection
185
- self._queue_fingerprints = {}
186
- self._queue_fingerprint_changed_at = {}
187
257
  self._queue_stalled = {}
258
+ self._last_stall_check = 0
259
+ self._stall_check_interval = 60
188
260
 
189
- atexit.register(self.close)
261
+ self._is_owner = False
190
262
  self.set_cleanup_task()
191
263
  self.register_state_signals()
192
264
 
@@ -216,28 +288,44 @@ class CeleryExecutor(object):
216
288
  def set_started_state(self, task_id, task, *args, **kwargs):
217
289
  result = AsyncResult(task_id, app=self.app)
218
290
  result.backend.store_result(task_id, result=None, state=CeleryExecutor.STARTED)
291
+ queue = task.request.delivery_info.get(
292
+ "routing_key", self.app.conf.task_default_routing_key
293
+ ) if task else None
294
+ if queue:
295
+ self.redis_client.set(f"queue:last_started:{queue}", str(time.time()))
296
+ request_timelimit = getattr(getattr(task, "request", None), "timelimit", None) or (None, None)
297
+ time_limit = request_timelimit[0] or getattr(task, "time_limit", None) or 0
298
+ extended_ttl = int(time_limit) + int(self.app.conf.result_expires)
299
+ if extended_ttl > int(self.app.conf.result_expires):
300
+ self.redis_client.eval(
301
+ "for _, k in ipairs(KEYS) do redis.call('expire', k, ARGV[1]) end",
302
+ 2,
303
+ f"{CeleryExecutor.CELERY_BROKER_PREFIX}{task_id}",
304
+ f"task:queue:{task_id}",
305
+ extended_ttl,
306
+ )
219
307
 
220
308
  def register_worker_initialization(self):
221
309
  if self.preload_functions:
222
310
  celery_signals.worker_process_init.connect(self.preload_on_worker_init)
223
311
 
224
- def set_last_completed_time(self, sender=None, task_id=None, task=None, **kwargs):
312
+ def cleanup_task_keys(self, sender=None, task_id=None, **kwargs):
225
313
  try:
226
- queue = task.request.delivery_info.get(
227
- "routing_key", self.app.conf.task_default_routing_key
228
- ) if task else None
229
- if queue:
230
- self.redis_client.set(f"queue:last_completed:{queue}", str(time.time()))
314
+ if task_id:
315
+ self.redis_client.delete(f"task:queue:{task_id}")
231
316
  except Exception:
232
317
  pass
233
318
 
234
319
  def register_state_signals(self):
235
320
  celery_signals.before_task_publish.connect(self.set_awaiting_state)
236
321
  celery_signals.task_prerun.connect(self.set_started_state)
237
- celery_signals.task_postrun.connect(self.set_last_completed_time)
322
+ celery_signals.task_postrun.connect(self.cleanup_task_keys)
238
323
 
239
324
  def cleanup_task_results(in_dir, out_dir, redis_host, redis_port, task_expires, result_expires):
240
325
  logger.info("Starting cleanup of expired task results...")
326
+ if not os.path.isdir(in_dir) or not os.path.isdir(out_dir):
327
+ logger.warning(f"Data directories missing (in={os.path.isdir(in_dir)}, out={os.path.isdir(out_dir)}), skipping cleanup")
328
+ return
241
329
  redis_client = redis.StrictRedis(host=redis_host, port=redis_port)
242
330
  for file_name in os.listdir(in_dir):
243
331
  file_path = os.path.join(in_dir, file_name)
@@ -307,6 +395,8 @@ class CeleryExecutor(object):
307
395
  }
308
396
 
309
397
  def start(self):
398
+ self._is_owner = True
399
+ atexit.register(self.close)
310
400
  logger.info("Initializing data folders...")
311
401
  shutil.rmtree(self.root_folder, ignore_errors=True)
312
402
  for folder in [self.in_folder, self.out_folder]:
@@ -333,6 +423,9 @@ class CeleryExecutor(object):
333
423
  if attempts_done == CeleryExecutor.MAX_STARTUP_CHECKS:
334
424
  raise TimeoutError("Redis not ready after a long wait!")
335
425
  self.redis_client.flushall()
426
+ boot_time = str(time.time())
427
+ for queue_name in self.queue_config:
428
+ self.redis_client.set(f"queue:last_started:{queue_name}", boot_time)
336
429
  for queue, queue_config in self.queue_config.items():
337
430
  for slot in range(queue_config.n_workers):
338
431
  worker_name = f"celery@{self.name}:{queue}:{slot}"
@@ -407,7 +500,10 @@ class CeleryExecutor(object):
407
500
  while not self.stop_event.is_set():
408
501
  try:
409
502
  self._restart_dead_processes()
410
- self._check_queue_stalls()
503
+ now = time.time()
504
+ if now - self._last_stall_check >= self._stall_check_interval:
505
+ self._last_stall_check = now
506
+ self._check_queue_stalls()
411
507
  time.sleep(5)
412
508
  except Exception as e:
413
509
  logger.error(f"Error in worker health check: {e}")
@@ -442,44 +538,56 @@ class CeleryExecutor(object):
442
538
  logger.info(f"Restarted beat process (PID: {self.beat_process.pid})")
443
539
 
444
540
  def _check_queue_stalls(self):
445
- for queue_name, queue_config in self.queue_config.items():
446
- if not queue_config.stall_timeout:
447
- continue
541
+ try:
542
+ inspector = self.app.control.inspect(timeout=5)
543
+ active_data = inspector.active() or {}
544
+ except Exception as e:
545
+ logger.error(f"Failed to inspect active workers: {e}")
546
+ return
547
+
548
+ now = time.time()
549
+ grace_period = 3 * self.app.conf.worker_proc_alive_timeout
550
+
551
+ active_queues = set()
552
+ for _worker_name, tasks in active_data.items():
553
+ for task_info in (tasks or []):
554
+ queue = (task_info.get("delivery_info") or {}).get(
555
+ "routing_key", self.app.conf.task_default_routing_key
556
+ )
557
+ active_queues.add(queue)
558
+
559
+ for queue_name in self.queue_config:
448
560
  try:
449
- queue_items = self.redis_client.lrange(queue_name, 0, -1)
450
- if len(queue_items) == 0:
561
+ pending_count = self.redis_client.llen(queue_name)
562
+ was_stalled = self._queue_stalled.get(queue_name, False)
563
+
564
+ if pending_count == 0:
565
+ self.redis_client.delete(f"queue:first_enqueued:{queue_name}")
566
+ if was_stalled:
567
+ logger.info(f"Queue '{queue_name}' has recovered from stall.")
451
568
  self._queue_stalled[queue_name] = False
452
- self._queue_fingerprints.pop(queue_name, None)
453
- self._queue_fingerprint_changed_at.pop(queue_name, None)
454
569
  continue
455
570
 
456
- fingerprint = hashlib.md5(b"".join(sorted(queue_items))).hexdigest()
457
- now_time = time.time()
458
- prev_fingerprint = self._queue_fingerprints.get(queue_name)
459
-
460
- if fingerprint != prev_fingerprint:
461
- self._queue_fingerprints[queue_name] = fingerprint
462
- self._queue_fingerprint_changed_at[queue_name] = now_time
571
+ if queue_name in active_queues:
572
+ if was_stalled:
573
+ logger.info(f"Queue '{queue_name}' has recovered from stall.")
463
574
  self._queue_stalled[queue_name] = False
464
575
  continue
465
576
 
466
- fingerprint_age = now_time - self._queue_fingerprint_changed_at.get(queue_name, now_time)
467
-
468
- last_completed_raw = self.redis_client.get(f"queue:last_completed:{queue_name}")
469
- last_completed = float(last_completed_raw) if last_completed_raw else 0
470
- time_since_completion = (now_time - last_completed) if last_completed else float("inf")
577
+ last_started_raw = self.redis_client.get(f"queue:last_started:{queue_name}")
578
+ first_enqueued_raw = self.redis_client.get(f"queue:first_enqueued:{queue_name}")
579
+ last_started = float(last_started_raw) if last_started_raw else 0
580
+ first_enqueued = float(first_enqueued_raw) if first_enqueued_raw else 0
581
+ last_activity = max(last_started, first_enqueued)
582
+ time_since_activity = (now - last_activity) if last_activity else float("inf")
471
583
 
472
- was_stalled = self._queue_stalled.get(queue_name, False)
473
- is_stalled = (
474
- fingerprint_age >= queue_config.stall_timeout
475
- and time_since_completion >= queue_config.stall_timeout
476
- )
584
+ is_stalled = time_since_activity >= grace_period
477
585
  self._queue_stalled[queue_name] = is_stalled
478
586
 
479
587
  if is_stalled and not was_stalled:
480
588
  logger.error(
481
- f"Queue '{queue_name}' is STALLED: {len(queue_items)} task(s) stuck for "
482
- f"{fingerprint_age:.0f}s with no completions in {time_since_completion:.0f}s. "
589
+ f"Queue '{queue_name}' is STALLED: {pending_count} task(s) pending, "
590
+ f"no active workers, no task started in {time_since_activity:.0f}s. "
483
591
  f"New requests will receive 503."
484
592
  )
485
593
  elif not is_stalled and was_stalled:
@@ -488,6 +596,9 @@ class CeleryExecutor(object):
488
596
  logger.error(f"Error checking stall for queue '{queue_name}': {e}")
489
597
 
490
598
  def close(self):
599
+ if not self._is_owner:
600
+ return
601
+ self._is_owner = False
491
602
  self.stop_event.set()
492
603
  if self.worker_health_check_thread and self.worker_health_check_thread.is_alive():
493
604
  self.worker_health_check_thread.join(timeout=2)
@@ -523,9 +634,16 @@ class CeleryExecutor(object):
523
634
  task_id = task.request.id
524
635
  serialized_data = CelerySerialized(folder=self.in_folder, redis_client=self.redis_client)
525
636
  serialized_data.set_task_id(task_id)
526
- data = serialized_data.load()
637
+ metadata, parts = serialized_data.load()
527
638
  del serialized_data
528
- args, kwargs = data.get("args", []), data.get("kwargs", {})
639
+ args, kwargs = metadata.get("args", []), metadata.get("kwargs", {})
640
+
641
+ part_count = metadata.get("part_count", 0)
642
+ if part_count > 0:
643
+ if metadata.get("part_is_list", part_count > 1):
644
+ args = [parts] + args
645
+ else:
646
+ args = [parts[0]] + args
529
647
 
530
648
  if self.app.conf.task_always_eager:
531
649
  queue = task.request.delivery_info.get("routing_key", self.app.conf.task_default_routing_key)
@@ -536,9 +654,25 @@ class CeleryExecutor(object):
536
654
  queue = task.request.delivery_info.get("routing_key", self.app.conf.task_default_routing_key)
537
655
  task.request.resources = (self.resource_registry or {}).get(queue, {})
538
656
  out_data = task_fn(task, *args, **kwargs)
539
- serialized_data = CelerySerialized(folder=self.out_folder, data=out_data, redis_client=self.redis_client)
657
+ if isinstance(out_data, TaskResult):
658
+ out_part_is_list = isinstance(out_data.part, list)
659
+ if out_data.part is None:
660
+ out_parts = None
661
+ elif out_part_is_list:
662
+ out_parts = out_data.part
663
+ else:
664
+ out_parts = [out_data.part]
665
+ serialized_data = CelerySerialized(
666
+ folder=self.out_folder, data=out_data.data, redis_client=self.redis_client
667
+ )
668
+ else:
669
+ out_parts = None
670
+ out_part_is_list = False
671
+ serialized_data = CelerySerialized(
672
+ folder=self.out_folder, data=out_data, redis_client=self.redis_client
673
+ )
540
674
  serialized_data.set_task_id(task_id)
541
- serialized_data.dump()
675
+ serialized_data.dump(parts=out_parts, part_is_list=out_part_is_list)
542
676
  del serialized_data
543
677
  if os.path.isfile(os.path.join(self.in_folder, task_id)):
544
678
  with RedisFileLock(self.redis_client, os.path.join(self.in_folder, task_id)):
@@ -579,7 +713,7 @@ class CeleryExecutor(object):
579
713
 
580
714
  return decorator
581
715
 
582
- async def send_task(self, task_fn, args=None, kwargs=None, **options) -> str:
716
+ async def send_task(self, task_fn, args=None, kwargs=None, part=None, **options) -> str:
583
717
  args = args if args is not None else []
584
718
  kwargs = kwargs if kwargs is not None else {}
585
719
  if self.app.conf.task_always_eager and "dev_preload" not in self.app.conf:
@@ -610,10 +744,10 @@ class CeleryExecutor(object):
610
744
  task_fn.apply_async(task_id=task_id, **options)
611
745
 
612
746
  def _store_task_queue_mapping(task_id, queue_name):
613
- self.redis_client.set(
614
- f"task:queue:{task_id}", queue_name,
615
- ex=self.app.conf.result_expires,
616
- )
747
+ pipe = self.redis_client.pipeline()
748
+ pipe.set(f"task:queue:{task_id}", queue_name, ex=self.app.conf.result_expires)
749
+ pipe.set(f"queue:first_enqueued:{queue_name}", str(time.time()), nx=True)
750
+ pipe.execute()
617
751
 
618
752
  task_id = str(uuid.uuid4())
619
753
  queue_name = self.app.conf.task_default_routing_key
@@ -635,14 +769,21 @@ class CeleryExecutor(object):
635
769
  await asyncio.get_running_loop().run_in_executor(
636
770
  self.set_thread_pool, _check_queue_llen, queue_name
637
771
  )
638
- await asyncio.get_running_loop().run_in_executor(
639
- self.io_thread_pool,
640
- _write_task_data,
641
- self.in_folder,
642
- args,
643
- kwargs,
644
- task_id,
645
- )
772
+
773
+ if part is not None:
774
+ await self._write_task_data_with_part(
775
+ task_id, args, kwargs, part
776
+ )
777
+ else:
778
+ await asyncio.get_running_loop().run_in_executor(
779
+ self.io_thread_pool,
780
+ _write_task_data,
781
+ self.in_folder,
782
+ args,
783
+ kwargs,
784
+ task_id,
785
+ )
786
+
646
787
  await asyncio.get_running_loop().run_in_executor(
647
788
  self.set_thread_pool, _send_task, task_fn, task_id, options
648
789
  )
@@ -664,12 +805,57 @@ class CeleryExecutor(object):
664
805
  sem.release()
665
806
  return task_id
666
807
 
808
+ async def _write_task_data_with_part(self, task_id, args, kwargs, part):
809
+ """Write task data with a streamed part to the multipart file.
810
+ The part becomes the first arg on the worker side.
811
+ """
812
+ boundary = CelerySerialized.BOUNDARY_PREFIX + task_id.encode()
813
+ boundary_line = boundary + CelerySerialized.BOUNDARY_SUFFIX
814
+ end_boundary_line = boundary + CelerySerialized.BOUNDARY_END
815
+
816
+ is_list = isinstance(part, list)
817
+ items = part if is_list else [part]
818
+ part_count = len(items)
819
+
820
+ metadata = {"args": args, "kwargs": kwargs, "part_count": part_count, "part_is_list": is_list}
821
+ metadata_bytes = json.dumps(metadata).encode()
822
+
823
+ file_path = os.path.join(self.in_folder, task_id)
824
+ with open(file_path, "wb") as f:
825
+ f.write(boundary_line)
826
+ f.write(CelerySerialized.CONTENT_TYPE_JSON)
827
+ f.write(CelerySerialized.HEADER_END)
828
+ f.write(metadata_bytes)
829
+ f.write(b"\r\n")
830
+
831
+ for item in items:
832
+ f.write(boundary_line)
833
+ f.write(CelerySerialized.CONTENT_TYPE_BYTES)
834
+ f.write(CelerySerialized.HEADER_END)
835
+ if hasattr(item, "__aiter__"):
836
+ async for chunk in item:
837
+ f.write(chunk if isinstance(chunk, bytes) else chunk.encode())
838
+ elif hasattr(item, "read"):
839
+ while True:
840
+ chunk = await item.read(65536)
841
+ if not chunk:
842
+ break
843
+ f.write(chunk if isinstance(chunk, bytes) else chunk.encode())
844
+ elif isinstance(item, bytes):
845
+ f.write(item)
846
+ else:
847
+ raise TypeError(f"Unsupported part item type: {type(item)}")
848
+ f.write(b"\r\n")
849
+
850
+ f.write(end_boundary_line)
851
+
667
852
  async def terminate_task(self, task_id):
668
853
  def _terminate_task(celery_app, task_id):
669
854
  celery_app.control.revoke(task_id, terminate=True)
670
855
 
671
856
  def _remove_task_data(celery_app, in_folder, out_folder, task_id):
672
857
  celery_app.AsyncResult(task_id).forget()
858
+ self.redis_client.delete(f"task:queue:{task_id}")
673
859
  if os.path.isfile(os.path.join(in_folder, task_id)):
674
860
  with RedisFileLock(self.redis_client, os.path.join(self.in_folder, task_id)):
675
861
  os.remove(os.path.join(in_folder, task_id))
@@ -705,6 +891,13 @@ class CeleryExecutor(object):
705
891
  self.get_thread_pool, _poll_task_state, self.app, task_id
706
892
  )
707
893
 
894
+ async def get_task_info(self, task_id):
895
+ def _get_info(celery_app, task_id):
896
+ return celery_app.AsyncResult(task_id).info
897
+ return await asyncio.get_running_loop().run_in_executor(
898
+ self.get_thread_pool, _get_info, self.app, task_id,
899
+ )
900
+
708
901
  async def get_task_result(self, task_id, remove=False):
709
902
  def _try_get_task_data(celery_app, task_id):
710
903
  async_result = celery_app.AsyncResult(task_id)
@@ -717,11 +910,19 @@ class CeleryExecutor(object):
717
910
  def _read_task_data(out_folder, task_id):
718
911
  serialized_data = CelerySerialized(folder=out_folder, redis_client=self.redis_client)
719
912
  serialized_data.set_task_id(task_id)
720
- result = serialized_data.load()
721
- return result
913
+ metadata, parts = serialized_data.load()
914
+ data = metadata.get("data", metadata)
915
+ part_count = metadata.get("part_count", 0)
916
+ if part_count > 0:
917
+ if metadata.get("part_is_list", part_count > 1):
918
+ return TaskResult(data=data, part=parts)
919
+ else:
920
+ return TaskResult(data=data, part=parts[0])
921
+ return data
722
922
 
723
923
  def _remove_task_data(celery_app, in_folder, out_folder, task_id):
724
924
  celery_app.AsyncResult(task_id).forget()
925
+ self.redis_client.delete(f"task:queue:{task_id}")
725
926
  if os.path.isfile(os.path.join(in_folder, task_id)):
726
927
  with RedisFileLock(self.redis_client, os.path.join(self.in_folder, task_id)):
727
928
  os.remove(os.path.join(in_folder, task_id))
@@ -744,10 +945,10 @@ class CeleryExecutor(object):
744
945
  )
745
946
  return result
746
947
 
747
- async def send_and_wait_task(self, task_fn, args=None, kwargs=None, timeout=60, **options):
948
+ async def send_and_wait_task(self, task_fn, args=None, kwargs=None, part=None, timeout=60, **options):
748
949
  args = args if args is not None else []
749
950
  kwargs = kwargs if kwargs is not None else {}
750
- task_id = await self.send_task(task_fn, args, kwargs, **options)
951
+ task_id = await self.send_task(task_fn, args, kwargs, part=part, **options)
751
952
  ready = False
752
953
  state = None
753
954
  start_time = time.time()
@@ -782,6 +983,11 @@ class FastAPICeleryTaskRoute(Route):
782
983
  path="/task/{task_id}",
783
984
  methods=["DELETE"],
784
985
  summary="Cancel a queued or running task.",
986
+ description=(
987
+ "Terminate a task by its ID. If the task is still queued "
988
+ "(AWAITING) it is revoked; if it is running (STARTED) the "
989
+ "worker process is interrupted."
990
+ ),
785
991
  status_code=200,
786
992
  responses=DefaultErrorRoute.error_responses,
787
993
  )
@@ -794,7 +1000,14 @@ class FastAPICeleryTaskRoute(Route):
794
1000
  self,
795
1001
  path="/task/{task_id}",
796
1002
  methods=["GET"],
797
- summary="Get task status (and result if available).",
1003
+ summary="Get task status, progress, and result.",
1004
+ description=(
1005
+ "Poll a running or completed task. Returns task_state "
1006
+ "(AWAITING, STARTED, SUCCESS, FAILURE, etc.), and when "
1007
+ "STARTED may include a 'progress' dict with task-specific "
1008
+ "metrics (e.g. nodes_done / nodes_total). When SUCCESS, "
1009
+ "the 'data' field contains the task result."
1010
+ ),
798
1011
  status_code=200,
799
1012
  responses=DefaultErrorRoute.error_responses,
800
1013
  )
@@ -824,10 +1037,12 @@ class FastAPICeleryTaskRoute(Route):
824
1037
  data = {"task_state": "UNKNOWN", "task_id": task_id}
825
1038
  elif state in ["AWAITING", "STARTED"]:
826
1039
  data = {"task_state": state, "task_id": task_id}
1040
+ if state == "STARTED":
1041
+ info = await self.celery_executor.get_task_info(task_id)
1042
+ if info and isinstance(info, dict):
1043
+ data["progress"] = info
827
1044
  elif state == "SUCCESS":
828
1045
  data = {"task_state": state, "task_id": task_id, "data": result}
829
- elif state == "STARTED":
830
- data = {"task_status": state, "task_id": task_id}
831
1046
  else:
832
1047
  raise ValueError(f"Unknown task state {state}!")
833
1048
  return CommonModels.DataResponseModel(data=data)