hud-python 0.4.45__py3-none-any.whl → 0.4.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/rl/config.py +1 -0
- hud/rl/learner.py +20 -10
- hud/rl/train.py +12 -0
- hud/tools/base.py +37 -1
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.45.dist-info → hud_python-0.4.46.dist-info}/METADATA +1 -1
- {hud_python-0.4.45.dist-info → hud_python-0.4.46.dist-info}/RECORD +11 -11
- {hud_python-0.4.45.dist-info → hud_python-0.4.46.dist-info}/WHEEL +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.4.46.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.4.46.dist-info}/licenses/LICENSE +0 -0
hud/rl/config.py
CHANGED
hud/rl/learner.py
CHANGED
|
@@ -146,17 +146,27 @@ class GRPOLearner:
|
|
|
146
146
|
policy.gradient_checkpointing_enable()
|
|
147
147
|
self.log("Gradient checkpointing enabled for memory efficiency")
|
|
148
148
|
|
|
149
|
-
# Add LoRA adapters
|
|
150
|
-
lora_config = LoraConfig(
|
|
151
|
-
r=model_cfg.lora_r,
|
|
152
|
-
lora_alpha=model_cfg.lora_alpha,
|
|
153
|
-
lora_dropout=model_cfg.lora_dropout,
|
|
154
|
-
task_type="CAUSAL_LM",
|
|
155
|
-
bias="none",
|
|
156
|
-
target_modules=list(model_cfg.target_modules),
|
|
157
|
-
)
|
|
149
|
+
# Add LoRA adapters or load existing adapter
|
|
158
150
|
policy.config.use_cache = False
|
|
159
|
-
|
|
151
|
+
|
|
152
|
+
if model_cfg.adapter_path:
|
|
153
|
+
# Load existing adapter as baseline
|
|
154
|
+
self.log(f"Loading existing LoRA adapter from: {model_cfg.adapter_path}")
|
|
155
|
+
from peft import PeftModel
|
|
156
|
+
policy = PeftModel.from_pretrained(policy, model_cfg.adapter_path)
|
|
157
|
+
# Enable adapter training
|
|
158
|
+
policy.train()
|
|
159
|
+
else:
|
|
160
|
+
# Create new LoRA adapter
|
|
161
|
+
lora_config = LoraConfig(
|
|
162
|
+
r=model_cfg.lora_r,
|
|
163
|
+
lora_alpha=model_cfg.lora_alpha,
|
|
164
|
+
lora_dropout=model_cfg.lora_dropout,
|
|
165
|
+
task_type="CAUSAL_LM",
|
|
166
|
+
bias="none",
|
|
167
|
+
target_modules=list(model_cfg.target_modules),
|
|
168
|
+
)
|
|
169
|
+
policy = get_peft_model(policy, lora_config)
|
|
160
170
|
|
|
161
171
|
# Wrap with DDP if in distributed mode
|
|
162
172
|
if self.world_size > 1:
|
hud/rl/train.py
CHANGED
|
@@ -95,6 +95,18 @@ async def train(config: Config, tasks: list[Task]) -> None:
|
|
|
95
95
|
if is_main_process()
|
|
96
96
|
else None
|
|
97
97
|
)
|
|
98
|
+
|
|
99
|
+
# Load initial adapter if provided
|
|
100
|
+
if is_main_process() and config.model.adapter_path and vllm:
|
|
101
|
+
hud_console.info(f"Loading baseline adapter from: {config.model.adapter_path}")
|
|
102
|
+
success = vllm.load_adapter(config.model.base_model, config.model.adapter_path)
|
|
103
|
+
if success and actor is not None:
|
|
104
|
+
hud_console.info("Successfully loaded baseline adapter as 'base_model'")
|
|
105
|
+
# Update actor to use the loaded adapter
|
|
106
|
+
actor.update_adapter(config.model.base_model)
|
|
107
|
+
else:
|
|
108
|
+
hud_console.error("Failed to load baseline adapter")
|
|
109
|
+
exit(1)
|
|
98
110
|
|
|
99
111
|
# Training state
|
|
100
112
|
step = 0
|
hud/tools/base.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import TYPE_CHECKING, Any, cast
|
|
4
|
+
from typing import TYPE_CHECKING, Any, cast, Awaitable
|
|
5
5
|
|
|
6
6
|
from fastmcp import FastMCP
|
|
7
7
|
|
|
@@ -16,6 +16,8 @@ if TYPE_CHECKING:
|
|
|
16
16
|
# Basic result types for tools
|
|
17
17
|
BaseResult = list[ContentBlock] | EvaluationResult
|
|
18
18
|
|
|
19
|
+
import logging
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
19
21
|
|
|
20
22
|
class BaseTool(ABC):
|
|
21
23
|
"""
|
|
@@ -58,6 +60,10 @@ class BaseTool(ABC):
|
|
|
58
60
|
self.title = title or self.__class__.__name__.replace("Tool", "").replace("_", " ").title()
|
|
59
61
|
self.description = description or (self.__doc__.strip() if self.__doc__ else None)
|
|
60
62
|
self.meta = meta
|
|
63
|
+
self._callbacks: dict[
|
|
64
|
+
str,
|
|
65
|
+
list[Callable[..., Awaitable[Any]]],
|
|
66
|
+
] = {} # {"event_name": [callback_functions]}
|
|
61
67
|
|
|
62
68
|
# Expose attributes FastMCP expects when registering an instance directly
|
|
63
69
|
self.__name__ = self.name # FastMCP uses fn.__name__ if name param omitted
|
|
@@ -100,6 +106,36 @@ class BaseTool(ABC):
|
|
|
100
106
|
)
|
|
101
107
|
return self._mcp_tool
|
|
102
108
|
|
|
109
|
+
def add_callback(self, event_type: str, callback: Callable[..., Awaitable[Any]]):
|
|
110
|
+
"""Register a callback function for specific event
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
event_type: (Required) Specific event name to trigger callback
|
|
114
|
+
e.g. "after_click", "before_navigate"
|
|
115
|
+
callback: (Required) Async function to call. Must be defined by `async def f(...)`
|
|
116
|
+
"""
|
|
117
|
+
if event_type not in self._callbacks:
|
|
118
|
+
self._callbacks[event_type] = []
|
|
119
|
+
self._callbacks[event_type].append(callback)
|
|
120
|
+
|
|
121
|
+
def remove_callback(self, event_type: str, callback: Callable[..., Awaitable[Any]]):
|
|
122
|
+
"""Remove a registered callback
|
|
123
|
+
Args:
|
|
124
|
+
event_type: (Required) Specific event name to trigger callback
|
|
125
|
+
e.g. "after_click", "before_navigate"
|
|
126
|
+
callback: (Required) Function to remove from callback list.
|
|
127
|
+
"""
|
|
128
|
+
if (event_type in self._callbacks) and (callback in self._callbacks[event_type]):
|
|
129
|
+
self._callbacks[event_type].remove(callback)
|
|
130
|
+
|
|
131
|
+
async def _trigger_callbacks(self, event_type: str, **kwargs):
|
|
132
|
+
"""Trigger all registered callback functions of an event type"""
|
|
133
|
+
callback_list = self._callbacks.get(event_type, [])
|
|
134
|
+
for callback in callback_list:
|
|
135
|
+
try:
|
|
136
|
+
await callback(**kwargs)
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.warning(f"Callback failed for {event_type}: {e}")
|
|
103
139
|
|
|
104
140
|
# Prefix for internal tool names
|
|
105
141
|
_INTERNAL_PREFIX = "int_"
|
hud/utils/tests/test_version.py
CHANGED
hud/version.py
CHANGED
|
@@ -2,7 +2,7 @@ hud/__init__.py,sha256=JMDFUE1pP0J1Xl_miBdt7ERvoffZmTzSFe8yxz512A8,552
|
|
|
2
2
|
hud/__main__.py,sha256=YR8Dq8OhINOsVfQ55PmRXXg4fEK84Rt_-rMtJ5rvhWo,145
|
|
3
3
|
hud/settings.py,sha256=disObWa-DgXzoDcCDp3y1dTPaNsbR0IvoMJL9Eg4zyo,3947
|
|
4
4
|
hud/types.py,sha256=RVwfx9rIF-D6P5HPwz9WuCzcbNhWHd_wId4uqanjah4,11170
|
|
5
|
-
hud/version.py,sha256=
|
|
5
|
+
hud/version.py,sha256=aha9n6Uks_Ql6r4xnI3U-csrKn4jndncgvM0Ko-l91c,105
|
|
6
6
|
hud/agents/__init__.py,sha256=UoIkljWdbq4bM0LD-mSaw6w826EqdEjOk7r6glNYwYQ,286
|
|
7
7
|
hud/agents/base.py,sha256=_u1zR3gXzZ1RlTCUYdMcvgHqdJBC4-AB1lZt0yBx8lg,35406
|
|
8
8
|
hud/agents/claude.py,sha256=TGhm5gE2ltINDAdEsDxKuT9iGMQ5G87R6kmabU3KPt8,16101
|
|
@@ -121,10 +121,10 @@ hud/rl/__init__.py,sha256=yYL7U1WV6L3mr3Hig48-4lhnryTaWj4nCXm4lG5vrYI,25
|
|
|
121
121
|
hud/rl/actor.py,sha256=H6gwRGRY1YpkOyiaJ9yai8yQwcI-Gx0dFxd18jpLx_Q,6950
|
|
122
122
|
hud/rl/buffer.py,sha256=z47HOjOBJx3umUzzUfdtq_N4ZoJ8FMBPkX8YQKBtd3A,15457
|
|
123
123
|
hud/rl/chat_template.jinja,sha256=XTdzI8oFGEcSA-exKxyHaprwRDmX5Am1KEb0VxvUc6U,4965
|
|
124
|
-
hud/rl/config.py,sha256=
|
|
124
|
+
hud/rl/config.py,sha256=sCU56mjtgJpu_C0TXqpT14v1LmZv0ntmUjgNkFamTPA,5713
|
|
125
125
|
hud/rl/distributed.py,sha256=Mr3NEj3rbS9FgpHofC_GrqpkvNQSpPFOqLQc2NXPNXs,3678
|
|
126
|
-
hud/rl/learner.py,sha256=
|
|
127
|
-
hud/rl/train.py,sha256
|
|
126
|
+
hud/rl/learner.py,sha256=xlCF5eJkeUIwhGErlv8YnCN1l4UFYrE4oSSLIQWWyx0,27230
|
|
127
|
+
hud/rl/train.py,sha256=0FScXz-5mCrL7H-auipZoVfeI43IrJMR5rrLz_iOGg4,15593
|
|
128
128
|
hud/rl/types.py,sha256=lrLKo7iaqodYth2EyeuOQfLiuzXfYM2eJjPmpObrD7c,3965
|
|
129
129
|
hud/rl/utils.py,sha256=IsgVUUibxnUzb32a4mu1sYrgJC1CwoG9E-Dd5y5VDOA,19115
|
|
130
130
|
hud/rl/vllm_adapter.py,sha256=2wnTfoXPI4C9EzhVxk0GU-ArLjX7hgXS0BndMwN8Ppg,4751
|
|
@@ -163,7 +163,7 @@ hud/telemetry/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSu
|
|
|
163
163
|
hud/telemetry/tests/test_replay.py,sha256=eREc6qgSJDRT1pOPdyhiEoEJ9H2yT1ospaU1RvTKlvg,1328
|
|
164
164
|
hud/telemetry/tests/test_trace.py,sha256=0rxR77CjcStat3ILA9QAswieOJ3J_386QmjmNDp34oA,2486
|
|
165
165
|
hud/tools/__init__.py,sha256=i6lE0GxYcPnlLLd-55ryCCHo7o9anC4RfqkuYFXvzMQ,1009
|
|
166
|
-
hud/tools/base.py,sha256=
|
|
166
|
+
hud/tools/base.py,sha256=KJfkhwWV6IQKBW1kc5yw1YMJSUSUifHgXXHN0NMANFw,17517
|
|
167
167
|
hud/tools/bash.py,sha256=LJViMGb3lTGBm_gequVVTM7ySh1Xh9bOOIZXU29Lmrw,5209
|
|
168
168
|
hud/tools/edit.py,sha256=N0AYFXp07-vAJy2li7lvHOL6hfgJOU4LL3iLSZrbRWU,12745
|
|
169
169
|
hud/tools/playwright.py,sha256=iyMrQ-ZKyeFia2fBp0yguXswTcXfGqdZcTXXCfUupFU,14988
|
|
@@ -219,10 +219,10 @@ hud/utils/tests/test_init.py,sha256=2QLQSGgyP9wJhOvPCusm_zjJad0qApOZi1BXpxcdHXQ,
|
|
|
219
219
|
hud/utils/tests/test_mcp.py,sha256=0pUa16mL-bqbZDXp5NHBnt1gO5o10BOg7zTMHZ1DNPM,4023
|
|
220
220
|
hud/utils/tests/test_progress.py,sha256=QSF7Kpi03Ff_l3mAeqW9qs1nhK50j9vBiSobZq7T4f4,7394
|
|
221
221
|
hud/utils/tests/test_telemetry.py,sha256=5jl7bEx8C8b-FfFUko5pf4UY-mPOR-9HaeL98dGtVHM,2781
|
|
222
|
-
hud/utils/tests/test_version.py,sha256=
|
|
222
|
+
hud/utils/tests/test_version.py,sha256=_sCmpdXghujnfjw34TWJs-QsalOI2Yl0pSMqhfdFKio,160
|
|
223
223
|
hud/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
224
|
-
hud_python-0.4.
|
|
225
|
-
hud_python-0.4.
|
|
226
|
-
hud_python-0.4.
|
|
227
|
-
hud_python-0.4.
|
|
228
|
-
hud_python-0.4.
|
|
224
|
+
hud_python-0.4.46.dist-info/METADATA,sha256=HD0Epvlb5lMuTxSGxJnVGdmfHeBIcn-hFgs1BOdpe84,22275
|
|
225
|
+
hud_python-0.4.46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
226
|
+
hud_python-0.4.46.dist-info/entry_points.txt,sha256=jJbodNFg1m0-CDofe5AHvB4zKBq7sSdP97-ohaQ3ae4,63
|
|
227
|
+
hud_python-0.4.46.dist-info/licenses/LICENSE,sha256=yIzBheVUf86FC1bztAcr7RYWWNxyd3B-UJQ3uddg1HA,1078
|
|
228
|
+
hud_python-0.4.46.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|