streamlit-octostar-utils 0.4.2.dev24__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/PKG-INFO +1 -1
  2. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/pyproject.toml +1 -1
  3. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/celery.py +324 -100
  4. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/fastapi.py +1 -97
  5. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/nifi.py +442 -81
  6. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/nlp/ner.py +214 -83
  7. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/ontology/inheritance.py +5 -5
  8. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/LICENSE +0 -0
  9. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/README.md +0 -0
  10. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/__init__.py +0 -0
  11. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/__init__.py +0 -0
  12. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/contents.py +0 -0
  13. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parallelism.py +0 -0
  14. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/__init__.py +0 -0
  15. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/combine_fields.py +0 -0
  16. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/entities_parser.py +0 -0
  17. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/generics.py +0 -0
  18. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/info.py +0 -0
  19. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/linkchart_functions.py +0 -0
  20. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/matches.py +0 -0
  21. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/parameters.py +0 -0
  22. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/rules.py +0 -0
  23. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/api_crafter/parser/signals.py +0 -0
  24. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/__init__.py +0 -0
  25. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/dict.py +0 -0
  26. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/filetypes.py +0 -0
  27. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/threading/__init__.py +0 -0
  28. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/threading/key_queue.py +0 -0
  29. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/core/timestamp.py +0 -0
  30. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/nlp/__init__.py +0 -0
  31. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/nlp/custom_recognizers.py +0 -0
  32. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/nlp/language.py +0 -0
  33. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/octostar/__init__.py +0 -0
  34. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/octostar/client.py +0 -0
  35. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/octostar/context.py +0 -0
  36. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/octostar/permissions.py +0 -0
  37. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/ontology/__init__.py +0 -0
  38. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/ontology/relationships.py +0 -0
  39. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/ontology/validation.py +0 -0
  40. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/style/__init__.py +0 -0
  41. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/style/common.py +0 -0
  42. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/threading/__init__.py +0 -0
  43. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/threading/async_task_manager.py +0 -0
  44. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/threading/session_callback_manager.py +0 -0
  45. {streamlit_octostar_utils-0.4.2.dev24 → streamlit_octostar_utils-0.5.0}/streamlit_octostar_utils/threading/session_state_hot_swapper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: streamlit-octostar-utils
3
- Version: 0.4.2.dev24
3
+ Version: 0.5.0
4
4
  Summary:
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -5,7 +5,7 @@ include = '\.pyi?$'
5
5
 
6
6
  [tool.poetry]
7
7
  name = "streamlit-octostar-utils"
8
- version = "0.4.2-dev.24"
8
+ version = "0.5.0"
9
9
  description = ""
10
10
  license = "MIT"
11
11
  authors = ["Octostar"]
@@ -9,12 +9,10 @@ import subprocess
9
9
  from fastapi import Query
10
10
  import time
11
11
  import os
12
- import pickle
13
12
  import atexit
14
13
  import redis
15
14
  import uuid
16
15
  import json
17
- import hashlib
18
16
  import shutil
19
17
  import threading
20
18
  from pottery import Redlock
@@ -57,18 +55,40 @@ class CeleryQueueConfig:
57
55
  max_tasks_in_queue=None,
58
56
  max_tasks_per_child=None,
59
57
  max_memory_per_child=None,
60
- stall_timeout=1200,
61
58
  **options,
62
59
  ):
63
60
  self.n_workers = n_workers
64
61
  self.max_tasks_in_queue = max_tasks_in_queue
65
62
  self.max_tasks_per_child = max_tasks_per_child
66
63
  self.max_memory_per_child = max_memory_per_child # KiB
67
- self.stall_timeout = stall_timeout # seconds; None or 0 to disable
68
64
  self.options = options
69
65
 
70
66
 
67
+ class TaskResult:
68
+ """Wrapper for task results that include binary parts alongside JSON data.
69
+ Tasks returning binary data (e.g. images) should return a TaskResult
70
+ so that serialized_io writes them as multipart parts instead of attempting
71
+ JSON serialization on bytes."""
72
+
73
+ def __init__(self, data, part=None):
74
+ self.data = data
75
+ self.part = part
76
+
77
+
71
78
  class CelerySerialized:
