ai-nk-cce 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ai_nk_cce-0.1.0.dist-info/METADATA +118 -0
  2. ai_nk_cce-0.1.0.dist-info/RECORD +46 -0
  3. ai_nk_cce-0.1.0.dist-info/WHEEL +4 -0
  4. api/__init__.py +0 -0
  5. api/mpcdf_vllm.py +94 -0
  6. evals/nk_model.py +277 -0
  7. model/README.md +64 -0
  8. model/config/dataset_conv_v1.yml +9 -0
  9. model/config/dataset_conv_v2_m2.yml +9 -0
  10. model/config/dataset_conv_v3_m2_assembl_nearest.yml +9 -0
  11. model/config/dataset_debug.yml +9 -0
  12. model/config/dataset_v4_int_format.yml +9 -0
  13. model/config/dataset_v5.yml +9 -0
  14. model/config/inference.yml +7 -0
  15. model/config/train.yml +24 -0
  16. model/config/train_debug.yml +19 -0
  17. model/config/train_from_checkpoint.yml +24 -0
  18. model/config/train_from_checkpoint_debug.yml +19 -0
  19. model/config/train_grpo.yml +30 -0
  20. model/config/train_grpo_debug.yml +30 -0
  21. model/config/train_grpo_debug_vllm.yml +32 -0
  22. model/config.py +54 -0
  23. model/dataset.py +324 -0
  24. model/inference.py +51 -0
  25. model/nk_assistant.py +207 -0
  26. model/parser.py +70 -0
  27. model/run_slurm.py +335 -0
  28. model/score.ipynb +596 -0
  29. model/scripts/template.slurm +54 -0
  30. model/scripts/template_rl.slurm +54 -0
  31. model/train.py +293 -0
  32. nk_model/__init__.py +0 -0
  33. nk_model/assembler.py +112 -0
  34. nk_model/biased_prediction_agent.py +389 -0
  35. nk_model/dataset.py +434 -0
  36. nk_model/enums.py +21 -0
  37. nk_model/landscape_cache.py +149 -0
  38. nk_model/models.py +172 -0
  39. nk_model/nk_landscape.py +498 -0
  40. simulation/hill_climber_simulation.py +211 -0
  41. simulation/hill_climber_vs_ai_simulation.py +132 -0
  42. simulation/landscape_selection.py +179 -0
  43. utils/__init__.py +0 -0
  44. utils/binary_conversion.py +128 -0
  45. utils/logging.py +33 -0
  46. utils/utils.py +51 -0
