ray-embedding 0.12.5__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ray-embedding might be problematic. Click here for more details.

ray_embedding/deploy.py CHANGED
@@ -1,55 +1,79 @@
1
- import os
2
-
3
- import torch
4
- from ray.serve import Application
5
-
6
- from ray_embedding.dto import AppConfig, ModelDeploymentConfig, DeployedModel
7
- from ray_embedding.embedding_model import EmbeddingModel
8
- from ray_embedding.model_router import ModelRouter
9
-
10
-
11
- def build_model(model_config: ModelDeploymentConfig) -> DeployedModel:
12
- deployment_name = model_config.deployment
13
- model = model_config.model
14
- served_model_name = model_config.served_model_name or os.path.basename(model)
15
- device = model_config.device
16
- backend = model_config.backend or "torch"
17
- matryoshka_dim = model_config.matryoshka_dim
18
- trust_remote_code = model_config.trust_remote_code or False
19
- model_kwargs = model_config.model_kwargs or {}
20
- cuda_memory_flush_threshold = model_config.cuda_memory_flush_threshold or 0.8
21
-
22
- if "torch_dtype" in model_kwargs:
23
- torch_dtype = model_kwargs["torch_dtype"].strip()
24
- if torch_dtype == "float16":
25
- model_kwargs["torch_dtype"] = torch.float16
26
- elif torch_dtype == "bfloat16":
27
- model_kwargs["torch_dtype"] = torch.bfloat16
28
- elif torch_dtype == "float32":
29
- model_kwargs["torch_dtype"] = torch.float32
30
- else:
31
- raise ValueError(f"Invalid torch_dtype: '{torch_dtype}'")
32
-
33
- deployment = EmbeddingModel.options(name=deployment_name).bind(model=model,
34
- served_model_name=served_model_name,
35
- device=device,
36
- backend=backend,
37
- matryoshka_dim=matryoshka_dim,
38
- trust_remote_code=trust_remote_code,
39
- model_kwargs=model_kwargs
40
- )
41
- return DeployedModel(model=served_model_name,
42
- deployment_handle=deployment,
43
- batch_size=model_config.batch_size,
44
- num_retries=model_config.num_retries
45
- )
46
-
47
-
48
- def build_app(args: AppConfig) -> Application:
49
- model_router, models = args.model_router, args.models
50
- assert model_router and models
51
- assert model_router.path_prefix
52
-
53
- deployed_models = {model_config.served_model_name: build_model(model_config) for model_config in models}
54
- router = ModelRouter.options(name=model_router.deployment).bind(deployed_models, model_router.path_prefix)
55
- return router
1
+ import os
2
+
3
+ import torch
4
+ from ray.serve import Application
5
+
6
+ from ray_embedding.dto import AppConfig, ModelDeploymentConfig, DeployedModel, NodeReaperConfig
7
+ from ray_embedding.embedding_model import EmbeddingModel
8
+ from ray_embedding.node_reaper import NodeReaper, NODE_REAPER_DEPLOYMENT_NAME
9
+ from ray_embedding.utils import node_affinity_for_head, node_affinity_for_worker
10
+ from ray_embedding.model_router import ModelRouter
11
+
12
+
13
+ def build_model(model_config: ModelDeploymentConfig, node_reaper):
14
+ deployment_name = model_config.deployment
15
+ model = model_config.model
16
+ served_model_name = model_config.served_model_name or os.path.basename(model)
17
+ device = model_config.device
18
+ backend = model_config.backend or "torch"
19
+ matryoshka_dim = model_config.matryoshka_dim
20
+ trust_remote_code = model_config.trust_remote_code or False
21
+ model_kwargs = model_config.model_kwargs or {}
22
+ cuda_memory_flush_threshold = model_config.cuda_memory_flush_threshold or 0.8
23
+
24
+ if "torch_dtype" in model_kwargs:
25
+ torch_dtype = model_kwargs["torch_dtype"].strip()
26
+ if torch_dtype == "float16":
27
+ model_kwargs["torch_dtype"] = torch.float16
28
+ elif torch_dtype == "bfloat16":
29
+ model_kwargs["torch_dtype"] = torch.bfloat16
30
+ elif torch_dtype == "float32":
31
+ model_kwargs["torch_dtype"] = torch.float32
32
+ else:
33
+ raise ValueError(f"Invalid torch_dtype: '{torch_dtype}'")
34
+
35
+ deployment = EmbeddingModel.options(name=deployment_name).bind(model=model,
36
+ served_model_name=served_model_name,
37
+ device=device,
38
+ backend=backend,
39
+ matryoshka_dim=matryoshka_dim,
40
+ trust_remote_code=trust_remote_code,
41
+ model_kwargs=model_kwargs,
42
+ cuda_memory_flush_threshold=cuda_memory_flush_threshold,
43
+ node_reaper=node_reaper,
44
+ )
45
+ return DeployedModel(model=served_model_name,
46
+ deployment_handle=deployment,
47
+ batch_size=model_config.batch_size,
48
+ num_retries=model_config.num_retries
49
+ )
50
+
51
+
52
+ def build_app(args: AppConfig) -> Application:
53
+ model_router, models = args.model_router, args.models
54
+ assert model_router and models
55
+ assert model_router.path_prefix
56
+
57
+ node_reaper_config = args.node_reaper or NodeReaperConfig()
58
+
59
+ node_reaper_kwargs = {
60
+ "ssh_user": node_reaper_config.ssh_user,
61
+ "ssh_private_key": node_reaper_config.ssh_private_key,
62
+ }
63
+ if node_reaper_config.retention_seconds is not None:
64
+ node_reaper_kwargs["retention_seconds"] = node_reaper_config.retention_seconds
65
+ if node_reaper_config.reap_interval_seconds is not None:
66
+ node_reaper_kwargs["reap_interval_seconds"] = node_reaper_config.reap_interval_seconds
67
+
68
+ node_reaper = NodeReaper.options(
69
+ name=NODE_REAPER_DEPLOYMENT_NAME,
70
+ ray_actor_options={"scheduling_strategy": node_affinity_for_head()},
71
+ ).bind(**node_reaper_kwargs)
72
+
73
+ deployed_models = {model_config.served_model_name: build_model(model_config, node_reaper) for model_config in models}
74
+ router = ModelRouter.options(
75
+ name=model_router.deployment,
76
+ ray_actor_options={"scheduling_strategy": node_affinity_for_worker()}
77
+ ).bind(deployed_models, model_router.path_prefix, node_reaper)
78
+
79
+ return router
ray_embedding/dto.py CHANGED
@@ -1,51 +1,59 @@
1
- import dataclasses
2
- from typing import Union, List, Optional, Dict, Any
3
- from pydantic import BaseModel
4
- from ray.serve.handle import DeploymentHandle
5
-
6
-
7
- class EmbeddingRequest(BaseModel):
8
- """Schema of embedding requests (compatible with OpenAI)"""
9
- model: str # Model name (for compatibility; only one model is used here)
10
- input: Union[str, List[str]] # List of strings to embed
11
- dimensions: Optional[int] = None
12
-
13
-
14
- class EmbeddingResponse(BaseModel):
15
- """Schema of embedding response (compatible with OpenAI)"""
16
- object: str
17
- data: List[dict] # Embedding data including index and vector
18
- model: str # Model name used for embedding
19
-
20
-
21
- class ModelRouterConfig(BaseModel):
22
- deployment: str
23
- path_prefix: List[str] = []
24
- max_concurrency: int = 32
25
-
26
-
27
- class ModelDeploymentConfig(BaseModel):
28
- model: str
29
- served_model_name: str
30
- batch_size: Optional[int] = 8
31
- num_retries: Optional[int] = 2
32
- device: Optional[str] = None
33
- backend: Optional[str] = None
34
- matryoshka_dim: Optional[int] = 768
35
- trust_remote_code: Optional[bool] = False
36
- model_kwargs: Optional[Dict[str, Any]] = {}
37
- cuda_memory_flush_threshold: Optional[float] = 0.8
38
- deployment: str
39
-
40
-
41
- class AppConfig(BaseModel):
42
- model_router: ModelRouterConfig
43
- models: List[ModelDeploymentConfig]
44
-
45
-
46
- @dataclasses.dataclass
47
- class DeployedModel:
48
- model: str
49
- deployment_handle: DeploymentHandle
50
- batch_size: int
51
- num_retries: Optional[int] = 2
1
+ import dataclasses
2
+ from typing import Union, List, Optional, Dict, Any
3
+ from pydantic import BaseModel
4
+ from ray.serve.handle import DeploymentHandle
5
+
6
+
7
+ class EmbeddingRequest(BaseModel):
8
+ """Schema of embedding requests (compatible with OpenAI)"""
9
+ model: str # Model name (for compatibility; only one model is used here)
10
+ input: Union[str, List[str]] # List of strings to embed
11
+ dimensions: Optional[int] = None
12
+
13
+
14
+ class EmbeddingResponse(BaseModel):
15
+ """Schema of embedding response (compatible with OpenAI)"""
16
+ object: str
17
+ data: List[dict] # Embedding data including index and vector
18
+ model: str # Model name used for embedding
19
+
20
+
21
+ class ModelRouterConfig(BaseModel):
22
+ deployment: str
23
+ path_prefix: List[str] = []
24
+ max_concurrency: int = 32
25
+
26
+
27
+ class ModelDeploymentConfig(BaseModel):
28
+ model: str
29
+ served_model_name: str
30
+ batch_size: Optional[int] = 8
31
+ num_retries: Optional[int] = 2
32
+ device: Optional[str] = None
33
+ backend: Optional[str] = None
34
+ matryoshka_dim: Optional[int] = 768
35
+ trust_remote_code: Optional[bool] = False
36
+ model_kwargs: Optional[Dict[str, Any]] = {}
37
+ cuda_memory_flush_threshold: Optional[float] = 0.8
38
+ deployment: str
39
+
40
+
41
+ class NodeReaperConfig(BaseModel):
42
+ ssh_user: str = "ubuntu"
43
+ ssh_private_key: str = "/home/ray/ray_bootstrap_key.pem"
44
+ retention_seconds: Optional[int] = 900
45
+ reap_interval_seconds: Optional[int] = 60
46
+
47
+
48
+ class AppConfig(BaseModel):
49
+ model_router: ModelRouterConfig
50
+ node_reaper: NodeReaperConfig
51
+ models: List[ModelDeploymentConfig]
52
+
53
+
54
+ @dataclasses.dataclass
55
+ class DeployedModel:
56
+ model: str
57
+ deployment_handle: DeploymentHandle
58
+ batch_size: int
59
+ num_retries: Optional[int] = 2
@@ -1,104 +1,112 @@
1
- import logging
2
- import os.path
3
- import time
4
- from typing import Optional, Dict, Any, List, Union
5
-
6
- import torch
7
- from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
8
- from ray import serve
9
- from sentence_transformers import SentenceTransformer
10
-
11
-
12
- @serve.deployment
13
- class EmbeddingModel:
14
- def __init__(self, model: str, served_model_name: Optional[str] = None,
15
- device: Optional[str] = None, backend: Optional[str] = "torch",
16
- matryoshka_dim: Optional[int] = None, trust_remote_code: Optional[bool] = False,
17
- model_kwargs: Dict[str, Any] = None, cuda_memory_flush_threshold: Optional[float] = 0.8):
18
- logging.basicConfig(level=logging.INFO)
19
- self.logger = logging.getLogger(self.__class__.__name__)
20
- self.model = model
21
- self.served_model_name = served_model_name or os.path.basename(self.model)
22
- self.init_device = device
23
- if self.init_device is None or self.init_device == "auto":
24
- self.init_device = "cuda" if torch.cuda.is_available() else "cpu"
25
- if self.init_device == "cuda":
26
- self.wait_for_cuda()
27
- self.torch_device = torch.device(self.init_device)
28
- self.backend = backend or "torch"
29
- self.matryoshka_dim = matryoshka_dim
30
- self.trust_remote_code = trust_remote_code or False
31
- self.model_kwargs = model_kwargs or {}
32
- self.cuda_memory_flush_threshold = cuda_memory_flush_threshold
33
- self.logger.info(f"Initializing embedding model: {self.model}")
34
- self.embedding_model = SentenceTransformer(self.model, device=self.init_device, backend=self.backend,
35
- trust_remote_code=self.trust_remote_code,
36
- model_kwargs=self.model_kwargs)
37
-
38
- self.logger.info(f"Successfully initialized model {self.model} using device {self.torch_device}")
39
-
40
- async def __call__(self, text: Union[str, List[str]], dimensions: Optional[int] = None) -> List[List[float]]:
41
- """Compute embeddings for the input text using the current model."""
42
- if not text or (isinstance(text, list) and not all(text)):
43
- raise ValueError("Input text is empty or invalid")
44
-
45
- text = [text] if isinstance(text, str) else text
46
- truncate_dim = dimensions or self.matryoshka_dim
47
-
48
- # Compute embeddings in PyTorch format
49
- embeddings = self.embedding_model.encode(
50
- text, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False,
51
- ).to(self.torch_device)
52
-
53
- if truncate_dim is not None:
54
- # Truncate and re-normalize the embeddings
55
- embeddings = embeddings[:, :truncate_dim]
56
- embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
57
-
58
- # Move all embeddings to CPU at once before conversion
59
- embeddings_list = embeddings.cpu().tolist()
60
-
61
- # don't wait for GC
62
- del embeddings
63
-
64
- return embeddings_list
65
-
66
- def wait_for_cuda(self, wait: int = 10):
67
- if self.init_device == "cuda" and not torch.cuda.is_available():
68
- time.sleep(wait)
69
- self.check_health()
70
-
71
- def check_health(self):
72
- if self.init_device != "cuda":
73
- return
74
-
75
- try:
76
- # Even though CUDA was available at init time,
77
- # CUDA can become unavailable - this is a known problem in AWS EC2+Docker
78
- # https://github.com/ray-project/ray/issues/49594
79
- nvmlInit()
80
- count = nvmlDeviceGetCount()
81
- assert count >= 1, "No CUDA devices found"
82
-
83
- # replicas only have access to GPU 0
84
- handle = nvmlDeviceGetHandleByIndex(0)
85
- mem_info = nvmlDeviceGetMemoryInfo(handle)
86
- except Exception as e:
87
- raise RuntimeError(f"CUDA health check failed: {e}")
88
-
89
- reserved = torch.cuda.memory_reserved() # bytes currently reserved by CUDA cache
90
- threshold_bytes = self.cuda_memory_flush_threshold * mem_info.total
91
-
92
- if reserved > threshold_bytes:
93
- # flush only when cache exceeds the percentage threshold
94
- torch.cuda.empty_cache()
95
-
96
- def __del__(self):
97
- # Clean up and free any remaining GPU memory
98
- try:
99
- if hasattr(self, 'embedding_model'):
100
- del self.embedding_model
101
- if torch.cuda.is_available():
102
- torch.cuda.empty_cache()
103
- except Exception as e:
104
- self.logger.warning(f"Error during cleanup: {e}")
1
+ import logging
2
+ import os.path
3
+ import time
4
+ from typing import Optional, Dict, Any, List, Union
5
+
6
+ import torch
7
+ from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
8
+ from ray import serve
9
+ from ray.serve.handle import DeploymentHandle
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+ from ray_embedding.utils import report_unhealthy_replica
13
+
14
+
15
+ @serve.deployment
16
+ class EmbeddingModel:
17
+ def __init__(self, model: str, served_model_name: Optional[str] = None,
18
+ device: Optional[str] = None, backend: Optional[str] = "torch",
19
+ matryoshka_dim: Optional[int] = None, trust_remote_code: Optional[bool] = False,
20
+ model_kwargs: Dict[str, Any] = None, cuda_memory_flush_threshold: Optional[float] = 0.8,
21
+ node_reaper: Optional[DeploymentHandle] = None):
22
+ logging.basicConfig(level=logging.INFO)
23
+ self.logger = logging.getLogger(self.__class__.__name__)
24
+ self.model = model
25
+ self.served_model_name = served_model_name or os.path.basename(self.model)
26
+ self.init_device = device
27
+ self.cuda_memory_flush_threshold = cuda_memory_flush_threshold
28
+ if self.init_device is None or self.init_device == "auto":
29
+ self.init_device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ if self.init_device == "cuda":
31
+ self.wait_for_cuda()
32
+ self.torch_device = torch.device(self.init_device)
33
+ self.backend = backend or "torch"
34
+ self.matryoshka_dim = matryoshka_dim
35
+ self.trust_remote_code = trust_remote_code or False
36
+ self.model_kwargs = model_kwargs or {}
37
+ self.node_reaper = node_reaper
38
+
39
+ self.logger.info(f"Initializing embedding model: {self.model}")
40
+ self.embedding_model = SentenceTransformer(self.model, device=self.init_device, backend=self.backend,
41
+ trust_remote_code=self.trust_remote_code,
42
+ model_kwargs=self.model_kwargs)
43
+
44
+ self.logger.info(f"Successfully initialized model {self.model} using device {self.torch_device}")
45
+
46
+ async def __call__(self, text: Union[str, List[str]], dimensions: Optional[int] = None) -> List[List[float]]:
47
+ """Compute embeddings for the input text using the current model."""
48
+ if not text or (isinstance(text, list) and not all(text)):
49
+ raise ValueError("Input text is empty or invalid")
50
+
51
+ text = [text] if isinstance(text, str) else text
52
+ truncate_dim = dimensions or self.matryoshka_dim
53
+
54
+ # Compute embeddings in PyTorch format
55
+ embeddings = self.embedding_model.encode(
56
+ text, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False,
57
+ ).to(self.torch_device)
58
+
59
+ if truncate_dim is not None:
60
+ # Truncate and re-normalize the embeddings
61
+ embeddings = embeddings[:, :truncate_dim]
62
+ embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
63
+
64
+ # Move all embeddings to CPU at once before conversion
65
+ embeddings_list = embeddings.cpu().tolist()
66
+
67
+ # don't wait for GC
68
+ del embeddings
69
+
70
+ return embeddings_list
71
+
72
+ def wait_for_cuda(self, wait: int = 10):
73
+ if self.init_device == "cuda" and not torch.cuda.is_available():
74
+ time.sleep(wait)
75
+ self.check_health()
76
+
77
+ def check_health(self):
78
+ if self.init_device != "cuda":
79
+ return
80
+
81
+ try:
82
+ # Even though CUDA was available at init time,
83
+ # CUDA can become unavailable - this is a known problem in AWS EC2+Docker
84
+ # https://github.com/ray-project/ray/issues/49594
85
+ nvmlInit()
86
+ count = nvmlDeviceGetCount()
87
+ assert count >= 1, "No CUDA devices found"
88
+
89
+ # replicas only have access to GPU 0
90
+ handle = nvmlDeviceGetHandleByIndex(0)
91
+ mem_info = nvmlDeviceGetMemoryInfo(handle)
92
+ except Exception as e:
93
+ error_message = f"CUDA health check failed: {e}"
94
+ report_unhealthy_replica(error=error_message, node_reaper=self.node_reaper)
95
+ raise RuntimeError(error_message)
96
+
97
+ reserved = torch.cuda.memory_reserved() # bytes currently reserved by CUDA cache
98
+ threshold_bytes = self.cuda_memory_flush_threshold * mem_info.total
99
+
100
+ if reserved > threshold_bytes:
101
+ # flush only when cache exceeds the percentage threshold
102
+ torch.cuda.empty_cache()
103
+
104
+ def __del__(self):
105
+ # Clean up and free any remaining GPU memory
106
+ try:
107
+ if hasattr(self, 'embedding_model'):
108
+ del self.embedding_model
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+ except Exception as e:
112
+ self.logger.warning(f"Error during cleanup: {e}")
@@ -4,17 +4,20 @@ import time
4
4
  from typing import Optional, Dict, List, Tuple