79
+ """Serializes task data to a boundary-delimited multipart file.
80
+
81
+ Format: metadata JSON part followed by optional binary/streamed parts,
82
+ separated by boundary markers (the task_id). Replaces pickle entirely.
83
+ """
84
+
85
+ BOUNDARY_PREFIX = b"--"
86
+ BOUNDARY_SUFFIX = b"\r\n"
87
+ BOUNDARY_END = b"--\r\n"
88
+ CONTENT_TYPE_JSON = b"Content-Type: application/json\r\n"
89
+ CONTENT_TYPE_BYTES = b"Content-Type: application/octet-stream\r\n"
90
+ HEADER_END = b"\r\n"
91
+
72
92
  def __init__(self, folder, redis_client, data=None):
73
93
  self.folder = folder
74
94
  self.data = data
@@ -77,18 +97,70 @@ class CelerySerialized:
77
97
  def set_task_id(self, task_id):
78
98
  self.task_id = task_id
79
99
 
80
- def dump(self):
100
+ def _boundary(self):
101
+ return self.task_id.encode()
102
+
103
+ def _write_boundary(self, f):
104
+ f.write(self.BOUNDARY_PREFIX + self._boundary() + self.BOUNDARY_SUFFIX)
105
+
106
+ def _write_end_boundary(self, f):
107
+ f.write(self.BOUNDARY_PREFIX + self._boundary() + self.BOUNDARY_END)
108
+
109
+ def dump(self, parts=None, part_is_list=False):
110
+ """Write metadata + optional parts in multipart format.
111
+ parts: optional list of bytes objects to write as additional parts.
112
+ part_is_list: whether the original part was a list (preserves type on read).
113
+ """
81
114
  assert self.task_id
115
+ if isinstance(self.data, dict):
116
+ metadata = self.data
117
+ else:
118
+ metadata = {"data": self.data}
119
+ metadata["part_count"] = len(parts) if parts else 0
120
+ metadata["part_is_list"] = part_is_list
82
121
  with RedisFileLock(self.redis_client, os.path.join(self.folder, self.task_id)):
83
- with open(os.path.join(self.folder, self.task_id), "wb") as target_file:
84
- pickle.dump(self.data, file=target_file, protocol=pickle.HIGHEST_PROTOCOL)
122
+ with open(os.path.join(self.folder, self.task_id), "wb") as f:
123
+ self._write_boundary(f)
124
+ f.write(self.CONTENT_TYPE_JSON)
125
+ f.write(self.HEADER_END)
126
+ f.write(json.dumps(metadata).encode())
127
+ f.write(b"\r\n")
128
+ if parts:
129
+ for part in parts:
130
+ self._write_boundary(f)
131
+ f.write(self.CONTENT_TYPE_BYTES)
132
+ f.write(self.HEADER_END)
133
+ f.write(part)
134
+ f.write(b"\r\n")
135
+ self._write_end_boundary(f)
85
136
 
86
137
  def load(self):
138
+ """Read multipart file. Returns (metadata_dict, list_of_bytes_parts)."""
87
139
  assert self.task_id
140
+ boundary = self.BOUNDARY_PREFIX + self._boundary()
141
+ end_boundary = self.BOUNDARY_PREFIX + self._boundary() + b"--"
88
142
  with RedisFileLock(self.redis_client, os.path.join(self.folder, self.task_id)):
89
- with open(os.path.join(self.folder, self.task_id), "rb") as source_file:
90
- data = pickle.load(source_file)
91
- return data
143
+ with open(os.path.join(self.folder, self.task_id), "rb") as f:
144
+ raw = f.read()
145
+ sections = raw.split(boundary)
146
+ metadata = None
147
+ parts = []
148
+ for section in sections:
149
+ section = section.strip(b"\r\n")
150
+ if not section or section == b"--":
151
+ continue
152
+ header_end = section.find(b"\r\n\r\n")
153
+ if header_end == -1:
154
+ continue
155
+ header = section[:header_end]
156
+ body = section[header_end + 4:]
157
+ if body.endswith(b"\r\n"):
158
+ body = body[:-2]
159
+ if b"application/json" in header:
160
+ metadata = json.loads(body)
161
+ else:
162
+ parts.append(body)
163
+ return metadata or {}, parts
92
164
 
93
165
 
94
166
  class CeleryExecutor(object):
@@ -129,7 +201,10 @@ class CeleryExecutor(object):
129
201
  self.get_thread_pool = None
130
202
  self.set_thread_pool = None
131
203
  self.io_thread_pool = None
