earthscope-sdk 1.0.0b1__py3-none-any.whl → 1.2.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. earthscope_sdk/__init__.py +1 -1
  2. earthscope_sdk/auth/auth_flow.py +8 -6
  3. earthscope_sdk/auth/client_credentials_flow.py +3 -13
  4. earthscope_sdk/auth/device_code_flow.py +19 -10
  5. earthscope_sdk/client/_client.py +47 -0
  6. earthscope_sdk/client/data_access/__init__.py +0 -0
  7. earthscope_sdk/client/data_access/_arrow/__init__.py +0 -0
  8. earthscope_sdk/client/data_access/_arrow/_common.py +94 -0
  9. earthscope_sdk/client/data_access/_arrow/_gnss.py +116 -0
  10. earthscope_sdk/client/data_access/_base.py +85 -0
  11. earthscope_sdk/client/data_access/_query_plan/__init__.py +0 -0
  12. earthscope_sdk/client/data_access/_query_plan/_gnss_observations.py +295 -0
  13. earthscope_sdk/client/data_access/_query_plan/_query_plan.py +259 -0
  14. earthscope_sdk/client/data_access/_query_plan/_request_set.py +133 -0
  15. earthscope_sdk/client/data_access/_service.py +114 -0
  16. earthscope_sdk/client/discovery/__init__.py +0 -0
  17. earthscope_sdk/client/discovery/_base.py +303 -0
  18. earthscope_sdk/client/discovery/_service.py +209 -0
  19. earthscope_sdk/client/discovery/models.py +144 -0
  20. earthscope_sdk/common/context.py +73 -1
  21. earthscope_sdk/common/service.py +10 -8
  22. earthscope_sdk/config/_bootstrap.py +42 -0
  23. earthscope_sdk/config/models.py +54 -21
  24. earthscope_sdk/config/settings.py +11 -0
  25. earthscope_sdk/model/secret.py +29 -0
  26. earthscope_sdk/util/__init__.py +0 -0
  27. earthscope_sdk/util/_concurrency.py +64 -0
  28. earthscope_sdk/util/_itertools.py +57 -0
  29. earthscope_sdk/util/_time.py +57 -0
  30. earthscope_sdk/util/_types.py +5 -0
  31. {earthscope_sdk-1.0.0b1.dist-info → earthscope_sdk-1.2.0b0.dist-info}/METADATA +15 -4
  32. earthscope_sdk-1.2.0b0.dist-info/RECORD +49 -0
  33. {earthscope_sdk-1.0.0b1.dist-info → earthscope_sdk-1.2.0b0.dist-info}/WHEEL +1 -1
  34. earthscope_sdk-1.0.0b1.dist-info/RECORD +0 -28
  35. {earthscope_sdk-1.0.0b1.dist-info → earthscope_sdk-1.2.0b0.dist-info}/licenses/LICENSE +0 -0
  36. {earthscope_sdk-1.0.0b1.dist-info → earthscope_sdk-1.2.0b0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
1
+ from datetime import timedelta
2
+ from enum import Enum
3
+ from typing import Annotated, Any, Generic, Iterable, Optional, TypeVar, Union
4
+
5
+ from pydantic import BaseModel, BeforeValidator, TypeAdapter
6
+
7
+ from earthscope_sdk.util._itertools import to_set
8
+
9
+
10
+ def _coerce_timedelta_ms(v: Union[int, float, timedelta, str]) -> timedelta:
11
+ if isinstance(v, (int, float)):
12
+ return timedelta(milliseconds=v)
13
+
14
+ # fallback to Pydantic's timedelta parser
15
+ return v
16
+
17
+
18
+ P = TypeVar("P")
19
+
20
+
21
+ class Page(BaseModel, Generic[P]):
22
+ has_next: bool
23
+ offset: int
24
+ limit: int
25
+ items: list[P]
26
+ total: Optional[int] = None
27
+
28
+
29
+ class DatasourceBaseModel(BaseModel):
30
+ edid: str
31
+ names: dict[str, str]
32
+ description: Optional[str] = None
33
+
34
+ def to_arrow_columns(
35
+ self,
36
+ *,
37
+ fields: Union[list[str], str] = ["edid", "names"],
38
+ namespaces: Union[list[str], str] = [],
39
+ ) -> dict[str, Any]:
40
+ """
41
+ Convert the datasource model to a dictionary suitable for use in an Arrow table.
42
+ """
43
+ result = {}
44
+ namespaces = to_set(namespaces)
45
+ fields = to_set(fields)
46
+
47
+ # Add names to fields if namespaces are requested
48
+ if namespaces:
49
+ fields.add("names")
50
+
51
+ for field in fields:
52
+ if field != "names":
53
+ result[field] = getattr(self, field)
54
+ continue
55
+
56
+ # Explode names to own columns
57
+ if not namespaces:
58
+ names = {k.lower(): v for k, v in self.names.items()}
59
+ else:
60
+ names = {
61
+ k_lower: v
62
+ for k, v in self.names.items()
63
+ if (k_lower := k.lower()) in namespaces
64
+ }
65
+
66
+ result.update(names)
67
+
68
+ return result
69
+
70
+
71
+ class NetworkDatasource(DatasourceBaseModel): ...
72
+
73
+
74
+ ListNetworkDatasourcesResult = TypeAdapter(Union[Page[str], Page[NetworkDatasource]])
75
+
76
+
77
+ class StationDatasource(DatasourceBaseModel):
78
+ network_edids: Optional[list[str]] = None
79
+ networks: Optional[list[NetworkDatasource]] = None
80
+
81
+
82
+ ListStationDatasourcesResult = TypeAdapter(Union[Page[str], Page[StationDatasource]])
83
+
84
+
85
+ class _StationDatasourceMember(DatasourceBaseModel):
86
+ station_edid: Optional[str] = None
87
+ station: Optional[StationDatasource] = None
88
+
89
+ def to_arrow_columns(
90
+ self,
91
+ *,
92
+ fields: list[str] = ["edid", "names"],
93
+ namespaces: Optional[list[str]] = None,
94
+ ) -> dict[str, Any]:
95
+ result = super().to_arrow_columns(fields=fields, namespaces=namespaces)
96
+ if self.station:
97
+ parent_columns = self.station.to_arrow_columns(
98
+ fields=["names"],
99
+ namespaces=namespaces,
100
+ )
101
+ result.update(parent_columns)
102
+
103
+ return result
104
+
105
+
106
+ class SessionDatasource(_StationDatasourceMember):
107
+ sample_interval: Annotated[timedelta, BeforeValidator(_coerce_timedelta_ms)]
108
+ """
109
+ Session sample interval.
110
+ """
111
+
112
+ roll: timedelta # already in seconds
113
+ """
114
+ Session file roll cadence.
115
+ """
116
+
117
+
118
+ ListSessionDatasourcesResult = TypeAdapter(Union[Page[str], Page[SessionDatasource]])
119
+
120
+
121
+ class StreamType(Enum):
122
+ GNSS_RAW = "gnss_raw"
123
+ GNSS_PPP = "gnss_ppp"
124
+
125
+
126
+ class StreamDatasource(_StationDatasourceMember):
127
+ stream_type: StreamType
128
+ facility: str
129
+ software: str
130
+ label: str
131
+ sample_interval: Annotated[timedelta, BeforeValidator(_coerce_timedelta_ms)]
132
+ """
133
+ Stream sample interval.
134
+ """
135
+
136
+ def to_arrow_columns(
137
+ self,
138
+ *,
139
+ fields: Iterable[str] = ["edid", "names", "facility", "software", "label"],
140
+ ) -> dict[str, Any]:
141
+ return super().to_arrow_columns(fields=fields)
142
+
143
+
144
+ ListStreamDatasourcesResult = TypeAdapter(Union[Page[str], Page[StreamDatasource]])
@@ -1,5 +1,6 @@
1
+ import asyncio
1
2
  import sys
2
- from functools import cached_property
3
+ from functools import cached_property, partial
3
4
  from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar, cast
4
5
 
5
6
  if sys.version_info >= (3, 10):
@@ -9,6 +10,8 @@ else:
9
10
 
10
11
 
11
12
  if TYPE_CHECKING:
13
+ from concurrent.futures import ThreadPoolExecutor
14
+
12
15
  from httpx import AsyncClient
13
16
 
14
17
  from earthscope_sdk.common._sync_runner import SyncRunner
@@ -64,6 +67,19 @@ class SdkContext:
64
67
 
65
68
  return DeviceCodeFlow(ctx=self)
66
69
 
70
+ @cached_property
71
+ def executor(self):
72
+ """
73
+ Thread pool executor for running sync functions in the background
74
+ """
75
+ import concurrent.futures
76
+
77
+ self._executor = concurrent.futures.ThreadPoolExecutor(
78
+ max_workers=self.settings.thread_pool_max_workers
79
+ )
80
+
81
+ return self._executor
82
+
67
83
  @cached_property
68
84
  def httpx_client(self):
69
85
  """
@@ -78,6 +94,8 @@ class SdkContext:
78
94
  self._httpx_client = httpx.AsyncClient(
79
95
  auth=self.auth_flow,
80
96
  headers={
97
+ **self.settings.http.extra_headers,
98
+ # override anything specified via extra_headers
81
99
  "user-agent": self.settings.http.user_agent,
82
100
  },
83
101
  limits=self.settings.http.limits,
@@ -103,6 +121,18 @@ class SdkContext:
103
121
  """
104
122
  return self._settings
105
123
 
124
+ @cached_property
125
+ def _rate_limit(self):
126
+ """
127
+ Rate limit semaphore for rate limiting HTTP requests
128
+ """
129
+ from earthscope_sdk.util._concurrency import RateLimitSemaphore
130
+
131
+ return RateLimitSemaphore(
132
+ max_concurrent=self.settings.http.rate_limit.max_concurrent,
133
+ max_per_second=self.settings.http.rate_limit.max_per_second,
134
+ )
135
+
106
136
  def __init__(
107
137
  self,
108
138
  settings: Optional["SdkSettings"] = None,
@@ -113,6 +143,7 @@ class SdkContext:
113
143
  from earthscope_sdk.config.settings import SdkSettings
114
144
 
115
145
  # Local state
146
+ self._executor: Optional["ThreadPoolExecutor"] = None
116
147
  self._httpx_client: Optional["AsyncClient"] = None
117
148
  self._runner: Optional["SyncRunner"] = runner
118
149
  self._settings = settings or SdkSettings()
@@ -127,6 +158,9 @@ class SdkContext:
127
158
  if self._runner:
128
159
  self._runner.stop()
129
160
 
161
+ if self._executor:
162
+ self._executor.shutdown()
163
+
130
164
  def close(self):
131
165
  """
132
166
  Close this SdkContext to release underlying resources (e.g. connection pools)
@@ -141,6 +175,44 @@ class SdkContext:
141
175
  if self._runner:
142
176
  self._runner.stop()
143
177
 
178
+ if self._executor:
179
+ self._executor.shutdown()
180
+
181
+ async def run_in_executor(self, fn: Callable[..., Any], *args, **kwargs):
182
+ """
183
+ Run a function in the executor
184
+ """
185
+ loop = asyncio.get_running_loop()
186
+ return await loop.run_in_executor(self.executor, partial(fn, *args, **kwargs))
187
+
188
+ def asyncify(
189
+ self,
190
+ fn: Callable[T_ParamSpec, T_Retval],
191
+ ) -> Callable[T_ParamSpec, Coroutine[Any, Any, T_Retval]]:
192
+ """
193
+ Decorator that wraps a sync function to run in the executor as a coroutine.
194
+
195
+ Usage:
196
+ @ctx.asyncify
197
+ def some_sync_function(arg1, arg2):
198
+ # sync code here
199
+ return result
200
+
201
+ # Now some_sync_function is async and runs in executor
202
+ result = await some_sync_function(arg1, arg2)
203
+ """
204
+
205
+ async def wrapper(
206
+ *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
207
+ ) -> T_Retval:
208
+ loop = asyncio.get_running_loop()
209
+ return await loop.run_in_executor(
210
+ self.executor,
211
+ partial(fn, *args, **kwargs),
212
+ )
213
+
214
+ return wrapper
215
+
144
216
  def syncify(
145
217
  self,
146
218
  async_function: Callable[T_ParamSpec, Coroutine[Any, Any, T_Retval]],
@@ -36,17 +36,19 @@ class SdkService:
36
36
 
37
37
  Performs common response handling.
38
38
  """
39
- resp = await self.ctx.httpx_client.send(request=request)
39
+ # Global rate limiting
40
+ async with self._ctx._rate_limit:
41
+ resp = await self.ctx.httpx_client.send(request=request)
40
42
 
41
- # Throw specific errors for certain status codes
43
+ # Throw specific errors for certain status codes
42
44
 
43
- if resp.status_code == 401:
44
- await resp.aread() # must read body before using .text prop
45
- raise UnauthenticatedError(resp.text)
45
+ if resp.status_code == 401:
46
+ await resp.aread() # must read body before using .text prop
47
+ raise UnauthenticatedError(resp.text)
46
48
 
47
- if resp.status_code == 403:
48
- await resp.aread() # must read body before using .text prop
49
- raise UnauthorizedError(resp.text)
49
+ if resp.status_code == 403:
50
+ await resp.aread() # must read body before using .text prop
51
+ raise UnauthorizedError(resp.text)
50
52
 
51
53
  # Raise HTTP errors
52
54
  resp.raise_for_status()
@@ -0,0 +1,42 @@
1
+ """
2
+ This module facilitates bootstrapping SDK settings from a JSON-encoded environment variable.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+
9
+ from pydantic_settings import PydanticBaseSettingsSource
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class BootstrapEnvironmentSettingsSource(PydanticBaseSettingsSource):
15
+ """
16
+ This SettingsSource facilitates bootstrapping the SDK from a special environment variable.
17
+
18
+ The environment variable should be a JSON string of the expected SDK settings and structure.
19
+ """
20
+
21
+ def __init__(self, settings_cls, env_var: str):
22
+ super().__init__(settings_cls)
23
+ self._env_var = env_var
24
+
25
+ def __call__(self):
26
+ try:
27
+ bootstrap_settings = os.environ[self._env_var]
28
+ except KeyError:
29
+ return {}
30
+
31
+ try:
32
+ return json.loads(bootstrap_settings)
33
+ except json.JSONDecodeError:
34
+ logger.warning(
35
+ f"Found bootstrap environment variable '{self._env_var}', but unable to decode content as JSON"
36
+ )
37
+ return {}
38
+
39
+ def __repr__(self) -> str:
40
+ return f"{self.__class__.__name__}(env_var='{self._env_var}')"
41
+
42
+ def get_field_value(self, *args, **kwargs): ... # unused abstract method
@@ -1,13 +1,14 @@
1
1
  import base64
2
2
  import binascii
3
3
  import datetime as dt
4
+ import fnmatch
4
5
  import functools
5
6
  from contextlib import suppress
6
7
  from enum import Enum
7
8
  from functools import cached_property
8
9
  from typing import Annotated, Any, Optional, Type, Union
9
10
 
10
- from annotated_types import Ge, Gt
11
+ from annotated_types import Ge, Gt, Interval
11
12
  from pydantic import (
12
13
  AliasChoices,
13
14
  BaseModel,
@@ -15,14 +16,12 @@ from pydantic import (
15
16
  ConfigDict,
16
17
  Field,
17
18
  HttpUrl,
18
- SecretStr,
19
- SerializationInfo,
20
19
  ValidationError,
21
- field_serializer,
22
20
  model_validator,
23
21
  )
24
22
 
25
23
  from earthscope_sdk import __version__
24
+ from earthscope_sdk.model.secret import SecretStr
26
25
 
27
26
 
28
27
  def _try_float(v: Any):
@@ -94,23 +93,6 @@ class Tokens(BaseModel):
94
93
 
95
94
  raise ValueError("Unable to decode access token body")
96
95
 
97
- @field_serializer("access_token", "id_token", "refresh_token", when_used="json")
98
- def dump_secret_json(self, secret: Optional[SecretStr], info: SerializationInfo):
99
- """
100
- A special field serializer to dump the actual secret value when writing to JSON.
101
-
102
- Only writes secret in plaintext when `info.context == "plaintext".
103
-
104
- See [Pydantic docs](https://docs.pydantic.dev/latest/concepts/serialization/#serialization-context)
105
- """
106
- if secret is None:
107
- return None
108
-
109
- if info.context == "plaintext":
110
- return secret.get_secret_value()
111
-
112
- return str(secret)
113
-
114
96
  @model_validator(mode="after")
115
97
  def ensure_one_of(self):
116
98
  # allow all fields to be optional in subclasses
@@ -207,6 +189,12 @@ class AuthFlowSettings(Tokens):
207
189
  scope: str = "offline_access"
208
190
  client_secret: Optional[SecretStr] = None
209
191
 
192
+ # Only inject bearer token for requests to these hosts
193
+ allowed_hosts: set[str] = {
194
+ "earthscope.org",
195
+ "*.earthscope.org",
196
+ }
197
+
210
198
  # Auth exchange retries
211
199
  retry: HttpRetrySettings = HttpRetrySettings(
212
200
  attempts=5,
@@ -222,6 +210,46 @@ class AuthFlowSettings(Tokens):
222
210
 
223
211
  return AuthFlowType.DeviceCode
224
212
 
213
+ @cached_property
214
+ def allowed_host_patterns(self) -> set[str]:
215
+ """
216
+ The subset of allowed hosts that are glob patterns.
217
+
218
+ Use `is_host_allowed` to check if a host is allowed by any of these patterns.
219
+ """
220
+ return {h for h in self.allowed_hosts if "*" in h or "?" in h}
221
+
222
+ def is_host_allowed(self, host: str) -> bool:
223
+ """
224
+ Check if a host matches any pattern in the allowed hosts set.
225
+
226
+ Supports glob patterns with '?' and '*' characters (e.g., *.earthscope.org).
227
+
228
+ Args:
229
+ host: The hostname to check
230
+
231
+ Returns:
232
+ True if the host matches any allowed pattern, False otherwise
233
+ """
234
+ if host in self.allowed_hosts:
235
+ return True
236
+
237
+ for allowed_pattern in self.allowed_host_patterns:
238
+ if fnmatch.fnmatch(host, allowed_pattern):
239
+ self.allowed_hosts.add(host)
240
+ return True
241
+
242
+ return False
243
+
244
+
245
+ class RateLimitSettings(BaseModel):
246
+ """
247
+ Rate limit settings
248
+ """
249
+
250
+ max_concurrent: Annotated[int, Interval(ge=1, le=200)] = 100
251
+ max_per_second: Annotated[float, Interval(ge=1, le=200)] = 150.0
252
+
225
253
 
226
254
  class HttpSettings(BaseModel):
227
255
  """
@@ -240,8 +268,12 @@ class HttpSettings(BaseModel):
240
268
  # automatically retry requests
241
269
  retry: HttpRetrySettings = HttpRetrySettings()
242
270
 
271
+ # rate limit outgoing requests
272
+ rate_limit: RateLimitSettings = RateLimitSettings()
273
+
243
274
  # Other
244
275
  user_agent: str = f"earthscope-sdk py/{__version__}"
276
+ extra_headers: dict[str, str] = {}
245
277
 
246
278
  @cached_property
247
279
  def limits(self):
@@ -289,3 +321,4 @@ class SdkBaseSettings(BaseModel):
289
321
  http: HttpSettings = HttpSettings()
290
322
  oauth2: AuthFlowSettings = AuthFlowSettings()
291
323
  resources: ResourceRefs = ResourceRefs()
324
+ thread_pool_max_workers: Optional[int] = None
@@ -11,11 +11,15 @@ from pydantic_settings import (
11
11
  TomlConfigSettingsSource,
12
12
  )
13
13
 
14
+ from earthscope_sdk.config._bootstrap import BootstrapEnvironmentSettingsSource
14
15
  from earthscope_sdk.config._compat import LegacyEarthScopeCLISettingsSource
15
16
  from earthscope_sdk.config._util import deep_merge, get_config_dir, slugify
16
17
  from earthscope_sdk.config.error import ProfileDoesNotExistError
17
18
  from earthscope_sdk.config.models import SdkBaseSettings, Tokens
18
19
 
20
+ _BOOTSTRAP_ENV_VAR = "ES_BOOTSTRAP_SETTINGS"
21
+ """Environment variable for bootstrapping the SDK"""
22
+
19
23
  _DEFAULT_PROFILE = "default"
20
24
  """Default profile name"""
21
25
 
@@ -269,6 +273,12 @@ class SdkSettings(SdkBaseSettings, BaseSettings):
269
273
  alias = SdkSettings.model_fields["profile_name"].validation_alias
270
274
  global_settings = _GlobalSettingsSource(settings_cls, "profile_name", alias)
271
275
 
276
+ # Check for bootstrapping configuration
277
+ bootstrap_settings = BootstrapEnvironmentSettingsSource(
278
+ settings_cls,
279
+ _BOOTSTRAP_ENV_VAR,
280
+ )
281
+
272
282
  # Compatibility with earthscope-cli v0.x.x state:
273
283
  # If we find this file, we only care about the access and refresh tokens
274
284
  keep_keys = {"access_token", "refresh_token"}
@@ -281,4 +291,5 @@ class SdkSettings(SdkBaseSettings, BaseSettings):
281
291
  dotenv_settings,
282
292
  global_settings,
283
293
  legacy_settings,
294
+ bootstrap_settings,
284
295
  )
@@ -0,0 +1,29 @@
1
+ from typing import Annotated
2
+
3
+ from pydantic import PlainSerializer, SerializationInfo
4
+ from pydantic import SecretStr as _SecretStr
5
+
6
+
7
+ def _dump_secret_plaintext(secret: _SecretStr, info: SerializationInfo):
8
+ """
9
+ A special field serializer to dump the actual secret value.
10
+
11
+ Only writes secret in plaintext when `info.context == "plaintext".
12
+
13
+ See [Pydantic docs](https://docs.pydantic.dev/latest/concepts/serialization/#serialization-context)
14
+ """
15
+
16
+ if info.context == "plaintext":
17
+ return secret.get_secret_value()
18
+
19
+ return str(secret)
20
+
21
+
22
+ SecretStr = Annotated[
23
+ _SecretStr,
24
+ PlainSerializer(
25
+ _dump_secret_plaintext,
26
+ return_type=str,
27
+ when_used="json-unless-none",
28
+ ),
29
+ ]
File without changes
@@ -0,0 +1,64 @@
1
+ """
2
+ Rate-limited semaphore for controlling concurrent operations.
3
+ """
4
+
5
+ import asyncio
6
+ import time
7
+
8
+
9
+ class RateLimitSemaphore(asyncio.Semaphore):
10
+ """
11
+ A semaphore that limits both the maximum number of concurrent acquisitions
12
+ and the maximum rate at which acquisitions can occur.
13
+ """
14
+
15
+ @property
16
+ def max_concurrent(self) -> int:
17
+ """Maximum number of concurrent acquisitions."""
18
+ return self._max_concurrent
19
+
20
+ @property
21
+ def max_per_second(self) -> float:
22
+ """Maximum acquisitions per second."""
23
+ return self._max_per_second
24
+
25
+ def __init__(self, max_concurrent: int, max_per_second: float):
26
+ """
27
+ Initialize the rate-limited semaphore.
28
+
29
+ Args:
30
+ max_concurrent: Maximum number of concurrent acquisitions.
31
+ max_per_second: Maximum acquisitions per second.
32
+ """
33
+ super().__init__(max_concurrent)
34
+ self._max_concurrent = max_concurrent
35
+ self._max_per_second = max_per_second
36
+
37
+ self._acquire_interval = 1.0 / max_per_second
38
+ self._next_acquire_time = 0
39
+ self._rate_lock = asyncio.Lock()
40
+
41
+ async def acquire(self):
42
+ """
43
+ Acquire a semaphore with rate limiting.
44
+
45
+ In addition to traditional semaphore behavior, this will also rate limit
46
+ the number of acquisitions per second by blocking until the next
47
+ acquisition interval.
48
+ """
49
+ # Rate limit acquisitions (serialized to prevent race conditions)
50
+ async with self._rate_lock:
51
+ now = time.monotonic()
52
+ wait = self._next_acquire_time - now
53
+
54
+ if wait > 0:
55
+ # Sleep until next acquisition interval
56
+ await asyncio.sleep(wait)
57
+ self._next_acquire_time += self._acquire_interval
58
+
59
+ else:
60
+ # no wait, set next acquisition time to next interval
61
+ self._next_acquire_time = now + self._acquire_interval
62
+
63
+ # Acquire semaphore for concurrency limiting
64
+ return await super().acquire()
@@ -0,0 +1,57 @@
1
+ from itertools import islice
2
+ from typing import Generator, Iterable, TypeVar, Union
3
+
4
+ T = TypeVar("T")
5
+
6
+
7
+ def batched(
8
+ iterable: Iterable[T],
9
+ n: int = 10,
10
+ ) -> Generator[tuple[T, ...], None, None]:
11
+ """Process an iterable as batches of size `n`
12
+
13
+ Args:
14
+ iterable (Iterable[T]): anything iterable
15
+ n (Optional[int]): the size of tuples to yield
16
+
17
+ Yields (tuple[T, ...]) tuple of size `n` elements from `iterable`
18
+
19
+ Example:
20
+
21
+ ```py
22
+ my_list = [1,2,3,4,5]
23
+ print(list(batched(my_list, 3))) # prints [(1,2,3), (4,5)]
24
+ ```
25
+ """
26
+ if n < 1:
27
+ raise ValueError("n must be at least 1")
28
+
29
+ iterator = iter(iterable)
30
+ while batch := tuple(islice(iterator, n)):
31
+ yield batch
32
+
33
+
34
+ def to_list(maybe_list: Union[T, list[T], set[T]]) -> list[T]:
35
+ """
36
+ Coerce the argument into a list if it is not already
37
+ """
38
+ if isinstance(maybe_list, list):
39
+ return maybe_list
40
+
41
+ if isinstance(maybe_list, set):
42
+ return list(maybe_list)
43
+
44
+ return [maybe_list]
45
+
46
+
47
+ def to_set(maybe_set: Union[T, list[T], set[T]]) -> set[T]:
48
+ """
49
+ Coerce the argument into a set if it is not already
50
+ """
51
+ if isinstance(maybe_set, set):
52
+ return maybe_set
53
+
54
+ if isinstance(maybe_set, list):
55
+ return set(maybe_set)
56
+
57
+ return {maybe_set}