juniper-recurrence-client 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- juniper_recurrence_client-0.1.0/PKG-INFO +104 -0
- juniper_recurrence_client-0.1.0/README.md +72 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client/__init__.py +34 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client/_version.py +7 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client/client.py +433 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client/constants.py +97 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client/exceptions.py +36 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client.egg-info/PKG-INFO +104 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client.egg-info/SOURCES.txt +16 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client.egg-info/dependency_links.txt +1 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client.egg-info/requires.txt +10 -0
- juniper_recurrence_client-0.1.0/juniper_recurrence_client.egg-info/top_level.txt +1 -0
- juniper_recurrence_client-0.1.0/pyproject.toml +75 -0
- juniper_recurrence_client-0.1.0/setup.cfg +4 -0
- juniper_recurrence_client-0.1.0/tests/test_auth.py +50 -0
- juniper_recurrence_client-0.1.0/tests/test_client.py +116 -0
- juniper_recurrence_client-0.1.0/tests/test_errors.py +63 -0
- juniper_recurrence_client-0.1.0/tests/test_extra.py +76 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: juniper-recurrence-client
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: HTTP client for the juniper-recurrence service (the Δt-native LMU recurrence model + cross-validation API)
|
|
5
|
+
Author: Paul Calnon
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/pcalnon/juniper-recurrence
|
|
8
|
+
Project-URL: Repository, https://github.com/pcalnon/juniper-recurrence
|
|
9
|
+
Project-URL: Issues, https://github.com/pcalnon/juniper-recurrence/issues
|
|
10
|
+
Keywords: juniper,recurrence,lmu,http-client,rest,time-series,cross-validation
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Operating System :: OS Independent
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
22
|
+
Requires-Python: >=3.12
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
Requires-Dist: requests>=2.28.0
|
|
25
|
+
Requires-Dist: urllib3>=2.0.0
|
|
26
|
+
Provides-Extra: test
|
|
27
|
+
Requires-Dist: pytest>=8.0; extra == "test"
|
|
28
|
+
Requires-Dist: pytest-cov>=5.0; extra == "test"
|
|
29
|
+
Requires-Dist: responses>=0.23; extra == "test"
|
|
30
|
+
Provides-Extra: observability
|
|
31
|
+
Requires-Dist: juniper-observability>=0.3.1; extra == "observability"
|
|
32
|
+
|
|
33
|
+
# juniper-recurrence-client
|
|
34
|
+
|
|
35
|
+
HTTP client for the **juniper-recurrence** service — the FastAPI app wrapping the Δt-native LMU
|
|
36
|
+
recurrence model and its cross-validation API. A lean `requests`-based client mirroring
|
|
37
|
+
[`juniper-data-client`](https://github.com/pcalnon/juniper-data-client) and
|
|
38
|
+
[`juniper-cascor-client`](https://github.com/pcalnon/juniper-cascor-client), so consumers
|
|
39
|
+
(notably **juniper-canopy**'s recurrence backend adapter) drive every Juniper backend the same way.
|
|
40
|
+
|
|
41
|
+
## Install
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install juniper-recurrence-client # once published
|
|
45
|
+
pip install -e ".[test]" # local development
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
`requests`-only at the core; `pip install juniper-recurrence-client[observability]` adds the
|
|
49
|
+
optional `juniper-observability` integration (X-Request-ID propagation + the `on_request` hook).
|
|
50
|
+
|
|
51
|
+
## Quick start
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
from juniper_recurrence_client import JuniperRecurrenceClient
|
|
55
|
+
|
|
56
|
+
client = JuniperRecurrenceClient("http://localhost:8211", api_key="…")
|
|
57
|
+
|
|
58
|
+
# Train the LMU regressor on a dataset (by id / name / generator)
|
|
59
|
+
client.train(name="equities", d=16)
|
|
60
|
+
|
|
61
|
+
# Predict — inline X with Δt, or a dataset reference
|
|
62
|
+
client.predict(dataset_id="ds-1")
|
|
63
|
+
|
|
64
|
+
# Walk-forward cross-validation over the dataset's _full split
|
|
65
|
+
result = client.crossval(name="equities", n_folds=4, scheme="expanding", embargo=2)
|
|
66
|
+
print(result["eval_aggregate"])
|
|
67
|
+
|
|
68
|
+
# Inspect
|
|
69
|
+
client.get_model() # topology + metrics
|
|
70
|
+
client.training_status() # state + events
|
|
71
|
+
client.is_ready() # readiness probe
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## API surface
|
|
75
|
+
|
|
76
|
+
| Method | Endpoint |
|
|
77
|
+
|--------|----------|
|
|
78
|
+
| `train(*, dataset_id / name / generator, params, split, d, theta, ridge)` | `POST /v1/train` |
|
|
79
|
+
| `training_status()` | `GET /v1/training/status` |
|
|
80
|
+
| `predict(*, X / dt / target_dt / seq_lengths, or a dataset ref)` | `POST /v1/predict` |
|
|
81
|
+
| `crossval(*, n_folds, scheme, embargo, min_train, dataset ref, d, theta, ridge)` | `POST /v1/crossval` |
|
|
82
|
+
| `crossval_status()` | `GET /v1/crossval/status` |
|
|
83
|
+
| `get_model()` | `GET /v1/model` |
|
|
84
|
+
| `get_dataset()` | `GET /v1/dataset` |
|
|
85
|
+
| `health_check()` / `is_ready()` / `wait_for_ready()` | `GET /v1/health[/ready]` |
|
|
86
|
+
|
|
87
|
+
## Authentication
|
|
88
|
+
|
|
89
|
+
Pass `api_key=…`, or set `JUNIPER_RECURRENCE_API_KEY` (or the Docker-secret
|
|
90
|
+
`JUNIPER_RECURRENCE_API_KEY_FILE`, a path whose stripped contents are the key). The key is sent
|
|
91
|
+
as the `X-API-Key` header. Note the asymmetry: the **server** reads the *plural*
|
|
92
|
+
`JUNIPER_RECURRENCE_API_KEYS` (its accepted set); the **client** sends one key under the
|
|
93
|
+
*singular* env var.
|
|
94
|
+
|
|
95
|
+
## Errors
|
|
96
|
+
|
|
97
|
+
All errors derive from `JuniperRecurrenceClientError`: `JuniperRecurrenceConnectionError`,
|
|
98
|
+
`JuniperRecurrenceTimeoutError`, `JuniperRecurrenceNotFoundError` (404),
|
|
99
|
+
`JuniperRecurrenceConflictError` (409 — a run already in progress, or no trained model yet),
|
|
100
|
+
`JuniperRecurrenceValidationError` (400/422), `JuniperRecurrenceConfigurationError`.
|
|
101
|
+
|
|
102
|
+
## License
|
|
103
|
+
|
|
104
|
+
MIT — see [LICENSE](https://github.com/pcalnon/juniper-recurrence/blob/main/LICENSE).
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# juniper-recurrence-client
|
|
2
|
+
|
|
3
|
+
HTTP client for the **juniper-recurrence** service — the FastAPI app wrapping the Δt-native LMU
|
|
4
|
+
recurrence model and its cross-validation API. A lean `requests`-based client mirroring
|
|
5
|
+
[`juniper-data-client`](https://github.com/pcalnon/juniper-data-client) and
|
|
6
|
+
[`juniper-cascor-client`](https://github.com/pcalnon/juniper-cascor-client), so consumers
|
|
7
|
+
(notably **juniper-canopy**'s recurrence backend adapter) drive every Juniper backend the same way.
|
|
8
|
+
|
|
9
|
+
## Install
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
pip install juniper-recurrence-client # once published
|
|
13
|
+
pip install -e ".[test]" # local development
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
`requests`-only at the core; `pip install juniper-recurrence-client[observability]` adds the
|
|
17
|
+
optional `juniper-observability` integration (X-Request-ID propagation + the `on_request` hook).
|
|
18
|
+
|
|
19
|
+
## Quick start
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
from juniper_recurrence_client import JuniperRecurrenceClient
|
|
23
|
+
|
|
24
|
+
client = JuniperRecurrenceClient("http://localhost:8211", api_key="…")
|
|
25
|
+
|
|
26
|
+
# Train the LMU regressor on a dataset (by id / name / generator)
|
|
27
|
+
client.train(name="equities", d=16)
|
|
28
|
+
|
|
29
|
+
# Predict — inline X with Δt, or a dataset reference
|
|
30
|
+
client.predict(dataset_id="ds-1")
|
|
31
|
+
|
|
32
|
+
# Walk-forward cross-validation over the dataset's _full split
|
|
33
|
+
result = client.crossval(name="equities", n_folds=4, scheme="expanding", embargo=2)
|
|
34
|
+
print(result["eval_aggregate"])
|
|
35
|
+
|
|
36
|
+
# Inspect
|
|
37
|
+
client.get_model() # topology + metrics
|
|
38
|
+
client.training_status() # state + events
|
|
39
|
+
client.is_ready() # readiness probe
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## API surface
|
|
43
|
+
|
|
44
|
+
| Method | Endpoint |
|
|
45
|
+
|--------|----------|
|
|
46
|
+
| `train(*, dataset_id / name / generator, params, split, d, theta, ridge)` | `POST /v1/train` |
|
|
47
|
+
| `training_status()` | `GET /v1/training/status` |
|
|
48
|
+
| `predict(*, X / dt / target_dt / seq_lengths, or a dataset ref)` | `POST /v1/predict` |
|
|
49
|
+
| `crossval(*, n_folds, scheme, embargo, min_train, dataset ref, d, theta, ridge)` | `POST /v1/crossval` |
|
|
50
|
+
| `crossval_status()` | `GET /v1/crossval/status` |
|
|
51
|
+
| `get_model()` | `GET /v1/model` |
|
|
52
|
+
| `get_dataset()` | `GET /v1/dataset` |
|
|
53
|
+
| `health_check()` / `is_ready()` / `wait_for_ready()` | `GET /v1/health[/ready]` |
|
|
54
|
+
|
|
55
|
+
## Authentication
|
|
56
|
+
|
|
57
|
+
Pass `api_key=…`, or set `JUNIPER_RECURRENCE_API_KEY` (or the Docker-secret
|
|
58
|
+
`JUNIPER_RECURRENCE_API_KEY_FILE`, a path whose stripped contents are the key). The key is sent
|
|
59
|
+
as the `X-API-Key` header. Note the asymmetry: the **server** reads the *plural*
|
|
60
|
+
`JUNIPER_RECURRENCE_API_KEYS` (its accepted set); the **client** sends one key under the
|
|
61
|
+
*singular* env var.
|
|
62
|
+
|
|
63
|
+
## Errors
|
|
64
|
+
|
|
65
|
+
All errors derive from `JuniperRecurrenceClientError`: `JuniperRecurrenceConnectionError`,
|
|
66
|
+
`JuniperRecurrenceTimeoutError`, `JuniperRecurrenceNotFoundError` (404),
|
|
67
|
+
`JuniperRecurrenceConflictError` (409 — a run already in progress, or no trained model yet),
|
|
68
|
+
`JuniperRecurrenceValidationError` (400/422), `JuniperRecurrenceConfigurationError`.
|
|
69
|
+
|
|
70
|
+
## License
|
|
71
|
+
|
|
72
|
+
MIT — see [LICENSE](https://github.com/pcalnon/juniper-recurrence/blob/main/LICENSE).
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""juniper-recurrence-client — HTTP client for the juniper-recurrence service.
|
|
2
|
+
|
|
3
|
+
A lean ``requests``-based client wrapping the juniper-recurrence FastAPI app's REST surface
|
|
4
|
+
(train / predict / cross-validate / inspect / health). Mirrors juniper-data-client and
|
|
5
|
+
juniper-cascor-client so consumers (notably juniper-canopy's recurrence backend adapter) drive
|
|
6
|
+
every Juniper backend the same way.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from juniper_recurrence_client._version import __version__
|
|
12
|
+
from juniper_recurrence_client.client import JuniperRecurrenceClient, RequestHook
|
|
13
|
+
from juniper_recurrence_client.exceptions import (
|
|
14
|
+
JuniperRecurrenceClientError,
|
|
15
|
+
JuniperRecurrenceConfigurationError,
|
|
16
|
+
JuniperRecurrenceConflictError,
|
|
17
|
+
JuniperRecurrenceConnectionError,
|
|
18
|
+
JuniperRecurrenceNotFoundError,
|
|
19
|
+
JuniperRecurrenceTimeoutError,
|
|
20
|
+
JuniperRecurrenceValidationError,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"__version__",
|
|
25
|
+
"JuniperRecurrenceClient",
|
|
26
|
+
"RequestHook",
|
|
27
|
+
"JuniperRecurrenceClientError",
|
|
28
|
+
"JuniperRecurrenceConnectionError",
|
|
29
|
+
"JuniperRecurrenceTimeoutError",
|
|
30
|
+
"JuniperRecurrenceNotFoundError",
|
|
31
|
+
"JuniperRecurrenceConflictError",
|
|
32
|
+
"JuniperRecurrenceValidationError",
|
|
33
|
+
"JuniperRecurrenceConfigurationError",
|
|
34
|
+
]
|
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
"""REST API client for the juniper-recurrence service.
|
|
2
|
+
|
|
3
|
+
A lean ``requests``-based client wrapping the juniper-recurrence FastAPI app's REST surface
|
|
4
|
+
(train / predict / cross-validate / inspect / health), for consumers such as juniper-canopy's
|
|
5
|
+
recurrence backend adapter. Mirrors juniper-data-client's transport machinery: an idempotent-only
|
|
6
|
+
retry policy, ``X-API-Key`` auth with ``_FILE`` Docker-secret indirection, typed exceptions, the
|
|
7
|
+
optional ``on_request`` instrumentation hook, and best-effort ``X-Request-ID`` propagation.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import time
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any, Callable, Optional
|
|
17
|
+
from urllib.parse import urlparse
|
|
18
|
+
|
|
19
|
+
import requests
|
|
20
|
+
from requests.adapters import HTTPAdapter
|
|
21
|
+
from urllib3.util.retry import Retry
|
|
22
|
+
|
|
23
|
+
from juniper_recurrence_client.constants import (
|
|
24
|
+
API_KEY_ENV_VAR,
|
|
25
|
+
API_KEY_FILE_ENV_VAR,
|
|
26
|
+
API_KEY_HEADER_NAME,
|
|
27
|
+
API_VERSION_PATH_SUFFIX,
|
|
28
|
+
DEFAULT_BACKOFF_FACTOR,
|
|
29
|
+
DEFAULT_BASE_URL,
|
|
30
|
+
DEFAULT_READY_POLL_INTERVAL,
|
|
31
|
+
DEFAULT_READY_TIMEOUT,
|
|
32
|
+
DEFAULT_RETRIES,
|
|
33
|
+
DEFAULT_TIMEOUT,
|
|
34
|
+
DEFAULT_URL_SCHEME_PREFIX,
|
|
35
|
+
ENDPOINT_CROSSVAL,
|
|
36
|
+
ENDPOINT_CROSSVAL_STATUS,
|
|
37
|
+
ENDPOINT_DATASET,
|
|
38
|
+
ENDPOINT_HEALTH,
|
|
39
|
+
ENDPOINT_HEALTH_READY,
|
|
40
|
+
ENDPOINT_MODEL,
|
|
41
|
+
ENDPOINT_PREDICT,
|
|
42
|
+
ENDPOINT_TRAIN,
|
|
43
|
+
ENDPOINT_TRAINING_STATUS,
|
|
44
|
+
HEALTH_READY_STATUS,
|
|
45
|
+
HTTP_400_BAD_REQUEST,
|
|
46
|
+
HTTP_404_NOT_FOUND,
|
|
47
|
+
HTTP_409_CONFLICT,
|
|
48
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
49
|
+
HTTP_POOL_CONNECTIONS,
|
|
50
|
+
HTTP_POOL_MAXSIZE,
|
|
51
|
+
RETRY_ALLOWED_METHODS,
|
|
52
|
+
RETRYABLE_STATUS_CODES,
|
|
53
|
+
URL_SCHEME_PREFIXES,
|
|
54
|
+
)
|
|
55
|
+
from juniper_recurrence_client.exceptions import (
|
|
56
|
+
JuniperRecurrenceClientError,
|
|
57
|
+
JuniperRecurrenceConflictError,
|
|
58
|
+
JuniperRecurrenceConnectionError,
|
|
59
|
+
JuniperRecurrenceNotFoundError,
|
|
60
|
+
JuniperRecurrenceTimeoutError,
|
|
61
|
+
JuniperRecurrenceValidationError,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
logger = logging.getLogger("juniper_recurrence_client.client")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _resolve_api_key_from_env() -> Optional[str]:
|
|
68
|
+
"""Resolve the juniper-recurrence API key from the environment.
|
|
69
|
+
|
|
70
|
+
Honors the Docker-secret ``JUNIPER_RECURRENCE_API_KEY_FILE`` indirection (a file whose
|
|
71
|
+
stripped contents are the key) before the plain ``JUNIPER_RECURRENCE_API_KEY`` env var, so a
|
|
72
|
+
consumer that mounts the key as a file and leaves ``api_key`` unset still authenticates.
|
|
73
|
+
"""
|
|
74
|
+
file_path = os.environ.get(API_KEY_FILE_ENV_VAR)
|
|
75
|
+
if file_path:
|
|
76
|
+
try:
|
|
77
|
+
content = Path(file_path).read_text(encoding="utf-8").strip()
|
|
78
|
+
except OSError:
|
|
79
|
+
content = ""
|
|
80
|
+
if content:
|
|
81
|
+
return content
|
|
82
|
+
return os.environ.get(API_KEY_ENV_VAR)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
#: Optional instrumentation hook, invoked once per HTTP call with
|
|
86
|
+
#: ``(method, url, status, duration_ms, error)``. ``error is None`` is the canonical success
|
|
87
|
+
#: signal (``status`` may be set even on the typed-error paths). Mirrors juniper-data-client so
|
|
88
|
+
#: canopy/cascor can pass the same Prometheus/structured-log closure they already use.
|
|
89
|
+
RequestHook = Callable[[str, str, Optional[int], float, Optional[BaseException]], None]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _noop_request_hook(
|
|
93
|
+
method: str,
|
|
94
|
+
url: str,
|
|
95
|
+
status: Optional[int],
|
|
96
|
+
duration_ms: float,
|
|
97
|
+
error: Optional[BaseException],
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Default :data:`RequestHook` — does nothing (named so the default is a real callable)."""
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _dataset_ref(
|
|
103
|
+
*,
|
|
104
|
+
dataset_id: Optional[str],
|
|
105
|
+
name: Optional[str],
|
|
106
|
+
generator: Optional[str],
|
|
107
|
+
params: Optional[dict[str, Any]],
|
|
108
|
+
split: str,
|
|
109
|
+
) -> dict[str, Any]:
|
|
110
|
+
"""Build the app's ``DatasetRef`` body from selection kwargs.
|
|
111
|
+
|
|
112
|
+
Exactly one of ``dataset_id`` / ``name`` / ``generator`` is expected; the server validates
|
|
113
|
+
that invariant and returns 422 otherwise.
|
|
114
|
+
"""
|
|
115
|
+
ref: dict[str, Any] = {"split": split}
|
|
116
|
+
if dataset_id is not None:
|
|
117
|
+
ref["dataset_id"] = dataset_id
|
|
118
|
+
if name is not None:
|
|
119
|
+
ref["name"] = name
|
|
120
|
+
if generator is not None:
|
|
121
|
+
ref["generator"] = generator
|
|
122
|
+
if params is not None:
|
|
123
|
+
ref["params"] = params
|
|
124
|
+
return ref
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class JuniperRecurrenceClient:
|
|
128
|
+
"""Client for the juniper-recurrence REST API (train / predict / cross-validate / inspect).
|
|
129
|
+
|
|
130
|
+
Automatic retry (idempotent methods only), connection pooling, and ``X-API-Key`` auth.
|
|
131
|
+
|
|
132
|
+
Example:
|
|
133
|
+
>>> client = JuniperRecurrenceClient("http://localhost:8211")
|
|
134
|
+
>>> client.train(name="equities", d=16)
|
|
135
|
+
>>> preds = client.predict(dataset_id="ds-1")
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
141
|
+
timeout: int = DEFAULT_TIMEOUT,
|
|
142
|
+
retries: int = DEFAULT_RETRIES,
|
|
143
|
+
backoff_factor: float = DEFAULT_BACKOFF_FACTOR,
|
|
144
|
+
api_key: Optional[str] = None,
|
|
145
|
+
on_request: Optional[RequestHook] = None,
|
|
146
|
+
) -> None:
|
|
147
|
+
"""Initialize the client.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
base_url: Base URL of the juniper-recurrence app (default ``http://localhost:8211``).
|
|
151
|
+
timeout: Per-request timeout in seconds.
|
|
152
|
+
retries: Retry attempts for transient failures on idempotent methods.
|
|
153
|
+
backoff_factor: Exponential backoff factor for retries.
|
|
154
|
+
api_key: API key for ``X-API-Key`` auth. If unset, resolved from
|
|
155
|
+
``JUNIPER_RECURRENCE_API_KEY`` (and its ``_FILE`` form).
|
|
156
|
+
on_request: Optional instrumentation hook (see :data:`RequestHook`); defaults to a
|
|
157
|
+
no-op. Hook exceptions are caught and logged so instrumentation never crashes a
|
|
158
|
+
request path.
|
|
159
|
+
"""
|
|
160
|
+
self.base_url = self._normalize_url(base_url)
|
|
161
|
+
self.timeout = timeout
|
|
162
|
+
self.retries = retries
|
|
163
|
+
self.backoff_factor = backoff_factor
|
|
164
|
+
self.session = self._create_session()
|
|
165
|
+
self._on_request: RequestHook = on_request or _noop_request_hook
|
|
166
|
+
|
|
167
|
+
resolved_api_key = api_key or _resolve_api_key_from_env()
|
|
168
|
+
if resolved_api_key:
|
|
169
|
+
self.session.headers[API_KEY_HEADER_NAME] = resolved_api_key
|
|
170
|
+
|
|
171
|
+
def _normalize_url(self, url: str) -> str:
|
|
172
|
+
"""Normalize the base URL: ensure a scheme, drop a trailing slash and any ``/v1`` suffix."""
|
|
173
|
+
url = url.strip()
|
|
174
|
+
if not url.startswith(URL_SCHEME_PREFIXES):
|
|
175
|
+
url = f"{DEFAULT_URL_SCHEME_PREFIX}{url}"
|
|
176
|
+
parsed = urlparse(url)
|
|
177
|
+
normalized = f"{parsed.scheme}://{parsed.netloc}{parsed.path}".rstrip("/")
|
|
178
|
+
if normalized.endswith(API_VERSION_PATH_SUFFIX):
|
|
179
|
+
normalized = normalized[: -len(API_VERSION_PATH_SUFFIX)]
|
|
180
|
+
return normalized
|
|
181
|
+
|
|
182
|
+
def _create_session(self) -> requests.Session:
|
|
183
|
+
"""Create a ``requests.Session`` with the idempotent-only retry policy + pooling."""
|
|
184
|
+
session = requests.Session()
|
|
185
|
+
retry_strategy = Retry(
|
|
186
|
+
total=self.retries,
|
|
187
|
+
backoff_factor=self.backoff_factor,
|
|
188
|
+
status_forcelist=RETRYABLE_STATUS_CODES,
|
|
189
|
+
allowed_methods=RETRY_ALLOWED_METHODS,
|
|
190
|
+
)
|
|
191
|
+
adapter = HTTPAdapter(
|
|
192
|
+
max_retries=retry_strategy,
|
|
193
|
+
pool_connections=HTTP_POOL_CONNECTIONS,
|
|
194
|
+
pool_maxsize=HTTP_POOL_MAXSIZE,
|
|
195
|
+
)
|
|
196
|
+
for scheme in URL_SCHEME_PREFIXES:
|
|
197
|
+
session.mount(scheme, adapter)
|
|
198
|
+
return session
|
|
199
|
+
|
|
200
|
+
def _request(self, method: str, endpoint: str, **kwargs: Any) -> requests.Response: # noqa: C901
|
|
201
|
+
"""Make an HTTP request, mapping transport/HTTP errors to typed exceptions.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
JuniperRecurrenceConnectionError / JuniperRecurrenceTimeoutError: transport failures.
|
|
205
|
+
JuniperRecurrenceNotFoundError (404), JuniperRecurrenceConflictError (409),
|
|
206
|
+
JuniperRecurrenceValidationError (400/422), or JuniperRecurrenceClientError (other).
|
|
207
|
+
"""
|
|
208
|
+
url = f"{self.base_url}{endpoint}"
|
|
209
|
+
kwargs.setdefault("timeout", self.timeout)
|
|
210
|
+
|
|
211
|
+
# Best-effort X-Request-ID propagation via juniper-observability (no-op if absent or
|
|
212
|
+
# unset); a caller-supplied X-Request-ID always wins.
|
|
213
|
+
headers = dict(kwargs.get("headers") or {})
|
|
214
|
+
if "X-Request-ID" not in headers:
|
|
215
|
+
try:
|
|
216
|
+
from juniper_observability import request_id_var # noqa: PLC0415
|
|
217
|
+
|
|
218
|
+
rid = request_id_var.get()
|
|
219
|
+
if rid:
|
|
220
|
+
headers["X-Request-ID"] = rid
|
|
221
|
+
kwargs["headers"] = headers
|
|
222
|
+
except (ImportError, LookupError):
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
start = time.monotonic()
|
|
226
|
+
response: Optional[requests.Response] = None
|
|
227
|
+
outgoing_error: Optional[BaseException] = None
|
|
228
|
+
try:
|
|
229
|
+
try:
|
|
230
|
+
response = self.session.request(method, url, **kwargs)
|
|
231
|
+
except requests.exceptions.ConnectionError as e:
|
|
232
|
+
outgoing_error = JuniperRecurrenceConnectionError(f"Failed to connect to juniper-recurrence at {self.base_url}: {e}")
|
|
233
|
+
raise outgoing_error from e
|
|
234
|
+
except requests.exceptions.Timeout as e:
|
|
235
|
+
outgoing_error = JuniperRecurrenceTimeoutError(f"Request to {url} timed out after {self.timeout}s: {e}")
|
|
236
|
+
raise outgoing_error from e
|
|
237
|
+
except requests.exceptions.RequestException as e:
|
|
238
|
+
outgoing_error = JuniperRecurrenceClientError(f"Request failed: {e}")
|
|
239
|
+
raise outgoing_error from e
|
|
240
|
+
|
|
241
|
+
if response.ok:
|
|
242
|
+
return response
|
|
243
|
+
|
|
244
|
+
error_detail = response.text
|
|
245
|
+
try:
|
|
246
|
+
error_json = response.json()
|
|
247
|
+
if "detail" in error_json:
|
|
248
|
+
error_detail = error_json["detail"]
|
|
249
|
+
except (ValueError, KeyError):
|
|
250
|
+
error_detail = response.text
|
|
251
|
+
|
|
252
|
+
if response.status_code == HTTP_404_NOT_FOUND:
|
|
253
|
+
outgoing_error = JuniperRecurrenceNotFoundError(f"Resource not found: {error_detail}")
|
|
254
|
+
raise outgoing_error
|
|
255
|
+
elif response.status_code == HTTP_409_CONFLICT:
|
|
256
|
+
outgoing_error = JuniperRecurrenceConflictError(f"Conflict: {error_detail}")
|
|
257
|
+
raise outgoing_error
|
|
258
|
+
elif response.status_code in (HTTP_400_BAD_REQUEST, HTTP_422_UNPROCESSABLE_ENTITY):
|
|
259
|
+
outgoing_error = JuniperRecurrenceValidationError(f"Validation error: {error_detail}")
|
|
260
|
+
raise outgoing_error
|
|
261
|
+
else:
|
|
262
|
+
outgoing_error = JuniperRecurrenceClientError(f"Request failed ({response.status_code}): {error_detail}")
|
|
263
|
+
raise outgoing_error
|
|
264
|
+
finally:
|
|
265
|
+
duration_ms = (time.monotonic() - start) * 1000.0
|
|
266
|
+
status = response.status_code if response is not None else None
|
|
267
|
+
try:
|
|
268
|
+
self._on_request(method, url, status, duration_ms, outgoing_error)
|
|
269
|
+
except Exception: # noqa: BLE001 — instrumentation must not crash production paths
|
|
270
|
+
logger.warning("on_request hook raised; suppressed to keep request path resilient", exc_info=True)
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def _parse_json(response: requests.Response) -> Any:
|
|
274
|
+
"""Parse a response body as JSON, surfacing a typed error on a malformed body."""
|
|
275
|
+
try:
|
|
276
|
+
return response.json()
|
|
277
|
+
except ValueError as e:
|
|
278
|
+
preview = (response.text or "")[:200]
|
|
279
|
+
raise JuniperRecurrenceClientError(f"Malformed JSON response from {response.url}: {e}: {preview!r}") from e
|
|
280
|
+
|
|
281
|
+
# ─── Training ─────────────────────────────────────────────────────────────
|
|
282
|
+
|
|
283
|
+
def train(
|
|
284
|
+
self,
|
|
285
|
+
*,
|
|
286
|
+
dataset_id: Optional[str] = None,
|
|
287
|
+
name: Optional[str] = None,
|
|
288
|
+
generator: Optional[str] = None,
|
|
289
|
+
params: Optional[dict[str, Any]] = None,
|
|
290
|
+
split: str = "train",
|
|
291
|
+
d: Optional[int] = None,
|
|
292
|
+
theta: Optional[float] = None,
|
|
293
|
+
ridge: Optional[float] = None,
|
|
294
|
+
) -> dict[str, Any]:
|
|
295
|
+
"""``POST /v1/train`` — synchronously fit the LMU regressor on a dataset split.
|
|
296
|
+
|
|
297
|
+
Supply exactly one of ``dataset_id`` / ``name`` / ``generator``. Returns the
|
|
298
|
+
``TrainResponse`` (``final_metrics``, ``n_epochs``, ``stopped_reason``, ``dataset``).
|
|
299
|
+
Raises :class:`JuniperRecurrenceConflictError` (409) if a run is already in progress.
|
|
300
|
+
"""
|
|
301
|
+
body: dict[str, Any] = {"dataset": _dataset_ref(dataset_id=dataset_id, name=name, generator=generator, params=params, split=split)}
|
|
302
|
+
if d is not None:
|
|
303
|
+
body["d"] = d
|
|
304
|
+
if theta is not None:
|
|
305
|
+
body["theta"] = theta
|
|
306
|
+
if ridge is not None:
|
|
307
|
+
body["ridge"] = ridge
|
|
308
|
+
return self._parse_json(self._request("POST", ENDPOINT_TRAIN, json=body))
|
|
309
|
+
|
|
310
|
+
def training_status(self) -> dict[str, Any]:
|
|
311
|
+
"""``GET /v1/training/status`` — current training state, metrics, and emitted events."""
|
|
312
|
+
return self._parse_json(self._request("GET", ENDPOINT_TRAINING_STATUS))
|
|
313
|
+
|
|
314
|
+
# ─── Prediction ───────────────────────────────────────────────────────────
|
|
315
|
+
|
|
316
|
+
def predict(
|
|
317
|
+
self,
|
|
318
|
+
*,
|
|
319
|
+
X: Optional[Any] = None,
|
|
320
|
+
dt: Optional[Any] = None,
|
|
321
|
+
target_dt: Optional[Any] = None,
|
|
322
|
+
seq_lengths: Optional[Any] = None,
|
|
323
|
+
dataset_id: Optional[str] = None,
|
|
324
|
+
name: Optional[str] = None,
|
|
325
|
+
generator: Optional[str] = None,
|
|
326
|
+
params: Optional[dict[str, Any]] = None,
|
|
327
|
+
split: str = "train",
|
|
328
|
+
) -> dict[str, Any]:
|
|
329
|
+
"""``POST /v1/predict`` — predictions from the trained model.
|
|
330
|
+
|
|
331
|
+
Supply exactly one of inline ``X`` (optionally with ``dt`` / ``target_dt`` /
|
|
332
|
+
``seq_lengths``) or a dataset reference. Returns ``{"predictions": ..., "shape": ...}``.
|
|
333
|
+
"""
|
|
334
|
+
body: dict[str, Any] = {}
|
|
335
|
+
if X is not None:
|
|
336
|
+
body["X"] = X
|
|
337
|
+
if dt is not None:
|
|
338
|
+
body["dt"] = dt
|
|
339
|
+
if target_dt is not None:
|
|
340
|
+
body["target_dt"] = target_dt
|
|
341
|
+
if seq_lengths is not None:
|
|
342
|
+
body["seq_lengths"] = seq_lengths
|
|
343
|
+
if dataset_id is not None or name is not None or generator is not None:
|
|
344
|
+
body["dataset"] = _dataset_ref(dataset_id=dataset_id, name=name, generator=generator, params=params, split=split)
|
|
345
|
+
return self._parse_json(self._request("POST", ENDPOINT_PREDICT, json=body))
|
|
346
|
+
|
|
347
|
+
# ─── Cross-validation ──────────────────────────────────────────────────────
|
|
348
|
+
|
|
349
|
+
def crossval(
|
|
350
|
+
self,
|
|
351
|
+
*,
|
|
352
|
+
n_folds: int,
|
|
353
|
+
dataset_id: Optional[str] = None,
|
|
354
|
+
name: Optional[str] = None,
|
|
355
|
+
generator: Optional[str] = None,
|
|
356
|
+
params: Optional[dict[str, Any]] = None,
|
|
357
|
+
split: str = "train",
|
|
358
|
+
scheme: str = "expanding",
|
|
359
|
+
embargo: int = 0,
|
|
360
|
+
min_train: Optional[int] = None,
|
|
361
|
+
d: Optional[int] = None,
|
|
362
|
+
theta: Optional[float] = None,
|
|
363
|
+
ridge: Optional[float] = None,
|
|
364
|
+
) -> dict[str, Any]:
|
|
365
|
+
"""``POST /v1/crossval`` — synchronous walk-forward cross-validation over the ``_full`` split.
|
|
366
|
+
|
|
367
|
+
Returns the ``CrossValResponse`` (per-fold ``folds`` + ``eval_aggregate`` / ``eval_std``).
|
|
368
|
+
``scheme`` is ``"expanding"`` or ``"rolling"``. Raises 409 if a CV run is already running.
|
|
369
|
+
"""
|
|
370
|
+
body: dict[str, Any] = {
|
|
371
|
+
"dataset": _dataset_ref(dataset_id=dataset_id, name=name, generator=generator, params=params, split=split),
|
|
372
|
+
"n_folds": n_folds,
|
|
373
|
+
"scheme": scheme,
|
|
374
|
+
"embargo": embargo,
|
|
375
|
+
}
|
|
376
|
+
if min_train is not None:
|
|
377
|
+
body["min_train"] = min_train
|
|
378
|
+
if d is not None:
|
|
379
|
+
body["d"] = d
|
|
380
|
+
if theta is not None:
|
|
381
|
+
body["theta"] = theta
|
|
382
|
+
if ridge is not None:
|
|
383
|
+
body["ridge"] = ridge
|
|
384
|
+
return self._parse_json(self._request("POST", ENDPOINT_CROSSVAL, json=body))
|
|
385
|
+
|
|
386
|
+
def crossval_status(self) -> dict[str, Any]:
|
|
387
|
+
"""``GET /v1/crossval/status`` — the most recent cross-validation result, if any."""
|
|
388
|
+
return self._parse_json(self._request("GET", ENDPOINT_CROSSVAL_STATUS))
|
|
389
|
+
|
|
390
|
+
# ─── Inspection ────────────────────────────────────────────────────────────
|
|
391
|
+
|
|
392
|
+
def get_model(self) -> dict[str, Any]:
|
|
393
|
+
"""``GET /v1/model`` — the trained model's topology + metrics (409 if none trained)."""
|
|
394
|
+
return self._parse_json(self._request("GET", ENDPOINT_MODEL))
|
|
395
|
+
|
|
396
|
+
def get_dataset(self) -> dict[str, Any]:
|
|
397
|
+
"""``GET /v1/dataset`` — descriptor of the split the model was trained on (409 if none)."""
|
|
398
|
+
return self._parse_json(self._request("GET", ENDPOINT_DATASET))
|
|
399
|
+
|
|
400
|
+
# ─── Health / Readiness ────────────────────────────────────────────────────
|
|
401
|
+
|
|
402
|
+
def health_check(self) -> dict[str, Any]:
|
|
403
|
+
"""``GET /v1/health`` — liveness."""
|
|
404
|
+
return self._parse_json(self._request("GET", ENDPOINT_HEALTH))
|
|
405
|
+
|
|
406
|
+
def is_ready(self) -> bool:
|
|
407
|
+
"""``GET /v1/health/ready`` — ``True`` iff the service reports ready."""
|
|
408
|
+
try:
|
|
409
|
+
payload = self._parse_json(self._request("GET", ENDPOINT_HEALTH_READY))
|
|
410
|
+
except JuniperRecurrenceClientError:
|
|
411
|
+
return False
|
|
412
|
+
return payload.get("status") == HEALTH_READY_STATUS
|
|
413
|
+
|
|
414
|
+
def wait_for_ready(self, timeout: float = DEFAULT_READY_TIMEOUT, poll_interval: float = DEFAULT_READY_POLL_INTERVAL) -> bool:
|
|
415
|
+
"""Poll ``/v1/health/ready`` until ready or ``timeout`` seconds elapse."""
|
|
416
|
+
deadline = time.monotonic() + timeout
|
|
417
|
+
while time.monotonic() < deadline:
|
|
418
|
+
if self.is_ready():
|
|
419
|
+
return True
|
|
420
|
+
time.sleep(poll_interval)
|
|
421
|
+
return self.is_ready()
|
|
422
|
+
|
|
423
|
+
# ─── Lifecycle ─────────────────────────────────────────────────────────────
|
|
424
|
+
|
|
425
|
+
def close(self) -> None:
|
|
426
|
+
"""Close the underlying HTTP session."""
|
|
427
|
+
self.session.close()
|
|
428
|
+
|
|
429
|
+
def __enter__(self) -> "JuniperRecurrenceClient":
|
|
430
|
+
return self
|
|
431
|
+
|
|
432
|
+
def __exit__(self, *exc: object) -> None:
|
|
433
|
+
self.close()
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Protocol-level constants for the juniper-recurrence REST client.
|
|
2
|
+
|
|
3
|
+
Centralizes the literals used by ``client.py`` — base URL, endpoint paths, header names,
|
|
4
|
+
HTTP/retry configuration — mirroring juniper-data-client's constants module so the wire
|
|
5
|
+
contract is discoverable in one place.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Tuple
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"DEFAULT_BASE_URL",
|
|
12
|
+
"DEFAULT_TIMEOUT",
|
|
13
|
+
"DEFAULT_RETRIES",
|
|
14
|
+
"DEFAULT_BACKOFF_FACTOR",
|
|
15
|
+
"RETRYABLE_STATUS_CODES",
|
|
16
|
+
"RETRY_ALLOWED_METHODS",
|
|
17
|
+
"HTTP_POOL_CONNECTIONS",
|
|
18
|
+
"HTTP_POOL_MAXSIZE",
|
|
19
|
+
"URL_SCHEME_PREFIXES",
|
|
20
|
+
"DEFAULT_URL_SCHEME_PREFIX",
|
|
21
|
+
"API_VERSION_PATH_SUFFIX",
|
|
22
|
+
"DEFAULT_READY_TIMEOUT",
|
|
23
|
+
"DEFAULT_READY_POLL_INTERVAL",
|
|
24
|
+
"HEALTH_READY_STATUS",
|
|
25
|
+
"API_KEY_HEADER_NAME",
|
|
26
|
+
"API_KEY_ENV_VAR",
|
|
27
|
+
"API_KEY_FILE_ENV_VAR",
|
|
28
|
+
"ENDPOINT_HEALTH",
|
|
29
|
+
"ENDPOINT_HEALTH_READY",
|
|
30
|
+
"ENDPOINT_TRAIN",
|
|
31
|
+
"ENDPOINT_TRAINING_STATUS",
|
|
32
|
+
"ENDPOINT_PREDICT",
|
|
33
|
+
"ENDPOINT_MODEL",
|
|
34
|
+
"ENDPOINT_DATASET",
|
|
35
|
+
"ENDPOINT_CROSSVAL",
|
|
36
|
+
"ENDPOINT_CROSSVAL_STATUS",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
# ─── Service Configuration ───────────────────────────────────────────────────
|
|
40
|
+
|
|
41
|
+
# The juniper-recurrence app binds container port 8210; juniper-deploy maps host 8211 -> 8210,
|
|
42
|
+
# so the default host-facing base URL is 8211 (mirrors the deploy port map).
|
|
43
|
+
DEFAULT_BASE_URL: str = "http://localhost:8211"
|
|
44
|
+
|
|
45
|
+
# ─── HTTP Configuration ──────────────────────────────────────────────────────
|
|
46
|
+
|
|
47
|
+
DEFAULT_TIMEOUT: int = 30
|
|
48
|
+
DEFAULT_RETRIES: int = 3
|
|
49
|
+
DEFAULT_BACKOFF_FACTOR: float = 0.5
|
|
50
|
+
RETRYABLE_STATUS_CODES: List[int] = [429, 500, 502, 503, 504]
|
|
51
|
+
# Auto-retry is restricted to idempotent methods (RFC 9110 §9.2.2). The recurrence POSTs
|
|
52
|
+
# (train / predict / crossval) carry server-side state — train and crossval are lock-guarded —
|
|
53
|
+
# so a transient-5xx retry must not silently re-issue them. Only GET/HEAD auto-retry.
|
|
54
|
+
RETRY_ALLOWED_METHODS: List[str] = ["HEAD", "GET"]
|
|
55
|
+
HTTP_POOL_CONNECTIONS: int = 10
|
|
56
|
+
HTTP_POOL_MAXSIZE: int = 10
|
|
57
|
+
|
|
58
|
+
# ─── URL Normalization ───────────────────────────────────────────────────────
|
|
59
|
+
|
|
60
|
+
URL_SCHEME_PREFIXES: Tuple[str, ...] = ("http://", "https://")
|
|
61
|
+
DEFAULT_URL_SCHEME_PREFIX: str = "http://"
|
|
62
|
+
API_VERSION_PATH_SUFFIX: str = "/v1"
|
|
63
|
+
|
|
64
|
+
# ─── Readiness Polling ───────────────────────────────────────────────────────
|
|
65
|
+
|
|
66
|
+
DEFAULT_READY_TIMEOUT: float = 30.0
|
|
67
|
+
DEFAULT_READY_POLL_INTERVAL: float = 0.5
|
|
68
|
+
HEALTH_READY_STATUS: str = "ready"
|
|
69
|
+
|
|
70
|
+
# ─── Authentication ──────────────────────────────────────────────────────────
|
|
71
|
+
|
|
72
|
+
# The recurrence app enforces the X-API-Key header, reading its accepted keys from
|
|
73
|
+
# JUNIPER_RECURRENCE_API_KEYS (plural; CSV or JSON array, with _FILE indirection). The client
|
|
74
|
+
# sends a single key, resolved from the singular JUNIPER_RECURRENCE_API_KEY (and its _FILE
|
|
75
|
+
# Docker-secret form) — document the singular/plural asymmetry in AGENTS.md.
|
|
76
|
+
API_KEY_HEADER_NAME: str = "X-API-Key"
|
|
77
|
+
API_KEY_ENV_VAR: str = "JUNIPER_RECURRENCE_API_KEY"
|
|
78
|
+
API_KEY_FILE_ENV_VAR: str = f"{API_KEY_ENV_VAR}_FILE"
|
|
79
|
+
|
|
80
|
+
# ─── REST Endpoints (the juniper-recurrence app surface) ─────────────────────
|
|
81
|
+
|
|
82
|
+
ENDPOINT_HEALTH: str = "/v1/health"
|
|
83
|
+
ENDPOINT_HEALTH_READY: str = "/v1/health/ready"
|
|
84
|
+
ENDPOINT_TRAIN: str = "/v1/train"
|
|
85
|
+
ENDPOINT_TRAINING_STATUS: str = "/v1/training/status"
|
|
86
|
+
ENDPOINT_PREDICT: str = "/v1/predict"
|
|
87
|
+
ENDPOINT_MODEL: str = "/v1/model"
|
|
88
|
+
ENDPOINT_DATASET: str = "/v1/dataset"
|
|
89
|
+
ENDPOINT_CROSSVAL: str = "/v1/crossval"
|
|
90
|
+
ENDPOINT_CROSSVAL_STATUS: str = "/v1/crossval/status"
|
|
91
|
+
|
|
92
|
+
# ─── HTTP Status Codes ───────────────────────────────────────────────────────
|
|
93
|
+
|
|
94
|
+
HTTP_400_BAD_REQUEST: int = 400
|
|
95
|
+
HTTP_404_NOT_FOUND: int = 404
|
|
96
|
+
HTTP_409_CONFLICT: int = 409
|
|
97
|
+
HTTP_422_UNPROCESSABLE_ENTITY: int = 422
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Custom exceptions for the juniper-recurrence client library.
|
|
2
|
+
|
|
3
|
+
Mirrors juniper-data-client's flat hierarchy (one base + typed leaves), adding a
|
|
4
|
+
``JuniperRecurrenceConflictError`` for the recurrence app's ``409`` responses (a training /
|
|
5
|
+
cross-validation run already in progress, or an operation that needs a trained model that does
|
|
6
|
+
not yet exist) — a status the data-client surface never returns.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class JuniperRecurrenceClientError(Exception):
|
|
11
|
+
"""Base exception for all juniper-recurrence client errors."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class JuniperRecurrenceConnectionError(JuniperRecurrenceClientError):
|
|
15
|
+
"""Raised when the connection to the juniper-recurrence service fails."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class JuniperRecurrenceTimeoutError(JuniperRecurrenceClientError):
|
|
19
|
+
"""Raised when a request to the juniper-recurrence service times out."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class JuniperRecurrenceNotFoundError(JuniperRecurrenceClientError):
|
|
23
|
+
"""Raised when a requested resource is not found (404)."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class JuniperRecurrenceValidationError(JuniperRecurrenceClientError):
|
|
27
|
+
"""Raised when request parameters fail validation (400 / 422)."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class JuniperRecurrenceConflictError(JuniperRecurrenceClientError):
|
|
31
|
+
"""Raised on a 409 Conflict — a training/cross-validation run is already in progress, or
|
|
32
|
+
the operation requires a trained model/dataset that does not yet exist."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class JuniperRecurrenceConfigurationError(JuniperRecurrenceClientError):
|
|
36
|
+
"""Raised when juniper-recurrence client configuration is missing or invalid."""
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: juniper-recurrence-client
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: HTTP client for the juniper-recurrence service (the Δt-native LMU recurrence model + cross-validation API)
|
|
5
|
+
Author: Paul Calnon
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/pcalnon/juniper-recurrence
|
|
8
|
+
Project-URL: Repository, https://github.com/pcalnon/juniper-recurrence
|
|
9
|
+
Project-URL: Issues, https://github.com/pcalnon/juniper-recurrence/issues
|
|
10
|
+
Keywords: juniper,recurrence,lmu,http-client,rest,time-series,cross-validation
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Operating System :: OS Independent
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
22
|
+
Requires-Python: >=3.12
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
Requires-Dist: requests>=2.28.0
|
|
25
|
+
Requires-Dist: urllib3>=2.0.0
|
|
26
|
+
Provides-Extra: test
|
|
27
|
+
Requires-Dist: pytest>=8.0; extra == "test"
|
|
28
|
+
Requires-Dist: pytest-cov>=5.0; extra == "test"
|
|
29
|
+
Requires-Dist: responses>=0.23; extra == "test"
|
|
30
|
+
Provides-Extra: observability
|
|
31
|
+
Requires-Dist: juniper-observability>=0.3.1; extra == "observability"
|
|
32
|
+
|
|
33
|
+
# juniper-recurrence-client
|
|
34
|
+
|
|
35
|
+
HTTP client for the **juniper-recurrence** service — the FastAPI app wrapping the Δt-native LMU
|
|
36
|
+
recurrence model and its cross-validation API. A lean `requests`-based client mirroring
|
|
37
|
+
[`juniper-data-client`](https://github.com/pcalnon/juniper-data-client) and
|
|
38
|
+
[`juniper-cascor-client`](https://github.com/pcalnon/juniper-cascor-client), so consumers
|
|
39
|
+
(notably **juniper-canopy**'s recurrence backend adapter) drive every Juniper backend the same way.
|
|
40
|
+
|
|
41
|
+
## Install
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install juniper-recurrence-client # once published
|
|
45
|
+
pip install -e ".[test]" # local development
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
`requests`-only at the core; `pip install juniper-recurrence-client[observability]` adds the
|
|
49
|
+
optional `juniper-observability` integration (X-Request-ID propagation + the `on_request` hook).
|
|
50
|
+
|
|
51
|
+
## Quick start
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
from juniper_recurrence_client import JuniperRecurrenceClient
|
|
55
|
+
|
|
56
|
+
client = JuniperRecurrenceClient("http://localhost:8211", api_key="…")
|
|
57
|
+
|
|
58
|
+
# Train the LMU regressor on a dataset (by id / name / generator)
|
|
59
|
+
client.train(name="equities", d=16)
|
|
60
|
+
|
|
61
|
+
# Predict — inline X with Δt, or a dataset reference
|
|
62
|
+
client.predict(dataset_id="ds-1")
|
|
63
|
+
|
|
64
|
+
# Walk-forward cross-validation over the dataset's _full split
|
|
65
|
+
result = client.crossval(name="equities", n_folds=4, scheme="expanding", embargo=2)
|
|
66
|
+
print(result["eval_aggregate"])
|
|
67
|
+
|
|
68
|
+
# Inspect
|
|
69
|
+
client.get_model() # topology + metrics
|
|
70
|
+
client.training_status() # state + events
|
|
71
|
+
client.is_ready() # readiness probe
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## API surface
|
|
75
|
+
|
|
76
|
+
| Method | Endpoint |
|
|
77
|
+
|--------|----------|
|
|
78
|
+
| `train(*, dataset_id / name / generator, params, split, d, theta, ridge)` | `POST /v1/train` |
|
|
79
|
+
| `training_status()` | `GET /v1/training/status` |
|
|
80
|
+
| `predict(*, X / dt / target_dt / seq_lengths, or a dataset ref)` | `POST /v1/predict` |
|
|
81
|
+
| `crossval(*, n_folds, scheme, embargo, min_train, dataset ref, d, theta, ridge)` | `POST /v1/crossval` |
|
|
82
|
+
| `crossval_status()` | `GET /v1/crossval/status` |
|
|
83
|
+
| `get_model()` | `GET /v1/model` |
|
|
84
|
+
| `get_dataset()` | `GET /v1/dataset` |
|
|
85
|
+
| `health_check()` / `is_ready()` / `wait_for_ready()` | `GET /v1/health[/ready]` |
|
|
86
|
+
|
|
87
|
+
## Authentication
|
|
88
|
+
|
|
89
|
+
Pass `api_key=…`, or set `JUNIPER_RECURRENCE_API_KEY` (or the Docker-secret
|
|
90
|
+
`JUNIPER_RECURRENCE_API_KEY_FILE`, a path whose stripped contents are the key). The key is sent
|
|
91
|
+
as the `X-API-Key` header. Note the asymmetry: the **server** reads the *plural*
|
|
92
|
+
`JUNIPER_RECURRENCE_API_KEYS` (its accepted set); the **client** sends one key under the
|
|
93
|
+
*singular* env var.
|
|
94
|
+
|
|
95
|
+
## Errors
|
|
96
|
+
|
|
97
|
+
All errors derive from `JuniperRecurrenceClientError`: `JuniperRecurrenceConnectionError`,
|
|
98
|
+
`JuniperRecurrenceTimeoutError`, `JuniperRecurrenceNotFoundError` (404),
|
|
99
|
+
`JuniperRecurrenceConflictError` (409 — a run already in progress, or no trained model yet),
|
|
100
|
+
`JuniperRecurrenceValidationError` (400/422), `JuniperRecurrenceConfigurationError`.
|
|
101
|
+
|
|
102
|
+
## License
|
|
103
|
+
|
|
104
|
+
MIT — see [LICENSE](https://github.com/pcalnon/juniper-recurrence/blob/main/LICENSE).
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
juniper_recurrence_client/__init__.py
|
|
4
|
+
juniper_recurrence_client/_version.py
|
|
5
|
+
juniper_recurrence_client/client.py
|
|
6
|
+
juniper_recurrence_client/constants.py
|
|
7
|
+
juniper_recurrence_client/exceptions.py
|
|
8
|
+
juniper_recurrence_client.egg-info/PKG-INFO
|
|
9
|
+
juniper_recurrence_client.egg-info/SOURCES.txt
|
|
10
|
+
juniper_recurrence_client.egg-info/dependency_links.txt
|
|
11
|
+
juniper_recurrence_client.egg-info/requires.txt
|
|
12
|
+
juniper_recurrence_client.egg-info/top_level.txt
|
|
13
|
+
tests/test_auth.py
|
|
14
|
+
tests/test_client.py
|
|
15
|
+
tests/test_errors.py
|
|
16
|
+
tests/test_extra.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
juniper_recurrence_client
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "juniper-recurrence-client"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = "HTTP client for the juniper-recurrence service (the Δt-native LMU recurrence model + cross-validation API)"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [{ name = "Paul Calnon" }]
|
|
13
|
+
keywords = ["juniper", "recurrence", "lmu", "http-client", "rest", "time-series", "cross-validation"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 3 - Alpha",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
"Operating System :: OS Independent",
|
|
20
|
+
"Programming Language :: Python :: 3",
|
|
21
|
+
"Programming Language :: Python :: 3.12",
|
|
22
|
+
"Programming Language :: Python :: 3.13",
|
|
23
|
+
"Programming Language :: Python :: 3.14",
|
|
24
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
25
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
26
|
+
]
|
|
27
|
+
dependencies = [
|
|
28
|
+
"requests>=2.28.0",
|
|
29
|
+
"urllib3>=2.0.0",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
[project.optional-dependencies]
|
|
33
|
+
test = [
|
|
34
|
+
"pytest>=8.0",
|
|
35
|
+
"pytest-cov>=5.0",
|
|
36
|
+
"responses>=0.23",
|
|
37
|
+
]
|
|
38
|
+
# Opt-in: the X-Request-ID propagation + the on_request hook integrate with
|
|
39
|
+
# juniper-observability, but the client never requires it (the import is guarded).
|
|
40
|
+
observability = [
|
|
41
|
+
"juniper-observability>=0.3.1",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
[project.urls]
|
|
45
|
+
Homepage = "https://github.com/pcalnon/juniper-recurrence"
|
|
46
|
+
Repository = "https://github.com/pcalnon/juniper-recurrence"
|
|
47
|
+
Issues = "https://github.com/pcalnon/juniper-recurrence/issues"
|
|
48
|
+
|
|
49
|
+
[tool.setuptools.dynamic]
|
|
50
|
+
version = { attr = "juniper_recurrence_client._version.__version__" }
|
|
51
|
+
|
|
52
|
+
[tool.setuptools.packages.find]
|
|
53
|
+
include = ["juniper_recurrence_client*"]
|
|
54
|
+
exclude = ["tests*"]
|
|
55
|
+
|
|
56
|
+
[tool.pytest.ini_options]
|
|
57
|
+
minversion = "8.0"
|
|
58
|
+
testpaths = ["tests"]
|
|
59
|
+
python_files = ["test_*.py"]
|
|
60
|
+
python_classes = ["Test*"]
|
|
61
|
+
python_functions = ["test_*"]
|
|
62
|
+
# Ecosystem-standard autoload SIGSEGV defense (dash/playwright plugins are not used here).
|
|
63
|
+
addopts = ["-ra", "--strict-markers", "--strict-config", "-p", "no:dash", "-p", "no:playwright"]
|
|
64
|
+
|
|
65
|
+
[tool.ruff]
|
|
66
|
+
line-length = 512
|
|
67
|
+
target-version = "py312"
|
|
68
|
+
|
|
69
|
+
[tool.ruff.lint]
|
|
70
|
+
select = ["E", "F", "W", "B", "I", "N"]
|
|
71
|
+
|
|
72
|
+
[tool.ruff.lint.per-file-ignores]
|
|
73
|
+
# X is the standard ML design-matrix argument name (juniper-model-core's array contract);
|
|
74
|
+
# pep8-naming N803's lowercase rule is a false positive on it (mirrors the model package).
|
|
75
|
+
"juniper_recurrence_client/client.py" = ["N803"]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""API-key resolution tests (explicit arg, env var, ``_FILE`` Docker-secret, precedence)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from juniper_recurrence_client import JuniperRecurrenceClient
|
|
8
|
+
from juniper_recurrence_client.constants import API_KEY_ENV_VAR, API_KEY_FILE_ENV_VAR, API_KEY_HEADER_NAME
|
|
9
|
+
|
|
10
|
+
BASE_URL = "http://x:8211"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_explicit_api_key_sets_header() -> None:
|
|
14
|
+
client = JuniperRecurrenceClient(base_url=BASE_URL, api_key="secret-key")
|
|
15
|
+
assert client.session.headers[API_KEY_HEADER_NAME] == "secret-key"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_no_api_key_leaves_header_unset() -> None:
|
|
19
|
+
client = JuniperRecurrenceClient(base_url=BASE_URL)
|
|
20
|
+
assert API_KEY_HEADER_NAME not in client.session.headers
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_api_key_from_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
24
|
+
monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
|
|
25
|
+
client = JuniperRecurrenceClient(base_url=BASE_URL)
|
|
26
|
+
assert client.session.headers[API_KEY_HEADER_NAME] == "env-key"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_file_indirection_beats_plain_env(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
|
30
|
+
key_file = tmp_path / "key.txt"
|
|
31
|
+
key_file.write_text(" file-key\n", encoding="utf-8")
|
|
32
|
+
monkeypatch.setenv(API_KEY_FILE_ENV_VAR, str(key_file))
|
|
33
|
+
monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
|
|
34
|
+
client = JuniperRecurrenceClient(base_url=BASE_URL)
|
|
35
|
+
assert client.session.headers[API_KEY_HEADER_NAME] == "file-key"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test_explicit_api_key_beats_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
39
|
+
monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
|
|
40
|
+
client = JuniperRecurrenceClient(base_url=BASE_URL, api_key="explicit")
|
|
41
|
+
assert client.session.headers[API_KEY_HEADER_NAME] == "explicit"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_empty_file_falls_back_to_plain_env(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
|
45
|
+
empty = tmp_path / "empty.txt"
|
|
46
|
+
empty.write_text(" \n", encoding="utf-8")
|
|
47
|
+
monkeypatch.setenv(API_KEY_FILE_ENV_VAR, str(empty))
|
|
48
|
+
monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
|
|
49
|
+
client = JuniperRecurrenceClient(base_url=BASE_URL)
|
|
50
|
+
assert client.session.headers[API_KEY_HEADER_NAME] == "env-key"
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Unit tests for ``JuniperRecurrenceClient`` (HTTP mocked with ``responses``)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
import responses
|
|
9
|
+
|
|
10
|
+
from juniper_recurrence_client import JuniperRecurrenceClient
|
|
11
|
+
|
|
12
|
+
BASE_URL = "http://recurrence.test:8211"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _client(**kwargs: object) -> JuniperRecurrenceClient:
|
|
16
|
+
kwargs.setdefault("retries", 0)
|
|
17
|
+
return JuniperRecurrenceClient(base_url=BASE_URL, **kwargs)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.mark.parametrize(
|
|
21
|
+
"raw,expected",
|
|
22
|
+
[
|
|
23
|
+
("http://recurrence.test:8211", "http://recurrence.test:8211"),
|
|
24
|
+
("http://recurrence.test:8211/", "http://recurrence.test:8211"),
|
|
25
|
+
("http://recurrence.test:8211/v1", "http://recurrence.test:8211"),
|
|
26
|
+
("recurrence.test:8211", "http://recurrence.test:8211"),
|
|
27
|
+
],
|
|
28
|
+
)
|
|
29
|
+
def test_normalize_url(raw: str, expected: str) -> None:
|
|
30
|
+
assert JuniperRecurrenceClient(base_url=raw).base_url == expected
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@responses.activate
|
|
34
|
+
def test_train_posts_dataset_ref_and_hyperparams() -> None:
|
|
35
|
+
responses.add(responses.POST, f"{BASE_URL}/v1/train", json={"final_metrics": {"r2": 0.9}, "n_epochs": 1, "stopped_reason": None, "dataset": {}}, status=200)
|
|
36
|
+
out = _client().train(name="equities", d=16, theta=2.0, ridge=0.1)
|
|
37
|
+
assert out["final_metrics"]["r2"] == 0.9
|
|
38
|
+
sent = json.loads(responses.calls[0].request.body)
|
|
39
|
+
assert sent["dataset"] == {"split": "train", "name": "equities"}
|
|
40
|
+
assert sent["d"] == 16 and sent["theta"] == 2.0 and sent["ridge"] == 0.1
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@responses.activate
|
|
44
|
+
def test_training_status() -> None:
|
|
45
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/training/status", json={"state": "trained", "final_metrics": {"r2": 0.9}, "stopped_reason": None, "events": []}, status=200)
|
|
46
|
+
assert _client().training_status()["state"] == "trained"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@responses.activate
|
|
50
|
+
def test_predict_inline_x() -> None:
|
|
51
|
+
responses.add(responses.POST, f"{BASE_URL}/v1/predict", json={"predictions": [[1.0]], "shape": [1, 1]}, status=200)
|
|
52
|
+
out = _client().predict(X=[[[1.0, 2.0]]], dt=[[0.0]])
|
|
53
|
+
assert out["shape"] == [1, 1]
|
|
54
|
+
sent = json.loads(responses.calls[0].request.body)
|
|
55
|
+
assert sent["X"] == [[[1.0, 2.0]]] and sent["dt"] == [[0.0]] and "dataset" not in sent
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@responses.activate
|
|
59
|
+
def test_predict_by_dataset_ref() -> None:
|
|
60
|
+
responses.add(responses.POST, f"{BASE_URL}/v1/predict", json={"predictions": [], "shape": [0, 1]}, status=200)
|
|
61
|
+
_client().predict(dataset_id="ds-1")
|
|
62
|
+
sent = json.loads(responses.calls[0].request.body)
|
|
63
|
+
assert sent["dataset"] == {"split": "train", "dataset_id": "ds-1"} and "X" not in sent
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@responses.activate
|
|
67
|
+
def test_crossval_passes_config() -> None:
|
|
68
|
+
responses.add(responses.POST, f"{BASE_URL}/v1/crossval", json={"task_type": "regression", "n_folds": 3, "folds": [], "eval_aggregate": {}, "eval_std": {}, "dataset": {}}, status=200)
|
|
69
|
+
out = _client().crossval(name="equities", n_folds=3, scheme="rolling", embargo=2, min_train=10)
|
|
70
|
+
assert out["n_folds"] == 3
|
|
71
|
+
sent = json.loads(responses.calls[0].request.body)
|
|
72
|
+
assert sent["n_folds"] == 3 and sent["scheme"] == "rolling" and sent["embargo"] == 2 and sent["min_train"] == 10
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@responses.activate
|
|
76
|
+
def test_crossval_status_model_dataset() -> None:
|
|
77
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/crossval/status", json={"state": "done", "result": None}, status=200)
|
|
78
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/model", json={"topology": {"model_type": "lmu"}, "metrics": {}}, status=200)
|
|
79
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/dataset", json={"dataset_id": "ds-1", "split": "train"}, status=200)
|
|
80
|
+
c = _client()
|
|
81
|
+
assert c.crossval_status()["state"] == "done"
|
|
82
|
+
assert c.get_model()["topology"]["model_type"] == "lmu"
|
|
83
|
+
assert c.get_dataset()["dataset_id"] == "ds-1"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@responses.activate
|
|
87
|
+
def test_health_and_is_ready() -> None:
|
|
88
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/health", json={"status": "ok"}, status=200)
|
|
89
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "ready"}, status=200)
|
|
90
|
+
c = _client()
|
|
91
|
+
assert c.health_check()["status"] == "ok"
|
|
92
|
+
assert c.is_ready() is True
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@responses.activate
|
|
96
|
+
def test_is_ready_false_when_not_ready() -> None:
|
|
97
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "starting"}, status=200)
|
|
98
|
+
assert _client().is_ready() is False
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@responses.activate
|
|
102
|
+
def test_on_request_hook_fires_once() -> None:
|
|
103
|
+
seen: list[tuple[str, object, object, object]] = []
|
|
104
|
+
|
|
105
|
+
def hook(method: str, url: str, status: object, duration_ms: float, error: object) -> None:
|
|
106
|
+
seen.append((method, url, status, error))
|
|
107
|
+
|
|
108
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/model", json={"topology": {}, "metrics": {}}, status=200)
|
|
109
|
+
_client(on_request=hook).get_model()
|
|
110
|
+
assert len(seen) == 1
|
|
111
|
+
assert seen[0][0] == "GET" and seen[0][2] == 200 and seen[0][3] is None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_context_manager_closes() -> None:
|
|
115
|
+
with _client() as client:
|
|
116
|
+
assert client.session is not None
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""HTTP/transport error-mapping tests (status code -> typed exception)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import responses
|
|
7
|
+
|
|
8
|
+
from juniper_recurrence_client import (
|
|
9
|
+
JuniperRecurrenceClient,
|
|
10
|
+
JuniperRecurrenceClientError,
|
|
11
|
+
JuniperRecurrenceConflictError,
|
|
12
|
+
JuniperRecurrenceConnectionError,
|
|
13
|
+
JuniperRecurrenceNotFoundError,
|
|
14
|
+
JuniperRecurrenceValidationError,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
BASE_URL = "http://recurrence.test:8211"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _client() -> JuniperRecurrenceClient:
|
|
21
|
+
return JuniperRecurrenceClient(base_url=BASE_URL, retries=0)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@responses.activate
|
|
25
|
+
def test_404_maps_to_not_found() -> None:
|
|
26
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/model", json={"detail": "no model"}, status=404)
|
|
27
|
+
with pytest.raises(JuniperRecurrenceNotFoundError, match="no model"):
|
|
28
|
+
_client().get_model()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@responses.activate
|
|
32
|
+
def test_409_maps_to_conflict() -> None:
|
|
33
|
+
responses.add(responses.POST, f"{BASE_URL}/v1/train", json={"detail": "training already in progress"}, status=409)
|
|
34
|
+
with pytest.raises(JuniperRecurrenceConflictError, match="in progress"):
|
|
35
|
+
_client().train(name="equities")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@responses.activate
|
|
39
|
+
def test_422_maps_to_validation() -> None:
|
|
40
|
+
responses.add(responses.POST, f"{BASE_URL}/v1/crossval", json={"detail": "n_folds must be >= 2"}, status=422)
|
|
41
|
+
with pytest.raises(JuniperRecurrenceValidationError, match="n_folds"):
|
|
42
|
+
_client().crossval(name="equities", n_folds=1)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@responses.activate
|
|
46
|
+
def test_500_maps_to_client_error() -> None:
|
|
47
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/dataset", json={"detail": "boom"}, status=500)
|
|
48
|
+
with pytest.raises(JuniperRecurrenceClientError, match="500"):
|
|
49
|
+
_client().get_dataset()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@responses.activate
|
|
53
|
+
def test_connection_error_maps() -> None:
|
|
54
|
+
# No response registered for this URL -> responses raises a ConnectionError.
|
|
55
|
+
with pytest.raises(JuniperRecurrenceConnectionError):
|
|
56
|
+
_client().get_model()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@responses.activate
|
|
60
|
+
def test_malformed_json_raises_client_error() -> None:
|
|
61
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/dataset", body="not-json", status=200, content_type="application/json")
|
|
62
|
+
with pytest.raises(JuniperRecurrenceClientError, match="Malformed JSON"):
|
|
63
|
+
_client().get_dataset()
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Extra coverage: timeout / generic-request-error mapping, full predict & crossval bodies,
|
|
2
|
+
readiness polling (success + timeout), and the api-key file-read-error fallback."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import requests
|
|
10
|
+
import responses
|
|
11
|
+
|
|
12
|
+
from juniper_recurrence_client import (
|
|
13
|
+
JuniperRecurrenceClient,
|
|
14
|
+
JuniperRecurrenceClientError,
|
|
15
|
+
JuniperRecurrenceTimeoutError,
|
|
16
|
+
)
|
|
17
|
+
from juniper_recurrence_client.constants import API_KEY_ENV_VAR, API_KEY_FILE_ENV_VAR, API_KEY_HEADER_NAME
|
|
18
|
+
|
|
19
|
+
BASE_URL = "http://recurrence.test:8211"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _client(**kwargs: object) -> JuniperRecurrenceClient:
|
|
23
|
+
kwargs.setdefault("retries", 0)
|
|
24
|
+
return JuniperRecurrenceClient(base_url=BASE_URL, **kwargs)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@responses.activate
|
|
28
|
+
def test_timeout_maps_to_timeout_error() -> None:
|
|
29
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/model", body=requests.exceptions.Timeout("slow"))
|
|
30
|
+
with pytest.raises(JuniperRecurrenceTimeoutError):
|
|
31
|
+
_client().get_model()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@responses.activate
|
|
35
|
+
def test_generic_request_exception_maps_to_client_error() -> None:
|
|
36
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/dataset", body=requests.exceptions.RequestException("boom"))
|
|
37
|
+
with pytest.raises(JuniperRecurrenceClientError):
|
|
38
|
+
_client().get_dataset()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@responses.activate
|
|
42
|
+
def test_predict_full_aux_body() -> None:
|
|
43
|
+
responses.add(responses.POST, f"{BASE_URL}/v1/predict", json={"predictions": [], "shape": [0, 1]}, status=200)
|
|
44
|
+
_client().predict(X=[[[1.0]]], dt=[[0.0]], target_dt=[1.0], seq_lengths=[1])
|
|
45
|
+
sent = json.loads(responses.calls[0].request.body)
|
|
46
|
+
assert sent["target_dt"] == [1.0] and sent["seq_lengths"] == [1]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@responses.activate
|
|
50
|
+
def test_crossval_passes_hyperparams() -> None:
|
|
51
|
+
responses.add(responses.POST, f"{BASE_URL}/v1/crossval", json={"task_type": "regression", "n_folds": 2, "folds": [], "eval_aggregate": {}, "eval_std": {}, "dataset": {}}, status=200)
|
|
52
|
+
_client().crossval(generator="equities_seq", n_folds=2, d=8, theta=1.5, ridge=0.2)
|
|
53
|
+
sent = json.loads(responses.calls[0].request.body)
|
|
54
|
+
assert sent["dataset"]["generator"] == "equities_seq"
|
|
55
|
+
assert sent["d"] == 8 and sent["theta"] == 1.5 and sent["ridge"] == 0.2
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@responses.activate
|
|
59
|
+
def test_wait_for_ready_polls_until_ready() -> None:
|
|
60
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "starting"}, status=200)
|
|
61
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "ready"}, status=200)
|
|
62
|
+
assert _client().wait_for_ready(timeout=2.0, poll_interval=0.01) is True
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@responses.activate
|
|
66
|
+
def test_wait_for_ready_times_out() -> None:
|
|
67
|
+
responses.add(responses.GET, f"{BASE_URL}/v1/health/ready", json={"status": "starting"}, status=200)
|
|
68
|
+
assert _client().wait_for_ready(timeout=0.03, poll_interval=0.01) is False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_api_key_file_read_error_falls_back_to_env(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
|
72
|
+
# _FILE points at a missing path -> OSError on read -> fall back to the plain env var.
|
|
73
|
+
monkeypatch.setenv(API_KEY_FILE_ENV_VAR, str(tmp_path / "does-not-exist"))
|
|
74
|
+
monkeypatch.setenv(API_KEY_ENV_VAR, "env-key")
|
|
75
|
+
client = JuniperRecurrenceClient(base_url=BASE_URL)
|
|
76
|
+
assert client.session.headers[API_KEY_HEADER_NAME] == "env-key"
|