132
- self.queue_threadlocks = {k: threading.Lock() for k in self.queue_config.keys()}
204
+ self.queue_semaphores = {
205
+ k: threading.Semaphore(v.max_tasks_in_queue) if v.max_tasks_in_queue else None
206
+ for k, v in self.queue_config.items()
207
+ }
133
208
 
134
209
  # Folder setup
135
210
  self.base_folder = Path(base_folder).resolve()
@@ -147,7 +222,7 @@ class CeleryExecutor(object):
147
222
  self.app = Celery(self.filename)
148
223
  self.app.conf.broker_url = f"redis://{self.redis_host}:{self.redis_port}/0"
149
224
  self.app.conf.result_backend = f"redis://{self.redis_host}:{self.redis_port}/0"
150
- self.app.conf.track_started = True
225
+ self.app.conf.track_started = False
151
226
  self.app.conf.task_serializer = "json"
152
227
  self.app.conf.result_serializer = "json"
153
228
  self.app.conf.accept_content = ["application/json"]
@@ -179,11 +254,11 @@ class CeleryExecutor(object):
179
254
  self.worker_info = {}
180
255
 
181
256
  # Queue stall detection
182
- self._queue_fingerprints = {}
183
- self._queue_fingerprint_changed_at = {}
184
257
  self._queue_stalled = {}
258
+ self._last_stall_check = 0
259
+ self._stall_check_interval = 60
185
260
 
186
- atexit.register(self.close)
261
+ self._is_owner = False
187
262
  self.set_cleanup_task()
188
263
  self.register_state_signals()
189
264
 
@@ -213,28 +288,44 @@ class CeleryExecutor(object):
213
288
  def set_started_state(self, task_id, task, *args, **kwargs):
214
289
  result = AsyncResult(task_id, app=self.app)
215
290
  result.backend.store_result(task_id, result=None, state=CeleryExecutor.STARTED)
291
+ queue = task.request.delivery_info.get(
292
+ "routing_key", self.app.conf.task_default_routing_key
293
+ ) if task else None
294
+ if queue:
295
+ self.redis_client.set(f"queue:last_started:{queue}", str(time.time()))
296
+ request_timelimit = getattr(getattr(task, "request", None), "timelimit", None) or (None, None)
297
+ time_limit = request_timelimit[0] or getattr(task, "time_limit", None) or 0
298
+ extended_ttl = int(time_limit) + int(self.app.conf.result_expires)
299
+ if extended_ttl > int(self.app.conf.result_expires):
300
+ self.redis_client.eval(
301
+ "for _, k in ipairs(KEYS) do redis.call('expire', k, ARGV[1]) end",
302
+ 2,
303
+ f"{CeleryExecutor.CELERY_BROKER_PREFIX}{task_id}",
304
+ f"task:queue:{task_id}",
305
+ extended_ttl,
306
+ )
216
307
 
217
308
  def register_worker_initialization(self):
218
309
  if self.preload_functions:
219
310
  celery_signals.worker_process_init.connect(self.preload_on_worker_init)
220
311
 
221
- def set_last_completed_time(self, sender=None, task_id=None, task=None, **kwargs):
312
+ def cleanup_task_keys(self, sender=None, task_id=None, **kwargs):
222
313
  try:
223
- queue = task.request.delivery_info.get(
224
- "routing_key", self.app.conf.task_default_routing_key
225
- ) if task else None
226
- if queue:
227
- self.redis_client.set(f"queue:last_completed:{queue}", str(time.time()))
314
+ if task_id:
315
+ self.redis_client.delete(f"task:queue:{task_id}")
228
316
  except Exception:
229
317
  pass
230
318
 
231
319
  def register_state_signals(self):
232
320
  celery_signals.before_task_publish.connect(self.set_awaiting_state)
233
321
  celery_signals.task_prerun.connect(self.set_started_state)
234
- celery_signals.task_postrun.connect(self.set_last_completed_time)
322
+ celery_signals.task_postrun.connect(self.cleanup_task_keys)
235
323
 
236
324
  def cleanup_task_results(in_dir, out_dir, redis_host, redis_port, task_expires, result_expires):
237
325
  logger.info("Starting cleanup of expired task results...")
326
+ if not os.path.isdir(in_dir) or not os.path.isdir(out_dir):
327
+ logger.warning(f"Data directories missing (in={os.path.isdir(in_dir)}, out={os.path.isdir(out_dir)}), skipping cleanup")
328
+ return
238
329
  redis_client = redis.StrictRedis(host=redis_host, port=redis_port)
