sheaf-serve 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,30 @@
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ lint:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ - uses: astral-sh/setup-uv@v5
15
+ - run: uv sync --extra dev
16
+ - run: uv run ruff check src/ tests/
17
+ - run: uv run ruff format --check src/ tests/
18
+
19
+ test:
20
+ runs-on: ubuntu-latest
21
+ strategy:
22
+ matrix:
23
+ python-version: ["3.10", "3.11", "3.12"]
24
+ steps:
25
+ - uses: actions/checkout@v4
26
+ - uses: astral-sh/setup-uv@v5
27
+ with:
28
+ python-version: ${{ matrix.python-version }}
29
+ - run: uv sync --extra dev
30
+ - run: uv run pytest tests/ -v
@@ -0,0 +1,18 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+
8
+ jobs:
9
+ publish:
10
+ runs-on: ubuntu-latest
11
+ environment: pypi
12
+ permissions:
13
+ id-token: write # required for trusted publishing
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ - uses: astral-sh/setup-uv@v5
17
+ - run: uv build
18
+ - uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,13 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.egg-info/
4
+ .eggs/
5
+ dist/
6
+ build/
7
+ .venv/
8
+ venv/
9
+ .mypy_cache/
10
+ .ruff_cache/
11
+ .pytest_cache/
12
+ *.so
13
+ .env
@@ -0,0 +1,52 @@
1
+ Metadata-Version: 2.4
2
+ Name: sheaf-serve
3
+ Version: 0.1.0
4
+ Summary: Unified serving layer for non-text foundation models
5
+ Author-email: Alex Korbonits <alexkorbonits@gmail.com>
6
+ License: Apache-2.0
7
+ Keywords: foundation-models,inference,mlops,ray,serving
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.10
17
+ Requires-Dist: numpy>=1.24.0
18
+ Requires-Dist: pydantic>=2.0.0
19
+ Requires-Dist: ray[serve]>=2.10.0
20
+ Provides-Extra: all
21
+ Requires-Dist: chronos-forecasting>=1.0.0; extra == 'all'
22
+ Requires-Dist: fair-esm>=2.0.0; extra == 'all'
23
+ Requires-Dist: feast>=0.40.0; extra == 'all'
24
+ Requires-Dist: pymilvus>=2.4.0; extra == 'all'
25
+ Requires-Dist: tabpfn>=2.0.0; extra == 'all'
26
+ Requires-Dist: timesfm>=1.0.0; extra == 'all'
27
+ Provides-Extra: dev
28
+ Requires-Dist: mypy>=1.9.0; extra == 'dev'
29
+ Requires-Dist: pre-commit>=3.7.0; extra == 'dev'
30
+ Requires-Dist: pytest-asyncio>=0.23.0; extra == 'dev'
31
+ Requires-Dist: pytest>=8.0.0; extra == 'dev'
32
+ Requires-Dist: ruff>=0.4.0; extra == 'dev'
33
+ Provides-Extra: feast
34
+ Requires-Dist: feast>=0.40.0; extra == 'feast'
35
+ Provides-Extra: milvus
36
+ Requires-Dist: pymilvus>=2.4.0; extra == 'milvus'
37
+ Provides-Extra: molecular
38
+ Requires-Dist: fair-esm>=2.0.0; extra == 'molecular'
39
+ Provides-Extra: tabular
40
+ Requires-Dist: tabpfn>=2.0.0; extra == 'tabular'
41
+ Provides-Extra: time-series
42
+ Requires-Dist: chronos-forecasting>=1.0.0; extra == 'time-series'
43
+ Requires-Dist: timesfm>=1.0.0; extra == 'time-series'
44
+ Description-Content-Type: text/markdown
45
+
46
+ # Sheaf
47
+
48
+ Unified serving layer for non-text foundation models.
49
+
50
+ ```bash
51
+ pip install sheaf-serve
52
+ ```
@@ -0,0 +1,7 @@
1
+ # Sheaf
2
+
3
+ Unified serving layer for non-text foundation models.
4
+
5
+ ```bash
6
+ pip install sheaf-serve
7
+ ```
@@ -0,0 +1,75 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "sheaf-serve"
7
+ version = "0.1.0"
8
+ description = "Unified serving layer for non-text foundation models"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "Apache-2.0" }
12
+ authors = [{ name = "Alex Korbonits", email = "alexkorbonits@gmail.com" }]
13
+ keywords = ["serving", "inference", "foundation-models", "ray", "mlops"]
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "Intended Audience :: Developers",
17
+ "Intended Audience :: Science/Research",
18
+ "License :: OSI Approved :: Apache Software License",
19
+ "Programming Language :: Python :: 3.10",
20
+ "Programming Language :: Python :: 3.11",
21
+ "Programming Language :: Python :: 3.12",
22
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
23
+ ]
24
+ dependencies = [
25
+ "ray[serve]>=2.10.0",
26
+ "pydantic>=2.0.0",
27
+ "numpy>=1.24.0",
28
+ ]
29
+
30
+ [project.optional-dependencies]
31
+ time-series = [
32
+ "chronos-forecasting>=1.0.0",
33
+ "timesfm>=1.0.0",
34
+ ]
35
+ tabular = [
36
+ "tabpfn>=2.0.0",
37
+ ]
38
+ molecular = [
39
+ "fair-esm>=2.0.0",
40
+ ]
41
+ feast = [
42
+ "feast>=0.40.0",
43
+ ]
44
+ milvus = [
45
+ "pymilvus>=2.4.0",
46
+ ]
47
+ dev = [
48
+ "pytest>=8.0.0",
49
+ "pytest-asyncio>=0.23.0",
50
+ "ruff>=0.4.0",
51
+ "mypy>=1.9.0",
52
+ "pre-commit>=3.7.0",
53
+ ]
54
+ all = [
55
+ "sheaf-serve[time-series,tabular,molecular,feast,milvus]",
56
+ ]
57
+
58
+ [tool.hatch.build.targets.wheel]
59
+ packages = ["src/sheaf"]
60
+
61
+ [tool.ruff]
62
+ line-length = 88
63
+ src = ["src"]
64
+
65
+ [tool.ruff.lint]
66
+ select = ["E", "F", "I", "UP"]
67
+
68
+ [tool.mypy]
69
+ python_version = "3.10"
70
+ strict = true
71
+ files = ["src/sheaf"]
72
+
73
+ [tool.pytest.ini_options]
74
+ asyncio_mode = "auto"
75
+ testpaths = ["tests"]
@@ -0,0 +1,7 @@
1
+ """Sheaf — unified serving layer for non-text foundation models."""
2
+
3
+ from sheaf.server import ModelServer
4
+ from sheaf.spec import ModelSpec
5
+
6
+ __version__ = "0.1.0"
7
+ __all__ = ["ModelServer", "ModelSpec"]
@@ -0,0 +1,10 @@
1
+ from sheaf.api.base import BaseRequest, BaseResponse, ModelType
2
+ from sheaf.api.time_series import TimeSeriesRequest, TimeSeriesResponse
3
+
4
+ __all__ = [
5
+ "BaseRequest",
6
+ "BaseResponse",
7
+ "ModelType",
8
+ "TimeSeriesRequest",
9
+ "TimeSeriesResponse",
10
+ ]
@@ -0,0 +1,35 @@
1
+ """Base request/response contracts and model type registry."""
2
+
3
+ from enum import Enum
4
+ from typing import Any
5
+ from uuid import UUID, uuid4
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class ModelType(str, Enum):
11
+ TIME_SERIES = "time_series"
12
+ TABULAR = "tabular"
13
+ MOLECULAR = "molecular"
14
+ GEOSPATIAL = "geospatial"
15
+ DIFFUSION = "diffusion"
16
+ NEURAL_OPERATOR = "neural_operator"
17
+ AUDIO = "audio"
18
+ EMBEDDING = "embedding"
19
+
20
+
21
+ class BaseRequest(BaseModel):
22
+ request_id: UUID = Field(default_factory=uuid4)
23
+ model_type: ModelType
24
+ model_name: str
25
+
26
+ model_config = {"arbitrary_types_allowed": True}
27
+
28
+
29
+ class BaseResponse(BaseModel):
30
+ request_id: UUID
31
+ model_type: ModelType
32
+ model_name: str
33
+ metadata: dict[str, Any] = Field(default_factory=dict)
34
+
35
+ model_config = {"arbitrary_types_allowed": True}
@@ -0,0 +1,80 @@
1
+ """API contract for time series foundation models (Chronos2, TimesFM, etc.)."""
2
+
3
+ from enum import Enum
4
+ from typing import Annotated, Any
5
+ from uuid import UUID
6
+
7
+ import numpy as np
8
+ from pydantic import Field, model_validator
9
+
10
+ from sheaf.api.base import BaseRequest, BaseResponse, ModelType
11
+
12
+
13
+ class Frequency(str, Enum):
14
+ MINUTELY = "1min"
15
+ FIVE_MINUTELY = "5min"
16
+ FIFTEEN_MINUTELY = "15min"
17
+ HOURLY = "1h"
18
+ DAILY = "1d"
19
+ WEEKLY = "1W"
20
+ MONTHLY = "1M"
21
+
22
+
23
+ class OutputMode(str, Enum):
24
+ MEAN = "mean"
25
+ QUANTILES = "quantiles"
26
+ SAMPLES = "samples"
27
+
28
+
29
+ class TimeSeriesRequest(BaseRequest):
30
+ """Request contract for time series foundation models.
31
+
32
+ Either `history` (raw values) or `feature_ref` (Feast entity reference)
33
+ must be provided, not both.
34
+ """
35
+
36
+ model_type: ModelType = ModelType.TIME_SERIES
37
+
38
+ # Input: raw history or feature store reference
39
+ history: Annotated[list[float] | None, Field(default=None)]
40
+ feature_ref: Annotated[dict[str, str] | None, Field(default=None)]
41
+ # e.g. {"feature_view": "asset_prices", "entity_id": "AAPL"}
42
+
43
+ horizon: Annotated[int, Field(gt=0)]
44
+ frequency: Frequency
45
+ output_mode: OutputMode = OutputMode.MEAN
46
+ quantile_levels: list[float] = Field(
47
+ default=[0.1, 0.5, 0.9],
48
+ description="Quantile levels — only used when output_mode=quantiles",
49
+ )
50
+ num_samples: int = Field(
51
+ default=20,
52
+ description="Number of samples — only used when output_mode=samples",
53
+ )
54
+
55
+ @model_validator(mode="after")
56
+ def validate_input_source(self) -> "TimeSeriesRequest":
57
+ if self.history is None and self.feature_ref is None:
58
+ raise ValueError("One of `history` or `feature_ref` must be provided.")
59
+ if self.history is not None and self.feature_ref is not None:
60
+ raise ValueError("Provide either `history` or `feature_ref`, not both.")
61
+ return self
62
+
63
+
64
+ class TimeSeriesResponse(BaseResponse):
65
+ """Response contract for time series foundation models."""
66
+
67
+ model_type: ModelType = ModelType.TIME_SERIES
68
+
69
+ # Mean forecast — always populated
70
+ mean: list[float]
71
+
72
+ # Populated when output_mode=quantiles
73
+ quantiles: dict[str, list[float]] | None = None
74
+ # e.g. {"0.1": [...], "0.5": [...], "0.9": [...]}
75
+
76
+ # Populated when output_mode=samples
77
+ samples: list[list[float]] | None = None
78
+
79
+ horizon: int
80
+ frequency: str
@@ -0,0 +1,4 @@
1
+ from sheaf.backends.base import ModelBackend
2
+ from sheaf.backends.chronos import Chronos2Backend
3
+
4
+ __all__ = ["ModelBackend", "Chronos2Backend"]
@@ -0,0 +1,41 @@
1
+ """Abstract base class for model backends."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ from sheaf.api.base import BaseRequest, BaseResponse
7
+
8
+
9
+ class ModelBackend(ABC):
10
+ """Pluggable model backend.
11
+
12
+ Implement this to add a new model to Sheaf. The backend owns:
13
+ - Model loading and initialization
14
+ - Preprocessing raw inputs
15
+ - Running inference
16
+ - Postprocessing outputs into a typed response
17
+ """
18
+
19
+ @abstractmethod
20
+ def load(self) -> None:
21
+ """Load model weights and initialize any required state."""
22
+ ...
23
+
24
+ @abstractmethod
25
+ def predict(self, request: BaseRequest) -> BaseResponse:
26
+ """Run inference for a single request."""
27
+ ...
28
+
29
+ def batch_predict(self, requests: list[BaseRequest]) -> list[BaseResponse]:
30
+ """Run inference over a batch.
31
+
32
+ Default implementation runs requests sequentially.
33
+ Override to implement model-type-aware batching.
34
+ """
35
+ return [self.predict(r) for r in requests]
36
+
37
+ @property
38
+ @abstractmethod
39
+ def model_type(self) -> str:
40
+ """The ModelType this backend serves."""
41
+ ...
@@ -0,0 +1,143 @@
1
+ """Chronos2 backend for time series forecasting.
2
+
3
+ Requires: pip install sheaf-serve[time-series]
4
+ Supports: amazon/chronos-t5-{tiny,mini,small,base,large}
5
+ amazon/chronos-bolt-{tiny,mini,small,base}
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ import numpy as np
13
+
14
+ from sheaf.api.base import BaseRequest, BaseResponse, ModelType
15
+ from sheaf.api.time_series import OutputMode, TimeSeriesRequest, TimeSeriesResponse
16
+ from sheaf.backends.base import ModelBackend
17
+ from sheaf.registry import register_backend
18
+
19
+ if TYPE_CHECKING:
20
+ import torch
21
+
22
+
23
+ @register_backend("chronos2")
24
+ class Chronos2Backend(ModelBackend):
25
+ """ModelBackend implementation for Chronos / Chronos-Bolt models.
26
+
27
+ Args:
28
+ model_id: HuggingFace model ID, e.g. "amazon/chronos-bolt-small"
29
+ device_map: "cpu", "cuda", "mps", or "auto"
30
+ torch_dtype: "bfloat16", "float32", etc. Passed to from_pretrained.
31
+ num_samples: Default number of samples for probabilistic output.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model_id: str = "amazon/chronos-bolt-small",
37
+ device_map: str = "cpu",
38
+ torch_dtype: str = "bfloat16",
39
+ num_samples: int = 20,
40
+ ) -> None:
41
+ self._model_id = model_id
42
+ self._device_map = device_map
43
+ self._torch_dtype = torch_dtype
44
+ self._default_num_samples = num_samples
45
+ self._pipeline: Any = None
46
+
47
+ @property
48
+ def model_type(self) -> str:
49
+ return ModelType.TIME_SERIES
50
+
51
+ def load(self) -> None:
52
+ try:
53
+ import torch
54
+ from chronos import BaseChronosPipeline
55
+ except ImportError as e:
56
+ raise ImportError(
57
+ "chronos-forecasting is required for the Chronos2 backend. "
58
+ "Install it with: pip install sheaf-serve[time-series]"
59
+ ) from e
60
+
61
+ dtype_map = {
62
+ "bfloat16": torch.bfloat16,
63
+ "float32": torch.float32,
64
+ "float16": torch.float16,
65
+ }
66
+ torch_dtype = dtype_map.get(self._torch_dtype, torch.bfloat16)
67
+
68
+ self._pipeline = BaseChronosPipeline.from_pretrained(
69
+ self._model_id,
70
+ device_map=self._device_map,
71
+ torch_dtype=torch_dtype,
72
+ )
73
+
74
+ def predict(self, request: BaseRequest) -> BaseResponse:
75
+ if not isinstance(request, TimeSeriesRequest):
76
+ raise TypeError(f"Expected TimeSeriesRequest, got {type(request)}")
77
+ return self._run([request])[0]
78
+
79
+ def batch_predict(self, requests: list[BaseRequest]) -> list[BaseResponse]:
80
+ ts_requests = [r for r in requests if isinstance(r, TimeSeriesRequest)]
81
+ if len(ts_requests) != len(requests):
82
+ raise TypeError("All requests must be TimeSeriesRequest for Chronos2Backend")
83
+ return self._run(ts_requests)
84
+
85
+ def _run(self, requests: list[TimeSeriesRequest]) -> list[TimeSeriesResponse]:
86
+ import torch
87
+
88
+ if self._pipeline is None:
89
+ raise RuntimeError("Backend not loaded. Call load() first.")
90
+
91
+ # Build context tensors — pad to same length within batch
92
+ histories = [r.history or [] for r in requests]
93
+ max_len = max(len(h) for h in histories)
94
+ padded = [
95
+ [float("nan")] * (max_len - len(h)) + h
96
+ for h in histories
97
+ ]
98
+ context = torch.tensor(padded, dtype=torch.float32)
99
+
100
+ # Bucket by horizon: all requests in a batch must share a horizon
101
+ # (caller is responsible for bucketing — see BatchPolicy(bucket_by="horizon"))
102
+ horizon = requests[0].horizon
103
+ num_samples = requests[0].num_samples
104
+
105
+ # predict returns [batch, num_samples, horizon]
106
+ forecast = self._pipeline.predict(
107
+ context=context,
108
+ prediction_length=horizon,
109
+ num_samples=num_samples,
110
+ )
111
+ forecast_np: np.ndarray = forecast.numpy()
112
+
113
+ responses = []
114
+ for i, req in enumerate(requests):
115
+ samples = forecast_np[i] # [num_samples, horizon]
116
+ responses.append(self._build_response(req, samples))
117
+ return responses
118
+
119
+ def _build_response(
120
+ self, req: TimeSeriesRequest, samples: np.ndarray
121
+ ) -> TimeSeriesResponse:
122
+ mean = samples.mean(axis=0).tolist()
123
+
124
+ quantiles = None
125
+ if req.output_mode == OutputMode.QUANTILES:
126
+ quantiles = {
127
+ str(q): np.quantile(samples, q, axis=0).tolist()
128
+ for q in req.quantile_levels
129
+ }
130
+
131
+ raw_samples = None
132
+ if req.output_mode == OutputMode.SAMPLES:
133
+ raw_samples = samples.tolist()
134
+
135
+ return TimeSeriesResponse(
136
+ request_id=req.request_id,
137
+ model_name=req.model_name,
138
+ horizon=req.horizon,
139
+ frequency=req.frequency.value,
140
+ mean=mean,
141
+ quantiles=quantiles,
142
+ samples=raw_samples,
143
+ )
File without changes
File without changes
@@ -0,0 +1,24 @@
1
+ """Backend registry — separated to avoid circular imports."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from sheaf.backends.base import ModelBackend
9
+
10
+ _BACKEND_REGISTRY: dict[str, type["ModelBackend"]] = {}
11
+
12
+
13
+ def register_backend(name: str):
14
+ """Decorator to register a ModelBackend implementation by name.
15
+
16
+ Example:
17
+ @register_backend("chronos2")
18
+ class Chronos2Backend(ModelBackend):
19
+ ...
20
+ """
21
+ def decorator(cls: type["ModelBackend"]) -> type["ModelBackend"]:
22
+ _BACKEND_REGISTRY[name] = cls
23
+ return cls
24
+ return decorator
@@ -0,0 +1,3 @@
1
+ from sheaf.scheduling.batch import BatchPolicy
2
+
3
+ __all__ = ["BatchPolicy"]
@@ -0,0 +1,17 @@
1
+ """Batching policies for model-type-aware request scheduling."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class BatchPolicy(BaseModel):
7
+ """Controls how requests are batched before hitting the model backend.
8
+
9
+ max_batch_size: hard cap on requests per batch
10
+ timeout_ms: max time to wait for a full batch before flushing
11
+ bucket_by: field name to bucket on before batching (e.g. "horizon" for
12
+ time series, so variable-length forecasts don't get mixed)
13
+ """
14
+
15
+ max_batch_size: int = Field(default=32, gt=0)
16
+ timeout_ms: int = Field(default=50, gt=0)
17
+ bucket_by: str | None = None
@@ -0,0 +1,71 @@
1
+ """ModelServer — the Ray Serve entry point for Sheaf."""
2
+
3
+ from typing import Any
4
+
5
+ import ray
6
+ from ray import serve
7
+
8
+ from sheaf.api.base import BaseRequest, BaseResponse
9
+ from sheaf.backends.base import ModelBackend
10
+ from sheaf.registry import _BACKEND_REGISTRY, register_backend # noqa: F401
11
+ from sheaf.spec import ModelSpec
12
+
13
+
14
+ @serve.deployment
15
+ class _SheafDeployment:
16
+ def __init__(self, spec: ModelSpec) -> None:
17
+ backend_cls = _BACKEND_REGISTRY.get(spec.backend)
18
+ if backend_cls is None:
19
+ raise ValueError(
20
+ f"Unknown backend '{spec.backend}'. "
21
+ f"Registered backends: {list(_BACKEND_REGISTRY)}"
22
+ )
23
+ self._backend: ModelBackend = backend_cls(**spec.backend_kwargs)
24
+ self._backend.load()
25
+ self._spec = spec
26
+
27
+ async def __call__(self, request: BaseRequest) -> BaseResponse:
28
+ return self._backend.predict(request)
29
+
30
+
31
+ class ModelServer:
32
+ """Top-level serving orchestrator.
33
+
34
+ Example:
35
+ server = ModelServer(
36
+ models=[chronos_spec, tabpfn_spec],
37
+ )
38
+ server.run()
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ models: list[ModelSpec],
44
+ host: str = "0.0.0.0",
45
+ port: int = 8000,
46
+ ) -> None:
47
+ self._models = models
48
+ self._host = host
49
+ self._port = port
50
+ self._deployments: dict[str, Any] = {}
51
+
52
+ def run(self) -> None:
53
+ if not ray.is_initialized():
54
+ ray.init()
55
+
56
+ serve.start(http_options={"host": self._host, "port": self._port})
57
+
58
+ for spec in self._models:
59
+ deployment = _SheafDeployment.options(
60
+ name=spec.name,
61
+ num_replicas=spec.resources.replicas,
62
+ ray_actor_options={
63
+ "num_cpus": spec.resources.num_cpus,
64
+ "num_gpus": spec.resources.num_gpus,
65
+ },
66
+ ).bind(spec)
67
+ handle = serve.run(deployment, name=spec.name, route_prefix=f"/{spec.name}")
68
+ self._deployments[spec.name] = handle
69
+
70
+ def shutdown(self) -> None:
71
+ serve.shutdown()
@@ -0,0 +1,35 @@
1
+ """ModelSpec — declares what Sheaf should serve."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from sheaf.api.base import ModelType
6
+ from sheaf.scheduling.batch import BatchPolicy
7
+
8
+
9
+ class ResourceConfig(BaseModel):
10
+ num_cpus: float = 1.0
11
+ num_gpus: float = 0.0
12
+ memory_gb: float | None = None
13
+ replicas: int = 1
14
+
15
+
16
+ class ModelSpec(BaseModel):
17
+ """Declares a model to be served by Sheaf.
18
+
19
+ Example:
20
+ spec = ModelSpec(
21
+ name="chronos2-small",
22
+ model_type=ModelType.TIME_SERIES,
23
+ backend="chronos2",
24
+ backend_kwargs={"model_size": "small"},
25
+ resources=ResourceConfig(num_gpus=1, replicas=2),
26
+ batch_policy=BatchPolicy(max_batch_size=64, bucket_by="horizon"),
27
+ )
28
+ """
29
+
30
+ name: str
31
+ model_type: ModelType
32
+ backend: str
33
+ backend_kwargs: dict = Field(default_factory=dict)
34
+ resources: ResourceConfig = Field(default_factory=ResourceConfig)
35
+ batch_policy: BatchPolicy = Field(default_factory=BatchPolicy)