pico-ioc 1.5.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pico_ioc/container.py CHANGED
@@ -1,168 +1,305 @@
1
- from __future__ import annotations
2
-
1
+ # src/pico_ioc/container.py
3
2
  import inspect
4
- from typing import Any, Dict, get_origin, get_args, Annotated
5
- import typing as _t
3
+ import contextvars
4
+ from typing import Any, Dict, List, Optional, Tuple, overload, Union
5
+ from contextlib import contextmanager
6
+ from .constants import LOGGER, PICO_META
7
+ from .exceptions import CircularDependencyError, ComponentCreationError, ProviderNotFoundError
8
+ from .factory import ComponentFactory
9
+ from .locator import ComponentLocator
10
+ from .scope import ScopedCaches, ScopeManager
11
+ from .aop import UnifiedComponentProxy, ContainerObserver
6
12
 
7
- from .proxy import IoCProxy
8
- from .interceptors import MethodInterceptor, ContainerInterceptor, MethodCtx, ResolveCtx, CreateCtx, run_resolve_chain, run_create_chain
9
- from .decorators import QUALIFIERS_KEY
10
- from . import _state
13
+ KeyT = Union[str, type]
14
+ _resolve_chain: contextvars.ContextVar[Tuple[KeyT, ...]] = contextvars.ContextVar("pico_resolve_chain", default=())
11
15
 
12
- class Binder:
13
- def __init__(self, container: PicoContainer):
14
- self._c = container
16
+ class PicoContainer:
17
+ _container_id_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("pico_container_id", default=None)
18
+ _container_registry: Dict[str, "PicoContainer"] = {}
15
19
 
16
- def bind(self, key: Any, provider, *, lazy: bool, tags: tuple[str, ...] = ()):
17
- self._c.bind(key, provider, lazy=lazy, tags=tags)
20
+ class _Ctx:
21
+ def __init__(self, container_id: str, profiles: Tuple[str, ...], created_at: float) -> None:
22
+ self.container_id = container_id
23
+ self.profiles = profiles
24
+ self.created_at = created_at
25
+ self.resolve_count = 0
26
+ self.cache_hit_count = 0
18
27
 
19
- def has(self, key: Any) -> bool:
20
- return self._c.has(key)
28
+ def __init__(self, component_factory: ComponentFactory, caches: ScopedCaches, scopes: ScopeManager, observers: Optional[List["ContainerObserver"]] = None, container_id: Optional[str] = None, profiles: Tuple[str, ...] = ()) -> None:
29
+ self._factory = component_factory
30
+ self._caches = caches
31
+ self.scopes = scopes
32
+ self._locator: Optional[ComponentLocator] = None
33
+ self._observers = list(observers or [])
34
+ self.container_id = container_id or self._generate_container_id()
35
+ import time as _t
36
+ self.context = PicoContainer._Ctx(container_id=self.container_id, profiles=profiles, created_at=_t.time())
37
+ PicoContainer._container_registry[self.container_id] = self
21
38
 
22
- def get(self, key: Any):
23
- return self._c.get(key)
39
+ @staticmethod
40
+ def _generate_container_id() -> str:
41
+ import time as _t, random as _r
42
+ return f"c{_t.time_ns():x}{_r.randrange(1<<16):04x}"
24
43
 
