ray-embedding 0.10.12__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ray-embedding might be problematic. Click here for more details.

ray_embedding/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from ray_embedding.deploy import deploy_model
1
+ from ray_embedding.deploy import build_app
2
2
 
ray_embedding/deploy.py CHANGED
@@ -1,26 +1,22 @@
1
1
  from typing import Dict, Any, Optional
2
2
  from ray.serve import Application
3
+ from ray.serve.handle import DeploymentHandle
4
+
5
+ from ray_embedding.dto import AppConfig, ModelDeploymentConfig
3
6
  from ray_embedding.embedding_model import EmbeddingModel
4
7
  import torch
5
8
 
9
+ from ray_embedding.embedding_service import EmbeddingService
6
10
 
7
- def deploy_model(args: Dict[str, Any]) -> Application:
8
- """Builds and deploys a SentenceTransformer embedding model.
9
- :arg args: arguments for initializing a SentenceTransformer model
10
- :returns: a Ray Serve Application
11
- """
12
- assert args
13
- deployment_name: str = args.pop("deployment", "")
14
- assert deployment_name
15
-
16
- model: str = args.pop("model", "")
17
- assert model
18
11
 
19
- device: Optional[str] = args.pop("device", None)
20
- backend: Optional[str] = args.pop("backend", "torch")
21
- matryoshka_dim: Optional[int] = args.pop("matryoshka_dim", None)
22
- trust_remote_code: Optional[bool] = args.pop("trust_remote_code", False)
23
- model_kwargs: Dict[str, Any] = args.pop("model_kwargs", {})
12
+ def build_model(model_config: ModelDeploymentConfig) -> DeploymentHandle:
13
+ deployment_name = model_config.deployment_name
14
+ model = model_config.model
15
+ device = model_config.device
16
+ backend = model_config.backend or "torch"
17
+ matryoshka_dim = model_config.matryoshka_dim
18
+ trust_remote_code = model_config.trust_remote_code or False
19
+ model_kwargs = model_config.model_kwargs or {}
24
20
  if "torch_dtype" in model_kwargs:
25
21
  torch_dtype = model_kwargs["torch_dtype"].strip()
26
22
  if torch_dtype == "float16":
@@ -39,4 +35,12 @@ def deploy_model(args: Dict[str, Any]) -> Application:
39
35
  trust_remote_code=trust_remote_code,
40
36
  model_kwargs=model_kwargs
41
37
  )
42
- return deployment
38
+ return deployment
39
+
40
+ def build_app(args: AppConfig) -> Application:
41
+ model_router, models = args.model_router, args.models
42
+ assert model_router and models
43
+
44
+ served_models = {model_config.model: build_model(model_config) for model_config in models}
45
+ app = EmbeddingService.options(name=model_router.deployment).bind(served_models)
46
+ return app
ray_embedding/dto.py CHANGED
@@ -1,5 +1,7 @@
1
- from typing import Union, List, Optional
1
+ import dataclasses
2
+ from typing import Union, List, Optional, Dict, Any
2
3
  from pydantic import BaseModel
4
+ from ray.serve.handle import DeploymentHandle
3
5
 
4
6
 
5
7
  class EmbeddingRequest(BaseModel):
@@ -13,4 +15,34 @@ class EmbeddingResponse(BaseModel):
13
15
  """Schema of embedding response (compatible with OpenAI)"""
14
16
  object: str
15
17
  data: List[dict] # Embedding data including index and vector
16
- model: str # Model name used for embedding
18
+ model: str # Model name used for embedding
19
+
20
+
21
+ class ModelRouterConfig(BaseModel):
22
+ deployment: str
23
+
24
+
25
+ class ModelDeploymentConfig(BaseModel):
26
+ deployment: str
27
+ model: str
28
+ served_model_name: str
29
+ batch_size: Optional[int] = 8
30
+ num_retries: Optional[int] = 2
31
+ device: Optional[str] = None
32
+ backend: Optional[str] = None
33
+ matryoshka_dim: Optional[int] = 768
34
+ trust_remote_code: Optional[bool] = False
35
+ model_kwargs: Optional[Dict[str, Any]] = {}
36
+
37
+
38
+ class AppConfig(BaseModel):
39
+ model_router: ModelRouterConfig
40
+ models: List[ModelDeploymentConfig]
41
+
42
+
43
+ @dataclasses.dataclass
44
+ class DeployedModel:
45
+ model: str
46
+ deployment_handle: DeploymentHandle
47
+ batch_size: int
48
+ num_retries: Optional[int] = 2
@@ -1,34 +1,18 @@
1
1
  import logging
