netgreener-sdk 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ Metadata-Version: 2.4
2
+ Name: netgreener_sdk
3
+ Version: 0.1.0
4
+ Summary: Thin Python client for the NetGreener HTTP API (auth, run sessions, surrogate training).
5
+ Author-email: Saeed Khazaee <saeed@netgreener.com>
6
+ Requires-Python: >=3.10
7
+ Requires-Dist: requests<3,>=2.28.0
8
+ Provides-Extra: dev
9
+ Requires-Dist: pytest>=7.0; extra == 'dev'
@@ -0,0 +1,82 @@
1
+ # netgreener_sdk
2
+
3
+ Thin Python client for the NetGreener HTTP API (`/api/v1`).
4
+
5
+ > **Full handbook (CLI + VS Code + SDK + CI):** [`../docs/PROGRAMMER_BOOKLET.md`](../docs/PROGRAMMER_BOOKLET.md) Same contracts as `api_server` routers; paths mirror `netgreener_cli/netgreener/api_client.py` with extra **`/features`** coverage.
6
+
7
+ ## Install
8
+
9
+ From the repo (editable):
10
+
11
+ ```bash
12
+ pip install -e ./netgreener_sdk
13
+ ```
14
+
15
+ Requires Python **3.10+** and `requests`.
16
+
17
+ ## Quick start
18
+
19
+ ```python
20
+ from netgreener_sdk import NetGreenerClient, APIError
21
+
22
+ client = NetGreenerClient(base_url="http://localhost:8000/api/v1")
23
+ client.login(email="you@example.com", password="…")
24
+ me = client.get_me()
25
+ print(me["email"])
26
+
27
+ # Project feature metadata (Parquet / path pointer)
28
+ row = client.post_project_feature_metadata(
29
+ 1,
30
+ {"parquet_path": "/data/features.parquet", "feature_version": 1},
31
+ )
32
+ ```
33
+
34
+ ## Base URL
35
+
36
+ Pass the API root **including** `/api/v1`, e.g. `https://your-host/api/v1`.
37
+
38
+ ## Auth
39
+
40
+ - **`login(email, password)`**; sets the bearer token on the client; response includes `user`, `access_token`, `expires_in`.
41
+ - **`set_token(token)`**; use an existing JWT.
42
+ - **`get_access_token()`**; current token or `None`.
43
+ - **`get_me()`**; `GET /auth/me` (requires valid token).
44
+
45
+ Feature metadata **writes** on the server require **developer** or **admin**; **GET** requires project access.
46
+
47
+ ## See also
48
+
49
+ - `docs/CODEANALYZER_TO_API_SURFACE.md`; which endpoints exist and how they map to CodeAnalyzer outputs.
50
+ - `docs/PLAN_INDEX.md`; full product/engineering plan checklist.
51
+
52
+ ## Tests
53
+
54
+ From repo root:
55
+
56
+ ```bash
57
+ python -m pytest netgreener_sdk/tests -q
58
+ ```
59
+
60
+ ## Publish (CI)
61
+
62
+ **Same PyPI token as the CLI** — Azure Library group `Package_credential`, secret `PYPI_API_TOKEN`, must allow upload to both `netgreener` and `netgreener_sdk`.
63
+
64
+ **Different pipeline and tag** (not the CLI pipeline):
65
+
66
+ | Package | Pipeline | Git tag | PyPI name |
67
+ |---------|----------|---------|-----------|
68
+ | CLI | `azure-pipelines-cli-publish.yml` | `cli-v*.*.*` | `netgreener` |
69
+ | SDK | `azure-pipelines-sdk-publish.yml` | `sdk-v*.*.*` | `netgreener_sdk` |
70
+
71
+ Example SDK release: `git tag sdk-v0.1.0 && git push origin sdk-v0.1.0`
72
+
73
+ ## Pre-publish local build (recommended)
74
+
75
+ Before pushing the `sdk-v...` tag, verify the wheel/sdist build is clean:
76
+
77
+ ```bash
78
+ python -m pip install --upgrade pip
79
+ python -m pip install build twine
80
+ python -m build netgreener_sdk
81
+ python -m twine check netgreener_sdk/dist/*
82
+ ```
@@ -0,0 +1,5 @@
1
+ """Thin NetGreener API client (Phase 2 SDK skeleton)."""
2
+
3
+ from netgreener_sdk.client import APIError, NetGreenerClient
4
+
5
+ __all__ = ["APIError", "NetGreenerClient"]
@@ -0,0 +1,520 @@
1
+ """
2
+ Minimal HTTP client for NetGreener API v1.
3
+
4
+ Mirrors the paths used by the CLI (`netgreener.api_client`) without importing the CLI.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Mapping, MutableMapping, Optional
10
+ from urllib.parse import quote
11
+
12
+ import requests
13
+
14
+
15
+ class APIError(Exception):
16
+ def __init__(self, status_code: int, detail: str):
17
+ self.status_code = status_code
18
+ self.detail = detail
19
+ super().__init__(f"HTTP {status_code}: {detail}")
20
+
21
+
22
+ class _AuthPreservingSession(requests.Session):
23
+ """Preserve Authorization and method on redirects (e.g. http→https)."""
24
+
25
+ def rebuild_auth(self, prepared_request, response):
26
+ return
27
+
28
+ def rebuild_method(self, _prepared_request, _response):
29
+ return
30
+
31
+
32
+ class NetGreenerClient:
33
+ """
34
+ :param base_url: API root including ``/api/v1`` (e.g. ``https://host/api/v1``).
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ token: Optional[str] = None,
40
+ base_url: str = "http://localhost:8000/api/v1",
41
+ org_id: Optional[int] = None,
42
+ environment: Optional[str] = None,
43
+ ):
44
+ self.base_url = base_url.rstrip("/")
45
+ self._token = token
46
+ self._org_id = org_id
47
+ self._environment = environment
48
+ self._session = _AuthPreservingSession()
49
+
50
+ def _headers(self) -> Mapping[str, str]:
51
+ h: MutableMapping[str, str] = {"Content-Type": "application/json"}
52
+ if self._token:
53
+ h["Authorization"] = f"Bearer {self._token}"
54
+ if self._org_id is not None:
55
+ h["X-Organization-ID"] = str(self._org_id)
56
+ if self._environment:
57
+ h["X-Environment"] = self._environment
58
+ return h
59
+
60
+ def _raise(self, resp: requests.Response) -> None:
61
+ if resp.ok:
62
+ return
63
+ try:
64
+ body = resp.json()
65
+ detail = body.get("detail", resp.text)
66
+ except Exception:
67
+ detail = resp.text
68
+ if not isinstance(detail, str):
69
+ detail = str(detail)
70
+ raise APIError(resp.status_code, detail)
71
+
72
+ def login(self, email: str, password: str) -> dict[str, Any]:
73
+ resp = self._session.post(
74
+ f"{self.base_url}/auth/login",
75
+ json={"email": email, "password": password},
76
+ timeout=30,
77
+ )
78
+ self._raise(resp)
79
+ data = resp.json()
80
+ self._token = data.get("access_token")
81
+ return data
82
+
83
+ def set_token(self, token: Optional[str]) -> None:
84
+ self._token = token
85
+
86
+ def set_org_context(self, org_id: Optional[int], environment: Optional[str] = None) -> None:
87
+ """
88
+ Set org-scoped request context used by API authz and billing routes.
89
+ """
90
+ self._org_id = org_id
91
+ self._environment = environment
92
+
93
+ def get_access_token(self) -> Optional[str]:
94
+ """Bearer token set by ``login`` or ``set_token``."""
95
+ return self._token
96
+
97
+ def get_org_context(self) -> dict[str, Optional[str]]:
98
+ return {
99
+ "org_id": self._org_id,
100
+ "environment": self._environment,
101
+ }
102
+
103
+ def create_service_token(
104
+ self,
105
+ org_id: int,
106
+ scopes: list[str],
107
+ *,
108
+ environment: str = "dev",
109
+ expires_in_hours: int = 24 * 30,
110
+ ) -> dict[str, Any]:
111
+ payload = {
112
+ "org_id": org_id,
113
+ "environment": environment,
114
+ "scopes": scopes,
115
+ "expires_in_hours": expires_in_hours,
116
+ }
117
+ resp = self._session.post(
118
+ f"{self.base_url}/auth/service-token",
119
+ headers=self._headers(),
120
+ json=payload,
121
+ timeout=30,
122
+ )
123
+ self._raise(resp)
124
+ return resp.json()
125
+
126
+ def get_me(self) -> dict[str, Any]:
127
+ """GET ``/auth/me``; current user profile (requires a valid token)."""
128
+ resp = self._session.get(
129
+ f"{self.base_url}/auth/me",
130
+ headers=self._headers(),
131
+ timeout=30,
132
+ )
133
+ self._raise(resp)
134
+ return resp.json()
135
+
136
+ def post_run_session(self, payload: dict[str, Any]) -> dict[str, Any]:
137
+ """POST ``/runsessions``; same contract as CLI run upload."""
138
+ resp = self._session.post(
139
+ f"{self.base_url}/runsessions",
140
+ headers=self._headers(),
141
+ json=payload,
142
+ timeout=60,
143
+ )
144
+ self._raise(resp)
145
+ return resp.json()
146
+
147
+ def get_run_session(self, run_id: int) -> dict[str, Any]:
148
+ """GET ``/runsessions/{run_id}``."""
149
+ resp = self._session.get(
150
+ f"{self.base_url}/runsessions/{run_id}",
151
+ headers=self._headers(),
152
+ timeout=30,
153
+ )
154
+ self._raise(resp)
155
+ return resp.json()
156
+
157
+ def get_run_sessions(
158
+ self,
159
+ *,
160
+ project_id: Optional[int] = None,
161
+ ) -> list[dict[str, Any]]:
162
+ """GET ``/runsessions`` with optional ``project_id`` query."""
163
+ params: dict[str, int] = {}
164
+ if project_id is not None:
165
+ params["project_id"] = project_id
166
+ resp = self._session.get(
167
+ f"{self.base_url}/runsessions",
168
+ headers=self._headers(),
169
+ params=params or None,
170
+ timeout=60,
171
+ )
172
+ self._raise(resp)
173
+ return resp.json()
174
+
175
+ def post_surrogate_training(self, payload: dict[str, Any]) -> dict[str, Any]:
176
+ """POST ``/surrogate-training/``."""
177
+ resp = self._session.post(
178
+ f"{self.base_url}/surrogate-training/",
179
+ headers=self._headers(),
180
+ json=payload,
181
+ timeout=60,
182
+ )
183
+ self._raise(resp)
184
+ return resp.json()
185
+
186
+ # --- Features API (`/api/v1/features`); see `docs/CODEANALYZER_TO_API_SURFACE.md` ---
187
+
188
+ def post_feature_collection_input(self, payload: dict[str, Any]) -> dict[str, Any]:
189
+ """POST ``/features/inputs``; requires existing ``feature_name`` in ``featuredefinition``."""
190
+ resp = self._session.post(
191
+ f"{self.base_url}/features/inputs",
192
+ headers=self._headers(),
193
+ json=payload,
194
+ timeout=60,
195
+ )
196
+ self._raise(resp)
197
+ return resp.json()
198
+
199
+ def get_feature_collection_inputs(
200
+ self,
201
+ *,
202
+ project_id: Optional[int] = None,
203
+ feature_name: Optional[str] = None,
204
+ entered_by_user_id: Optional[int] = None,
205
+ ) -> list[dict[str, Any]]:
206
+ """GET ``/features/inputs`` with optional query filters."""
207
+ params: dict[str, Any] = {}
208
+ if project_id is not None:
209
+ params["project_id"] = project_id
210
+ if feature_name is not None:
211
+ params["feature_name"] = feature_name
212
+ if entered_by_user_id is not None:
213
+ params["entered_by_user_id"] = entered_by_user_id
214
+ resp = self._session.get(
215
+ f"{self.base_url}/features/inputs",
216
+ headers=self._headers(),
217
+ params=params or None,
218
+ timeout=60,
219
+ )
220
+ self._raise(resp)
221
+ return resp.json()
222
+
223
+ def get_feature_collection_input(self, input_id: int) -> dict[str, Any]:
224
+ """GET ``/features/inputs/{input_id}``."""
225
+ resp = self._session.get(
226
+ f"{self.base_url}/features/inputs/{input_id}",
227
+ headers=self._headers(),
228
+ timeout=60,
229
+ )
230
+ self._raise(resp)
231
+ return resp.json()
232
+
233
+ def update_feature_collection_input(
234
+ self, input_id: int, payload: dict[str, Any]
235
+ ) -> dict[str, Any]:
236
+ """PUT ``/features/inputs/{input_id}`` (developer/admin)."""
237
+ resp = self._session.put(
238
+ f"{self.base_url}/features/inputs/{input_id}",
239
+ headers=self._headers(),
240
+ json=payload,
241
+ timeout=60,
242
+ )
243
+ self._raise(resp)
244
+ return resp.json()
245
+
246
+ def delete_feature_collection_input(self, input_id: int) -> dict[str, Any]:
247
+ """DELETE ``/features/inputs/{input_id}``."""
248
+ resp = self._session.delete(
249
+ f"{self.base_url}/features/inputs/{input_id}",
250
+ headers=self._headers(),
251
+ timeout=60,
252
+ )
253
+ self._raise(resp)
254
+ return resp.json()
255
+
256
+ def list_project_feature_collection_inputs(self, project_id: int) -> list[dict[str, Any]]:
257
+ """GET ``/features/project/{project_id}/inputs``."""
258
+ resp = self._session.get(
259
+ f"{self.base_url}/features/project/{project_id}/inputs",
260
+ headers=self._headers(),
261
+ timeout=60,
262
+ )
263
+ self._raise(resp)
264
+ return resp.json()
265
+
266
+ def list_feature_definitions(self) -> list[dict[str, Any]]:
267
+ """GET ``/features/definitions``."""
268
+ resp = self._session.get(
269
+ f"{self.base_url}/features/definitions",
270
+ headers=self._headers(),
271
+ timeout=60,
272
+ )
273
+ self._raise(resp)
274
+ return resp.json()
275
+
276
+ def post_feature_definition(self, payload: dict[str, Any]) -> dict[str, Any]:
277
+ """POST ``/features/definitions``."""
278
+ resp = self._session.post(
279
+ f"{self.base_url}/features/definitions",
280
+ headers=self._headers(),
281
+ json=payload,
282
+ timeout=60,
283
+ )
284
+ self._raise(resp)
285
+ return resp.json()
286
+
287
+ def get_feature_definition(self, feature_name: str) -> dict[str, Any]:
288
+ """GET ``/features/definitions/{feature_name}``."""
289
+ safe_name = quote(str(feature_name), safe="")
290
+ resp = self._session.get(
291
+ f"{self.base_url}/features/definitions/{safe_name}",
292
+ headers=self._headers(),
293
+ timeout=60,
294
+ )
295
+ self._raise(resp)
296
+ return resp.json()
297
+
298
+ def update_feature_definition(
299
+ self, feature_name: str, payload: dict[str, Any]
300
+ ) -> dict[str, Any]:
301
+ """PUT ``/features/definitions/{feature_name}``."""
302
+ safe_name = quote(str(feature_name), safe="")
303
+ resp = self._session.put(
304
+ f"{self.base_url}/features/definitions/{safe_name}",
305
+ headers=self._headers(),
306
+ json=payload,
307
+ timeout=60,
308
+ )
309
+ self._raise(resp)
310
+ return resp.json()
311
+
312
+ def delete_feature_definition(self, feature_name: str) -> dict[str, Any]:
313
+ """DELETE ``/features/definitions/{feature_name}``."""
314
+ safe_name = quote(str(feature_name), safe="")
315
+ resp = self._session.delete(
316
+ f"{self.base_url}/features/definitions/{safe_name}",
317
+ headers=self._headers(),
318
+ timeout=60,
319
+ )
320
+ self._raise(resp)
321
+ return resp.json()
322
+
323
+ # --- Project feature metadata (Parquet / on-disk path) ---
324
+
325
+ def get_project_feature_metadata(self, project_id: int) -> dict[str, Any]:
326
+ """GET ``/features/project/{project_id}/feature-metadata``."""
327
+ resp = self._session.get(
328
+ f"{self.base_url}/features/project/{project_id}/feature-metadata",
329
+ headers=self._headers(),
330
+ timeout=60,
331
+ )
332
+ self._raise(resp)
333
+ return resp.json()
334
+
335
+ def post_project_feature_metadata(
336
+ self, project_id: int, payload: dict[str, Any]
337
+ ) -> dict[str, Any]:
338
+ """POST ``/features/project/{project_id}/feature-metadata`` (body includes ``project_id`` matching path)."""
339
+ body = {**payload, "project_id": project_id}
340
+ resp = self._session.post(
341
+ f"{self.base_url}/features/project/{project_id}/feature-metadata",
342
+ headers=self._headers(),
343
+ json=body,
344
+ timeout=60,
345
+ )
346
+ self._raise(resp)
347
+ return resp.json()
348
+
349
+ def put_project_feature_metadata(
350
+ self, project_id: int, payload: dict[str, Any]
351
+ ) -> dict[str, Any]:
352
+ """PUT ``/features/project/{project_id}/feature-metadata``."""
353
+ resp = self._session.put(
354
+ f"{self.base_url}/features/project/{project_id}/feature-metadata",
355
+ headers=self._headers(),
356
+ json=payload,
357
+ timeout=60,
358
+ )
359
+ self._raise(resp)
360
+ return resp.json()
361
+
362
+ def delete_project_feature_metadata(self, project_id: int) -> dict[str, Any]:
363
+ """DELETE ``/features/project/{project_id}/feature-metadata``."""
364
+ resp = self._session.delete(
365
+ f"{self.base_url}/features/project/{project_id}/feature-metadata",
366
+ headers=self._headers(),
367
+ timeout=60,
368
+ )
369
+ self._raise(resp)
370
+ return resp.json()
371
+
372
+ # --- Projects ---
373
+
374
+ def get_projects(self, *, is_active: Optional[bool] = None) -> list[dict[str, Any]]:
375
+ """GET ``/projects`` with optional ``is_active`` filter."""
376
+ params: dict[str, Any] = {}
377
+ if is_active is not None:
378
+ params["is_active"] = is_active
379
+ resp = self._session.get(
380
+ f"{self.base_url}/projects",
381
+ headers=self._headers(),
382
+ params=params or None,
383
+ timeout=30,
384
+ )
385
+ self._raise(resp)
386
+ return resp.json()
387
+
388
+ # --- Credit quote ---
389
+
390
+ def quote_credits(self, org_id: int, payload: dict[str, Any]) -> dict[str, Any]:
391
+ """POST ``/billing/organizations/{org_id}/credit-quote``."""
392
+ resp = self._session.post(
393
+ f"{self.base_url}/billing/organizations/{org_id}/credit-quote",
394
+ headers=self._headers(),
395
+ json=payload,
396
+ timeout=30,
397
+ )
398
+ self._raise(resp)
399
+ return resp.json()
400
+
401
+ # --- KPI estimation (Phase 7) ---
402
+
403
+ def get_project_estimate_feature_status(self, project_id: int) -> dict[str, Any]:
404
+ """GET ``/runs/projects/{project_id}/estimate-feature-status``."""
405
+ resp = self._session.get(
406
+ f"{self.base_url}/runs/projects/{project_id}/estimate-feature-status",
407
+ headers=self._headers(),
408
+ timeout=30,
409
+ )
410
+ self._raise(resp)
411
+ return resp.json()
412
+
413
+ def estimate_run(self, payload: dict[str, Any]) -> dict[str, Any]:
414
+ """POST ``/runs/estimate``; general estimation from explicit feature_values."""
415
+ resp = self._session.post(
416
+ f"{self.base_url}/runs/estimate",
417
+ headers=self._headers(),
418
+ json=payload,
419
+ timeout=60,
420
+ )
421
+ self._raise(resp)
422
+ return resp.json()
423
+
424
+ def estimate_project(self, project_id: int, payload: dict[str, Any]) -> dict[str, Any]:
425
+ """POST ``/runs/projects/{project_id}/estimate``; feature values resolved server-side."""
426
+ resp = self._session.post(
427
+ f"{self.base_url}/runs/projects/{project_id}/estimate",
428
+ headers=self._headers(),
429
+ json=payload,
430
+ timeout=60,
431
+ )
432
+ self._raise(resp)
433
+ return resp.json()
434
+
435
+ def get_run_estimate(self, run_id: int) -> dict[str, Any]:
436
+ """GET ``/runs/{run_id}/estimate``; fetch persisted estimate for a completed run."""
437
+ resp = self._session.get(
438
+ f"{self.base_url}/runs/{run_id}/estimate",
439
+ headers=self._headers(),
440
+ timeout=30,
441
+ )
442
+ self._raise(resp)
443
+ return resp.json()
444
+
445
+ # --- Analyze / findings (CLI parity) ---
446
+
447
+ def analyze_project_files(
448
+ self, project_id: int, per_file: list[dict[str, Any]]
449
+ ) -> list[dict[str, Any]]:
450
+ """POST ``/analyze/project``; return findings (persist with ``post_findings_batch``)."""
451
+ files_payload = [
452
+ {"path": entry["rel_path"], "features": entry["features"]}
453
+ for entry in per_file
454
+ if isinstance(entry, dict) and "rel_path" in entry and "features" in entry
455
+ ]
456
+ if not files_payload:
457
+ return []
458
+ resp = self._session.post(
459
+ f"{self.base_url}/analyze/project",
460
+ headers=self._headers(),
461
+ json={"project_id": project_id, "files": files_payload},
462
+ timeout=60,
463
+ )
464
+ self._raise(resp)
465
+ data = resp.json()
466
+ findings = data.get("findings", [])
467
+ return findings if isinstance(findings, list) else []
468
+
469
+ def post_findings_batch(self, findings: list[dict[str, Any]]) -> dict[str, Any]:
470
+ """POST ``/findings/batch``."""
471
+ resp = self._session.post(
472
+ f"{self.base_url}/findings/batch",
473
+ headers=self._headers(),
474
+ json={"findings": findings},
475
+ timeout=30,
476
+ )
477
+ self._raise(resp)
478
+ return resp.json()
479
+
480
+ def get_project_findings(self, project_id: int) -> list[dict[str, Any]]:
481
+ resp = self._session.get(
482
+ f"{self.base_url}/findings/project/{project_id}",
483
+ headers=self._headers(),
484
+ timeout=30,
485
+ )
486
+ self._raise(resp)
487
+ data = resp.json()
488
+ return data if isinstance(data, list) else []
489
+
490
+ def get_surrogate_training_for_project(self, project_id: int) -> list[dict[str, Any]]:
491
+ resp = self._session.get(
492
+ f"{self.base_url}/surrogate-training/project/{project_id}",
493
+ headers=self._headers(),
494
+ timeout=30,
495
+ )
496
+ self._raise(resp)
497
+ data = resp.json()
498
+ return data if isinstance(data, list) else []
499
+
500
+ def submit_finding_feedback(
501
+ self,
502
+ finding_id: int,
503
+ *,
504
+ verdict: str,
505
+ note: str | None = None,
506
+ severity: str | None = None,
507
+ ) -> dict[str, Any]:
508
+ payload: dict[str, Any] = {"verdict": verdict}
509
+ if note:
510
+ payload["note"] = note
511
+ if severity:
512
+ payload["severity"] = severity
513
+ resp = self._session.post(
514
+ f"{self.base_url}/findings/{finding_id}/feedback",
515
+ headers=self._headers(),
516
+ json=payload,
517
+ timeout=15,
518
+ )
519
+ self._raise(resp)
520
+ return resp.json()
@@ -0,0 +1,16 @@
1
+ [project]
2
+ name = "netgreener_sdk"
3
+ version = "0.1.0"
4
+ description = "Thin Python client for the NetGreener HTTP API (auth, run sessions, surrogate training)."
5
+ authors = [{ name = "Saeed Khazaee", email = "saeed@netgreener.com" }]
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "requests>=2.28.0,<3",
9
+ ]
10
+
11
+ [project.optional-dependencies]
12
+ dev = ["pytest>=7.0"]
13
+
14
+ [build-system]
15
+ requires = ["hatchling>=1.21.0"]
16
+ build-backend = "hatchling.build"
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+
7
+ # Ensure imports resolve when tests are run from monorepo root:
8
+ # `python -m pytest netgreener_sdk/tests -q`
9
+ repo_pkg_root = Path(__file__).resolve().parents[1]
10
+ if str(repo_pkg_root) not in sys.path:
11
+ sys.path.insert(0, str(repo_pkg_root))
@@ -0,0 +1,332 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import pytest
6
+
7
+ from netgreener_sdk.client import APIError, NetGreenerClient
8
+
9
+
10
+ class DummyResponse:
11
+ def __init__(self, status_code: int, payload: Any = None, text: str = "") -> None:
12
+ self.status_code = status_code
13
+ self._payload = payload
14
+ self.text = text
15
+
16
+ @property
17
+ def ok(self) -> bool:
18
+ return 200 <= self.status_code < 300
19
+
20
+ def json(self):
21
+ if isinstance(self._payload, Exception):
22
+ raise self._payload
23
+ return self._payload
24
+
25
+
26
+ class DummySession:
27
+ def __init__(self) -> None:
28
+ self.calls: list[tuple[str, str, dict[str, Any]]] = []
29
+ self.responses: list[DummyResponse] = []
30
+
31
+ def enqueue(self, response: DummyResponse) -> None:
32
+ self.responses.append(response)
33
+
34
+ def _pop(self, method: str, url: str, **kwargs):
35
+ self.calls.append((method, url, kwargs))
36
+ assert self.responses, "No queued response for request"
37
+ return self.responses.pop(0)
38
+
39
+ def get(self, url, **kwargs):
40
+ return self._pop("GET", url, **kwargs)
41
+
42
+ def post(self, url, **kwargs):
43
+ return self._pop("POST", url, **kwargs)
44
+
45
+ def put(self, url, **kwargs):
46
+ return self._pop("PUT", url, **kwargs)
47
+
48
+ def delete(self, url, **kwargs):
49
+ return self._pop("DELETE", url, **kwargs)
50
+
51
+
52
+ def _client_and_session() -> tuple[NetGreenerClient, DummySession]:
53
+ client = NetGreenerClient(base_url="http://localhost:8000/api/v1")
54
+ session = DummySession()
55
+ client._session = session # type: ignore[attr-defined]
56
+ return client, session
57
+
58
+
59
+ def test_login_sets_token():
60
+ client, session = _client_and_session()
61
+ session.enqueue(
62
+ DummyResponse(
63
+ 200,
64
+ {"access_token": "abc", "user": {"user_id": 1}, "token_type": "bearer"},
65
+ )
66
+ )
67
+
68
+ data = client.login("a@b.com", "pw")
69
+
70
+ assert data["access_token"] == "abc"
71
+ assert client.get_access_token() == "abc"
72
+ method, url, kwargs = session.calls[0]
73
+ assert method == "POST"
74
+ assert url.endswith("/auth/login")
75
+ assert kwargs["json"]["email"] == "a@b.com"
76
+
77
+
78
+ def test_get_me_uses_auth_header():
79
+ client, session = _client_and_session()
80
+ client.set_token("tok")
81
+ session.enqueue(DummyResponse(200, {"user_id": 7, "email": "x@y.com"}))
82
+
83
+ payload = client.get_me()
84
+
85
+ assert payload["user_id"] == 7
86
+ method, url, kwargs = session.calls[0]
87
+ assert method == "GET"
88
+ assert url.endswith("/auth/me")
89
+ assert kwargs["headers"]["Authorization"] == "Bearer tok"
90
+
91
+
92
+ def test_quote_credits_uses_billing_quote_path():
93
+ client, session = _client_and_session()
94
+ session.enqueue(DummyResponse(200, {"action": "run_estimate", "total_credits": 12}))
95
+
96
+ data = client.quote_credits(3, {"action": "run_estimate", "reported_input_tokens": 1000})
97
+
98
+ assert data["total_credits"] == 12
99
+ method, url, kwargs = session.calls[0]
100
+ assert method == "POST"
101
+ assert url.endswith("/billing/organizations/3/credit-quote")
102
+ assert kwargs["json"]["action"] == "run_estimate"
103
+
104
+
105
+ def test_feature_definitions_crud_paths_and_encoding():
106
+ client, session = _client_and_session()
107
+ client.set_token("tok")
108
+ session.enqueue(DummyResponse(200, [{"feature_name": "f1"}])) # list
109
+ session.enqueue(DummyResponse(200, {"feature_name": "f/a"})) # create
110
+ session.enqueue(DummyResponse(200, {"feature_name": "f/a"})) # get
111
+ session.enqueue(DummyResponse(200, {"feature_name": "f/a"})) # update
112
+ session.enqueue(DummyResponse(200, {"success": True})) # delete
113
+
114
+ client.list_feature_definitions()
115
+ client.post_feature_definition({"feature_name": "f/a"})
116
+ client.get_feature_definition("f/a")
117
+ client.update_feature_definition("f/a", {"description": "x"})
118
+ client.delete_feature_definition("f/a")
119
+
120
+ assert session.calls[0][1].endswith("/features/definitions")
121
+ assert session.calls[1][1].endswith("/features/definitions")
122
+ assert session.calls[2][1].endswith("/features/definitions/f%2Fa")
123
+ assert session.calls[3][1].endswith("/features/definitions/f%2Fa")
124
+ assert session.calls[4][1].endswith("/features/definitions/f%2Fa")
125
+
126
+
127
+ def test_project_feature_metadata_post_injects_project_id():
128
+ client, session = _client_and_session()
129
+ client.set_token("tok")
130
+ session.enqueue(DummyResponse(200, {"project_id": 2}))
131
+
132
+ client.post_project_feature_metadata(2, {"parquet_path": "/tmp/f.parquet"})
133
+
134
+ method, url, kwargs = session.calls[0]
135
+ assert method == "POST"
136
+ assert url.endswith("/features/project/2/feature-metadata")
137
+ assert kwargs["json"]["project_id"] == 2
138
+ assert kwargs["json"]["parquet_path"] == "/tmp/f.parquet"
139
+
140
+
141
+ def test_api_error_uses_detail_from_json():
142
+ client, session = _client_and_session()
143
+ session.enqueue(DummyResponse(404, {"detail": "Feature not found"}))
144
+
145
+ with pytest.raises(APIError) as exc:
146
+ client.get_feature_definition("missing")
147
+ assert exc.value.status_code == 404
148
+ assert "Feature not found" in exc.value.detail
149
+
150
+
151
+ def test_api_error_falls_back_to_text_when_json_unavailable():
152
+ client, session = _client_and_session()
153
+ session.enqueue(DummyResponse(500, payload=ValueError("bad json"), text="server exploded"))
154
+
155
+ with pytest.raises(APIError) as exc:
156
+ client.get_feature_definition("anything")
157
+ assert exc.value.status_code == 500
158
+ assert "server exploded" in exc.value.detail
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Run sessions
163
+ # ---------------------------------------------------------------------------
164
+
165
+
166
+ def test_get_run_session_uses_correct_path_and_token():
167
+ client, session = _client_and_session()
168
+ client.set_token("tok")
169
+ session.enqueue(DummyResponse(200, {"run_id": 42, "project_id": 1}))
170
+
171
+ data = client.get_run_session(42)
172
+
173
+ assert data["run_id"] == 42
174
+ method, url, kwargs = session.calls[0]
175
+ assert method == "GET"
176
+ assert url.endswith("/runsessions/42")
177
+ assert kwargs["headers"]["Authorization"] == "Bearer tok"
178
+
179
+
180
+ def test_get_run_sessions_no_filter():
181
+ client, session = _client_and_session()
182
+ client.set_token("tok")
183
+ session.enqueue(DummyResponse(200, [{"run_id": 1}, {"run_id": 2}]))
184
+
185
+ data = client.get_run_sessions()
186
+
187
+ assert len(data) == 2
188
+ method, url, kwargs = session.calls[0]
189
+ assert method == "GET"
190
+ assert url.endswith("/runsessions")
191
+ assert kwargs.get("params") is None
192
+
193
+
194
+ def test_get_run_sessions_project_filter():
195
+ client, session = _client_and_session()
196
+ client.set_token("tok")
197
+ session.enqueue(DummyResponse(200, [{"run_id": 5}]))
198
+
199
+ client.get_run_sessions(project_id=7)
200
+
201
+ _, _, kwargs = session.calls[0]
202
+ assert kwargs["params"]["project_id"] == 7
203
+
204
+
205
+ def test_post_run_session_sends_payload():
206
+ client, session = _client_and_session()
207
+ client.set_token("tok")
208
+ payload = {"project_id": 3, "start_time": "2026-01-01T00:00:00Z", "end_time": "2026-01-01T01:00:00Z"}
209
+ session.enqueue(DummyResponse(200, {"run_id": 99}))
210
+
211
+ data = client.post_run_session(payload)
212
+
213
+ assert data["run_id"] == 99
214
+ method, url, kwargs = session.calls[0]
215
+ assert method == "POST"
216
+ assert url.endswith("/runsessions")
217
+ assert kwargs["json"]["project_id"] == 3
218
+
219
+
220
+ # ---------------------------------------------------------------------------
221
+ # Projects
222
+ # ---------------------------------------------------------------------------
223
+
224
+
225
+ def test_get_projects_no_filter():
226
+ client, session = _client_and_session()
227
+ client.set_token("tok")
228
+ session.enqueue(DummyResponse(200, [{"project_id": 1, "name": "P1"}, {"project_id": 2, "name": "P2"}]))
229
+
230
+ data = client.get_projects()
231
+
232
+ assert len(data) == 2
233
+ method, url, kwargs = session.calls[0]
234
+ assert method == "GET"
235
+ assert url.endswith("/projects")
236
+ assert kwargs.get("params") is None
237
+
238
+
239
+ def test_get_projects_active_filter():
240
+ client, session = _client_and_session()
241
+ client.set_token("tok")
242
+ session.enqueue(DummyResponse(200, [{"project_id": 1}]))
243
+
244
+ client.get_projects(is_active=True)
245
+
246
+ _, _, kwargs = session.calls[0]
247
+ assert kwargs["params"]["is_active"] is True
248
+
249
+
250
+ # ---------------------------------------------------------------------------
251
+ # KPI estimation (Phase 7)
252
+ # ---------------------------------------------------------------------------
253
+
254
+
255
+ def test_get_project_estimate_feature_status():
256
+ client, session = _client_and_session()
257
+ client.set_token("tok")
258
+ session.enqueue(DummyResponse(200, {"project_id": 3, "feature_count": 12, "feature_source": "surrogate_training_data", "feature_keys": ["k1", "k2"]}))
259
+
260
+ data = client.get_project_estimate_feature_status(3)
261
+
262
+ assert data["feature_count"] == 12
263
+ method, url, kwargs = session.calls[0]
264
+ assert method == "GET"
265
+ assert url.endswith("/runs/projects/3/estimate-feature-status")
266
+ assert kwargs["headers"]["Authorization"] == "Bearer tok"
267
+
268
+
269
+ def test_estimate_run_posts_feature_values():
270
+ client, session = _client_and_session()
271
+ client.set_token("tok")
272
+ payload = {"feature_values": {"num_layers": 4}, "use_llm": False}
273
+ session.enqueue(DummyResponse(200, {"estimates": {"energy_kwh": {"low": 0.1, "high": 0.5}}, "confidence": "low", "estimate_mode": "stub"}))
274
+
275
+ data = client.estimate_run(payload)
276
+
277
+ assert "estimates" in data
278
+ method, url, kwargs = session.calls[0]
279
+ assert method == "POST"
280
+ assert url.endswith("/runs/estimate")
281
+ assert kwargs["json"]["use_llm"] is False
282
+
283
+
284
+ def test_estimate_project_uses_project_path():
285
+ client, session = _client_and_session()
286
+ client.set_token("tok")
287
+ session.enqueue(DummyResponse(200, {"estimates": {}, "confidence": "medium", "estimate_mode": "llm", "model": "gpt-4o"}))
288
+
289
+ data = client.estimate_project(5, {"use_llm": True, "feature_values": {}})
290
+
291
+ assert data["confidence"] == "medium"
292
+ method, url, kwargs = session.calls[0]
293
+ assert method == "POST"
294
+ assert url.endswith("/runs/projects/5/estimate")
295
+ assert kwargs["json"]["use_llm"] is True
296
+
297
+
298
+ def test_get_run_estimate_uses_correct_path():
299
+ client, session = _client_and_session()
300
+ client.set_token("tok")
301
+ session.enqueue(DummyResponse(200, {"run_id": 7, "project_id": 2, "run_estimate_v0": {"confidence": "high"}}))
302
+
303
+ data = client.get_run_estimate(7)
304
+
305
+ assert data["run_id"] == 7
306
+ method, url, _ = session.calls[0]
307
+ assert method == "GET"
308
+ assert url.endswith("/runs/7/estimate")
309
+
310
+
311
+ def test_estimate_project_raises_on_422():
312
+ client, session = _client_and_session()
313
+ client.set_token("tok")
314
+ session.enqueue(DummyResponse(422, {"detail": "No extracted feature values found"}))
315
+
316
+ with pytest.raises(APIError) as exc:
317
+ client.estimate_project(9, {"feature_values": {}, "use_llm": True})
318
+ assert exc.value.status_code == 422
319
+ assert "No extracted feature values found" in exc.value.detail
320
+
321
+
322
+ def test_org_context_headers_forwarded():
323
+ client, session = _client_and_session()
324
+ client.set_token("tok")
325
+ client.set_org_context(org_id=42, environment="prod")
326
+ session.enqueue(DummyResponse(200, {}))
327
+
328
+ client.get_projects()
329
+
330
+ _, _, kwargs = session.calls[0]
331
+ assert kwargs["headers"]["X-Organization-ID"] == "42"
332
+ assert kwargs["headers"]["X-Environment"] == "prod"