ray-embedding 0.13.9__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ray_embedding/deploy.py CHANGED
@@ -1,70 +1,87 @@
1
- import os
2
- from typing import Optional
3
-
4
- import torch
5
- from ray.serve import Application
6
- from ray.serve.handle import DeploymentHandle
7
-
8
- from ray_embedding.dto import AppConfig, ModelDeploymentConfig, DeployedModel
9
- from ray_embedding.embedding_model import EmbeddingModel
10
- from ray_embedding.model_router import ModelRouter
11
- from ray_embedding.node_health import NodeHealthTracker
12
-
13
- DEFAULT_NODE_HEALTH_CHECK_INTERVAL_S = 30
14
-
15
-
16
- def build_model(model_config: ModelDeploymentConfig, node_health_tracker: Optional[DeploymentHandle] = None) -> DeployedModel:
17
- deployment_name = model_config.deployment
18
- model = model_config.model
19
- served_model_name = model_config.served_model_name or os.path.basename(model)
20
- device = model_config.device
21
- backend = model_config.backend or "torch"
22
- matryoshka_dim = model_config.matryoshka_dim
23
- trust_remote_code = model_config.trust_remote_code or False
24
- model_kwargs = model_config.model_kwargs or {}
25
- cuda_memory_flush_threshold = model_config.cuda_memory_flush_threshold or 0.8
26
-
27
- if "torch_dtype" in model_kwargs:
28
- torch_dtype = model_kwargs["torch_dtype"].strip()
29
- if torch_dtype == "float16":
30
- model_kwargs["torch_dtype"] = torch.float16
31
- elif torch_dtype == "bfloat16":
32
- model_kwargs["torch_dtype"] = torch.bfloat16
33
- elif torch_dtype == "float32":
34
- model_kwargs["torch_dtype"] = torch.float32
35
- else:
36
- raise ValueError(f"Invalid torch_dtype: '{torch_dtype}'")
37
-
38
- deployment = EmbeddingModel.options(name=deployment_name).bind(model=model,
39
- served_model_name=served_model_name,
40
- device=device,
41
- backend=backend,
42
- matryoshka_dim=matryoshka_dim,
43
- trust_remote_code=trust_remote_code,
44
- model_kwargs=model_kwargs,
45
- cuda_memory_flush_threshold=cuda_memory_flush_threshold,
46
- node_health_tracker=node_health_tracker
47
- )
48
- return DeployedModel(model=served_model_name,
49
- deployment_handle=deployment,
50
- batch_size=model_config.batch_size,
51
- num_retries=model_config.num_retries
52
- )
53
-
54
-
55
- def build_app(args: AppConfig) -> Application:
56
- model_router, models = args.model_router, args.models
57
- assert model_router and models
58
- assert model_router.path_prefix
59
-
60
- node_health_check_interval_s = args.node_health_check_interval_s or DEFAULT_NODE_HEALTH_CHECK_INTERVAL_S
61
- tracked_model_deployments = [model_config.deployment for model_config in models]
62
- node_health_tracker = (NodeHealthTracker.options(health_check_period_s=node_health_check_interval_s)
63
- .bind(tracked_model_deployments=tracked_model_deployments))
64
- deployed_models = {model_config.served_model_name: build_model(model_config, node_health_tracker=node_health_tracker)
65
- for model_config in models}
66
- router = (ModelRouter.options(name=model_router.deployment)
67
- .bind(deployed_models=deployed_models,
68
- path_prefix=model_router.path_prefix,
69
- node_health_tracker=node_health_tracker))
70
- return router
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ import torch
5
+ from ray.serve import Application
6
+
7
+ from ray_embedding.dto import AppConfig, ModelDeploymentConfig, DeployedModel, NodeReaperConfig
8
+ from ray_embedding.embedding_model import EmbeddingModel
9
+ from ray_embedding.node_reaper import NodeReaper, NODE_REAPER_DEPLOYMENT_NAME
10
+ from ray_embedding.utils import get_head_node_id
11
+ from ray_embedding.model_router import ModelRouter
12
+
13
+
14
+ def build_model(model_config: ModelDeploymentConfig, node_reaper):
15
+ deployment_name = model_config.deployment
16
+ model = model_config.model
17
+ served_model_name = model_config.served_model_name or os.path.basename(model)
18
+ device = model_config.device
19
+ backend = model_config.backend or "torch"
20
+ matryoshka_dim = model_config.matryoshka_dim
21
+ trust_remote_code = model_config.trust_remote_code or False
22
+ model_kwargs = model_config.model_kwargs or {}
23
+ cuda_memory_flush_threshold = model_config.cuda_memory_flush_threshold or 0.8
24
+
25
+ if "torch_dtype" in model_kwargs:
26
+ torch_dtype = model_kwargs["torch_dtype"].strip()
27
+ if torch_dtype == "float16":
28
+ model_kwargs["torch_dtype"] = torch.float16
29
+ elif torch_dtype == "bfloat16":
30
+ model_kwargs["torch_dtype"] = torch.bfloat16
31
+ elif torch_dtype == "float32":
32
+ model_kwargs["torch_dtype"] = torch.float32
33
+ else:
34
+ raise ValueError(f"Invalid torch_dtype: '{torch_dtype}'")
35
+
36
+ deployment = EmbeddingModel.options(name=deployment_name).bind(model=model,
37
+ served_model_name=served_model_name,
38
+ device=device,
39
+ backend=backend,
40
+ matryoshka_dim=matryoshka_dim,
41
+ trust_remote_code=trust_remote_code,
42
+ model_kwargs=model_kwargs,
43
+ cuda_memory_flush_threshold=cuda_memory_flush_threshold,
44
+ node_reaper=node_reaper,
45
+ )
46
+ return DeployedModel(model=served_model_name,
47
+ deployment_handle=deployment,
48
+ batch_size=model_config.batch_size,
49
+ num_retries=model_config.num_retries
50
+ )
51
+
52
+
53
+ def build_app(args: AppConfig) -> Application:
54
+ model_router, models = args.model_router, args.models
55
+ assert model_router and models
56
+ assert model_router.path_prefix
57
+
58
+ node_reaper_config = args.node_reaper or NodeReaperConfig()
59
+
60
+ node_reaper_kwargs: Dict[str, Any] = {
61
+ "ssh_user": node_reaper_config.ssh_user,
62
+ "ssh_private_key": node_reaper_config.ssh_private_key,
63
+ }
64
+ if node_reaper_config.retention_seconds is not None:
65
+ node_reaper_kwargs["retention_seconds"] = node_reaper_config.retention_seconds
66
+ if node_reaper_config.reap_interval_seconds is not None:
67
+ node_reaper_kwargs["reap_interval_seconds"] = node_reaper_config.reap_interval_seconds
68
+
69
+ node_reaper = NodeReaper.options(
70
+ name=NODE_REAPER_DEPLOYMENT_NAME,
71
+ ray_actor_options={"num_cpus": 0.25, "resources": {"node_type:head": 1}},
72
+ autoscaling_config={"initial_replicas": 1, "min_replicas": 1, "max_replicas": 1}
73
+ ).bind(**node_reaper_kwargs)
74
+
75
+ deployed_models = {model_config.served_model_name: build_model(model_config, node_reaper) for model_config in models}
76
+ model_router_kwargs = {
77
+ "deployed_models": deployed_models,
78
+ "path_prefix": model_router.path_prefix,
79
+ "max_concurrency": model_router.max_concurrency,
80
+ "node_reaper": node_reaper
81
+ }
82
+ router = ModelRouter.options(
83
+ name=model_router.deployment,
84
+ ray_actor_options={"num_cpus": 0.25, "resources": {"node_type:worker": 1}}
85
+ ).bind(**model_router_kwargs)
86
+
87
+ return router
ray_embedding/dto.py CHANGED
@@ -1,52 +1,59 @@
1
- import dataclasses
2
- from typing import Union, List, Optional, Dict, Any
3
- from pydantic import BaseModel
4
- from ray.serve.handle import DeploymentHandle
5
-
6
-
7
- class EmbeddingRequest(BaseModel):
8
- """Schema of embedding requests (compatible with OpenAI)"""
9
- model: str # Model name (for compatibility; only one model is used here)
10
- input: Union[str, List[str]] # List of strings to embed
11
- dimensions: Optional[int] = None
12
-
13
-
14
- class EmbeddingResponse(BaseModel):
15
- """Schema of embedding response (compatible with OpenAI)"""
16
- object: str
17
- data: List[dict] # Embedding data including index and vector
18
- model: str # Model name used for embedding
19
-
20
-
21
- class ModelRouterConfig(BaseModel):
22
- deployment: str
23
- path_prefix: List[str] = []
24
- max_concurrency: int = 32
25
-
26
-
27
- class ModelDeploymentConfig(BaseModel):
28
- model: str
29
- served_model_name: str
30
- batch_size: Optional[int] = 8
31
- num_retries: Optional[int] = 2
32
- device: Optional[str] = None
33
- backend: Optional[str] = None
34
- matryoshka_dim: Optional[int] = 768
35
- trust_remote_code: Optional[bool] = False
36
- model_kwargs: Optional[Dict[str, Any]] = {}
37
- cuda_memory_flush_threshold: Optional[float] = 0.8
38
- deployment: str
39
-
40
-
41
- class AppConfig(BaseModel):
42
- model_router: ModelRouterConfig
43
- node_health_check_interval_s: Optional[int] = 30
44
- models: List[ModelDeploymentConfig]
45
-
46
-
47
- @dataclasses.dataclass
48
- class DeployedModel:
49
- model: str
50
- deployment_handle: DeploymentHandle
51
- batch_size: int
52
- num_retries: Optional[int] = 2
1
+ import dataclasses
2
+ from typing import Union, List, Optional, Dict, Any
3
+ from pydantic import BaseModel
4
+ from ray.serve.handle import DeploymentHandle
5
+
6
+
7
+ class EmbeddingRequest(BaseModel):
8
+ """Schema of embedding requests (compatible with OpenAI)"""
9
+ model: str # Model name (for compatibility; only one model is used here)
10
+ input: Union[str, List[str]] # List of strings to embed
11
+ dimensions: Optional[int] = None
12
+
13
+
14
+ class EmbeddingResponse(BaseModel):
15
+ """Schema of embedding response (compatible with OpenAI)"""
16
+ object: str
17
+ data: List[dict] # Embedding data including index and vector
18
+ model: str # Model name used for embedding
19
+
20
+
21
+ class ModelRouterConfig(BaseModel):
22
+ deployment: str
23
+ path_prefix: List[str] = []
24
+ max_concurrency: int = 32
25
+
26
+
27
+ class ModelDeploymentConfig(BaseModel):
28
+ model: str
29
+ served_model_name: str
30
+ batch_size: Optional[int] = 8
31
+ num_retries: Optional[int] = 2
32
+ device: Optional[str] = None
33
+ backend: Optional[str] = None
34
+ matryoshka_dim: Optional[int] = 768
35
+ trust_remote_code: Optional[bool] = False
36
+ model_kwargs: Optional[Dict[str, Any]] = {}
37
+ cuda_memory_flush_threshold: Optional[float] = 0.8
38
+ deployment: str
39
+
40
+
41
+ class NodeReaperConfig(BaseModel):
42
+ ssh_user: str = "ubuntu"
43
+ ssh_private_key: str = "/home/ray/ray_bootstrap_key.pem"
44
+ retention_seconds: Optional[int] = 900
45
+ reap_interval_seconds: Optional[int] = 60
46
+
47
+
48
+ class AppConfig(BaseModel):
49
+ model_router: ModelRouterConfig
50
+ node_reaper: Optional[NodeReaperConfig] = None
51
+ models: List[ModelDeploymentConfig]
52
+
53
+
54
+ @dataclasses.dataclass
55
+ class DeployedModel:
56
+ model: str
57
+ deployment_handle: DeploymentHandle
58
+ batch_size: int
59
+ num_retries: Optional[int] = 2
@@ -1,126 +1,112 @@
1
- import logging
2
- import os.path
3
- import time
4
- from typing import Optional, Dict, Any, List, Union
5
-
6
- import ray
7
- import torch
8
- from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
9
- from ray import serve
10
- from ray.util import get_node_ip_address
11
- from ray.serve.handle import DeploymentHandle
12
- from sentence_transformers import SentenceTransformer
13
-
14
-
15
- @serve.deployment
16
- class EmbeddingModel:
17
- def __init__(self, model: str, served_model_name: Optional[str] = None,
18
- device: Optional[str] = None, backend: Optional[str] = "torch",
19
- matryoshka_dim: Optional[int] = None, trust_remote_code: Optional[bool] = False,
20
- model_kwargs: Dict[str, Any] = None, cuda_memory_flush_threshold: Optional[float] = 0.8,
21
- node_health_tracker: Optional[DeploymentHandle] = None):
22
- logging.basicConfig(level=logging.INFO)
23
- self.logger = logging.getLogger(self.__class__.__name__)
24
- self.model = model
25
- self.served_model_name = served_model_name or os.path.basename(self.model)
26
- self.init_device = device
27
- self.cuda_memory_flush_threshold = cuda_memory_flush_threshold
28
- if self.init_device is None or self.init_device == "auto":
29
- self.init_device = "cuda" if torch.cuda.is_available() else "cpu"
30
- if self.init_device == "cuda":
31
- self.wait_for_cuda()
32
- self.torch_device = torch.device(self.init_device)
33
- self.backend = backend or "torch"
34
- self.matryoshka_dim = matryoshka_dim
35
- self.trust_remote_code = trust_remote_code or False
36
- self.model_kwargs = model_kwargs or {}
37
-
38
- self.logger.info(f"Initializing embedding model: {self.model}")
39
- self.embedding_model = SentenceTransformer(self.model, device=self.init_device, backend=self.backend,
40
- trust_remote_code=self.trust_remote_code,
41
- model_kwargs=self.model_kwargs)
42
-
43
- self.node_health_tracker = node_health_tracker
44
- replica_context = serve.get_replica_context()
45
- self.deployment_name = replica_context.deployment
46
- self.replica_actor_name = replica_context.replica_id.to_full_id_str()
47
- self.node_ip = get_node_ip_address()
48
- self.logger.info(f"Successfully initialized model {self.model} using device {self.torch_device}. "
49
- f"Deployment name: {self.deployment_name}, Replica actor name: {self.replica_actor_name}, Node IP: {self.node_ip}")
50
-
51
- async def __call__(self, text: Union[str, List[str]], dimensions: Optional[int] = None) -> List[List[float]]:
52
- """Compute embeddings for the input text using the current model."""
53
- if not text or (isinstance(text, list) and not all(text)):
54
- raise ValueError("Input text is empty or invalid")
55
-
56
- text = [text] if isinstance(text, str) else text
57
- truncate_dim = dimensions or self.matryoshka_dim
58
-
59
- # Compute embeddings in PyTorch format
60
- embeddings = self.embedding_model.encode(
61
- text, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False,
62
- ).to(self.torch_device)
63
-
64
- if truncate_dim is not None:
65
- # Truncate and re-normalize the embeddings
66
- embeddings = embeddings[:, :truncate_dim]
67
- embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
68
-
69
- # Move all embeddings to CPU at once before conversion
70
- embeddings_list = embeddings.cpu().tolist()
71
-
72
- # don't wait for GC
73
- del embeddings
74
-
75
- return embeddings_list
76
-
77
- def wait_for_cuda(self, wait: int = 10):
78
- if self.init_device == "cuda" and not torch.cuda.is_available():
79
- time.sleep(wait)
80
- self.check_cuda()
81
-
82
- def check_cuda(self) -> Any:
83
- if self.init_device != "cuda":
84
- return None
85
- try:
86
- # Even though CUDA was available at init time,
87
- # CUDA can become unavailable - this is a known problem in AWS EC2+Docker
88
- # https://github.com/ray-project/ray/issues/49594
89
- nvmlInit()
90
- count = nvmlDeviceGetCount()
91
- assert count >= 1, "No CUDA devices found"
92
-
93
- # replicas only have access to GPU 0
94
- handle = nvmlDeviceGetHandleByIndex(0)
95
- return handle
96
- except Exception as e:
97
- error_msg = f"CUDA health check failed for deployment: " \
98
- f"{self.deployment_name}, replica: {self.replica_actor_name}, node: {self.node_ip}.\n{e}"
99
- self.logger.error(error_msg)
100
- if self.node_health_tracker:
101
- self.node_health_tracker.report_bad_gpu_node.remote(self.node_ip, self.deployment_name, self.replica_actor_name)
102
- raise RuntimeError(error_msg)
103
-
104
- async def check_health(self):
105
- if self.node_health_tracker:
106
- if await self.node_health_tracker.is_bad_gpu_node.remote(self.node_ip):
107
- raise RuntimeError(f"The node {self.node_ip} is marked bad.")
108
-
109
- handle = self.check_cuda() # Raises an exception if CUDA is unavailable
110
- mem_info = nvmlDeviceGetMemoryInfo(handle)
111
- reserved = torch.cuda.memory_reserved() # bytes currently reserved by CUDA cache
112
- threshold_bytes = self.cuda_memory_flush_threshold * mem_info.total
113
-
114
- if reserved > threshold_bytes:
115
- # flush only when cache exceeds the percentage threshold
116
- torch.cuda.empty_cache()
117
-
118
- def __del__(self):
119
- # Clean up and free any remaining GPU memory
120
- try:
121
- if hasattr(self, 'embedding_model'):
122
- del self.embedding_model
123
- if torch.cuda.is_available():
124
- torch.cuda.empty_cache()
125
- except Exception as e:
126
- self.logger.warning(f"Error during cleanup: {e}")
1
+ import logging
2
+ import os.path
3
+ import time
4
+ from typing import Optional, Dict, Any, List, Union
5
+
6
+ import torch
7
+ from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
8
+ from ray import serve
9
+ from ray.serve.handle import DeploymentHandle
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+ from ray_embedding.utils import report_unhealthy_replica
13
+
14
+
15
+ @serve.deployment
16
+ class EmbeddingModel:
17
+ def __init__(self, model: str, served_model_name: Optional[str] = None,
18
+ device: Optional[str] = None, backend: Optional[str] = "torch",
19
+ matryoshka_dim: Optional[int] = None, trust_remote_code: Optional[bool] = False,
20
+ model_kwargs: Dict[str, Any] = None, cuda_memory_flush_threshold: Optional[float] = 0.8,
21
+ node_reaper: Optional[DeploymentHandle] = None):
22
+ logging.basicConfig(level=logging.INFO)
23
+ self.logger = logging.getLogger(self.__class__.__name__)
24
+ self.model = model
25
+ self.served_model_name = served_model_name or os.path.basename(self.model)
26
+ self.init_device = device
27
+ self.cuda_memory_flush_threshold = cuda_memory_flush_threshold
28
+ if self.init_device is None or self.init_device == "auto":
29
+ self.init_device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ if self.init_device == "cuda":
31
+ self.wait_for_cuda()
32
+ self.torch_device = torch.device(self.init_device)
33
+ self.backend = backend or "torch"
34
+ self.matryoshka_dim = matryoshka_dim
35
+ self.trust_remote_code = trust_remote_code or False
36
+ self.model_kwargs = model_kwargs or {}
37
+ self.node_reaper = node_reaper
38
+
39
+ self.logger.info(f"Initializing embedding model: {self.model}")
40
+ self.embedding_model = SentenceTransformer(self.model, device=self.init_device, backend=self.backend,
41
+ trust_remote_code=self.trust_remote_code,
42
+ model_kwargs=self.model_kwargs)
43
+
44
+ self.logger.info(f"Successfully initialized model {self.model} using device {self.torch_device}")
45
+
46
+ async def __call__(self, text: Union[str, List[str]], dimensions: Optional[int] = None) -> List[List[float]]:
47
+ """Compute embeddings for the input text using the current model."""
48
+ if not text or (isinstance(text, list) and not all(text)):
49
+ raise ValueError("Input text is empty or invalid")
50
+
51
+ text = [text] if isinstance(text, str) else text
52
+ truncate_dim = dimensions or self.matryoshka_dim
53
+
54
+ # Compute embeddings in PyTorch format
55
+ embeddings = self.embedding_model.encode(
56
+ text, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False,
57
+ ).to(self.torch_device)
58
+
59
+ if truncate_dim is not None:
60
+ # Truncate and re-normalize the embeddings
61
+ embeddings = embeddings[:, :truncate_dim]
62
+ embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
63
+
64
+ # Move all embeddings to CPU at once before conversion
65
+ embeddings_list = embeddings.cpu().tolist()
66
+
67
+ # don't wait for GC
68
+ del embeddings
69
+
70
+ return embeddings_list
71
+
72
+ def wait_for_cuda(self, wait: int = 10):
73
+ if self.init_device == "cuda" and not torch.cuda.is_available():
74
+ time.sleep(wait)
75
+ self.check_health()
76
+
77
+ def check_health(self):
78
+ if self.init_device != "cuda":
79
+ return
80
+
81
+ try:
82
+ # Even though CUDA was available at init time,
83
+ # CUDA can become unavailable - this is a known problem in AWS EC2+Docker
84
+ # https://github.com/ray-project/ray/issues/49594
85
+ nvmlInit()
86
+ count = nvmlDeviceGetCount()
87
+ assert count >= 1, "No CUDA devices found"
88
+
89
+ # replicas only have access to GPU 0
90
+ handle = nvmlDeviceGetHandleByIndex(0)
91
+ mem_info = nvmlDeviceGetMemoryInfo(handle)
92
+ except Exception as e:
93
+ error_message = f"CUDA health check failed: {e}"
94
+ report_unhealthy_replica(error=error_message, node_reaper=self.node_reaper)
95
+ raise RuntimeError(error_message)
96
+
97
+ reserved = torch.cuda.memory_reserved() # bytes currently reserved by CUDA cache
98
+ threshold_bytes = self.cuda_memory_flush_threshold * mem_info.total
99
+
100
+ if reserved > threshold_bytes:
101
+ # flush only when cache exceeds the percentage threshold
102
+ torch.cuda.empty_cache()
103
+
104
+ def __del__(self):
105
+ # Clean up and free any remaining GPU memory
106
+ try:
107
+ if hasattr(self, 'embedding_model'):
108
+ del self.embedding_model
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+ except Exception as e:
112
+ self.logger.warning(f"Error during cleanup: {e}")
@@ -4,20 +4,20 @@ import time
4
4
  from typing import Optional, Dict, List, Tuple