5
5
 
6
6
  from fastapi import FastAPI, HTTPException
7
+ import ray
7
8
  from ray import serve
8
9
  from ray.serve.handle import DeploymentHandle
9
10
 
10
11
  from ray_embedding.dto import DeployedModel, EmbeddingRequest, EmbeddingResponse
12
+ from ray_embedding.utils import get_current_node_ip
11
13
 
12
14
  web_api = FastAPI(title="Ray Embeddings - OpenAI-compatible API")
13
15
 
14
16
  @serve.deployment
15
17
  @serve.ingress(web_api)
16
18
  class ModelRouter:
17
- def __init__(self, deployed_models: Dict[str, DeployedModel], path_prefix: List[str], max_concurrency: Optional[int] = 32):
19
+ def __init__(self, deployed_models: Dict[str, DeployedModel], path_prefix: List[str],
20
+ max_concurrency: Optional[int] = 32, node_reaper: Optional[DeploymentHandle] = None):
18
21
  assert deployed_models, "models cannot be empty"
19
22
  assert path_prefix, "path_prefix cannot be empty"
20
23
 
@@ -32,6 +35,7 @@ class ModelRouter:
32
35
  "permission": []} for item in self.deployed_models.keys()
33
36
  ]
34
37
  self.logger.info(f"Successfully registered models: {self.available_models}")