239
330
  for file_name in os.listdir(in_dir):
240
331
  file_path = os.path.join(in_dir, file_name)
@@ -304,6 +395,8 @@ class CeleryExecutor(object):
304
395
  }
305
396
 
306
397
  def start(self):
398
+ self._is_owner = True
399
+ atexit.register(self.close)
307
400
  logger.info("Initializing data folders...")
308
401
  shutil.rmtree(self.root_folder, ignore_errors=True)
309
402
  for folder in [self.in_folder, self.out_folder]:
@@ -330,6 +423,9 @@ class CeleryExecutor(object):
330
423
  if attempts_done == CeleryExecutor.MAX_STARTUP_CHECKS:
331
424
  raise TimeoutError("Redis not ready after a long wait!")
332
425
  self.redis_client.flushall()
426
+ boot_time = str(time.time())
427
+ for queue_name in self.queue_config:
428
+ self.redis_client.set(f"queue:last_started:{queue_name}", boot_time)
333
429
  for queue, queue_config in self.queue_config.items():
334
430
  for slot in range(queue_config.n_workers):
335
431
  worker_name = f"celery@{self.name}:{queue}:{slot}"
@@ -404,7 +500,10 @@ class CeleryExecutor(object):
404
500
  while not self.stop_event.is_set():
405
501
  try:
406
502
  self._restart_dead_processes()
407
- self._check_queue_stalls()
503
+ now = time.time()
504
+ if now - self._last_stall_check >= self._stall_check_interval:
505
+ self._last_stall_check = now
506
+ self._check_queue_stalls()
408
507
  time.sleep(5)
409
508
  except Exception as e:
410
509
  logger.error(f"Error in worker health check: {e}")
@@ -439,44 +538,56 @@ class CeleryExecutor(object):
439
538
  logger.info(f"Restarted beat process (PID: {self.beat_process.pid})")
440
539
 
441
540
  def _check_queue_stalls(self):
442
- for queue_name, queue_config in self.queue_config.items():
443
- if not queue_config.stall_timeout:
444
- continue
541
+ try:
542
+ inspector = self.app.control.inspect(timeout=5)
543
+ active_data = inspector.active() or {}
544
+ except Exception as e:
545
+ logger.error(f"Failed to inspect active workers: {e}")
546
+ return
547
+
548
+ now = time.time()
549
+ grace_period = 3 * self.app.conf.worker_proc_alive_timeout
550
+
551
+ active_queues = set()
552
+ for _worker_name, tasks in active_data.items():
553
+ for task_info in (tasks or []):
554
+ queue = (task_info.get("delivery_info") or {}).get(
555
+ "routing_key", self.app.conf.task_default_routing_key
556
+ )
557
+ active_queues.add(queue)
558
+
559
+ for queue_name in self.queue_config:
445
560
  try:
446
- queue_items = self.redis_client.lrange(queue_name, 0, -1)
447
- if len(queue_items) == 0:
561
+ pending_count = self.redis_client.llen(queue_name)
562
+ was_stalled = self._queue_stalled.get(queue_name, False)
563
+
564
+ if pending_count == 0:
565
+ self.redis_client.delete(f"queue:first_enqueued:{queue_name}")
566
+ if was_stalled:
567
+ logger.info(f"Queue '{queue_name}' has recovered from stall.")
448
568
  self._queue_stalled[queue_name] = False
449
- self._queue_fingerprints.pop(queue_name, None)
450
- self._queue_fingerprint_changed_at.pop(queue_name, None)
451
569
  continue
452
570
 
453
- fingerprint = hashlib.md5(b"".join(sorted(queue_items))).hexdigest()
454
- now_time = time.time()
455
- prev_fingerprint = self._queue_fingerprints.get(queue_name)
456
-
457
- if fingerprint != prev_fingerprint:
458
- self._queue_fingerprints[queue_name] = fingerprint
459
- self._queue_fingerprint_changed_at[queue_name] = now_time
571
+ if queue_name in active_queues:
572
+ if was_stalled:
573
+ logger.info(f"Queue '{queue_name}' has recovered from stall.")
460
574
  self._queue_stalled[queue_name] = False
461
575
  continue
462
576
 
