reasonbench 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reasonbench/__init__.py +126 -0
- reasonbench/datasets.py +53 -0
- reasonbench/methods/__init__.py +9 -0
- reasonbench/methods/cot.py +55 -0
- reasonbench/methods/cot_sc.py +54 -0
- reasonbench/methods/foa.py +127 -0
- reasonbench/methods/got.py +123 -0
- reasonbench/methods/io.py +63 -0
- reasonbench/methods/rap.py +214 -0
- reasonbench/methods/react.py +54 -0
- reasonbench/methods/tot_bfs.py +93 -0
- reasonbench/methods/tot_dfs.py +169 -0
- reasonbench/models/__init__.py +4 -0
- reasonbench/models/anthropic.py +58 -0
- reasonbench/models/api.py +176 -0
- reasonbench/models/online.py +228 -0
- reasonbench/models/qroq.py +132 -0
- reasonbench/tasks/__init__.py +8 -0
- reasonbench/tasks/game24/__init__.py +3 -0
- reasonbench/tasks/game24/agents.py +440 -0
- reasonbench/tasks/game24/benchmark.py +51 -0
- reasonbench/tasks/game24/environment.py +85 -0
- reasonbench/tasks/game24/prompts.py +393 -0
- reasonbench/tasks/game24/state.py +49 -0
- reasonbench/tasks/hle/__init__.py +3 -0
- reasonbench/tasks/hle/agents.py +332 -0
- reasonbench/tasks/hle/benchmark.py +128 -0
- reasonbench/tasks/hle/environment.py +188 -0
- reasonbench/tasks/hle/prompts.py +200 -0
- reasonbench/tasks/hle/state.py +80 -0
- reasonbench/tasks/hotpotqa/__init__.py +3 -0
- reasonbench/tasks/hotpotqa/agents.py +411 -0
- reasonbench/tasks/hotpotqa/benchmark.py +104 -0
- reasonbench/tasks/hotpotqa/environment.py +123 -0
- reasonbench/tasks/hotpotqa/prompts.py +489 -0
- reasonbench/tasks/hotpotqa/state.py +61 -0
- reasonbench/tasks/humaneval/__init__.py +3 -0
- reasonbench/tasks/humaneval/agents.py +361 -0
- reasonbench/tasks/humaneval/benchmark.py +53 -0
- reasonbench/tasks/humaneval/environment.py +503 -0
- reasonbench/tasks/humaneval/prompts.py +152 -0
- reasonbench/tasks/humaneval/state.py +62 -0
- reasonbench/tasks/logiqa/__init__.py +3 -0
- reasonbench/tasks/logiqa/agents.py +258 -0
- reasonbench/tasks/logiqa/benchmark.py +81 -0
- reasonbench/tasks/logiqa/environment.py +78 -0
- reasonbench/tasks/logiqa/prompts.py +211 -0
- reasonbench/tasks/logiqa/state.py +66 -0
- reasonbench/tasks/matharena/__init__.py +0 -0
- reasonbench/tasks/matharena/agents.py +154 -0
- reasonbench/tasks/matharena/benchmark.py +105 -0
- reasonbench/tasks/matharena/environment.py +110 -0
- reasonbench/tasks/matharena/prompts.py +79 -0
- reasonbench/tasks/matharena/state.py +66 -0
- reasonbench/tasks/scibench/__init__.py +3 -0
- reasonbench/tasks/scibench/agents.py +460 -0
- reasonbench/tasks/scibench/benchmark.py +89 -0
- reasonbench/tasks/scibench/environment.py +103 -0
- reasonbench/tasks/scibench/prompts.py +242 -0
- reasonbench/tasks/scibench/state.py +61 -0
- reasonbench/tasks/sonnetwriting/__init__.py +3 -0
- reasonbench/tasks/sonnetwriting/agents.py +338 -0
- reasonbench/tasks/sonnetwriting/benchmark.py +92 -0
- reasonbench/tasks/sonnetwriting/environment.py +302 -0
- reasonbench/tasks/sonnetwriting/prompts.py +382 -0
- reasonbench/tasks/sonnetwriting/state.py +52 -0
- reasonbench/typedefs.py +154 -0
- reasonbench/utils.py +430 -0
- reasonbench-0.0.1.dist-info/METADATA +242 -0
- reasonbench-0.0.1.dist-info/RECORD +73 -0
- reasonbench-0.0.1.dist-info/WHEEL +5 -0
- reasonbench-0.0.1.dist-info/licenses/LICENSE +21 -0
- reasonbench-0.0.1.dist-info/top_level.txt +1 -0
reasonbench/__init__.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from typing import TypedDict
|
|
2
|
+
from .typedefs import DecodingParameters
|
|
3
|
+
from .datasets import get_dataset_path
|
|
4
|
+
|
|
5
|
+
class BenchmarkFactory:
|
|
6
|
+
registry = {}
|
|
7
|
+
|
|
8
|
+
@classmethod
|
|
9
|
+
def register(cls, benchmark_cls):
|
|
10
|
+
cls.registry[benchmark_cls.__name__.lower()] = benchmark_cls
|
|
11
|
+
return benchmark_cls
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def get(cls, task: str, *args, **kwargs):
|
|
15
|
+
key = f"benchmark{task}".lower()
|
|
16
|
+
try:
|
|
17
|
+
path = get_dataset_path(task)
|
|
18
|
+
return cls.registry[key](path=path, *args, **kwargs)
|
|
19
|
+
except KeyError:
|
|
20
|
+
raise ValueError(f"No benchmark found for task={task}")
|
|
21
|
+
|
|
22
|
+
class EnvironmentFactory:
|
|
23
|
+
registry = {}
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def register(cls, env_cls):
|
|
27
|
+
cls.registry[env_cls.__name__.lower()] = env_cls
|
|
28
|
+
return env_cls
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def get(cls, task: str, *args, **kwargs):
|
|
32
|
+
key = f"environment{task}".lower()
|
|
33
|
+
try:
|
|
34
|
+
return cls.registry[key](*args, **kwargs)
|
|
35
|
+
except KeyError:
|
|
36
|
+
raise ValueError(f"No environment found for task={task}")
|
|
37
|
+
|
|
38
|
+
class AgentFactory:
|
|
39
|
+
registry = {}
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def register(cls, agent_cls):
|
|
43
|
+
cls.registry[agent_cls.__name__.lower()] = agent_cls
|
|
44
|
+
return agent_cls
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def get(cls, agent_type: str, benchmark: str, *args, **kwargs):
|
|
48
|
+
key = f"agent{agent_type}{benchmark}".lower()
|
|
49
|
+
try:
|
|
50
|
+
return cls.registry[key]#(*args, **kwargs) : Not initialized
|
|
51
|
+
except KeyError:
|
|
52
|
+
raise ValueError(f"No agent found for type={agent_type}, benchmark={benchmark}")
|
|
53
|
+
|
|
54
|
+
class AgentDictFactory:
|
|
55
|
+
registry = {}
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def register(cls, agent_dict_cls):
|
|
59
|
+
cls.registry[agent_dict_cls.__name__.lower()] = agent_dict_cls
|
|
60
|
+
return agent_dict_cls
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def get(cls, method: str, *args, **kwargs):
|
|
64
|
+
key = f"agentdict{method}".lower()
|
|
65
|
+
try:
|
|
66
|
+
return cls.registry[key](*args, **kwargs)
|
|
67
|
+
except KeyError:
|
|
68
|
+
raise ValueError(f"No agent dict found for method={method}")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class MethodFactory:
|
|
72
|
+
registry = {}
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def register(cls, method_cls):
|
|
76
|
+
cls.registry[method_cls.__name__.lower()] = method_cls
|
|
77
|
+
return method_cls
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def get(cls, method: str, benchmark: str, params: DecodingParameters, *args, **kwargs):
|
|
81
|
+
key = f"method{method}".lower()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if method == "io":
|
|
85
|
+
agents = {
|
|
86
|
+
"step": AgentFactory.get("io", benchmark),
|
|
87
|
+
}
|
|
88
|
+
elif method in ["cot", "cot_sc"]:
|
|
89
|
+
agents = {
|
|
90
|
+
"step": AgentFactory.get("cot", benchmark),
|
|
91
|
+
}
|
|
92
|
+
elif method == "foa":
|
|
93
|
+
agents = {
|
|
94
|
+
"step": AgentFactory.get("act", benchmark),
|
|
95
|
+
"evaluate": AgentFactory.get("evaluate", benchmark),
|
|
96
|
+
}
|
|
97
|
+
elif method in ["tot_bfs", "tot_dfs"]:
|
|
98
|
+
agents = {
|
|
99
|
+
"step": AgentFactory.get("bfs", benchmark),
|
|
100
|
+
"evaluate": AgentFactory.get("evaluate", benchmark),
|
|
101
|
+
}
|
|
102
|
+
elif method == "got":
|
|
103
|
+
agents = {
|
|
104
|
+
"step": AgentFactory.get("act", benchmark),
|
|
105
|
+
"aggregate": AgentFactory.get("aggregate", benchmark),
|
|
106
|
+
"evaluate": AgentFactory.get("evaluate", benchmark),
|
|
107
|
+
}
|
|
108
|
+
elif method == "rap":
|
|
109
|
+
agents = {
|
|
110
|
+
"step": AgentFactory.get("react", benchmark),
|
|
111
|
+
"evaluate": AgentFactory.get("selfevaluate", benchmark),
|
|
112
|
+
}
|
|
113
|
+
elif method == "react":
|
|
114
|
+
agents = {
|
|
115
|
+
"step": AgentFactory.get("react", benchmark),
|
|
116
|
+
}
|
|
117
|
+
else:
|
|
118
|
+
raise NotImplementedError(f"Method {method} is not implemented yet.")
|
|
119
|
+
|
|
120
|
+
# For the moment, only supporting same params for all agents
|
|
121
|
+
agents.update({k+"_params": params for k in agents.keys()})
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
return cls.registry[key](agents=agents, *args, **kwargs)
|
|
125
|
+
except KeyError:
|
|
126
|
+
raise ValueError(f"No method found for name={method}")
|
reasonbench/datasets.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Utility for downloading datasets from HuggingFace Hub."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from huggingface_hub import hf_hub_download
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
HF_REPO_ID = "potamitisn/ReasonBench"
|
|
11
|
+
LOCAL_DATASETS_DIR = Path("datasets")
|
|
12
|
+
|
|
13
|
+
# Maps task name to the actual filename on HF Hub
|
|
14
|
+
DATASET_FILES = {
|
|
15
|
+
"game24": "dataset_game24.csv.gz",
|
|
16
|
+
"hle": "dataset_hle.jsonl.gz",
|
|
17
|
+
"hotpotqa": "dataset_hotpotqa.csv.gz",
|
|
18
|
+
"humaneval": "dataset_humaneval.csv.gz",
|
|
19
|
+
"logiqa": "dataset_logiqa.csv.gz",
|
|
20
|
+
"matharena": "dataset_matharena.jsonl.gz",
|
|
21
|
+
"scibench": "dataset_scibench.csv.gz",
|
|
22
|
+
"sonnetwriting": "dataset_sonnetwriting.jsonl.gz",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_dataset_path(task: str, local_dir: str | Path = LOCAL_DATASETS_DIR) -> str:
|
|
27
|
+
"""Return the local path to a dataset file, downloading from HF Hub if needed.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
task: Task name (e.g., "game24", "hle").
|
|
31
|
+
local_dir: Local directory to store/look for datasets.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Path to the local dataset file.
|
|
35
|
+
"""
|
|
36
|
+
if task not in DATASET_FILES:
|
|
37
|
+
raise ValueError(f"Unknown task: {task}. Available: {list(DATASET_FILES.keys())}")
|
|
38
|
+
|
|
39
|
+
filename = DATASET_FILES[task]
|
|
40
|
+
local_path = Path(local_dir) / filename
|
|
41
|
+
|
|
42
|
+
if local_path.exists():
|
|
43
|
+
return str(local_path)
|
|
44
|
+
|
|
45
|
+
logger.info(f"Dataset for '{task}' not found locally. Downloading from HuggingFace Hub...")
|
|
46
|
+
downloaded = hf_hub_download(
|
|
47
|
+
repo_id=HF_REPO_ID,
|
|
48
|
+
filename=filename,
|
|
49
|
+
repo_type="dataset",
|
|
50
|
+
local_dir=str(local_dir),
|
|
51
|
+
)
|
|
52
|
+
logger.info(f"Downloaded to {downloaded}")
|
|
53
|
+
return str(downloaded)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .foa import AgentDictFOA, MethodFOA
|
|
2
|
+
from .tot_bfs import AgentDictTOT, MethodTOT_BFS
|
|
3
|
+
from .tot_dfs import AgentDictTOT, MethodTOT_DFS
|
|
4
|
+
from .got import AgentDictGOT, MethodGOT
|
|
5
|
+
from .rap import MethodRAP, AgentDictRAP
|
|
6
|
+
from .react import AgentDictReact, MethodReact
|
|
7
|
+
from .io import AgentDictIO, MethodIO
|
|
8
|
+
from .cot import AgentDictCOT, MethodCOT
|
|
9
|
+
from .cot_sc import AgentDictCoT, MethodCOT_SC
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
from omegaconf import OmegaConf
|
|
6
|
+
from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
|
|
7
|
+
from .. import MethodFactory, AgentDictFactory
|
|
8
|
+
from ..utils import Resampler
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
@AgentDictFactory.register
|
|
12
|
+
class AgentDictCOT(TypedDict):
|
|
13
|
+
step: Agent # ActAgent
|
|
14
|
+
step_params: DecodingParameters
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@MethodFactory.register
|
|
18
|
+
class MethodCOT(Method):
|
|
19
|
+
def __init__(self,
|
|
20
|
+
model: Model,
|
|
21
|
+
agents: AgentDictCOT,
|
|
22
|
+
env: Environment,
|
|
23
|
+
config: OmegaConf,
|
|
24
|
+
n: int = 1
|
|
25
|
+
):
|
|
26
|
+
super().__init__(model, agents, env, config)
|
|
27
|
+
|
|
28
|
+
self.step_agent = agents["step"]
|
|
29
|
+
self.step_params = agents["step_params"]
|
|
30
|
+
|
|
31
|
+
self.n = config.n
|
|
32
|
+
assert self.n == 1, "CoT has only 1 output"
|
|
33
|
+
|
|
34
|
+
async def solve(self, idx: int, state: State, namespace: str, value_cache: dict=None):
|
|
35
|
+
randomness = idx
|
|
36
|
+
random.seed(randomness)
|
|
37
|
+
|
|
38
|
+
states = [state.clone(randomness=random.randint(0, MAX_SEED)) for _ in range(self.n)]
|
|
39
|
+
|
|
40
|
+
action_coroutines = [
|
|
41
|
+
self.step_agent.act(
|
|
42
|
+
model=self.model,
|
|
43
|
+
state=state,
|
|
44
|
+
n=1,
|
|
45
|
+
namespace=namespace,
|
|
46
|
+
request_id=f"idx{idx}-{hash(state)}-agent{i}",
|
|
47
|
+
params=self.step_params
|
|
48
|
+
)
|
|
49
|
+
for i, state in enumerate(states)
|
|
50
|
+
]
|
|
51
|
+
actions = await asyncio.gather(*action_coroutines)
|
|
52
|
+
|
|
53
|
+
# Execute the actions
|
|
54
|
+
states = [self.env.step(state, action[0]) for state, action in zip(states, actions)]
|
|
55
|
+
return states
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from omegaconf import OmegaConf
|
|
7
|
+
from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
|
|
8
|
+
from .. import MethodFactory, AgentDictFactory
|
|
9
|
+
from ..utils import Resampler
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
@AgentDictFactory.register
|
|
13
|
+
class AgentDictCoT(TypedDict):
|
|
14
|
+
step: Agent # ActAgent
|
|
15
|
+
step_params: DecodingParameters
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@MethodFactory.register
|
|
19
|
+
class MethodCOT_SC(Method):
|
|
20
|
+
def __init__(self,
|
|
21
|
+
model: Model,
|
|
22
|
+
agents: AgentDictCoT,
|
|
23
|
+
env: Environment,
|
|
24
|
+
config: OmegaConf,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(model, agents, env, config)
|
|
27
|
+
|
|
28
|
+
self.step_agent = agents["step"]
|
|
29
|
+
self.step_params = agents["step_params"]
|
|
30
|
+
|
|
31
|
+
self.n = config.n
|
|
32
|
+
assert self.n > 1, "CoT-SC needs at least 2 outputs"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
async def solve(self, idx: int, state: State, namespace: str, value_cache: dict=None):
|
|
36
|
+
randomness = idx
|
|
37
|
+
random.seed(randomness)
|
|
38
|
+
|
|
39
|
+
actions = await self.step_agent.act(
|
|
40
|
+
model=self.model,
|
|
41
|
+
state=state,
|
|
42
|
+
n=self.n,
|
|
43
|
+
namespace=namespace,
|
|
44
|
+
request_id=f"idx{idx}-{hash(state)}",
|
|
45
|
+
params=self.step_params
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
votes = [action for action in actions]
|
|
49
|
+
counts = Counter(votes)
|
|
50
|
+
most_common_action = counts.most_common(1)[0][0]
|
|
51
|
+
state = self.env.step(state, most_common_action)
|
|
52
|
+
return [state]
|
|
53
|
+
|
|
54
|
+
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
from omegaconf import OmegaConf
|
|
6
|
+
from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
|
|
7
|
+
from .. import MethodFactory, AgentDictFactory
|
|
8
|
+
from ..utils import Resampler
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
@AgentDictFactory.register
|
|
12
|
+
class AgentDictFOA(TypedDict):
|
|
13
|
+
step: Agent # ActAgent
|
|
14
|
+
evaluate: Agent # EvaluateAgent
|
|
15
|
+
step_params: DecodingParameters
|
|
16
|
+
evaluate_params: DecodingParameters
|
|
17
|
+
|
|
18
|
+
@MethodFactory.register
|
|
19
|
+
class MethodFOA(Method):
|
|
20
|
+
def __init__(self,
|
|
21
|
+
agents: AgentDictFOA,
|
|
22
|
+
model: Model,
|
|
23
|
+
env: Environment,
|
|
24
|
+
config: OmegaConf
|
|
25
|
+
):
|
|
26
|
+
super().__init__(model, agents, env, config)
|
|
27
|
+
|
|
28
|
+
self.step_agent = agents["step"]
|
|
29
|
+
self.eval_agent = agents["evaluate"]
|
|
30
|
+
|
|
31
|
+
self.step_params = agents["step_params"]
|
|
32
|
+
self.evaluate_params = agents["evaluate_params"]
|
|
33
|
+
|
|
34
|
+
self.num_agents = config.num_agents
|
|
35
|
+
self.num_steps = config.num_steps
|
|
36
|
+
self.k = config.k
|
|
37
|
+
self.backtrack = config.backtrack
|
|
38
|
+
self.resampling = config.resampling
|
|
39
|
+
self.origin = config.origin
|
|
40
|
+
self.min_steps = config.min_steps
|
|
41
|
+
self.num_evaluations = config.num_evaluations
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def solve(self, idx: int, state: State, namespace: str, value_cache: dict = None):
|
|
45
|
+
randomness = idx
|
|
46
|
+
random.seed(randomness)
|
|
47
|
+
resampler = Resampler(randomness)
|
|
48
|
+
|
|
49
|
+
# Records of previously visited states (state_identifier, state_value, state)
|
|
50
|
+
visited_states = [("INIT", self.origin, state)]
|
|
51
|
+
initial_state = state
|
|
52
|
+
|
|
53
|
+
# Initialize state for each agent
|
|
54
|
+
states = [state.clone(randomness=random.randint(0, MAX_SEED)) for _ in range(self.num_agents)]
|
|
55
|
+
|
|
56
|
+
solved = False
|
|
57
|
+
for step in range(self.num_steps):
|
|
58
|
+
|
|
59
|
+
if solved:
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
# Generate actions for each state
|
|
63
|
+
action_coroutines = [
|
|
64
|
+
self.step_agent.act(
|
|
65
|
+
model=self.model,
|
|
66
|
+
state=state,
|
|
67
|
+
n=1,
|
|
68
|
+
namespace=namespace,
|
|
69
|
+
request_id=f"idx{idx}-step{step}-{hash(state)}-agent{i}",
|
|
70
|
+
params=self.step_params
|
|
71
|
+
)
|
|
72
|
+
for i, state in enumerate(states)
|
|
73
|
+
]
|
|
74
|
+
actions = await asyncio.gather(*action_coroutines)
|
|
75
|
+
|
|
76
|
+
# Execute actions
|
|
77
|
+
states = [self.env.step(state, action[0]) for state, action in zip(states, actions)]
|
|
78
|
+
|
|
79
|
+
# Early stop in case any state is solved
|
|
80
|
+
if any(self.env.evaluate(state)[1] == 1 for state in states):
|
|
81
|
+
solved = True
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
# Filter previously visited states records
|
|
85
|
+
remaining_steps = self.num_steps - (step + 1)
|
|
86
|
+
visited_states = [(identifier, value*self.backtrack, state) for identifier, value, state in visited_states]
|
|
87
|
+
visited_states = [state for state in visited_states if remaining_steps >= self.min_steps - len(state[2].steps)]
|
|
88
|
+
|
|
89
|
+
# Pruning : Failed = Finished not correctly
|
|
90
|
+
failed = [i for i, state in enumerate(states) if self.env.is_final(state)]
|
|
91
|
+
if visited_states != []:
|
|
92
|
+
replacements, _ = resampler.resample(visited_states.copy(), len(failed), self.resampling)
|
|
93
|
+
else:
|
|
94
|
+
replacements, _ = resampler.resample([("", 1, state) for state in states], len(failed), resampling_method="linear")
|
|
95
|
+
states = [replacements.pop(0) if i in failed else state for i, state in enumerate(states)]
|
|
96
|
+
|
|
97
|
+
# Evaluation phase
|
|
98
|
+
if step < self.num_steps-1 and self.k and step % self.k == 0:
|
|
99
|
+
|
|
100
|
+
# Evaluate the states
|
|
101
|
+
value_coroutines = [
|
|
102
|
+
self.eval_agent.act(
|
|
103
|
+
model=self.model,
|
|
104
|
+
state=state,
|
|
105
|
+
n=self.num_evaluations,
|
|
106
|
+
namespace=namespace,
|
|
107
|
+
request_id=f"idx{idx}-evaluation{step}-{hash(state)}-agent{i}",
|
|
108
|
+
params=self.evaluate_params,
|
|
109
|
+
cache=value_cache
|
|
110
|
+
)
|
|
111
|
+
for i, state in enumerate(states)
|
|
112
|
+
]
|
|
113
|
+
values = await asyncio.gather(*value_coroutines)
|
|
114
|
+
|
|
115
|
+
# Update previously visited states records
|
|
116
|
+
for i, (state, value) in enumerate(zip(states, values)):
|
|
117
|
+
if i not in failed:
|
|
118
|
+
visited_states.append((f"{i}.{step}", value, state))
|
|
119
|
+
|
|
120
|
+
# Resampling
|
|
121
|
+
states, resampled_idxs = resampler.resample(visited_states, self.num_agents, self.resampling)
|
|
122
|
+
|
|
123
|
+
if len(states) == 0:
|
|
124
|
+
return [initial_state]
|
|
125
|
+
else:
|
|
126
|
+
return states
|
|
127
|
+
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import asyncio
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
from omegaconf import OmegaConf
|
|
6
|
+
from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
|
|
7
|
+
from .. import MethodFactory, AgentDictFactory
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
@AgentDictFactory.register
|
|
11
|
+
class AgentDictGOT(TypedDict):
|
|
12
|
+
step: Agent
|
|
13
|
+
aggregate: Agent
|
|
14
|
+
evaluate: Agent
|
|
15
|
+
step_params: DecodingParameters
|
|
16
|
+
aggregate_params: DecodingParameters
|
|
17
|
+
evaluate_params: DecodingParameters
|
|
18
|
+
|
|
19
|
+
@MethodFactory.register
|
|
20
|
+
class MethodGOT(Method):
|
|
21
|
+
def __init__(self,
|
|
22
|
+
model: Model,
|
|
23
|
+
agents: AgentDictGOT,
|
|
24
|
+
env: Environment,
|
|
25
|
+
config: OmegaConf
|
|
26
|
+
):
|
|
27
|
+
super().__init__(model, agents, env, config)
|
|
28
|
+
|
|
29
|
+
self.step_agent = agents["step"]
|
|
30
|
+
self.aggregate_agent = agents["aggregate"]
|
|
31
|
+
self.eval_agent = agents["evaluate"]
|
|
32
|
+
|
|
33
|
+
self.step_params = agents["step_params"]
|
|
34
|
+
self.aggregate_params = agents["aggregate_params"]
|
|
35
|
+
self.evaluate_params = agents["evaluate_params"]
|
|
36
|
+
|
|
37
|
+
self.num_selections = config.num_selections
|
|
38
|
+
self.num_steps = config.num_steps
|
|
39
|
+
self.num_generate = config.num_generate
|
|
40
|
+
self.num_best = config.num_best
|
|
41
|
+
self.num_evaluations = config.num_evaluations
|
|
42
|
+
|
|
43
|
+
async def solve(self, idx: int, state: State, namespace: str, value_cache: dict = None):
|
|
44
|
+
randomness = idx
|
|
45
|
+
random.seed(randomness)
|
|
46
|
+
states = [state.clone(randomness=random.randint(0, MAX_SEED))]
|
|
47
|
+
logger.debug(f"Solving game: {idx}")
|
|
48
|
+
|
|
49
|
+
solved = False
|
|
50
|
+
for step in range(self.num_steps):
|
|
51
|
+
if solved:
|
|
52
|
+
logger.debug(f"Task {idx} solved at step {step - 1}.")
|
|
53
|
+
break
|
|
54
|
+
|
|
55
|
+
logger.debug(f"Step: {step} ({idx})")
|
|
56
|
+
# Generate actions for each state
|
|
57
|
+
action_coroutines = [
|
|
58
|
+
self.step_agent.act(
|
|
59
|
+
model=self.model,
|
|
60
|
+
state=state,
|
|
61
|
+
n=self.num_generate,
|
|
62
|
+
namespace=namespace,
|
|
63
|
+
request_id=f"idx{idx}-step{step}-{hash(state)}-agent{i}",
|
|
64
|
+
params=self.step_params,
|
|
65
|
+
)
|
|
66
|
+
for i, state in enumerate(states)
|
|
67
|
+
]
|
|
68
|
+
generated_actions = await asyncio.gather(*action_coroutines)
|
|
69
|
+
logger.debug(f"{len(generated_actions)} Actions generated for task {idx}; \n {generated_actions}")
|
|
70
|
+
|
|
71
|
+
# Aggregate actions
|
|
72
|
+
aggregate_coroutines = [
|
|
73
|
+
self.aggregate_agent.act(
|
|
74
|
+
model=self.model,
|
|
75
|
+
state=state,
|
|
76
|
+
actions=action,
|
|
77
|
+
k=self.num_selections,
|
|
78
|
+
namespace=namespace,
|
|
79
|
+
request_id=f"idx{idx}-aggregate{step}-{hash(state)}-agent{i}",
|
|
80
|
+
params=self.aggregate_params,
|
|
81
|
+
)
|
|
82
|
+
for i, (state, action) in enumerate(zip(states, generated_actions))
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
actions = await asyncio.gather(*aggregate_coroutines)
|
|
86
|
+
logger.debug(f"{len(actions)} Actions selected for task {idx}: \n{actions}")
|
|
87
|
+
|
|
88
|
+
# Execute actions on environment
|
|
89
|
+
proposed_states = []
|
|
90
|
+
for state, actions in zip(states, actions):
|
|
91
|
+
for action in actions:
|
|
92
|
+
proposed_states.append(self.env.step(state, action))
|
|
93
|
+
|
|
94
|
+
if proposed_states == []:
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
# Early stop in case any state is solved
|
|
98
|
+
if any(self.env.evaluate(state)[1] == 1 for state in states):
|
|
99
|
+
solved = True
|
|
100
|
+
|
|
101
|
+
logger.debug(f"Env step for task {idx}: \n{proposed_states}")
|
|
102
|
+
# Evaluate all proposals
|
|
103
|
+
value_coroutines = [
|
|
104
|
+
self.eval_agent.act(
|
|
105
|
+
model=self.model,
|
|
106
|
+
state=state,
|
|
107
|
+
n=self.num_evaluations,
|
|
108
|
+
namespace=namespace,
|
|
109
|
+
request_id=f"idx{idx}-evaluation{step}-{hash(state)}-agent{i}",
|
|
110
|
+
params=self.evaluate_params,
|
|
111
|
+
cache=value_cache
|
|
112
|
+
)
|
|
113
|
+
for i, state in enumerate(proposed_states)
|
|
114
|
+
]
|
|
115
|
+
values = await asyncio.gather(*value_coroutines)
|
|
116
|
+
logger.debug(f"Values given for task {idx}: \n{values}")
|
|
117
|
+
|
|
118
|
+
# Choose the best states based on their value
|
|
119
|
+
state_value_pairs = list(zip(proposed_states, values))
|
|
120
|
+
sorted_pairs = sorted(state_value_pairs, key=lambda x: x[1], reverse=True)
|
|
121
|
+
states, values = map(list, zip(*sorted_pairs[:self.num_best]))
|
|
122
|
+
|
|
123
|
+
return states
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
from omegaconf import OmegaConf
|
|
6
|
+
from ..typedefs import Method, Model, Agent, Environment, DecodingParameters, State, Benchmark, MAX_SEED
|
|
7
|
+
from .. import MethodFactory, AgentDictFactory
|
|
8
|
+
from ..utils import Resampler
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
@AgentDictFactory.register
|
|
12
|
+
class AgentDictIO(TypedDict):
|
|
13
|
+
step: Agent # ActAgent
|
|
14
|
+
step_params: DecodingParameters
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@MethodFactory.register
|
|
18
|
+
class MethodIO(Method):
|
|
19
|
+
def __init__(self,
|
|
20
|
+
model: Model,
|
|
21
|
+
agents: AgentDictIO,
|
|
22
|
+
env: Environment,
|
|
23
|
+
config: OmegaConf,
|
|
24
|
+
):
|
|
25
|
+
super().__init__(model, agents, env, config)
|
|
26
|
+
|
|
27
|
+
self.step_agent = agents["step"]
|
|
28
|
+
self.step_params = agents["step_params"]
|
|
29
|
+
|
|
30
|
+
self.n = config.n
|
|
31
|
+
assert self.n == 1, "IO has only 1 output"
|
|
32
|
+
|
|
33
|
+
async def solve(self, idx: int, state: State, namespace: str, value_cache: dict=None):
|
|
34
|
+
randomness = idx
|
|
35
|
+
random.seed(randomness)
|
|
36
|
+
|
|
37
|
+
states = [state.clone(randomness=random.randint(0, MAX_SEED)) for _ in range(self.n)]
|
|
38
|
+
|
|
39
|
+
action_coroutines = [
|
|
40
|
+
self.step_agent.act(
|
|
41
|
+
model=self.model,
|
|
42
|
+
state=state,
|
|
43
|
+
n=1,
|
|
44
|
+
namespace=namespace,
|
|
45
|
+
request_id=f"idx{idx}-{hash(state)}-agent{i}",
|
|
46
|
+
params=self.step_params
|
|
47
|
+
)
|
|
48
|
+
for i, state in enumerate(states)
|
|
49
|
+
]
|
|
50
|
+
actions = await asyncio.gather(*action_coroutines)
|
|
51
|
+
|
|
52
|
+
# Execute the actions
|
|
53
|
+
new_states = []
|
|
54
|
+
|
|
55
|
+
for state, action in zip(states, actions):
|
|
56
|
+
try:
|
|
57
|
+
new_states.append(self.env.step(state, action[0]))
|
|
58
|
+
except Exception as e:
|
|
59
|
+
print(f"Step failed: {e}")
|
|
60
|
+
new_states.append(state) # or whatever fallback you want
|
|
61
|
+
|
|
62
|
+
states = new_states
|
|
63
|
+
return states
|