ai-nk-cce 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ai_nk_cce-0.1.0.dist-info/METADATA +118 -0
  2. ai_nk_cce-0.1.0.dist-info/RECORD +46 -0
  3. ai_nk_cce-0.1.0.dist-info/WHEEL +4 -0
  4. api/__init__.py +0 -0
  5. api/mpcdf_vllm.py +94 -0
  6. evals/nk_model.py +277 -0
  7. model/README.md +64 -0
  8. model/config/dataset_conv_v1.yml +9 -0
  9. model/config/dataset_conv_v2_m2.yml +9 -0
  10. model/config/dataset_conv_v3_m2_assembl_nearest.yml +9 -0
  11. model/config/dataset_debug.yml +9 -0
  12. model/config/dataset_v4_int_format.yml +9 -0
  13. model/config/dataset_v5.yml +9 -0
  14. model/config/inference.yml +7 -0
  15. model/config/train.yml +24 -0
  16. model/config/train_debug.yml +19 -0
  17. model/config/train_from_checkpoint.yml +24 -0
  18. model/config/train_from_checkpoint_debug.yml +19 -0
  19. model/config/train_grpo.yml +30 -0
  20. model/config/train_grpo_debug.yml +30 -0
  21. model/config/train_grpo_debug_vllm.yml +32 -0
  22. model/config.py +54 -0
  23. model/dataset.py +324 -0
  24. model/inference.py +51 -0
  25. model/nk_assistant.py +207 -0
  26. model/parser.py +70 -0
  27. model/run_slurm.py +335 -0
  28. model/score.ipynb +596 -0
  29. model/scripts/template.slurm +54 -0
  30. model/scripts/template_rl.slurm +54 -0
  31. model/train.py +293 -0
  32. nk_model/__init__.py +0 -0
  33. nk_model/assembler.py +112 -0
  34. nk_model/biased_prediction_agent.py +389 -0
  35. nk_model/dataset.py +434 -0
  36. nk_model/enums.py +21 -0
  37. nk_model/landscape_cache.py +149 -0
  38. nk_model/models.py +172 -0
  39. nk_model/nk_landscape.py +498 -0
  40. simulation/hill_climber_simulation.py +211 -0
  41. simulation/hill_climber_vs_ai_simulation.py +132 -0
  42. simulation/landscape_selection.py +179 -0
  43. utils/__init__.py +0 -0
  44. utils/binary_conversion.py +128 -0
  45. utils/logging.py +33 -0
  46. utils/utils.py +51 -0
