ai-nk-cce 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ai_nk_cce-0.1.0.dist-info/METADATA +118 -0
  2. ai_nk_cce-0.1.0.dist-info/RECORD +46 -0
  3. ai_nk_cce-0.1.0.dist-info/WHEEL +4 -0
  4. api/__init__.py +0 -0
  5. api/mpcdf_vllm.py +94 -0
  6. evals/nk_model.py +277 -0
  7. model/README.md +64 -0
  8. model/config/dataset_conv_v1.yml +9 -0
  9. model/config/dataset_conv_v2_m2.yml +9 -0
  10. model/config/dataset_conv_v3_m2_assembl_nearest.yml +9 -0
  11. model/config/dataset_debug.yml +9 -0
  12. model/config/dataset_v4_int_format.yml +9 -0
  13. model/config/dataset_v5.yml +9 -0
  14. model/config/inference.yml +7 -0
  15. model/config/train.yml +24 -0
  16. model/config/train_debug.yml +19 -0
  17. model/config/train_from_checkpoint.yml +24 -0
  18. model/config/train_from_checkpoint_debug.yml +19 -0
  19. model/config/train_grpo.yml +30 -0
  20. model/config/train_grpo_debug.yml +30 -0
  21. model/config/train_grpo_debug_vllm.yml +32 -0
  22. model/config.py +54 -0
  23. model/dataset.py +324 -0
  24. model/inference.py +51 -0
  25. model/nk_assistant.py +207 -0
  26. model/parser.py +70 -0
  27. model/run_slurm.py +335 -0
  28. model/score.ipynb +596 -0
  29. model/scripts/template.slurm +54 -0
  30. model/scripts/template_rl.slurm +54 -0
  31. model/train.py +293 -0
  32. nk_model/__init__.py +0 -0
  33. nk_model/assembler.py +112 -0
  34. nk_model/biased_prediction_agent.py +389 -0
  35. nk_model/dataset.py +434 -0
  36. nk_model/enums.py +21 -0
  37. nk_model/landscape_cache.py +149 -0
  38. nk_model/models.py +172 -0
  39. nk_model/nk_landscape.py +498 -0
  40. simulation/hill_climber_simulation.py +211 -0
  41. simulation/hill_climber_vs_ai_simulation.py +132 -0
  42. simulation/landscape_selection.py +179 -0
  43. utils/__init__.py +0 -0
  44. utils/binary_conversion.py +128 -0
  45. utils/logging.py +33 -0
  46. utils/utils.py +51 -0