5
5
 
6
6
  from fastapi import FastAPI, HTTPException
7
+ import ray
7
8
  from ray import serve
8
9
  from ray.serve.handle import DeploymentHandle
9
- from ray.util import get_node_ip_address
10
10
 
11
11
  from ray_embedding.dto import DeployedModel, EmbeddingRequest, EmbeddingResponse
12
+ from ray_embedding.utils import get_current_node_ip
12
13
 
13
14
  web_api = FastAPI(title="Ray Embeddings - OpenAI-compatible API")
14
15
 
15
16
  @serve.deployment
16
17
  @serve.ingress(web_api)
17
18
  class ModelRouter:
18
- def __init__(self, deployed_models: Dict[str, DeployedModel],
19
- path_prefix: List[str], max_concurrency: Optional[int] = 32,
20
- node_health_tracker: Optional[DeploymentHandle] = None):
19
+ def __init__(self, deployed_models: Dict[str, DeployedModel], path_prefix: List[str],
20
+ max_concurrency: Optional[int] = 32, node_reaper: Optional[DeploymentHandle] = None):
21
21
  assert deployed_models, "models cannot be empty"
22
22
  assert path_prefix, "path_prefix cannot be empty"
23
23
 
@@ -35,13 +35,7 @@ class ModelRouter:
35
35
  "permission": []} for item in self.deployed_models.keys()
