adaptive-harmony 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. adaptive_harmony/__init__.py +162 -0
  2. adaptive_harmony/common/__init__.py +40 -0
  3. adaptive_harmony/common/callbacks.py +219 -0
  4. adaptive_harmony/common/checkpointing.py +163 -0
  5. adaptive_harmony/common/dpo.py +92 -0
  6. adaptive_harmony/common/env_grpo.py +361 -0
  7. adaptive_harmony/common/grpo.py +260 -0
  8. adaptive_harmony/common/gspo.py +70 -0
  9. adaptive_harmony/common/ppo.py +303 -0
  10. adaptive_harmony/common/rm.py +79 -0
  11. adaptive_harmony/common/sft.py +121 -0
  12. adaptive_harmony/core/__init__.py +0 -0
  13. adaptive_harmony/core/dataset.py +72 -0
  14. adaptive_harmony/core/display.py +93 -0
  15. adaptive_harmony/core/image_utils.py +110 -0
  16. adaptive_harmony/core/reasoning.py +12 -0
  17. adaptive_harmony/core/reward_client/__init__.py +19 -0
  18. adaptive_harmony/core/reward_client/client.py +160 -0
  19. adaptive_harmony/core/reward_client/reward_types.py +49 -0
  20. adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  21. adaptive_harmony/core/rich_counter.py +351 -0
  22. adaptive_harmony/core/rl_utils.py +38 -0
  23. adaptive_harmony/core/schedulers.py +38 -0
  24. adaptive_harmony/core/structured_output.py +385 -0
  25. adaptive_harmony/core/utils.py +365 -0
  26. adaptive_harmony/environment/__init__.py +8 -0
  27. adaptive_harmony/environment/environment.py +121 -0
  28. adaptive_harmony/evaluation/__init__.py +1 -0
  29. adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  30. adaptive_harmony/graders/__init__.py +20 -0
  31. adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  32. adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  33. adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  34. adaptive_harmony/graders/base_grader.py +265 -0
  35. adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  36. adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  37. adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  38. adaptive_harmony/graders/combined_grader.py +118 -0
  39. adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  40. adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  41. adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  42. adaptive_harmony/graders/exceptions.py +9 -0
  43. adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  44. adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  45. adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  46. adaptive_harmony/graders/range_judge/__init__.py +7 -0
  47. adaptive_harmony/graders/range_judge/prompts.py +232 -0
  48. adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  49. adaptive_harmony/graders/range_judge/types.py +12 -0
  50. adaptive_harmony/graders/reward_server_grader.py +36 -0
  51. adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  52. adaptive_harmony/graders/utils.py +79 -0
  53. adaptive_harmony/logging_table.py +1 -0
  54. adaptive_harmony/metric_logger.py +452 -0
  55. adaptive_harmony/parameters/__init__.py +2 -0
  56. adaptive_harmony/py.typed +0 -0
  57. adaptive_harmony/runtime/__init__.py +2 -0
  58. adaptive_harmony/runtime/context.py +2 -0
  59. adaptive_harmony/runtime/data.py +2 -0
  60. adaptive_harmony/runtime/decorators.py +2 -0
  61. adaptive_harmony/runtime/model_artifact_save.py +2 -0
  62. adaptive_harmony/runtime/runner.py +27 -0
  63. adaptive_harmony/runtime/simple_notifier.py +2 -0
  64. adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
  65. adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
  66. adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
  67. adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,110 @@