@@ -0,0 +1,54 @@
1
+ #!/bin/bash -l
2
+ #SBATCH --output {output_dir}/slurm-%x-%j.out
3
+ #SBATCH --error {output_dir}/slurm-%x-%j.out
4
+ #SBATCH --chdir ./
5
+ #SBATCH --job-name {job_name}/{job_id}
6
+ #
7
+ #SBATCH --nodes={n_nodes}
8
+ #SBATCH --tasks-per-node=1
9
+ #SBATCH --cpus-per-task={n_cpu}
10
+ #SBATCH --mem={memory}
11
+ #
12
+ #SBATCH --constraint="gpu"
13
+ #SBATCH --gres=gpu:a100:{n_gpu}
14
+ #SBATCH --partition=gpu
15
+
16
+ # Wall clock limit (max is 24 hours):
17
+ #SBATCH --time={time}
18
+
19
+ module purge
20
+ module load apptainer
21
+
22
+ source .env
23
+
24
+ # create huggingface cache directory if it doesn't exist
25
+ mkdir -p ~/.cache/huggingface
26
+
27
+ echo "Runing training using the image: {image}"
28
+ echo "Runing training using the config: {config_file}"
29
+
30
+ srun apptainer exec \
31
+ --nv \
32
+ --contain \
33
+ --cleanenv \
34
+ --pwd /root/llm-strategic-tuning \
35
+ --bind .:/root/llm-strategic-tuning \
36
+ --bind ~/.cache/huggingface:/root/.cache/huggingface \
37
+ --bind /ptmp:/ptmp \
38
+ --env HUGGING_FACE_HUB_TOKEN="$HUGGINGFACE_TOKEN" \
39
+ --env WANDB_API_KEY="$WANDB_API_KEY" \
40
+ --env WANDB_ENTITY="chm-ml" \
41
+ --env WANDB_PROJECT="{project_name}" \
42
+ --env WANDB_RUN_GROUP="{group_name}" \
43
+ --env WANDB_NAME="{job_name}/{job_id}" \
44
+ --env NCCL_DEBUG="INFO" \
45
+ --env NCCL_BLOCKING_WAIT="0" \
46
+ --env HF_HOME="/root/.cache/huggingface" \
47
+ {image} \
48
+ bash -c "python -m torch.distributed.run \
49
+ --nnodes=\"$SLURM_NNODES\" \
50
+ --nproc-per-node=gpu \
51
+ --rdzv-id=\"$SLURM_JOBID\" \
52
+ --rdzv-endpoint=\$(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | head -n 1) \
53
+ --rdzv-backend=\"c10d\" \
54
+ {script} --config {config_file} --rl"
model/train.py ADDED
@@ -0,0 +1,293 @@
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ from datasets import load_from_disk
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from transformers import Trainer, TrainingArguments
10
+ from trl import GRPOConfig, GRPOTrainer
11
+
12
+ from src.model.config import TrainConfig
13
+ from src.model.nk_assistant import NKAssistant
14
+ from src.utils.utils import load_config_from_yaml
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def tokenize_function(row, tokenizer):
20
+ context = row["context"]
21
+ target = row["target"]
22
+ full_input = context + target
23
+
24
+ # Tokenize context and target separately
25
+ context_tokens = tokenizer(context, add_special_tokens=False)["input_ids"]
26
+
27
+ # Tokenize combined text with padding and truncation
28
+ tokenized = tokenizer(full_input, padding=False, return_tensors="np")
29
+
30
+ # Create labels and mask the context part
31
+ labels = tokenized["input_ids"].copy()
32
+ context_length = len(context_tokens)
33
+
34
+ # Mask context tokens by setting them to -100
35
+ labels[:, :context_length] = -100
36
+
37
+ tokenized["labels"] = labels
38
+
39
+ return {
40
+ "input_ids": tokenized["input_ids"].squeeze(0),
41
+ "attention_mask": tokenized["attention_mask"].squeeze(0),
42
+ "labels": labels.squeeze(0),
43
+ }
44
+
45
+
46
+ def custom_data_collator(features, tokenizer):
47
+ input_ids = [f["input_ids"].clone().detach() for f in features]
48
+ attention_mask = [f["attention_mask"].clone().detach() for f in features]
49
+ labels = [f["labels"].clone().detach() for f in features]
50
+
51
+ batch = {
52
+ "input_ids": pad_sequence(
53
+ input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
54
+ ),
55
+ "attention_mask": pad_sequence(
56
+ attention_mask, batch_first=True, padding_value=0
57
+ ),
58
+ "labels": pad_sequence(labels, batch_first=True, padding_value=-100),
59
+ }
60
+ return batch
61
+
62
+
63
+ def target_rank_reward_function(
64
+ completions: List[str], ranks: List[float], **kwargs
65
+ ) -> List[float]:
66
+ """
67
+ Reward function that evaluates the rank of a target completion based on
68
+ the provided ranks.
69
+
70
+ Args:
71
+ completions (List[str]): List of generated target completions in
72
+ binary string format
73
+ ranks (List[float]): List of rank values for each possible target
74
+ **kwargs: Additional arguments passed from the trainer
75
+
76
+ Returns:
77
+ List[float]: List of reward values for each completion, where each
78
+ reward is:
79
+ - The rank value from ranks if the completion is valid
80
+ - Different negative values based on the type of format violation
81
+ """
82
+
83
+ rewards = []
84
+ for completion, rank_scores in zip(completions, ranks):
85
+ try:
86
+ int_target = int(completion.replace(",", ""), 2)
87
+ reward = rank_scores[int_target]
88
+ rewards.append(reward)
89
+ except Exception as e:
90
+ # First, check if completion contains only valid characters
91
+ # (0, 1, comma, space)
92
+ valid_chars = set(["0", "1", ","])
93
+ if not all(c in valid_chars for c in completion):
94
+ logger.error(
95
+ f"Error evaluating completion: Contains invalid "
96
+ f"characters: {completion[:50]}..."
97
+ )
98
+ rewards.append(-2.00) # Severe penalty for invalid chars
99
+ continue
100
+
101
+ # Check if it's a valid binary string with commas
102
+ parts = completion.split(",")
103
+ # The length should be consistent with a valid binary string
104
+ # (e.g., 8 bits for n=8). We can infer expected length from the
105
+ # size of rank_scores array
106
+ expected_bits = (
107
+ int(np.log2(len(rank_scores))) if len(rank_scores) > 0 else 8
108
+ )
109
+ if len(parts) != expected_bits:
110
+ logger.error(
111
+ f"Error evaluating completion: Wrong length: "
112
+ f"{len(parts)} vs expected {expected_bits}"
113
+ )
114
+ rewards.append(-1.50) # Penalty for wrong length
115
+ continue
116
+
117
+ logger.error(f"Error evaluating completion: {e}")
118
+ rewards.append(-1.00) # Generic penalty for other errors
119
+ return rewards
120
+
121
+
122
+ def train(config: TrainConfig):
123
+ print("Starting supervised learning training...")
124
+ print(f"Training with config: {config}")
125
+ print(f"Dataset file: {config.dataset_file}")
126
+ print(f"Trainer args: {config.trainer_args}")
127
+ print(f"Assistant config: {config.assistant_config}")
128
+ print(f"Final model path: {config.final_model_path}")
129
+
130
+ ds = load_from_disk(config.dataset_file)
131
+
132
+ if config.trainer_args["num_train_epochs"] < 1:
133
+ for split in ds.keys():
134
+ print(
135
+ f"Split of fraction "
136
+ f"{config.trainer_args['num_train_epochs']} from dataset "
137
+ f"partition {split}"
138
+ )
139
+ fraction = int(
140
+ len(ds[split]) * config.trainer_args["num_train_epochs"]
141
+ )
142
+ ds[split] = ds[split].take(fraction)
143
+ print(
144
+ f"Fractioned dataset partition {split} has "
145
+ f"{len(ds[split])} samples"
146
+ )
147
+
148
+ ass = NKAssistant(config.assistant_config, metadata=config)
149
+
150
+ # Create text documents to train on
151
+ print("Create and add prompts to dataset")
152
+ ds = ds.map(ass.create_text_from_row)
153
+
154
+ # Tokenize text
155
+ print("Tokenize prompt")
156
+
157
+ def tokenize_function_with_tokenizer(row):
158
+ return tokenize_function(row, ass.tokenizer)
159
+
160
+ ds = ds.map(tokenize_function_with_tokenizer, batched=False)
161
+
162
+ # Turn into torch tensor
163
+ ds.set_format(
164
+ type="torch",
165
+ columns=["input_ids", "attention_mask", "labels"],
166
+ )
167
+
168
+ # Setup collator for padding of batches
169
+ def data_collator(features):
170
+ return custom_data_collator(features, ass.tokenizer)
171
+
172
+ # Initialize trainer with default values
173
+ trainer = Trainer(
174
+ model=ass.model,
175
+ args=TrainingArguments(**config.trainer_args),
176
+ train_dataset=ds["train"],
177
+ eval_dataset=ds["test"],
178
+ data_collator=data_collator,
179
+ tokenizer=ass.tokenizer,
180
+ )
181
+
182
+ # Train model
183
+ trainer.train()
184
+
185
+ ass._save_pretrained(Path(config.final_model_path))
186
+
187
+
188
+ def train_reinforcement(config: TrainConfig):
189
+ print("Starting reinforcement learning training...")
190
+ print(f"Training with config: {config}")
191
+ print(f"Dataset file: {config.dataset_file}")
192
+ print(f"Trainer args: {config.trainer_args}")
193
+ print(f"Assistant config: {config.assistant_config}")
194
+ print(f"Final model path: {config.final_model_path}")
195
+
196
+ ds = load_from_disk(config.dataset_file)
197
+
198
+ for split in ds.keys():
199
+ logger.debug(
200
+ f"Configs fraction for {split} split: "
201
+ f"{config.trainer_args['num_train_epochs']}"
202
+ )
203
+ fraction = int(
204
+ len(ds[split]) * config.trainer_args["num_train_epochs"]
205
+ )
206
+ if split != "train":
207
+ # steps = samples * grpo_creation_steps / batch_size * devices
208
+ # samples = wished_steps * batch_size * devices /
209
+ # grpo_creation_steps
210
+ fraction = int(min(600 * 24 * 4 / 8, fraction))
211
+
212
+ logger.debug(f"Fraction of {fraction} samples for {split} split")
213
+ ds[split] = ds[split].take(fraction)
214
+ logger.debug(
215
+ f"Fractioned dataset partition {split} has "
216
+ f"{len(ds[split])} samples"
217
+ )
218
+
219
+ ass = NKAssistant(config.assistant_config, metadata=config)
220
+
221
+ # Create prompts for reinforcement learning
222
+ print("Create and add prompts to dataset for RL")
223
+ ds = ds.map(ass.create_prompt_for_rl_from_row)
224
+
225
+ print("Successfully completed prompts creation mapping.")
226
+
227
+ # Turn into torch tensor while keeping the prompt field
228
+ ds.set_format(
229
+ type="torch",
230
+ columns=["prompt", "ranks"],
231
+ )
232
+
233
+ # Initialize GRPOConfig with default values
234
+ grpo_config = GRPOConfig(
235
+ **config.trainer_args # Override with any custom trainer args
236
+ )
237
+
238
+ print("Initializing GRPOTrainer...")
239
+ # Initialize GRPOTrainer
240
+ trainer = GRPOTrainer(
241
+ model=ass.model,
242
+ args=grpo_config,
243
+ train_dataset=ds["train"],
244
+ eval_dataset=ds["test"],
245
+ reward_funcs=target_rank_reward_function,
246
+ processing_class=ass.tokenizer,
247
+ )
248
+
249
+ print("Starting training...")
250
+ # Train model
251
+ trainer.train()
252
+
253
+ print("Saving model...")
254
+ ass._save_pretrained(Path(config.final_model_path))
255
+
256
+
257
+ if __name__ == "__main__":
258
+ # Parse command-line arguments
259
+ parser = argparse.ArgumentParser(
260
+ description="Train a model with the given config."
261
+ )
262
+ parser.add_argument(
263
+ "--config", type=str, required=True, help="Path to YAML config file"
264
+ )
265
+ parser.add_argument(
266
+ "--rl",
267
+ action="store_true",
268
+ help=(
269
+ "Use reinforcement learning training instead of supervised "
270
+ "learning"
271
+ ),
272
+ )
273
+ parser.add_argument(
274
+ "--rl_method",
275
+ type=str,
276
+ default="grpo",
277
+ choices=["grpo"],
278
+ help="Reinforcement learning method to use (default: grpo)",
279
+ )
280
+ args = parser.parse_args()
281
+
282
+ # Load config from YAML file
283
+ config = load_config_from_yaml(args.config, TrainConfig)
284
+
285
+ # Train the model using the appropriate method
286
+ if args.rl:
287
+ print("Training with reinforcement learning")
288
+ train_reinforcement(config)
289
+ else:
290
+ print("Training with supervised learning")
291
+ train(config)
292
+
293
+ print("Training completed successfully.")
nk_model/__init__.py ADDED
File without changes
nk_model/assembler.py ADDED
@@ -0,0 +1,112 @@
1
+ """Utilities for assembling sequences (vectors) and simple bit operations.
2
+
3
+ This module provides:
4
+ - Bit helper (`reverse01`) for boolean bit values
5
+ - Random selection helper (`random_elem`)
6
+ - Simple assemblers (`assembler_v1`, `assembler_v1_v2`, `assembler_v1_sym`)
7
+ - `ensemble_builder` to grow a set of vectors up to a target length
8
+ while avoiding duplicates and invalid outputs.
9
+ """
10
+
11
+ from random import choice
12
+ from typing import Callable, Optional, TypeVar
13
+
14
+ T = TypeVar("T")
15
+
16
+
17
+ def reverse01(bits: list[int]) -> list[int]:
18
+ """Return a new list where each bit is flipped."""
19
+ return [1 - bit for bit in bits]
20
+
21
+
22
+ def random_elem(items: list[T]) -> Optional[T]:
23
+ """Return a random element from the list, or None if the list is empty."""
24
+ return choice(items) if items else None
25
+
26
+
27
+ def assembler_v1(vectors: list[list[T]]) -> Optional[list[T]]:
28
+ """
29
+ Concatenate a randomly chosen vector with itself.
30
+
31
+ Returns None if the input list is empty.
32
+ """
33
+ if not vectors:
34
+ return None
35
+ v1 = random_elem(vectors)
36
+ return None if v1 is None else v1 + v1
37
+
38
+
39
+ def assembler_v1_v2(vectors: list[list[T]]) -> Optional[list[T]]:
40
+ """
41
+ Concatenate two randomly chosen vectors from the input.
42
+
43
+ Returns None if the input list is empty.
44
+ """
45
+ if not vectors:
46
+ return None
47
+ v1 = random_elem(vectors)
48
+ v2 = random_elem(vectors)
49
+ if v1 is None or v2 is None:
50
+ return None
51
+ return v1 + v2
52
+
53
+
54
+ def assembler_v1_sym(
55
+ vectors: list[list[int]],
56
+ ) -> Optional[list[int]]:
57
+ """
58
+ For integer vectors: pick v1 at random, then concatenate with v1
59
+ or its flip.
60
+
61
+ Returns None if the input list is empty.
62
+ """
63
+ if not vectors:
64
+ return None
65
+ v1 = random_elem(vectors)
66
+ if v1 is None:
67
+ return None
68
+ v2 = random_elem([v1, reverse01(v1)])
69
+ return None if v2 is None else v1 + v2
70
+
71
+
72
+ def ensemble_builder(
73
+ assembler: Callable[[list[list[T]]], Optional[list[T]]],
74
+ n: int,
75
+ max_len: int,
76
+ vectors: list[list[T]],
77
+ ) -> list[list[T]]:
78
+ """
79
+ Grow the starting ensemble of vectors by repeatedly assembling new vectors.
80
+
81
+ Args:
82
+ assembler:
83
+ Function to assemble new vectors from the current ensemble. Should
84
+ take a list of vectors and return a new vector.
85
+ n:
86
+ Number of vectors the ensemble should return of length max_len.
87
+ max_len:
88
+ Maximum length of the vectors in the ensemble.
89
+ The ensemble will return vectors of length max_len.
90
+ vectors:
91
+ Starting ensemble of vectors.
92
+ The ensemble will start with these vectors.
93
+
94
+ Returns:
95
+ List of vectors in the ensemble of length max_len.
96
+ """
97
+ if n == 0 or not vectors:
98
+ return vectors
99
+
100
+ new_vector = assembler(vectors)
101
+ if (
102
+ new_vector is None
103
+ or new_vector in vectors
104
+ or len(new_vector) > max_len
105
+ ):
106
+ return ensemble_builder(assembler, n, max_len, vectors)
107
+
108
+ new_vectors = vectors + [new_vector]
109
+ if len(new_vector) < max_len:
110
+ return ensemble_builder(assembler, n, max_len, new_vectors)
111
+
112
+ return ensemble_builder(assembler, n - 1, max_len, new_vectors)