rappel 0.10.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rappel might be problematic. Click here for more details.

rappel/bridge.py ADDED
@@ -0,0 +1,414 @@
1
+ import asyncio
2
+ import os
3
+ import shlex
4
+ import shutil
5
+ import subprocess
6
+ import tempfile
7
+ import time
8
+ from contextlib import asynccontextmanager
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from threading import Lock, RLock
12
+ from typing import AsyncIterator, NoReturn, Optional
13
+
14
+ import grpc
15
+ from grpc import aio # type: ignore[attr-defined]
16
+
17
+ from proto import messages_pb2 as pb2
18
+ from proto import messages_pb2_grpc as pb2_grpc
19
+ from rappel.logger import configure as configure_logger
20
+
21
+ from .actions import serialize_error_payload, serialize_result_payload
22
+ from .workflow_runtime import execute_action
23
+
24
+ DEFAULT_HOST = "127.0.0.1"
25
+ LOGGER = configure_logger("rappel.bridge")
26
+
27
+ _PORT_LOCK = RLock()
28
+ _CACHED_GRPC_PORT: Optional[int] = None
29
+ _GRPC_TARGET: Optional[str] = None
30
+ _GRPC_CHANNEL: Optional[aio.Channel] = None
31
+ _GRPC_STUB: Optional[pb2_grpc.WorkflowServiceStub] = None
32
+ _GRPC_LOOP: Optional[asyncio.AbstractEventLoop] = None
33
+ _BOOT_MUTEX = Lock()
34
+ _ASYNC_BOOT_LOCK: asyncio.Lock = asyncio.Lock()
35
+
36
+
37
+ @dataclass
38
+ class RunInstanceResult:
39
+ workflow_version_id: str
40
+ workflow_instance_id: str
41
+
42
+
43
+ @dataclass
44
+ class RunBatchResult:
45
+ workflow_version_id: str
46
+ workflow_instance_ids: list[str]
47
+ queued: int
48
+
49
+
50
+ def _boot_command() -> list[str]:
51
+ override = os.environ.get("RAPPEL_BOOT_COMMAND")
52
+ if override:
53
+ LOGGER.debug("Using RAPPEL_BOOT_COMMAND=%s", override)
54
+ return shlex.split(override)
55
+ binary = os.environ.get("RAPPEL_BOOT_BINARY", "boot-rappel-singleton")
56
+ LOGGER.debug("Using RAPPEL_BOOT_BINARY=%s", binary)
57
+ return [binary]
58
+
59
+
60
+ def _repo_root() -> Path:
61
+ return Path(__file__).resolve().parents[3]
62
+
63
+
64
+ def _resolve_boot_binary(binary: str) -> str:
65
+ if Path(binary).is_absolute():
66
+ return binary
67
+ resolved = shutil.which(binary)
68
+ if resolved:
69
+ return resolved
70
+ repo_root = _repo_root()
71
+ for profile in ("debug", "release"):
72
+ candidate = repo_root / "target" / profile / binary
73
+ if candidate.exists():
74
+ return str(candidate)
75
+ return binary
76
+
77
+
78
+ def _ensure_boot_binary(binary: str) -> str:
79
+ resolved = _resolve_boot_binary(binary)
80
+ if Path(resolved).exists():
81
+ return resolved
82
+ repo_root = _repo_root()
83
+ cargo_toml = repo_root / "Cargo.toml"
84
+ if cargo_toml.exists():
85
+ LOGGER.info("boot binary %s not found; building via cargo", binary)
86
+ subprocess.run(
87
+ [
88
+ "cargo",
89
+ "build",
90
+ "--bin",
91
+ "boot-rappel-singleton",
92
+ "--bin",
93
+ "rappel-bridge",
94
+ ],
95
+ cwd=repo_root,
96
+ check=True,
97
+ )
98
+ resolved = _resolve_boot_binary(binary)
99
+ return resolved
100
+
101
+
102
+ def _remember_grpc_port(port: int) -> int:
103
+ global _CACHED_GRPC_PORT
104
+ with _PORT_LOCK:
105
+ _CACHED_GRPC_PORT = port
106
+ return port
107
+
108
+
109
+ def _cached_grpc_port() -> Optional[int]:
110
+ with _PORT_LOCK:
111
+ return _CACHED_GRPC_PORT
112
+
113
+
114
+ def _env_grpc_port_override() -> Optional[int]:
115
+ """Check for explicit gRPC port override via environment."""
116
+ override = os.environ.get("RAPPEL_BRIDGE_GRPC_PORT")
117
+ if not override:
118
+ return None
119
+ try:
120
+ return int(override)
121
+ except ValueError as exc: # pragma: no cover
122
+ raise RuntimeError(f"invalid RAPPEL_BRIDGE_GRPC_PORT value: {override}") from exc
123
+
124
+
125
+ def _boot_singleton_blocking() -> int:
126
+ """Boot the singleton and return the gRPC port."""
127
+ command = _boot_command()
128
+ if os.environ.get("RAPPEL_BOOT_COMMAND") is None:
129
+ command[0] = _ensure_boot_binary(command[0])
130
+ with tempfile.NamedTemporaryFile(mode="w+", suffix=".txt") as f:
131
+ output_file = Path(f.name)
132
+
133
+ command.extend(["--output-file", str(output_file)])
134
+ LOGGER.info("Booting rappel singleton via: %s", " ".join(command))
135
+
136
+ try:
137
+ subprocess.run(
138
+ command,
139
+ check=True,
140
+ timeout=10,
141
+ )
142
+ except subprocess.TimeoutExpired as exc: # pragma: no cover
143
+ LOGGER.error("boot command timed out after %s seconds", exc.timeout)
144
+ raise RuntimeError("unable to boot rappel server") from exc
145
+ except subprocess.CalledProcessError as exc: # pragma: no cover
146
+ LOGGER.error("boot command failed: %s", exc)
147
+ raise RuntimeError("unable to boot rappel server") from exc
148
+ except OSError as exc: # pragma: no cover
149
+ LOGGER.error("unable to spawn boot command: %s", exc)
150
+ raise RuntimeError("unable to boot rappel server") from exc
151
+
152
+ try:
153
+ # We use a file as a message passer because passing a PIPE to the singleton launcher
154
+ # will block our code indefinitely
155
+ # The singleton launches the webserver subprocess to inherit the stdin/stdout that the
156
+ # singleton launcher receives; which means that in the case of a PIPE it would pass that
157
+ # pipe to the subprocess and therefore never correctly close the file descriptor and signal
158
+ # exit process status to Python.
159
+ port_str = output_file.read_text().strip()
160
+ grpc_port = int(port_str)
161
+ LOGGER.info("boot command reported singleton gRPC port %s", grpc_port)
162
+ return grpc_port
163
+ except (ValueError, FileNotFoundError) as exc: # pragma: no cover
164
+ raise RuntimeError(f"unable to read port from output file: {exc}") from exc
165
+
166
+
167
+ def _resolve_grpc_port() -> int:
168
+ """Resolve the gRPC port, booting singleton if necessary."""
169
+ cached = _cached_grpc_port()
170
+ if cached is not None:
171
+ return cached
172
+ env_port = _env_grpc_port_override()
173
+ if env_port is not None:
174
+ return _remember_grpc_port(env_port)
175
+ with _BOOT_MUTEX:
176
+ cached = _cached_grpc_port()
177
+ if cached is not None:
178
+ return cached
179
+ port = _boot_singleton_blocking()
180
+ return _remember_grpc_port(port)
181
+
182
+
183
+ async def _ensure_grpc_port_async() -> int:
184
+ """Ensure we have a gRPC port, booting singleton if necessary."""
185
+ cached = _cached_grpc_port()
186
+ if cached is not None:
187
+ return cached
188
+ env_port = _env_grpc_port_override()
189
+ if env_port is not None:
190
+ return _remember_grpc_port(env_port)
191
+ async with _ASYNC_BOOT_LOCK:
192
+ cached = _cached_grpc_port()
193
+ if cached is not None:
194
+ return cached
195
+ loop = asyncio.get_running_loop()
196
+ LOGGER.info("No cached singleton found, booting new instance")
197
+ port = await loop.run_in_executor(None, _boot_singleton_blocking)
198
+ LOGGER.info("Singleton ready on gRPC port %s", port)
199
+ return _remember_grpc_port(port)
200
+
201
+
202
+ @asynccontextmanager
203
+ async def ensure_singleton() -> AsyncIterator[int]:
204
+ """Yield the gRPC port for the singleton server, booting it exactly once."""
205
+ port = await _ensure_grpc_port_async()
206
+ yield port
207
+
208
+
209
+ def _grpc_target() -> str:
210
+ """Get the gRPC target address for the bridge server."""
211
+ # Check for explicit full address override
212
+ explicit = os.environ.get("RAPPEL_BRIDGE_GRPC_ADDR")
213
+ if explicit:
214
+ return explicit
215
+
216
+ # Otherwise, use host + port
217
+ host = os.environ.get("RAPPEL_BRIDGE_GRPC_HOST", DEFAULT_HOST)
218
+ port = _resolve_grpc_port()
219
+ return f"{host}:{port}"
220
+
221
+
222
+ def assert_never(value: object) -> NoReturn:
223
+ raise AssertionError(f"Unhandled value: {value!r}")
224
+
225
+
226
+ async def _workflow_stub() -> pb2_grpc.WorkflowServiceStub:
227
+ global _GRPC_TARGET, _GRPC_CHANNEL, _GRPC_STUB, _GRPC_LOOP
228
+ target = _grpc_target()
229
+ loop = asyncio.get_running_loop()
230
+ channel_to_wait: Optional[aio.Channel] = None
231
+ with _PORT_LOCK:
232
+ if (
233
+ _GRPC_STUB is not None
234
+ and _GRPC_TARGET == target
235
+ and _GRPC_LOOP is loop
236
+ and not loop.is_closed()
237
+ ):
238
+ return _GRPC_STUB
239
+ channel = aio.insecure_channel(target)
240
+ stub = pb2_grpc.WorkflowServiceStub(channel)
241
+ _GRPC_CHANNEL = channel
242
+ _GRPC_STUB = stub
243
+ _GRPC_TARGET = target
244
+ _GRPC_LOOP = loop
245
+ channel_to_wait = channel
246
+ if channel_to_wait is not None:
247
+ await channel_to_wait.channel_ready()
248
+ return _GRPC_STUB # type: ignore[return-value]
249
+
250
+
251
+ async def run_instance(payload: bytes) -> RunInstanceResult:
252
+ """Register a workflow definition and start an instance over the gRPC bridge."""
253
+ async with ensure_singleton():
254
+ stub = await _workflow_stub()
255
+ registration = pb2.WorkflowRegistration()
256
+ registration.ParseFromString(payload)
257
+ request = pb2.RegisterWorkflowRequest(
258
+ registration=registration,
259
+ )
260
+ try:
261
+ response = await stub.RegisterWorkflow(request, timeout=30.0)
262
+ except aio.AioRpcError as exc: # pragma: no cover
263
+ raise RuntimeError(f"register_workflow failed: {exc}") from exc
264
+ return RunInstanceResult(
265
+ workflow_version_id=response.workflow_version_id,
266
+ workflow_instance_id=response.workflow_instance_id,
267
+ )
268
+
269
+
270
+ async def run_instances_batch(
271
+ payload: bytes,
272
+ *,
273
+ count: int = 1,
274
+ inputs: Optional[pb2.WorkflowArguments] = None,
275
+ inputs_list: Optional[list[pb2.WorkflowArguments]] = None,
276
+ batch_size: int = 500,
277
+ include_instance_ids: bool = False,
278
+ ) -> RunBatchResult:
279
+ """Register a workflow definition and start multiple instances over the gRPC bridge."""
280
+ if count < 1 and not inputs_list:
281
+ raise ValueError("count must be >= 1 when inputs_list is empty")
282
+ if batch_size < 1:
283
+ raise ValueError("batch_size must be >= 1")
284
+
285
+ async with ensure_singleton():
286
+ stub = await _workflow_stub()
287
+ registration = pb2.WorkflowRegistration()
288
+ registration.ParseFromString(payload)
289
+ request = pb2.RegisterWorkflowBatchRequest(
290
+ registration=registration,
291
+ count=count,
292
+ batch_size=batch_size,
293
+ include_instance_ids=include_instance_ids,
294
+ )
295
+ if inputs is not None:
296
+ request.inputs.CopyFrom(inputs)
297
+ if inputs_list:
298
+ request.inputs_list.extend(inputs_list)
299
+ try:
300
+ response = await stub.RegisterWorkflowBatch(request, timeout=30.0)
301
+ except aio.AioRpcError as exc: # pragma: no cover
302
+ raise RuntimeError(f"register_workflow_batch failed: {exc}") from exc
303
+ return RunBatchResult(
304
+ workflow_version_id=response.workflow_version_id,
305
+ workflow_instance_ids=list(response.workflow_instance_ids),
306
+ queued=response.queued,
307
+ )
308
+
309
+
310
+ async def execute_workflow(payload: bytes) -> bytes:
311
+ """Execute a workflow via the in-memory workflow streaming API."""
312
+ os.environ.setdefault("RAPPEL_BRIDGE_IN_MEMORY", "1")
313
+ async with ensure_singleton():
314
+ stub = await _workflow_stub()
315
+
316
+ registration = pb2.WorkflowRegistration()
317
+ registration.ParseFromString(payload)
318
+ LOGGER.debug(
319
+ "pytest stream start: workflow=%s ir_hash=%s",
320
+ registration.workflow_name,
321
+ registration.ir_hash,
322
+ )
323
+
324
+ queue: asyncio.Queue[Optional[pb2.WorkflowStreamRequest]] = asyncio.Queue()
325
+ skip_sleep = bool(os.environ.get("PYTEST_CURRENT_TEST"))
326
+ await queue.put(pb2.WorkflowStreamRequest(registration=registration, skip_sleep=skip_sleep))
327
+
328
+ async def request_stream() -> AsyncIterator[pb2.WorkflowStreamRequest]:
329
+ while True:
330
+ item = await queue.get()
331
+ if item is None:
332
+ return
333
+ yield item
334
+
335
+ call = stub.ExecuteWorkflow(request_stream(), timeout=300.0)
336
+ result_payload: Optional[bytes] = None
337
+
338
+ async for response in call:
339
+ kind = response.WhichOneof("kind")
340
+ match kind:
341
+ case "action_dispatch":
342
+ dispatch = response.action_dispatch
343
+ LOGGER.debug(
344
+ "pytest stream dispatch: action_id=%s module=%s action=%s",
345
+ dispatch.action_id,
346
+ dispatch.module_name,
347
+ dispatch.action_name,
348
+ )
349
+ start_ns = time.monotonic_ns()
350
+ execution = await execute_action(dispatch)
351
+ end_ns = time.monotonic_ns()
352
+ action_result = pb2.ActionResult(
353
+ action_id=dispatch.action_id,
354
+ success=execution.exception is None,
355
+ payload=(
356
+ serialize_result_payload(execution.result)
357
+ if execution.exception is None
358
+ else serialize_error_payload(dispatch.action_name, execution.exception)
359
+ ),
360
+ worker_start_ns=start_ns,
361
+ worker_end_ns=end_ns,
362
+ error_type=(
363
+ type(execution.exception).__name__
364
+ if execution.exception is not None
365
+ else ""
366
+ ),
367
+ error_message=str(execution.exception)
368
+ if execution.exception is not None
369
+ else "",
370
+ )
371
+ LOGGER.debug(
372
+ "pytest stream result: action_id=%s success=%s",
373
+ dispatch.action_id,
374
+ execution.exception is None,
375
+ )
376
+ await queue.put(pb2.WorkflowStreamRequest(action_result=action_result))
377
+ case "workflow_result":
378
+ result_payload = response.workflow_result.payload
379
+ LOGGER.debug(
380
+ "pytest stream complete: workflow=%s payload_bytes=%s",
381
+ registration.workflow_name,
382
+ len(result_payload),
383
+ )
384
+ await queue.put(None)
385
+ break
386
+ case None:
387
+ continue
388
+ case _:
389
+ assert_never(kind)
390
+
391
+ if result_payload is None:
392
+ raise RuntimeError("workflow stream ended without a result")
393
+ return result_payload
394
+
395
+
396
+ async def wait_for_instance(
397
+ instance_id: str,
398
+ poll_interval_secs: float = 1.0,
399
+ ) -> Optional[bytes]:
400
+ """Block until the workflow daemon produces the requested instance payload."""
401
+ async with ensure_singleton():
402
+ stub = await _workflow_stub()
403
+ request = pb2.WaitForInstanceRequest(
404
+ instance_id=instance_id,
405
+ poll_interval_secs=poll_interval_secs,
406
+ )
407
+ try:
408
+ response = await stub.WaitForInstance(request, timeout=None)
409
+ except aio.AioRpcError as exc: # pragma: no cover
410
+ status_fn = exc.code
411
+ if callable(status_fn) and status_fn() == grpc.StatusCode.NOT_FOUND:
412
+ return None
413
+ raise RuntimeError(f"wait_for_instance failed: {exc}") from exc
414
+ return bytes(response.payload)
rappel/dependencies.py ADDED
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
5
+ from dataclasses import dataclass
6
+ from typing import Annotated, Any, AsyncIterator, Callable, Optional, get_args, get_origin
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class DependMarker:
11
+ """Internal marker for dependency injection."""
12
+
13
+ dependency: Optional[Callable[..., Any]] = None
14
+ use_cache: bool = True
15
+
16
+
17
+ def Depend( # noqa: N802
18
+ dependency: Optional[Callable[..., Any]] = None,
19
+ *,
20
+ use_cache: bool = True,
21
+ ) -> Any:
22
+ """Marker for dependency injection, mirroring FastAPI's Depends syntax.
23
+
24
+ Returns Any to allow usage as a default parameter value:
25
+ def my_func(service: MyService = Depend(get_service)):
26
+ ...
27
+ """
28
+ return DependMarker(dependency=dependency, use_cache=use_cache)
29
+
30
+
31
+ def _depend_from_annotation(annotation: Any) -> DependMarker | None:
32
+ origin = get_origin(annotation)
33
+ if origin is not Annotated:
34
+ return None
35
+ metadata = get_args(annotation)[1:]
36
+ for meta in metadata:
37
+ if isinstance(meta, DependMarker):
38
+ return meta
39
+ return None
40
+
41
+
42
+ def _dependency_marker(parameter: inspect.Parameter) -> DependMarker | None:
43
+ if isinstance(parameter.default, DependMarker):
44
+ return parameter.default
45
+ return _depend_from_annotation(parameter.annotation)
46
+
47
+
48
+ class _DependencyResolver:
49
+ """Resolve dependency graphs for a callable, including context manager lifetimes."""
50
+
51
+ def __init__(self, initial_kwargs: Optional[dict[str, Any]] = None) -> None:
52
+ self._context: dict[str, Any] = dict(initial_kwargs or {})
53
+ self._cache: dict[Callable[..., Any], Any] = {}
54
+ self._active: set[Callable[..., Any]] = set()
55
+ self._stack = AsyncExitStack()
56
+
57
+ async def close(self) -> None:
58
+ await self._stack.aclose()
59
+
60
+ async def build_call_kwargs(self, func: Callable[..., Any]) -> dict[str, Any]:
61
+ call_kwargs: dict[str, Any] = {}
62
+ signature = inspect.signature(func)
63
+ func_name = func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
64
+ for name, parameter in signature.parameters.items():
65
+ if parameter.kind in (
66
+ inspect.Parameter.VAR_POSITIONAL,
67
+ inspect.Parameter.VAR_KEYWORD,
68
+ ):
69
+ continue
70
+ if name in self._context:
71
+ call_kwargs[name] = self._context[name]
72
+ continue
73
+ marker = _dependency_marker(parameter)
74
+ if marker is not None:
75
+ value = await self._resolve_dependency(marker)
76
+ self._context[name] = value
77
+ call_kwargs[name] = value
78
+ continue
79
+ if parameter.default is not inspect.Parameter.empty:
80
+ call_kwargs[name] = parameter.default
81
+ self._context.setdefault(name, parameter.default)
82
+ continue
83
+ raise TypeError(f"Missing required parameter '{name}' for {func_name}")
84
+ return call_kwargs
85
+
86
+ async def _resolve_dependency(self, marker: DependMarker) -> Any:
87
+ dependency = marker.dependency
88
+ if dependency is None:
89
+ raise TypeError("Depend requires a dependency callable")
90
+ if marker.use_cache and dependency in self._cache:
91
+ return self._cache[dependency]
92
+ if dependency in self._active:
93
+ name = (
94
+ dependency.__name__
95
+ if hasattr(dependency, "__name__")
96
+ else dependency.__class__.__name__
97
+ )
98
+ raise RuntimeError(f"Circular dependency detected for {name}")
99
+ self._active.add(dependency)
100
+ try:
101
+ kwargs = await self.build_call_kwargs(dependency)
102
+ value = await self._call_dependency(dependency, kwargs)
103
+ if marker.use_cache:
104
+ self._cache[dependency] = value
105
+ return value
106
+ finally:
107
+ self._active.discard(dependency)
108
+
109
+ async def _call_dependency(
110
+ self,
111
+ dependency: Callable[..., Any],
112
+ kwargs: dict[str, Any],
113
+ ) -> Any:
114
+ if inspect.isasyncgenfunction(dependency):
115
+ context_manager = asynccontextmanager(dependency)(**kwargs)
116
+ return await self._stack.enter_async_context(context_manager)
117
+ if inspect.isgeneratorfunction(dependency):
118
+ context_manager = contextmanager(dependency)(**kwargs)
119
+ return self._stack.enter_context(context_manager)
120
+ result = dependency(**kwargs)
121
+ resolved = await self._await_if_needed(result)
122
+ return await self._enter_context_if_needed(resolved)
123
+
124
+ async def _await_if_needed(self, value: Any) -> Any:
125
+ if inspect.isawaitable(value):
126
+ return await value
127
+ return value
128
+
129
+ async def _enter_context_if_needed(self, value: Any) -> Any:
130
+ if hasattr(value, "__aenter__") and hasattr(value, "__aexit__"):
131
+ return await self._stack.enter_async_context(value) # type: ignore[arg-type]
132
+ if hasattr(value, "__enter__") and hasattr(value, "__exit__"):
133
+ return self._stack.enter_context(value) # type: ignore[arg-type]
134
+ return value
135
+
136
+
137
+ @asynccontextmanager
138
+ async def provide_dependencies(
139
+ func: Callable[..., Any],
140
+ kwargs: Optional[dict[str, Any]] = None,
141
+ ) -> AsyncIterator[dict[str, Any]]:
142
+ """Resolve dependencies for ``func`` and manage their lifetimes."""
143
+
144
+ resolver = _DependencyResolver(kwargs)
145
+ try:
146
+ call_kwargs = await resolver.build_call_kwargs(func)
147
+ yield call_kwargs
148
+ finally:
149
+ await resolver.close()
rappel/exceptions.py ADDED
@@ -0,0 +1,18 @@
1
+ """Custom exception types raised by rappel workflows."""
2
+
3
+
4
+ class ExhaustedRetriesError(Exception):
5
+ """Raised when an action exhausts its allotted retry attempts."""
6
+
7
+ def __init__(self, message: str | None = None) -> None:
8
+ super().__init__(message or "action exhausted retries")
9
+
10
+
11
+ ExhaustedRetries = ExhaustedRetriesError
12
+
13
+
14
+ class ScheduleAlreadyExistsError(Exception):
15
+ """Raised when a schedule name is already registered."""
16
+
17
+ def __init__(self, message: str | None = None) -> None:
18
+ super().__init__(message or "schedule already exists")
rappel/formatter.py ADDED
@@ -0,0 +1,110 @@
1
+ import os
2
+ import re
3
+ import sys
4
+ from typing import List, Sequence, TextIO
5
+
6
+ RESET = "\033[0m"
7
+
8
+ STYLE_CODES = {
9
+ "bold": "\033[1m",
10
+ "dim": "\033[2m",
11
+ "red": "\033[31m",
12
+ "green": "\033[32m",
13
+ "yellow": "\033[33m",
14
+ "blue": "\033[34m",
15
+ "magenta": "\033[35m",
16
+ "cyan": "\033[36m",
17
+ "white": "\033[37m",
18
+ }
19
+
20
+ _TAG_PATTERN = re.compile(r"\[(/?)([a-zA-Z]+)?\]")
21
+
22
+
23
+ def supports_color(stream: TextIO) -> bool:
24
+ """Return True if the provided stream likely supports ANSI colors."""
25
+
26
+ if os.environ.get("NO_COLOR"):
27
+ return False
28
+ if os.environ.get("FORCE_COLOR"):
29
+ return True
30
+ if stream.isatty():
31
+ if sys.platform != "win32":
32
+ return True
33
+ return bool(
34
+ os.environ.get("ANSICON")
35
+ or os.environ.get("WT_SESSION")
36
+ or os.environ.get("TERM_PROGRAM")
37
+ )
38
+ return False
39
+
40
+
41
+ class Formatter:
42
+ """
43
+ Very small markup formatter inspired by Rich's tag syntax. We want to minimize our python client
44
+ dependencies to just grpc+standard library.
45
+
46
+ """
47
+
48
+ def __init__(self, enable_colors: bool) -> None:
49
+ self._enable_colors = enable_colors
50
+
51
+ @property
52
+ def enable_colors(self) -> bool:
53
+ return self._enable_colors
54
+
55
+ def format(self, text: str) -> str:
56
+ if not text:
57
+ return text
58
+ return _apply_markup(text, self._enable_colors)
59
+
60
+ def apply_styles(self, text: str, styles: Sequence[str]) -> str:
61
+ """Wrap text with markup tags for the provided style sequence."""
62
+
63
+ if not styles:
64
+ return text
65
+ opening = "".join(f"[{style}]" for style in styles)
66
+ closing = "".join(f"[/{style}]" for style in reversed(styles))
67
+ return f"{opening}{text}{closing}"
68
+
69
+
70
+ def _apply_markup(text: str, enable_colors: bool) -> str:
71
+ if not enable_colors:
72
+ return _TAG_PATTERN.sub("", text)
73
+ result: List[str] = []
74
+ stack: List[str] = []
75
+ index = 0
76
+ for match in _TAG_PATTERN.finditer(text):
77
+ result.append(text[index : match.start()])
78
+ is_closing = match.group(1) == "/"
79
+ tag_name = match.group(2)
80
+ if tag_name is None:
81
+ if is_closing:
82
+ if stack:
83
+ stack.clear()
84
+ result.append(RESET)
85
+ else:
86
+ result.append(match.group(0))
87
+ index = match.end()
88
+ continue
89
+ if is_closing:
90
+ if tag_name in stack:
91
+ while stack:
92
+ name = stack.pop()
93
+ result.append(RESET)
94
+ if name == tag_name:
95
+ break
96
+ if stack:
97
+ result.append("".join(STYLE_CODES[name] for name in stack))
98
+ index = match.end()
99
+ continue
100
+ code = STYLE_CODES.get(tag_name)
101
+ if code is None:
102
+ result.append(match.group(0))
103
+ else:
104
+ stack.append(tag_name)
105
+ result.append(code)
106
+ index = match.end()
107
+ result.append(text[index:])
108
+ if stack:
109
+ result.append(RESET)
110
+ return "".join(result)