463
- fingerprint_age = now_time - self._queue_fingerprint_changed_at.get(queue_name, now_time)
577
+ last_started_raw = self.redis_client.get(f"queue:last_started:{queue_name}")
578
+ first_enqueued_raw = self.redis_client.get(f"queue:first_enqueued:{queue_name}")
579
+ last_started = float(last_started_raw) if last_started_raw else 0
580
+ first_enqueued = float(first_enqueued_raw) if first_enqueued_raw else 0
581
+ last_activity = max(last_started, first_enqueued)
582
+ time_since_activity = (now - last_activity) if last_activity else float("inf")
464
583
 
465
- last_completed_raw = self.redis_client.get(f"queue:last_completed:{queue_name}")
466
- last_completed = float(last_completed_raw) if last_completed_raw else 0
467
- time_since_completion = (now_time - last_completed) if last_completed else float("inf")
468
-
469
- was_stalled = self._queue_stalled.get(queue_name, False)
470
- is_stalled = (
471
- fingerprint_age >= queue_config.stall_timeout
472
- and time_since_completion >= queue_config.stall_timeout
473
- )
584
+ is_stalled = time_since_activity >= grace_period
474
585
  self._queue_stalled[queue_name] = is_stalled
475
586
 
476
587
  if is_stalled and not was_stalled:
477
588
  logger.error(
478
- f"Queue '{queue_name}' is STALLED: {len(queue_items)} task(s) stuck for "
479
- f"{fingerprint_age:.0f}s with no completions in {time_since_completion:.0f}s. "
589
+ f"Queue '{queue_name}' is STALLED: {pending_count} task(s) pending, "
590
+ f"no active workers, no task started in {time_since_activity:.0f}s. "
480
591
  f"New requests will receive 503."
481
592
  )
482
593
  elif not is_stalled and was_stalled:
@@ -485,6 +596,9 @@ class CeleryExecutor(object):
485
596
  logger.error(f"Error checking stall for queue '{queue_name}': {e}")
486
597
 
487
598
  def close(self):
599
+ if not self._is_owner:
600
+ return
601
+ self._is_owner = False
488
602
  self.stop_event.set()
489
603
  if self.worker_health_check_thread and self.worker_health_check_thread.is_alive():
490
604
  self.worker_health_check_thread.join(timeout=2)
@@ -520,9 +634,16 @@ class CeleryExecutor(object):
520
634
  task_id = task.request.id
521
635
  serialized_data = CelerySerialized(folder=self.in_folder, redis_client=self.redis_client)
522
636
  serialized_data.set_task_id(task_id)
523
- data = serialized_data.load()
637
+ metadata, parts = serialized_data.load()
524
638
  del serialized_data
525
- args, kwargs = data.get("args", []), data.get("kwargs", {})
639
+ args, kwargs = metadata.get("args", []), metadata.get("kwargs", {})
640
+
641
+ part_count = metadata.get("part_count", 0)
642
+ if part_count > 0:
643
+ if metadata.get("part_is_list", part_count > 1):
644
+ args = [parts] + args
645
+ else:
646
+ args = [parts[0]] + args
526
647
 
527
648
  if self.app.conf.task_always_eager:
528
649
  queue = task.request.delivery_info.get("routing_key", self.app.conf.task_default_routing_key)
@@ -533,9 +654,25 @@ class CeleryExecutor(object):
533
654
  queue = task.request.delivery_info.get("routing_key", self.app.conf.task_default_routing_key)
534
655
  task.request.resources = (self.resource_registry or {}).get(queue, {})
535
656
  out_data = task_fn(task, *args, **kwargs)
536
- serialized_data = CelerySerialized(folder=self.out_folder, data=out_data, redis_client=self.redis_client)
657
+ if isinstance(out_data, TaskResult):
658
+ out_part_is_list = isinstance(out_data.part, list)
659
+ if out_data.part is None:
660
+ out_parts = None
661
+ elif out_part_is_list:
662
+ out_parts = out_data.part
663
+ else:
664
+ out_parts = [out_data.part]
665
+ serialized_data = CelerySerialized(
666
+ folder=self.out_folder, data=out_data.data, redis_client=self.redis_client
667
+ )
668
+ else:
669
+ out_parts = None
670
+ out_part_is_list = False
671
+ serialized_data = CelerySerialized(
672
+ folder=self.out_folder, data=out_data, redis_client=self.redis_client
673
+ )
537
674
  serialized_data.set_task_id(task_id)
538
- serialized_data.dump()
675
+ serialized_data.dump(parts=out_parts, part_is_list=out_part_is_list)
539
676
  del serialized_data
540
677
  if os.path.isfile(os.path.join(self.in_folder, task_id)):
