juniper-recurrence-client 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,104 @@
1
+ Metadata-Version: 2.4
2
+ Name: juniper-recurrence-client
3
+ Version: 0.1.0
4
+ Summary: HTTP client for the juniper-recurrence service (the Δt-native LMU recurrence model + cross-validation API)
5
+ Author: Paul Calnon
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/pcalnon/juniper-recurrence
8
+ Project-URL: Repository, https://github.com/pcalnon/juniper-recurrence
9
+ Project-URL: Issues, https://github.com/pcalnon/juniper-recurrence/issues
10
+ Keywords: juniper,recurrence,lmu,http-client,rest,time-series,cross-validation
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Programming Language :: Python :: 3.14
20
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
22
+ Requires-Python: >=3.12
23
+ Description-Content-Type: text/markdown
24
+ Requires-Dist: requests>=2.28.0
25
+ Requires-Dist: urllib3>=2.0.0
26
+ Provides-Extra: test
27
+ Requires-Dist: pytest>=8.0; extra == "test"
28
+ Requires-Dist: pytest-cov>=5.0; extra == "test"
29
+ Requires-Dist: responses>=0.23; extra == "test"
30
+ Provides-Extra: observability
31
+ Requires-Dist: juniper-observability>=0.3.1; extra == "observability"
32
+
33
+ # juniper-recurrence-client
34
+
35
+ HTTP client for the **juniper-recurrence** service — the FastAPI app wrapping the Δt-native LMU
36
+ recurrence model and its cross-validation API. A lean `requests`-based client mirroring
37
+ [`juniper-data-client`](https://github.com/pcalnon/juniper-data-client) and
38
+ [`juniper-cascor-client`](https://github.com/pcalnon/juniper-cascor-client), so consumers
39
+ (notably **juniper-canopy**'s recurrence backend adapter) drive every Juniper backend the same way.
40
+
41
+ ## Install
42
+
43
+ ```bash
44
+ pip install juniper-recurrence-client # once published
45
+ pip install -e ".[test]" # local development
46
+ ```
47
+
48
+ `requests`-only at the core; `pip install juniper-recurrence-client[observability]` adds the
49
+ optional `juniper-observability` integration (X-Request-ID propagation + the `on_request` hook).
50
+
51
+ ## Quick start
52
+
53
+ ```python
54
+ from juniper_recurrence_client import JuniperRecurrenceClient
55
+
56
+ client = JuniperRecurrenceClient("http://localhost:8211", api_key="…")
57
+
58
+ # Train the LMU regressor on a dataset (by id / name / generator)
59
+ client.train(name="equities", d=16)
60
+
61
+ # Predict — inline X with Δt, or a dataset reference
62
+ client.predict(dataset_id="ds-1")
63
+
64
+ # Walk-forward cross-validation over the dataset's _full split
65
+ result = client.crossval(name="equities", n_folds=4, scheme="expanding", embargo=2)
66
+ print(result["eval_aggregate"])
67
+
68
+ # Inspect
69
+ client.get_model() # topology + metrics
70
+ client.training_status() # state + events
71
+ client.is_ready() # readiness probe
72
+ ```
73
+
74
+ ## API surface
75
+
76
+ | Method | Endpoint |
77
+ |--------|----------|
78
+ | `train(*, dataset_id / name / generator, params, split, d, theta, ridge)` | `POST /v1/train` |
79
+ | `training_status()` | `GET /v1/training/status` |
80
+ | `predict(*, X / dt / target_dt / seq_lengths, or a dataset ref)` | `POST /v1/predict` |
81
+ | `crossval(*, n_folds, scheme, embargo, min_train, dataset ref, d, theta, ridge)` | `POST /v1/crossval` |
82
+ | `crossval_status()` | `GET /v1/crossval/status` |
83
+ | `get_model()` | `GET /v1/model` |
84
+ | `get_dataset()` | `GET /v1/dataset` |
85
+ | `health_check()` / `is_ready()` / `wait_for_ready()` | `GET /v1/health[/ready]` |
86
+
87
+ ## Authentication
88
+
89
+ Pass `api_key=…`, or set `JUNIPER_RECURRENCE_API_KEY` (or the Docker-secret
90
+ `JUNIPER_RECURRENCE_API_KEY_FILE`, a path whose stripped contents are the key). The key is sent
91
+ as the `X-API-Key` header. Note the asymmetry: the **server** reads the *plural*
92
+ `JUNIPER_RECURRENCE_API_KEYS` (its accepted set); the **client** sends one key under the
93
+ *singular* env var.
94
+
95
+ ## Errors
96
+
97
+ All errors derive from `JuniperRecurrenceClientError`: `JuniperRecurrenceConnectionError`,
98
+ `JuniperRecurrenceTimeoutError`, `JuniperRecurrenceNotFoundError` (404),
99
+ `JuniperRecurrenceConflictError` (409 — a run already in progress, or no trained model yet),
100
+ `JuniperRecurrenceValidationError` (400/422), `JuniperRecurrenceConfigurationError`.
101
+
102
+ ## License
103
+
104
+ MIT — see [LICENSE](https://github.com/pcalnon/juniper-recurrence/blob/main/LICENSE).
@@ -0,0 +1,72 @@
1
+ # juniper-recurrence-client
2
+
3
+ HTTP client for the **juniper-recurrence** service — the FastAPI app wrapping the Δt-native LMU
4
+ recurrence model and its cross-validation API. A lean `requests`-based client mirroring
5
+ [`juniper-data-client`](https://github.com/pcalnon/juniper-data-client) and
6
+ [`juniper-cascor-client`](https://github.com/pcalnon/juniper-cascor-client), so consumers
7
+ (notably **juniper-canopy**'s recurrence backend adapter) drive every Juniper backend the same way.
8
+
9
+ ## Install
10
+
11
+ ```bash
12
+ pip install juniper-recurrence-client # once published
13
+ pip install -e ".[test]" # local development
14
+ ```
15
+
16
+ `requests`-only at the core; `pip install juniper-recurrence-client[observability]` adds the
17
+ optional `juniper-observability` integration (X-Request-ID propagation + the `on_request` hook).
18
+
19
+ ## Quick start
20
+
21
+ ```python
22
+ from juniper_recurrence_client import JuniperRecurrenceClient
23
+
24
+ client = JuniperRecurrenceClient("http://localhost:8211", api_key="…")
25
+
26
+ # Train the LMU regressor on a dataset (by id / name / generator)
27
+ client.train(name="equities", d=16)
28
+
29
+ # Predict — inline X with Δt, or a dataset reference
30
+ client.predict(dataset_id="ds-1")
31
+
32
+ # Walk-forward cross-validation over the dataset's _full split
33
+ result = client.crossval(name="equities", n_folds=4, scheme="expanding", embargo=2)
34
+ print(result["eval_aggregate"])
35
+
36
+ # Inspect
37
+ client.get_model() # topology + metrics
38
+ client.training_status() # state + events
39
+ client.is_ready() # readiness probe
40
+ ```
41
+
42
+ ## API surface
43
+
44
+ | Method | Endpoint |
45
+ |--------|----------|
46
+ | `train(*, dataset_id / name / generator, params, split, d, theta, ridge)` | `POST /v1/train` |
47
+ | `training_status()` | `GET /v1/training/status` |
48
+ | `predict(*, X / dt / target_dt / seq_lengths, or a dataset ref)` | `POST /v1/predict` |
49
+ | `crossval(*, n_folds, scheme, embargo, min_train, dataset ref, d, theta, ridge)` | `POST /v1/crossval` |
50
+ | `crossval_status()` | `GET /v1/crossval/status` |
51
+ | `get_model()` | `GET /v1/model` |
52
+ | `get_dataset()` | `GET /v1/dataset` |
53
+ | `health_check()` / `is_ready()` / `wait_for_ready()` | `GET /v1/health[/ready]` |
54
+
55
+ ## Authentication
56
+
57
+ Pass `api_key=…`, or set `JUNIPER_RECURRENCE_API_KEY` (or the Docker-secret
58
+ `JUNIPER_RECURRENCE_API_KEY_FILE`, a path whose stripped contents are the key). The key is sent
59
+ as the `X-API-Key` header. Note the asymmetry: the **server** reads the *plural*
60
+ `JUNIPER_RECURRENCE_API_KEYS` (its accepted set); the **client** sends one key under the
61
+ *singular* env var.
62
+
63
+ ## Errors
64
+
65
+ All errors derive from `JuniperRecurrenceClientError`: `JuniperRecurrenceConnectionError`,
66
+ `JuniperRecurrenceTimeoutError`, `JuniperRecurrenceNotFoundError` (404),
67
+ `JuniperRecurrenceConflictError` (409 — a run already in progress, or no trained model yet),
68
+ `JuniperRecurrenceValidationError` (400/422), `JuniperRecurrenceConfigurationError`.
69
+
70
+ ## License
71
+
72
+ MIT — see [LICENSE](https://github.com/pcalnon/juniper-recurrence/blob/main/LICENSE).
@@ -0,0 +1,34 @@
1
+ """juniper-recurrence-client — HTTP client for the juniper-recurrence service.
2
+
3
+ A lean ``requests``-based client wrapping the juniper-recurrence FastAPI app's REST surface
4
+ (train / predict / cross-validate / inspect / health). Mirrors juniper-data-client and
5
+ juniper-cascor-client so consumers (notably juniper-canopy's recurrence backend adapter) drive
6
+ every Juniper backend the same way.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from juniper_recurrence_client._version import __version__
12
+ from juniper_recurrence_client.client import JuniperRecurrenceClient, RequestHook
13
+ from juniper_recurrence_client.exceptions import (
14
+ JuniperRecurrenceClientError,
15
+ JuniperRecurrenceConfigurationError,
16
+ JuniperRecurrenceConflictError,
17
+ JuniperRecurrenceConnectionError,
18
+ JuniperRecurrenceNotFoundError,
19
+ JuniperRecurrenceTimeoutError,
20
+ JuniperRecurrenceValidationError,
21
+ )
22
+
23
+ __all__ = [
24
+ "__version__",
25
+ "JuniperRecurrenceClient",
26
+ "RequestHook",
27
+ "JuniperRecurrenceClientError",
28
+ "JuniperRecurrenceConnectionError",
29
+ "JuniperRecurrenceTimeoutError",
30
+ "JuniperRecurrenceNotFoundError",
31
+ "JuniperRecurrenceConflictError",
32
+ "JuniperRecurrenceValidationError",
33
+ "JuniperRecurrenceConfigurationError",
34
+ ]
@@ -0,0 +1,7 @@
1
+ """Single source of truth for the juniper-recurrence-client version.
2
+
3
+ Kept import-free so setuptools can parse ``__version__`` statically at build time
4
+ (``[tool.setuptools.dynamic]`` in pyproject.toml) without importing requests.
5
+ """
6
+
7
+ __version__ = "0.1.0"
@@ -0,0 +1,433 @@
1
+ """REST API client for the juniper-recurrence service.
2
+
3
+ A lean ``requests``-based client wrapping the juniper-recurrence FastAPI app's REST surface
4
+ (train / predict / cross-validate / inspect / health), for consumers such as juniper-canopy's
5
+ recurrence backend adapter. Mirrors juniper-data-client's transport machinery: an idempotent-only
6
+ retry policy, ``X-API-Key`` auth with ``_FILE`` Docker-secret indirection, typed exceptions, the
7
+ optional ``on_request`` instrumentation hook, and best-effort ``X-Request-ID`` propagation.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ import os
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Any, Callable, Optional
17
+ from urllib.parse import urlparse
18
+
19
+ import requests
20
+ from requests.adapters import HTTPAdapter
21
+ from urllib3.util.retry import Retry
22
+
23
+ from juniper_recurrence_client.constants import (
24
+ API_KEY_ENV_VAR,
25
+ API_KEY_FILE_ENV_VAR,
26
+ API_KEY_HEADER_NAME,
27
+ API_VERSION_PATH_SUFFIX,
28
+ DEFAULT_BACKOFF_FACTOR,
29
+ DEFAULT_BASE_URL,
30
+ DEFAULT_READY_POLL_INTERVAL,
31
+ DEFAULT_READY_TIMEOUT,
32
+ DEFAULT_RETRIES,
33
+ DEFAULT_TIMEOUT,
34
+ DEFAULT_URL_SCHEME_PREFIX,
35
+ ENDPOINT_CROSSVAL,
36
+ ENDPOINT_CROSSVAL_STATUS,
37
+ ENDPOINT_DATASET,
38
+ ENDPOINT_HEALTH,
39
+ ENDPOINT_HEALTH_READY,
40
+ ENDPOINT_MODEL,
41
+ ENDPOINT_PREDICT,
42
+ ENDPOINT_TRAIN,
43
+ ENDPOINT_TRAINING_STATUS,
44
+ HEALTH_READY_STATUS,
45
+ HTTP_400_BAD_REQUEST,
46
+ HTTP_404_NOT_FOUND,
47
+ HTTP_409_CONFLICT,
48
+ HTTP_422_UNPROCESSABLE_ENTITY,
49
+ HTTP_POOL_CONNECTIONS,
50
+ HTTP_POOL_MAXSIZE,
51
+ RETRY_ALLOWED_METHODS,
52
+ RETRYABLE_STATUS_CODES,
53
+ URL_SCHEME_PREFIXES,
54
+ )
55
+ from juniper_recurrence_client.exceptions import (
56
+ JuniperRecurrenceClientError,
57
+ JuniperRecurrenceConflictError,
58
+ JuniperRecurrenceConnectionError,
59
+ JuniperRecurrenceNotFoundError,
60
+ JuniperRecurrenceTimeoutError,
61
+ JuniperRecurrenceValidationError,
62
+ )
63
+
64
+ logger = logging.getLogger("juniper_recurrence_client.client")
65
+
66
+
67
+ def _resolve_api_key_from_env() -> Optional[str]:
68
+ """Resolve the juniper-recurrence API key from the environment.
69
+
70
+ Honors the Docker-secret ``JUNIPER_RECURRENCE_API_KEY_FILE`` indirection (a file whose
71
+ stripped contents are the key) before the plain ``JUNIPER_RECURRENCE_API_KEY`` env var, so a
72
+ consumer that mounts the key as a file and leaves ``api_key`` unset still authenticates.
73
+ """
74
+ file_path = os.environ.get(API_KEY_FILE_ENV_VAR)
75
+ if file_path:
76
+ try:
77
+ content = Path(file_path).read_text(encoding="utf-8").strip()
78
+ except OSError:
79
+ content = ""
80
+ if content:
81
+ return content
82
+ return os.environ.get(API_KEY_ENV_VAR)
83
+
84
+
85
+ #: Optional instrumentation hook, invoked once per HTTP call with
86
+ #: ``(method, url, status, duration_ms, error)``. ``error is None`` is the canonical success
87
+ #: signal (``status`` may be set even on the typed-error paths). Mirrors juniper-data-client so
88
+ #: canopy/cascor can pass the same Prometheus/structured-log closure they already use.
89
+ RequestHook = Callable[[str, str, Optional[int], float, Optional[BaseException]], None]
90
+
91
+
92
+ def _noop_request_hook(
93
+ method: str,
94
+ url: str,
95
+ status: Optional[int],
96
+ duration_ms: float,
97
+ error: Optional[BaseException],
98
+ ) -> None:
99
+ """Default :data:`RequestHook` — does nothing (named so the default is a real callable)."""
100
+
101
+
102
+ def _dataset_ref(
103
+ *,
104
+ dataset_id: Optional[str],
105
+ name: Optional[str],
106
+ generator: Optional[str],
107
+ params: Optional[dict[str, Any]],
108
+ split: str,
109
+ ) -> dict[str, Any]:
110
+ """Build the app's ``DatasetRef`` body from selection kwargs.
111
+
112
+ Exactly one of ``dataset_id`` / ``name`` / ``generator`` is expected; the server validates
113
+ that invariant and returns 422 otherwise.
114
+ """
115
+ ref: dict[str, Any] = {"split": split}
116
+ if dataset_id is not None:
117
+ ref["dataset_id"] = dataset_id
118
+ if name is not None:
119
+ ref["name"] = name
120
+ if generator is not None:
121
+ ref["generator"] = generator
122
+ if params is not None:
123
+ ref["params"] = params
124
+ return ref
125
+
126
+
127
+ class JuniperRecurrenceClient:
128
+ """Client for the juniper-recurrence REST API (train / predict / cross-validate / inspect).
129
+
130
+ Automatic retry (idempotent methods only), connection pooling, and ``X-API-Key`` auth.
131
+
132
+ Example:
133
+ >>> client = JuniperRecurrenceClient("http://localhost:8211")
134
+ >>> client.train(name="equities", d=16)
135
+ >>> preds = client.predict(dataset_id="ds-1")
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ base_url: str = DEFAULT_BASE_URL,
141
+ timeout: int = DEFAULT_TIMEOUT,
142
+ retries: int = DEFAULT_RETRIES,
143
+ backoff_factor: float = DEFAULT_BACKOFF_FACTOR,
144
+ api_key: Optional[str] = None,
145
+ on_request: Optional[RequestHook] = None,
146
+ ) -> None:
147
+ """Initialize the client.
148
+
149
+ Args:
150
+ base_url: Base URL of the juniper-recurrence app (default ``http://localhost:8211``).
151
+ timeout: Per-request timeout in seconds.
152
+ retries: Retry attempts for transient failures on idempotent methods.
153
+ backoff_factor: Exponential backoff factor for retries.
154
+ api_key: API key for ``X-API-Key`` auth. If unset, resolved from
155
+ ``JUNIPER_RECURRENCE_API_KEY`` (and its ``_FILE`` form).
156
+ on_request: Optional instrumentation hook (see :data:`RequestHook`); defaults to a
157
+ no-op. Hook exceptions are caught and logged so instrumentation never crashes a
158
+ request path.
159
+ """
160
+ self.base_url = self._normalize_url(base_url)
161
+ self.timeout = timeout
162
+ self.retries = retries
163
+ self.backoff_factor = backoff_factor
164
+ self.session = self._create_session()
165
+ self._on_request: RequestHook = on_request or _noop_request_hook
166
+
167
+ resolved_api_key = api_key or _resolve_api_key_from_env()
168
+ if resolved_api_key:
169
+ self.session.headers[API_KEY_HEADER_NAME] = resolved_api_key
170
+
171
+ def _normalize_url(self, url: str) -> str:
172
+ """Normalize the base URL: ensure a scheme, drop a trailing slash and any ``/v1`` suffix."""
173
+ url = url.strip()
174
+ if not url.startswith(URL_SCHEME_PREFIXES):
175
+ url = f"{DEFAULT_URL_SCHEME_PREFIX}{url}"
176
+ parsed = urlparse(url)
177
+ normalized = f"{parsed.scheme}://{parsed.netloc}{parsed.path}".rstrip("/")
178
+ if normalized.endswith(API_VERSION_PATH_SUFFIX):
179
+ normalized = normalized[: -len(API_VERSION_PATH_SUFFIX)]
180
+ return normalized
181
+
182
+ def _create_session(self) -> requests.Session:
183
+ """Create a ``requests.Session`` with the idempotent-only retry policy + pooling."""
184
+ session = requests.Session()
185
+ retry_strategy = Retry(
186
+ total=self.retries,
187
+ backoff_factor=self.backoff_factor,
188
+ status_forcelist=RETRYABLE_STATUS_CODES,
189
+ allowed_methods=RETRY_ALLOWED_METHODS,
190
+ )
191
+ adapter = HTTPAdapter(
192
+ max_retries=retry_strategy,
193
+ pool_connections=HTTP_POOL_CONNECTIONS,
194
+ pool_maxsize=HTTP_POOL_MAXSIZE,
195
+ )
196
+ for scheme in URL_SCHEME_PREFIXES:
197
+ session.mount(scheme, adapter)
198
+ return session
199
+
200
+ def _request(self, method: str, endpoint: str, **kwargs: Any) -> requests.Response: # noqa: C901
201
+ """Make an HTTP request, mapping transport/HTTP errors to typed exceptions.
202
+
203
+ Raises:
204
+ JuniperRecurrenceConnectionError / JuniperRecurrenceTimeoutError: transport failures.
205
+ JuniperRecurrenceNotFoundError (404), JuniperRecurrenceConflictError (409),
206
+ JuniperRecurrenceValidationError (400/422), or JuniperRecurrenceClientError (other).
207
+ """
208
+ url = f"{self.base_url}{endpoint}"
209
+ kwargs.setdefault("timeout", self.timeout)
210
+
211
+ # Best-effort X-Request-ID propagation via juniper-observability (no-op if absent or
212
+ # unset); a caller-supplied X-Request-ID always wins.
213
+ headers = dict(kwargs.get("headers") or {})
214
+ if "X-Request-ID" not in headers:
215
+ try:
216
+ from juniper_observability import request_id_var # noqa: PLC0415
217
+
218
+ rid = request_id_var.get()
219
+ if rid:
220
+ headers["X-Request-ID"] = rid
221
+ kwargs["headers"] = headers
222
+ except (ImportError, LookupError):
223
+ pass
224
+
225
+ start = time.monotonic()
226
+ response: Optional[requests.Response] = None
227
+ outgoing_error: Optional[BaseException] = None
228
+ try:
229
+ try:
230
+ response = self.session.request(method, url, **kwargs)
231
+ except requests.exceptions.ConnectionError as e:
232
+ outgoing_error = JuniperRecurrenceConnectionError(f"Failed to connect to juniper-recurrence at {self.base_url}: {e}")
233
+ raise outgoing_error from e
234
+ except requests.exceptions.Timeout as e:
235
+ outgoing_error = JuniperRecurrenceTimeoutError(f"Request to {url} timed out after {self.timeout}s: {e}")
236
+ raise outgoing_error from e
237
+ except requests.exceptions.RequestException as e:
238
+ outgoing_error = JuniperRecurrenceClientError(f"Request failed: {e}")
239
+ raise outgoing_error from e
240
+
241
+ if response.ok:
242
+ return response
243
+
244
+ error_detail = response.text
245
+ try:
246
+ error_json = response.json()
247
+ if "detail" in error_json:
248
+ error_detail = error_json["detail"]
249
+ except (ValueError, KeyError):
250
+ error_detail = response.text
251
+
252
+ if response.status_code == HTTP_404_NOT_FOUND:
253
+ outgoing_error = JuniperRecurrenceNotFoundError(f"Resource not found: {error_detail}")
254
+ raise outgoing_error
255
+ elif response.status_code == HTTP_409_CONFLICT:
256
+ outgoing_error = JuniperRecurrenceConflictError(f"Conflict: {error_detail}")
257
+ raise outgoing_error
258
+ elif response.status_code in (HTTP_400_BAD_REQUEST, HTTP_422_UNPROCESSABLE_ENTITY):
259
+ outgoing_error = JuniperRecurrenceValidationError(f"Validation error: {error_detail}")
260
+ raise outgoing_error
261
+ else:
262
+ outgoing_error = JuniperRecurrenceClientError(f"Request failed ({response.status_code}): {error_detail}")
263
+ raise outgoing_error
264
+ finally:
265
+ duration_ms = (time.monotonic() - start) * 1000.0
266
+ status = response.status_code if response is not None else None
267
+ try:
268
+ self._on_request(method, url, status, duration_ms, outgoing_error)
269
+ except Exception: # noqa: BLE001 — instrumentation must not crash production paths
270
+ logger.warning("on_request hook raised; suppressed to keep request path resilient", exc_info=True)
271
+
272
+ @staticmethod
273
+ def _parse_json(response: requests.Response) -> Any:
274
+ """Parse a response body as JSON, surfacing a typed error on a malformed body."""
275
+ try:
276
+ return response.json()
277
+ except ValueError as e:
278
+ preview = (response.text or "")[:200]
279
+ raise JuniperRecurrenceClientError(f"Malformed JSON response from {response.url}: {e}: {preview!r}") from e
280
+
281
+ # ─── Training ─────────────────────────────────────────────────────────────
282
+
283
+ def train(
284
+ self,
285
+ *,
286
+ dataset_id: Optional[str] = None,
287
+ name: Optional[str] = None,
288
+ generator: Optional[str] = None,
289
+ params: Optional[dict[str, Any]] = None,
290
+ split: str = "train",
291
+ d: Optional[int] = None,
292
+ theta: Optional[float] = None,
293
+ ridge: Optional[float] = None,
294
+ ) -> dict[str, Any]:
295
+ """``POST /v1/train`` — synchronously fit the LMU regressor on a dataset split.
296
+
297
+ Supply exactly one of ``dataset_id`` / ``name`` / ``generator``. Returns the
298
+ ``TrainResponse`` (``final_metrics``, ``n_epochs``, ``stopped_reason``, ``dataset``).
299
+ Raises :class:`JuniperRecurrenceConflictError` (409) if a run is already in progress.
300
+ """
301
+ body: dict[str, Any] = {"dataset": _dataset_ref(dataset_id=dataset_id, name=name, generator=generator, params=params, split=split)}
302
+ if d is not None:
303
+ body["d"] = d
304
+ if theta is not None:
305
+ body["theta"] = theta
306
+ if ridge is not None:
307
+ body["ridge"] = ridge
308
+ return self._parse_json(self._request("POST", ENDPOINT_TRAIN, json=body))
309
+
310
+ def training_status(self) -> dict[str, Any]:
311
+ """``GET /v1/training/status`` — current training state, metrics, and emitted events."""
312
+ return self._parse_json(self._request("GET", ENDPOINT_TRAINING_STATUS))
313
+
314
+ # ─── Prediction ───────────────────────────────────────────────────────────
315
+
316
+ def predict(
317
+ self,
318
+ *,
319
+ X: Optional[Any] = None,
320
+ dt: Optional[Any] = None,
321
+ target_dt: Optional[Any] = None,
322
+ seq_lengths: Optional[Any] = None,
323
+ dataset_id: Optional[str] = None,
324
+ name: Optional[str] = None,
325
+ generator: Optional[str] = None,
326
+ params: Optional[dict[str, Any]] = None,
327
+ split: str = "train",
328
+ ) -> dict[str, Any]:
329
+ """``POST /v1/predict`` — predictions from the trained model.
330
+
331
+ Supply exactly one of inline ``X`` (optionally with ``dt`` / ``target_dt`` /
332
+ ``seq_lengths``) or a dataset reference. Returns ``{"predictions": ..., "shape": ...}``.
333
+ """
334
+ body: dict[str, Any] = {}
335
+ if X is not None:
336
+ body["X"] = X
337
+ if dt is not None:
338
+ body["dt"] = dt
339
+ if target_dt is not None:
340
+ body["target_dt"] = target_dt
341
+ if seq_lengths is not None:
342
+ body["seq_lengths"] = seq_lengths
343
+ if dataset_id is not None or name is not None or generator is not None:
344
+ body["dataset"] = _dataset_ref(dataset_id=dataset_id, name=name, generator=generator, params=params, split=split)
345
+ return self._parse_json(self._request("POST", ENDPOINT_PREDICT, json=body))
346
+
347
+ # ─── Cross-validation ──────────────────────────────────────────────────────
348
+
349
+ def crossval(
350
+ self,
351
+ *,
352
+ n_folds: int,
353
+ dataset_id: Optional[str] = None,
354
+ name: Optional[str] = None,
355
+ generator: Optional[str] = None,
356
+ params: Optional[dict[str, Any]] = None,
357
+ split: str = "train",
358
+ scheme: str = "expanding",
359
+ embargo: int = 0,
360
+ min_train: Optional[int] = None,
361
+ d: Optional[int] = None,
362
+ theta: Optional[float] = None,
363
+ ridge: Optional[float] = None,
364
+ ) -> dict[str, Any]:
365
+ """``POST /v1/crossval`` — synchronous walk-forward cross-validation over the ``_full`` split.
366
+
367
+ Returns the ``CrossValResponse`` (per-fold ``folds`` + ``eval_aggregate`` / ``eval_std``).
368
+ ``scheme`` is ``"expanding"`` or ``"rolling"``. Raises 409 if a CV run is already running.
369
+ """
370
+ body: dict[str, Any] = {
371
+ "dataset": _dataset_ref(dataset_id=dataset_id, name=name, generator=generator, params=params, split=split),
372
+ "n_folds": n_folds,
373
+ "scheme": scheme,
374
+ "embargo": embargo,
375
+ }
376
+ if min_train is not None:
377
+ body["min_train"] = min_train
378
+ if d is not None:
379
+ body["d"] = d
380
+ if theta is not None:
381
+ body["theta"] = theta
382
+ if ridge is not None:
383
+ body["ridge"] = ridge
384
+ return self._parse_json(self._request("POST", ENDPOINT_CROSSVAL, json=body))
385
+
386
+ def crossval_status(self) -> dict[str, Any]:
387
+ """``GET /v1/crossval/status`` — the most recent cross-validation result, if any."""
388
+ return self._parse_json(self._request("GET", ENDPOINT_CROSSVAL_STATUS))
389
+
390
+ # ─── Inspection ────────────────────────────────────────────────────────────
391
+
392
+ def get_model(self) -> dict[str, Any]:
393
+ """``GET /v1/model`` — the trained model's topology + metrics (409 if none trained)."""
394
+ return self._parse_json(self._request("GET", ENDPOINT_MODEL))
395
+
396
+ def get_dataset(self) -> dict[str, Any]:
397
+ """``GET /v1/dataset`` — descriptor of the split the model was trained on (409 if none)."""
398
+ return self._parse_json(self._request("GET", ENDPOINT_DATASET))
399
+
400
+ # ─── Health / Readiness ────────────────────────────────────────────────────
401
+
402
+ def health_check(self) -> dict[str, Any]:
403
+ """``GET /v1/health`` — liveness."""
404
+ return self._parse_json(self._request("GET", ENDPOINT_HEALTH))
405
+
406
+ def is_ready(self) -> bool:
407
+ """``GET /v1/health/ready`` — ``True`` iff the service reports ready."""
408
+ try:
409
+ payload = self._parse_json(self._request("GET", ENDPOINT_HEALTH_READY))
410
+ except JuniperRecurrenceClientError:
411
+ return False
412
+ return payload.get("status") == HEALTH_READY_STATUS
413
+
414
+ def wait_for_ready(self, timeout: float = DEFAULT_READY_TIMEOUT, poll_interval: float = DEFAULT_READY_POLL_INTERVAL) -> bool:
415
+ """Poll ``/v1/health/ready`` until ready or ``timeout`` seconds elapse."""
416
+ deadline = time.monotonic() + timeout
417
+ while time.monotonic() < deadline:
418
+ if self.is_ready():
419
+ return True
420
+ time.sleep(poll_interval)
421
+ return self.is_ready()
422
+
423
+ # ─── Lifecycle ─────────────────────────────────────────────────────────────
424
+
425
+ def close(self) -> None:
426
+ """Close the underlying HTTP session."""
427
+ self.session.close()
428
+
429
+ def __enter__(self) -> "JuniperRecurrenceClient":
430
+ return self
431
+
432
+ def __exit__(self, *exc: object) -> None:
433
+ self.close()
@@ -0,0 +1,97 @@
1
+ """Protocol-level constants for the juniper-recurrence REST client.
2
+
3
+ Centralizes the literals used by ``client.py`` — base URL, endpoint paths, header names,
4
+ HTTP/retry configuration — mirroring juniper-data-client's constants module so the wire
5
+ contract is discoverable in one place.
6
+ """
7
+
8
+ from typing import List, Tuple
9
+
10
+ __all__ = [
11
+ "DEFAULT_BASE_URL",
12
+ "DEFAULT_TIMEOUT",
13
+ "DEFAULT_RETRIES",
14
+ "DEFAULT_BACKOFF_FACTOR",
15
+ "RETRYABLE_STATUS_CODES",
16
+ "RETRY_ALLOWED_METHODS",
17
+ "HTTP_POOL_CONNECTIONS",
18
+ "HTTP_POOL_MAXSIZE",
19
+ "URL_SCHEME_PREFIXES",
20
+ "DEFAULT_URL_SCHEME_PREFIX",
21
+ "API_VERSION_PATH_SUFFIX",
22
+ "DEFAULT_READY_TIMEOUT",
23
+ "DEFAULT_READY_POLL_INTERVAL",
24
+ "HEALTH_READY_STATUS",
25
+ "API_KEY_HEADER_NAME",
26
+ "API_KEY_ENV_VAR",
27
+ "API_KEY_FILE_ENV_VAR",
28
+ "ENDPOINT_HEALTH",
29
+ "ENDPOINT_HEALTH_READY",
30
+ "ENDPOINT_TRAIN",
31
+ "ENDPOINT_TRAINING_STATUS",
32
+ "ENDPOINT_PREDICT",
33
+ "ENDPOINT_MODEL",
34
+ "ENDPOINT_DATASET",
35
+ "ENDPOINT_CROSSVAL",
36
+ "ENDPOINT_CROSSVAL_STATUS",
37
+ ]
38
+
39
+ # ─── Service Configuration ───────────────────────────────────────────────────
40
+
41
+ # The juniper-recurrence app binds container port 8210; juniper-deploy maps host 8211 -> 8210,
42
+ # so the default host-facing base URL is 8211 (mirrors the deploy port map).
43
+ DEFAULT_BASE_URL: str = "http://localhost:8211"
44
+
45
+ # ─── HTTP Configuration ──────────────────────────────────────────────────────
46
+
47
+ DEFAULT_TIMEOUT: int = 30
48
+ DEFAULT_RETRIES: int = 3
49
+ DEFAULT_BACKOFF_FACTOR: float = 0.5
50
+ RETRYABLE_STATUS_CODES: List[int] = [429, 500, 502, 503, 504]
51
+ # Auto-retry is restricted to idempotent methods (RFC 9110 §9.2.2). The recurrence POSTs
52
+ # (train / predict / crossval) carry server-side state — train and crossval are lock-guarded —
53
+ # so a transient-5xx retry must not silently re-issue them. Only GET/HEAD auto-retry.
54
+ RETRY_ALLOWED_METHODS: List[str] = ["HEAD", "GET"]
55
+ HTTP_POOL_CONNECTIONS: int = 10
56
+ HTTP_POOL_MAXSIZE: int = 10
57
+
58
+ # ─── URL Normalization ───────────────────────────────────────────────────────
59
+
60
+ URL_SCHEME_PREFIXES: Tuple[str, ...] = ("http://", "https://")
61
+ DEFAULT_URL_SCHEME_PREFIX: str = "http://"
62
+ API_VERSION_PATH_SUFFIX: str = "/v1"
63
+
64
+ # ─── Readiness Polling ───────────────────────────────────────────────────────
65
+
66
+ DEFAULT_READY_TIMEOUT: float = 30.0
67
+ DEFAULT_READY_POLL_INTERVAL: float = 0.5
68
+ HEALTH_READY_STATUS: str = "ready"
69
+
70
+ # ─── Authentication ──────────────────────────────────────────────────────────
71
+
72
+ # The recurrence app enforces the X-API-Key header, reading its accepted keys from
73
+ # JUNIPER_RECURRENCE_API_KEYS (plural; CSV or JSON array, with _FILE indirection). The client
74
+ # sends a single key, resolved from the singular JUNIPER_RECURRENCE_API_KEY (and its _FILE
75
+ # Docker-secret form) — document the singular/plural asymmetry in AGENTS.md.
76
+ API_KEY_HEADER_NAME: str = "X-API-Key"
77
+ API_KEY_ENV_VAR: str = "JUNIPER_RECURRENCE_API_KEY"
78
+ API_KEY_FILE_ENV_VAR: str = f"{API_KEY_ENV_VAR}_FILE"
79
+
80
+ # ─── REST Endpoints (the juniper-recurrence app surface) ─────────────────────
81
+
82
+ ENDPOINT_HEALTH: str = "/v1/health"
83
+ ENDPOINT_HEALTH_READY: str = "/v1/health/ready"
84
+ ENDPOINT_TRAIN: str = "/v1/train"
85
+ ENDPOINT_TRAINING_STATUS: str = "/v1/training/status"
86
+ ENDPOINT_PREDICT: str = "/v1/predict"
87
+ ENDPOINT_MODEL: str = "/v1/model"
88
+ ENDPOINT_DATASET: str = "/v1/dataset"
89
+ ENDPOINT_CROSSVAL: str = "/v1/crossval"
90
+ ENDPOINT_CROSSVAL_STATUS: str = "/v1/crossval/status"
91
+
92
+ # ─── HTTP Status Codes ───────────────────────────────────────────────────────
93
+
94
+ HTTP_400_BAD_REQUEST: int = 400
95
+ HTTP_404_NOT_FOUND: int = 404
96
+ HTTP_409_CONFLICT: int = 409
97
+ HTTP_422_UNPROCESSABLE_ENTITY: int = 422
@@ -0,0 +1,36 @@
1
+ """Custom exceptions for the juniper-recurrence client library.
2
+
3
+ Mirrors juniper-data-client's flat hierarchy (one base + typed leaves), adding a
4
+ ``JuniperRecurrenceConflictError`` for the recurrence app's ``409`` responses (a training /
5
+ cross-validation run already in progress, or an operation that needs a trained model that does
6
+ not yet exist) — a status the data-client surface never returns.
7
+ """
8
+
9
+
10
+ class JuniperRecurrenceClientError(Exception):
11
+ """Base exception for all juniper-recurrence client errors."""
12
+
13
+
14
+ class JuniperRecurrenceConnectionError(JuniperRecurrenceClientError):
15
+ """Raised when the connection to the juniper-recurrence service fails."""
16
+
17
+
18
+ class JuniperRecurrenceTimeoutError(JuniperRecurrenceClientError):
19
+ """Raised when a request to the juniper-recurrence service times out."""
20
+
21
+
22
+ class JuniperRecurrenceNotFoundError(JuniperRecurrenceClientError):
23
+ """Raised when a requested resource is not found (404)."""
24
+
25
+
26
+ class JuniperRecurrenceValidationError(JuniperRecurrenceClientError):
27
+ """Raised when request parameters fail validation (400 / 422)."""
28
+
29
+
30
+ class JuniperRecurrenceConflictError(JuniperRecurrenceClientError):
31
+ """Raised on a 409 Conflict — a training/cross-validation run is already in progress, or
32
+ the operation requires a trained model/dataset that does not yet exist."""
33
+
34
+
35
+ class JuniperRecurrenceConfigurationError(JuniperRecurrenceClientError):
36
+ """Raised when juniper-recurrence client configuration is missing or invalid."""
@@ -0,0 +1,104 @@
1
+ Metadata-Version: 2.4
2
+ Name: juniper-recurrence-client
3
+ Version: 0.1.0
4
+ Summary: HTTP client for the juniper-recurrence service (the Δt-native LMU recurrence model + cross-validation API)
5
+ Author: Paul Calnon
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/pcalnon/juniper-recurrence
8
+ Project-URL: Repository, https://github.com/pcalnon/juniper-recurrence
9
+ Project-URL: Issues, https://github.com/pcalnon/juniper-recurrence/issues
10
+ Keywords: juniper,recurrence,lmu,http-client,rest,time-series,cross-validation
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Programming Language :: Python :: 3.14
20
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
22
+ Requires-Python: >=3.12
23
+ Description-Content-Type: text/markdown
24
+ Requires-Dist: requests>=2.28.0
25
+ Requires-Dist: urllib3>=2.0.0
26
+ Provides-Extra: test
27
+ Requires-Dist: pytest>=8.0; extra == "test"
28
+ Requires-Dist: pytest-cov>=5.0; extra == "test"
29
+ Requires-Dist: responses>=0.23; extra == "test"
30
+ Provides-Extra: observability
31
+ Requires-Dist: juniper-observability>=0.3.1; extra == "observability"
32
+
33
+ # juniper-recurrence-client
34
+
35
+ HTTP client for the **juniper-recurrence** service — the FastAPI app wrapping the Δt-native LMU
36
+ recurrence model and its cross-validation API. A lean `requests`-based client mirroring
37
+ [`juniper-data-client`](https://github.com/pcalnon/juniper-data-client) and
38
+ [`juniper-cascor-client`](https://github.com/pcalnon/juniper-cascor-client), so consumers
39
+ (notably **juniper-canopy**'s recurrence backend adapter) drive every Juniper backend the same way.
40
+
41
+ ## Install
42
+
43
+ ```bash
44
+ pip install juniper-recurrence-client # once published
45
+ pip install -e ".[test]" # local development
46
+ ```
47
+
48
+ `requests`-only at the core; `pip install juniper-recurrence-client[observability]` adds the
49
+ optional `juniper-observability` integration (X-Request-ID propagation + the `on_request` hook).
50
+
51
+ ## Quick start
52
+
53
+ ```python
54
+ from juniper_recurrence_client import JuniperRecurrenceClient
55
+
56
+ client = JuniperRecurrenceClient("http://localhost:8211", api_key="…")
57
+
58
+ # Train the LMU regressor on a dataset (by id / name / generator)
59
+ client.train(name="equities", d=16)
60
+
61
+ # Predict — inline X with Δt, or a dataset reference
62
+ client.predict(dataset_id="ds-1")
63
+
64
+ # Walk-forward cross-validation over the dataset's _full split
65
+ result = client.crossval(name="equities", n_folds=4, scheme="expanding", embargo=2)
66
+ print(result["eval_aggregate"])
67
+
68
+ # Inspect
69
+ client.get_model() # topology + metrics
70
+ client.training_status() # state + events
71
+ client.is_ready() # readiness probe
72
+ ```
73
+
74
+ ## API surface
75
+
76
+ | Method | Endpoint |
77
+ |--------|----------|
78
+ | `train(*, dataset_id / name / generator, params, split, d, theta, ridge)` | `POST /v1/train` |
79
+ | `training_status()` | `GET /v1/training/status` |
80
+ | `predict(*, X / dt / target_dt / seq_lengths, or a dataset ref)` | `POST /v1/predict` |
81
+ | `crossval(*, n_folds, scheme, embargo, min_train, dataset ref, d, theta, ridge)` | `POST /v1/crossval` |
82
+ | `crossval_status()` | `GET /v1/crossval/status` |
83
+ | `get_model()` | `GET /v1/model` |
84
+ | `get_dataset()` | `GET /v1/dataset` |
85
+ | `health_check()` / `is_ready()` / `wait_for_ready()` | `GET /v1/health[/ready]` |
86
+
87
+ ## Authentication
88
+
89
+ Pass `api_key=…`, or set `JUNIPER_RECURRENCE_API_KEY` (or the Docker-secret
90
+ `JUNIPER_RECURRENCE_API_KEY_FILE`, a path whose stripped contents are the key). The key is sent
91
+ as the `X-API-Key` header. Note the asymmetry: the **server** reads the *plural*
92
+ `JUNIPER_RECURRENCE_API_KEYS` (its accepted set); the **client** sends one key under the
93
+ *singular* env var.
94
+
95
+ ## Errors
96
+
97
+ All errors derive from `JuniperRecurrenceClientError`: `JuniperRecurrenceConnectionError`,
98
+ `JuniperRecurrenceTimeoutError`, `JuniperRecurrenceNotFoundError` (404),
99
+ `JuniperRecurrenceConflictError` (409 — a run already in progress, or no trained model yet),
100
+ `JuniperRecurrenceValidationError` (400/422), `JuniperRecurrenceConfigurationError`.
101
+
102
+ ## License
103
+
104
+ MIT — see [LICENSE](https://github.com/pcalnon/juniper-recurrence/blob/main/LICENSE).
@@ -0,0 +1,16 @@
1
+ README.md
2
+ pyproject.toml
3
+ juniper_recurrence_client/__init__.py
4
+ juniper_recurrence_client/_version.py
5
+ juniper_recurrence_client/client.py
6
+ juniper_recurrence_client/constants.py
7
+ juniper_recurrence_client/exceptions.py
8
+ juniper_recurrence_client.egg-info/PKG-INFO
9
+ juniper_recurrence_client.egg-info/SOURCES.txt
10
+ juniper_recurrence_client.egg-info/dependency_links.txt
11
+ juniper_recurrence_client.egg-info/requires.txt
12
+ juniper_recurrence_client.egg-info/top_level.txt
13
+ tests/test_auth.py
14
+ tests/test_client.py
15
+ tests/test_errors.py
16
+ tests/test_extra.py
@@ -0,0 +1,10 @@
1
+ requests>=2.28.0
2
+ urllib3>=2.0.0
3
+
4
+ [observability]
5
+ juniper-observability>=0.3.1
6
+
7
+ [test]
8
+ pytest>=8.0
9
+ pytest-cov>=5.0
10
+ responses>=0.23
@@ -0,0 +1 @@
1
+ juniper_recurrence_client
@@ -0,0 +1,75 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "juniper-recurrence-client"
7
+ dynamic = ["version"]
8
+ description = "HTTP client for the juniper-recurrence service (the Δt-native LMU recurrence model + cross-validation API)"
9
+ readme = "README.md"
10
+ requires-python = ">=3.12"
11
+ license = { text = "MIT" }
12
+ authors = [{ name = "Paul Calnon" }]
13
+ keywords = ["juniper", "recurrence", "lmu", "http-client", "rest", "time-series", "cross-validation"]
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "Intended Audience :: Developers",
17
+ "Intended Audience :: Science/Research",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Operating System :: OS Independent",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.12",
22
+ "Programming Language :: Python :: 3.13",
23
+ "Programming Language :: Python :: 3.14",
24
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
25
+ "Topic :: Software Development :: Libraries :: Python Modules",
26
+ ]
27
+ dependencies = [
28
+ "requests>=2.28.0",
29
+ "urllib3>=2.0.0",
30
+ ]
31
+
32
+ [project.optional-dependencies]
33
+ test = [
34
+ "pytest>=8.0",
35
+ "pytest-cov>=5.0",
36
+ "responses>=0.23",
37
+ ]
38
+ # Opt-in: the X-Request-ID propagation + the on_request hook integrate with
39
+ # juniper-observability, but the client never requires it (the import is guarded).
40
+ observability = [
41
+ "juniper-observability>=0.3.1",
42
+ ]
43
+
44
+ [project.urls]
45
+ Homepage = "https://github.com/pcalnon/juniper-recurrence"
46
+ Repository = "https://github.com/pcalnon/juniper-recurrence"
47
+ Issues = "https://github.com/pcalnon/juniper-recurrence/issues"
48
+
49
+ [tool.setuptools.dynamic]
50
+ version = { attr = "juniper_recurrence_client._version.__version__" }
51
+
52
+ [tool.setuptools.packages.find]
53
+ include = ["juniper_recurrence_client*"]
54
+ exclude = ["tests*"]
55
+
56
+ [tool.pytest.ini_options]
57
+ minversion = "8.0"
58
+ testpaths = ["tests"]
59
+ python_files = ["test_*.py"]
60
+ python_classes = ["Test*"]
61
+ python_functions = ["test_*"]
62
+ # Ecosystem-standard autoload SIGSEGV defense (dash/playwright plugins are not used here).
63
+ addopts = ["-ra", "--strict-markers", "--strict-config", "-p", "no:dash", "-p", "no:playwright"]
64
+
65
+ [tool.ruff]
66
+ line-length = 512
67
+ target-version = "py312"
68
+
69
+ [tool.ruff.lint]
70
+ select = ["E", "F", "W", "B", "I", "N"]
71
+
72
+ [tool.ruff.lint.per-file-ignores]
73
+ # X is the standard ML design-matrix argument name (juniper-model-core's array contract);
74
+ # pep8-naming N803's lowercase rule is a false positive on it (mirrors the model package).
75
+ "juniper_recurrence_client/client.py" = ["N803"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,50 @@
1
+ """API-key resolution tests (explicit arg, env var, ``_FILE`` Docker-secret, precedence)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from juniper_recurrence_client import JuniperRecurrenceClient
8
+ from juniper_recurrence_client.constants import API_KEY_ENV_VAR, API_KEY_FILE_ENV_VAR, API_KEY_HEADER_NAME
9
+
10
+ BASE_URL = "http://x:8211"
11
+
12
+
13
+ def test_explicit_api_key_sets_header() -> None:
14
+ client = JuniperRecurrenceClient(base_url=BASE_URL, api_key="secret-key")
15
+ assert client.session.headers[API_KEY_HEADER_NAME] == "secret-key"
16
+
17
+
18
+ def test_no_api_key_leaves_header_unset() -> None:
19
+ client = JuniperRecurrenceClient(base_url=BASE_URL)
20
+ assert API_KEY_HEADER_NAME not in client.session.headers
21
+
22
+
23
+ def test_api_key_from_env(monkeypatch: pytest.MonkeyPatch) -> None:
24
+ monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
25
+ client = JuniperRecurrenceClient(base_url=BASE_URL)
26
+ assert client.session.headers[API_KEY_HEADER_NAME] == "env-key"
27
+
28
+
29
+ def test_file_indirection_beats_plain_env(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
30
+ key_file = tmp_path / "key.txt"
31
+ key_file.write_text(" file-key\n", encoding="utf-8")
32
+ monkeypatch.setenv(API_KEY_FILE_ENV_VAR, str(key_file))
33
+ monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
34
+ client = JuniperRecurrenceClient(base_url=BASE_URL)
35
+ assert client.session.headers[API_KEY_HEADER_NAME] == "file-key"
36
+
37
+
38
+ def test_explicit_api_key_beats_env(monkeypatch: pytest.MonkeyPatch) -> None:
39
+ monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
40
+ client = JuniperRecurrenceClient(base_url=BASE_URL, api_key="explicit")
41
+ assert client.session.headers[API_KEY_HEADER_NAME] == "explicit"
42
+
43
+
44
+ def test_empty_file_falls_back_to_plain_env(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
45
+ empty = tmp_path / "empty.txt"
46
+ empty.write_text(" \n", encoding="utf-8")
47
+ monkeypatch.setenv(API_KEY_FILE_ENV_VAR, str(empty))
48
+ monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
49
+ client = JuniperRecurrenceClient(base_url=BASE_URL)
50
+ assert client.session.headers[API_KEY_HEADER_NAME] == "env-key"
@@ -0,0 +1,116 @@
1
+ """Unit tests for ``JuniperRecurrenceClient`` (HTTP mocked with ``responses``)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+
7
+ import pytest
8
+ import responses
9
+
10
+ from juniper_recurrence_client import JuniperRecurrenceClient
11
+
12
+ BASE_URL = "http://recurrence.test:8211"
13
+
14
+
15
+ def _client(**kwargs: object) -> JuniperRecurrenceClient:
16
+ kwargs.setdefault("retries", 0)
17
+ return JuniperRecurrenceClient(base_url=BASE_URL, **kwargs)
18
+
19
+
20
+ @pytest.mark.parametrize(
21
+ "raw,expected",
22
+ [
23
+ ("http://recurrence.test:8211", "http://recurrence.test:8211"),
24
+ ("http://recurrence.test:8211/", "http://recurrence.test:8211"),
25
+ ("http://recurrence.test:8211/v1", "http://recurrence.test:8211"),
26
+ ("recurrence.test:8211", "http://recurrence.test:8211"),
27
+ ],
28
+ )
29
+ def test_normalize_url(raw: str, expected: str) -> None:
30
+ assert JuniperRecurrenceClient(base_url=raw).base_url == expected
31
+
32
+
33
+ @responses.activate
34
+ def test_train_posts_dataset_ref_and_hyperparams() -> None:
35
+ responses.add(responses.POST, f"{BASE_URL}/v1/train", json={"final_metrics": {"r2": 0.9}, "n_epochs": 1, "stopped_reason": None, "dataset": {}}, status=200)
36
+ out = _client().train(name="equities", d=16, theta=2.0, ridge=0.1)
37
+ assert out["final_metrics"]["r2"] == 0.9
38
+ sent = json.loads(responses.calls[0].request.body)
39
+ assert sent["dataset"] == {"split": "train", "name": "equities"}
40
+ assert sent["d"] == 16 and sent["theta"] == 2.0 and sent["ridge"] == 0.1
41
+
42
+
43
+ @responses.activate
44
+ def test_training_status() -> None:
45
+ responses.add(responses.GET, f"{BASE_URL}/v1/training/status", json={"state": "trained", "final_metrics": {"r2": 0.9}, "stopped_reason": None, "events": []}, status=200)
46
+ assert _client().training_status()["state"] == "trained"
47
+
48
+
49
+ @responses.activate
50
+ def test_predict_inline_x() -> None:
51
+ responses.add(responses.POST, f"{BASE_URL}/v1/predict", json={"predictions": [[1.0]], "shape": [1, 1]}, status=200)
52
+ out = _client().predict(X=[[[1.0, 2.0]]], dt=[[0.0]])
53
+ assert out["shape"] == [1, 1]
54
+ sent = json.loads(responses.calls[0].request.body)
55
+ assert sent["X"] == [[[1.0, 2.0]]] and sent["dt"] == [[0.0]] and "dataset" not in sent
56
+
57
+
58
+ @responses.activate
59
+ def test_predict_by_dataset_ref() -> None:
60
+ responses.add(responses.POST, f"{BASE_URL}/v1/predict", json={"predictions": [], "shape": [0, 1]}, status=200)
61
+ _client().predict(dataset_id="ds-1")
62
+ sent = json.loads(responses.calls[0].request.body)
63
+ assert sent["dataset"] == {"split": "train", "dataset_id": "ds-1"} and "X" not in sent
64
+
65
+
66
+ @responses.activate
67
+ def test_crossval_passes_config() -> None:
68
+ responses.add(responses.POST, f"{BASE_URL}/v1/crossval", json={"task_type": "regression", "n_folds": 3, "folds": [], "eval_aggregate": {}, "eval_std": {}, "dataset": {}}, status=200)
69
+ out = _client().crossval(name="equities", n_folds=3, scheme="rolling", embargo=2, min_train=10)
70
+ assert out["n_folds"] == 3
71
+ sent = json.loads(responses.calls[0].request.body)
72
+ assert sent["n_folds"] == 3 and sent["scheme"] == "rolling" and sent["embargo"] == 2 and sent["min_train"] == 10
73
+
74
+
75
+ @responses.activate
76
+ def test_crossval_status_model_dataset() -> None:
77
+ responses.add(responses.GET, f"{BASE_URL}/v1/crossval/status", json={"state": "done", "result": None}, status=200)
78
+ responses.add(responses.GET, f"{BASE_URL}/v1/model", json={"topology": {"model_type": "lmu"}, "metrics": {}}, status=200)
79
+ responses.add(responses.GET, f"{BASE_URL}/v1/dataset", json={"dataset_id": "ds-1", "split": "train"}, status=200)
80
+ c = _client()
81
+ assert c.crossval_status()["state"] == "done"
82
+ assert c.get_model()["topology"]["model_type"] == "lmu"
83
+ assert c.get_dataset()["dataset_id"] == "ds-1"
84
+
85
+
86
+ @responses.activate
87
+ def test_health_and_is_ready() -> None:
88
+ responses.add(responses.GET, f"{BASE_URL}/v1/health", json={"status": "ok"}, status=200)
89
+ responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "ready"}, status=200)
90
+ c = _client()
91
+ assert c.health_check()["status"] == "ok"
92
+ assert c.is_ready() is True
93
+
94
+
95
+ @responses.activate
96
+ def test_is_ready_false_when_not_ready() -> None:
97
+ responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "starting"}, status=200)
98
+ assert _client().is_ready() is False
99
+
100
+
101
+ @responses.activate
102
+ def test_on_request_hook_fires_once() -> None:
103
+ seen: list[tuple[str, object, object, object]] = []
104
+
105
+ def hook(method: str, url: str, status: object, duration_ms: float, error: object) -> None:
106
+ seen.append((method, url, status, error))
107
+
108
+ responses.add(responses.GET, f"{BASE_URL}/v1/model", json={"topology": {}, "metrics": {}}, status=200)
109
+ _client(on_request=hook).get_model()
110
+ assert len(seen) == 1
111
+ assert seen[0][0] == "GET" and seen[0][2] == 200 and seen[0][3] is None
112
+
113
+
114
+ def test_context_manager_closes() -> None:
115
+ with _client() as client:
116
+ assert client.session is not None
@@ -0,0 +1,63 @@
1
+ """HTTP/transport error-mapping tests (status code -> typed exception)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ import responses
7
+
8
+ from juniper_recurrence_client import (
9
+ JuniperRecurrenceClient,
10
+ JuniperRecurrenceClientError,
11
+ JuniperRecurrenceConflictError,
12
+ JuniperRecurrenceConnectionError,
13
+ JuniperRecurrenceNotFoundError,
14
+ JuniperRecurrenceValidationError,
15
+ )
16
+
17
+ BASE_URL = "http://recurrence.test:8211"
18
+
19
+
20
+ def _client() -> JuniperRecurrenceClient:
21
+ return JuniperRecurrenceClient(base_url=BASE_URL, retries=0)
22
+
23
+
24
+ @responses.activate
25
+ def test_404_maps_to_not_found() -> None:
26
+ responses.add(responses.GET, f"{BASE_URL}/v1/model", json={"detail": "no model"}, status=404)
27
+ with pytest.raises(JuniperRecurrenceNotFoundError, match="no model"):
28
+ _client().get_model()
29
+
30
+
31
+ @responses.activate
32
+ def test_409_maps_to_conflict() -> None:
33
+ responses.add(responses.POST, f"{BASE_URL}/v1/train", json={"detail": "training already in progress"}, status=409)
34
+ with pytest.raises(JuniperRecurrenceConflictError, match="in progress"):
35
+ _client().train(name="equities")
36
+
37
+
38
+ @responses.activate
39
+ def test_422_maps_to_validation() -> None:
40
+ responses.add(responses.POST, f"{BASE_URL}/v1/crossval", json={"detail": "n_folds must be >= 2"}, status=422)
41
+ with pytest.raises(JuniperRecurrenceValidationError, match="n_folds"):
42
+ _client().crossval(name="equities", n_folds=1)
43
+
44
+
45
+ @responses.activate
46
+ def test_500_maps_to_client_error() -> None:
47
+ responses.add(responses.GET, f"{BASE_URL}/v1/dataset", json={"detail": "boom"}, status=500)
48
+ with pytest.raises(JuniperRecurrenceClientError, match="500"):
49
+ _client().get_dataset()
50
+
51
+
52
+ @responses.activate
53
+ def test_connection_error_maps() -> None:
54
+ # No response registered for this URL -> responses raises a ConnectionError.
55
+ with pytest.raises(JuniperRecurrenceConnectionError):
56
+ _client().get_model()
57
+
58
+
59
+ @responses.activate
60
+ def test_malformed_json_raises_client_error() -> None:
61
+ responses.add(responses.GET, f"{BASE_URL}/v1/dataset", body="not-json", status=200, content_type="application/json")
62
+ with pytest.raises(JuniperRecurrenceClientError, match="Malformed JSON"):
63
+ _client().get_dataset()
@@ -0,0 +1,76 @@
1
+ """Extra coverage: timeout / generic-request-error mapping, full predict & crossval bodies,
2
+ readiness polling (success + timeout), and the api-key file-read-error fallback."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+
8
+ import pytest
9
+ import requests
10
+ import responses
11
+
12
+ from juniper_recurrence_client import (
13
+ JuniperRecurrenceClient,
14
+ JuniperRecurrenceClientError,
15
+ JuniperRecurrenceTimeoutError,
16
+ )
17
+ from juniper_recurrence_client.constants import API_KEY_ENV_VAR, API_KEY_FILE_ENV_VAR, API_KEY_HEADER_NAME
18
+
19
+ BASE_URL = "http://recurrence.test:8211"
20
+
21
+
22
+ def _client(**kwargs: object) -> JuniperRecurrenceClient:
23
+ kwargs.setdefault("retries", 0)
24
+ return JuniperRecurrenceClient(base_url=BASE_URL, **kwargs)
25
+
26
+
27
+ @responses.activate
28
+ def test_timeout_maps_to_timeout_error() -> None:
29
+ responses.add(responses.GET, f"{BASE_URL}/v1/model", body=requests.exceptions.Timeout("slow"))
30
+ with pytest.raises(JuniperRecurrenceTimeoutError):
31
+ _client().get_model()
32
+
33
+
34
+ @responses.activate
35
+ def test_generic_request_exception_maps_to_client_error() -> None:
36
+ responses.add(responses.GET, f"{BASE_URL}/v1/dataset", body=requests.exceptions.RequestException("boom"))
37
+ with pytest.raises(JuniperRecurrenceClientError):
38
+ _client().get_dataset()
39
+
40
+
41
+ @responses.activate
42
+ def test_predict_full_aux_body() -> None:
43
+ responses.add(responses.POST, f"{BASE_URL}/v1/predict", json={"predictions": [], "shape": [0, 1]}, status=200)
44
+ _client().predict(X=[[[1.0]]], dt=[[0.0]], target_dt=[1.0], seq_lengths=[1])
45
+ sent = json.loads(responses.calls[0].request.body)
46
+ assert sent["target_dt"] == [1.0] and sent["seq_lengths"] == [1]
47
+
48
+
49
+ @responses.activate
50
+ def test_crossval_passes_hyperparams() -> None:
51
+ responses.add(responses.POST, f"{BASE_URL}/v1/crossval", json={"task_type": "regression", "n_folds": 2, "folds": [], "eval_aggregate": {}, "eval_std": {}, "dataset": {}}, status=200)
52
+ _client().crossval(generator="equities_seq", n_folds=2, d=8, theta=1.5, ridge=0.2)
53
+ sent = json.loads(responses.calls[0].request.body)
54
+ assert sent["dataset"]["generator"] == "equities_seq"
55
+ assert sent["d"] == 8 and sent["theta"] == 1.5 and sent["ridge"] == 0.2
56
+
57
+
58
+ @responses.activate
59
+ def test_wait_for_ready_polls_until_ready() -> None:
60
+ responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "starting"}, status=200)
61
+ responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "ready"}, status=200)
62
+ assert _client().wait_for_ready(timeout=2.0, poll_interval=0.01) is True
63
+
64
+
65
+ @responses.activate
66
+ def test_wait_for_ready_times_out() -> None:
67
+ responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "starting"}, status=200)
68
+ assert _client().wait_for_ready(timeout=0.03, poll_interval=0.01) is False
69
+
70
+
71
+ def test_api_key_file_read_error_falls_back_to_env(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
72
+ # _FILE points at a missing path -> OSError on read -> fall back to the plain env var.
73
+ monkeypatch.setenv(API_KEY_FILE_ENV_VAR, str(tmp_path / "does-not-exist"))
74
+ monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
75
+ client = JuniperRecurrenceClient(base_url=BASE_URL)
76
+ assert client.session.headers[API_KEY_HEADER_NAME] == "env-key"