pico-ioc 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pico_ioc/config.py ADDED
@@ -0,0 +1,332 @@
1
+ # src/pico_ioc/config.py
2
+ from __future__ import annotations
3
+
4
+ import os, json, configparser, pathlib
5
+ from dataclasses import is_dataclass, fields, MISSING
6
+ from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Protocol
7
+
8
+ # ---- Flags & metadata on classes / fields ----
9
+ _CONFIG_FLAG = "_pico_is_config_component"
10
+ _CONFIG_PREFIX = "_pico_config_prefix"
11
+ _FIELD_META = "_pico_config_field_meta" # dict: name -> FieldSpec
12
+
13
+ # ---- Source protocol & implementations ----
14
+
15
+ class ConfigSource(Protocol):
16
+ def get(self, key: str) -> Optional[str]: ...
17
+ def keys(self) -> Iterable[str]: ...
18
+
19
+ class EnvSource:
20
+ def __init__(self, prefix: str = ""):
21
+ self.prefix = prefix or ""
22
+ def get(self, key: str) -> Optional[str]:
23
+ # try PREFIX+KEY first, then KEY
24
+ v = os.getenv(self.prefix + key)
25
+ if v is not None:
26
+ return v
27
+ return os.getenv(key)
28
+ def keys(self) -> Iterable[str]:
29
+ # best-effort; env keys only (without prefix expansion)
30
+ return os.environ.keys()
31
+
32
+ class FileSource:
33
+ def __init__(self, path: os.PathLike[str] | str, optional: bool = False):
34
+ self.path = str(path)
35
+ self.optional = bool(optional)
36
+ self._cache: Dict[str, Any] = {}
37
+ self._load_once()
38
+
39
+ def _load_once(self):
40
+ p = pathlib.Path(self.path)
41
+ if not p.exists():
42
+ if self.optional:
43
+ self._cache = {}
44
+ return
45
+ raise FileNotFoundError(self.path)
46
+ text = p.read_text(encoding="utf-8")
47
+
48
+ # Try in order: JSON, INI, dotenv, YAML (if available)
49
+ # JSON
50
+ try:
51
+ data = json.loads(text)
52
+ self._cache = _flatten_obj(data)
53
+ return
54
+ except Exception:
55
+ pass
56
+ # INI
57
+ try:
58
+ cp = configparser.ConfigParser()
59
+ cp.read_string(text)
60
+ data = {s: dict(cp.items(s)) for s in cp.sections()}
61
+ # also root-level keys under DEFAULT
62
+ data.update(dict(cp.defaults()))
63
+ self._cache = _flatten_obj(data)
64
+ return
65
+ except Exception:
66
+ pass
67
+ # dotenv (simple KEY=VALUE per line)
68
+ try:
69
+ kv = {}
70
+ for line in text.splitlines():
71
+ line = line.strip()
72
+ if not line or line.startswith("#"):
73
+ continue
74
+ if "=" in line:
75
+ k, v = line.split("=", 1)
76
+ kv[k.strip()] = _strip_quotes(v.strip())
77
+ self._cache = _flatten_obj(kv)
78
+ if self._cache:
79
+ return
80
+ except Exception:
81
+ pass
82
+ # YAML if available
83
+ try:
84
+ import yaml # type: ignore
85
+ data = yaml.safe_load(text) or {}
86
+ self._cache = _flatten_obj(data)
87
+ return
88
+ except Exception:
89
+ # if everything fails, fallback to empty (optional) or raise
90
+ if self.optional:
91
+ self._cache = {}
92
+ return
93
+ raise ValueError(f"Unrecognized file format: {self.path}")
94
+
95
+ def get(self, key: str) -> Optional[str]:
96
+ v = self._cache.get(key)
97
+ return None if v is None else str(v)
98
+
99
+ def keys(self) -> Iterable[str]:
100
+ return self._cache.keys()
101
+
102
+ # ---- Field specs (overrides) ----
103
+
104
+ class FieldSpec:
105
+ __slots__ = ("sources", "keys", "default", "path_is_dot")
106
+ def __init__(self, *, sources: Tuple[str, ...], keys: Tuple[str, ...], default: Any, path_is_dot: bool):
107
+ self.sources = sources
108
+ self.keys = keys
109
+ self.default = default
110
+ self.path_is_dot = path_is_dot # true when keys are dotted-paths for structured files
111
+
112
+ class _ValueSentinel:
113
+ def __getitem__(self, key_default: str | Tuple[str, Any], /):
114
+ if isinstance(key_default, tuple):
115
+ key, default = key_default
116
+ else:
117
+ key, default = key_default, MISSING
118
+ # default sources order env>file unless overridden in Value(...)
119
+ return _ValueFactory(key, default)
120
+ Value = _ValueSentinel()
121
+
122
+ class _ValueFactory:
123
+ def __init__(self, key: str, default: Any):
124
+ self.key = key
125
+ self.default = default
126
+ def __call__(self, *, sources: Tuple[str, ...] = ("env","file")):
127
+ return FieldSpec(sources=tuple(sources), keys=(self.key,), default=self.default, path_is_dot=False)
128
+
129
+ class _EnvSentinel:
130
+ def __getitem__(self, key_default: str | Tuple[str, Any], /):
131
+ key, default = (key_default if isinstance(key_default, tuple) else (key_default, MISSING))
132
+ return FieldSpec(sources=("env",), keys=(key,), default=default, path_is_dot=False)
133
+ Env = _EnvSentinel()
134
+
135
+ class _FileSentinel:
136
+ def __getitem__(self, key_default: str | Tuple[str, Any], /):
137
+ key, default = (key_default if isinstance(key_default, tuple) else (key_default, MISSING))
138
+ return FieldSpec(sources=("file",), keys=(key,), default=default, path_is_dot=False)
139
+ File = _FileSentinel()
140
+
141
+ class _PathSentinel:
142
+ class _FilePath:
143
+ def __getitem__(self, key_default: str | Tuple[str, Any], /):
144
+ key, default = (key_default if isinstance(key_default, tuple) else (key_default, MISSING))
145
+ return FieldSpec(sources=("file",), keys=(key,), default=default, path_is_dot=True)
146
+ file = _FilePath()
147
+ Path = _PathSentinel()
148
+
149
+ # ---- Class decorator ----
150
+
151
+ def config_component(*, prefix: str = ""):
152
+ def dec(cls):
153
+ setattr(cls, _CONFIG_FLAG, True)
154
+ setattr(cls, _CONFIG_PREFIX, prefix or "")
155
+ if not hasattr(cls, _FIELD_META):
156
+ setattr(cls, _FIELD_META, {})
157
+ return cls
158
+ return dec
159
+
160
+ def is_config_component(cls: type) -> bool:
161
+ return bool(getattr(cls, _CONFIG_FLAG, False))
162
+
163
+ # ---- Registry / resolution ----
164
+
165
+ class ConfigRegistry:
166
+ """Holds ordered sources and provides typed resolution for @config_component classes."""
167
+ def __init__(self, sources: Sequence[ConfigSource]):
168
+ self.sources = tuple(sources or ())
169
+
170
+ def resolve(self, keys: Iterable[str]) -> Optional[str]:
171
+ # try each key across sources in order
172
+ for key in keys:
173
+ for src in self.sources:
174
+ v = src.get(key)
175
+ if v is not None:
176
+ return v
177
+ return None
178
+
179
+ def register_field_spec(cls: type, name: str, spec: FieldSpec) -> None:
180
+ meta: Dict[str, FieldSpec] = getattr(cls, _FIELD_META, None) or {}
181
+ meta[name] = spec
182
+ setattr(cls, _FIELD_META, meta)
183
+
184
+ def build_component_instance(cls: type, registry: ConfigRegistry) -> Any:
185
+ prefix = getattr(cls, _CONFIG_PREFIX, "")
186
+ overrides: Dict[str, FieldSpec] = getattr(cls, _FIELD_META, {}) or {}
187
+
188
+ if is_dataclass(cls):
189
+ kwargs = {}
190
+ for f in fields(cls):
191
+ name = f.name
192
+ spec = overrides.get(name)
193
+ if spec:
194
+ val = _resolve_with_spec(spec, registry)
195
+ else:
196
+ # auto: PREFIX+NAME or NAME (env), NAME (file)
197
+ val = registry.resolve((prefix + name.upper(), name.upper()))
198
+ if val is None and f.default is not MISSING:
199
+ val = f.default
200
+ elif val is None and f.default_factory is not MISSING: # type: ignore
201
+ val = f.default_factory() # type: ignore
202
+ if val is None and f.default is MISSING and getattr(f, "default_factory", MISSING) is MISSING: # type: ignore
203
+ raise NameError(f"Missing config for field {cls.__name__}.{name}")
204
+ kwargs[name] = _coerce_type(val, f.type)
205
+ return cls(**kwargs)
206
+
207
+ # Non-dataclass: inspect __init__ signature
208
+ import inspect
209
+ sig = inspect.signature(cls.__init__)
210
+ hints = _get_type_hints_safe(cls.__init__, owner=cls)
211
+ kwargs = {}
212
+ for pname, par in sig.parameters.items():
213
+ if pname == "self" or par.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
214
+ continue
215
+ ann = hints.get(pname, par.annotation)
216
+ spec = overrides.get(pname)
217
+ if spec:
218
+ val = _resolve_with_spec(spec, registry)
219
+ else:
220
+ val = registry.resolve((prefix + pname.upper(), pname.upper()))
221
+ if val is None and par.default is not inspect._empty:
222
+ val = par.default
223
+ if val is None and par.default is inspect._empty:
224
+ raise NameError(f"Missing config for field {cls.__name__}.{pname}")
225
+ kwargs[pname] = _coerce_type(val, ann)
226
+ return cls(**kwargs)
227
+
228
+ # ---- helpers ----
229
+
230
+ def _resolve_with_spec(spec: FieldSpec, registry: ConfigRegistry) -> Any:
231
+ # respect spec.sources ordering, but try all keys for each source
232
+ for src_kind in spec.sources:
233
+ if src_kind == "env":
234
+ v = _resolve_from_sources(registry, spec.keys, predicate=lambda s: isinstance(s, EnvSource))
235
+ elif src_kind == "file":
236
+ if spec.path_is_dot:
237
+ v = _resolve_path_from_files(registry, spec.keys)
238
+ else:
239
+ v = _resolve_from_sources(registry, spec.keys, predicate=lambda s: isinstance(s, FileSource))
240
+ else:
241
+ v = None
242
+ if v is not None:
243
+ return v
244
+ return None if spec.default is MISSING else spec.default
245
+
246
+ def _resolve_from_sources(registry: ConfigRegistry, keys: Tuple[str, ...], predicate: Callable[[ConfigSource], bool]) -> Optional[str]:
247
+ for key in keys:
248
+ for src in registry.sources:
249
+ if predicate(src):
250
+ v = src.get(key)
251
+ if v is not None:
252
+ return v
253
+ return None
254
+
255
+ def _resolve_path_from_files(registry: ConfigRegistry, dotted_keys: Tuple[str, ...]) -> Optional[str]:
256
+ for key in dotted_keys:
257
+ path = key.split(".")
258
+ for src in registry.sources:
259
+ if isinstance(src, FileSource):
260
+ # FileSource caches flattened dict already
261
+ v = src.get(key)
262
+ if v is not None:
263
+ return v
264
+ return None
265
+
266
+ def _flatten_obj(obj: Any, prefix: str = "") -> Dict[str, Any]:
267
+ out: Dict[str, Any] = {}
268
+ if isinstance(obj, dict):
269
+ for k, v in obj.items():
270
+ k2 = (prefix + "." + str(k)) if prefix else str(k)
271
+ out.update(_flatten_obj(v, k2))
272
+ elif isinstance(obj, (list, tuple)):
273
+ for i, v in enumerate(obj):
274
+ k2 = (prefix + "." + str(i)) if prefix else str(i)
275
+ out.update(_flatten_obj(v, k2))
276
+ else:
277
+ out[prefix] = obj
278
+ if "." in prefix:
279
+ # also expose leaf as KEY without dots if single-segment? no; keep dotted only
280
+ pass
281
+ # also expose top-level KEY without dots when no prefix used:
282
+ if prefix and "." not in prefix:
283
+ out[prefix] = obj
284
+ # Additionally mirror top-level simple keys as UPPERCASE for convenience
285
+ if prefix and "." not in prefix:
286
+ out[prefix.upper()] = obj
287
+ return out
288
+
289
+ def _strip_quotes(s: str) -> str:
290
+ if (s.startswith('"') and s.endswith('"')) or (s.startswith("'") and s.endswith("'")):
291
+ return s[1:-1]
292
+ return s
293
+
294
+ def _coerce_type(val: Any, ann: Any) -> Any:
295
+ if val is None:
296
+ return None
297
+ # strings from sources come as str; coerce to basic types
298
+ try:
299
+ from typing import get_origin, get_args
300
+ origin = get_origin(ann) or ann
301
+ if origin in (int,):
302
+ return int(val)
303
+ if origin in (float,):
304
+ return float(val)
305
+ if origin in (bool,):
306
+ s = str(val).strip().lower()
307
+ if s in ("1","true","yes","y","on"): return True
308
+ if s in ("0","false","no","n","off"): return False
309
+ return bool(val)
310
+ except Exception:
311
+ pass
312
+ return val
313
+
314
+ def _get_type_hints_safe(fn, owner=None):
315
+ try:
316
+ import inspect
317
+ mod = inspect.getmodule(fn)
318
+ g = getattr(mod, "__dict__", {})
319
+ l = vars(owner) if owner is not None else None
320
+ from typing import get_type_hints
321
+ return get_type_hints(fn, globalns=g, localns=l, include_extras=True)
322
+ except Exception:
323
+ return {}
324
+
325
+ # ---- Public API helpers to be imported by users ----
326
+
327
+ __all__ = [
328
+ "config_component", "EnvSource", "FileSource",
329
+ "Env", "File", "Path", "Value",
330
+ "ConfigRegistry", "register_field_spec", "is_config_component",
331
+ ]
332
+
pico_ioc/container.py CHANGED
@@ -1,8 +1,10 @@
1
- # src/pico_ioc/container.py (Refactorizado)
1
+ # src/pico_ioc/container.py
2
2
  from __future__ import annotations