36
36
  ]
37
37
  self.logger.info(f"Successfully registered models: {self.available_models}")
38
- self.node_health_tracker = node_health_tracker
39
- replica_context = serve.get_replica_context()
40
- self.deployment_name = replica_context.deployment
41
- self.replica_actor_name = replica_context.replica_id.to_full_id_str()
42
- self.node_ip = get_node_ip_address()
43
- self.logger.info(f"Successfully initialized model router. "
44
- f"Deployment name: {self.deployment_name}, Replica actor name: {self.replica_actor_name}, Node IP: {self.node_ip}")
38
+ self.node_reaper = node_reaper
45
39
 
46
40
  async def _compute_embeddings_from_resized_batches(self, model: str, inputs: List[str], dimensions: Optional[int] = None):
47
41
  deployed_model = self.deployed_models[model]
@@ -122,7 +116,19 @@ class ModelRouter:
122
116
  raise HTTPException(status_code=400, detail=f"The API path prefix specified is invalid: '{path_prefix}'")
123
117
  return {"object": "list", "data": self.available_models}
124
118
 
125
- async def check_health(self):
126
- if self.node_health_tracker:
127
- if await self.node_health_tracker.is_bad_gpu_or_no_model_replica_on_node.remote(self.node_ip):
128
- raise RuntimeError(f"The node {self.node_ip} is marked bad, or no model replica running on the node.")
119
+ def check_health(self):
120
+ if not self.node_reaper:
121
+ return
122
+
123
+ try:
124
+ unhealthy_node_ips = ray.get(self.node_reaper.get_unhealthy_node_ips.remote())
125
+ except Exception as exc:
126
+ self.logger.warning(f"Unable to fetch node reaper data: {exc}")
127
+ return
128
+
129
+ if not unhealthy_node_ips:
130
+ return
131
+
132
+ node_ip = get_current_node_ip()
133
+ if node_ip and node_ip in unhealthy_node_ips:
134
+ raise RuntimeError("Model router replica is colocated with an unhealthy embedding replica node.")
@@ -0,0 +1,124 @@
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Dict, Any, List, Optional, Set
6
+
7
+ from ray import serve
8
+
9
+
10
+ NODE_REAPER_DEPLOYMENT_NAME = "NodeReaper"
11
+
12
+
13
+ @serve.deployment
14
+ class NodeReaper:
15
+ def __init__(
16
+ self,
17
+ ssh_user: str,
18
+ ssh_private_key: str,
19
+ retention_seconds: int = 900,
20
+ reap_interval_seconds: int = 60,
21
+ ):
22
+ logging.basicConfig(level=logging.INFO)
23
+ self.logger = logging.getLogger(self.__class__.__name__)
24
+ self.ssh_user = ssh_user
25
+ key_path = Path(ssh_private_key).expanduser()
26
+ if not key_path.exists():
27
+ raise FileNotFoundError(f"SSH private key not found: {key_path}")
28
+ self.ssh_private_key = key_path.as_posix()
29
+ self.retention_seconds = retention_seconds
30
+ self.reap_interval_seconds = max(30, reap_interval_seconds)
31
+
32
+ self._unhealthy_replicas: Dict[str, Dict[str, Any]] = {}
33
+ self._nodes_marked_for_reap: Dict[str, float] = {}
34
+ self._nodes_inflight: Set[str] = set()
35
+
36
+ loop = asyncio.get_event_loop()
37
+ self._reaper_task = loop.create_task(self._reap_loop())
38
+ self.logger.info("NodeReaper initialized; monitoring unhealthy nodes for recycling")
39
+
40
+ def __del__(self):
41
+ if hasattr(self, "_reaper_task") and self._reaper_task and not self._reaper_task.done():
42
+ self._reaper_task.cancel()
43
+
44
+ def report_failure(self, replica_id: str, node_ip: str, error: Optional[str] = None):
45
+ self._unhealthy_replicas[replica_id] = {
46
+ "node_ip": node_ip,
47
+ "error": error,
48
+ "timestamp": time.time(),
49
+ }
50
+ self._nodes_marked_for_reap[node_ip] = self._nodes_marked_for_reap.get(node_ip, time.time())
51
+ self.logger.warning(f"Replica {replica_id} on {node_ip} marked for reaping: {error}")
52
+ self._purge_stale()
53
+
54
+ def get_unhealthy_node_ips(self) -> List[str]:
55
+ self._purge_stale()
56
+ return list(self._nodes_marked_for_reap.keys())
57
+
58
+ async def _reap_loop(self):
59
+ while True:
60
+ try:
61
+ await asyncio.sleep(self.reap_interval_seconds)
62
+ await self._reap_pending_nodes()
63
+ except asyncio.CancelledError:
64
+ break
65
+ except Exception as exc:
66
+ self.logger.warning(f"Unexpected error in reap loop: {exc}")
67
+
68
+ async def _reap_pending_nodes(self):
69
+ nodes = self.get_unhealthy_node_ips()
70
+ for node_ip in nodes:
71
+ if node_ip in self._nodes_inflight:
72
+ continue
73
+ self._nodes_inflight.add(node_ip)
74
+ try:
75
+ await self._reap_node(node_ip)
76
+ self._clear_node(node_ip)
77
+ self.logger.info(f"Successfully reaped node {node_ip}")
78
+ except Exception as exc:
79
+ self.logger.error(f"Failed to reap node {node_ip}: {exc}")
80
+ finally:
81
+ self._nodes_inflight.discard(node_ip)
82
+
83
+ async def _reap_node(self, node_ip: str):
84
+ ssh_command = [
85
+ "ssh",
86
+ "-i",
87
+ self.ssh_private_key,
88
+ "-o",
89
+ "StrictHostKeyChecking=no",
90
+ f"{self.ssh_user}@{node_ip}",
91
+ "docker stop ray_container",
92
+ ]
93
+
94
+ self.logger.info(f"Reaping node {node_ip} via SSH")
95
+ process = await asyncio.create_subprocess_exec(
96
+ *ssh_command,
97
+ stdout=asyncio.subprocess.PIPE,
98
+ stderr=asyncio.subprocess.PIPE,
99
+ )
100
+ stdout, stderr = await process.communicate()
101
+ if process.returncode != 0:
102
+ stdout_text = stdout.decode().strip()
103
+ stderr_text = stderr.decode().strip()
104
+ raise RuntimeError(
105
+ f"SSH command failed with code {process.returncode}. stdout={stdout_text} stderr={stderr_text}"
106
+ )
107
+
108
+ def _clear_node(self, node_ip: str):
109
+ to_delete = [replica for replica, data in self._unhealthy_replicas.items() if data.get("node_ip") == node_ip]
110
+ for replica in to_delete:
111
+ self._unhealthy_replicas.pop(replica, None)
112
+ self._nodes_marked_for_reap.pop(node_ip, None)
113
+
114
+ def _purge_stale(self):
115
+ if not self.retention_seconds:
116
+ return
117
+ cutoff = time.time() - self.retention_seconds
118
+ replica_ids = [replica_id for replica_id, data in self._unhealthy_replicas.items()
119
+ if data.get("timestamp", 0) < cutoff]
120
+ for replica_id in replica_ids:
121
+ node_ip = self._unhealthy_replicas[replica_id]["node_ip"]
122
+ self._unhealthy_replicas.pop(replica_id, None)
123
+ if node_ip in self._nodes_marked_for_reap and self._nodes_marked_for_reap[node_ip] < cutoff:
124
+ self._nodes_marked_for_reap.pop(node_ip, None)
ray_embedding/utils.py ADDED
@@ -0,0 +1,60 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import ray
4
+ from ray import serve
5
+ from ray.serve.handle import DeploymentHandle
6
+ from ray.util import get_node_ip_address, state
7
+ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, NotIn
8
+
9
+ from ray_embedding.node_reaper import NODE_REAPER_DEPLOYMENT_NAME
10
+
11
+
12
+
13
+ def get_head_node_id() -> Tuple[str, str]:
14
+ try:
15
+ nodes = state.list_nodes(filters=[("is_head_node", "=", True)])
16
+ if not nodes:
17
+ raise RuntimeError("Unable to locate head node for NodeReaper deployment.")
18
+ head_node = nodes[0]
19
+ return head_node["node_id"], head_node["node_ip"]
20
+ except Exception as exc:
21
+ raise RuntimeError("Unable to locate the head node ID for NodeReaper deployment.") from exc
22
+
23
+
24
+ def get_node_reaper_handle() -> DeploymentHandle:
25
+ try:
26
+ return serve.context.get_deployment_handle(NODE_REAPER_DEPLOYMENT_NAME)
27
+ except Exception:
28
+ return serve.get_deployment(NODE_REAPER_DEPLOYMENT_NAME).get_handle(sync=False)
29
+
30
+
31
+ def get_current_replica_tag() -> Optional[str]:
32
+ try:
33
+ context = serve.context.get_current_replica_context()
34
+ except Exception:
35
+ context = None
36
+ if context is None:
37
+ return None
38
+ return getattr(context, "replica_tag", None)
39
+
40
+
41
+ def get_current_node_ip() -> Optional[str]:
42
+ try:
43
+ return get_node_ip_address()
44
+ except Exception:
45
+ return None
46
+
47
+
48
+ def report_unhealthy_replica(error: Optional[str] = None,
49
+ node_reaper: Optional[DeploymentHandle] = None) -> None:
50
+ replica_id = get_current_replica_tag()
51
+ node_ip = get_current_node_ip()
52
+ if not (replica_id and node_ip):
53
+ return
54
+ handle = node_reaper
55
+ if handle is None:
56
+ try:
57
+ handle = get_node_reaper_handle()
58
+ except Exception:
59
+ return
60
+ handle.report_failure.remote(replica_id, node_ip, error)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ray-embedding
3
- Version: 0.13.9
3
+ Version: 0.14.0
4
4
  Summary: Deploy SentenceTransformers embedding models to a ray cluster
