adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import html
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
from PIL import Image
|
|
7
|
+
|
|
8
|
+
from adaptive_harmony import StringThread
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def save_inlined_string_to_png(image_str: str, path_to_save: str):
|
|
12
|
+
_, base64_str = image_str.split(",")
|
|
13
|
+
image_data = base64.b64decode(base64_str)
|
|
14
|
+
image = Image.open(io.BytesIO(image_data))
|
|
15
|
+
image.save(path_to_save)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def save_string_thread_to_markdown_dir(thread: StringThread, directory_path: str):
|
|
19
|
+
all_fragments = thread.get_fragments()
|
|
20
|
+
|
|
21
|
+
os.makedirs(directory_path, exist_ok=True)
|
|
22
|
+
image_counter = 0
|
|
23
|
+
with open(os.path.join(directory_path, "main.md"), "w") as w:
|
|
24
|
+
for role, fragments in all_fragments:
|
|
25
|
+
w.write(f"# {role}\n")
|
|
26
|
+
for fragment in fragments:
|
|
27
|
+
if fragment["type"] == "text":
|
|
28
|
+
w.write(fragment["text"])
|
|
29
|
+
else:
|
|
30
|
+
assert fragment["type"] == "image"
|
|
31
|
+
filename = f"image_{image_counter}.jpeg"
|
|
32
|
+
image_path = os.path.join(directory_path, filename)
|
|
33
|
+
save_inlined_string_to_png(fragment["url"], image_path)
|
|
34
|
+
w.write(f"\n")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def string_thread_to_html_string(thread: StringThread) -> str:
|
|
38
|
+
"""
|
|
39
|
+
Converts a StringThread into a single, valid HTML string.
|
|
40
|
+
Text is converted to paragraphs, and images are embedded directly
|
|
41
|
+
using their base64 data URIs.
|
|
42
|
+
"""
|
|
43
|
+
html_parts = []
|
|
44
|
+
all_fragments = thread.get_fragments()
|
|
45
|
+
|
|
46
|
+
for role, fragments in all_fragments:
|
|
47
|
+
# Use a heading for the role (e.g., 'user', 'model')
|
|
48
|
+
html_parts.append(f"<h2>{role.capitalize()}</h2>")
|
|
49
|
+
for fragment in fragments:
|
|
50
|
+
if fragment["type"] == "text":
|
|
51
|
+
# Escape HTML special characters. CSS will handle wrapping and newlines.
|
|
52
|
+
text_content = html.escape(fragment["text"])
|
|
53
|
+
html_parts.append(f"<p>{text_content}</p>")
|
|
54
|
+
elif fragment["type"] == "image":
|
|
55
|
+
# The 'url' is the data URI, which can be used directly in the <img> src
|
|
56
|
+
image_url = fragment["url"]
|
|
57
|
+
html_parts.append(
|
|
58
|
+
f'<img src="{image_url}" alt="Embedded image content" style="max-width: 500px; height: auto; display: block; margin: 10px 0;">'
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Combine all parts into a single string
|
|
62
|
+
body_content = "\n".join(html_parts)
|
|
63
|
+
|
|
64
|
+
# Wrap the content in a complete, styled HTML document
|
|
65
|
+
return f"""
|
|
66
|
+
<!DOCTYPE html>
|
|
67
|
+
<html lang="en">
|
|
68
|
+
<head>
|
|
69
|
+
<meta charset="UTF-8">
|
|
70
|
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
71
|
+
<title>StringThread Conversation</title>
|
|
72
|
+
<style>
|
|
73
|
+
body {{
|
|
74
|
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
|
75
|
+
line-height: 1.6;
|
|
76
|
+
color: #000000;
|
|
77
|
+
margin: 20px auto;
|
|
78
|
+
padding: 0 20px;
|
|
79
|
+
}}
|
|
80
|
+
h2 {{
|
|
81
|
+
border-bottom: 2px solid #eee;
|
|
82
|
+
padding-bottom: 10px;
|
|
83
|
+
margin-top: 40px;
|
|
84
|
+
margin-left: 40px;
|
|
85
|
+
width: 600px;
|
|
86
|
+
color: #000000;
|
|
87
|
+
}}
|
|
88
|
+
p {{
|
|
89
|
+
width: 600px;
|
|
90
|
+
margin: 16px 0;
|
|
91
|
+
margin-left: 40px;
|
|
92
|
+
word-break: break-word; /* Helps with breaking long words if needed */
|
|
93
|
+
white-space: pre-wrap; /* Preserve newlines and wrap long lines */
|
|
94
|
+
}}
|
|
95
|
+
img {{
|
|
96
|
+
border: 1px solid #ddd;
|
|
97
|
+
border-radius: 8px;
|
|
98
|
+
padding: 5px;
|
|
99
|
+
background-color: #ffffff; /* Changed to white */
|
|
100
|
+
}}
|
|
101
|
+
</style>
|
|
102
|
+
</head>
|
|
103
|
+
<body>
|
|
104
|
+
<div style="background-color: #ffffff; width: 680px;">
|
|
105
|
+
<h1>Conversation Thread</h1>
|
|
106
|
+
{body_content}
|
|
107
|
+
</div>
|
|
108
|
+
</body>
|
|
109
|
+
</html>
|
|
110
|
+
"""
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def remove_reasoning(completion: str) -> str:
|
|
5
|
+
"""Get the completion without the reasoning.
|
|
6
|
+
|
|
7
|
+
This is a quick implementation for Qwen3 thinking tags only at the moment.
|
|
8
|
+
"""
|
|
9
|
+
# Remove either <think>...</think> tags or content from start to </think> if <think> is a prefix from the chat template.
|
|
10
|
+
result = re.sub(r"<think>.*?</think>|^.*?</think>", "", completion, count=1, flags=re.DOTALL)
|
|
11
|
+
|
|
12
|
+
return result.strip()
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from adaptive_harmony.core.reward_client.client import RewardClient
|
|
2
|
+
from adaptive_harmony.core.reward_client.reward_types import (
|
|
3
|
+
MetadataValidationResponse,
|
|
4
|
+
Request,
|
|
5
|
+
Response,
|
|
6
|
+
ServerInfo,
|
|
7
|
+
Turn,
|
|
8
|
+
)
|
|
9
|
+
from adaptive_harmony.core.reward_client.websocket_utils import ResponseAccumulator
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"RewardClient",
|
|
13
|
+
"Turn",
|
|
14
|
+
"Request",
|
|
15
|
+
"Response",
|
|
16
|
+
"MetadataValidationResponse",
|
|
17
|
+
"ServerInfo",
|
|
18
|
+
"ResponseAccumulator",
|
|
19
|
+
]
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Final
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from httpx import Limits
|
|
7
|
+
from jsonschema import ValidationError as JsonSchemaValidationError
|
|
8
|
+
from jsonschema import validate
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from websockets.asyncio.client import ClientConnection, connect
|
|
11
|
+
|
|
12
|
+
from adaptive_harmony.core.reward_client.reward_types import (
|
|
13
|
+
MetadataValidationResponse,
|
|
14
|
+
Request,
|
|
15
|
+
Response,
|
|
16
|
+
ServerInfo,
|
|
17
|
+
Turn,
|
|
18
|
+
)
|
|
19
|
+
from adaptive_harmony.core.reward_client.websocket_utils import ResponseAccumulator
|
|
20
|
+
|
|
21
|
+
SCORE_PATH: Final = "/score"
|
|
22
|
+
INFO_PATH: Final = "/info"
|
|
23
|
+
METADATA_SCHEMA_PATH: Final = "/metadata_schema"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def read_task(client: ClientConnection, responses: dict[int, Response | asyncio.Event]):
|
|
27
|
+
response_accumulators: dict[int, ResponseAccumulator] = {}
|
|
28
|
+
while True:
|
|
29
|
+
try:
|
|
30
|
+
msg = await client.recv()
|
|
31
|
+
obj = json.loads(msg)
|
|
32
|
+
id = obj["id"]
|
|
33
|
+
if total_num_chunks := obj.get("total_num_chunks"):
|
|
34
|
+
assert id not in response_accumulators
|
|
35
|
+
response_accumulators[id] = ResponseAccumulator(total_num_chunks)
|
|
36
|
+
else:
|
|
37
|
+
assert (acc := response_accumulators.get(id))
|
|
38
|
+
acc.add_chunk(obj["chunk"])
|
|
39
|
+
if acc.is_complete():
|
|
40
|
+
del response_accumulators[id]
|
|
41
|
+
response = Response.model_validate_json(acc.get_full_data()) # type: ignore
|
|
42
|
+
# sanity check
|
|
43
|
+
assert response.id == id
|
|
44
|
+
event: asyncio.Event = responses.pop(id) # type: ignore
|
|
45
|
+
responses[id] = response
|
|
46
|
+
event.set()
|
|
47
|
+
except Exception as e:
|
|
48
|
+
logger.error(f"{e}")
|
|
49
|
+
break
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class RewardClient:
|
|
53
|
+
def __init__(self, base_url: str, max_connections: int = 32, timeout: float | None = None):
|
|
54
|
+
if base_url.startswith("https://"):
|
|
55
|
+
self.use_secure_protocol = True
|
|
56
|
+
self.base_url = base_url.removeprefix("https://")
|
|
57
|
+
else:
|
|
58
|
+
assert base_url.startswith("http://"), f"Unknown url format {base_url}"
|
|
59
|
+
self.use_secure_protocol = False
|
|
60
|
+
self.base_url = base_url.removeprefix("http://")
|
|
61
|
+
|
|
62
|
+
self._client = httpx.AsyncClient(
|
|
63
|
+
headers=dict(),
|
|
64
|
+
base_url=self._get_http_url(),
|
|
65
|
+
timeout=timeout,
|
|
66
|
+
limits=Limits(max_connections=max_connections),
|
|
67
|
+
)
|
|
68
|
+
self._metadata_json_schema: None | dict[str, Any] = None
|
|
69
|
+
self.use_websocket = False
|
|
70
|
+
self.max_connections = max_connections
|
|
71
|
+
|
|
72
|
+
def _get_http_url(self):
|
|
73
|
+
if self.use_secure_protocol:
|
|
74
|
+
return f"https://{self.base_url}"
|
|
75
|
+
else:
|
|
76
|
+
return f"http://{self.base_url}"
|
|
77
|
+
|
|
78
|
+
def _get_ws_url(self):
|
|
79
|
+
if self.use_secure_protocol:
|
|
80
|
+
return f"wss://{self.base_url}/ws"
|
|
81
|
+
else:
|
|
82
|
+
return f"ws://{self.base_url}/ws"
|
|
83
|
+
|
|
84
|
+
async def setup(self):
|
|
85
|
+
await self.connect_websocket()
|
|
86
|
+
|
|
87
|
+
async def connect_websocket(self):
|
|
88
|
+
self.use_websocket = True
|
|
89
|
+
self.ws_client: ClientConnection = await connect(self._get_ws_url(), ping_timeout=None)
|
|
90
|
+
self.ws_responses: dict[int, Response | asyncio.Event] = dict()
|
|
91
|
+
logger.info("Spawning_read_task")
|
|
92
|
+
self.read_task = asyncio.create_task(read_task(self.ws_client, self.ws_responses))
|
|
93
|
+
self.request_id = 0
|
|
94
|
+
# no need to blast 2x more than the amount of workers
|
|
95
|
+
self.semaphore = asyncio.Semaphore(self.max_connections)
|
|
96
|
+
return self
|
|
97
|
+
|
|
98
|
+
async def drop_websocket(self):
|
|
99
|
+
assert self.use_websocket
|
|
100
|
+
logger.info("Cancelling_read_task")
|
|
101
|
+
self.read_task.cancel()
|
|
102
|
+
await self.ws_client.close()
|
|
103
|
+
|
|
104
|
+
async def _post(self, path: str, data: dict) -> httpx.Response:
|
|
105
|
+
response = await self._client.post(path, json=data)
|
|
106
|
+
response.raise_for_status()
|
|
107
|
+
return response
|
|
108
|
+
|
|
109
|
+
async def _ws_post(self, req: Request) -> Response:
|
|
110
|
+
async with self.semaphore:
|
|
111
|
+
request_id = self.request_id
|
|
112
|
+
self.request_id += 1
|
|
113
|
+
req.id = request_id
|
|
114
|
+
event = asyncio.Event()
|
|
115
|
+
self.ws_responses[request_id] = event
|
|
116
|
+
await self.ws_client.send(req.model_dump_json())
|
|
117
|
+
await event.wait()
|
|
118
|
+
response: Response = self.ws_responses.pop(request_id) # type: ignore
|
|
119
|
+
assert response.id == request_id
|
|
120
|
+
return response
|
|
121
|
+
|
|
122
|
+
async def score(self, req: Request) -> Response:
|
|
123
|
+
if not self.use_websocket:
|
|
124
|
+
response = await self._post(SCORE_PATH, req.model_dump())
|
|
125
|
+
return Response(**response.json())
|
|
126
|
+
else:
|
|
127
|
+
return await self._ws_post(req)
|
|
128
|
+
|
|
129
|
+
async def validate_metadata(self, metadata: dict[Any, Any]):
|
|
130
|
+
if self._metadata_json_schema is None:
|
|
131
|
+
response = await self._client.get(METADATA_SCHEMA_PATH)
|
|
132
|
+
self._metadata_json_schema = response.json()
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
validate(instance=metadata, schema=self._metadata_json_schema) # type: ignore
|
|
136
|
+
return MetadataValidationResponse(is_valid=True)
|
|
137
|
+
except JsonSchemaValidationError as e:
|
|
138
|
+
return MetadataValidationResponse(is_valid=False, error_message=str(e))
|
|
139
|
+
|
|
140
|
+
async def info(self) -> ServerInfo:
|
|
141
|
+
response = await self._client.get(INFO_PATH)
|
|
142
|
+
response.raise_for_status()
|
|
143
|
+
return ServerInfo(**response.json())
|
|
144
|
+
|
|
145
|
+
def blocking_info(self) -> ServerInfo:
|
|
146
|
+
return ServerInfo(**httpx.get(self._client.base_url.join(INFO_PATH), timeout=self._client.timeout).json())
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
async def main():
|
|
150
|
+
# client = RewardClient("0.0.0.0:50056")
|
|
151
|
+
client = await RewardClient("0.0.0.0:50056").connect_websocket()
|
|
152
|
+
tasks = []
|
|
153
|
+
for _ in range(1024):
|
|
154
|
+
task = client.score(Request(turns=[Turn(role="assistant", content="hello")]))
|
|
155
|
+
tasks.append(task)
|
|
156
|
+
await asyncio.gather(*tasks)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
if __name__ == "__main__":
|
|
160
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Any, TypeVar
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel as PydanticBaseModel
|
|
4
|
+
from pydantic import ConfigDict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseModel(PydanticBaseModel):
|
|
8
|
+
model_config = ConfigDict(extra="forbid")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
MetadataType = TypeVar("MetadataType", bound=BaseModel)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Turn(BaseModel):
|
|
15
|
+
"""@public"""
|
|
16
|
+
|
|
17
|
+
role: str
|
|
18
|
+
content: str
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Request(BaseModel):
|
|
22
|
+
"""@public"""
|
|
23
|
+
|
|
24
|
+
turns: list[Turn]
|
|
25
|
+
metadata: dict[str, Any] | None = None
|
|
26
|
+
id: int | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Response(BaseModel):
|
|
30
|
+
"""@public"""
|
|
31
|
+
|
|
32
|
+
reward: float
|
|
33
|
+
metadata: dict[str, Any]
|
|
34
|
+
id: int | None = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MetadataValidationResponse(BaseModel):
|
|
38
|
+
"""@public"""
|
|
39
|
+
|
|
40
|
+
is_valid: bool
|
|
41
|
+
error_message: str | None = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ServerInfo(BaseModel):
|
|
45
|
+
"""@public"""
|
|
46
|
+
|
|
47
|
+
version: str
|
|
48
|
+
name: str
|
|
49
|
+
description: str
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class ResponseAccumulator:
|
|
6
|
+
total_num_chunks: int = 0
|
|
7
|
+
chunks_received: list[str] = field(default_factory=list)
|
|
8
|
+
|
|
9
|
+
def add_chunk(self, chunk: str):
|
|
10
|
+
self.chunks_received.append(chunk)
|
|
11
|
+
|
|
12
|
+
def is_complete(self) -> bool:
|
|
13
|
+
return len(self.chunks_received) == self.total_num_chunks
|
|
14
|
+
|
|
15
|
+
def get_full_data(self) -> str | None:
|
|
16
|
+
if self.is_complete():
|
|
17
|
+
return "".join(self.chunks_received)
|
|
18
|
+
return None
|