ai-nk-cce 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_nk_cce-0.1.0.dist-info/METADATA +118 -0
- ai_nk_cce-0.1.0.dist-info/RECORD +46 -0
- ai_nk_cce-0.1.0.dist-info/WHEEL +4 -0
- api/__init__.py +0 -0
- api/mpcdf_vllm.py +94 -0
- evals/nk_model.py +277 -0
- model/README.md +64 -0
- model/config/dataset_conv_v1.yml +9 -0
- model/config/dataset_conv_v2_m2.yml +9 -0
- model/config/dataset_conv_v3_m2_assembl_nearest.yml +9 -0
- model/config/dataset_debug.yml +9 -0
- model/config/dataset_v4_int_format.yml +9 -0
- model/config/dataset_v5.yml +9 -0
- model/config/inference.yml +7 -0
- model/config/train.yml +24 -0
- model/config/train_debug.yml +19 -0
- model/config/train_from_checkpoint.yml +24 -0
- model/config/train_from_checkpoint_debug.yml +19 -0
- model/config/train_grpo.yml +30 -0
- model/config/train_grpo_debug.yml +30 -0
- model/config/train_grpo_debug_vllm.yml +32 -0
- model/config.py +54 -0
- model/dataset.py +324 -0
- model/inference.py +51 -0
- model/nk_assistant.py +207 -0
- model/parser.py +70 -0
- model/run_slurm.py +335 -0
- model/score.ipynb +596 -0
- model/scripts/template.slurm +54 -0
- model/scripts/template_rl.slurm +54 -0
- model/train.py +293 -0
- nk_model/__init__.py +0 -0
- nk_model/assembler.py +112 -0
- nk_model/biased_prediction_agent.py +389 -0
- nk_model/dataset.py +434 -0
- nk_model/enums.py +21 -0
- nk_model/landscape_cache.py +149 -0
- nk_model/models.py +172 -0
- nk_model/nk_landscape.py +498 -0
- simulation/hill_climber_simulation.py +211 -0
- simulation/hill_climber_vs_ai_simulation.py +132 -0
- simulation/landscape_selection.py +179 -0
- utils/__init__.py +0 -0
- utils/binary_conversion.py +128 -0
- utils/logging.py +33 -0
- utils/utils.py +51 -0
model/config/train.yml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
dataset_file: data/samples/<<ds>>
|
|
2
|
+
final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
|
|
3
|
+
trainer_args:
|
|
4
|
+
output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
|
|
5
|
+
eval_strategy: "steps"
|
|
6
|
+
per_device_train_batch_size: 20
|
|
7
|
+
per_device_eval_batch_size: 20
|
|
8
|
+
save_steps: 0.1
|
|
9
|
+
save_total_limit: 3
|
|
10
|
+
logging_dir: "./logs"
|
|
11
|
+
logging_steps: 10
|
|
12
|
+
eval_steps: 100
|
|
13
|
+
do_eval: true
|
|
14
|
+
num_train_epochs: 1
|
|
15
|
+
learning_rate: "<<float: lr>>"
|
|
16
|
+
report_to: "wandb"
|
|
17
|
+
run_name: <<job_name>>/<<job_id>>
|
|
18
|
+
lr_scheduler_type: constant
|
|
19
|
+
assistant_config:
|
|
20
|
+
parser_config:
|
|
21
|
+
include_payoff: true
|
|
22
|
+
model_path: gpt2
|
|
23
|
+
generation_params:
|
|
24
|
+
temperature: 0.7
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
dataset_file: data/samples/<<ds>>
|
|
2
|
+
final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
|
|
3
|
+
trainer_args:
|
|
4
|
+
output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
|
|
5
|
+
evaluation_strategy: epoch
|
|
6
|
+
per_device_train_batch_size: 1
|
|
7
|
+
per_device_eval_batch_size: 1
|
|
8
|
+
save_steps: 10000
|
|
9
|
+
save_total_limit: 2
|
|
10
|
+
logging_dir: ./logs
|
|
11
|
+
logging_steps: 500
|
|
12
|
+
num_train_epochs: 1
|
|
13
|
+
report_to: none
|
|
14
|
+
assistant_config:
|
|
15
|
+
parser_config:
|
|
16
|
+
include_payoff: true
|
|
17
|
+
model_path: gpt2
|
|
18
|
+
generation_params:
|
|
19
|
+
temperature: 0.7
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
dataset_file: data/samples/<<ds>>
|
|
2
|
+
final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
|
|
3
|
+
trainer_args:
|
|
4
|
+
output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
|
|
5
|
+
eval_strategy: "steps"
|
|
6
|
+
per_device_train_batch_size: 20
|
|
7
|
+
per_device_eval_batch_size: 20
|
|
8
|
+
save_steps: 0.1
|
|
9
|
+
save_total_limit: 3
|
|
10
|
+
logging_dir: "./logs"
|
|
11
|
+
logging_steps: 10
|
|
12
|
+
eval_steps: 100
|
|
13
|
+
do_eval: true
|
|
14
|
+
num_train_epochs: 0.2
|
|
15
|
+
learning_rate: "<<float: lr>>"
|
|
16
|
+
report_to: "wandb"
|
|
17
|
+
run_name: <<job_name>>/<<job_id>>
|
|
18
|
+
lr_scheduler_type: cosine
|
|
19
|
+
assistant_config:
|
|
20
|
+
parser_config:
|
|
21
|
+
include_payoff: true
|
|
22
|
+
model_path: models/gpt2_v5/1e-5/2025_05_26__18_02_35
|
|
23
|
+
generation_params:
|
|
24
|
+
temperature: 0.7
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
dataset_file: data/samples/<<ds>>
|
|
2
|
+
final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
|
|
3
|
+
trainer_args:
|
|
4
|
+
output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
|
|
5
|
+
evaluation_strategy: epoch
|
|
6
|
+
per_device_train_batch_size: 1
|
|
7
|
+
per_device_eval_batch_size: 1
|
|
8
|
+
save_steps: 10000
|
|
9
|
+
save_total_limit: 2
|
|
10
|
+
logging_dir: ./logs
|
|
11
|
+
logging_steps: 500
|
|
12
|
+
num_train_epochs: 1
|
|
13
|
+
report_to: none
|
|
14
|
+
assistant_config:
|
|
15
|
+
parser_config:
|
|
16
|
+
include_payoff: true
|
|
17
|
+
model_path: models/gpt2_v5/1e-5/2025_05_20__10_27_38
|
|
18
|
+
generation_params:
|
|
19
|
+
temperature: 0.7
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
dataset_file: data/samples/<<ds>>
|
|
2
|
+
final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
|
|
3
|
+
trainer_args:
|
|
4
|
+
output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
|
|
5
|
+
eval_strategy: "steps"
|
|
6
|
+
per_device_train_batch_size: 24 # must be divisible by generations per prompt
|
|
7
|
+
per_device_eval_batch_size: 24
|
|
8
|
+
save_steps: 0.1
|
|
9
|
+
save_total_limit: 3
|
|
10
|
+
logging_dir: "./logs"
|
|
11
|
+
logging_steps: 0.01
|
|
12
|
+
eval_steps: 0.1
|
|
13
|
+
do_eval: true
|
|
14
|
+
num_train_epochs: 1.0
|
|
15
|
+
learning_rate: "<<float: lr>>"
|
|
16
|
+
report_to: "wandb"
|
|
17
|
+
run_name: <<job_name>>/<<job_id>>
|
|
18
|
+
lr_scheduler_type: constant
|
|
19
|
+
# GRPO specific args
|
|
20
|
+
num_generations: 8
|
|
21
|
+
beta: 0.04
|
|
22
|
+
epsilon: 0.2
|
|
23
|
+
max_prompt_length: 512
|
|
24
|
+
max_completion_length: 15
|
|
25
|
+
assistant_config:
|
|
26
|
+
parser_config:
|
|
27
|
+
include_payoff: true
|
|
28
|
+
model_path: models/gpt2_v5/1e-5/2025_05_26__18_02_35
|
|
29
|
+
generation_params:
|
|
30
|
+
temperature: 0.7
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
dataset_file: data/samples/<<ds>>
|
|
2
|
+
final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
|
|
3
|
+
trainer_args:
|
|
4
|
+
output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
|
|
5
|
+
eval_strategy: "steps"
|
|
6
|
+
per_device_train_batch_size: 8 # must be divisible by generations per prompt
|
|
7
|
+
per_device_eval_batch_size: 8
|
|
8
|
+
save_steps: 0.5 # Very high number to avoid frequent saves during debug
|
|
9
|
+
save_total_limit: 1 # Keep only one checkpoint
|
|
10
|
+
logging_dir: "./logs"
|
|
11
|
+
logging_steps: 0.01
|
|
12
|
+
eval_steps: 0.1
|
|
13
|
+
do_eval: true
|
|
14
|
+
num_train_epochs: 0.005
|
|
15
|
+
learning_rate: "<<float: lr>>"
|
|
16
|
+
report_to: "wandb"
|
|
17
|
+
run_name: <<job_name>>/<<job_id>>
|
|
18
|
+
lr_scheduler_type: constant
|
|
19
|
+
# GRPO specific args
|
|
20
|
+
num_generations: 8
|
|
21
|
+
beta: 0.04
|
|
22
|
+
epsilon: 0.2
|
|
23
|
+
max_prompt_length: 512
|
|
24
|
+
max_completion_length: 15
|
|
25
|
+
assistant_config:
|
|
26
|
+
parser_config:
|
|
27
|
+
include_payoff: true
|
|
28
|
+
model_path: models/gpt2_v5/1e-5/2025_05_26__18_02_35
|
|
29
|
+
generation_params:
|
|
30
|
+
temperature: 0.7
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
dataset_file: data/samples/<<ds>>
|
|
2
|
+
final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
|
|
3
|
+
trainer_args:
|
|
4
|
+
output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
|
|
5
|
+
eval_strategy: "steps"
|
|
6
|
+
per_device_train_batch_size: 8 # must be divisible by generations per prompt
|
|
7
|
+
per_device_eval_batch_size: 8
|
|
8
|
+
save_steps: 1 # Very high number to avoid frequent saves during debug
|
|
9
|
+
save_total_limit: 1 # Keep only one checkpoint
|
|
10
|
+
logging_dir: "./logs"
|
|
11
|
+
logging_steps: 0.1
|
|
12
|
+
eval_steps: 0.1
|
|
13
|
+
do_eval: true
|
|
14
|
+
num_train_epochs: 0.001
|
|
15
|
+
learning_rate: "<<float: lr>>"
|
|
16
|
+
report_to: "wandb"
|
|
17
|
+
run_name: <<job_name>>/<<job_id>>
|
|
18
|
+
lr_scheduler_type: constant
|
|
19
|
+
# GRPO specific args
|
|
20
|
+
beta: 0.04
|
|
21
|
+
epsilon: 0.2
|
|
22
|
+
max_prompt_length: 512
|
|
23
|
+
max_completion_length: 15
|
|
24
|
+
use_vllm: true
|
|
25
|
+
vllm_mode: "colocate"
|
|
26
|
+
vllm_gpu_memory_utilization: 0.10
|
|
27
|
+
assistant_config:
|
|
28
|
+
parser_config:
|
|
29
|
+
include_payoff: true
|
|
30
|
+
model_path: models/gpt2_v5/1e-5/2025_05_26__18_02_35
|
|
31
|
+
generation_params:
|
|
32
|
+
temperature: 0.7
|
model/config.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ParserConfig(BaseModel):
|
|
7
|
+
include_payoff: bool = True
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NKAssistantConfig(BaseModel):
|
|
11
|
+
parser_config: ParserConfig
|
|
12
|
+
use_mpcdf_vllm: bool = False
|
|
13
|
+
model_path: Optional[str] = None
|
|
14
|
+
generation_params: Optional[Dict[str, Any]] = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TrainConfig(BaseModel):
|
|
18
|
+
trainer_args: Any
|
|
19
|
+
dataset_file: str = Field(..., description="The path to the dataset file")
|
|
20
|
+
final_model_path: str = Field(
|
|
21
|
+
..., description="The path to save the final model"
|
|
22
|
+
)
|
|
23
|
+
assistant_config: NKAssistantConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class InferenceConfig(BaseModel):
|
|
27
|
+
dataset_file: str = Field(..., description="The path to the dataset file")
|
|
28
|
+
model_path: str = Field(
|
|
29
|
+
..., description="The path to the fine-tuned model"
|
|
30
|
+
)
|
|
31
|
+
output_dataset_file: str = Field(
|
|
32
|
+
..., description="The path to save the dataset with suggestions"
|
|
33
|
+
)
|
|
34
|
+
generation_params: Optional[Dict[str, Any]] = Field(
|
|
35
|
+
default=None, description="Generation parameters"
|
|
36
|
+
)
|
|
37
|
+
splits: List[str] = Field(
|
|
38
|
+
default=["train", "test"], description="Dataset splits to process"
|
|
39
|
+
)
|
|
40
|
+
max_test_samples: Optional[int] = Field(
|
|
41
|
+
default=None, description="Maximum number of test samples to process"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class DataSetConfig(BaseModel):
|
|
46
|
+
input_file: str
|
|
47
|
+
output_file: str
|
|
48
|
+
samples_per_landscapes: int = 1
|
|
49
|
+
constraints_range: Tuple[int, int] = (0, 8)
|
|
50
|
+
n_samples: List[int] = [8]
|
|
51
|
+
include_payoff: bool = True
|
|
52
|
+
test_ratio: float = 0.1
|
|
53
|
+
exp_ratio: float = 0.1
|
|
54
|
+
debug_size: Optional[int] = None
|
model/dataset.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import random
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from datasets import Dataset, DatasetDict
|
|
8
|
+
|
|
9
|
+
from src.model.config import DataSetConfig
|
|
10
|
+
from src.model.parser import create_context, create_target
|
|
11
|
+
from src.utils.utils import BIN_ARRAY, load_config_from_yaml
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def binary_vectors_within_radius(
|
|
15
|
+
origin: int,
|
|
16
|
+
radius: int,
|
|
17
|
+
n_dim: int = 8,
|
|
18
|
+
) -> np.ndarray:
|
|
19
|
+
"""
|
|
20
|
+
Create the list of indices of all binary vectors in a ball of radius r
|
|
21
|
+
around origin.
|
|
22
|
+
Args:
|
|
23
|
+
origin (int): Integer representation of the binary vector to center
|
|
24
|
+
around.
|
|
25
|
+
radius (int): Maximum Hamming distance from origin to include in the
|
|
26
|
+
ball.
|
|
27
|
+
n_dim (int, optional): Number of dimensions of binary vectors. Must be
|
|
28
|
+
<= 8. Defaults to 8.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
np.ndarray: Array of indices corresponding to binary vectors within
|
|
32
|
+
radius of origin.
|
|
33
|
+
"""
|
|
34
|
+
assert n_dim <= 8, "n_dim must be less than or equal to 8"
|
|
35
|
+
# create a list of all binary vectors of length n_dim
|
|
36
|
+
all_vectors = BIN_ARRAY[: 2**n_dim]
|
|
37
|
+
origin_vector = all_vectors[origin]
|
|
38
|
+
|
|
39
|
+
# calculate hamming distances to origin using element-wise comparison
|
|
40
|
+
distances = np.sum(all_vectors != origin_vector, axis=1)
|
|
41
|
+
# return the indices of vectors within the radius
|
|
42
|
+
ball_indices = np.where(distances <= radius)[0]
|
|
43
|
+
|
|
44
|
+
return ball_indices
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_rank_score(
|
|
48
|
+
payoff: np.ndarray,
|
|
49
|
+
constraint_idx: np.ndarray,
|
|
50
|
+
) -> np.ndarray:
|
|
51
|
+
"""
|
|
52
|
+
Calculates a ranking score for the indices within the constraint indices.
|
|
53
|
+
It ranks the indices within the constraint area and gives them a score
|
|
54
|
+
according to their rank. The indices outside the constraint area are given
|
|
55
|
+
a score of -1.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
payoff (np.ndarray): Array of payoff values.
|
|
59
|
+
constraint_idx (np.ndarray): Boolean array indicating which indices
|
|
60
|
+
are constrained.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
np.ndarray: Array where constrained indices have values from 0 to 1
|
|
64
|
+
based on their rank (0 for lowest payoff, 1 for highest), and
|
|
65
|
+
unconstrained indices have -1.
|
|
66
|
+
"""
|
|
67
|
+
# Initialize all scores to -1
|
|
68
|
+
rank_score = np.full_like(payoff, fill_value=-1, dtype=float)
|
|
69
|
+
|
|
70
|
+
# Get payoffs only for constrained indices
|
|
71
|
+
constrained_payoffs = payoff[constraint_idx]
|
|
72
|
+
|
|
73
|
+
# Calculate rank scores from 0 to 1
|
|
74
|
+
n_constrained = len(constrained_payoffs)
|
|
75
|
+
if n_constrained > 1:
|
|
76
|
+
ranks = np.linspace(0, 1, n_constrained)
|
|
77
|
+
else:
|
|
78
|
+
ranks = np.array([1.0]) # If only one point, give it rank 1
|
|
79
|
+
|
|
80
|
+
# Assign ranks to constrained indices
|
|
81
|
+
rank_score[constraint_idx] = ranks[np.argsort(constrained_payoffs)]
|
|
82
|
+
|
|
83
|
+
return rank_score
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def create_sample(
|
|
87
|
+
payoff: np.ndarray, hamming_distance: Tuple[int, int], n_samples: List[int]
|
|
88
|
+
) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
|
|
89
|
+
|
|
90
|
+
assert all(map(lambda x: x <= len(payoff), n_samples)), (
|
|
91
|
+
"n_samples must be less than or equal to the length of the "
|
|
92
|
+
"payoff array"
|
|
93
|
+
)
|
|
94
|
+
assert hamming_distance[0] <= hamming_distance[1], (
|
|
95
|
+
"constraints_range must be a tuple of two integers where the first "
|
|
96
|
+
"is less than or equal to the second"
|
|
97
|
+
)
|
|
98
|
+
assert hamming_distance[0] > 0, (
|
|
99
|
+
"constraints_range must be a tuple of two integers where the first "
|
|
100
|
+
"is greater than 0"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# create a list of all idx [0, 1, 2, ..., len(payoff) - 1]
|
|
104
|
+
# as integer representation of a binary vector
|
|
105
|
+
all_idx = np.arange(len(payoff))
|
|
106
|
+
|
|
107
|
+
# random choice of sample size and number of constraints
|
|
108
|
+
sample_size = np.random.choice(n_samples)
|
|
109
|
+
hamming_distance = np.random.randint(
|
|
110
|
+
hamming_distance[0], hamming_distance[1]
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# creating a list of random sample indices
|
|
114
|
+
sample_idx = np.random.choice(a=all_idx, size=sample_size, replace=False)
|
|
115
|
+
|
|
116
|
+
sample_payoffs = payoff[sample_idx]
|
|
117
|
+
sample_idx = sample_idx[np.argsort(sample_payoffs)]
|
|
118
|
+
|
|
119
|
+
# random choice of origin
|
|
120
|
+
origin_idx = np.random.choice(all_idx)
|
|
121
|
+
|
|
122
|
+
# create a list of indices fullfilling the constraints
|
|
123
|
+
# to get the target index
|
|
124
|
+
constraint_idx = binary_vectors_within_radius(
|
|
125
|
+
origin=origin_idx,
|
|
126
|
+
radius=hamming_distance,
|
|
127
|
+
n_dim=int(np.log2(len(payoff))),
|
|
128
|
+
)
|
|
129
|
+
constraint_idx = constraint_idx[constraint_idx != origin_idx]
|
|
130
|
+
|
|
131
|
+
# get the target index as the index with the highest payoff
|
|
132
|
+
# that is not the origin
|
|
133
|
+
target_idx = constraint_idx[np.argmax(payoff[constraint_idx])]
|
|
134
|
+
|
|
135
|
+
# get the rank score for all indices
|
|
136
|
+
rank_score = get_rank_score(payoff, constraint_idx)
|
|
137
|
+
|
|
138
|
+
return sample_idx, target_idx, origin_idx, hamming_distance, rank_score
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def create_raw_dataset(df: pd.DataFrame, config: DataSetConfig) -> Dataset:
|
|
142
|
+
examples = {
|
|
143
|
+
"landscape_id": [],
|
|
144
|
+
"sample": [],
|
|
145
|
+
"n": [],
|
|
146
|
+
"k": [],
|
|
147
|
+
"power_scale": [],
|
|
148
|
+
"payoffs": [],
|
|
149
|
+
"sample_idx": [],
|
|
150
|
+
"target_idx": [],
|
|
151
|
+
"origin_idx": [],
|
|
152
|
+
"hamming_distance": [],
|
|
153
|
+
"ranks": [],
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
from tqdm import tqdm
|
|
157
|
+
|
|
158
|
+
total_iterations = (
|
|
159
|
+
len(df.groupby("landscape_uuid")) * config.samples_per_landscapes
|
|
160
|
+
)
|
|
161
|
+
pbar = tqdm(total=total_iterations, desc="Creating dataset")
|
|
162
|
+
for s in range(config.samples_per_landscapes):
|
|
163
|
+
for id, group in df.groupby("landscape_uuid"):
|
|
164
|
+
|
|
165
|
+
n = group["n"].values[0].item()
|
|
166
|
+
k = group["k"].values[0].item()
|
|
167
|
+
power_scale = group["power_scale"].values[0].item()
|
|
168
|
+
|
|
169
|
+
# Sort the group by binary coordinates
|
|
170
|
+
group = sort_by_binary_coords(group, n)
|
|
171
|
+
|
|
172
|
+
# create a sample
|
|
173
|
+
payoffs = group["payoff"].values.astype(np.int32)
|
|
174
|
+
|
|
175
|
+
(
|
|
176
|
+
sample_idxs,
|
|
177
|
+
target_idx,
|
|
178
|
+
origin_idx,
|
|
179
|
+
hamming_distance,
|
|
180
|
+
ranks,
|
|
181
|
+
) = create_sample(
|
|
182
|
+
payoff=payoffs,
|
|
183
|
+
hamming_distance=config.constraints_range,
|
|
184
|
+
n_samples=config.n_samples,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
examples["landscape_id"].append(id)
|
|
188
|
+
examples["sample"].append(s)
|
|
189
|
+
examples["n"].append(n)
|
|
190
|
+
examples["k"].append(k)
|
|
191
|
+
examples["power_scale"].append(power_scale)
|
|
192
|
+
examples["payoffs"].append(payoffs)
|
|
193
|
+
examples["sample_idx"].append(sample_idxs)
|
|
194
|
+
examples["target_idx"].append(target_idx)
|
|
195
|
+
examples["origin_idx"].append(origin_idx)
|
|
196
|
+
examples["hamming_distance"].append(hamming_distance)
|
|
197
|
+
examples["ranks"].append(ranks)
|
|
198
|
+
|
|
199
|
+
pbar.update(1)
|
|
200
|
+
|
|
201
|
+
pbar.close()
|
|
202
|
+
|
|
203
|
+
return Dataset.from_dict(examples)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def create_context_from_row(row: dict, include_payoff: bool = True) -> dict:
|
|
207
|
+
sample_payoff = np.array(row["payoffs"])[row["sample_idx"]]
|
|
208
|
+
|
|
209
|
+
context = create_context(
|
|
210
|
+
n=row["n"],
|
|
211
|
+
k=row["k"],
|
|
212
|
+
power_scale=row["power_scale"],
|
|
213
|
+
sample_idxs=row["sample_idx"],
|
|
214
|
+
origin_idx=row["origin_idx"],
|
|
215
|
+
hamming_distance=row["hamming_distance"],
|
|
216
|
+
payoff=sample_payoff,
|
|
217
|
+
include_payoff=include_payoff,
|
|
218
|
+
)
|
|
219
|
+
return {"context": context}
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def create_target_from_row(row: dict) -> dict:
|
|
223
|
+
target = create_target(target_idx=row["target_idx"])
|
|
224
|
+
return {"target": target}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def create_dataset_from_file(config: DataSetConfig) -> DatasetDict:
|
|
228
|
+
df = pd.read_parquet(config.input_file)
|
|
229
|
+
landscape_uuid = df["landscape_uuid"].unique()
|
|
230
|
+
random.shuffle(landscape_uuid)
|
|
231
|
+
|
|
232
|
+
# Split landscape IDs into train, test, exp
|
|
233
|
+
split_points = [
|
|
234
|
+
int((1 - config.test_ratio - config.exp_ratio) * len(landscape_uuid)),
|
|
235
|
+
int((1 - config.exp_ratio) * len(landscape_uuid)),
|
|
236
|
+
]
|
|
237
|
+
train_ids, test_ids, exp_ids = np.split(landscape_uuid, split_points)
|
|
238
|
+
|
|
239
|
+
# Assertions to ensure splits are correct
|
|
240
|
+
assert len(train_ids) + len(test_ids) + len(exp_ids) == len(landscape_uuid)
|
|
241
|
+
assert len(set(train_ids) & set(test_ids)) == 0
|
|
242
|
+
assert len(set(train_ids) & set(exp_ids)) == 0
|
|
243
|
+
assert len(set(test_ids) & set(exp_ids)) == 0
|
|
244
|
+
|
|
245
|
+
# Apply debug size limitation if specified
|
|
246
|
+
if config.debug_size is not None:
|
|
247
|
+
dbs = config.debug_size
|
|
248
|
+
train_ids, test_ids, exp_ids = (
|
|
249
|
+
train_ids[:dbs],
|
|
250
|
+
test_ids[:dbs],
|
|
251
|
+
exp_ids[:dbs],
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Create datasets for each split
|
|
255
|
+
datasets = {}
|
|
256
|
+
for split_name, split_ids in [
|
|
257
|
+
("train", train_ids),
|
|
258
|
+
("test", test_ids),
|
|
259
|
+
("exp", exp_ids),
|
|
260
|
+
]:
|
|
261
|
+
# Create raw dataset and apply context/target transformation
|
|
262
|
+
split_df = df[df["landscape_uuid"].isin(split_ids)]
|
|
263
|
+
datasets[split_name] = create_raw_dataset(split_df, config)
|
|
264
|
+
|
|
265
|
+
ds = DatasetDict(datasets)
|
|
266
|
+
return ds
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def apply_context_target(
|
|
270
|
+
ds: DatasetDict, config: DataSetConfig
|
|
271
|
+
) -> DatasetDict:
|
|
272
|
+
def create_context_from_row_(row):
|
|
273
|
+
return create_context_from_row(row, config.include_payoff)
|
|
274
|
+
|
|
275
|
+
ds = ds.map(create_context_from_row_, batched=False)
|
|
276
|
+
ds = ds.map(create_target_from_row, batched=False)
|
|
277
|
+
return ds
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def sort_by_binary_coords(df: pd.DataFrame, n: int) -> pd.DataFrame:
|
|
281
|
+
"""
|
|
282
|
+
Sort dataframe by converting binary coordinates to integer representation.
|
|
283
|
+
Args:
|
|
284
|
+
df (pd.DataFrame): Input dataframe with binary coordinates
|
|
285
|
+
n (int): Number of binary coordinate columns
|
|
286
|
+
Returns:
|
|
287
|
+
pd.DataFrame: Sorted dataframe with sort_key column
|
|
288
|
+
"""
|
|
289
|
+
# Get the last n columns as binary coordinates
|
|
290
|
+
coord_cols = df.columns[-n:]
|
|
291
|
+
# Convert binary coordinates to integer and sort
|
|
292
|
+
df["sort_key"] = df[coord_cols].apply(
|
|
293
|
+
lambda x: int("".join(map(str, x)), 2), axis=1
|
|
294
|
+
)
|
|
295
|
+
df = df.sort_values("sort_key")
|
|
296
|
+
return df
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
if __name__ == "__main__":
|
|
300
|
+
print("Starting dataset creation...")
|
|
301
|
+
# Parse command-line arguments
|
|
302
|
+
parser = argparse.ArgumentParser(
|
|
303
|
+
description="Generate dataset from config."
|
|
304
|
+
)
|
|
305
|
+
parser.add_argument(
|
|
306
|
+
"--config", type=str, required=True, help="Path to YAML config file"
|
|
307
|
+
)
|
|
308
|
+
args = parser.parse_args()
|
|
309
|
+
|
|
310
|
+
# Load config from YAML file
|
|
311
|
+
config = load_config_from_yaml(args.config, DataSetConfig)
|
|
312
|
+
|
|
313
|
+
print(config)
|
|
314
|
+
|
|
315
|
+
# Create and save dataset
|
|
316
|
+
dataset = create_dataset_from_file(config)
|
|
317
|
+
|
|
318
|
+
print(dataset)
|
|
319
|
+
|
|
320
|
+
dataset = apply_context_target(dataset, config)
|
|
321
|
+
|
|
322
|
+
dataset.save_to_disk(config.output_file)
|
|
323
|
+
|
|
324
|
+
print(f"Dataset saved to {config.output_file}")
|
model/inference.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from datasets import load_from_disk
|
|
4
|
+
|
|
5
|
+
from src.model.config import InferenceConfig
|
|
6
|
+
from src.model.nk_assistant import NKAssistant
|
|
7
|
+
from src.utils.utils import load_config_from_yaml
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def run_inference(config: InferenceConfig):
|
|
11
|
+
# Load dataset and model
|
|
12
|
+
ds = load_from_disk(config.dataset_file)
|
|
13
|
+
|
|
14
|
+
# Load model and tokenizer
|
|
15
|
+
ass = NKAssistant.from_pretrained(config.model_path)
|
|
16
|
+
|
|
17
|
+
if config.generation_params:
|
|
18
|
+
ass.generation_params = config.generation_params
|
|
19
|
+
|
|
20
|
+
# Process each split
|
|
21
|
+
result_ds = ds
|
|
22
|
+
for split in config.splits:
|
|
23
|
+
if split in ds:
|
|
24
|
+
print(f"Processing {split} split...")
|
|
25
|
+
if config.max_test_samples:
|
|
26
|
+
print(f"Processing {config.max_test_samples} test samples...")
|
|
27
|
+
ds[split] = ds[split].select(range(config.max_test_samples))
|
|
28
|
+
result_ds[split] = ds[split].map(ass.suggest_from_row)
|
|
29
|
+
|
|
30
|
+
# Save the dataset with suggestions
|
|
31
|
+
result_ds.save_to_disk(config.output_dataset_file)
|
|
32
|
+
print(f"Dataset with suggestions saved to {config.output_dataset_file}")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
if __name__ == "__main__":
|
|
36
|
+
# Parse command-line arguments
|
|
37
|
+
parser = argparse.ArgumentParser(
|
|
38
|
+
description="Run inference with a fine-tuned model."
|
|
39
|
+
)
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
"--config", type=str, required=True, help="Path to YAML config file"
|
|
42
|
+
)
|
|
43
|
+
args = parser.parse_args()
|
|
44
|
+
|
|
45
|
+
# Load config from YAML file
|
|
46
|
+
config = load_config_from_yaml(args.config, InferenceConfig)
|
|
47
|
+
|
|
48
|
+
# Run inference
|
|
49
|
+
run_inference(config)
|
|
50
|
+
|
|
51
|
+
print("Inference completed successfully.")
|