541
678
  with RedisFileLock(self.redis_client, os.path.join(self.in_folder, task_id)):
@@ -576,36 +713,23 @@ class CeleryExecutor(object):
576
713
 
577
714
  return decorator
578
715
 
579
- async def send_task(self, task_fn, args=[], kwargs={}, **options) -> str:
716
+ async def send_task(self, task_fn, args=None, kwargs=None, part=None, **options) -> str:
717
+ args = args if args is not None else []
718
+ kwargs = kwargs if kwargs is not None else {}
580
719
  if self.app.conf.task_always_eager and "dev_preload" not in self.app.conf:
581
720
  self.preload_on_worker_init()
582
721
  self.app.conf.dev_preload = True
583
722
 
584
- def _reserve_queue_slot(queue_name):
723
+ def _check_queue_llen(queue_name):
585
724
  if self._queue_stalled.get(queue_name, False):
586
725
  raise CeleryExecutor.QueueStalledException(
587
726
  f"Queue '{queue_name}' is stalled. Service temporarily unavailable."
588
727
  )
589
- limit = self.queue_config[queue_name].max_tasks_in_queue
590
- if limit:
591
- reservation_key = f"queue:reserved:{queue_name}"
592
- with self.queue_threadlocks[queue_name]:
593
- queue_count = self.redis_client.llen(queue_name)
594
- reserved_count = int(self.redis_client.get(reservation_key) or 0)
595
- total_count = queue_count + reserved_count
596
- if total_count >= limit:
597
- raise CeleryExecutor.QueueFullException(
598
- f"Queue '{queue_name}' has reached its limit of {limit} tasks!"
599
- )
600
- self.redis_client.incr(reservation_key)
601
- return True
602
- return False
603
-
604
- def _release_queue_slot(queue_name):
605
- limit = self.queue_config[queue_name].max_tasks_in_queue
606
- if limit:
607
- reservation_key = f"queue:reserved:{queue_name}"
608
- self.redis_client.decr(reservation_key)
728
+ if self.redis_client.llen(queue_name) >= self.queue_config[queue_name].max_tasks_in_queue:
729
+ raise CeleryExecutor.QueueFullException(
730
+ f"Queue '{queue_name}' has reached its limit of "
731
+ f"{self.queue_config[queue_name].max_tasks_in_queue} tasks!"
732
+ )
609
733
 
610
734
  def _write_task_data(in_folder, task_args, task_kwargs, task_id):
