reasonbench 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. reasonbench/__init__.py +126 -0
  2. reasonbench/datasets.py +53 -0
  3. reasonbench/methods/__init__.py +9 -0
  4. reasonbench/methods/cot.py +55 -0
  5. reasonbench/methods/cot_sc.py +54 -0
  6. reasonbench/methods/foa.py +127 -0
  7. reasonbench/methods/got.py +123 -0
  8. reasonbench/methods/io.py +63 -0
  9. reasonbench/methods/rap.py +214 -0
  10. reasonbench/methods/react.py +54 -0
  11. reasonbench/methods/tot_bfs.py +93 -0
  12. reasonbench/methods/tot_dfs.py +169 -0
  13. reasonbench/models/__init__.py +4 -0
  14. reasonbench/models/anthropic.py +58 -0
  15. reasonbench/models/api.py +176 -0
  16. reasonbench/models/online.py +228 -0
  17. reasonbench/models/qroq.py +132 -0
  18. reasonbench/tasks/__init__.py +8 -0
  19. reasonbench/tasks/game24/__init__.py +3 -0
  20. reasonbench/tasks/game24/agents.py +440 -0
  21. reasonbench/tasks/game24/benchmark.py +51 -0
  22. reasonbench/tasks/game24/environment.py +85 -0
  23. reasonbench/tasks/game24/prompts.py +393 -0
  24. reasonbench/tasks/game24/state.py +49 -0
  25. reasonbench/tasks/hle/__init__.py +3 -0
  26. reasonbench/tasks/hle/agents.py +332 -0
  27. reasonbench/tasks/hle/benchmark.py +128 -0
  28. reasonbench/tasks/hle/environment.py +188 -0
  29. reasonbench/tasks/hle/prompts.py +200 -0
  30. reasonbench/tasks/hle/state.py +80 -0
  31. reasonbench/tasks/hotpotqa/__init__.py +3 -0
  32. reasonbench/tasks/hotpotqa/agents.py +411 -0
  33. reasonbench/tasks/hotpotqa/benchmark.py +104 -0
  34. reasonbench/tasks/hotpotqa/environment.py +123 -0
  35. reasonbench/tasks/hotpotqa/prompts.py +489 -0
  36. reasonbench/tasks/hotpotqa/state.py +61 -0
  37. reasonbench/tasks/humaneval/__init__.py +3 -0
  38. reasonbench/tasks/humaneval/agents.py +361 -0
  39. reasonbench/tasks/humaneval/benchmark.py +53 -0
  40. reasonbench/tasks/humaneval/environment.py +503 -0
  41. reasonbench/tasks/humaneval/prompts.py +152 -0
  42. reasonbench/tasks/humaneval/state.py +62 -0
  43. reasonbench/tasks/logiqa/__init__.py +3 -0
  44. reasonbench/tasks/logiqa/agents.py +258 -0
  45. reasonbench/tasks/logiqa/benchmark.py +81 -0
  46. reasonbench/tasks/logiqa/environment.py +78 -0
  47. reasonbench/tasks/logiqa/prompts.py +211 -0
  48. reasonbench/tasks/logiqa/state.py +66 -0
  49. reasonbench/tasks/matharena/__init__.py +0 -0
  50. reasonbench/tasks/matharena/agents.py +154 -0
  51. reasonbench/tasks/matharena/benchmark.py +105 -0
  52. reasonbench/tasks/matharena/environment.py +110 -0
  53. reasonbench/tasks/matharena/prompts.py +79 -0
  54. reasonbench/tasks/matharena/state.py +66 -0
  55. reasonbench/tasks/scibench/__init__.py +3 -0
  56. reasonbench/tasks/scibench/agents.py +460 -0
  57. reasonbench/tasks/scibench/benchmark.py +89 -0
  58. reasonbench/tasks/scibench/environment.py +103 -0
  59. reasonbench/tasks/scibench/prompts.py +242 -0
  60. reasonbench/tasks/scibench/state.py +61 -0
  61. reasonbench/tasks/sonnetwriting/__init__.py +3 -0
  62. reasonbench/tasks/sonnetwriting/agents.py +338 -0
  63. reasonbench/tasks/sonnetwriting/benchmark.py +92 -0
  64. reasonbench/tasks/sonnetwriting/environment.py +302 -0
  65. reasonbench/tasks/sonnetwriting/prompts.py +382 -0
  66. reasonbench/tasks/sonnetwriting/state.py +52 -0
  67. reasonbench/typedefs.py +154 -0
  68. reasonbench/utils.py +430 -0
  69. reasonbench-0.0.1.dist-info/METADATA +242 -0
  70. reasonbench-0.0.1.dist-info/RECORD +73 -0
  71. reasonbench-0.0.1.dist-info/WHEEL +5 -0
  72. reasonbench-0.0.1.dist-info/licenses/LICENSE +21 -0
  73. reasonbench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,126 @@