25
- class PicoContainer:
26
- def __init__(self, providers: Dict[Any, Dict[str, Any]] | None = None):
27
- self._providers = providers or {}
28
- self._singletons: Dict[Any, Any] = {}
29
- self._method_interceptors: tuple[MethodInterceptor, ...] = ()
30
- self._container_interceptors: tuple[ContainerInterceptor, ...] = ()
31
- self._active_profiles: tuple[str, ...] = ()
32
- self._seen_interceptor_types: set[type] = set()
33
- self._method_cap: int | None = None
34
-
35
- def add_method_interceptor(self, it: MethodInterceptor) -> None:
36
- self._method_interceptors = self._method_interceptors + (it,)
37
-
38
- def add_container_interceptor(self, it: ContainerInterceptor) -> None:
39
- t = type(it)
40
- if t in self._seen_interceptor_types:
41
- return
42
- self._seen_interceptor_types.add(t)
43
- self._container_interceptors = self._container_interceptors + (it,)
44
-
45
- def set_method_cap(self, n: int | None) -> None:
46
- self._method_cap = (int(n) if n is not None else None)
47
-
48
- def binder(self) -> Binder:
49
- return Binder(self)
50
-
51
- def bind(self, key: Any, provider, *, lazy: bool, tags: tuple[str, ...] = ()):
52
- self._singletons.pop(key, None)
53
- meta = {"factory": provider, "lazy": bool(lazy)}
54
- try:
55
- q = getattr(key, QUALIFIERS_KEY, ())
56
- except Exception:
57
- q = ()
58
- meta["qualifiers"] = tuple(q) if q else ()
59
- meta["tags"] = tuple(tags) if tags else ()
60
- self._providers[key] = meta
61
-
62
- def has(self, key: Any) -> bool:
63
- return key in self._providers
64
-
65
- def _notify_resolve(self, key: Any, ann: Any, quals: tuple[str, ...] | tuple()):
66
- ctx = ResolveCtx(key=key, qualifiers={q: True for q in quals or ()}, requested_by=None, profiles=self._active_profiles)
67
- run_resolve_chain(self._container_interceptors, ctx)
68
-
69
- def get(self, key: Any):
70
- if _state._scanning.get() and not _state._resolving.get():
71
- raise RuntimeError("re-entrant container access during scan")
72
- prov = self._providers.get(key)
73
- if prov is None:
74
- raise NameError(f"No provider found for key {key!r}")
75
- if key in self._singletons:
76
- return self._singletons[key]
77
- def base_provider():
78
- return prov["factory"]()
79
- cls = key if isinstance(key, type) else None
80
- ctx = CreateCtx(key=key, component=cls, provider=base_provider, profiles=self._active_profiles)
81
- tok = _state._resolving.set(True)
44
+ @classmethod
45
+ def get_current(cls) -> Optional["PicoContainer"]:
46
+ cid = cls._container_id_var.get()
47
+ return cls._container_registry.get(cid) if cid else None
48
+
49
+ @classmethod
50
+ def get_current_id(cls) -> Optional[str]:
51
+ return cls._container_id_var.get()
52
+
53
+ @classmethod
54
+ def all_containers(cls) -> Dict[str, "PicoContainer"]:
55
+ return dict(cls._container_registry)
56
+
57
+ def activate(self) -> contextvars.Token:
58
+ return PicoContainer._container_id_var.set(self.container_id)
59
+
60
+ def deactivate(self, token: contextvars.Token) -> None:
61
+ PicoContainer._container_id_var.reset(token)
62
+
63
+ @contextmanager
64
+ def as_current(self):
65
+ token = self.activate()
82
66
  try:
83
- instance = run_create_chain(self._container_interceptors, ctx)
67
+ yield self
84
68
  finally:
85
- _state._resolving.reset(tok)
86
- if self._method_interceptors and not isinstance(instance, IoCProxy):
87
- chain = self._method_interceptors
88
- cap = getattr(self, "_method_cap", None)
89
- if isinstance(cap, int) and cap >= 0:
90
- chain = chain[:cap]
91
- instance = IoCProxy(instance, chain, container=self, request_key=key)
92
- self._singletons[key] = instance
93
- return instance
69
+ self.deactivate(token)
94
70
 
95
- def eager_instantiate_all(self):
96
- for key, prov in list(self._providers.items()):
97
- if not prov["lazy"]:
98
- self.get(key)
71
+ def attach_locator(self, locator: ComponentLocator) -> None:
72
+ self._locator = locator
99
73
 
100
- def get_all(self, base_type: Any):
101
- return tuple(self._resolve_all_for_base(base_type, qualifiers=()))
74
+ def _cache_for(self, key: KeyT):
75
+ md = self._locator._metadata.get(key) if self._locator else None
76
+ sc = (md.scope if md else "singleton")
77
+ return self._caches.for_scope(self.scopes, sc)
102
78
 
103
- def get_all_qualified(self, base_type: Any, *qualifiers: str):
104
- return tuple(self._resolve_all_for_base(base_type, qualifiers=qualifiers))
79
+ def has(self, key: KeyT) -> bool:
80
+ cache = self._cache_for(key)
81
+ return cache.get(key) is not None or self._factory.has(key)
105
82
 
