cohesion-sdk 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cohesion/__init__.py ADDED
@@ -0,0 +1,65 @@
1
+ # Copyright 2026 COHESION AUTH LLC
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # SPDX-License-Identifier: Apache-2.0
7
+ """Official Python SDK for the COHESION Judgment Independence Score API.
8
+
9
+ The client surface is locked by the v1.1.0 API pilot-readiness spec.
10
+ Import what you need from the top level::
11
+
12
+ from cohesion import (
13
+ Client,
14
+ wrap_openai,
15
+ wrap_anthropic,
16
+ wrap_azure_openai,
17
+ LangChainCallback,
18
+ )
19
+ from cohesion import (
20
+ CohesionError,
21
+ AuthenticationError,
22
+ RateLimitError,
23
+ ValidationError,
24
+ ServerError,
25
+ NetworkError,
26
+ )
27
+ """
28
+
29
+ from ._version import __version__
30
+ from .client import Client
31
+ from .exceptions import (
32
+ AuthenticationError,
33
+ CohesionError,
34
+ NetworkError,
35
+ RateLimitError,
36
+ ServerError,
37
+ ValidationError,
38
+ )
39
+ from .langchain import LangChainCallback
40
+ from .wrappers import (
41
+ DecisionReport,
42
+ WrapContext,
43
+ WrappedCompletion,
44
+ wrap_anthropic,
45
+ wrap_azure_openai,
46
+ wrap_openai,
47
+ )
48
+
49
+ __all__ = [
50
+ "__version__",
51
+ "Client",
52
+ "CohesionError",
53
+ "AuthenticationError",
54
+ "RateLimitError",
55
+ "ValidationError",
56
+ "ServerError",
57
+ "NetworkError",
58
+ "wrap_openai",
59
+ "wrap_anthropic",
60
+ "wrap_azure_openai",
61
+ "LangChainCallback",
62
+ "WrapContext",
63
+ "DecisionReport",
64
+ "WrappedCompletion",
65
+ ]
cohesion/_version.py ADDED
@@ -0,0 +1,8 @@
1
+ # Copyright 2026 COHESION AUTH LLC
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # SPDX-License-Identifier: Apache-2.0
7
+
8
+ __version__ = "1.0.0"
cohesion/cli.py ADDED
@@ -0,0 +1,118 @@
1
+ # Copyright 2026 COHESION AUTH LLC
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # SPDX-License-Identifier: Apache-2.0
7
+ """``cohesion`` command-line entrypoint.
8
+
9
+ Usage::
10
+
11
+ cohesion score --operator-id X --from-file y.json
12
+ cohesion score --operator-id X --from-file - # read from stdin
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import Any
23
+
24
+ from ._version import __version__
25
+ from .client import Client
26
+ from .exceptions import CohesionError
27
+
28
+ __all__ = ["main", "build_parser"]
29
+
30
+
31
+ def build_parser() -> argparse.ArgumentParser:
32
+ parser = argparse.ArgumentParser(
33
+ prog="cohesion",
34
+ description="COHESION Judgment Independence Score API command-line tool.",
35
+ )
36
+ parser.add_argument("--version", action="version", version=f"cohesion-sdk-python {__version__}")
37
+ parser.add_argument(
38
+ "--base-url",
39
+ default=os.environ.get("COHESION_BASE_URL", "https://api.cohesionauth.com"),
40
+ help="API base URL. Defaults to $COHESION_BASE_URL or the production host.",
41
+ )
42
+
43
+ sub = parser.add_subparsers(dest="command", required=True)
44
+
45
+ score = sub.add_parser("score", help="Submit a scoring payload.")
46
+ score.add_argument("--operator-id", required=True, help="Operator identifier.")
47
+ score.add_argument(
48
+ "--from-file",
49
+ required=True,
50
+ help="Path to a JSON file containing the scoring payload, or '-' for stdin.",
51
+ )
52
+ score.add_argument(
53
+ "--session-id",
54
+ default=None,
55
+ help="Override session_id. Default: file contents.",
56
+ )
57
+ score.add_argument(
58
+ "--domain",
59
+ default=None,
60
+ help="Override domain. Default: file contents.",
61
+ )
62
+ return parser
63
+
64
+
65
+ def _read_payload(path: str) -> dict[str, Any]:
66
+ raw = sys.stdin.read() if path == "-" else Path(path).read_text(encoding="utf-8")
67
+ try:
68
+ data = json.loads(raw)
69
+ except json.JSONDecodeError as exc:
70
+ raise SystemExit(f"Invalid JSON in {path}: {exc}") from exc
71
+ if not isinstance(data, dict):
72
+ raise SystemExit("Payload must be a JSON object.")
73
+ return data
74
+
75
+
76
+ def _resolve_api_key() -> str:
77
+ key = os.environ.get("COHESION_API_KEY")
78
+ if not key:
79
+ raise SystemExit(
80
+ "COHESION_API_KEY environment variable is required. "
81
+ "Get your key at cohesionauth.com/dashboard/api-keys."
82
+ )
83
+ return key
84
+
85
+
86
+ def main(argv: list[str] | None = None) -> int:
87
+ parser = build_parser()
88
+ args = parser.parse_args(argv)
89
+
90
+ if args.command == "score":
91
+ payload = _read_payload(args.from_file)
92
+ if args.session_id:
93
+ payload["session_id"] = args.session_id
94
+ if args.domain:
95
+ payload["domain"] = args.domain
96
+ payload["operator_id"] = args.operator_id
97
+
98
+ client = Client(api_key=_resolve_api_key(), base_url=args.base_url)
99
+ try:
100
+ response = client.score(**payload)
101
+ except CohesionError as err:
102
+ sys.stderr.write(f"error: {err}\n")
103
+ if err.next_step:
104
+ sys.stderr.write(f"next_step: {err.next_step}\n")
105
+ if err.request_id:
106
+ sys.stderr.write(f"request_id: {err.request_id}\n")
107
+ return 1
108
+ finally:
109
+ client.close()
110
+ sys.stdout.write(response.model_dump_json(indent=2) + "\n")
111
+ return 0
112
+
113
+ parser.error(f"Unknown command: {args.command}")
114
+ return 2
115
+
116
+
117
+ if __name__ == "__main__": # pragma: no cover
118
+ raise SystemExit(main())
cohesion/client.py ADDED
@@ -0,0 +1,349 @@
1
+ # Copyright 2026 COHESION AUTH LLC
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # SPDX-License-Identifier: Apache-2.0
7
+ """Cohesion client.
8
+
9
+ Method surface locked by spec §13.1. Retry with full-jitter exponential
10
+ backoff on 429 + 5xx + NetworkError, max 3 attempts by default, honors
11
+ ``Retry-After`` (integer seconds or HTTP-date). ``Idempotency-Key``
12
+ auto-generated as UUIDv4 on every POST.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ import re
19
+ import time
20
+ import uuid
21
+ from typing import Any
22
+
23
+ import httpx
24
+
25
+ from ._version import __version__
26
+ from .exceptions import (
27
+ AuthenticationError,
28
+ CohesionError,
29
+ NetworkError,
30
+ RateLimitError,
31
+ ServerError,
32
+ ValidationError,
33
+ error_from_response,
34
+ sanitize_message,
35
+ )
36
+ from .logging import get_logger
37
+ from .models import (
38
+ AuditEvent,
39
+ AuditLogResponse,
40
+ BatchScoreRequest,
41
+ BatchScoreResponse,
42
+ ComplianceReport,
43
+ Interaction,
44
+ KeyRevocationResponse,
45
+ KeyRotationResponse,
46
+ MaintenanceRecommendation,
47
+ OperatorProfile,
48
+ OrganizationDashboard,
49
+ ScoreRequest,
50
+ ScoreResponse,
51
+ )
52
+ from .retry import DEFAULT_RETRY_POLICY, RetryPolicy, compute_backoff_ms, is_retryable_status
53
+ from .telemetry import maybe_init_telemetry, report_sdk_error
54
+
55
+ __all__ = ["Client"]
56
+
57
+ _DEFAULT_BASE_URL = "https://api.cohesionauth.com"
58
+ _OPERATOR_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,256}$")
59
+
60
+
61
+ class Client:
62
+ """Synchronous client for the COHESION scoring API."""
63
+
64
+ def __init__(
65
+ self,
66
+ api_key: str,
67
+ base_url: str = _DEFAULT_BASE_URL,
68
+ timeout: float = 30.0,
69
+ max_retries: int = 3,
70
+ logger: logging.Logger | None = None,
71
+ enable_telemetry: bool = False,
72
+ *,
73
+ http_client: httpx.Client | None = None,
74
+ retry_policy: RetryPolicy | None = None,
75
+ sleep: Any | None = None,
76
+ ) -> None:
77
+ if not api_key or not isinstance(api_key, str):
78
+ raise ValueError("api_key is required")
79
+ self._api_key = api_key
80
+ self._base_url = base_url.rstrip("/")
81
+ self._timeout = timeout
82
+ self._max_retries = max_retries
83
+ self._user_agent = f"cohesion-sdk-python/{__version__}"
84
+ self._logger = get_logger(logger)
85
+ self._retry_policy = (retry_policy or DEFAULT_RETRY_POLICY).__class__(
86
+ max_retries=max_retries,
87
+ base_delay_ms=(retry_policy or DEFAULT_RETRY_POLICY).base_delay_ms,
88
+ max_delay_ms=(retry_policy or DEFAULT_RETRY_POLICY).max_delay_ms,
89
+ retryable_statuses=(retry_policy or DEFAULT_RETRY_POLICY).retryable_statuses,
90
+ )
91
+ self._sleep = sleep if sleep is not None else time.sleep
92
+ self._owns_http_client = http_client is None
93
+ self._http = http_client if http_client is not None else httpx.Client(timeout=timeout)
94
+
95
+ if enable_telemetry:
96
+ maybe_init_telemetry(True)
97
+
98
+ # ───── Lifecycle ──────────────────────────────────────────────────────
99
+
100
+ def close(self) -> None:
101
+ if self._owns_http_client:
102
+ self._http.close()
103
+
104
+ def __enter__(self) -> Client:
105
+ return self
106
+
107
+ def __exit__(self, *_exc: Any) -> None:
108
+ self.close()
109
+
110
+ # ───── Scoring endpoints ──────────────────────────────────────────────
111
+
112
+ def score(self, **kwargs: Any) -> ScoreResponse:
113
+ req = ScoreRequest.model_validate(kwargs)
114
+ data = self._request("POST", "/v1/score", json_body=req.model_dump())
115
+ return ScoreResponse.model_validate(data)
116
+
117
+ def score_batch(self, interactions: list[Any]) -> BatchScoreResponse:
118
+ """Score up to 100 interactions in a single call.
119
+
120
+ ``interactions`` accepts either:
121
+ * a full :class:`BatchScoreRequest` payload as a dict (with
122
+ ``session_id``, ``operator_id``, ``domain``, ``interactions``), or
123
+ * a list of interaction dicts/:class:`Interaction` objects, in which
124
+ case the caller must pass ``session_id``, ``operator_id``, and
125
+ ``domain`` via the corresponding kwargs in a dict wrapper.
126
+
127
+ The most common usage passes a dict matching
128
+ :class:`BatchScoreRequest`.
129
+ """
130
+ if isinstance(interactions, dict):
131
+ req = BatchScoreRequest.model_validate(interactions)
132
+ elif isinstance(interactions, list):
133
+ items = [
134
+ i if isinstance(i, Interaction) else Interaction.model_validate(i)
135
+ for i in interactions
136
+ ]
137
+ if not items:
138
+ raise ValidationError(
139
+ "interactions list must be non-empty",
140
+ field="interactions",
141
+ )
142
+ req = BatchScoreRequest.model_validate(
143
+ {
144
+ "session_id": "",
145
+ "operator_id": "",
146
+ "domain": "general",
147
+ "interactions": [i.model_dump() for i in items],
148
+ }
149
+ )
150
+ else:
151
+ raise ValidationError(
152
+ "interactions must be a dict (BatchScoreRequest) or a list",
153
+ field="interactions",
154
+ )
155
+ data = self._request("POST", "/v1/score/batch", json_body=req.model_dump())
156
+ return BatchScoreResponse.model_validate(data)
157
+
158
+ def operator_profile(self, operator_id: str) -> OperatorProfile:
159
+ self._validate_operator_id(operator_id)
160
+ data = self._request("GET", f"/v1/operator/{operator_id}/profile")
161
+ return OperatorProfile.model_validate(data)
162
+
163
+ def organization_dashboard(self) -> OrganizationDashboard:
164
+ data = self._request("GET", "/v1/organization/dashboard")
165
+ return OrganizationDashboard.model_validate(data)
166
+
167
+ def maintenance_recommend(self, **kwargs: Any) -> MaintenanceRecommendation:
168
+ operator_id = kwargs.get("operator_id")
169
+ if not operator_id or not isinstance(operator_id, str):
170
+ raise ValidationError("operator_id is required", field="operator_id")
171
+ self._validate_operator_id(operator_id)
172
+ body = {k: v for k, v in kwargs.items() if v is not None}
173
+ data = self._request("POST", "/v1/maintenance/recommend", json_body=body)
174
+ return MaintenanceRecommendation.model_validate(data)
175
+
176
+ def compliance_report(self) -> ComplianceReport:
177
+ data = self._request("GET", "/v1/compliance/report")
178
+ return ComplianceReport.model_validate(data)
179
+
180
+ # ───── Admin endpoints ───────────────────────────────────────────────
181
+
182
+ def admin_key_rotate(self) -> KeyRotationResponse:
183
+ data = self._request("POST", "/v1/admin/key/rotate", json_body={})
184
+ return KeyRotationResponse.model_validate(data)
185
+
186
+ def admin_key_revoke(self) -> KeyRevocationResponse:
187
+ data = self._request("POST", "/v1/admin/key/revoke", json_body={})
188
+ return KeyRevocationResponse.model_validate(data)
189
+
190
+ def admin_audit_log(
191
+ self,
192
+ event_type: str | None = None,
193
+ since: str | None = None,
194
+ until: str | None = None,
195
+ limit: int = 100,
196
+ ) -> list[AuditEvent]:
197
+ if not 1 <= limit <= 500:
198
+ raise ValidationError(
199
+ "limit must be between 1 and 500",
200
+ field="limit",
201
+ )
202
+ params: dict[str, Any] = {"limit": limit}
203
+ if event_type is not None:
204
+ params["event_type"] = event_type
205
+ if since is not None:
206
+ params["since"] = since
207
+ if until is not None:
208
+ params["until"] = until
209
+ data = self._request("GET", "/v1/admin/audit-log", params=params)
210
+ resp = AuditLogResponse.model_validate(data)
211
+ return resp.events
212
+
213
+ # ───── Internals ─────────────────────────────────────────────────────
214
+
215
+ @staticmethod
216
+ def _validate_operator_id(operator_id: str) -> None:
217
+ if not isinstance(operator_id, str) or not _OPERATOR_ID_RE.match(operator_id):
218
+ raise ValidationError(
219
+ "operator_id must be 1-256 characters, alphanumeric plus - and _",
220
+ field="operator_id",
221
+ )
222
+
223
+ def _build_headers(self, method: str, idempotency_key: str | None) -> dict[str, str]:
224
+ headers: dict[str, str] = {
225
+ "X-API-Key": self._api_key,
226
+ "User-Agent": self._user_agent,
227
+ "Accept": "application/json",
228
+ }
229
+ if method == "POST":
230
+ headers["Content-Type"] = "application/json"
231
+ headers["Idempotency-Key"] = idempotency_key or str(uuid.uuid4())
232
+ return headers
233
+
234
+ def _request(
235
+ self,
236
+ method: str,
237
+ path: str,
238
+ *,
239
+ json_body: Any | None = None,
240
+ params: dict[str, Any] | None = None,
241
+ idempotency_key: str | None = None,
242
+ ) -> Any:
243
+ url = f"{self._base_url}{path}"
244
+ headers = self._build_headers(method, idempotency_key)
245
+
246
+ last_error: CohesionError | None = None
247
+
248
+ # total attempts = max_retries + 1
249
+ for attempt in range(self._max_retries + 1):
250
+ try:
251
+ self._logger.debug("HTTP %s %s", method, path, extra={"attempt": attempt})
252
+ response = self._http.request(
253
+ method,
254
+ url,
255
+ headers=headers,
256
+ params=params,
257
+ json=json_body if method == "POST" else None,
258
+ timeout=self._timeout,
259
+ )
260
+ except httpx.TimeoutException as exc:
261
+ err = NetworkError(
262
+ sanitize_message(f"Request timed out after {self._timeout}s: {exc}"),
263
+ next_step=(
264
+ "Check connectivity to api.cohesionauth.com and ensure TLS 1.3 "
265
+ "egress is permitted from your environment."
266
+ ),
267
+ )
268
+ last_error = err
269
+ if attempt < self._max_retries:
270
+ delay_ms = compute_backoff_ms(attempt, self._retry_policy)
271
+ self._logger.warning(
272
+ "Transport timeout. Sleeping %dms.",
273
+ delay_ms,
274
+ extra={"attempt": attempt},
275
+ )
276
+ self._sleep(delay_ms / 1000.0)
277
+ continue
278
+ report_sdk_error(err)
279
+ raise err from exc
280
+ except httpx.HTTPError as exc:
281
+ err = NetworkError(
282
+ sanitize_message(f"Network error: {exc}"),
283
+ next_step=(
284
+ "Check connectivity to api.cohesionauth.com and ensure TLS 1.3 "
285
+ "egress is permitted from your environment."
286
+ ),
287
+ )
288
+ last_error = err
289
+ if attempt < self._max_retries:
290
+ delay_ms = compute_backoff_ms(attempt, self._retry_policy)
291
+ self._logger.warning(
292
+ "Transport error. Sleeping %dms.",
293
+ delay_ms,
294
+ extra={"attempt": attempt},
295
+ )
296
+ self._sleep(delay_ms / 1000.0)
297
+ continue
298
+ report_sdk_error(err)
299
+ raise err from exc
300
+
301
+ if 200 <= response.status_code < 300:
302
+ try:
303
+ return response.json()
304
+ except ValueError as exc:
305
+ raise ServerError(
306
+ "Server returned non-JSON response.",
307
+ status_code=response.status_code,
308
+ ) from exc
309
+
310
+ envelope: dict[str, Any] | None
311
+ try:
312
+ envelope = response.json()
313
+ if not isinstance(envelope, dict):
314
+ envelope = None
315
+ except ValueError:
316
+ envelope = None
317
+
318
+ retry_after_header = response.headers.get("Retry-After")
319
+ err = error_from_response(response.status_code, envelope, retry_after_header)
320
+ last_error = err
321
+
322
+ if attempt < self._max_retries and is_retryable_status(
323
+ response.status_code, self._retry_policy
324
+ ):
325
+ retry_after_ms = err.retry_after * 1000 if isinstance(err, RateLimitError) else None
326
+ delay_ms = compute_backoff_ms(attempt, self._retry_policy, retry_after_ms)
327
+ self._logger.warning(
328
+ "Retryable %d. Sleeping %dms.",
329
+ response.status_code,
330
+ delay_ms,
331
+ extra={"attempt": attempt, "request_id": err.request_id},
332
+ )
333
+ self._sleep(delay_ms / 1000.0)
334
+ continue
335
+
336
+ if isinstance(
337
+ err,
338
+ AuthenticationError
339
+ | RateLimitError
340
+ | ValidationError
341
+ | ServerError
342
+ | CohesionError,
343
+ ):
344
+ report_sdk_error(err)
345
+ raise err
346
+
347
+ # Unreachable: loop either returns or raises.
348
+ assert last_error is not None # nosec B101 - unreachable safety net, loop invariant guarantees last_error is set
349
+ raise last_error