1
+ from typing import TypedDict
2
+ from .typedefs import DecodingParameters
3
+ from .datasets import get_dataset_path
4
+
5
+ class BenchmarkFactory:
6
+ registry = {}
7
+
8
+ @classmethod
9
+ def register(cls, benchmark_cls):
10
+ cls.registry[benchmark_cls.__name__.lower()] = benchmark_cls
11
+ return benchmark_cls
12
+
13
+ @classmethod
14
+ def get(cls, task: str, *args, **kwargs):
15
+ key = f"benchmark{task}".lower()
16
+ try:
17
+ path = get_dataset_path(task)
18
+ return cls.registry[key](path=path, *args, **kwargs)
19
+ except KeyError:
20
+ raise ValueError(f"No benchmark found for task={task}")
21
+
22
+ class EnvironmentFactory:
23
+ registry = {}
24
+
25
+ @classmethod
26
+ def register(cls, env_cls):
27
+ cls.registry[env_cls.__name__.lower()] = env_cls
28
+ return env_cls
29
+
30
+ @classmethod
31
+ def get(cls, task: str, *args, **kwargs):
32
+ key = f"environment{task}".lower()
33
+ try:
34
+ return cls.registry[key](*args, **kwargs)
35
+ except KeyError:
36
+ raise ValueError(f"No environment found for task={task}")
37
+
38
+ class AgentFactory:
39
+ registry = {}
40
+
41
+ @classmethod
42
+ def register(cls, agent_cls):
43
+ cls.registry[agent_cls.__name__.lower()] = agent_cls
44
+ return agent_cls
45
+
46
+ @classmethod
47
+ def get(cls, agent_type: str, benchmark: str, *args, **kwargs):
48
+ key = f"agent{agent_type}{benchmark}".lower()
49
+ try:
50
+ return cls.registry[key]#(*args, **kwargs) : Not initialized
51
+ except KeyError:
52
+ raise ValueError(f"No agent found for type={agent_type}, benchmark={benchmark}")
53
+
54
+ class AgentDictFactory:
55
+ registry = {}
56
+
57
+ @classmethod
58
+ def register(cls, agent_dict_cls):
59
+ cls.registry[agent_dict_cls.__name__.lower()] = agent_dict_cls
60
+ return agent_dict_cls
61
+
62
+ @classmethod
63
+ def get(cls, method: str, *args, **kwargs):
64
+ key = f"agentdict{method}".lower()
65
+ try:
66
+ return cls.registry[key](*args, **kwargs)
67
+ except KeyError:
68
+ raise ValueError(f"No agent dict found for method={method}")
69
+
70
+
71
+ class MethodFactory:
72
+ registry = {}
73
+
74
+ @classmethod
75
+ def register(cls, method_cls):
76
+ cls.registry[method_cls.__name__.lower()] = method_cls
77
+ return method_cls
78
+
79
+ @classmethod
80
+ def get(cls, method: str, benchmark: str, params: DecodingParameters, *args, **kwargs):
81
+ key = f"method{method}".lower()
82
+
83
+
84
+ if method == "io":
85
+ agents = {
86
+ "step": AgentFactory.get("io", benchmark),
87
+ }
88
+ elif method in ["cot", "cot_sc"]:
89
+ agents = {
90
+ "step": AgentFactory.get("cot", benchmark),
91
+ }
92
+ elif method == "foa":
93
+ agents = {
94
+ "step": AgentFactory.get("act", benchmark),
95
+ "evaluate": AgentFactory.get("evaluate", benchmark),
96
+ }
97
+ elif method in ["tot_bfs", "tot_dfs"]:
98
+ agents = {
99
+ "step": AgentFactory.get("bfs", benchmark),
100
+ "evaluate": AgentFactory.get("evaluate", benchmark),
101
+ }
102
+ elif method == "got":
103
+ agents = {
104
+ "step": AgentFactory.get("act", benchmark),
105
+ "aggregate": AgentFactory.get("aggregate", benchmark),
106
+ "evaluate": AgentFactory.get("evaluate", benchmark),
107
+ }
108
+ elif method == "rap":
109
+ agents = {
110
+ "step": AgentFactory.get("react", benchmark),
111
+ "evaluate": AgentFactory.get("selfevaluate", benchmark),
112
+ }
113
+ elif method == "react":
114
+ agents = {
115
+ "step": AgentFactory.get("react", benchmark),
116
+ }
117
+ else:
118
+ raise NotImplementedError(f"Method {method} is not implemented yet.")
119
+
120
+ # For the moment, only supporting same params for all agents
121
+ agents.update({k+"_params": params for k in agents.keys()})
122
+
123
+ try:
124
+ return cls.registry[key](agents=agents, *args, **kwargs)
125
+ except KeyError:
126
+ raise ValueError(f"No method found for name={method}")
@@ -0,0 +1,53 @@
1
+ """Utility for downloading datasets from HuggingFace Hub."""
2
+
3
+ import os
4
+ import logging
5
+ from pathlib import Path
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ HF_REPO_ID = "potamitisn/ReasonBench"
11
+ LOCAL_DATASETS_DIR = Path("datasets")
12
+
13
+ # Maps task name to the actual filename on HF Hub
14
+ DATASET_FILES = {
15
+ "game24": "dataset_game24.csv.gz",
16
+ "hle": "dataset_hle.jsonl.gz",
17
+ "hotpotqa": "dataset_hotpotqa.csv.gz",
18
+ "humaneval": "dataset_humaneval.csv.gz",
19
+ "logiqa": "dataset_logiqa.csv.gz",
20
+ "matharena": "dataset_matharena.jsonl.gz",
21
+ "scibench": "dataset_scibench.csv.gz",
22
+ "sonnetwriting": "dataset_sonnetwriting.jsonl.gz",
23
+ }
24
+
25
+
26
+ def get_dataset_path(task: str, local_dir: str | Path = LOCAL_DATASETS_DIR) -> str:
27
+ """Return the local path to a dataset file, downloading from HF Hub if needed.
28
+
29
+ Args:
30
+ task: Task name (e.g., "game24", "hle").
31
+ local_dir: Local directory to store/look for datasets.
32
+
33
+ Returns:
34
+ Path to the local dataset file.
35
+ """
36
+ if task not in DATASET_FILES:
37
+ raise ValueError(f"Unknown task: {task}. Available: {list(DATASET_FILES.keys())}")
38
+
39
+ filename = DATASET_FILES[task]
40
+ local_path = Path(local_dir) / filename
41
+
42
+ if local_path.exists():
43
+ return str(local_path)
44
+
45
+ logger.info(f"Dataset for '{task}' not found locally. Downloading from HuggingFace Hub...")
46
+ downloaded = hf_hub_download(
47
+ repo_id=HF_REPO_ID,
48
+ filename=filename,
49
+ repo_type="dataset",
50
+ local_dir=str(local_dir),
51
+ )
52
+ logger.info(f"Downloaded to {downloaded}")
53
+ return str(downloaded)
@@ -0,0 +1,9 @@
1
+ from .foa import AgentDictFOA, MethodFOA
2
+ from .tot_bfs import AgentDictTOT, MethodTOT_BFS
3
+ from .tot_dfs import AgentDictTOT, MethodTOT_DFS
4
+ from .got import AgentDictGOT, MethodGOT
5
+ from .rap import MethodRAP, AgentDictRAP
6
+ from .react import AgentDictReact, MethodReact
7
+ from .io import AgentDictIO, MethodIO
8
+ from .cot import AgentDictCOT, MethodCOT
9
+ from .cot_sc import AgentDictCoT, MethodCOT_SC
@@ -0,0 +1,55 @@
1
+ import random
2
+ import logging
3
+ import asyncio
4
+ from typing import TypedDict
5
+ from omegaconf import OmegaConf
6
+ from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
7
+ from .. import MethodFactory, AgentDictFactory
8
+ from ..utils import Resampler
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @AgentDictFactory.register
12
+ class AgentDictCOT(TypedDict):
13
+ step: Agent # ActAgent
14
+ step_params: DecodingParameters
15
+
16
+
17
+ @MethodFactory.register
18
+ class MethodCOT(Method):
19
+ def __init__(self,
20
+ model: Model,
21
+ agents: AgentDictCOT,
22
+ env: Environment,
23
+ config: OmegaConf,
24
+ n: int = 1
25
+ ):
26
+ super().__init__(model, agents, env, config)
27
+
28
+ self.step_agent = agents["step"]
29
+ self.step_params = agents["step_params"]
30
+
31
+ self.n = config.n
32
+ assert self.n == 1, "CoT has only 1 output"
33
+
34
+ async def solve(self, idx: int, state: State, namespace: str, value_cache: dict=None):
35
+ randomness = idx
36
+ random.seed(randomness)
37
+
38
+ states = [state.clone(randomness=random.randint(0, MAX_SEED)) for _ in range(self.n)]
39
+
40
+ action_coroutines = [
41
+ self.step_agent.act(
42
+ model=self.model,
43
+ state=state,
44
+ n=1,
45
+ namespace=namespace,
46
+ request_id=f"idx{idx}-{hash(state)}-agent{i}",
47
+ params=self.step_params
48
+ )
49
+ for i, state in enumerate(states)
50
+ ]
51
+ actions = await asyncio.gather(*action_coroutines)
52
+
53
+ # Execute the actions
54
+ states = [self.env.step(state, action[0]) for state, action in zip(states, actions)]
55
+ return states
@@ -0,0 +1,54 @@
1
+ import random
2
+ import logging
3
+ import asyncio
4
+ from typing import TypedDict
5
+ from collections import Counter
6
+ from omegaconf import OmegaConf
7
+ from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
8
+ from .. import MethodFactory, AgentDictFactory
9
+ from ..utils import Resampler
10
+ logger = logging.getLogger(__name__)
11
+
12
+ @AgentDictFactory.register
13
+ class AgentDictCoT(TypedDict):
14
+ step: Agent # ActAgent
15
+ step_params: DecodingParameters
16
+
17
+
18
+ @MethodFactory.register
19
+ class MethodCOT_SC(Method):
20
+ def __init__(self,
21
+ model: Model,
22
+ agents: AgentDictCoT,
23
+ env: Environment,
24
+ config: OmegaConf,
25
+ ):
26
+ super().__init__(model, agents, env, config)
27
+
28
+ self.step_agent = agents["step"]
29
+ self.step_params = agents["step_params"]
30
+
31
+ self.n = config.n
32
+ assert self.n > 1, "CoT-SC needs at least 2 outputs"
33
+
34
+
35
+ async def solve(self, idx: int, state: State, namespace: str, value_cache: dict=None):
36
+ randomness = idx
37
+ random.seed(randomness)
38
+
39
+ actions = await self.step_agent.act(
40
+ model=self.model,
41
+ state=state,
42
+ n=self.n,
43
+ namespace=namespace,
44
+ request_id=f"idx{idx}-{hash(state)}",
45
+ params=self.step_params
46
+ )
47
+
48
+ votes = [action for action in actions]
49
+ counts = Counter(votes)
50
+ most_common_action = counts.most_common(1)[0][0]
51
+ state = self.env.step(state, most_common_action)
52
+ return [state]
53
+
54
+
@@ -0,0 +1,127 @@
1
+ import random
2
+ import logging
3
+ import asyncio
4
+ from typing import TypedDict
5
+ from omegaconf import OmegaConf
6
+ from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
7
+ from .. import MethodFactory, AgentDictFactory
8
+ from ..utils import Resampler
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @AgentDictFactory.register
12
+ class AgentDictFOA(TypedDict):
13
+ step: Agent # ActAgent
14
+ evaluate: Agent # EvaluateAgent
15
+ step_params: DecodingParameters
16
+ evaluate_params: DecodingParameters
17
+
18
+ @MethodFactory.register
19
+ class MethodFOA(Method):
20
+ def __init__(self,
21
+ agents: AgentDictFOA,
22
+ model: Model,
23
+ env: Environment,
24
+ config: OmegaConf
25
+ ):
26
+ super().__init__(model, agents, env, config)
27
+
28
+ self.step_agent = agents["step"]
29
+ self.eval_agent = agents["evaluate"]
30
+
31
+ self.step_params = agents["step_params"]
32
+ self.evaluate_params = agents["evaluate_params"]
33
+
34
+ self.num_agents = config.num_agents
35
+ self.num_steps = config.num_steps
36
+ self.k = config.k
37
+ self.backtrack = config.backtrack
38
+ self.resampling = config.resampling
39
+ self.origin = config.origin
40
+ self.min_steps = config.min_steps
41
+ self.num_evaluations = config.num_evaluations
42
+
43
+
44
+ async def solve(self, idx: int, state: State, namespace: str, value_cache: dict = None):
45
+ randomness = idx
46
+ random.seed(randomness)
47
+ resampler = Resampler(randomness)
48
+
49
+ # Records of previously visited states (state_identifier, state_value, state)
50
+ visited_states = [("INIT", self.origin, state)]
51
+ initial_state = state
52
+
53
+ # Initialize state for each agent
54
+ states = [state.clone(randomness=random.randint(0, MAX_SEED)) for _ in range(self.num_agents)]
55
+
56
+ solved = False
57
+ for step in range(self.num_steps):
58
+
59
+ if solved:
60
+ break
61
+
62
+ # Generate actions for each state
63
+ action_coroutines = [
64
+ self.step_agent.act(
65
+ model=self.model,
66
+ state=state,
67
+ n=1,
68
+ namespace=namespace,
69
+ request_id=f"idx{idx}-step{step}-{hash(state)}-agent{i}",
70
+ params=self.step_params
71
+ )
72
+ for i, state in enumerate(states)
73
+ ]
74
+ actions = await asyncio.gather(*action_coroutines)
75
+
76
+ # Execute actions
77
+ states = [self.env.step(state, action[0]) for state, action in zip(states, actions)]
78
+
79
+ # Early stop in case any state is solved
80
+ if any(self.env.evaluate(state)[1] == 1 for state in states):
81
+ solved = True
82
+ break
83
+
84
+ # Filter previously visited states records
85
+ remaining_steps = self.num_steps - (step + 1)
86
+ visited_states = [(identifier, value*self.backtrack, state) for identifier, value, state in visited_states]
87
+ visited_states = [state for state in visited_states if remaining_steps >= self.min_steps - len(state[2].steps)]
88
+
89
+ # Pruning : Failed = Finished not correctly
90
+ failed = [i for i, state in enumerate(states) if self.env.is_final(state)]
91
+ if visited_states != []:
92
+ replacements, _ = resampler.resample(visited_states.copy(), len(failed), self.resampling)
93
+ else:
94
+ replacements, _ = resampler.resample([("", 1, state) for state in states], len(failed), resampling_method="linear")
95
+ states = [replacements.pop(0) if i in failed else state for i, state in enumerate(states)]
96
+
97
+ # Evaluation phase
98
+ if step < self.num_steps-1 and self.k and step % self.k == 0:
99
+
100
+ # Evaluate the states
101
+ value_coroutines = [
102
+ self.eval_agent.act(
103
+ model=self.model,
104
+ state=state,
105
+ n=self.num_evaluations,
106
+ namespace=namespace,
107
+ request_id=f"idx{idx}-evaluation{step}-{hash(state)}-agent{i}",
108
+ params=self.evaluate_params,
109
+ cache=value_cache
110
+ )
111
+ for i, state in enumerate(states)
112
+ ]
113
+ values = await asyncio.gather(*value_coroutines)
114
+
115
+ # Update previously visited states records
116
+ for i, (state, value) in enumerate(zip(states, values)):
117
+ if i not in failed:
118
+ visited_states.append((f"{i}.{step}", value, state))
119
+
120
+ # Resampling
121
+ states, resampled_idxs = resampler.resample(visited_states, self.num_agents, self.resampling)
122
+
123
+ if len(states) == 0:
124
+ return [initial_state]
125
+ else:
126
+ return states
127
+
@@ -0,0 +1,123 @@
1
+ import random
2
+ import asyncio
3
+ import logging
4
+ from typing import TypedDict
5
+ from omegaconf import OmegaConf
6
+ from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
7
+ from .. import MethodFactory, AgentDictFactory
8
+ logger = logging.getLogger(__name__)
9
+
10
+ @AgentDictFactory.register
11
+ class AgentDictGOT(TypedDict):
12
+ step: Agent
13
+ aggregate: Agent
14
+ evaluate: Agent
15
+ step_params: DecodingParameters
16
+ aggregate_params: DecodingParameters
17
+ evaluate_params: DecodingParameters
18
+
19
+ @MethodFactory.register
20
+ class MethodGOT(Method):
21
+ def __init__(self,
22
+ model: Model,
23
+ agents: AgentDictGOT,
24
+ env: Environment,
25
+ config: OmegaConf
26
+ ):
27
+ super().__init__(model, agents, env, config)
28
+
29
+ self.step_agent = agents["step"]
30
+ self.aggregate_agent = agents["aggregate"]
31
+ self.eval_agent = agents["evaluate"]
32
+
33
+ self.step_params = agents["step_params"]
34
+ self.aggregate_params = agents["aggregate_params"]
35
+ self.evaluate_params = agents["evaluate_params"]
36
+
37
+ self.num_selections = config.num_selections
38
+ self.num_steps = config.num_steps
39
+ self.num_generate = config.num_generate
40
+ self.num_best = config.num_best
41
+ self.num_evaluations = config.num_evaluations
42
+
43
+ async def solve(self, idx: int, state: State, namespace: str, value_cache: dict = None):
44
+ randomness = idx
45
+ random.seed(randomness)
46
+ states = [state.clone(randomness=random.randint(0, MAX_SEED))]
47
+ logger.debug(f"Solving game: {idx}")
48
+
49
+ solved = False
50
+ for step in range(self.num_steps):
51
+ if solved:
52
+ logger.debug(f"Task {idx} solved at step {step - 1}.")
53
+ break
54
+
55
+ logger.debug(f"Step: {step} ({idx})")
56
+ # Generate actions for each state
57
+ action_coroutines = [
58
+ self.step_agent.act(
59
+ model=self.model,
60
+ state=state,
61
+ n=self.num_generate,
62
+ namespace=namespace,
63
+ request_id=f"idx{idx}-step{step}-{hash(state)}-agent{i}",
64
+ params=self.step_params,
65
+ )
66
+ for i, state in enumerate(states)
67
+ ]
68
+ generated_actions = await asyncio.gather(*action_coroutines)
69
+ logger.debug(f"{len(generated_actions)} Actions generated for task {idx}; \n {generated_actions}")
70
+
71
+ # Aggregate actions
72
+ aggregate_coroutines = [
73
+ self.aggregate_agent.act(
74
+ model=self.model,
75
+ state=state,
76
+ actions=action,
77
+ k=self.num_selections,
78
+ namespace=namespace,
79
+ request_id=f"idx{idx}-aggregate{step}-{hash(state)}-agent{i}",
80
+ params=self.aggregate_params,
81
+ )
82
+ for i, (state, action) in enumerate(zip(states, generated_actions))
83
+ ]
84
+
85
+ actions = await asyncio.gather(*aggregate_coroutines)
86
+ logger.debug(f"{len(actions)} Actions selected for task {idx}: \n{actions}")
87
+
88
+ # Execute actions on environment
89
+ proposed_states = []
90
+ for state, actions in zip(states, actions):
91
+ for action in actions:
92
+ proposed_states.append(self.env.step(state, action))
93
+
94
+ if proposed_states == []:
95
+ break
96
+
97
+ # Early stop in case any state is solved
98
+ if any(self.env.evaluate(state)[1] == 1 for state in states):
99
+ solved = True
100
+
101
+ logger.debug(f"Env step for task {idx}: \n{proposed_states}")
102
+ # Evaluate all proposals
103
+ value_coroutines = [
104
+ self.eval_agent.act(
105
+ model=self.model,
106
+ state=state,
107
+ n=self.num_evaluations,
108
+ namespace=namespace,
109
+ request_id=f"idx{idx}-evaluation{step}-{hash(state)}-agent{i}",
110
+ params=self.evaluate_params,
111
+ cache=value_cache
112
+ )
113
+ for i, state in enumerate(proposed_states)
114
+ ]
115
+ values = await asyncio.gather(*value_coroutines)
116
+ logger.debug(f"Values given for task {idx}: \n{values}")
117
+
118
+ # Choose the best states based on their value
119
+ state_value_pairs = list(zip(proposed_states, values))
120
+ sorted_pairs = sorted(state_value_pairs, key=lambda x: x[1], reverse=True)
121
+ states, values = map(list, zip(*sorted_pairs[:self.num_best]))
122
+
123
+ return states
@@ -0,0 +1,63 @@
1
+ import random
2
+ import logging
3
+ import asyncio
4
+ from typing import TypedDict
5
+ from omegaconf import OmegaConf
6
+ from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
7
+ from .. import MethodFactory, AgentDictFactory
8
+ from ..utils import Resampler
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @AgentDictFactory.register
12
+ class AgentDictIO(TypedDict):
13
+ step: Agent # ActAgent
14
+ step_params: DecodingParameters
15
+
16
+
17
+ @MethodFactory.register
18
+ class MethodIO(Method):
19
+ def __init__(self,
20
+ model: Model,
21
+ agents: AgentDictIO,
22
+ env: Environment,
23
+ config: OmegaConf,
24
+ ):
25
+ super().__init__(model, agents, env, config)
26
+
27
+ self.step_agent = agents["step"]
28
+ self.step_params = agents["step_params"]
29
+
30
+ self.n = config.n
31
+ assert self.n == 1, "IO has only 1 output"
32
+
33
+ async def solve(self, idx: int, state: State, namespace: str, value_cache: dict=None):
34
+ randomness = idx
35
+ random.seed(randomness)
36
+
37
+ states = [state.clone(randomness=random.randint(0, MAX_SEED)) for _ in range(self.n)]
38
+
39
+ action_coroutines = [
40
+ self.step_agent.act(
41
+ model=self.model,
42
+ state=state,
43
+ n=1,
44
+ namespace=namespace,
45
+ request_id=f"idx{idx}-{hash(state)}-agent{i}",
46
+ params=self.step_params
47
+ )
48
+ for i, state in enumerate(states)
49
+ ]
50
+ actions = await asyncio.gather(*action_coroutines)
51
+
52
+ # Execute the actions
53
+ new_states = []
54
+
55
+ for state, action in zip(states, actions):
56
+ try:
57
+ new_states.append(self.env.step(state, action[0]))
58
+ except Exception as e:
59
+ print(f"Step failed: {e}")
60
+ new_states.append(state) # or whatever fallback you want
61
+
62
+ states = new_states
63
+ return states