611
735
  serialized_data = CelerySerialized(
@@ -620,28 +744,46 @@ class CeleryExecutor(object):
620
744
  task_fn.apply_async(task_id=task_id, **options)
621
745
 
622
746
  def _store_task_queue_mapping(task_id, queue_name):
623
- self.redis_client.set(
624
- f"task:queue:{task_id}", queue_name,
625
- ex=self.app.conf.result_expires,
626
- )
747
+ pipe = self.redis_client.pipeline()
748
+ pipe.set(f"task:queue:{task_id}", queue_name, ex=self.app.conf.result_expires)
749
+ pipe.set(f"queue:first_enqueued:{queue_name}", str(time.time()), nx=True)
750
+ pipe.execute()
627
751
 
628
752
  task_id = str(uuid.uuid4())
629
753
  queue_name = self.app.conf.task_default_routing_key
630
754
  queue_name = getattr(task_fn, "queue", queue_name)
631
755
  queue_name = options.get("queue", queue_name)
632
- reserved = False
756
+
757
+ sem = self.queue_semaphores.get(queue_name)
758
+ acquired = False
759
+ if sem is not None:
760
+ if not sem.acquire(blocking=False):
761
+ raise CeleryExecutor.QueueFullException(
762
+ f"Queue '{queue_name}' has reached its limit of "
763
+ f"{self.queue_config[queue_name].max_tasks_in_queue} tasks!"
764
+ )
765
+ acquired = True
766
+
633
767
  try:
634
- reserved = await asyncio.get_running_loop().run_in_executor(
635
- self.set_thread_pool, _reserve_queue_slot, queue_name
636
- )
637
- await asyncio.get_running_loop().run_in_executor(
638
- self.io_thread_pool,
639
- _write_task_data,
640
- self.in_folder,
641
- args,
642
- kwargs,
643
- task_id,
644
- )
768
+ if acquired:
769
+ await asyncio.get_running_loop().run_in_executor(
770
+ self.set_thread_pool, _check_queue_llen, queue_name
771
+ )
772
+
773
+ if part is not None:
774
+ await self._write_task_data_with_part(
775
+ task_id, args, kwargs, part
776
+ )
777
+ else:
778
+ await asyncio.get_running_loop().run_in_executor(
779
+ self.io_thread_pool,
780
+ _write_task_data,
781
+ self.in_folder,
782
+ args,
783
+ kwargs,
784
+ task_id,
785
+ )
786
+
645
787
  await asyncio.get_running_loop().run_in_executor(
646
788
  self.set_thread_pool, _send_task, task_fn, task_id, options
647
789
  )
@@ -652,17 +794,68 @@ class CeleryExecutor(object):
652
794
  logger.info(f"Cancelling task {task_id} due to disconnect!")
653
795
  await self.terminate_task(task_id)
654
796
  raise
797
+ except Exception:
798
+ try:
799
+ await self.terminate_task(task_id)
800
+ except Exception:
801
+ pass
802
+ raise
655
803
  finally:
656
- if reserved:
657
- await asyncio.get_running_loop().run_in_executor(self.set_thread_pool, _release_queue_slot, queue_name)
804
+ if acquired:
805
+ sem.release()
658
806
  return task_id
659
807
 
808
+ async def _write_task_data_with_part(self, task_id, args, kwargs, part):
809
+ """Write task data with a streamed part to the multipart file.
810
+ The part becomes the first arg on the worker side.
811
+ """
812
+ boundary = CelerySerialized.BOUNDARY_PREFIX + task_id.encode()
813
+ boundary_line = boundary + CelerySerialized.BOUNDARY_SUFFIX
814
+ end_boundary_line = boundary + CelerySerialized.BOUNDARY_END
815
+
816
+ is_list = isinstance(part, list)
817
+ items = part if is_list else [part]
818
+ part_count = len(items)
819
+
820
+ metadata = {"args": args, "kwargs": kwargs, "part_count": part_count, "part_is_list": is_list}
821
+ metadata_bytes = json.dumps(metadata).encode()
822
+
823
+ file_path = os.path.join(self.in_folder, task_id)
824
+ with open(file_path, "wb") as f:
825
+ f.write(boundary_line)
826
+ f.write(CelerySerialized.CONTENT_TYPE_JSON)
827
+ f.write(CelerySerialized.HEADER_END)
828
+ f.write(metadata_bytes)
829
+ f.write(b"\r\n")
830
+
831
+ for item in items:
832
+ f.write(boundary_line)
833
+ f.write(CelerySerialized.CONTENT_TYPE_BYTES)
834
+ f.write(CelerySerialized.HEADER_END)
835
+ if hasattr(item, "__aiter__"):
836
+ async for chunk in item:
837
+ f.write(chunk if isinstance(chunk, bytes) else chunk.encode())
838
+ elif hasattr(item, "read"):
839
+ while True:
840
+ chunk = await item.read(65536)
841
+ if not chunk:
842
+ break
843
+ f.write(chunk if isinstance(chunk, bytes) else chunk.encode())
844
+ elif isinstance(item, bytes):
845
+ f.write(item)
846
+ else:
847
+ raise TypeError(f"Unsupported part item type: {type(item)}")
848
+ f.write(b"\r\n")
849
+
850
+ f.write(end_boundary_line)
851
+
660
852
  async def terminate_task(self, task_id):
661
853
  def _terminate_task(celery_app, task_id):
662
854
  celery_app.control.revoke(task_id, terminate=True)
663
855
 
664
856
  def _remove_task_data(celery_app, in_folder, out_folder, task_id):
665
857
  celery_app.AsyncResult(task_id).forget()
858
+ self.redis_client.delete(f"task:queue:{task_id}")
666
859
  if os.path.isfile(os.path.join(in_folder, task_id)):
667
860
  with RedisFileLock(self.redis_client, os.path.join(self.in_folder, task_id)):
668
861
  os.remove(os.path.join(in_folder, task_id))
@@ -698,6 +891,13 @@ class CeleryExecutor(object):
698
891
  self.get_thread_pool, _poll_task_state, self.app, task_id
699
892
  )
700
893
 
894
+ async def get_task_info(self, task_id):
895
+ def _get_info(celery_app, task_id):
896
+ return celery_app.AsyncResult(task_id).info
897
+ return await asyncio.get_running_loop().run_in_executor(
898
+ self.get_thread_pool, _get_info, self.app, task_id,
899
+ )
900
+
701
901
  async def get_task_result(self, task_id, remove=False):
