prismadata 0.3.2__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {prismadata-0.3.2 → prismadata-0.4.0}/PKG-INFO +1 -1
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/__init__.py +2 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_batch.py +20 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_constants.py +5 -0
- prismadata-0.4.0/prismadata/_prepare.py +119 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/async_client.py +45 -6
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/client.py +43 -5
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/exceptions.py +4 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/pyproject.toml +1 -1
- {prismadata-0.3.2 → prismadata-0.4.0}/LICENSE +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/README.md +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_async_auth.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_async_http.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_auth.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_cache.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_columns.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_enrich.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_http.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_progress.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_types.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_validation.py +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/py.typed +0 -0
- {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/sklearn.py +0 -0
|
@@ -37,6 +37,7 @@ from .client import Client
|
|
|
37
37
|
from .exceptions import (
|
|
38
38
|
AuthenticationError,
|
|
39
39
|
BatchError,
|
|
40
|
+
BatchPrepareError,
|
|
40
41
|
PrismaDataError,
|
|
41
42
|
QuotaExhaustedError,
|
|
42
43
|
RateLimitError,
|
|
@@ -62,6 +63,7 @@ __all__ = [
|
|
|
62
63
|
"__version__",
|
|
63
64
|
"AuthenticationError",
|
|
64
65
|
"BatchError",
|
|
66
|
+
"BatchPrepareError",
|
|
65
67
|
"PrismaDataError",
|
|
66
68
|
"QuotaExhaustedError",
|
|
67
69
|
"RateLimitError",
|
|
@@ -11,6 +11,26 @@ from .exceptions import BatchError
|
|
|
11
11
|
logger = logging.getLogger("prismadata.batch")
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
def split_into_groups(
|
|
15
|
+
items: dict[str, Any], num_groups: int,
|
|
16
|
+
) -> list[dict[str, Any]]:
|
|
17
|
+
"""Split a dict into num_groups roughly-equal sub-dicts.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
items: Dictionary to split.
|
|
21
|
+
num_groups: Number of groups (must be >= 1).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
List of dicts, each containing a subset of the items.
|
|
25
|
+
"""
|
|
26
|
+
keys = list(items.keys())
|
|
27
|
+
group_size = math.ceil(len(keys) / num_groups)
|
|
28
|
+
return [
|
|
29
|
+
{k: items[k] for k in keys[i : i + group_size]}
|
|
30
|
+
for i in range(0, len(keys), group_size)
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
14
34
|
def _raise_if_partial(
|
|
15
35
|
results: dict[str, Any],
|
|
16
36
|
failed_keys: list[str],
|
|
@@ -22,6 +22,11 @@ USER_AGENT_PREFIX = "prismadata-python"
|
|
|
22
22
|
|
|
23
23
|
DIRECTION_MAP = {"outgoing": "SAINDO", "incoming": "INDO"}
|
|
24
24
|
|
|
25
|
+
PREPARE_POLL_INTERVAL = 10
|
|
26
|
+
PREPARE_TIMEOUT = 300
|
|
27
|
+
DEFAULT_CHUNK_THRESHOLD = 10_000
|
|
28
|
+
DEFAULT_MAX_WORKERS = 2
|
|
29
|
+
|
|
25
30
|
ENV_API_KEY = "PRISMADATA_APIKEY"
|
|
26
31
|
ENV_USERNAME = "PRISMADATA_USERNAME"
|
|
27
32
|
ENV_PASSWORD = "PRISMADATA_PASSWORD"
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Batch prepare lifecycle for auto-scaling large batch operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any, Awaitable, Callable
|
|
9
|
+
|
|
10
|
+
from ._constants import PREPARE_POLL_INTERVAL, PREPARE_TIMEOUT
|
|
11
|
+
from .exceptions import BatchPrepareError
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("prismadata.prepare")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def batch_prepare(
|
|
17
|
+
post_fn: Callable[..., Any],
|
|
18
|
+
total_items: int,
|
|
19
|
+
) -> dict[str, Any]:
|
|
20
|
+
"""Signal the API that a large batch is starting.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
post_fn: Callable that POSTs to an API path (client._post).
|
|
24
|
+
total_items: Total number of items in the batch.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Response dict with ``session_id`` and ``max_workers``.
|
|
28
|
+
"""
|
|
29
|
+
logger.info("Preparing batch for %d items", total_items)
|
|
30
|
+
return post_fn("/v1/batch/prepare", {"total_items": total_items})
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def wait_until_ready(
|
|
34
|
+
get_fn: Callable[..., Any],
|
|
35
|
+
session_id: str,
|
|
36
|
+
timeout: int = PREPARE_TIMEOUT,
|
|
37
|
+
interval: int = PREPARE_POLL_INTERVAL,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Poll until the infrastructure is ready.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
get_fn: Callable that GETs from an API path (client._get).
|
|
43
|
+
session_id: Session ID from batch_prepare response.
|
|
44
|
+
timeout: Max seconds to wait.
|
|
45
|
+
interval: Seconds between polls.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
BatchPrepareError: If not ready within timeout.
|
|
49
|
+
"""
|
|
50
|
+
deadline = time.monotonic() + timeout
|
|
51
|
+
logger.info("Waiting for batch session %s to be ready (timeout=%ds)", session_id, timeout)
|
|
52
|
+
while True:
|
|
53
|
+
status = get_fn(f"/v1/batch/prepare/{session_id}/status")
|
|
54
|
+
if status.get("ready"):
|
|
55
|
+
logger.info("Batch session %s is ready", session_id)
|
|
56
|
+
return
|
|
57
|
+
if time.monotonic() >= deadline:
|
|
58
|
+
raise BatchPrepareError(
|
|
59
|
+
f"Batch prepare timed out after {timeout}s for session {session_id}"
|
|
60
|
+
)
|
|
61
|
+
time.sleep(interval)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def batch_complete(
|
|
65
|
+
post_fn: Callable[..., Any],
|
|
66
|
+
session_id: str,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""Signal the API that the batch is finished (scale down).
|
|
69
|
+
|
|
70
|
+
Always called in a finally block — swallows errors to avoid
|
|
71
|
+
masking the original exception.
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
post_fn(f"/v1/batch/complete/{session_id}", {})
|
|
75
|
+
logger.info("Batch session %s completed", session_id)
|
|
76
|
+
except Exception:
|
|
77
|
+
logger.warning("Failed to complete batch session %s", session_id, exc_info=True)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
async def async_batch_prepare(
|
|
81
|
+
post_fn: Callable[..., Awaitable[Any]],
|
|
82
|
+
total_items: int,
|
|
83
|
+
) -> dict[str, Any]:
|
|
84
|
+
"""Async version of batch_prepare."""
|
|
85
|
+
logger.info("Preparing batch for %d items", total_items)
|
|
86
|
+
return await post_fn("/v1/batch/prepare", {"total_items": total_items})
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
async def async_wait_until_ready(
|
|
90
|
+
get_fn: Callable[..., Awaitable[Any]],
|
|
91
|
+
session_id: str,
|
|
92
|
+
timeout: int = PREPARE_TIMEOUT,
|
|
93
|
+
interval: int = PREPARE_POLL_INTERVAL,
|
|
94
|
+
) -> None:
|
|
95
|
+
"""Async version of wait_until_ready."""
|
|
96
|
+
deadline = time.monotonic() + timeout
|
|
97
|
+
logger.info("Waiting for batch session %s to be ready (timeout=%ds)", session_id, timeout)
|
|
98
|
+
while True:
|
|
99
|
+
status = await get_fn(f"/v1/batch/prepare/{session_id}/status")
|
|
100
|
+
if status.get("ready"):
|
|
101
|
+
logger.info("Batch session %s is ready", session_id)
|
|
102
|
+
return
|
|
103
|
+
if time.monotonic() >= deadline:
|
|
104
|
+
raise BatchPrepareError(
|
|
105
|
+
f"Batch prepare timed out after {timeout}s for session {session_id}"
|
|
106
|
+
)
|
|
107
|
+
await asyncio.sleep(interval)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
async def async_batch_complete(
|
|
111
|
+
post_fn: Callable[..., Awaitable[Any]],
|
|
112
|
+
session_id: str,
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Async version of batch_complete."""
|
|
115
|
+
try:
|
|
116
|
+
await post_fn(f"/v1/batch/complete/{session_id}", {})
|
|
117
|
+
logger.info("Batch session %s completed", session_id)
|
|
118
|
+
except Exception:
|
|
119
|
+
logger.warning("Failed to complete batch session %s", session_id, exc_info=True)
|
|
@@ -8,17 +8,20 @@ from typing import Any, Literal, TYPE_CHECKING
|
|
|
8
8
|
from ._async_auth import AsyncAuthManager
|
|
9
9
|
from ._async_http import AsyncHttpClient
|
|
10
10
|
from ._types import QuotaInfo
|
|
11
|
-
from ._batch import async_process_batch, async_process_routing_batch
|
|
11
|
+
from ._batch import async_process_batch, async_process_routing_batch, split_into_groups
|
|
12
12
|
from ._cache import CacheManager
|
|
13
13
|
from ._columns import clean_columns
|
|
14
14
|
from ._constants import (
|
|
15
15
|
BASE_URL,
|
|
16
16
|
DEFAULT_CACHE_TTL,
|
|
17
|
+
DEFAULT_CHUNK_THRESHOLD,
|
|
18
|
+
DEFAULT_MAX_WORKERS,
|
|
17
19
|
DEFAULT_TIMEOUT,
|
|
18
20
|
DIRECTION_MAP,
|
|
19
21
|
MAX_BATCH_SIZE,
|
|
20
22
|
MAX_ROUTING_BATCH,
|
|
21
23
|
)
|
|
24
|
+
from ._prepare import async_batch_complete, async_batch_prepare, async_wait_until_ready
|
|
22
25
|
from ._progress import progress_bar
|
|
23
26
|
from ._validation import validate_lat_lng, validate_profile, validate_route_points
|
|
24
27
|
from .client import _resolve_credentials
|
|
@@ -657,9 +660,17 @@ class AsyncClient:
|
|
|
657
660
|
on_error: Literal["raise", "skip"] = "raise",
|
|
658
661
|
timeout: int | None = None,
|
|
659
662
|
show_progress: bool | None = None,
|
|
663
|
+
auto_scale: bool = True,
|
|
664
|
+
max_workers: int = DEFAULT_MAX_WORKERS,
|
|
665
|
+
chunk_threshold: int = DEFAULT_CHUNK_THRESHOLD,
|
|
660
666
|
**kwargs: Any,
|
|
661
667
|
) -> dict[str, Any]:
|
|
662
|
-
"""Batch geocode addresses and aggregate location APIs.
|
|
668
|
+
"""Batch geocode addresses and aggregate location APIs.
|
|
669
|
+
|
|
670
|
+
For large batches (above ``chunk_threshold``), the SDK signals the API
|
|
671
|
+
to scale up infrastructure, splits the work across parallel workers,
|
|
672
|
+
and signals scale-down when finished.
|
|
673
|
+
"""
|
|
663
674
|
params: dict[str, Any] = {}
|
|
664
675
|
for svc in services:
|
|
665
676
|
params[svc] = True
|
|
@@ -667,16 +678,44 @@ class AsyncClient:
|
|
|
667
678
|
|
|
668
679
|
chunk_size = batch_size if batch_size is not None else MAX_BATCH_SIZE
|
|
669
680
|
use_progress = show_progress if show_progress is not None else self._show_progress
|
|
681
|
+
total = len(addresses)
|
|
682
|
+
should_scale = auto_scale and total >= chunk_threshold
|
|
670
683
|
|
|
671
684
|
async def _request(chunk: dict) -> dict[str, Any]:
|
|
672
685
|
return await self._post("/location/batch/geocoder/aggregator", chunk, params=params, timeout=timeout)
|
|
673
686
|
|
|
674
|
-
|
|
687
|
+
if not should_scale:
|
|
688
|
+
with progress_bar(total, desc="Geocode+Aggregating", enabled=use_progress) as bar:
|
|
675
689
|
|
|
676
|
-
|
|
677
|
-
|
|
690
|
+
def _on_progress(n: int) -> None:
|
|
691
|
+
bar.update(n)
|
|
678
692
|
|
|
679
|
-
|
|
693
|
+
result = await async_process_batch(addresses, _request, chunk_size, on_progress=_on_progress, on_error=on_error)
|
|
694
|
+
else:
|
|
695
|
+
import asyncio
|
|
696
|
+
|
|
697
|
+
resp = await async_batch_prepare(self._post, total)
|
|
698
|
+
session_id = resp["session_id"]
|
|
699
|
+
num_workers = min(max_workers, resp.get("max_workers", max_workers))
|
|
700
|
+
|
|
701
|
+
try:
|
|
702
|
+
await async_wait_until_ready(self._get, session_id)
|
|
703
|
+
groups = split_into_groups(addresses, num_workers)
|
|
704
|
+
|
|
705
|
+
async def _process_group(group: dict) -> dict[str, Any]:
|
|
706
|
+
return await async_process_batch(group, _request, chunk_size, on_error=on_error)
|
|
707
|
+
|
|
708
|
+
group_results = await asyncio.gather(
|
|
709
|
+
*[_process_group(g) for g in groups],
|
|
710
|
+
return_exceptions=(on_error == "skip"),
|
|
711
|
+
)
|
|
712
|
+
result = {}
|
|
713
|
+
for gr in group_results:
|
|
714
|
+
if isinstance(gr, Exception):
|
|
715
|
+
continue
|
|
716
|
+
result.update(gr)
|
|
717
|
+
finally:
|
|
718
|
+
await async_batch_complete(self._post, session_id)
|
|
680
719
|
|
|
681
720
|
if self._clean:
|
|
682
721
|
return {k: clean_columns(v) if isinstance(v, dict) else v for k, v in result.items()}
|
|
@@ -7,13 +7,15 @@ import warnings
|
|
|
7
7
|
from typing import Any, Literal, TYPE_CHECKING
|
|
8
8
|
|
|
9
9
|
from ._auth import AuthManager
|
|
10
|
-
from ._batch import process_batch, process_routing_batch
|
|
10
|
+
from ._batch import process_batch, process_routing_batch, split_into_groups
|
|
11
11
|
from ._types import QuotaInfo
|
|
12
12
|
from ._cache import CacheManager
|
|
13
13
|
from ._columns import clean_columns
|
|
14
14
|
from ._constants import (
|
|
15
15
|
BASE_URL,
|
|
16
16
|
DEFAULT_CACHE_TTL,
|
|
17
|
+
DEFAULT_CHUNK_THRESHOLD,
|
|
18
|
+
DEFAULT_MAX_WORKERS,
|
|
17
19
|
DEFAULT_TIMEOUT,
|
|
18
20
|
DIRECTION_MAP,
|
|
19
21
|
ENV_API_KEY,
|
|
@@ -23,6 +25,7 @@ from ._constants import (
|
|
|
23
25
|
MAX_ROUTING_BATCH,
|
|
24
26
|
)
|
|
25
27
|
from ._http import HttpClient
|
|
28
|
+
from ._prepare import batch_complete, batch_prepare, wait_until_ready
|
|
26
29
|
from ._progress import progress_bar
|
|
27
30
|
from ._validation import validate_lat_lng, validate_profile, validate_route_points
|
|
28
31
|
from .exceptions import AuthenticationError
|
|
@@ -884,10 +887,17 @@ class Client:
|
|
|
884
887
|
on_error: Literal["raise", "skip"] = "raise",
|
|
885
888
|
timeout: int | None = None,
|
|
886
889
|
show_progress: bool | None = None,
|
|
890
|
+
auto_scale: bool = True,
|
|
891
|
+
max_workers: int = DEFAULT_MAX_WORKERS,
|
|
892
|
+
chunk_threshold: int = DEFAULT_CHUNK_THRESHOLD,
|
|
887
893
|
**kwargs: Any,
|
|
888
894
|
) -> dict[str, Any]:
|
|
889
895
|
"""Batch geocode addresses and aggregate location APIs.
|
|
890
896
|
|
|
897
|
+
For large batches (above ``chunk_threshold``), the SDK signals the API
|
|
898
|
+
to scale up infrastructure, splits the work across parallel workers,
|
|
899
|
+
and signals scale-down when finished.
|
|
900
|
+
|
|
891
901
|
Args:
|
|
892
902
|
addresses: Mapping of id to address dict.
|
|
893
903
|
services: List of service names to enable.
|
|
@@ -895,6 +905,9 @@ class Client:
|
|
|
895
905
|
on_error: ``"raise"`` (default) or ``"skip"`` to return partial results.
|
|
896
906
|
timeout: Override default request timeout (seconds).
|
|
897
907
|
show_progress: Override progress bar setting.
|
|
908
|
+
auto_scale: If True, call prepare/complete for large batches.
|
|
909
|
+
max_workers: Max parallel workers (server may return fewer).
|
|
910
|
+
chunk_threshold: Minimum items to trigger auto-scaling.
|
|
898
911
|
**kwargs: Additional parameters for specific services.
|
|
899
912
|
"""
|
|
900
913
|
params: dict[str, Any] = {}
|
|
@@ -904,16 +917,41 @@ class Client:
|
|
|
904
917
|
|
|
905
918
|
chunk_size = batch_size if batch_size is not None else MAX_BATCH_SIZE
|
|
906
919
|
use_progress = show_progress if show_progress is not None else self._show_progress
|
|
920
|
+
total = len(addresses)
|
|
921
|
+
should_scale = auto_scale and total >= chunk_threshold
|
|
907
922
|
|
|
908
923
|
def _request(chunk: dict) -> dict[str, Any]:
|
|
909
924
|
return self._post("/location/batch/geocoder/aggregator", chunk, params=params, timeout=timeout)
|
|
910
925
|
|
|
911
|
-
|
|
926
|
+
if not should_scale:
|
|
927
|
+
with progress_bar(total, desc="Geocode+Aggregating", enabled=use_progress) as bar:
|
|
912
928
|
|
|
913
|
-
|
|
914
|
-
|
|
929
|
+
def _on_progress(n: int) -> None:
|
|
930
|
+
bar.update(n)
|
|
931
|
+
|
|
932
|
+
result = process_batch(addresses, _request, chunk_size, on_progress=_on_progress, on_error=on_error)
|
|
933
|
+
else:
|
|
934
|
+
resp = batch_prepare(self._post, total)
|
|
935
|
+
session_id = resp["session_id"]
|
|
936
|
+
num_workers = min(max_workers, resp.get("max_workers", max_workers))
|
|
937
|
+
|
|
938
|
+
try:
|
|
939
|
+
wait_until_ready(self._get, session_id)
|
|
940
|
+
groups = split_into_groups(addresses, num_workers)
|
|
941
|
+
|
|
942
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
943
|
+
|
|
944
|
+
def _process_group(group: dict) -> dict[str, Any]:
|
|
945
|
+
return process_batch(group, _request, chunk_size, on_error=on_error)
|
|
915
946
|
|
|
916
|
-
|
|
947
|
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
948
|
+
futures = [executor.submit(_process_group, g) for g in groups]
|
|
949
|
+
result = {}
|
|
950
|
+
for f in as_completed(futures):
|
|
951
|
+
chunk_result = f.result()
|
|
952
|
+
result.update(chunk_result)
|
|
953
|
+
finally:
|
|
954
|
+
batch_complete(self._post, session_id)
|
|
917
955
|
|
|
918
956
|
if self._clean:
|
|
919
957
|
return {k: clean_columns(v) if isinstance(v, dict) else v for k, v in result.items()}
|
|
@@ -38,6 +38,10 @@ class ValidationError(PrismaDataError):
|
|
|
38
38
|
"""Raised when the API returns a validation error (422)."""
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
class BatchPrepareError(PrismaDataError):
|
|
42
|
+
"""Raised when batch prepare polling times out or fails."""
|
|
43
|
+
|
|
44
|
+
|
|
41
45
|
class BatchError(PrismaDataError):
|
|
42
46
|
"""Raised when a batch operation has partial failures.
|
|
43
47
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|