3
+
3
4
  import inspect
4
- from typing import Any, Dict, get_origin, get_args, Annotated, Sequence, Optional, Callable, Union, Tuple
5
+ from typing import Any, Dict, get_origin, get_args, Annotated
5
6
  import typing as _t
7
+
6
8
  from .proxy import IoCProxy
7
9
  from .interceptors import MethodInterceptor, ContainerInterceptor
8
10
  from .decorators import QUALIFIERS_KEY
@@ -10,7 +12,7 @@ from . import _state
10
12
 
11
13
 
12
14
  class Binder:
13
- def __init__(self, container: "PicoContainer"):
15
+ def __init__(self, container: PicoContainer):
14
16
  self._c = container
15
17
 
16
18
  def bind(self, key: Any, provider, *, lazy: bool, tags: tuple[str, ...] = ()):
@@ -24,14 +26,16 @@ class Binder:
24
26
 
25
27
 
26
28
  class PicoContainer:
27
- def __init__(self, providers: Dict[Any, Dict[str, Any]]):
28
- self._providers = providers
29
+ def __init__(self, providers: Dict[Any, Dict[str, Any]] | None = None):
30
+ self._providers = providers or {}
29
31
  self._singletons: Dict[Any, Any] = {}
30
32
  self._method_interceptors: tuple[MethodInterceptor, ...] = ()
