ai-nk-cce 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_nk_cce-0.1.0.dist-info/METADATA +118 -0
- ai_nk_cce-0.1.0.dist-info/RECORD +46 -0
- ai_nk_cce-0.1.0.dist-info/WHEEL +4 -0
- api/__init__.py +0 -0
- api/mpcdf_vllm.py +94 -0
- evals/nk_model.py +277 -0
- model/README.md +64 -0
- model/config/dataset_conv_v1.yml +9 -0
- model/config/dataset_conv_v2_m2.yml +9 -0
- model/config/dataset_conv_v3_m2_assembl_nearest.yml +9 -0
- model/config/dataset_debug.yml +9 -0
- model/config/dataset_v4_int_format.yml +9 -0
- model/config/dataset_v5.yml +9 -0
- model/config/inference.yml +7 -0
- model/config/train.yml +24 -0
- model/config/train_debug.yml +19 -0
- model/config/train_from_checkpoint.yml +24 -0
- model/config/train_from_checkpoint_debug.yml +19 -0
- model/config/train_grpo.yml +30 -0
- model/config/train_grpo_debug.yml +30 -0
- model/config/train_grpo_debug_vllm.yml +32 -0
- model/config.py +54 -0
- model/dataset.py +324 -0
- model/inference.py +51 -0
- model/nk_assistant.py +207 -0
- model/parser.py +70 -0
- model/run_slurm.py +335 -0
- model/score.ipynb +596 -0
- model/scripts/template.slurm +54 -0
- model/scripts/template_rl.slurm +54 -0
- model/train.py +293 -0
- nk_model/__init__.py +0 -0
- nk_model/assembler.py +112 -0
- nk_model/biased_prediction_agent.py +389 -0
- nk_model/dataset.py +434 -0
- nk_model/enums.py +21 -0
- nk_model/landscape_cache.py +149 -0
- nk_model/models.py +172 -0
- nk_model/nk_landscape.py +498 -0
- simulation/hill_climber_simulation.py +211 -0
- simulation/hill_climber_vs_ai_simulation.py +132 -0
- simulation/landscape_selection.py +179 -0
- utils/__init__.py +0 -0
- utils/binary_conversion.py +128 -0
- utils/logging.py +33 -0
- utils/utils.py +51 -0
model/nk_assistant.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
8
|
+
|
|
9
|
+
from src.api.mpcdf_vllm import mpcdf_vllm_request
|
|
10
|
+
from src.model.config import NKAssistantConfig, TrainConfig
|
|
11
|
+
from src.model.parser import create_context, create_target, target_to_int
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_model_and_tokenizer(model_path: str):
|
|
15
|
+
"""Load the model and tokenizer from the given path"""
|
|
16
|
+
try:
|
|
17
|
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
18
|
+
except Exception:
|
|
19
|
+
# Fix if tokenizer is not found
|
|
20
|
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
21
|
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
|
22
|
+
|
|
23
|
+
# Ensure pad_token is set
|
|
24
|
+
if tokenizer.pad_token_id is None:
|
|
25
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
26
|
+
|
|
27
|
+
# Move model to GPU if available
|
|
28
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
29
|
+
model.to(device)
|
|
30
|
+
|
|
31
|
+
return model, tokenizer, device
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class NKAssistant:
|
|
35
|
+
def __init__(
|
|
36
|
+
self, config: NKAssistantConfig, metadata: Optional[TrainConfig] = None
|
|
37
|
+
):
|
|
38
|
+
if not config.use_mpcdf_vllm:
|
|
39
|
+
model, tokenizer, device = load_model_and_tokenizer(
|
|
40
|
+
config.model_path
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
model = None
|
|
44
|
+
tokenizer = None
|
|
45
|
+
device = None
|
|
46
|
+
self.model = model
|
|
47
|
+
self.tokenizer = tokenizer
|
|
48
|
+
self.device = device
|
|
49
|
+
self.config = config
|
|
50
|
+
self.metadata = metadata
|
|
51
|
+
self.generation_params = config.generation_params
|
|
52
|
+
|
|
53
|
+
def _save_pretrained(self, save_directory: Path) -> None:
|
|
54
|
+
save_directory.mkdir(parents=True, exist_ok=True)
|
|
55
|
+
self.model.save_pretrained(save_directory)
|
|
56
|
+
self.tokenizer.save_pretrained(save_directory)
|
|
57
|
+
if self.metadata:
|
|
58
|
+
with open(save_directory / "metadata.json", "w") as f:
|
|
59
|
+
json.dump(self.metadata.model_dump(), f)
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def from_pretrained(
|
|
63
|
+
cls, pretrained_model_name_or_path: Union[str, Path], **kwargs
|
|
64
|
+
):
|
|
65
|
+
# Use Hugging Face logic to locate the local or remote files
|
|
66
|
+
model_path = Path(pretrained_model_name_or_path)
|
|
67
|
+
|
|
68
|
+
if not model_path.exists():
|
|
69
|
+
# Download files from Hub into cache
|
|
70
|
+
from huggingface_hub import snapshot_download
|
|
71
|
+
|
|
72
|
+
model_path = Path(
|
|
73
|
+
snapshot_download(repo_id=pretrained_model_name_or_path)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
metadata_path = model_path / "metadata.json"
|
|
77
|
+
metadata = None
|
|
78
|
+
if metadata_path.exists():
|
|
79
|
+
with open(metadata_path) as f:
|
|
80
|
+
metadata = TrainConfig(**json.load(f))
|
|
81
|
+
|
|
82
|
+
config = metadata.assistant_config
|
|
83
|
+
config.model_path = str(model_path)
|
|
84
|
+
|
|
85
|
+
return cls(config=config, metadata=metadata)
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def create_text_from_row(row, include_payoff=True, include_target=True):
|
|
89
|
+
sample_payoff = np.array(row["payoffs"])[row["sample_idx"]]
|
|
90
|
+
|
|
91
|
+
context = create_context(
|
|
92
|
+
n=row["n"],
|
|
93
|
+
k=row["k"],
|
|
94
|
+
power_scale=row["power_scale"],
|
|
95
|
+
sample_idxs=row["sample_idx"],
|
|
96
|
+
origin_idx=row["origin_idx"],
|
|
97
|
+
hamming_distance=row["hamming_distance"],
|
|
98
|
+
payoff=sample_payoff,
|
|
99
|
+
include_payoff=include_payoff,
|
|
100
|
+
)
|
|
101
|
+
if include_target:
|
|
102
|
+
target = create_target(target_idx=row["target_idx"])
|
|
103
|
+
return {"context": context, "target": target}
|
|
104
|
+
return {"context": context}
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def create_prompt_for_rl_from_row(row):
|
|
108
|
+
sample_payoff = np.array(row["payoffs"])[row["sample_idx"]]
|
|
109
|
+
context = create_context(
|
|
110
|
+
n=row["n"],
|
|
111
|
+
k=row["k"],
|
|
112
|
+
power_scale=row["power_scale"],
|
|
113
|
+
sample_idxs=row["sample_idx"],
|
|
114
|
+
origin_idx=row["origin_idx"],
|
|
115
|
+
hamming_distance=row["hamming_distance"],
|
|
116
|
+
payoff=sample_payoff,
|
|
117
|
+
include_payoff=True,
|
|
118
|
+
)
|
|
119
|
+
return {"prompt": context}
|
|
120
|
+
|
|
121
|
+
def suggest_target_from_context(
|
|
122
|
+
self, context: str, target_length: int
|
|
123
|
+
) -> str:
|
|
124
|
+
# Tokenize input
|
|
125
|
+
input_ids = self.tokenizer.encode(context, return_tensors="pt").to(
|
|
126
|
+
self.device
|
|
127
|
+
)
|
|
128
|
+
n_input = input_ids.shape[1]
|
|
129
|
+
|
|
130
|
+
# Generate output
|
|
131
|
+
with torch.no_grad():
|
|
132
|
+
output = self.model.generate(
|
|
133
|
+
input_ids,
|
|
134
|
+
max_length=n_input + target_length,
|
|
135
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
136
|
+
**self.generation_params,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Decode only the new tokens (exclude the input tokens)
|
|
140
|
+
suggestion = self.tokenizer.decode(
|
|
141
|
+
output[0, n_input:], skip_special_tokens=True
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return suggestion
|
|
145
|
+
|
|
146
|
+
def suggest_from_row(self, row):
|
|
147
|
+
context = self.create_text_from_row(row, include_target=False)[
|
|
148
|
+
"context"
|
|
149
|
+
]
|
|
150
|
+
n = row["n"]
|
|
151
|
+
target_length = n * 2 - 1
|
|
152
|
+
if self.config.use_mpcdf_vllm:
|
|
153
|
+
binary_string = mpcdf_vllm_request(
|
|
154
|
+
prompt=context,
|
|
155
|
+
max_tokens=target_length,
|
|
156
|
+
**self.generation_params,
|
|
157
|
+
)
|
|
158
|
+
else:
|
|
159
|
+
binary_string = self.suggest_target_from_context(
|
|
160
|
+
context, target_length
|
|
161
|
+
)
|
|
162
|
+
return {"suggestion": binary_string}
|
|
163
|
+
|
|
164
|
+
def suggest(
|
|
165
|
+
self,
|
|
166
|
+
n: int,
|
|
167
|
+
k: int,
|
|
168
|
+
power_scale: float,
|
|
169
|
+
hamming_distance: int,
|
|
170
|
+
sample_idxs: List[int],
|
|
171
|
+
origin_idx: int,
|
|
172
|
+
payoffs: List[float],
|
|
173
|
+
) -> int:
|
|
174
|
+
# Create context from input data
|
|
175
|
+
context = create_context(
|
|
176
|
+
n=n,
|
|
177
|
+
k=k,
|
|
178
|
+
power_scale=power_scale,
|
|
179
|
+
sample_idxs=sample_idxs,
|
|
180
|
+
origin_idx=origin_idx,
|
|
181
|
+
hamming_distance=hamming_distance,
|
|
182
|
+
payoff=payoffs,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Set max suggestion length based on n
|
|
186
|
+
target_length = (
|
|
187
|
+
n * 2 - 1
|
|
188
|
+
) # Length needed for binary string with commas
|
|
189
|
+
|
|
190
|
+
if self.config.use_mpcdf_vllm:
|
|
191
|
+
binary_string = mpcdf_vllm_request(
|
|
192
|
+
prompt=context,
|
|
193
|
+
max_tokens=target_length,
|
|
194
|
+
**self.generation_params,
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
binary_string = self.suggest_target_from_context(
|
|
198
|
+
context, target_length
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Parse the suggestion to an integer
|
|
202
|
+
try:
|
|
203
|
+
suggestion_int = target_to_int(binary_string)
|
|
204
|
+
return suggestion_int
|
|
205
|
+
except Exception as e:
|
|
206
|
+
print(f"Error parsing suggestion '{binary_string}': {e}")
|
|
207
|
+
return -1
|
model/parser.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Optional, Union
|
|
3
|
+
|
|
4
|
+
from src.utils.binary_conversion import binary_str_to_int, int_to_binary_str
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def create_context(
|
|
10
|
+
*,
|
|
11
|
+
n: int,
|
|
12
|
+
k: int,
|
|
13
|
+
power_scale: float,
|
|
14
|
+
sample_idxs: List[int],
|
|
15
|
+
origin_idx: int,
|
|
16
|
+
hamming_distance: int,
|
|
17
|
+
payoff: Optional[List[Union[float, int]]] = None,
|
|
18
|
+
include_payoff=True,
|
|
19
|
+
**kwargs,
|
|
20
|
+
) -> str:
|
|
21
|
+
# Header
|
|
22
|
+
doc = ""
|
|
23
|
+
doc += f"n: {n}\n"
|
|
24
|
+
doc += f"k: {k}\n"
|
|
25
|
+
doc += f"p: {power_scale:.2f}\n"
|
|
26
|
+
doc += f"h: {hamming_distance}\n"
|
|
27
|
+
doc += "\n"
|
|
28
|
+
# Sample rows
|
|
29
|
+
if include_payoff:
|
|
30
|
+
assert (
|
|
31
|
+
payoff is not None
|
|
32
|
+
), "Payoff values must be provided when include_payoff is True"
|
|
33
|
+
doc += "sample,payoff\n"
|
|
34
|
+
doc += "\n".join(
|
|
35
|
+
[
|
|
36
|
+
f"{int_to_binary_str(idx)},{p}"
|
|
37
|
+
for idx, p in zip(sample_idxs, payoff)
|
|
38
|
+
]
|
|
39
|
+
)
|
|
40
|
+
else:
|
|
41
|
+
doc += "sample\n"
|
|
42
|
+
doc += "\n".join([int_to_binary_str(idx) for idx in sample_idxs])
|
|
43
|
+
doc += "\n"
|
|
44
|
+
doc += "\n"
|
|
45
|
+
# Constraints
|
|
46
|
+
doc += "user\n"
|
|
47
|
+
doc += int_to_binary_str(origin_idx) + "\n"
|
|
48
|
+
doc += "\n"
|
|
49
|
+
# Target
|
|
50
|
+
doc += "assistant\n"
|
|
51
|
+
return doc
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def create_target(*, target_idx: int) -> str:
|
|
55
|
+
"""
|
|
56
|
+
Create a target string from a target index.
|
|
57
|
+
"""
|
|
58
|
+
return int_to_binary_str(target_idx)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def target_to_int(target):
|
|
62
|
+
"""Convert target string to integer (deprecated, use binary_str_to_int)."""
|
|
63
|
+
int_target = binary_str_to_int(target)
|
|
64
|
+
logger.debug(f"target: {target}, int_target: {int_target}")
|
|
65
|
+
return int_target
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def eval_target(target, ranks):
|
|
69
|
+
target_idx = binary_str_to_int(target)
|
|
70
|
+
return ranks[target_idx]
|
model/run_slurm.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
# import shutil
|
|
6
|
+
import subprocess
|
|
7
|
+
import sys
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
|
|
10
|
+
# from collections import OrderedDict
|
|
11
|
+
import yaml
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def generate_local_job_id():
|
|
15
|
+
"""
|
|
16
|
+
Generates a local job ID based on timestamp.
|
|
17
|
+
"""
|
|
18
|
+
return datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def parse_unknown_args(unknown_args, argv):
|
|
22
|
+
# Process unknown arguments to create a dictionary
|
|
23
|
+
dynamic_args = {}
|
|
24
|
+
for arg in unknown_args:
|
|
25
|
+
if arg.startswith("--"):
|
|
26
|
+
key = arg.lstrip("-")
|
|
27
|
+
# Assuming the next item in the list is the value
|
|
28
|
+
if argv.index(arg) + 1 < len(argv):
|
|
29
|
+
value = argv[argv.index(arg) + 1]
|
|
30
|
+
dynamic_args[key] = value
|
|
31
|
+
|
|
32
|
+
return dynamic_args
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def define_output(args):
|
|
36
|
+
# Compute additional arguments
|
|
37
|
+
if args["output_dir"] is None:
|
|
38
|
+
args["output_dir"] = os.path.join(
|
|
39
|
+
"experiments",
|
|
40
|
+
args["group_name"],
|
|
41
|
+
args["job_name"],
|
|
42
|
+
args["job_id"],
|
|
43
|
+
)
|
|
44
|
+
if not os.path.exists(args["output_dir"]):
|
|
45
|
+
os.makedirs(args["output_dir"])
|
|
46
|
+
return args
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def define_compute_resources(args):
|
|
50
|
+
if args["n_gpu"] > 4:
|
|
51
|
+
assert args["n_gpu"] % 4 == 0
|
|
52
|
+
n_nodes = args["n_gpu"] // 4
|
|
53
|
+
n_gpu = 4
|
|
54
|
+
else:
|
|
55
|
+
n_nodes = 1
|
|
56
|
+
n_gpu = args["n_gpu"]
|
|
57
|
+
|
|
58
|
+
if n_gpu >= 4:
|
|
59
|
+
memory = 0
|
|
60
|
+
else:
|
|
61
|
+
memory = 125000 * n_gpu
|
|
62
|
+
partition = "gpu"
|
|
63
|
+
cpu = n_gpu * 18 # 18 cores per GPU
|
|
64
|
+
|
|
65
|
+
args = {
|
|
66
|
+
**args,
|
|
67
|
+
"n_nodes": n_nodes,
|
|
68
|
+
"n_gpu": n_gpu,
|
|
69
|
+
"n_cpu": cpu,
|
|
70
|
+
"partition": partition,
|
|
71
|
+
"memory": memory,
|
|
72
|
+
}
|
|
73
|
+
return args
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def deep_merge_configs(main_config, included_config):
|
|
77
|
+
"""
|
|
78
|
+
Deep merges two configurations with the main config retaining its order.
|
|
79
|
+
"""
|
|
80
|
+
for key, value in included_config.items():
|
|
81
|
+
if (
|
|
82
|
+
key in main_config
|
|
83
|
+
and isinstance(main_config[key], dict)
|
|
84
|
+
and isinstance(value, dict)
|
|
85
|
+
):
|
|
86
|
+
deep_merge_configs(main_config[key], value)
|
|
87
|
+
else:
|
|
88
|
+
main_config[key] = value
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def find_include_value(data, target_key):
|
|
92
|
+
if isinstance(data, dict):
|
|
93
|
+
for key, value in data.items():
|
|
94
|
+
if key == target_key:
|
|
95
|
+
return value
|
|
96
|
+
elif isinstance(value, dict):
|
|
97
|
+
result = find_include_value(value, target_key)
|
|
98
|
+
if result is not None:
|
|
99
|
+
return result
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _convert_to_float(value):
|
|
104
|
+
"""Convert value to float, return None on error."""
|
|
105
|
+
try:
|
|
106
|
+
return float(value)
|
|
107
|
+
except ValueError:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _convert_to_int(value):
|
|
112
|
+
"""Convert value to int, return None on error."""
|
|
113
|
+
try:
|
|
114
|
+
return int(value)
|
|
115
|
+
except ValueError:
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _convert_to_list_int(value):
|
|
120
|
+
"""Convert comma-separated string to list of ints, return None on error."""
|
|
121
|
+
try:
|
|
122
|
+
return list(map(int, value.split(",")))
|
|
123
|
+
except ValueError:
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _convert_replacement_value(placeholder_type, replacement_value):
|
|
128
|
+
"""Convert replacement value to specified type."""
|
|
129
|
+
if replacement_value is None:
|
|
130
|
+
return replacement_value
|
|
131
|
+
if placeholder_type == "float":
|
|
132
|
+
return _convert_to_float(replacement_value)
|
|
133
|
+
if placeholder_type == "int":
|
|
134
|
+
return _convert_to_int(replacement_value)
|
|
135
|
+
if placeholder_type == "bool":
|
|
136
|
+
return replacement_value.lower() == "true"
|
|
137
|
+
if placeholder_type == "list_int":
|
|
138
|
+
return _convert_to_list_int(replacement_value)
|
|
139
|
+
return replacement_value
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def replace_placeholder(element, replacements=None):
|
|
143
|
+
if replacements is None:
|
|
144
|
+
replacements = {}
|
|
145
|
+
|
|
146
|
+
if isinstance(element, dict):
|
|
147
|
+
for key, value in element.items():
|
|
148
|
+
element[key] = replace_placeholder(value, replacements)
|
|
149
|
+
elif isinstance(element, list):
|
|
150
|
+
return [replace_placeholder(item, replacements) for item in element]
|
|
151
|
+
elif isinstance(element, str):
|
|
152
|
+
|
|
153
|
+
def replacement_function(match):
|
|
154
|
+
placeholder_type = match.group(1) if match.group(1) else "str"
|
|
155
|
+
placeholder_variable = match.group(2)
|
|
156
|
+
replacement_value = replacements.get(
|
|
157
|
+
placeholder_variable, match.group(0)
|
|
158
|
+
)
|
|
159
|
+
return _convert_replacement_value(
|
|
160
|
+
placeholder_type, replacement_value
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
pattern = r"<<(?:(\w+): )?(\w+)>>"
|
|
164
|
+
matches = list(re.finditer(pattern, element))
|
|
165
|
+
|
|
166
|
+
# If exactly one match and it spans the entire string, perform
|
|
167
|
+
# type conversion
|
|
168
|
+
if len(matches) == 1 and matches[0].span() == (0, len(element)):
|
|
169
|
+
return replacement_function(matches[0])
|
|
170
|
+
|
|
171
|
+
# For strings with multiple placeholders or additional text, replace
|
|
172
|
+
# without type conversion
|
|
173
|
+
def string_replacement_function(match):
|
|
174
|
+
return str(replacement_function(match))
|
|
175
|
+
|
|
176
|
+
result = re.sub(pattern, string_replacement_function, element)
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
return element
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def load_and_merge_configs(
|
|
183
|
+
config_path,
|
|
184
|
+
):
|
|
185
|
+
"""
|
|
186
|
+
Loads configuration from the main file and merges included configurations
|
|
187
|
+
while preserving the order.
|
|
188
|
+
"""
|
|
189
|
+
with open(config_path, "r") as file:
|
|
190
|
+
# Load the main configuration with FullLoader to preserve the order
|
|
191
|
+
main_config = yaml.load(file, Loader=yaml.FullLoader)
|
|
192
|
+
|
|
193
|
+
# Check if there are included configs and process them
|
|
194
|
+
includes = find_include_value(main_config, "__include")
|
|
195
|
+
if includes is not None:
|
|
196
|
+
for include_path in includes:
|
|
197
|
+
with open(include_path, "r") as inc_file:
|
|
198
|
+
included_config = yaml.load(inc_file, Loader=yaml.FullLoader)
|
|
199
|
+
deep_merge_configs(main_config, included_config)
|
|
200
|
+
|
|
201
|
+
return main_config
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def copy_config(config, args):
|
|
205
|
+
"""
|
|
206
|
+
Copies the given config file to a job-specific directory after merging
|
|
207
|
+
included configurations. Preserves the order of parameters in the main
|
|
208
|
+
config file.
|
|
209
|
+
"""
|
|
210
|
+
dest_filename = f"{args['job_id']}.yml"
|
|
211
|
+
dest_path = os.path.join(args["output_dir"], dest_filename)
|
|
212
|
+
|
|
213
|
+
with open(dest_path, "w") as file:
|
|
214
|
+
yaml.dump(
|
|
215
|
+
config, file, sort_keys=False
|
|
216
|
+
) # Prevent sorting keys on dump
|
|
217
|
+
args["copied_config_file"] = dest_path
|
|
218
|
+
return args
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def generate_bash_script(args):
|
|
222
|
+
"""
|
|
223
|
+
Reads in a bash template file and replaces placeholders with the given
|
|
224
|
+
config path. Writes the modified script to the job-specific directory.
|
|
225
|
+
"""
|
|
226
|
+
output_path = os.path.join(args["output_dir"], f"{args['job_id']}.sh")
|
|
227
|
+
|
|
228
|
+
with open(args["template"], "r") as file:
|
|
229
|
+
script = file.read().format(**args)
|
|
230
|
+
|
|
231
|
+
with open(output_path, "w") as file:
|
|
232
|
+
file.write(script)
|
|
233
|
+
return output_path
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def submit_script(script_path):
|
|
237
|
+
"""
|
|
238
|
+
Submits the given bash script to sbatch.
|
|
239
|
+
"""
|
|
240
|
+
subprocess.run(["sbatch", script_path])
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def main():
|
|
244
|
+
parser = argparse.ArgumentParser(
|
|
245
|
+
description="Submit jobs with documentation of the YAML configuration."
|
|
246
|
+
)
|
|
247
|
+
parser.add_argument(
|
|
248
|
+
"--config_file",
|
|
249
|
+
type=str,
|
|
250
|
+
required=True,
|
|
251
|
+
help="Path to the YAML configuration file.",
|
|
252
|
+
)
|
|
253
|
+
parser.add_argument(
|
|
254
|
+
"--output_dir",
|
|
255
|
+
type=str,
|
|
256
|
+
default=None,
|
|
257
|
+
help=(
|
|
258
|
+
"Path to the output data. If both are None, path of config file "
|
|
259
|
+
"is used."
|
|
260
|
+
),
|
|
261
|
+
)
|
|
262
|
+
parser.add_argument(
|
|
263
|
+
"--template",
|
|
264
|
+
type=str,
|
|
265
|
+
default="src/model/scripts/template.slurm",
|
|
266
|
+
help="Path to the bash script template.",
|
|
267
|
+
)
|
|
268
|
+
parser.add_argument(
|
|
269
|
+
"--dry",
|
|
270
|
+
action="store_true",
|
|
271
|
+
help="Only create files, do not submit the job.",
|
|
272
|
+
)
|
|
273
|
+
parser.add_argument(
|
|
274
|
+
"--n_gpu", type=int, default=1, help="Number of GPUs to use."
|
|
275
|
+
)
|
|
276
|
+
parser.add_argument(
|
|
277
|
+
"--time",
|
|
278
|
+
type=str,
|
|
279
|
+
default="00:10:00",
|
|
280
|
+
help="Expected runtime in HH:MM:SS format.",
|
|
281
|
+
)
|
|
282
|
+
parser.add_argument(
|
|
283
|
+
"--script",
|
|
284
|
+
type=str,
|
|
285
|
+
default="src/model/train.py",
|
|
286
|
+
help="Script to run.",
|
|
287
|
+
)
|
|
288
|
+
parser.add_argument(
|
|
289
|
+
"--job_name", type=str, default="v1", help="Name of the job."
|
|
290
|
+
)
|
|
291
|
+
parser.add_argument(
|
|
292
|
+
"--group_name",
|
|
293
|
+
type=str,
|
|
294
|
+
default="debug",
|
|
295
|
+
help="Name of the experiment group.",
|
|
296
|
+
)
|
|
297
|
+
parser.add_argument(
|
|
298
|
+
"--project_name",
|
|
299
|
+
type=str,
|
|
300
|
+
default="NK-Landscape",
|
|
301
|
+
help="Project to charge.",
|
|
302
|
+
)
|
|
303
|
+
parser.add_argument(
|
|
304
|
+
"--image",
|
|
305
|
+
type=str,
|
|
306
|
+
default="/u/lumi/projects/llm-strategic-tuning/images/ai_nk_rl.sif",
|
|
307
|
+
help="Apptainer image to use",
|
|
308
|
+
)
|
|
309
|
+
argv = sys.argv[1:]
|
|
310
|
+
|
|
311
|
+
known_args, unknown_args = parser.parse_known_args(sys.argv[1:])
|
|
312
|
+
dynamic_args = parse_unknown_args(unknown_args, argv)
|
|
313
|
+
args_dict = vars(known_args)
|
|
314
|
+
args_dict.update(dynamic_args)
|
|
315
|
+
|
|
316
|
+
args_dict = define_compute_resources(args_dict)
|
|
317
|
+
args_dict["job_id"] = generate_local_job_id()
|
|
318
|
+
args_dict = args_dict.copy()
|
|
319
|
+
|
|
320
|
+
args_dict = define_output(args_dict)
|
|
321
|
+
config = load_and_merge_configs(args_dict["config_file"])
|
|
322
|
+
config = replace_placeholder(config, args_dict)
|
|
323
|
+
args_dict = copy_config(config, args_dict)
|
|
324
|
+
|
|
325
|
+
args_dict["config_file"] = args_dict["copied_config_file"]
|
|
326
|
+
generated_script = generate_bash_script(args_dict)
|
|
327
|
+
|
|
328
|
+
if not args_dict["dry"]:
|
|
329
|
+
submit_script(generated_script)
|
|
330
|
+
else:
|
|
331
|
+
print(f"Generated script at {generated_script} without submission.")
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
if __name__ == "__main__":
|
|
335
|
+
main()
|