data-hub-watcher 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,237 @@
1
+ from __future__ import annotations
2
+ import logging
3
+ import os
4
+ from datetime import datetime, timezone
5
+ from typing import Any
6
+
7
+ import requests
8
+
9
+ from data_hub_watcher.models import (
10
+ ApiErrorDetail,
11
+ ConfigChecksumResponse,
12
+ EventsResponse,
13
+ FileResponse,
14
+ HeartbeatResponse,
15
+ InstrumentDetailResponse,
16
+ InstrumentResponse,
17
+ PresignedUploadResponse,
18
+ RegisterWatcherResponse,
19
+ RunDetailResponse,
20
+ RunResponse,
21
+ UploadQueueResponse,
22
+ WatcherUpdateInfoResponse,
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ApiError(Exception):
29
+ """Raised when the Data Hub API returns a non-2xx response."""
30
+
31
+ def __init__(
32
+ self,
33
+ message: str,
34
+ status_code: int = 0,
35
+ detail: ApiErrorDetail | None = None,
36
+ ) -> None:
37
+ super().__init__(message)
38
+ self.message = message
39
+ self.status_code = status_code
40
+ self.detail = detail
41
+
42
+
43
+ DEFAULT_TIMEOUT: tuple[float, float] = (5, 30) # (connect, read) seconds
44
+
45
+
46
+ class DataHubClient:
47
+ """HTTP client for the Data Hub API."""
48
+
49
+ def __init__(
50
+ self,
51
+ base_url: str,
52
+ api_key: str | None = None,
53
+ timeout: tuple[float, float] = DEFAULT_TIMEOUT,
54
+ ) -> None:
55
+ self.base_url = base_url.rstrip("/")
56
+ self._timeout = timeout
57
+ # A persistent session reuses TCP connections across requests, which
58
+ # matters when the watcher is long-running and chatting with the API
59
+ # every heartbeat interval.
60
+ self._session = requests.Session()
61
+
62
+ # Allow the API key to be passed explicitly (e.g. during `init`) or
63
+ # fall back to the environment variable for normal operation.
64
+ key = api_key or os.environ.get("DATA_HUB_API_KEY", "")
65
+ if key:
66
+ self._session.headers["Authorization"] = f"Bearer {key}"
67
+
68
+ # ------------------------------------------------------------------
69
+ # Internal helpers
70
+ # ------------------------------------------------------------------
71
+
72
+ def _url(self, path: str) -> str:
73
+ return f"{self.base_url}{path}"
74
+
75
+ def _handle_error(self, resp: requests.Response) -> None:
76
+ """Parse an error body and raise `ApiError`."""
77
+ detail: ApiErrorDetail | None = None
78
+ try:
79
+ body = resp.json()
80
+ if "error" in body:
81
+ detail = ApiErrorDetail.model_validate(body["error"])
82
+ msg = detail.message
83
+ else:
84
+ msg = resp.text
85
+ except Exception:
86
+ msg = resp.text
87
+ raise ApiError(msg, status_code=resp.status_code, detail=detail)
88
+
89
+ def _request(
90
+ self,
91
+ method: str,
92
+ path: str,
93
+ *,
94
+ json: dict[str, Any] | None = None,
95
+ params: dict[str, Any] | None = None,
96
+ ) -> requests.Response:
97
+ try:
98
+ resp = self._session.request(
99
+ method, self._url(path), json=json, params=params, timeout=self._timeout
100
+ )
101
+ except requests.ConnectionError as exc:
102
+ raise ApiError(f"Connection error: {exc}") from exc
103
+ except requests.Timeout as exc:
104
+ raise ApiError(f"Request timed out: {exc}") from exc
105
+
106
+ if not resp.ok:
107
+ self._handle_error(resp)
108
+ return resp
109
+
110
+ # ------------------------------------------------------------------
111
+ # Instruments
112
+ # ------------------------------------------------------------------
113
+
114
+ def list_instruments(self) -> list[InstrumentResponse]:
115
+ resp = self._request("GET", "/instruments")
116
+ return [InstrumentResponse.model_validate(item) for item in resp.json()]
117
+
118
+ def create_instrument(self, id: str, display_name: str | None = None) -> InstrumentResponse:
119
+ payload: dict[str, Any] = {"id": id}
120
+ if display_name:
121
+ payload["display_name"] = display_name
122
+ resp = self._request("POST", "/instruments", json=payload)
123
+ return InstrumentResponse.model_validate(resp.json())
124
+
125
+ def get_instrument(self, instrument_id: str) -> InstrumentDetailResponse:
126
+ resp = self._request("GET", f"/instruments/{instrument_id}")
127
+ return InstrumentDetailResponse.model_validate(resp.json())
128
+
129
+ # ------------------------------------------------------------------
130
+ # Watchers
131
+ # ------------------------------------------------------------------
132
+
133
+ def register_watcher(
134
+ self,
135
+ instrument_id: str,
136
+ hostname: str | None = None,
137
+ os_info: str | None = None,
138
+ ) -> RegisterWatcherResponse:
139
+ payload: dict[str, Any] = {"instrument_id": instrument_id}
140
+ if hostname:
141
+ payload["hostname"] = hostname
142
+ if os_info:
143
+ payload["os_info"] = os_info
144
+ resp = self._request("POST", "/watchers/register", json=payload)
145
+ return RegisterWatcherResponse.model_validate(resp.json())
146
+
147
+ def push_config(
148
+ self, watcher_id: str, config_yaml: str, checksum: str
149
+ ) -> ConfigChecksumResponse:
150
+ resp = self._request(
151
+ "PUT",
152
+ f"/watchers/{watcher_id}/config",
153
+ json={"config_yaml": config_yaml, "config_checksum": checksum},
154
+ )
155
+ return ConfigChecksumResponse.model_validate(resp.json())
156
+
157
+ def get_config_checksum(self, watcher_id: str) -> ConfigChecksumResponse | None:
158
+ """Return the remote checksum, or `None` if no config has been pushed.
159
+
160
+ A 404 is expected for newly registered watchers that haven't pushed
161
+ config yet — it is not an error condition.
162
+ """
163
+ try:
164
+ resp = self._request("GET", f"/watchers/{watcher_id}/config-checksum")
165
+ return ConfigChecksumResponse.model_validate(resp.json())
166
+ except ApiError as exc:
167
+ if exc.status_code == 404:
168
+ return None
169
+ raise
170
+
171
+ def send_heartbeat(self, watcher_id: str, payload: dict[str, Any]) -> HeartbeatResponse:
172
+ resp = self._request("POST", f"/watchers/{watcher_id}/heartbeat", json=payload)
173
+ return HeartbeatResponse.model_validate(resp.json())
174
+
175
+ def send_events(self, watcher_id: str, events: list[dict[str, Any]]) -> EventsResponse:
176
+ resp = self._request("POST", f"/watchers/{watcher_id}/events", json={"events": events})
177
+ return EventsResponse.model_validate(resp.json())
178
+
179
+ def get_update_info(self, watcher_id: str) -> WatcherUpdateInfoResponse:
180
+ """Fetch server-reported watcher release metadata.
181
+
182
+ Used by `self-update` and the in-process updater to decide whether
183
+ the running watcher should upgrade itself.
184
+ """
185
+ resp = self._request("GET", f"/watchers/{watcher_id}/update-check")
186
+ return WatcherUpdateInfoResponse.model_validate(resp.json())
187
+
188
+ # ------------------------------------------------------------------
189
+ # Runs
190
+ # ------------------------------------------------------------------
191
+
192
+ def report_run(self, instrument_id: str, run_data: dict[str, Any]) -> RunResponse:
193
+ resp = self._request("POST", f"/instruments/{instrument_id}/runs", json=run_data)
194
+ return RunResponse.model_validate(resp.json())
195
+
196
+ def update_run(
197
+ self, instrument_id: str, run_id: str, data: dict[str, Any]
198
+ ) -> RunDetailResponse:
199
+ resp = self._request("PATCH", f"/instruments/{instrument_id}/runs/{run_id}", json=data)
200
+ return RunDetailResponse.model_validate(resp.json())
201
+
202
+ # ------------------------------------------------------------------
203
+ # Upload queue / files
204
+ # ------------------------------------------------------------------
205
+
206
+ def get_upload_queue(self, watcher_id: str) -> UploadQueueResponse:
207
+ resp = self._request("GET", f"/watchers/{watcher_id}/upload-queue")
208
+ return UploadQueueResponse.model_validate(resp.json())
209
+
210
+ def request_upload_url(
211
+ self,
212
+ instrument_id: str,
213
+ run_id: str,
214
+ filename: str,
215
+ content_type: str | None = None,
216
+ size_bytes: int | None = None,
217
+ file_created_at_ts: float | None = None,
218
+ ) -> PresignedUploadResponse:
219
+ payload: dict[str, Any] = {"filename": filename}
220
+ if content_type:
221
+ payload["content_type"] = content_type
222
+ if size_bytes is not None:
223
+ payload["size_bytes"] = size_bytes
224
+ if file_created_at_ts:
225
+ payload["file_created_at"] = datetime.fromtimestamp(
226
+ file_created_at_ts, tz=timezone.utc
227
+ ).isoformat()
228
+ resp = self._request(
229
+ "POST",
230
+ f"/instruments/{instrument_id}/runs/{run_id}/request-upload-url",
231
+ json=payload,
232
+ )
233
+ return PresignedUploadResponse.model_validate(resp.json())
234
+
235
+ def mark_file_uploaded(self, file_id: int, s3_info: dict[str, Any]) -> FileResponse:
236
+ resp = self._request("PATCH", f"/files/{file_id}", json=s3_info)
237
+ return FileResponse.model_validate(resp.json())