1
+ import base64
2
+ import html
3
+ import io
4
+ import os
5
+
6
+ from PIL import Image
7
+
8
+ from adaptive_harmony import StringThread
9
+
10
+
11
+ def save_inlined_string_to_png(image_str: str, path_to_save: str):
12
+ _, base64_str = image_str.split(",")
13
+ image_data = base64.b64decode(base64_str)
14
+ image = Image.open(io.BytesIO(image_data))
15
+ image.save(path_to_save)
16
+
17
+
18
+ def save_string_thread_to_markdown_dir(thread: StringThread, directory_path: str):
19
+ all_fragments = thread.get_fragments()
20
+
21
+ os.makedirs(directory_path, exist_ok=True)
22
+ image_counter = 0
23
+ with open(os.path.join(directory_path, "main.md"), "w") as w:
24
+ for role, fragments in all_fragments:
25
+ w.write(f"# {role}\n")
26
+ for fragment in fragments:
27
+ if fragment["type"] == "text":
28
+ w.write(fragment["text"])
29
+ else:
30
+ assert fragment["type"] == "image"
31
+ filename = f"image_{image_counter}.jpeg"
32
+ image_path = os.path.join(directory_path, filename)
33
+ save_inlined_string_to_png(fragment["url"], image_path)
34
+ w.write(f"![image info]({filename})\n")
35
+
36
+
37
+ def string_thread_to_html_string(thread: StringThread) -> str:
38
+ """
39
+ Converts a StringThread into a single, valid HTML string.
40
+ Text is converted to paragraphs, and images are embedded directly
41
+ using their base64 data URIs.
42
+ """
43
+ html_parts = []
44
+ all_fragments = thread.get_fragments()
45
+
46
+ for role, fragments in all_fragments:
47
+ # Use a heading for the role (e.g., 'user', 'model')
48
+ html_parts.append(f"<h2>{role.capitalize()}</h2>")
49
+ for fragment in fragments:
50
+ if fragment["type"] == "text":
51
+ # Escape HTML special characters. CSS will handle wrapping and newlines.
52
+ text_content = html.escape(fragment["text"])
53
+ html_parts.append(f"<p>{text_content}</p>")
54
+ elif fragment["type"] == "image":
55
+ # The 'url' is the data URI, which can be used directly in the <img> src
56
+ image_url = fragment["url"]
57
+ html_parts.append(
58
+ f'<img src="{image_url}" alt="Embedded image content" style="max-width: 500px; height: auto; display: block; margin: 10px 0;">'
59
+ )
60
+
61
+ # Combine all parts into a single string
62
+ body_content = "\n".join(html_parts)
63
+
64
+ # Wrap the content in a complete, styled HTML document
65
+ return f"""
66
+ <!DOCTYPE html>
67
+ <html lang="en">
68
+ <head>
69
+ <meta charset="UTF-8">
70
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
71
+ <title>StringThread Conversation</title>
72
+ <style>
73
+ body {{
74
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
75
+ line-height: 1.6;
76
+ color: #000000;
77
+ margin: 20px auto;
78
+ padding: 0 20px;
79
+ }}
80
+ h2 {{
81
+ border-bottom: 2px solid #eee;
82
+ padding-bottom: 10px;
83
+ margin-top: 40px;
84
+ margin-left: 40px;
85
+ width: 600px;
86
+ color: #000000;
87
+ }}
88
+ p {{
89
+ width: 600px;
90
+ margin: 16px 0;
91
+ margin-left: 40px;
92
+ word-break: break-word; /* Helps with breaking long words if needed */
93
+ white-space: pre-wrap; /* Preserve newlines and wrap long lines */
94
+ }}
95
+ img {{
96
+ border: 1px solid #ddd;
97
+ border-radius: 8px;
98
+ padding: 5px;
99
+ background-color: #ffffff; /* Changed to white */
100
+ }}
101
+ </style>
102
+ </head>
103
+ <body>
104
+ <div style="background-color: #ffffff; width: 680px;">
105
+ <h1>Conversation Thread</h1>
106
+ {body_content}
107
+ </div>
108
+ </body>
109
+ </html>
110
+ """
@@ -0,0 +1,12 @@
1
+ import re
2
+
3
+
4
+ def remove_reasoning(completion: str) -> str:
5
+ """Get the completion without the reasoning.
6
+
7
+ This is a quick implementation for Qwen3 thinking tags only at the moment.
8
+ """
9
+ # Remove either <think>...</think> tags or content from start to </think> if <think> is a prefix from the chat template.
10
+ result = re.sub(r"<think>.*?</think>|^.*?</think>", "", completion, count=1, flags=re.DOTALL)
11
+
12
+ return result.strip()
@@ -0,0 +1,19 @@
1
+ from adaptive_harmony.core.reward_client.client import RewardClient
2
+ from adaptive_harmony.core.reward_client.reward_types import (
3
+ MetadataValidationResponse,
4
+ Request,
5
+ Response,
6
+ ServerInfo,
7
+ Turn,
8
+ )
9
+ from adaptive_harmony.core.reward_client.websocket_utils import ResponseAccumulator
10
+
11
+ __all__ = [
12
+ "RewardClient",
13
+ "Turn",
14
+ "Request",
15
+ "Response",
16
+ "MetadataValidationResponse",
17
+ "ServerInfo",
18
+ "ResponseAccumulator",
19
+ ]
@@ -0,0 +1,160 @@
1
+ import asyncio
2
+ import json
3
+ from typing import Any, Final
4
+
5
+ import httpx
6
+ from httpx import Limits
7
+ from jsonschema import ValidationError as JsonSchemaValidationError
8
+ from jsonschema import validate
9
+ from loguru import logger
10
+ from websockets.asyncio.client import ClientConnection, connect
11
+
12
+ from adaptive_harmony.core.reward_client.reward_types import (
13
+ MetadataValidationResponse,
14
+ Request,
15
+ Response,
16
+ ServerInfo,
17
+ Turn,
18
+ )
19
+ from adaptive_harmony.core.reward_client.websocket_utils import ResponseAccumulator
20
+
21
+ SCORE_PATH: Final = "/score"
22
+ INFO_PATH: Final = "/info"
23
+ METADATA_SCHEMA_PATH: Final = "/metadata_schema"
24
+
25
+
26
+ async def read_task(client: ClientConnection, responses: dict[int, Response | asyncio.Event]):
27
+ response_accumulators: dict[int, ResponseAccumulator] = {}
28
+ while True:
29
+ try:
30
+ msg = await client.recv()
31
+ obj = json.loads(msg)
32
+ id = obj["id"]
33
+ if total_num_chunks := obj.get("total_num_chunks"):
34
+ assert id not in response_accumulators
35
+ response_accumulators[id] = ResponseAccumulator(total_num_chunks)
36
+ else:
37
+ assert (acc := response_accumulators.get(id))
38
+ acc.add_chunk(obj["chunk"])
39
+ if acc.is_complete():
40
+ del response_accumulators[id]
41
+ response = Response.model_validate_json(acc.get_full_data()) # type: ignore
42
+ # sanity check
43
+ assert response.id == id
44
+ event: asyncio.Event = responses.pop(id) # type: ignore
45
+ responses[id] = response
46
+ event.set()
47
+ except Exception as e:
48
+ logger.error(f"{e}")
49
+ break
50
+
51
+
52
+ class RewardClient:
53
+ def __init__(self, base_url: str, max_connections: int = 32, timeout: float | None = None):
54
+ if base_url.startswith("https://"):
55
+ self.use_secure_protocol = True
56
+ self.base_url = base_url.removeprefix("https://")
57
+ else:
58
+ assert base_url.startswith("http://"), f"Unknown url format {base_url}"
59
+ self.use_secure_protocol = False
60
+ self.base_url = base_url.removeprefix("http://")
61
+
62
+ self._client = httpx.AsyncClient(
63
+ headers=dict(),
64
+ base_url=self._get_http_url(),
65
+ timeout=timeout,
66
+ limits=Limits(max_connections=max_connections),
67
+ )
68
+ self._metadata_json_schema: None | dict[str, Any] = None
69
+ self.use_websocket = False
70
+ self.max_connections = max_connections
71
+
72
+ def _get_http_url(self):
73
+ if self.use_secure_protocol:
74
+ return f"https://{self.base_url}"
75
+ else:
76
+ return f"http://{self.base_url}"
77
+
78
+ def _get_ws_url(self):
79
+ if self.use_secure_protocol:
80
+ return f"wss://{self.base_url}/ws"
81
+ else:
82
+ return f"ws://{self.base_url}/ws"
83
+
84
+ async def setup(self):
85
+ await self.connect_websocket()
86
+
87
+ async def connect_websocket(self):
88
+ self.use_websocket = True
89
+ self.ws_client: ClientConnection = await connect(self._get_ws_url(), ping_timeout=None)
90
+ self.ws_responses: dict[int, Response | asyncio.Event] = dict()
91
+ logger.info("Spawning_read_task")
92
+ self.read_task = asyncio.create_task(read_task(self.ws_client, self.ws_responses))
93
+ self.request_id = 0
94
+ # no need to blast 2x more than the amount of workers
95
+ self.semaphore = asyncio.Semaphore(self.max_connections)
96
+ return self
97
+
98
+ async def drop_websocket(self):
99
+ assert self.use_websocket
100
+ logger.info("Cancelling_read_task")
101
+ self.read_task.cancel()
102
+ await self.ws_client.close()
103
+
104
+ async def _post(self, path: str, data: dict) -> httpx.Response:
105
+ response = await self._client.post(path, json=data)
106
+ response.raise_for_status()
107
+ return response
108
+
109
+ async def _ws_post(self, req: Request) -> Response:
110
+ async with self.semaphore:
111
+ request_id = self.request_id
112
+ self.request_id += 1
113
+ req.id = request_id
114
+ event = asyncio.Event()
115
+ self.ws_responses[request_id] = event
116
+ await self.ws_client.send(req.model_dump_json())
117
+ await event.wait()
118
+ response: Response = self.ws_responses.pop(request_id) # type: ignore
119
+ assert response.id == request_id
120
+ return response
121
+
122
+ async def score(self, req: Request) -> Response:
123
+ if not self.use_websocket:
124
+ response = await self._post(SCORE_PATH, req.model_dump())
125
+ return Response(**response.json())
126
+ else:
127
+ return await self._ws_post(req)
128
+
129
+ async def validate_metadata(self, metadata: dict[Any, Any]):
130
+ if self._metadata_json_schema is None:
131
+ response = await self._client.get(METADATA_SCHEMA_PATH)
132
+ self._metadata_json_schema = response.json()
133
+
134
+ try:
135
+ validate(instance=metadata, schema=self._metadata_json_schema) # type: ignore
136
+ return MetadataValidationResponse(is_valid=True)
137
+ except JsonSchemaValidationError as e:
138
+ return MetadataValidationResponse(is_valid=False, error_message=str(e))
139
+
140
+ async def info(self) -> ServerInfo:
141
+ response = await self._client.get(INFO_PATH)
142
+ response.raise_for_status()
143
+ return ServerInfo(**response.json())
144
+
145
+ def blocking_info(self) -> ServerInfo:
146
+ return ServerInfo(**httpx.get(self._client.base_url.join(INFO_PATH), timeout=self._client.timeout).json())
147
+
148
+
149
+ async def main():
150
+ # client = RewardClient("0.0.0.0:50056")
151
+ client = await RewardClient("0.0.0.0:50056").connect_websocket()
152
+ tasks = []
153
+ for _ in range(1024):
154
+ task = client.score(Request(turns=[Turn(role="assistant", content="hello")]))
155
+ tasks.append(task)
156
+ await asyncio.gather(*tasks)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ asyncio.run(main())
@@ -0,0 +1,49 @@
1
+ from typing import Any, TypeVar
2
+
3
+ from pydantic import BaseModel as PydanticBaseModel
4
+ from pydantic import ConfigDict
5
+
6
+
7
+ class BaseModel(PydanticBaseModel):
8
+ model_config = ConfigDict(extra="forbid")
9
+
10
+
11
+ MetadataType = TypeVar("MetadataType", bound=BaseModel)
12
+
13
+
14
+ class Turn(BaseModel):
15
+ """@public"""
16
+
17
+ role: str
18
+ content: str
19
+
20
+
21
+ class Request(BaseModel):
22
+ """@public"""
23
+
24
+ turns: list[Turn]
25
+ metadata: dict[str, Any] | None = None
26
+ id: int | None = None
27
+
28
+
29
+ class Response(BaseModel):
30
+ """@public"""
31
+
32
+ reward: float
33
+ metadata: dict[str, Any]
34
+ id: int | None = None
35
+
36
+
37
+ class MetadataValidationResponse(BaseModel):
38
+ """@public"""
39
+
40
+ is_valid: bool
41
+ error_message: str | None = None
42
+
43
+
44
+ class ServerInfo(BaseModel):
45
+ """@public"""
46
+
47
+ version: str
48
+ name: str
49
+ description: str
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class ResponseAccumulator:
6
+ total_num_chunks: int = 0
7
+ chunks_received: list[str] = field(default_factory=list)
8
+
9
+ def add_chunk(self, chunk: str):
10
+ self.chunks_received.append(chunk)
11
+
12
+ def is_complete(self) -> bool:
13
+ return len(self.chunks_received) == self.total_num_chunks
14
+
15
+ def get_full_data(self) -> str | None:
16
+ if self.is_complete():
17
+ return "".join(self.chunks_received)
18
+ return None