31
33
  self._container_interceptors: tuple[ContainerInterceptor, ...] = ()
32
34
  self._active_profiles: tuple[str, ...] = ()
33
35
  self._seen_interceptor_types: set[type] = set()
34
36
 
37
+ # --- interceptors ---
38
+
35
39
  def add_method_interceptor(self, it: MethodInterceptor) -> None:
36
40
  t = type(it)
37
41
  if t in self._seen_interceptor_types:
@@ -45,9 +49,10 @@ class PicoContainer:
45
49
  return
46
50
  self._seen_interceptor_types.add(t)
47
51
  self._container_interceptors = self._container_interceptors + (it,)
48
-
52
+
53
+ # --- binding ---
54
+
49
55
  def binder(self) -> Binder:
50
- """Returns a binder for this container."""
51
56
  return Binder(self)
52
57
 
53
58
  def bind(self, key: Any, provider, *, lazy: bool, tags: tuple[str, ...] = ()):
@@ -61,21 +66,27 @@ class PicoContainer:
61
66
  meta["tags"] = tuple(tags) if tags else ()
62
67
  self._providers[key] = meta
63
68
 
69
+ # --- resolution ---
70
+
64
71
  def has(self, key: Any) -> bool:
65
72
  return key in self._providers
66
73
 
67
74
  def get(self, key: Any):
