prehend 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prehend/__init__.py +19 -0
- prehend/clients/__init__.py +59 -0
- prehend/clients/anthropic.py +120 -0
- prehend/clients/azure_openai.py +152 -0
- prehend/clients/base_lm.py +43 -0
- prehend/clients/coordination.py +164 -0
- prehend/clients/gemini.py +172 -0
- prehend/clients/openai.py +564 -0
- prehend/clients/portkey.py +104 -0
- prehend/clients/scheduler.py +321 -0
- prehend/core/__init__.py +0 -0
- prehend/core/comms_utils.py +270 -0
- prehend/core/lm_handler.py +430 -0
- prehend/core/rlm.py +1270 -0
- prehend/core/srlm.py +459 -0
- prehend/core/types.py +303 -0
- prehend/core/verifier.py +215 -0
- prehend/environments/__init__.py +82 -0
- prehend/environments/base_env.py +388 -0
- prehend/environments/constants.py +32 -0
- prehend/environments/daytona_repl.py +708 -0
- prehend/environments/docker_repl.py +355 -0
- prehend/environments/e2b_repl.py +515 -0
- prehend/environments/ipython_repl.py +1521 -0
- prehend/environments/local_repl.py +765 -0
- prehend/environments/modal_repl.py +518 -0
- prehend/environments/prime_repl.py +604 -0
- prehend/logger/__init__.py +4 -0
- prehend/logger/rlm_logger.py +91 -0
- prehend/logger/verbose.py +538 -0
- prehend/memory/__init__.py +54 -0
- prehend/memory/bank.py +95 -0
- prehend/memory/distill.py +147 -0
- prehend/memory/embed.py +67 -0
- prehend/memory/embed_openai.py +35 -0
- prehend/memory/factory.py +94 -0
- prehend/memory/harness.py +116 -0
- prehend/memory/inject.py +56 -0
- prehend/memory/pruning_rules.py +57 -0
- prehend/memory/reflect.py +62 -0
- prehend/memory/retrieve.py +102 -0
- prehend/memory/tagger.py +25 -0
- prehend/metrics.py +404 -0
- prehend/utils/__init__.py +0 -0
- prehend/utils/exceptions.py +73 -0
- prehend/utils/parsing.py +122 -0
- prehend/utils/prompts.py +195 -0
- prehend/utils/rlm_utils.py +12 -0
- prehend/utils/token_utils.py +143 -0
- prehend-0.2.0.dist-info/METADATA +229 -0
- prehend-0.2.0.dist-info/RECORD +54 -0
- prehend-0.2.0.dist-info/WHEEL +5 -0
- prehend-0.2.0.dist-info/licenses/LICENSE +21 -0
- prehend-0.2.0.dist-info/top_level.txt +1 -0
prehend/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from prehend.core.rlm import RLM
|
|
2
|
+
from prehend.core.srlm import SRLM
|
|
3
|
+
from prehend.utils.exceptions import (
|
|
4
|
+
BudgetExceededError,
|
|
5
|
+
CancellationError,
|
|
6
|
+
ErrorThresholdExceededError,
|
|
7
|
+
TimeoutExceededError,
|
|
8
|
+
TokenLimitExceededError,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"RLM",
|
|
13
|
+
"SRLM",
|
|
14
|
+
"BudgetExceededError",
|
|
15
|
+
"TimeoutExceededError",
|
|
16
|
+
"TokenLimitExceededError",
|
|
17
|
+
"ErrorThresholdExceededError",
|
|
18
|
+
"CancellationError",
|
|
19
|
+
]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from dotenv import load_dotenv
|
|
4
|
+
|
|
5
|
+
from prehend.clients.base_lm import BaseLM
|
|
6
|
+
from prehend.core.types import ClientBackend
|
|
7
|
+
|
|
8
|
+
load_dotenv()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_client(
|
|
12
|
+
backend: ClientBackend,
|
|
13
|
+
backend_kwargs: dict[str, Any],
|
|
14
|
+
) -> BaseLM:
|
|
15
|
+
"""
|
|
16
|
+
Routes a specific backend and the args (as a dict) to the appropriate client if supported.
|
|
17
|
+
Currently supported backends: ['openai']
|
|
18
|
+
"""
|
|
19
|
+
if backend == "openai":
|
|
20
|
+
from prehend.clients.openai import OpenAIClient
|
|
21
|
+
|
|
22
|
+
return OpenAIClient(**backend_kwargs)
|
|
23
|
+
elif backend == "vllm":
|
|
24
|
+
from prehend.clients.openai import OpenAIClient
|
|
25
|
+
|
|
26
|
+
assert "base_url" in backend_kwargs, (
|
|
27
|
+
"base_url is required to be set to local vLLM server address for vLLM"
|
|
28
|
+
)
|
|
29
|
+
return OpenAIClient(**backend_kwargs)
|
|
30
|
+
elif backend == "portkey":
|
|
31
|
+
from prehend.clients.portkey import PortkeyClient
|
|
32
|
+
|
|
33
|
+
return PortkeyClient(**backend_kwargs)
|
|
34
|
+
elif backend == "openrouter":
|
|
35
|
+
from prehend.clients.openai import OpenAIClient
|
|
36
|
+
|
|
37
|
+
backend_kwargs.setdefault("base_url", "https://openrouter.ai/api/v1")
|
|
38
|
+
return OpenAIClient(**backend_kwargs)
|
|
39
|
+
elif backend == "vercel":
|
|
40
|
+
from prehend.clients.openai import OpenAIClient
|
|
41
|
+
|
|
42
|
+
backend_kwargs.setdefault("base_url", "https://ai-gateway.vercel.sh/v1")
|
|
43
|
+
return OpenAIClient(**backend_kwargs)
|
|
44
|
+
elif backend == "anthropic":
|
|
45
|
+
from prehend.clients.anthropic import AnthropicClient
|
|
46
|
+
|
|
47
|
+
return AnthropicClient(**backend_kwargs)
|
|
48
|
+
elif backend == "gemini":
|
|
49
|
+
from prehend.clients.gemini import GeminiClient
|
|
50
|
+
|
|
51
|
+
return GeminiClient(**backend_kwargs)
|
|
52
|
+
elif backend == "azure_openai":
|
|
53
|
+
from prehend.clients.azure_openai import AzureOpenAIClient
|
|
54
|
+
|
|
55
|
+
return AzureOpenAIClient(**backend_kwargs)
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Unknown backend: {backend}. Supported backends: ['openai', 'vllm', 'portkey', 'openrouter', 'anthropic', 'azure_openai', 'gemini', 'vercel']"
|
|
59
|
+
)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import anthropic
|
|
5
|
+
|
|
6
|
+
from prehend.clients.base_lm import BaseLM
|
|
7
|
+
from prehend.core.types import ModelUsageSummary, UsageSummary
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnthropicClient(BaseLM):
|
|
11
|
+
"""
|
|
12
|
+
LM Client for running models with the Anthropic API.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
api_key: str,
|
|
18
|
+
model_name: str | None = None,
|
|
19
|
+
max_tokens: int = 32768,
|
|
20
|
+
**kwargs,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(model_name=model_name, **kwargs)
|
|
23
|
+
self.client = anthropic.Anthropic(api_key=api_key, timeout=self.timeout)
|
|
24
|
+
self.async_client = anthropic.AsyncAnthropic(api_key=api_key, timeout=self.timeout)
|
|
25
|
+
self.model_name = model_name
|
|
26
|
+
self.max_tokens = max_tokens
|
|
27
|
+
|
|
28
|
+
# Per-model usage tracking
|
|
29
|
+
self.model_call_counts: dict[str, int] = defaultdict(int)
|
|
30
|
+
self.model_input_tokens: dict[str, int] = defaultdict(int)
|
|
31
|
+
self.model_output_tokens: dict[str, int] = defaultdict(int)
|
|
32
|
+
self.model_total_tokens: dict[str, int] = defaultdict(int)
|
|
33
|
+
|
|
34
|
+
def completion(
|
|
35
|
+
self,
|
|
36
|
+
prompt: str | list[dict[str, Any]],
|
|
37
|
+
model: str | None = None,
|
|
38
|
+
priority: str | int | None = None, # accepted for interface parity; no scheduler here
|
|
39
|
+
) -> str:
|
|
40
|
+
messages, system = self._prepare_messages(prompt)
|
|
41
|
+
|
|
42
|
+
model = model or self.model_name
|
|
43
|
+
if not model:
|
|
44
|
+
raise ValueError("Model name is required for Anthropic client.")
|
|
45
|
+
|
|
46
|
+
kwargs = {"model": model, "max_tokens": self.max_tokens, "messages": messages}
|
|
47
|
+
if system:
|
|
48
|
+
kwargs["system"] = system
|
|
49
|
+
|
|
50
|
+
response = self.client.messages.create(**kwargs)
|
|
51
|
+
self._track_cost(response, model)
|
|
52
|
+
return response.content[0].text
|
|
53
|
+
|
|
54
|
+
async def acompletion(
|
|
55
|
+
self,
|
|
56
|
+
prompt: str | list[dict[str, Any]],
|
|
57
|
+
model: str | None = None,
|
|
58
|
+
priority: str | int | None = None,
|
|
59
|
+
) -> str:
|
|
60
|
+
messages, system = self._prepare_messages(prompt)
|
|
61
|
+
|
|
62
|
+
model = model or self.model_name
|
|
63
|
+
if not model:
|
|
64
|
+
raise ValueError("Model name is required for Anthropic client.")
|
|
65
|
+
|
|
66
|
+
kwargs = {"model": model, "max_tokens": self.max_tokens, "messages": messages}
|
|
67
|
+
if system:
|
|
68
|
+
kwargs["system"] = system
|
|
69
|
+
|
|
70
|
+
response = await self.async_client.messages.create(**kwargs)
|
|
71
|
+
self._track_cost(response, model)
|
|
72
|
+
return response.content[0].text
|
|
73
|
+
|
|
74
|
+
def _prepare_messages(
|
|
75
|
+
self, prompt: str | list[dict[str, Any]]
|
|
76
|
+
) -> tuple[list[dict[str, Any]], str | None]:
|
|
77
|
+
"""Prepare messages and extract system prompt for Anthropic API."""
|
|
78
|
+
system = None
|
|
79
|
+
|
|
80
|
+
if isinstance(prompt, str):
|
|
81
|
+
messages = [{"role": "user", "content": prompt}]
|
|
82
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
83
|
+
# Extract system message if present (Anthropic handles system separately)
|
|
84
|
+
messages = []
|
|
85
|
+
for msg in prompt:
|
|
86
|
+
if msg.get("role") == "system":
|
|
87
|
+
system = msg.get("content")
|
|
88
|
+
else:
|
|
89
|
+
messages.append(msg)
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
92
|
+
|
|
93
|
+
return messages, system
|
|
94
|
+
|
|
95
|
+
def _track_cost(self, response: anthropic.types.Message, model: str):
|
|
96
|
+
self.model_call_counts[model] += 1
|
|
97
|
+
self.model_input_tokens[model] += response.usage.input_tokens
|
|
98
|
+
self.model_output_tokens[model] += response.usage.output_tokens
|
|
99
|
+
self.model_total_tokens[model] += response.usage.input_tokens + response.usage.output_tokens
|
|
100
|
+
|
|
101
|
+
# Track last call for handler to read
|
|
102
|
+
self.last_prompt_tokens = response.usage.input_tokens
|
|
103
|
+
self.last_completion_tokens = response.usage.output_tokens
|
|
104
|
+
|
|
105
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
106
|
+
model_summaries = {}
|
|
107
|
+
for model in self.model_call_counts:
|
|
108
|
+
model_summaries[model] = ModelUsageSummary(
|
|
109
|
+
total_calls=self.model_call_counts[model],
|
|
110
|
+
total_input_tokens=self.model_input_tokens[model],
|
|
111
|
+
total_output_tokens=self.model_output_tokens[model],
|
|
112
|
+
)
|
|
113
|
+
return UsageSummary(model_usage_summaries=model_summaries)
|
|
114
|
+
|
|
115
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
116
|
+
return ModelUsageSummary(
|
|
117
|
+
total_calls=1,
|
|
118
|
+
total_input_tokens=self.last_prompt_tokens,
|
|
119
|
+
total_output_tokens=self.last_completion_tokens,
|
|
120
|
+
)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import openai
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
|
|
8
|
+
from prehend.clients.base_lm import BaseLM
|
|
9
|
+
from prehend.core.types import ModelUsageSummary, UsageSummary
|
|
10
|
+
|
|
11
|
+
load_dotenv()
|
|
12
|
+
|
|
13
|
+
# Load API key from environment variable
|
|
14
|
+
DEFAULT_AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AzureOpenAIClient(BaseLM):
|
|
18
|
+
"""
|
|
19
|
+
LM Client for running models with the Azure OpenAI API.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
api_key: str | None = None,
|
|
25
|
+
model_name: str | None = None,
|
|
26
|
+
azure_endpoint: str | None = None,
|
|
27
|
+
api_version: str | None = None,
|
|
28
|
+
azure_deployment: str | None = None,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
super().__init__(model_name=model_name, **kwargs)
|
|
32
|
+
|
|
33
|
+
if api_key is None:
|
|
34
|
+
api_key = DEFAULT_AZURE_OPENAI_API_KEY
|
|
35
|
+
|
|
36
|
+
if azure_endpoint is None:
|
|
37
|
+
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
38
|
+
|
|
39
|
+
if api_version is None:
|
|
40
|
+
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-01")
|
|
41
|
+
|
|
42
|
+
if azure_deployment is None:
|
|
43
|
+
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
|
44
|
+
|
|
45
|
+
if azure_endpoint is None:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"azure_endpoint is required for Azure OpenAI client. "
|
|
48
|
+
"Set it via argument or AZURE_OPENAI_ENDPOINT environment variable."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
self.client = openai.AzureOpenAI(
|
|
52
|
+
api_key=api_key,
|
|
53
|
+
azure_endpoint=azure_endpoint,
|
|
54
|
+
api_version=api_version,
|
|
55
|
+
azure_deployment=azure_deployment,
|
|
56
|
+
timeout=self.timeout,
|
|
57
|
+
)
|
|
58
|
+
self.async_client = openai.AsyncAzureOpenAI(
|
|
59
|
+
api_key=api_key,
|
|
60
|
+
azure_endpoint=azure_endpoint,
|
|
61
|
+
api_version=api_version,
|
|
62
|
+
azure_deployment=azure_deployment,
|
|
63
|
+
timeout=self.timeout,
|
|
64
|
+
)
|
|
65
|
+
self.model_name = model_name
|
|
66
|
+
self.azure_deployment = azure_deployment
|
|
67
|
+
|
|
68
|
+
# Per-model usage tracking
|
|
69
|
+
self.model_call_counts: dict[str, int] = defaultdict(int)
|
|
70
|
+
self.model_input_tokens: dict[str, int] = defaultdict(int)
|
|
71
|
+
self.model_output_tokens: dict[str, int] = defaultdict(int)
|
|
72
|
+
self.model_total_tokens: dict[str, int] = defaultdict(int)
|
|
73
|
+
|
|
74
|
+
def completion(
|
|
75
|
+
self,
|
|
76
|
+
prompt: str | list[dict[str, Any]],
|
|
77
|
+
model: str | None = None,
|
|
78
|
+
priority: str | int | None = None, # accepted for interface parity; no scheduler here
|
|
79
|
+
) -> str:
|
|
80
|
+
if isinstance(prompt, str):
|
|
81
|
+
messages = [{"role": "user", "content": prompt}]
|
|
82
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
83
|
+
messages = prompt
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
86
|
+
|
|
87
|
+
model = model or self.model_name
|
|
88
|
+
if not model:
|
|
89
|
+
raise ValueError("Model name is required for Azure OpenAI client.")
|
|
90
|
+
|
|
91
|
+
response = self.client.chat.completions.create(
|
|
92
|
+
model=model,
|
|
93
|
+
messages=messages,
|
|
94
|
+
)
|
|
95
|
+
self._track_cost(response, model)
|
|
96
|
+
return response.choices[0].message.content
|
|
97
|
+
|
|
98
|
+
async def acompletion(
|
|
99
|
+
self,
|
|
100
|
+
prompt: str | list[dict[str, Any]],
|
|
101
|
+
model: str | None = None,
|
|
102
|
+
priority: str | int | None = None,
|
|
103
|
+
) -> str:
|
|
104
|
+
if isinstance(prompt, str):
|
|
105
|
+
messages = [{"role": "user", "content": prompt}]
|
|
106
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
107
|
+
messages = prompt
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
110
|
+
|
|
111
|
+
model = model or self.model_name
|
|
112
|
+
if not model:
|
|
113
|
+
raise ValueError("Model name is required for Azure OpenAI client.")
|
|
114
|
+
|
|
115
|
+
response = await self.async_client.chat.completions.create(
|
|
116
|
+
model=model,
|
|
117
|
+
messages=messages,
|
|
118
|
+
)
|
|
119
|
+
self._track_cost(response, model)
|
|
120
|
+
return response.choices[0].message.content
|
|
121
|
+
|
|
122
|
+
def _track_cost(self, response: openai.ChatCompletion, model: str):
|
|
123
|
+
self.model_call_counts[model] += 1
|
|
124
|
+
|
|
125
|
+
usage = getattr(response, "usage", None)
|
|
126
|
+
if usage is None:
|
|
127
|
+
raise ValueError("No usage data received. Tracking tokens not possible.")
|
|
128
|
+
|
|
129
|
+
self.model_input_tokens[model] += usage.prompt_tokens
|
|
130
|
+
self.model_output_tokens[model] += usage.completion_tokens
|
|
131
|
+
self.model_total_tokens[model] += usage.total_tokens
|
|
132
|
+
|
|
133
|
+
# Track last call for handler to read
|
|
134
|
+
self.last_prompt_tokens = usage.prompt_tokens
|
|
135
|
+
self.last_completion_tokens = usage.completion_tokens
|
|
136
|
+
|
|
137
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
138
|
+
model_summaries = {}
|
|
139
|
+
for model in self.model_call_counts:
|
|
140
|
+
model_summaries[model] = ModelUsageSummary(
|
|
141
|
+
total_calls=self.model_call_counts[model],
|
|
142
|
+
total_input_tokens=self.model_input_tokens[model],
|
|
143
|
+
total_output_tokens=self.model_output_tokens[model],
|
|
144
|
+
)
|
|
145
|
+
return UsageSummary(model_usage_summaries=model_summaries)
|
|
146
|
+
|
|
147
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
148
|
+
return ModelUsageSummary(
|
|
149
|
+
total_calls=1,
|
|
150
|
+
total_input_tokens=self.last_prompt_tokens,
|
|
151
|
+
total_output_tokens=self.last_completion_tokens,
|
|
152
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from prehend.core.types import ModelUsageSummary, UsageSummary
|
|
5
|
+
|
|
6
|
+
# Default timeout for LM API calls (in seconds)
|
|
7
|
+
DEFAULT_TIMEOUT: float = 300.0
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseLM(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Base class for all language model routers / clients. When the RLM makes sub-calls, it currently
|
|
13
|
+
does so in a model-agnostic way, so this class provides a base interface for all language models.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, model_name: str, timeout: float = DEFAULT_TIMEOUT, **kwargs):
|
|
17
|
+
self.model_name = model_name
|
|
18
|
+
self.timeout = timeout
|
|
19
|
+
self.kwargs = kwargs
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def completion(
|
|
23
|
+
self, prompt: str | dict[str, Any], priority: str | int | None = None
|
|
24
|
+
) -> str:
|
|
25
|
+
"""Run a completion. priority is a scheduling hint ("high"/"low"/"normal" or 1-5);
|
|
26
|
+
backends without a request scheduler may ignore it."""
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
async def acompletion(
|
|
31
|
+
self, prompt: str | dict[str, Any], priority: str | int | None = None
|
|
32
|
+
) -> str:
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
37
|
+
"""Get cost summary for all model calls."""
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
42
|
+
"""Get the last cost summary of the model."""
|
|
43
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Cross-process admission gate for RequestScheduler (two-flock gate+pool).
|
|
2
|
+
|
|
3
|
+
Design: docs/superpowers/specs/2026-06-10-cross-process-coordination-design.md.
|
|
4
|
+
|
|
5
|
+
Two lock files per server key in a shared coordination directory:
|
|
6
|
+
|
|
7
|
+
<dir>/<key>.gate - doorway. Normal requests hold SH momentarily on the
|
|
8
|
+
way in; a p1 holds EX for its whole run, which freezes
|
|
9
|
+
new admissions machine-wide (the cross-process
|
|
10
|
+
_waiting_p1 rule) and serializes p1s globally.
|
|
11
|
+
<dir>/<key>.pool - the in-flight set. Normal requests hold SH for the
|
|
12
|
+
request duration; a p1 takes EX, granted only when
|
|
13
|
+
every holder drains (the cross-process _active == 0
|
|
14
|
+
rule).
|
|
15
|
+
|
|
16
|
+
Crash cleanup is the kernel's: flock drops when an fd closes, including on
|
|
17
|
+
process death. The gate distinguishes only p1 vs everything else; p2-p5
|
|
18
|
+
ordering stays in-process. Same-host processes only (flock does not span
|
|
19
|
+
machines, and network filesystems are explicitly out of scope).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import asyncio
|
|
23
|
+
import fcntl
|
|
24
|
+
import logging
|
|
25
|
+
import os
|
|
26
|
+
import threading
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
from prehend.clients.scheduler import Priority
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
# Async acquisition polls LOCK_NB at this interval instead of blocking a
|
|
34
|
+
# thread: a cancelled task cannot interrupt a blocking flock in an executor
|
|
35
|
+
# thread, and that thread would eventually acquire a lock nobody releases.
|
|
36
|
+
POLL_INTERVAL = 0.025
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CrossProcessGate:
|
|
40
|
+
"""Two-flock readers-writer gate with writer preference.
|
|
41
|
+
|
|
42
|
+
enter()/aenter() acquire for one request; exit() releases one acquisition
|
|
43
|
+
(non-blocking fd closes, so both sync and async paths use it). Normal
|
|
44
|
+
requests' pool fds are fungible: exit(NORMAL) closes any one of this
|
|
45
|
+
process's SH holds, which the kernel treats identically.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, coordination_dir: str | Path, server_key: str):
|
|
49
|
+
self._dir = Path(coordination_dir)
|
|
50
|
+
self._gate_path = self._dir / f"{server_key}.gate"
|
|
51
|
+
self._pool_path = self._dir / f"{server_key}.pool"
|
|
52
|
+
self._mu = threading.Lock()
|
|
53
|
+
self._pool_fds: list[int] = [] # one SH fd per in-flight normal request
|
|
54
|
+
self._p1_fds: tuple[int, int] | None = None # (gate_fd, pool_fd) of the active p1
|
|
55
|
+
# Fail fast: surface an unwritable dir or a no-flock filesystem at
|
|
56
|
+
# construction, not on request N.
|
|
57
|
+
try:
|
|
58
|
+
self._dir.mkdir(parents=True, exist_ok=True)
|
|
59
|
+
for path in (self._gate_path, self._pool_path):
|
|
60
|
+
fd = self._open(path)
|
|
61
|
+
try:
|
|
62
|
+
fcntl.flock(fd, fcntl.LOCK_SH | fcntl.LOCK_NB)
|
|
63
|
+
fcntl.flock(fd, fcntl.LOCK_UN)
|
|
64
|
+
except BlockingIOError:
|
|
65
|
+
pass # held EX by a live p1 elsewhere: flock works here
|
|
66
|
+
finally:
|
|
67
|
+
os.close(fd)
|
|
68
|
+
except OSError as e:
|
|
69
|
+
raise RuntimeError(
|
|
70
|
+
f"cross-process coordination unavailable at {self._dir}: {e}"
|
|
71
|
+
) from e
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _open(path: Path) -> int:
|
|
75
|
+
return os.open(path, os.O_RDWR | os.O_CREAT, 0o644)
|
|
76
|
+
|
|
77
|
+
def enter(self, priority: int) -> None:
|
|
78
|
+
"""Blocking acquisition for one request. Releases partial holds and
|
|
79
|
+
re-raises on failure, leaving no lock behind."""
|
|
80
|
+
if priority == Priority.CONTENTION_RETRY:
|
|
81
|
+
gate_fd = self._open(self._gate_path)
|
|
82
|
+
try:
|
|
83
|
+
fcntl.flock(gate_fd, fcntl.LOCK_EX)
|
|
84
|
+
pool_fd = self._open(self._pool_path)
|
|
85
|
+
try:
|
|
86
|
+
fcntl.flock(pool_fd, fcntl.LOCK_EX)
|
|
87
|
+
except BaseException:
|
|
88
|
+
os.close(pool_fd)
|
|
89
|
+
raise
|
|
90
|
+
except BaseException:
|
|
91
|
+
os.close(gate_fd)
|
|
92
|
+
raise
|
|
93
|
+
with self._mu:
|
|
94
|
+
self._p1_fds = (gate_fd, pool_fd)
|
|
95
|
+
else:
|
|
96
|
+
gate_fd = self._open(self._gate_path)
|
|
97
|
+
try:
|
|
98
|
+
fcntl.flock(gate_fd, fcntl.LOCK_SH)
|
|
99
|
+
pool_fd = self._open(self._pool_path)
|
|
100
|
+
try:
|
|
101
|
+
fcntl.flock(pool_fd, fcntl.LOCK_SH)
|
|
102
|
+
except BaseException:
|
|
103
|
+
os.close(pool_fd)
|
|
104
|
+
raise
|
|
105
|
+
finally:
|
|
106
|
+
# The gate is only the doorway: release it whether or not the
|
|
107
|
+
# pool acquisition succeeded.
|
|
108
|
+
os.close(gate_fd)
|
|
109
|
+
with self._mu:
|
|
110
|
+
self._pool_fds.append(pool_fd)
|
|
111
|
+
|
|
112
|
+
async def aenter(self, priority: int) -> None:
|
|
113
|
+
"""Async acquisition: LOCK_NB poll loop (POLL_INTERVAL) instead of a
|
|
114
|
+
blocking flock in an executor thread, so task cancellation can never
|
|
115
|
+
strand a lock in a thread nobody joins. On any failure, including
|
|
116
|
+
CancelledError, partial holds are released before re-raising."""
|
|
117
|
+
op = fcntl.LOCK_EX if priority == Priority.CONTENTION_RETRY else fcntl.LOCK_SH
|
|
118
|
+
gate_fd = self._open(self._gate_path)
|
|
119
|
+
try:
|
|
120
|
+
await self._apoll(gate_fd, op)
|
|
121
|
+
pool_fd = self._open(self._pool_path)
|
|
122
|
+
try:
|
|
123
|
+
await self._apoll(pool_fd, op)
|
|
124
|
+
except BaseException:
|
|
125
|
+
os.close(pool_fd)
|
|
126
|
+
raise
|
|
127
|
+
except BaseException:
|
|
128
|
+
os.close(gate_fd)
|
|
129
|
+
raise
|
|
130
|
+
if priority == Priority.CONTENTION_RETRY:
|
|
131
|
+
with self._mu:
|
|
132
|
+
self._p1_fds = (gate_fd, pool_fd)
|
|
133
|
+
else:
|
|
134
|
+
os.close(gate_fd)
|
|
135
|
+
with self._mu:
|
|
136
|
+
self._pool_fds.append(pool_fd)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
async def _apoll(fd: int, op: int) -> None:
|
|
140
|
+
while True:
|
|
141
|
+
try:
|
|
142
|
+
fcntl.flock(fd, op | fcntl.LOCK_NB)
|
|
143
|
+
return
|
|
144
|
+
except BlockingIOError:
|
|
145
|
+
await asyncio.sleep(POLL_INTERVAL)
|
|
146
|
+
|
|
147
|
+
def exit(self, priority: int) -> None:
|
|
148
|
+
"""Release one acquisition. Never raises: it sits in finally paths,
|
|
149
|
+
and the locks are released by the fd close regardless."""
|
|
150
|
+
try:
|
|
151
|
+
if priority == Priority.CONTENTION_RETRY:
|
|
152
|
+
with self._mu:
|
|
153
|
+
fds, self._p1_fds = self._p1_fds, None
|
|
154
|
+
if fds is not None:
|
|
155
|
+
gate_fd, pool_fd = fds
|
|
156
|
+
os.close(pool_fd)
|
|
157
|
+
os.close(gate_fd)
|
|
158
|
+
else:
|
|
159
|
+
with self._mu:
|
|
160
|
+
pool_fd = self._pool_fds.pop() if self._pool_fds else None
|
|
161
|
+
if pool_fd is not None:
|
|
162
|
+
os.close(pool_fd)
|
|
163
|
+
except OSError as e:
|
|
164
|
+
log.warning("gate exit failed (locks still released on close): %s", e)
|