2
2
  import os.path
3
3
  import time
4
- from typing import Optional, Dict, Any, List
4
+ from typing import Optional, Dict, Any, List, Union
5
5
 
6
6
  import torch
7
- from fastapi import FastAPI, HTTPException
7
+ from pynvml import nvmlInit, nvmlDeviceGetCount
8
8
  from ray import serve
9
9
  from sentence_transformers import SentenceTransformer
10
10
 
11
- from ray_embedding.dto import EmbeddingResponse, EmbeddingRequest
12
11
 
13
- web_api = FastAPI(title=f"Ray Embeddings - OpenAI-compatible API")
14
-
15
-
16
- @serve.deployment(
17
- num_replicas="auto",
18
- ray_actor_options={
19
- "num_cpus": 1,
20
- "num_gpus": 0
21
- },
22
- autoscaling_config={
23
- "target_ongoing_requests": 2,
24
- "min_replicas": 0,
25
- "initial_replicas": 1,
26
- "max_replicas": 1,
27
- }
28
- )
29
- @serve.ingress(web_api)
12
+ @serve.deployment
30
13
  class EmbeddingModel:
31
- def __init__(self, model: str, device: Optional[str] = None, backend: Optional[str] = "torch",
14
+ def __init__(self, model: str, served_model_name: Optional[str] = None,
15
+ device: Optional[str] = None, backend: Optional[str] = "torch",
32
16
  matryoshka_dim: Optional[int] = None, trust_remote_code: Optional[bool] = False,
33
17
  model_kwargs: Dict[str, Any] = None):
34
18
  logging.basicConfig(level=logging.INFO)
@@ -49,56 +33,29 @@ class EmbeddingModel:
49
33
  trust_remote_code=self.trust_remote_code,
50
34
  model_kwargs=self.model_kwargs)
51
35
 
52
- self.served_model_name = os.path.basename(self.model)
53
- self.available_models = [
54
- {"id": self.served_model_name,
55
- "object": "model",
56
- "created": int(time.time()),
57
- "owned_by": "openai",
58
- "permission": []}
59
- ]
36
+ self.served_model_name = served_model_name or os.path.basename(self.model)
60
37
  self.logger.info(f"Successfully initialized embedding model {self.model} using device {self.torch_device}")
61
38
 
62
- @web_api.post("/v1/embeddings", response_model=EmbeddingResponse)
63
- async def create_embeddings(self, request: EmbeddingRequest):
64
- """Generate embeddings for the input text using the specified model."""
65
- try:
66
- assert request.model == self.served_model_name, (
67
- f"Model '{request.model}' is not supported. Use '{self.served_model_name}' instead."
68
- )
69
- if isinstance(request.input, str):
70
- request.input = [request.input]
71
-
72
- truncate_dim = request.dimensions or self.matryoshka_dim
73
-
74
- # Compute embeddings and convert to a PyTorch tensor on the GPU
75
- embeddings = self.embedding_model.encode(
76
- request.input, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False,
77
- ).to(self.torch_device)
39
+ async def __call__(self, text: Union[str, List[str]], dimensions: Optional[int] = None):
40
+ """Compute embeddings for the input text using the loaded model."""
41
+ if isinstance(text, str):
42
+ text = [text]
43
+ truncate_dim = dimensions or self.matryoshka_dim
78
44
 
79
- if truncate_dim is not None:
80
- # Truncate and re-normalize the embeddings
81
- embeddings = embeddings[:, :truncate_dim]
82
- embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
45
+ # Compute embeddings and convert to a PyTorch tensor on the GPU
46
+ embeddings = self.embedding_model.encode(
47
+ text, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False,
48
+ ).to(self.torch_device)
83
49
 