38
+ self.node_reaper = node_reaper
35
39
 
36
40
  async def _compute_embeddings_from_resized_batches(self, model: str, inputs: List[str], dimensions: Optional[int] = None):
37
41
  deployed_model = self.deployed_models[model]
@@ -111,3 +115,20 @@ class ModelRouter:
111
115
  if path_prefix not in self.path_prefix:
112
116
  raise HTTPException(status_code=400, detail=f"The API path prefix specified is invalid: '{path_prefix}'")
113
117
  return {"object": "list", "data": self.available_models}
118
+
119
+ def check_health(self):
120
+ if not self.node_reaper:
121
+ return
122
+
123
+ try:
124
+ unhealthy_node_ips = ray.get(self.node_reaper.get_unhealthy_node_ips.remote())
125
+ except Exception as exc:
126
+ self.logger.warning(f"Unable to fetch node reaper data: {exc}")
127
+ return
128
+
129
+ if not unhealthy_node_ips:
130
+ return
131
+
132
+ node_ip = get_current_node_ip()
133
+ if node_ip and node_ip in unhealthy_node_ips:
134
+ raise RuntimeError("Model router replica is colocated with an unhealthy embedding replica node.")
@@ -0,0 +1,130 @@
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Dict, Any, List, Optional, Set
6
+
7
+ from ray import serve
8
+
9
+
10
+ NODE_REAPER_DEPLOYMENT_NAME = "NodeReaper"
11
+
12
+
13
+ @serve.deployment(
14
+ name=NODE_REAPER_DEPLOYMENT_NAME,
15
+ route_prefix=None,
16
+ num_replicas=1,
17
+ ray_actor_options={"num_cpus": 0.25},
18
+ autoscaling_config={"min_replicas": 1, "max_replicas": 1},
19
+ )
20
+ class NodeReaper:
21
+ def __init__(
22
+ self,
23
+ ssh_user: str,
24
+ ssh_private_key: str,
25
+ retention_seconds: int = 900,
26
+ reap_interval_seconds: int = 60,
27
+ ):
28
+ logging.basicConfig(level=logging.INFO)
29
+ self.logger = logging.getLogger(self.__class__.__name__)
30
+ self.ssh_user = ssh_user
31
+ key_path = Path(ssh_private_key).expanduser()
32
+ if not key_path.exists():
33
+ raise FileNotFoundError(f"SSH private key not found: {key_path}")
34
+ self.ssh_private_key = key_path.as_posix()
35
+ self.retention_seconds = retention_seconds
36
+ self.reap_interval_seconds = max(30, reap_interval_seconds)
37
+
38
+ self._unhealthy_replicas: Dict[str, Dict[str, Any]] = {}
39
+ self._nodes_marked_for_reap: Dict[str, float] = {}
40
+ self._nodes_inflight: Set[str] = set()
41
+
42
+ loop = asyncio.get_event_loop()
43
+ self._reaper_task = loop.create_task(self._reap_loop())
44
+ self.logger.info("NodeReaper initialized; monitoring unhealthy nodes for recycling")
45
+
46
+ def __del__(self):
47
+ if hasattr(self, "_reaper_task") and self._reaper_task and not self._reaper_task.done():
48
+ self._reaper_task.cancel()
49
+
50
+ def report_failure(self, replica_id: str, node_ip: str, error: Optional[str] = None):
51
+ self._unhealthy_replicas[replica_id] = {
52
+ "node_ip": node_ip,
53
+ "error": error,
54
+ "timestamp": time.time(),
55
+ }
56
+ self._nodes_marked_for_reap[node_ip] = self._nodes_marked_for_reap.get(node_ip, time.time())
57
+ self.logger.warning(f"Replica {replica_id} on {node_ip} marked for reaping: {error}")
58
+ self._purge_stale()
59
+
60
+ def get_unhealthy_node_ips(self) -> List[str]:
61
+ self._purge_stale()
62
+ return list(self._nodes_marked_for_reap.keys())
63
+
64
+ async def _reap_loop(self):
65
+ while True:
66
+ try:
67
+ await asyncio.sleep(self.reap_interval_seconds)
68
+ await self._reap_pending_nodes()
69
+ except asyncio.CancelledError:
70
+ break
71
+ except Exception as exc:
72
+ self.logger.warning(f"Unexpected error in reap loop: {exc}")
73
+
74
+ async def _reap_pending_nodes(self):
75
+ nodes = self.get_unhealthy_node_ips()
76
+ for node_ip in nodes:
77
+ if node_ip in self._nodes_inflight:
78
+ continue
79
+ self._nodes_inflight.add(node_ip)
80
+ try:
81
+ await self._reap_node(node_ip)
82
+ self._clear_node(node_ip)
83
+ self.logger.info(f"Successfully reaped node {node_ip}")
84
+ except Exception as exc:
85
+ self.logger.error(f"Failed to reap node {node_ip}: {exc}")
86
+ finally:
87
+ self._nodes_inflight.discard(node_ip)
88
+
89
+ async def _reap_node(self, node_ip: str):
90
+ ssh_command = [
91
+ "ssh",
92
+ "-i",
93
+ self.ssh_private_key,
94
+ "-o",
95
+ "StrictHostKeyChecking=no",
96
+ f"{self.ssh_user}@{node_ip}",
97
+ "docker stop ray_container",
98
+ ]
99
+
100
+ self.logger.info(f"Reaping node {node_ip} via SSH")
101
+ process = await asyncio.create_subprocess_exec(
102
+ *ssh_command,
103
+ stdout=asyncio.subprocess.PIPE,
104
+ stderr=asyncio.subprocess.PIPE,
105
+ )
106
+ stdout, stderr = await process.communicate()
107
+ if process.returncode != 0:
108
+ stdout_text = stdout.decode().strip()
109
+ stderr_text = stderr.decode().strip()
110
+ raise RuntimeError(
111
+ f"SSH command failed with code {process.returncode}. stdout={stdout_text} stderr={stderr_text}"
112
+ )
113
+
114
+ def _clear_node(self, node_ip: str):
115
+ to_delete = [replica for replica, data in self._unhealthy_replicas.items() if data.get("node_ip") == node_ip]
116
+ for replica in to_delete:
117
+ self._unhealthy_replicas.pop(replica, None)
118
+ self._nodes_marked_for_reap.pop(node_ip, None)
119
+
120
+ def _purge_stale(self):
121
+ if not self.retention_seconds:
122
+ return
123
+ cutoff = time.time() - self.retention_seconds
124
+ replica_ids = [replica_id for replica_id, data in self._unhealthy_replicas.items()
125
+ if data.get("timestamp", 0) < cutoff]
126
+ for replica_id in replica_ids:
127
+ node_ip = self._unhealthy_replicas[replica_id]["node_ip"]
128
+ self._unhealthy_replicas.pop(replica_id, None)
129
+ if node_ip in self._nodes_marked_for_reap and self._nodes_marked_for_reap[node_ip] < cutoff:
130
+ self._nodes_marked_for_reap.pop(node_ip, None)
ray_embedding/utils.py ADDED
@@ -0,0 +1,71 @@
1
+ from typing import Optional
2
+
3
+ import ray
4
+ from ray import serve
5
+ from ray.serve.handle import DeploymentHandle
6
+ from ray.util import get_node_ip_address, state
7
+ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, NotIn
8
+
9
+ from ray_embedding.node_reaper import NODE_REAPER_DEPLOYMENT_NAME
10
+
11
+
12
+
13
+ def get_head_node_id() -> str:
14
+ try:
15
+ nodes = state.list_nodes(filters=[("is_head_node", "=", True)])
16
+ if not nodes:
17
+ raise RuntimeError("Unable to locate head node for NodeReaper deployment.")
18
+ head_node = nodes[0]
19
+ return head_node["node_id"]
20
+ except Exception as exc:
21
+ raise RuntimeError("Unable to locate the head node ID for NodeReaper deployment.") from exc
22
+
23
+
24
+ HEAD_NODE_ID = get_head_node_id()
25
+
26
+
27
+ def node_affinity_for_head() -> NodeAffinitySchedulingStrategy:
28
+ return NodeAffinitySchedulingStrategy(node_id=HEAD_NODE_ID, soft=False)
29
+
30
+
31
+ def node_affinity_for_worker() -> NodeAffinitySchedulingStrategy:
32
+ return NodeAffinitySchedulingStrategy(node_id=NotIn(HEAD_NODE_ID), soft=False)
33
+
34
+
35
+ def get_node_reaper_handle() -> DeploymentHandle:
36
+ try:
37
+ return serve.context.get_deployment_handle(NODE_REAPER_DEPLOYMENT_NAME)
38
+ except Exception:
39
+ return serve.get_deployment(NODE_REAPER_DEPLOYMENT_NAME).get_handle(sync=False)
40
+
41
+
42
+ def get_current_replica_tag() -> Optional[str]:
43
+ try:
44
+ context = serve.context.get_current_replica_context()
45
+ except Exception:
46
+ context = None
47
+ if context is None:
48
+ return None
49
+ return getattr(context, "replica_tag", None)
50
+
51
+
52
+ def get_current_node_ip() -> Optional[str]:
53
+ try:
54
+ return get_node_ip_address()
55
+ except Exception:
56
+ return None
57
+
58
+
59
+ def report_unhealthy_replica(error: Optional[str] = None,
60
+ node_reaper: Optional[DeploymentHandle] = None) -> None:
61
+ replica_id = get_current_replica_tag()
62
+ node_ip = get_current_node_ip()
63
+ if not (replica_id and node_ip):
64
+ return
65
+ handle = node_reaper
66
+ if handle is None:
67
+ try:
68
+ handle = get_node_reaper_handle()
69
+ except Exception:
70
+ return
71
+ handle.report_failure.remote(replica_id, node_ip, error)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ray-embedding
3
- Version: 0.12.5
3
+ Version: 0.13.0
4
4
  Summary: Deploy SentenceTransformers embedding models to a ray cluster
