prismadata 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prismadata/__init__.py +69 -0
- prismadata/_async_auth.py +112 -0
- prismadata/_async_http.py +143 -0
- prismadata/_auth.py +126 -0
- prismadata/_batch.py +196 -0
- prismadata/_cache.py +52 -0
- prismadata/_columns.py +32 -0
- prismadata/_constants.py +27 -0
- prismadata/_enrich.py +96 -0
- prismadata/_http.py +248 -0
- prismadata/_progress.py +34 -0
- prismadata/_types.py +117 -0
- prismadata/_validation.py +29 -0
- prismadata/async_client.py +606 -0
- prismadata/client.py +795 -0
- prismadata/exceptions.py +61 -0
- prismadata/py.typed +0 -0
- prismadata/sklearn.py +146 -0
- prismadata-0.1.0.dist-info/METADATA +247 -0
- prismadata-0.1.0.dist-info/RECORD +22 -0
- prismadata-0.1.0.dist-info/WHEEL +4 -0
- prismadata-0.1.0.dist-info/licenses/LICENSE +21 -0
prismadata/__init__.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""PrismaData - Python client for location intelligence API.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
from prismadata import Client
|
|
6
|
+
|
|
7
|
+
client = Client(api_key="your-key")
|
|
8
|
+
result = client.geocode(full_address="Av Paulista 1000, Sao Paulo")
|
|
9
|
+
|
|
10
|
+
Async usage::
|
|
11
|
+
|
|
12
|
+
from prismadata import AsyncClient
|
|
13
|
+
|
|
14
|
+
async with await AsyncClient.create(api_key="your-key") as client:
|
|
15
|
+
result = await client.geocode(full_address="Av Paulista 1000, Sao Paulo")
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from ._types import (
|
|
19
|
+
BorderResult,
|
|
20
|
+
GeocodeResult,
|
|
21
|
+
IncomePdfResult,
|
|
22
|
+
IncomeStaticResult,
|
|
23
|
+
InfoscResult,
|
|
24
|
+
IsochroneResult,
|
|
25
|
+
PrecatoryResult,
|
|
26
|
+
PrisonResult,
|
|
27
|
+
ReverseGeocodeResult,
|
|
28
|
+
RouteResult,
|
|
29
|
+
RoutingBatchResult,
|
|
30
|
+
SlumResult,
|
|
31
|
+
UserInfo,
|
|
32
|
+
)
|
|
33
|
+
from .async_client import AsyncClient
|
|
34
|
+
from .client import Client
|
|
35
|
+
from .exceptions import (
|
|
36
|
+
AuthenticationError,
|
|
37
|
+
BatchError,
|
|
38
|
+
PrismaDataError,
|
|
39
|
+
QuotaExhaustedError,
|
|
40
|
+
RateLimitError,
|
|
41
|
+
ValidationError,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
__version__ = "0.1.0"
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
"AsyncClient",
|
|
48
|
+
"Client",
|
|
49
|
+
"__version__",
|
|
50
|
+
"AuthenticationError",
|
|
51
|
+
"BatchError",
|
|
52
|
+
"PrismaDataError",
|
|
53
|
+
"QuotaExhaustedError",
|
|
54
|
+
"RateLimitError",
|
|
55
|
+
"ValidationError",
|
|
56
|
+
"BorderResult",
|
|
57
|
+
"GeocodeResult",
|
|
58
|
+
"IncomePdfResult",
|
|
59
|
+
"IncomeStaticResult",
|
|
60
|
+
"InfoscResult",
|
|
61
|
+
"IsochroneResult",
|
|
62
|
+
"PrecatoryResult",
|
|
63
|
+
"PrisonResult",
|
|
64
|
+
"ReverseGeocodeResult",
|
|
65
|
+
"RouteResult",
|
|
66
|
+
"RoutingBatchResult",
|
|
67
|
+
"SlumResult",
|
|
68
|
+
"UserInfo",
|
|
69
|
+
]
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Async JWT authentication manager for the PrismaData API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from ._auth import _decode_jwt_claims
|
|
13
|
+
from ._constants import TOKEN_RENEW_MARGIN
|
|
14
|
+
from .exceptions import AuthenticationError
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger("prismadata.async_auth")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AsyncAuthManager:
|
|
20
|
+
"""Manages JWT token lifecycle asynchronously."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
base_url: str,
|
|
25
|
+
timeout: int,
|
|
26
|
+
*,
|
|
27
|
+
api_key: str | None = None,
|
|
28
|
+
username: str | None = None,
|
|
29
|
+
password: str | None = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
self._base_url = base_url
|
|
32
|
+
self._timeout = timeout
|
|
33
|
+
self._api_key = api_key
|
|
34
|
+
self._username = username
|
|
35
|
+
self._password = password
|
|
36
|
+
self._token: str | None = None
|
|
37
|
+
self._token_exp: float = 0.0
|
|
38
|
+
self._rate_limit: float | None = None
|
|
39
|
+
self._sandbox: bool = False
|
|
40
|
+
self._lock = asyncio.Lock()
|
|
41
|
+
self._http_client: httpx.AsyncClient | None = None
|
|
42
|
+
|
|
43
|
+
def set_http_client(self, client: httpx.AsyncClient) -> None:
|
|
44
|
+
"""Attach a shared httpx.AsyncClient for token refresh requests."""
|
|
45
|
+
self._http_client = client
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def rate_limit(self) -> float | None:
|
|
49
|
+
return self._rate_limit
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def is_sandbox(self) -> bool:
|
|
53
|
+
return self._sandbox
|
|
54
|
+
|
|
55
|
+
async def authenticate(self) -> None:
|
|
56
|
+
"""Force token fetch immediately."""
|
|
57
|
+
async with self._lock:
|
|
58
|
+
await self._fetch_token()
|
|
59
|
+
|
|
60
|
+
async def ensure_valid_token(self) -> None:
|
|
61
|
+
"""Ensure a valid token is available, refreshing if needed."""
|
|
62
|
+
if self._token and time.time() < (self._token_exp - TOKEN_RENEW_MARGIN):
|
|
63
|
+
return
|
|
64
|
+
async with self._lock:
|
|
65
|
+
if self._token and time.time() < (self._token_exp - TOKEN_RENEW_MARGIN):
|
|
66
|
+
return
|
|
67
|
+
await self._fetch_token()
|
|
68
|
+
|
|
69
|
+
def get_headers(self) -> dict[str, str]:
|
|
70
|
+
"""Return auth headers (sync — call ensure_valid_token first)."""
|
|
71
|
+
if self._api_key:
|
|
72
|
+
return {"X-Apikey": self._api_key}
|
|
73
|
+
return {"Authorization": f"Bearer {self._token}"}
|
|
74
|
+
|
|
75
|
+
async def _try_refresh_claims(self) -> None:
|
|
76
|
+
"""Best-effort claim refresh in API key mode."""
|
|
77
|
+
try:
|
|
78
|
+
await self.ensure_valid_token()
|
|
79
|
+
except Exception as exc:
|
|
80
|
+
logger.warning("Claim refresh failed (API key mode): %s", exc)
|
|
81
|
+
|
|
82
|
+
async def _fetch_token(self) -> None:
|
|
83
|
+
mode = "api_key" if self._api_key else "credentials"
|
|
84
|
+
logger.debug("Fetching token (mode=%s)", mode)
|
|
85
|
+
|
|
86
|
+
if self._api_key:
|
|
87
|
+
data = {"api_key": self._api_key}
|
|
88
|
+
else:
|
|
89
|
+
data = {"username": self._username, "password": self._password}
|
|
90
|
+
|
|
91
|
+
url = f"{self._base_url}/auth/token"
|
|
92
|
+
try:
|
|
93
|
+
if self._http_client is not None:
|
|
94
|
+
resp = await self._http_client.post(url, data=data)
|
|
95
|
+
else:
|
|
96
|
+
async with httpx.AsyncClient(timeout=self._timeout) as http:
|
|
97
|
+
resp = await http.post(url, data=data)
|
|
98
|
+
except httpx.HTTPError as exc:
|
|
99
|
+
raise AuthenticationError(f"Failed to connect to auth endpoint: {exc}") from exc
|
|
100
|
+
|
|
101
|
+
if resp.status_code in (401, 403):
|
|
102
|
+
raise AuthenticationError(f"Authentication failed: {resp.text}")
|
|
103
|
+
resp.raise_for_status()
|
|
104
|
+
|
|
105
|
+
body = resp.json()
|
|
106
|
+
self._token = body["access_token"]
|
|
107
|
+
claims = _decode_jwt_claims(self._token)
|
|
108
|
+
self._token_exp = claims.get("exp", time.time() + 3600)
|
|
109
|
+
self._rate_limit = claims.get("rate_limit")
|
|
110
|
+
self._sandbox = claims.get("sandbox", False)
|
|
111
|
+
expires_in = self._token_exp - time.time()
|
|
112
|
+
logger.debug("Token obtained, expires in %.0fs, rate_limit=%s", expires_in, self._rate_limit)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Async HTTP transport layer with retry, throttle, and error handling."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
from tenacity import (
|
|
12
|
+
RetryCallState,
|
|
13
|
+
retry,
|
|
14
|
+
retry_if_result,
|
|
15
|
+
stop_after_attempt,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from ._async_auth import AsyncAuthManager
|
|
19
|
+
from ._constants import (
|
|
20
|
+
DEFAULT_TIMEOUT,
|
|
21
|
+
RETRYABLE_STATUS_CODES,
|
|
22
|
+
RETRY_MAX_ATTEMPTS,
|
|
23
|
+
USER_AGENT_PREFIX,
|
|
24
|
+
)
|
|
25
|
+
from ._http import _handle_response, _is_retryable, _wait_for_retry
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("prismadata.async_http")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AsyncHttpClient:
|
|
31
|
+
"""Async HTTP client with automatic auth, retry, and throttle."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
auth: AsyncAuthManager,
|
|
36
|
+
timeout: int = DEFAULT_TIMEOUT,
|
|
37
|
+
version: str = "0.1.0",
|
|
38
|
+
app_name: str | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
self._auth = auth
|
|
41
|
+
self._timeout = timeout
|
|
42
|
+
self._last_request_time: float = 0.0
|
|
43
|
+
self._rl_remaining: int | None = None
|
|
44
|
+
self._rl_reset: float | None = None
|
|
45
|
+
headers: dict[str, str] = {
|
|
46
|
+
"User-Agent": f"{USER_AGENT_PREFIX}/{version}",
|
|
47
|
+
"X-Client": f"{USER_AGENT_PREFIX}/{version}",
|
|
48
|
+
}
|
|
49
|
+
if app_name:
|
|
50
|
+
headers["X-App"] = app_name
|
|
51
|
+
self._client = httpx.AsyncClient(timeout=timeout, headers=headers)
|
|
52
|
+
|
|
53
|
+
async def close(self) -> None:
|
|
54
|
+
await self._client.aclose()
|
|
55
|
+
|
|
56
|
+
async def get(self, path: str, params: dict[str, Any] | None = None) -> Any:
|
|
57
|
+
return await self._request("GET", path, params=params)
|
|
58
|
+
|
|
59
|
+
async def post(self, path: str, json_body: Any = None, params: dict[str, Any] | None = None) -> Any:
|
|
60
|
+
return await self._request("POST", path, json_body=json_body, params=params)
|
|
61
|
+
|
|
62
|
+
async def _request(
|
|
63
|
+
self,
|
|
64
|
+
method: str,
|
|
65
|
+
path: str,
|
|
66
|
+
params: dict[str, Any] | None = None,
|
|
67
|
+
json_body: Any = None,
|
|
68
|
+
) -> Any:
|
|
69
|
+
response = await self._do_request(method, path, params, json_body)
|
|
70
|
+
return _handle_response(response)
|
|
71
|
+
|
|
72
|
+
@retry(
|
|
73
|
+
retry=retry_if_result(_is_retryable),
|
|
74
|
+
stop=stop_after_attempt(RETRY_MAX_ATTEMPTS),
|
|
75
|
+
wait=_wait_for_retry,
|
|
76
|
+
retry_error_callback=lambda state: state.outcome.result(),
|
|
77
|
+
)
|
|
78
|
+
async def _do_request(
|
|
79
|
+
self,
|
|
80
|
+
method: str,
|
|
81
|
+
path: str,
|
|
82
|
+
params: dict[str, Any] | None = None,
|
|
83
|
+
json_body: Any = None,
|
|
84
|
+
) -> httpx.Response:
|
|
85
|
+
await self._throttle()
|
|
86
|
+
|
|
87
|
+
if self._auth._api_key:
|
|
88
|
+
await self._auth._try_refresh_claims()
|
|
89
|
+
else:
|
|
90
|
+
await self._auth.ensure_valid_token()
|
|
91
|
+
headers = self._auth.get_headers()
|
|
92
|
+
|
|
93
|
+
url = f"{self._auth._base_url}{path}"
|
|
94
|
+
|
|
95
|
+
if params:
|
|
96
|
+
params = {k: v for k, v in params.items() if v is not None}
|
|
97
|
+
|
|
98
|
+
logger.debug("%s %s", method, path)
|
|
99
|
+
t0 = time.monotonic()
|
|
100
|
+
response = await self._client.request(
|
|
101
|
+
method, url, params=params, json=json_body, headers=headers
|
|
102
|
+
)
|
|
103
|
+
elapsed = time.monotonic() - t0
|
|
104
|
+
self._update_rate_limit(response)
|
|
105
|
+
status = response.status_code
|
|
106
|
+
if status in RETRYABLE_STATUS_CODES:
|
|
107
|
+
logger.warning("Retryable %d on %s %s, will retry", status, method, path)
|
|
108
|
+
else:
|
|
109
|
+
logger.debug("%s %s -> %d (%.1fs)", method, path, status, elapsed)
|
|
110
|
+
return response
|
|
111
|
+
|
|
112
|
+
def _update_rate_limit(self, response: httpx.Response) -> None:
|
|
113
|
+
remaining = response.headers.get("x-ratelimit-remaining")
|
|
114
|
+
reset = response.headers.get("x-ratelimit-reset")
|
|
115
|
+
if remaining is not None:
|
|
116
|
+
self._rl_remaining = int(remaining)
|
|
117
|
+
if reset is not None:
|
|
118
|
+
self._rl_reset = float(reset)
|
|
119
|
+
|
|
120
|
+
async def _throttle(self) -> None:
|
|
121
|
+
if self._rl_remaining is not None and self._rl_remaining <= 0:
|
|
122
|
+
if self._rl_reset:
|
|
123
|
+
wait = self._rl_reset - time.time()
|
|
124
|
+
if wait > 0:
|
|
125
|
+
logger.debug("Rate limit: sleeping %.2fs (headers)", wait)
|
|
126
|
+
await asyncio.sleep(wait)
|
|
127
|
+
self._rl_remaining = None
|
|
128
|
+
self._rl_reset = None
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
if self._rl_remaining is not None:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
rate_limit = self._auth.rate_limit
|
|
135
|
+
if not rate_limit or rate_limit <= 0:
|
|
136
|
+
return
|
|
137
|
+
min_interval = 1.0 / rate_limit
|
|
138
|
+
elapsed = time.time() - self._last_request_time
|
|
139
|
+
if elapsed < min_interval:
|
|
140
|
+
wait = min_interval - elapsed
|
|
141
|
+
logger.debug("Rate limit: sleeping %.2fs (jwt claim)", wait)
|
|
142
|
+
await asyncio.sleep(wait)
|
|
143
|
+
self._last_request_time = time.time()
|
prismadata/_auth.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""JWT authentication manager for the PrismaData API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from ._constants import TOKEN_RENEW_MARGIN
|
|
15
|
+
from .exceptions import AuthenticationError
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("prismadata.auth")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AuthManager:
|
|
21
|
+
"""Manages JWT token lifecycle: obtain, cache, and auto-renew."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
base_url: str,
|
|
26
|
+
timeout: int,
|
|
27
|
+
*,
|
|
28
|
+
api_key: str | None = None,
|
|
29
|
+
username: str | None = None,
|
|
30
|
+
password: str | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
self._base_url = base_url
|
|
33
|
+
self._timeout = timeout
|
|
34
|
+
self._api_key = api_key
|
|
35
|
+
self._username = username
|
|
36
|
+
self._password = password
|
|
37
|
+
self._token: str | None = None
|
|
38
|
+
self._token_exp: float = 0.0
|
|
39
|
+
self._rate_limit: float | None = None
|
|
40
|
+
self._sandbox: bool = False
|
|
41
|
+
self._lock = threading.Lock()
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def rate_limit(self) -> float | None:
|
|
45
|
+
return self._rate_limit
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def is_sandbox(self) -> bool:
|
|
49
|
+
return self._sandbox
|
|
50
|
+
|
|
51
|
+
def authenticate(self) -> None:
|
|
52
|
+
"""Force token fetch immediately. Raises AuthenticationError on failure."""
|
|
53
|
+
with self._lock:
|
|
54
|
+
self._fetch_token()
|
|
55
|
+
|
|
56
|
+
def get_headers(self) -> dict[str, str]:
|
|
57
|
+
if self._api_key:
|
|
58
|
+
self._try_refresh_claims()
|
|
59
|
+
return {"X-Apikey": self._api_key}
|
|
60
|
+
self._ensure_token()
|
|
61
|
+
return {"Authorization": f"Bearer {self._token}"}
|
|
62
|
+
|
|
63
|
+
def _try_refresh_claims(self) -> None:
|
|
64
|
+
"""Best-effort claim refresh in API key mode.
|
|
65
|
+
|
|
66
|
+
Uses token exp as cadence — adapts to whatever validity
|
|
67
|
+
the server sets. Failures are logged as warnings.
|
|
68
|
+
"""
|
|
69
|
+
try:
|
|
70
|
+
self._ensure_token()
|
|
71
|
+
except Exception as exc:
|
|
72
|
+
logger.warning("Claim refresh failed (API key mode): %s", exc)
|
|
73
|
+
|
|
74
|
+
def _ensure_token(self) -> None:
|
|
75
|
+
if self._token and time.time() < (self._token_exp - TOKEN_RENEW_MARGIN):
|
|
76
|
+
return
|
|
77
|
+
with self._lock:
|
|
78
|
+
if self._token and time.time() < (self._token_exp - TOKEN_RENEW_MARGIN):
|
|
79
|
+
return
|
|
80
|
+
self._fetch_token()
|
|
81
|
+
|
|
82
|
+
def _fetch_token(self) -> None:
|
|
83
|
+
mode = "api_key" if self._api_key else "credentials"
|
|
84
|
+
logger.debug("Fetching token (mode=%s)", mode)
|
|
85
|
+
|
|
86
|
+
if self._api_key:
|
|
87
|
+
data = {"api_key": self._api_key}
|
|
88
|
+
else:
|
|
89
|
+
data = {"username": self._username, "password": self._password}
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
resp = httpx.post(
|
|
93
|
+
f"{self._base_url}/auth/token",
|
|
94
|
+
data=data,
|
|
95
|
+
timeout=self._timeout,
|
|
96
|
+
)
|
|
97
|
+
except httpx.HTTPError as exc:
|
|
98
|
+
raise AuthenticationError(f"Failed to connect to auth endpoint: {exc}") from exc
|
|
99
|
+
|
|
100
|
+
if resp.status_code in (401, 403):
|
|
101
|
+
raise AuthenticationError(f"Authentication failed: {resp.text}")
|
|
102
|
+
resp.raise_for_status()
|
|
103
|
+
|
|
104
|
+
body = resp.json()
|
|
105
|
+
self._token = body["access_token"]
|
|
106
|
+
claims = _decode_jwt_claims(self._token)
|
|
107
|
+
self._token_exp = claims.get("exp", time.time() + 3600)
|
|
108
|
+
self._rate_limit = claims.get("rate_limit")
|
|
109
|
+
self._sandbox = claims.get("sandbox", False)
|
|
110
|
+
expires_in = self._token_exp - time.time()
|
|
111
|
+
logger.debug("Token obtained, expires in %.0fs, rate_limit=%s", expires_in, self._rate_limit)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _decode_jwt_claims(token: str) -> dict[str, Any]:
|
|
115
|
+
"""Decode JWT payload without signature verification."""
|
|
116
|
+
parts = token.split(".")
|
|
117
|
+
if len(parts) != 3:
|
|
118
|
+
return {}
|
|
119
|
+
payload = parts[1]
|
|
120
|
+
padding = 4 - len(payload) % 4
|
|
121
|
+
if padding != 4:
|
|
122
|
+
payload += "=" * padding
|
|
123
|
+
try:
|
|
124
|
+
return json.loads(base64.urlsafe_b64decode(payload))
|
|
125
|
+
except Exception:
|
|
126
|
+
return {}
|
prismadata/_batch.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""Batch processing with automatic chunking."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
from typing import Any, Awaitable, Callable
|
|
8
|
+
|
|
9
|
+
from .exceptions import BatchError
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger("prismadata.batch")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _raise_if_partial(
|
|
15
|
+
results: dict[str, Any],
|
|
16
|
+
failed_keys: list[str],
|
|
17
|
+
errors: list[Exception],
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Raise BatchError if any chunks failed."""
|
|
20
|
+
if not errors:
|
|
21
|
+
return
|
|
22
|
+
msg = f"Batch completed with {len(errors)} chunk failure(s), {len(failed_keys)} keys failed"
|
|
23
|
+
raise BatchError(
|
|
24
|
+
msg,
|
|
25
|
+
partial_results=results,
|
|
26
|
+
failed_keys=failed_keys,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def process_batch(
|
|
31
|
+
items: dict[str, list[float]],
|
|
32
|
+
request_fn: Callable[[dict[str, list[float]]], dict[str, Any]],
|
|
33
|
+
max_size: int,
|
|
34
|
+
on_progress: Callable[[int], None] | None = None,
|
|
35
|
+
) -> dict[str, Any]:
|
|
36
|
+
"""Split a dict of {id: [lat, lng]} into chunks and merge results.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
items: Mapping of point_id to [lat, lng].
|
|
40
|
+
request_fn: Function that posts a batch and returns results dict.
|
|
41
|
+
max_size: Maximum items per request.
|
|
42
|
+
on_progress: Callback invoked with the number of items completed per chunk.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Merged results dict.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
BatchError: If one or more chunks fail. Contains partial_results
|
|
49
|
+
from successful chunks and failed_keys from failed ones.
|
|
50
|
+
"""
|
|
51
|
+
keys = list(items.keys())
|
|
52
|
+
total = len(keys)
|
|
53
|
+
num_chunks = math.ceil(total / max_size) if total else 0
|
|
54
|
+
logger.debug("Processing %d items in %d chunks (max_size=%d)", total, num_chunks, max_size)
|
|
55
|
+
results: dict[str, Any] = {}
|
|
56
|
+
failed_keys: list[str] = []
|
|
57
|
+
errors: list[Exception] = []
|
|
58
|
+
|
|
59
|
+
chunk_idx = 0
|
|
60
|
+
for start in range(0, total, max_size):
|
|
61
|
+
chunk_idx += 1
|
|
62
|
+
chunk_keys = keys[start : start + max_size]
|
|
63
|
+
chunk = {k: items[k] for k in chunk_keys}
|
|
64
|
+
try:
|
|
65
|
+
chunk_result = request_fn(chunk)
|
|
66
|
+
results.update(chunk_result)
|
|
67
|
+
logger.debug("Chunk %d/%d completed (%d items)", chunk_idx, num_chunks, len(chunk_keys))
|
|
68
|
+
except Exception as exc:
|
|
69
|
+
failed_keys.extend(chunk_keys)
|
|
70
|
+
errors.append(exc)
|
|
71
|
+
logger.warning("Chunk %d/%d failed (%d items): %s", chunk_idx, num_chunks, len(chunk_keys), exc)
|
|
72
|
+
if on_progress:
|
|
73
|
+
on_progress(len(chunk_keys))
|
|
74
|
+
|
|
75
|
+
_raise_if_partial(results, failed_keys, errors)
|
|
76
|
+
return results
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def process_routing_batch(
|
|
80
|
+
items: list[dict[str, Any]],
|
|
81
|
+
request_fn: Callable[[list[dict[str, Any]]], dict[str, Any]],
|
|
82
|
+
max_size: int,
|
|
83
|
+
on_progress: Callable[[int], None] | None = None,
|
|
84
|
+
) -> dict[str, Any]:
|
|
85
|
+
"""Split a list of routing items into chunks and merge results.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
items: List of routing request items.
|
|
89
|
+
request_fn: Function that posts a batch and returns results dict.
|
|
90
|
+
max_size: Maximum items per request.
|
|
91
|
+
on_progress: Callback invoked with the number of items completed per chunk.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Merged results dict with TOTAL, SUCESSOS, FALHAS, RESULTADOS.
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
BatchError: If one or more chunks fail.
|
|
98
|
+
"""
|
|
99
|
+
total = len(items)
|
|
100
|
+
num_chunks = math.ceil(total / max_size) if total else 0
|
|
101
|
+
logger.debug("Processing %d items in %d chunks (max_size=%d)", total, num_chunks, max_size)
|
|
102
|
+
merged: dict[str, Any] = {"TOTAL": 0, "SUCESSOS": 0, "FALHAS": 0, "RESULTADOS": []}
|
|
103
|
+
failed_keys: list[str] = []
|
|
104
|
+
errors: list[Exception] = []
|
|
105
|
+
|
|
106
|
+
chunk_idx = 0
|
|
107
|
+
for start in range(0, total, max_size):
|
|
108
|
+
chunk_idx += 1
|
|
109
|
+
chunk = items[start : start + max_size]
|
|
110
|
+
try:
|
|
111
|
+
chunk_result = request_fn(chunk)
|
|
112
|
+
merged["TOTAL"] += chunk_result.get("TOTAL", 0)
|
|
113
|
+
merged["SUCESSOS"] += chunk_result.get("SUCESSOS", 0)
|
|
114
|
+
merged["FALHAS"] += chunk_result.get("FALHAS", 0)
|
|
115
|
+
merged["RESULTADOS"].extend(chunk_result.get("RESULTADOS", []))
|
|
116
|
+
logger.debug("Chunk %d/%d completed (%d items)", chunk_idx, num_chunks, len(chunk))
|
|
117
|
+
except Exception as exc:
|
|
118
|
+
failed_keys.extend(str(start + i) for i in range(len(chunk)))
|
|
119
|
+
errors.append(exc)
|
|
120
|
+
logger.warning("Chunk %d/%d failed (%d items): %s", chunk_idx, num_chunks, len(chunk), exc)
|
|
121
|
+
if on_progress:
|
|
122
|
+
on_progress(len(chunk))
|
|
123
|
+
|
|
124
|
+
_raise_if_partial(merged, failed_keys, errors)
|
|
125
|
+
return merged
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def async_process_batch(
|
|
129
|
+
items: dict[str, list[float]],
|
|
130
|
+
request_fn: Callable[[dict[str, list[float]]], Awaitable[dict[str, Any]]],
|
|
131
|
+
max_size: int,
|
|
132
|
+
on_progress: Callable[[int], None] | None = None,
|
|
133
|
+
) -> dict[str, Any]:
|
|
134
|
+
"""Async version of process_batch — sequential chunks to respect rate limits."""
|
|
135
|
+
keys = list(items.keys())
|
|
136
|
+
total = len(keys)
|
|
137
|
+
num_chunks = math.ceil(total / max_size) if total else 0
|
|
138
|
+
logger.debug("Async processing %d items in %d chunks (max_size=%d)", total, num_chunks, max_size)
|
|
139
|
+
results: dict[str, Any] = {}
|
|
140
|
+
failed_keys: list[str] = []
|
|
141
|
+
errors: list[Exception] = []
|
|
142
|
+
|
|
143
|
+
chunk_idx = 0
|
|
144
|
+
for start in range(0, total, max_size):
|
|
145
|
+
chunk_idx += 1
|
|
146
|
+
chunk_keys = keys[start : start + max_size]
|
|
147
|
+
chunk = {k: items[k] for k in chunk_keys}
|
|
148
|
+
try:
|
|
149
|
+
chunk_result = await request_fn(chunk)
|
|
150
|
+
results.update(chunk_result)
|
|
151
|
+
logger.debug("Chunk %d/%d completed (%d items)", chunk_idx, num_chunks, len(chunk_keys))
|
|
152
|
+
except Exception as exc:
|
|
153
|
+
failed_keys.extend(chunk_keys)
|
|
154
|
+
errors.append(exc)
|
|
155
|
+
logger.warning("Chunk %d/%d failed (%d items): %s", chunk_idx, num_chunks, len(chunk_keys), exc)
|
|
156
|
+
if on_progress:
|
|
157
|
+
on_progress(len(chunk_keys))
|
|
158
|
+
|
|
159
|
+
_raise_if_partial(results, failed_keys, errors)
|
|
160
|
+
return results
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def async_process_routing_batch(
|
|
164
|
+
items: list[dict[str, Any]],
|
|
165
|
+
request_fn: Callable[[list[dict[str, Any]]], Awaitable[dict[str, Any]]],
|
|
166
|
+
max_size: int,
|
|
167
|
+
on_progress: Callable[[int], None] | None = None,
|
|
168
|
+
) -> dict[str, Any]:
|
|
169
|
+
"""Async version of process_routing_batch — sequential chunks to respect rate limits."""
|
|
170
|
+
total = len(items)
|
|
171
|
+
num_chunks = math.ceil(total / max_size) if total else 0
|
|
172
|
+
logger.debug("Async processing %d items in %d chunks (max_size=%d)", total, num_chunks, max_size)
|
|
173
|
+
merged: dict[str, Any] = {"TOTAL": 0, "SUCESSOS": 0, "FALHAS": 0, "RESULTADOS": []}
|
|
174
|
+
failed_keys: list[str] = []
|
|
175
|
+
errors: list[Exception] = []
|
|
176
|
+
|
|
177
|
+
chunk_idx = 0
|
|
178
|
+
for start in range(0, total, max_size):
|
|
179
|
+
chunk_idx += 1
|
|
180
|
+
chunk = items[start : start + max_size]
|
|
181
|
+
try:
|
|
182
|
+
chunk_result = await request_fn(chunk)
|
|
183
|
+
merged["TOTAL"] += chunk_result.get("TOTAL", 0)
|
|
184
|
+
merged["SUCESSOS"] += chunk_result.get("SUCESSOS", 0)
|
|
185
|
+
merged["FALHAS"] += chunk_result.get("FALHAS", 0)
|
|
186
|
+
merged["RESULTADOS"].extend(chunk_result.get("RESULTADOS", []))
|
|
187
|
+
logger.debug("Chunk %d/%d completed (%d items)", chunk_idx, num_chunks, len(chunk))
|
|
188
|
+
except Exception as exc:
|
|
189
|
+
failed_keys.extend(str(start + i) for i in range(len(chunk)))
|
|
190
|
+
errors.append(exc)
|
|
191
|
+
logger.warning("Chunk %d/%d failed (%d items): %s", chunk_idx, num_chunks, len(chunk), exc)
|
|
192
|
+
if on_progress:
|
|
193
|
+
on_progress(len(chunk))
|
|
194
|
+
|
|
195
|
+
_raise_if_partial(merged, failed_keys, errors)
|
|
196
|
+
return merged
|
prismadata/_cache.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Optional disk cache wrapper using diskcache."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CacheManager:
|
|
12
|
+
"""Transparent disk cache. Degrades gracefully if diskcache is not installed."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, enabled: bool = False, ttl: int = 86400, directory: str | None = None) -> None:
|
|
15
|
+
self._enabled = enabled
|
|
16
|
+
self._ttl = ttl
|
|
17
|
+
self._cache: Any = None
|
|
18
|
+
|
|
19
|
+
if not enabled:
|
|
20
|
+
return
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import diskcache
|
|
24
|
+
cache_dir = directory or os.path.expanduser("~/.prismadata/cache")
|
|
25
|
+
self._cache = diskcache.Cache(cache_dir)
|
|
26
|
+
except ImportError:
|
|
27
|
+
self._enabled = False
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def enabled(self) -> bool:
|
|
31
|
+
return self._enabled and self._cache is not None
|
|
32
|
+
|
|
33
|
+
def get(self, endpoint: str, params: dict[str, Any]) -> Any | None:
|
|
34
|
+
if not self.enabled:
|
|
35
|
+
return None
|
|
36
|
+
key = _make_key(endpoint, params)
|
|
37
|
+
return self._cache.get(key)
|
|
38
|
+
|
|
39
|
+
def set(self, endpoint: str, params: dict[str, Any], value: Any) -> None:
|
|
40
|
+
if not self.enabled:
|
|
41
|
+
return
|
|
42
|
+
key = _make_key(endpoint, params)
|
|
43
|
+
self._cache.set(key, value, expire=self._ttl)
|
|
44
|
+
|
|
45
|
+
def close(self) -> None:
|
|
46
|
+
if self._cache is not None:
|
|
47
|
+
self._cache.close()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _make_key(endpoint: str, params: dict[str, Any]) -> str:
|
|
51
|
+
raw = json.dumps({"e": endpoint, "p": params}, sort_keys=True)
|
|
52
|
+
return hashlib.sha256(raw.encode()).hexdigest()
|