84
- # Move all embeddings to CPU at once before conversion
85
- embeddings = embeddings.cpu().tolist()
50
+ if truncate_dim is not None:
51
+ # Truncate and re-normalize the embeddings
52
+ embeddings = embeddings[:, :truncate_dim]
53
+ embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
86
54
 
87
- # Convert embeddings to list format for response
88
- response_data = [
89
- {"index": idx, "embedding": emb}
90
- for idx, emb in enumerate(embeddings)
91
- ]
92
- return EmbeddingResponse(object="list", data=response_data, model=request.model)
55
+ # Move all embeddings to CPU at once before conversion
56
+ embeddings = embeddings.cpu().tolist()
57
+ return embeddings
93
58
 
94
- except Exception as e:
95
- self.logger.error(e)
96
- raise HTTPException(status_code=500, detail=str(e))
97
-
98
- @web_api.get("/v1/models")
99
- async def list_models(self):
100
- """Returns the list of available models in OpenAI-compatible format."""
101
- return {"object": "list", "data": self.available_models}
102
59
 
103
60
  def wait_for_cuda(self, wait: int = 10):
104
61
  if self.init_device == "cuda" and not torch.cuda.is_available():
@@ -106,8 +63,12 @@ class EmbeddingModel:
106
63
  self.check_health()
107
64
 
108
65
  def check_health(self):
109
- if self.init_device == "cuda" and not torch.cuda.is_available():
66
+ if self.init_device == "cuda":
110
67
  # Even though CUDA was available at init time,
111
68
  # CUDA can become unavailable - this is a known problem in AWS EC2
112
69
  # https://github.com/ray-project/ray/issues/49594
113
- raise RuntimeError("CUDA device is not available")
70
+ try:
71
+ nvmlInit()
72
+ assert nvmlDeviceGetCount() >= 1
73
+ except:
74
+ raise RuntimeError("CUDA device is not available")
@@ -0,0 +1,77 @@
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ from typing import Optional, Dict, List
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from ray import serve
8
+
9
+ from ray_embedding.dto import EmbeddingResponse, EmbeddingRequest, DeployedModel
10
+
11
+ web_api = FastAPI(title="Ray Embeddings - OpenAI-compatible API")
12
+
13
+ @serve.deployment
14
+ @serve.ingress(web_api)
15
+ class EmbeddingService:
16
+ def __init__(self, served_models: Dict[str, DeployedModel]):
17
+ self.logger = logging.getLogger(self.__class__.__name__)
18
+ assert served_models, "models cannot be empty"
19
+ self.served_models = served_models
20
+ self.available_models = [
21
+ {"id": str(item),
22
+ "object": "model",
23
+ "created": int(time.time()),
24
+ "owned_by": "openai",
25
+ "permission": []} for item in self.served_models.keys()
26
+ ]
27
+ self.logger.info(f"Successfully registered models: {self.available_models}")
28
+
29
+ async def _compute_embeddings_from_resized_batches(self, model: str, inputs: List[str], dimensions: Optional[int] = None):
30
+ assert model in self.served_models
31
+ model_handle = self.served_models[model].deployment_handle
32
+ batch_size = self.served_models[model].batch_size
33
+ num_retries = self.served_models[model].num_retries
34
+
35
+ # Resize the inputs into batch_size items, and dispatch in parallel
36
+ batches = [inputs[i:i+batch_size] for i in range(0, len(inputs), batch_size)]
37
+ tasks = [model_handle.remote(batch, dimensions) for batch in batches]
38
+ all_results = await asyncio.gather(*tasks, return_exceptions=True)
39
+
40
+ # Retry any failed model calls
41
+ for i, result in enumerate(all_results):
42
+ if isinstance(result, Exception):
43
+ retries = 0
44
+ while retries < num_retries:
45
+ try:
46
+ all_results[i] = await model_handle.remote(batches[i], dimensions)
47
+ except Exception as e:
48
+ self.logger.warning(e)
49
+ finally:
50
+ retries += 1
51
+ if not isinstance(all_results[i], Exception):
52
+ break
53
+
54
+ if retries >= num_retries and isinstance(all_results[i], Exception):
55
+ raise all_results[i]
56
+
57
+ # Flatten the results because all_results is a list of lists
58
+ return [emb for result in all_results for emb in result]
59
+
60
+ @web_api.post("/v1/embeddings", response_model=EmbeddingResponse)
61
+ async def compute_embeddings(self, request: EmbeddingRequest):
62
+ try:
63
+ inputs = request.input if isinstance(request.input, list) else [request.input]
64
+ embeddings = await self._compute_embeddings_from_resized_batches(request.model, inputs, request.dimensions)
65
+ response_data = [
66
+ {"index": idx, "embedding": emb}
67
+ for idx, emb in enumerate(embeddings)
68
+ ]
69
+ return EmbeddingResponse(object="list", data=response_data, model=request.model)
70
+ except Exception as e:
71
+ self.logger.error(f"Failed to create embeddings: {e}")
72
+ raise HTTPException(status_code=500, detail=str(e))
73
+
74
+ @web_api.get("/v1/models")
75
+ async def list_models(self):
76
+ """Returns the list of available models in OpenAI-compatible format."""
77
+ return {"object": "list", "data": self.available_models}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ray-embedding
3
- Version: 0.10.12
3
+ Version: 0.11.0
4
4
  Summary: Deploy SentenceTransformers embedding models to a ray cluster
