eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,271 @@
1
+ import importlib
2
+ import json
3
+ import logging
4
+ import os
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import uvicorn
8
+ from fastapi import FastAPI, HTTPException, Request
9
+ from pydantic import BaseModel, Field
10
+
11
+ from .models import EvaluateResult
12
+
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class Message(BaseModel):
19
+ """Model for a conversation message."""
20
+
21
+ role: str
22
+ content: str
23
+
24
+ class Config:
25
+ extra = "allow" # Allow extra fields
26
+
27
+
28
+ class RewardRequest(BaseModel):
29
+ """Request model for reward endpoints."""
30
+
31
+ messages: List[Message] = Field(..., description="List of conversation messages")
32
+ ground_truth: Optional[Union[str, List[Message]]] = Field(
33
+ None, description="Ground truth data (string or list of messages) for context"
34
+ )
35
+
36
+ class Config:
37
+ extra = "allow" # Allow extra fields for arbitrary kwargs
38
+
39
+
40
+ class RewardServer:
41
+ """
42
+ Server for hosting reward functions.
43
+
44
+ This class creates a FastAPI server that can host reward functions.
45
+
46
+ Args:
47
+ func_path: Path to the reward function to host (e.g., "module.path:function_name")
48
+ host: Host to bind the server to
49
+ port: Port to bind the server to
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ func_path: str,
55
+ host: str = "0.0.0.0",
56
+ port: int = 8000,
57
+ ):
58
+ self.func_path = func_path
59
+ self.host = host
60
+ self.port = port
61
+ self.app = FastAPI(title="Reward Function Server")
62
+
63
+ # Load the reward function
64
+ self.reward_func = self._load_function()
65
+
66
+ # Register the endpoints
67
+ self._setup_routes()
68
+
69
+ def _load_function(self):
70
+ """Load the reward function from the provided path."""
71
+ try:
72
+ if ":" not in self.func_path:
73
+ raise ValueError(f"Invalid func_path format: {self.func_path}, expected 'module.path:function_name'")
74
+
75
+ module_path, func_name = self.func_path.split(":", 1)
76
+ module = importlib.import_module(module_path)
77
+ func = getattr(module, func_name)
78
+
79
+ logger.info(f"Loaded reward function {func_name} from {module_path}")
80
+ return func
81
+ except (ImportError, AttributeError) as e:
82
+ raise ImportError(f"Failed to load function from path {self.func_path}: {str(e)}")
83
+
84
+ def _setup_routes(self):
85
+ """Set up the API routes."""
86
+
87
+ @self.app.get("/")
88
+ async def root():
89
+ """Get server info."""
90
+ return {
91
+ "status": "ok",
92
+ "reward_function": self.func_path,
93
+ "endpoints": ["/reward"],
94
+ }
95
+
96
+ @self.app.post("/reward")
97
+ async def reward(request: RewardRequest):
98
+ """
99
+ Get reward score for messages.
100
+
101
+ Args:
102
+ request: RewardRequest object with messages and optional parameters
103
+
104
+ Returns:
105
+ EvaluateResult object with score and metrics
106
+ """
107
+ try:
108
+ # Extract kwargs from the request
109
+ kwargs = request.dict(exclude={"messages", "ground_truth"})
110
+
111
+ # Set default for ground_truth if not provided and expected as list
112
+ ground_truth_data = request.ground_truth
113
+ if ground_truth_data is None:
114
+ # This default applies if ground_truth is expected to be a list of messages for context
115
+ ground_truth_data = request.messages[:-1] if request.messages else []
116
+
117
+ # Call the reward function
118
+ result = self.reward_func(
119
+ messages=request.messages,
120
+ ground_truth=ground_truth_data,
121
+ **kwargs,
122
+ )
123
+
124
+ # Handle different return types
125
+ # The self.reward_func is expected to be decorated by the new @reward_function,
126
+ # which returns a dictionary.
127
+ if isinstance(result, dict) and "score" in result:
128
+ return result
129
+ elif isinstance(result, EvaluateResult): # Should not happen if func is from new decorator
130
+ logger.warning("Reward function returned EvaluateResult object directly to server; expected dict.")
131
+ return result.model_dump()
132
+ elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple
133
+ logger.warning("Reward function returned legacy tuple format to server.")
134
+ score, components = result
135
+ return {"score": score, "metrics": components}
136
+ else:
137
+ raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}")
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error processing reward request: {str(e)}")
141
+ raise HTTPException(status_code=500, detail=str(e))
142
+
143
+ @self.app.get("/health")
144
+ async def health():
145
+ """Health check endpoint."""
146
+ return {"status": "ok"}
147
+
148
+ def run(self):
149
+ """Run the server."""
150
+ logger.info(f"Starting reward server on {self.host}:{self.port}")
151
+ uvicorn.run(self.app, host=self.host, port=self.port)
152
+
153
+
154
+ def serve(func_path: str, host: str = "0.0.0.0", port: int = 8000):
155
+ """
156
+ Serve a reward function as an HTTP API.
157
+
158
+ Args:
159
+ func_path: Path to the reward function to serve (e.g., "module.path:function_name")
160
+ host: Host to bind the server to
161
+ port: Port to bind the server to
162
+ """
163
+ server = RewardServer(func_path=func_path, host=host, port=port)
164
+ server.run()
165
+
166
+
167
+ # ngrok-based serve_tunnel is deprecated in favor of Serveo via subprocess_manager.
168
+ # def serve_tunnel(func_path: str, port: int = 8000):
169
+ # """
170
+ # Serve a reward function with an ngrok tunnel.
171
+ # DEPRECATED.
172
+ # """
173
+ # try:
174
+ # import pyngrok.ngrok as ngrok # type: ignore
175
+ # except ImportError:
176
+ # raise ImportError(
177
+ # "The 'pyngrok' package is required to use serve_tunnel. "
178
+ # "Please install it with 'pip install pyngrok'."
179
+ # )
180
+ #
181
+ # # Open the tunnel
182
+ # tunnel = ngrok.connect(port)
183
+ # public_url = tunnel.public_url
184
+ #
185
+ # # Print the tunnel URL
186
+ # logger.info(f"Reward function available at: {public_url}/reward")
187
+ #
188
+ # # Start the server
189
+ # serve(func_path=func_path, host="0.0.0.0", port=port)
190
+
191
+
192
+ def create_app(reward_func: Callable[..., EvaluateResult]) -> FastAPI:
193
+ """
194
+ Create a FastAPI app for the given reward function.
195
+
196
+ This function creates a FastAPI app that can be used to serve a reward function.
197
+ It's particularly useful for testing or when you want to manage the lifecycle
198
+ of the app yourself.
199
+
200
+ Args:
201
+ reward_func: The reward function to serve
202
+
203
+ Returns:
204
+ A FastAPI app instance
205
+ """
206
+ app = FastAPI(title="Reward Function Server")
207
+
208
+ @app.get("/")
209
+ async def root():
210
+ """Get server info."""
211
+ return {"status": "ok", "endpoints": ["/reward"]}
212
+
213
+ @app.post("/reward")
214
+ async def reward(request_data: RewardRequest):
215
+ """
216
+ Get reward score for messages.
217
+
218
+ Args:
219
+ request_data: RewardRequest object with messages and optional parameters
220
+
221
+ Returns:
222
+ EvaluateResult object with score and metrics
223
+ """
224
+ try:
225
+ # Convert Pydantic models to dictionaries using model_dump (Pydantic v2)
226
+ messages = [msg.model_dump() for msg in request_data.messages]
227
+ ground_truth_data: Optional[Union[str, List[Dict[str, Any]]]] = None
228
+
229
+ if isinstance(request_data.ground_truth, str):
230
+ ground_truth_data = request_data.ground_truth
231
+ elif isinstance(request_data.ground_truth, list):
232
+ ground_truth_data = [msg.model_dump() for msg in request_data.ground_truth]
233
+
234
+ # Extract kwargs from any extra fields
235
+ kwargs = {k: v for k, v in request_data.model_dump().items() if k not in ["messages", "ground_truth"]}
236
+
237
+ # Set default for ground_truth if not provided and expected as list
238
+ if ground_truth_data is None:
239
+ # This default applies if ground_truth is expected to be a list of messages for context
240
+ ground_truth_data = messages[:-1] if messages else []
241
+
242
+ # Call the reward function
243
+ result = reward_func(messages=messages, ground_truth=ground_truth_data, **kwargs)
244
+
245
+ # Handle different return types
246
+ # The reward_func is expected to be decorated by the new @reward_function,
247
+ # which returns a dictionary.
248
+ if isinstance(result, dict) and "score" in result:
249
+ return result
250
+ elif isinstance(result, EvaluateResult): # Should not happen if func is from new decorator
251
+ logger.warning(
252
+ "Reward function passed to create_app returned EvaluateResult object directly; expected dict after decoration."
253
+ )
254
+ return result.model_dump()
255
+ elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple
256
+ logger.warning("Reward function passed to create_app returned legacy tuple format.")
257
+ score, components = result
258
+ return {"score": score, "metrics": components}
259
+ else:
260
+ raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}")
261
+
262
+ except Exception as e:
263
+ logger.error(f"Error processing reward request: {str(e)}")
264
+ raise HTTPException(status_code=500, detail=str(e))
265
+
266
+ @app.get("/health")
267
+ async def health():
268
+ """Health check endpoint."""
269
+ return {"status": "ok"}
270
+
271
+ return app
@@ -0,0 +1,260 @@
1
+ import inspect
2
+ from functools import wraps
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ List,
8
+ Literal,
9
+ Optional,
10
+ Protocol,
11
+ TypeVar,
12
+ Union,
13
+ cast,
14
+ get_args,
15
+ get_origin,
16
+ )
17
+
18
+ from pydantic import TypeAdapter, ValidationError
19
+
20
+ # EvaluateResult and StepOutput are now extended/defined in models.py
21
+ from .models import ( # Removed StepOutput as it's not used here directly
22
+ EvaluateResult,
23
+ Message,
24
+ )
25
+
26
+ # Import resource types
27
+ from .resources import ResourceDict
28
+
29
+ _single_res_adapter = TypeAdapter(EvaluateResult)
30
+ _list_res_adapter = TypeAdapter(List[EvaluateResult])
31
+
32
+ # Define a type for the mode parameter
33
+ EvaluationMode = Literal["pointwise", "batch"]
34
+
35
+ # TypeVar for the function being decorated, to preserve its signature as much as possible.
36
+ F = TypeVar("F", bound=Callable[..., Any])
37
+
38
+
39
+ def reward_function(
40
+ _func: Optional[F] = None,
41
+ *,
42
+ mode: EvaluationMode = "pointwise",
43
+ id: Optional[str] = None,
44
+ requirements: Optional[List[str]] = None, # Changed to List[str]
45
+ resources: Optional[ResourceDict] = None, # Resource management
46
+ concurrency: Optional[int] = None,
47
+ timeout: Optional[int] = None,
48
+ ) -> Union[F, Callable[[F], F]]:
49
+ """
50
+ Decorator for user-defined reward and evaluation functions with resource management.
51
+
52
+ It handles:
53
+ - Coercing input messages (and ground truths if applicable) to Pydantic `Message` objects
54
+ if the decorated function is type-hinted to receive them. This part currently targets
55
+ parameters named 'messages' and 'ground_truth'.
56
+ - Validating that the output conforms to `EvaluateResult` (for pointwise) or `List[EvaluateResult]` (for batch).
57
+ - Managing declared resources (LLMs, databases, etc.) with automatic setup and cleanup
58
+
59
+ Args:
60
+ _func: The user's reward/evaluation function. Optional for decorator usage with args.
61
+ mode: Specifies the operational mode. Defaults to "pointwise".
62
+ - "pointwise": Function processes one rollout. Expected output: `EvaluateResult`.
63
+ - "batch": Function processes a batch of rollouts. Expected output: `List[EvaluateResult]`.
64
+ id: Optional identifier for the reward function, used for deployment
65
+ requirements: Optional string content for requirements.txt for deployment
66
+ resources: Optional dictionary of resource types to resource instances.
67
+ Example: {"llms": [llm_resource]}
68
+ Resources are automatically setup before evaluation and cleaned up after.
69
+ concurrency: Optional number of concurrent requests to the reward function. This will only take effect if the function is async or there are async resources binded to the reward function (e.g. LLM resource).
70
+ timeout: Optional timeout for the reward function. This will only take effect if the function is async or there are async resources binded to the reward function (e.g. LLM resource).
71
+
72
+ Returns:
73
+ A decorator if `_func` is None, or the decorated function.
74
+ """
75
+
76
+ def decorator(func: F) -> F:
77
+ sig = inspect.signature(func)
78
+ params = sig.parameters
79
+
80
+ # Validate that the function accepts **kwargs
81
+ has_var_keyword = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values())
82
+
83
+ if not has_var_keyword:
84
+ raise ValueError(
85
+ f"Function '{func.__name__}' must accept **kwargs parameter. "
86
+ f"Please add '**kwargs' to the function signature."
87
+ )
88
+
89
+ # Setup resources once when the decorator is applied
90
+ resource_managers = {}
91
+ if resources:
92
+ for resource_type, resource_list in resources.items():
93
+ managers = []
94
+ for resource in resource_list:
95
+ resource.setup()
96
+ managers.append(resource)
97
+ resource_managers[resource_type] = managers
98
+
99
+ # Detect if the user supplied function is a coroutine (async def)
100
+ _is_async_function = inspect.iscoroutinefunction(func)
101
+
102
+ def _prepare_final_args(*args: Any, **kwargs: Any):
103
+ """Prepare final positional and keyword arguments for the user function call.
104
+ This includes Pydantic coercion and resource injection. Returns a tuple of
105
+ (call_args, call_kwargs).
106
+ """
107
+ # Bind arguments to handle *args and **kwargs correctly for the wrapped function
108
+ bound_args = sig.bind_partial(*args, **kwargs)
109
+ bound_args.apply_defaults()
110
+
111
+ # Create a mutable copy of arguments to modify
112
+ final_func_args = dict(bound_args.arguments)
113
+
114
+ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Message]:
115
+ if not isinstance(data_list, list):
116
+ raise TypeError(f"Expected a list for '{arg_name_for_error}', got {type(data_list)}")
117
+ typed_list = []
118
+ for i, item_data in enumerate(data_list):
119
+ if isinstance(item_data, Message):
120
+ typed_list.append(item_data)
121
+ elif isinstance(item_data, dict):
122
+ typed_list.append(Message(**item_data))
123
+ else:
124
+ raise TypeError(f"Unexpected type for item {i} in '{arg_name_for_error}': {type(item_data)}")
125
+ return typed_list
126
+
127
+ # 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch)
128
+ if mode == "pointwise" and "messages" in params and "messages" in final_func_args:
129
+ messages_param_annotation = params["messages"].annotation
130
+ if (
131
+ get_origin(messages_param_annotation) in (list, List)
132
+ and get_args(messages_param_annotation)
133
+ and get_args(messages_param_annotation)[0] == Message
134
+ ):
135
+ try:
136
+ final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages")
137
+ except Exception as err:
138
+ raise ValueError(f"Input 'messages' failed Pydantic validation: {err}") from None
139
+
140
+ elif mode == "batch" and "rollouts_messages" in params and "rollouts_messages" in final_func_args:
141
+ param_annotation = params["rollouts_messages"].annotation
142
+ inner = get_args(param_annotation)[0] if get_args(param_annotation) else None
143
+ if get_origin(param_annotation) == list and inner and get_origin(inner) == list:
144
+ if get_args(inner) and get_args(inner)[0] == Message:
145
+ try:
146
+ coerced_rollouts = []
147
+ for i, rollout_data in enumerate(final_func_args["rollouts_messages"]):
148
+ coerced_rollouts.append(
149
+ _coerce_to_list_message(rollout_data, f"rollouts_messages[{i}]")
150
+ )
151
+ final_func_args["rollouts_messages"] = coerced_rollouts
152
+ except Exception as err:
153
+ raise ValueError(f"Input 'rollouts_messages' failed Pydantic validation: {err}") from None
154
+
155
+ # Ground truth coercion (if needed)
156
+ if "ground_truth" in params and "ground_truth" in final_func_args:
157
+ gt_ann = params["ground_truth"].annotation
158
+ if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message:
159
+ if final_func_args["ground_truth"] is not None:
160
+ try:
161
+ final_func_args["ground_truth"] = _coerce_to_list_message(
162
+ final_func_args["ground_truth"], "ground_truth"
163
+ )
164
+ except Exception as err:
165
+ raise ValueError(
166
+ f"Input 'ground_truth' failed Pydantic validation for List[Message]: {err}"
167
+ ) from None
168
+
169
+ # Inject resource clients into kwargs (resources are already setup)
170
+ if resource_managers:
171
+ final_func_args["resources"] = {
172
+ resource_type: [manager.get_client() for manager in managers]
173
+ for resource_type, managers in resource_managers.items()
174
+ }
175
+
176
+ # Call the author's function using the (potentially modified) arguments dictionary.
177
+ # final_func_args should contain all parameters expected by func, correctly mapped.
178
+ # Reconstruct args and kwargs for the call to func
179
+ call_args: List[Any] = []
180
+ call_kwargs: Dict[str, Any] = {}
181
+ for (
182
+ p_name,
183
+ p_obj,
184
+ ) in params.items(): # params from inspect.signature(func).parameters
185
+ if p_obj.kind == inspect.Parameter.VAR_POSITIONAL:
186
+ # If original func had *pos_args, final_func_args might contain it as a tuple
187
+ call_args.extend(final_func_args.get(p_name, ()))
188
+ elif p_obj.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs
189
+ # If original func had **kw_args, final_func_args contains the dict of these
190
+ call_kwargs.update(final_func_args.get(p_name, {}))
191
+ elif p_name in final_func_args: # Named parameters
192
+ if p_obj.kind == inspect.Parameter.POSITIONAL_ONLY:
193
+ call_args.append(final_func_args[p_name])
194
+ else: # POSITIONAL_OR_KEYWORD, KEYWORD_ONLY
195
+ call_kwargs[p_name] = final_func_args[p_name]
196
+
197
+ return call_args, call_kwargs
198
+
199
+ def _validate_output(result: Any):
200
+ if mode == "pointwise":
201
+ if isinstance(result, EvaluateResult):
202
+ return result
203
+ return _single_res_adapter.validate_python(result)
204
+ elif mode == "batch":
205
+ if isinstance(result, list) and all(isinstance(item, EvaluateResult) for item in result):
206
+ return result
207
+ return _list_res_adapter.validate_python(result)
208
+ else:
209
+ raise ValueError(f"Internal error: Invalid mode '{mode}' in wrapper.")
210
+
211
+ if _is_async_function:
212
+
213
+ @wraps(func)
214
+ async def async_wrapper(
215
+ *args: Any,
216
+ **kwargs: Any,
217
+ ) -> Union[EvaluateResult, List[EvaluateResult]]:
218
+ call_args, call_kwargs = _prepare_final_args(*args, **kwargs)
219
+ result = await func(*call_args, **call_kwargs) # type: ignore[misc]
220
+ try:
221
+ return _validate_output(result)
222
+ except ValidationError as err:
223
+ raise ValueError(
224
+ f"Return value from function '{func.__name__}' failed Pydantic validation for mode '{mode}':\n{err}"
225
+ ) from None
226
+
227
+ wrapper_fn = async_wrapper
228
+
229
+ else:
230
+
231
+ @wraps(func)
232
+ def sync_wrapper(
233
+ *args: Any,
234
+ **kwargs: Any,
235
+ ) -> Union[EvaluateResult, List[EvaluateResult]]:
236
+ call_args, call_kwargs = _prepare_final_args(*args, **kwargs)
237
+ result = func(*call_args, **call_kwargs)
238
+ try:
239
+ return _validate_output(result)
240
+ except ValidationError as err:
241
+ raise ValueError(
242
+ f"Return value from function '{func.__name__}' failed Pydantic validation for mode '{mode}':\n{err}"
243
+ ) from None
244
+
245
+ wrapper_fn = sync_wrapper
246
+
247
+ # Set attributes for introspection and deployment
248
+ wrapper_fn._reward_function_id = id # type: ignore[attr-defined]
249
+ wrapper_fn._reward_function_requirements = requirements # type: ignore[attr-defined]
250
+ wrapper_fn._reward_function_mode = mode # type: ignore[attr-defined]
251
+ wrapper_fn._reward_function_resources = resources # type: ignore[attr-defined]
252
+ wrapper_fn._reward_function_timeout = timeout # type: ignore[attr-defined]
253
+ wrapper_fn._reward_function_concurrency = concurrency # type: ignore[attr-defined]
254
+
255
+ return cast(F, wrapper_fn)
256
+
257
+ if _func is None: # Decorator called with arguments, e.g., @reward_function(mode="batch")
258
+ return decorator
259
+ else: # Decorator called without arguments, e.g., @reward_function (defaults to pointwise)
260
+ return decorator(_func)
@@ -0,0 +1,8 @@
1
+ # This file makes the 'utils' directory a Python package.
2
+
3
+ # You can selectively expose functions or classes from modules within 'utils' here
4
+ # for easier access, e.g.:
5
+ # from .dataset_helpers import load_jsonl_to_hf_dataset
6
+
7
+ # For now, allow direct import of modules like:
8
+ # from eval_protocol.utils.dataset_helpers import ...