68
75
  if _state._scanning.get() and not _state._resolving.get():
69
76
  raise RuntimeError("re-entrant container access during scan")
77
+
70
78
  prov = self._providers.get(key)
71
79
  if prov is None:
72
80
  raise NameError(f"No provider found for key {key!r}")
81
+
73
82
  if key in self._singletons:
74
83
  return self._singletons[key]
75
84
 
76
85
  for ci in self._container_interceptors:
77
- try: ci.on_before_create(key)
78
- except Exception: pass
86
+ try:
87
+ ci.on_before_create(key)
88
+ except Exception:
89
+ pass
79
90
 
80
91
  tok = _state._resolving.set(True)
81
92
  try:
@@ -83,8 +94,10 @@ class PicoContainer:
83
94
  instance = prov["factory"]()
84
95
  except BaseException as exc:
85
96
  for ci in self._container_interceptors:
86
- try: ci.on_exception(key, exc)
87
- except Exception: pass
97
+ try:
98
+ ci.on_exception(key, exc)
99
+ except Exception:
100
+ pass
88
101
  raise
89
102
  finally:
90
103
  _state._resolving.reset(tok)
@@ -103,11 +116,15 @@ class PicoContainer:
103
116
  self._singletons[key] = instance
104
117
  return instance
105
118
 