106
- def _resolve_all_for_base(self, base_type: Any, qualifiers=()):
107
- matches = []
108
- for provider_key, meta in self._providers.items():
109
- cls = provider_key if isinstance(provider_key, type) else None
110
- if cls is None:
111
- continue
112
- if _requires_collection_of_base(cls, base_type):
83
+ @overload
84
+ def get(self, key: type) -> Any: ...
85
+ @overload
86
+ def get(self, key: str) -> Any: ...
87
+ def get(self, key: KeyT) -> Any:
88
+ cache = self._cache_for(key)
89
+ cached = cache.get(key)
90
+ if cached is not None:
91
+ self.context.cache_hit_count += 1
92
+ for o in self._observers: o.on_cache_hit(key)
93
+ return cached
94
+ import time as _tm
95
+ t0 = _tm.perf_counter()
96
+ chain = list(_resolve_chain.get())
97
+ for k in chain:
98
+ if k == key:
99
+ raise CircularDependencyError(chain, key)
100
+ token_chain = _resolve_chain.set(tuple(chain + [key]))
101
+ token_container = self.activate()
102
+ try:
103
+ if not self._factory.has(key):
104
+ alt = None
105
+ if isinstance(key, type):
106
+ alt = self._resolve_type_key(key)
107
+ elif isinstance(key, str) and self._locator:
108
+ for k, md in self._locator._metadata.items():
109
+ if md.pico_name == key:
110
+ alt = k
111
+ break
112
+ if alt is not None:
113
+ self._factory.bind(key, lambda a=alt: self.get(a))
114
+ provider = self._factory.get(key)
115
+ try:
116
+ instance = provider()
117
+ except ProviderNotFoundError as e:
118
+ raise
119
+ except Exception as e:
120
+ raise ComponentCreationError(key, e)
121
+ instance = self._maybe_wrap_with_aspects(key, instance)
122
+ cache.put(key, instance)
123
+ self.context.resolve_count += 1
124
+ took_ms = (_tm.perf_counter() - t0) * 1000
125
+ for o in self._observers: o.on_resolve(key, took_ms)
126
+ return instance
127
+ finally:
128
+ _resolve_chain.reset(token_chain)
129
+ self.deactivate(token_container)
130
+
131
+ async def aget(self, key: KeyT) -> Any:
132
+ cache = self._cache_for(key)
133
+ cached = cache.get(key)
134
+ if cached is not None:
135
+ self.context.cache_hit_count += 1
136
+ for o in self._observers: o.on_cache_hit(key)
137
+ return cached
138
+ import time as _tm
139
+ t0 = _tm.perf_counter()
140
+ chain = list(_resolve_chain.get())
141
+ for k in chain:
142
+ if k == key:
143
+ raise CircularDependencyError(chain, key)
144
+ token_chain = _resolve_chain.set(tuple(chain + [key]))
145
+ token_container = self.activate()
146
+ try:
147
+ if not self._factory.has(key):
148
+ alt = None
149
+ if isinstance(key, type):
150
+ alt = self._resolve_type_key(key)
151
+ elif isinstance(key, str) and self._locator:
152
+ for k, md in self._locator._metadata.items():
153
+ if md.pico_name == key:
154
+ alt = k
155
+ break
156
+ if alt is not None:
157
+ self._factory.bind(key, lambda a=alt: self.get(a))
158
+ provider = self._factory.get(key)
159
+ try:
160
+ instance = provider()
161
+ if inspect.isawaitable(instance):
162
+ instance = await instance
163
+ except ProviderNotFoundError as e:
164
+ raise
165
+ except Exception as e:
166
+ raise ComponentCreationError(key, e)
167
+ instance = self._maybe_wrap_with_aspects(key, instance)
168
+ cache.put(key, instance)
169
+ self.context.resolve_count += 1
170
+ took_ms = (_tm.perf_counter() - t0) * 1000
171
+ for o in self._observers: o.on_resolve(key, took_ms)
172
+ return instance
173
+ finally:
174
+ _resolve_chain.reset(token_chain)
175
+ self.deactivate(token_container)
176
+
177
+ def _resolve_type_key(self, key: type):
178
+ if not self._locator:
179
+ return None
180
+ cands: List[Tuple[bool, Any]] = []
181
+ for k, md in self._locator._metadata.items():
182
+ typ = md.provided_type or md.concrete_class
183
+ if not isinstance(typ, type):
113
184
  continue