5
5
  Author: Crispin Almodovar
6
6
  Author-email:
@@ -0,0 +1,9 @@
1
+ ray_embedding/__init__.py,sha256=YS5LAZfRIwwVvE3C9g7hsauvjgIkqKtHyxkwMFFfAGY,46
2
+ ray_embedding/deploy.py,sha256=rqVJ8GVh3hGyjojJ9wzHosiPe6gRNVv5-8zkENKKqMc,2195
3
+ ray_embedding/dto.py,sha256=QlduDoqkFHaeF_KgsFeUKq2XWiPMmrgRPy_QjCTSCRE,1399
4
+ ray_embedding/embedding_model.py,sha256=d_gqqcKK3B2nF8qMDk7NKW8gHBbtPwvPFArKTBKNdFA,3372
5
+ ray_embedding/embedding_service.py,sha256=HwIaVQFegKYF3sp6ST5G2vf0oW0Sl3W_SDp6eq79Tng,3412
6
+ ray_embedding-0.11.0.dist-info/METADATA,sha256=VdFyMJDMJcaWinYq7hAhztPkCzjGtSmtD7UNibfdlPY,1094
7
+ ray_embedding-0.11.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ ray_embedding-0.11.0.dist-info/top_level.txt,sha256=ziCblpJq1YsrryshFqxTRuRMgNuO1_tgvAAkGShATNA,14
9
+ ray_embedding-0.11.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,8 +0,0 @@
1
- ray_embedding/__init__.py,sha256=OYJT0rVaaGzY613JqgfktsCgroDnBkGOHxR2FE9UtRU,49
2
- ray_embedding/deploy.py,sha256=ZGxcG4589WcRtaM6H84YJarw0m1XqHNgfOf3PLAhM5M,1995
3
- ray_embedding/dto.py,sha256=e91ejZbM_NB9WTjF1YnfuV71cajYIh0vOX8oV_g2OwM,595
4
- ray_embedding/embedding_model.py,sha256=JfNt0rJYXGlbNjg5xmY14k4jQmNZKHnTQZnTi5SbNSc,4829
5
- ray_embedding-0.10.12.dist-info/METADATA,sha256=7EjNTy3P2zGYP6JzjrigmpbAujjaroJA-e1slz_Ra1E,1095
6
- ray_embedding-0.10.12.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
7
- ray_embedding-0.10.12.dist-info/top_level.txt,sha256=ziCblpJq1YsrryshFqxTRuRMgNuO1_tgvAAkGShATNA,14
8
- ray_embedding-0.10.12.dist-info/RECORD,,