702
902
  def _try_get_task_data(celery_app, task_id):
703
903
  async_result = celery_app.AsyncResult(task_id)
@@ -710,11 +910,19 @@ class CeleryExecutor(object):
710
910
  def _read_task_data(out_folder, task_id):
711
911
  serialized_data = CelerySerialized(folder=out_folder, redis_client=self.redis_client)
712
912
  serialized_data.set_task_id(task_id)
713
- result = serialized_data.load()
714
- return result
913
+ metadata, parts = serialized_data.load()
914
+ data = metadata.get("data", metadata)
915
+ part_count = metadata.get("part_count", 0)
916
+ if part_count > 0:
917
+ if metadata.get("part_is_list", part_count > 1):
918
+ return TaskResult(data=data, part=parts)
919
+ else:
920
+ return TaskResult(data=data, part=parts[0])
921
+ return data
715
922
 
716
923
  def _remove_task_data(celery_app, in_folder, out_folder, task_id):
717
924
  celery_app.AsyncResult(task_id).forget()
925
+ self.redis_client.delete(f"task:queue:{task_id}")
718
926
  if os.path.isfile(os.path.join(in_folder, task_id)):
719
927
  with RedisFileLock(self.redis_client, os.path.join(self.in_folder, task_id)):
720
928
  os.remove(os.path.join(in_folder, task_id))
@@ -737,8 +945,10 @@ class CeleryExecutor(object):
737
945
  )
738
946
  return result
739
947
 
740
- async def send_and_wait_task(self, task_fn, args=[], kwargs={}, timeout=60, **options):
741
- task_id = await self.send_task(task_fn, args, kwargs, **options)
948
+ async def send_and_wait_task(self, task_fn, args=None, kwargs=None, part=None, timeout=60, **options):
949
+ args = args if args is not None else []
950
+ kwargs = kwargs if kwargs is not None else {}
951
+ task_id = await self.send_task(task_fn, args, kwargs, part=part, **options)
742
952
  ready = False
743
953
  state = None
744
954
  start_time = time.time()
@@ -773,6 +983,11 @@ class FastAPICeleryTaskRoute(Route):
773
983
  path="/task/{task_id}",
774
984
  methods=["DELETE"],
775
985
  summary="Cancel a queued or running task.",
986
+ description=(
987
+ "Terminate a task by its ID. If the task is still queued "
988
+ "(AWAITING) it is revoked; if it is running (STARTED) the "
989
+ "worker process is interrupted."
990
+ ),
776
991
  status_code=200,
777
992
  responses=DefaultErrorRoute.error_responses,
778
993
  )
@@ -785,7 +1000,14 @@ class FastAPICeleryTaskRoute(Route):
785
1000
  self,
786
1001
  path="/task/{task_id}",
787
1002
  methods=["GET"],
788
- summary="Get task status (and result if available).",
1003
+ summary="Get task status, progress, and result.",
1004
+ description=(
1005
+ "Poll a running or completed task. Returns task_state "
1006
+ "(AWAITING, STARTED, SUCCESS, FAILURE, etc.), and when "
1007
+ "STARTED may include a 'progress' dict with task-specific "
1008
+ "metrics (e.g. nodes_done / nodes_total). When SUCCESS, "
1009
+ "the 'data' field contains the task result."
1010
+ ),
789
1011
  status_code=200,
790
1012
  responses=DefaultErrorRoute.error_responses,
791
1013
  )
@@ -815,10 +1037,12 @@ class FastAPICeleryTaskRoute(Route):
815
1037
  data = {"task_state": "UNKNOWN", "task_id": task_id}
816
1038
  elif state in ["AWAITING", "STARTED"]:
817
1039
  data = {"task_state": state, "task_id": task_id}
1040
+ if state == "STARTED":
1041
+ info = await self.celery_executor.get_task_info(task_id)
1042
+ if info and isinstance(info, dict):
1043
+ data["progress"] = info
818
1044
  elif state == "SUCCESS":
819
1045
  data = {"task_state": state, "task_id": task_id, "data": result}
820
- elif state == "STARTED":
821
- data = {"task_status": state, "task_id": task_id}
822
1046
  else:
823
1047
  raise ValueError(f"Unknown task state {state}!")
824
1048
  return CommonModels.DataResponseModel(data=data)