sheaf-serve 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sheaf_serve-0.1.0/.github/workflows/ci.yml +30 -0
- sheaf_serve-0.1.0/.github/workflows/publish.yml +18 -0
- sheaf_serve-0.1.0/.gitignore +13 -0
- sheaf_serve-0.1.0/PKG-INFO +52 -0
- sheaf_serve-0.1.0/README.md +7 -0
- sheaf_serve-0.1.0/pyproject.toml +75 -0
- sheaf_serve-0.1.0/src/sheaf/__init__.py +7 -0
- sheaf_serve-0.1.0/src/sheaf/api/__init__.py +10 -0
- sheaf_serve-0.1.0/src/sheaf/api/base.py +35 -0
- sheaf_serve-0.1.0/src/sheaf/api/time_series.py +80 -0
- sheaf_serve-0.1.0/src/sheaf/backends/__init__.py +4 -0
- sheaf_serve-0.1.0/src/sheaf/backends/base.py +41 -0
- sheaf_serve-0.1.0/src/sheaf/backends/chronos.py +143 -0
- sheaf_serve-0.1.0/src/sheaf/cache/__init__.py +0 -0
- sheaf_serve-0.1.0/src/sheaf/integrations/__init__.py +0 -0
- sheaf_serve-0.1.0/src/sheaf/registry.py +24 -0
- sheaf_serve-0.1.0/src/sheaf/scheduling/__init__.py +3 -0
- sheaf_serve-0.1.0/src/sheaf/scheduling/batch.py +17 -0
- sheaf_serve-0.1.0/src/sheaf/server.py +71 -0
- sheaf_serve-0.1.0/src/sheaf/spec.py +35 -0
- sheaf_serve-0.1.0/tests/__init__.py +0 -0
- sheaf_serve-0.1.0/tests/test_api.py +50 -0
- sheaf_serve-0.1.0/uv.lock +5036 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
lint:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
steps:
|
|
13
|
+
- uses: actions/checkout@v4
|
|
14
|
+
- uses: astral-sh/setup-uv@v5
|
|
15
|
+
- run: uv sync --extra dev
|
|
16
|
+
- run: uv run ruff check src/ tests/
|
|
17
|
+
- run: uv run ruff format --check src/ tests/
|
|
18
|
+
|
|
19
|
+
test:
|
|
20
|
+
runs-on: ubuntu-latest
|
|
21
|
+
strategy:
|
|
22
|
+
matrix:
|
|
23
|
+
python-version: ["3.10", "3.11", "3.12"]
|
|
24
|
+
steps:
|
|
25
|
+
- uses: actions/checkout@v4
|
|
26
|
+
- uses: astral-sh/setup-uv@v5
|
|
27
|
+
with:
|
|
28
|
+
python-version: ${{ matrix.python-version }}
|
|
29
|
+
- run: uv sync --extra dev
|
|
30
|
+
- run: uv run pytest tests/ -v
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- "v*"
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
publish:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
environment: pypi
|
|
12
|
+
permissions:
|
|
13
|
+
id-token: write # required for trusted publishing
|
|
14
|
+
steps:
|
|
15
|
+
- uses: actions/checkout@v4
|
|
16
|
+
- uses: astral-sh/setup-uv@v5
|
|
17
|
+
- run: uv build
|
|
18
|
+
- uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sheaf-serve
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Unified serving layer for non-text foundation models
|
|
5
|
+
Author-email: Alex Korbonits <alexkorbonits@gmail.com>
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Keywords: foundation-models,inference,mlops,ray,serving
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.10
|
|
17
|
+
Requires-Dist: numpy>=1.24.0
|
|
18
|
+
Requires-Dist: pydantic>=2.0.0
|
|
19
|
+
Requires-Dist: ray[serve]>=2.10.0
|
|
20
|
+
Provides-Extra: all
|
|
21
|
+
Requires-Dist: chronos-forecasting>=1.0.0; extra == 'all'
|
|
22
|
+
Requires-Dist: fair-esm>=2.0.0; extra == 'all'
|
|
23
|
+
Requires-Dist: feast>=0.40.0; extra == 'all'
|
|
24
|
+
Requires-Dist: pymilvus>=2.4.0; extra == 'all'
|
|
25
|
+
Requires-Dist: tabpfn>=2.0.0; extra == 'all'
|
|
26
|
+
Requires-Dist: timesfm>=1.0.0; extra == 'all'
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: mypy>=1.9.0; extra == 'dev'
|
|
29
|
+
Requires-Dist: pre-commit>=3.7.0; extra == 'dev'
|
|
30
|
+
Requires-Dist: pytest-asyncio>=0.23.0; extra == 'dev'
|
|
31
|
+
Requires-Dist: pytest>=8.0.0; extra == 'dev'
|
|
32
|
+
Requires-Dist: ruff>=0.4.0; extra == 'dev'
|
|
33
|
+
Provides-Extra: feast
|
|
34
|
+
Requires-Dist: feast>=0.40.0; extra == 'feast'
|
|
35
|
+
Provides-Extra: milvus
|
|
36
|
+
Requires-Dist: pymilvus>=2.4.0; extra == 'milvus'
|
|
37
|
+
Provides-Extra: molecular
|
|
38
|
+
Requires-Dist: fair-esm>=2.0.0; extra == 'molecular'
|
|
39
|
+
Provides-Extra: tabular
|
|
40
|
+
Requires-Dist: tabpfn>=2.0.0; extra == 'tabular'
|
|
41
|
+
Provides-Extra: time-series
|
|
42
|
+
Requires-Dist: chronos-forecasting>=1.0.0; extra == 'time-series'
|
|
43
|
+
Requires-Dist: timesfm>=1.0.0; extra == 'time-series'
|
|
44
|
+
Description-Content-Type: text/markdown
|
|
45
|
+
|
|
46
|
+
# Sheaf
|
|
47
|
+
|
|
48
|
+
Unified serving layer for non-text foundation models.
|
|
49
|
+
|
|
50
|
+
```bash
|
|
51
|
+
pip install sheaf-serve
|
|
52
|
+
```
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "sheaf-serve"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Unified serving layer for non-text foundation models"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = { text = "Apache-2.0" }
|
|
12
|
+
authors = [{ name = "Alex Korbonits", email = "alexkorbonits@gmail.com" }]
|
|
13
|
+
keywords = ["serving", "inference", "foundation-models", "ray", "mlops"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 3 - Alpha",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"License :: OSI Approved :: Apache Software License",
|
|
19
|
+
"Programming Language :: Python :: 3.10",
|
|
20
|
+
"Programming Language :: Python :: 3.11",
|
|
21
|
+
"Programming Language :: Python :: 3.12",
|
|
22
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
23
|
+
]
|
|
24
|
+
dependencies = [
|
|
25
|
+
"ray[serve]>=2.10.0",
|
|
26
|
+
"pydantic>=2.0.0",
|
|
27
|
+
"numpy>=1.24.0",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.optional-dependencies]
|
|
31
|
+
time-series = [
|
|
32
|
+
"chronos-forecasting>=1.0.0",
|
|
33
|
+
"timesfm>=1.0.0",
|
|
34
|
+
]
|
|
35
|
+
tabular = [
|
|
36
|
+
"tabpfn>=2.0.0",
|
|
37
|
+
]
|
|
38
|
+
molecular = [
|
|
39
|
+
"fair-esm>=2.0.0",
|
|
40
|
+
]
|
|
41
|
+
feast = [
|
|
42
|
+
"feast>=0.40.0",
|
|
43
|
+
]
|
|
44
|
+
milvus = [
|
|
45
|
+
"pymilvus>=2.4.0",
|
|
46
|
+
]
|
|
47
|
+
dev = [
|
|
48
|
+
"pytest>=8.0.0",
|
|
49
|
+
"pytest-asyncio>=0.23.0",
|
|
50
|
+
"ruff>=0.4.0",
|
|
51
|
+
"mypy>=1.9.0",
|
|
52
|
+
"pre-commit>=3.7.0",
|
|
53
|
+
]
|
|
54
|
+
all = [
|
|
55
|
+
"sheaf-serve[time-series,tabular,molecular,feast,milvus]",
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
[tool.hatch.build.targets.wheel]
|
|
59
|
+
packages = ["src/sheaf"]
|
|
60
|
+
|
|
61
|
+
[tool.ruff]
|
|
62
|
+
line-length = 88
|
|
63
|
+
src = ["src"]
|
|
64
|
+
|
|
65
|
+
[tool.ruff.lint]
|
|
66
|
+
select = ["E", "F", "I", "UP"]
|
|
67
|
+
|
|
68
|
+
[tool.mypy]
|
|
69
|
+
python_version = "3.10"
|
|
70
|
+
strict = true
|
|
71
|
+
files = ["src/sheaf"]
|
|
72
|
+
|
|
73
|
+
[tool.pytest.ini_options]
|
|
74
|
+
asyncio_mode = "auto"
|
|
75
|
+
testpaths = ["tests"]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Base request/response contracts and model type registry."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any
|
|
5
|
+
from uuid import UUID, uuid4
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelType(str, Enum):
|
|
11
|
+
TIME_SERIES = "time_series"
|
|
12
|
+
TABULAR = "tabular"
|
|
13
|
+
MOLECULAR = "molecular"
|
|
14
|
+
GEOSPATIAL = "geospatial"
|
|
15
|
+
DIFFUSION = "diffusion"
|
|
16
|
+
NEURAL_OPERATOR = "neural_operator"
|
|
17
|
+
AUDIO = "audio"
|
|
18
|
+
EMBEDDING = "embedding"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseRequest(BaseModel):
|
|
22
|
+
request_id: UUID = Field(default_factory=uuid4)
|
|
23
|
+
model_type: ModelType
|
|
24
|
+
model_name: str
|
|
25
|
+
|
|
26
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BaseResponse(BaseModel):
|
|
30
|
+
request_id: UUID
|
|
31
|
+
model_type: ModelType
|
|
32
|
+
model_name: str
|
|
33
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
34
|
+
|
|
35
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""API contract for time series foundation models (Chronos2, TimesFM, etc.)."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Annotated, Any
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from pydantic import Field, model_validator
|
|
9
|
+
|
|
10
|
+
from sheaf.api.base import BaseRequest, BaseResponse, ModelType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Frequency(str, Enum):
|
|
14
|
+
MINUTELY = "1min"
|
|
15
|
+
FIVE_MINUTELY = "5min"
|
|
16
|
+
FIFTEEN_MINUTELY = "15min"
|
|
17
|
+
HOURLY = "1h"
|
|
18
|
+
DAILY = "1d"
|
|
19
|
+
WEEKLY = "1W"
|
|
20
|
+
MONTHLY = "1M"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OutputMode(str, Enum):
|
|
24
|
+
MEAN = "mean"
|
|
25
|
+
QUANTILES = "quantiles"
|
|
26
|
+
SAMPLES = "samples"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TimeSeriesRequest(BaseRequest):
|
|
30
|
+
"""Request contract for time series foundation models.
|
|
31
|
+
|
|
32
|
+
Either `history` (raw values) or `feature_ref` (Feast entity reference)
|
|
33
|
+
must be provided, not both.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
model_type: ModelType = ModelType.TIME_SERIES
|
|
37
|
+
|
|
38
|
+
# Input: raw history or feature store reference
|
|
39
|
+
history: Annotated[list[float] | None, Field(default=None)]
|
|
40
|
+
feature_ref: Annotated[dict[str, str] | None, Field(default=None)]
|
|
41
|
+
# e.g. {"feature_view": "asset_prices", "entity_id": "AAPL"}
|
|
42
|
+
|
|
43
|
+
horizon: Annotated[int, Field(gt=0)]
|
|
44
|
+
frequency: Frequency
|
|
45
|
+
output_mode: OutputMode = OutputMode.MEAN
|
|
46
|
+
quantile_levels: list[float] = Field(
|
|
47
|
+
default=[0.1, 0.5, 0.9],
|
|
48
|
+
description="Quantile levels — only used when output_mode=quantiles",
|
|
49
|
+
)
|
|
50
|
+
num_samples: int = Field(
|
|
51
|
+
default=20,
|
|
52
|
+
description="Number of samples — only used when output_mode=samples",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@model_validator(mode="after")
|
|
56
|
+
def validate_input_source(self) -> "TimeSeriesRequest":
|
|
57
|
+
if self.history is None and self.feature_ref is None:
|
|
58
|
+
raise ValueError("One of `history` or `feature_ref` must be provided.")
|
|
59
|
+
if self.history is not None and self.feature_ref is not None:
|
|
60
|
+
raise ValueError("Provide either `history` or `feature_ref`, not both.")
|
|
61
|
+
return self
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TimeSeriesResponse(BaseResponse):
|
|
65
|
+
"""Response contract for time series foundation models."""
|
|
66
|
+
|
|
67
|
+
model_type: ModelType = ModelType.TIME_SERIES
|
|
68
|
+
|
|
69
|
+
# Mean forecast — always populated
|
|
70
|
+
mean: list[float]
|
|
71
|
+
|
|
72
|
+
# Populated when output_mode=quantiles
|
|
73
|
+
quantiles: dict[str, list[float]] | None = None
|
|
74
|
+
# e.g. {"0.1": [...], "0.5": [...], "0.9": [...]}
|
|
75
|
+
|
|
76
|
+
# Populated when output_mode=samples
|
|
77
|
+
samples: list[list[float]] | None = None
|
|
78
|
+
|
|
79
|
+
horizon: int
|
|
80
|
+
frequency: str
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Abstract base class for model backends."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sheaf.api.base import BaseRequest, BaseResponse
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelBackend(ABC):
|
|
10
|
+
"""Pluggable model backend.
|
|
11
|
+
|
|
12
|
+
Implement this to add a new model to Sheaf. The backend owns:
|
|
13
|
+
- Model loading and initialization
|
|
14
|
+
- Preprocessing raw inputs
|
|
15
|
+
- Running inference
|
|
16
|
+
- Postprocessing outputs into a typed response
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def load(self) -> None:
|
|
21
|
+
"""Load model weights and initialize any required state."""
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def predict(self, request: BaseRequest) -> BaseResponse:
|
|
26
|
+
"""Run inference for a single request."""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def batch_predict(self, requests: list[BaseRequest]) -> list[BaseResponse]:
|
|
30
|
+
"""Run inference over a batch.
|
|
31
|
+
|
|
32
|
+
Default implementation runs requests sequentially.
|
|
33
|
+
Override to implement model-type-aware batching.
|
|
34
|
+
"""
|
|
35
|
+
return [self.predict(r) for r in requests]
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def model_type(self) -> str:
|
|
40
|
+
"""The ModelType this backend serves."""
|
|
41
|
+
...
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Chronos2 backend for time series forecasting.
|
|
2
|
+
|
|
3
|
+
Requires: pip install sheaf-serve[time-series]
|
|
4
|
+
Supports: amazon/chronos-t5-{tiny,mini,small,base,large}
|
|
5
|
+
amazon/chronos-bolt-{tiny,mini,small,base}
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from sheaf.api.base import BaseRequest, BaseResponse, ModelType
|
|
15
|
+
from sheaf.api.time_series import OutputMode, TimeSeriesRequest, TimeSeriesResponse
|
|
16
|
+
from sheaf.backends.base import ModelBackend
|
|
17
|
+
from sheaf.registry import register_backend
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@register_backend("chronos2")
|
|
24
|
+
class Chronos2Backend(ModelBackend):
|
|
25
|
+
"""ModelBackend implementation for Chronos / Chronos-Bolt models.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model_id: HuggingFace model ID, e.g. "amazon/chronos-bolt-small"
|
|
29
|
+
device_map: "cpu", "cuda", "mps", or "auto"
|
|
30
|
+
torch_dtype: "bfloat16", "float32", etc. Passed to from_pretrained.
|
|
31
|
+
num_samples: Default number of samples for probabilistic output.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model_id: str = "amazon/chronos-bolt-small",
|
|
37
|
+
device_map: str = "cpu",
|
|
38
|
+
torch_dtype: str = "bfloat16",
|
|
39
|
+
num_samples: int = 20,
|
|
40
|
+
) -> None:
|
|
41
|
+
self._model_id = model_id
|
|
42
|
+
self._device_map = device_map
|
|
43
|
+
self._torch_dtype = torch_dtype
|
|
44
|
+
self._default_num_samples = num_samples
|
|
45
|
+
self._pipeline: Any = None
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def model_type(self) -> str:
|
|
49
|
+
return ModelType.TIME_SERIES
|
|
50
|
+
|
|
51
|
+
def load(self) -> None:
|
|
52
|
+
try:
|
|
53
|
+
import torch
|
|
54
|
+
from chronos import BaseChronosPipeline
|
|
55
|
+
except ImportError as e:
|
|
56
|
+
raise ImportError(
|
|
57
|
+
"chronos-forecasting is required for the Chronos2 backend. "
|
|
58
|
+
"Install it with: pip install sheaf-serve[time-series]"
|
|
59
|
+
) from e
|
|
60
|
+
|
|
61
|
+
dtype_map = {
|
|
62
|
+
"bfloat16": torch.bfloat16,
|
|
63
|
+
"float32": torch.float32,
|
|
64
|
+
"float16": torch.float16,
|
|
65
|
+
}
|
|
66
|
+
torch_dtype = dtype_map.get(self._torch_dtype, torch.bfloat16)
|
|
67
|
+
|
|
68
|
+
self._pipeline = BaseChronosPipeline.from_pretrained(
|
|
69
|
+
self._model_id,
|
|
70
|
+
device_map=self._device_map,
|
|
71
|
+
torch_dtype=torch_dtype,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def predict(self, request: BaseRequest) -> BaseResponse:
|
|
75
|
+
if not isinstance(request, TimeSeriesRequest):
|
|
76
|
+
raise TypeError(f"Expected TimeSeriesRequest, got {type(request)}")
|
|
77
|
+
return self._run([request])[0]
|
|
78
|
+
|
|
79
|
+
def batch_predict(self, requests: list[BaseRequest]) -> list[BaseResponse]:
|
|
80
|
+
ts_requests = [r for r in requests if isinstance(r, TimeSeriesRequest)]
|
|
81
|
+
if len(ts_requests) != len(requests):
|
|
82
|
+
raise TypeError("All requests must be TimeSeriesRequest for Chronos2Backend")
|
|
83
|
+
return self._run(ts_requests)
|
|
84
|
+
|
|
85
|
+
def _run(self, requests: list[TimeSeriesRequest]) -> list[TimeSeriesResponse]:
|
|
86
|
+
import torch
|
|
87
|
+
|
|
88
|
+
if self._pipeline is None:
|
|
89
|
+
raise RuntimeError("Backend not loaded. Call load() first.")
|
|
90
|
+
|
|
91
|
+
# Build context tensors — pad to same length within batch
|
|
92
|
+
histories = [r.history or [] for r in requests]
|
|
93
|
+
max_len = max(len(h) for h in histories)
|
|
94
|
+
padded = [
|
|
95
|
+
[float("nan")] * (max_len - len(h)) + h
|
|
96
|
+
for h in histories
|
|
97
|
+
]
|
|
98
|
+
context = torch.tensor(padded, dtype=torch.float32)
|
|
99
|
+
|
|
100
|
+
# Bucket by horizon: all requests in a batch must share a horizon
|
|
101
|
+
# (caller is responsible for bucketing — see BatchPolicy(bucket_by="horizon"))
|
|
102
|
+
horizon = requests[0].horizon
|
|
103
|
+
num_samples = requests[0].num_samples
|
|
104
|
+
|
|
105
|
+
# predict returns [batch, num_samples, horizon]
|
|
106
|
+
forecast = self._pipeline.predict(
|
|
107
|
+
context=context,
|
|
108
|
+
prediction_length=horizon,
|
|
109
|
+
num_samples=num_samples,
|
|
110
|
+
)
|
|
111
|
+
forecast_np: np.ndarray = forecast.numpy()
|
|
112
|
+
|
|
113
|
+
responses = []
|
|
114
|
+
for i, req in enumerate(requests):
|
|
115
|
+
samples = forecast_np[i] # [num_samples, horizon]
|
|
116
|
+
responses.append(self._build_response(req, samples))
|
|
117
|
+
return responses
|
|
118
|
+
|
|
119
|
+
def _build_response(
|
|
120
|
+
self, req: TimeSeriesRequest, samples: np.ndarray
|
|
121
|
+
) -> TimeSeriesResponse:
|
|
122
|
+
mean = samples.mean(axis=0).tolist()
|
|
123
|
+
|
|
124
|
+
quantiles = None
|
|
125
|
+
if req.output_mode == OutputMode.QUANTILES:
|
|
126
|
+
quantiles = {
|
|
127
|
+
str(q): np.quantile(samples, q, axis=0).tolist()
|
|
128
|
+
for q in req.quantile_levels
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
raw_samples = None
|
|
132
|
+
if req.output_mode == OutputMode.SAMPLES:
|
|
133
|
+
raw_samples = samples.tolist()
|
|
134
|
+
|
|
135
|
+
return TimeSeriesResponse(
|
|
136
|
+
request_id=req.request_id,
|
|
137
|
+
model_name=req.model_name,
|
|
138
|
+
horizon=req.horizon,
|
|
139
|
+
frequency=req.frequency.value,
|
|
140
|
+
mean=mean,
|
|
141
|
+
quantiles=quantiles,
|
|
142
|
+
samples=raw_samples,
|
|
143
|
+
)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Backend registry — separated to avoid circular imports."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from sheaf.backends.base import ModelBackend
|
|
9
|
+
|
|
10
|
+
_BACKEND_REGISTRY: dict[str, type["ModelBackend"]] = {}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_backend(name: str):
|
|
14
|
+
"""Decorator to register a ModelBackend implementation by name.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
@register_backend("chronos2")
|
|
18
|
+
class Chronos2Backend(ModelBackend):
|
|
19
|
+
...
|
|
20
|
+
"""
|
|
21
|
+
def decorator(cls: type["ModelBackend"]) -> type["ModelBackend"]:
|
|
22
|
+
_BACKEND_REGISTRY[name] = cls
|
|
23
|
+
return cls
|
|
24
|
+
return decorator
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Batching policies for model-type-aware request scheduling."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BatchPolicy(BaseModel):
|
|
7
|
+
"""Controls how requests are batched before hitting the model backend.
|
|
8
|
+
|
|
9
|
+
max_batch_size: hard cap on requests per batch
|
|
10
|
+
timeout_ms: max time to wait for a full batch before flushing
|
|
11
|
+
bucket_by: field name to bucket on before batching (e.g. "horizon" for
|
|
12
|
+
time series, so variable-length forecasts don't get mixed)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
max_batch_size: int = Field(default=32, gt=0)
|
|
16
|
+
timeout_ms: int = Field(default=50, gt=0)
|
|
17
|
+
bucket_by: str | None = None
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""ModelServer — the Ray Serve entry point for Sheaf."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import ray
|
|
6
|
+
from ray import serve
|
|
7
|
+
|
|
8
|
+
from sheaf.api.base import BaseRequest, BaseResponse
|
|
9
|
+
from sheaf.backends.base import ModelBackend
|
|
10
|
+
from sheaf.registry import _BACKEND_REGISTRY, register_backend # noqa: F401
|
|
11
|
+
from sheaf.spec import ModelSpec
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@serve.deployment
|
|
15
|
+
class _SheafDeployment:
|
|
16
|
+
def __init__(self, spec: ModelSpec) -> None:
|
|
17
|
+
backend_cls = _BACKEND_REGISTRY.get(spec.backend)
|
|
18
|
+
if backend_cls is None:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"Unknown backend '{spec.backend}'. "
|
|
21
|
+
f"Registered backends: {list(_BACKEND_REGISTRY)}"
|
|
22
|
+
)
|
|
23
|
+
self._backend: ModelBackend = backend_cls(**spec.backend_kwargs)
|
|
24
|
+
self._backend.load()
|
|
25
|
+
self._spec = spec
|
|
26
|
+
|
|
27
|
+
async def __call__(self, request: BaseRequest) -> BaseResponse:
|
|
28
|
+
return self._backend.predict(request)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ModelServer:
|
|
32
|
+
"""Top-level serving orchestrator.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
server = ModelServer(
|
|
36
|
+
models=[chronos_spec, tabpfn_spec],
|
|
37
|
+
)
|
|
38
|
+
server.run()
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
models: list[ModelSpec],
|
|
44
|
+
host: str = "0.0.0.0",
|
|
45
|
+
port: int = 8000,
|
|
46
|
+
) -> None:
|
|
47
|
+
self._models = models
|
|
48
|
+
self._host = host
|
|
49
|
+
self._port = port
|
|
50
|
+
self._deployments: dict[str, Any] = {}
|
|
51
|
+
|
|
52
|
+
def run(self) -> None:
|
|
53
|
+
if not ray.is_initialized():
|
|
54
|
+
ray.init()
|
|
55
|
+
|
|
56
|
+
serve.start(http_options={"host": self._host, "port": self._port})
|
|
57
|
+
|
|
58
|
+
for spec in self._models:
|
|
59
|
+
deployment = _SheafDeployment.options(
|
|
60
|
+
name=spec.name,
|
|
61
|
+
num_replicas=spec.resources.replicas,
|
|
62
|
+
ray_actor_options={
|
|
63
|
+
"num_cpus": spec.resources.num_cpus,
|
|
64
|
+
"num_gpus": spec.resources.num_gpus,
|
|
65
|
+
},
|
|
66
|
+
).bind(spec)
|
|
67
|
+
handle = serve.run(deployment, name=spec.name, route_prefix=f"/{spec.name}")
|
|
68
|
+
self._deployments[spec.name] = handle
|
|
69
|
+
|
|
70
|
+
def shutdown(self) -> None:
|
|
71
|
+
serve.shutdown()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""ModelSpec — declares what Sheaf should serve."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
from sheaf.api.base import ModelType
|
|
6
|
+
from sheaf.scheduling.batch import BatchPolicy
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ResourceConfig(BaseModel):
|
|
10
|
+
num_cpus: float = 1.0
|
|
11
|
+
num_gpus: float = 0.0
|
|
12
|
+
memory_gb: float | None = None
|
|
13
|
+
replicas: int = 1
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ModelSpec(BaseModel):
|
|
17
|
+
"""Declares a model to be served by Sheaf.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
spec = ModelSpec(
|
|
21
|
+
name="chronos2-small",
|
|
22
|
+
model_type=ModelType.TIME_SERIES,
|
|
23
|
+
backend="chronos2",
|
|
24
|
+
backend_kwargs={"model_size": "small"},
|
|
25
|
+
resources=ResourceConfig(num_gpus=1, replicas=2),
|
|
26
|
+
batch_policy=BatchPolicy(max_batch_size=64, bucket_by="horizon"),
|
|
27
|
+
)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
name: str
|
|
31
|
+
model_type: ModelType
|
|
32
|
+
backend: str
|
|
33
|
+
backend_kwargs: dict = Field(default_factory=dict)
|
|
34
|
+
resources: ResourceConfig = Field(default_factory=ResourceConfig)
|
|
35
|
+
batch_policy: BatchPolicy = Field(default_factory=BatchPolicy)
|