reflectapi-runtime 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,165 @@
1
+ """Batch operations support for ReflectAPI Python clients."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import TYPE_CHECKING, Any, Generic
7
+ # Sentinel object for batch operations where no HTTP response exists
8
+ class _BatchNoResponse:
9
+ """Sentinel object representing the absence of an HTTP response in batch operations."""
10
+
11
+ def __repr__(self) -> str:
12
+ return "<BatchNoResponse>"
13
+
14
+ # Singleton instance
15
+ BATCH_NO_RESPONSE = _BatchNoResponse()
16
+
17
+ import httpx
18
+
19
+ if TYPE_CHECKING:
20
+ from collections.abc import Awaitable, Callable
21
+
22
+ from .exceptions import ApiError
23
+ from .response import ApiResponse, TransportMetadata
24
+ from .types import BatchResult, T
25
+
26
+
27
+ class BatchClient(Generic[T]):
28
+ """Client for executing batch operations with concurrency control."""
29
+
30
+ def __init__(self, max_concurrent: int = 10) -> None:
31
+ self.max_concurrent = max_concurrent
32
+ self._tasks: list[Callable[[], Awaitable[Any]]] = []
33
+ self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_concurrent)
34
+
35
+ def add_task(self, coro_func: Callable[[], Awaitable[T]]) -> None:
36
+ """Add a coroutine function to be executed in the batch."""
37
+ self._tasks.append(coro_func)
38
+
39
+ async def gather(self) -> list[BatchResult[Any]]:
40
+ """Execute all tasks concurrently and return results.
41
+
42
+ Returns a list where each item is either an ApiResponse (success)
43
+ or an ApiError (failure). This allows handling partial failures
44
+ in batch operations.
45
+ """
46
+ if not self._tasks:
47
+ return []
48
+
49
+ async def execute_with_semaphore(
50
+ task: Callable[[], Awaitable[T]],
51
+ ) -> BatchResult[T]:
52
+ """Execute a single task with semaphore control."""
53
+ async with self._semaphore:
54
+ try:
55
+ result = await task()
56
+ # If result is already an ApiResponse or ApiError, return as-is
57
+ if isinstance(result, ApiResponse | ApiError):
58
+ return result
59
+ # Otherwise wrap in ApiResponse with dummy metadata
60
+ # No HTTP request occurred - using sentinel for raw_response
61
+ metadata = TransportMetadata(
62
+ status_code=200,
63
+ headers=httpx.Headers({}),
64
+ timing=0.0,
65
+ raw_response=BATCH_NO_RESPONSE, # Sentinel for batch operations
66
+ )
67
+ return ApiResponse(result, metadata)
68
+ except Exception as e:
69
+ if isinstance(e, ApiError):
70
+ return e
71
+ # Wrap other exceptions as ApiErrors
72
+ return ApiError(
73
+ f"Unexpected error in batch operation: {e}", cause=e
74
+ )
75
+
76
+ # Execute all tasks concurrently
77
+ results = await asyncio.gather(
78
+ *[execute_with_semaphore(task) for task in self._tasks],
79
+ return_exceptions=False, # We handle exceptions manually
80
+ )
81
+
82
+ return results
83
+
84
+ async def gather_successful(self) -> list[ApiResponse[Any]]:
85
+ """Execute all tasks and return only successful results.
86
+
87
+ Failed operations are silently ignored. Use gather() if you need
88
+ to handle failures explicitly.
89
+ """
90
+ results = await self.gather()
91
+ return [result for result in results if isinstance(result, ApiResponse)]
92
+
93
+ async def gather_with_errors(self) -> tuple[list[ApiResponse[Any]], list[ApiError]]:
94
+ """Execute all tasks and return successes and failures separately."""
95
+ results = await self.gather()
96
+
97
+ successes: list[ApiResponse[Any]] = []
98
+ failures: list[ApiError] = []
99
+
100
+ for result in results:
101
+ if isinstance(result, ApiResponse):
102
+ successes.append(result)
103
+ else:
104
+ failures.append(result)
105
+
106
+ return successes, failures
107
+
108
+ def __len__(self) -> int:
109
+ """Return the number of tasks in the batch."""
110
+ return len(self._tasks)
111
+
112
+ def clear(self) -> None:
113
+ """Clear all tasks from the batch."""
114
+ self._tasks.clear()
115
+
116
+
117
+ class BatchContextManager:
118
+ """Context manager for batch operations that integrates with clients."""
119
+
120
+ def __init__(self, client: Any, max_concurrent: int = 10) -> None:
121
+ self.client = client
122
+ self.batch_client = BatchClient[Any](max_concurrent)
123
+ self._original_make_request: Callable[..., Any] | None = None
124
+
125
+ async def __aenter__(self) -> BatchClient[Any]:
126
+ """Enter the batch context and intercept client requests."""
127
+ # Store the original _make_request method
128
+ self._original_make_request = self.client._make_request
129
+
130
+ # Replace with a batching version
131
+ def batch_request(*args: Any, **kwargs: Any) -> Callable[[], Awaitable[Any]]:
132
+ """Create a task that will be executed later."""
133
+
134
+ async def task() -> Any:
135
+ # Restore original method temporarily for execution
136
+ if self._original_make_request:
137
+ return await self._original_make_request(*args, **kwargs)
138
+ raise RuntimeError("Original request method not available")
139
+
140
+ self.batch_client.add_task(task)
141
+ return task # Return the task for potential individual awaiting
142
+
143
+ self.client._make_request = batch_request
144
+ return self.batch_client
145
+
146
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
147
+ """Exit the batch context and restore original client behavior."""
148
+ # Restore the original _make_request method
149
+ if self._original_make_request:
150
+ self.client._make_request = self._original_make_request
151
+
152
+
153
+ # Mixin for clients to add batch support
154
+ class BatchMixin:
155
+ """Mixin that adds batch operation support to clients."""
156
+
157
+ def batch(self, max_concurrent: int = 10) -> BatchContextManager:
158
+ """Create a batch context for concurrent operations.
159
+
160
+ Usage:
161
+ async with client.batch(max_concurrent=10) as batch:
162
+ tasks = [client.create_pet(Pet(name=f"Pet_{i}")) for i in range(100)]
163
+ results = await batch.gather()
164
+ """
165
+ return BatchContextManager(self, max_concurrent)