overload-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. overload/__init__.py +3 -0
  2. overload/__main__.py +5 -0
  3. overload/cli.py +393 -0
  4. overload/collection/__init__.py +1 -0
  5. overload/collection/environment.py +23 -0
  6. overload/collection/models.py +88 -0
  7. overload/collection/parser.py +220 -0
  8. overload/collection/variables.py +84 -0
  9. overload/config_file.py +73 -0
  10. overload/engine/__init__.py +1 -0
  11. overload/engine/assertions.py +151 -0
  12. overload/engine/auth.py +87 -0
  13. overload/engine/events.py +50 -0
  14. overload/engine/http_client.py +274 -0
  15. overload/engine/load_patterns.py +730 -0
  16. overload/engine/models.py +254 -0
  17. overload/engine/rate_limiter.py +124 -0
  18. overload/engine/runner.py +86 -0
  19. overload/report/__init__.py +1 -0
  20. overload/report/exporters.py +77 -0
  21. overload/report/generator.py +71 -0
  22. overload/report/templates/report.html +369 -0
  23. overload/utils/__init__.py +1 -0
  24. overload/utils/naming.py +26 -0
  25. overload/web/__init__.py +1 -0
  26. overload/web/app.py +38 -0
  27. overload/web/routes/__init__.py +1 -0
  28. overload/web/routes/api.py +461 -0
  29. overload/web/routes/ws.py +77 -0
  30. overload/web/static/css/app.css +242 -0
  31. overload/web/static/js/app.js +241 -0
  32. overload/web/static/js/charts.js +385 -0
  33. overload/web/static/js/collection.js +344 -0
  34. overload/web/static/js/runner.js +625 -0
  35. overload/web/templates/index.html +23 -0
  36. overload_cli-0.1.0.dist-info/METADATA +267 -0
  37. overload_cli-0.1.0.dist-info/RECORD +40 -0
  38. overload_cli-0.1.0.dist-info/WHEEL +4 -0
  39. overload_cli-0.1.0.dist-info/entry_points.txt +2 -0
  40. overload_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,220 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from collections.abc import Generator