5
5
  Author: Crispin Almodovar
6
6
  Author-email:
@@ -0,0 +1,11 @@
1
+ ray_embedding/__init__.py,sha256=YS5LAZfRIwwVvE3C9g7hsauvjgIkqKtHyxkwMFFfAGY,46
2
+ ray_embedding/deploy.py,sha256=N3YOin0sh5_HxXoyY3laLiaOn1vA-JkBI6SjeOg3HDw,3871
3
+ ray_embedding/dto.py,sha256=l0zVrzwjlvYqWsmY8LcoizKrW4AB3eSrGVXUNrIZp3o,1722
4
+ ray_embedding/embedding_model.py,sha256=Zr5lxVuy60y8-JgsOmKDD44FZlbTL1tiiY-3_72sTR4,4905
5
+ ray_embedding/model_router.py,sha256=W2c0hvqwDe1iCfNx4ee2UT7wKduywMP8dY0Ggb8xBvU,6658
6
+ ray_embedding/node_reaper.py,sha256=kxxkelMFVpVo827m3JL-2Wat0HSk71Tux3cgNMivX5w,4989
7
+ ray_embedding/utils.py,sha256=IE0uyBRIW9HVQjO1I46qvtxGQfdorPwU4FsAlYHjI9g,2186
8
+ ray_embedding-0.13.0.dist-info/METADATA,sha256=BL9-K72yoNDNBAdtL3jmlK_gBSWNTb093ngaTvGjqdI,1094
9
+ ray_embedding-0.13.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ ray_embedding-0.13.0.dist-info/top_level.txt,sha256=ziCblpJq1YsrryshFqxTRuRMgNuO1_tgvAAkGShATNA,14
11
+ ray_embedding-0.13.0.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- ray_embedding/__init__.py,sha256=YS5LAZfRIwwVvE3C9g7hsauvjgIkqKtHyxkwMFFfAGY,46
2
- ray_embedding/deploy.py,sha256=1Nzb39OylxBEMDqCcfD-ByTefhANXbjLMzLo_YAkCfw,2710
3
- ray_embedding/dto.py,sha256=l0hxz_fdGjZtLMZS3BzQ1tLzAOiO_8NpX4i5Wdyuk6Q,1519
4
- ray_embedding/embedding_model.py,sha256=6iEaIg_mCpGEY-5F0uff2wTOMH1wI42u2N8DnaZE3mA,4670
5
- ray_embedding/model_router.py,sha256=BsOEz24ttvpDD4LZsDVg9rLhn26FxgUsDAvcjI0Feao,5917
6
- ray_embedding-0.12.5.dist-info/METADATA,sha256=EvWeadexmzrfUATF6dYl8c54cGCfdp5EEeW23vkni38,1094
7
- ray_embedding-0.12.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- ray_embedding-0.12.5.dist-info/top_level.txt,sha256=ziCblpJq1YsrryshFqxTRuRMgNuO1_tgvAAkGShATNA,14
9
- ray_embedding-0.12.5.dist-info/RECORD,,