prismadata 0.3.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {prismadata-0.3.2 → prismadata-0.4.0}/PKG-INFO +1 -1
  2. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/__init__.py +2 -0
  3. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_batch.py +20 -0
  4. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_constants.py +5 -0
  5. prismadata-0.4.0/prismadata/_prepare.py +119 -0
  6. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/async_client.py +45 -6
  7. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/client.py +43 -5
  8. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/exceptions.py +4 -0
  9. {prismadata-0.3.2 → prismadata-0.4.0}/pyproject.toml +1 -1
  10. {prismadata-0.3.2 → prismadata-0.4.0}/LICENSE +0 -0
  11. {prismadata-0.3.2 → prismadata-0.4.0}/README.md +0 -0
  12. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_async_auth.py +0 -0
  13. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_async_http.py +0 -0
  14. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_auth.py +0 -0
  15. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_cache.py +0 -0
  16. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_columns.py +0 -0
  17. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_enrich.py +0 -0
  18. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_http.py +0 -0
  19. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_progress.py +0 -0
  20. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_types.py +0 -0
  21. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/_validation.py +0 -0
  22. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/py.typed +0 -0
  23. {prismadata-0.3.2 → prismadata-0.4.0}/prismadata/sklearn.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: prismadata
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Python client for the PrismaData location intelligence API
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -37,6 +37,7 @@ from .client import Client
37
37
  from .exceptions import (
38
38
  AuthenticationError,
39
39
  BatchError,
40
+ BatchPrepareError,
40
41
  PrismaDataError,
41
42
  QuotaExhaustedError,
42
43
  RateLimitError,
@@ -62,6 +63,7 @@ __all__ = [
62
63
  "__version__",
63
64
  "AuthenticationError",
64
65
  "BatchError",
66
+ "BatchPrepareError",
65
67
  "PrismaDataError",
66
68
  "QuotaExhaustedError",
67
69
  "RateLimitError",
@@ -11,6 +11,26 @@ from .exceptions import BatchError
11
11
  logger = logging.getLogger("prismadata.batch")
12
12
 
13
13
 
14
+ def split_into_groups(
15
+ items: dict[str, Any], num_groups: int,
16
+ ) -> list[dict[str, Any]]:
17
+ """Split a dict into num_groups roughly-equal sub-dicts.
18
+
19
+ Args:
20
+ items: Dictionary to split.
21
+ num_groups: Number of groups (must be >= 1).
22
+
23
+ Returns:
24
+ List of dicts, each containing a subset of the items.
25
+ """
26
+ keys = list(items.keys())
27
+ group_size = math.ceil(len(keys) / num_groups)
28
+ return [
29
+ {k: items[k] for k in keys[i : i + group_size]}
30
+ for i in range(0, len(keys), group_size)
31
+ ]
32
+
33
+
14
34
  def _raise_if_partial(
15
35
  results: dict[str, Any],
16
36
  failed_keys: list[str],
@@ -22,6 +22,11 @@ USER_AGENT_PREFIX = "prismadata-python"
22
22
 
23
23
  DIRECTION_MAP = {"outgoing": "SAINDO", "incoming": "INDO"}
24
24
 
25
+ PREPARE_POLL_INTERVAL = 10
26
+ PREPARE_TIMEOUT = 300
27
+ DEFAULT_CHUNK_THRESHOLD = 10_000
28
+ DEFAULT_MAX_WORKERS = 2
29
+
25
30
  ENV_API_KEY = "PRISMADATA_APIKEY"
26
31
  ENV_USERNAME = "PRISMADATA_USERNAME"
27
32
  ENV_PASSWORD = "PRISMADATA_PASSWORD"
@@ -0,0 +1,119 @@
1
+ """Batch prepare lifecycle for auto-scaling large batch operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import time
8
+ from typing import Any, Awaitable, Callable
9
+
10
+ from ._constants import PREPARE_POLL_INTERVAL, PREPARE_TIMEOUT
11
+ from .exceptions import BatchPrepareError
12
+
13
+ logger = logging.getLogger("prismadata.prepare")
14
+
15
+
16
+ def batch_prepare(
17
+ post_fn: Callable[..., Any],
18
+ total_items: int,
19
+ ) -> dict[str, Any]:
20
+ """Signal the API that a large batch is starting.
21
+
22
+ Args:
23
+ post_fn: Callable that POSTs to an API path (client._post).
24
+ total_items: Total number of items in the batch.
25
+
26
+ Returns:
27
+ Response dict with ``session_id`` and ``max_workers``.
28
+ """
29
+ logger.info("Preparing batch for %d items", total_items)
30
+ return post_fn("/v1/batch/prepare", {"total_items": total_items})
31
+
32
+
33
+ def wait_until_ready(
34
+ get_fn: Callable[..., Any],
35
+ session_id: str,
36
+ timeout: int = PREPARE_TIMEOUT,
37
+ interval: int = PREPARE_POLL_INTERVAL,
38
+ ) -> None:
39
+ """Poll until the infrastructure is ready.
40
+
41
+ Args:
42
+ get_fn: Callable that GETs from an API path (client._get).
43
+ session_id: Session ID from batch_prepare response.
44
+ timeout: Max seconds to wait.
45
+ interval: Seconds between polls.
46
+
47
+ Raises:
48
+ BatchPrepareError: If not ready within timeout.
49
+ """
50
+ deadline = time.monotonic() + timeout
51
+ logger.info("Waiting for batch session %s to be ready (timeout=%ds)", session_id, timeout)
52
+ while True:
53
+ status = get_fn(f"/v1/batch/prepare/{session_id}/status")
54
+ if status.get("ready"):
55
+ logger.info("Batch session %s is ready", session_id)
56
+ return
57
+ if time.monotonic() >= deadline:
58
+ raise BatchPrepareError(
59
+ f"Batch prepare timed out after {timeout}s for session {session_id}"
60
+ )
61
+ time.sleep(interval)
62
+
63
+
64
+ def batch_complete(
65
+ post_fn: Callable[..., Any],
66
+ session_id: str,
67
+ ) -> None:
68
+ """Signal the API that the batch is finished (scale down).
69
+
70
+ Always called in a finally block — swallows errors to avoid
71
+ masking the original exception.
72
+ """
73
+ try:
74
+ post_fn(f"/v1/batch/complete/{session_id}", {})
75
+ logger.info("Batch session %s completed", session_id)
76
+ except Exception:
77
+ logger.warning("Failed to complete batch session %s", session_id, exc_info=True)
78
+
79
+
80
+ async def async_batch_prepare(
81
+ post_fn: Callable[..., Awaitable[Any]],
82
+ total_items: int,
83
+ ) -> dict[str, Any]:
84
+ """Async version of batch_prepare."""
85
+ logger.info("Preparing batch for %d items", total_items)
86
+ return await post_fn("/v1/batch/prepare", {"total_items": total_items})
87
+
88
+
89
+ async def async_wait_until_ready(
90
+ get_fn: Callable[..., Awaitable[Any]],
91
+ session_id: str,
92
+ timeout: int = PREPARE_TIMEOUT,
93
+ interval: int = PREPARE_POLL_INTERVAL,
94
+ ) -> None:
95
+ """Async version of wait_until_ready."""
96
+ deadline = time.monotonic() + timeout
97
+ logger.info("Waiting for batch session %s to be ready (timeout=%ds)", session_id, timeout)
98
+ while True:
99
+ status = await get_fn(f"/v1/batch/prepare/{session_id}/status")
100
+ if status.get("ready"):
101
+ logger.info("Batch session %s is ready", session_id)
102
+ return
103
+ if time.monotonic() >= deadline:
104
+ raise BatchPrepareError(
105
+ f"Batch prepare timed out after {timeout}s for session {session_id}"
106
+ )
107
+ await asyncio.sleep(interval)
108
+
109
+
110
+ async def async_batch_complete(
111
+ post_fn: Callable[..., Awaitable[Any]],
112
+ session_id: str,
113
+ ) -> None:
114
+ """Async version of batch_complete."""
115
+ try:
116
+ await post_fn(f"/v1/batch/complete/{session_id}", {})
117
+ logger.info("Batch session %s completed", session_id)
118
+ except Exception:
119
+ logger.warning("Failed to complete batch session %s", session_id, exc_info=True)
@@ -8,17 +8,20 @@ from typing import Any, Literal, TYPE_CHECKING
8
8
  from ._async_auth import AsyncAuthManager
9
9
  from ._async_http import AsyncHttpClient
10
10
  from ._types import QuotaInfo
11
- from ._batch import async_process_batch, async_process_routing_batch
11
+ from ._batch import async_process_batch, async_process_routing_batch, split_into_groups
12
12
  from ._cache import CacheManager
13
13
  from ._columns import clean_columns
14
14
  from ._constants import (
15
15
  BASE_URL,
16
16
  DEFAULT_CACHE_TTL,
17
+ DEFAULT_CHUNK_THRESHOLD,
18
+ DEFAULT_MAX_WORKERS,
17
19
  DEFAULT_TIMEOUT,
18
20
  DIRECTION_MAP,
19
21
  MAX_BATCH_SIZE,
20
22
  MAX_ROUTING_BATCH,
21
23
  )
24
+ from ._prepare import async_batch_complete, async_batch_prepare, async_wait_until_ready
22
25
  from ._progress import progress_bar
23
26
  from ._validation import validate_lat_lng, validate_profile, validate_route_points
24
27
  from .client import _resolve_credentials
@@ -657,9 +660,17 @@ class AsyncClient:
657
660
  on_error: Literal["raise", "skip"] = "raise",
658
661
  timeout: int | None = None,
659
662
  show_progress: bool | None = None,
663
+ auto_scale: bool = True,
664
+ max_workers: int = DEFAULT_MAX_WORKERS,
665
+ chunk_threshold: int = DEFAULT_CHUNK_THRESHOLD,
660
666
  **kwargs: Any,
661
667
  ) -> dict[str, Any]:
662
- """Batch geocode addresses and aggregate location APIs."""
668
+ """Batch geocode addresses and aggregate location APIs.
669
+
670
+ For large batches (above ``chunk_threshold``), the SDK signals the API
671
+ to scale up infrastructure, splits the work across parallel workers,
672
+ and signals scale-down when finished.
673
+ """
663
674
  params: dict[str, Any] = {}
664
675
  for svc in services:
665
676
  params[svc] = True
@@ -667,16 +678,44 @@ class AsyncClient:
667
678
 
668
679
  chunk_size = batch_size if batch_size is not None else MAX_BATCH_SIZE
669
680
  use_progress = show_progress if show_progress is not None else self._show_progress
681
+ total = len(addresses)
682
+ should_scale = auto_scale and total >= chunk_threshold
670
683
 
671
684
  async def _request(chunk: dict) -> dict[str, Any]:
672
685
  return await self._post("/location/batch/geocoder/aggregator", chunk, params=params, timeout=timeout)
673
686
 
674
- with progress_bar(len(addresses), desc="Geocode+Aggregating", enabled=use_progress) as bar:
687
+ if not should_scale:
688
+ with progress_bar(total, desc="Geocode+Aggregating", enabled=use_progress) as bar:
675
689
 
676
- def _on_progress(n: int) -> None:
677
- bar.update(n)
690
+ def _on_progress(n: int) -> None:
691
+ bar.update(n)
678
692
 
679
- result = await async_process_batch(addresses, _request, chunk_size, on_progress=_on_progress, on_error=on_error)
693
+ result = await async_process_batch(addresses, _request, chunk_size, on_progress=_on_progress, on_error=on_error)
694
+ else:
695
+ import asyncio
696
+
697
+ resp = await async_batch_prepare(self._post, total)
698
+ session_id = resp["session_id"]
699
+ num_workers = min(max_workers, resp.get("max_workers", max_workers))
700
+
701
+ try:
702
+ await async_wait_until_ready(self._get, session_id)
703
+ groups = split_into_groups(addresses, num_workers)
704
+
705
+ async def _process_group(group: dict) -> dict[str, Any]:
706
+ return await async_process_batch(group, _request, chunk_size, on_error=on_error)
707
+
708
+ group_results = await asyncio.gather(
709
+ *[_process_group(g) for g in groups],
710
+ return_exceptions=(on_error == "skip"),
711
+ )
712
+ result = {}
713
+ for gr in group_results:
714
+ if isinstance(gr, Exception):
715
+ continue
716
+ result.update(gr)
717
+ finally:
718
+ await async_batch_complete(self._post, session_id)
680
719
 
681
720
  if self._clean:
682
721
  return {k: clean_columns(v) if isinstance(v, dict) else v for k, v in result.items()}
@@ -7,13 +7,15 @@ import warnings
7
7
  from typing import Any, Literal, TYPE_CHECKING
8
8
 
9
9
  from ._auth import AuthManager
10
- from ._batch import process_batch, process_routing_batch
10
+ from ._batch import process_batch, process_routing_batch, split_into_groups
11
11
  from ._types import QuotaInfo
12
12
  from ._cache import CacheManager
13
13
  from ._columns import clean_columns
14
14
  from ._constants import (
15
15
  BASE_URL,
16
16
  DEFAULT_CACHE_TTL,
17
+ DEFAULT_CHUNK_THRESHOLD,
18
+ DEFAULT_MAX_WORKERS,
17
19
  DEFAULT_TIMEOUT,
18
20
  DIRECTION_MAP,
19
21
  ENV_API_KEY,
@@ -23,6 +25,7 @@ from ._constants import (
23
25
  MAX_ROUTING_BATCH,
24
26
  )
25
27
  from ._http import HttpClient
28
+ from ._prepare import batch_complete, batch_prepare, wait_until_ready
26
29
  from ._progress import progress_bar
27
30
  from ._validation import validate_lat_lng, validate_profile, validate_route_points
28
31
  from .exceptions import AuthenticationError
@@ -884,10 +887,17 @@ class Client:
884
887
  on_error: Literal["raise", "skip"] = "raise",
885
888
  timeout: int | None = None,
886
889
  show_progress: bool | None = None,
890
+ auto_scale: bool = True,
891
+ max_workers: int = DEFAULT_MAX_WORKERS,
892
+ chunk_threshold: int = DEFAULT_CHUNK_THRESHOLD,
887
893
  **kwargs: Any,
888
894
  ) -> dict[str, Any]:
889
895
  """Batch geocode addresses and aggregate location APIs.
890
896
 
897
+ For large batches (above ``chunk_threshold``), the SDK signals the API
898
+ to scale up infrastructure, splits the work across parallel workers,
899
+ and signals scale-down when finished.
900
+
891
901
  Args:
892
902
  addresses: Mapping of id to address dict.
893
903
  services: List of service names to enable.
@@ -895,6 +905,9 @@ class Client:
895
905
  on_error: ``"raise"`` (default) or ``"skip"`` to return partial results.
896
906
  timeout: Override default request timeout (seconds).
897
907
  show_progress: Override progress bar setting.
908
+ auto_scale: If True, call prepare/complete for large batches.
909
+ max_workers: Max parallel workers (server may return fewer).
910
+ chunk_threshold: Minimum items to trigger auto-scaling.
898
911
  **kwargs: Additional parameters for specific services.
899
912
  """
900
913
  params: dict[str, Any] = {}
@@ -904,16 +917,41 @@ class Client:
904
917
 
905
918
  chunk_size = batch_size if batch_size is not None else MAX_BATCH_SIZE
906
919
  use_progress = show_progress if show_progress is not None else self._show_progress
920
+ total = len(addresses)
921
+ should_scale = auto_scale and total >= chunk_threshold
907
922
 
908
923
  def _request(chunk: dict) -> dict[str, Any]:
909
924
  return self._post("/location/batch/geocoder/aggregator", chunk, params=params, timeout=timeout)
910
925
 
911
- with progress_bar(len(addresses), desc="Geocode+Aggregating", enabled=use_progress) as bar:
926
+ if not should_scale:
927
+ with progress_bar(total, desc="Geocode+Aggregating", enabled=use_progress) as bar:
912
928
 
913
- def _on_progress(n: int) -> None:
914
- bar.update(n)
929
+ def _on_progress(n: int) -> None:
930
+ bar.update(n)
931
+
932
+ result = process_batch(addresses, _request, chunk_size, on_progress=_on_progress, on_error=on_error)
933
+ else:
934
+ resp = batch_prepare(self._post, total)
935
+ session_id = resp["session_id"]
936
+ num_workers = min(max_workers, resp.get("max_workers", max_workers))
937
+
938
+ try:
939
+ wait_until_ready(self._get, session_id)
940
+ groups = split_into_groups(addresses, num_workers)
941
+
942
+ from concurrent.futures import ThreadPoolExecutor, as_completed
943
+
944
+ def _process_group(group: dict) -> dict[str, Any]:
945
+ return process_batch(group, _request, chunk_size, on_error=on_error)
915
946
 
916
- result = process_batch(addresses, _request, chunk_size, on_progress=_on_progress, on_error=on_error)
947
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
948
+ futures = [executor.submit(_process_group, g) for g in groups]
949
+ result = {}
950
+ for f in as_completed(futures):
951
+ chunk_result = f.result()
952
+ result.update(chunk_result)
953
+ finally:
954
+ batch_complete(self._post, session_id)
917
955
 
918
956
  if self._clean:
919
957
  return {k: clean_columns(v) if isinstance(v, dict) else v for k, v in result.items()}
@@ -38,6 +38,10 @@ class ValidationError(PrismaDataError):
38
38
  """Raised when the API returns a validation error (422)."""
39
39
 
40
40
 
41
+ class BatchPrepareError(PrismaDataError):
42
+ """Raised when batch prepare polling times out or fails."""
43
+
44
+
41
45
  class BatchError(PrismaDataError):
42
46
  """Raised when a batch operation has partial failures.
43
47
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "prismadata"
3
- version = "0.3.2"
3
+ version = "0.4.0"
4
4
  description = "Python client for the PrismaData location intelligence API"
5
5
  authors = ["PrismaData <contato@prismadata.io>"]
6
6
  license = "MIT"
File without changes
File without changes