ai-nk-cce 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_nk_cce-0.1.0.dist-info/METADATA +118 -0
- ai_nk_cce-0.1.0.dist-info/RECORD +46 -0
- ai_nk_cce-0.1.0.dist-info/WHEEL +4 -0
- api/__init__.py +0 -0
- api/mpcdf_vllm.py +94 -0
- evals/nk_model.py +277 -0
- model/README.md +64 -0
- model/config/dataset_conv_v1.yml +9 -0
- model/config/dataset_conv_v2_m2.yml +9 -0
- model/config/dataset_conv_v3_m2_assembl_nearest.yml +9 -0
- model/config/dataset_debug.yml +9 -0
- model/config/dataset_v4_int_format.yml +9 -0
- model/config/dataset_v5.yml +9 -0
- model/config/inference.yml +7 -0
- model/config/train.yml +24 -0
- model/config/train_debug.yml +19 -0
- model/config/train_from_checkpoint.yml +24 -0
- model/config/train_from_checkpoint_debug.yml +19 -0
- model/config/train_grpo.yml +30 -0
- model/config/train_grpo_debug.yml +30 -0
- model/config/train_grpo_debug_vllm.yml +32 -0
- model/config.py +54 -0
- model/dataset.py +324 -0
- model/inference.py +51 -0
- model/nk_assistant.py +207 -0
- model/parser.py +70 -0
- model/run_slurm.py +335 -0
- model/score.ipynb +596 -0
- model/scripts/template.slurm +54 -0
- model/scripts/template_rl.slurm +54 -0
- model/train.py +293 -0
- nk_model/__init__.py +0 -0
- nk_model/assembler.py +112 -0
- nk_model/biased_prediction_agent.py +389 -0
- nk_model/dataset.py +434 -0
- nk_model/enums.py +21 -0
- nk_model/landscape_cache.py +149 -0
- nk_model/models.py +172 -0
- nk_model/nk_landscape.py +498 -0
- simulation/hill_climber_simulation.py +211 -0
- simulation/hill_climber_vs_ai_simulation.py +132 -0
- simulation/landscape_selection.py +179 -0
- utils/__init__.py +0 -0
- utils/binary_conversion.py +128 -0
- utils/logging.py +33 -0
- utils/utils.py +51 -0
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
#!/bin/bash -l
|
|
2
|
+
#SBATCH --output {output_dir}/slurm-%x-%j.out
|
|
3
|
+
#SBATCH --error {output_dir}/slurm-%x-%j.out
|
|
4
|
+
#SBATCH --chdir ./
|
|
5
|
+
#SBATCH --job-name {job_name}/{job_id}
|
|
6
|
+
#
|
|
7
|
+
#SBATCH --nodes={n_nodes}
|
|
8
|
+
#SBATCH --tasks-per-node=1
|
|
9
|
+
#SBATCH --cpus-per-task={n_cpu}
|
|
10
|
+
#SBATCH --mem={memory}
|
|
11
|
+
#
|
|
12
|
+
#SBATCH --constraint="gpu"
|
|
13
|
+
#SBATCH --gres=gpu:a100:{n_gpu}
|
|
14
|
+
#SBATCH --partition=gpu
|
|
15
|
+
|
|
16
|
+
# Wall clock limit (max is 24 hours):
|
|
17
|
+
#SBATCH --time={time}
|
|
18
|
+
|
|
19
|
+
module purge
|
|
20
|
+
module load apptainer
|
|
21
|
+
|
|
22
|
+
source .env
|
|
23
|
+
|
|
24
|
+
# create huggingface cache directory if it doesn't exist
|
|
25
|
+
mkdir -p ~/.cache/huggingface
|
|
26
|
+
|
|
27
|
+
echo "Runing training using the image: {image}"
|
|
28
|
+
echo "Runing training using the config: {config_file}"
|
|
29
|
+
|
|
30
|
+
srun apptainer exec \
|
|
31
|
+
--nv \
|
|
32
|
+
--contain \
|
|
33
|
+
--cleanenv \
|
|
34
|
+
--pwd /root/llm-strategic-tuning \
|
|
35
|
+
--bind .:/root/llm-strategic-tuning \
|
|
36
|
+
--bind ~/.cache/huggingface:/root/.cache/huggingface \
|
|
37
|
+
--bind /ptmp:/ptmp \
|
|
38
|
+
--env HUGGING_FACE_HUB_TOKEN="$HUGGINGFACE_TOKEN" \
|
|
39
|
+
--env WANDB_API_KEY="$WANDB_API_KEY" \
|
|
40
|
+
--env WANDB_ENTITY="chm-ml" \
|
|
41
|
+
--env WANDB_PROJECT="{project_name}" \
|
|
42
|
+
--env WANDB_RUN_GROUP="{group_name}" \
|
|
43
|
+
--env WANDB_NAME="{job_name}/{job_id}" \
|
|
44
|
+
--env NCCL_DEBUG="INFO" \
|
|
45
|
+
--env NCCL_BLOCKING_WAIT="0" \
|
|
46
|
+
--env HF_HOME="/root/.cache/huggingface" \
|
|
47
|
+
{image} \
|
|
48
|
+
bash -c "python -m torch.distributed.run \
|
|
49
|
+
--nnodes=\"$SLURM_NNODES\" \
|
|
50
|
+
--nproc-per-node=gpu \
|
|
51
|
+
--rdzv-id=\"$SLURM_JOBID\" \
|
|
52
|
+
--rdzv-endpoint=\$(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | head -n 1) \
|
|
53
|
+
--rdzv-backend=\"c10d\" \
|
|
54
|
+
{script} --config {config_file} --rl"
|
model/train.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from datasets import load_from_disk
|
|
8
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
9
|
+
from transformers import Trainer, TrainingArguments
|
|
10
|
+
from trl import GRPOConfig, GRPOTrainer
|
|
11
|
+
|
|
12
|
+
from src.model.config import TrainConfig
|
|
13
|
+
from src.model.nk_assistant import NKAssistant
|
|
14
|
+
from src.utils.utils import load_config_from_yaml
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def tokenize_function(row, tokenizer):
|
|
20
|
+
context = row["context"]
|
|
21
|
+
target = row["target"]
|
|
22
|
+
full_input = context + target
|
|
23
|
+
|
|
24
|
+
# Tokenize context and target separately
|
|
25
|
+
context_tokens = tokenizer(context, add_special_tokens=False)["input_ids"]
|
|
26
|
+
|
|
27
|
+
# Tokenize combined text with padding and truncation
|
|
28
|
+
tokenized = tokenizer(full_input, padding=False, return_tensors="np")
|
|
29
|
+
|
|
30
|
+
# Create labels and mask the context part
|
|
31
|
+
labels = tokenized["input_ids"].copy()
|
|
32
|
+
context_length = len(context_tokens)
|
|
33
|
+
|
|
34
|
+
# Mask context tokens by setting them to -100
|
|
35
|
+
labels[:, :context_length] = -100
|
|
36
|
+
|
|
37
|
+
tokenized["labels"] = labels
|
|
38
|
+
|
|
39
|
+
return {
|
|
40
|
+
"input_ids": tokenized["input_ids"].squeeze(0),
|
|
41
|
+
"attention_mask": tokenized["attention_mask"].squeeze(0),
|
|
42
|
+
"labels": labels.squeeze(0),
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def custom_data_collator(features, tokenizer):
|
|
47
|
+
input_ids = [f["input_ids"].clone().detach() for f in features]
|
|
48
|
+
attention_mask = [f["attention_mask"].clone().detach() for f in features]
|
|
49
|
+
labels = [f["labels"].clone().detach() for f in features]
|
|
50
|
+
|
|
51
|
+
batch = {
|
|
52
|
+
"input_ids": pad_sequence(
|
|
53
|
+
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
|
|
54
|
+
),
|
|
55
|
+
"attention_mask": pad_sequence(
|
|
56
|
+
attention_mask, batch_first=True, padding_value=0
|
|
57
|
+
),
|
|
58
|
+
"labels": pad_sequence(labels, batch_first=True, padding_value=-100),
|
|
59
|
+
}
|
|
60
|
+
return batch
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def target_rank_reward_function(
|
|
64
|
+
completions: List[str], ranks: List[float], **kwargs
|
|
65
|
+
) -> List[float]:
|
|
66
|
+
"""
|
|
67
|
+
Reward function that evaluates the rank of a target completion based on
|
|
68
|
+
the provided ranks.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
completions (List[str]): List of generated target completions in
|
|
72
|
+
binary string format
|
|
73
|
+
ranks (List[float]): List of rank values for each possible target
|
|
74
|
+
**kwargs: Additional arguments passed from the trainer
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
List[float]: List of reward values for each completion, where each
|
|
78
|
+
reward is:
|
|
79
|
+
- The rank value from ranks if the completion is valid
|
|
80
|
+
- Different negative values based on the type of format violation
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
rewards = []
|
|
84
|
+
for completion, rank_scores in zip(completions, ranks):
|
|
85
|
+
try:
|
|
86
|
+
int_target = int(completion.replace(",", ""), 2)
|
|
87
|
+
reward = rank_scores[int_target]
|
|
88
|
+
rewards.append(reward)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
# First, check if completion contains only valid characters
|
|
91
|
+
# (0, 1, comma, space)
|
|
92
|
+
valid_chars = set(["0", "1", ","])
|
|
93
|
+
if not all(c in valid_chars for c in completion):
|
|
94
|
+
logger.error(
|
|
95
|
+
f"Error evaluating completion: Contains invalid "
|
|
96
|
+
f"characters: {completion[:50]}..."
|
|
97
|
+
)
|
|
98
|
+
rewards.append(-2.00) # Severe penalty for invalid chars
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
# Check if it's a valid binary string with commas
|
|
102
|
+
parts = completion.split(",")
|
|
103
|
+
# The length should be consistent with a valid binary string
|
|
104
|
+
# (e.g., 8 bits for n=8). We can infer expected length from the
|
|
105
|
+
# size of rank_scores array
|
|
106
|
+
expected_bits = (
|
|
107
|
+
int(np.log2(len(rank_scores))) if len(rank_scores) > 0 else 8
|
|
108
|
+
)
|
|
109
|
+
if len(parts) != expected_bits:
|
|
110
|
+
logger.error(
|
|
111
|
+
f"Error evaluating completion: Wrong length: "
|
|
112
|
+
f"{len(parts)} vs expected {expected_bits}"
|
|
113
|
+
)
|
|
114
|
+
rewards.append(-1.50) # Penalty for wrong length
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
logger.error(f"Error evaluating completion: {e}")
|
|
118
|
+
rewards.append(-1.00) # Generic penalty for other errors
|
|
119
|
+
return rewards
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def train(config: TrainConfig):
|
|
123
|
+
print("Starting supervised learning training...")
|
|
124
|
+
print(f"Training with config: {config}")
|
|
125
|
+
print(f"Dataset file: {config.dataset_file}")
|
|
126
|
+
print(f"Trainer args: {config.trainer_args}")
|
|
127
|
+
print(f"Assistant config: {config.assistant_config}")
|
|
128
|
+
print(f"Final model path: {config.final_model_path}")
|
|
129
|
+
|
|
130
|
+
ds = load_from_disk(config.dataset_file)
|
|
131
|
+
|
|
132
|
+
if config.trainer_args["num_train_epochs"] < 1:
|
|
133
|
+
for split in ds.keys():
|
|
134
|
+
print(
|
|
135
|
+
f"Split of fraction "
|
|
136
|
+
f"{config.trainer_args['num_train_epochs']} from dataset "
|
|
137
|
+
f"partition {split}"
|
|
138
|
+
)
|
|
139
|
+
fraction = int(
|
|
140
|
+
len(ds[split]) * config.trainer_args["num_train_epochs"]
|
|
141
|
+
)
|
|
142
|
+
ds[split] = ds[split].take(fraction)
|
|
143
|
+
print(
|
|
144
|
+
f"Fractioned dataset partition {split} has "
|
|
145
|
+
f"{len(ds[split])} samples"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
ass = NKAssistant(config.assistant_config, metadata=config)
|
|
149
|
+
|
|
150
|
+
# Create text documents to train on
|
|
151
|
+
print("Create and add prompts to dataset")
|
|
152
|
+
ds = ds.map(ass.create_text_from_row)
|
|
153
|
+
|
|
154
|
+
# Tokenize text
|
|
155
|
+
print("Tokenize prompt")
|
|
156
|
+
|
|
157
|
+
def tokenize_function_with_tokenizer(row):
|
|
158
|
+
return tokenize_function(row, ass.tokenizer)
|
|
159
|
+
|
|
160
|
+
ds = ds.map(tokenize_function_with_tokenizer, batched=False)
|
|
161
|
+
|
|
162
|
+
# Turn into torch tensor
|
|
163
|
+
ds.set_format(
|
|
164
|
+
type="torch",
|
|
165
|
+
columns=["input_ids", "attention_mask", "labels"],
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Setup collator for padding of batches
|
|
169
|
+
def data_collator(features):
|
|
170
|
+
return custom_data_collator(features, ass.tokenizer)
|
|
171
|
+
|
|
172
|
+
# Initialize trainer with default values
|
|
173
|
+
trainer = Trainer(
|
|
174
|
+
model=ass.model,
|
|
175
|
+
args=TrainingArguments(**config.trainer_args),
|
|
176
|
+
train_dataset=ds["train"],
|
|
177
|
+
eval_dataset=ds["test"],
|
|
178
|
+
data_collator=data_collator,
|
|
179
|
+
tokenizer=ass.tokenizer,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Train model
|
|
183
|
+
trainer.train()
|
|
184
|
+
|
|
185
|
+
ass._save_pretrained(Path(config.final_model_path))
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def train_reinforcement(config: TrainConfig):
|
|
189
|
+
print("Starting reinforcement learning training...")
|
|
190
|
+
print(f"Training with config: {config}")
|
|
191
|
+
print(f"Dataset file: {config.dataset_file}")
|
|
192
|
+
print(f"Trainer args: {config.trainer_args}")
|
|
193
|
+
print(f"Assistant config: {config.assistant_config}")
|
|
194
|
+
print(f"Final model path: {config.final_model_path}")
|
|
195
|
+
|
|
196
|
+
ds = load_from_disk(config.dataset_file)
|
|
197
|
+
|
|
198
|
+
for split in ds.keys():
|
|
199
|
+
logger.debug(
|
|
200
|
+
f"Configs fraction for {split} split: "
|
|
201
|
+
f"{config.trainer_args['num_train_epochs']}"
|
|
202
|
+
)
|
|
203
|
+
fraction = int(
|
|
204
|
+
len(ds[split]) * config.trainer_args["num_train_epochs"]
|
|
205
|
+
)
|
|
206
|
+
if split != "train":
|
|
207
|
+
# steps = samples * grpo_creation_steps / batch_size * devices
|
|
208
|
+
# samples = wished_steps * batch_size * devices /
|
|
209
|
+
# grpo_creation_steps
|
|
210
|
+
fraction = int(min(600 * 24 * 4 / 8, fraction))
|
|
211
|
+
|
|
212
|
+
logger.debug(f"Fraction of {fraction} samples for {split} split")
|
|
213
|
+
ds[split] = ds[split].take(fraction)
|
|
214
|
+
logger.debug(
|
|
215
|
+
f"Fractioned dataset partition {split} has "
|
|
216
|
+
f"{len(ds[split])} samples"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
ass = NKAssistant(config.assistant_config, metadata=config)
|
|
220
|
+
|
|
221
|
+
# Create prompts for reinforcement learning
|
|
222
|
+
print("Create and add prompts to dataset for RL")
|
|
223
|
+
ds = ds.map(ass.create_prompt_for_rl_from_row)
|
|
224
|
+
|
|
225
|
+
print("Successfully completed prompts creation mapping.")
|
|
226
|
+
|
|
227
|
+
# Turn into torch tensor while keeping the prompt field
|
|
228
|
+
ds.set_format(
|
|
229
|
+
type="torch",
|
|
230
|
+
columns=["prompt", "ranks"],
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Initialize GRPOConfig with default values
|
|
234
|
+
grpo_config = GRPOConfig(
|
|
235
|
+
**config.trainer_args # Override with any custom trainer args
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
print("Initializing GRPOTrainer...")
|
|
239
|
+
# Initialize GRPOTrainer
|
|
240
|
+
trainer = GRPOTrainer(
|
|
241
|
+
model=ass.model,
|
|
242
|
+
args=grpo_config,
|
|
243
|
+
train_dataset=ds["train"],
|
|
244
|
+
eval_dataset=ds["test"],
|
|
245
|
+
reward_funcs=target_rank_reward_function,
|
|
246
|
+
processing_class=ass.tokenizer,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
print("Starting training...")
|
|
250
|
+
# Train model
|
|
251
|
+
trainer.train()
|
|
252
|
+
|
|
253
|
+
print("Saving model...")
|
|
254
|
+
ass._save_pretrained(Path(config.final_model_path))
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
if __name__ == "__main__":
|
|
258
|
+
# Parse command-line arguments
|
|
259
|
+
parser = argparse.ArgumentParser(
|
|
260
|
+
description="Train a model with the given config."
|
|
261
|
+
)
|
|
262
|
+
parser.add_argument(
|
|
263
|
+
"--config", type=str, required=True, help="Path to YAML config file"
|
|
264
|
+
)
|
|
265
|
+
parser.add_argument(
|
|
266
|
+
"--rl",
|
|
267
|
+
action="store_true",
|
|
268
|
+
help=(
|
|
269
|
+
"Use reinforcement learning training instead of supervised "
|
|
270
|
+
"learning"
|
|
271
|
+
),
|
|
272
|
+
)
|
|
273
|
+
parser.add_argument(
|
|
274
|
+
"--rl_method",
|
|
275
|
+
type=str,
|
|
276
|
+
default="grpo",
|
|
277
|
+
choices=["grpo"],
|
|
278
|
+
help="Reinforcement learning method to use (default: grpo)",
|
|
279
|
+
)
|
|
280
|
+
args = parser.parse_args()
|
|
281
|
+
|
|
282
|
+
# Load config from YAML file
|
|
283
|
+
config = load_config_from_yaml(args.config, TrainConfig)
|
|
284
|
+
|
|
285
|
+
# Train the model using the appropriate method
|
|
286
|
+
if args.rl:
|
|
287
|
+
print("Training with reinforcement learning")
|
|
288
|
+
train_reinforcement(config)
|
|
289
|
+
else:
|
|
290
|
+
print("Training with supervised learning")
|
|
291
|
+
train(config)
|
|
292
|
+
|
|
293
|
+
print("Training completed successfully.")
|
nk_model/__init__.py
ADDED
|
File without changes
|
nk_model/assembler.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Utilities for assembling sequences (vectors) and simple bit operations.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
- Bit helper (`reverse01`) for boolean bit values
|
|
5
|
+
- Random selection helper (`random_elem`)
|
|
6
|
+
- Simple assemblers (`assembler_v1`, `assembler_v1_v2`, `assembler_v1_sym`)
|
|
7
|
+
- `ensemble_builder` to grow a set of vectors up to a target length
|
|
8
|
+
while avoiding duplicates and invalid outputs.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from random import choice
|
|
12
|
+
from typing import Callable, Optional, TypeVar
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def reverse01(bits: list[int]) -> list[int]:
|
|
18
|
+
"""Return a new list where each bit is flipped."""
|
|
19
|
+
return [1 - bit for bit in bits]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def random_elem(items: list[T]) -> Optional[T]:
|
|
23
|
+
"""Return a random element from the list, or None if the list is empty."""
|
|
24
|
+
return choice(items) if items else None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def assembler_v1(vectors: list[list[T]]) -> Optional[list[T]]:
|
|
28
|
+
"""
|
|
29
|
+
Concatenate a randomly chosen vector with itself.
|
|
30
|
+
|
|
31
|
+
Returns None if the input list is empty.
|
|
32
|
+
"""
|
|
33
|
+
if not vectors:
|
|
34
|
+
return None
|
|
35
|
+
v1 = random_elem(vectors)
|
|
36
|
+
return None if v1 is None else v1 + v1
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def assembler_v1_v2(vectors: list[list[T]]) -> Optional[list[T]]:
|
|
40
|
+
"""
|
|
41
|
+
Concatenate two randomly chosen vectors from the input.
|
|
42
|
+
|
|
43
|
+
Returns None if the input list is empty.
|
|
44
|
+
"""
|
|
45
|
+
if not vectors:
|
|
46
|
+
return None
|
|
47
|
+
v1 = random_elem(vectors)
|
|
48
|
+
v2 = random_elem(vectors)
|
|
49
|
+
if v1 is None or v2 is None:
|
|
50
|
+
return None
|
|
51
|
+
return v1 + v2
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def assembler_v1_sym(
|
|
55
|
+
vectors: list[list[int]],
|
|
56
|
+
) -> Optional[list[int]]:
|
|
57
|
+
"""
|
|
58
|
+
For integer vectors: pick v1 at random, then concatenate with v1
|
|
59
|
+
or its flip.
|
|
60
|
+
|
|
61
|
+
Returns None if the input list is empty.
|
|
62
|
+
"""
|
|
63
|
+
if not vectors:
|
|
64
|
+
return None
|
|
65
|
+
v1 = random_elem(vectors)
|
|
66
|
+
if v1 is None:
|
|
67
|
+
return None
|
|
68
|
+
v2 = random_elem([v1, reverse01(v1)])
|
|
69
|
+
return None if v2 is None else v1 + v2
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def ensemble_builder(
|
|
73
|
+
assembler: Callable[[list[list[T]]], Optional[list[T]]],
|
|
74
|
+
n: int,
|
|
75
|
+
max_len: int,
|
|
76
|
+
vectors: list[list[T]],
|
|
77
|
+
) -> list[list[T]]:
|
|
78
|
+
"""
|
|
79
|
+
Grow the starting ensemble of vectors by repeatedly assembling new vectors.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
assembler:
|
|
83
|
+
Function to assemble new vectors from the current ensemble. Should
|
|
84
|
+
take a list of vectors and return a new vector.
|
|
85
|
+
n:
|
|
86
|
+
Number of vectors the ensemble should return of length max_len.
|
|
87
|
+
max_len:
|
|
88
|
+
Maximum length of the vectors in the ensemble.
|
|
89
|
+
The ensemble will return vectors of length max_len.
|
|
90
|
+
vectors:
|
|
91
|
+
Starting ensemble of vectors.
|
|
92
|
+
The ensemble will start with these vectors.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
List of vectors in the ensemble of length max_len.
|
|
96
|
+
"""
|
|
97
|
+
if n == 0 or not vectors:
|
|
98
|
+
return vectors
|
|
99
|
+
|
|
100
|
+
new_vector = assembler(vectors)
|
|
101
|
+
if (
|
|
102
|
+
new_vector is None
|
|
103
|
+
or new_vector in vectors
|
|
104
|
+
or len(new_vector) > max_len
|
|
105
|
+
):
|
|
106
|
+
return ensemble_builder(assembler, n, max_len, vectors)
|
|
107
|
+
|
|
108
|
+
new_vectors = vectors + [new_vector]
|
|
109
|
+
if len(new_vector) < max_len:
|
|
110
|
+
return ensemble_builder(assembler, n, max_len, new_vectors)
|
|
111
|
+
|
|
112
|
+
return ensemble_builder(assembler, n - 1, max_len, new_vectors)
|