mrok 0.4.6__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. mrok/agent/devtools/inspector/app.py +2 -2
  2. mrok/agent/sidecar/app.py +61 -35
  3. mrok/agent/sidecar/main.py +35 -9
  4. mrok/agent/ziticorn.py +9 -3
  5. mrok/cli/commands/__init__.py +2 -2
  6. mrok/cli/commands/admin/bootstrap.py +3 -2
  7. mrok/cli/commands/admin/utils.py +2 -2
  8. mrok/cli/commands/agent/run/sidecar.py +59 -1
  9. mrok/cli/commands/{proxy → frontend}/__init__.py +1 -1
  10. mrok/cli/commands/frontend/run.py +91 -0
  11. mrok/constants.py +0 -2
  12. mrok/controller/openapi/examples.py +13 -0
  13. mrok/controller/schemas.py +2 -2
  14. mrok/frontend/__init__.py +3 -0
  15. mrok/frontend/app.py +75 -0
  16. mrok/{proxy → frontend}/main.py +12 -10
  17. mrok/proxy/__init__.py +0 -3
  18. mrok/proxy/app.py +158 -83
  19. mrok/proxy/asgi.py +96 -0
  20. mrok/proxy/backend.py +45 -0
  21. mrok/proxy/event_publisher.py +66 -0
  22. mrok/proxy/exceptions.py +22 -0
  23. mrok/{master.py → proxy/master.py} +36 -81
  24. mrok/{metrics.py → proxy/metrics.py} +38 -50
  25. mrok/{http/middlewares.py → proxy/middleware.py} +17 -26
  26. mrok/{datastructures.py → proxy/models.py} +43 -10
  27. mrok/proxy/stream.py +68 -0
  28. mrok/{http → proxy}/utils.py +1 -1
  29. mrok/proxy/worker.py +64 -0
  30. mrok/{http/config.py → proxy/ziticorn.py} +29 -6
  31. mrok/types/proxy.py +20 -0
  32. mrok/types/ziti.py +1 -0
  33. mrok/ziti/api.py +15 -18
  34. mrok/ziti/bootstrap.py +3 -2
  35. mrok/ziti/identities.py +5 -4
  36. mrok/ziti/services.py +3 -2
  37. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/METADATA +2 -5
  38. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/RECORD +43 -39
  39. mrok/cli/commands/proxy/run.py +0 -49
  40. mrok/http/forwarder.py +0 -354
  41. mrok/http/lifespan.py +0 -39
  42. mrok/http/pool.py +0 -239
  43. mrok/http/protocol.py +0 -11
  44. mrok/http/server.py +0 -14
  45. mrok/http/types.py +0 -18
  46. /mrok/{http → proxy}/constants.py +0 -0
  47. /mrok/{http → types}/__init__.py +0 -0
  48. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/WHEEL +0 -0
  49. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/entry_points.txt +0 -0
  50. {mrok-0.4.6.dist-info → mrok-0.6.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -8,7 +8,7 @@ import time
8
8
  import psutil
9
9
  from hdrh.histogram import HdrHistogram
10
10
 
11
- from mrok.datastructures import (
11
+ from mrok.proxy.models import (
12
12
  DataTransferMetrics,
13
13
  ProcessMetrics,
14
14
  RequestsMetrics,
@@ -20,11 +20,7 @@ logger = logging.getLogger("mrok.proxy")
20
20
 
21
21
 
22
22
  def _collect_process_usage(interval: float) -> ProcessMetrics:
23
- try:
24
- proc = psutil.Process(os.getpid())
25
- except psutil.NoSuchProcess:
26
- return ProcessMetrics(cpu=0.0, mem=0.0)
27
-
23
+ proc = psutil.Process(os.getpid())
28
24
  total_cpu = 0.0
29
25
  total_mem = 0.0
30
26
 
@@ -33,29 +29,28 @@ def _collect_process_usage(interval: float) -> ProcessMetrics:
33
29
  except Exception:
34
30
  total_cpu = 0.0
35
31
 
36
- if interval and interval > 0:
32
+ if interval and interval > 0: # pragma: no branch
37
33
  time.sleep(interval)
38
34
 
39
35
  try:
40
36
  total_cpu = proc.cpu_percent(None)
41
- except Exception:
37
+ except Exception: # pragma: no cover
42
38
  total_cpu = 0.0
43
39
 
44
40
  try:
45
41
  total_mem = proc.memory_percent()
46
- except Exception:
42
+ except Exception: # pragma: no cover
47
43
  total_mem = 0.0
48
44
 
49
45
  return ProcessMetrics(cpu=total_cpu, mem=total_mem)
50
46
 
51
47
 
52
- async def get_process_and_children_usage(interval: float = 0.1) -> ProcessMetrics:
48
+ async def get_process_metrics(interval: float = 0.1) -> ProcessMetrics:
53
49
  return await asyncio.to_thread(_collect_process_usage, interval)
54
50
 
55
51
 
56
- class WorkerMetricsCollector:
52
+ class MetricsCollector:
57
53
  def __init__(self, worker_id: str, lowest=1, highest=60000, sigfigs=3):
58
- # Request-level counters
59
54
  self.worker_id = worker_id
60
55
  self.total_requests = 0
61
56
  self.successful_requests = 0
@@ -63,14 +58,11 @@ class WorkerMetricsCollector:
63
58
  self.bytes_in = 0
64
59
  self.bytes_out = 0
65
60
 
66
- # RPS
67
61
  self._tick_last = time.time()
68
62
  self._tick_requests = 0
69
63
 
70
- # latency histogram
71
64
  self.hist = HdrHistogram(lowest, highest, sigfigs)
72
65
 
73
- # async lock
74
66
  self._lock = asyncio.Lock()
75
67
 
76
68
  async def on_request_start(self, scope):
@@ -102,38 +94,34 @@ class WorkerMetricsCollector:
102
94
  self.hist.record_value(elapsed_ms)
103
95
 
104
96
  async def snapshot(self) -> WorkerMetrics:
105
- try:
106
- async with self._lock:
107
- now = time.time()
108
- delta = now - self._tick_last
109
- rps = int(self._tick_requests / delta) if delta > 0 else 0
110
- data = WorkerMetrics(
111
- worker_id=self.worker_id,
112
- process=await get_process_and_children_usage(),
113
- requests=RequestsMetrics(
114
- rps=rps,
115
- total=self.total_requests,
116
- successful=self.successful_requests,
117
- failed=self.failed_requests,
118
- ),
119
- data_transfer=DataTransferMetrics(
120
- bytes_in=self.bytes_in,
121
- bytes_out=self.bytes_out,
122
- ),
123
- response_time=ResponseTimeMetrics(
124
- avg=self.hist.get_mean_value(),
125
- min=self.hist.get_min_value(),
126
- max=self.hist.get_max_value(),
127
- p50=self.hist.get_value_at_percentile(50),
128
- p90=self.hist.get_value_at_percentile(90),
129
- p99=self.hist.get_value_at_percentile(99),
130
- ),
131
- )
132
-
133
- self._tick_last = now
134
- self._tick_requests = 0
135
-
136
- return data
137
- except Exception:
138
- logger.exception("Exception calculating snapshot")
139
- raise
97
+ async with self._lock:
98
+ now = time.time()
99
+ delta = now - self._tick_last
100
+ rps = int(self._tick_requests / delta) if delta > 0 else 0
101
+ data = WorkerMetrics(
102
+ worker_id=self.worker_id,
103
+ process=await get_process_metrics(),
104
+ requests=RequestsMetrics(
105
+ rps=rps,
106
+ total=self.total_requests,
107
+ successful=self.successful_requests,
108
+ failed=self.failed_requests,
109
+ ),
110
+ data_transfer=DataTransferMetrics(
111
+ bytes_in=self.bytes_in,
112
+ bytes_out=self.bytes_out,
113
+ ),
114
+ response_time=ResponseTimeMetrics(
115
+ avg=self.hist.get_mean_value(),
116
+ min=self.hist.get_min_value(),
117
+ max=self.hist.get_max_value(),
118
+ p50=self.hist.get_value_at_percentile(50),
119
+ p90=self.hist.get_value_at_percentile(90),
120
+ p99=self.hist.get_value_at_percentile(99),
121
+ ),
122
+ )
123
+
124
+ self._tick_last = now
125
+ self._tick_requests = 0
126
+
127
+ return data
@@ -1,20 +1,22 @@
1
1
  import asyncio
2
- import inspect
3
2
  import logging
4
3
  import time
5
- from collections.abc import Callable, Coroutine
6
- from typing import Any
7
4
 
8
- from mrok.datastructures import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
9
- from mrok.http.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
10
- from mrok.http.types import ASGIApp, ASGIReceive, ASGISend, Message, Scope
11
- from mrok.http.utils import must_capture_request, must_capture_response
12
- from mrok.metrics import WorkerMetricsCollector
5
+ from mrok.proxy.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
6
+ from mrok.proxy.metrics import MetricsCollector
7
+ from mrok.proxy.models import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse
8
+ from mrok.proxy.utils import must_capture_request, must_capture_response
9
+ from mrok.types.proxy import (
10
+ ASGIApp,
11
+ ASGIReceive,
12
+ ASGISend,
13
+ Message,
14
+ ResponseCompleteCallback,
15
+ Scope,
16
+ )
13
17
 
14
18
  logger = logging.getLogger("mrok.proxy")
15
19
 
16
- ResponseCompleteCallback = Callable[[HTTPResponse], Coroutine[Any, Any, None] | None]
17
-
18
20
 
19
21
  class CaptureMiddleware:
20
22
  def __init__(
@@ -92,22 +94,11 @@ class CaptureMiddleware:
92
94
  body=resp_buf.getvalue() if state["capture_resp_body"] else None,
93
95
  body_truncated=resp_buf.overflow if state["capture_resp_body"] else None,
94
96
  )
95
- await asyncio.create_task(self.handle_callback(response))
96
-
97
- async def handle_callback(self, response: HTTPResponse):
98
- try:
99
- if inspect.iscoroutinefunction(self._on_response_complete):
100
- await self._on_response_complete(response)
101
- else:
102
- await asyncio.get_running_loop().run_in_executor(
103
- None, self._on_response_complete, response
104
- )
105
- except Exception:
106
- logger.exception("Error invoking callback")
97
+ asyncio.create_task(self._on_response_complete(response))
107
98
 
108
99
 
109
100
  class MetricsMiddleware:
110
- def __init__(self, app, metrics: WorkerMetricsCollector):
101
+ def __init__(self, app: ASGIApp, metrics: MetricsCollector):
111
102
  self.app = app
112
103
  self.metrics = metrics
113
104
 
@@ -116,11 +107,11 @@ class MetricsMiddleware:
116
107
  return await self.app(scope, receive, send)
117
108
 
118
109
  start_time = await self.metrics.on_request_start(scope)
119
- status_code = 500 # default if app errors early
110
+ status_code = 500
120
111
 
121
112
  async def wrapped_receive():
122
113
  msg = await receive()
123
- if msg["type"] == "http.request" and msg.get("body"):
114
+ if msg["type"] == "http.request" and msg.get("body"): # pragma: no branch
124
115
  await self.metrics.on_request_body(len(msg["body"]))
125
116
  return msg
126
117
 
@@ -131,7 +122,7 @@ class MetricsMiddleware:
131
122
  status_code = msg["status"]
132
123
  await self.metrics.on_response_start(status_code)
133
124
 
134
- elif msg["type"] == "http.response.body":
125
+ elif msg["type"] == "http.response.body": # pragma: no branch
135
126
  body = msg.get("body", b"")
136
127
  await self.metrics.on_response_chunk(len(body))
137
128
 
@@ -1,11 +1,52 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
4
+ from pathlib import Path
3
5
  from typing import Literal
4
6
 
5
- from pydantic import BaseModel, Field
7
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
6
8
  from pydantic_core import core_schema
7
9
 
8
10
 
11
+ class X509Credentials(BaseModel):
12
+ key: str
13
+ cert: str
14
+ ca: str
15
+
16
+ @field_validator("key", "cert", "ca", mode="before")
17
+ @classmethod
18
+ def strip_pem_prefix(cls, value: str) -> str:
19
+ if isinstance(value, str) and value.startswith("pem:"):
20
+ return value[4:]
21
+ return value
22
+
23
+
24
+ class ServiceMetadata(BaseModel):
25
+ model_config = ConfigDict(extra="ignore")
26
+ identity: str
27
+ extension: str
28
+ instance: str
29
+ domain: str | None = None
30
+ tags: dict[str, str | bool | None] | None = None
31
+
32
+
33
+ class Identity(BaseModel):
34
+ model_config = ConfigDict(extra="ignore")
35
+ zt_api: str = Field(validation_alias="ztAPI")
36
+ id: X509Credentials
37
+ zt_apis: str | None = Field(default=None, validation_alias="ztAPIs")
38
+ config_types: str | None = Field(default=None, validation_alias="configTypes")
39
+ enable_ha: bool = Field(default=False, validation_alias="enableHa")
40
+ mrok: ServiceMetadata | None = None
41
+
42
+ @staticmethod
43
+ def load_from_file(path: str | Path) -> Identity:
44
+ path = Path(path)
45
+ with path.open("r", encoding="utf-8") as f:
46
+ data = json.load(f)
47
+ return Identity.model_validate(data)
48
+
49
+
9
50
  class FixedSizeByteBuffer:
10
51
  def __init__(self, max_size: int):
11
52
  self._max_size = max_size
@@ -140,17 +181,9 @@ class WorkerMetrics(BaseModel):
140
181
  process: ProcessMetrics
141
182
 
142
183
 
143
- class Meta(BaseModel):
144
- identity: str
145
- extension: str
146
- instance: str
147
- domain: str
148
- tags: dict[str, str] | None = None
149
-
150
-
151
184
  class Status(BaseModel):
152
185
  type: Literal["status"] = "status"
153
- meta: Meta
186
+ meta: ServiceMetadata
154
187
  metrics: WorkerMetrics
155
188
 
156
189
 
mrok/proxy/stream.py ADDED
@@ -0,0 +1,68 @@
1
+ import asyncio
2
+ import select
3
+ import sys
4
+ from typing import Any
5
+
6
+ from httpcore import AsyncNetworkStream
7
+
8
+ from mrok.types.proxy import ASGIReceive
9
+
10
+
11
+ def is_readable(sock): # pragma: no cover
12
+ # Stolen from
13
+ # https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471C1-L478C31
14
+
15
+ # use select.select on Windows, and select.poll everywhere else
16
+ if sys.platform == "win32":
17
+ rready, _, _ = select.select([sock], [], [], 0)
18
+ return bool(rready)
19
+ p = select.poll()
20
+ p.register(sock, select.POLLIN)
21
+ return bool(p.poll(0))
22
+
23
+
24
+ class AIONetworkStream(AsyncNetworkStream):
25
+ def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
26
+ self._reader = reader
27
+ self._writer = writer
28
+
29
+ async def read(self, n: int, timeout: float | None = None) -> bytes:
30
+ return await asyncio.wait_for(self._reader.read(n), timeout)
31
+
32
+ async def write(self, data: bytes, timeout: float | None = None) -> None:
33
+ self._writer.write(data)
34
+ await asyncio.wait_for(self._writer.drain(), timeout)
35
+
36
+ async def aclose(self) -> None:
37
+ self._writer.close()
38
+ await self._writer.wait_closed()
39
+
40
+ def get_extra_info(self, info: str) -> Any:
41
+ transport = self._writer.transport
42
+ if info == "is_readable":
43
+ sock = transport.get_extra_info("socket")
44
+ return is_readable(sock)
45
+ return transport.get_extra_info(info)
46
+
47
+
48
+ class ASGIRequestBodyStream:
49
+ def __init__(self, receive: ASGIReceive):
50
+ self._receive = receive
51
+ self._more_body = True
52
+
53
+ def __aiter__(self):
54
+ return self
55
+
56
+ async def __anext__(self) -> bytes:
57
+ if not self._more_body:
58
+ raise StopAsyncIteration
59
+
60
+ msg = await self._receive()
61
+ if msg["type"] == "http.request":
62
+ chunk = msg.get("body", b"")
63
+ self._more_body = msg.get("more_body", False)
64
+ return chunk
65
+ elif msg["type"] == "http.disconnect":
66
+ raise Exception("Client disconnected.")
67
+
68
+ raise Exception("Unexpected asgi message.")
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Mapping
2
2
 
3
- from mrok.http.constants import (
3
+ from mrok.proxy.constants import (
4
4
  BINARY_CONTENT_TYPES,
5
5
  BINARY_PREFIXES,
6
6
  MAX_REQUEST_BODY_BYTES,
mrok/proxy/worker.py ADDED
@@ -0,0 +1,64 @@
1
+ import asyncio
2
+ import contextlib
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ from uvicorn.importer import import_from_string
7
+
8
+ from mrok.conf import get_settings
9
+ from mrok.logging import setup_logging
10
+ from mrok.proxy.asgi import ASGIAppWrapper
11
+ from mrok.proxy.event_publisher import EventPublisher
12
+ from mrok.proxy.models import Identity
13
+ from mrok.proxy.ziticorn import BackendConfig, Server
14
+ from mrok.types.proxy import ASGIApp
15
+
16
+ logger = logging.getLogger("mrok.proxy")
17
+
18
+
19
+ class Worker:
20
+ def __init__(
21
+ self,
22
+ worker_id: str,
23
+ app: ASGIApp | str,
24
+ identity_file: str | Path,
25
+ *,
26
+ events_enabled: bool = True,
27
+ event_publisher_port: int = 50000,
28
+ metrics_interval: float = 5.0,
29
+ ):
30
+ self._worker_id = worker_id
31
+ self._identity_file = identity_file
32
+ self._identity = Identity.load_from_file(self._identity_file)
33
+ self._app = app
34
+
35
+ self._events_enabled = events_enabled
36
+ self._event_publisher = (
37
+ EventPublisher(
38
+ worker_id=worker_id,
39
+ meta=self._identity.mrok,
40
+ event_publisher_port=event_publisher_port,
41
+ metrics_interval=metrics_interval,
42
+ )
43
+ if events_enabled
44
+ else None
45
+ )
46
+
47
+ def setup_app(self):
48
+ app = ASGIAppWrapper(
49
+ self._app if not isinstance(self._app, str) else import_from_string(self._app),
50
+ lifespan=self._event_publisher.lifespan if self._events_enabled else None,
51
+ )
52
+
53
+ if self._events_enabled:
54
+ self._event_publisher.setup_middleware(app)
55
+ return app
56
+
57
+ def run(self):
58
+ setup_logging(get_settings())
59
+ app = self.setup_app()
60
+
61
+ config = BackendConfig(app, self._identity_file)
62
+ server = Server(config)
63
+ with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError):
64
+ server.run()
@@ -6,17 +6,40 @@ from pathlib import Path
6
6
  from typing import Any
7
7
 
8
8
  import openziti
9
- from uvicorn import config
9
+ from uvicorn import config, server
10
+ from uvicorn.lifespan.on import LifespanOn
11
+ from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol as UvHttpToolsProtocol
10
12
 
11
- from mrok.http.protocol import MrokHttpToolsProtocol
12
- from mrok.http.types import ASGIApp
13
+ from mrok.types.proxy import ASGIApp
13
14
 
14
15
  logger = logging.getLogger("mrok.proxy")
15
16
 
16
- config.LIFESPAN["auto"] = "mrok.http.lifespan:MrokLifespan"
17
+ config.LIFESPAN["auto"] = "mrok.proxy.ziticorn:Lifespan"
17
18
 
18
19
 
19
- class MrokBackendConfig(config.Config):
20
+ class Lifespan(LifespanOn):
21
+ def __init__(self, lf_config: config.Config) -> None:
22
+ super().__init__(lf_config)
23
+ self.logger = logging.getLogger("mrok.proxy")
24
+
25
+
26
+ class HttpToolsProtocol(UvHttpToolsProtocol):
27
+ def __init__(self, *args, **kwargs):
28
+ super().__init__(*args, **kwargs)
29
+ self.logger = logging.getLogger("mrok.proxy")
30
+ self.access_logger = logging.getLogger("mrok.access")
31
+ self.access_log = self.access_logger.hasHandlers()
32
+
33
+
34
+ class Server(server.Server):
35
+ async def serve(self, sockets: list[socket.socket] | None = None) -> None:
36
+ if not sockets:
37
+ sockets = [self.config.bind_socket()]
38
+ with self.capture_signals():
39
+ await self._serve(sockets)
40
+
41
+
42
+ class BackendConfig(config.Config):
20
43
  def __init__(
21
44
  self,
22
45
  app: ASGIApp | Callable[..., Any] | str,
@@ -32,7 +55,7 @@ class MrokBackendConfig(config.Config):
32
55
  super().__init__(
33
56
  app,
34
57
  loop="asyncio",
35
- http=MrokHttpToolsProtocol,
58
+ http=HttpToolsProtocol,
36
59
  backlog=backlog,
37
60
  )
38
61
 
mrok/types/proxy.py ADDED
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Awaitable, Callable, Coroutine, Mapping, MutableMapping
4
+ from contextlib import AbstractAsyncContextManager
5
+ from typing import Any, Never
6
+
7
+ from mrok.proxy.models import HTTPResponse
8
+
9
+ Scope = MutableMapping[str, Any]
10
+ Message = MutableMapping[str, Any]
11
+
12
+ ASGIReceive = Callable[[], Awaitable[Message]]
13
+ ASGISend = Callable[[Message], Awaitable[None]]
14
+ ASGIApp = Callable[[Scope, ASGIReceive, ASGISend], Awaitable[None]]
15
+ StatelessLifespan = Callable[[ASGIApp], AbstractAsyncContextManager[None]]
16
+ StatefulLifespan = Callable[[ASGIApp], AbstractAsyncContextManager[Mapping[str, Any]]]
17
+ Lifespan = StatelessLifespan | StatefulLifespan
18
+
19
+ LifespanCallback = Callable[[], Awaitable[None]]
20
+ ResponseCompleteCallback = Callable[[HTTPResponse], Coroutine[Any, Any, Never]]
mrok/types/ziti.py ADDED
@@ -0,0 +1 @@
1
+ Tags = dict[str, str | bool | None]
mrok/ziti/api.py CHANGED
@@ -11,12 +11,11 @@ from typing import Any, Literal
11
11
  import httpx
12
12
 
13
13
  from mrok.conf import Settings
14
+ from mrok.types.ziti import Tags
14
15
  from mrok.ziti.constants import MROK_VERSION_TAG, MROK_VERSION_TAG_NAME
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
18
- TagsType = dict[str, str | bool | None]
19
-
20
19
 
21
20
  class ZitiAPIError(Exception):
22
21
  pass
@@ -70,7 +69,7 @@ class BaseZitiAPI(ABC):
70
69
  ),
71
70
  )
72
71
 
73
- async def create(self, endpoint: str, payload: dict[str, Any], tags: TagsType | None) -> str:
72
+ async def create(self, endpoint: str, payload: dict[str, Any], tags: Tags | None) -> str:
74
73
  payload["tags"] = self._merge_tags(tags)
75
74
  response: httpx.Response = await self.httpx_client.post(
76
75
  endpoint,
@@ -156,8 +155,8 @@ class BaseZitiAPI(ABC):
156
155
  ) -> None:
157
156
  return await self.httpx_client.__aexit__(exc_type, exc_val, exc_tb)
158
157
 
159
- def _merge_tags(self, tags: TagsType | None) -> TagsType:
160
- prepared_tags: TagsType = tags or {}
158
+ def _merge_tags(self, tags: Tags | None) -> Tags:
159
+ prepared_tags: Tags = tags or {}
161
160
  prepared_tags.update(MROK_VERSION_TAG)
162
161
  return prepared_tags
163
162
 
@@ -281,9 +280,7 @@ class ZitiManagementAPI(BaseZitiAPI):
281
280
  async def search_config(self, id_or_name) -> dict[str, Any] | None:
282
281
  return await self.search_by_id_or_name("/configs", id_or_name)
283
282
 
284
- async def create_config(
285
- self, name: str, config_type_id: str, tags: TagsType | None = None
286
- ) -> str:
283
+ async def create_config(self, name: str, config_type_id: str, tags: Tags | None = None) -> str:
287
284
  return await self.create(
288
285
  "/configs",
289
286
  {
@@ -302,7 +299,7 @@ class ZitiManagementAPI(BaseZitiAPI):
302
299
  async def delete_config(self, config_id: str) -> None:
303
300
  return await self.delete("/configs", config_id)
304
301
 
305
- async def create_config_type(self, name: str, tags: TagsType | None = None) -> str:
302
+ async def create_config_type(self, name: str, tags: Tags | None = None) -> str:
306
303
  return await self.create(
307
304
  "/config-types",
308
305
  {
@@ -316,7 +313,7 @@ class ZitiManagementAPI(BaseZitiAPI):
316
313
  self,
317
314
  name: str,
318
315
  config_id: str,
319
- tags: TagsType | None = None,
316
+ tags: Tags | None = None,
320
317
  ) -> str:
321
318
  return await self.create(
322
319
  "/services",
@@ -332,7 +329,7 @@ class ZitiManagementAPI(BaseZitiAPI):
332
329
  self,
333
330
  name: str,
334
331
  service_id: str,
335
- tags: TagsType | None = None,
332
+ tags: Tags | None = None,
336
333
  ) -> str:
337
334
  return await self.create(
338
335
  "/service-edge-router-policies",
@@ -351,7 +348,7 @@ class ZitiManagementAPI(BaseZitiAPI):
351
348
  self,
352
349
  name: str,
353
350
  identity_id: str,
354
- tags: TagsType | None = None,
351
+ tags: Tags | None = None,
355
352
  ) -> str:
356
353
  return await self.create(
357
354
  "/edge-router-policies",
@@ -385,10 +382,10 @@ class ZitiManagementAPI(BaseZitiAPI):
385
382
  async def delete_service(self, service_id: str) -> None:
386
383
  return await self.delete("/services", service_id)
387
384
 
388
- async def create_user_identity(self, name: str, tags: TagsType | None = None) -> str:
385
+ async def create_user_identity(self, name: str, tags: Tags | None = None) -> str:
389
386
  return await self._create_identity(name, "User", tags=tags)
390
387
 
391
- async def create_device_identity(self, name: str, tags: TagsType | None = None) -> str:
388
+ async def create_device_identity(self, name: str, tags: Tags | None = None) -> str:
392
389
  return await self._create_identity(name, "Device", tags=tags)
393
390
 
394
391
  async def search_identity(self, id_or_name: str) -> dict[str, Any] | None:
@@ -412,12 +409,12 @@ class ZitiManagementAPI(BaseZitiAPI):
412
409
  return response.text
413
410
 
414
411
  async def create_dial_service_policy(
415
- self, name: str, service_id: str, identity_id: str, tags: TagsType | None = None
412
+ self, name: str, service_id: str, identity_id: str, tags: Tags | None = None
416
413
  ) -> str:
417
414
  return await self._create_service_policy("Dial", name, service_id, identity_id, tags)
418
415
 
419
416
  async def create_bind_service_policy(
420
- self, name: str, service_id: str, identity_id: str, tags: TagsType | None = None
417
+ self, name: str, service_id: str, identity_id: str, tags: Tags | None = None
421
418
  ) -> str:
422
419
  return await self._create_service_policy("Bind", name, service_id, identity_id, tags)
423
420
 
@@ -433,7 +430,7 @@ class ZitiManagementAPI(BaseZitiAPI):
433
430
  name: str,
434
431
  service_id: str,
435
432
  identity_id: str,
436
- tags: TagsType | None = None,
433
+ tags: Tags | None = None,
437
434
  ) -> str:
438
435
  return await self.create(
439
436
  "/service-policies",
@@ -451,7 +448,7 @@ class ZitiManagementAPI(BaseZitiAPI):
451
448
  self,
452
449
  name: str,
453
450
  type: Literal["User", "Device", "Default"],
454
- tags: TagsType | None = None,
451
+ tags: Tags | None = None,
455
452
  ) -> str:
456
453
  return await self.create(
457
454
  "/identities",
mrok/ziti/bootstrap.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  from typing import Any
3
3
 
4
- from mrok.ziti.api import TagsType, ZitiClientAPI, ZitiManagementAPI
4
+ from mrok.types.ziti import Tags
5
+ from mrok.ziti.api import ZitiClientAPI, ZitiManagementAPI
5
6
  from mrok.ziti.identities import enroll_proxy_identity
6
7
 
7
8
  logger = logging.getLogger(__name__)
@@ -13,7 +14,7 @@ async def bootstrap_identity(
13
14
  identity_name: str,
14
15
  mode: str,
15
16
  forced: bool,
16
- tags: TagsType | None,
17
+ tags: Tags | None,
17
18
  ) -> tuple[str, dict[str, Any] | None]:
18
19
  logger.info(f"Bootstrapping '{identity_name}' identity...")
19
20