119
+ # --- lifecycle ---
120
+
106
121
  def eager_instantiate_all(self):
107
122
  for key, prov in list(self._providers.items()):
108
123
  if not prov["lazy"]:
109
124
  self.get(key)
110
125
 
126
+ # --- helpers for multiples ---
127
+
111
128
  def get_all(self, base_type: Any):
112
129
  return tuple(self._resolve_all_for_base(base_type, qualifiers=()))
113
130
 
@@ -133,6 +150,8 @@ class PicoContainer:
133
150
  return self._providers.copy()
134
151
 
135
152
 
153
+ # --- compatibility helpers ---
154
+
136
155
  def _is_protocol(t) -> bool:
137
156
  return getattr(t, "_is_protocol", False) is True
138
157
 
pico_ioc/decorators.py CHANGED
@@ -1,8 +1,12 @@
1
- # pico_ioc/decorators.py
1
+ # src/pico_ioc/decorators.py
2
2
  from __future__ import annotations
3
+
3
4
  import functools
4
5
  from typing import Any, Iterable, Optional, Callable, Tuple, Literal
5
6
 
7
+
8
+ # ---- marker attributes (read by scanner/policy) ----
9
+
6
10
  COMPONENT_FLAG = "_is_component"
7
11
  COMPONENT_KEY = "_component_key"