6
+ from typing import Any
7
+
8
+ from overload.collection.models import (
9
+ AuthConfig,
10
+ CollectionVariable,
11
+ ParsedCollection,
12
+ ParsedRequest,
13
+ QueryParam,
14
+ RequestBody,
15
+ )
16
+
17
+ SUPPORTED_SCHEMA_VERSIONS = ("v2.1.0", "v2.0.0")
18
+
19
+
20
+ def parse_collection(source: str | Path | dict) -> ParsedCollection:
21
+ if isinstance(source, dict):
22
+ data = source
23
+ else:
24
+ path = Path(source)
25
+ if not path.exists():
26
+ raise FileNotFoundError(f"Collection file not found: {path}")
27
+ with open(path, encoding="utf-8") as f:
28
+ data = json.load(f)
29
+
30
+ _validate_schema(data)
31
+
32
+ info = data.get("info", {})
33
+ collection_auth = _parse_auth(data.get("auth"))
34
+ collection_vars = [
35
+ CollectionVariable(
36
+ key=v.get("key", ""),
37
+ value=v.get("value", ""),
38
+ type=v.get("type", "string"),
39
+ )
40
+ for v in data.get("variable", [])
41
+ ]
42
+
43
+ requests = list(
44
+ _flatten_items(data.get("item", []), path=[], parent_auth=collection_auth)
45
+ )
46
+
47
+ return ParsedCollection(
48
+ name=info.get("name", "Unnamed Collection"),
49
+ description=info.get("description", ""),
50
+ requests=requests,
51
+ variables=collection_vars,
52
+ auth=collection_auth,
53
+ )
54
+
55
+
56
+ def _validate_schema(data: dict) -> None:
57
+ schema = data.get("info", {}).get("schema", "")
58
+ if not any(version in schema for version in SUPPORTED_SCHEMA_VERSIONS):
59
+ if not data.get("info", {}).get("name"):
60
+ raise ValueError(
61
+ "Invalid collection format. Expected a Postman Collection v2.x JSON file."
62
+ )
63
+
64
+
65
+ def _flatten_items(
66
+ items: list[dict],
67
+ path: list[str],
68
+ parent_auth: AuthConfig | None,
69
+ ) -> Generator[ParsedRequest, None, None]:
70
+ for item in items:
71
+ if "item" in item and "request" not in item:
72
+ folder_name = item.get("name", "Unnamed Folder")
73
+ folder_auth = _parse_auth(item.get("auth")) or parent_auth
74
+ yield from _flatten_items(
75
+ item["item"], path=path + [folder_name], parent_auth=folder_auth
76
+ )
77
+ elif "request" in item:
78
+ yield _parse_request(item, path, parent_auth)
79
+
80
+
81
+ def _parse_request(
82
+ item: dict, folder_path: list[str], inherited_auth: AuthConfig | None
83
+ ) -> ParsedRequest:
84
+ req = item["request"]
85
+
86
+ if isinstance(req, str):
87
+ return ParsedRequest(
88
+ name=item.get("name", "Unnamed Request"),
89
+ method="GET",
90
+ url_raw=req,
91
+ folder_path=folder_path,
92
+ )
93
+
94
+ method = req.get("method", "GET").upper()
95
+ url_raw, query_params = _parse_url(req.get("url", ""))
96
+ headers = _parse_headers(req.get("header", []))
97
+ body = _parse_body(req.get("body"))
98
+ auth = _parse_auth(req.get("auth")) or inherited_auth
99
+
100
+ return ParsedRequest(
101
+ name=item.get("name", "Unnamed Request"),
102
+ method=method,
103
+ url_raw=url_raw,
104
+ headers=headers,
105
+ body=body,
106
+ auth=auth,
107
+ query_params=query_params,
108
+ folder_path=folder_path,
109
+ )
110
+
111
+
112
+ def _parse_url(url: Any) -> tuple[str, list[QueryParam]]:
113
+ if isinstance(url, str):
114
+ return url, []
115
+
116
+ if isinstance(url, dict):
117
+ raw = url.get("raw", "")
118
+ query_params = [
119
+ QueryParam(
120
+ key=q.get("key", ""),
121
+ value=q.get("value", ""),
122
+ disabled=q.get("disabled", False),
123
+ )
124
+ for q in url.get("query", [])
125
+ ]
126
+
127
+ if not raw:
128
+ protocol = url.get("protocol", "https")
129
+ host = ".".join(url.get("host", []))
130
+ path = "/".join(url.get("path", []))
131
+ raw = f"{protocol}://{host}"
132
+ if path:
133
+ raw += f"/{path}"
134
+
135
+ return raw, query_params
136
+
137
+ return str(url), []
138
+
139
+
140
+ def _parse_headers(headers: list[dict] | None) -> dict[str, str]:
141
+ if not headers:
142
+ return {}
143
+ return {
144
+ h["key"]: h.get("value", "")
145
+ for h in headers
146
+ if not h.get("disabled", False) and "key" in h
147
+ }
148
+
149
+
150
+ def _parse_body(body: dict | None) -> RequestBody:
151
+ if not body:
152
+ return RequestBody(mode="none")
153
+
154
+ mode = body.get("mode", "none")
155
+
156
+ if mode == "raw":
157
+ content = body.get("raw", "")
158
+ lang = body.get("options", {}).get("raw", {}).get("language", "json")
159
+ content_type_map = {
160
+ "json": "application/json",
161
+ "xml": "application/xml",
162
+ "text": "text/plain",
163
+ "html": "text/html",
164
+ "javascript": "application/javascript",
165
+ }
166
+ return RequestBody(
167
+ mode="raw",
168
+ content=content,
169
+ content_type=content_type_map.get(lang, "text/plain"),
170
+ )
171
+
172
+ if mode == "formdata":
173
+ fields = [
174
+ {"key": f["key"], "value": f.get("value", ""), "type": f.get("type", "text")}
175
+ for f in body.get("formdata", [])
176
+ if not f.get("disabled", False)
177
+ ]
178
+ return RequestBody(mode="formdata", content=fields)
179
+
180
+ if mode == "urlencoded":
181
+ fields = [
182
+ {"key": f["key"], "value": f.get("value", "")}
183
+ for f in body.get("urlencoded", [])
184
+ if not f.get("disabled", False)
185
+ ]
186
+ return RequestBody(
187
+ mode="urlencoded",
188
+ content=fields,
189
+ content_type="application/x-www-form-urlencoded",
190
+ )
191
+
192
+ if mode == "graphql":
193
+ gql = body.get("graphql", {})
194
+ return RequestBody(
195
+ mode="graphql",
196
+ content={"query": gql.get("query", ""), "variables": gql.get("variables", "{}")},
197
+ content_type="application/json",
198
+ )
199
+
200
+ if mode == "file":
201
+ src = body.get("file", {}).get("src", "")
202
+ return RequestBody(mode="file", content=src)
203
+
204
+ return RequestBody(mode="none")
205
+
206
+
207
+ def _parse_auth(auth: dict | None) -> AuthConfig | None:
208
+ if not auth:
209
+ return None
210
+
211
+ auth_type = auth.get("type", "noauth")
212
+ if auth_type == "noauth":
213
+ return None
214
+
215
+ params: dict[str, str] = {}
216
+ for item in auth.get(auth_type, []):
217
+ if isinstance(item, dict) and "key" in item:
218
+ params[item["key"]] = item.get("value", "")
219
+
220
+ return AuthConfig(type=auth_type, params=params)
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import random
5
+ import re
6
+ import time
7
+ import uuid
8
+
9
+ from overload.collection.models import CollectionVariable
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ VARIABLE_PATTERN = re.compile(r"\{\{([^}]+)\}\}")
14
+
15
+ DYNAMIC_VARIABLES: dict[str, callable] = {
16
+ "$randomInt": lambda: str(random.randint(0, 1000)),
17
+ "$timestamp": lambda: str(int(time.time())),
18
+ "$guid": lambda: str(uuid.uuid4()),
19
+ "$randomBoolean": lambda: random.choice(["true", "false"]),
20
+ "$randomColor": lambda: random.choice(["red", "blue", "green", "yellow", "purple"]),
21
+ "$randomFirstName": lambda: random.choice(["John", "Jane", "Alice", "Bob", "Charlie"]),
22
+ "$randomEmail": lambda: f"user{random.randint(1, 9999)}@example.com",
23
+ }
24
+
25
+
26
+ class VariableContext:
27
+ def __init__(
28
+ self,
29
+ collection_vars: list[CollectionVariable] | None = None,
30
+ environment_vars: dict[str, str] | None = None,
31
+ runtime_vars: dict[str, str] | None = None,
32
+ ) -> None:
33
+ self._scopes: list[dict[str, str]] = [
34
+ runtime_vars or {},
35
+ environment_vars or {},
36
+ {v.key: v.value for v in (collection_vars or [])},
37
+ ]
38
+ self._unresolved: set[str] = set()
39
+
40
+ @property
41
+ def unresolved(self) -> set[str]:
42
+ return self._unresolved.copy()
43
+
44
+ def set_variable(self, key: str, value: str) -> None:
45
+ self._scopes[0][key] = value
46
+
47
+ def get_variable(self, key: str) -> str | None:
48
+ for scope in self._scopes:
49
+ if key in scope:
50
+ return scope[key]
51
+ return None
52
+
53
+ def get_all_variables(self) -> dict[str, str]:
54
+ merged: dict[str, str] = {}
55
+ for scope in reversed(self._scopes):
56
+ merged.update(scope)
57
+ return merged
58
+
59
+ def resolve(self, template: str) -> str:
60
+ if not template or "{{" not in template:
61
+ return template
62
+
63
+ def _replacer(match: re.Match) -> str:
64
+ var_name = match.group(1).strip()
65
+
66
+ if var_name in DYNAMIC_VARIABLES:
67
+ return DYNAMIC_VARIABLES[var_name]()
68
+
69
+ for scope in self._scopes:
70
+ if var_name in scope:
71
+ value = scope[var_name]
72
+ return self.resolve(value) if "{{" in value else value
73
+
74
+ self._unresolved.add(var_name)
75
+ logger.warning("Unresolved variable: {{%s}}", var_name)
76
+ return match.group(0)
77
+
78
+ return VARIABLE_PATTERN.sub(_replacer, template)
79
+
80
+ def resolve_dict(self, d: dict[str, str]) -> dict[str, str]:
81
+ return {self.resolve(k): self.resolve(v) for k, v in d.items()}
82
+
83
+ def resolve_url(self, url: str) -> str:
84
+ return self.resolve(url)
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import yaml
8
+
9
+ from overload.engine.models import Threshold
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ CONFIG_FILENAME = "overload.config.yaml"
14
+
15
+
16
+ def load_config(path: str | Path) -> dict[str, Any]:
17
+ path = Path(path)
18
+ if not path.exists():
19
+ raise FileNotFoundError(f"Config file not found: {path}")
20
+
21
+ with open(path, encoding="utf-8") as f:
22
+ raw = yaml.safe_load(f)
23
+
24
+ if not isinstance(raw, dict):
25
+ raise ValueError(f"Config file must contain a YAML mapping, got {type(raw).__name__}")
26
+
27
+ return raw
28
+
29
+
30
+ def save_config(
31
+ path: str | Path,
32
+ test_type: str,
33
+ config: dict[str, Any],
34
+ thresholds: list[Threshold] | None = None,
35
+ ) -> str:
36
+ data: dict[str, Any] = {
37
+ "test_type": test_type,
38
+ "config": config,
39
+ }
40
+
41
+ if thresholds:
42
+ data["thresholds"] = [
43
+ {"metric": t.metric, "operator": t.operator, "value": t.value}
44
+ for t in thresholds
45
+ ]
46
+
47
+ path = Path(path)
48
+ with open(path, "w", encoding="utf-8") as f:
49
+ yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False)
50
+
51
+ logger.info("Config saved to %s", path)
52
+ return str(path)
53
+
54
+
55
+ def extract_thresholds(raw: dict[str, Any]) -> list[Threshold]:
56
+ thresholds: list[Threshold] = []
57
+ for entry in raw.get("thresholds", []):
58
+ if not isinstance(entry, dict):
59
+ continue
60
+ thresholds.append(Threshold(
61
+ metric=entry["metric"],
62
+ operator=entry["operator"],
63
+ value=float(entry["value"]),
64
+ ))
65
+ return thresholds
66
+
67
+
68
+ def extract_config(raw: dict[str, Any]) -> dict[str, Any]:
69
+ return raw.get("config", {})
70
+
71
+
72
+ def extract_test_type(raw: dict[str, Any]) -> str | None:
73
+ return raw.get("test_type")
@@ -0,0 +1 @@
1
+ from __future__ import annotations
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import operator as op
5
+ import xml.etree.ElementTree as ET
6
+ from typing import Callable
7
+
8
+ from overload.engine.models import AssertionResult, Threshold, Verdict
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _OPERATORS: dict[str, Callable[[float, float], bool]] = {
13
+ "<": op.lt,
14
+ "<=": op.le,
15
+ ">": op.gt,
16
+ ">=": op.ge,
17
+ "==": op.eq,
18
+ }
19
+
20
+ _METRIC_EXTRACTORS: dict[str, Callable[[dict], float | None]] = {
21
+ "p50_latency_ms": lambda s: s.get("latency", {}).get("median"),
22
+ "p95_latency_ms": lambda s: s.get("latency", {}).get("p95"),
23
+ "p99_latency_ms": lambda s: s.get("latency", {}).get("p99"),
24
+ "max_latency_ms": lambda s: s.get("latency", {}).get("max"),
25
+ "mean_latency_ms": lambda s: s.get("latency", {}).get("mean"),
26
+ "error_rate_pct": lambda s: (
27
+ (s["errors"] / s["total"]) * 100 if s.get("total", 0) > 0 else 0.0
28
+ ),
29
+ "success_rate_pct": lambda s: (
30
+ (s["ok"] / s["total"]) * 100 if s.get("total", 0) > 0 else 0.0
31
+ ),
32
+ "avg_rps": lambda s: s.get("avg_rps"),
33
+ "total_requests": lambda s: float(s["total"]) if "total" in s else None,
34
+ "rate_limited_count": lambda s: float(s.get("rate_limited", 0)),
35
+ }
36
+
37
+
38
+ def evaluate(computed: dict, thresholds: list[Threshold]) -> Verdict:
39
+ results: list[AssertionResult] = []
40
+
41
+ for threshold in thresholds:
42
+ extractor = _METRIC_EXTRACTORS.get(threshold.metric)
43
+ if extractor is None:
44
+ logger.warning("Unknown metric: %s — skipping assertion", threshold.metric)
45
+ results.append(AssertionResult(
46
+ metric=threshold.metric,
47
+ operator=threshold.operator,
48
+ expected=threshold.value,
49
+ actual=float("nan"),
50
+ passed=False,
51
+ ))
52
+ continue
53
+
54
+ actual = extractor(computed)
55
+ if actual is None:
56
+ logger.warning("Metric %s not available in stats", threshold.metric)
57
+ results.append(AssertionResult(
58
+ metric=threshold.metric,
59
+ operator=threshold.operator,
60
+ expected=threshold.value,
61
+ actual=float("nan"),
62
+ passed=False,
63
+ ))
64
+ continue
65
+
66
+ comparator = _OPERATORS.get(threshold.operator)
67
+ if comparator is None:
68
+ logger.error("Invalid operator: %s", threshold.operator)
69
+ results.append(AssertionResult(
70
+ metric=threshold.metric,
71
+ operator=threshold.operator,
72
+ expected=threshold.value,
73
+ actual=actual,
74
+ passed=False,
75
+ ))
76
+ continue
77
+
78
+ passed = comparator(actual, threshold.value)
79
+ results.append(AssertionResult(
80
+ metric=threshold.metric,
81
+ operator=threshold.operator,
82
+ expected=threshold.value,
83
+ actual=actual,
84
+ passed=passed,
85
+ ))
86
+
87
+ return Verdict(
88
+ passed=all(r.passed for r in results),
89
+ results=results,
90
+ )
91
+
92
+
93
+ def parse_threshold(expr: str) -> Threshold:
94
+ for symbol in ("<=", ">=", "==", "<", ">"):
95
+ if symbol in expr:
96
+ metric, value_str = expr.split(symbol, 1)
97
+ return Threshold(
98
+ metric=metric.strip(),
99
+ operator=symbol,
100
+ value=float(value_str.strip()),
101
+ )
102
+ raise ValueError(f"Invalid threshold expression: {expr!r}")
103
+
104
+
105
+ def _format_value(metric: str, value: float) -> str:
106
+ if "latency" in metric:
107
+ return f"{value:.1f}ms"
108
+ if "pct" in metric:
109
+ return f"{value:.1f}%"
110
+ if "rps" in metric:
111
+ return f"{value:.1f}/s"
112
+ if value == int(value):
113
+ return str(int(value))
114
+ return f"{value:.1f}"
115
+
116
+
117
+ def print_verdict(verdict: Verdict) -> None:
118
+ print("\n Assertions:")
119
+ for r in verdict.results:
120
+ mark = "\033[32m✓\033[0m" if r.passed else "\033[31m✗\033[0m"
121
+ actual_str = _format_value(r.metric, r.actual)
122
+ expected_str = _format_value(r.metric, r.expected)
123
+ print(f" {mark} {r.metric} {actual_str} {r.operator} {expected_str}")
124
+
125
+ label = "\033[32mPASS\033[0m" if verdict.passed else "\033[31mFAIL\033[0m"
126
+ print(f"\n Verdict: {label}")
127
+
128
+
129
+ def write_junit_xml(verdict: Verdict, path: str, test_name: str = "overload") -> None:
130
+ suite = ET.Element("testsuite", {
131
+ "name": test_name,
132
+ "tests": str(len(verdict.results)),
133
+ "failures": str(sum(1 for r in verdict.results if not r.passed)),
134
+ })
135
+
136
+ for r in verdict.results:
137
+ case = ET.SubElement(suite, "testcase", {
138
+ "name": f"{r.metric} {r.operator} {_format_value(r.metric, r.expected)}",
139
+ "classname": "overload.assertions",
140
+ })
141
+ if not r.passed:
142
+ failure = ET.SubElement(case, "failure", {
143
+ "message": (
144
+ f"{r.metric}: expected {r.operator} {_format_value(r.metric, r.expected)}, "
145
+ f"got {_format_value(r.metric, r.actual)}"
146
+ ),
147
+ })
148
+
149
+ tree = ET.ElementTree(suite)
150
+ ET.indent(tree)
151
+ tree.write(path, encoding="unicode", xml_declaration=True)
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import time
5
+ from dataclasses import dataclass
6
+
7
+ import httpx
8
+
9
+ from overload.collection.models import AuthConfig
10
+ from overload.collection.variables import VariableContext
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class _CachedToken:
17
+ token: str
18
+ expires_at: float
19
+
20
+
21
+ _TOKEN_CACHE: dict[str, _CachedToken] = {}
22
+
23
+
24
+ async def fetch_oauth2_token(
25
+ auth: AuthConfig,
26
+ ctx: VariableContext,
27
+ client: httpx.AsyncClient,
28
+ ) -> str:
29
+ params = auth.params
30
+
31
+ # Postman uses "accessTokenUrl"; some environments expose it as "access_token_url"
32
+ token_url = ctx.resolve(
33
+ params.get("accessTokenUrl")
34
+ or params.get("access_token_url")
35
+ or params.get("tokenUrl")
36
+ or ""
37
+ )
38
+ client_id = ctx.resolve(params.get("clientId") or params.get("client_id") or "")
39
+ client_secret = ctx.resolve(params.get("clientSecret") or params.get("client_secret") or "")
40
+ scope = ctx.resolve(params.get("scope") or "")
41
+
42
+ if not token_url:
43
+ raise ValueError("OAuth2 auth: accessTokenUrl is required but missing")
44
+
45
+ cache_key = f"{client_id}:{token_url}"
46
+ cached = _TOKEN_CACHE.get(cache_key)
47
+ if cached and cached.expires_at > time.time() + 30:
48
+ logger.debug("OAuth2: using cached token for client_id=%s", client_id)
49
+ return cached.token
50
+
51
+ logger.info("OAuth2: fetching token from %s (client_id=%s)", token_url, client_id)
52
+
53
+ post_data: dict[str, str] = {
54
+ "grant_type": "client_credentials",
55
+ "client_id": client_id,
56
+ "client_secret": client_secret,
57
+ }
58
+ if scope:
59
+ post_data["scope"] = scope
60
+
61
+ try:
62
+ response = await client.post(token_url, data=post_data)
63
+ response.raise_for_status()
64
+ body = response.json()
65
+ except httpx.HTTPStatusError as exc:
66
+ raise RuntimeError(
67
+ f"OAuth2 token request failed: {exc.response.status_code} — {exc.response.text}"
68
+ ) from exc
69
+ except Exception as exc:
70
+ raise RuntimeError(f"OAuth2 token request failed: {exc}") from exc
71
+
72
+ access_token: str | None = body.get("access_token")
73
+ if not access_token:
74
+ raise RuntimeError(f"OAuth2 response missing 'access_token' field: {body}")
75
+
76
+ expires_in = int(body.get("expires_in", 3600))
77
+ _TOKEN_CACHE[cache_key] = _CachedToken(
78
+ token=access_token,
79
+ expires_at=time.time() + expires_in,
80
+ )
81
+
82
+ logger.info("OAuth2: token acquired, expires in %ds", expires_in)
83
+ return access_token
84
+
85
+
86
+ def clear_token_cache() -> None:
87
+ _TOKEN_CACHE.clear()
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from collections import defaultdict
6
+ from collections.abc import Callable
7
+ from typing import Any
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class EventBus:
13
+ def __init__(self) -> None:
14
+ self._handlers: dict[str, list[Callable]] = defaultdict(list)
15
+ self._queues: dict[str, list[asyncio.Queue]] = defaultdict(list)
16
+
17
+ def subscribe(self, event: str, handler: Callable) -> None:
18
+ self._handlers[event].append(handler)
19
+
20
+ def unsubscribe(self, event: str, handler: Callable) -> None:
21
+ handlers = self._handlers.get(event, [])
22
+ if handler in handlers:
23
+ handlers.remove(handler)
24
+
25
+ def create_queue(self, event: str) -> asyncio.Queue:
26
+ queue: asyncio.Queue = asyncio.Queue()
27
+ self._queues[event].append(queue)
28
+ return queue
29
+
30
+ def remove_queue(self, event: str, queue: asyncio.Queue) -> None:
31
+ queues = self._queues.get(event, [])
32
+ if queue in queues:
33
+ queues.remove(queue)
34
+
35
+ async def emit(self, event: str, data: Any = None) -> None:
36
+ logger.debug("Event emitted: %s", event)
37
+
38
+ for handler in self._handlers.get(event, []):
39
+ try:
40
+ result = handler(data)
41
+ if asyncio.iscoroutine(result):
42
+ await result
43
+ except Exception:
44
+ logger.exception("Error in event handler for %s", event)
45
+
46
+ for queue in self._queues.get(event, []):
47
+ try:
48
+ queue.put_nowait({"event": event, "data": data})
49
+ except asyncio.QueueFull:
50
+ logger.warning("Event queue full for %s, dropping event", event)