model/config/train.yml ADDED
@@ -0,0 +1,24 @@
1
+ dataset_file: data/samples/<<ds>>
2
+ final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
3
+ trainer_args:
4
+ output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
5
+ eval_strategy: "steps"
6
+ per_device_train_batch_size: 20
7
+ per_device_eval_batch_size: 20
8
+ save_steps: 0.1
9
+ save_total_limit: 3
10
+ logging_dir: "./logs"
11
+ logging_steps: 10
12
+ eval_steps: 100
13
+ do_eval: true
14
+ num_train_epochs: 1
15
+ learning_rate: "<<float: lr>>"
16
+ report_to: "wandb"
17
+ run_name: <<job_name>>/<<job_id>>
18
+ lr_scheduler_type: constant
19
+ assistant_config:
20
+ parser_config:
21
+ include_payoff: true
22
+ model_path: gpt2
23
+ generation_params:
24
+ temperature: 0.7
@@ -0,0 +1,19 @@
1
+ dataset_file: data/samples/<<ds>>
2
+ final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
3
+ trainer_args:
4
+ output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
5
+ evaluation_strategy: epoch
6
+ per_device_train_batch_size: 1
7
+ per_device_eval_batch_size: 1
8
+ save_steps: 10000
9
+ save_total_limit: 2
10
+ logging_dir: ./logs
11
+ logging_steps: 500
12
+ num_train_epochs: 1
13
+ report_to: none
14
+ assistant_config:
15
+ parser_config:
16
+ include_payoff: true
17
+ model_path: gpt2
18
+ generation_params:
19
+ temperature: 0.7
@@ -0,0 +1,24 @@
1
+ dataset_file: data/samples/<<ds>>
2
+ final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
3
+ trainer_args:
4
+ output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
5
+ eval_strategy: "steps"
6
+ per_device_train_batch_size: 20
7
+ per_device_eval_batch_size: 20
8
+ save_steps: 0.1
9
+ save_total_limit: 3
10
+ logging_dir: "./logs"
11
+ logging_steps: 10
12
+ eval_steps: 100
13
+ do_eval: true
14
+ num_train_epochs: 0.2
15
+ learning_rate: "<<float: lr>>"
16
+ report_to: "wandb"
17
+ run_name: <<job_name>>/<<job_id>>
18
+ lr_scheduler_type: cosine
19
+ assistant_config:
20
+ parser_config:
21
+ include_payoff: true
22
+ model_path: models/gpt2_v5/1e-5/2025_05_26__18_02_35
23
+ generation_params:
24
+ temperature: 0.7
@@ -0,0 +1,19 @@
1
+ dataset_file: data/samples/<<ds>>
2
+ final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
3
+ trainer_args:
4
+ output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
5
+ evaluation_strategy: epoch
6
+ per_device_train_batch_size: 1
7
+ per_device_eval_batch_size: 1
8
+ save_steps: 10000
9
+ save_total_limit: 2
10
+ logging_dir: ./logs
11
+ logging_steps: 500
12
+ num_train_epochs: 1
13
+ report_to: none
14
+ assistant_config:
15
+ parser_config:
16
+ include_payoff: true
17
+ model_path: models/gpt2_v5/1e-5/2025_05_20__10_27_38
18
+ generation_params:
19
+ temperature: 0.7
@@ -0,0 +1,30 @@
1
+ dataset_file: data/samples/<<ds>>
2
+ final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
3
+ trainer_args:
4
+ output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
5
+ eval_strategy: "steps"
6
+ per_device_train_batch_size: 24 # must be divisible by generations per prompt
7
+ per_device_eval_batch_size: 24
8
+ save_steps: 0.1
9
+ save_total_limit: 3
10
+ logging_dir: "./logs"
11
+ logging_steps: 0.01
12
+ eval_steps: 0.1
13
+ do_eval: true
14
+ num_train_epochs: 1.0
15
+ learning_rate: "<<float: lr>>"
16
+ report_to: "wandb"
17
+ run_name: <<job_name>>/<<job_id>>
18
+ lr_scheduler_type: constant
19
+ # GRPO specific args
20
+ num_generations: 8
21
+ beta: 0.04
22
+ epsilon: 0.2
23
+ max_prompt_length: 512
24
+ max_completion_length: 15
25
+ assistant_config:
26
+ parser_config:
27
+ include_payoff: true
28
+ model_path: models/gpt2_v5/1e-5/2025_05_26__18_02_35
29
+ generation_params:
30
+ temperature: 0.7
@@ -0,0 +1,30 @@
1
+ dataset_file: data/samples/<<ds>>
2
+ final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
3
+ trainer_args:
4
+ output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
5
+ eval_strategy: "steps"
6
+ per_device_train_batch_size: 8 # must be divisible by generations per prompt
7
+ per_device_eval_batch_size: 8
8
+ save_steps: 0.5 # Very high number to avoid frequent saves during debug
9
+ save_total_limit: 1 # Keep only one checkpoint
10
+ logging_dir: "./logs"
11
+ logging_steps: 0.01
12
+ eval_steps: 0.1
13
+ do_eval: true
14
+ num_train_epochs: 0.005
15
+ learning_rate: "<<float: lr>>"
16
+ report_to: "wandb"
17
+ run_name: <<job_name>>/<<job_id>>
18
+ lr_scheduler_type: constant
19
+ # GRPO specific args
20
+ num_generations: 8
21
+ beta: 0.04
22
+ epsilon: 0.2
23
+ max_prompt_length: 512
24
+ max_completion_length: 15
25
+ assistant_config:
26
+ parser_config:
27
+ include_payoff: true
28
+ model_path: models/gpt2_v5/1e-5/2025_05_26__18_02_35
29
+ generation_params:
30
+ temperature: 0.7
@@ -0,0 +1,32 @@
1
+ dataset_file: data/samples/<<ds>>
2
+ final_model_path: models/<<group_name>>/<<job_name>>/<<job_id>>
3
+ trainer_args:
4
+ output_dir: checkpoints/<<group_name>>/<<job_name>>/<<job_id>>
5
+ eval_strategy: "steps"
6
+ per_device_train_batch_size: 8 # must be divisible by generations per prompt
7
+ per_device_eval_batch_size: 8
8
+ save_steps: 1 # Very high number to avoid frequent saves during debug
9
+ save_total_limit: 1 # Keep only one checkpoint
10
+ logging_dir: "./logs"
11
+ logging_steps: 0.1
12
+ eval_steps: 0.1
13
+ do_eval: true
14
+ num_train_epochs: 0.001
15
+ learning_rate: "<<float: lr>>"
16
+ report_to: "wandb"
17
+ run_name: <<job_name>>/<<job_id>>
18
+ lr_scheduler_type: constant
19
+ # GRPO specific args
20
+ beta: 0.04
21
+ epsilon: 0.2
22
+ max_prompt_length: 512
23
+ max_completion_length: 15
24
+ use_vllm: true
25
+ vllm_mode: "colocate"
26
+ vllm_gpu_memory_utilization: 0.10
27
+ assistant_config:
28
+ parser_config:
29
+ include_payoff: true
30
+ model_path: models/gpt2_v5/1e-5/2025_05_26__18_02_35
31
+ generation_params:
32
+ temperature: 0.7
model/config.py ADDED
@@ -0,0 +1,54 @@
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class ParserConfig(BaseModel):
7
+ include_payoff: bool = True
8
+
9
+
10
+ class NKAssistantConfig(BaseModel):
11
+ parser_config: ParserConfig
12
+ use_mpcdf_vllm: bool = False
13
+ model_path: Optional[str] = None
14
+ generation_params: Optional[Dict[str, Any]] = None
15
+
16
+
17
+ class TrainConfig(BaseModel):
18
+ trainer_args: Any
19
+ dataset_file: str = Field(..., description="The path to the dataset file")
20
+ final_model_path: str = Field(
21
+ ..., description="The path to save the final model"
22
+ )
23
+ assistant_config: NKAssistantConfig
24
+
25
+
26
+ class InferenceConfig(BaseModel):
27
+ dataset_file: str = Field(..., description="The path to the dataset file")
28
+ model_path: str = Field(
29
+ ..., description="The path to the fine-tuned model"
30
+ )
31
+ output_dataset_file: str = Field(
32
+ ..., description="The path to save the dataset with suggestions"
33
+ )
34
+ generation_params: Optional[Dict[str, Any]] = Field(
35
+ default=None, description="Generation parameters"
36
+ )
37
+ splits: List[str] = Field(
38
+ default=["train", "test"], description="Dataset splits to process"
39
+ )
40
+ max_test_samples: Optional[int] = Field(
41
+ default=None, description="Maximum number of test samples to process"
42
+ )
43
+
44
+
45
+ class DataSetConfig(BaseModel):
46
+ input_file: str
47
+ output_file: str
48
+ samples_per_landscapes: int = 1
49
+ constraints_range: Tuple[int, int] = (0, 8)
50
+ n_samples: List[int] = [8]
51
+ include_payoff: bool = True
52
+ test_ratio: float = 0.1
53
+ exp_ratio: float = 0.1
54
+ debug_size: Optional[int] = None
model/dataset.py ADDED
@@ -0,0 +1,324 @@
1
+ import argparse
2
+ import random
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from datasets import Dataset, DatasetDict
8
+
9
+ from src.model.config import DataSetConfig
10
+ from src.model.parser import create_context, create_target
11
+ from src.utils.utils import BIN_ARRAY, load_config_from_yaml
12
+
13
+
14
+ def binary_vectors_within_radius(
15
+ origin: int,
16
+ radius: int,
17
+ n_dim: int = 8,
18
+ ) -> np.ndarray:
19
+ """
20
+ Create the list of indices of all binary vectors in a ball of radius r
21
+ around origin.
22
+ Args:
23
+ origin (int): Integer representation of the binary vector to center
24
+ around.
25
+ radius (int): Maximum Hamming distance from origin to include in the
26
+ ball.
27
+ n_dim (int, optional): Number of dimensions of binary vectors. Must be
28
+ <= 8. Defaults to 8.
29
+
30
+ Returns:
31
+ np.ndarray: Array of indices corresponding to binary vectors within
32
+ radius of origin.
33
+ """
34
+ assert n_dim <= 8, "n_dim must be less than or equal to 8"
35
+ # create a list of all binary vectors of length n_dim
36
+ all_vectors = BIN_ARRAY[: 2**n_dim]
37
+ origin_vector = all_vectors[origin]
38
+
39
+ # calculate hamming distances to origin using element-wise comparison
40
+ distances = np.sum(all_vectors != origin_vector, axis=1)
41
+ # return the indices of vectors within the radius
42
+ ball_indices = np.where(distances <= radius)[0]
43
+
44
+ return ball_indices
45
+
46
+
47
+ def get_rank_score(
48
+ payoff: np.ndarray,
49
+ constraint_idx: np.ndarray,
50
+ ) -> np.ndarray:
51
+ """
52
+ Calculates a ranking score for the indices within the constraint indices.
53
+ It ranks the indices within the constraint area and gives them a score
54
+ according to their rank. The indices outside the constraint area are given
55
+ a score of -1.
56
+
57
+ Args:
58
+ payoff (np.ndarray): Array of payoff values.
59
+ constraint_idx (np.ndarray): Boolean array indicating which indices
60
+ are constrained.
61
+
62
+ Returns:
63
+ np.ndarray: Array where constrained indices have values from 0 to 1
64
+ based on their rank (0 for lowest payoff, 1 for highest), and
65
+ unconstrained indices have -1.
66
+ """
67
+ # Initialize all scores to -1
68
+ rank_score = np.full_like(payoff, fill_value=-1, dtype=float)
69
+
70
+ # Get payoffs only for constrained indices
71
+ constrained_payoffs = payoff[constraint_idx]
72
+
73
+ # Calculate rank scores from 0 to 1
74
+ n_constrained = len(constrained_payoffs)
75
+ if n_constrained > 1:
76
+ ranks = np.linspace(0, 1, n_constrained)
77
+ else:
78
+ ranks = np.array([1.0]) # If only one point, give it rank 1
79
+
80
+ # Assign ranks to constrained indices
81
+ rank_score[constraint_idx] = ranks[np.argsort(constrained_payoffs)]
82
+
83
+ return rank_score
84
+
85
+
86
+ def create_sample(
87
+ payoff: np.ndarray, hamming_distance: Tuple[int, int], n_samples: List[int]
88
+ ) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
89
+
90
+ assert all(map(lambda x: x <= len(payoff), n_samples)), (
91
+ "n_samples must be less than or equal to the length of the "
92
+ "payoff array"
93
+ )
94
+ assert hamming_distance[0] <= hamming_distance[1], (
95
+ "constraints_range must be a tuple of two integers where the first "
96
+ "is less than or equal to the second"
97
+ )
98
+ assert hamming_distance[0] > 0, (
99
+ "constraints_range must be a tuple of two integers where the first "
100
+ "is greater than 0"
101
+ )
102
+
103
+ # create a list of all idx [0, 1, 2, ..., len(payoff) - 1]
104
+ # as integer representation of a binary vector
105
+ all_idx = np.arange(len(payoff))
106
+
107
+ # random choice of sample size and number of constraints
108
+ sample_size = np.random.choice(n_samples)
109
+ hamming_distance = np.random.randint(
110
+ hamming_distance[0], hamming_distance[1]
111
+ )
112
+
113
+ # creating a list of random sample indices
114
+ sample_idx = np.random.choice(a=all_idx, size=sample_size, replace=False)
115
+
116
+ sample_payoffs = payoff[sample_idx]
117
+ sample_idx = sample_idx[np.argsort(sample_payoffs)]
118
+
119
+ # random choice of origin
120
+ origin_idx = np.random.choice(all_idx)
121
+
122
+ # create a list of indices fullfilling the constraints
123
+ # to get the target index
124
+ constraint_idx = binary_vectors_within_radius(
125
+ origin=origin_idx,
126
+ radius=hamming_distance,
127
+ n_dim=int(np.log2(len(payoff))),
128
+ )
129
+ constraint_idx = constraint_idx[constraint_idx != origin_idx]
130
+
131
+ # get the target index as the index with the highest payoff
132
+ # that is not the origin
133
+ target_idx = constraint_idx[np.argmax(payoff[constraint_idx])]
134
+
135
+ # get the rank score for all indices
136
+ rank_score = get_rank_score(payoff, constraint_idx)
137
+
138
+ return sample_idx, target_idx, origin_idx, hamming_distance, rank_score
139
+
140
+
141
+ def create_raw_dataset(df: pd.DataFrame, config: DataSetConfig) -> Dataset:
142
+ examples = {
143
+ "landscape_id": [],
144
+ "sample": [],
145
+ "n": [],
146
+ "k": [],
147
+ "power_scale": [],
148
+ "payoffs": [],
149
+ "sample_idx": [],
150
+ "target_idx": [],
151
+ "origin_idx": [],
152
+ "hamming_distance": [],
153
+ "ranks": [],
154
+ }
155
+
156
+ from tqdm import tqdm
157
+
158
+ total_iterations = (
159
+ len(df.groupby("landscape_uuid")) * config.samples_per_landscapes
160
+ )
161
+ pbar = tqdm(total=total_iterations, desc="Creating dataset")
162
+ for s in range(config.samples_per_landscapes):
163
+ for id, group in df.groupby("landscape_uuid"):
164
+
165
+ n = group["n"].values[0].item()
166
+ k = group["k"].values[0].item()
167
+ power_scale = group["power_scale"].values[0].item()
168
+
169
+ # Sort the group by binary coordinates
170
+ group = sort_by_binary_coords(group, n)
171
+
172
+ # create a sample
173
+ payoffs = group["payoff"].values.astype(np.int32)
174
+
175
+ (
176
+ sample_idxs,
177
+ target_idx,
178
+ origin_idx,
179
+ hamming_distance,
180
+ ranks,
181
+ ) = create_sample(
182
+ payoff=payoffs,
183
+ hamming_distance=config.constraints_range,
184
+ n_samples=config.n_samples,
185
+ )
186
+
187
+ examples["landscape_id"].append(id)
188
+ examples["sample"].append(s)
189
+ examples["n"].append(n)
190
+ examples["k"].append(k)
191
+ examples["power_scale"].append(power_scale)
192
+ examples["payoffs"].append(payoffs)
193
+ examples["sample_idx"].append(sample_idxs)
194
+ examples["target_idx"].append(target_idx)
195
+ examples["origin_idx"].append(origin_idx)
196
+ examples["hamming_distance"].append(hamming_distance)
197
+ examples["ranks"].append(ranks)
198
+
199
+ pbar.update(1)
200
+
201
+ pbar.close()
202
+
203
+ return Dataset.from_dict(examples)
204
+
205
+
206
+ def create_context_from_row(row: dict, include_payoff: bool = True) -> dict:
207
+ sample_payoff = np.array(row["payoffs"])[row["sample_idx"]]
208
+
209
+ context = create_context(
210
+ n=row["n"],
211
+ k=row["k"],
212
+ power_scale=row["power_scale"],
213
+ sample_idxs=row["sample_idx"],
214
+ origin_idx=row["origin_idx"],
215
+ hamming_distance=row["hamming_distance"],
216
+ payoff=sample_payoff,
217
+ include_payoff=include_payoff,
218
+ )
219
+ return {"context": context}
220
+
221
+
222
+ def create_target_from_row(row: dict) -> dict:
223
+ target = create_target(target_idx=row["target_idx"])
224
+ return {"target": target}
225
+
226
+
227
+ def create_dataset_from_file(config: DataSetConfig) -> DatasetDict:
228
+ df = pd.read_parquet(config.input_file)
229
+ landscape_uuid = df["landscape_uuid"].unique()
230
+ random.shuffle(landscape_uuid)
231
+
232
+ # Split landscape IDs into train, test, exp
233
+ split_points = [
234
+ int((1 - config.test_ratio - config.exp_ratio) * len(landscape_uuid)),
235
+ int((1 - config.exp_ratio) * len(landscape_uuid)),
236
+ ]
237
+ train_ids, test_ids, exp_ids = np.split(landscape_uuid, split_points)
238
+
239
+ # Assertions to ensure splits are correct
240
+ assert len(train_ids) + len(test_ids) + len(exp_ids) == len(landscape_uuid)
241
+ assert len(set(train_ids) & set(test_ids)) == 0
242
+ assert len(set(train_ids) & set(exp_ids)) == 0
243
+ assert len(set(test_ids) & set(exp_ids)) == 0
244
+
245
+ # Apply debug size limitation if specified
246
+ if config.debug_size is not None:
247
+ dbs = config.debug_size
248
+ train_ids, test_ids, exp_ids = (
249
+ train_ids[:dbs],
250
+ test_ids[:dbs],
251
+ exp_ids[:dbs],
252
+ )
253
+
254
+ # Create datasets for each split
255
+ datasets = {}
256
+ for split_name, split_ids in [
257
+ ("train", train_ids),
258
+ ("test", test_ids),
259
+ ("exp", exp_ids),
260
+ ]:
261
+ # Create raw dataset and apply context/target transformation
262
+ split_df = df[df["landscape_uuid"].isin(split_ids)]
263
+ datasets[split_name] = create_raw_dataset(split_df, config)
264
+
265
+ ds = DatasetDict(datasets)
266
+ return ds
267
+
268
+
269
+ def apply_context_target(
270
+ ds: DatasetDict, config: DataSetConfig
271
+ ) -> DatasetDict:
272
+ def create_context_from_row_(row):
273
+ return create_context_from_row(row, config.include_payoff)
274
+
275
+ ds = ds.map(create_context_from_row_, batched=False)
276
+ ds = ds.map(create_target_from_row, batched=False)
277
+ return ds
278
+
279
+
280
+ def sort_by_binary_coords(df: pd.DataFrame, n: int) -> pd.DataFrame:
281
+ """
282
+ Sort dataframe by converting binary coordinates to integer representation.
283
+ Args:
284
+ df (pd.DataFrame): Input dataframe with binary coordinates
285
+ n (int): Number of binary coordinate columns
286
+ Returns:
287
+ pd.DataFrame: Sorted dataframe with sort_key column
288
+ """
289
+ # Get the last n columns as binary coordinates
290
+ coord_cols = df.columns[-n:]
291
+ # Convert binary coordinates to integer and sort
292
+ df["sort_key"] = df[coord_cols].apply(
293
+ lambda x: int("".join(map(str, x)), 2), axis=1
294
+ )
295
+ df = df.sort_values("sort_key")
296
+ return df
297
+
298
+
299
+ if __name__ == "__main__":
300
+ print("Starting dataset creation...")
301
+ # Parse command-line arguments
302
+ parser = argparse.ArgumentParser(
303
+ description="Generate dataset from config."
304
+ )
305
+ parser.add_argument(
306
+ "--config", type=str, required=True, help="Path to YAML config file"
307
+ )
308
+ args = parser.parse_args()
309
+
310
+ # Load config from YAML file
311
+ config = load_config_from_yaml(args.config, DataSetConfig)
312
+
313
+ print(config)
314
+
315
+ # Create and save dataset
316
+ dataset = create_dataset_from_file(config)
317
+
318
+ print(dataset)
319
+
320
+ dataset = apply_context_target(dataset, config)
321
+
322
+ dataset.save_to_disk(config.output_file)
323
+
324
+ print(f"Dataset saved to {config.output_file}")
model/inference.py ADDED
@@ -0,0 +1,51 @@
1
+ import argparse
2
+
3
+ from datasets import load_from_disk
4
+
5
+ from src.model.config import InferenceConfig
6
+ from src.model.nk_assistant import NKAssistant
7
+ from src.utils.utils import load_config_from_yaml
8
+
9
+
10
+ def run_inference(config: InferenceConfig):
11
+ # Load dataset and model
12
+ ds = load_from_disk(config.dataset_file)
13
+
14
+ # Load model and tokenizer
15
+ ass = NKAssistant.from_pretrained(config.model_path)
16
+
17
+ if config.generation_params:
18
+ ass.generation_params = config.generation_params
19
+
20
+ # Process each split
21
+ result_ds = ds
22
+ for split in config.splits:
23
+ if split in ds:
24
+ print(f"Processing {split} split...")
25
+ if config.max_test_samples:
26
+ print(f"Processing {config.max_test_samples} test samples...")
27
+ ds[split] = ds[split].select(range(config.max_test_samples))
28
+ result_ds[split] = ds[split].map(ass.suggest_from_row)
29
+
30
+ # Save the dataset with suggestions
31
+ result_ds.save_to_disk(config.output_dataset_file)
32
+ print(f"Dataset with suggestions saved to {config.output_dataset_file}")
33
+
34
+
35
+ if __name__ == "__main__":
36
+ # Parse command-line arguments
37
+ parser = argparse.ArgumentParser(
38
+ description="Run inference with a fine-tuned model."
39
+ )
40
+ parser.add_argument(
41
+ "--config", type=str, required=True, help="Path to YAML config file"
42
+ )
43
+ args = parser.parse_args()
44
+
45
+ # Load config from YAML file
46
+ config = load_config_from_yaml(args.config, InferenceConfig)
47
+
48
+ # Run inference
49
+ run_inference(config)
50
+
51
+ print("Inference completed successfully.")