114
- if _is_compatible(cls, base_type):
115
- prov_qs = meta.get("qualifiers", ())
116
- if all(q in prov_qs for q in qualifiers):
117
- inst = self.get(provider_key)
118
- matches.append(inst)
119
- return matches
120
-
121
- def get_providers(self) -> Dict[Any, Dict]:
122
- return self._providers.copy()
123
-
124
- def _is_protocol(t) -> bool:
125
- return getattr(t, "_is_protocol", False) is True
126
-
127
- def _is_compatible(cls, base) -> bool:
128
- try:
129
- if isinstance(base, type) and issubclass(cls, base):
130
- return True
131
- except TypeError:
132
- pass
133
- if _is_protocol(base):
134
- names = set(getattr(base, "__annotations__", {}).keys())
135
- names.update(n for n in getattr(base, "__dict__", {}).keys() if not n.startswith("_"))
136
- for n in names:
137
- if n.startswith("__") and n.endswith("__"):
185
+ try:
186
+ if typ is not key and issubclass(typ, key):
187
+ cands.append((md.primary, k))
188
+ except Exception:
138
189
  continue
139
- if not hasattr(cls, n):
140
- return False
141
- return True
142
- return False
143
-
144
- def _requires_collection_of_base(cls, base) -> bool:
145
- try:
146
- sig = inspect.signature(cls.__init__)
147
- except Exception:
148
- return False
149
- try:
150
- from .resolver import _get_hints
151
- hints = _get_hints(cls.__init__, owner_cls=cls)
152
- except Exception:
153
- hints = {}
154
- for name, param in sig.parameters.items():
155
- if name == "self":
156
- continue
157
- ann = hints.get(name, param.annotation)
158
- origin = get_origin(ann) or ann
159
- if origin in (list, tuple, _t.List, _t.Tuple):
160
- inner = (get_args(ann) or (object,))[0]
161
- if get_origin(inner) is Annotated:
162
- args = get_args(inner)
163
- if args:
164
- inner = args[0]
165
- if inner is base:
166
- return True
167
- return False
190
+ if not cands:
191
+ return None
192
+ prim = [k for is_p, k in cands if is_p]
193
+ return prim[0] if prim else cands[0][1]
194
+
195
+ def _maybe_wrap_with_aspects(self, key, instance: Any) -> Any:
196
+ if isinstance(instance, UnifiedComponentProxy):
197
+ return instance
198
+ cls = type(instance)
199
+ for _, fn in inspect.getmembers(cls, predicate=lambda m: inspect.isfunction(m) or inspect.ismethod(m) or inspect.iscoroutinefunction(m)):
200
+ if getattr(fn, "_pico_interceptors_", None):
201
+ return UnifiedComponentProxy(container=self, target=instance)
202
+ return instance
203
+
204
+ def cleanup_all(self) -> None:
205
+ for _, obj in self._caches.all_items():
206
+ for _, m in inspect.getmembers(obj, predicate=inspect.ismethod):
207
+ meta = getattr(m, PICO_META, {})
208
+ if meta.get("cleanup", False):
209
+ from .api import _resolve_args
210
+ kwargs = _resolve_args(m, self)
211
+ m(**kwargs)
212
+ if self._locator:
213
+ seen = set()
214
+ for md in self._locator._metadata.values():
215
+ fc = md.factory_class
216
+ if fc and fc not in seen:
217
+ seen.add(fc)
218
+ inst = self.get(fc) if self._factory.has(fc) else fc()
219
+ for _, m in inspect.getmembers(inst, predicate=inspect.ismethod):
220
+ meta = getattr(m, PICO_META, {})
221
+ if meta.get("cleanup", False):
222
+ from .api import _resolve_args
223
+ kwargs = _resolve_args(m, self)
224
+ m(**kwargs)
225
+
226
+ def activate_scope(self, name: str, scope_id: Any):
227
+ return self.scopes.activate(name, scope_id)
228
+
229
+ def deactivate_scope(self, name: str, token: Optional[contextvars.Token]) -> None:
230
+ self.scopes.deactivate(name, token)
231
+
232
+ def info(self, msg: str) -> None:
233
+ LOGGER.info(f"[{self.container_id[:8]}] {msg}")
234
+
235
+ @contextmanager
236
+ def scope(self, name: str, scope_id: Any):
237
+ tok = self.activate_scope(name, scope_id)
238
+ try:
239
+ yield self
240
+ finally:
241
+ self.deactivate_scope(name, tok)
242
+
243
+ def health_check(self) -> Dict[str, bool]:
244
+ out: Dict[str, bool] = {}
245
+ for k, obj in self._caches.all_items():
246
+ for name, m in inspect.getmembers(obj, predicate=callable):
247
+ if getattr(m, PICO_META, {}).get("health_check", False):
248
+ try:
249
+ out[f"{getattr(k,'__name__',k)}.{name}"] = bool(m())
250
+ except Exception:
251
+ out[f"{getattr(k,'__name__',k)}.{name}"] = False
252
+ return out
253
+
254
+ async def cleanup_all_async(self) -> None:
255
+ for _, obj in self._caches.all_items():
256
+ for _, m in inspect.getmembers(obj, predicate=inspect.ismethod):
257
+ meta = getattr(m, PICO_META, {})
258
+ if meta.get("cleanup", False):
259
+ from .api import _resolve_args
260
+ res = m(**_resolve_args(m, self))
261
+ import inspect as _i
262
+ if _i.isawaitable(res):
263
+ await res
264
+ if self._locator:
265
+ seen = set()
266
+ for md in self._locator._metadata.values():
267
+ fc = md.factory_class
268
+ if fc and fc not in seen:
269
+ seen.add(fc)
270
+ inst = self.get(fc) if self._factory.has(fc) else fc()
271
+ for _, m in inspect.getmembers(inst, predicate=inspect.ismethod):
272
+ meta = getattr(m, PICO_META, {})
273
+ if meta.get("cleanup", False):
274
+ from .api import _resolve_args
275
+ res = m(**_resolve_args(m, self))
276
+ import inspect as _i
277
+ if _i.isawaitable(res):
278
+ await res
279
+ try:
280
+ from .event_bus import EventBus
281
+ for _, obj in self._caches.all_items():
282
+ if isinstance(obj, EventBus):
283
+ await obj.aclose()
284
+ except Exception:
285
+ pass
286
+
287
+ def stats(self) -> Dict[str, Any]:
288
+ import time as _t
289
+ resolves = self.context.resolve_count
290
+ hits = self.context.cache_hit_count
291
+ total = resolves + hits
292
+ return {
293
+ "container_id": self.container_id,
294
+ "profiles": self.context.profiles,
295
+ "uptime_seconds": _t.time() - self.context.created_at,
296
+ "total_resolves": resolves,
297
+ "cache_hits": hits,
298
+ "cache_hit_rate": (hits / total) if total > 0 else 0.0,
299
+ "registered_components": len(self._locator._metadata) if self._locator else 0,
300
+ }
301
+
302
+ def shutdown(self) -> None:
303
+ self.cleanup_all()
304
+ PicoContainer._container_registry.pop(self.container_id, None)
168
305
 
