polygres 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
polygres/client.py ADDED
@@ -0,0 +1,696 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import random
5
+ import re
6
+ import time
7
+ from dataclasses import dataclass, field
8
+ from typing import Any
9
+ from urllib.parse import urlparse
10
+
11
+ import httpx
12
+
13
+ from polygres.errors import (
14
+ PolygresAPIError,
15
+ PolygresAuthError,
16
+ PolygresNotFoundError,
17
+ PolygresPermissionError,
18
+ PolygresRateLimitError,
19
+ PolygresRuntimeError,
20
+ PolygresValidationError,
21
+ )
22
+ from polygres.models import (
23
+ ConnectionInfo,
24
+ GraphConnectionResponse,
25
+ GraphPathResponse,
26
+ GraphResult,
27
+ HybridResult,
28
+ Page,
29
+ RetrievalReadiness,
30
+ TextResult,
31
+ VectorResult,
32
+ )
33
+
34
+ API_KEY_RE = re.compile(r"^poly_live_[0-9a-f]{32}$")
35
+ PROJECT_RE = re.compile(r"^p[a-z0-9]{23}$")
36
+ RETRY_STATUSES = {408, 429, 500, 502, 503, 504}
37
+ VERSION = "0.1.0"
38
+
39
+
40
+ class Polygres:
41
+ def __init__(
42
+ self,
43
+ *,
44
+ api_key: str,
45
+ runtime_url: str | None = None,
46
+ base_url: str | None = None,
47
+ timeout: float | httpx.Timeout = 30.0,
48
+ connect_timeout: float = 10.0,
49
+ max_retries: int = 2,
50
+ headers: dict[str, str] | None = None,
51
+ ) -> None:
52
+ if not API_KEY_RE.match(api_key):
53
+ raise PolygresValidationError("API key must match poly_live_[32hex]")
54
+ selected_url = _select_runtime_url(runtime_url=runtime_url, base_url=base_url)
55
+ _validate_base_url(selected_url)
56
+ _validate_positive_timeout(connect_timeout, "connect_timeout")
57
+ if isinstance(timeout, (int, float)):
58
+ _validate_positive_timeout(float(timeout), "timeout")
59
+ timeout_config: float | httpx.Timeout = httpx.Timeout(
60
+ float(timeout), connect=connect_timeout
61
+ )
62
+ elif isinstance(timeout, httpx.Timeout):
63
+ timeout_config = timeout
64
+ else:
65
+ raise PolygresValidationError("timeout must be a positive number or httpx.Timeout")
66
+ if max_retries < 0 or max_retries > 5:
67
+ raise PolygresValidationError("max_retries must be between 0 and 5")
68
+ if headers is not None and not all(
69
+ isinstance(key, str) and isinstance(value, str)
70
+ for key, value in headers.items()
71
+ ):
72
+ raise PolygresValidationError("headers must contain string keys and values")
73
+
74
+ self._api_key = api_key
75
+ self._base_url = selected_url.rstrip("/")
76
+ self._timeout = timeout_config
77
+ self._max_retries = max_retries
78
+ self._headers = headers or {}
79
+
80
+ def project(self, project_id: str | None = None) -> Project:
81
+ if project_id is not None and not PROJECT_RE.match(project_id):
82
+ raise PolygresValidationError("project id must match ^p[a-z0-9]{23}$")
83
+ return Project(self, project_id)
84
+
85
+ def _headers_for(self) -> dict[str, str]:
86
+ headers = dict(self._headers)
87
+ headers.update(
88
+ {
89
+ "Authorization": f"Bearer {self._api_key}",
90
+ "User-Agent": f"polygres-python/{VERSION}",
91
+ }
92
+ )
93
+ return headers
94
+
95
+ def _get(
96
+ self,
97
+ path: str,
98
+ *,
99
+ timeout: float | httpx.Timeout | None = None,
100
+ max_retries: int | None = None,
101
+ ) -> dict[str, Any]:
102
+ return self._request(
103
+ "GET", path, timeout=timeout, max_retries=max_retries
104
+ )
105
+
106
+ def _post(
107
+ self,
108
+ path: str,
109
+ payload: dict[str, Any],
110
+ *,
111
+ timeout: float | httpx.Timeout | None = None,
112
+ max_retries: int | None = None,
113
+ ) -> dict[str, Any]:
114
+ return self._request(
115
+ "POST",
116
+ path,
117
+ json=payload,
118
+ timeout=timeout,
119
+ max_retries=max_retries,
120
+ )
121
+
122
+ def _request(
123
+ self,
124
+ method: str,
125
+ path: str,
126
+ *,
127
+ json: dict[str, Any] | None = None,
128
+ timeout: float | httpx.Timeout | None = None,
129
+ max_retries: int | None = None,
130
+ ) -> dict[str, Any]:
131
+ retry_budget = self._max_retries if max_retries is None else max_retries
132
+ if retry_budget < 0 or retry_budget > 5:
133
+ raise PolygresValidationError("max_retries must be between 0 and 5")
134
+ timeout_config = self._timeout if timeout is None else timeout
135
+ if isinstance(timeout_config, (int, float)):
136
+ _validate_positive_timeout(float(timeout_config), "timeout")
137
+ url = f"{self._base_url}{path}"
138
+ for attempt in range(retry_budget + 1):
139
+ try:
140
+ with httpx.Client(timeout=timeout_config) as client:
141
+ response = client.request(
142
+ method,
143
+ url,
144
+ headers=self._headers_for(),
145
+ json=json,
146
+ )
147
+ except httpx.TimeoutException as exc:
148
+ if attempt < retry_budget:
149
+ _sleep_before_retry(attempt, None)
150
+ continue
151
+ raise PolygresRuntimeError(
152
+ "Polygres request timed out", status_code=None
153
+ ) from exc
154
+ except httpx.NetworkError as exc:
155
+ if attempt < retry_budget:
156
+ _sleep_before_retry(attempt, None)
157
+ continue
158
+ raise PolygresRuntimeError(
159
+ "Polygres network request failed", status_code=None
160
+ ) from exc
161
+
162
+ if response.status_code in RETRY_STATUSES and attempt < retry_budget:
163
+ _sleep_before_retry(attempt, response.headers.get("Retry-After"))
164
+ continue
165
+ if response.is_error:
166
+ raise _api_error(response)
167
+ return response.json()
168
+ raise PolygresRuntimeError("Polygres request failed")
169
+
170
+
171
+ @dataclass
172
+ class Project:
173
+ _client: Polygres
174
+ project_id: str | None
175
+ graph: GraphNamespace = field(init=False)
176
+ vector: VectorNamespace = field(init=False)
177
+ text: TextNamespace = field(init=False)
178
+ hybrid: HybridNamespace = field(init=False)
179
+
180
+ def __post_init__(self) -> None:
181
+ self.graph = GraphNamespace(self)
182
+ self.vector = VectorNamespace(self)
183
+ self.text = TextNamespace(self)
184
+ self.hybrid = HybridNamespace(self)
185
+
186
+ def readiness(self) -> RetrievalReadiness:
187
+ payload = self._client._get("/retrieval/readiness")
188
+ return RetrievalReadiness.from_api(payload)
189
+
190
+ def connection_info(self) -> ConnectionInfo:
191
+ payload = self._client._get("/connection-info")
192
+ return ConnectionInfo.from_api(payload)
193
+
194
+ def _post_page(
195
+ self,
196
+ path: str,
197
+ payload: dict[str, Any],
198
+ parser: Any,
199
+ *,
200
+ timeout: float | httpx.Timeout | None = None,
201
+ max_retries: int | None = None,
202
+ ) -> Page[Any]:
203
+ response = self._client._post(
204
+ path,
205
+ _compact(payload),
206
+ timeout=timeout,
207
+ max_retries=max_retries,
208
+ )
209
+
210
+ def fetch_next(cursor: str) -> Page[Any]:
211
+ return self._post_page(
212
+ path,
213
+ {**payload, "cursor": cursor},
214
+ parser,
215
+ timeout=timeout,
216
+ max_retries=max_retries,
217
+ )
218
+
219
+ return Page.from_api(response, parser, fetch_next)
220
+
221
+
222
+ @dataclass
223
+ class GraphNamespace:
224
+ _project: Project
225
+
226
+ def expand(
227
+ self,
228
+ start: dict[str, Any] | list[dict[str, Any]],
229
+ *,
230
+ max_depth: int = 5,
231
+ relationship_types: list[str] | None = None,
232
+ direction: str = "out",
233
+ filters: dict[str, Any] | None = None,
234
+ limit: int = 50,
235
+ cursor: str | None = None,
236
+ timeout: float | None = None,
237
+ max_retries: int | None = None,
238
+ ) -> Page[GraphResult]:
239
+ _validate_required(start, "start")
240
+ _validate_range(max_depth, "max_depth", 1, 20)
241
+ _validate_range(limit, "limit", 1, 1000)
242
+ payload = {
243
+ "start": start,
244
+ "max_depth": max_depth,
245
+ "relationship_types": relationship_types,
246
+ "direction": _sdk_direction(direction),
247
+ "filters": filters or {},
248
+ "limit": limit,
249
+ "cursor": cursor,
250
+ }
251
+ return self._project._post_page(
252
+ "/graph/expand",
253
+ payload,
254
+ GraphResult.from_api,
255
+ timeout=timeout,
256
+ max_retries=max_retries,
257
+ )
258
+
259
+ def neighborhood(
260
+ self,
261
+ start: dict[str, Any],
262
+ *,
263
+ radius: int = 2,
264
+ relationship_types: list[str] | None = None,
265
+ direction: str = "any",
266
+ filters: dict[str, Any] | None = None,
267
+ limit: int = 100,
268
+ cursor: str | None = None,
269
+ ) -> Page[GraphResult]:
270
+ _validate_range(radius, "radius", 1, 20)
271
+ _validate_range(limit, "limit", 1, 1000)
272
+ return self._project._post_page(
273
+ "/graph/neighborhood",
274
+ {
275
+ "start": start,
276
+ "max_depth": radius,
277
+ "relationship_types": relationship_types,
278
+ "direction": _sdk_direction(direction),
279
+ "filters": filters or {},
280
+ "limit": limit,
281
+ "cursor": cursor,
282
+ },
283
+ GraphResult.from_api,
284
+ )
285
+
286
+ def related(
287
+ self,
288
+ start: dict[str, Any],
289
+ *,
290
+ relationship_types: list[str] | None = None,
291
+ direction: str = "any",
292
+ filters: dict[str, Any] | None = None,
293
+ limit: int = 20,
294
+ cursor: str | None = None,
295
+ ) -> Page[GraphResult]:
296
+ _validate_range(limit, "limit", 1, 1000)
297
+ return self._project._post_page(
298
+ "/graph/related",
299
+ {
300
+ "start": start,
301
+ "max_depth": 1,
302
+ "relationship_types": relationship_types,
303
+ "direction": _sdk_direction(direction),
304
+ "filters": filters or {},
305
+ "limit": limit,
306
+ "cursor": cursor,
307
+ },
308
+ GraphResult.from_api,
309
+ )
310
+
311
+ def path(
312
+ self,
313
+ source: dict[str, Any],
314
+ target: dict[str, Any],
315
+ *,
316
+ max_depth: int = 5,
317
+ relationship_types: list[str] | None = None,
318
+ direction: str = "any",
319
+ ) -> GraphPathResponse:
320
+ _validate_required(source, "source")
321
+ _validate_required(target, "target")
322
+ _validate_range(max_depth, "max_depth", 1, 20)
323
+ payload = _compact(
324
+ {
325
+ "source": source,
326
+ "target": target,
327
+ "max_depth": max_depth,
328
+ "relationship_types": relationship_types,
329
+ "direction": _sdk_direction(direction),
330
+ }
331
+ )
332
+ response = self._project._client._post(
333
+ "/graph/path",
334
+ payload,
335
+ )
336
+ return GraphPathResponse.from_api(response)
337
+
338
+ def connection(
339
+ self,
340
+ entities: list[dict[str, Any]],
341
+ *,
342
+ max_depth: int = 5,
343
+ relationship_types: list[str] | None = None,
344
+ direction: str = "any",
345
+ ) -> GraphConnectionResponse:
346
+ if len(entities) < 2 or len(entities) > 10:
347
+ raise PolygresValidationError("entities must contain 2..10 items")
348
+ _validate_range(max_depth, "max_depth", 1, 20)
349
+ payload = _compact(
350
+ {
351
+ "entities": entities,
352
+ "max_depth": max_depth,
353
+ "relationship_types": relationship_types,
354
+ "direction": _sdk_direction(direction),
355
+ }
356
+ )
357
+ response = self._project._client._post(
358
+ "/graph/connection",
359
+ payload,
360
+ )
361
+ return GraphConnectionResponse.from_api(response)
362
+
363
+
364
+ @dataclass
365
+ class VectorNamespace:
366
+ _project: Project
367
+
368
+ def search(
369
+ self,
370
+ embedding: list[float],
371
+ *,
372
+ config: str | None = None,
373
+ limit: int | None = None,
374
+ filters: dict[str, Any] | None = None,
375
+ max_distance: float | None = None,
376
+ min_similarity: float | None = None,
377
+ include_values: bool = False,
378
+ cursor: str | None = None,
379
+ ) -> Page[VectorResult]:
380
+ _validate_embedding(embedding)
381
+ _validate_vector_options(limit, max_distance, min_similarity)
382
+ return self._project._post_page(
383
+ "/vector/search",
384
+ {
385
+ "embedding": embedding,
386
+ "config": config,
387
+ "limit": limit,
388
+ "filters": filters or {},
389
+ "max_distance": max_distance,
390
+ "min_similarity": min_similarity,
391
+ "include_values": include_values,
392
+ "cursor": cursor,
393
+ },
394
+ VectorResult.from_api,
395
+ )
396
+
397
+ def similar_to(
398
+ self,
399
+ row_id: str,
400
+ *,
401
+ config: str | None = None,
402
+ limit: int | None = None,
403
+ filters: dict[str, Any] | None = None,
404
+ max_distance: float | None = None,
405
+ min_similarity: float | None = None,
406
+ include_values: bool = False,
407
+ cursor: str | None = None,
408
+ ) -> Page[VectorResult]:
409
+ if not row_id:
410
+ raise PolygresValidationError("row_id is required")
411
+ _validate_vector_options(limit, max_distance, min_similarity)
412
+ return self._project._post_page(
413
+ "/vector/similar-to",
414
+ {
415
+ "row_id": row_id,
416
+ "config": config,
417
+ "limit": limit,
418
+ "filters": filters or {},
419
+ "max_distance": max_distance,
420
+ "min_similarity": min_similarity,
421
+ "include_values": include_values,
422
+ "cursor": cursor,
423
+ },
424
+ VectorResult.from_api,
425
+ )
426
+
427
+
428
+ @dataclass
429
+ class TextNamespace:
430
+ _project: Project
431
+
432
+ def tsvector(
433
+ self,
434
+ query: str,
435
+ *,
436
+ config: str,
437
+ limit: int = 10,
438
+ filters: dict[str, Any] | None = None,
439
+ cursor: str | None = None,
440
+ ) -> Page[TextResult]:
441
+ _validate_text_query(query)
442
+ _validate_range(limit, "limit", 1, 1000)
443
+ return self._text_page(
444
+ "tsvector",
445
+ {
446
+ "query": query,
447
+ "config": config,
448
+ "limit": limit,
449
+ "filters": filters or {},
450
+ "cursor": cursor,
451
+ },
452
+ )
453
+
454
+ def fuzzy(
455
+ self,
456
+ query: str,
457
+ *,
458
+ config: str,
459
+ limit: int = 10,
460
+ filters: dict[str, Any] | None = None,
461
+ cursor: str | None = None,
462
+ ) -> Page[TextResult]:
463
+ _validate_text_query(query)
464
+ _validate_range(limit, "limit", 1, 1000)
465
+ return self._text_page(
466
+ "fuzzy",
467
+ {
468
+ "query": query,
469
+ "config": config,
470
+ "limit": limit,
471
+ "filters": filters or {},
472
+ "cursor": cursor,
473
+ },
474
+ )
475
+
476
+ def _text_page(self, mode: str, payload: dict[str, Any]) -> Page[TextResult]:
477
+ return self._project._post_page(
478
+ f"/text/{mode}",
479
+ payload,
480
+ TextResult.from_api,
481
+ )
482
+
483
+
484
+ @dataclass
485
+ class HybridNamespace:
486
+ _project: Project
487
+
488
+ def graph_first(
489
+ self,
490
+ start: dict[str, Any],
491
+ embedding: list[float],
492
+ *,
493
+ config: str | None = None,
494
+ max_depth: int = 2,
495
+ relationship_types: list[str] | None = None,
496
+ direction: str = "any",
497
+ filters: dict[str, Any] | None = None,
498
+ limit: int = 10,
499
+ cursor: str | None = None,
500
+ ) -> Page[HybridResult]:
501
+ _validate_embedding(embedding)
502
+ _validate_range(max_depth, "max_depth", 1, 20)
503
+ _validate_range(limit, "limit", 1, 1000)
504
+ return self._hybrid_page(
505
+ "graph-first",
506
+ {
507
+ "start": start,
508
+ "embedding": embedding,
509
+ "config": config,
510
+ "max_depth": max_depth,
511
+ "relationship_types": relationship_types,
512
+ "direction": _sdk_direction(direction),
513
+ "filters": filters or {},
514
+ "limit": limit,
515
+ "cursor": cursor,
516
+ },
517
+ )
518
+
519
+ def vector_first(
520
+ self,
521
+ embedding: list[float],
522
+ *,
523
+ config: str | None = None,
524
+ vector_limit: int = 20,
525
+ max_depth: int = 1,
526
+ relationship_types: list[str] | None = None,
527
+ direction: str = "any",
528
+ filters: dict[str, Any] | None = None,
529
+ limit: int = 10,
530
+ cursor: str | None = None,
531
+ ) -> Page[HybridResult]:
532
+ _validate_embedding(embedding)
533
+ _validate_range(vector_limit, "vector_limit", 1, 1000)
534
+ _validate_range(max_depth, "max_depth", 1, 20)
535
+ _validate_range(limit, "limit", 1, 1000)
536
+ return self._hybrid_page(
537
+ "vector-first",
538
+ {
539
+ "embedding": embedding,
540
+ "config": config,
541
+ "vector_limit": vector_limit,
542
+ "max_depth": max_depth,
543
+ "relationship_types": relationship_types,
544
+ "direction": _sdk_direction(direction),
545
+ "filters": filters or {},
546
+ "limit": limit,
547
+ "cursor": cursor,
548
+ },
549
+ )
550
+
551
+ def joint(
552
+ self,
553
+ embedding: list[float],
554
+ start: dict[str, Any],
555
+ *,
556
+ config: str | None = None,
557
+ vector_weight: float = 0.7,
558
+ graph_weight: float = 0.3,
559
+ max_depth: int = 2,
560
+ limit: int = 10,
561
+ cursor: str | None = None,
562
+ ) -> Page[HybridResult]:
563
+ _validate_embedding(embedding)
564
+ _validate_range(max_depth, "max_depth", 1, 20)
565
+ _validate_range(limit, "limit", 1, 1000)
566
+ return self._hybrid_page(
567
+ "joint",
568
+ {
569
+ "embedding": embedding,
570
+ "start": start,
571
+ "config": config,
572
+ "weights": {"vector": vector_weight, "graph": graph_weight},
573
+ "max_depth": max_depth,
574
+ "limit": limit,
575
+ "cursor": cursor,
576
+ },
577
+ )
578
+
579
+ def _hybrid_page(self, mode: str, payload: dict[str, Any]) -> Page[HybridResult]:
580
+ return self._project._post_page(
581
+ f"/hybrid/{mode}",
582
+ payload,
583
+ HybridResult.from_api,
584
+ )
585
+
586
+
587
+ def _select_runtime_url(*, runtime_url: str | None, base_url: str | None) -> str:
588
+ normalized_runtime_url = runtime_url.rstrip("/") if runtime_url is not None else None
589
+ normalized_base_url = base_url.rstrip("/") if base_url is not None else None
590
+ if (
591
+ normalized_runtime_url
592
+ and normalized_base_url
593
+ and normalized_runtime_url != normalized_base_url
594
+ ):
595
+ raise PolygresValidationError("runtime_url and base_url must match when both are provided")
596
+ selected = normalized_runtime_url or normalized_base_url
597
+ if selected is None:
598
+ raise PolygresValidationError("runtime_url is required")
599
+ return selected
600
+
601
+
602
+ def _validate_base_url(base_url: str) -> None:
603
+ parsed = urlparse(base_url)
604
+ if parsed.scheme == "https" and parsed.netloc:
605
+ return
606
+ if parsed.scheme == "http" and parsed.hostname in {"localhost", "127.0.0.1"}:
607
+ return
608
+ raise PolygresValidationError("base_url must be HTTPS except localhost development")
609
+
610
+
611
+ def _validate_positive_timeout(value: float, name: str) -> None:
612
+ if value <= 0 or not math.isfinite(value):
613
+ raise PolygresValidationError(f"{name} must be positive")
614
+
615
+
616
+ def _validate_required(value: Any, name: str) -> None:
617
+ if not value:
618
+ raise PolygresValidationError(f"{name} is required")
619
+
620
+
621
+ def _validate_range(value: int, name: str, minimum: int, maximum: int) -> None:
622
+ if value < minimum or value > maximum:
623
+ raise PolygresValidationError(f"{name} must be between {minimum} and {maximum}")
624
+
625
+
626
+ def _validate_embedding(embedding: list[float]) -> None:
627
+ if not embedding:
628
+ raise PolygresValidationError("embedding must be non-empty")
629
+ for value in embedding:
630
+ if not isinstance(value, (int, float)) or not math.isfinite(float(value)):
631
+ raise PolygresValidationError("embedding values must be finite numbers")
632
+
633
+
634
+ def _validate_text_query(query: str) -> None:
635
+ if not isinstance(query, str) or not query.strip():
636
+ raise PolygresValidationError("query must be non-empty")
637
+
638
+
639
+ def _validate_vector_options(
640
+ limit: int | None, max_distance: float | None, min_similarity: float | None
641
+ ) -> None:
642
+ if limit is not None:
643
+ _validate_range(limit, "limit", 1, 1000)
644
+ if max_distance is not None and min_similarity is not None:
645
+ raise PolygresValidationError("max_distance and min_similarity are mutually exclusive")
646
+
647
+
648
+ def _sdk_direction(direction: str) -> str:
649
+ if direction not in {"out", "in", "any", "both"}:
650
+ raise PolygresValidationError("direction must be out, in, any, or both")
651
+ return "any" if direction == "both" else direction
652
+
653
+
654
+ def _compact(payload: dict[str, Any]) -> dict[str, Any]:
655
+ return {key: value for key, value in payload.items() if value is not None}
656
+
657
+
658
+ def _sleep_before_retry(attempt: int, retry_after: str | None) -> None:
659
+ delay = 0.025 * (2**attempt) + random.uniform(0, 0.005)
660
+ if retry_after:
661
+ try:
662
+ parsed_delay = float(retry_after)
663
+ if parsed_delay >= 0:
664
+ delay = parsed_delay
665
+ except ValueError:
666
+ pass
667
+ time.sleep(delay)
668
+
669
+
670
+ def _api_error(response: httpx.Response) -> PolygresAPIError:
671
+ try:
672
+ body = response.json()
673
+ except ValueError:
674
+ body = {}
675
+ error = body.get("error", {}) if isinstance(body, dict) else {}
676
+ message = str(error.get("message") or f"Polygres API error {response.status_code}")
677
+ request_id = body.get("request_id") if isinstance(body, dict) else None
678
+ code = error.get("code")
679
+ details = error.get("details") if isinstance(error.get("details"), dict) else {}
680
+ kwargs = {
681
+ "status_code": response.status_code,
682
+ "request_id": request_id,
683
+ "code": code,
684
+ "details": details,
685
+ }
686
+ if response.status_code == 401:
687
+ return PolygresAuthError(message, **kwargs)
688
+ if response.status_code == 403:
689
+ return PolygresPermissionError(message, **kwargs)
690
+ if response.status_code == 404:
691
+ return PolygresNotFoundError(message, **kwargs)
692
+ if response.status_code == 429:
693
+ return PolygresRateLimitError(message, **kwargs)
694
+ if response.status_code in {408, 500, 502, 503, 504}:
695
+ return PolygresRuntimeError(message, **kwargs)
696
+ return PolygresAPIError(message, **kwargs)