model/nk_assistant.py ADDED
@@ -0,0 +1,207 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ from src.api.mpcdf_vllm import mpcdf_vllm_request
10
+ from src.model.config import NKAssistantConfig, TrainConfig
11
+ from src.model.parser import create_context, create_target, target_to_int
12
+
13
+
14
+ def load_model_and_tokenizer(model_path: str):
15
+ """Load the model and tokenizer from the given path"""
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
18
+ except Exception:
19
+ # Fix if tokenizer is not found
20
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
21
+ model = AutoModelForCausalLM.from_pretrained(model_path)
22
+
23
+ # Ensure pad_token is set
24
+ if tokenizer.pad_token_id is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ # Move model to GPU if available
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ model.to(device)
30
+
31
+ return model, tokenizer, device
32
+
33
+
34
+ class NKAssistant:
35
+ def __init__(
36
+ self, config: NKAssistantConfig, metadata: Optional[TrainConfig] = None
37
+ ):
38
+ if not config.use_mpcdf_vllm:
39
+ model, tokenizer, device = load_model_and_tokenizer(
40
+ config.model_path
41
+ )
42
+ else:
43
+ model = None
44
+ tokenizer = None
45
+ device = None
46
+ self.model = model
47
+ self.tokenizer = tokenizer
48
+ self.device = device
49
+ self.config = config
50
+ self.metadata = metadata
51
+ self.generation_params = config.generation_params
52
+
53
+ def _save_pretrained(self, save_directory: Path) -> None:
54
+ save_directory.mkdir(parents=True, exist_ok=True)
55
+ self.model.save_pretrained(save_directory)
56
+ self.tokenizer.save_pretrained(save_directory)
57
+ if self.metadata:
58
+ with open(save_directory / "metadata.json", "w") as f:
59
+ json.dump(self.metadata.model_dump(), f)
60
+
61
+ @classmethod
62
+ def from_pretrained(
63
+ cls, pretrained_model_name_or_path: Union[str, Path], **kwargs
64
+ ):
65
+ # Use Hugging Face logic to locate the local or remote files
66
+ model_path = Path(pretrained_model_name_or_path)
67
+
68
+ if not model_path.exists():
69
+ # Download files from Hub into cache
70
+ from huggingface_hub import snapshot_download
71
+
72
+ model_path = Path(
73
+ snapshot_download(repo_id=pretrained_model_name_or_path)
74
+ )
75
+
76
+ metadata_path = model_path / "metadata.json"
77
+ metadata = None
78
+ if metadata_path.exists():
79
+ with open(metadata_path) as f:
80
+ metadata = TrainConfig(**json.load(f))
81
+
82
+ config = metadata.assistant_config
83
+ config.model_path = str(model_path)
84
+
85
+ return cls(config=config, metadata=metadata)
86
+
87
+ @staticmethod
88
+ def create_text_from_row(row, include_payoff=True, include_target=True):
89
+ sample_payoff = np.array(row["payoffs"])[row["sample_idx"]]
90
+
91
+ context = create_context(
92
+ n=row["n"],
93
+ k=row["k"],
94
+ power_scale=row["power_scale"],
95
+ sample_idxs=row["sample_idx"],
96
+ origin_idx=row["origin_idx"],
97
+ hamming_distance=row["hamming_distance"],
98
+ payoff=sample_payoff,
99
+ include_payoff=include_payoff,
100
+ )
101
+ if include_target:
102
+ target = create_target(target_idx=row["target_idx"])
103
+ return {"context": context, "target": target}
104
+ return {"context": context}
105
+
106
+ @staticmethod
107
+ def create_prompt_for_rl_from_row(row):
108
+ sample_payoff = np.array(row["payoffs"])[row["sample_idx"]]
109
+ context = create_context(
110
+ n=row["n"],
111
+ k=row["k"],
112
+ power_scale=row["power_scale"],
113
+ sample_idxs=row["sample_idx"],
114
+ origin_idx=row["origin_idx"],
115
+ hamming_distance=row["hamming_distance"],
116
+ payoff=sample_payoff,
117
+ include_payoff=True,
118
+ )
119
+ return {"prompt": context}
120
+
121
+ def suggest_target_from_context(
122
+ self, context: str, target_length: int
123
+ ) -> str:
124
+ # Tokenize input
125
+ input_ids = self.tokenizer.encode(context, return_tensors="pt").to(
126
+ self.device
127
+ )
128
+ n_input = input_ids.shape[1]
129
+
130
+ # Generate output
131
+ with torch.no_grad():
132
+ output = self.model.generate(
133
+ input_ids,
134
+ max_length=n_input + target_length,
135
+ pad_token_id=self.tokenizer.eos_token_id,
136
+ **self.generation_params,
137
+ )
138
+
139
+ # Decode only the new tokens (exclude the input tokens)
140
+ suggestion = self.tokenizer.decode(
141
+ output[0, n_input:], skip_special_tokens=True
142
+ )
143
+
144
+ return suggestion
145
+
146
+ def suggest_from_row(self, row):
147
+ context = self.create_text_from_row(row, include_target=False)[
148
+ "context"
149
+ ]
150
+ n = row["n"]
151
+ target_length = n * 2 - 1
152
+ if self.config.use_mpcdf_vllm:
153
+ binary_string = mpcdf_vllm_request(
154
+ prompt=context,
155
+ max_tokens=target_length,
156
+ **self.generation_params,
157
+ )
158
+ else:
159
+ binary_string = self.suggest_target_from_context(
160
+ context, target_length
161
+ )
162
+ return {"suggestion": binary_string}
163
+
164
+ def suggest(
165
+ self,
166
+ n: int,
167
+ k: int,
168
+ power_scale: float,
169
+ hamming_distance: int,
170
+ sample_idxs: List[int],
171
+ origin_idx: int,
172
+ payoffs: List[float],
173
+ ) -> int:
174
+ # Create context from input data
175
+ context = create_context(
176
+ n=n,
177
+ k=k,
178
+ power_scale=power_scale,
179
+ sample_idxs=sample_idxs,
180
+ origin_idx=origin_idx,
181
+ hamming_distance=hamming_distance,
182
+ payoff=payoffs,
183
+ )
184
+
185
+ # Set max suggestion length based on n
186
+ target_length = (
187
+ n * 2 - 1
188
+ ) # Length needed for binary string with commas
189
+
190
+ if self.config.use_mpcdf_vllm:
191
+ binary_string = mpcdf_vllm_request(
192
+ prompt=context,
193
+ max_tokens=target_length,
194
+ **self.generation_params,
195
+ )
196
+ else:
197
+ binary_string = self.suggest_target_from_context(
198
+ context, target_length
199
+ )
200
+
201
+ # Parse the suggestion to an integer
202
+ try:
203
+ suggestion_int = target_to_int(binary_string)
204
+ return suggestion_int
205
+ except Exception as e:
206
+ print(f"Error parsing suggestion '{binary_string}': {e}")
207
+ return -1
model/parser.py ADDED
@@ -0,0 +1,70 @@
1
+ import logging
2
+ from typing import List, Optional, Union
3
+
4
+ from src.utils.binary_conversion import binary_str_to_int, int_to_binary_str
5
+
6
+ logger = logging.getLogger()
7
+
8
+
9
+ def create_context(
10
+ *,
11
+ n: int,
12
+ k: int,
13
+ power_scale: float,
14
+ sample_idxs: List[int],
15
+ origin_idx: int,
16
+ hamming_distance: int,
17
+ payoff: Optional[List[Union[float, int]]] = None,
18
+ include_payoff=True,
19
+ **kwargs,
20
+ ) -> str:
21
+ # Header
22
+ doc = ""
23
+ doc += f"n: {n}\n"
24
+ doc += f"k: {k}\n"
25
+ doc += f"p: {power_scale:.2f}\n"
26
+ doc += f"h: {hamming_distance}\n"
27
+ doc += "\n"
28
+ # Sample rows
29
+ if include_payoff:
30
+ assert (
31
+ payoff is not None
32
+ ), "Payoff values must be provided when include_payoff is True"
33
+ doc += "sample,payoff\n"
34
+ doc += "\n".join(
35
+ [
36
+ f"{int_to_binary_str(idx)},{p}"
37
+ for idx, p in zip(sample_idxs, payoff)
38
+ ]
39
+ )
40
+ else:
41
+ doc += "sample\n"
42
+ doc += "\n".join([int_to_binary_str(idx) for idx in sample_idxs])
43
+ doc += "\n"
44
+ doc += "\n"
45
+ # Constraints
46
+ doc += "user\n"
47
+ doc += int_to_binary_str(origin_idx) + "\n"
48
+ doc += "\n"
49
+ # Target
50
+ doc += "assistant\n"
51
+ return doc
52
+
53
+
54
+ def create_target(*, target_idx: int) -> str:
55
+ """
56
+ Create a target string from a target index.
57
+ """
58
+ return int_to_binary_str(target_idx)
59
+
60
+
61
+ def target_to_int(target):
62
+ """Convert target string to integer (deprecated, use binary_str_to_int)."""
63
+ int_target = binary_str_to_int(target)
64
+ logger.debug(f"target: {target}, int_target: {int_target}")
65
+ return int_target
66
+
67
+
68
+ def eval_target(target, ranks):
69
+ target_idx = binary_str_to_int(target)
70
+ return ranks[target_idx]
model/run_slurm.py ADDED
@@ -0,0 +1,335 @@
1
+ import argparse
2
+ import os
3
+ import re
4
+
5
+ # import shutil
6
+ import subprocess
7
+ import sys
8
+ from datetime import datetime
9
+
10
+ # from collections import OrderedDict
11
+ import yaml
12
+
13
+
14
+ def generate_local_job_id():
15
+ """
16
+ Generates a local job ID based on timestamp.
17
+ """
18
+ return datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
19
+
20
+
21
+ def parse_unknown_args(unknown_args, argv):
22
+ # Process unknown arguments to create a dictionary
23
+ dynamic_args = {}
24
+ for arg in unknown_args:
25
+ if arg.startswith("--"):
26
+ key = arg.lstrip("-")
27
+ # Assuming the next item in the list is the value
28
+ if argv.index(arg) + 1 < len(argv):
29
+ value = argv[argv.index(arg) + 1]
30
+ dynamic_args[key] = value
31
+
32
+ return dynamic_args
33
+
34
+
35
+ def define_output(args):
36
+ # Compute additional arguments
37
+ if args["output_dir"] is None:
38
+ args["output_dir"] = os.path.join(
39
+ "experiments",
40
+ args["group_name"],
41
+ args["job_name"],
42
+ args["job_id"],
43
+ )
44
+ if not os.path.exists(args["output_dir"]):
45
+ os.makedirs(args["output_dir"])
46
+ return args
47
+
48
+
49
+ def define_compute_resources(args):
50
+ if args["n_gpu"] > 4:
51
+ assert args["n_gpu"] % 4 == 0
52
+ n_nodes = args["n_gpu"] // 4
53
+ n_gpu = 4
54
+ else:
55
+ n_nodes = 1
56
+ n_gpu = args["n_gpu"]
57
+
58
+ if n_gpu >= 4:
59
+ memory = 0
60
+ else:
61
+ memory = 125000 * n_gpu
62
+ partition = "gpu"
63
+ cpu = n_gpu * 18 # 18 cores per GPU
64
+
65
+ args = {
66
+ **args,
67
+ "n_nodes": n_nodes,
68
+ "n_gpu": n_gpu,
69
+ "n_cpu": cpu,
70
+ "partition": partition,
71
+ "memory": memory,
72
+ }
73
+ return args
74
+
75
+
76
+ def deep_merge_configs(main_config, included_config):
77
+ """
78
+ Deep merges two configurations with the main config retaining its order.
79
+ """
80
+ for key, value in included_config.items():
81
+ if (
82
+ key in main_config
83
+ and isinstance(main_config[key], dict)
84
+ and isinstance(value, dict)
85
+ ):
86
+ deep_merge_configs(main_config[key], value)
87
+ else:
88
+ main_config[key] = value
89
+
90
+
91
+ def find_include_value(data, target_key):
92
+ if isinstance(data, dict):
93
+ for key, value in data.items():
94
+ if key == target_key:
95
+ return value
96
+ elif isinstance(value, dict):
97
+ result = find_include_value(value, target_key)
98
+ if result is not None:
99
+ return result
100
+ return None
101
+
102
+
103
+ def _convert_to_float(value):
104
+ """Convert value to float, return None on error."""
105
+ try:
106
+ return float(value)
107
+ except ValueError:
108
+ return None
109
+
110
+
111
+ def _convert_to_int(value):
112
+ """Convert value to int, return None on error."""
113
+ try:
114
+ return int(value)
115
+ except ValueError:
116
+ return None
117
+
118
+
119
+ def _convert_to_list_int(value):
120
+ """Convert comma-separated string to list of ints, return None on error."""
121
+ try:
122
+ return list(map(int, value.split(",")))
123
+ except ValueError:
124
+ return None
125
+
126
+
127
+ def _convert_replacement_value(placeholder_type, replacement_value):
128
+ """Convert replacement value to specified type."""
129
+ if replacement_value is None:
130
+ return replacement_value
131
+ if placeholder_type == "float":
132
+ return _convert_to_float(replacement_value)
133
+ if placeholder_type == "int":
134
+ return _convert_to_int(replacement_value)
135
+ if placeholder_type == "bool":
136
+ return replacement_value.lower() == "true"
137
+ if placeholder_type == "list_int":
138
+ return _convert_to_list_int(replacement_value)
139
+ return replacement_value
140
+
141
+
142
+ def replace_placeholder(element, replacements=None):
143
+ if replacements is None:
144
+ replacements = {}
145
+
146
+ if isinstance(element, dict):
147
+ for key, value in element.items():
148
+ element[key] = replace_placeholder(value, replacements)
149
+ elif isinstance(element, list):
150
+ return [replace_placeholder(item, replacements) for item in element]
151
+ elif isinstance(element, str):
152
+
153
+ def replacement_function(match):
154
+ placeholder_type = match.group(1) if match.group(1) else "str"
155
+ placeholder_variable = match.group(2)
156
+ replacement_value = replacements.get(
157
+ placeholder_variable, match.group(0)
158
+ )
159
+ return _convert_replacement_value(
160
+ placeholder_type, replacement_value
161
+ )
162
+
163
+ pattern = r"<<(?:(\w+): )?(\w+)>>"
164
+ matches = list(re.finditer(pattern, element))
165
+
166
+ # If exactly one match and it spans the entire string, perform
167
+ # type conversion
168
+ if len(matches) == 1 and matches[0].span() == (0, len(element)):
169
+ return replacement_function(matches[0])
170
+
171
+ # For strings with multiple placeholders or additional text, replace
172
+ # without type conversion
173
+ def string_replacement_function(match):
174
+ return str(replacement_function(match))
175
+
176
+ result = re.sub(pattern, string_replacement_function, element)
177
+ return result
178
+
179
+ return element
180
+
181
+
182
+ def load_and_merge_configs(
183
+ config_path,
184
+ ):
185
+ """
186
+ Loads configuration from the main file and merges included configurations
187
+ while preserving the order.
188
+ """
189
+ with open(config_path, "r") as file:
190
+ # Load the main configuration with FullLoader to preserve the order
191
+ main_config = yaml.load(file, Loader=yaml.FullLoader)
192
+
193
+ # Check if there are included configs and process them
194
+ includes = find_include_value(main_config, "__include")
195
+ if includes is not None:
196
+ for include_path in includes:
197
+ with open(include_path, "r") as inc_file:
198
+ included_config = yaml.load(inc_file, Loader=yaml.FullLoader)
199
+ deep_merge_configs(main_config, included_config)
200
+
201
+ return main_config
202
+
203
+
204
+ def copy_config(config, args):
205
+ """
206
+ Copies the given config file to a job-specific directory after merging
207
+ included configurations. Preserves the order of parameters in the main
208
+ config file.
209
+ """
210
+ dest_filename = f"{args['job_id']}.yml"
211
+ dest_path = os.path.join(args["output_dir"], dest_filename)
212
+
213
+ with open(dest_path, "w") as file:
214
+ yaml.dump(
215
+ config, file, sort_keys=False
216
+ ) # Prevent sorting keys on dump
217
+ args["copied_config_file"] = dest_path
218
+ return args
219
+
220
+
221
+ def generate_bash_script(args):
222
+ """
223
+ Reads in a bash template file and replaces placeholders with the given
224
+ config path. Writes the modified script to the job-specific directory.
225
+ """
226
+ output_path = os.path.join(args["output_dir"], f"{args['job_id']}.sh")
227
+
228
+ with open(args["template"], "r") as file:
229
+ script = file.read().format(**args)
230
+
231
+ with open(output_path, "w") as file:
232
+ file.write(script)
233
+ return output_path
234
+
235
+
236
+ def submit_script(script_path):
237
+ """
238
+ Submits the given bash script to sbatch.
239
+ """
240
+ subprocess.run(["sbatch", script_path])
241
+
242
+
243
+ def main():
244
+ parser = argparse.ArgumentParser(
245
+ description="Submit jobs with documentation of the YAML configuration."
246
+ )
247
+ parser.add_argument(
248
+ "--config_file",
249
+ type=str,
250
+ required=True,
251
+ help="Path to the YAML configuration file.",
252
+ )
253
+ parser.add_argument(
254
+ "--output_dir",
255
+ type=str,
256
+ default=None,
257
+ help=(
258
+ "Path to the output data. If both are None, path of config file "
259
+ "is used."
260
+ ),
261
+ )
262
+ parser.add_argument(
263
+ "--template",
264
+ type=str,
265
+ default="src/model/scripts/template.slurm",
266
+ help="Path to the bash script template.",
267
+ )
268
+ parser.add_argument(
269
+ "--dry",
270
+ action="store_true",
271
+ help="Only create files, do not submit the job.",
272
+ )
273
+ parser.add_argument(
274
+ "--n_gpu", type=int, default=1, help="Number of GPUs to use."
275
+ )
276
+ parser.add_argument(
277
+ "--time",
278
+ type=str,
279
+ default="00:10:00",
280
+ help="Expected runtime in HH:MM:SS format.",
281
+ )
282
+ parser.add_argument(
283
+ "--script",
284
+ type=str,
285
+ default="src/model/train.py",
286
+ help="Script to run.",
287
+ )
288
+ parser.add_argument(
289
+ "--job_name", type=str, default="v1", help="Name of the job."
290
+ )
291
+ parser.add_argument(
292
+ "--group_name",
293
+ type=str,
294
+ default="debug",
295
+ help="Name of the experiment group.",
296
+ )
297
+ parser.add_argument(
298
+ "--project_name",
299
+ type=str,
300
+ default="NK-Landscape",
301
+ help="Project to charge.",
302
+ )
303
+ parser.add_argument(
304
+ "--image",
305
+ type=str,
306
+ default="/u/lumi/projects/llm-strategic-tuning/images/ai_nk_rl.sif",
307
+ help="Apptainer image to use",
308
+ )
309
+ argv = sys.argv[1:]
310
+
311
+ known_args, unknown_args = parser.parse_known_args(sys.argv[1:])
312
+ dynamic_args = parse_unknown_args(unknown_args, argv)
313
+ args_dict = vars(known_args)
314
+ args_dict.update(dynamic_args)
315
+
316
+ args_dict = define_compute_resources(args_dict)
317
+ args_dict["job_id"] = generate_local_job_id()
318
+ args_dict = args_dict.copy()
319
+
320
+ args_dict = define_output(args_dict)
321
+ config = load_and_merge_configs(args_dict["config_file"])
322
+ config = replace_placeholder(config, args_dict)
323
+ args_dict = copy_config(config, args_dict)
324
+
325
+ args_dict["config_file"] = args_dict["copied_config_file"]
326
+ generated_script = generate_bash_script(args_dict)
327
+
328
+ if not args_dict["dry"]:
329
+ submit_script(generated_script)
330
+ else:
331
+ print(f"Generated script at {generated_script} without submission.")
332
+
333
+
334
+ if __name__ == "__main__":
335
+ main()