5
5
  Author: Crispin Almodovar
6
6
  Author-email:
@@ -31,6 +31,6 @@ to see how this library is used.
31
31
  - onnx-gpu
32
32
  - onnx-cpu
33
33
  - openvino-cpu
34
-
34
+ - fastembed-onnx-cpu
35
35
 
36
36
 
@@ -0,0 +1,11 @@
1
+ ray_embedding/__init__.py,sha256=YS5LAZfRIwwVvE3C9g7hsauvjgIkqKtHyxkwMFFfAGY,46
2
+ ray_embedding/deploy.py,sha256=NYpGDGF8y1rh3Thts-NC4nb8anXQJDC0dFZC18_R2f8,4170
3
+ ray_embedding/dto.py,sha256=6JuAcD6pLfzUL48HfyPnZI7Hb-o66KFM5UtYZOOgwc8,1739
4
+ ray_embedding/embedding_model.py,sha256=Zr5lxVuy60y8-JgsOmKDD44FZlbTL1tiiY-3_72sTR4,4905
5
+ ray_embedding/model_router.py,sha256=W2c0hvqwDe1iCfNx4ee2UT7wKduywMP8dY0Ggb8xBvU,6658
6
+ ray_embedding/node_reaper.py,sha256=ISwSHnQs22B_f3PihND3KYTLkJSDbg1JWIAaKS-qCm0,4800
7
+ ray_embedding/utils.py,sha256=cbdI7q6xSvbl31ZthdM8mz55VrN8pubkoD6RqKGYLUc,1898
8
+ ray_embedding-0.14.0.dist-info/METADATA,sha256=uJlMttPN4bYVZbvLi5g37dCQRyzEzHrj7eqK1fYiv_w,1094
9
+ ray_embedding-0.14.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ ray_embedding-0.14.0.dist-info/top_level.txt,sha256=ziCblpJq1YsrryshFqxTRuRMgNuO1_tgvAAkGShATNA,14
11
+ ray_embedding-0.14.0.dist-info/RECORD,,
@@ -1,94 +0,0 @@
1
- import logging
2
- import threading
3
- from typing import Set, List
4
-
5
- import ray
6
- from ray import serve
7
- from ray._private.services import get_node_ip_address
8
- from ray.util.state import list_actors
9
-
10
-
11
- @serve.deployment(autoscaling_config=dict(min_replicas=0, max_replicas=1),
12
- ray_actor_options=dict(num_cpus=0.1))
13
- class NodeHealthTracker:
14
- """Maintains a list of bad nodes, as reported by replicas that call the report_bad_node func.
15
- Bad nodes are those that fail GPU/CUDA health check.
16
- What's the purpose? Because when an embedding model replica becomes unhealthy
17
- (due to GPU/CUDA issues), we want Ray to kill all replicas running on the node.
18
- When Ray detects that there are no running replicas on a node, the node is stopped
19
- and replaced with a new one.
20
- """
21
- def __init__(self, tracked_model_deployments: List[str] = None):
22
- logging.basicConfig(level=logging.INFO)
23
- self.logger = logging.getLogger(self.__class__.__name__)
24
- self.tracked_model_deployments = tracked_model_deployments or []
25
- self.bad_gpu_node_ips: Set[str] = set()
26
- self.lock = threading.RLock()
27
- replica_context = serve.get_replica_context()
28
- self.app_name = replica_context.app_name
29
- self.deployment_name = replica_context.deployment
30
- self.replica_actor_name = replica_context.replica_id.to_full_id_str()
31
- self.node_ip = get_node_ip_address()
32
- self.logger.info(f"Successfully initialized NodeHealthTracker. Tracked model deployments: {self.tracked_model_deployments}")
33
-
34
- async def report_bad_gpu_node(self, node_ip: str, deployment_name: str, replica_actor_name: str):
35
- with self.lock:
36
- if node_ip not in self.bad_gpu_node_ips:
37
- self.bad_gpu_node_ips.add(node_ip)
38
- self.logger.warning(
39
- f"[Bad GPU node reported] Deployment: {deployment_name}, Replica: {replica_actor_name}, Node IP: {node_ip}"
40
- )
41
-
42
- async def is_bad_gpu_node(self, node_ip: str) -> bool:
43
- self.logger.info(f"Checking if node {node_ip} is marked bad.")
44
- with self.lock:
45
- is_bad_gpu_node = node_ip in self.bad_gpu_node_ips
46
- self.logger.info(f"Node {node_ip} is marked bad: {is_bad_gpu_node}")
47
- return is_bad_gpu_node
48
-
49
- async def is_bad_gpu_or_no_model_replica_on_node(self, node_ip: str):
50
- self.logger.info(f"Checking if node {node_ip} is marked bad or no model replica running on the node.")
51
- is_bad_gpu_node = await self.is_bad_gpu_node(node_ip)
52
- is_no_model_replica_running_on_node = not await self.is_model_replica_running_on_node(node_ip)
53
- return is_bad_gpu_node or is_no_model_replica_running_on_node
54
-
55
- async def check_health(self):
56
- """Called periodically by Ray Serve. Used here to clean up stale node IDs."""
57
- try:
58
- current_node_ips = {node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]}
59
- with self.lock:
60
- stale_nodes = self.bad_gpu_node_ips - current_node_ips
61
- if stale_nodes:
62
- self.logger.info(f"Removing stale bad node_ips: {stale_nodes}")
63
- self.bad_gpu_node_ips.intersection_update(current_node_ips)
64
- self.logger.info(f"Current nodes: {current_node_ips}. Bad GPU nodes: {self.bad_gpu_node_ips}.")
65
- except Exception as e:
66
- raise RuntimeError(f"An error occurred in check_health during bad node cleanup: {e}")
67
-
68
- async def is_model_replica_running_on_node(self, node_ip: str) -> bool:
69
- """
70
- Return True if there is at least one replica of the self.tracked_model_deployments
71
- running on the specified node_ip.
72
- """
73
- try:
74
- self.logger.info(f"Checking if there is at least one replica of tracked_deployments={self.tracked_model_deployments} "
75
- f"running on node {node_ip}.")
76
- target_node_id = next(node["NodeID"] for node in ray.nodes() if node["Alive"] and node["NodeManagerAddress"] == node_ip)
77
- assert target_node_id, f"No node found with IP {node_ip}"
78
- prefixes = tuple(f"SERVE_REPLICA::{self.app_name}#{d}" for d in self.tracked_model_deployments)
79
- self.logger.info(f"Checking actors with prefixes: {prefixes} in node IP {node_ip}, ID {target_node_id}")
80
-
81
- for actor in list_actors(detail=False, filters=[("node_id", "=", target_node_id)]):
82
- self.logger.info(f"Checking actor: {actor}")
83
- if actor.state in ["DEPENDENCIES_UNREADY", 'PENDING_CREATION', 'ALIVE', 'RESTARTING']:
84
- for prefix in prefixes:
85
- if actor.name.startswith(prefix):
86
- self.logger.info(f"Found a replica {actor.name} of "
87
- f"tracked_deployments={self.tracked_model_deployments} "
88
- f"running in node IP {node_ip}, node ID {target_node_id}.")
89
- return True
90
- self.logger.info(f"No replicas of tracked deployments={self.tracked_model_deployments} running on node: {node_ip}.")
91
- return False
92
- except Exception as e:
93
- self.logger.error(f"An error occurred while checking replicas on node {node_ip}: {e}")
94
- return False
@@ -1,10 +0,0 @@
1
- ray_embedding/__init__.py,sha256=YS5LAZfRIwwVvE3C9g7hsauvjgIkqKtHyxkwMFFfAGY,46
2
- ray_embedding/deploy.py,sha256=2R7bQ7aPc9G8H9KVoemxum6-9YxmlXQogWbhFhuslko,3762
3
- ray_embedding/dto.py,sha256=lk_LuVQPq3MLIMTMddqHviYXILY6V5dvbzDJuD_D_qc,1573
4
- ray_embedding/embedding_model.py,sha256=P2xyXCznxXmdQBK6zodOJEMvxGVRMA8Ra3O5Qi7RCh0,6013
5
- ray_embedding/model_router.py,sha256=fmaeXzaAJeCemzL9nUoXfdCrU-ZaCe_29fx5ayDCTC0,6845
6
- ray_embedding/node_health.py,sha256=bKRoFHS6cVRQBOYTcv0dRA61VDeiJjmIPT8tA0hbRIU,5350
7
- ray_embedding-0.13.9.dist-info/METADATA,sha256=O1ObZ9JwO7eI-6Vke5hwxmyLy4VvRWIb0IgiTD6GZzQ,1074
8
- ray_embedding-0.13.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- ray_embedding-0.13.9.dist-info/top_level.txt,sha256=ziCblpJq1YsrryshFqxTRuRMgNuO1_tgvAAkGShATNA,14
10
- ray_embedding-0.13.9.dist-info/RECORD,,