pico_ioc/event_bus.py ADDED
@@ -0,0 +1,224 @@
1
+ # src/pico_ioc/event_bus.py
2
+ import asyncio
3
+ import inspect
4
+ import logging
5
+ import threading
6
+ from dataclasses import dataclass, field
7
+ from enum import Enum, auto
8
+ from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Tuple, Type
9
+ from .api import factory, provides, configure, cleanup, primary
10
+ from .exceptions import EventBusClosedError, EventBusError, EventBusQueueFullError, EventBusHandlerError
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ class ExecPolicy(Enum):
15
+ INLINE = auto()
16
+ THREADPOOL = auto()
17
+ TASK = auto()
18
+
19
+ class ErrorPolicy(Enum):
20
+ LOG = auto()
21
+ RAISE = auto()
22
+
23
+ class Event: ...
24
+
25
+ @dataclass(order=True)
26
+ class _Subscriber:
27
+ sort_index: int = field(init=False, repr=False, compare=True)
28
+ priority: int = field(compare=False)
29
+ callback: Callable[[Event], Any] | Callable[[Event], Awaitable[Any]] = field(compare=False)
30
+ policy: ExecPolicy = field(compare=False)
31
+ once: bool = field(compare=False)
32
+ def __post_init__(self):
33
+ self.sort_index = -int(self.priority)
34
+
35
+ class EventBus:
36
+ def __init__(
37
+ self,
38
+ *,
39
+ default_exec_policy: ExecPolicy = ExecPolicy.INLINE,
40
+ error_policy: ErrorPolicy = ErrorPolicy.LOG,
41
+ max_queue_size: int = 0,
42
+ ):
43
+ self._subs: Dict[Type[Event], List[_Subscriber]] = {}
44
+ self._default_policy = default_exec_policy
45
+ self._error_policy = error_policy
46
+ self._queue: Optional[asyncio.Queue[Event]] = asyncio.Queue(max_queue_size) if max_queue_size >= 0 else None
47
+ self._worker_task: Optional[asyncio.Task] = None
48
+ self._worker_loop: Optional[asyncio.AbstractEventLoop] = None
49
+ self._closed = False
50
+ self._lock = threading.RLock()
51
+
52
+ def subscribe(
53
+ self,
54
+ event_type: Type[Event],
55
+ fn: Callable[[Event], Any] | Callable[[Event], Awaitable[Any]],
56
+ *,
57
+ priority: int = 0,
58
+ policy: Optional[ExecPolicy] = None,
59
+ once: bool = False,
60
+ ) -> None:
61
+ with self._lock:
62
+ if self._closed:
63
+ raise EventBusClosedError()
64
+ sub = _Subscriber(priority=priority, callback=fn, policy=policy or self._default_policy, once=once)
65
+ lst = self._subs.setdefault(event_type, [])
66
+ if any(s.callback is fn for s in lst):
67
+ return
68
+ lst.append(sub)
69
+ lst.sort()
70
+
71
+ def unsubscribe(self, event_type: Type[Event], fn: Callable[[Event], Any] | Callable[[Event], Awaitable[Any]]) -> None:
72
+ with self._lock:
73
+ lst = self._subs.get(event_type, [])
74
+ self._subs[event_type] = [s for s in lst if s.callback is not fn]
75
+
76
+ def publish_sync(self, event: Event) -> None:
77
+ try:
78
+ loop = asyncio.get_running_loop()
79
+ except RuntimeError:
80
+ asyncio.run(self.publish(event))
81
+ return
82
+ if loop.is_running():
83
+ async def _bridge():
84
+ await self.publish(event)
85
+ loop.create_task(_bridge())
86
+ else:
87
+ asyncio.run(self.publish(event))
88
+
89
+ async def publish(self, event: Event) -> None:
90
+ if self._closed:
91
+ raise EventBusClosedError()
92
+ with self._lock:
93
+ subs = list(self._subs.get(type(event), []))
94
+ to_remove: List[_Subscriber] = []
95
+ pending: List[asyncio.Task] = []
96
+ for sub in subs:
97
+ try:
98
+ cb = sub.callback
99
+ if inspect.iscoroutinefunction(cb):
100
+ if sub.policy is ExecPolicy.TASK:
101
+ pending.append(asyncio.create_task(cb(event)))
102
+ else:
103
+ await cb(event)
104
+ else:
105
+ if sub.policy is ExecPolicy.THREADPOOL:
106
+ loop = asyncio.get_running_loop()
107
+ await loop.run_in_executor(None, cb, event)
108
+ else:
109
+ cb(event)
110
+ if sub.once:
111
+ to_remove.append(sub)
112
+ except Exception as ex:
113
+ self._handle_error(EventBusHandlerError(type(event).__name__, getattr(sub.callback, "__name__", "<callback>"), ex))
114
+ if pending:
115
+ try:
116
+ await asyncio.gather(*pending, return_exceptions=False)
117
+ except Exception as ex:
118
+ self._handle_error(EventBusError(f"Unhandled error awaiting event tasks: {ex}"))
119
+ if to_remove:
120
+ with self._lock:
121
+ lst = self._subs.get(type(event), [])
122
+ self._subs[type(event)] = [s for s in lst if s not in to_remove]
123
+
124
+ async def start_worker(self) -> None:
125
+ if self._closed:
126
+ raise EventBusClosedError()
127
+ if self._worker_task:
128
+ return
129
+ if self._queue is None:
130
+ self._queue = asyncio.Queue()
131
+ loop = asyncio.get_running_loop()
132
+ self._worker_loop = loop
133
+ async def _worker():
134
+ while True:
135
+ evt = await self._queue.get()
136
+ if evt is None:
137
+ self._queue.task_done()
138
+ break
139
+ try:
140
+ await self.publish(evt)
141
+ finally:
142
+ self._queue.task_done()
143
+ self._worker_task = asyncio.create_task(_worker())
144
+
145
+ async def stop_worker(self) -> None:
146
+ if self._worker_task and self._queue and self._worker_loop:
147
+ await self._queue.put(None)
148
+ await self._queue.join()
149
+ await self._worker_task
150
+ self._worker_task = None
151
+ self._worker_loop = None
152
+
153
+ def post(self, event: Event) -> None:
154
+ if self._closed:
155
+ raise EventBusClosedError()
156
+ if self._queue is None:
157
+ raise EventBusError("Worker queue not initialized. Call start_worker().")
158
+ loop = self._worker_loop
159
+ if loop and loop.is_running():
160
+ try:
161
+ current_loop = asyncio.get_running_loop()
162
+ if current_loop is loop:
163
+ try:
164
+ self._queue.put_nowait(event)
165
+ return
166
+ except asyncio.QueueFull:
167
+ raise EventBusQueueFullError()
168
+ except RuntimeError:
169
+ pass
170
+ try:
171
+ loop.call_soon_threadsafe(self._queue.put_nowait, event)
172
+ return
173
+ except asyncio.QueueFull:
174
+ raise EventBusQueueFullError()
175
+ else:
176
+ raise EventBusError("Worker queue not initialized or loop not running. Call start_worker().")
177
+
178
+ async def aclose(self) -> None:
179
+ await self.stop_worker()
180
+ with self._lock:
181
+ self._closed = True
182
+ self._subs.clear()
183
+
184
+ def _handle_error(self, ex: EventBusError) -> None:
185
+ if self._error_policy is ErrorPolicy.RAISE:
186
+ raise ex
187
+ if self._error_policy is ErrorPolicy.LOG:
188
+ log.exception("%s", ex)
189
+
190
+ def subscribe(event_type: Type[Event], *, priority: int = 0, policy: ExecPolicy = ExecPolicy.INLINE, once: bool = False):
191
+ def dec(fn: Callable[[Event], Any] | Callable[[Event], Awaitable[Any]]):
192
+ subs: Iterable[Tuple[Type[Event], int, ExecPolicy, bool]] = getattr(fn, "_pico_subscriptions_", ())
193
+ subs = list(subs)
194
+ subs.append((event_type, int(priority), policy, bool(once)))
195
+ setattr(fn, "_pico_subscriptions_", tuple(subs))
196
+ return fn
197
+ return dec
198
+
199
+ class AutoSubscriberMixin:
200
+ @configure
201
+ def _pico_autosubscribe(self, event_bus: EventBus) -> None:
202
+ for _, attr in inspect.getmembers(self, predicate=callable):
203
+ subs: Iterable[Tuple[Type[Event], int, ExecPolicy, bool]] = getattr(attr, "_pico_subscriptions_", ())
204
+ for evt_t, pr, pol, once in subs:
205
+ event_bus.subscribe(evt_t, attr, priority=pr, policy=pol, once=once)
206
+
207
+ @factory
208
+ @primary
209
+ class PicoEventBusProvider:
210
+ @provides(EventBus)
211
+ def build(self) -> EventBus:
212
+ return EventBus()
213
+ @cleanup
214
+ def shutdown(self, event_bus: EventBus) -> None:
215
+ try:
216
+ loop = asyncio.get_running_loop()
217
+ except RuntimeError:
218
+ asyncio.run(event_bus.aclose())
219
+ return
220
+ if loop.is_running():
221
+ loop.create_task(event_bus.aclose())
222
+ else:
223
+ asyncio.run(event_bus.aclose())
224
+