8
12
  COMPONENT_LAZY = "_component_lazy"
@@ -24,12 +28,16 @@ CONDITIONAL_META = "_pico_conditional"
24
28
  INTERCEPTOR_META = "__pico_interceptor__"
25
29
 
26
30
 
31
+ # ---- core decorators ----
32
+
27
33
  def factory_component(cls):
34
+ """Mark a class as a factory component (its methods can @provides)."""
28
35
  setattr(cls, FACTORY_FLAG, True)
29
36
  return cls
30
37
 
31
38
 
32
39
  def component(cls=None, *, name: Any = None, lazy: bool = False, tags: Iterable[str] = ()):
40
+ """Mark a class as a component. Optional: custom key, lazy instantiation, tags."""
33
41
  def dec(c):
34
42
  setattr(c, COMPONENT_FLAG, True)
35
43
  setattr(c, COMPONENT_KEY, name if name is not None else c)
@@ -40,6 +48,7 @@ def component(cls=None, *, name: Any = None, lazy: bool = False, tags: Iterable[
40
48
 
41
49
 
42
50
  def provides(key: Any, *, lazy: bool = False, tags: Iterable[str] = ()):
51
+ """Declare a factory method that provides a binding for `key`."""
43
52
  def dec(fn):
44
53
  @functools.wraps(fn)
45
54
  def w(*a, **k):
@@ -52,15 +61,20 @@ def provides(key: Any, *, lazy: bool = False, tags: Iterable[str] = ()):
52
61
 
53
62
 
54
63
  def plugin(cls):
64
+ """Mark a class as a Pico plugin (scanner lifecycle)."""
55
65
  setattr(cls, PLUGIN_FLAG, True)
56
66
  return cls
57
67
 
58
68
 
69
+ # ---- qualifiers ----
70
+
59
71
  class Qualifier(str):
72
+ """String qualifier type used with Annotated[T, 'q1', ...]."""
60
73
  __slots__ = ()
61
74
 
62
75
 
63
76
  def qualifier(*qs: Qualifier):
77
+ """Attach one or more qualifiers to a component class key."""
64
78
  def dec(cls):
65
79
  current: Iterable[Qualifier] = getattr(cls, QUALIFIERS_KEY, ())
66
80
  seen = set(current)
@@ -74,7 +88,10 @@ def qualifier(*qs: Qualifier):
74
88
  return dec
75
89
 
76
90
 
91
+ # ---- defaults / selection ----
92
+
77
93
  def on_missing(selector: object, *, priority: int = 0):
94
+ """Declare this target as a default for `selector` when no binding exists."""
78
95
  def dec(obj):
79
96
  setattr(obj, ON_MISSING_META, {"selector": selector, "priority": int(priority)})
80
97
  return obj
@@ -82,16 +99,18 @@ def on_missing(selector: object, *, priority: int = 0):
82
99
 
83
100
 
84
101
  def primary(obj):
102
+ """Hint this candidate should be preferred among equals."""
85
103
  setattr(obj, PRIMARY_FLAG, True)
86
104
  return obj
87
105
 
88
106
 
89
107
  def conditional(
90
108
  *,
91
- profiles: tuple[str, ...] = (),
92
- require_env: tuple[str, ...] = (),
109
+ profiles: Tuple[str, ...] = (),
110
+ require_env: Tuple[str, ...] = (),
93
111
  predicate: Optional[Callable[[], bool]] = None,
94
112
  ):
113
+ """Activate only when profiles/env/predicate conditions pass."""
95
114
  def dec(obj):
96
115
  setattr(obj, CONDITIONAL_META, {
97
116
  "profiles": tuple(profiles),
@@ -102,6 +121,8 @@ def conditional(
102
121
  return dec
103
122
 
104
123
 
124
+ # ---- interceptors ----
125
+
105
126
  def interceptor(
106
127
  _obj=None,
107
128
  *,
@@ -111,6 +132,7 @@ def interceptor(
111
132
  require_env: Tuple[str, ...] = (),
112
133
  predicate: Callable[[], bool] | None = None,
113
134
  ):
135
+ """Declare an interceptor (method or container) with optional activation metadata."""
114
136
  def dec(obj):
115
137
  setattr(obj, INTERCEPTOR_META, {
116
138
  "kind": kind,
@@ -122,13 +144,15 @@ def interceptor(
122
144
  return obj
123
145
  return dec if _obj is None else dec(_obj)
124
146
 
147
+
125
148
  __all__ = [
126
- "component", "factory_component", "provides", "plugin", "qualifier", "Qualifier",
149
+ "component", "factory_component", "provides", "plugin",
150
+ "Qualifier", "qualifier",
151
+ "on_missing", "primary", "conditional", "interceptor",
127
152
  "COMPONENT_FLAG", "COMPONENT_KEY", "COMPONENT_LAZY",
128
153
  "FACTORY_FLAG", "PROVIDES_KEY", "PROVIDES_LAZY",
129
154
  "PLUGIN_FLAG", "QUALIFIERS_KEY", "COMPONENT_TAGS", "PROVIDES_TAGS",
130
- "on_missing", "primary", "conditional",
131
155
  "ON_MISSING_META", "PRIMARY_FLAG", "CONDITIONAL_META",
132
- "interceptor", "INTERCEPTOR_META",
156
+ "INTERCEPTOR_META",
133
157
  ]
134
158
 
pico_ioc/interceptors.py CHANGED
@@ -1,7 +1,9 @@
1
- # pico_ioc/interceptors.py
1
+ # src/pico_ioc/interceptors.py
2
2
  from __future__ import annotations
3
- from typing import Any, Callable, Protocol, Sequence
3
+
4
4
  import inspect
5
+ from typing import Any, Callable, Protocol, Sequence
6
+
5
7
 
6
8
  class Invocation:
7
9
  __slots__ = ("target", "method_name", "call", "args", "kwargs", "is_async")
@@ -15,31 +17,35 @@ class Invocation:
15
17
  self.kwargs = kwargs
16
18
  self.is_async = inspect.iscoroutinefunction(call)
17
19
 
20
+
18
21
  class MethodInterceptor(Protocol):
19
22
  def __call__(self, inv: Invocation, proceed: Callable[[], Any]) -> Any: ...
20
23
 
24
+
21
25
  async def _chain_async(interceptors: Sequence[MethodInterceptor], inv: Invocation, i: int = 0):
22
26
  if i >= len(interceptors):
23
27
  return await inv.call(*inv.args, **inv.kwargs)
24
28
  cur = interceptors[i]
29
+
25
30
  async def next_step():
26
31
  return await _chain_async(interceptors, inv, i + 1)
32
+
27
33
  res = cur(inv, next_step)
28
34
  return await res if inspect.isawaitable(res) else res
29
35
 
36
+
30
37
  def _chain_sync(interceptors: Sequence[MethodInterceptor], inv: Invocation, i: int = 0):
31
38
  if i >= len(interceptors):
32
39
  return inv.call(*inv.args, **inv.kwargs)
33
40
  cur = interceptors[i]
34
41
  return cur(inv, lambda: _chain_sync(interceptors, inv, i + 1))
35
42
 
43
+
36
44
  def dispatch(interceptors: Sequence[MethodInterceptor], inv: Invocation):
45
+ """Dispatch invocation through a chain of interceptors."""
37
46
  if inv.is_async:
38
- # return a coroutine that the caller will await
39
- return _chain_async(interceptors, inv, 0)
40
- # return the final value directly for sync methods
41
- res = _chain_sync(interceptors, inv, 0)
42
- return res
47
+ return _chain_async(interceptors, inv, 0) # coroutine
48
+ return _chain_sync(interceptors, inv, 0) # value
43
49
 
44
50
 
45
51
  class ContainerInterceptor(Protocol):